From cee755b73e0ec8b698f3bde1a692ce87a50d2ef1 Mon Sep 17 00:00:00 2001 From: Dennis Noll <dennis.noll@rwth-aachen.de> Date: Fri, 17 Jul 2020 11:23:50 +0200 Subject: [PATCH] [keras] metrics+plotting: added one-vs-all AUC + plotting plotting: - roc_curve (+log) - node_activation (+disjoint) (+weighted) --- keras.py | 105 +++++++++++++++++++++++++++++++++++++++++++++++----- plotting.py | 79 +++++++++++++++++++++++++++++++-------- 2 files changed, 160 insertions(+), 24 deletions(-) diff --git a/keras.py b/keras.py index f2fc84e..361ab65 100644 --- a/keras.py +++ b/keras.py @@ -19,6 +19,7 @@ from .plotting import ( figure_confusion_matrix, figure_activations, figure_node_activations, + figure_roc_curve, ) # various helper functions @@ -84,7 +85,7 @@ class KFeed(object): return dict( zip(["x", "y", "sample_weight"], self.get(bal[self.train])), validation_data=self.get(bal[self.valid]), - **kwargs + **kwargs, ) def gfeed(self, src, batch_size, rng=np.random, auto_steps=np.max, validation=None, **kwargs): @@ -108,7 +109,7 @@ class KFeed(object): ), validation_data=self.get(val), workers=0, - **kwargs + **kwargs, ) if validation is not None: validation_copy = validation.copy() @@ -117,7 +118,7 @@ class KFeed(object): validation_copy.pop("cls", CustomValidation)( **dict( zip(("x", "y", "sample_weight"), ret.pop("validation_data")), - **validation_copy + **validation_copy, ) ), ) @@ -393,7 +394,7 @@ class CustomValidation(tf.keras.callbacks.Callback): class PlotMulticlass(CustomValidation): def __init__(self, logdir=None, class_names=["signal", "background"], **kwargs): super().__init__(**kwargs) - self.predict_values = kwargs["x"] + self.x = kwargs["x"] self.truth = kwargs["y"] self.sample_weight = kwargs.get("sample_weight", None) self.class_names = class_names @@ -404,10 +405,24 @@ class PlotMulticlass(CustomValidation): self.make_plots(epoch, logs) def make_plots(self, epoch, logs): - prediction = self.model.predict(self.predict_values) + prediction = self.model.predict(self.x) truth = self.truth imgs = {} + imgs["roc_curve"] = figure_to_image( + figure_roc_curve( + truth, prediction, class_names=self.class_names, sample_weight=self.sample_weight + ) + ) + imgs["roc_curve_log"] = figure_to_image( + figure_roc_curve( + truth, + prediction, + class_names=self.class_names, + sample_weight=self.sample_weight, + scale="log", + ) + ) imgs["confusion_matrix_true"] = figure_to_image( figure_confusion_matrix( truth, @@ -429,9 +444,26 @@ class PlotMulticlass(CustomValidation): imgs["activation"] = figure_to_image( figure_activations(prediction, class_names=self.class_names) ) - imgs["node_activation"] = figure_to_image( + imgs["node_activation_unweighted"] = figure_to_image( figure_node_activations(prediction, truth, class_names=self.class_names) ) + imgs["node_activation"] = figure_to_image( + figure_node_activations( + prediction, truth, class_names=self.class_names, sample_weight=self.sample_weight + ) + ) + imgs["node_activation_disjoint_unweighted"] = figure_to_image( + figure_node_activations(prediction, truth, class_names=self.class_names, disjoint=True) + ) + imgs["node_activation_disjoint"] = figure_to_image( + figure_node_activations( + prediction, + truth, + class_names=self.class_names, + disjoint=True, + sample_weight=self.sample_weight, + ) + ) for name, img in imgs.items(): with self.file_writer.as_default(): tf.summary.image(name, img, step=epoch) @@ -529,7 +561,7 @@ class PatientTracker(BestTracker): override=None, jank=np.nan, jank_last=-np.inf, - **kwargs + **kwargs, ): if override is None: del override @@ -716,7 +748,62 @@ class TQES(EarlyStopping): self.tqE.close() -def classification_metrics(): +class AUCOneVsAll(tf.keras.metrics.AUC): + def __init__(self, one=0, *args, **kwargs): + self.one = one + super().__init__(*args, **kwargs) + + def update_state(self, y_true, y_pred, sample_weight=None): + from tensorflow.python.framework import ops + from tensorflow.python.framework import tensor_shape + from tensorflow.python.keras.utils import metrics_utils + from tensorflow.python.ops import check_ops + + deps = [] + if not self._built: + self._build(tensor_shape.TensorShape(y_pred.shape)) + + if self.multi_label or (self.label_weights is not None): + # y_true should have shape (number of examples, number of labels). + shapes = [(y_true, ("N", "L"))] + if self.multi_label: + # TP, TN, FP, and FN should all have shape + # (number of thresholds, number of labels). + shapes.extend( + [ + (self.true_positives, ("T", "L")), + (self.true_negatives, ("T", "L")), + (self.false_positives, ("T", "L")), + (self.false_negatives, ("T", "L")), + ] + ) + if self.label_weights is not None: + # label_weights should be of length equal to the number of labels. + shapes.append((self.label_weights, ("L",))) + deps = [check_ops.assert_shapes(shapes, message="Number of labels is not consistent.")] + + # Only forward label_weights to update_confusion_matrix_variables when + # multi_label is False. Otherwise the averaging of individual label AUCs is + # handled in AUC.result + label_weights = None if self.multi_label else self.label_weights + with ops.control_dependencies(deps): + return metrics_utils.update_confusion_matrix_variables( + { + metrics_utils.ConfusionMatrix.TRUE_POSITIVES: self.true_positives, + metrics_utils.ConfusionMatrix.TRUE_NEGATIVES: self.true_negatives, + metrics_utils.ConfusionMatrix.FALSE_POSITIVES: self.false_positives, + metrics_utils.ConfusionMatrix.FALSE_NEGATIVES: self.false_negatives, + }, + y_true[:, self.one], + y_pred[:, self.one], + self.thresholds, + sample_weight=sample_weight, + multi_label=self.multi_label, + label_weights=label_weights, + ) + + +def classification_metrics(classes=2): return [ tf.keras.metrics.CategoricalAccuracy(name="accuracy"), tf.keras.metrics.CategoricalCrossentropy(name="crossentropy"), @@ -727,7 +814,7 @@ def classification_metrics(): tf.keras.metrics.Precision(name="precision"), tf.keras.metrics.Recall(name="recall"), tf.keras.metrics.AUC(name="auc"), - ] + ] + [AUCOneVsAll(one=i, name=f"custom_auc_{i}") for i in range(classes) if classes > 2] class DenseLayer(tf.keras.layers.Layer): diff --git a/plotting.py b/plotting.py index 45f3e45..ceeec28 100644 --- a/plotting.py +++ b/plotting.py @@ -6,7 +6,9 @@ import tensorflow as tf import itertools import numpy as np from matplotlib import pyplot as plt -from sklearn.metrics import confusion_matrix +from sklearn.metrics import confusion_matrix, roc_curve, auc + +from .numpy import one_hot class Quadrature: @@ -30,7 +32,7 @@ def figure_confusion_matrix( class_names=["signal", "background"], sample_weight=None, normalize="true", - **kwargs + **kwargs, ): assert len(class_names) == truth.shape[-1] == prediction.shape[-1] fig, ax = plt.subplots() @@ -52,8 +54,8 @@ def figure_confusion_matrix( for i, j in itertools.product(range(cm.shape[0]), range(cm.shape[1])): plt.text(j, i, np.around(cm[i, j], decimals=2), horizontalalignment="center", size=7) - plt.ylabel("True label") - plt.xlabel("Predicted label") + plt.ylabel("True label" + " (normed)" * (normalize == "true")) + plt.xlabel("Predicted label" + " (normed)" * (normalize == "pred")) fig.tight_layout() return fig @@ -67,7 +69,7 @@ def figure_activations(activations, class_names=None): plt.hist( activations[:, i], bins, - histtype=u"step", + histtype="step", density=True, label="%i" % i if class_names is None else class_names[i], ) @@ -78,7 +80,9 @@ def figure_activations(activations, class_names=None): return fig -def figure_node_activations(activations, truth, class_names=None): +def figure_node_activations( + activations, truth, class_names=None, disjoint=False, sample_weight=None +): n_b, n_p = activations.shape quad = Quadrature(n_p) @@ -87,24 +91,69 @@ def figure_node_activations(activations, truth, class_names=None): bins = np.linspace(0, 1.0, 10) process_activations = [] + process_weights = [] for process in range(n_p): - process_activations.append(activations[truth[:, process]].swapaxes(0, 1)) + if disjoint: + max_activations = np.argmax(activations, axis=-1) + one_hot_max_activations = one_hot(max_activations) + values = (activations * one_hot_max_activations)[truth[:, process]].swapaxes(0, 1) + values[values == 0] = -10000 + else: + values = activations[truth[:, process]].swapaxes(0, 1) + process_activations.append(values) + if sample_weight is not None: + process_weights.append(sample_weight[truth[:, process]]) for node in range(n_p): + ax_index = quad.index(node) for process in range(n_p): - ax[quad.index(node)].hist( - process_activations[process][node], - bins, - histtype=u"step", - density=True, - label="%i" % process if class_names is None else class_names[process], - ) + label = "%i" % process if class_names is None else class_names[process] + plot_kwargs = {"histtype": "step", "density": True, "label": label, "range": (0.0, 1.0)} + if len(process_weights) == n_p: + plot_kwargs["weights"] = process_weights[process] + + old_err = np.seterr(divide="ignore", invalid="ignore") + ax[ax_index].hist(process_activations[process][node], bins, **plot_kwargs) + np.seterr(**old_err) + + ax[ax_index].text( + 0.95, + 0.95, + f"node {class_names[node]}", + ha="right", + va="top", + transform=ax[ax_index].transAxes, + ) ax[quad.index(node)].set_yscale("log") - ax[quad.index(cols - 1)].legend(bbox_to_anchor=(1.05, 1.0), loc="upper left") + ax[quad.index(cols - 1)].legend(title="processes", bbox_to_anchor=(1.05, 1.0), loc="upper left") fig.tight_layout() return fig +def figure_roc_curve( + truth, prediction, indices=[0], class_names=None, sample_weight=None, lw=2, scale="linear" +): + fig = plt.figure() + for index in indices: + fpr, tpr, _ = roc_curve(truth[:, index], prediction[:, index]) + roc_auc = auc(fpr, tpr) + name = index if class_names is None else class_names[index] + plt.plot(fpr, tpr, lw=lw, label=f"{name} vs All (area = {roc_auc:.2f})") + plt.plot([0, 1], [0, 1], color="navy", lw=lw, linestyle="--") + lower = 0.0 + if scale.endswith("log"): + plt.xscale(scale) + plt.yscale(scale) + lower = 1e-5 + plt.xlim([lower, 1.0]) + plt.ylim([lower, 1.05]) + plt.xlabel("False Positive Rate") + plt.ylabel("True Positive Rate") + plt.title("Receiver operating characteristic curve") + plt.legend(loc="lower right") + return fig + + def figure_to_image(figure): """Converts the matplotlib plot specified by 'figure' to a PNG image and returns it. The supplied figure is closed and inaccessible after this call.""" -- GitLab