Newer
Older
# -*- coding: utf-8 -*-
import io
import warnings
import tensorflow as tf
import itertools
import numpy as np
from matplotlib import pyplot as plt
from sklearn.metrics import confusion_matrix, roc_curve, auc
from .numpy import one_hot
class Multiplot:
def __init__(self, n):
try:
self.cols, self.rows = [i for i in n]
except ValueError:
n = n[0]
self.cols = np.ceil(np.sqrt(n)).astype(int)
self.rows = np.ceil(n / self.cols).astype(int)
except TypeError:
self.cols = np.ceil(np.sqrt(n)).astype(int)
self.rows = np.ceil(n / self.cols).astype(int)
def lenghts(self):
return self.rows, self.cols
def index(self, i):
row = int(i / self.cols)
col = i - row * self.cols
if self.rows == 1:
return col
if self.cols == 1:
return row
else:
return row, col
def saveplot(f):
def helper(*args, **kwargs):
plt.close("all")
return f(*args, **kwargs)
return helper
def figure_confusion_matrix(
truth,
prediction,
class_names=["signal", "background"],
sample_weight=None,
normalize="true",
):
assert len(class_names) == truth.shape[-1] == prediction.shape[-1]
fig, ax = plt.subplots()
cm = confusion_matrix(
np.argmax(truth, axis=-1),
np.argmax(prediction, axis=-1),
sample_weight=sample_weight,
normalize=normalize,
)
cmap = "plasma" if normalize == "true" else "viridis"
im = ax.imshow(cm, interpolation="nearest", cmap=cmap)
plt.title("Confusion matrix")
fig.colorbar(im, ax=ax)
tick_marks = np.arange(len(class_names))
plt.xticks(tick_marks, class_names, rotation=45)
plt.yticks(tick_marks, class_names)
for i, j in itertools.product(range(cm.shape[0]), range(cm.shape[1])):
plt.text(
j,
i,
np.around(cm[i, j], decimals=2),
horizontalalignment="center",
size=7,
)
plt.ylabel("True label" + " (normed)" * (normalize == "true"))
plt.xlabel("Predicted label" + " (normed)" * (normalize == "pred"))
fig.tight_layout()
return fig
def figure_activations(activations, class_names=None):
bins = np.linspace(0, 1.0, 10)
n_b, n_p = activations.shape
fig = plt.figure()
for i in range(n_p):
plt.hist(
activations[:, i],
bins,
density=True,
label="%i" % i if class_names is None else class_names[i],
)
plt.yscale("log")
plt.legend()
fig.tight_layout()
return fig
def figure_history(history_csv_path):
pd.read_csv(history_csv_path).plot(subplots=True, figsize=(30, 30), layout=(7, 6))
fig = plt.figure()
fig.tight_layout()
return fig
def plot_histories(history_csv_path, path, cut=None, roll=1):
df = pd.read_csv(history_csv_path)
df = df.set_index("epoch")
if col.startswith("val_"):
continue
fig = plt.figure()
ind = [col, "val_" + col]
else:
ind = col
ax = value.rolling(roll, min_periods=1).mean().plot()
ax.set_xlabel("Epoch")
ax.set_ylabel(col.capitalize())
fig.tight_layout()
plt.savefig(f"{path}/{col}.pdf")
plt.close("all")
def figure_dict(d):
fig = plt.figure(figsize=(8, len(d) / 3))
plt.barh(*zip(*d.items()))
return fig
def plot_dict(d, path="tmp.pdf"):
fig = figure_dict(d)
plt.savefig(path)
plt.close("all")
def figure_weights(w, y, class_names=None, relative=False):
n_p = y.shape[1]
bins = range(n_p)
weights = w[:, None] * y
fig = plt.figure()
pos = np.sum(weights > 0, axis=0)
neg = np.sum(weights < 0, axis=0)
values = neg / (pos + neg)
plt.bar(bins, values)
plt.xticks(bins, class_names, rotation=45)
plt.yscale("log")
plt.xlabel("Classfication Process")
plt.ylabel("Fraction Negative weights")
fig.tight_layout()
return fig
def figure_y(y, class_names=None, relative=False):
n_p = y.shape[1]
bins = range(n_p)
fig = plt.figure()
values = y.sum(axis=0)
if relative:
values = values / values.sum()
plt.bar(bins, values)
plt.xticks(bins, class_names, rotation=45)
plt.yscale("log")
plt.xlabel("Classfication Process")
plt.ylabel("Number of Events" + relative * " (normed)")
fig.tight_layout()
return fig
def figure_node_activations(
activations, truth, class_names=None, disjoint=False, sample_weight=None
):
n_b, n_p = activations.shape
multiplot = Multiplot(n_p)
rows, cols = multiplot.lenghts()
size = 5
fig, ax = plt.subplots(rows, cols, figsize=(size * cols, size * rows))
bins = np.linspace(0, 1.0, 10)
process_activations = []
process_weights = []
for process in range(n_p):
this_process = truth[:, process].astype(np.bool)
if disjoint:
max_activations = np.argmax(activations, axis=-1)
one_hot_max_activations = one_hot(max_activations, n=n_p)
values = (activations * one_hot_max_activations)[this_process].swapaxes(0, 1)
values[values == 0] = -10000
else:
values = activations[this_process].swapaxes(0, 1)
process_activations.append(values)
if sample_weight is not None:
process_weights.append(sample_weight[this_process])
old_err = np.seterr(divide="ignore", invalid="ignore")
for node in range(n_p):
ax_index = multiplot.index(node)
for process in range(n_p):
label = "%i" % process if class_names is None else class_names[process]
plot_kwargs = {"histtype": "step", "label": label, "range": (0.0, 1.0)}
if len(process_weights) == n_p:
plot_kwargs["weights"] = process_weights[process]
ax[ax_index].hist(process_activations[process][node], bins, **plot_kwargs)
ax[ax_index].text(
0.95,
0.95,
f"node {class_names[node]}",
ha="right",
va="top",
transform=ax[ax_index].transAxes,
)
ax[multiplot.index(node)].set_yscale("log")
ax[multiplot.index(cols - 1)].legend(
title="processes", bbox_to_anchor=(1.05, 1.0), loc="upper left"
)
fig.tight_layout()
return fig
def figure_roc_curve(
truth,
prediction,
indices=[0],
class_names=None,
sample_weight=None,
lw=2,
scale="linear",
):
fig = plt.figure()
for index in indices:
fpr, tpr, _ = roc_curve(truth[:, index], prediction[:, index])
roc_auc = auc(fpr, tpr)
name = index if class_names is None else class_names[index]
plt.plot(fpr, tpr, lw=lw, label=f"{name} vs All (area = {roc_auc:.2f})")
plt.plot([0, 1], [0, 1], color="navy", lw=lw, linestyle="--")
lower = 0.0
if scale.endswith("log"):
plt.xscale(scale)
plt.yscale(scale)
lower = 1e-5
plt.xlim([lower, 1.0])
plt.ylim([lower, 1.05])
plt.xlabel("False Positive Rate")
plt.ylabel("True Positive Rate")
plt.title("Receiver operating characteristic curve")
plt.legend(loc="lower right")
return fig
def figure_inputs(
inps,
truth,
sample_weight=None,
columns=None,
class_names=None,
signalvsbkg=False,
bins=20,
overflow=1000,
):
multiplot = Multiplot(inps.shape[1:][::-1])
rows, cols = multiplot.lenghts()
size = 5
fig, ax = plt.subplots(rows, cols, figsize=(size * cols, size * rows))
ax = ax if isinstance(ax, np.ndarray) else np.array([ax])
inps = inps.reshape(inps.shape[0], -1)
inps = np.clip(inps, -overflow, overflow)
order = np.argsort(-(sample_weight[:, None] * truth).sum(axis=0))
class_names = np.array(class_names)[order]
for feat, name in enumerate(columns):
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
ax_index = multiplot.index(feat)
if signalvsbkg:
mask = np.argmax(truth, axis=-1) != 0
bins = ax[ax_index].hist(
inps[:, feat][mask],
histtype="stepfilled",
weights=sample_weight[mask],
label="Background",
density=True,
)[1]
mask = np.argmax(truth, axis=-1) == 0
ax[ax_index].hist(
inps[:, feat][mask],
histtype="step",
bins=bins,
weights=sample_weight[mask],
label="HH",
density=True,
linewidth=2,
)
else:
for i in range(len(class_names)):
mask = np.argmax(truth, axis=-1) == i
ax[ax_index].hist(
inps[:, feat][mask],
histtype="step",
bins=bins,
weights=sample_weight[mask],
label=class_names[i],
density=True,
linewidth=2,
)
ax[ax_index].set_title(name)
ax[ax_index].legend()
fig.tight_layout()
return fig
def figure_weight_study(
class_inps,
sample_weights=None,
columns=None,
label=None,
log=False,
mode="plain",
**kwargs,
):
multiplot = Multiplot(class_inps[0].shape[1:][::-1])
rows, cols = multiplot.lenghts()
fig, ax = plt.subplots(rows, cols, figsize=(size * cols, size * rows))
ax = ax if isinstance(ax, np.ndarray) else np.array([ax])
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
class_inps = [inps.reshape(inps.shape[0], -1) for inps in class_inps]
for feat, name in enumerate(columns):
ax_index = multiplot.index(feat)
bins = 25
ref = None
for i, inps in enumerate(class_inps):
if mode == "plain":
val, bins, _ = ax[ax_index].hist(
inps[:, feat],
histtype="step",
bins=bins,
weights=sample_weights[i],
density=True,
label=label[i] if label else None,
**kwargs,
)
if mode == "rel":
val, bins = np.histogram(
inps[:, feat], bins=bins, weights=sample_weights[i], density=True
)
if ref is None:
ref = val
ax[ax_index].bar(
bins[:-1],
height=(val - ref) / ref,
width=bins[1:] - bins[:-1],
align="edge",
label=label[i] if label else None,
alpha=0.5,
)
if mode == "weight":
mask = sample_weights[0] > 0
pos_feat, pos_weight = class_inps[0][:, feat][mask], sample_weights[0][mask]
neg_feat, neg_weight = (
class_inps[0][:, feat][~mask],
sample_weights[0][~mask],
)
val_pos, bins = np.histogram(pos_feat, bins=bins, weights=pos_weight)
val_neg, bins = np.histogram(neg_feat, bins=bins, weights=neg_weight)
ax[ax_index].bar(
bins[:-1],
height=np.abs(val_neg) / (val_pos + np.abs(val_neg)),
width=bins[1:] - bins[:-1],
align="edge",
)
ax[ax_index].set_yscale("log")
ax[ax_index].set_title(name)
ax[ax_index].legend()
fig.tight_layout()
return fig
def figure_multihist(data, columns=None):
fig, ax = plt.subplots()
df = pd.DataFrame(np.reshape(data, (data.shape[0], -1)), columns=columns)
warnings.simplefilter("ignore") # temporary fix outdated pandas
df.hist(figsize=(20, 20))
fig.tight_layout()
return fig
def figure_to_image(figure):
"""Converts the matplotlib plot specified by 'figure' to a PNG image and
returns it. The supplied figure is closed and inaccessible after this call."""
# Save the plot to a PNG in memory.
buf = io.BytesIO()
plt.savefig(buf, format="png")
# Closing the figure prevents it from being displayed directly inside
# the notebook.
plt.close(figure)
buf.seek(0)
# Convert PNG buffer to TF image
image = tf.image.decode_png(buf.getvalue(), channels=4)
# Add the batch dimension
image = tf.expand_dims(image, 0)
return image
def figure_lbn_weights(
weights,
name,
norm=False,
path="",
cmap="OrRd",
inp_particle_names=None,
**fig_kwargs,
):
# normalize weight tensor to a sum of 100 per row
if norm:
weights = np.abs(weights)
weights = weights / np.sum(weights, axis=0).reshape((1, weights.shape[1])) * 100
# create the figure
ph, pv = weights.shape
fig, ax = plt.subplots(1, 1, figsize=((pv + 1) // 2, (ph + 1) // 2))
# create and style the plot
ax.imshow(weights, cmap=cmap, vmin=0, vmax=100, origin="lower")
ax.set_title(f"{name} weights", fontdict={"fontsize": 12})
ax.set_xticks(list(range(weights.shape[1])))
ax.set_ylabel("Input particle")
ax.set_yticks(list(range(weights.shape[0])))
if inp_particle_names:
ax.set_yticklabels(inp_particle_names)
# write weights into each bin
for (i, j), val in np.ndenumerate(weights):
ax.text(j, i, int(round(weights[i, j])), fontsize=8, ha="center", va="center", color="k")
return fig