Newer
Older
# -*- coding: utf-8 -*-
import io
import warnings
import tensorflow as tf
import itertools
import numpy as np
from matplotlib import pyplot as plt
from sklearn.metrics import confusion_matrix, roc_curve, auc
from .numpy import one_hot
class Multiplot:
def __init__(self, n):
try:
self.cols, self.rows = [i for i in n]
except ValueError:
n = n[0]
self.cols = np.ceil(np.sqrt(n)).astype(int)
self.rows = np.ceil(n / self.cols).astype(int)
except TypeError:
self.cols = np.ceil(np.sqrt(n)).astype(int)
self.rows = np.ceil(n / self.cols).astype(int)
def lenghts(self):
return self.rows, self.cols
def index(self, i):
row = int(i / self.cols)
col = i - row * self.cols
if self.rows == 1:
return col
if self.cols == 1:
return row
else:
return row, col
def saveplot(f):
def helper(*args, **kwargs):
plt.close("all")
return f(*args, **kwargs)
return helper
def figure_confusion_matrix(
truth,
prediction,
class_names=["signal", "background"],
sample_weight=None,
normalize="true",
):
assert len(class_names) == truth.shape[-1] == prediction.shape[-1]
fig, ax = plt.subplots()
cm = confusion_matrix(
np.argmax(truth, axis=-1),
np.argmax(prediction, axis=-1),
sample_weight=sample_weight,
normalize=normalize,
)
cmap = "plasma" if normalize == "true" else "viridis"
im = ax.imshow(cm, interpolation="nearest", cmap=cmap)
plt.title("Confusion matrix")
fig.colorbar(im, ax=ax)
tick_marks = np.arange(len(class_names))
plt.xticks(tick_marks, class_names, rotation=45)
plt.yticks(tick_marks, class_names)
for i, j in itertools.product(range(cm.shape[0]), range(cm.shape[1])):
fontdict = {"color": "black" if cm[i, j] > half else "white"}
plt.text(
j,
i,
np.around(cm[i, j], decimals=2),
horizontalalignment="center",
size=7,
plt.ylabel("True label" + " (normed)" * (normalize == "true"))
plt.xlabel("Predicted label" + " (normed)" * (normalize == "pred"))
fig.tight_layout()
return fig
def figure_activations(activations, class_names=None):
bins = np.linspace(0, 1.0, 10)
n_b, n_p = activations.shape
fig = plt.figure()
for i in range(n_p):
plt.hist(
activations[:, i],
bins,
histtype="step",
density=True,
label="%i" % i if class_names is None else class_names[i],
)
plt.yscale("log")
plt.legend()
fig.tight_layout()
return fig
def figure_history(history_csv_path):
pd.read_csv(history_csv_path).plot(subplots=True, figsize=(30, 30), layout=(7, 6))
fig = plt.figure()
fig.tight_layout()
return fig
def plot_histories(history_csv_path, path, cut=None, roll=1):
df = pd.read_csv(history_csv_path)
df = df.set_index("epoch")
if col.startswith("val_"):
continue
fig = plt.figure()
ind = [col, "val_" + col]
else:
ind = col
ax = value.rolling(roll, min_periods=1).mean().plot()
ax.set_xlabel("Epoch")
ax.set_ylabel(col.capitalize())
fig.tight_layout()
plt.savefig(f"{path}/{col}.pdf")
plt.close("all")
def figure_dict(d, xlabel=""):
fig = plt.figure(figsize=(8, len(d) / 3))
plt.barh(*zip(*d.items()))
plt.xlabel(xlabel)
return fig
def plot_dict(d, path="tmp.pdf"):
fig = figure_dict(d)
plt.savefig(path)
plt.close("all")
def figure_weights(w, y, class_names=None, relative=False):
n_p = y.shape[1]
bins = range(n_p)
weights = w[:, None] * y
fig = plt.figure()
pos = np.sum(weights > 0, axis=0)
neg = np.sum(weights < 0, axis=0)
values = neg / (pos + neg)
plt.bar(bins, values)
plt.xticks(bins, class_names, rotation=45)
plt.yscale("log")
plt.xlabel("Classfication Process")
plt.ylabel("Fraction Negative weights")
fig.tight_layout()
return fig
def figure_y(y, class_names=None, relative=False):
n_p = y.shape[1]
bins = range(n_p)
fig = plt.figure()
values = y.sum(axis=0)
if relative:
values = values / values.sum()
plt.bar(bins, values)
plt.xticks(bins, class_names, rotation=45)
plt.yscale("log")
plt.xlabel("Classfication Process")
plt.ylabel("Number of Events" + relative * " (normed)")
fig.tight_layout()
return fig
def figure_node_activations(
activations, truth, class_names=None, disjoint=False, sample_weight=None
):
n_b, n_p = activations.shape
multiplot = Multiplot(n_p)
rows, cols = multiplot.lenghts()
size = 5
fig, ax = plt.subplots(rows, cols, figsize=(size * cols, size * rows))
bins = np.linspace(0, 1.0, 10)
process_activations = []
process_weights = []
for process in range(n_p):
this_process = truth[:, process].astype(np.bool)
if disjoint:
max_activations = np.argmax(activations, axis=-1)
one_hot_max_activations = one_hot(max_activations, n=n_p)
values = (activations * one_hot_max_activations)[this_process].swapaxes(0, 1)
values[values == 0] = -10000
else:
values = activations[this_process].swapaxes(0, 1)
process_activations.append(values)
if sample_weight is not None:
process_weights.append(sample_weight[this_process])
old_err = np.seterr(divide="ignore", invalid="ignore")
for node in range(n_p):
ax_index = multiplot.index(node)
for process in range(n_p):
label = "%i" % process if class_names is None else class_names[process]
plot_kwargs = {"histtype": "step", "label": label, "range": (0.0, 1.0)}
if len(process_weights) == n_p:
plot_kwargs["weights"] = process_weights[process]
ax[ax_index].hist(process_activations[process][node], bins, **plot_kwargs)
ax[ax_index].text(
0.95,
0.95,
f"node {class_names[node]}",
ha="right",
va="top",
transform=ax[ax_index].transAxes,
)
ax[multiplot.index(node)].set_yscale("log")
ax[multiplot.index(cols - 1)].legend(
title="processes", bbox_to_anchor=(1.05, 1.0), loc="upper left"
)
fig.tight_layout()
return fig
def figure_roc_curve(
truth,
prediction,
indices=[0],
class_names=None,
sample_weight=None,
lw=2,
scale="linear",
):
fig = plt.figure()
for index in indices:
fpr, tpr, _ = roc_curve(truth[:, index], prediction[:, index])
roc_auc = auc(fpr, tpr)
name = index if class_names is None else class_names[index]
plt.plot(fpr, tpr, lw=lw, label=f"{name} vs All (area = {roc_auc:.2f})")
plt.plot([0, 1], [0, 1], color="navy", lw=lw, linestyle="--")
lower = 0.0
if scale.endswith("log"):
plt.xscale(scale)
plt.yscale(scale)
lower = 1e-5
plt.xlim([lower, 1.0])
plt.ylim([lower, 1.05])
plt.xlabel("False Positive Rate")
plt.ylabel("True Positive Rate")
plt.title("Receiver operating characteristic curve")
plt.legend(loc="lower right")
return fig
def figure_inputs(
inps,
truth,
sample_weight=None,
columns=None,
class_names=None,
signalvsbkg=False,
bins=20,
overflow=1000,
):
multiplot = Multiplot(inps.shape[1:][::-1])
rows, cols = multiplot.lenghts()
size = 5
fig, ax = plt.subplots(rows, cols, figsize=(size * cols, size * rows))
ax = ax if isinstance(ax, np.ndarray) else np.array([ax])
inps = inps.reshape(-1, np.prod(inps.shape[1:]))
inps = np.clip(inps, -overflow, overflow)
for feat, name in enumerate(columns):
ax_index = multiplot.index(feat)
if signalvsbkg:
mask = np.argmax(truth, axis=-1) != 0
_bins = ax[ax_index].hist(
inps[:, feat][mask],
histtype="stepfilled",
weights=sample_weight[mask],
label="Background",
density=True,
)[1]
mask = np.argmax(truth, axis=-1) == 0
ax[ax_index].hist(
inps[:, feat][mask],
histtype="step",
weights=sample_weight[mask],
label="HH",
density=True,
linewidth=2,
)
else:
histtype = "stepfilled"
for i in range(len(class_names)):
mask = np.argmax(truth, axis=-1) == i
_bins = ax[ax_index].hist(
inps[:, feat][mask],
histtype=histtype,
bins=_bins,
weights=sample_weight[mask],
label=class_names[i],
density=True,
linewidth=2,
)[1]
histtype = "step"
ax[ax_index].set_yscale(yscale)
ax[ax_index].set_title(name)
ax[ax_index].legend()
fig.tight_layout()
return fig
def figure_weight_study(
class_inps,
sample_weights=None,
columns=None,
label=None,
log=False,
mode="plain",
**kwargs,
):
multiplot = Multiplot(class_inps[0].shape[1:][::-1])
rows, cols = multiplot.lenghts()
fig, ax = plt.subplots(rows, cols, figsize=(size * cols, size * rows))
ax = ax if isinstance(ax, np.ndarray) else np.array([ax])
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
class_inps = [inps.reshape(inps.shape[0], -1) for inps in class_inps]
for feat, name in enumerate(columns):
ax_index = multiplot.index(feat)
bins = 25
ref = None
for i, inps in enumerate(class_inps):
if mode == "plain":
val, bins, _ = ax[ax_index].hist(
inps[:, feat],
histtype="step",
bins=bins,
weights=sample_weights[i],
density=True,
label=label[i] if label else None,
**kwargs,
)
if mode == "rel":
val, bins = np.histogram(
inps[:, feat], bins=bins, weights=sample_weights[i], density=True
)
if ref is None:
ref = val
ax[ax_index].bar(
bins[:-1],
height=(val - ref) / ref,
width=bins[1:] - bins[:-1],
align="edge",
label=label[i] if label else None,
alpha=0.5,
)
if mode == "weight":
mask = sample_weights[0] > 0
pos_feat, pos_weight = class_inps[0][:, feat][mask], sample_weights[0][mask]
neg_feat, neg_weight = (
class_inps[0][:, feat][~mask],
sample_weights[0][~mask],
)
val_pos, bins = np.histogram(pos_feat, bins=bins, weights=pos_weight)
val_neg, bins = np.histogram(neg_feat, bins=bins, weights=neg_weight)
ax[ax_index].bar(
bins[:-1],
height=np.abs(val_neg) / (val_pos + np.abs(val_neg)),
width=bins[1:] - bins[:-1],
align="edge",
)
ax[ax_index].set_yscale("log")
ax[ax_index].set_title(name)
ax[ax_index].legend()
fig.tight_layout()
return fig
def figure_multihist(data, columns=None):
fig, ax = plt.subplots()
df = pd.DataFrame(np.reshape(data, (data.shape[0], -1)), columns=columns)
warnings.simplefilter("ignore") # temporary fix outdated pandas
df.hist(figsize=(20, 20))
fig.tight_layout()
return fig
def figure_to_image(figure):
"""Converts the matplotlib plot specified by 'figure' to a PNG image and
returns it. The supplied figure is closed and inaccessible after this call."""
# Save the plot to a PNG in memory.
buf = io.BytesIO()
plt.savefig(buf, format="png")
plt.close(figure)
buf.seek(0)
# Convert PNG buffer to TF image
image = tf.image.decode_png(buf.getvalue(), channels=4)
# Add the batch dimension
image = tf.expand_dims(image, 0)
return image
def figure_lbn_weights(
weights,
name,
norm=False,
path="",
cmap="OrRd",
inp_particle_names=None,
**fig_kwargs,
):
# normalize weight tensor to a sum of 100 per row
if norm:
weights = np.abs(weights)
weights = weights / np.sum(weights, axis=0).reshape((1, weights.shape[1])) * 100
# create the figure
ph, pv = weights.shape
fig, ax = plt.subplots(1, 1, figsize=((pv + 1) // 2, (ph + 1) // 2))
# create and style the plot
ax.imshow(weights, cmap=cmap, vmin=0, vmax=100, origin="lower")
ax.set_title(f"{name} weights", fontdict={"fontsize": 12})
ax.set_xticks(list(range(weights.shape[1])))
ax.set_ylabel("Input particle")
ax.set_yticks(list(range(weights.shape[0])))
if inp_particle_names:
ax.set_yticklabels(inp_particle_names)
# write weights into each bin
for (i, j), val in np.ndenumerate(weights):
ax.text(j, i, int(round(weights[i, j])), fontsize=8, ha="center", va="center", color="k")
return fig
def figure_correlation(values, label=None, path="tmp.pdf"):
import pandas as pd
n = values.shape[-1]
fig = plt.figure(figsize=(int(n / 6), int(n / 6)))
df = pd.DataFrame(values)
plt.matshow(df.corr(), fignum=0)
plt.xticks(np.arange(len(label)), label, rotation="vertical", fontsize=7)
plt.yticks(np.arange(len(label)), label, fontsize=7)
plt.colorbar(fraction=0.046, pad=0.04)
plt.tight_layout()
return fig