From 6e0fbf294f6e3a7e419aacface862d277148d7fd Mon Sep 17 00:00:00 2001
From: Dennis Noll <dennis.noll@rwth-aachen.de>
Date: Fri, 14 Aug 2020 11:14:42 +0200
Subject: [PATCH] [keras] PlotMulticlass callback: adds tags

---
 keras.py | 6 ++++--
 1 file changed, 4 insertions(+), 2 deletions(-)

diff --git a/keras.py b/keras.py
index b8cebe6..434b30c 100644
--- a/keras.py
+++ b/keras.py
@@ -400,6 +400,7 @@ class PlotMulticlass(TFSummaryCallback):
         columns=None,
         plot_inputs=False,
         signalvsbkg=False,
+        tag="",
         **kwargs,
     ):
         super().__init__(**kwargs)
@@ -411,6 +412,7 @@ class PlotMulticlass(TFSummaryCallback):
         self.columns = columns
         self.to_file = to_file
         self.signalvsbkg = signalvsbkg
+        self.tag = tag
 
     def on_test_begin(self, logs=None):
         self.on_train_begin(logs=logs)
@@ -449,7 +451,7 @@ class PlotMulticlass(TFSummaryCallback):
             )
             for name, img in imgs.items():
                 with self.writer.as_default():
-                    tf.summary.image(name, img, step=0)
+                    tf.summary.image(f"{name}{self.tag}", img, step=0)
 
     def on_test_end(self, logs=None):
         self.on_epoch_end(epoch=0, logs=logs)
@@ -519,7 +521,7 @@ class PlotMulticlass(TFSummaryCallback):
         )
         for name, img in imgs.items():
             with self.writer.as_default():
-                tf.summary.image(name, img, step=epoch)
+                tf.summary.image(f"{name}{self.tag}", img, step=epoch)
 
 
 class ModelLH(tf.keras.Model):
-- 
GitLab