# from itertools import izip import itertools from functools import cached_property import gc from collections import OrderedDict, defaultdict import fnmatch import math import datetime import numpy as np import tensorflow as tf from tqdm import tqdm from inspect import getargspec from warnings import warn import re from tensorflow.python.keras.callbacks import make_logs from tensorflow.python.keras.backend import track_variable from tensorflow.python.keras.utils.mode_keys import ModeKeys from operator import itemgetter from lbn import LBNLayer as OriginalLBNLayer from matplotlib import pyplot as plt from .evil import pin from .data import SKDict, DSS from .plotting import ( figure_to_image, figure_confusion_matrix, figure_activations, figure_node_activations, figure_roc_curve, figure_y, figure_weights, figure_correlation, figure_inputs, figure_dict, figure_lbn_weights, ) from .numpy import mean_excl, std_excl # various helper functions def kVar(*args, **kwargs): """produce a keras-tracked tf.Variable from all given parameters""" var = tf.Variable(*args, **kwargs) track_variable(var) return var def kOpt(opt, **kwargs): """instanciate a keras Optimizer with all applicable **kwargs""" if not callable(opt): opt = getattr(tf.keras.optimizers, opt) assert issubclass(opt, tf.keras.optimizers.Optimizer) args = getargspec(opt.__init__).args - {"self"} return opt(**{k: v for k, v in kwargs.items() if k in args}) def kInput(ref, **kwargs): """produce a keras.Input with shape & dtype according to ref""" kwargs.setdefault("shape", ref.shape[1:]) kwargs.setdefault("dtype", ref.dtype) return tf.keras.Input(**kwargs) def keras_register_custom_object(obj): """decorator for globally registering a custom object with keras""" tf.keras.utils.get_custom_objects()[obj.__name__] = obj return obj class KFeed(object): def __init__(self, x, y, w=None, train="train", valid="valid", balance=lambda x: x): pin(locals()) @property def xyw(self): return tuple(v if isinstance(v, tuple) else (v,) for v in (self.x, self.y, self.w)) def get(self, src): return tuple(tuple(src[k] for k in v) for v in self.xyw) @property def all(self): return set().union(*self.xyw) def balance(self, src): return src def kfeed(self, src, **kwargs): src = src.only(*self.all) bal = self.balance(src) return dict( zip(["x", "y", "sample_weight"], self.get(bal[self.train])), validation_data=self.get(bal[self.valid]), **kwargs, ) def balance_weights(self, weights): if isinstance(weights, SKDict): sums = weights.map(np.sum) if len(sums) > 1: ref = np.mean(list(sums.values())) weights = weights.__class__({k: weights[k] * (ref / s) for k, s in sums.items()}) return weights def compact(self, src, keep=()): ret = src.only(*self.all, *keep) val = ret[self.valid] assert not isinstance(self.w, tuple) if isinstance(val[self.w], SKDict): val[self.w] = self.balance_weights(val[self.w]) val = val.fuse(*val[self.w].keys()) ret = ret.only(self.train) ret[self.valid] = val return ret def gfeed(self, src, batch_size, rng=np.random, auto_steps=np.max, **kwargs): """ Creates a generator for tf.keras' model.fit(). Requires, that mean weights per process are equal: dss["weight"] = dss["weight"].map(lambda x: x / np.mean(x)) """ src = self.compact(src) val = src[self.valid] val.blen # asserts that all first dimensions have equal length return dict( dict( zip( ("x", "steps_per_epoch"), self.gensteps(src[self.train], batch_size, rng=rng, auto_steps=auto_steps), ) ), validation_data=self.get(val), workers=0, **kwargs, ) def gensteps(self, src, batch_size, rng=np.random, auto_steps=np.max): keys = src.mkeys(self.all) gen = ( ( self.get(DSS.zip(*parts).map(np.concatenate)) for parts in zip( *[ src[k].batch_generator( batch_size // len(keys), rng=np.random.RandomState(rng.randint(1 << 31, size=20)), ) for k in keys ] ) ) if len(keys) > 1 else map(self.get, src[list(keys)[0]].batch_generator(batch_size, rng=rng)) ) gs = float(batch_size // len(keys)) steps = int(auto_steps([src[k].blen / gs for k in keys])) or None return gen, steps def generator(self, *args, **kwargs): assert "auto_steps" not in kwargs return self.gensteps(*args, **kwargs)[0] def getShapes(self, src): return tuple(tuple(src[k].shape for k in g) for g in self.xyw) def getShapesK(self, src): strip_first_dim = lambda x: x[1:] if x[0] is None else x shapes = self.getShapes(src) return tuple(tuple(strip_first_dim(a) for a in b) for b in shapes) def mkInputs(self, src, **kwargs): return tuple( tf.keras.Input(shape=ref.shape[1:], dtype=ref.dtype, **kwargs) for ref in (src[x] for x in self.xyw[0]) ) def Normal(ref, const=None, ignore_zeros=False, name=None, **kwargs): """ Normalizing layer according to ref. If given, variables at the indices const will not be normalized. """ if ignore_zeros: mean = mean_excl(ref, value=0, **kwargs) mean = np.nan_to_num(mean) std = std_excl(ref, value=0, **kwargs) std = np.nan_to_num(std) else: mean = ref.mean(**kwargs) std = ref.std(**kwargs) if const is not None: mean[const] = 0 std[const] = 1 std = np.where(std == 0, 1, std) mul = 1.0 / std add = -mean / std return tf.keras.layers.Lambda((lambda x: (x * mul) + add), name=name) def Onehot(index, n, name=None): """ One hot encodes a variable referred to by index. n is the number of different variables. """ def to_onehot(x): # Concat zeros to eye for indices in x equal to n (larger than those encoded by one) eye = tf.concat((tf.eye(n), tf.zeros((1, n))), axis=0) return tf.concat( ( x[..., :index], tf.gather(eye, tf.cast(x[..., index], tf.int64)), x[..., (index + 1) :], ), axis=-1, ) return tf.keras.layers.Lambda(to_onehot, name=name) class Moment(tf.keras.metrics.Mean): def __init__(self, order, label=False, **kwargs): """Metric calculating the order-th moment""" assert order == int(order) assert label == bool(label) kwargs.setdefault("name", "%smom%d" % ("l" if label else "", order)) super(Moment, self).__init__(**kwargs) self.order = order self.label = label def update_state(self, y_true, y_pred, sample_weight=None): y = y_true if self.label else y_pred y = tf.keras.backend.cast(y, self._dtype) if self.order == 0: y = tf.keras.backend.ones_like(y) elif self.order == 1: pass elif self.order == 2: y = tf.keras.backend.square(y) else: y = tf.keras.backend.pow(y, self.order) return super(Moment, self).update_state(y, sample_weight=sample_weight) def get_config(self): return dict(super(Moment, self).get_config(), order=self.order, label=self.label) # lots of call backs def _patfilter(pattern, items): if isinstance(pattern, (list, tuple)): pattern = "|".join(map(fnmatch.translate, pattern)) return filter(re.compile(pattern).search, items) try: import nvidia_smi except ImportError: pass else: class GPUStats(tf.keras.callbacks.Callback): """ conda: conda install -c fastai nvidia-ml-py3 pip : pip install nvidia-ml-py3 """ def __init__(self, idx=0): nvidia_smi.nvmlInit() self.handle = nvidia_smi.nvmlDeviceGetHandleByIndex(idx) def on_x_end(self, x, logs=None): self.mem = nvidia_smi.nvmlDeviceGetMemoryInfo(self.handle) self.res = nvidia_smi.nvmlDeviceGetUtilizationRates(self.handle) logs["GPU-Usage [%]"] = self.res.gpu logs["GPU-vRAM [%]"] = 100 * self.mem.used / self.mem.total logs["GPU-vRAM [MiB]"] = self.mem.used / (1024 ** 2) on_batch_end = on_x_end class Moment2Std(tf.keras.callbacks.Callback): def __init__(self, mom1="mom1", mom2="mom2", std="std"): assert mom1 and mom2 and std pin(locals()) def on_x_end(self, x, logs=None): if logs is None: return for key1, mom1 in list(logs.items()): if not key1.endswith(self.mom1): continue prefix = key1[: -len(self.mom1)] mom2 = logs.get(prefix + self.mom2, None) if mom2 is not None: logs[prefix + self.std] = (mom2 - mom1 ** 2) ** 0.5 on_batch_end = on_x_end on_epoch_end = on_x_end class LogRewrite(tf.keras.callbacks.Callback): def __init__(self, *rewrites, **kwargs): self.collapse_weighted = kwargs.pop("collapse_weighted", False) assert not kwargs self.rewrites = list(rewrites) if self.collapse_weighted: self.rewrites.append(lambda s: s.replace("weighted_", "")) def on_x_end(self, x, logs=None): if logs is None: return shadow = [] for key, val in logs.items(): short = key for rewrite in self.rewrites: short = rewrite(short) if short == key: continue if short in logs: shadow.append(short) logs[short] = val del logs[key] if shadow: warn("%r: shadows the following keys: %s" % (",".join(shadow))) on_batch_end = on_x_end on_epoch_end = on_x_end class LogTransforms(tf.keras.callbacks.Callback): _prefixesRe = re.compile("^(val_|)(.+)") def __init__(self, *funcs, **kwargs): self.transforms = [] for func in funcs: self.add_transform(func) for name, func in kwargs.items(): self.add_transform(func, name) def add_transform(self, func, name=None): if name is None: name = func.__name__ assert re.match(r"[a-zA-Z]\w*$", name) argspec = getargspec(func) assert argspec.keywords or argspec.args self.transforms.append((name, func, True if argspec.keywords else set(argspec.args))) def on_x_end(self, x, logs=None): if logs is None: return plogs = {} for key, value in logs.items(): prefix, suffix = self._prefixesRe.match(key).groups() plogs.setdefault(prefix, {})[suffix] = value for name, func, args in self.transforms: for prefix, log in plogs.items(): if args is True: val = func(**log) else: missing = args.difference(log.keys()) if missing: msg = "%r: %s=%r needs unavailable log values: %s" % ( self, name, func, ", ".join(missing), ) if prefix: warn(msg) continue else: raise RuntimeError(msg) val = func(**{key: val for key, val in log.items() if key in args}) logs[prefix + name] = log[name] = val on_batch_end = on_x_end on_epoch_end = on_x_end class LogEMA(tf.keras.callbacks.Callback): def __init__(self, keys=[], pattern=[], pairs={}, ema=0, ricu=0, ema_fmt="%s_ema"): assert ema < 1 and ricu < 1 and "%s" in ema_fmt if ema < 0: ema = 1 - 10 ** ema assert ema < 1 if ricu < 0: ricu = -1.0 / ricu keys = {k: None for k in keys} keys.update(pairs) pin(locals(), pairs) self.reset() def reset(self): self.ema_val = {} self.run_num = defaultdict(int) def on_epoch_end(self, x, logs=None): assert logs keys = {key: None for key in _patfilter(self.pattern, logs.keys())} keys.update(self.keys) for key, out in keys.items(): if not out: out = self.ema_fmt % key assert key in logs assert out not in keys assert out not in logs val = logs[key] if key in self.ema_val: delta = self.ema_val[key] - val dsign = int(np.sign(delta)) self.run_num[key] += dsign if dsign * np.sign(self.run_num[key]) < -self.ricu < 0: self.run_num[key] *= self.ricu ** 0.5 val += delta * self.ema ** (1 + (self.ricu * self.run_num[key]) ** 4) logs[out] = self.ema_val[key] = val class CustomValidation(tf.keras.callbacks.Callback): def __init__(self, **kwargs): self.kwargs = kwargs def on_epoch_end(self, x, logs=None): if logs is None: return res = self.model.evaluate(**self.kwargs) if not isinstance(res, list): res = [res] logs.update(make_logs(self.model, logs, res, mode=ModeKeys.TEST, prefix="val_")) class TFSummaryCallback(tf.keras.callbacks.Callback): def __init__(self, logdir=None, **kwargs): self.writer = tf.summary.create_file_writer(logdir) class PlottingCallback: def __init__(self, path=None, **kwargs): pin(locals()) super().__init__(**kwargs) def draw(self, figure, dict, key): if self.path is not None: plt.savefig(f"{self.path}/{key}.pdf") image = figure_to_image(figure) dict.update({key: image}) class PlotMulticlass(PlottingCallback, TFSummaryCallback): def __init__( self, x, y, sample_weight=(None,), class_names=["signal", "background"], to_file=False, columns=None, plot_inputs=False, signalvsbkg=False, plot_importance=False, plot_activations=False, tag="", **kwargs, ): pin(locals()) self.y = y[0] self.sample_weight_flat = sample_weight[0] self.comet_experiment = kwargs.get("comet_experiment", None) self.freq = kwargs.get("freq", None) self.file_extension = kwargs.get("file_extension", "pdf") self.verbose = int(kwargs.get("verbose", 1)) super().__init__(**kwargs) def draw(self, figure, dict, key, epoch=None, name=None): if self.path is not None: image_path = f"{self.path}/{key}.{self.file_extension}" if self.verbose: print("PlotMulticlass-callback: Plotting {}".format(key)) plt.savefig(image_path) if self.comet_experiment is not None: self.comet_experiment.log_image( image_path, name=name, image_format=self.file_extension, step=epoch ) image = figure_to_image(figure) dict.update({key: image}) def on_train_begin(self, logs=None): if self.plot_inputs: self.make_input_plots() def on_epoch_end(self, epoch, logs=None): if isinstance(self.freq, (float, int)): if epoch % int(self.freq) == 0 and epoch > 0: self.make_eval_plots(epoch=epoch, name=str(epoch)) def on_train_end(self, logs=None): self.make_eval_plots() def make_input_plots(self): inps = self.x if not isinstance(inps, (list, tuple)): inps = [inps] imgs = {} if self.sample_weight is not None: for part, inp in zip(self.columns.keys(), inps): self.draw( figure_inputs( inp, self.y, sample_weight=self.sample_weight_flat, columns=self.columns[part], class_names=self.class_names, signalvsbkg=self.signalvsbkg, ), imgs, f"inp_x_{part}", name="Inputs", ) self.draw( figure_weights(self.sample_weight_flat, self.y, class_names=self.class_names), imgs, "inp_weights", ) self.draw( figure_correlation( np.concatenate( [inp.reshape(-1, np.prod(inp.shape[1:])) for inp in inps], axis=1 ), label=list(itertools.chain(*self.columns.values())), ), imgs, "inp_correlation", ) self.draw( figure_y(self.y, class_names=self.class_names), imgs, "inp_y", ) for name, img in imgs.items(): with self.writer.as_default(): tf.summary.image(f"{name}{self.tag}", img, step=0) def clear_figure(self, fig): fig.clf() del fig plt.close("all") gc.collect() @cached_property def prediction(self): return self.model.predict(self.x, batch_size=4096) def make_eval_plots(self, epoch=0, name=""): imgs = {} fig = figure_roc_curve( self.y, self.prediction, class_names=self.class_names, sample_weight=self.sample_weight_flat, ) self.draw(fig, imgs, "roc_curve{}".format(name), epoch=epoch, name="ROC") self.clear_figure(fig) fig = figure_confusion_matrix( self.y, self.prediction, class_names=self.class_names, sample_weight=self.sample_weight_flat, normalize="true", ) self.draw( fig, imgs, "confusion_matrix_true{}".format(name), epoch=epoch, name="Confusion-Matrix-True", ) self.clear_figure(fig) fig = figure_confusion_matrix( self.y, self.prediction, class_names=self.class_names, sample_weight=self.sample_weight_flat, normalize="pred", ) self.draw( fig, imgs, "confusion_matrix_pred{}".format(name), epoch=epoch, name="Confusion-Matrix-Pred", ) self.clear_figure(fig) if self.plot_activations: fig = figure_node_activations( self.prediction, self.y, class_names=self.class_names, disjoint=True, sample_weight=self.sample_weight_flat, ) self.draw( fig, imgs, "node_activation_disjoint{}".format(name), epoch=epoch, name="Node-Activation", ) self.clear_figure(fig) if self.plot_importance: for method in ["perm"]: importance = feature_importance( self.model, x=self.x, y=self.y, sample_weight=self.sample_weight, method=method, columns=[c for col in self.columns.values() for c in col], batch_size=4096, ) kinematics = ["px", "py", "pz", "x", "y", "z", "energy"] fig = figure_dict( { key: importance[key] for key in importance.keys() if any(key.endswith(f"_{p}") for p in kinematics) }, xlabel="Importance [AU]", ) self.draw( fig, imgs, f"importance_{method}_kin{name}", epoch=epoch, name="Kinematic Importance", ) self.clear_figure(fig) fig = figure_dict( { key: importance[key] for key in importance.keys() if not any(key.endswith(p) for p in kinematics) }, xlabel="Importance [AU]", ) self.draw( fig, imgs, f"importance_{method}_nonkin{name}", epoch=epoch, name="Non-Kinematic Importance", ) self.clear_figure(fig) for name, img in imgs.items(): with self.writer.as_default(): tf.summary.image(f"{name}{self.tag}", img, step=epoch) del imgs class PlotMulticlassEval(PlotMulticlass): def on_test_end(self, logs=None): self.make_input_plots() self.make_eval_plots(0) class PlotAssignment(PlotMulticlass): def __init__( self, x, y, **kwargs, ): super().__init__(x, y, **kwargs) n_objects = self.y.shape[1] n_flavours = 3 self.y = self.y.reshape(-1, n_flavours) self.sample_weight_flat = (self.y * np.array([1 / 0.81, 1 / 0.15, 1 / 0.037])).sum(axis=-1) @cached_property def prediction(self): return self.model.predict(self.x, batch_size=4096).reshape(-1, 3) class PlotLBN(PlottingCallback, TFSummaryCallback): def __init__(self, lbn_layer=None, inp_particle_names=None, *args, **kwargs): self.lbn = lbn_layer self.inp_particle_names = inp_particle_names super().__init__(*args, **kwargs) def on_train_end(self, logs=None): self.make_plots() def make_plots(self, epoch=0): imgs = {} pkwargs = {"inp_particle_names": self.inp_particle_names} fig = figure_lbn_weights( self.lbn.weights[0].numpy(), name="particles", cmap="OrRd", **pkwargs ) self.draw(fig, imgs, "lbn_particles") fig = figure_lbn_weights( self.lbn.weights[1].numpy(), name="restframes", cmap="YlGn", **pkwargs ) self.draw(fig, imgs, "lbn_restframes") pkwargs["norm"] = True fig = figure_lbn_weights( self.lbn.weights[0].numpy(), name="particles", cmap="OrRd", **pkwargs ) self.draw(fig, imgs, "lbn_particles_normed") fig = figure_lbn_weights( self.lbn.weights[1].numpy(), name="restframes", cmap="YlGn", **pkwargs ) self.draw(fig, imgs, "lbn_restframes_normed") for name, img in imgs.items(): with self.writer.as_default(): tf.summary.image(name, img, step=epoch) class PlotLBNEval(PlotLBN): def on_test_end(self, logs=None): self.make_plots() class ModelLH(tf.keras.Model): def __init__(self, *args, **kwargs): self.loss_hook = kwargs.pop("loss_hook", None) super(ModelLH, self).__init__(*args, **kwargs) def _update_sample_weight_modes(self, sample_weights=None): if not self._is_compiled: return if sample_weights and any([s is not None for s in sample_weights]): pass # don't default sample_weight_mode to "samplewise", it prevents proper function caching # for endpoint in self._training_endpoints: # endpoint.sample_weight_mode = ( # endpoint.sample_weight_mode or 'samplewise') else: for endpoint in self._training_endpoints: endpoint.sample_weight_mode = None def _prepare_total_loss(self, *args, **kwargs): orig = [ (ep, ep.__dict__.copy(), ep.training_target.__dict__.copy()) for ep in self._training_endpoints ] self.loss_hook(self._training_endpoints.copy()) ret = super(ModelLH, self)._prepare_total_loss(*args, **kwargs) for ep, ed, td in orig: ep.__dict__.update(ed) ep.training_target.__dict__.update(td) return ret class GCCallback(tf.keras.callbacks.Callback): def on_epoch_end(self, epoch, logs=None): gc.collect() class TensorBoard(tf.keras.callbacks.TensorBoard): def __init__(self, *args, **kwargs): self.writer = kwargs.pop("writer") super(TensorBoard, self).__init__(*args, **kwargs) def _init_writer(self, model=None): pass def on_train_end(self, logs=None): pass class CheckpointModel(tf.keras.callbacks.Callback): """Gets a dict of targets (checkpoints), if current epoch is in dict, save model to target.""" def __init__(self, checkpoints=None): self.targets = checkpoints.targets def on_epoch_end(self, epoch, logs=None): if epoch in self.targets: target = self.targets[epoch] self.model.save(target.path) class BestTracker(tf.keras.callbacks.Callback): def __init__( self, monitor="val_loss", mode="auto", min_delta=0, min_delta_rel=0, baseline=None, ): pin(locals()) self.reset() @property def mode_multiplier(self): assert self.mode in ("auto", "min", "max") if self.mode == "max" or (self.mode == "auto" and "acc" in self.monitor): return -1 else: return 1 def reset(self): if self.baseline is not None: self.best = self.baseline else: self.best = self.mode_multiplier * np.inf def update_best(self, logs): current = self.get_monitor_value(logs) relative = delta = (current - self.best) * self.mode_multiplier if np.isfinite(self.best): relative /= abs(self.best) or 1.0 update = delta < -self.min_delta and relative < -self.min_delta_rel if update: self.best = current return update def get_monitor_value(self, logs): assert logs assert self.monitor in logs return logs[self.monitor] class PatientTracker(BestTracker): """ Jank: All PatientTrackers can communicate over the jank variable: If one PatientTracker triggers, all other PatientTrackers may not trigger for the number of epochs given by their jank. In order to participate in the jank mechanism, the jank of a PatientTracker must be finite (e.g. jank -1). """ def __init__( self, patience=10, cooldown=0, cooldown0=0, override=None, jank=np.nan, jank_last=-np.inf, **kwargs, ): if override is None: del override cooldown_counter = cooldown if cooldown0 is True else cooldown0 pin(locals(), kwargs) super(PatientTracker, self).__init__(**kwargs) def reset(self): super(PatientTracker, self).reset() self.wait = 0 def override(self, logs): return None def patient_step(self, epoch, logs): # update jank logs["jank"] = min(logs.get("jank", np.inf), epoch - self.jank_last) # test override override = self.override(logs) if override is not None: return override # check jank if logs["jank"] <= self.jank: return if 0 < self.cooldown_counter: self.cooldown_counter -= 1 if self.update_best(logs): self.wait = 0 return "best" elif self.cooldown_counter <= 0: self.wait += 1 if self.patience <= self.wait: self.cooldown_counter = self.cooldown return "good" class ScaleOnPlateau(PatientTracker): def __init__( self, target, factor, min=None, max=None, verbose=0, log_key=None, linearly=False, **kwargs, ): pin(locals(), kwargs) super(ScaleOnPlateau, self).__init__(**kwargs) def on_train_begin(self, logs=None): self.reset() @property def value(self): return tf.keras.backend.get_value(self.target) @value.setter def value(self, new): return tf.keras.backend.set_value(self.target, new) def on_epoch_end(self, epoch, logs=None): logs = logs or {} cur = self.value if self.log_key is not None: logs.setdefault(self.log_key, cur) if self.patient_step(epoch, logs) == "good": if self.linearly: new = cur + self.factor else: new = cur * self.factor if self.min is not None: new = max(new, self.min) if self.max is not None: new = min(new, self.max) if cur != new: self.value = new self.reset() if np.isfinite(self.jank): self.jank_last = epoch if self.verbose > 0: print( "\nEpoch %05d: %s scaling %s to %s." % ( epoch + 1, self.__class__.__name__, self.log_key or self.target, new, ) ) class ReduceLROnPlateau(ScaleOnPlateau): def __init__(self, min_lr=0, **kwargs): super(ReduceLROnPlateau, self).__init__(min=min_lr, target=None, log_key="lr", **kwargs) @property def target(self): return self.model.optimizer.lr @target.setter def target(self, target): assert target is None class EarlyStopping(PatientTracker): def __init__(self, runtime=None, restore_best_weights=False, verbose=0, do_stop=None, **kwargs): stop_time = datetime.datetime.now() + datetime.timedelta(**runtime) if runtime else None pin(locals(), kwargs) super(EarlyStopping, self).__init__(**kwargs) def reset(self): super(EarlyStopping, self).reset() self.best_weights = None self.best_epoch = None def on_train_begin(self, logs=None): self.reset() def on_epoch_end(self, epoch, logs=None): action = self.patient_step(epoch, logs) if action == "best": if self.restore_best_weights: self.best_weights = self.model.get_weights() self.best_epoch = epoch elif action == "good" or ( (self.stop_time is not None) and (datetime.datetime.now() > self.stop_time) ): self.stopped_epoch = epoch self.model.stop_training = True if self.restore_best_weights and self.best_weights is not None: self.model.set_weights(self.best_weights) if self.verbose > 0: print( "EarlyStopping-Callback: Restoring model weights from the end of the best epoch." ) def on_train_end(self, logs=None): hist = self.model.history if self.restore_best_weights and self.best_epoch in hist.epoch: idx = hist.epoch.index(self.best_epoch) else: idx = -1 hist.final_logs = {k: v[idx] for k, v in hist.history.items()} class TQES(EarlyStopping): def __init__(self, log_pattern=None, prog_batch="steps", **kwargs): assert prog_batch in (None, "steps", "samples") pin(locals(), kwargs) super(TQES, self).__init__(**kwargs) @property def use_steps(self): return self.prog_batch == "steps" def on_train_begin(self, logs=None): super(TQES, self).on_train_begin(logs) self.tqE = tqdm(desc=getattr(self.model, "name", None), total=self.params["epochs"]) def on_epoch_begin(self, epoch, logs=None): if self.prog_batch: self.tqB = tqdm( unit=("batch" if self.use_steps else "sample"), total=self.params["steps" if self.use_steps else "samples"], ) def on_batch_end(self, batch, logs=None): if self.tqB: logs = logs or {} self.tqB.set_postfix(self.make_postfix(logs), refresh=False) self.tqB.update( logs.get("num_steps", 1) * (1 if self.use_steps else logs.get("size", 0)) ) def on_epoch_end(self, epoch, logs=None): super(TQES, self).on_epoch_end(epoch, logs) last = self.get_monitor_value(logs) if self.tqB: self.tqB.close() self.tqB = None self.tqE.set_postfix( self.make_postfix( logs, [ ("best", self.best), ("conf", self.wait), ("rdlb", ((last or 0.0) - self.best) / self.best), ], ), refresh=False, ) self.tqE.update(epoch - self.tqE.n) def make_postfix(self, logs, extra=[]): if self.log_pattern is None: keys = self.params["metrics"] else: keys = _patfilter(self.log_pattern, logs.keys()) return OrderedDict( [(key, logs[key]) for key in sorted(keys) if key in logs] + list(filter(None, extra)) ) def on_train_end(self, logs=None): super(TQES, self).on_train_end(logs) self.tqE.close() class AUCOneVsAll(tf.keras.metrics.AUC): def __init__(self, one=0, *args, **kwargs): self.one = one super().__init__(*args, **kwargs) def update_state(self, y_true, y_pred, sample_weight=None): from tensorflow.python.framework import ops from tensorflow.python.framework import tensor_shape from tensorflow.python.keras.utils import metrics_utils from tensorflow.python.ops import check_ops deps = [] if not self._built: self._build(tensor_shape.TensorShape(y_pred.shape)) if self.multi_label or (self.label_weights is not None): # y_true should have shape (number of examples, number of labels). shapes = [(y_true, ("N", "L"))] if self.multi_label: # TP, TN, FP, and FN should all have shape # (number of thresholds, number of labels). shapes.extend( [ (self.true_positives, ("T", "L")), (self.true_negatives, ("T", "L")), (self.false_positives, ("T", "L")), (self.false_negatives, ("T", "L")), ] ) if self.label_weights is not None: # label_weights should be of length equal to the number of labels. shapes.append((self.label_weights, ("L",))) deps = [check_ops.assert_shapes(shapes, message="Number of labels is not consistent.")] # Only forward label_weights to update_confusion_matrix_variables when # multi_label is False. Otherwise the averaging of individual label AUCs is # handled in AUC.result label_weights = None if self.multi_label else self.label_weights with ops.control_dependencies(deps): return metrics_utils.update_confusion_matrix_variables( { metrics_utils.ConfusionMatrix.TRUE_POSITIVES: self.true_positives, metrics_utils.ConfusionMatrix.TRUE_NEGATIVES: self.true_negatives, metrics_utils.ConfusionMatrix.FALSE_POSITIVES: self.false_positives, metrics_utils.ConfusionMatrix.FALSE_NEGATIVES: self.false_negatives, }, y_true[:, self.one], y_pred[:, self.one], self.thresholds, sample_weight=sample_weight, multi_label=self.multi_label, label_weights=label_weights, ) # Layer Definitions class WhereEquals(tf.keras.layers.Layer): def __init__(self, value=0): super(WhereEquals, self).__init__() self.value = value def call(self, inp): return tf.where(inp[:, 0] == self.value) class BaseLayer(tf.keras.layers.Layer): """ The BaseLayer object is a feature collection of different features used in a DNN layer It features: * l2 regu * the weights (the real layer) * batch norm * activation function * dynamically chosen dropout Parameters ---------- nodes : int The number of nodes. activation : str or one of tf.keras.activations The used activation function. dropout : float The used dropout ration. If "selu" is used as activation function, dropout becomes AlphaDropout. l2 : float The used factor of l2 regu. batch_norm : bool Wether to use dropout or not. If selu activation is used, batch_norm is forced off. """ def __init__(self, nodes=0, activation=None, dropout=0.0, l2=0, batch_norm=False, **kwargs): super().__init__(name="BaseLayer") self.nodes = nodes self.activation = activation self.dropout = dropout self.l2 = l2 self.batch_norm = batch_norm def build(self, input_shape): parts = {} if self.activation == "selu": kernel_initializer = "lecun_normal" else: kernel_initializer = "glorot_uniform" l2 = tf.keras.regularizers.l2(self.l2) weights = tf.keras.layers.Dense( self.nodes, kernel_regularizer=l2, kernel_initializer=kernel_initializer, ) parts["weights"] = weights if self.batch_norm and not self.activation == "selu": bn = tf.keras.layers.BatchNormalization() parts["bn"] = bn act = tf.keras.layers.Activation(self.activation) parts["act"] = act if self.activation == "selu": dropout = tf.keras.layers.AlphaDropout(self.dropout) else: dropout = tf.keras.layers.Dropout(self.dropout) parts["dropout"] = dropout self.parts = parts @property def part_order(self): raise NotImplementedError def call(self, input_tensor, training=False): x = input_tensor for part in self.part_order: if part in self.parts: x = self.parts[part](x, training=training) return x def get_config(self): return { "nodes": self.nodes, "activation": self.activation, "dropout": self.dropout, "l2": self.l2, "batch_norm": self.batch_norm, } class DenseLayer(BaseLayer): part_order = ["weights", "bn", "act", "dropout"] def __init__(self, **kwargs): super().__init__(name="DenseLayer", **kwargs) class BNActWeightLayer(BaseLayer): part_order = ["bn", "act", "weights", "dropout"] def __init__(self, **kwargs): super().__init__(name="BNActWeightLayer", **kwargs) class ResNetBlock(tf.keras.layers.Layer): """ The ResNetBlock object is an implementation of one residual DNN block. Parameters ---------- jump : int The number layers to bypass by residual connection. kwargs : Arguments for DenseLayer. """ def __init__(self, jump=2, sub_kwargs=None, **kwargs): super().__init__(name="ResNetBlock") self.jump = jump self.sub_kwargs = kwargs if sub_kwargs is None else sub_kwargs def build(self, input_shape): layers = [] for i in range(self.jump - 1): layers.append(DenseLayer(**self.sub_kwargs)) kwargs = dict(self.sub_kwargs) activation = kwargs.pop("activation") kwargs.pop("nodes") layers.append(DenseLayer(nodes=input_shape[-1], **kwargs)) self.layers = layers self.out_activation = tf.keras.layers.Activation(activation) def call(self, input_tensor, training=False): x = input_tensor for layer in self.layers: x = layer(x, training=training) x += input_tensor x = self.out_activation(x) return x def get_config(self): return {"jump": self.jump, "sub_kwargs": self.sub_kwargs} class DenseNetBlock(tf.keras.layers.Layer): """ The DenseNetBlock object is an implementation of one DenseNet DNN block. Parameters ---------- block_size : int The size of a DenseNet Block, includes the transition layer. kwargs : Arguments for DenseLayer. """ def __init__(self, block_size=2, sub_kwargs=None, **kwargs): super().__init__(name="DenseNetBlock") self.block_size = block_size self.sub_kwargs = kwargs if sub_kwargs is None else sub_kwargs def build(self, input_shape): layers = [] for i in range(self.block_size - 1): layers.append(DenseLayer(**self.sub_kwargs)) self.layers = layers self.transition = DenseLayer(**self.sub_kwargs) def call(self, input_tensor, training=False): x = input_tensor for layer in self.layers: y = layer(x, training=training) x = tf.keras.layers.concatenate([x, y], axis=-1) x = self.transition(x) return x def get_config(self): return {"block_size": self.block_size, "sub_kwargs": self.sub_kwargs} class LinearNetwork(tf.keras.layers.Layer): @property def name(self): raise NotImplementedError @property def substructure(self): raise NotImplementedError def __init__(self, layers=0, sub_kwargs=None, **kwargs): super().__init__(name=self.name) self.layers = layers self.sub_kwargs = kwargs if sub_kwargs is None else sub_kwargs def build(self, input_shape): network_layers = [] for layer in range(self.layers): network_layers.append(self.substructure(**self.sub_kwargs)) self.network_layers = network_layers def call(self, input_tensor, training=False): x = input_tensor for layer in self.network_layers: x = layer(x, training=training) return x def get_config(self): return {"layers": self.layers, "sub_kwargs": self.sub_kwargs} class FullyConnected(LinearNetwork): """ The FullyConnected object is an implementation of a fully connected DNN. Parameters ---------- layers : int The number of layers. kwargs : Arguments for DenseLayer. """ name = "FullyConnected" substructure = DenseLayer class ResNet(LinearNetwork): """ The ResNet object is an implementation of a Residual Neural Network. Parameters ---------- layers : int The number of residual blocks. kwargs : Arguments for ResNetBlock. """ name = "ResNet" substructure = ResNetBlock class DenseNet(LinearNetwork): """ The DenseNet object is an implementation of a DenseNet Neural Network. Parameters ---------- layers : int The number of densely connected (DenseNet) blocks. kwargs : Arguments for DenseNetBlock. """ name = "DenseNet" substructure = DenseNetBlock class Xception1D(tf.keras.layers.Layer): """ The Xception1D object is an implementation of a Xception Neural Network. Because Xception was trained on 2D RGB data, each feature vector for the Xception1D is automatically scaled up to (71, 71, 3). Parameters ---------- kwargs : Arguments for DenseLayers. """ def __init__(self, sub_kwargs=None, **kwargs): super().__init__(name="DenseNet") self.sub_kwargs = kwargs if sub_kwargs is None else sub_kwargs def build(self, input_shape): lenght = input_shape[-1] assert lenght < 71 # until now only includes small feature spaces pad = math.ceil((71 - lenght) / 2) padded_lenght = 71 + 2 * math.ceil((71 - lenght) / 2) self.up_sample = tf.keras.layers.UpSampling2D(size=(1, lenght)) self.zero_pad = ( tf.keras.layers.ZeroPadding2D(pad) if lenght < 71 else tf.keras.layers.Lambda(lambda x: x) ) self.first_network = tf.keras.applications.Xception( include_top=False, input_shape=(71 + padded_lenght, padded_lenght, 3) ) self.first_network.trainable = False self.flatten = tf.keras.layers.Flatten() self.second_network = FullyConnected(**self.sub_kwargs) def call(self, input_tensor, training=False): x = input_tensor x = tf.keras.backend.expand_dims(x) x = tf.keras.backend.expand_dims(x) x = self.up_sample(x) x = tf.keras.backend.repeat_elements(x, 3, axis=-1) x = self.zero_pad(x) x = self.first_network(x) x = self.flatten(x) x = self.second_network(x) return x def get_config(self): return {"sub_kwargs": self.sub_kwargs} class PermutationInvariantDense1D(tf.keras.layers.Layer): """ The PermutationInvariantDense1D layer is a TimeDistributed(Dense) implementation. Additionally it imposes (if wanted) a permutation inivariant pooling operation. Parameters ---------- name : name of the layer depth : depth of the ResNet layers nfeatures: number of features of Conv1D layer batchnorm: enable batchnorm for ResNet layers pooling_op: permutation invariant pooling operation """ def __init__(self, **kwargs): name = kwargs.pop("name", "PermutationInvariantDense1D") super().__init__(name=name) self.depth = kwargs.pop("depth", 1) self.nfeatures = kwargs.pop("nfeatures", 32) self.batchnorm = kwargs.pop("batchnorm", False) self.pooling_op = kwargs.pop("pooling_op", None) def build(self, input_shape): opts = dict(activation="elu", kernel_initializer="he_normal") self.layer = tf.keras.layers.TimeDistributed( tf.keras.layers.Dense( self.nfeatures, input_shape=input_shape, **opts, ) ) network_layers = [] for layer in range(self.depth - 1): network_layers.append( tf.keras.layers.TimeDistributed(tf.keras.layers.Dense(self.nfeatures, **opts)) ) if self.batchnorm: network_layers.append(tf.keras.layers.BatchNormalization()) if self.pooling_op: # permutation invariant pooling op network_layers.append(tf.keras.layers.Lambda(lambda x: self.pooling_op(x, axis=1))) self.network_layers = network_layers def call(self, input_tensor, training=False): x = input_tensor for layer in self.network_layers: x = layer(x, training=training) return x class SplitHighLow(tf.keras.layers.Layer): def call(self, inputs): return inputs[:, :, :4], inputs[:, :, 4:] class LL(tf.keras.layers.Layer): def call(self, inputs): return inputs[:, :, :4] class LBNLayer(tf.keras.layers.Layer): """ Custom implementation of the LBNLayer with automatic cropping to low-level variables before and batchnorm after the LBN layer. Parameters ---------- args : args for original LBN layer kwargs : kwargs for original LBN layer """ def __init__(self, *args, **kwargs): super().__init__(name="LBNLayer") _kwargs = kwargs.copy() if "features" not in _kwargs: _kwargs["features"] = self.all_features self._args = args self._kwargs = _kwargs def build(self, input_shape): self.concat = tf.keras.layers.Concatenate(axis=-2) lbn_inp_shape = (sum(_[-2] for _ in input_shape), 4) self.lbn_layer = OriginalLBNLayer(lbn_inp_shape, *self._args, **self._kwargs) self.batch_norm = tf.keras.layers.BatchNormalization() @property def all_features(self): return [ "E", "px", "py", "pz", "pt", "p", "m", "phi", "eta", "beta", "gamma", "pair_cos", "pair_dr", "pair_ds", "pair_dy", ] def call(self, input_tensors, training=False): ll = [SplitHighLow()(tensor)[0] for tensor in input_tensors] ll = self.concat(ll) feats = self.lbn_layer(ll) feats = self.batch_norm(feats) return feats def feature_importance_grad(model, x=None, filter_types=["int32", "int64"], **kwargs): inp = {f"{i}": _x for (i, _x) in enumerate(x)} inp["sample_weight"] = kwargs["sample_weight"] ds = tf.data.Dataset.from_tensor_slices((inp)) ds = ds.batch(256) grad = 0 n_batches = tf.data.experimental.cardinality(ds).numpy() with tqdm(total=n_batches) as pbar: pbar.set_description("Importance grad") for _x in ds: if "sample_weight" in _x: sw = tf.cast(_x.pop("sample_weight"), tf.float32) inp = list(_x.values()) with tf.GradientTape(watch_accessed_variables=False) as tape: [tape.watch(i) for i in inp if i.dtype not in filter_types] pred = model(inp, training=False) ix = tf.keras.layers.Lambda(lambda x: x[:, -1])(tf.argsort(pred, axis=-1)) decision = tf.gather(pred, ix, batch_dims=1) gradients = tape.gradient( decision, [i for i in inp if i.dtype not in filter_types] ) # gradients for decision nodes gradients = [ grad if grad is not None else tf.constant(0.0, dtype=tf.float32) for grad in gradients ] # categorical tensors gradients = [tf.transpose(tf.transpose(g) * sw) for g in gradients] gradients = [tf.math.reduce_mean(tf.math.abs(g), axis=0) for g in gradients] # norm gradients to std normed_gradients = [] for g, i in zip(gradients, inp): if i.dtype in filter_types: val = np.nan * tf.math.reduce_std(i, axis=0) else: val = g * tf.math.reduce_std(i, axis=0) normed_gradients.append(val) grad += tf.concat([tf.keras.backend.flatten(g) for g in normed_gradients], axis=-1) pbar.update() mean_gradients = grad.numpy() / n_batches return mean_gradients / mean_gradients.max() def feature_importance_perm(model, x=None, **kwargs): inp_list = list(x) feat = 0 # loss ref = model.evaluate(x=x, **kwargs)[feat] vals = [] for tensor_index, tensor in enumerate(inp_list): for index in np.ndindex(tensor.shape[1:]): # copy original values of feature original_values = np.copy(tensor[(Ellipsis, *index)]) # randomly permute indexed feature tensor[(Ellipsis, *index)] = np.random.permutation(tensor[(Ellipsis, *index)]) # write permuteted feature to network input (inp) inp = inp_list.copy() inp[tensor_index] = tensor vals.append(model.evaluate(x=inp, **kwargs)[feat]) # cleanup and write back original values of feature tensor[(Ellipsis, *index)] = original_values vals = np.array(vals) return (vals - ref) / ref def feature_importance(*args, method="grad", columns=[], **kwargs): if method == "grad": importance = feature_importance_grad(*args, **kwargs) elif method == "perm": importance = feature_importance_perm(*args, **kwargs) else: raise NotImplementedError("Feature importance method not implemented") return { k: v for k, v in sorted( dict(zip(columns, importance.astype(float))).items(), key=lambda item: item[1], ) if v is not np.nan } # Custom Loss Functions @tf.function def cross_entropy_t( labels, predictions, sample_weight=None, focal_gamma=None, class_weight=None, epsilon=1e-7, ): # get true-negative component predictions = tf.clip_by_value(predictions, epsilon, 1 - epsilon) tn = labels * tf.math.log(predictions) # focal loss? if focal_gamma is not None: tn *= (1 - predictions) ** focal_gamma # convert into loss loss = -tn # apply class weights if class_weight is not None: loss *= class_weight # apply sample weights if sample_weight is not None: loss *= sample_weight[:, tf.newaxis] return tf.reduce_mean(tf.reduce_sum(loss, axis=-1)) @tf.function def grouped_cross_entropy_t( labels, predictions, sample_weight=None, group_ids=None, focal_gamma=None, class_weight=None, epsilon=1e-7, std_xent=True, ): loss_terms = [] # ensure same type for labels and predictions labels = tf.cast(labels, predictions.dtype) if std_xent: loss_terms.append( cross_entropy_t( labels, predictions, sample_weight=sample_weight, focal_gamma=focal_gamma, class_weight=class_weight, epsilon=epsilon, ) ) if group_ids: # create grouped labels and predictions labels_grouped, predictions_grouped = [], [] for _, ids in group_ids: labels_grouped.append( tf.reduce_sum(tf.gather(labels, ids, axis=-1), axis=-1, keepdims=True) ) predictions_grouped.append( tf.reduce_sum(tf.gather(predictions, ids, axis=-1), axis=-1, keepdims=True) ) labels_grouped = ( tf.concat(labels_grouped, axis=-1) if len(labels_grouped) > 1 else labels_grouped[0] ) predictions_grouped = ( tf.concat(predictions_grouped, axis=-1) if len(predictions_grouped) > 1 else predictions_grouped[0] ) loss_terms.append( cross_entropy_t( labels_grouped, predictions_grouped, sample_weight=sample_weight, focal_gamma=focal_gamma, class_weight=tf.constant([w for w, _ in group_ids], tf.float32), epsilon=epsilon, ) ) return sum(loss_terms) / len(loss_terms) class GroupedXEnt(tf.keras.losses.Loss): def __init__( self, group_ids=None, focal_gamma=None, class_weight=None, epsilon=1e-7, std_xent=True, *args, **kwargs, ): super(GroupedXEnt, self).__init__(*args, **kwargs) self.group_ids = group_ids self.focal_gamma = focal_gamma self.class_weight = class_weight self.epsilon = epsilon self.std_xent = std_xent def call(self, y_true, y_pred, sample_weight=None): return grouped_cross_entropy_t( y_true, y_pred, sample_weight=sample_weight, group_ids=self.group_ids, focal_gamma=self.focal_gamma, class_weight=self.class_weight, epsilon=self.epsilon, std_xent=self.std_xent, ) def write_summary(model, target): """write tf keras model summary to law target""" summary = [] model.summary(print_fn=lambda x: summary.append(x)) summary = "\n".join(summary) target.dump(summary) def find_layer(layer, name="LBNLayer"): if layer.name == name: return layer if hasattr(layer, "layers"): sub_layers = layer.layers for sub_layer in sub_layers: layer = find_layer(sub_layer, name=name) if layer is not None: return layer