Skip to content
Snippets Groups Projects
Commit 66f23183 authored by Dennis Noll's avatar Dennis Noll
Browse files

[keras] callbacks: added input plotting

parent f56cb424
No related branches found
No related tags found
No related merge requests found
......@@ -20,6 +20,7 @@ from .plotting import (
figure_activations,
figure_node_activations,
figure_roc_curve,
figure_multihist,
)
# various helper functions
......@@ -391,14 +392,35 @@ class CustomValidation(tf.keras.callbacks.Callback):
logs.update(make_logs(self.model, logs, res, mode=ModeKeys.TEST, prefix="val_"))
class PlotMulticlass(CustomValidation):
def __init__(self, logdir=None, class_names=["signal", "background"], **kwargs):
class TFSummaryCallback(tf.keras.callbacks.Callback):
def __init__(self, **kwargs):
self.writer = tf.summary.create_file_writer(kwargs.pop("logdir"))
super().__init__(**kwargs)
class PlotMulticlass(TFSummaryCallback, CustomValidation):
def __init__(self, class_names=["signal", "background"], to_file=False, columns=None, **kwargs):
super().__init__(**kwargs)
self.x = kwargs["x"]
self.truth = kwargs["y"]
self.sample_weight = kwargs.get("sample_weight", None)
self.class_names = class_names
self.file_writer = tf.summary.create_file_writer(logdir)
self.columns = columns
self.to_file = to_file
def on_train_begin(self, logs=None):
if self.columns:
inps = self.kwargs["x"]
if not isinstance(inps, (list, tuple)):
inps = [inps]
imgs = {}
for part, inp in zip(self.columns.keys(), inps):
imgs[f"input_{part}"] = figure_to_image(
figure_multihist(inp, columns=self.columns[part])
)
for name, img in imgs.items():
with self.writer.as_default():
tf.summary.image(name, img, step=0)
def on_epoch_end(self, epoch, logs=None):
super().on_epoch_end(epoch, logs)
......@@ -407,7 +429,6 @@ class PlotMulticlass(CustomValidation):
def make_plots(self, epoch, logs):
prediction = self.model.predict(self.x)
truth = self.truth
imgs = {}
imgs["roc_curve"] = figure_to_image(
figure_roc_curve(
......@@ -465,7 +486,7 @@ class PlotMulticlass(CustomValidation):
)
)
for name, img in imgs.items():
with self.file_writer.as_default():
with self.writer.as_default():
tf.summary.image(name, img, step=epoch)
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment