diff --git a/data.py b/data.py
index eab779ffddd7366902af993b7027924704dde741..6044750490ab965826e7dfa2a7646858bbff4959 100644
--- a/data.py
+++ b/data.py
@@ -1,7 +1,7 @@
 import numpy as np
 import tensorflow as tf
 from operator import itemgetter
-from os import listdir, path
+from os import listdir, path, remove
 
 
 class SKDict(dict):
@@ -47,6 +47,11 @@ class SKDict(dict):
         else:
             super(SKDict, self).__setitem__(key, value)
 
+    def __delitem__(self, key):
+        key = self.keyify(key)
+        for k in self[key].keys():
+            super().__delitem__(k | key)
+
     def copy(self):
         return self.__class__(self)
 
@@ -216,6 +221,9 @@ class DSS(SKDict):
             }
         )
 
-    def to_npy(self, dir, sep="_", **kwargs):
+    def to_npy(self, dir, sep="_", clean=True, **kwargs):
+        for fn in listdir(dir):
+            if fn.endswith(".npy"):
+                remove(path.join(dir, fn))
         for key, value in self.items():
             np.save(path.join(dir, "%s.npy" % sep.join(sorted(key))), value, **kwargs)
diff --git a/evil.py b/evil.py
index 90a9993081692ee856820c73b1a17521f3d20315..8ff29bd677e41fb44216b11d33d1728ffd63f847 100644
--- a/evil.py
+++ b/evil.py
@@ -10,14 +10,17 @@ def ccall(func, pass_name=False):
         key = "word"
         def arguments(): pass
     """
+
     class metaclass(type):
         def __new__(cls, name, a, d):
             d.pop("__module__", 0)
             if pass_name:
                 d["__name__"] = name
             return func(*a[1:], **d)
+
         def __call__(*a, **d):
             return func(*a, **d)
+
     return type.__new__(metaclass, func.__name__, (), {})
 
 
@@ -31,9 +34,13 @@ def pin(loc, *skip):
         these = 2
     """
     loc = dict(loc)
-    loc.pop("self").__dict__.update({k: v for k, v in loc.items() if not any(
-        s is k or
-        s is True and k.startswith("_") or
-        s is v and isinstance(s, object)
-        for s in skip
-    )})
+    loc.pop("self").__dict__.update(
+        {
+            k: v
+            for k, v in loc.items()
+            if not any(
+                s is k or s is True and k.startswith("_") or s is v and isinstance(s, object)
+                for s in skip
+            )
+        }
+    )
diff --git a/keras.py b/keras.py
index e4b333cd1d28e2d504c0807390335ab8a6ffb503..5e6a4f95e081978f1c0e3f52ad334ae0da0143ea 100644
--- a/keras.py
+++ b/keras.py
@@ -11,6 +11,7 @@ from tensorflow.python.keras.callbacks import make_logs
 from tensorflow.python.keras.backend import track_variable
 from tensorflow.python.keras.utils.mode_keys import ModeKeys
 from operator import itemgetter
+from lbn import LBNLayer as OriginalLBNLayer
 
 from .evil import pin
 from .data import SKDict, DSS
@@ -19,6 +20,11 @@ from .plotting import (
     figure_confusion_matrix,
     figure_activations,
     figure_node_activations,
+    figure_roc_curve,
+    figure_multihist,
+    figure_y,
+    figure_weights,
+    figure_inputs,
 )
 
 # various helper functions
@@ -84,10 +90,10 @@ class KFeed(object):
         return dict(
             zip(["x", "y", "sample_weight"], self.get(bal[self.train])),
             validation_data=self.get(bal[self.valid]),
-            **kwargs
+            **kwargs,
         )
 
-    def gfeed(self, src, batch_size, rng=np.random, auto_steps=np.max, validation=None, **kwargs):
+    def gfeed(self, src, batch_size, rng=np.random, auto_steps=np.max, **kwargs):
         """
         Creates a generator for tf.keras' model.fit().
         Requires, that mean weights per process are equal:
@@ -99,7 +105,7 @@ class KFeed(object):
         val[self.w] = self.balance_weights(val[self.w])
         val = val.fuse(*val[self.w].keys())
         val.blen
-        ret = dict(
+        return dict(
             dict(
                 zip(
                     ("x", "steps_per_epoch"),
@@ -108,20 +114,8 @@ class KFeed(object):
             ),
             validation_data=self.get(val),
             workers=0,
-            **kwargs
+            **kwargs,
         )
-        if validation is not None:
-            validation_copy = validation.copy()
-            ret.setdefault("callbacks", []).insert(
-                0,
-                validation_copy.pop("cls", CustomValidation)(
-                    **dict(
-                        zip(("x", "y", "sample_weight"), ret.pop("validation_data")),
-                        **validation_copy
-                    )
-                ),
-            )
-        return ret
 
     def gensteps(self, src, batch_size, rng=np.random, auto_steps=np.max):
         keys = src.mkeys(self.all)
@@ -394,24 +388,100 @@ class CustomValidation(tf.keras.callbacks.Callback):
         logs.update(make_logs(self.model, logs, res, mode=ModeKeys.TEST, prefix="val_"))
 
 
-class PlotMulticlass(CustomValidation):
-    def __init__(self, logdir=None, class_names=["signal", "background"], **kwargs):
+class TFSummaryCallback(tf.keras.callbacks.Callback):
+    def __init__(self, logdir=None, **kwargs):
+        self.writer = tf.summary.create_file_writer(logdir)
+
+
+class PlotMulticlass(TFSummaryCallback):
+    def __init__(
+        self,
+        x,
+        y,
+        sample_weight=None,
+        class_names=["signal", "background"],
+        to_file=False,
+        columns=None,
+        plot_inputs=False,
+        signalvsbkg=False,
+        tag="",
+        **kwargs,
+    ):
         super().__init__(**kwargs)
-        self.predict_values = kwargs["x"]
-        self.truth = kwargs["y"]
-        self.sample_weight = kwargs.get("sample_weight", None)
+        self.x = x
+        self.truth = y
+        self.sample_weight = sample_weight
         self.class_names = class_names
-        self.file_writer = tf.summary.create_file_writer(logdir)
+        self.plot_inputs = plot_inputs
+        self.columns = columns
+        self.to_file = to_file
+        self.signalvsbkg = signalvsbkg
+        self.tag = tag
+
+    def on_test_begin(self, logs=None):
+        self.on_train_begin(logs=logs)
+
+    def on_train_begin(self, logs=None):
+        if self.plot_inputs:
+            imgs = {}
+            if self.columns:
+                inps = self.x
+                if not isinstance(inps, (list, tuple)):
+                    inps = [inps]
+
+                for part, inp in zip(self.columns.keys(), inps):
+                    imgs[f"inp_xmerged_{part}"] = figure_to_image(
+                        figure_multihist(inp, columns=self.columns[part])
+                    )
+            if self.sample_weight is not None:
+                for part, inp in zip(self.columns.keys(), inps):
+                    imgs[f"inp_x_{part}"] = figure_to_image(
+                        figure_inputs(
+                            inp,
+                            self.truth,
+                            sample_weight=self.sample_weight,
+                            columns=self.columns[part],
+                            class_names=self.class_names,
+                            signalvsbkg=self.signalvsbkg,
+                        )
+                    )
+
+                imgs["inp_weights"] = figure_to_image(
+                    figure_weights(self.sample_weight, self.truth, class_names=self.class_names)
+                )
+            imgs["inp_y"] = figure_to_image(figure_y(self.truth, class_names=self.class_names))
+            imgs["inp_yrelative"] = figure_to_image(
+                figure_y(self.truth, class_names=self.class_names, relative=True)
+            )
+            for name, img in imgs.items():
+                with self.writer.as_default():
+                    tf.summary.image(f"{name}{self.tag}", img, step=0)
+
+    def on_test_end(self, logs=None):
+        self.on_epoch_end(epoch=0, logs=logs)
 
     def on_epoch_end(self, epoch, logs=None):
         super().on_epoch_end(epoch, logs)
         self.make_plots(epoch, logs)
 
     def make_plots(self, epoch, logs):
-        prediction = self.model.predict(self.predict_values)
+        prediction = self.model.predict(self.x)
         truth = self.truth
-
         imgs = {}
+        imgs["roc_curve"] = figure_to_image(
+            figure_roc_curve(
+                truth, prediction, class_names=self.class_names, sample_weight=self.sample_weight
+            )
+        )
+        imgs["roc_curve_log"] = figure_to_image(
+            figure_roc_curve(
+                truth,
+                prediction,
+                class_names=self.class_names,
+                sample_weight=self.sample_weight,
+                scale="log",
+            )
+        )
         imgs["confusion_matrix_true"] = figure_to_image(
             figure_confusion_matrix(
                 truth,
@@ -433,12 +503,29 @@ class PlotMulticlass(CustomValidation):
         imgs["activation"] = figure_to_image(
             figure_activations(prediction, class_names=self.class_names)
         )
-        imgs["node_activation"] = figure_to_image(
+        imgs["node_activation_unweighted"] = figure_to_image(
             figure_node_activations(prediction, truth, class_names=self.class_names)
         )
+        imgs["node_activation"] = figure_to_image(
+            figure_node_activations(
+                prediction, truth, class_names=self.class_names, sample_weight=self.sample_weight
+            )
+        )
+        imgs["node_activation_disjoint_unweighted"] = figure_to_image(
+            figure_node_activations(prediction, truth, class_names=self.class_names, disjoint=True)
+        )
+        imgs["node_activation_disjoint"] = figure_to_image(
+            figure_node_activations(
+                prediction,
+                truth,
+                class_names=self.class_names,
+                disjoint=True,
+                sample_weight=self.sample_weight,
+            )
+        )
         for name, img in imgs.items():
-            with self.file_writer.as_default():
-                tf.summary.image(name, img, step=epoch)
+            with self.writer.as_default():
+                tf.summary.image(f"{name}{self.tag}", img, step=epoch)
 
 
 class ModelLH(tf.keras.Model):
@@ -487,6 +574,21 @@ class TensorBoard(tf.keras.callbacks.TensorBoard):
         pass
 
 
+class CheckpointModel(tf.keras.callbacks.Callback):
+    def __init__(self, savedir="tmp", frequency=1, identifier="cp"):
+        pin(locals())
+
+    def get_index(self, epoch):
+        return epoch
+
+    def checkpoint_dir(self, epoch):
+        return f"{self.savedir}/{self.identifier}-{self.get_index(epoch)}"
+
+    def on_epoch_end(self, epoch, logs=None):
+        if epoch % self.frequency == 0:
+            self.model.save(self.checkpoint_dir(epoch))
+
+
 class BestTracker(tf.keras.callbacks.Callback):
     def __init__(
         self, monitor="val_loss", mode="auto", min_delta=0, min_delta_rel=0, baseline=None
@@ -533,7 +635,7 @@ class PatientTracker(BestTracker):
         override=None,
         jank=np.nan,
         jank_last=-np.inf,
-        **kwargs
+        **kwargs,
     ):
         if override is None:
             del override
@@ -723,18 +825,59 @@ class TQES(EarlyStopping):
         self.tqE.close()
 
 
-def classification_metrics():
-    return [
-        tf.keras.metrics.CategoricalAccuracy(name="accuracy"),
-        tf.keras.metrics.CategoricalCrossentropy(name="crossentropy"),
-        tf.keras.metrics.TruePositives(name="tp"),
-        tf.keras.metrics.FalsePositives(name="fp"),
-        tf.keras.metrics.TrueNegatives(name="tn"),
-        tf.keras.metrics.FalseNegatives(name="fn"),
-        tf.keras.metrics.Precision(name="precision"),
-        tf.keras.metrics.Recall(name="recall"),
-        tf.keras.metrics.AUC(name="auc"),
-    ]
+class AUCOneVsAll(tf.keras.metrics.AUC):
+    def __init__(self, one=0, *args, **kwargs):
+        self.one = one
+        super().__init__(*args, **kwargs)
+
+    def update_state(self, y_true, y_pred, sample_weight=None):
+        from tensorflow.python.framework import ops
+        from tensorflow.python.framework import tensor_shape
+        from tensorflow.python.keras.utils import metrics_utils
+        from tensorflow.python.ops import check_ops
+
+        deps = []
+        if not self._built:
+            self._build(tensor_shape.TensorShape(y_pred.shape))
+
+        if self.multi_label or (self.label_weights is not None):
+            # y_true should have shape (number of examples, number of labels).
+            shapes = [(y_true, ("N", "L"))]
+            if self.multi_label:
+                # TP, TN, FP, and FN should all have shape
+                # (number of thresholds, number of labels).
+                shapes.extend(
+                    [
+                        (self.true_positives, ("T", "L")),
+                        (self.true_negatives, ("T", "L")),
+                        (self.false_positives, ("T", "L")),
+                        (self.false_negatives, ("T", "L")),
+                    ]
+                )
+            if self.label_weights is not None:
+                # label_weights should be of length equal to the number of labels.
+                shapes.append((self.label_weights, ("L",)))
+            deps = [check_ops.assert_shapes(shapes, message="Number of labels is not consistent.")]
+
+        # Only forward label_weights to update_confusion_matrix_variables when
+        # multi_label is False. Otherwise the averaging of individual label AUCs is
+        # handled in AUC.result
+        label_weights = None if self.multi_label else self.label_weights
+        with ops.control_dependencies(deps):
+            return metrics_utils.update_confusion_matrix_variables(
+                {
+                    metrics_utils.ConfusionMatrix.TRUE_POSITIVES: self.true_positives,
+                    metrics_utils.ConfusionMatrix.TRUE_NEGATIVES: self.true_negatives,
+                    metrics_utils.ConfusionMatrix.FALSE_POSITIVES: self.false_positives,
+                    metrics_utils.ConfusionMatrix.FALSE_NEGATIVES: self.false_negatives,
+                },
+                y_true[:, self.one],
+                y_pred[:, self.one],
+                self.thresholds,
+                sample_weight=sample_weight,
+                multi_label=self.multi_label,
+                label_weights=label_weights,
+            )
 
 
 class DenseLayer(tf.keras.layers.Layer):
@@ -765,25 +908,32 @@ class DenseLayer(tf.keras.layers.Layer):
     """
 
     def __init__(self, nodes=0, activation=None, dropout=0.0, l2=0, batch_norm=False):
-        super().__init__()
+        super().__init__(name="DenseLayer")
+        self.nodes = nodes
+        self.activation = activation
+        self.dropout = dropout
+        self.l2 = l2
+        self.batch_norm = batch_norm
+
+    def build(self, input_shape):
         parts = []
 
-        l2 = tf.keras.regularizers.l2(l2 if l2 else 0.0)
-        weights = tf.keras.layers.Dense(nodes, kernel_regularizer=l2)
+        l2 = tf.keras.regularizers.l2(self.l2)
+        weights = tf.keras.layers.Dense(self.nodes, kernel_regularizer=l2)
         parts.append(weights)
 
-        if batch_norm:
-            dropout = 0.0
+        if self.batch_norm:
+            self.dropout = 0.0
             bn = tf.keras.layers.BatchNormalization()
             parts.append(bn)
 
-        act = tf.keras.layers.Activation(activation)
+        act = tf.keras.layers.Activation(self.activation)
         parts.append(act)
 
-        if activation == "selu":
-            dropout = tf.keras.layers.AlphaDropout(dropout)
+        if self.activation == "selu":
+            dropout = tf.keras.layers.AlphaDropout(self.dropout)
         else:
-            dropout = tf.keras.layers.Dropout(dropout)
+            dropout = tf.keras.layers.Dropout(self.dropout)
 
         parts.append(dropout)
         self.parts = parts
@@ -794,6 +944,15 @@ class DenseLayer(tf.keras.layers.Layer):
             x = part(x, training=training)
         return x
 
+    def get_config(self):
+        return {
+            "nodes": self.nodes,
+            "activation": self.activation,
+            "dropout": self.dropout,
+            "l2": self.l2,
+            "batch_norm": self.batch_norm,
+        }
+
 
 class ResNetBlock(tf.keras.layers.Layer):
     """
@@ -810,9 +969,9 @@ class ResNetBlock(tf.keras.layers.Layer):
 
     def __init__(self, config, jump=2, **kwargs):
         super().__init__(name="ResNetBlock")
-
+        self.jump = jump
         layers = []
-        for i in range(jump - 1):
+        for i in range(self.jump - 1):
             layers.append(DenseLayer(**kwargs))
 
         activation = kwargs.pop("activation")
@@ -828,6 +987,9 @@ class ResNetBlock(tf.keras.layers.Layer):
         x = self.out_activation(x)
         return x
 
+    def get_config(self):
+        return {"jump": self.jump}
+
 
 class FullyConnected(tf.keras.layers.Layer):
     """
@@ -835,27 +997,33 @@ class FullyConnected(tf.keras.layers.Layer):
 
     Parameters
     ----------
-    number_layers : int
+    layers : int
         The number of layers.
     kwargs :
         Arguments for DenseLayer.
 
     """
 
-    def __init__(self, number_layers=0, **kwargs):
+    def __init__(self, layers=0, sub_kwargs=None, **kwargs):
         super().__init__(name="FullyConnected")
-
-        layers = []
-        for layer in range(number_layers):
-            layers.append(DenseLayer(**kwargs))
         self.layers = layers
+        self.sub_kwargs = kwargs if sub_kwargs is None else sub_kwargs
+
+    def build(self, input_shape):
+        network_layers = []
+        for layer in range(self.layers):
+            network_layers.append(DenseLayer(**self.sub_kwargs))
+        self.network_layers = network_layers
 
     def call(self, input_tensor, training=False):
         x = input_tensor
-        for layer in self.layers:
+        for layer in self.network_layers:
             x = layer(x, training=training)
         return x
 
+    def get_config(self):
+        return {"layers": self.layers, "sub_kwargs": self.sub_kwargs}
+
 
 class ResNet(tf.keras.layers.Layer):
     """
@@ -863,23 +1031,95 @@ class ResNet(tf.keras.layers.Layer):
 
     Parameters
     ----------
-    number_layers : int
+    layers : int
         The number of residual blocks.
     kwargs :
         Arguments for ResNetBlock.
 
     """
 
-    def __init__(self, number_layers=1, **kwargs):
+    def __init__(self, layers=1, **kwargs):
         super().__init__(name="ResNet")
-
-        layers = []
-        for i in range(number_layers):
-            layers.append(ResNetBlock(**kwargs))
         self.layers = layers
+        self.kwargs = kwargs
+
+    def build(self, input_shape):
+        _layers = []
+        for i in range(self.layers):
+            _layers.append(ResNetBlock(**self.kwargs))
+        self._layers = _layers
 
     def call(self, input_tensor, training=False):
         x = input_tensor
-        for layer in self.layers:
+        for layer in self._layers:
             x = layer(x, training=training)
         return x
+
+    def get_config(self):
+        return {"layers": self.layers}
+
+
+class RemoveLayer(tf.keras.layers.Layer):
+    def call(self, inputs):
+        return tf.keras.layers.Concatenate()([inputs[:, :, :4], inputs[:, :, 6:]])
+
+
+class SplitHighLow(tf.keras.layers.Layer):
+    def call(self, inputs):
+        return inputs[:, :, :4], inputs[:, :, 4:]
+
+
+class LBNLayer(tf.keras.layers.Layer):
+    """
+    Custom implementation of the LBNLayer with automatic cropping to
+    low-level variables before and batchnorm after the LBN layer.
+
+    Parameters
+    ----------
+    args :
+        args for original LBN layer
+    kwargs :
+        kwargs for original LBN layer
+
+    """
+
+    def __init__(self, *args, **kwargs):
+        super().__init__(name="LBNLayer")
+        _kwargs = kwargs.copy()
+        if "features" not in _kwargs:
+            _kwargs["features"] = self.all_features
+        self._args = args
+        self._kwargs = _kwargs
+
+    def build(self, input_shape):
+        self.concat = tf.keras.layers.Concatenate(axis=-2)
+        lbn_inp_shape = (sum(_[-2] for _ in input_shape), 4)
+        self.lbn_layer = OriginalLBNLayer(lbn_inp_shape, *self._args, **self._kwargs)
+        self.batch_norm = tf.keras.layers.BatchNormalization()
+
+    @property
+    def all_features(self):
+        return [
+            "E",
+            "px",
+            "py",
+            "pz",
+            "pt",
+            "p",
+            "m",
+            "phi",
+            "eta",
+            "beta",
+            "gamma",
+            "pair_cos",
+            "pair_dr",
+            "pair_ds",
+            "pair_dy",
+        ]
+
+    def call(self, input_tensors, training=False):
+        ll = [SplitHighLow()(tensor)[0] for tensor in input_tensors]
+        ll = self.concat(ll)
+        feats = self.lbn_layer(ll)
+        feats = self.batch_norm(feats)
+        return feats
diff --git a/numpy.py b/numpy.py
new file mode 100644
index 0000000000000000000000000000000000000000..356f2a3b0a61f351f5b7f0abbc19f217021ec1ca
--- /dev/null
+++ b/numpy.py
@@ -0,0 +1,71 @@
+# -*- coding: utf-8 -*-
+
+import numpy as np
+
+
+def one_hot(a, n=None):
+    if n is None:
+        n = a.max() + 1
+    o_h = np.zeros((a.size, n))
+    o_h[np.arange(a.size), a] = 1
+    return o_h
+
+
+def array_to_hist(values, bins):
+    hist = np.histogram2d([0.0], [0.0], bins=bins)
+    hist = list(hist)
+    hist[0] = values
+    hist = tuple(hist)
+    return hist
+
+
+def swap_rows(a, swp, roll=1, swp_to=None):
+    """Swaps rows in the last dimension of a.
+        The rows to swap are specified by the swp argument.
+        If swp_to is defined, the rows to swap the rows from swp
+        in can be directly specified, if not, the swapping is done
+        in a rolling manner.
+    Returns:
+        Array with swapped rows
+    """
+    if not (swp_to):
+        a[..., swp] = a[..., np.roll(swp, roll)]
+        return a
+    else:
+        a[..., swp] = a[..., swp_to]
+        return a
+
+
+def create_swaps(swap_length, circle, length):
+    start = -swap_length if circle else 0
+    return np.column_stack(
+        (np.arange(start, length - swap_length), np.arange(start + swap_length, length))
+    )
+
+
+def _axis_clip(ref, axis):
+    assert -ref <= axis < ref
+    axis = ref + axis if axis < 0 else axis
+    return axis, ref - axis - 1
+
+
+def xsel(source, saxis, indices, iaxis=-1):
+    return source[xsel_mask(source, saxis, indices, iaxis)]
+
+
+def xsel_mask(source, saxis, indices, iaxis=-1):
+    saxis, safter = _axis_clip(source.ndim, saxis)
+    iaxis, iafter = _axis_clip(indices.ndim, iaxis)
+    assert iaxis <= saxis
+    assert iafter <= safter
+    indices = indices[((None,) * (saxis - iaxis)) + (Ellipsis,) + ((None,) * (safter - iafter))]
+    grid = np.ogrid[tuple(map(slice, source.shape))]
+    grid[saxis] = indices
+    return tuple(grid)
+
+
+def template_to(inp, template):
+    return np.argsort(
+        xsel(np.argsort(inp, axis=-1), 1, np.argsort(np.argsort(template, axis=-1), axis=-1), 1),
+        axis=-1,
+    )
diff --git a/plotting.py b/plotting.py
index 45f3e4592b15707aa23648a879d479d514e674ee..d5b1b9fa1c12c428198a495bb4e2f246f1ad4eb7 100644
--- a/plotting.py
+++ b/plotting.py
@@ -1,27 +1,47 @@
 # -*- coding: utf-8 -*-
 
 import io
+import warnings
 
 import tensorflow as tf
 import itertools
 import numpy as np
 from matplotlib import pyplot as plt
-from sklearn.metrics import confusion_matrix
+from sklearn.metrics import confusion_matrix, roc_curve, auc
+import pandas as pd
 
+from .numpy import one_hot
 
-class Quadrature:
+
+class Multiplot:
     def __init__(self, n):
-        self.n = n
-        self.cols = np.ceil(np.sqrt(n)).astype(int)
-        self.rows = np.ceil(n / self.cols).astype(int)
+        try:
+            self.cols, self.rows = [i for i in n]
+        except TypeError:
+            self.n = n
+            self.cols = np.ceil(np.sqrt(n)).astype(int)
+            self.rows = np.ceil(n / self.cols).astype(int)
 
     def lenghts(self):
         return self.rows, self.cols
 
-    def index(self, n):
-        row = int(n / self.cols)
-        col = n - row * self.cols
-        return row, col
+    def index(self, i):
+        row = int(i / self.cols)
+        col = i - row * self.cols
+        if self.rows == 1:
+            return col
+        if self.cols == 1:
+            return row
+        else:
+            return row, col
+
+
+def saveplot(f):
+    def helper(*args, **kwargs):
+        plt.close("all")
+        return f(*args, **kwargs)
+
+    return helper
 
 
 def figure_confusion_matrix(
@@ -30,7 +50,7 @@ def figure_confusion_matrix(
     class_names=["signal", "background"],
     sample_weight=None,
     normalize="true",
-    **kwargs
+    **kwargs,
 ):
     assert len(class_names) == truth.shape[-1] == prediction.shape[-1]
     fig, ax = plt.subplots()
@@ -52,8 +72,8 @@ def figure_confusion_matrix(
     for i, j in itertools.product(range(cm.shape[0]), range(cm.shape[1])):
         plt.text(j, i, np.around(cm[i, j], decimals=2), horizontalalignment="center", size=7)
 
-    plt.ylabel("True label")
-    plt.xlabel("Predicted label")
+    plt.ylabel("True label" + " (normed)" * (normalize == "true"))
+    plt.xlabel("Predicted label" + " (normed)" * (normalize == "pred"))
     fig.tight_layout()
     return fig
 
@@ -67,7 +87,7 @@ def figure_activations(activations, class_names=None):
         plt.hist(
             activations[:, i],
             bins,
-            histtype=u"step",
+            histtype="step",
             density=True,
             label="%i" % i if class_names is None else class_names[i],
         )
@@ -78,29 +98,262 @@ def figure_activations(activations, class_names=None):
     return fig
 
 
-def figure_node_activations(activations, truth, class_names=None):
+def figure_history(history_csv_path):
+    # self.input()["training"]["history"].path
+    pd.read_csv(history_csv_path).plot(subplots=True, figsize=(30, 30), layout=(7, 6))
+    fig = plt.figure()
+    fig.tight_layout()
+    return fig
+    # plt.savefig(self.output()["plots"].path + "/history.pdf")
+
+
+def plot_histories(history_csv_path, path, cut=None, roll=1):
+    pdf = pd.read_csv(history_csv_path)
+    pdf = pdf.set_index("epoch")
+    pdf = pdf.truncate(after=cut)
+    for col in pdf.columns:
+        if col.startswith("val_"):
+            continue
+        fig = plt.figure()
+        if "val_" + col in pdf.columns:
+            ind = [col, "val_" + col]
+        else:
+            ind = col
+        value = pdf[ind]
+        ax = value.rolling(roll, min_periods=1).mean().plot()
+        ax.set_xlabel("Epoch")
+        ax.set_ylabel(col.capitalize())
+        fig.tight_layout()
+        plt.savefig(f"{path}/{col}.pdf")
+        plt.close("all")
+
+
+def figure_weights(w, y, class_names=None, relative=False):
+    n_p = y.shape[1]
+    bins = range(n_p)
+    weights = w[:, None] * y
+    fig = plt.figure()
+    pos = np.sum(weights > 0, axis=0)
+    neg = np.sum(weights < 0, axis=0)
+    values = neg / (pos + neg)
+    plt.bar(bins, values)
+    plt.xticks(bins, class_names, rotation=45)
+    plt.yscale("log")
+    plt.xlabel("Classfication Process")
+    plt.ylabel("Fraction Negative weights")
+    fig.tight_layout()
+    return fig
+
+
+def figure_y(y, class_names=None, relative=False):
+    n_p = y.shape[1]
+    bins = range(n_p)
+
+    fig = plt.figure()
+    values = y.sum(axis=0)
+    if relative:
+        values = values / values.sum()
+    plt.bar(bins, values)
+    plt.xticks(bins, class_names, rotation=45)
+    plt.yscale("log")
+    plt.xlabel("Classfication Process")
+    plt.ylabel("Number of Events" + relative * " (normed)")
+    fig.tight_layout()
+    return fig
+
+
+def figure_node_activations(
+    activations, truth, class_names=None, disjoint=False, sample_weight=None
+):
     n_b, n_p = activations.shape
-    quad = Quadrature(n_p)
+    multiplot = Multiplot(n_p)
 
-    rows, cols = quad.lenghts()
+    rows, cols = multiplot.lenghts()
     fig, ax = plt.subplots(rows, cols, figsize=(15, 15 * rows / cols))
     bins = np.linspace(0, 1.0, 10)
 
     process_activations = []
+    process_weights = []
     for process in range(n_p):
-        process_activations.append(activations[truth[:, process]].swapaxes(0, 1))
+        if disjoint:
+            max_activations = np.argmax(activations, axis=-1)
+            one_hot_max_activations = one_hot(max_activations, n=n_p)
+            values = (activations * one_hot_max_activations)[truth[:, process]].swapaxes(0, 1)
+            values[values == 0] = -10000
+        else:
+            values = activations[truth[:, process]].swapaxes(0, 1)
+        process_activations.append(values)
+        if sample_weight is not None:
+            process_weights.append(sample_weight[truth[:, process]])
 
+    old_err = np.seterr(divide="ignore", invalid="ignore")
     for node in range(n_p):
+        ax_index = multiplot.index(node)
         for process in range(n_p):
-            ax[quad.index(node)].hist(
-                process_activations[process][node],
-                bins,
-                histtype=u"step",
+            label = "%i" % process if class_names is None else class_names[process]
+            plot_kwargs = {"histtype": "step", "label": label, "range": (0.0, 1.0)}
+            if len(process_weights) == n_p:
+                plot_kwargs["weights"] = process_weights[process]
+            ax[ax_index].hist(process_activations[process][node], bins, **plot_kwargs)
+
+        ax[ax_index].text(
+            0.95,
+            0.95,
+            f"node {class_names[node]}",
+            ha="right",
+            va="top",
+            transform=ax[ax_index].transAxes,
+        )
+        ax[multiplot.index(node)].set_yscale("log")
+    np.seterr(**old_err)
+    ax[multiplot.index(cols - 1)].legend(
+        title="processes", bbox_to_anchor=(1.05, 1.0), loc="upper left"
+    )
+    fig.tight_layout()
+    return fig
+
+
+def figure_roc_curve(
+    truth, prediction, indices=[0], class_names=None, sample_weight=None, lw=2, scale="linear"
+):
+    fig = plt.figure()
+    for index in indices:
+        fpr, tpr, _ = roc_curve(truth[:, index], prediction[:, index])
+        roc_auc = auc(fpr, tpr)
+        name = index if class_names is None else class_names[index]
+        plt.plot(fpr, tpr, lw=lw, label=f"{name} vs All (area = {roc_auc:.2f})")
+    plt.plot([0, 1], [0, 1], color="navy", lw=lw, linestyle="--")
+    lower = 0.0
+    if scale.endswith("log"):
+        plt.xscale(scale)
+        plt.yscale(scale)
+        lower = 1e-5
+    plt.xlim([lower, 1.0])
+    plt.ylim([lower, 1.05])
+    plt.xlabel("False Positive Rate")
+    plt.ylabel("True Positive Rate")
+    plt.title("Receiver operating characteristic curve")
+    plt.legend(loc="lower right")
+    return fig
+
+
+def figure_inputs(
+    inps, truth, sample_weight=None, columns=None, class_names=None, signalvsbkg=False, bins=20
+):
+    multiplot = Multiplot(inps.shape[1:][::-1])
+    rows, cols = multiplot.lenghts()
+    size = len(columns)
+    fig, ax = plt.subplots(rows, cols, figsize=(size, size * rows / cols))
+
+    inps = inps.reshape(inps.shape[0], -1)
+    order = np.argsort(-(sample_weight[:, None] * truth).sum(axis=0))
+    class_names = np.array(class_names)[order]
+    for feat, name in enumerate(columns):
+        ax_index = multiplot.index(feat)
+        if signalvsbkg:
+            mask = np.argmax(truth, axis=-1) != 0
+            bins = ax[ax_index].hist(
+                inps[:, feat][mask],
+                histtype="stepfilled",
+                weights=sample_weight[mask],
+                label="Background",
                 density=True,
-                label="%i" % process if class_names is None else class_names[process],
+            )[1]
+            mask = np.argmax(truth, axis=-1) == 0
+            ax[ax_index].hist(
+                inps[:, feat][mask],
+                histtype="step",
+                bins=bins,
+                weights=sample_weight[mask],
+                label="HH",
+                density=True,
+                linewidth=2,
+            )
+        else:
+            for i in range(len(class_names)):
+                mask = np.argmax(truth, axis=-1) == i
+                ax[ax_index].hist(
+                    inps[:, feat][mask],
+                    histtype="step",
+                    bins=bins,
+                    weights=sample_weight[mask],
+                    label=class_names[i],
+                    density=True,
+                    linewidth=2,
+                )
+        ax[ax_index].set_title(name)
+        ax[ax_index].legend()
+    fig.tight_layout()
+    return fig
+
+
+def figure_weight_study(
+    class_inps, sample_weights=None, columns=None, label=None, log=False, mode="plain", **kwargs
+):
+    multiplot = Multiplot(class_inps[0].shape[1:][::-1])
+    rows, cols = multiplot.lenghts()
+    size = 5
+    fig, ax = plt.subplots(rows, cols, figsize=(cols * size, rows * size))
+
+    class_inps = [inps.reshape(inps.shape[0], -1) for inps in class_inps]
+    for feat, name in enumerate(columns):
+        ax_index = multiplot.index(feat)
+
+        bins = 25
+        ref = None
+
+        for i, inps in enumerate(class_inps):
+            if mode == "plain":
+                val, bins, _ = ax[ax_index].hist(
+                    inps[:, feat],
+                    histtype="step",
+                    bins=bins,
+                    weights=sample_weights[i],
+                    density=True,
+                    label=label[i] if label else None,
+                    **kwargs,
+                )
+            if mode == "rel":
+                val, bins = np.histogram(
+                    inps[:, feat], bins=bins, weights=sample_weights[i], density=True
+                )
+                if ref is None:
+                    ref = val
+                ax[ax_index].bar(
+                    bins[:-1],
+                    height=(val - ref) / ref,
+                    width=bins[1:] - bins[:-1],
+                    align="edge",
+                    label=label[i] if label else None,
+                    alpha=0.5,
+                )
+        if mode == "weight":
+            mask = sample_weights[0] > 0
+            pos_feat, pos_weight = class_inps[0][:, feat][mask], sample_weights[0][mask]
+            neg_feat, neg_weight = class_inps[0][:, feat][~mask], sample_weights[0][~mask]
+
+            val_pos, bins = np.histogram(pos_feat, bins=bins, weights=pos_weight)
+            val_neg, bins = np.histogram(neg_feat, bins=bins, weights=neg_weight)
+            ax[ax_index].bar(
+                bins[:-1],
+                height=np.abs(val_neg) / (val_pos + np.abs(val_neg)),
+                width=bins[1:] - bins[:-1],
+                align="edge",
             )
-        ax[quad.index(node)].set_yscale("log")
-    ax[quad.index(cols - 1)].legend(bbox_to_anchor=(1.05, 1.0), loc="upper left")
+
+        if mode in ["weight"]:
+            ax[ax_index].set_yscale("log")
+        ax[ax_index].set_title(name)
+        ax[ax_index].legend()
+    fig.tight_layout()
+    return fig
+
+
+def figure_multihist(data, columns=None):
+    fig, ax = plt.subplots()
+    df = pd.DataFrame(np.reshape(data, (data.shape[0], -1)), columns=columns)
+    warnings.simplefilter("ignore")  # temporary fix outdated pandas
+    df.hist(figsize=(20, 20))
     fig.tight_layout()
     return fig
 
diff --git a/tf.py b/tf.py
index 0398d4ab52286132a98cf599a5603d0026883046..fbfad716fca915253cb4c8773d779f7eb1eaee0a 100644
--- a/tf.py
+++ b/tf.py
@@ -5,25 +5,32 @@ from time import time
 from analysis.util import atq
 from io import BytesIO
 import h5py
-from tensorflow.python.keras.engine.saving import save_weights_to_hdf5_group, load_weights_from_hdf5_group
+from tensorflow.python.keras.engine.saving import (
+    save_weights_to_hdf5_group,
+    load_weights_from_hdf5_group,
+)
 
 from evil import pin, ccall
 
 
 from tensorflow.python.util import deprecation
+
 deprecation._PRINT_DEPRECATION_WARNINGS = False
 
 
 def last_map(func, obj, order=1):
-    return tf.concat([
-        obj[..., :-1],
-        tf.expand_dims(
-            last_map(func=func, obj=obj[..., -1], order=order - 1)
-            if 1 < order else
-            func(obj[..., -1]),
-            axis=-1
-        ),
-    ], axis=-1)
+    return tf.concat(
+        [
+            obj[..., :-1],
+            tf.expand_dims(
+                last_map(func=func, obj=obj[..., -1], order=order - 1)
+                if 1 < order
+                else func(obj[..., -1]),
+                axis=-1,
+            ),
+        ],
+        axis=-1,
+    )
 
 
 def tf_meanstd(val, **kwargs):
@@ -34,6 +41,7 @@ def tf_meanstd(val, **kwargs):
 
 # this is mostly unused ...
 
+
 class FD(dict):
     def add(self, data, dtype=None, shape=None, **kwargs):
         if shape is True:
@@ -48,7 +56,17 @@ class FD(dict):
 
 
 class Chain(object):
-    def __init__(self, __name__, loss, sumnbs=(), step=None, opt=None, train=None, ema=(1, 2, 3, 4, 5), **kwargs):
+    def __init__(
+        self,
+        __name__,
+        loss,
+        sumnbs=(),
+        step=None,
+        opt=None,
+        train=None,
+        ema=(1, 2, 3, 4, 5),
+        **kwargs
+    ):
         if step is None:
             step = tf.Variable(0, trainable=False, name="%s_step" % __name__)
         if opt is None:
@@ -88,9 +106,13 @@ class Chain(object):
             reset_ema = ema_step.initializer
             for i in ema:
                 with tf.name_scope("%s_%d" % (__name__, i)):
-                    ema = tf.train.ExponentialMovingAverage(decay=1. - tf.maximum(
-                        0.1**i, tf.exp((np.log(0.1**i) / 10.) * tf.cast(ema_step, tf.float32))
-                    ))
+                    ema = tf.train.ExponentialMovingAverage(
+                        decay=1.0
+                        - tf.maximum(
+                            0.1 ** i,
+                            tf.exp((np.log(0.1 ** i) / 10.0) * tf.cast(ema_step, tf.float32)),
+                        )
+                    )
                     _ema_ops.append(ema.apply(_ema_val.values()))
                     for key, val in _ema_val.items():
                         val = ema.average(val)
@@ -139,7 +161,9 @@ class Runner(object):
     @contextmanager
     def make(cls, path, gpuOpts={}, flush_secs=20):
         with cls.make_session(gpuOpts=gpuOpts) as sess:
-            with tf.summary.FileWriter(path, session=sess, flush_secs=flush_secs, graph=sess.graph) as writer:
+            with tf.summary.FileWriter(
+                path, session=sess, flush_secs=flush_secs, graph=sess.graph
+            ) as writer:
                 yield cls(sess, writer)
 
     @classmethod
@@ -168,11 +192,12 @@ class Confirmer(object):
                         break
                     if tnext < time():
                         tnext = time() + self.tdelay
-                        pTotal = 100. * i / its
+                        pTotal = 100.0 * i / its
                         self.task.set_progress_percentage(pTotal)
-                        self.task.set_status_message("total=%.1f%% conf=%.1f%% best=%.2e" % (
-                            pTotal, 100. * self.conf / target, self.best
-                        ))
+                        self.task.set_status_message(
+                            "total=%.1f%% conf=%.1f%% best=%.2e"
+                            % (pTotal, 100.0 * self.conf / target, self.best)
+                        )
                 if not (self.conf < target):
                     break
                 self.step += 1
@@ -216,11 +241,7 @@ class Confirmer(object):
         }
 
     def __repr__(self):
-        return "Confirmer(step=%d, best=%.3e, conf=%d)" % (
-            self.step,
-            self.best,
-            self.conf
-        )
+        return "Confirmer(step=%d, best=%.3e, conf=%d)" % (self.step, self.best, self.conf)
 
 
 class Bestie(object):