From cee755b73e0ec8b698f3bde1a692ce87a50d2ef1 Mon Sep 17 00:00:00 2001
From: Dennis Noll <dennis.noll@rwth-aachen.de>
Date: Fri, 17 Jul 2020 11:23:50 +0200
Subject: [PATCH] [keras] metrics+plotting: added one-vs-all AUC + plotting
 plotting: - roc_curve (+log) - node_activation (+disjoint) (+weighted)

---
 keras.py    | 105 +++++++++++++++++++++++++++++++++++++++++++++++-----
 plotting.py |  79 +++++++++++++++++++++++++++++++--------
 2 files changed, 160 insertions(+), 24 deletions(-)

diff --git a/keras.py b/keras.py
index f2fc84e..361ab65 100644
--- a/keras.py
+++ b/keras.py
@@ -19,6 +19,7 @@ from .plotting import (
     figure_confusion_matrix,
     figure_activations,
     figure_node_activations,
+    figure_roc_curve,
 )
 
 # various helper functions
@@ -84,7 +85,7 @@ class KFeed(object):
         return dict(
             zip(["x", "y", "sample_weight"], self.get(bal[self.train])),
             validation_data=self.get(bal[self.valid]),
-            **kwargs
+            **kwargs,
         )
 
     def gfeed(self, src, batch_size, rng=np.random, auto_steps=np.max, validation=None, **kwargs):
@@ -108,7 +109,7 @@ class KFeed(object):
             ),
             validation_data=self.get(val),
             workers=0,
-            **kwargs
+            **kwargs,
         )
         if validation is not None:
             validation_copy = validation.copy()
@@ -117,7 +118,7 @@ class KFeed(object):
                 validation_copy.pop("cls", CustomValidation)(
                     **dict(
                         zip(("x", "y", "sample_weight"), ret.pop("validation_data")),
-                        **validation_copy
+                        **validation_copy,
                     )
                 ),
             )
@@ -393,7 +394,7 @@ class CustomValidation(tf.keras.callbacks.Callback):
 class PlotMulticlass(CustomValidation):
     def __init__(self, logdir=None, class_names=["signal", "background"], **kwargs):
         super().__init__(**kwargs)
-        self.predict_values = kwargs["x"]
+        self.x = kwargs["x"]
         self.truth = kwargs["y"]
         self.sample_weight = kwargs.get("sample_weight", None)
         self.class_names = class_names
@@ -404,10 +405,24 @@ class PlotMulticlass(CustomValidation):
         self.make_plots(epoch, logs)
 
     def make_plots(self, epoch, logs):
-        prediction = self.model.predict(self.predict_values)
+        prediction = self.model.predict(self.x)
         truth = self.truth
 
         imgs = {}
+        imgs["roc_curve"] = figure_to_image(
+            figure_roc_curve(
+                truth, prediction, class_names=self.class_names, sample_weight=self.sample_weight
+            )
+        )
+        imgs["roc_curve_log"] = figure_to_image(
+            figure_roc_curve(
+                truth,
+                prediction,
+                class_names=self.class_names,
+                sample_weight=self.sample_weight,
+                scale="log",
+            )
+        )
         imgs["confusion_matrix_true"] = figure_to_image(
             figure_confusion_matrix(
                 truth,
@@ -429,9 +444,26 @@ class PlotMulticlass(CustomValidation):
         imgs["activation"] = figure_to_image(
             figure_activations(prediction, class_names=self.class_names)
         )
-        imgs["node_activation"] = figure_to_image(
+        imgs["node_activation_unweighted"] = figure_to_image(
             figure_node_activations(prediction, truth, class_names=self.class_names)
         )
+        imgs["node_activation"] = figure_to_image(
+            figure_node_activations(
+                prediction, truth, class_names=self.class_names, sample_weight=self.sample_weight
+            )
+        )
+        imgs["node_activation_disjoint_unweighted"] = figure_to_image(
+            figure_node_activations(prediction, truth, class_names=self.class_names, disjoint=True)
+        )
+        imgs["node_activation_disjoint"] = figure_to_image(
+            figure_node_activations(
+                prediction,
+                truth,
+                class_names=self.class_names,
+                disjoint=True,
+                sample_weight=self.sample_weight,
+            )
+        )
         for name, img in imgs.items():
             with self.file_writer.as_default():
                 tf.summary.image(name, img, step=epoch)
@@ -529,7 +561,7 @@ class PatientTracker(BestTracker):
         override=None,
         jank=np.nan,
         jank_last=-np.inf,
-        **kwargs
+        **kwargs,
     ):
         if override is None:
             del override
@@ -716,7 +748,62 @@ class TQES(EarlyStopping):
         self.tqE.close()
 
 
-def classification_metrics():
+class AUCOneVsAll(tf.keras.metrics.AUC):
+    def __init__(self, one=0, *args, **kwargs):
+        self.one = one
+        super().__init__(*args, **kwargs)
+
+    def update_state(self, y_true, y_pred, sample_weight=None):
+        from tensorflow.python.framework import ops
+        from tensorflow.python.framework import tensor_shape
+        from tensorflow.python.keras.utils import metrics_utils
+        from tensorflow.python.ops import check_ops
+
+        deps = []
+        if not self._built:
+            self._build(tensor_shape.TensorShape(y_pred.shape))
+
+        if self.multi_label or (self.label_weights is not None):
+            # y_true should have shape (number of examples, number of labels).
+            shapes = [(y_true, ("N", "L"))]
+            if self.multi_label:
+                # TP, TN, FP, and FN should all have shape
+                # (number of thresholds, number of labels).
+                shapes.extend(
+                    [
+                        (self.true_positives, ("T", "L")),
+                        (self.true_negatives, ("T", "L")),
+                        (self.false_positives, ("T", "L")),
+                        (self.false_negatives, ("T", "L")),
+                    ]
+                )
+            if self.label_weights is not None:
+                # label_weights should be of length equal to the number of labels.
+                shapes.append((self.label_weights, ("L",)))
+            deps = [check_ops.assert_shapes(shapes, message="Number of labels is not consistent.")]
+
+        # Only forward label_weights to update_confusion_matrix_variables when
+        # multi_label is False. Otherwise the averaging of individual label AUCs is
+        # handled in AUC.result
+        label_weights = None if self.multi_label else self.label_weights
+        with ops.control_dependencies(deps):
+            return metrics_utils.update_confusion_matrix_variables(
+                {
+                    metrics_utils.ConfusionMatrix.TRUE_POSITIVES: self.true_positives,
+                    metrics_utils.ConfusionMatrix.TRUE_NEGATIVES: self.true_negatives,
+                    metrics_utils.ConfusionMatrix.FALSE_POSITIVES: self.false_positives,
+                    metrics_utils.ConfusionMatrix.FALSE_NEGATIVES: self.false_negatives,
+                },
+                y_true[:, self.one],
+                y_pred[:, self.one],
+                self.thresholds,
+                sample_weight=sample_weight,
+                multi_label=self.multi_label,
+                label_weights=label_weights,
+            )
+
+
+def classification_metrics(classes=2):
     return [
         tf.keras.metrics.CategoricalAccuracy(name="accuracy"),
         tf.keras.metrics.CategoricalCrossentropy(name="crossentropy"),
@@ -727,7 +814,7 @@ def classification_metrics():
         tf.keras.metrics.Precision(name="precision"),
         tf.keras.metrics.Recall(name="recall"),
         tf.keras.metrics.AUC(name="auc"),
-    ]
+    ] + [AUCOneVsAll(one=i, name=f"custom_auc_{i}") for i in range(classes) if classes > 2]
 
 
 class DenseLayer(tf.keras.layers.Layer):
diff --git a/plotting.py b/plotting.py
index 45f3e45..ceeec28 100644
--- a/plotting.py
+++ b/plotting.py
@@ -6,7 +6,9 @@ import tensorflow as tf
 import itertools
 import numpy as np
 from matplotlib import pyplot as plt
-from sklearn.metrics import confusion_matrix
+from sklearn.metrics import confusion_matrix, roc_curve, auc
+
+from .numpy import one_hot
 
 
 class Quadrature:
@@ -30,7 +32,7 @@ def figure_confusion_matrix(
     class_names=["signal", "background"],
     sample_weight=None,
     normalize="true",
-    **kwargs
+    **kwargs,
 ):
     assert len(class_names) == truth.shape[-1] == prediction.shape[-1]
     fig, ax = plt.subplots()
@@ -52,8 +54,8 @@ def figure_confusion_matrix(
     for i, j in itertools.product(range(cm.shape[0]), range(cm.shape[1])):
         plt.text(j, i, np.around(cm[i, j], decimals=2), horizontalalignment="center", size=7)
 
-    plt.ylabel("True label")
-    plt.xlabel("Predicted label")
+    plt.ylabel("True label" + " (normed)" * (normalize == "true"))
+    plt.xlabel("Predicted label" + " (normed)" * (normalize == "pred"))
     fig.tight_layout()
     return fig
 
@@ -67,7 +69,7 @@ def figure_activations(activations, class_names=None):
         plt.hist(
             activations[:, i],
             bins,
-            histtype=u"step",
+            histtype="step",
             density=True,
             label="%i" % i if class_names is None else class_names[i],
         )
@@ -78,7 +80,9 @@ def figure_activations(activations, class_names=None):
     return fig
 
 
-def figure_node_activations(activations, truth, class_names=None):
+def figure_node_activations(
+    activations, truth, class_names=None, disjoint=False, sample_weight=None
+):
     n_b, n_p = activations.shape
     quad = Quadrature(n_p)
 
@@ -87,24 +91,69 @@ def figure_node_activations(activations, truth, class_names=None):
     bins = np.linspace(0, 1.0, 10)
 
     process_activations = []
+    process_weights = []
     for process in range(n_p):
-        process_activations.append(activations[truth[:, process]].swapaxes(0, 1))
+        if disjoint:
+            max_activations = np.argmax(activations, axis=-1)
+            one_hot_max_activations = one_hot(max_activations)
+            values = (activations * one_hot_max_activations)[truth[:, process]].swapaxes(0, 1)
+            values[values == 0] = -10000
+        else:
+            values = activations[truth[:, process]].swapaxes(0, 1)
+        process_activations.append(values)
+        if sample_weight is not None:
+            process_weights.append(sample_weight[truth[:, process]])
 
     for node in range(n_p):
+        ax_index = quad.index(node)
         for process in range(n_p):
-            ax[quad.index(node)].hist(
-                process_activations[process][node],
-                bins,
-                histtype=u"step",
-                density=True,
-                label="%i" % process if class_names is None else class_names[process],
-            )
+            label = "%i" % process if class_names is None else class_names[process]
+            plot_kwargs = {"histtype": "step", "density": True, "label": label, "range": (0.0, 1.0)}
+            if len(process_weights) == n_p:
+                plot_kwargs["weights"] = process_weights[process]
+
+            old_err = np.seterr(divide="ignore", invalid="ignore")
+            ax[ax_index].hist(process_activations[process][node], bins, **plot_kwargs)
+            np.seterr(**old_err)
+
+        ax[ax_index].text(
+            0.95,
+            0.95,
+            f"node {class_names[node]}",
+            ha="right",
+            va="top",
+            transform=ax[ax_index].transAxes,
+        )
         ax[quad.index(node)].set_yscale("log")
-    ax[quad.index(cols - 1)].legend(bbox_to_anchor=(1.05, 1.0), loc="upper left")
+    ax[quad.index(cols - 1)].legend(title="processes", bbox_to_anchor=(1.05, 1.0), loc="upper left")
     fig.tight_layout()
     return fig
 
 
+def figure_roc_curve(
+    truth, prediction, indices=[0], class_names=None, sample_weight=None, lw=2, scale="linear"
+):
+    fig = plt.figure()
+    for index in indices:
+        fpr, tpr, _ = roc_curve(truth[:, index], prediction[:, index])
+        roc_auc = auc(fpr, tpr)
+        name = index if class_names is None else class_names[index]
+        plt.plot(fpr, tpr, lw=lw, label=f"{name} vs All (area = {roc_auc:.2f})")
+    plt.plot([0, 1], [0, 1], color="navy", lw=lw, linestyle="--")
+    lower = 0.0
+    if scale.endswith("log"):
+        plt.xscale(scale)
+        plt.yscale(scale)
+        lower = 1e-5
+    plt.xlim([lower, 1.0])
+    plt.ylim([lower, 1.05])
+    plt.xlabel("False Positive Rate")
+    plt.ylabel("True Positive Rate")
+    plt.title("Receiver operating characteristic curve")
+    plt.legend(loc="lower right")
+    return fig
+
+
 def figure_to_image(figure):
     """Converts the matplotlib plot specified by 'figure' to a PNG image and
     returns it. The supplied figure is closed and inaccessible after this call."""
-- 
GitLab