Skip to content
Snippets Groups Projects
Commit c2d89107 authored by Dennis Noll's avatar Dennis Noll
Browse files

[keras] PlottingCallback: added data exploration at begin of epoch

parent 62a03057
No related branches found
No related tags found
No related merge requests found
......@@ -399,7 +399,7 @@ class TFSummaryCallback(tf.keras.callbacks.Callback):
super().__init__(**kwargs)
class PlotMulticlass(TFSummaryCallback, CustomValidation):
class PlotMulticlass(TFSummaryCallback):
def __init__(self, class_names=["signal", "background"], to_file=False, columns=None, **kwargs):
super().__init__(**kwargs)
self.x = kwargs["x"]
......@@ -410,18 +410,23 @@ class PlotMulticlass(TFSummaryCallback, CustomValidation):
self.to_file = to_file
def on_train_begin(self, logs=None):
imgs = {}
if self.columns:
inps = self.kwargs["x"]
inps = self.x
if not isinstance(inps, (list, tuple)):
inps = [inps]
imgs = {}
for part, inp in zip(self.columns.keys(), inps):
imgs[f"input_{part}"] = figure_to_image(
figure_multihist(inp, columns=self.columns[part])
)
for name, img in imgs.items():
with self.writer.as_default():
tf.summary.image(name, img, step=0)
imgs["processes"] = figure_to_image(figure_y(self.truth, class_names=self.class_names))
imgs["processes_relative"] = figure_to_image(
figure_y(self.truth, class_names=self.class_names, relative=True)
)
for name, img in imgs.items():
with self.writer.as_default():
tf.summary.image(name, img, step=0)
def on_epoch_end(self, epoch, logs=None):
super().on_epoch_end(epoch, logs)
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment