Skip to content
Snippets Groups Projects
Commit c6454af7 authored by Dennis Noll's avatar Dennis Noll
Browse files

[keras] PlotMulticlass callback: adds plot_input parameters

parent 3c4a65e0
No related branches found
No related tags found
No related merge requests found
......@@ -394,13 +394,19 @@ class CustomValidation(tf.keras.callbacks.Callback):
class TFSummaryCallback(tf.keras.callbacks.Callback):
def __init__(self, **kwargs):
self.writer = tf.summary.create_file_writer(kwargs.pop("logdir"))
super().__init__(**kwargs)
def __init__(self, logdir=None, **kwargs):
self.writer = tf.summary.create_file_writer(logdir)
class PlotMulticlass(TFSummaryCallback):
def __init__(self, class_names=["signal", "background"], to_file=False, columns=None, **kwargs):
def __init__(
self,
class_names=["signal", "background"],
to_file=False,
columns=None,
plot_inputs=False,
**kwargs,
):
super().__init__(**kwargs)
self.x = kwargs["x"]
self.truth = kwargs["y"]
......@@ -409,16 +415,22 @@ class PlotMulticlass(TFSummaryCallback):
self.columns = columns
self.to_file = to_file
def on_test_begin(self, logs=None):
self.on_train_begin(logs=logs)
def on_train_begin(self, logs=None):
imgs = {}
if self.columns:
inps = self.x
if not isinstance(inps, (list, tuple)):
inps = [inps]
for part, inp in zip(self.columns.keys(), inps):
imgs[f"input_{part}"] = figure_to_image(
figure_multihist(inp, columns=self.columns[part])
if self.plot_inputs:
imgs = {}
if self.columns:
inps = self.x
if not isinstance(inps, (list, tuple)):
inps = [inps]
for part, inp in zip(self.columns.keys(), inps):
imgs[f"inputs_merged_{part}"] = figure_to_image(
figure_multihist(inp, columns=self.columns[part])
)
if self.sample_weight is not None:
)
imgs["processes"] = figure_to_image(figure_y(self.truth, class_names=self.class_names))
imgs["processes_relative"] = figure_to_image(
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment