Commit 7a3ef5e4 authored by Dennis Noll's avatar Dennis Noll
Browse files

Merge branch 'multiclass_refactoring' into scikit-hep-upgrade

parents e5ecf37e 8712990b
......@@ -15,22 +15,22 @@ from config.util import PrepareConfig
from config.processes import rgb
from config.constants import (
BR_HH_BBTAUTAU,
BR_HH_BBWW_SL,
BR_HH_BBVV_DL,
HHres,
HHxs,
gfHHparams,
gfHHparamsNoSample,
)
from tasks.multiclass import MulticlassConfig
import utils.aci as aci
from config.analysis import analysis
# create the analysis
analysis = analysis.copy(name="diHiggs_sl")
get = analysis.processes.get
# add analysis specific processes
specific_processes = []
......@@ -121,11 +121,8 @@ for i, s in enumerate(signals):
# add all processes now
analysis.processes.extend(specific_processes + specific_signals)
# corrections which are only used for distinct analyses
analysis.aux["non_common_corrections"] = ["Fake", "VJetsCorrections"]
analysis.aux["multiclass_group"] = "mergedinclusive"
analysis.aux["doFakeNonClosureCorrection"] = True
# categories and files used for the sync of yields
......@@ -166,22 +163,6 @@ analysis.aux["process_groups"] = {
],
}
analysis.aux["multiclass_groups"] = {
"mergedinclusive": {
"class_HHGluGlu_NLO": {"groups": ["signal", "constrain", "fit"]},
"class_HHVBF_NLO": {"groups": ["signal", "constrain", "fit"]},
"tt": {"groups": ["background", "constrain"]},
"st": {"groups": ["background"]},
"wjets": {"groups": ["background"]},
"H": {"groups": ["background"]},
"class_other": {"groups": ["background"]},
},
"ggfttbar": {
"ggHH_kl_1_kt_1_2B2WToLNu2J": {"groups": ["signal", "constrain", "fit"]},
"tt": {"groups": ["background", "constrain"]},
},
}
analysis.aux["btag_sf_shifts"] = [
"lf",
"lfstats1",
......@@ -255,6 +236,7 @@ PrepareConfig(
HHReweigthing.update_config(
config=analysis,
root_processes={
"HH_2B2WToLNu2J_GluGlu": BR_HH_BBWW_SL,
"HH_2B2VTo2L2Nu_GluGlu": BR_HH_BBVV_DL,
"HH_2B2Tau_GluGlu": BR_HH_BBTAUTAU,
},
......@@ -276,6 +258,39 @@ HHReweigthing.update_config(
]
),
)
class_HHGluGlu_NLO_reweight = aci.Process(
name="class_HHGluGlu_NLO_reweight",
id=17789287,
label="HH(GGF)",
processes=[
analysis.processes.get("HH_2B2WToLNu2J_GluGlu_reweight"),
analysis.processes.get("HH_2B2VTo2L2Nu_GluGlu_reweight"),
analysis.processes.get("HH_2B2Tau_GluGlu_reweight"),
],
)
analysis.processes.extend([class_HHGluGlu_NLO_reweight])
analysis.aux["multiclass"] = MulticlassConfig(
groups={
"mergedinclusive": {
"class_HHGluGlu_NLO_reweight": {"groups": ["signal", "constrain", "fit"]},
"class_HHVBF_NLO": {"groups": ["signal", "constrain", "fit"]},
"tt": {"groups": ["background", "constrain"]},
"st": {"groups": ["background"]},
"wjets": {"groups": ["background"]},
"H": {"groups": ["background"]},
"class_other": {"groups": ["background"]},
},
"ggfttbar": {
"ggHH_kl_1_kt_1_2B2WToLNu2J": {"groups": ["signal", "constrain", "fit"]},
"tt": {"groups": ["background", "constrain"]},
},
},
group="mergedinclusive",
maxn=2e6,
)
from bbww_sl.config.categories import setup_categories
......
......@@ -13,11 +13,11 @@ def setup_categories(cfg):
("_resolved_1b", "_resolved_2b", "_boosted"): ("_incl", "_incl3"),
(r"_resolved(?!_\db)", "_boosted"): ("_incl", "_incl2"),
}
multiclass_group = cfg.aux["multiclass_groups"][cfg.aux["multiclass_group"]]
multiclass_config = cfg.aux["multiclass"]
multiclass_processes = multiclass_config.groups[multiclass_config.group]
classes = {
f"dnn_node_{cls}": get(cls).label_short if "HH" in cls else get(cls).label
for cls in multiclass_group.keys()
for cls in multiclass_processes.keys()
}
classes.update(
top="top",
......@@ -56,7 +56,7 @@ def setup_categories(cfg):
] + [
f"all_{region}_sr_prompt_dnn_node_{process}"
for region in ["boosted", "resolved_1b", "resolved_2b"]
for process in [p for p, g in multiclass_group.items() if "fit" in g["groups"]]
for process in [p for p, g in multiclass_processes.items() if "fit" in g["groups"]]
]:
cfg.categories.get(c).tags = {"fit"}
if "HH" in c:
......
# -*- coding=utf-8 -*-
from typing import List, Mapping
import uuid
import numpy as np
import tensorflow as tf
from tools import keras as tk
from utils import keras as uk
def build_model(
multiclass_inst,
input_shapes: Mapping[str, tuple] = None,
output_shapes=None,
normal_ref: Mapping[str, np.array] = {},
parts: List[str] = [],
) -> tf.keras.Model:
"""Function to build (possibly) complex tf.keras model.
Args:
input_shapes (Mapping[str, tuple], optional): Dict of input shapes, used to define input tensors of model. Defaults to None.
output_shapes ([type], optional): Output shapes, used to define output tensors of model. Defaults to None.
normal_ref (Mapping[str, np.array], optional): Dict of numpy arrays used to normalize input tensors. Defaults to {}.
parts (List[str], optional): List of strings used to identify "particles" which are used for in the LBN. Defaults to [].
Raises:
NotImplementedError: Currently only one output tensor is implemented.
Returns:
tf.keras.Model: Built tf.keras model which is ready to be used.
"""
if len(output_shapes) > 1:
raise NotImplementedError
inputs = {name: tf.keras.layers.Input(shape, name=name) for name, shape in input_shapes.items()}
x = dict(inputs)
x["lep"] = tk.Normal(ref=normal_ref["lep"], axis=0, const=np.s_[:, [-2, -1]])(x["lep"])
x["lep"] = tf.keras.layers.Concatenate()(
[tk.LL()(x["lep"]), uk.OneHotPDGID()(x["lep"]), uk.OneHotCharge()(x["lep"])]
)
x["jet"] = tk.Normal(ref=normal_ref["jet"], axis=0, ignore_zeros=True)(x["jet"])
x["nu"] = tk.Normal(ref=normal_ref["nu"], axis=0)(x["nu"])
x["met"] = tk.Normal(ref=normal_ref["met"], axis=0)(x["met"])
x["hl"] = tk.Normal(ref=normal_ref["hl"], axis=0, ignore_zeros=True)(x["hl"])
x["param"] = uk.OneHotYear()(x["param"])
x = [tf.keras.layers.Flatten()(_x) for _x in x.values()]
if multiclass_inst.use_lbn:
particles = {part: inputs[part] for part in parts}
# add dimension for met
if "met" in particles:
particles["met"] = tf.keras.backend.expand_dims(particles["met"], axis=-2)
if "nu" in particles:
particles["nu"] = tf.keras.backend.expand_dims(particles["nu"], axis=-2)
x += [
tk.LBNLayer(
multiclass_inst.lbn_particles,
LBN.PAIRS,
features=["E", "px", "py", "pz", "pt", "p", "m", "pair_cos"],
)(particles.values())
]
x = x[0] if len(x) == 1 else tf.keras.layers.Concatenate()(x)
Network = getattr(tk, multiclass_inst.network_architecture)
x = Network(
block_size=multiclass_inst.block_size,
jump=multiclass_inst.jump,
layers=multiclass_inst.layers,
nodes=multiclass_inst.nodes,
activation=multiclass_inst.activation,
batch_norm=multiclass_inst.batch_norm,
dropout=multiclass_inst.dropout,
l2=multiclass_inst.l2,
)(x)
x = tf.keras.layers.Dense(output_shapes[0][0], activation="softmax", name="output")(x)
model = tf.keras.Model(
inputs=inputs.values(),
outputs=x,
name=f"model_{multiclass_inst.split}_{uuid.uuid4().hex[0:10]}",
)
return model
......@@ -371,14 +371,11 @@ class DNNBase:
raise NotImplementedError
@property
def dnn_classes(self):
return list(
enumerate(
self.analysis_inst.aux["multiclass_groups"][
self.analysis_inst.aux["multiclass_group"]
].keys()
)
)
def multiclass_processes(self):
multiclass_config = self.analysis_inst.aux["multiclass"]
groups = multiclass_config.groups
group = multiclass_config.group
return groups[group]
def category_variables(self, category):
if "_dnn_node_" in category:
......@@ -405,14 +402,7 @@ class DNNBase:
:, :30
] # HACK: used for running PlotProducer to inspect variables before doing new training
(pm,) = self.dnn(inputs=dnn_inputs).values()
target_shape = (
mask.sum(),
len(
self.analysis_inst.aux["multiclass_groups"][
self.analysis_inst.aux["multiclass_group"]
]
),
)
target_shape = (mask.sum(), len(self.multiclass_processes))
assert pm.shape == target_shape, (pm.shape, target_shape)
assert not np.any(np.isnan(pm)), "DNN produces nan values"
......@@ -420,7 +410,7 @@ class DNNBase:
pred[mask] = pm
pmax = np.argmax(pred, axis=-1)
catkeys = [c for c in categories.keys() if not (set(c.split("_")) & self.skip_cats)]
for i, c in self.dnn_classes:
for i, c in enumerate(self.multiclass_processes.keys()):
m = pmax == i
cn = f"dnn_node_{c}"
selection.add(cn, m)
......
......@@ -929,13 +929,12 @@ class Base(common.NeutrinoBase, common.Base):
eigenvalue1_met = eventshapes_met.eigenvalues[:, 0]
eigenvalue2_met = eventshapes_met.eigenvalues[:, 1]
eigenvalue3_met = eventshapes_met.eigenvalues[:, 2]
year = const_arr(int(self.year))
if dataset_inst.is_mc:
(process,) = dataset_inst.processes.values()
(process,) = dataset_inst.processes.values
procid = const_arr(process.id)
else:
procid = const_arr(Pget("data").id)
procid = const_arr(self.analysis_inst.processes.get("data").id)
yield locals()
......
Supports Markdown
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment