Source code for plot_serializer.matplotlib.serializer

from typing import (
    Any,
    Tuple,
    Union,
)

import matplotlib.pyplot
import numpy as np
from matplotlib.axes import Axes as MplAxes
from matplotlib.figure import Figure as MplFigure
from mpl_toolkits.mplot3d.axes3d import Axes3D as MplAxes3D

from plot_serializer.matplotlib.axesproxy import AxesProxy, AxesProxy3D
from plot_serializer.serializer import Serializer


[docs] class MatplotlibSerializer(Serializer): """ Serializer specific to matplotlib. Most of the methods on this object mirror the matplotlib.pyplot api from matplotlib. Args: Serializer (_type_): Parent class """ def _create_axes_proxy(self, mpl_axes: Union[MplAxes3D, MplAxes]) -> Union[AxesProxy, AxesProxy3D]: proxy: Any if isinstance(mpl_axes, MplAxes3D): proxy = AxesProxy3D(mpl_axes, self._figure, self) self._add_collect_action(lambda: proxy._on_collect()) elif isinstance(mpl_axes, MplAxes): proxy = AxesProxy(mpl_axes, self._figure, self) self._add_collect_action(lambda: proxy._on_collect()) else: raise NotImplementedError("The matplotlib adapter only supports plots on 3D and normal axes") return proxy def subplots( self, *args: Any, **kwargs: Any, ) -> Tuple[MplFigure, Union[MplAxes, MplAxes3D, Any]]: figure, axes = matplotlib.pyplot.subplots(*args, **kwargs) new_axes: Any if isinstance(axes, np.ndarray): if isinstance(axes[0], np.ndarray): new_axes = np.array([list(map(self._create_axes_proxy, row)) for row in axes]) else: new_axes = np.array(list(map(self._create_axes_proxy, axes))) else: new_axes = self._create_axes_proxy(axes) return (figure, new_axes) def show(self, *args: Any, **kwargs: Any) -> None: matplotlib.pyplot.show(*args, **kwargs)