Commit 7f931a6e authored by Dennis Noll's avatar Dennis Noll
Browse files

[tasks] sync: can now specify categories for sync + group_tree (incl)

parent 61a9dc59
...@@ -21,7 +21,21 @@ from tasks.plotting import PlotHistsBase ...@@ -21,7 +21,21 @@ from tasks.plotting import PlotHistsBase
import utils.hist as histutils import utils.hist as histutils
class SyncSelectionUpload(RecipeMixin, ConfigTask): class Sync(ConfigTask):
@cached_property
def sync_config(self):
return self.config_inst.get_aux("sync", {})
@cached_property
def lookup_set(self):
return self.sync_config.get("lookup_set", {})
@cached_property
def categories(self):
return self.sync_config.get("categories", None)
class SyncSelectionUpload(Sync, RecipeMixin):
upload_identifier = luigi.Parameter(default="") upload_identifier = luigi.Parameter(default="")
def requires(self): def requires(self):
...@@ -30,10 +44,6 @@ class SyncSelectionUpload(RecipeMixin, ConfigTask): ...@@ -30,10 +44,6 @@ class SyncSelectionUpload(RecipeMixin, ConfigTask):
def output(self): def output(self):
return self.wlcg_target(f"{self.upload_identifier}sync.root", fs="wlcg_fs_public") return self.wlcg_target(f"{self.upload_identifier}sync.root", fs="wlcg_fs_public")
@cached_property
def lookup_set(self):
return self.config_inst.get_aux("sync", {})
def lookup(self, old_key): def lookup(self, old_key):
key = old_key key = old_key
for (old, new) in self.lookup_set: for (old, new) in self.lookup_set:
...@@ -60,7 +70,7 @@ class SyncFile(ExternalTask): ...@@ -60,7 +70,7 @@ class SyncFile(ExternalTask):
return law.LocalFileTarget(self.filepath) return law.LocalFileTarget(self.filepath)
class SyncSelection(PlotHistsBase, ConfigTask, PoolMap): class SyncSelection(PlotHistsBase, Sync, PoolMap):
files = law.CSVParameter(default=[]) files = law.CSVParameter(default=[])
own = luigi.BoolParameter() own = luigi.BoolParameter()
...@@ -140,8 +150,6 @@ class SyncSelection(PlotHistsBase, ConfigTask, PoolMap): ...@@ -140,8 +150,6 @@ class SyncSelection(PlotHistsBase, ConfigTask, PoolMap):
# default binning # default binning
binning = (100, -400, +400) binning = (100, -400, +400)
# currently bug in boost-histogram: https://github.com/scikit-hep/boost-histogram/issues/621
# optimally would use one histogram - as patch using two until fixed
compare0 = hist.Hist( compare0 = hist.Hist(
hist.axis.StrCategory(list(self.data.keys())[:1], name="group"), hist.axis.StrCategory(list(self.data.keys())[:1], name="group"),
hist.axis.Regular(*binning, name="variable"), hist.axis.Regular(*binning, name="variable"),
...@@ -203,7 +211,12 @@ class SyncSelection(PlotHistsBase, ConfigTask, PoolMap): ...@@ -203,7 +211,12 @@ class SyncSelection(PlotHistsBase, ConfigTask, PoolMap):
self.mask_keys self.mask_keys
self.feature_keys self.feature_keys
work = [(m, v) for m in self.mask_keys for v in self.feature_keys] work = [
(m, v)
for m in self.mask_keys
for v in self.feature_keys
if self.categories is None or m in self.categories
]
for (mask_key, feature_key), metric in self.pmap( for (mask_key, feature_key), metric in self.pmap(
self.sync, self.sync,
......
...@@ -651,6 +651,33 @@ class TreeExporter(ArrayExporter): ...@@ -651,6 +651,33 @@ class TreeExporter(ArrayExporter):
output = "sync.root" output = "sync.root"
dtype = None dtype = None
groups = {}
@classmethod
def group_tree(cls, outtree):
# add regrouped categories
cats = [c for c in outtree.keys() if c.startswith("is_")]
regroups = []
for dst, pat in cls.groups.items():
if dst is None:
continue
cats_new = set()
for cat in cats:
rem = re.sub(pat, dst, cat)
if rem != cat:
assert rem not in cats, f"new category {rem} already existing"
cats_new.add(rem)
regroups.append((dst, pat, cats.copy()))
cats.extend(sorted(cats_new))
for dst, pat, cats in tqdm(regroups, desc="regroup", leave=False):
for cat in tqdm(cats, unit="category", leave=False):
rem = re.sub(pat, dst, cat)
if rem != cat:
outtree[rem] = np.logical_or(outtree.get(rem, 0), outtree[cat])
return outtree
@classmethod @classmethod
def arrays_to_tree(cls, arrays, target, vars, **kwargs): def arrays_to_tree(cls, arrays, target, vars, **kwargs):
import uproot import uproot
...@@ -658,6 +685,9 @@ class TreeExporter(ArrayExporter): ...@@ -658,6 +685,9 @@ class TreeExporter(ArrayExporter):
with uproot.recreate(target.path) as file: with uproot.recreate(target.path) as file:
tree = defaultdict(list) tree = defaultdict(list)
for cat, cat_array in arrays.items(): for cat, cat_array in arrays.items():
# remove broken categories
if re.search(cls.groups.get(None, "x^"), cat):
continue
for var, _var_array in cat_array.items(): for var, _var_array in cat_array.items():
if var not in cls.tensors.fget(cls).keys(): if var not in cls.tensors.fget(cls).keys():
continue continue
...@@ -679,6 +709,7 @@ class TreeExporter(ArrayExporter): ...@@ -679,6 +709,7 @@ class TreeExporter(ArrayExporter):
outtree = {} outtree = {}
for v, k in tree.items(): for v, k in tree.items():
outtree[v] = np.concatenate(k, axis=-1) outtree[v] = np.concatenate(k, axis=-1)
outtree = cls.group_tree(outtree)
file["tree"] = uproot.newtree({n: v.dtype for n, v in outtree.items()}) file["tree"] = uproot.newtree({n: v.dtype for n, v in outtree.items()})
file["tree"].extend(outtree) file["tree"].extend(outtree)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment