Skip to content
Snippets Groups Projects
Commit 283c73f2 authored by Dennis Noll's avatar Dennis Noll
Browse files

[tools] comet: black

parent d2340589
No related branches found
No related tags found
No related merge requests found
import tensorflow as tf import tensorflow as tf
from tools.keras import PlottingCallback from tools.keras import PlottingCallback
class MulticlassificationConfusionMatrix(tf.keras.callbacks.Callback): class MulticlassificationConfusionMatrix(tf.keras.callbacks.Callback):
def __init__(self, comet_experiment, x=None, y=None, sample_weights=None, batch_size=None, labels=None, freq=100, verbose=0): def __init__(
self,
comet_experiment,
x=None,
y=None,
sample_weights=None,
batch_size=None,
labels=None,
freq=100,
verbose=0,
):
self.comet_experiment = comet_experiment self.comet_experiment = comet_experiment
self.data_x = x self.data_x = x
self.data_label = y self.data_label = y
self.data_weights = sample_weights self.data_weights = sample_weights
self.batch_size = batch_size self.batch_size = batch_size
self.freq = freq self.freq = freq
self.labels = labels self.labels = labels
self.verbose = verbose self.verbose = verbose
def on_train_begin(self, logs=None): def on_train_begin(self, logs=None):
if self.verbose: if self.verbose:
print("Logging confusion matrix to comet experiment") print("Logging confusion matrix to comet experiment")
pred = self.model.predict(x=self.data_x, batch_size=self.batch_size) pred = self.model.predict(x=self.data_x, batch_size=self.batch_size)
self.comet_experiment.log_confusion_matrix(y_true=self.data_label, self.comet_experiment.log_confusion_matrix(
y_predicted=pred, y_true=self.data_label,
labels=self.labels, y_predicted=pred,
title="Confusion Start", labels=self.labels,
epoch=0, title="Confusion Start",
file_name="confusion-matrix-0.json") epoch=0,
file_name="confusion-matrix-0.json",
)
def on_epoch_end(self, epoch, logs=None): def on_epoch_end(self, epoch, logs=None):
if epoch % self.freq == 0 and epoch != 0: if epoch % self.freq == 0 and epoch != 0:
if self.verbose: if self.verbose:
print("Logging confusion matrix to comet experiment") print("Logging confusion matrix to comet experiment")
pred = self.model.predict(x=self.data_x, batch_size=self.batch_size) pred = self.model.predict(x=self.data_x, batch_size=self.batch_size)
self.comet_experiment.log_confusion_matrix(y_true=self.data_label, self.comet_experiment.log_confusion_matrix(
y_predicted=pred, y_true=self.data_label,
labels=self.labels, y_predicted=pred,
title="Confusion Epoch {}".format(epoch), labels=self.labels,
epoch=epoch, title="Confusion Epoch {}".format(epoch),
file_name="confusion-matrix-{}.json".format(epoch)) epoch=epoch,
file_name="confusion-matrix-{}.json".format(epoch),
)
def on_train_end(self, logs=None): def on_train_end(self, logs=None):
if self.verbose: if self.verbose:
print("Logging confusion matrix to comet experiment") print("Logging confusion matrix to comet experiment")
pred = self.model.predict(x=self.data_x, batch_size=self.batch_size) pred = self.model.predict(x=self.data_x, batch_size=self.batch_size)
self.comet_experiment.log_confusion_matrix(y_true=self.data_label, self.comet_experiment.log_confusion_matrix(
y_predicted=pred, y_true=self.data_label,
labels=self.labels, y_predicted=pred,
title="Confusion End", labels=self.labels,
epoch=-1, title="Confusion End",
file_name="confusion-matrix-end.json") epoch=-1,
file_name="confusion-matrix-end.json",
)
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment