From 334c360309d2eca4bda4b744e2c491afde45a478 Mon Sep 17 00:00:00 2001 From: Niclas Eich <niclas.eich@rwth-aachen.de> Date: Wed, 26 May 2021 15:18:40 +0200 Subject: [PATCH] Added comet file for custom callbacks --- comet.py | 39 +++++++++++++++++++++++++++++++++++++++ 1 file changed, 39 insertions(+) create mode 100644 comet.py diff --git a/comet.py b/comet.py new file mode 100644 index 0000000..284ace8 --- /dev/null +++ b/comet.py @@ -0,0 +1,39 @@ +import tensorflow as tf + + +class MulticlassificationConfusionMatrix(tf.keras.callbacks.Callback): + + def __init__(self, comet_experiment, x=None, y=None, sample_weights=None, batch_size=None, labels=None, freq=100): + self.comet_experiment = comet_experiment + self.data_x = x + self.data_label = y + self.data_weights = sample_weights + self.batch_size = batch_size + self.freq = freq + self.labels=labels + + + def on_train_begin(self, logs=None): + + pred = self.model.predict(x=self.data_x, batch_size=self.batch_size) + print("DEBUGGING!") + print(pred.shape) + print(self.data_label.shape) + self.comet_experiment.log_confusion_matrix(y_true=self.data_label, + y_predicted=pred, + labels=self.labels) + + def on_epoch_end(self, epoch, logs=None): + if epoch // self.freq == 0 and epoch != 0: + pred = self.model.predict(x=self.data_x, batch_size=self.batch_size) + self.comet_experiment.log_confusion_matrix(y_true=self.data_label, + y_predicted=pred, + labels=self.labels) + + + def on_train_end(self, logs=None): + + pred = self.model.predict(x=self.data_x, batch_size=self.batch_size) + self.comet_experiment.log_confusion_matrix(y_true=self.data_label, + y_predicted=pred, + labels=self.labels) -- GitLab