diff --git a/keras.py b/keras.py index 15645ec2a9102c79b0498fcad95353e134408925..16c94fec825dc8325da7cb020e2ef6208dadc527 100644 --- a/keras.py +++ b/keras.py @@ -392,20 +392,25 @@ class TFSummaryCallback(tf.keras.callbacks.Callback): class PlotMulticlass(TFSummaryCallback): def __init__( self, + x, + y, + sample_weight=None, class_names=["signal", "background"], to_file=False, columns=None, plot_inputs=False, + signalvsbkg=False, **kwargs, ): super().__init__(**kwargs) - self.x = kwargs["x"] - self.truth = kwargs["y"] - self.sample_weight = kwargs.get("sample_weight", None) + self.x = x + self.truth = y + self.sample_weight = sample_weight self.class_names = class_names self.plot_inputs = plot_inputs self.columns = columns self.to_file = to_file + self.signalvsbkg = signalvsbkg def on_test_begin(self, logs=None): self.on_train_begin(logs=logs) @@ -419,26 +424,27 @@ class PlotMulticlass(TFSummaryCallback): inps = [inps] for part, inp in zip(self.columns.keys(), inps): - imgs[f"inputs_merged_{part}"] = figure_to_image( + imgs[f"inp_xmerged_{part}"] = figure_to_image( figure_multihist(inp, columns=self.columns[part]) ) if self.sample_weight is not None: for part, inp in zip(self.columns.keys(), inps): - imgs[f"inputs_{part}"] = figure_to_image( + imgs[f"inp_x_{part}"] = figure_to_image( figure_inputs( inp, self.truth, sample_weight=self.sample_weight, columns=self.columns[part], class_names=self.class_names, + signalvsbkg=self.signalvsbkg, ) ) - imgs["weights"] = figure_to_image( + imgs["inp_weights"] = figure_to_image( figure_weights(self.sample_weight, self.truth, class_names=self.class_names) ) - imgs["processes"] = figure_to_image(figure_y(self.truth, class_names=self.class_names)) - imgs["processes_relative"] = figure_to_image( + imgs["inp_y"] = figure_to_image(figure_y(self.truth, class_names=self.class_names)) + imgs["inp_yrelative"] = figure_to_image( figure_y(self.truth, class_names=self.class_names, relative=True) ) for name, img in imgs.items(): diff --git a/plotting.py b/plotting.py index acc8aae946501ca597febf3b783d4a9e5de073e1..1dad79930a3a80d877cbcffdbd344b70b6143127 100644 --- a/plotting.py +++ b/plotting.py @@ -1,6 +1,7 @@ # -*- coding: utf-8 -*- import io +import warnings import tensorflow as tf import itertools @@ -12,21 +13,32 @@ import pandas as pd from .numpy import one_hot -class Quadrature: +class Multiplot: def __init__(self, n): - self.n = n - self.cols = np.ceil(np.sqrt(n)).astype(int) - self.rows = np.ceil(n / self.cols).astype(int) + try: + self.cols, self.rows = [i for i in n] + except TypeError: + self.n = n + self.cols = np.ceil(np.sqrt(n)).astype(int) + self.rows = np.ceil(n / self.cols).astype(int) def lenghts(self): return self.rows, self.cols - def index(self, n): - row = int(n / self.cols) - col = n - row * self.cols + def index(self, i): + row = int(i / self.cols) + col = i - row * self.cols return row, col +def saveplot(f): + def helper(*args, **kwargs): + plt.close("all") + return f(*args, **kwargs) + + return helper + + def figure_confusion_matrix( truth, prediction, @@ -149,9 +161,9 @@ def figure_node_activations( activations, truth, class_names=None, disjoint=False, sample_weight=None ): n_b, n_p = activations.shape - quad = Quadrature(n_p) + multiplot = Multiplot(n_p) - rows, cols = quad.lenghts() + rows, cols = multiplot.lenghts() fig, ax = plt.subplots(rows, cols, figsize=(15, 15 * rows / cols)) bins = np.linspace(0, 1.0, 10) @@ -171,7 +183,7 @@ def figure_node_activations( old_err = np.seterr(divide="ignore", invalid="ignore") for node in range(n_p): - ax_index = quad.index(node) + ax_index = multiplot.index(node) for process in range(n_p): label = "%i" % process if class_names is None else class_names[process] plot_kwargs = {"histtype": "step", "label": label, "range": (0.0, 1.0)} @@ -187,9 +199,11 @@ def figure_node_activations( va="top", transform=ax[ax_index].transAxes, ) - ax[quad.index(node)].set_yscale("log") + ax[multiplot.index(node)].set_yscale("log") np.seterr(**old_err) - ax[quad.index(cols - 1)].legend(title="processes", bbox_to_anchor=(1.05, 1.0), loc="upper left") + ax[multiplot.index(cols - 1)].legend( + title="processes", bbox_to_anchor=(1.05, 1.0), loc="upper left" + ) fig.tight_layout() return fig @@ -218,9 +232,11 @@ def figure_roc_curve( return fig -def figure_inputs(inps, truth, sample_weight=None, columns=None, class_names=None): - quad = Quadrature(len(columns)) - rows, cols = quad.lenghts() +def figure_inputs( + inps, truth, sample_weight=None, columns=None, class_names=None, signalvsbkg=False, bins=20 +): + multiplot = Multiplot(inps.shape[1:][::-1]) + rows, cols = multiplot.lenghts() size = len(columns) fig, ax = plt.subplots(rows, cols, figsize=(size, size * rows / cols)) @@ -228,25 +244,100 @@ def figure_inputs(inps, truth, sample_weight=None, columns=None, class_names=Non order = np.argsort(-(sample_weight[:, None] * truth).sum(axis=0)) class_names = np.array(class_names)[order] for feat, name in enumerate(columns): - ax_index = quad.index(feat) - mask = np.argmax(truth, axis=-1) != 0 - bins = ax[ax_index].hist( - inps[:, feat][mask], - histtype="stepfilled", - weights=sample_weight[mask], - label="Background", - density=True, - )[1] - mask = np.argmax(truth, axis=-1) == 0 - ax[ax_index].hist( - inps[:, feat][mask], - histtype="step", - bins=bins, - weights=sample_weight[mask], - label="HH", - density=True, - linewidth=2, - ) + ax_index = multiplot.index(feat) + if signalvsbkg: + mask = np.argmax(truth, axis=-1) != 0 + bins = ax[ax_index].hist( + inps[:, feat][mask], + histtype="stepfilled", + weights=sample_weight[mask], + label="Background", + density=True, + )[1] + mask = np.argmax(truth, axis=-1) == 0 + ax[ax_index].hist( + inps[:, feat][mask], + histtype="step", + bins=bins, + weights=sample_weight[mask], + label="HH", + density=True, + linewidth=2, + ) + else: + for i in range(len(class_names)): + mask = np.argmax(truth, axis=-1) == i + ax[ax_index].hist( + inps[:, feat][mask], + histtype="step", + bins=bins, + weights=sample_weight[mask], + label=class_names[i], + density=True, + linewidth=2, + ) + ax[ax_index].set_title(name) + ax[ax_index].legend() + fig.tight_layout() + return fig + + +def figure_weight_study( + class_inps, sample_weights=None, columns=None, label=None, log=False, mode="plain", **kwargs +): + multiplot = Multiplot(class_inps[0].shape[1:][::-1]) + rows, cols = multiplot.lenghts() + size = 3 + fig, ax = plt.subplots(rows, cols, figsize=(cols * size, rows * size)) + + class_inps = [inps.reshape(inps.shape[0], -1) for inps in class_inps] + for feat, name in enumerate(columns): + ax_index = multiplot.index(feat) + + bins = 25 + ref = None + + for i, inps in enumerate(class_inps): + if mode == "plain": + val, bins, _ = ax[ax_index].hist( + inps[:, feat], + histtype="step", + bins=bins, + weights=sample_weights[i], + density=True, + label=label[i] if label else None, + **kwargs, + ) + if mode == "rel": + val, bins = np.histogram( + inps[:, feat], bins=bins, weights=sample_weights[i], density=True + ) + if ref is None: + ref = val + ax[ax_index].bar( + bins[:-1], + height=(val - ref) / ref, + width=bins[1:] - bins[:-1], + align="edge", + label=label[i] if label else None, + alpha=0.5, + ) + if mode == "weight": + mask = sample_weights[0] > 0 + pos_feat, pos_weight = class_inps[0][:, feat][mask], sample_weights[0][mask] + neg_feat, neg_weight = class_inps[0][:, feat][~mask], sample_weights[0][~mask] + + val_pos, bins = np.histogram(pos_feat, bins=bins, weights=pos_weight) + val_neg, bins = np.histogram(neg_feat, bins=bins, weights=neg_weight) + ax[ax_index].bar( + bins[:-1], + height=np.abs(val_neg) / (val_pos + np.abs(val_neg)), + width=bins[1:] - bins[:-1], + align="edge", + ) + + if mode in ["weight", "rel"]: + ax[ax_index].set_yscale("log") ax[ax_index].set_title(name) ax[ax_index].legend() fig.tight_layout() @@ -256,7 +347,8 @@ def figure_inputs(inps, truth, sample_weight=None, columns=None, class_names=Non def figure_multihist(data, columns=None): fig, ax = plt.subplots() df = pd.DataFrame(np.reshape(data, (data.shape[0], -1)), columns=columns) - df.hist(figsize=(20, 20)) + warnings.simplefilter("ignore"): # temporary fix outdated pandas + df.hist(figsize=(20, 20)) fig.tight_layout() return fig