diff --git a/plot_serializer/matplotlib/serializer.py b/plot_serializer/matplotlib/serializer.py index b895c417a97a08a64240c499bfbbc6127c9458dc..6876f41de7d18aabb65dfe452b5b53b07a612411 100644 --- a/plot_serializer/matplotlib/serializer.py +++ b/plot_serializer/matplotlib/serializer.py @@ -101,6 +101,14 @@ PLOTTING_METHODS = [ ] +def inherit_and_extend_doc(base_class, method_name, additional_doc): + def decorator(func): + func.__doc__ = getattr(base_class, method_name).__doc__ + additional_doc + return func + + return decorator + + def _convert_matplotlib_color( self, color_list: Any, length: int, cmap: Any, norm: Any ) -> Tuple[List[str] | None, bool]: @@ -157,6 +165,7 @@ class _AxesProxy(Proxy[MplAxes]): self._serializer = serializer self._plot: Optional[Plot] = None + @inherit_and_extend_doc(MplAxes, "plot", "\n\n Serialized parameters: x, y, color, marker, label. \n\n") def pie(self, x, **kwargs: Any) -> Any: try: result = self.delegate.pie(x, **kwargs) @@ -169,24 +178,22 @@ class _AxesProxy(Proxy[MplAxes]): if self._plot is not None: raise NotImplementedError("PlotSerializer does not yet support adding multiple plots per axes!") - slices: List[Slice] = [] - explode_list = kwargs.get("explode") label_list = kwargs.get("labels") radius = kwargs.get("radius") or None - color_list = kwargs.get("colors") c = kwargs.get("c") - if c is not None and color_list is None: - color_list = c - color_list = _convert_matplotlib_color(self, color_list, len(x), cmap="viridis", norm="linear")[0] x = np.asarray(x, np.float32) if not explode_list: explode_list = itertools.repeat(None) if not label_list: label_list = itertools.repeat(None) + if c is not None and color_list is None: + color_list = c + color_list = _convert_matplotlib_color(self, color_list, len(x), cmap="viridis", norm="linear")[0] + slices: List[Slice] = [] for index, (xi, label, explode) in enumerate(zip(x, label_list, explode_list)): color = color_list[index] if len(color_list) > index else None slices.append( @@ -200,6 +207,7 @@ class _AxesProxy(Proxy[MplAxes]): ) pie_plot = PiePlot(type="pie", radius=radius, slices=slices) self._plot = pie_plot + except Exception as e: logging.warning( "An unexpected error occurred in PlotSerializer when trying to read plot data! " @@ -209,11 +217,7 @@ class _AxesProxy(Proxy[MplAxes]): return result - pie.__doc__ = ( - "\n\n Serialized parameters: x, labels, colors, explode, radius. \n\n Matplotlib documentation \n\n" - + MplAxes.pie.__doc__ - ) - + @inherit_and_extend_doc(MplAxes, "bar", "\n\n Serialized parameters: x, height, color. \n\n") def bar( self, x, @@ -228,13 +232,10 @@ class _AxesProxy(Proxy[MplAxes]): raise try: - bars: List[Bar2D] = [] - if isinstance(x, float): x = [x] else: x = np.asarray(x) - if isinstance(height, float): height = [height] else: @@ -246,13 +247,11 @@ class _AxesProxy(Proxy[MplAxes]): color_list = c color_list = _convert_matplotlib_color(self, color_list, len(x), cmap="viridis", norm="linear")[0] + bars: List[Bar2D] = [] for index, (xi, h) in enumerate(zip(x, height)): color = color_list[index] if len(color_list) > index else None bars.append(Bar2D(x_i=xi, height=h, color=color)) - # for xi, h, color in zip(x, height, color_list): - # bars.append(Bar2D(x_i=xi, height=h, color=color)) - trace = BarTrace2D(type="bar", datapoints=bars) if self._plot is not None: @@ -271,20 +270,253 @@ class _AxesProxy(Proxy[MplAxes]): return result - """ - Custom wrapper for ax.plot with additional functionality. + def plot(self, *args: Any, **kwargs: Any) -> list[Line2D]: + """ + Serializes x,y, linewidth, linestyle, marker, color, label. - {plt.Axes.plot.__doc__} + Original matplotlib documentation: - Parameters - ---------- - *args : tuple - Positional arguments passed to `ax.plot`. - **kwargs : dict - Keyword arguments passed to `ax.plot`. - """ + Plot y versus x as lines and/or markers. - def plot(self, *args: Any, **kwargs: Any) -> list[Line2D]: + Call signatures:: + + plot([x], y, [fmt], *, data=None, **kwargs) + plot([x], y, [fmt], [x2], y2, [fmt2], ..., **kwargs) + + The coordinates of the points or line nodes are given by *x*, *y*. + + The optional parameter *fmt* is a convenient way for defining basic + formatting like color, marker and linestyle. It's a shortcut string + notation described in the *Notes* section below. + + >>> plot(x, y) # plot x and y using default line style and color + >>> plot(x, y, "bo") # plot x and y using blue circle markers + >>> plot(y) # plot y using x as index array 0..N-1 + >>> plot(y, "r+") # ditto, but with red plusses + + You can use `.Line2D` properties as keyword arguments for more + control on the appearance. Line properties and *fmt* can be mixed. + The following two calls yield identical results: + + >>> plot(x, y, "go--", linewidth=2, markersize=12) + >>> plot(x, y, color="green", marker="o", linestyle="dashed", linewidth=2, markersize=12) + + When conflicting with *fmt*, keyword arguments take precedence. + + + **Plotting labelled data** + + There's a convenient way for plotting objects with labelled data (i.e. + data that can be accessed by index ``obj['y']``). Instead of giving + the data in *x* and *y*, you can provide the object in the *data* + parameter and just give the labels for *x* and *y*:: + + >>> plot("xlabel", "ylabel", data=obj) + + All indexable objects are supported. This could e.g. be a `dict`, a + `pandas.DataFrame` or a structured numpy array. + + + **Plotting multiple sets of data** + + There are various ways to plot multiple sets of data. + + - The most straight forward way is just to call `plot` multiple times. + Example: + + >>> plot(x1, y1, "bo") + >>> plot(x2, y2, "go") + + - If *x* and/or *y* are 2D arrays, a separate data set will be drawn + for every column. If both *x* and *y* are 2D, they must have the + same shape. If only one of them is 2D with shape (N, m) the other + must have length N and will be used for every data set m. + + Example: + + >>> x = [1, 2, 3] + >>> y = np.array([[1, 2], [3, 4], [5, 6]]) + >>> plot(x, y) + + is equivalent to: + + >>> for col in range(y.shape[1]): + ... plot(x, y[:, col]) + + - The third way is to specify multiple sets of *[x]*, *y*, *[fmt]* + groups:: + + >>> plot(x1, y1, "g^", x2, y2, "g-") + + In this case, any additional keyword argument applies to all + datasets. Also, this syntax cannot be combined with the *data* + parameter. + + By default, each line is assigned a different style specified by a + 'style cycle'. The *fmt* and line property parameters are only + necessary if you want explicit deviations from these defaults. + Alternatively, you can also change the style cycle using + :rc:`axes.prop_cycle`. + + + Parameters + ---------- + x, y : array-like or scalar + The horizontal / vertical coordinates of the data points. + *x* values are optional and default to ``range(len(y))``. + + Commonly, these parameters are 1D arrays. + + They can also be scalars, or two-dimensional (in that case, the + columns represent separate data sets). + + These arguments cannot be passed as keywords. + + fmt : str, optional + A format string, e.g. 'ro' for red circles. See the *Notes* + section for a full description of the format strings. + + Format strings are just an abbreviation for quickly setting + basic line properties. All of these and more can also be + controlled by keyword arguments. + + This argument cannot be passed as keyword. + + data : indexable object, optional + An object with labelled data. If given, provide the label names to + plot in *x* and *y*. + + .. note:: + Technically there's a slight ambiguity in calls where the + second label is a valid *fmt*. ``plot('n', 'o', data=obj)`` + could be ``plt(x, y)`` or ``plt(y, fmt)``. In such cases, + the former interpretation is chosen, but a warning is issued. + You may suppress the warning by adding an empty format string + ``plot('n', 'o', '', data=obj)``. + + Returns + ------- + list of `.Line2D` + A list of lines representing the plotted data. + + Other Parameters + ---------------- + scalex, scaley : bool, default: True + These parameters determine if the view limits are adapted to the + data limits. The values are passed on to + `~.axes.Axes.autoscale_view`. + + **kwargs : `~matplotlib.lines.Line2D` properties, optional + *kwargs* are used to specify properties like a line label (for + auto legends), linewidth, antialiasing, marker face color. + Example:: + + >>> plot([1, 2, 3], [1, 2, 3], "go-", label="line 1", linewidth=2) + >>> plot([1, 2, 3], [1, 4, 9], "rs", label="line 2") + + If you specify multiple lines with one plot call, the kwargs apply + to all those lines. In case the label object is iterable, each + element is used as labels for each set of data. + + Here is a list of available `.Line2D` properties: + + %(Line2D:kwdoc)s + + See Also + -------- + scatter : XY scatter plot with markers of varying size and/or color ( + sometimes also called bubble chart). + + Notes + ----- + **Format Strings** + + A format string consists of a part for color, marker and line:: + + fmt = "[marker][line][color]" + + Each of them is optional. If not provided, the value from the style + cycle is used. Exception: If ``line`` is given, but no ``marker``, + the data will be a line without markers. + + Other combinations such as ``[color][marker][line]`` are also + supported, but note that their parsing may be ambiguous. + + **Markers** + + ============= =============================== + character description + ============= =============================== + ``'.'`` point marker + ``','`` pixel marker + ``'o'`` circle marker + ``'v'`` triangle_down marker + ``'^'`` triangle_up marker + ``'<'`` triangle_left marker + ``'>'`` triangle_right marker + ``'1'`` tri_down marker + ``'2'`` tri_up marker + ``'3'`` tri_left marker + ``'4'`` tri_right marker + ``'8'`` octagon marker + ``'s'`` square marker + ``'p'`` pentagon marker + ``'P'`` plus (filled) marker + ``'*'`` star marker + ``'h'`` hexagon1 marker + ``'H'`` hexagon2 marker + ``'+'`` plus marker + ``'x'`` x marker + ``'X'`` x (filled) marker + ``'D'`` diamond marker + ``'d'`` thin_diamond marker + ``'|'`` vline marker + ``'_'`` hline marker + ============= =============================== + + **Line Styles** + + ============= =============================== + character description + ============= =============================== + ``'-'`` solid line style + ``'--'`` dashed line style + ``'-.'`` dash-dot line style + ``':'`` dotted line style + ============= =============================== + + Example format strings:: + + "b" # blue markers with default shape + + "or" # red circles + "-g" # green solid line + "--" # dashed line with default color + "^k:" # black triangle_up markers connected by a dotted line + + **Colors** + + The supported color abbreviations are the single letter codes + + ============= =============================== + character color + ============= =============================== + ``'b'`` blue + ``'g'`` green + ``'r'`` red + ``'c'`` cyan + ``'m'`` magenta + ``'y'`` yellow + ``'k'`` black + ``'w'`` white + ============= =============================== + + and the ``'CN'`` colors that index into the default property cycle. + + If the color is the only part of the format string, you can + additionally use any `matplotlib.colors` spec, e.g. full names + (``'green'``) or hex strings (``'#008000'``). + """ try: mpl_lines = self.delegate.plot(*args, **kwargs) except Exception as e: @@ -296,26 +528,22 @@ class _AxesProxy(Proxy[MplAxes]): traces: List[LineTrace2D] = [] for mpl_line in mpl_lines: + xdata = mpl_line.get_xdata() + ydata = mpl_line.get_ydata() + thickness = mpl_line.get_linewidth() + linestyle = mpl_line.get_linestyle() + marker = mpl_line.get_marker() + label = mpl_line.get_label() color_list = kwargs.get("color") c = kwargs.get("c") if c is not None and color_list is None: color_list = c - - xdata = mpl_line.get_xdata() - ydata = mpl_line.get_ydata() - print(type(xdata)) + color_list = _convert_matplotlib_color(self, color_list, len(xdata), cmap="viridis", norm="linear")[0] points: List[Point2D] = [] - for x, y in zip(xdata, ydata): points.append(Point2D(x=x, y=y)) - label = mpl_line.get_label() - color_list = _convert_matplotlib_color(self, color_list, len(xdata), cmap="viridis", norm="linear")[0] - thickness = mpl_line.get_linewidth() - linestyle = mpl_line.get_linestyle() - marker = mpl_line.get_marker() - traces.append( LineTrace2D( type="line", @@ -358,6 +586,7 @@ class _AxesProxy(Proxy[MplAxes]): raise try: + verteces = path.get_offsets().tolist() marker = kwargs.get("marker") or "o" color_list = kwargs.get("c") color = kwargs.get("color") @@ -366,8 +595,12 @@ class _AxesProxy(Proxy[MplAxes]): sizes_list = kwargs.get("s") cmap = kwargs.get("cmap") or "viridis" norm = kwargs.get("norm") or "linear" + label = str(path.get_label()) (color_list, cmap_used) = _convert_matplotlib_color(self, color_list, len(x), cmap, norm) + if not cmap_used: + cmap = None + norm = None if sizes_list is not None: sizes_list = path.get_sizes() @@ -376,11 +609,7 @@ class _AxesProxy(Proxy[MplAxes]): if isinstance(sizes_list, np.generic): sizes_list = [sizes_list] * len(x) - label = str(path.get_label()) datapoints: List[Point2D] = [] - - verteces = path.get_offsets().tolist() - for index, (vertex, size) in enumerate(zip(verteces, sizes_list)): color = color_list[index] if len(color_list) > index else None datapoints.append( @@ -391,9 +620,6 @@ class _AxesProxy(Proxy[MplAxes]): size=size, ) ) - if not cmap_used: - cmap = None - norm = None trace: List[ScatterTrace2D] = [] trace.append( ScatterTrace2D(type="scatter", cmap=cmap, norm=norm, label=label, datapoints=datapoints, marker=marker) @@ -430,8 +656,6 @@ class _AxesProxy(Proxy[MplAxes]): conf_intervals = kwargs.get("conf_intervals") labels = kwargs.get("tick_labels") - trace: List[BoxTrace2D] = [] - boxes: List[Box] = [] x = _reshape_2D(x, "x") if not labels: @@ -441,6 +665,8 @@ class _AxesProxy(Proxy[MplAxes]): if not conf_intervals: conf_intervals = itertools.repeat(None) + trace: List[BoxTrace2D] = [] + boxes: List[Box] = [] for dataset, label, umedian, cintervals in zip(x, labels, usermedians, conf_intervals): x = np.ma.asarray(x) x = x.data[~x.mask].ravel() @@ -518,14 +744,14 @@ class _AxesProxy(Proxy[MplAxes]): if yerr is None: yerr = itertools.repeat(None) - print(xerr) - print(xerr.ndim) - if xerr.ndim == 0 or xerr.ndim == 1: xerr = np.broadcast_to(xerr, (2, len(x))) if yerr.ndim == 0 or yerr.ndim == 1: yerr = np.broadcast_to(yerr, (2, len(y))) + color = mcolors.to_hex(color) if color else None + ecolor = mcolors.to_hex(ecolor) if ecolor else None + errorpoints: List[ErrorPoint2D] = [] for xi, yi, x_error, y_error in zip(x, y, xerr.T, yerr.T): errorpoints.append( @@ -536,8 +762,6 @@ class _AxesProxy(Proxy[MplAxes]): yerr=y_error, ) ) - color = mcolors.to_hex(color) if color else None - ecolor = mcolors.to_hex(ecolor) if ecolor else None trace = ErrorBar2DTrace( type="errorbar2d", label=label, @@ -594,12 +818,10 @@ class _AxesProxy(Proxy[MplAxes]): if np.isscalar(x): x = [x] x = _reshape_2D(x, "x") - print(x) color_list = _convert_matplotlib_color(self, color_list, len(x), "viridis", "linear")[0] datasets: List[HistDataset] = [] - for index, (element, label) in enumerate(zip(x, label_list)): color = color_list[index] if len(color_list) > index else None datasets.append(HistDataset(x_i=element, color=color, label=label)) @@ -627,13 +849,6 @@ class _AxesProxy(Proxy[MplAxes]): ) return ret - def _are_lists_same_length(self, *lists) -> bool: - non_empty_lists = [lst for lst in lists if lst] - if not non_empty_lists: - return True - length = len(non_empty_lists[0]) - return all(len(lst) == length for lst in non_empty_lists) - def _on_collect(self) -> None: if self._plot is None: return @@ -697,17 +912,18 @@ class _AxesProxy3D(Proxy[MplAxes3D]): try: sizes_list = kwargs.get("s") marker = kwargs.get("marker") or "o" - color_list = kwargs.get("c") color = kwargs.get("color") if color is not None and color_list is None: color_list = color cmap = kwargs.get("cmap") or "viridis" norm = kwargs.get("norm") or "linear" - (color_list, cmap_used) = _convert_matplotlib_color(self, color_list, len(xs), cmap, norm) + label = str(path.get_label()) - trace: List[ScatterTrace3D] = [] - datapoints: List[Point3D] = [] + (color_list, cmap_used) = _convert_matplotlib_color(self, color_list, len(xs), cmap, norm) + if not cmap_used: + cmap = None + norm = None xs, ys, zs = cbook._broadcast_with_masks(xs, ys, zs) xs, ys, zs, sizes_list, color_list, color = cbook.delete_masked_points( @@ -719,14 +935,12 @@ class _AxesProxy3D(Proxy[MplAxes3D]): if isinstance(sizes_list, (np.generic, float, int)): sizes_list = [sizes_list] * len(xs) + trace: List[ScatterTrace3D] = [] + datapoints: List[Point3D] = [] for index, (xi, yi, zi, s) in enumerate(zip(xs, ys, zs, sizes_list)): c = color_list[index] if len(color_list) > index else None datapoints.append(Point3D(x=xi, y=yi, z=zi, color=c, size=s)) - label = str(path.get_label()) - if not cmap_used: - cmap = None - norm = None trace.append( ScatterTrace3D( type="scatter3D", cmap=cmap, norm=norm, label=label, datapoints=datapoints, marker=marker @@ -763,19 +977,18 @@ class _AxesProxy3D(Proxy[MplAxes3D]): raise try: + mpl_line = path[0] + xdata, ydata, zdata = mpl_line.get_data_3d() + label = mpl_line.get_label() + thickness = mpl_line.get_linewidth() + linestyle = mpl_line.get_linestyle() marker = kwargs.get("marker") or None color_list = kwargs.get("color") c = kwargs.get("c") if c is not None and color_list is None: color_list = c - color_list = _convert_matplotlib_color(self, color_list, len(x_values), "viridis", "linear")[0] - - mpl_line = path[0] - xdata, ydata, zdata = mpl_line.get_data_3d() - label = mpl_line.get_label() - thickness = mpl_line.get_linewidth() - linestyle = mpl_line.get_linestyle() + color_list = _convert_matplotlib_color(self, color_list, len(x_values), "viridis", "linear")[0] datapoints: List[Point3D] = [] for i in range(len(xdata)): @@ -825,6 +1038,12 @@ class _AxesProxy3D(Proxy[MplAxes3D]): raise try: + color = kwargs.get("color") + c = kwargs.get("c") + if c is not None and color is None: + color = c + label = surface.get_label() + length = len(x) width = len(x[0]) @@ -833,13 +1052,6 @@ class _AxesProxy3D(Proxy[MplAxes3D]): traces: List[SurfaceTrace3D] = [] datapoints: List[Point3D] = [] - - color = kwargs.get("color") - c = kwargs.get("c") - if c is not None and color is None: - color = c - label = surface.get_label() - for xi, yi, zi in zip(x, y, z): for xj, yj, zj in zip(xi, yi, zi): datapoints.append( @@ -947,7 +1159,7 @@ class MatplotlibSerializer(Serializer): self, *args: Any, **kwargs: Any, - ) -> Tuple[MplFigure, Union[MplAxes, MplAxes3D, Any]]: + ) -> Tuple[MplFigure, Union[MplAxes, MplAxes3D, _AxesProxy, _AxesProxy3D, Any]]: figure, axes = matplotlib.pyplot.subplots(*args, **kwargs) new_axes: Any diff --git a/plot_serializer/matplotlib/serializer.pyi b/plot_serializer/matplotlib/serializer.pyi index 18498d752cd1444da2b35979686bb4cbd7c59d1e..50bc65b99d4ab21988ff833f7bee52033d3d1954 100644 --- a/plot_serializer/matplotlib/serializer.pyi +++ b/plot_serializer/matplotlib/serializer.pyi @@ -12,7 +12,7 @@ from typing import ( from matplotlib.axes import Axes as MplAxes from matplotlib.figure import Figure as MplFigure -from plot_serializer.serializer import Serializer +from plot_serializer.serializer import Serializer, _AxesProxy class MatplotlibSerializer(Serializer): # Fancy way of properly type hinting the subplots method... @@ -30,7 +30,7 @@ class MatplotlibSerializer(Serializer): subplot_kw: None = None, gridspec_kw: Optional[Dict[str, Any]] = None, **fig_kw: Any, - ) -> Tuple[MplFigure, MplAxes]: ... + ) -> Tuple[MplFigure, _AxesProxy]: ... @overload def subplots( self, @@ -45,5 +45,5 @@ class MatplotlibSerializer(Serializer): subplot_kw: Optional[Dict[str, Any]] = None, gridspec_kw: Optional[Dict[str, Any]] = None, **fig_kw: Any, - ) -> Tuple[MplFigure, Any]: ... + ) -> Tuple[MplFigure, _AxesProxy]: ... def show(self, *, block: Optional[bool] = None) -> None: ...