import itertools import logging from typing import ( Any, List, Optional, Tuple, Union, ) import matplotlib.cbook as cbook import matplotlib.cm as cm import matplotlib.colors as mcolors import matplotlib.pyplot import numpy as np from matplotlib.axes import Axes as MplAxes from matplotlib.cbook import _reshape_2D from matplotlib.collections import PathCollection from matplotlib.container import BarContainer, ErrorbarContainer from matplotlib.figure import Figure as MplFigure from matplotlib.lines import Line2D from matplotlib.patches import Polygon from mpl_toolkits.mplot3d.art3d import Path3DCollection, Poly3DCollection from mpl_toolkits.mplot3d.axes3d import Axes3D as MplAxes3D from numpy import ndarray from plot_serializer.model import ( Axis, Bar2D, BarTrace2D, Box, BoxTrace2D, ErrorBar2DTrace, ErrorPoint2D, Figure, HistDataset, HistogramTrace, LineTrace2D, LineTrace3D, PiePlot, Plot, Plot2D, Plot3D, Point2D, Point3D, ScatterTrace2D, ScatterTrace3D, Slice, SurfaceTrace3D, ) from plot_serializer.proxy import Proxy from plot_serializer.serializer import Serializer __all__ = ["MatplotlibSerializer"] PLOTTING_METHODS = [ "plot", "errorbar", "hist", "scatter", "step", "loglog", "semilogx", "semilogy", "bar", "barh", "stem", "eventplot", "pie", "stackplot", "broken_barh", "fill", "acorr", "angle_spectrum", "cohere", "csd", "magnitude_spectrum", "phase_spectrum", "psd", "specgram", "xcorr", "ecdf", "boxplot", "violinplot", "bxp", "violin", "hexbin", "hist", "hist2d", "contour", "contourf", "imshow", "matshow", "pcolor", "pcolorfast", "pcolormesh", "spy", "tripcolor", "triplot", "tricontour" "tricontourf", ] def _convert_matplotlib_color( self, color_list: Any, length: int, cmap: Any, norm: Any ) -> Tuple[List[str] | None, bool]: cmap_used = False if not color_list: return ([None], cmap_used) colors: List[str] = [] color_type = type(color_list) if isinstance(color_list, np.generic): color_list = color_list.item() elif isinstance(color_list, np.ndarray): color_list = color_list.tolist() if color_type is str: colors.append(mcolors.to_hex(color_list, keep_alpha=True)) elif color_type is int or color_type is float: scalar_mappable = cm.ScalarMappable(norm=norm, cmap=cmap) rgba_tuple = scalar_mappable.to_rgba(color_list) hex_value = mcolors.to_hex(rgba_tuple, keep_alpha=True) colors.append(hex_value) cmap_used = True elif color_type is tuple and (len(color_list) == 3 or len(color_list) == 4): hex_value = mcolors.to_hex(color_list, keep_alpha=True) colors.append(hex_value) elif (color_type is list or isinstance(color_list, np.ndarray)) and all( isinstance(item, (int, float)) for item in color_list ): scalar_mappable = cm.ScalarMappable(norm=norm, cmap=cmap) rgba_tuples = scalar_mappable.to_rgba(color_list) hex_values = [mcolors.to_hex(rgba_value, keep_alpha=True) for rgba_value in rgba_tuples] colors.extend(hex_values) cmap_used = True elif color_type is list or isinstance(color_list, np.ndarray): for item in color_list: if (isinstance(item, str)) or (isinstance(item, tuple) and (len(item) == 3 or len(item) == 4)): colors.append(mcolors.to_hex(item, keep_alpha=True)) elif item is None: colors.append(None) else: raise NotImplementedError("Your color is not supported by PlotSerializer, see Documentation for more detail") if not (len(colors) == length): if not (len(colors) - 1): colors = [colors[0] for i in range(length)] else: raise ValueError("the lenth of your color array does not match the length of given data") return (colors, cmap_used) class _AxesProxy(Proxy[MplAxes]): def __init__(self, delegate: MplAxes, figure: Figure, serializer: Serializer) -> None: super().__init__(delegate) self._figure = figure self._serializer = serializer self._plot: Optional[Plot] = None def pie(self, x, **kwargs: Any) -> Any: try: result = self.delegate.pie(x, **kwargs) except Exception as e: add_msg = " - This error was thrown by Matplotlib and is independent of PlotSerializer!" e.args = (e.args[0] + add_msg,) + e.args[1:] if e.args else (add_msg,) raise try: if self._plot is not None: raise NotImplementedError("PlotSerializer does not yet support adding multiple plots per axes!") slices: List[Slice] = [] explode_list = kwargs.get("explode") label_list = kwargs.get("labels") radius = kwargs.get("radius") or None color_list = kwargs.get("colors") c = kwargs.get("c") if c is not None and color_list is None: color_list = c color_list = _convert_matplotlib_color(self, color_list, len(x), cmap="viridis", norm="linear")[0] x = np.asarray(x) if not explode_list: explode_list = itertools.repeat(None) if not label_list: label_list = itertools.repeat(None) for index, (xi, label, explode) in enumerate(zip(x, label_list, explode_list)): color = color_list[index] if len(color_list) > index else None slices.append( Slice( x=xi, radius=radius, explode=explode, label=label, color=color, ) ) pie_plot = PiePlot(type="pie", radius=radius, slices=slices) self._plot = pie_plot except Exception as e: logging.warning( "An unexpected error occurred in PlotSerializer when trying to read plot data! " + "Parts of the plot will not be serialized!", exc_info=e, ) return result def bar( self, x, height, **kwargs: Any, ) -> BarContainer: try: result = self.delegate.bar(x, height, **kwargs) except Exception as e: add_msg = " - This error was thrown by Matplotlib and is independent of PlotSerializer!" e.args = (e.args[0] + add_msg,) + e.args[1:] if e.args else (add_msg,) raise try: bars: List[Bar2D] = [] if isinstance(x, np.generic): x = x.item() if isinstance(x, (float, int, str)): x = [x] else: x = np.asarray(x) if isinstance(height, np.generic): height = height.item() if isinstance(height, (float, int, str)): height = [height] else: height = np.asarray(height) color_list = kwargs.get("color") c = kwargs.get("c") if c is not None and color_list is None: color_list = c color_list = _convert_matplotlib_color(self, color_list, len(x), cmap="viridis", norm="linear")[0] for index, (xi, h) in enumerate(zip(x, height)): color = color_list[index] if len(color_list) > index else None bars.append(Bar2D(x_i=xi, height=h, color=color)) # for xi, h, color in zip(x, height, color_list): # bars.append(Bar2D(x_i=xi, height=h, color=color)) trace = BarTrace2D(type="bar", datapoints=bars) if self._plot is not None: if not isinstance(self._plot, Plot2D): raise NotImplementedError("PlotSerializer does not yet support mixing 2d plots with other plots!") self._plot.traces.append(trace) else: self._plot = Plot2D(type="2d", x_axis=Axis(), y_axis=Axis(), traces=[trace]) except Exception as e: logging.warning( "An unexpected error occurred in PlotSerializer when trying to read plot data! " + "Parts of the plot will not be serialized!", exc_info=e, ) return result """ Custom wrapper for ax.plot with additional functionality. {plt.Axes.plot.__doc__} Parameters ---------- *args : tuple Positional arguments passed to `ax.plot`. **kwargs : dict Keyword arguments passed to `ax.plot`. """ def plot(self, *args: Any, **kwargs: Any) -> list[Line2D]: try: mpl_lines = self.delegate.plot(*args, **kwargs) except Exception as e: add_msg = " - This error was thrown by Matplotlib and is independent of PlotSerializer!" e.args = (e.args[0] + add_msg,) + e.args[1:] if e.args else (add_msg,) raise try: traces: List[LineTrace2D] = [] for mpl_line in mpl_lines: color_list = kwargs.get("color") c = kwargs.get("c") if c is not None and color_list is None: color_list = c xdata = mpl_line.get_xdata() ydata = mpl_line.get_ydata() points: List[Point2D] = [] for x, y in zip(xdata, ydata): points.append(Point2D(x=x, y=y)) label = mpl_line.get_label() color_list = _convert_matplotlib_color(self, color_list, len(xdata), cmap="viridis", norm="linear")[0] thickness = mpl_line.get_linewidth() linestyle = mpl_line.get_linestyle() marker = mpl_line.get_marker() traces.append( LineTrace2D( type="line", color=color_list[0], linewidth=thickness, linestyle=linestyle, label=label, datapoints=points, marker=marker, ) ) if self._plot is not None: if not isinstance(self._plot, Plot2D): raise NotImplementedError("PlotSerializer does not yet support mixing 2d plots with other plots!") self._plot.traces += traces else: self._plot = Plot2D(type="2d", x_axis=Axis(), y_axis=Axis(), traces=traces) except Exception as e: logging.warning( "An unexpected error occurred in PlotSerializer when trying to read plot data! " + "Parts of the plot will not be serialized!", exc_info=e, ) return mpl_lines def scatter( self, x, y, *args: Any, **kwargs: Any, ) -> PathCollection: try: path = self.delegate.scatter(x, y, *args, **kwargs) except Exception as e: add_msg = " - This error was thrown by Matplotlib and is independent of PlotSerializer!" e.args = (e.args[0] + add_msg,) + e.args[1:] if e.args else (add_msg,) raise try: marker = kwargs.get("marker") or "o" color_list = kwargs.get("c") color = kwargs.get("color") if color is not None and color_list is None: color_list = color sizes_list = kwargs.get("s") cmap = kwargs.get("cmap") or "viridis" norm = kwargs.get("norm") or "linear" if isinstance(x, np.generic): x = x.item() if isinstance(x, (float, int, str)): x = [x] (color_list, cmap_used) = _convert_matplotlib_color(self, color_list, len(x), cmap, norm) if sizes_list is not None: sizes_list = path.get_sizes() else: sizes_list = itertools.repeat(None) if isinstance(sizes_list, np.generic): sizes_list = [sizes_list] * len(x) label = str(path.get_label()) datapoints: List[Point2D] = [] verteces = path.get_offsets().tolist() for index, (vertex, size) in enumerate(zip(verteces, sizes_list)): color = color_list[index] if len(color_list) > index else None datapoints.append( Point2D( x=vertex[0], y=vertex[1], color=color, size=size, ) ) if not cmap_used: cmap = None norm = None trace: List[ScatterTrace2D] = [] trace.append( ScatterTrace2D(type="scatter", cmap=cmap, norm=norm, label=label, datapoints=datapoints, marker=marker) ) if self._plot is not None: if not isinstance(self._plot, Plot2D): raise NotImplementedError("PlotSerializer does not yet support mixing 2d plots with other plots!") self._plot.traces += trace else: self._plot = Plot2D(type="2d", x_axis=Axis(), y_axis=Axis(), traces=trace) except Exception as e: logging.warning( "An unexpected error occurred in PlotSerializer when trying to read plot data! " + "Parts of the plot will not be serialized!", exc_info=e, ) return path def boxplot(self, x, *args, **kwargs) -> dict: try: dic = self.delegate.boxplot(x, *args, **kwargs) except Exception as e: add_msg = " - This error was thrown by Matplotlib and is independent of PlotSerializer!" e.args = (e.args[0] + add_msg,) + e.args[1:] if e.args else (add_msg,) raise try: notch = kwargs.get("notch") or None whis = kwargs.get("whis") or None bootstrap = kwargs.get("bootstrap") usermedians = kwargs.get("usermedians") conf_intervals = kwargs.get("conf_intervals") labels = kwargs.get("tick_labels") trace: List[BoxTrace2D] = [] boxes: List[Box] = [] x = _reshape_2D(x, "x") if not labels: labels = itertools.repeat(None) if not usermedians: usermedians = itertools.repeat(None) if not conf_intervals: conf_intervals = itertools.repeat(None) for dataset, label, umedian, cintervals in zip(x, labels, usermedians, conf_intervals): x = np.ma.asarray(x) x = x.data[~x.mask].ravel() boxes.append( Box( x_i=dataset, tick_label=label, usermedian=umedian, conf_interval=cintervals, ) ) trace.append(BoxTrace2D(type="box", x=boxes, notch=notch, whis=whis, bootstrap=bootstrap)) if self._plot is not None: if not isinstance(self._plot, Plot2D): raise NotImplementedError("PlotSerializer does not yet support mixing 2d plots with other plots!") self._plot.traces += trace else: self._plot = Plot2D(type="2d", x_axis=Axis(), y_axis=Axis(), traces=trace) except Exception as e: logging.warning( "An unexpected error occurred in PlotSerializer when trying to read plot data! " + "Parts of the plot will not be serialized!", exc_info=e, ) return dic def errorbar(self, x, y, *args, **kwargs) -> ErrorbarContainer: def _upcast_err(err): """ Imported local function from Matplotlib errorbar function. """ if np.iterable(err) and len(err) > 0 and isinstance(cbook._safe_first_finite(err), np.ndarray): atype = type(cbook._safe_first_finite(err)) if atype is np.ndarray: return np.asarray(err, dtype=object) return atype(err) return np.asarray(err) try: container = self.delegate.errorbar(x, y, *args, **kwargs) except Exception as e: add_msg = " - This error was thrown by Matplotlib and is independent of PlotSerializer!" e.args = (e.args[0] + add_msg,) + e.args[1:] if e.args else (add_msg,) raise try: xerr = kwargs.get("xerr") yerr = kwargs.get("yerr") marker = kwargs.get("marker") or None color = kwargs.get("color") c = kwargs.get("c") if c is not None and color is None: color = c ecolor = kwargs.get("ecolor") label = kwargs.get("label") or None if not isinstance(x, np.ndarray): x = np.asarray(x, dtype=object) if not isinstance(y, np.ndarray): y = np.asarray(y, dtype=object) x, y = np.atleast_1d(x, y) if xerr is not None and not isinstance(xerr, np.ndarray): xerr = _upcast_err(xerr) np.broadcast_to(xerr, (2, len(x))) if yerr is not None and not isinstance(yerr, np.ndarray): yerr = _upcast_err(yerr) np.broadcast_to(xerr, (2, len(y))) if xerr is None: xerr = itertools.repeat(None) else: if xerr.ndim == 0 or xerr.ndim == 1: xerr = np.broadcast_to(xerr, (2, len(x))) xerr = xerr.T if yerr is None: yerr = itertools.repeat(None) else: if yerr.ndim == 0 or yerr.ndim == 1: yerr = np.broadcast_to(yerr, (2, len(y))) yerr = yerr.T errorpoints: List[ErrorPoint2D] = [] for xi, yi, x_error, y_error in zip(x, y, xerr, yerr): errorpoints.append( ErrorPoint2D( x=xi, y=yi, xerr=x_error, yerr=y_error, ) ) color = mcolors.to_hex(color) if color else None ecolor = mcolors.to_hex(ecolor) if ecolor else None trace = ErrorBar2DTrace( type="errorbar2d", label=label, marker=marker, datapoints=errorpoints, color=color, ecolor=ecolor, ) if self._plot is not None: if not isinstance(self._plot, Plot2D): raise NotImplementedError("PlotSerializer does not yet support mixing 2d plots with other plots!") self._plot.traces.append(trace) else: self._plot = Plot2D(type="2d", x_axis=Axis(), y_axis=Axis(), traces=[trace]) except Exception as e: logging.warning( "An unexpected error occurred in PlotSerializer when trying to read plot data! " + "Parts of the plot will not be serialized!", exc_info=e, ) return container def hist( self, x, *args, **kwargs ) -> tuple[ ndarray | list[ndarray], ndarray, BarContainer | Polygon | list[BarContainer | Polygon], ]: try: ret = self.delegate.hist(x, *args, **kwargs) except Exception as e: add_msg = " - This error was thrown by Matplotlib and is independent of PlotSerializer!" e.args = (e.args[0] + add_msg,) + e.args[1:] if e.args else (add_msg,) raise try: bins = kwargs.get("bins", 10) density = kwargs.get("density", False) cumulative = kwargs.get("cumulative", False) label_list = kwargs.get("label") color_list = kwargs.get("color") c = kwargs.get("c") if c is not None and color_list is None: color_list = c if not label_list: label_list = itertools.repeat(None) else: label_list = np.atleast_1d(np.asarray(label_list, str)) if np.isscalar(x): x = [x] x = _reshape_2D(x, "x") color_list = _convert_matplotlib_color(self, color_list, len(x), "viridis", "linear")[0] datasets: List[HistDataset] = [] for index, (element, label) in enumerate(zip(x, label_list)): color = color_list[index] if len(color_list) > index else None datasets.append(HistDataset(x_i=element, color=color, label=label)) trace = HistogramTrace( type="histogram", x=datasets, bins=bins, density=density, cumulative=cumulative, ) if self._plot is not None: if not isinstance(self._plot, Plot2D): raise NotImplementedError("PlotSerializer does not yet support mixing 2d plots with other plots!") self._plot.traces.append(trace) else: self._plot = Plot2D(type="2d", x_axis=Axis(), y_axis=Axis(), traces=[trace]) except Exception as e: logging.warning( "An unexpected error occurred in PlotSerializer when trying to read plot data! " + "Parts of the plot will not be serialized!", exc_info=e, ) return ret def _are_lists_same_length(self, *lists) -> bool: non_empty_lists = [lst for lst in lists if lst] if not non_empty_lists: return True length = len(non_empty_lists[0]) return all(len(lst) == length for lst in non_empty_lists) def _on_collect(self) -> None: if self._plot is None: return self._plot.title = self.delegate.get_title() if isinstance(self._plot, Plot2D): for spine in self.delegate.spines: if not self.delegate.spines[spine].get_visible(): if not self._plot.spines_removed: self._plot.spines_removed = [spine] else: self._plot.spines_removed.append(spine) xlabel = self.delegate.get_xlabel() xscale = self.delegate.get_xscale() self._plot.x_axis.label = xlabel self._plot.x_axis.scale = xscale if not self.delegate.get_autoscalex_on(): self._plot.x_axis.limit = self.delegate.get_xlim() ylabel = self.delegate.get_ylabel() yscale = self.delegate.get_yscale() if not self.delegate.get_autoscaley_on(): self._plot.y_axis.limit = self.delegate.get_ylim() self._plot.y_axis.label = ylabel self._plot.y_axis.scale = yscale self._figure.plots.append(self._plot) def __getattr__(self, __name: str) -> Any: if __name in PLOTTING_METHODS: logging.warning(f"{__name} is not supported by PlotSerializer! Data will be lost!") return super().__getattr__(__name) class _AxesProxy3D(Proxy[MplAxes3D]): def __init__(self, delegate: MplAxes3D, figure: Figure, serializer: Serializer) -> None: super().__init__(delegate) self._figure = figure self._serializer = serializer self._plot: Optional[Plot] = None def scatter( self, xs, ys, zs, *args: Any, **kwargs: Any, ) -> Path3DCollection: try: path = self.delegate.scatter(xs, ys, zs, *args, **kwargs) except Exception as e: add_msg = " - This error was thrown by Matplotlib and is independent of PlotSerializer!" e.args = (e.args[0] + add_msg,) + e.args[1:] if e.args else (add_msg,) raise try: sizes_list = kwargs.get("s") marker = kwargs.get("marker") or "o" color_list = kwargs.get("c") color = kwargs.get("color") if color is not None and color_list is None: color_list = color cmap = kwargs.get("cmap") or "viridis" norm = kwargs.get("norm") or "linear" if isinstance(xs, np.generic): xs = xs.item() if isinstance(xs, (float, int, str)): xs = [xs] (color_list, cmap_used) = _convert_matplotlib_color(self, color_list, len(xs), cmap, norm) trace: List[ScatterTrace3D] = [] datapoints: List[Point3D] = [] xs, ys, zs = cbook._broadcast_with_masks(xs, ys, zs) xs, ys, zs, sizes_list, color_list, color = cbook.delete_masked_points( xs, ys, zs, sizes_list, color_list, kwargs.get("color", None) ) if sizes_list is None: sizes_list = itertools.repeat(None) if isinstance(sizes_list, (np.generic, float, int)): sizes_list = [sizes_list] * len(xs) for index, (xi, yi, zi, s) in enumerate(zip(xs, ys, zs, sizes_list)): c = color_list[index] if len(color_list) > index else None datapoints.append(Point3D(x=xi, y=yi, z=zi, color=c, size=s)) label = str(path.get_label()) if not cmap_used: cmap = None norm = None trace.append( ScatterTrace3D( type="scatter3D", cmap=cmap, norm=norm, label=label, datapoints=datapoints, marker=marker ) ) if self._plot is not None: if not isinstance(self._plot, Plot3D): raise NotImplementedError("PlotSerializer does not yet support mixing 3d plots with other plots!") self._plot.traces += trace else: self._plot = Plot3D(type="3d", x_axis=Axis(), y_axis=Axis(), z_axis=Axis(), traces=trace) except Exception as e: logging.warning( "An unexpected error occurred in PlotSerializer when trying to read plot data! " + "Parts of the plot will not be serialized!", exc_info=e, ) return path def plot( self, x_values, y_values, *args: Any, **kwargs: Any, ) -> Path3DCollection: try: path = self.delegate.plot(x_values, y_values, *args, **kwargs) except Exception as e: add_msg = " - This error was thrown by Matplotlib and is independent of PlotSerializer!" e.args = (e.args[0] + add_msg,) + e.args[1:] if e.args else (add_msg,) raise try: marker = kwargs.get("marker") or None color_list = kwargs.get("color") c = kwargs.get("c") if c is not None and color_list is None: color_list = c color_list = _convert_matplotlib_color(self, color_list, len(x_values), "viridis", "linear")[0] mpl_line = path[0] xdata, ydata, zdata = mpl_line.get_data_3d() label = mpl_line.get_label() thickness = mpl_line.get_linewidth() linestyle = mpl_line.get_linestyle() datapoints: List[Point3D] = [] for i in range(len(xdata)): datapoints.append(Point3D(x=xdata[i], y=ydata[i], z=zdata[i])) trace: List[LineTrace3D] = [] trace.append( LineTrace3D( type="line3D", color=color_list[0], linewidth=thickness, linestyle=linestyle, label=label, datapoints=datapoints, marker=marker, ) ) if self._plot is not None: if not isinstance(self._plot, Plot3D): raise NotImplementedError("PlotSerializer does not yet support mixing 3d plots with other plots!") self._plot.traces += trace else: self._plot = Plot3D(type="3d", x_axis=Axis(), y_axis=Axis(), z_axis=Axis(), traces=trace) except Exception as e: logging.warning( "An unexpected error occurred in PlotSerializer when trying to read plot data! " + "Parts of the plot will not be serialized!", exc_info=e, ) return path def plot_surface( self, x, y, z, *args: Any, **kwargs: Any, ) -> Poly3DCollection: try: surface = self.delegate.plot_surface(x, y, z, *args, **kwargs) except Exception as e: add_msg = " - This error was thrown by Matplotlib and is independent of PlotSerializer!" e.args = (e.args[0] + add_msg,) + e.args[1:] if e.args else (add_msg,) raise try: length = len(x) width = len(x[0]) z = cbook._to_unmasked_float_array(z) x, y, z = np.broadcast_arrays(x, y, z) traces: List[SurfaceTrace3D] = [] datapoints: List[Point3D] = [] color = kwargs.get("color") c = kwargs.get("c") if c is not None and color is None: color = c label = surface.get_label() for xi, yi, zi in zip(x, y, z): for xj, yj, zj in zip(xi, yi, zi): datapoints.append( Point3D( x=xj, y=yj, z=zj, color=color, ) ) traces.append( SurfaceTrace3D( type="surface3D", length=length, width=width, label=label, datapoints=datapoints, ) ) if self._plot is not None: if not isinstance(self._plot, Plot3D): raise NotImplementedError("PlotSerializer does not yet support mixing 3d plots with other plots!") self._plot.traces += traces else: self._plot = Plot3D( type="3d", x_axis=Axis(), y_axis=Axis(), z_axis=Axis(), traces=traces, ) except Exception as e: logging.warning( "An unexpected error occurred in PlotSerializer when trying to read plot data! " + "Parts of the plot will not be serialized!", exc_info=e, ) return surface def _on_collect(self) -> None: if self._plot is None: return self._plot.title = self.delegate.get_title() if isinstance(self._plot, Plot3D): xlabel = self.delegate.get_xlabel() xscale = self.delegate.get_xscale() self._plot.x_axis.label = xlabel self._plot.x_axis.scale = xscale if not self.delegate.get_autoscalex_on(): self._plot.x_axis.limit = self.delegate.get_xlim() ylabel = self.delegate.get_ylabel() yscale = self.delegate.get_yscale() self._plot.y_axis.label = ylabel self._plot.y_axis.scale = yscale if not self.delegate.get_autoscaley_on(): self._plot.y_axis.limit = self.delegate.get_ylim() zlabel = self.delegate.get_zlabel() zscale = self.delegate.get_zscale() self._plot.z_axis.label = zlabel self._plot.z_axis.scale = zscale if not self.delegate.get_autoscalez_on(): self._plot.z_axis.limit = self.delegate.get_zlim() self._figure.plots.append(self._plot) def __getattr__(self, __name: str) -> Any: if __name in PLOTTING_METHODS: logging.warning(f"{__name} is not supported by PlotSerializer, the Data will not be saved!") return super().__getattr__(__name) class MatplotlibSerializer(Serializer): """ Serializer specific to matplotlib. Most of the methods on this object mirror the matplotlib.pyplot api from matplotlib. Args: Serializer (_type_): Parent class """ def _create_axes_proxy(self, mpl_axes: Union[MplAxes3D, MplAxes]) -> Union[_AxesProxy, _AxesProxy3D]: proxy: Any if isinstance(mpl_axes, MplAxes3D): proxy = _AxesProxy3D(mpl_axes, self._figure, self) self._add_collect_action(lambda: proxy._on_collect()) elif isinstance(mpl_axes, MplAxes): proxy = _AxesProxy(mpl_axes, self._figure, self) self._add_collect_action(lambda: proxy._on_collect()) else: raise NotImplementedError("The matplotlib adapter only supports plots on 3D and normal axes") return proxy def subplots( self, *args: Any, **kwargs: Any, ) -> Tuple[MplFigure, Union[MplAxes, MplAxes3D, Any]]: figure, axes = matplotlib.pyplot.subplots(*args, **kwargs) new_axes: Any if isinstance(axes, np.ndarray): if isinstance(axes[0], np.ndarray): new_axes = np.array([list(map(self._create_axes_proxy, row)) for row in axes]) else: new_axes = np.array(list(map(self._create_axes_proxy, axes))) else: new_axes = self._create_axes_proxy(axes) return (figure, new_axes) def show(self, *args: Any, **kwargs: Any) -> None: matplotlib.pyplot.show(*args, **kwargs)