diff --git a/keras.py b/keras.py index b13a8aff9a8cbb12fb2fa0f5a9dc8726780dc5f9..f2ed438af67bc92250a768baab3b6b67ea7ad4ef 100644 --- a/keras.py +++ b/keras.py @@ -1,11 +1,11 @@ # from itertools import izip import itertools +from functools import cached_property import gc from collections import OrderedDict, defaultdict import fnmatch import math import datetime -import nvidia_smi import numpy as np import tensorflow as tf @@ -40,14 +40,14 @@ from .plotting import ( def kVar(*args, **kwargs): - """ produce a keras-tracked tf.Variable from all given parameters """ + """produce a keras-tracked tf.Variable from all given parameters""" var = tf.Variable(*args, **kwargs) track_variable(var) return var def kOpt(opt, **kwargs): - """ instanciate a keras Optimizer with all applicable **kwargs """ + """instanciate a keras Optimizer with all applicable **kwargs""" if not callable(opt): opt = getattr(tf.keras.optimizers, opt) assert issubclass(opt, tf.keras.optimizers.Optimizer) @@ -56,14 +56,14 @@ def kOpt(opt, **kwargs): def kInput(ref, **kwargs): - """ produce a keras.Input with shape & dtype according to ref """ + """produce a keras.Input with shape & dtype according to ref""" kwargs.setdefault("shape", ref.shape[1:]) kwargs.setdefault("dtype", ref.dtype) return tf.keras.Input(**kwargs) def keras_register_custom_object(obj): - """ decorator for globally registering a custom object with keras """ + """decorator for globally registering a custom object with keras""" tf.keras.utils.get_custom_objects()[obj.__name__] = obj return obj @@ -219,7 +219,7 @@ def Onehot(index, n, name=None): class Moment(tf.keras.metrics.Mean): def __init__(self, order, label=False, **kwargs): - """ Metric calculating the order-th moment """ + """Metric calculating the order-th moment""" assert order == int(order) assert label == bool(label) kwargs.setdefault("name", "%smom%d" % ("l" if label else "", order)) @@ -255,24 +255,30 @@ def _patfilter(pattern, items): return filter(re.compile(pattern).search, items) -class GPUStats(tf.keras.callbacks.Callback): - """ - conda: conda install -c fastai nvidia-ml-py3 - pip : pip install nvidia-ml-py3 - """ +try: + import nvidia_smi +except ImportError: + pass +else: + + class GPUStats(tf.keras.callbacks.Callback): + """ + conda: conda install -c fastai nvidia-ml-py3 + pip : pip install nvidia-ml-py3 + """ - def __init__(self, idx=0): - nvidia_smi.nvmlInit() - self.handle = nvidia_smi.nvmlDeviceGetHandleByIndex(idx) + def __init__(self, idx=0): + nvidia_smi.nvmlInit() + self.handle = nvidia_smi.nvmlDeviceGetHandleByIndex(idx) - def on_x_end(self, x, logs=None): - self.mem = nvidia_smi.nvmlDeviceGetMemoryInfo(self.handle) - self.res = nvidia_smi.nvmlDeviceGetUtilizationRates(self.handle) - logs["GPU-Usage [%]"] = self.res.gpu - logs["GPU-vRAM [%]"] = 100 * self.mem.used / self.mem.total - logs["GPU-vRAM [MiB]"] = self.mem.used / (1024 ** 2) + def on_x_end(self, x, logs=None): + self.mem = nvidia_smi.nvmlDeviceGetMemoryInfo(self.handle) + self.res = nvidia_smi.nvmlDeviceGetUtilizationRates(self.handle) + logs["GPU-Usage [%]"] = self.res.gpu + logs["GPU-vRAM [%]"] = 100 * self.mem.used / self.mem.total + logs["GPU-vRAM [MiB]"] = self.mem.used / (1024 ** 2) - on_batch_end = on_x_end + on_batch_end = on_x_end class Moment2Std(tf.keras.callbacks.Callback): @@ -447,7 +453,7 @@ class PlotMulticlass(PlottingCallback, TFSummaryCallback): self, x, y, - sample_weight=None, + sample_weight=(None,), class_names=["signal", "background"], to_file=False, columns=None, @@ -521,13 +527,16 @@ class PlotMulticlass(PlottingCallback, TFSummaryCallback): plt.close("all") gc.collect() + @cached_property + def prediction(self): + return self.model.predict(self.x, batch_size=4096) + def make_eval_plots(self, epoch=0): - prediction = self.model.predict(self.x, batch_size=4096) imgs = {} fig = figure_roc_curve( self.y, - prediction, + self.prediction, class_names=self.class_names, sample_weight=self.sample_weight_flat, ) @@ -536,7 +545,7 @@ class PlotMulticlass(PlottingCallback, TFSummaryCallback): fig = figure_confusion_matrix( self.y, - prediction, + self.prediction, class_names=self.class_names, sample_weight=self.sample_weight_flat, normalize="true", @@ -546,7 +555,7 @@ class PlotMulticlass(PlottingCallback, TFSummaryCallback): fig = figure_confusion_matrix( self.y, - prediction, + self.prediction, class_names=self.class_names, sample_weight=self.sample_weight_flat, normalize="pred", @@ -556,7 +565,7 @@ class PlotMulticlass(PlottingCallback, TFSummaryCallback): if self.plot_activations: fig = figure_node_activations( - prediction, + self.prediction, self.y, class_names=self.class_names, disjoint=True, @@ -610,6 +619,24 @@ class PlotMulticlassEval(PlotMulticlass): self.make_eval_plots(0) +class PlotAssignment(PlotMulticlass): + def __init__( + self, + x, + y, + **kwargs, + ): + super().__init__(x, y, **kwargs) + n_objects = self.y.shape[1] + n_flavours = 3 + self.y = self.y.reshape(-1, n_flavours) + self.sample_weight_flat = (self.y * np.array([1 / 0.81, 1 / 0.15, 1 / 0.037])).sum(axis=-1) + + @cached_property + def prediction(self): + return self.model.predict(self.x, batch_size=4096).reshape(-1, 3) + + class PlotLBN(PlottingCallback, TFSummaryCallback): def __init__(self, lbn_layer=None, inp_particle_names=None, *args, **kwargs): self.lbn = lbn_layer diff --git a/numpy.py b/numpy.py index 356f2a3b0a61f351f5b7f0abbc19f217021ec1ca..dcad6f2bd6da140f09874bb468a15b031d8ad9b4 100644 --- a/numpy.py +++ b/numpy.py @@ -3,6 +3,12 @@ import numpy as np +def set_thresh(values, thresh=1e-5): + val = values.copy() + val[val < thresh] = 0 + return val + + def one_hot(a, n=None): if n is None: n = a.max() + 1