From 334c360309d2eca4bda4b744e2c491afde45a478 Mon Sep 17 00:00:00 2001
From: Niclas Eich <niclas.eich@rwth-aachen.de>
Date: Wed, 26 May 2021 15:18:40 +0200
Subject: [PATCH] Added comet file for custom callbacks

---
 comet.py | 39 +++++++++++++++++++++++++++++++++++++++
 1 file changed, 39 insertions(+)
 create mode 100644 comet.py

diff --git a/comet.py b/comet.py
new file mode 100644
index 0000000..284ace8
--- /dev/null
+++ b/comet.py
@@ -0,0 +1,39 @@
+import tensorflow as tf
+
+
+class MulticlassificationConfusionMatrix(tf.keras.callbacks.Callback):
+
+    def __init__(self, comet_experiment, x=None, y=None, sample_weights=None, batch_size=None, labels=None, freq=100):
+        self.comet_experiment = comet_experiment
+        self.data_x = x 
+        self.data_label = y 
+        self.data_weights = sample_weights 
+        self.batch_size = batch_size
+        self.freq = freq
+        self.labels=labels
+
+
+    def on_train_begin(self, logs=None):
+
+        pred = self.model.predict(x=self.data_x, batch_size=self.batch_size)
+        print("DEBUGGING!")
+        print(pred.shape)
+        print(self.data_label.shape)
+        self.comet_experiment.log_confusion_matrix(y_true=self.data_label,
+                                                   y_predicted=pred,
+                                                   labels=self.labels)
+
+    def on_epoch_end(self, epoch, logs=None):
+        if epoch // self.freq  == 0 and epoch != 0:
+            pred = self.model.predict(x=self.data_x, batch_size=self.batch_size)
+            self.comet_experiment.log_confusion_matrix(y_true=self.data_label,
+                                                       y_predicted=pred,
+                                                       labels=self.labels)
+
+
+    def on_train_end(self, logs=None):
+
+        pred = self.model.predict(x=self.data_x, batch_size=self.batch_size)
+        self.comet_experiment.log_confusion_matrix(y_true=self.data_label,
+                                                   y_predicted=pred,
+                                                   labels=self.labels)
-- 
GitLab