diff --git a/plot_serializer/matplotlib/serializer.py b/plot_serializer/matplotlib/serializer.py
index b895c417a97a08a64240c499bfbbc6127c9458dc..6876f41de7d18aabb65dfe452b5b53b07a612411 100644
--- a/plot_serializer/matplotlib/serializer.py
+++ b/plot_serializer/matplotlib/serializer.py
@@ -101,6 +101,14 @@ PLOTTING_METHODS = [
 ]
 
 
+def inherit_and_extend_doc(base_class, method_name, additional_doc):
+    def decorator(func):
+        func.__doc__ = getattr(base_class, method_name).__doc__ + additional_doc
+        return func
+
+    return decorator
+
+
 def _convert_matplotlib_color(
     self, color_list: Any, length: int, cmap: Any, norm: Any
 ) -> Tuple[List[str] | None, bool]:
@@ -157,6 +165,7 @@ class _AxesProxy(Proxy[MplAxes]):
         self._serializer = serializer
         self._plot: Optional[Plot] = None
 
+    @inherit_and_extend_doc(MplAxes, "plot", "\n\n Serialized parameters: x, y, color, marker, label. \n\n")
     def pie(self, x, **kwargs: Any) -> Any:
         try:
             result = self.delegate.pie(x, **kwargs)
@@ -169,24 +178,22 @@ class _AxesProxy(Proxy[MplAxes]):
             if self._plot is not None:
                 raise NotImplementedError("PlotSerializer does not yet support adding multiple plots per axes!")
 
-            slices: List[Slice] = []
-
             explode_list = kwargs.get("explode")
             label_list = kwargs.get("labels")
             radius = kwargs.get("radius") or None
-
             color_list = kwargs.get("colors")
             c = kwargs.get("c")
-            if c is not None and color_list is None:
-                color_list = c
-            color_list = _convert_matplotlib_color(self, color_list, len(x), cmap="viridis", norm="linear")[0]
 
             x = np.asarray(x, np.float32)
             if not explode_list:
                 explode_list = itertools.repeat(None)
             if not label_list:
                 label_list = itertools.repeat(None)
+            if c is not None and color_list is None:
+                color_list = c
+            color_list = _convert_matplotlib_color(self, color_list, len(x), cmap="viridis", norm="linear")[0]
 
+            slices: List[Slice] = []
             for index, (xi, label, explode) in enumerate(zip(x, label_list, explode_list)):
                 color = color_list[index] if len(color_list) > index else None
                 slices.append(
@@ -200,6 +207,7 @@ class _AxesProxy(Proxy[MplAxes]):
                 )
             pie_plot = PiePlot(type="pie", radius=radius, slices=slices)
             self._plot = pie_plot
+
         except Exception as e:
             logging.warning(
                 "An unexpected error occurred in PlotSerializer when trying to read plot data! "
@@ -209,11 +217,7 @@ class _AxesProxy(Proxy[MplAxes]):
 
         return result
 
-    pie.__doc__ = (
-        "\n\n Serialized parameters: x, labels, colors, explode, radius. \n\n Matplotlib documentation \n\n"
-        + MplAxes.pie.__doc__
-    )
-
+    @inherit_and_extend_doc(MplAxes, "bar", "\n\n Serialized parameters: x, height, color. \n\n")
     def bar(
         self,
         x,
@@ -228,13 +232,10 @@ class _AxesProxy(Proxy[MplAxes]):
             raise
 
         try:
-            bars: List[Bar2D] = []
-
             if isinstance(x, float):
                 x = [x]
             else:
                 x = np.asarray(x)
-
             if isinstance(height, float):
                 height = [height]
             else:
@@ -246,13 +247,11 @@ class _AxesProxy(Proxy[MplAxes]):
                 color_list = c
             color_list = _convert_matplotlib_color(self, color_list, len(x), cmap="viridis", norm="linear")[0]
 
+            bars: List[Bar2D] = []
             for index, (xi, h) in enumerate(zip(x, height)):
                 color = color_list[index] if len(color_list) > index else None
                 bars.append(Bar2D(x_i=xi, height=h, color=color))
 
-            # for xi, h, color in zip(x, height, color_list):
-            #     bars.append(Bar2D(x_i=xi, height=h, color=color))
-
             trace = BarTrace2D(type="bar", datapoints=bars)
 
             if self._plot is not None:
@@ -271,20 +270,253 @@ class _AxesProxy(Proxy[MplAxes]):
 
         return result
 
-    """
-    Custom wrapper for ax.plot with additional functionality.
+    def plot(self, *args: Any, **kwargs: Any) -> list[Line2D]:
+        """
+        Serializes x,y, linewidth, linestyle, marker, color, label.
 
-    {plt.Axes.plot.__doc__}
+        Original matplotlib documentation:
 
-    Parameters
-    ----------
-    *args : tuple
-        Positional arguments passed to `ax.plot`.
-    **kwargs : dict
-        Keyword arguments passed to `ax.plot`.
-    """
+        Plot y versus x as lines and/or markers.
 
-    def plot(self, *args: Any, **kwargs: Any) -> list[Line2D]:
+        Call signatures::
+
+            plot([x], y, [fmt], *, data=None, **kwargs)
+            plot([x], y, [fmt], [x2], y2, [fmt2], ..., **kwargs)
+
+        The coordinates of the points or line nodes are given by *x*, *y*.
+
+        The optional parameter *fmt* is a convenient way for defining basic
+        formatting like color, marker and linestyle. It's a shortcut string
+        notation described in the *Notes* section below.
+
+        >>> plot(x, y)  # plot x and y using default line style and color
+        >>> plot(x, y, "bo")  # plot x and y using blue circle markers
+        >>> plot(y)  # plot y using x as index array 0..N-1
+        >>> plot(y, "r+")  # ditto, but with red plusses
+
+        You can use `.Line2D` properties as keyword arguments for more
+        control on the appearance. Line properties and *fmt* can be mixed.
+        The following two calls yield identical results:
+
+        >>> plot(x, y, "go--", linewidth=2, markersize=12)
+        >>> plot(x, y, color="green", marker="o", linestyle="dashed", linewidth=2, markersize=12)
+
+        When conflicting with *fmt*, keyword arguments take precedence.
+
+
+        **Plotting labelled data**
+
+        There's a convenient way for plotting objects with labelled data (i.e.
+        data that can be accessed by index ``obj['y']``). Instead of giving
+        the data in *x* and *y*, you can provide the object in the *data*
+        parameter and just give the labels for *x* and *y*::
+
+        >>> plot("xlabel", "ylabel", data=obj)
+
+        All indexable objects are supported. This could e.g. be a `dict`, a
+        `pandas.DataFrame` or a structured numpy array.
+
+
+        **Plotting multiple sets of data**
+
+        There are various ways to plot multiple sets of data.
+
+        - The most straight forward way is just to call `plot` multiple times.
+          Example:
+
+          >>> plot(x1, y1, "bo")
+          >>> plot(x2, y2, "go")
+
+        - If *x* and/or *y* are 2D arrays, a separate data set will be drawn
+          for every column. If both *x* and *y* are 2D, they must have the
+          same shape. If only one of them is 2D with shape (N, m) the other
+          must have length N and will be used for every data set m.
+
+          Example:
+
+          >>> x = [1, 2, 3]
+          >>> y = np.array([[1, 2], [3, 4], [5, 6]])
+          >>> plot(x, y)
+
+          is equivalent to:
+
+          >>> for col in range(y.shape[1]):
+          ...     plot(x, y[:, col])
+
+        - The third way is to specify multiple sets of *[x]*, *y*, *[fmt]*
+          groups::
+
+          >>> plot(x1, y1, "g^", x2, y2, "g-")
+
+          In this case, any additional keyword argument applies to all
+          datasets. Also, this syntax cannot be combined with the *data*
+          parameter.
+
+        By default, each line is assigned a different style specified by a
+        'style cycle'. The *fmt* and line property parameters are only
+        necessary if you want explicit deviations from these defaults.
+        Alternatively, you can also change the style cycle using
+        :rc:`axes.prop_cycle`.
+
+
+        Parameters
+        ----------
+        x, y : array-like or scalar
+            The horizontal / vertical coordinates of the data points.
+            *x* values are optional and default to ``range(len(y))``.
+
+            Commonly, these parameters are 1D arrays.
+
+            They can also be scalars, or two-dimensional (in that case, the
+            columns represent separate data sets).
+
+            These arguments cannot be passed as keywords.
+
+        fmt : str, optional
+            A format string, e.g. 'ro' for red circles. See the *Notes*
+            section for a full description of the format strings.
+
+            Format strings are just an abbreviation for quickly setting
+            basic line properties. All of these and more can also be
+            controlled by keyword arguments.
+
+            This argument cannot be passed as keyword.
+
+        data : indexable object, optional
+            An object with labelled data. If given, provide the label names to
+            plot in *x* and *y*.
+
+            .. note::
+                Technically there's a slight ambiguity in calls where the
+                second label is a valid *fmt*. ``plot('n', 'o', data=obj)``
+                could be ``plt(x, y)`` or ``plt(y, fmt)``. In such cases,
+                the former interpretation is chosen, but a warning is issued.
+                You may suppress the warning by adding an empty format string
+                ``plot('n', 'o', '', data=obj)``.
+
+        Returns
+        -------
+        list of `.Line2D`
+            A list of lines representing the plotted data.
+
+        Other Parameters
+        ----------------
+        scalex, scaley : bool, default: True
+            These parameters determine if the view limits are adapted to the
+            data limits. The values are passed on to
+            `~.axes.Axes.autoscale_view`.
+
+        **kwargs : `~matplotlib.lines.Line2D` properties, optional
+            *kwargs* are used to specify properties like a line label (for
+            auto legends), linewidth, antialiasing, marker face color.
+            Example::
+
+            >>> plot([1, 2, 3], [1, 2, 3], "go-", label="line 1", linewidth=2)
+            >>> plot([1, 2, 3], [1, 4, 9], "rs", label="line 2")
+
+            If you specify multiple lines with one plot call, the kwargs apply
+            to all those lines. In case the label object is iterable, each
+            element is used as labels for each set of data.
+
+            Here is a list of available `.Line2D` properties:
+
+            %(Line2D:kwdoc)s
+
+        See Also
+        --------
+        scatter : XY scatter plot with markers of varying size and/or color (
+            sometimes also called bubble chart).
+
+        Notes
+        -----
+        **Format Strings**
+
+        A format string consists of a part for color, marker and line::
+
+            fmt = "[marker][line][color]"
+
+        Each of them is optional. If not provided, the value from the style
+        cycle is used. Exception: If ``line`` is given, but no ``marker``,
+        the data will be a line without markers.
+
+        Other combinations such as ``[color][marker][line]`` are also
+        supported, but note that their parsing may be ambiguous.
+
+        **Markers**
+
+        =============   ===============================
+        character       description
+        =============   ===============================
+        ``'.'``         point marker
+        ``','``         pixel marker
+        ``'o'``         circle marker
+        ``'v'``         triangle_down marker
+        ``'^'``         triangle_up marker
+        ``'<'``         triangle_left marker
+        ``'>'``         triangle_right marker
+        ``'1'``         tri_down marker
+        ``'2'``         tri_up marker
+        ``'3'``         tri_left marker
+        ``'4'``         tri_right marker
+        ``'8'``         octagon marker
+        ``'s'``         square marker
+        ``'p'``         pentagon marker
+        ``'P'``         plus (filled) marker
+        ``'*'``         star marker
+        ``'h'``         hexagon1 marker
+        ``'H'``         hexagon2 marker
+        ``'+'``         plus marker
+        ``'x'``         x marker
+        ``'X'``         x (filled) marker
+        ``'D'``         diamond marker
+        ``'d'``         thin_diamond marker
+        ``'|'``         vline marker
+        ``'_'``         hline marker
+        =============   ===============================
+
+        **Line Styles**
+
+        =============    ===============================
+        character        description
+        =============    ===============================
+        ``'-'``          solid line style
+        ``'--'``         dashed line style
+        ``'-.'``         dash-dot line style
+        ``':'``          dotted line style
+        =============    ===============================
+
+        Example format strings::
+
+            "b"  # blue markers with default shape
+
+            "or"  # red circles
+            "-g"  # green solid line
+            "--"  # dashed line with default color
+            "^k:"  # black triangle_up markers connected by a dotted line
+
+        **Colors**
+
+        The supported color abbreviations are the single letter codes
+
+        =============    ===============================
+        character        color
+        =============    ===============================
+        ``'b'``          blue
+        ``'g'``          green
+        ``'r'``          red
+        ``'c'``          cyan
+        ``'m'``          magenta
+        ``'y'``          yellow
+        ``'k'``          black
+        ``'w'``          white
+        =============    ===============================
+
+        and the ``'CN'`` colors that index into the default property cycle.
+
+        If the color is the only part of the format string, you can
+        additionally use any  `matplotlib.colors` spec, e.g. full names
+        (``'green'``) or hex strings (``'#008000'``).
+        """
         try:
             mpl_lines = self.delegate.plot(*args, **kwargs)
         except Exception as e:
@@ -296,26 +528,22 @@ class _AxesProxy(Proxy[MplAxes]):
             traces: List[LineTrace2D] = []
 
             for mpl_line in mpl_lines:
+                xdata = mpl_line.get_xdata()
+                ydata = mpl_line.get_ydata()
+                thickness = mpl_line.get_linewidth()
+                linestyle = mpl_line.get_linestyle()
+                marker = mpl_line.get_marker()
+                label = mpl_line.get_label()
                 color_list = kwargs.get("color")
                 c = kwargs.get("c")
                 if c is not None and color_list is None:
                     color_list = c
-
-                xdata = mpl_line.get_xdata()
-                ydata = mpl_line.get_ydata()
-                print(type(xdata))
+                color_list = _convert_matplotlib_color(self, color_list, len(xdata), cmap="viridis", norm="linear")[0]
 
                 points: List[Point2D] = []
-
                 for x, y in zip(xdata, ydata):
                     points.append(Point2D(x=x, y=y))
 
-                label = mpl_line.get_label()
-                color_list = _convert_matplotlib_color(self, color_list, len(xdata), cmap="viridis", norm="linear")[0]
-                thickness = mpl_line.get_linewidth()
-                linestyle = mpl_line.get_linestyle()
-                marker = mpl_line.get_marker()
-
                 traces.append(
                     LineTrace2D(
                         type="line",
@@ -358,6 +586,7 @@ class _AxesProxy(Proxy[MplAxes]):
             raise
 
         try:
+            verteces = path.get_offsets().tolist()
             marker = kwargs.get("marker") or "o"
             color_list = kwargs.get("c")
             color = kwargs.get("color")
@@ -366,8 +595,12 @@ class _AxesProxy(Proxy[MplAxes]):
             sizes_list = kwargs.get("s")
             cmap = kwargs.get("cmap") or "viridis"
             norm = kwargs.get("norm") or "linear"
+            label = str(path.get_label())
 
             (color_list, cmap_used) = _convert_matplotlib_color(self, color_list, len(x), cmap, norm)
+            if not cmap_used:
+                cmap = None
+                norm = None
 
             if sizes_list is not None:
                 sizes_list = path.get_sizes()
@@ -376,11 +609,7 @@ class _AxesProxy(Proxy[MplAxes]):
             if isinstance(sizes_list, np.generic):
                 sizes_list = [sizes_list] * len(x)
 
-            label = str(path.get_label())
             datapoints: List[Point2D] = []
-
-            verteces = path.get_offsets().tolist()
-
             for index, (vertex, size) in enumerate(zip(verteces, sizes_list)):
                 color = color_list[index] if len(color_list) > index else None
                 datapoints.append(
@@ -391,9 +620,6 @@ class _AxesProxy(Proxy[MplAxes]):
                         size=size,
                     )
                 )
-            if not cmap_used:
-                cmap = None
-                norm = None
             trace: List[ScatterTrace2D] = []
             trace.append(
                 ScatterTrace2D(type="scatter", cmap=cmap, norm=norm, label=label, datapoints=datapoints, marker=marker)
@@ -430,8 +656,6 @@ class _AxesProxy(Proxy[MplAxes]):
             conf_intervals = kwargs.get("conf_intervals")
             labels = kwargs.get("tick_labels")
 
-            trace: List[BoxTrace2D] = []
-            boxes: List[Box] = []
             x = _reshape_2D(x, "x")
 
             if not labels:
@@ -441,6 +665,8 @@ class _AxesProxy(Proxy[MplAxes]):
             if not conf_intervals:
                 conf_intervals = itertools.repeat(None)
 
+            trace: List[BoxTrace2D] = []
+            boxes: List[Box] = []
             for dataset, label, umedian, cintervals in zip(x, labels, usermedians, conf_intervals):
                 x = np.ma.asarray(x)
                 x = x.data[~x.mask].ravel()
@@ -518,14 +744,14 @@ class _AxesProxy(Proxy[MplAxes]):
             if yerr is None:
                 yerr = itertools.repeat(None)
 
-            print(xerr)
-            print(xerr.ndim)
-
             if xerr.ndim == 0 or xerr.ndim == 1:
                 xerr = np.broadcast_to(xerr, (2, len(x)))
             if yerr.ndim == 0 or yerr.ndim == 1:
                 yerr = np.broadcast_to(yerr, (2, len(y)))
 
+            color = mcolors.to_hex(color) if color else None
+            ecolor = mcolors.to_hex(ecolor) if ecolor else None
+
             errorpoints: List[ErrorPoint2D] = []
             for xi, yi, x_error, y_error in zip(x, y, xerr.T, yerr.T):
                 errorpoints.append(
@@ -536,8 +762,6 @@ class _AxesProxy(Proxy[MplAxes]):
                         yerr=y_error,
                     )
                 )
-            color = mcolors.to_hex(color) if color else None
-            ecolor = mcolors.to_hex(ecolor) if ecolor else None
             trace = ErrorBar2DTrace(
                 type="errorbar2d",
                 label=label,
@@ -594,12 +818,10 @@ class _AxesProxy(Proxy[MplAxes]):
             if np.isscalar(x):
                 x = [x]
             x = _reshape_2D(x, "x")
-            print(x)
 
             color_list = _convert_matplotlib_color(self, color_list, len(x), "viridis", "linear")[0]
 
             datasets: List[HistDataset] = []
-
             for index, (element, label) in enumerate(zip(x, label_list)):
                 color = color_list[index] if len(color_list) > index else None
                 datasets.append(HistDataset(x_i=element, color=color, label=label))
@@ -627,13 +849,6 @@ class _AxesProxy(Proxy[MplAxes]):
             )
         return ret
 
-    def _are_lists_same_length(self, *lists) -> bool:
-        non_empty_lists = [lst for lst in lists if lst]
-        if not non_empty_lists:
-            return True
-        length = len(non_empty_lists[0])
-        return all(len(lst) == length for lst in non_empty_lists)
-
     def _on_collect(self) -> None:
         if self._plot is None:
             return
@@ -697,17 +912,18 @@ class _AxesProxy3D(Proxy[MplAxes3D]):
         try:
             sizes_list = kwargs.get("s")
             marker = kwargs.get("marker") or "o"
-
             color_list = kwargs.get("c")
             color = kwargs.get("color")
             if color is not None and color_list is None:
                 color_list = color
             cmap = kwargs.get("cmap") or "viridis"
             norm = kwargs.get("norm") or "linear"
-            (color_list, cmap_used) = _convert_matplotlib_color(self, color_list, len(xs), cmap, norm)
+            label = str(path.get_label())
 
-            trace: List[ScatterTrace3D] = []
-            datapoints: List[Point3D] = []
+            (color_list, cmap_used) = _convert_matplotlib_color(self, color_list, len(xs), cmap, norm)
+            if not cmap_used:
+                cmap = None
+                norm = None
 
             xs, ys, zs = cbook._broadcast_with_masks(xs, ys, zs)
             xs, ys, zs, sizes_list, color_list, color = cbook.delete_masked_points(
@@ -719,14 +935,12 @@ class _AxesProxy3D(Proxy[MplAxes3D]):
             if isinstance(sizes_list, (np.generic, float, int)):
                 sizes_list = [sizes_list] * len(xs)
 
+            trace: List[ScatterTrace3D] = []
+            datapoints: List[Point3D] = []
             for index, (xi, yi, zi, s) in enumerate(zip(xs, ys, zs, sizes_list)):
                 c = color_list[index] if len(color_list) > index else None
                 datapoints.append(Point3D(x=xi, y=yi, z=zi, color=c, size=s))
 
-            label = str(path.get_label())
-            if not cmap_used:
-                cmap = None
-                norm = None
             trace.append(
                 ScatterTrace3D(
                     type="scatter3D", cmap=cmap, norm=norm, label=label, datapoints=datapoints, marker=marker
@@ -763,19 +977,18 @@ class _AxesProxy3D(Proxy[MplAxes3D]):
             raise
 
         try:
+            mpl_line = path[0]
+            xdata, ydata, zdata = mpl_line.get_data_3d()
+            label = mpl_line.get_label()
+            thickness = mpl_line.get_linewidth()
+            linestyle = mpl_line.get_linestyle()
             marker = kwargs.get("marker") or None
             color_list = kwargs.get("color")
             c = kwargs.get("c")
             if c is not None and color_list is None:
                 color_list = c
-            color_list = _convert_matplotlib_color(self, color_list, len(x_values), "viridis", "linear")[0]
-
-            mpl_line = path[0]
-            xdata, ydata, zdata = mpl_line.get_data_3d()
 
-            label = mpl_line.get_label()
-            thickness = mpl_line.get_linewidth()
-            linestyle = mpl_line.get_linestyle()
+            color_list = _convert_matplotlib_color(self, color_list, len(x_values), "viridis", "linear")[0]
 
             datapoints: List[Point3D] = []
             for i in range(len(xdata)):
@@ -825,6 +1038,12 @@ class _AxesProxy3D(Proxy[MplAxes3D]):
             raise
 
         try:
+            color = kwargs.get("color")
+            c = kwargs.get("c")
+            if c is not None and color is None:
+                color = c
+            label = surface.get_label()
+
             length = len(x)
             width = len(x[0])
 
@@ -833,13 +1052,6 @@ class _AxesProxy3D(Proxy[MplAxes3D]):
 
             traces: List[SurfaceTrace3D] = []
             datapoints: List[Point3D] = []
-
-            color = kwargs.get("color")
-            c = kwargs.get("c")
-            if c is not None and color is None:
-                color = c
-            label = surface.get_label()
-
             for xi, yi, zi in zip(x, y, z):
                 for xj, yj, zj in zip(xi, yi, zi):
                     datapoints.append(
@@ -947,7 +1159,7 @@ class MatplotlibSerializer(Serializer):
         self,
         *args: Any,
         **kwargs: Any,
-    ) -> Tuple[MplFigure, Union[MplAxes, MplAxes3D, Any]]:
+    ) -> Tuple[MplFigure, Union[MplAxes, MplAxes3D, _AxesProxy, _AxesProxy3D, Any]]:
         figure, axes = matplotlib.pyplot.subplots(*args, **kwargs)
 
         new_axes: Any
diff --git a/plot_serializer/matplotlib/serializer.pyi b/plot_serializer/matplotlib/serializer.pyi
index 18498d752cd1444da2b35979686bb4cbd7c59d1e..50bc65b99d4ab21988ff833f7bee52033d3d1954 100644
--- a/plot_serializer/matplotlib/serializer.pyi
+++ b/plot_serializer/matplotlib/serializer.pyi
@@ -12,7 +12,7 @@ from typing import (
 from matplotlib.axes import Axes as MplAxes
 from matplotlib.figure import Figure as MplFigure
 
-from plot_serializer.serializer import Serializer
+from plot_serializer.serializer import Serializer, _AxesProxy
 
 class MatplotlibSerializer(Serializer):
     # Fancy way of properly type hinting the subplots method...
@@ -30,7 +30,7 @@ class MatplotlibSerializer(Serializer):
         subplot_kw: None = None,
         gridspec_kw: Optional[Dict[str, Any]] = None,
         **fig_kw: Any,
-    ) -> Tuple[MplFigure, MplAxes]: ...
+    ) -> Tuple[MplFigure, _AxesProxy]: ...
     @overload
     def subplots(
         self,
@@ -45,5 +45,5 @@ class MatplotlibSerializer(Serializer):
         subplot_kw: Optional[Dict[str, Any]] = None,
         gridspec_kw: Optional[Dict[str, Any]] = None,
         **fig_kw: Any,
-    ) -> Tuple[MplFigure, Any]: ...
+    ) -> Tuple[MplFigure, _AxesProxy]: ...
     def show(self, *, block: Optional[bool] = None) -> None: ...