diff --git a/comet.py b/comet.py index 284ace8cfe8a19c2bc1165c00e380ec829cb7d0e..51efd685e97fb61e0d26682893746018ee1e138c 100644 --- a/comet.py +++ b/comet.py @@ -21,14 +21,16 @@ class MulticlassificationConfusionMatrix(tf.keras.callbacks.Callback): print(self.data_label.shape) self.comet_experiment.log_confusion_matrix(y_true=self.data_label, y_predicted=pred, - labels=self.labels) + labels=self.labels, + title="Confusion Start") def on_epoch_end(self, epoch, logs=None): if epoch // self.freq == 0 and epoch != 0: pred = self.model.predict(x=self.data_x, batch_size=self.batch_size) self.comet_experiment.log_confusion_matrix(y_true=self.data_label, y_predicted=pred, - labels=self.labels) + labels=self.labels, + title="Confusion Epoch {}".format(epoch)) def on_train_end(self, logs=None): @@ -36,4 +38,5 @@ class MulticlassificationConfusionMatrix(tf.keras.callbacks.Callback): pred = self.model.predict(x=self.data_x, batch_size=self.batch_size) self.comet_experiment.log_confusion_matrix(y_true=self.data_label, y_predicted=pred, - labels=self.labels) + labels=self.labels, + title="Confusion End")