diff --git a/keras.py b/keras.py
index b13a8aff9a8cbb12fb2fa0f5a9dc8726780dc5f9..f2ed438af67bc92250a768baab3b6b67ea7ad4ef 100644
--- a/keras.py
+++ b/keras.py
@@ -1,11 +1,11 @@
 # from itertools import izip
 import itertools
+from functools import cached_property
 import gc
 from collections import OrderedDict, defaultdict
 import fnmatch
 import math
 import datetime
-import nvidia_smi
 
 import numpy as np
 import tensorflow as tf
@@ -40,14 +40,14 @@ from .plotting import (
 
 
 def kVar(*args, **kwargs):
-    """ produce a keras-tracked tf.Variable from all given parameters """
+    """produce a keras-tracked tf.Variable from all given parameters"""
     var = tf.Variable(*args, **kwargs)
     track_variable(var)
     return var
 
 
 def kOpt(opt, **kwargs):
-    """ instanciate a keras Optimizer with all applicable **kwargs """
+    """instanciate a keras Optimizer with all applicable **kwargs"""
     if not callable(opt):
         opt = getattr(tf.keras.optimizers, opt)
     assert issubclass(opt, tf.keras.optimizers.Optimizer)
@@ -56,14 +56,14 @@ def kOpt(opt, **kwargs):
 
 
 def kInput(ref, **kwargs):
-    """ produce a keras.Input with shape & dtype according to ref """
+    """produce a keras.Input with shape & dtype according to ref"""
     kwargs.setdefault("shape", ref.shape[1:])
     kwargs.setdefault("dtype", ref.dtype)
     return tf.keras.Input(**kwargs)
 
 
 def keras_register_custom_object(obj):
-    """ decorator for globally registering a custom object with keras """
+    """decorator for globally registering a custom object with keras"""
     tf.keras.utils.get_custom_objects()[obj.__name__] = obj
     return obj
 
@@ -219,7 +219,7 @@ def Onehot(index, n, name=None):
 
 class Moment(tf.keras.metrics.Mean):
     def __init__(self, order, label=False, **kwargs):
-        """ Metric calculating the order-th moment """
+        """Metric calculating the order-th moment"""
         assert order == int(order)
         assert label == bool(label)
         kwargs.setdefault("name", "%smom%d" % ("l" if label else "", order))
@@ -255,24 +255,30 @@ def _patfilter(pattern, items):
     return filter(re.compile(pattern).search, items)
 
 
-class GPUStats(tf.keras.callbacks.Callback):
-    """
-    conda: conda install -c fastai nvidia-ml-py3
-    pip  : pip install nvidia-ml-py3
-    """
+try:
+    import nvidia_smi
+except ImportError:
+    pass
+else:
+
+    class GPUStats(tf.keras.callbacks.Callback):
+        """
+        conda: conda install -c fastai nvidia-ml-py3
+        pip  : pip install nvidia-ml-py3
+        """
 
-    def __init__(self, idx=0):
-        nvidia_smi.nvmlInit()
-        self.handle = nvidia_smi.nvmlDeviceGetHandleByIndex(idx)
+        def __init__(self, idx=0):
+            nvidia_smi.nvmlInit()
+            self.handle = nvidia_smi.nvmlDeviceGetHandleByIndex(idx)
 
-    def on_x_end(self, x, logs=None):
-        self.mem = nvidia_smi.nvmlDeviceGetMemoryInfo(self.handle)
-        self.res = nvidia_smi.nvmlDeviceGetUtilizationRates(self.handle)
-        logs["GPU-Usage [%]"] = self.res.gpu
-        logs["GPU-vRAM [%]"] = 100 * self.mem.used / self.mem.total
-        logs["GPU-vRAM [MiB]"] = self.mem.used / (1024 ** 2)
+        def on_x_end(self, x, logs=None):
+            self.mem = nvidia_smi.nvmlDeviceGetMemoryInfo(self.handle)
+            self.res = nvidia_smi.nvmlDeviceGetUtilizationRates(self.handle)
+            logs["GPU-Usage [%]"] = self.res.gpu
+            logs["GPU-vRAM [%]"] = 100 * self.mem.used / self.mem.total
+            logs["GPU-vRAM [MiB]"] = self.mem.used / (1024 ** 2)
 
-    on_batch_end = on_x_end
+        on_batch_end = on_x_end
 
 
 class Moment2Std(tf.keras.callbacks.Callback):
@@ -447,7 +453,7 @@ class PlotMulticlass(PlottingCallback, TFSummaryCallback):
         self,
         x,
         y,
-        sample_weight=None,
+        sample_weight=(None,),
         class_names=["signal", "background"],
         to_file=False,
         columns=None,
@@ -521,13 +527,16 @@ class PlotMulticlass(PlottingCallback, TFSummaryCallback):
         plt.close("all")
         gc.collect()
 
+    @cached_property
+    def prediction(self):
+        return self.model.predict(self.x, batch_size=4096)
+
     def make_eval_plots(self, epoch=0):
-        prediction = self.model.predict(self.x, batch_size=4096)
         imgs = {}
 
         fig = figure_roc_curve(
             self.y,
-            prediction,
+            self.prediction,
             class_names=self.class_names,
             sample_weight=self.sample_weight_flat,
         )
@@ -536,7 +545,7 @@ class PlotMulticlass(PlottingCallback, TFSummaryCallback):
 
         fig = figure_confusion_matrix(
             self.y,
-            prediction,
+            self.prediction,
             class_names=self.class_names,
             sample_weight=self.sample_weight_flat,
             normalize="true",
@@ -546,7 +555,7 @@ class PlotMulticlass(PlottingCallback, TFSummaryCallback):
 
         fig = figure_confusion_matrix(
             self.y,
-            prediction,
+            self.prediction,
             class_names=self.class_names,
             sample_weight=self.sample_weight_flat,
             normalize="pred",
@@ -556,7 +565,7 @@ class PlotMulticlass(PlottingCallback, TFSummaryCallback):
 
         if self.plot_activations:
             fig = figure_node_activations(
-                prediction,
+                self.prediction,
                 self.y,
                 class_names=self.class_names,
                 disjoint=True,
@@ -610,6 +619,24 @@ class PlotMulticlassEval(PlotMulticlass):
         self.make_eval_plots(0)
 
 
+class PlotAssignment(PlotMulticlass):
+    def __init__(
+        self,
+        x,
+        y,
+        **kwargs,
+    ):
+        super().__init__(x, y, **kwargs)
+        n_objects = self.y.shape[1]
+        n_flavours = 3
+        self.y = self.y.reshape(-1, n_flavours)
+        self.sample_weight_flat = (self.y * np.array([1 / 0.81, 1 / 0.15, 1 / 0.037])).sum(axis=-1)
+
+    @cached_property
+    def prediction(self):
+        return self.model.predict(self.x, batch_size=4096).reshape(-1, 3)
+
+
 class PlotLBN(PlottingCallback, TFSummaryCallback):
     def __init__(self, lbn_layer=None, inp_particle_names=None, *args, **kwargs):
         self.lbn = lbn_layer
diff --git a/numpy.py b/numpy.py
index 356f2a3b0a61f351f5b7f0abbc19f217021ec1ca..dcad6f2bd6da140f09874bb468a15b031d8ad9b4 100644
--- a/numpy.py
+++ b/numpy.py
@@ -3,6 +3,12 @@
 import numpy as np
 
 
+def set_thresh(values, thresh=1e-5):
+    val = values.copy()
+    val[val < thresh] = 0
+    return val
+
+
 def one_hot(a, n=None):
     if n is None:
         n = a.max() + 1