-
Dennis Noll authoredDennis Noll authored
Code owners
Assign users and groups as approvers for specific file changes. Learn more.
plotting.py 11.60 KiB
# -*- coding: utf-8 -*-
import io
import warnings
import tensorflow as tf
import itertools
import numpy as np
from matplotlib import pyplot as plt
from sklearn.metrics import confusion_matrix, roc_curve, auc
import pandas as pd
from .numpy import one_hot
class Multiplot:
def __init__(self, n):
try:
self.cols, self.rows = [i for i in n]
except TypeError:
self.n = n
self.cols = np.ceil(np.sqrt(n)).astype(int)
self.rows = np.ceil(n / self.cols).astype(int)
def lenghts(self):
return self.rows, self.cols
def index(self, i):
row = int(i / self.cols)
col = i - row * self.cols
return row, col
def saveplot(f):
def helper(*args, **kwargs):
plt.close("all")
return f(*args, **kwargs)
return helper
def figure_confusion_matrix(
truth,
prediction,
class_names=["signal", "background"],
sample_weight=None,
normalize="true",
**kwargs,
):
assert len(class_names) == truth.shape[-1] == prediction.shape[-1]
fig, ax = plt.subplots()
cm = confusion_matrix(
np.argmax(truth, axis=-1),
np.argmax(prediction, axis=-1),
sample_weight=sample_weight,
normalize=normalize,
)
cmap = "plasma" if normalize == "true" else "viridis"
im = ax.imshow(cm, interpolation="nearest", cmap=cmap)
plt.title("Confusion matrix")
fig.colorbar(im, ax=ax)
tick_marks = np.arange(len(class_names))
plt.xticks(tick_marks, class_names, rotation=45)
plt.yticks(tick_marks, class_names)
for i, j in itertools.product(range(cm.shape[0]), range(cm.shape[1])):
plt.text(j, i, np.around(cm[i, j], decimals=2), horizontalalignment="center", size=7)
plt.ylabel("True label" + " (normed)" * (normalize == "true"))
plt.xlabel("Predicted label" + " (normed)" * (normalize == "pred"))
fig.tight_layout()
return fig
def figure_activations(activations, class_names=None):
bins = np.linspace(0, 1.0, 10)
n_b, n_p = activations.shape
fig = plt.figure()
for i in range(n_p):
plt.hist(
activations[:, i],
bins,
histtype="step",
density=True,
label="%i" % i if class_names is None else class_names[i],
)
plt.yscale("log")
plt.legend()
fig.tight_layout()
return fig
def figure_history(history_csv_path):
# self.input()["training"]["history"].path
pd.read_csv(history_csv_path).plot(subplots=True, figsize=(30, 30), layout=(7, 6))
fig = plt.figure()
fig.tight_layout()
return fig
# plt.savefig(self.output()["plots"].path + "/history.pdf")
def plot_histories(history_csv_path, path, cut=None, roll=1):
pdf = pd.read_csv(history_csv_path)
pdf = pdf.set_index("epoch")
pdf = pdf.truncate(after=cut)
for col in pdf.columns:
if col.startswith("val_"):
continue
fig = plt.figure()
if "val_" + col in pdf.columns:
ind = [col, "val_" + col]
else:
ind = col
value = pdf[ind]
ax = value.rolling(roll, min_periods=1).mean().plot()
ax.set_xlabel("Epoch")
ax.set_ylabel(col.capitalize())
fig.tight_layout()
plt.savefig(f"{path}/{col}.pdf")
plt.close("all")
def figure_weights(w, y, class_names=None, relative=False):
n_p = y.shape[1]
bins = range(n_p)
weights = w[:, None] * y
fig = plt.figure()
pos = np.sum(weights > 0, axis=0)
neg = np.sum(weights < 0, axis=0)
values = neg / (pos + neg)
plt.bar(bins, values)
plt.xticks(bins, class_names, rotation=45)
plt.yscale("log")
plt.xlabel("Classfication Process")
plt.ylabel("Fraction Negative weights")
fig.tight_layout()
return fig
def figure_y(y, class_names=None, relative=False):
n_p = y.shape[1]
bins = range(n_p)
fig = plt.figure()
values = y.sum(axis=0)
if relative:
values = values / values.sum()
plt.bar(bins, values)
plt.xticks(bins, class_names, rotation=45)
plt.yscale("log")
plt.xlabel("Classfication Process")
plt.ylabel("Number of Events" + relative * " (normed)")
fig.tight_layout()
return fig
def figure_node_activations(
activations, truth, class_names=None, disjoint=False, sample_weight=None
):
n_b, n_p = activations.shape
multiplot = Multiplot(n_p)
rows, cols = multiplot.lenghts()
fig, ax = plt.subplots(rows, cols, figsize=(15, 15 * rows / cols))
bins = np.linspace(0, 1.0, 10)
process_activations = []
process_weights = []
for process in range(n_p):
if disjoint:
max_activations = np.argmax(activations, axis=-1)
one_hot_max_activations = one_hot(max_activations)
values = (activations * one_hot_max_activations)[truth[:, process]].swapaxes(0, 1)
values[values == 0] = -10000
else:
values = activations[truth[:, process]].swapaxes(0, 1)
process_activations.append(values)
if sample_weight is not None:
process_weights.append(sample_weight[truth[:, process]])
old_err = np.seterr(divide="ignore", invalid="ignore")
for node in range(n_p):
ax_index = multiplot.index(node)
for process in range(n_p):
label = "%i" % process if class_names is None else class_names[process]
plot_kwargs = {"histtype": "step", "label": label, "range": (0.0, 1.0)}
if len(process_weights) == n_p:
plot_kwargs["weights"] = process_weights[process]
ax[ax_index].hist(process_activations[process][node], bins, **plot_kwargs)
ax[ax_index].text(
0.95,
0.95,
f"node {class_names[node]}",
ha="right",
va="top",
transform=ax[ax_index].transAxes,
)
ax[multiplot.index(node)].set_yscale("log")
np.seterr(**old_err)
ax[multiplot.index(cols - 1)].legend(
title="processes", bbox_to_anchor=(1.05, 1.0), loc="upper left"
)
fig.tight_layout()
return fig
def figure_roc_curve(
truth, prediction, indices=[0], class_names=None, sample_weight=None, lw=2, scale="linear"
):
fig = plt.figure()
for index in indices:
fpr, tpr, _ = roc_curve(truth[:, index], prediction[:, index])
roc_auc = auc(fpr, tpr)
name = index if class_names is None else class_names[index]
plt.plot(fpr, tpr, lw=lw, label=f"{name} vs All (area = {roc_auc:.2f})")
plt.plot([0, 1], [0, 1], color="navy", lw=lw, linestyle="--")
lower = 0.0
if scale.endswith("log"):
plt.xscale(scale)
plt.yscale(scale)
lower = 1e-5
plt.xlim([lower, 1.0])
plt.ylim([lower, 1.05])
plt.xlabel("False Positive Rate")
plt.ylabel("True Positive Rate")
plt.title("Receiver operating characteristic curve")
plt.legend(loc="lower right")
return fig
def figure_inputs(
inps, truth, sample_weight=None, columns=None, class_names=None, signalvsbkg=False, bins=20
):
multiplot = Multiplot(inps.shape[1:][::-1])
rows, cols = multiplot.lenghts()
size = len(columns)
fig, ax = plt.subplots(rows, cols, figsize=(size, size * rows / cols))
inps = inps.reshape(inps.shape[0], -1)
order = np.argsort(-(sample_weight[:, None] * truth).sum(axis=0))
class_names = np.array(class_names)[order]
for feat, name in enumerate(columns):
ax_index = multiplot.index(feat)
if signalvsbkg:
mask = np.argmax(truth, axis=-1) != 0
bins = ax[ax_index].hist(
inps[:, feat][mask],
histtype="stepfilled",
weights=sample_weight[mask],
label="Background",
density=True,
)[1]
mask = np.argmax(truth, axis=-1) == 0
ax[ax_index].hist(
inps[:, feat][mask],
histtype="step",
bins=bins,
weights=sample_weight[mask],
label="HH",
density=True,
linewidth=2,
)
else:
for i in range(len(class_names)):
mask = np.argmax(truth, axis=-1) == i
ax[ax_index].hist(
inps[:, feat][mask],
histtype="step",
bins=bins,
weights=sample_weight[mask],
label=class_names[i],
density=True,
linewidth=2,
)
ax[ax_index].set_title(name)
ax[ax_index].legend()
fig.tight_layout()
return fig
def figure_weight_study(
class_inps, sample_weights=None, columns=None, label=None, log=False, mode="plain", **kwargs
):
multiplot = Multiplot(class_inps[0].shape[1:][::-1])
rows, cols = multiplot.lenghts()
size = 3
fig, ax = plt.subplots(rows, cols, figsize=(cols * size, rows * size))
class_inps = [inps.reshape(inps.shape[0], -1) for inps in class_inps]
for feat, name in enumerate(columns):
ax_index = multiplot.index(feat)
bins = 25
ref = None
for i, inps in enumerate(class_inps):
if mode == "plain":
val, bins, _ = ax[ax_index].hist(
inps[:, feat],
histtype="step",
bins=bins,
weights=sample_weights[i],
density=True,
label=label[i] if label else None,
**kwargs,
)
if mode == "rel":
val, bins = np.histogram(
inps[:, feat], bins=bins, weights=sample_weights[i], density=True
)
if ref is None:
ref = val
ax[ax_index].bar(
bins[:-1],
height=(val - ref) / ref,
width=bins[1:] - bins[:-1],
align="edge",
label=label[i] if label else None,
alpha=0.5,
)
if mode == "weight":
mask = sample_weights[0] > 0
pos_feat, pos_weight = class_inps[0][:, feat][mask], sample_weights[0][mask]
neg_feat, neg_weight = class_inps[0][:, feat][~mask], sample_weights[0][~mask]
val_pos, bins = np.histogram(pos_feat, bins=bins, weights=pos_weight)
val_neg, bins = np.histogram(neg_feat, bins=bins, weights=neg_weight)
ax[ax_index].bar(
bins[:-1],
height=np.abs(val_neg) / (val_pos + np.abs(val_neg)),
width=bins[1:] - bins[:-1],
align="edge",
)
if mode in ["weight", "rel"]:
ax[ax_index].set_yscale("log")
ax[ax_index].set_title(name)
ax[ax_index].legend()
fig.tight_layout()
return fig
def figure_multihist(data, columns=None):
fig, ax = plt.subplots()
df = pd.DataFrame(np.reshape(data, (data.shape[0], -1)), columns=columns)
warnings.simplefilter("ignore") # temporary fix outdated pandas
df.hist(figsize=(20, 20))
fig.tight_layout()
return fig
def figure_to_image(figure):
"""Converts the matplotlib plot specified by 'figure' to a PNG image and
returns it. The supplied figure is closed and inaccessible after this call."""
# Save the plot to a PNG in memory.
buf = io.BytesIO()
plt.savefig(buf, format="png")
# Closing the figure prevents it from being displayed directly inside
# the notebook.
plt.close(figure)
buf.seek(0)
# Convert PNG buffer to TF image
image = tf.image.decode_png(buf.getvalue(), channels=4)
# Add the batch dimension
image = tf.expand_dims(image, 0)
return image