From 48431ab459028649b6fab9209ad8737207b1f27f Mon Sep 17 00:00:00 2001 From: Dennis Noll <dennis.noll@rwth-aachen.de> Date: Wed, 15 Jul 2020 09:55:10 +0200 Subject: [PATCH] [keras] PlotMulticlass: adds callback for plots in Multiclassification --- keras.py | 67 +++++++++++++++++++++++++++++ plotting.py | 122 ++++++++++++++++++++++++++++++++++++++++++++++++++++ 2 files changed, 189 insertions(+) create mode 100644 plotting.py diff --git a/keras.py b/keras.py index 7003400..ce1bb54 100644 --- a/keras.py +++ b/keras.py @@ -14,6 +14,12 @@ from operator import itemgetter from .evil import pin from .data import SKDict, DSS +from .plotting import ( + figure_to_image, + figure_confusion_matrix, + figure_activations, + figure_node_activations, +) # various helper functions @@ -384,6 +390,53 @@ class CustomValidation(tf.keras.callbacks.Callback): logs.update(make_logs(self.model, logs, res, mode=ModeKeys.TEST, prefix="val_")) +class PlotMulticlass(CustomValidation): + def __init__(self, logdir=None, class_names=["signal", "background"], **kwargs): + super().__init__(**kwargs) + self.predict_values = kwargs["x"] + self.truth = kwargs["y"] + self.sample_weight = kwargs.get("sample_weight", None) + self.class_names = class_names + self.file_writer = tf.summary.create_file_writer(logdir) + + def on_epoch_end(self, epoch, logs=None): + super().on_epoch_end(epoch, logs) + self.make_plots(epoch, logs) + + def make_plots(self, epoch, logs): + prediction = self.model.predict(self.predict_values) + truth = self.truth + + imgs = {} + imgs["confusion_matrix_true"] = figure_to_image( + figure_confusion_matrix( + truth, + prediction, + class_names=self.class_names, + sample_weight=self.sample_weight, + normalize="true", + ) + ) + imgs["confusion_matrix_pred"] = figure_to_image( + figure_confusion_matrix( + truth, + prediction, + class_names=self.class_names, + sample_weight=self.sample_weight, + normalize="pred", + ) + ) + imgs["activation"] = figure_to_image( + figure_activations(prediction, class_names=self.class_names) + ) + imgs["node_activation"] = figure_to_image( + figure_node_activations(prediction, truth, class_names=self.class_names) + ) + for name, img in imgs.items(): + with self.file_writer.as_default(): + tf.summary.image(name, img, step=epoch) + + class ModelLH(tf.keras.Model): def __init__(self, *args, **kwargs): self.loss_hook = kwargs.pop("loss_hook", None) @@ -661,3 +714,17 @@ class TQES(EarlyStopping): def on_train_end(self, logs=None): super(TQES, self).on_train_end(logs) self.tqE.close() + + +def classification_metrics(): + return [ + tf.keras.metrics.CategoricalAccuracy(name="accuracy"), + tf.keras.metrics.CategoricalCrossentropy(name="crossentropy"), + tf.keras.metrics.TruePositives(name="tp"), + tf.keras.metrics.FalsePositives(name="fp"), + tf.keras.metrics.TrueNegatives(name="tn"), + tf.keras.metrics.FalseNegatives(name="fn"), + tf.keras.metrics.Precision(name="precision"), + tf.keras.metrics.Recall(name="recall"), + tf.keras.metrics.AUC(name="auc"), + ] diff --git a/plotting.py b/plotting.py new file mode 100644 index 0000000..45f3e45 --- /dev/null +++ b/plotting.py @@ -0,0 +1,122 @@ +# -*- coding: utf-8 -*- + +import io + +import tensorflow as tf +import itertools +import numpy as np +from matplotlib import pyplot as plt +from sklearn.metrics import confusion_matrix + + +class Quadrature: + def __init__(self, n): + self.n = n + self.cols = np.ceil(np.sqrt(n)).astype(int) + self.rows = np.ceil(n / self.cols).astype(int) + + def lenghts(self): + return self.rows, self.cols + + def index(self, n): + row = int(n / self.cols) + col = n - row * self.cols + return row, col + + +def figure_confusion_matrix( + truth, + prediction, + class_names=["signal", "background"], + sample_weight=None, + normalize="true", + **kwargs +): + assert len(class_names) == truth.shape[-1] == prediction.shape[-1] + fig, ax = plt.subplots() + cm = confusion_matrix( + np.argmax(truth, axis=-1), + np.argmax(prediction, axis=-1), + sample_weight=sample_weight, + normalize=normalize, + ) + cmap = "plasma" if normalize == "true" else "viridis" + im = ax.imshow(cm, interpolation="nearest", cmap=cmap) + plt.title("Confusion matrix") + fig.colorbar(im, ax=ax) + + tick_marks = np.arange(len(class_names)) + plt.xticks(tick_marks, class_names, rotation=45) + plt.yticks(tick_marks, class_names) + + for i, j in itertools.product(range(cm.shape[0]), range(cm.shape[1])): + plt.text(j, i, np.around(cm[i, j], decimals=2), horizontalalignment="center", size=7) + + plt.ylabel("True label") + plt.xlabel("Predicted label") + fig.tight_layout() + return fig + + +def figure_activations(activations, class_names=None): + bins = np.linspace(0, 1.0, 10) + n_b, n_p = activations.shape + fig = plt.figure() + + for i in range(n_p): + plt.hist( + activations[:, i], + bins, + histtype=u"step", + density=True, + label="%i" % i if class_names is None else class_names[i], + ) + + plt.yscale("log") + plt.legend() + fig.tight_layout() + return fig + + +def figure_node_activations(activations, truth, class_names=None): + n_b, n_p = activations.shape + quad = Quadrature(n_p) + + rows, cols = quad.lenghts() + fig, ax = plt.subplots(rows, cols, figsize=(15, 15 * rows / cols)) + bins = np.linspace(0, 1.0, 10) + + process_activations = [] + for process in range(n_p): + process_activations.append(activations[truth[:, process]].swapaxes(0, 1)) + + for node in range(n_p): + for process in range(n_p): + ax[quad.index(node)].hist( + process_activations[process][node], + bins, + histtype=u"step", + density=True, + label="%i" % process if class_names is None else class_names[process], + ) + ax[quad.index(node)].set_yscale("log") + ax[quad.index(cols - 1)].legend(bbox_to_anchor=(1.05, 1.0), loc="upper left") + fig.tight_layout() + return fig + + +def figure_to_image(figure): + """Converts the matplotlib plot specified by 'figure' to a PNG image and + returns it. The supplied figure is closed and inaccessible after this call.""" + # Save the plot to a PNG in memory. + buf = io.BytesIO() + plt.savefig(buf, format="png") + # Closing the figure prevents it from being displayed directly inside + # the notebook. + plt.close(figure) + buf.seek(0) + # Convert PNG buffer to TF image + image = tf.image.decode_png(buf.getvalue(), channels=4) + # Add the batch dimension + image = tf.expand_dims(image, 0) + return image -- GitLab