Commit df5d4f52 authored by Dennis Noll's avatar Dennis Noll
Browse files

[multiclass] model: adds default build_model

parent ca8bf2db
# -*- coding=utf-8 -*-
from typing import List, Mapping
import uuid
import numpy as np
import tensorflow as tf
from tools import keras as tk
from utils import keras as uk
def build_model(
multiclass_inst,
input_shapes: Mapping[str, tuple] = None,
output_shapes=None,
normal_ref: Mapping[str, np.array] = {},
parts: List[str] = [],
) -> tf.keras.Model:
"""Function to build (possibly) complex tf.keras model.
Args:
input_shapes (Mapping[str, tuple], optional): Dict of input shapes, used to define input tensors of model. Defaults to None.
output_shapes ([type], optional): Output shapes, used to define output tensors of model. Defaults to None.
normal_ref (Mapping[str, np.array], optional): Dict of numpy arrays used to normalize input tensors. Defaults to {}.
parts (List[str], optional): List of strings used to identify "particles" which are used for in the LBN. Defaults to [].
Raises:
NotImplementedError: Currently only one output tensor is implemented.
Returns:
tf.keras.Model: Built tf.keras model which is ready to be used.
"""
if len(output_shapes) > 1:
raise NotImplementedError
inputs = {name: tf.keras.layers.Input(shape, name=name) for name, shape in input_shapes.items()}
x = dict(inputs)
x["lep"] = tk.Normal(ref=normal_ref["lep"], axis=0, const=np.s_[:, [-2, -1]])(x["lep"])
x["lep"] = tf.keras.layers.Concatenate()(
[tk.LL()(x["lep"]), uk.OneHotPDGID()(x["lep"]), uk.OneHotCharge()(x["lep"])]
)
x["jet"] = tk.Normal(ref=normal_ref["jet"], axis=0, ignore_zeros=True)(x["jet"])
x["nu"] = tk.Normal(ref=normal_ref["nu"], axis=0)(x["nu"])
x["met"] = tk.Normal(ref=normal_ref["met"], axis=0)(x["met"])
x["hl"] = tk.Normal(ref=normal_ref["hl"], axis=0, ignore_zeros=True)(x["hl"])
x["param"] = uk.OneHotYear()(x["param"])
x = [tf.keras.layers.Flatten()(_x) for _x in x.values()]
if multiclass_inst.use_lbn:
particles = {part: inputs[part] for part in parts}
# add dimension for met
if "met" in particles:
particles["met"] = tf.keras.backend.expand_dims(particles["met"], axis=-2)
if "nu" in particles:
particles["nu"] = tf.keras.backend.expand_dims(particles["nu"], axis=-2)
x += [
tk.LBNLayer(
multiclass_inst.lbn_particles,
LBN.PAIRS,
features=["E", "px", "py", "pz", "pt", "p", "m", "pair_cos"],
)(particles.values())
]
x = x[0] if len(x) == 1 else tf.keras.layers.Concatenate()(x)
Network = getattr(tk, multiclass_inst.network_architecture)
x = Network(
block_size=multiclass_inst.block_size,
jump=multiclass_inst.jump,
layers=multiclass_inst.layers,
nodes=multiclass_inst.nodes,
activation=multiclass_inst.activation,
batch_norm=multiclass_inst.batch_norm,
dropout=multiclass_inst.dropout,
l2=multiclass_inst.l2,
)(x)
x = tf.keras.layers.Dense(output_shapes[0][0], activation="softmax", name="output")(x)
model = tf.keras.Model(
inputs=inputs.values(),
outputs=x,
name=f"model_{multiclass_inst.split}_{uuid.uuid4().hex[0:10]}",
)
return model
Supports Markdown
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment