Skip to content
Snippets Groups Projects
Commit 57e976a8 authored by Julius's avatar Julius
Browse files

intermediate commit for expand_type_support

parent 363ec3f7
Branches
Tags
1 merge request!32Expand type support
......@@ -109,9 +109,9 @@ def _deserialize_linetrace2d(trace: LineTrace2D, ax: MplAxes) -> None:
x,
y,
label=trace.label,
color=trace.line_color,
linewidth=trace.line_thickness,
linestyle=trace.line_style,
color=trace.color,
linewidth=trace.linewidth,
linestyle=trace.linestyle,
marker=trace.marker,
)
......
......
......@@ -173,14 +173,13 @@ class _AxesProxy(Proxy[MplAxes]):
self._serializer = serializer
self._plot: Optional[Plot] = None
def pie(self, size_list, **kwargs: Any) -> Any:
def pie(self, x, **kwargs: Any) -> Any:
try:
result = self.delegate.pie(size_list, **kwargs)
result = self.delegate.pie(x, **kwargs)
except Exception as e:
logging.warning(
"An Error got thrown from Matplotlib's independent from PlotSerializer!",
exc_info=e,
)
add_msg = " - This error was thrown by Matplotlib and is independent of PlotSerializer!"
e.args = (e.args[0] + add_msg,) + e.args[1:] if e.args else (add_msg,)
raise
try:
if self._plot is not None:
......@@ -193,21 +192,21 @@ class _AxesProxy(Proxy[MplAxes]):
radius = kwargs.get("radius") or None
color_list = kwargs.get("colors") or []
color_list = _convert_matplotlib_color(self, color_list, len(size_list), cmap="viridis", norm="linear")[0]
color_list = _convert_matplotlib_color(self, color_list, len(x), cmap="viridis", norm="linear")[0]
size_list = np.asarray(size_list, np.float32)
x = np.asarray(x, np.float32)
if not explode_list:
explode_list = itertools.repeat(None)
if not label_list:
label_list = itertools.repeat(None)
for size, label, explode, color in zip(size_list, label_list, explode_list, color_list):
for xi, label, explode, color in zip(x, label_list, explode_list, color_list):
slices.append(
Slice(
size=size,
xi=xi,
radius=radius,
offset=explode,
name=label,
explode=explode,
label=label,
color=color,
)
)
......@@ -225,10 +224,15 @@ class _AxesProxy(Proxy[MplAxes]):
def bar(
self,
x,
height_list,
height,
**kwargs: Any,
) -> BarContainer:
result = self.delegate.bar(x, height_list, **kwargs)
try:
result = self.delegate.bar(x, height, **kwargs)
except Exception as e:
add_msg = " - This error was thrown by Matplotlib and is independent of PlotSerializer!"
e.args = (e.args[0] + add_msg,) + e.args[1:] if e.args else (add_msg,)
raise
try:
bars: List[Bar2D] = []
......@@ -242,16 +246,14 @@ class _AxesProxy(Proxy[MplAxes]):
x = [x]
else:
x = np.asarray(x)
if not height_list:
pass
if isinstance(height_list, float):
height_list = [height_list]
if isinstance(height, float):
height = [height]
else:
height_list = np.asarray(height_list)
height = np.asarray(height)
for label, height, color in zip(x, height_list, color_list):
bars.append(Bar2D(y=height, label=label, color=color))
for xi, h, color in zip(x, height, color_list):
bars.append(Bar2D(x=xi, height=h, color=color))
trace = BarTrace2D(type="bar", datapoints=bars)
......@@ -272,7 +274,12 @@ class _AxesProxy(Proxy[MplAxes]):
return result
def plot(self, *args: Any, **kwargs: Any) -> list[Line2D]:
try:
mpl_lines = self.delegate.plot(*args, **kwargs)
except Exception as e:
add_msg = " - This error was thrown by Matplotlib and is independent of PlotSerializer!"
e.args = (e.args[0] + add_msg,) + e.args[1:] if e.args else (add_msg,)
raise
try:
traces: List[LineTrace2D] = []
......@@ -297,9 +304,9 @@ class _AxesProxy(Proxy[MplAxes]):
traces.append(
LineTrace2D(
type="line",
line_color=color_list[0],
line_thickness=thickness,
line_style=linestyle,
color=color_list[0],
linewidth=thickness,
linestyle=linestyle,
label=label,
datapoints=points,
marker=marker,
......@@ -323,12 +330,17 @@ class _AxesProxy(Proxy[MplAxes]):
def scatter(
self,
x_values,
y_values,
x,
y,
*args: Any,
**kwargs: Any,
) -> PathCollection:
path = self.delegate.scatter(x_values, y_values, *args, **kwargs)
try:
path = self.delegate.scatter(x, y, *args, **kwargs)
except Exception as e:
add_msg = " - This error was thrown by Matplotlib and is independent of PlotSerializer!"
e.args = (e.args[0] + add_msg,) + e.args[1:] if e.args else (add_msg,)
raise
try:
marker = kwargs.get("marker") or "o"
......@@ -337,7 +349,7 @@ class _AxesProxy(Proxy[MplAxes]):
cmap = kwargs.get("cmap") or "viridis"
norm = kwargs.get("norm") or "linear"
(color_list, cmap_used) = _convert_matplotlib_color(self, color_list, len(x_values), cmap, norm)
(color_list, cmap_used) = _convert_matplotlib_color(self, color_list, len(x), cmap, norm)
if sizes_list:
sizes_list = path.get_sizes()
......@@ -524,7 +536,13 @@ class _AxesProxy(Proxy[MplAxes]):
ndarray,
BarContainer | Polygon | list[BarContainer | Polygon],
]:
try:
ret = self.delegate.hist(x, *args, **kwargs)
except Exception as e:
add_msg = " - This error was thrown by Matplotlib and is independent of PlotSerializer!"
e.args = (e.args[0] + add_msg,) + e.args[1:] if e.args else (add_msg,)
raise
try:
bins = kwargs.get("bins") or 10
density = kwargs.get("density") or False
......@@ -625,13 +643,18 @@ class _AxesProxy3D(Proxy[MplAxes3D]):
def scatter(
self,
x_values: Iterable[float],
y_values: Iterable[float],
z_values: Iterable[float],
xs,
ys,
zs,
*args: Any,
**kwargs: Any,
) -> Path3DCollection:
path = self.delegate.scatter(x_values, y_values, z_values, *args, **kwargs)
try:
path = self.delegate.scatter(xs, ys, zs, *args, **kwargs)
except Exception as e:
add_msg = " - This error was thrown by Matplotlib and is independent of PlotSerializer!"
e.args = (e.args[0] + add_msg,) + e.args[1:] if e.args else (add_msg,)
raise
try:
sizes_list = kwargs.get("s") or []
......@@ -640,38 +663,21 @@ class _AxesProxy3D(Proxy[MplAxes3D]):
color_list = kwargs.get("c") or []
cmap = kwargs.get("cmap") or "viridis"
norm = kwargs.get("norm") or "linear"
(color_list, cmap_used) = _convert_matplotlib_color(self, color_list, len(x_values), cmap, norm)
if isinstance(x_values, float) or isinstance(x_values, int):
x_values = [x_values]
if isinstance(y_values, float) or isinstance(y_values, int):
y_values = [y_values]
if isinstance(z_values, float) or isinstance(z_values, int):
z_values = [z_values]
if isinstance(sizes_list, float) or isinstance(sizes_list, int):
sizes_list = [sizes_list]
(color_list, cmap_used) = _convert_matplotlib_color(self, color_list, len(xs), cmap, norm)
trace: List[ScatterTrace3D] = []
datapoints: List[Point3D] = []
sizes: List[float] = []
if sizes_list:
if not (len(x_values) == len(sizes_list)):
if not (len(sizes_list) - 1):
sizes = [sizes_list[0] for i in range(len(x_values))]
else:
raise ValueError(
"sizes list contains more than one element while not being as long as the x_values array"
xs, ys, zs = cbook._broadcast_with_masks(xs, ys, zs)
xs, ys, zs, sizes_list, color_list, color = cbook.delete_masked_points(
xs, ys, zs, sizes_list, color_list, kwargs.get("color", None)
)
else:
sizes = sizes_list
else:
sizes = [None] * len(x_values)
for i in range(len(x_values)):
c = color_list[i] if i < len(color_list) else None
s = sizes[i]
datapoints.append(Point3D(x=x_values[i], y=y_values[i], z=z_values[i], color=c, size=s))
if not sizes_list:
sizes_list = itertools.repeat(None)
for xi, yi, zi, c, s in zip(xs, ys, zs, color_list, sizes_list):
datapoints.append(Point3D(x=xi, y=yi, z=zi, color=c, size=s))
label = str(path.get_label())
if not cmap_used:
......@@ -705,7 +711,12 @@ class _AxesProxy3D(Proxy[MplAxes3D]):
*args: Any,
**kwargs: Any,
) -> Path3DCollection:
try:
path = self.delegate.plot(x_values, y_values, *args, **kwargs)
except Exception as e:
add_msg = " - This error was thrown by Matplotlib and is independent of PlotSerializer!"
e.args = (e.args[0] + add_msg,) + e.args[1:] if e.args else (add_msg,)
raise
try:
marker = kwargs.get("marker") or None
......@@ -727,9 +738,9 @@ class _AxesProxy3D(Proxy[MplAxes3D]):
trace.append(
LineTrace3D(
type="line3D",
line_color=color_list[0],
line_thickness=thickness,
line_style=linestyle,
color=color_list[0],
linewidth=thickness,
linestyle=linestyle,
label=label,
datapoints=datapoints,
marker=marker,
......@@ -753,20 +764,25 @@ class _AxesProxy3D(Proxy[MplAxes3D]):
def plot_surface(
self,
x_values: list[list[float]],
y_values: list[list[float]],
z_values: list[list[float]],
x: list[list[float]],
y: list[list[float]],
z: list[list[float]],
*args: Any,
**kwargs: Any,
) -> Poly3DCollection:
surface = self.delegate.plot_surface(x_values, y_values, z_values, *args, **kwargs)
try:
surface = self.delegate.plot_surface(x, y, z, *args, **kwargs)
except Exception as e:
add_msg = " - This error was thrown by Matplotlib and is independent of PlotSerializer!"
e.args = (e.args[0] + add_msg,) + e.args[1:] if e.args else (add_msg,)
raise
try:
length = len(x_values)
width = len(x_values[0])
# length = len(x)
# width = len(x[0])
if not length == len(y_values) == len(z_values):
raise ValueError("The x, y and z arrays do not contain the same amount of elements")
z = cbook._to_unmasked_float_array(z)
x, y, z = np.broadcast_arrays(x, y, z)
traces: List[SurfaceTrace3D] = []
datapoints: List[Point3D] = []
......@@ -774,28 +790,22 @@ class _AxesProxy3D(Proxy[MplAxes3D]):
color = kwargs.get("color") or None
label = surface.get_label()
for i in range(length):
if not width == len(x_values[i]) == len(y_values[i]) == len(z_values[i]):
raise ValueError(
f"The x, y and z arrays do not contain the same amount of elements in the second dimension {i}"
)
for j in range(width):
for xi, yi, zi in zip(x, y, z):
for xj, yj, zj in zip(xi, yi, zi):
datapoints.append(
Point3D(
x=x_values[i][j],
y=y_values[i][j],
z=z_values[i][j],
x=xj,
y=yj,
z=zj,
color=color,
# size=s,
)
)
traces.append(
SurfaceTrace3D(
type="surface3D",
length=length,
width=width,
# length=length,
# width=width,
label=label,
datapoints=datapoints,
)
......
......
......@@ -2,6 +2,8 @@ import logging
from re import A
from typing import Annotated, Any, Dict, List, Literal, Optional, Sequence, Tuple, Union
from matplotlib.colors import Colormap, Normalize
from matplotlib.markers import MarkerStyle
from numpy.typing import ArrayLike
from pydantic import BaseModel, Field, model_validator
......@@ -41,14 +43,13 @@ class Axis(BaseModel):
class Point2D(BaseModel):
metadata: Metadata = {}
x: Any # used to be: float
y: Any # used to be: float
color: Optional[str] = None
size: Optional[float] = None
x: Any
y: Any
color: Optional[Color] = None
size: Any = None
def emit_warnings(self) -> None:
msg: List[str] = []
# TODO: Improve the warning system
if len(msg) > 0:
logging.warning("%s is not set for Point2D.", msg)
......@@ -59,12 +60,11 @@ class Point3D(BaseModel):
x: Any # used to be: float
y: Any # used to be: float
z: Any # used to be: float
color: Optional[str] = None
color: Optional[Color] = None
size: Any = None
def emit_warnings(self) -> None:
msg: List[str] = []
# TODO: Improve the warning system
if len(msg) > 0:
logging.warning("%s is not set for Point3D.", msg)
......@@ -73,10 +73,10 @@ class Point3D(BaseModel):
class ScatterTrace2D(BaseModel):
type: Literal["scatter"]
metadata: Metadata = {}
cmap: Any = None
norm: Any = None
cmap: Optional[str | Colormap] = None
norm: Optional[Normalize] = None
label: Optional[str]
marker: Optional[str]
marker: Optional[MarkerStyle]
datapoints: List[Point2D]
def emit_warnings(self) -> None:
......@@ -95,10 +95,10 @@ class ScatterTrace2D(BaseModel):
class ScatterTrace3D(BaseModel):
type: Literal["scatter3D"]
metadata: Metadata = {}
cmap: Any = None
norm: Any = None
cmap: Optional[str | Colormap] = None
norm: Optional[Normalize] = None
label: Optional[str]
marker: Optional[str]
marker: Optional[MarkerStyle]
datapoints: List[Point3D]
def emit_warnings(self) -> None:
......@@ -117,10 +117,10 @@ class ScatterTrace3D(BaseModel):
class LineTrace2D(BaseModel):
type: Literal["line"]
metadata: Metadata = {}
line_color: Optional[str | Tuple[float, float, float] | Tuple[float, float, float, float]] = None
line_thickness: Optional[float] = None
line_style: Optional[str] = None
marker: Optional[str] = None
color: Optional[Color] = None
linewidth: Optional[float] = None
linestyle: Optional[str] = None
marker: Optional[MarkerStyle] = None
label: Optional[str] = None
datapoints: List[Point2D]
......@@ -140,9 +140,9 @@ class LineTrace2D(BaseModel):
class LineTrace3D(BaseModel):
type: Literal["line3D"]
metadata: Metadata = {}
line_color: Optional[str | Tuple[float, float, float] | Tuple[float, float, float, float]] = None
line_thickness: Optional[float] = None
line_style: Optional[str] = None
color: Optional[Color] = None
linewidth: Optional[float] = None
linestyle: Optional[str] = None
marker: Optional[str] = None
label: Optional[str] = None
datapoints: List[Point3D]
......@@ -190,7 +190,7 @@ class SurfaceTrace3D(BaseModel):
class Bar2D(BaseModel):
metadata: Metadata = {}
height: Any
x: Any
xi: Any
color: Optional[str | Tuple[float, float, float] | Tuple[float, float, float, float]] = None
def emit_warnings(self) -> None:
......@@ -213,8 +213,8 @@ class BarTrace2D(BaseModel):
class Box(BaseModel):
metadata: Metadata = {}
data: Any # used to be: List[float]
label: Any = None # used to be: Optional[str]
x: Any
tick_label: Any = None # used to be: Optional[str]
usermedian: Any = None # used to be: Optional[float]
conf_interval: Any = None # used to be: Optional[Tuple[float, float]]
......@@ -229,7 +229,7 @@ class BoxTrace2D(BaseModel):
type: Literal["box"]
metadata: Metadata = {}
notch: Optional[bool] = None
whis: Optional[Union[float, Tuple[float, float]]] = None
whis: Optional[float | ArrayLike] = None
bootstrap: Optional[int] = None
boxes: List[Box]
......@@ -245,10 +245,10 @@ class BoxTrace2D(BaseModel):
class ErrorPoint2D(BaseModel):
metadata: Metadata = {}
x: Any # used to be: float
y: Any # used to be: float
x_error: Optional[Tuple[float, float]]
y_error: Optional[Tuple[float, float]]
xi: Any # used to be: float
yi: Any # used to be: float
xerr: Any # should always be: Optional[Tuple[float, float]], however matplotlib stub does not specify
yerr: Any
def emit_warnings(self) -> None:
msg: List[str] = []
......@@ -261,7 +261,7 @@ class ErrorBar2DTrace(BaseModel):
type: Literal["errorbar2d"]
metadata: Metadata = {}
label: Optional[str] = None
marker: Optional[str] = None
marker: Optional[MarkerStyle] = None
color: Optional[Color] = None
ecolor: Optional[Color] = None
datapoints: List[ErrorPoint2D]
......@@ -278,7 +278,7 @@ class ErrorBar2DTrace(BaseModel):
class HistDataset(BaseModel):
metadata: Metadata = {}
x: List[float]
x: Any # should always be: List[Number], however matplotlib stub does not specify
color: Optional[str]
label: Optional[str]
......@@ -393,16 +393,16 @@ class Plot3D(BaseModel):
class Slice(BaseModel):
metadata: Metadata = {}
x: float
offset: Optional[Any] = None
name: Optional[str] = None
xi: float
explode: Optional[Any] = None
label: Optional[str] = None
color: Optional[Color] = None
def emit_warnings(self) -> None:
msg = []
if self.name is None or len(self.name.lstrip()) == 0:
msg.append("name")
if self.label is None or len(self.label.lstrip()) == 0:
msg.append("label")
if len(msg) > 0:
logging.warning("%s is not set for Slice object.", msg)
......
......
from typing import Any
import numpy as np
import pytest
from plot_serializer.matplotlib.serializer import MatplotlibSerializer
......@@ -41,6 +42,17 @@ from tests import validate_output
"log axis",
None,
),
(
"all_features_arraylike",
"bar_plot_all_features_arraylike_names",
np.array(["a", "b", "c", "d", "e", "f", "g", "h"]),
np.array([10, 20, 30, 40, 50, 60, 70, 80]),
["red", "green", "blue", "orange", "purple", "cyan", "blue", "blue"],
"My amazing bar plot",
"log",
"log axis",
None,
),
(
"different_input_types",
"bar_plot_different_input_types",
......
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please to comment