diff --git a/comet.py b/comet.py new file mode 100644 index 0000000000000000000000000000000000000000..284ace8cfe8a19c2bc1165c00e380ec829cb7d0e --- /dev/null +++ b/comet.py @@ -0,0 +1,39 @@ +import tensorflow as tf + + +class MulticlassificationConfusionMatrix(tf.keras.callbacks.Callback): + + def __init__(self, comet_experiment, x=None, y=None, sample_weights=None, batch_size=None, labels=None, freq=100): + self.comet_experiment = comet_experiment + self.data_x = x + self.data_label = y + self.data_weights = sample_weights + self.batch_size = batch_size + self.freq = freq + self.labels=labels + + + def on_train_begin(self, logs=None): + + pred = self.model.predict(x=self.data_x, batch_size=self.batch_size) + print("DEBUGGING!") + print(pred.shape) + print(self.data_label.shape) + self.comet_experiment.log_confusion_matrix(y_true=self.data_label, + y_predicted=pred, + labels=self.labels) + + def on_epoch_end(self, epoch, logs=None): + if epoch // self.freq == 0 and epoch != 0: + pred = self.model.predict(x=self.data_x, batch_size=self.batch_size) + self.comet_experiment.log_confusion_matrix(y_true=self.data_label, + y_predicted=pred, + labels=self.labels) + + + def on_train_end(self, logs=None): + + pred = self.model.predict(x=self.data_x, batch_size=self.batch_size) + self.comet_experiment.log_confusion_matrix(y_true=self.data_label, + y_predicted=pred, + labels=self.labels)