Skip to content
Snippets Groups Projects
Commit 1887a768 authored by Dennis Noll's avatar Dennis Noll
Browse files

[kers] init plot assignment

parent 6565d783
No related branches found
No related tags found
No related merge requests found
...@@ -614,6 +614,24 @@ class PlotMulticlassEval(PlotMulticlass): ...@@ -614,6 +614,24 @@ class PlotMulticlassEval(PlotMulticlass):
self.make_eval_plots(0) self.make_eval_plots(0)
class PlotAssignment(PlotMulticlass):
def __init__(
self,
x,
y,
**kwargs,
):
super().__init__(x, y, **kwargs)
n_objects = self.y.shape[1]
n_flavours = 3
self.y = self.y.reshape(-1, n_flavours)
self.sample_weight_flat = (self.y * np.array([1 / 0.81, 1 / 0.15, 1 / 0.037])).sum(axis=-1)
@cached_property
def prediction(self):
return self.model.predict(self.x, batch_size=4096).reshape(-1, 3)
class PlotLBN(PlottingCallback, TFSummaryCallback): class PlotLBN(PlottingCallback, TFSummaryCallback):
def __init__(self, lbn_layer=None, inp_particle_names=None, *args, **kwargs): def __init__(self, lbn_layer=None, inp_particle_names=None, *args, **kwargs):
self.lbn = lbn_layer self.lbn = lbn_layer
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment