Skip to content
Snippets Groups Projects
plotting.py 14 KiB
Newer Older

import tensorflow as tf
import itertools
import numpy as np
from matplotlib import pyplot as plt
from sklearn.metrics import confusion_matrix, roc_curve, auc
import pandas as pd
        try:
            self.cols, self.rows = [i for i in n]
        except ValueError:
            n = n[0]
            self.cols = np.ceil(np.sqrt(n)).astype(int)
            self.rows = np.ceil(n / self.cols).astype(int)
        except TypeError:
            self.cols = np.ceil(np.sqrt(n)).astype(int)
            self.rows = np.ceil(n / self.cols).astype(int)

    def lenghts(self):
        return self.rows, self.cols

    def index(self, i):
        row = int(i / self.cols)
        col = i - row * self.cols
        if self.rows == 1:
            return col
        if self.cols == 1:
            return row
        else:
            return row, col
def saveplot(f):
    def helper(*args, **kwargs):
        plt.close("all")
        return f(*args, **kwargs)

    return helper


def figure_confusion_matrix(
    truth,
    prediction,
    class_names=["signal", "background"],
    sample_weight=None,
    normalize="true",
):
    assert len(class_names) == truth.shape[-1] == prediction.shape[-1]
    fig, ax = plt.subplots()
    cm = confusion_matrix(
        np.argmax(truth, axis=-1),
        np.argmax(prediction, axis=-1),
        sample_weight=sample_weight,
        normalize=normalize,
    )
    cmap = "plasma" if normalize == "true" else "viridis"
    im = ax.imshow(cm, interpolation="nearest", cmap=cmap)
    plt.title("Confusion matrix")
    fig.colorbar(im, ax=ax)

    tick_marks = np.arange(len(class_names))
    plt.xticks(tick_marks, class_names, rotation=45)
    plt.yticks(tick_marks, class_names)
    half = (cm.max() + cm.min()) / 2
    for i, j in itertools.product(range(cm.shape[0]), range(cm.shape[1])):
        fontdict = {"color": "black" if cm[i, j] > half else "white"}
Dennis Noll's avatar
Dennis Noll committed
        plt.text(
            j,
            i,
            np.around(cm[i, j], decimals=2),
            horizontalalignment="center",
            size=7,
            fontdict=fontdict,
Dennis Noll's avatar
Dennis Noll committed
        )
    plt.ylabel("True label" + " (normed)" * (normalize == "true"))
    plt.xlabel("Predicted label" + " (normed)" * (normalize == "pred"))
    fig.tight_layout()
    return fig


def figure_activations(activations, class_names=None):
    bins = np.linspace(0, 1.0, 10)
    n_b, n_p = activations.shape
    fig = plt.figure()

    for i in range(n_p):
        plt.hist(
            activations[:, i],
            bins,
            density=True,
            label="%i" % i if class_names is None else class_names[i],
        )

    plt.yscale("log")
    plt.legend()
    fig.tight_layout()
    return fig


def figure_history(history_csv_path):
    pd.read_csv(history_csv_path).plot(subplots=True, figsize=(30, 30), layout=(7, 6))
    fig = plt.figure()
    fig.tight_layout()
    return fig


def plot_histories(history_csv_path, path, cut=None, roll=1):
Dennis Noll's avatar
Dennis Noll committed
    df = pd.read_csv(history_csv_path)
    df = df.set_index("epoch")
    df = df.sort_index()
Dennis Noll's avatar
Dennis Noll committed
    df = df.truncate(after=cut)
    for col in df.columns:
        if col.startswith("val_"):
            continue
        fig = plt.figure()
Dennis Noll's avatar
Dennis Noll committed
        if "val_" + col in df.columns:
            ind = [col, "val_" + col]
        else:
            ind = col
Dennis Noll's avatar
Dennis Noll committed
        value = df[ind]
        ax = value.rolling(roll, min_periods=1).mean().plot()
        ax.set_xlabel("Epoch")
        ax.set_ylabel(col.capitalize())
        fig.tight_layout()
        plt.savefig(f"{path}/{col}.pdf")
        plt.close("all")


def figure_dict(d, xlabel=""):
    fig = plt.figure(figsize=(8, len(d) / 3))
    plt.barh(*zip(*d.items()))
    fig.tight_layout()
    return fig


def plot_dict(d, path="tmp.pdf"):
    fig = figure_dict(d)
    plt.savefig(path)
    plt.close("all")


def figure_weights(w, y, class_names=None, relative=False):
    n_p = y.shape[1]
    bins = range(n_p)
    weights = w[:, None] * y
    fig = plt.figure()
    pos = np.sum(weights > 0, axis=0)
    neg = np.sum(weights < 0, axis=0)
    values = neg / (pos + neg)
    plt.bar(bins, values)
    plt.xticks(bins, class_names, rotation=45)
    plt.yscale("log")
    plt.xlabel("Classfication Process")
    plt.ylabel("Fraction Negative weights")
    fig.tight_layout()
    return fig


def figure_y(y, class_names=None, relative=False):
    n_p = y.shape[1]
    bins = range(n_p)

    fig = plt.figure()
    values = y.sum(axis=0)
    if relative:
        values = values / values.sum()
    plt.bar(bins, values)
    plt.xticks(bins, class_names, rotation=45)
    plt.yscale("log")
    plt.xlabel("Classfication Process")
    plt.ylabel("Number of Events" + relative * " (normed)")
    fig.tight_layout()
    return fig


def figure_node_activations(
    activations, truth, class_names=None, disjoint=False, sample_weight=None
):
    rows, cols = multiplot.lenghts()
    size = 5
    fig, ax = plt.subplots(rows, cols, figsize=(size * cols, size * rows))
    bins = np.linspace(0, 1.0, 10)

    process_activations = []
        this_process = truth[:, process].astype(np.bool)
        if disjoint:
            max_activations = np.argmax(activations, axis=-1)
            one_hot_max_activations = one_hot(max_activations, n=n_p)
            values = (activations * one_hot_max_activations)[this_process].swapaxes(0, 1)
            values[values == 0] = -10000
        else:
            values = activations[this_process].swapaxes(0, 1)
        process_activations.append(values)
        if sample_weight is not None:
            process_weights.append(sample_weight[this_process])
    old_err = np.seterr(divide="ignore", invalid="ignore")
        ax_index = multiplot.index(node)
            label = "%i" % process if class_names is None else class_names[process]
            plot_kwargs = {"histtype": "step", "label": label, "range": (0.0, 1.0)}
            if len(process_weights) == n_p:
                plot_kwargs["weights"] = process_weights[process]
            ax[ax_index].hist(process_activations[process][node], bins, **plot_kwargs)

        ax[ax_index].text(
            0.95,
            0.95,
            f"node {class_names[node]}",
            ha="right",
            va="top",
            transform=ax[ax_index].transAxes,
        )
        ax[multiplot.index(node)].set_yscale("log")
    np.seterr(**old_err)
    ax[multiplot.index(cols - 1)].legend(
        title="processes", bbox_to_anchor=(1.05, 1.0), loc="upper left"
    )
Dennis Noll's avatar
Dennis Noll committed
    truth,
    prediction,
    indices=[0],
    class_names=None,
    sample_weight=None,
    lw=2,
    scale="linear",
):
    fig = plt.figure()
    for index in indices:
        fpr, tpr, _ = roc_curve(truth[:, index], prediction[:, index])
        roc_auc = auc(fpr, tpr)
        name = index if class_names is None else class_names[index]
        plt.plot(fpr, tpr, lw=lw, label=f"{name} vs All (area = {roc_auc:.2f})")
    plt.plot([0, 1], [0, 1], color="navy", lw=lw, linestyle="--")
    lower = 0.0
    if scale.endswith("log"):
        plt.xscale(scale)
        plt.yscale(scale)
        lower = 1e-5
    plt.xlim([lower, 1.0])
    plt.ylim([lower, 1.05])
    plt.xlabel("False Positive Rate")
    plt.ylabel("True Positive Rate")
    plt.title("Receiver operating characteristic curve")
    plt.legend(loc="lower right")
    return fig


    inps,
    truth,
    sample_weight=None,
    columns=None,
    class_names=None,
    signalvsbkg=False,
    bins=20,
    overflow=1000,
):
    multiplot = Multiplot(inps.shape[1:][::-1])
    rows, cols = multiplot.lenghts()
    size = 5
    fig, ax = plt.subplots(rows, cols, figsize=(size * cols, size * rows))
    ax = ax if isinstance(ax, np.ndarray) else np.array([ax])
    inps = inps.reshape(-1, np.prod(inps.shape[1:]))
    inps = np.clip(inps, -overflow, overflow)
    for feat, name in enumerate(columns):
        ax_index = multiplot.index(feat)
        if signalvsbkg:
            mask = np.argmax(truth, axis=-1) != 0
            _bins = ax[ax_index].hist(
                inps[:, feat][mask],
                histtype="stepfilled",
                weights=sample_weight[mask],
                label="Background",
                density=True,
            )[1]
            mask = np.argmax(truth, axis=-1) == 0
            ax[ax_index].hist(
                inps[:, feat][mask],
                histtype="step",
                weights=sample_weight[mask],
                label="HH",
                density=True,
                linewidth=2,
            )
        else:
            for i in range(len(class_names)):
                mask = np.argmax(truth, axis=-1) == i
                _bins = ax[ax_index].hist(
                    histtype=histtype,
                    bins=_bins,
                    weights=sample_weight[mask],
                    label=class_names[i],
                    density=True,
                    linewidth=2,
                )[1]
                histtype = "step"

        ax[ax_index].set_yscale(yscale)
        ax[ax_index].set_title(name)
        ax[ax_index].legend()
    fig.tight_layout()
    return fig


def figure_weight_study(
Dennis Noll's avatar
Dennis Noll committed
    class_inps,
    sample_weights=None,
    columns=None,
    label=None,
    log=False,
    mode="plain",
    **kwargs,
):
    multiplot = Multiplot(class_inps[0].shape[1:][::-1])
    rows, cols = multiplot.lenghts()
Dennis Noll's avatar
Dennis Noll committed
    size = 5
    fig, ax = plt.subplots(rows, cols, figsize=(size * cols, size * rows))
    ax = ax if isinstance(ax, np.ndarray) else np.array([ax])

    class_inps = [inps.reshape(inps.shape[0], -1) for inps in class_inps]
    for feat, name in enumerate(columns):
        ax_index = multiplot.index(feat)

        bins = 25
        ref = None

        for i, inps in enumerate(class_inps):
            if mode == "plain":
                val, bins, _ = ax[ax_index].hist(
                    inps[:, feat],
                    histtype="step",
                    bins=bins,
                    weights=sample_weights[i],
                    density=True,
                    label=label[i] if label else None,
                    **kwargs,
                )
            if mode == "rel":
                val, bins = np.histogram(
                    inps[:, feat], bins=bins, weights=sample_weights[i], density=True
                )
                if ref is None:
                    ref = val
                ax[ax_index].bar(
                    bins[:-1],
                    height=(val - ref) / ref,
                    width=bins[1:] - bins[:-1],
                    align="edge",
                    label=label[i] if label else None,
                    alpha=0.5,
                )
        if mode == "weight":
            mask = sample_weights[0] > 0
            pos_feat, pos_weight = class_inps[0][:, feat][mask], sample_weights[0][mask]
Dennis Noll's avatar
Dennis Noll committed
            neg_feat, neg_weight = (
                class_inps[0][:, feat][~mask],
                sample_weights[0][~mask],
            )

            val_pos, bins = np.histogram(pos_feat, bins=bins, weights=pos_weight)
            val_neg, bins = np.histogram(neg_feat, bins=bins, weights=neg_weight)
            ax[ax_index].bar(
                bins[:-1],
                height=np.abs(val_neg) / (val_pos + np.abs(val_neg)),
                width=bins[1:] - bins[:-1],
                align="edge",
            )

Dennis Noll's avatar
Dennis Noll committed
        if mode in ["weight"]:
        ax[ax_index].set_title(name)
        ax[ax_index].legend()
    fig.tight_layout()
    return fig


def figure_multihist(data, columns=None):
    fig, ax = plt.subplots()
    df = pd.DataFrame(np.reshape(data, (data.shape[0], -1)), columns=columns)
Dennis Noll's avatar
Dennis Noll committed
    warnings.simplefilter("ignore")  # temporary fix outdated pandas
    df.hist(figsize=(20, 20))
    fig.tight_layout()
    return fig


def figure_to_image(figure):
    """Converts the matplotlib plot specified by 'figure' to a PNG image and
    returns it. The supplied figure is closed and inaccessible after this call."""
    # Save the plot to a PNG in memory.
    buf = io.BytesIO()
    plt.savefig(buf, format="png")
    plt.close(figure)
    buf.seek(0)
    # Convert PNG buffer to TF image
    image = tf.image.decode_png(buf.getvalue(), channels=4)
    # Add the batch dimension
    image = tf.expand_dims(image, 0)
    return image


def figure_lbn_weights(
    weights,
    name,
    norm=False,
    path="",
    cmap="OrRd",
    inp_particle_names=None,
    **fig_kwargs,
):
    # normalize weight tensor to a sum of 100 per row
    if norm:
        weights = np.abs(weights)
        weights = weights / np.sum(weights, axis=0).reshape((1, weights.shape[1])) * 100

    # create the figure
    ph, pv = weights.shape
    fig, ax = plt.subplots(1, 1, figsize=((pv + 1) // 2, (ph + 1) // 2))

    # create and style the plot
    ax.imshow(weights, cmap=cmap, vmin=0, vmax=100, origin="lower")
Dennis Noll's avatar
Dennis Noll committed
    ax.set_title(f"{name} weights", fontdict={"fontsize": 12})
Dennis Noll's avatar
Dennis Noll committed
    ax.set_xlabel(f"LBN {name} number")
    ax.set_xticks(list(range(weights.shape[1])))

    ax.set_ylabel("Input particle")
    ax.set_yticks(list(range(weights.shape[0])))
    if inp_particle_names:
        ax.set_yticklabels(inp_particle_names)

    # write weights into each bin
    for (i, j), val in np.ndenumerate(weights):
        ax.text(j, i, int(round(weights[i, j])), fontsize=8, ha="center", va="center", color="k")

    return fig


def figure_correlation(values, label=None, path="tmp.pdf"):
    import pandas as pd

    n = values.shape[-1]

    fig = plt.figure(figsize=(int(n / 6), int(n / 6)))
    df = pd.DataFrame(values)
    plt.matshow(df.corr(), fignum=0)

    plt.xticks(np.arange(len(label)), label, rotation="vertical", fontsize=7)
    plt.yticks(np.arange(len(label)), label, fontsize=7)

    plt.colorbar(fraction=0.046, pad=0.04)
    plt.tight_layout()

    return fig