Commit f9092601 authored by Dennis Noll's avatar Dennis Noll
Browse files

[tasks] sync: now uses coffea hists, includes stat tests+plotting

parent b57e7f1d
# coding: utf-8
from tasks.base import ConfigTask
from functools import cached_property
from utils.util import DotDict
import itertools
import numpy as np
import scipy
import uproot
import law
from law.task.base import ExternalTask
import luigi
import boost_histogram as bh
import hist
from tasks.base import BaseTask, ConfigTask
from tasks.mixins import PGroupMixin, RecipeMixin
from tasks.coffea import CoffeaProcessor
from tasks.plotting import PlotHistsBase
class SyncTask(RecipeMixin, ConfigTask):
def requires(self):
return CoffeaProcessor.req(self, processor="SyncSelection", debug=True)
return CoffeaProcessor.req(self, processor="SyncSelectionExporter", debug=True)
class SyncSelectionUpload(SyncTask):
......@@ -20,95 +33,154 @@ class SyncSelectionUpload(SyncTask):
class SyncSelectionPlots(SyncTask):
class SyncFile(ExternalTask):
filepath = luigi.Parameter()
def output(self):
return self.local_target("plots")
if self.filepath:
return law.LocalFileTarget(self.filepath)
def reduce_along_axis(func, axis, h):
axes = np.array(h.axes, dtype=object)
axes = np.delete(axes, axis)
hnew = hist.Hist(*axes)
hnew.view()[:] = np.apply_along_axis(func, axis, h.view())
return hnew
def _kstest(arr):
return scipy.stats.kstest(arr[:, 0], arr[:, 1]).statistic
def kstest(h1, h2, compare_axis, feature_axis):
assert feature_axis == -1, "Other axes not implemented yet"
axes1 = np.array(h1.axes, dtype=object)
axes2 = np.array(h2.axes, dtype=object)
# currently only works for StrCategory
compared_axis = [
f"{a}_{b}" for (a, b) in itertools.product(axes1[compare_axis], axes2[compare_axis])
new_axes = np.copy(axes1)
new_axes[compare_axis] = type(axes1[compare_axis])(compared_axis)
new_axes = np.delete(new_axes, feature_axis)
hnew = hist.Hist(*new_axes)
a1, a2 = h1.view(), h2.view()
c, d = np.broadcast_arrays(
np.expand_dims(a1, compare_axis), np.expand_dims(a2, compare_axis + 1)
s = np.stack([c, d], axis=-2)
# reshape
o = np.array(s.shape)
o[compare_axis] = o[compare_axis] * o[compare_axis + 1]
o = np.delete(o, compare_axis + 1)
s = np.reshape(s, o)
# perform test
kstest = np.vectorize(_kstest, signature="(2,n)->()")
hnew.view()[:] = kstest(s)
return hnew
class SyncSelection(PlotHistsBase, ConfigTask):
files = law.CSVParameter(default=[])
own = luigi.BoolParameter()
def _files(self):
out = {}
for i, f in enumerate(self.files):
if ":" in f:
name, file = f.split(":")
name, file = str(i), f
out[name] = file
return out
def requires(self):
req = {}
if self.own:
{"own": CoffeaProcessor.req(self, processor="SyncSelectionExporter", debug=True)}
req.update({name: SyncFile.req(self, filepath=file) for name, file in self._files.items()})
return req
def data(self):
inps = self.input()
assert len(inps) > 0, "No inputs defined, use `--own` and/or `--files`"
out = {}
for key, target in inps.items():
tree =["tree"]
keys, size = tree.keys(), len(tree["eventnr_eventnr"].array())
dtype = [(k.decode(), "i4" if self.is_mask_key(k.decode()) else "f4") for k in keys]
arr = np.zeros(shape=(size,), dtype=dtype)
for name in arr.dtype.names:
arr[name] = tree[name].array()
out[key] = arr
return out
def keys(self):
return list([0].dtype.names
def is_mask_key(self, key):
return key.startswith("is_")
def mask_keys(self):
return [k for k in self.keys if self.is_mask_key(k)]
def feature_keys(self):
return [k for k in self.keys if k not in self.mask_keys]
def output(self):
return {
"plots": self.local_target("plots"),
"statistical_tests": self.local_target("statistical_tests.json"),
lut = {"lep0_energy": "lep_pt"}
def run(self):
raise NotImplementedError
# this part of the sync code does currently not run and is not generic
# will be fixed when needed next time
regions = ["sr", "fr"]
el = []
output = "sync.root"
from collections import defaultdict
import json
import law
import uproot
from matplotlib.backends.backend_pdf import PdfPages
intersect = False
tallinn =[f"{self.tree_id}_SR"]
louvain =[f"{self.tree_id}_SR"]
aachen =[f"{self.tree_id}_SR"]
values = defaultdict(list)
sync_eventnr = {}
with PdfPages(f"{self.local}/sync.pdf") as pdf:
for ch in self.channels:
ch = self.lut.get(ch, ch)
louvain_mask = self.mask_channel(louvain, channel=ch)
tallinn_mask = self.mask_channel(tallinn, channel=ch)
aachen_mask = self.mask_channel(aachen, channel=ch)
louvain_eventnr = louvain["event"].array()[louvain_mask]
tallinn_eventnr = tallinn["event"].array()[tallinn_mask]
aachen_eventnr = aachen["event"].array()[aachen_mask]
# intersection
_, aachen_ind_intersect, louvain_ind_intersect = np.intersect1d(
# non-intersection
diff_eventnr = np.setdiff1d(
np.union1d(aachen_eventnr, louvain_eventnr),
np.intersect1d(aachen_eventnr, louvain_eventnr),
_, aachen_ind_diff, __ = np.intersect1d(
statistical_tests = {}
for mask_key in self.mask_keys:
for feature_key in self.feature_keys:
# extract name + binning from variables
var_name = self.lut.get(feature_key, feature_key)
var = self.config_inst.get_variable(var_name)
var_name =
binning = var.binning
except ValueError:
binning = (100, -400, +400)
compare = hist.Hist(
hist.axis.StrCategory(, name="group"),
hist.axis.Regular(*binning, name="variable"),,
print(f"channel {ch}")
stat = (
lambda name, tot, inter: f" {name}: {tot} events ({100 * inter / tot:.2f} %)"
for key, dat in
mask = dat[mask_key].astype(bool)
values = dat[feature_key][mask]
compare.fill(group=key, variable=values)
stack=compare[:1, :],
lines=compare[1:, :],
print(stat("aachen", aachen_eventnr.shape[0], aachen_ind_intersect.shape[0]))
print(stat("louvain", louvain_eventnr.shape[0], louvain_ind_intersect.shape[0]))
for key in aachen.keys():
values, labels = [], []
for data, mask, label in [
(aachen, aachen_mask, "Aachen"),
(louvain, louvain_mask, "Louvain"),
(tallinn, tallinn_mask, "Tallinn"),
v = data[key].array()[mask]
v[v == -9999] = 0
except KeyError:
self.plot(values, pdf, ch=ch, key=key, label=labels)
def plot(self, arrays, pdf, ch="", key="", label=None):
import matplotlib.pyplot as plt
import mplhep as hep
fig, ax = plt.subplots(figsize=(8, 6))
plt.hist(arrays, bins=40, histtype="bar", label=label)
plt.legend(title=f"channel: {ch}")
ax.set_ylabel("# events")
hep.cms.label(llabel="Private Work", ax=ax)
# plot histograms
# statistical tests
hnew = kstest(compare[:1, :], compare[1:, :], 0, -1)
statistical_tests[f"{mask_key}_{feature_key}"] = hnew.view()[0]
class SyncSelectionWrapper(SyncTask):
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment