import logging
import math
import os
import sys
from pathlib import Path
from typing import Any, Callable, List, Mapping, Optional, TextIO, Union
from rocrate.rocrate import ROCrate # type: ignore[import-untyped]
from plot_serializer.model import (
BarTrace2D,
BoxTrace2D,
ErrorBar2DTrace,
ErrorPoint2D,
Figure,
HistogramTrace,
LineTrace2D,
LineTrace3D,
PiePlot,
Plot3D,
Point2D,
Point3D,
PointTrace,
PointTraceNoBar,
ScatterTrace2D,
ScatterTrace3D,
SurfaceTrace3D,
Trace2D,
Trace3D,
Xyz,
)
_CURRENT_SPEC = "https://plot-serializer.readthedocs.io/en/latest/matplotlib_json_spec_0.2.0.html"
[docs]
def write_schema_json(file: Union[TextIO, str]) -> None:
"""
Writes the scheme of the figure to a file on disk.
Args:
file (Union[TextIO, str]): Either a filepath as string or a TextIO object
"""
if isinstance(file, str):
with open(file, "w") as file:
write_schema_json(file)
else:
file.write(Figure.schema_json(indent=2))
[docs]
class Serializer:
"""
A Serializer is an object that has a subclass for different libraries
(e.g. MatplotlibSerializer). The Serializer allows you to use a library like
you would normally, while collecting all the data you specify inside the plotting
library and providing methods for serializing that information to json.
"""
def __init__(self) -> None:
self._figure = Figure()
self._collect_actions: List[Callable[[], None]] = []
self._written_to_file: bool = False
self._was_collected: bool = False
def _add_collect_action(self, action: Callable[[], None]) -> None:
# Internal method to register a function that will be run every time
# the user accesses the current serializer state.
self._collect_actions.append(action)
def _cast_to_datapoint_trace(self, trace: Any) -> PointTrace | None:
if (
isinstance(trace, ScatterTrace2D)
or isinstance(trace, ScatterTrace3D)
or isinstance(trace, LineTrace2D)
or isinstance(trace, LineTrace3D)
or isinstance(trace, SurfaceTrace3D)
or isinstance(trace, BarTrace2D)
or isinstance(trace, ErrorBar2DTrace)
):
return trace
return None
def _find_traces(
self,
traces: List[Trace2D] | List[Trace3D],
trace_selector: tuple[float, float] | tuple[float, float, float],
trace_rel_tol: float,
) -> List[PointTraceNoBar]:
result_traces: List[PointTraceNoBar] = []
for trace in traces:
datapoint_trace = self._cast_to_datapoint_trace(trace)
if datapoint_trace is None or isinstance(datapoint_trace, BarTrace2D):
continue
else:
for datapoint in datapoint_trace.datapoints:
if isinstance(datapoint, Point2D) or isinstance(datapoint, ErrorPoint2D):
if len(trace_selector) != 2:
raise ValueError("Length of trace_selector needs to two when dealing with 2D-points")
elif math.isclose(datapoint.x, trace_selector[0], rel_tol=trace_rel_tol) and math.isclose(
datapoint.y, trace_selector[1], rel_tol=trace_rel_tol
):
result_traces.append(datapoint_trace)
break
elif isinstance(datapoint, Point3D):
if len(trace_selector) != 3:
raise ValueError("Length of trace_selector needs to three when dealing with 3D-points")
elif (
math.isclose(datapoint.x, trace_selector[0], rel_tol=trace_rel_tol)
and math.isclose(datapoint.y, trace_selector[1], rel_tol=trace_rel_tol)
and math.isclose(datapoint.z, trace_selector[2], rel_tol=trace_rel_tol)
):
result_traces.append(datapoint_trace)
return result_traces
def _find_points(
self,
trace: PointTrace,
point_selector: tuple[float, float] | tuple[float, float, float],
point_rel_tolerance: float,
) -> List[Point2D | ErrorPoint2D | Point3D]:
result_points: List[Point2D | ErrorPoint2D | Point3D] = []
if isinstance(trace, BarTrace2D):
raise ValueError("Code Error, this should not be reached. Its relevance is for Mypy errors.")
for datapoint in trace.datapoints:
if isinstance(datapoint, Point2D) or isinstance(datapoint, ErrorPoint2D):
if len(point_selector) != 2:
raise ValueError("Length of point_selector needs to two when dealing with 2D-points")
elif math.isclose(datapoint.x, point_selector[0], rel_tol=point_rel_tolerance) and math.isclose(
datapoint.y, point_selector[1], rel_tol=point_rel_tolerance
):
result_points.append(datapoint)
elif isinstance(datapoint, Point3D):
if len(point_selector) != 3:
raise ValueError("Length of point_selector needs to three when dealing with 3D-points")
elif (
math.isclose(datapoint.x, point_selector[0], rel_tol=point_rel_tolerance)
and math.isclose(datapoint.y, point_selector[1], rel_tol=point_rel_tolerance)
and math.isclose(datapoint.z, point_selector[2], rel_tol=point_rel_tolerance)
):
result_points.append(datapoint)
return result_points
def _update_points_metadata(
self,
trace: PointTrace,
point_selector: int | tuple[float, float] | tuple[float, float, float],
point_rel_tolerance: float,
dict: Mapping[str, Union[int, float, str]],
) -> int:
if isinstance(point_selector, int):
trace.datapoints[point_selector].metadata.update(dict)
return 1
else:
datapoints = self._find_points(trace, point_selector, point_rel_tolerance)
for datapoint in datapoints:
datapoint.metadata.update(dict)
return len(datapoints)
def check_collected_and_written(self) -> None:
if self._written_to_file:
raise NotImplementedError(
"You have already written your JSON file, added metadata will not be represented in the JSON"
)
if not self._was_collected:
self.serialized_figure()
def add_custom_metadata_plot(
self,
dict: Mapping[str, Union[int, float, str]],
plot_selector: int = 0,
) -> None:
self.check_collected_and_written()
plot = self._figure.plots[plot_selector]
plot.metadata.update(dict)
def add_custom_metadata_axis(
self,
dict: Mapping[str, Union[int, float, str]],
axis: Xyz,
plot_selector: int = 0,
) -> None:
self.check_collected_and_written()
plot = self._figure.plots[plot_selector]
if isinstance(plot, PiePlot):
raise ValueError("PiePlot has no axis to which metadata can be added")
elif not isinstance(plot, Plot3D) and axis == "z":
raise ValueError("cannot modify z axis, only x and y axis found, plot is not 3D")
elif isinstance(plot, Plot3D) and axis == "z":
plot.z_axis.metadata.update(dict)
elif axis == "x":
plot.x_axis.metadata.update(dict)
elif axis == "y":
plot.y_axis.metadata.update(dict)
def add_custom_metadata_trace(
self,
dict: Mapping[str, Union[int, float, str]],
plot_selector: int = 0,
trace_selector: int | tuple[float, float] | tuple[float, float, float] = 0,
trace_rel_tol: float = 0.000000001,
) -> None:
self.check_collected_and_written()
plot = self._figure.plots[plot_selector]
count_traces_changed: int = 0
if isinstance(plot, PiePlot):
raise NotImplementedError(
"Pieplot does not have any traces to add metadata to."
+ "Try add_custom_metadata_datapoints for adding metadata to slices"
)
else:
if isinstance(trace_selector, int):
trace = plot.traces[trace_selector]
trace.metadata.update(dict)
count_traces_changed += 1
else:
selected_traces = self._find_traces(plot.traces, trace_selector, trace_rel_tol)
for trace in selected_traces:
trace.metadata.update(dict)
count_traces_changed += len(selected_traces)
logging.info(f"In total, {count_traces_changed} traces' metadata were updated")
def add_custom_metadata_datapoints(
self,
dict: Mapping[str, Union[int, float, str]],
point_selector: int | tuple[float, float] | tuple[float, float, float],
trace_selector: int | tuple[float, float] | tuple[float, float, float],
point_rel_tolerance: float = 0.000000001,
plot_selector: int = 0,
trace_rel_tol: float = sys.float_info.max,
) -> None:
self.check_collected_and_written()
plot = self._figure.plots[plot_selector]
count_points_changed: int = 0
if isinstance(plot, PiePlot):
if isinstance(point_selector, int):
plot.slices[point_selector].metadata.update(dict)
count_points_changed += 1
else:
raise ValueError(
"Trying to access slices of Pie using tuples, not index."
+ "Point selection via tuples only viable for real datapoints."
)
else:
if isinstance(trace_selector, int):
selected_trace = plot.traces[trace_selector]
if isinstance(selected_trace, BoxTrace2D):
if isinstance(point_selector, int):
selected_trace.x[point_selector].metadata.update(dict)
count_points_changed += 1
else:
raise ValueError(
"Can not search for point in boxtrace as values might be strings. Try selecting by index."
)
elif isinstance(selected_trace, HistogramTrace):
if isinstance(point_selector, int):
selected_trace.x[point_selector].metadata.update(dict)
count_points_changed += 1
else:
raise ValueError("Can not search for points in histtrace, try selecting by index")
else:
trace = self._cast_to_datapoint_trace(plot.traces[trace_selector])
if trace is None:
raise ValueError("Selected Plot has no points! Verify plot- and trace-selector arguments.")
elif isinstance(trace, BarTrace2D):
if isinstance(point_selector, int):
trace.datapoints[point_selector].metadata.update(dict)
else:
raise ValueError(
"Can not search for point in bartrace as values might be strings."
+ "Try searching by index."
)
else:
count_points_changed += self._update_points_metadata(
trace, point_selector, point_rel_tolerance, dict
)
else:
selected_traces = self._find_traces(plot.traces, trace_selector, trace_rel_tol)
for trace in selected_traces:
count_points_changed += self._update_points_metadata(
trace, point_selector, point_rel_tolerance, dict
)
logging.info(f"In total, {count_points_changed} datapoints' metadata were updated")
[docs]
def add_to_ro_crate(
self,
crate_path: Union[str, Path],
file_path: str,
*,
create: bool = True,
name: Optional[str] = None,
) -> None:
"""
Adds the figure from this serializer to the specified ro-crate as a json file.
If the specified ro-crate does not exist, by default, a new one will be created.
If no name is explicitly specified, the name of the figure is used instead.
If the figure has no name, the name of the file specified in file path is used.
Args:
crate_path (Union[str, Path]): Path to the root folder of the ro-crate.
file_path (str): File path within the ro-crate where the file is placed
(excluding the path to the ro-crate itself).
create (bool): Whether to create the ro-crate if it doesn't exist. Defaults to True.
name (Optional[str], optional): Name of the ro-crate. Defaults to None.
"""
_temporary_file_name = "_temporary_plotserializer_output.json"
crate_path = Path(crate_path)
if not file_path.endswith(".json"):
file_path += ".json"
if name is None:
name = self.serialized_figure().title
if name is None:
name = Path(file_path).stem
# Load crate
if create:
crate_path.mkdir(parents=True, exist_ok=True)
try:
crate = ROCrate(crate_path)
except ValueError:
crate = ROCrate(crate_path, init=True)
else:
crate = ROCrate(crate_path)
try:
# Write temporary json file
self.write_json_file(_temporary_file_name)
# Add file to rocrate
crate.add_file(
source=_temporary_file_name,
dest_path=file_path,
properties={
"name": name,
"encodingFormat": "application/json",
"conformsTo": {
"@id": _CURRENT_SPEC,
},
},
)
# Write the changed crate
crate.write(crate_path)
finally:
# Remove temporary file
Path(_temporary_file_name).unlink()
[docs]
def to_json(self, *, emit_warnings: bool = True) -> str:
"""
Returns the data that has been collected so far as a json-encoded string.
Args:
emit_warnings (bool): If set to True (default), warnings about missing graph properties will be logged
Returns:
str: Json string
"""
if not self._was_collected:
self.serialized_figure()
if emit_warnings:
self._figure.emit_warnings()
return self._figure.model_dump_json(indent=2, exclude_defaults=True)
[docs]
def write_json_file(self, file: Union[TextIO, str], *, emit_warnings: bool = True) -> None:
"""
Writes the collected data as json to a file on disk.
Args:
file (Union[TextIO, str]): Either a filepath as string or a TextIO object
emit_warnings (bool): If set to True (default), warnings about missing graph properties will be logged
"""
if self._written_to_file:
raise NotImplementedError("You can only write the figure into the JSON once! Multiple tries were attempted")
if isinstance(file, str):
directory = os.path.dirname(file)
if directory:
os.makedirs(directory, exist_ok=True)
with open(file, "w") as file:
self.write_json_file(file)
else:
file.write(self.to_json(emit_warnings=emit_warnings))
self._written_to_file = True