diff --git a/keras.py b/keras.py
index b8cebe697938248fbf9c46b4bf1535b9273a4adb..434b30c0ed31b305e447a290d968bef7af88f98f 100644
--- a/keras.py
+++ b/keras.py
@@ -400,6 +400,7 @@ class PlotMulticlass(TFSummaryCallback):
         columns=None,
         plot_inputs=False,
         signalvsbkg=False,
+        tag="",
         **kwargs,
     ):
         super().__init__(**kwargs)
@@ -411,6 +412,7 @@ class PlotMulticlass(TFSummaryCallback):
         self.columns = columns
         self.to_file = to_file
         self.signalvsbkg = signalvsbkg
+        self.tag = tag
 
     def on_test_begin(self, logs=None):
         self.on_train_begin(logs=logs)
@@ -449,7 +451,7 @@ class PlotMulticlass(TFSummaryCallback):
             )
             for name, img in imgs.items():
                 with self.writer.as_default():
-                    tf.summary.image(name, img, step=0)
+                    tf.summary.image(f"{name}{self.tag}", img, step=0)
 
     def on_test_end(self, logs=None):
         self.on_epoch_end(epoch=0, logs=logs)
@@ -519,7 +521,7 @@ class PlotMulticlass(TFSummaryCallback):
         )
         for name, img in imgs.items():
             with self.writer.as_default():
-                tf.summary.image(name, img, step=epoch)
+                tf.summary.image(f"{name}{self.tag}", img, step=epoch)
 
 
 class ModelLH(tf.keras.Model):