Commit fb1793b5 authored by Dennis Noll's avatar Dennis Noll
Browse files

[multiclass] model+tensors: can now use "blackout" to hide parts of tensors

parent e6f14144
......@@ -2,50 +2,52 @@
from typing import List, Mapping
import uuid
import numpy as np
import tensorflow as tf
from tools import keras as tk
from utils import keras as uk
from lbn import LBN
def build_model(
multiclass_inst,
input_shapes: Mapping[str, tuple] = None,
output_shapes=None,
normal_ref: Mapping[str, np.array] = {},
parts: List[str] = [],
) -> tf.keras.Model:
"""Function to build (possibly) complex tf.keras model.
Args:
input_shapes (Mapping[str, tuple], optional): Dict of input shapes, used to define input tensors of model. Defaults to None.
output_shapes ([type], optional): Output shapes, used to define output tensors of model. Defaults to None.
normal_ref (Mapping[str, np.array], optional): Dict of numpy arrays used to normalize input tensors. Defaults to {}.
parts (List[str], optional): List of strings used to identify "particles" which are used for in the LBN. Defaults to [].
Raises:
NotImplementedError: Currently only one output tensor is implemented.
Returns:
tf.keras.Model: Built tf.keras model which is ready to be used.
"""
if len(output_shapes) > 1:
raise NotImplementedError
tensors = multiclass_inst.tensors
inputs = {name: tf.keras.layers.Input(shape, name=name) for name, shape in input_shapes.items()}
x = dict(inputs)
used_inputs = {}
for key, val in inputs.items():
used_inputs[key] = tk.Blackout(
ref=normal_ref[key], slic=tensors[key][-1].get("blackout", [])
)(val)
x = dict(used_inputs)
x["lep"] = tk.Normal(ref=normal_ref["lep"], axis=0, const=np.s_[:, [-2, -1]])(x["lep"])
x["lep"] = tf.keras.layers.Concatenate()(
[tk.LL()(x["lep"]), uk.OneHotPDGID()(x["lep"]), uk.OneHotCharge()(x["lep"])]
[
tk.LL()(x["lep"]),
uk.OneHotPDGID()(x["lep"]),
uk.OneHotCharge()(x["lep"]),
]
)
x["jet"] = tk.Normal(ref=normal_ref["jet"], axis=0, ignore_zeros=True)(x["jet"])
x["nu"] = tk.Normal(ref=normal_ref["nu"], axis=0)(x["nu"])
x["met"] = tk.Normal(ref=normal_ref["met"], axis=0)(x["met"])
x["nu"] = tk.Normal(ref=normal_ref["nu"], axis=0, ignore_zeros=True)(x["nu"])
x["met"] = tk.Normal(ref=normal_ref["met"], axis=0, ignore_zeros=True)(x["met"])
x["hl"] = tk.Normal(ref=normal_ref["hl"], axis=0, ignore_zeros=True)(x["hl"])
x["param"] = uk.OneHotYear()(x["param"])
x = [tf.keras.layers.Flatten()(_x) for _x in x.values()]
if multiclass_inst.use_lbn:
particles = {part: inputs[part] for part in parts}
# add dimension for met
particles = {
key: used_inputs[key] for key, val in tensors.items() if val[-1].get("part", [])
}
# add dimension for met and nu
if "met" in particles:
particles["met"] = tf.keras.backend.expand_dims(particles["met"], axis=-2)
if "nu" in particles:
......
......@@ -137,8 +137,8 @@ class Base:
"jet": ("analysis_jets", 6, common + ("btagDeepFlavB",), np.float32, {"groups": ["multiclass", "eval", "input", "part"]}),
"fat": ("good_fat_bjets", 1, common + fat_jet_vars, np.float32, {"groups": ["multiclass", "eval", "input", "part"]}),
"met": ("metp4", 0, common, np.float32, {"groups": ["multiclass", "eval", "input", "part"]}),
"nu": ("neutrino", 0, common, np.float32, {"groups": ["multiclass", "input", "part"]}),
"hl": (None, 0, hl, np.float32, {"groups": ["multiclass", "eval", "input"]}),
"nu": ("neutrino", 0, common, np.float32, {"groups": ["multiclass", "input", "part"], "blackout": np.s_[0:]}),
"hl": (None, 0, hl, np.float32, {"groups": ["multiclass", "eval", "input"], "blackout": np.s_[30:]}),
"param": (None, 0, param, np.float32, {"groups": ["multiclass", "eval", "input"]}),
"eventnr": (None, 0, ["eventnr"], np.int64, {"groups": ["multiclass", "eval", "split"]}),
"procid": (None, 0, ["procid"], np.int32, {"groups": ["multiclass", "class"]}),
......
Supports Markdown
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment