Skip to content
Snippets Groups Projects
Commit 283c73f2 authored by Dennis Noll's avatar Dennis Noll
Browse files

[tools] comet: black

parent d2340589
No related branches found
No related tags found
No related merge requests found
import tensorflow as tf
from tools.keras import PlottingCallback
class MulticlassificationConfusionMatrix(tf.keras.callbacks.Callback):
def __init__(self, comet_experiment, x=None, y=None, sample_weights=None, batch_size=None, labels=None, freq=100, verbose=0):
def __init__(
self,
comet_experiment,
x=None,
y=None,
sample_weights=None,
batch_size=None,
labels=None,
freq=100,
verbose=0,
):
self.comet_experiment = comet_experiment
self.data_x = x
self.data_label = y
self.data_weights = sample_weights
self.data_x = x
self.data_label = y
self.data_weights = sample_weights
self.batch_size = batch_size
self.freq = freq
self.labels = labels
self.verbose = verbose
def on_train_begin(self, logs=None):
if self.verbose:
print("Logging confusion matrix to comet experiment")
pred = self.model.predict(x=self.data_x, batch_size=self.batch_size)
self.comet_experiment.log_confusion_matrix(y_true=self.data_label,
y_predicted=pred,
labels=self.labels,
title="Confusion Start",
epoch=0,
file_name="confusion-matrix-0.json")
self.comet_experiment.log_confusion_matrix(
y_true=self.data_label,
y_predicted=pred,
labels=self.labels,
title="Confusion Start",
epoch=0,
file_name="confusion-matrix-0.json",
)
def on_epoch_end(self, epoch, logs=None):
if epoch % self.freq == 0 and epoch != 0:
if epoch % self.freq == 0 and epoch != 0:
if self.verbose:
print("Logging confusion matrix to comet experiment")
pred = self.model.predict(x=self.data_x, batch_size=self.batch_size)
self.comet_experiment.log_confusion_matrix(y_true=self.data_label,
y_predicted=pred,
labels=self.labels,
title="Confusion Epoch {}".format(epoch),
epoch=epoch,
file_name="confusion-matrix-{}.json".format(epoch))
self.comet_experiment.log_confusion_matrix(
y_true=self.data_label,
y_predicted=pred,
labels=self.labels,
title="Confusion Epoch {}".format(epoch),
epoch=epoch,
file_name="confusion-matrix-{}.json".format(epoch),
)
def on_train_end(self, logs=None):
if self.verbose:
print("Logging confusion matrix to comet experiment")
pred = self.model.predict(x=self.data_x, batch_size=self.batch_size)
self.comet_experiment.log_confusion_matrix(y_true=self.data_label,
y_predicted=pred,
labels=self.labels,
title="Confusion End",
epoch=-1,
file_name="confusion-matrix-end.json")
self.comet_experiment.log_confusion_matrix(
y_true=self.data_label,
y_predicted=pred,
labels=self.labels,
title="Confusion End",
epoch=-1,
file_name="confusion-matrix-end.json",
)
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment