From 6e0fbf294f6e3a7e419aacface862d277148d7fd Mon Sep 17 00:00:00 2001 From: Dennis Noll <dennis.noll@rwth-aachen.de> Date: Fri, 14 Aug 2020 11:14:42 +0200 Subject: [PATCH] [keras] PlotMulticlass callback: adds tags --- keras.py | 6 ++++-- 1 file changed, 4 insertions(+), 2 deletions(-) diff --git a/keras.py b/keras.py index b8cebe6..434b30c 100644 --- a/keras.py +++ b/keras.py @@ -400,6 +400,7 @@ class PlotMulticlass(TFSummaryCallback): columns=None, plot_inputs=False, signalvsbkg=False, + tag="", **kwargs, ): super().__init__(**kwargs) @@ -411,6 +412,7 @@ class PlotMulticlass(TFSummaryCallback): self.columns = columns self.to_file = to_file self.signalvsbkg = signalvsbkg + self.tag = tag def on_test_begin(self, logs=None): self.on_train_begin(logs=logs) @@ -449,7 +451,7 @@ class PlotMulticlass(TFSummaryCallback): ) for name, img in imgs.items(): with self.writer.as_default(): - tf.summary.image(name, img, step=0) + tf.summary.image(f"{name}{self.tag}", img, step=0) def on_test_end(self, logs=None): self.on_epoch_end(epoch=0, logs=logs) @@ -519,7 +521,7 @@ class PlotMulticlass(TFSummaryCallback): ) for name, img in imgs.items(): with self.writer.as_default(): - tf.summary.image(name, img, step=epoch) + tf.summary.image(f"{name}{self.tag}", img, step=epoch) class ModelLH(tf.keras.Model): -- GitLab