From 48431ab459028649b6fab9209ad8737207b1f27f Mon Sep 17 00:00:00 2001
From: Dennis Noll <dennis.noll@rwth-aachen.de>
Date: Wed, 15 Jul 2020 09:55:10 +0200
Subject: [PATCH] [keras] PlotMulticlass: adds callback for plots in
 Multiclassification

---
 keras.py    |  67 +++++++++++++++++++++++++++++
 plotting.py | 122 ++++++++++++++++++++++++++++++++++++++++++++++++++++
 2 files changed, 189 insertions(+)
 create mode 100644 plotting.py

diff --git a/keras.py b/keras.py
index 7003400..ce1bb54 100644
--- a/keras.py
+++ b/keras.py
@@ -14,6 +14,12 @@ from operator import itemgetter
 
 from .evil import pin
 from .data import SKDict, DSS
+from .plotting import (
+    figure_to_image,
+    figure_confusion_matrix,
+    figure_activations,
+    figure_node_activations,
+)
 
 # various helper functions
 
@@ -384,6 +390,53 @@ class CustomValidation(tf.keras.callbacks.Callback):
         logs.update(make_logs(self.model, logs, res, mode=ModeKeys.TEST, prefix="val_"))
 
 
+class PlotMulticlass(CustomValidation):
+    def __init__(self, logdir=None, class_names=["signal", "background"], **kwargs):
+        super().__init__(**kwargs)
+        self.predict_values = kwargs["x"]
+        self.truth = kwargs["y"]
+        self.sample_weight = kwargs.get("sample_weight", None)
+        self.class_names = class_names
+        self.file_writer = tf.summary.create_file_writer(logdir)
+
+    def on_epoch_end(self, epoch, logs=None):
+        super().on_epoch_end(epoch, logs)
+        self.make_plots(epoch, logs)
+
+    def make_plots(self, epoch, logs):
+        prediction = self.model.predict(self.predict_values)
+        truth = self.truth
+
+        imgs = {}
+        imgs["confusion_matrix_true"] = figure_to_image(
+            figure_confusion_matrix(
+                truth,
+                prediction,
+                class_names=self.class_names,
+                sample_weight=self.sample_weight,
+                normalize="true",
+            )
+        )
+        imgs["confusion_matrix_pred"] = figure_to_image(
+            figure_confusion_matrix(
+                truth,
+                prediction,
+                class_names=self.class_names,
+                sample_weight=self.sample_weight,
+                normalize="pred",
+            )
+        )
+        imgs["activation"] = figure_to_image(
+            figure_activations(prediction, class_names=self.class_names)
+        )
+        imgs["node_activation"] = figure_to_image(
+            figure_node_activations(prediction, truth, class_names=self.class_names)
+        )
+        for name, img in imgs.items():
+            with self.file_writer.as_default():
+                tf.summary.image(name, img, step=epoch)
+
+
 class ModelLH(tf.keras.Model):
     def __init__(self, *args, **kwargs):
         self.loss_hook = kwargs.pop("loss_hook", None)
@@ -661,3 +714,17 @@ class TQES(EarlyStopping):
     def on_train_end(self, logs=None):
         super(TQES, self).on_train_end(logs)
         self.tqE.close()
+
+
+def classification_metrics():
+    return [
+        tf.keras.metrics.CategoricalAccuracy(name="accuracy"),
+        tf.keras.metrics.CategoricalCrossentropy(name="crossentropy"),
+        tf.keras.metrics.TruePositives(name="tp"),
+        tf.keras.metrics.FalsePositives(name="fp"),
+        tf.keras.metrics.TrueNegatives(name="tn"),
+        tf.keras.metrics.FalseNegatives(name="fn"),
+        tf.keras.metrics.Precision(name="precision"),
+        tf.keras.metrics.Recall(name="recall"),
+        tf.keras.metrics.AUC(name="auc"),
+    ]
diff --git a/plotting.py b/plotting.py
new file mode 100644
index 0000000..45f3e45
--- /dev/null
+++ b/plotting.py
@@ -0,0 +1,122 @@
+# -*- coding: utf-8 -*-
+
+import io
+
+import tensorflow as tf
+import itertools
+import numpy as np
+from matplotlib import pyplot as plt
+from sklearn.metrics import confusion_matrix
+
+
+class Quadrature:
+    def __init__(self, n):
+        self.n = n
+        self.cols = np.ceil(np.sqrt(n)).astype(int)
+        self.rows = np.ceil(n / self.cols).astype(int)
+
+    def lenghts(self):
+        return self.rows, self.cols
+
+    def index(self, n):
+        row = int(n / self.cols)
+        col = n - row * self.cols
+        return row, col
+
+
+def figure_confusion_matrix(
+    truth,
+    prediction,
+    class_names=["signal", "background"],
+    sample_weight=None,
+    normalize="true",
+    **kwargs
+):
+    assert len(class_names) == truth.shape[-1] == prediction.shape[-1]
+    fig, ax = plt.subplots()
+    cm = confusion_matrix(
+        np.argmax(truth, axis=-1),
+        np.argmax(prediction, axis=-1),
+        sample_weight=sample_weight,
+        normalize=normalize,
+    )
+    cmap = "plasma" if normalize == "true" else "viridis"
+    im = ax.imshow(cm, interpolation="nearest", cmap=cmap)
+    plt.title("Confusion matrix")
+    fig.colorbar(im, ax=ax)
+
+    tick_marks = np.arange(len(class_names))
+    plt.xticks(tick_marks, class_names, rotation=45)
+    plt.yticks(tick_marks, class_names)
+
+    for i, j in itertools.product(range(cm.shape[0]), range(cm.shape[1])):
+        plt.text(j, i, np.around(cm[i, j], decimals=2), horizontalalignment="center", size=7)
+
+    plt.ylabel("True label")
+    plt.xlabel("Predicted label")
+    fig.tight_layout()
+    return fig
+
+
+def figure_activations(activations, class_names=None):
+    bins = np.linspace(0, 1.0, 10)
+    n_b, n_p = activations.shape
+    fig = plt.figure()
+
+    for i in range(n_p):
+        plt.hist(
+            activations[:, i],
+            bins,
+            histtype=u"step",
+            density=True,
+            label="%i" % i if class_names is None else class_names[i],
+        )
+
+    plt.yscale("log")
+    plt.legend()
+    fig.tight_layout()
+    return fig
+
+
+def figure_node_activations(activations, truth, class_names=None):
+    n_b, n_p = activations.shape
+    quad = Quadrature(n_p)
+
+    rows, cols = quad.lenghts()
+    fig, ax = plt.subplots(rows, cols, figsize=(15, 15 * rows / cols))
+    bins = np.linspace(0, 1.0, 10)
+
+    process_activations = []
+    for process in range(n_p):
+        process_activations.append(activations[truth[:, process]].swapaxes(0, 1))
+
+    for node in range(n_p):
+        for process in range(n_p):
+            ax[quad.index(node)].hist(
+                process_activations[process][node],
+                bins,
+                histtype=u"step",
+                density=True,
+                label="%i" % process if class_names is None else class_names[process],
+            )
+        ax[quad.index(node)].set_yscale("log")
+    ax[quad.index(cols - 1)].legend(bbox_to_anchor=(1.05, 1.0), loc="upper left")
+    fig.tight_layout()
+    return fig
+
+
+def figure_to_image(figure):
+    """Converts the matplotlib plot specified by 'figure' to a PNG image and
+    returns it. The supplied figure is closed and inaccessible after this call."""
+    # Save the plot to a PNG in memory.
+    buf = io.BytesIO()
+    plt.savefig(buf, format="png")
+    # Closing the figure prevents it from being displayed directly inside
+    # the notebook.
+    plt.close(figure)
+    buf.seek(0)
+    # Convert PNG buffer to TF image
+    image = tf.image.decode_png(buf.getvalue(), channels=4)
+    # Add the batch dimension
+    image = tf.expand_dims(image, 0)
+    return image
-- 
GitLab