Skip to content
Snippets Groups Projects
Code owners
Assign users and groups as approvers for specific file changes. Learn more.
comet.py 2.13 KiB
import tensorflow as tf
from tools.keras import PlottingCallback


class MulticlassificationConfusionMatrix(tf.keras.callbacks.Callback):
    def __init__(
        self,
        comet_experiment,
        x=None,
        y=None,
        sample_weights=None,
        batch_size=None,
        labels=None,
        freq=100,
        verbose=0,
    ):
        self.comet_experiment = comet_experiment
        self.data_x = x
        self.data_label = y
        self.data_weights = sample_weights
        self.batch_size = batch_size
        self.freq = freq
        self.labels = labels
        self.verbose = verbose

    def on_train_begin(self, logs=None):
        if self.verbose:
            print("Logging confusion matrix to comet experiment")
        pred = self.model.predict(x=self.data_x, batch_size=self.batch_size)
        self.comet_experiment.log_confusion_matrix(
            y_true=self.data_label,
            y_predicted=pred,
            labels=self.labels,
            title="Confusion Start",
            epoch=0,
            file_name="confusion-matrix-0.json",
        )

    def on_epoch_end(self, epoch, logs=None):
        if epoch % self.freq == 0 and epoch != 0:
            if self.verbose:
                print("Logging confusion matrix to comet experiment")
            pred = self.model.predict(x=self.data_x, batch_size=self.batch_size)
            self.comet_experiment.log_confusion_matrix(
                y_true=self.data_label,
                y_predicted=pred,
                labels=self.labels,
                title="Confusion Epoch {}".format(epoch),
                epoch=epoch,
                file_name="confusion-matrix-{}.json".format(epoch),
            )

    def on_train_end(self, logs=None):
        if self.verbose:
            print("Logging confusion matrix to comet experiment")
        pred = self.model.predict(x=self.data_x, batch_size=self.batch_size)
        self.comet_experiment.log_confusion_matrix(
            y_true=self.data_label,
            y_predicted=pred,
            labels=self.labels,
            title="Confusion End",
            epoch=-1,
            file_name="confusion-matrix-end.json",
        )