Skip to content
Snippets Groups Projects
Code owners
Assign users and groups as approvers for specific file changes. Learn more.
plotting.py 11.60 KiB
# -*- coding: utf-8 -*-

import io
import warnings

import tensorflow as tf
import itertools
import numpy as np
from matplotlib import pyplot as plt
from sklearn.metrics import confusion_matrix, roc_curve, auc
import pandas as pd

from .numpy import one_hot


class Multiplot:
    def __init__(self, n):
        try:
            self.cols, self.rows = [i for i in n]
        except TypeError:
            self.n = n
            self.cols = np.ceil(np.sqrt(n)).astype(int)
            self.rows = np.ceil(n / self.cols).astype(int)

    def lenghts(self):
        return self.rows, self.cols

    def index(self, i):
        row = int(i / self.cols)
        col = i - row * self.cols
        return row, col


def saveplot(f):
    def helper(*args, **kwargs):
        plt.close("all")
        return f(*args, **kwargs)

    return helper


def figure_confusion_matrix(
    truth,
    prediction,
    class_names=["signal", "background"],
    sample_weight=None,
    normalize="true",
    **kwargs,
):
    assert len(class_names) == truth.shape[-1] == prediction.shape[-1]
    fig, ax = plt.subplots()
    cm = confusion_matrix(
        np.argmax(truth, axis=-1),
        np.argmax(prediction, axis=-1),
        sample_weight=sample_weight,
        normalize=normalize,
    )
    cmap = "plasma" if normalize == "true" else "viridis"
    im = ax.imshow(cm, interpolation="nearest", cmap=cmap)
    plt.title("Confusion matrix")
    fig.colorbar(im, ax=ax)

    tick_marks = np.arange(len(class_names))
    plt.xticks(tick_marks, class_names, rotation=45)
    plt.yticks(tick_marks, class_names)

    for i, j in itertools.product(range(cm.shape[0]), range(cm.shape[1])):
        plt.text(j, i, np.around(cm[i, j], decimals=2), horizontalalignment="center", size=7)

    plt.ylabel("True label" + " (normed)" * (normalize == "true"))
    plt.xlabel("Predicted label" + " (normed)" * (normalize == "pred"))
    fig.tight_layout()
    return fig


def figure_activations(activations, class_names=None):
    bins = np.linspace(0, 1.0, 10)
    n_b, n_p = activations.shape
    fig = plt.figure()

    for i in range(n_p):
        plt.hist(
            activations[:, i],
            bins,
            histtype="step",
            density=True,
            label="%i" % i if class_names is None else class_names[i],
        )

    plt.yscale("log")
    plt.legend()
    fig.tight_layout()
    return fig


def figure_history(history_csv_path):
    # self.input()["training"]["history"].path
    pd.read_csv(history_csv_path).plot(subplots=True, figsize=(30, 30), layout=(7, 6))
    fig = plt.figure()
    fig.tight_layout()
    return fig
    # plt.savefig(self.output()["plots"].path + "/history.pdf")


def plot_histories(history_csv_path, path, cut=None, roll=1):
    pdf = pd.read_csv(history_csv_path)
    pdf = pdf.set_index("epoch")
    pdf = pdf.truncate(after=cut)
    for col in pdf.columns:
        if col.startswith("val_"):
            continue
        fig = plt.figure()
        if "val_" + col in pdf.columns:
            ind = [col, "val_" + col]
        else:
            ind = col
        value = pdf[ind]
        ax = value.rolling(roll, min_periods=1).mean().plot()
        ax.set_xlabel("Epoch")
        ax.set_ylabel(col.capitalize())
        fig.tight_layout()
        plt.savefig(f"{path}/{col}.pdf")
        plt.close("all")


def figure_weights(w, y, class_names=None, relative=False):
    n_p = y.shape[1]
    bins = range(n_p)
    weights = w[:, None] * y
    fig = plt.figure()
    pos = np.sum(weights > 0, axis=0)
    neg = np.sum(weights < 0, axis=0)
    values = neg / (pos + neg)
    plt.bar(bins, values)
    plt.xticks(bins, class_names, rotation=45)
    plt.yscale("log")
    plt.xlabel("Classfication Process")
    plt.ylabel("Fraction Negative weights")
    fig.tight_layout()
    return fig


def figure_y(y, class_names=None, relative=False):
    n_p = y.shape[1]
    bins = range(n_p)

    fig = plt.figure()
    values = y.sum(axis=0)
    if relative:
        values = values / values.sum()
    plt.bar(bins, values)
    plt.xticks(bins, class_names, rotation=45)
    plt.yscale("log")
    plt.xlabel("Classfication Process")
    plt.ylabel("Number of Events" + relative * " (normed)")
    fig.tight_layout()
    return fig


def figure_node_activations(
    activations, truth, class_names=None, disjoint=False, sample_weight=None
):
    n_b, n_p = activations.shape
    multiplot = Multiplot(n_p)

    rows, cols = multiplot.lenghts()
    fig, ax = plt.subplots(rows, cols, figsize=(15, 15 * rows / cols))
    bins = np.linspace(0, 1.0, 10)

    process_activations = []
    process_weights = []
    for process in range(n_p):
        if disjoint:
            max_activations = np.argmax(activations, axis=-1)
            one_hot_max_activations = one_hot(max_activations)
            values = (activations * one_hot_max_activations)[truth[:, process]].swapaxes(0, 1)
            values[values == 0] = -10000
        else:
            values = activations[truth[:, process]].swapaxes(0, 1)
        process_activations.append(values)
        if sample_weight is not None:
            process_weights.append(sample_weight[truth[:, process]])

    old_err = np.seterr(divide="ignore", invalid="ignore")
    for node in range(n_p):
        ax_index = multiplot.index(node)
        for process in range(n_p):
            label = "%i" % process if class_names is None else class_names[process]
            plot_kwargs = {"histtype": "step", "label": label, "range": (0.0, 1.0)}
            if len(process_weights) == n_p:
                plot_kwargs["weights"] = process_weights[process]
            ax[ax_index].hist(process_activations[process][node], bins, **plot_kwargs)

        ax[ax_index].text(
            0.95,
            0.95,
            f"node {class_names[node]}",
            ha="right",
            va="top",
            transform=ax[ax_index].transAxes,
        )
        ax[multiplot.index(node)].set_yscale("log")
    np.seterr(**old_err)
    ax[multiplot.index(cols - 1)].legend(
        title="processes", bbox_to_anchor=(1.05, 1.0), loc="upper left"
    )
    fig.tight_layout()
    return fig


def figure_roc_curve(
    truth, prediction, indices=[0], class_names=None, sample_weight=None, lw=2, scale="linear"
):
    fig = plt.figure()
    for index in indices:
        fpr, tpr, _ = roc_curve(truth[:, index], prediction[:, index])
        roc_auc = auc(fpr, tpr)
        name = index if class_names is None else class_names[index]
        plt.plot(fpr, tpr, lw=lw, label=f"{name} vs All (area = {roc_auc:.2f})")
    plt.plot([0, 1], [0, 1], color="navy", lw=lw, linestyle="--")
    lower = 0.0
    if scale.endswith("log"):
        plt.xscale(scale)
        plt.yscale(scale)
        lower = 1e-5
    plt.xlim([lower, 1.0])
    plt.ylim([lower, 1.05])
    plt.xlabel("False Positive Rate")
    plt.ylabel("True Positive Rate")
    plt.title("Receiver operating characteristic curve")
    plt.legend(loc="lower right")
    return fig


def figure_inputs(
    inps, truth, sample_weight=None, columns=None, class_names=None, signalvsbkg=False, bins=20
):
    multiplot = Multiplot(inps.shape[1:][::-1])
    rows, cols = multiplot.lenghts()
    size = len(columns)
    fig, ax = plt.subplots(rows, cols, figsize=(size, size * rows / cols))

    inps = inps.reshape(inps.shape[0], -1)
    order = np.argsort(-(sample_weight[:, None] * truth).sum(axis=0))
    class_names = np.array(class_names)[order]
    for feat, name in enumerate(columns):
        ax_index = multiplot.index(feat)
        if signalvsbkg:
            mask = np.argmax(truth, axis=-1) != 0
            bins = ax[ax_index].hist(
                inps[:, feat][mask],
                histtype="stepfilled",
                weights=sample_weight[mask],
                label="Background",
                density=True,
            )[1]
            mask = np.argmax(truth, axis=-1) == 0
            ax[ax_index].hist(
                inps[:, feat][mask],
                histtype="step",
                bins=bins,
                weights=sample_weight[mask],
                label="HH",
                density=True,
                linewidth=2,
            )
        else:
            for i in range(len(class_names)):
                mask = np.argmax(truth, axis=-1) == i
                ax[ax_index].hist(
                    inps[:, feat][mask],
                    histtype="step",
                    bins=bins,
                    weights=sample_weight[mask],
                    label=class_names[i],
                    density=True,
                    linewidth=2,
                )
        ax[ax_index].set_title(name)
        ax[ax_index].legend()
    fig.tight_layout()
    return fig


def figure_weight_study(
    class_inps, sample_weights=None, columns=None, label=None, log=False, mode="plain", **kwargs
):
    multiplot = Multiplot(class_inps[0].shape[1:][::-1])
    rows, cols = multiplot.lenghts()
    size = 3
    fig, ax = plt.subplots(rows, cols, figsize=(cols * size, rows * size))

    class_inps = [inps.reshape(inps.shape[0], -1) for inps in class_inps]
    for feat, name in enumerate(columns):
        ax_index = multiplot.index(feat)

        bins = 25
        ref = None

        for i, inps in enumerate(class_inps):
            if mode == "plain":
                val, bins, _ = ax[ax_index].hist(
                    inps[:, feat],
                    histtype="step",
                    bins=bins,
                    weights=sample_weights[i],
                    density=True,
                    label=label[i] if label else None,
                    **kwargs,
                )
            if mode == "rel":
                val, bins = np.histogram(
                    inps[:, feat], bins=bins, weights=sample_weights[i], density=True
                )
                if ref is None:
                    ref = val
                ax[ax_index].bar(
                    bins[:-1],
                    height=(val - ref) / ref,
                    width=bins[1:] - bins[:-1],
                    align="edge",
                    label=label[i] if label else None,
                    alpha=0.5,
                )
        if mode == "weight":
            mask = sample_weights[0] > 0
            pos_feat, pos_weight = class_inps[0][:, feat][mask], sample_weights[0][mask]
            neg_feat, neg_weight = class_inps[0][:, feat][~mask], sample_weights[0][~mask]

            val_pos, bins = np.histogram(pos_feat, bins=bins, weights=pos_weight)
            val_neg, bins = np.histogram(neg_feat, bins=bins, weights=neg_weight)
            ax[ax_index].bar(
                bins[:-1],
                height=np.abs(val_neg) / (val_pos + np.abs(val_neg)),
                width=bins[1:] - bins[:-1],
                align="edge",
            )

        if mode in ["weight", "rel"]:
            ax[ax_index].set_yscale("log")
        ax[ax_index].set_title(name)
        ax[ax_index].legend()
    fig.tight_layout()
    return fig


def figure_multihist(data, columns=None):
    fig, ax = plt.subplots()
    df = pd.DataFrame(np.reshape(data, (data.shape[0], -1)), columns=columns)
    warnings.simplefilter("ignore")  # temporary fix outdated pandas
    df.hist(figsize=(20, 20))
    fig.tight_layout()
    return fig


def figure_to_image(figure):
    """Converts the matplotlib plot specified by 'figure' to a PNG image and
    returns it. The supplied figure is closed and inaccessible after this call."""
    # Save the plot to a PNG in memory.
    buf = io.BytesIO()
    plt.savefig(buf, format="png")
    # Closing the figure prevents it from being displayed directly inside
    # the notebook.
    plt.close(figure)
    buf.seek(0)
    # Convert PNG buffer to TF image
    image = tf.image.decode_png(buf.getvalue(), channels=4)
    # Add the batch dimension
    image = tf.expand_dims(image, 0)
    return image