Skip to content
Snippets Groups Projects
Commit a18a0e81 authored by jan.middendorf@rwth-aachen.de's avatar jan.middendorf@rwth-aachen.de
Browse files

Merge branch 'master' of git.rwth-aachen.de:3pia/cms_analyses/tools

parents aa7bd212 bbb0bcc3
No related branches found
No related tags found
No related merge requests found
...@@ -8,7 +8,7 @@ class SKDict(dict): ...@@ -8,7 +8,7 @@ class SKDict(dict):
@staticmethod @staticmethod
def keyify(keyish): def keyify(keyish):
if not isinstance(keyish, (tuple, list, set, frozenset)): if not isinstance(keyish, (tuple, list, set, frozenset)):
keyish = keyish, keyish = (keyish,)
keyish = frozenset(keyish) keyish = frozenset(keyish)
assert not any(isinstance(key, set) for key in keyish) assert not any(isinstance(key, set) for key in keyish)
return keyish return keyish
...@@ -19,7 +19,7 @@ class SKDict(dict): ...@@ -19,7 +19,7 @@ class SKDict(dict):
def update(self, *args, **kwargs): def update(self, *args, **kwargs):
# assert 0 <= len(args) <= 1 # assert 0 <= len(args) <= 1
args += kwargs, args += (kwargs,)
for arg in args: for arg in args:
for k, v in arg.items(): for k, v in arg.items():
self[k] = v self[k] = v
...@@ -34,11 +34,7 @@ class SKDict(dict): ...@@ -34,11 +34,7 @@ class SKDict(dict):
key = self.keyify(key) key = self.keyify(key)
if key in self: if key in self:
return super(SKDict, self).__getitem__(key) return super(SKDict, self).__getitem__(key)
ret = self.__class__({ ret = self.__class__({k - key: v for k, v in self.items() if key <= k})
k - key: v
for k, v in self.items()
if key <= k
})
if not ret: if not ret:
raise KeyError(key) raise KeyError(key)
return ret return ret
...@@ -64,10 +60,7 @@ class SKDict(dict): ...@@ -64,10 +60,7 @@ class SKDict(dict):
assert all(isinstance(inst, cls) for inst in insts) assert all(isinstance(inst, cls) for inst in insts)
keys = set() keys = set()
keys.update(*(inst.keys() for inst in insts)) keys.update(*(inst.keys() for inst in insts))
return cls({ return cls({key: tuple(inst.get(key) for inst in insts) for key in keys})
key: tuple(inst.get(key) for inst in insts)
for key in keys
})
def only(self, *keys): def only(self, *keys):
return self.__class__({key: self[key] for key in keys}) return self.__class__({key: self[key] for key in keys})
...@@ -88,7 +81,7 @@ class SKDict(dict): ...@@ -88,7 +81,7 @@ class SKDict(dict):
assert len(keys) == 1 # bad depth assert len(keys) == 1 # bad depth
return list(keys)[0] return list(keys)[0]
elif ads == {False}: elif ads == {False}:
return (), return ((),)
else: else:
raise RuntimeError("bad depth") raise RuntimeError("bad depth")
...@@ -97,10 +90,7 @@ class SKDict(dict): ...@@ -97,10 +90,7 @@ class SKDict(dict):
@property @property
def pretty(self): def pretty(self):
return { return {"/".join(sorted(map(str, k))): v for k, v in self.items()}
"/".join(sorted(map(str, k))): v
for k, v in self.items()
}
class GetNextSlice(object): class GetNextSlice(object):
...@@ -113,7 +103,7 @@ class GetNextSlice(object): ...@@ -113,7 +103,7 @@ class GetNextSlice(object):
if self.curr is None: if self.curr is None:
self.curr = self.next() self.curr = self.next()
self.pos = 0 self.pos = 0
sli = self.curr[self.pos:self.pos + num] sli = self.curr[self.pos : self.pos + num]
self.pos += num self.pos += num
if len(sli) < num: if len(sli) < num:
del self.curr del self.curr
...@@ -135,39 +125,48 @@ class DSS(SKDict): ...@@ -135,39 +125,48 @@ class DSS(SKDict):
assert len(lens) == 1 assert len(lens) == 1
return lens[0] return lens[0]
@property
def dtype(self):
dtypes = list(set(val.dtype for val in self.values()))
assert len(dtypes) == 1
return dtypes[0]
@property
def dims(self):
dimss = list(set(val.ndim for val in self.values()))
assert len(dimss) == 1
return dimss[0]
@property
def shape(self):
shapes = list(set(val.shape for val in self.values()))
if len(shapes) > 1:
assert set(map(len, shapes)) == {self.dims}
return tuple(s[0] if len(s) == 1 else None for s in map(list, map(set, zip(*shapes))))
return shapes[0]
def fuse(self, *keys, **kwargs): def fuse(self, *keys, **kwargs):
op = kwargs.pop("op", np.concatenate) op = kwargs.pop("op", np.concatenate)
assert not kwargs assert not kwargs
return self.zip(*( return self.zip(*(self[self.keyify(key)] for key in keys)).map(op)
self[self.keyify(key)] for key in keys
)).map(op)
def split(self, thresh, right=False, rng=np.random): def split(self, thresh, right=False, rng=np.random):
if isinstance(thresh, int): if isinstance(thresh, int):
thresh = np.linspace(0, 1, num=thresh + 1)[1:-1] thresh = np.linspace(0, 1, num=thresh + 1)[1:-1]
if isinstance(thresh, float): if isinstance(thresh, float):
thresh = thresh, thresh = (thresh,)
thresh = np.array(thresh) thresh = np.array(thresh)
assert np.all((0 < thresh) & (thresh < 1)) assert np.all((0 < thresh) & (thresh < 1))
idx = np.digitize(rng.uniform(size=self.blen), thresh, right=right) idx = np.digitize(rng.uniform(size=self.blen), thresh, right=right)
return tuple( return tuple(self.map(itemgetter(idx == i)) for i in range(len(thresh) + 1))
self.map(itemgetter(idx == i))
for i in range(len(thresh) + 1)
)
def shuffle(self, rng=np.random): def shuffle(self, rng=np.random):
return self.map(itemgetter(rng.permutation(self.blen))) return self.map(itemgetter(rng.permutation(self.blen)))
def gen_feed_dict(self, tensor2key, batch_size=1024): def gen_feed_dict(self, tensor2key, batch_size=1024):
for sli in self.batch_slices(batch_size): for sli in self.batch_slices(batch_size):
buf = { buf = {key: self[key][sli] for key in set(tensor2key.values())}
key: self[key][sli] yield {tensor: buf[key] for tensor, key in tensor2key.items()}
for key in set(tensor2key.values())
}
yield {
tensor: buf[key]
for tensor, key in tensor2key.items()
}
def batch_slices(self, batch_size): def batch_slices(self, batch_size):
for i in range(0, self.blen, batch_size): for i in range(0, self.blen, batch_size):
...@@ -191,9 +190,7 @@ class DSS(SKDict): ...@@ -191,9 +190,7 @@ class DSS(SKDict):
getter = itemgetter(x, y, w) getter = itemgetter(x, y, w)
train, valid = self["train"], self["valid"] train, valid = self["train"], self["valid"]
return dict( return dict(
zip(["x", "y", "sample_weight"], getter(train)), zip(["x", "y", "sample_weight"], getter(train)), validation_data=getter(valid), **kwargs
validation_data=getter(valid),
**kwargs
) )
def balanced(self, *keys, **kwargs): def balanced(self, *keys, **kwargs):
...@@ -207,18 +204,17 @@ class DSS(SKDict): ...@@ -207,18 +204,17 @@ class DSS(SKDict):
s = np.sum(s.values()) s = np.sum(s.values())
sums[key] = s sums[key] = s
ref = kref(sums.values()) if callable(kref) else sums[kref] ref = kref(sums.values()) if callable(kref) else sums[kref]
return self.__class__({ return self.__class__({k: self[k].map(lambda x: x * (ref / s)) for k, s in sums.items()})
k: self[k].map(lambda x: x * (ref / s))
for k, s in sums.items()
})
@classmethod @classmethod
def from_npy(cls, dir, sep="_", **kwargs): def from_npy(cls, dir, sep="_", **kwargs):
return cls({ return cls(
tuple(fn[:-4].split(sep)): np.load(path.join(dir, fn), **kwargs) {
for fn in listdir(dir) tuple(fn[:-4].split(sep)): np.load(path.join(dir, fn), **kwargs)
if fn.endswith(".npy") for fn in listdir(dir)
}) if fn.endswith(".npy")
}
)
def to_npy(self, dir, sep="_", **kwargs): def to_npy(self, dir, sep="_", **kwargs):
for key, value in self.items(): for key, value in self.items():
......
This diff is collapsed.
# -*- coding: utf-8 -*-
import io
import tensorflow as tf
import itertools
import numpy as np
from matplotlib import pyplot as plt
from sklearn.metrics import confusion_matrix
class Quadrature:
def __init__(self, n):
self.n = n
self.cols = np.ceil(np.sqrt(n)).astype(int)
self.rows = np.ceil(n / self.cols).astype(int)
def lenghts(self):
return self.rows, self.cols
def index(self, n):
row = int(n / self.cols)
col = n - row * self.cols
return row, col
def figure_confusion_matrix(
truth,
prediction,
class_names=["signal", "background"],
sample_weight=None,
normalize="true",
**kwargs
):
assert len(class_names) == truth.shape[-1] == prediction.shape[-1]
fig, ax = plt.subplots()
cm = confusion_matrix(
np.argmax(truth, axis=-1),
np.argmax(prediction, axis=-1),
sample_weight=sample_weight,
normalize=normalize,
)
cmap = "plasma" if normalize == "true" else "viridis"
im = ax.imshow(cm, interpolation="nearest", cmap=cmap)
plt.title("Confusion matrix")
fig.colorbar(im, ax=ax)
tick_marks = np.arange(len(class_names))
plt.xticks(tick_marks, class_names, rotation=45)
plt.yticks(tick_marks, class_names)
for i, j in itertools.product(range(cm.shape[0]), range(cm.shape[1])):
plt.text(j, i, np.around(cm[i, j], decimals=2), horizontalalignment="center", size=7)
plt.ylabel("True label")
plt.xlabel("Predicted label")
fig.tight_layout()
return fig
def figure_activations(activations, class_names=None):
bins = np.linspace(0, 1.0, 10)
n_b, n_p = activations.shape
fig = plt.figure()
for i in range(n_p):
plt.hist(
activations[:, i],
bins,
histtype=u"step",
density=True,
label="%i" % i if class_names is None else class_names[i],
)
plt.yscale("log")
plt.legend()
fig.tight_layout()
return fig
def figure_node_activations(activations, truth, class_names=None):
n_b, n_p = activations.shape
quad = Quadrature(n_p)
rows, cols = quad.lenghts()
fig, ax = plt.subplots(rows, cols, figsize=(15, 15 * rows / cols))
bins = np.linspace(0, 1.0, 10)
process_activations = []
for process in range(n_p):
process_activations.append(activations[truth[:, process]].swapaxes(0, 1))
for node in range(n_p):
for process in range(n_p):
ax[quad.index(node)].hist(
process_activations[process][node],
bins,
histtype=u"step",
density=True,
label="%i" % process if class_names is None else class_names[process],
)
ax[quad.index(node)].set_yscale("log")
ax[quad.index(cols - 1)].legend(bbox_to_anchor=(1.05, 1.0), loc="upper left")
fig.tight_layout()
return fig
def figure_to_image(figure):
"""Converts the matplotlib plot specified by 'figure' to a PNG image and
returns it. The supplied figure is closed and inaccessible after this call."""
# Save the plot to a PNG in memory.
buf = io.BytesIO()
plt.savefig(buf, format="png")
# Closing the figure prevents it from being displayed directly inside
# the notebook.
plt.close(figure)
buf.seek(0)
# Convert PNG buffer to TF image
image = tf.image.decode_png(buf.getvalue(), channels=4)
# Add the batch dimension
image = tf.expand_dims(image, 0)
return image
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment