-
Dennis Noll authoredDennis Noll authored
Code owners
Assign users and groups as approvers for specific file changes. Learn more.
comet.py 2.13 KiB
import tensorflow as tf
from tools.keras import PlottingCallback
class MulticlassificationConfusionMatrix(tf.keras.callbacks.Callback):
def __init__(
self,
comet_experiment,
x=None,
y=None,
sample_weights=None,
batch_size=None,
labels=None,
freq=100,
verbose=0,
):
self.comet_experiment = comet_experiment
self.data_x = x
self.data_label = y
self.data_weights = sample_weights
self.batch_size = batch_size
self.freq = freq
self.labels = labels
self.verbose = verbose
def on_train_begin(self, logs=None):
if self.verbose:
print("Logging confusion matrix to comet experiment")
pred = self.model.predict(x=self.data_x, batch_size=self.batch_size)
self.comet_experiment.log_confusion_matrix(
y_true=self.data_label,
y_predicted=pred,
labels=self.labels,
title="Confusion Start",
epoch=0,
file_name="confusion-matrix-0.json",
)
def on_epoch_end(self, epoch, logs=None):
if epoch % self.freq == 0 and epoch != 0:
if self.verbose:
print("Logging confusion matrix to comet experiment")
pred = self.model.predict(x=self.data_x, batch_size=self.batch_size)
self.comet_experiment.log_confusion_matrix(
y_true=self.data_label,
y_predicted=pred,
labels=self.labels,
title="Confusion Epoch {}".format(epoch),
epoch=epoch,
file_name="confusion-matrix-{}.json".format(epoch),
)
def on_train_end(self, logs=None):
if self.verbose:
print("Logging confusion matrix to comet experiment")
pred = self.model.predict(x=self.data_x, batch_size=self.batch_size)
self.comet_experiment.log_confusion_matrix(
y_true=self.data_label,
y_predicted=pred,
labels=self.labels,
title="Confusion End",
epoch=-1,
file_name="confusion-matrix-end.json",
)