diff --git a/keras.py b/keras.py index b8cebe697938248fbf9c46b4bf1535b9273a4adb..434b30c0ed31b305e447a290d968bef7af88f98f 100644 --- a/keras.py +++ b/keras.py @@ -400,6 +400,7 @@ class PlotMulticlass(TFSummaryCallback): columns=None, plot_inputs=False, signalvsbkg=False, + tag="", **kwargs, ): super().__init__(**kwargs) @@ -411,6 +412,7 @@ class PlotMulticlass(TFSummaryCallback): self.columns = columns self.to_file = to_file self.signalvsbkg = signalvsbkg + self.tag = tag def on_test_begin(self, logs=None): self.on_train_begin(logs=logs) @@ -449,7 +451,7 @@ class PlotMulticlass(TFSummaryCallback): ) for name, img in imgs.items(): with self.writer.as_default(): - tf.summary.image(name, img, step=0) + tf.summary.image(f"{name}{self.tag}", img, step=0) def on_test_end(self, logs=None): self.on_epoch_end(epoch=0, logs=logs) @@ -519,7 +521,7 @@ class PlotMulticlass(TFSummaryCallback): ) for name, img in imgs.items(): with self.writer.as_default(): - tf.summary.image(name, img, step=epoch) + tf.summary.image(f"{name}{self.tag}", img, step=epoch) class ModelLH(tf.keras.Model):