diff --git a/comet.py b/comet.py
new file mode 100644
index 0000000000000000000000000000000000000000..284ace8cfe8a19c2bc1165c00e380ec829cb7d0e
--- /dev/null
+++ b/comet.py
@@ -0,0 +1,39 @@
+import tensorflow as tf
+
+
+class MulticlassificationConfusionMatrix(tf.keras.callbacks.Callback):
+
+    def __init__(self, comet_experiment, x=None, y=None, sample_weights=None, batch_size=None, labels=None, freq=100):
+        self.comet_experiment = comet_experiment
+        self.data_x = x 
+        self.data_label = y 
+        self.data_weights = sample_weights 
+        self.batch_size = batch_size
+        self.freq = freq
+        self.labels=labels
+
+
+    def on_train_begin(self, logs=None):
+
+        pred = self.model.predict(x=self.data_x, batch_size=self.batch_size)
+        print("DEBUGGING!")
+        print(pred.shape)
+        print(self.data_label.shape)
+        self.comet_experiment.log_confusion_matrix(y_true=self.data_label,
+                                                   y_predicted=pred,
+                                                   labels=self.labels)
+
+    def on_epoch_end(self, epoch, logs=None):
+        if epoch // self.freq  == 0 and epoch != 0:
+            pred = self.model.predict(x=self.data_x, batch_size=self.batch_size)
+            self.comet_experiment.log_confusion_matrix(y_true=self.data_label,
+                                                       y_predicted=pred,
+                                                       labels=self.labels)
+
+
+    def on_train_end(self, logs=None):
+
+        pred = self.model.predict(x=self.data_x, batch_size=self.batch_size)
+        self.comet_experiment.log_confusion_matrix(y_true=self.data_label,
+                                                   y_predicted=pred,
+                                                   labels=self.labels)