diff --git a/keras.py b/keras.py
index 15645ec2a9102c79b0498fcad95353e134408925..16c94fec825dc8325da7cb020e2ef6208dadc527 100644
--- a/keras.py
+++ b/keras.py
@@ -392,20 +392,25 @@ class TFSummaryCallback(tf.keras.callbacks.Callback):
 class PlotMulticlass(TFSummaryCallback):
     def __init__(
         self,
+        x,
+        y,
+        sample_weight=None,
         class_names=["signal", "background"],
         to_file=False,
         columns=None,
         plot_inputs=False,
+        signalvsbkg=False,
         **kwargs,
     ):
         super().__init__(**kwargs)
-        self.x = kwargs["x"]
-        self.truth = kwargs["y"]
-        self.sample_weight = kwargs.get("sample_weight", None)
+        self.x = x
+        self.truth = y
+        self.sample_weight = sample_weight
         self.class_names = class_names
         self.plot_inputs = plot_inputs
         self.columns = columns
         self.to_file = to_file
+        self.signalvsbkg = signalvsbkg
 
     def on_test_begin(self, logs=None):
         self.on_train_begin(logs=logs)
@@ -419,26 +424,27 @@ class PlotMulticlass(TFSummaryCallback):
                     inps = [inps]
 
                 for part, inp in zip(self.columns.keys(), inps):
-                    imgs[f"inputs_merged_{part}"] = figure_to_image(
+                    imgs[f"inp_xmerged_{part}"] = figure_to_image(
                         figure_multihist(inp, columns=self.columns[part])
                     )
             if self.sample_weight is not None:
                 for part, inp in zip(self.columns.keys(), inps):
-                    imgs[f"inputs_{part}"] = figure_to_image(
+                    imgs[f"inp_x_{part}"] = figure_to_image(
                         figure_inputs(
                             inp,
                             self.truth,
                             sample_weight=self.sample_weight,
                             columns=self.columns[part],
                             class_names=self.class_names,
+                            signalvsbkg=self.signalvsbkg,
                         )
                     )
 
-                imgs["weights"] = figure_to_image(
+                imgs["inp_weights"] = figure_to_image(
                     figure_weights(self.sample_weight, self.truth, class_names=self.class_names)
                 )
-            imgs["processes"] = figure_to_image(figure_y(self.truth, class_names=self.class_names))
-            imgs["processes_relative"] = figure_to_image(
+            imgs["inp_y"] = figure_to_image(figure_y(self.truth, class_names=self.class_names))
+            imgs["inp_yrelative"] = figure_to_image(
                 figure_y(self.truth, class_names=self.class_names, relative=True)
             )
             for name, img in imgs.items():
diff --git a/plotting.py b/plotting.py
index acc8aae946501ca597febf3b783d4a9e5de073e1..1dad79930a3a80d877cbcffdbd344b70b6143127 100644
--- a/plotting.py
+++ b/plotting.py
@@ -1,6 +1,7 @@
 # -*- coding: utf-8 -*-
 
 import io
+import warnings
 
 import tensorflow as tf
 import itertools
@@ -12,21 +13,32 @@ import pandas as pd
 from .numpy import one_hot
 
 
-class Quadrature:
+class Multiplot:
     def __init__(self, n):
-        self.n = n
-        self.cols = np.ceil(np.sqrt(n)).astype(int)
-        self.rows = np.ceil(n / self.cols).astype(int)
+        try:
+            self.cols, self.rows = [i for i in n]
+        except TypeError:
+            self.n = n
+            self.cols = np.ceil(np.sqrt(n)).astype(int)
+            self.rows = np.ceil(n / self.cols).astype(int)
 
     def lenghts(self):
         return self.rows, self.cols
 
-    def index(self, n):
-        row = int(n / self.cols)
-        col = n - row * self.cols
+    def index(self, i):
+        row = int(i / self.cols)
+        col = i - row * self.cols
         return row, col
 
 
+def saveplot(f):
+    def helper(*args, **kwargs):
+        plt.close("all")
+        return f(*args, **kwargs)
+
+    return helper
+
+
 def figure_confusion_matrix(
     truth,
     prediction,
@@ -149,9 +161,9 @@ def figure_node_activations(
     activations, truth, class_names=None, disjoint=False, sample_weight=None
 ):
     n_b, n_p = activations.shape
-    quad = Quadrature(n_p)
+    multiplot = Multiplot(n_p)
 
-    rows, cols = quad.lenghts()
+    rows, cols = multiplot.lenghts()
     fig, ax = plt.subplots(rows, cols, figsize=(15, 15 * rows / cols))
     bins = np.linspace(0, 1.0, 10)
 
@@ -171,7 +183,7 @@ def figure_node_activations(
 
     old_err = np.seterr(divide="ignore", invalid="ignore")
     for node in range(n_p):
-        ax_index = quad.index(node)
+        ax_index = multiplot.index(node)
         for process in range(n_p):
             label = "%i" % process if class_names is None else class_names[process]
             plot_kwargs = {"histtype": "step", "label": label, "range": (0.0, 1.0)}
@@ -187,9 +199,11 @@ def figure_node_activations(
             va="top",
             transform=ax[ax_index].transAxes,
         )
-        ax[quad.index(node)].set_yscale("log")
+        ax[multiplot.index(node)].set_yscale("log")
     np.seterr(**old_err)
-    ax[quad.index(cols - 1)].legend(title="processes", bbox_to_anchor=(1.05, 1.0), loc="upper left")
+    ax[multiplot.index(cols - 1)].legend(
+        title="processes", bbox_to_anchor=(1.05, 1.0), loc="upper left"
+    )
     fig.tight_layout()
     return fig
 
@@ -218,9 +232,11 @@ def figure_roc_curve(
     return fig
 
 
-def figure_inputs(inps, truth, sample_weight=None, columns=None, class_names=None):
-    quad = Quadrature(len(columns))
-    rows, cols = quad.lenghts()
+def figure_inputs(
+    inps, truth, sample_weight=None, columns=None, class_names=None, signalvsbkg=False, bins=20
+):
+    multiplot = Multiplot(inps.shape[1:][::-1])
+    rows, cols = multiplot.lenghts()
     size = len(columns)
     fig, ax = plt.subplots(rows, cols, figsize=(size, size * rows / cols))
 
@@ -228,25 +244,100 @@ def figure_inputs(inps, truth, sample_weight=None, columns=None, class_names=Non
     order = np.argsort(-(sample_weight[:, None] * truth).sum(axis=0))
     class_names = np.array(class_names)[order]
     for feat, name in enumerate(columns):
-        ax_index = quad.index(feat)
-        mask = np.argmax(truth, axis=-1) != 0
-        bins = ax[ax_index].hist(
-            inps[:, feat][mask],
-            histtype="stepfilled",
-            weights=sample_weight[mask],
-            label="Background",
-            density=True,
-        )[1]
-        mask = np.argmax(truth, axis=-1) == 0
-        ax[ax_index].hist(
-            inps[:, feat][mask],
-            histtype="step",
-            bins=bins,
-            weights=sample_weight[mask],
-            label="HH",
-            density=True,
-            linewidth=2,
-        )
+        ax_index = multiplot.index(feat)
+        if signalvsbkg:
+            mask = np.argmax(truth, axis=-1) != 0
+            bins = ax[ax_index].hist(
+                inps[:, feat][mask],
+                histtype="stepfilled",
+                weights=sample_weight[mask],
+                label="Background",
+                density=True,
+            )[1]
+            mask = np.argmax(truth, axis=-1) == 0
+            ax[ax_index].hist(
+                inps[:, feat][mask],
+                histtype="step",
+                bins=bins,
+                weights=sample_weight[mask],
+                label="HH",
+                density=True,
+                linewidth=2,
+            )
+        else:
+            for i in range(len(class_names)):
+                mask = np.argmax(truth, axis=-1) == i
+                ax[ax_index].hist(
+                    inps[:, feat][mask],
+                    histtype="step",
+                    bins=bins,
+                    weights=sample_weight[mask],
+                    label=class_names[i],
+                    density=True,
+                    linewidth=2,
+                )
+        ax[ax_index].set_title(name)
+        ax[ax_index].legend()
+    fig.tight_layout()
+    return fig
+
+
+def figure_weight_study(
+    class_inps, sample_weights=None, columns=None, label=None, log=False, mode="plain", **kwargs
+):
+    multiplot = Multiplot(class_inps[0].shape[1:][::-1])
+    rows, cols = multiplot.lenghts()
+    size = 3
+    fig, ax = plt.subplots(rows, cols, figsize=(cols * size, rows * size))
+
+    class_inps = [inps.reshape(inps.shape[0], -1) for inps in class_inps]
+    for feat, name in enumerate(columns):
+        ax_index = multiplot.index(feat)
+
+        bins = 25
+        ref = None
+
+        for i, inps in enumerate(class_inps):
+            if mode == "plain":
+                val, bins, _ = ax[ax_index].hist(
+                    inps[:, feat],
+                    histtype="step",
+                    bins=bins,
+                    weights=sample_weights[i],
+                    density=True,
+                    label=label[i] if label else None,
+                    **kwargs,
+                )
+            if mode == "rel":
+                val, bins = np.histogram(
+                    inps[:, feat], bins=bins, weights=sample_weights[i], density=True
+                )
+                if ref is None:
+                    ref = val
+                ax[ax_index].bar(
+                    bins[:-1],
+                    height=(val - ref) / ref,
+                    width=bins[1:] - bins[:-1],
+                    align="edge",
+                    label=label[i] if label else None,
+                    alpha=0.5,
+                )
+        if mode == "weight":
+            mask = sample_weights[0] > 0
+            pos_feat, pos_weight = class_inps[0][:, feat][mask], sample_weights[0][mask]
+            neg_feat, neg_weight = class_inps[0][:, feat][~mask], sample_weights[0][~mask]
+
+            val_pos, bins = np.histogram(pos_feat, bins=bins, weights=pos_weight)
+            val_neg, bins = np.histogram(neg_feat, bins=bins, weights=neg_weight)
+            ax[ax_index].bar(
+                bins[:-1],
+                height=np.abs(val_neg) / (val_pos + np.abs(val_neg)),
+                width=bins[1:] - bins[:-1],
+                align="edge",
+            )
+
+        if mode in ["weight", "rel"]:
+            ax[ax_index].set_yscale("log")
         ax[ax_index].set_title(name)
         ax[ax_index].legend()
     fig.tight_layout()
@@ -256,7 +347,8 @@ def figure_inputs(inps, truth, sample_weight=None, columns=None, class_names=Non
 def figure_multihist(data, columns=None):
     fig, ax = plt.subplots()
     df = pd.DataFrame(np.reshape(data, (data.shape[0], -1)), columns=columns)
-    df.hist(figsize=(20, 20))
+        warnings.simplefilter("ignore"):  # temporary fix outdated pandas
+        df.hist(figsize=(20, 20))
     fig.tight_layout()
     return fig