Commit a16b6b9d authored by Dennis Noll's avatar Dennis Noll
Browse files

[multiclass] model/default: fixes part

parent fb1793b5
...@@ -5,10 +5,12 @@ import uuid ...@@ -5,10 +5,12 @@ import uuid
import numpy as np import numpy as np
import tensorflow as tf import tensorflow as tf
from tools import keras as tk
from utils import keras as uk
from lbn import LBN from lbn import LBN
from utils import keras as uk
from tools import keras as tk
def build_model( def build_model(
multiclass_inst, multiclass_inst,
input_shapes: Mapping[str, tuple] = None, input_shapes: Mapping[str, tuple] = None,
...@@ -45,7 +47,9 @@ def build_model( ...@@ -45,7 +47,9 @@ def build_model(
x = [tf.keras.layers.Flatten()(_x) for _x in x.values()] x = [tf.keras.layers.Flatten()(_x) for _x in x.values()]
if multiclass_inst.use_lbn: if multiclass_inst.use_lbn:
particles = { particles = {
key: used_inputs[key] for key, val in tensors.items() if val[-1].get("part", []) key: used_inputs[key]
for key, val in tensors.items()
if "part" in val[-1].get("groups", [])
} }
# add dimension for met and nu # add dimension for met and nu
if "met" in particles: if "met" in particles:
......
Supports Markdown
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment