Skip to content
Snippets Groups Projects
Commit 6e0fbf29 authored by Dennis Noll's avatar Dennis Noll
Browse files

[keras] PlotMulticlass callback: adds tags

parent 5cf61256
No related branches found
No related tags found
No related merge requests found
......@@ -400,6 +400,7 @@ class PlotMulticlass(TFSummaryCallback):
columns=None,
plot_inputs=False,
signalvsbkg=False,
tag="",
**kwargs,
):
super().__init__(**kwargs)
......@@ -411,6 +412,7 @@ class PlotMulticlass(TFSummaryCallback):
self.columns = columns
self.to_file = to_file
self.signalvsbkg = signalvsbkg
self.tag = tag
def on_test_begin(self, logs=None):
self.on_train_begin(logs=logs)
......@@ -449,7 +451,7 @@ class PlotMulticlass(TFSummaryCallback):
)
for name, img in imgs.items():
with self.writer.as_default():
tf.summary.image(name, img, step=0)
tf.summary.image(f"{name}{self.tag}", img, step=0)
def on_test_end(self, logs=None):
self.on_epoch_end(epoch=0, logs=logs)
......@@ -519,7 +521,7 @@ class PlotMulticlass(TFSummaryCallback):
)
for name, img in imgs.items():
with self.writer.as_default():
tf.summary.image(name, img, step=epoch)
tf.summary.image(f"{name}{self.tag}", img, step=epoch)
class ModelLH(tf.keras.Model):
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment