diff --git a/keras.py b/keras.py
index 5248cbe85cc105bbea6b979f402e518696a564e6..6c5d7bbf706789d1a8ace8b4ff78602d4475e61e 100644
--- a/keras.py
+++ b/keras.py
@@ -410,8 +410,8 @@ class PlotMulticlass(TFSummaryCallback):
     ):
         super().__init__(**kwargs)
         self.x = x
-        self.truth = y
-        self.sample_weight = sample_weight
+        self.truth = y[0]
+        self.sample_weight = sample_weight[0]
         self.class_names = class_names
         self.plot_inputs = plot_inputs
         self.columns = columns