Skip to content
Snippets Groups Projects
Code owners
Assign users and groups as approvers for specific file changes. Learn more.
keras.py 41.66 KiB
# from itertools import izip
import gc
from collections import OrderedDict, defaultdict
import fnmatch

import numpy as np
import tensorflow as tf
from tqdm import tqdm
from inspect import getargspec
from warnings import warn
import re
from tensorflow.python.keras.callbacks import make_logs
from tensorflow.python.keras.backend import track_variable
from tensorflow.python.keras.utils.mode_keys import ModeKeys
from operator import itemgetter
from lbn import LBNLayer as OriginalLBNLayer
from matplotlib import pyplot as plt

from .evil import pin
from .data import SKDict, DSS
from .plotting import (
    figure_to_image,
    figure_confusion_matrix,
    figure_activations,
    figure_node_activations,
    figure_roc_curve,
    figure_multihist,
    figure_y,
    figure_weights,
    figure_inputs,
    figure_dict,
)

# various helper functions


def kVar(*args, **kwargs):
    """ produce a keras-tracked tf.Variable from all given parameters """
    var = tf.Variable(*args, **kwargs)
    track_variable(var)
    return var


def kOpt(opt, **kwargs):
    """ instanciate a keras Optimizer with all applicable **kwargs """
    if not callable(opt):
        opt = getattr(tf.keras.optimizers, opt)
    assert issubclass(opt, tf.keras.optimizers.Optimizer)
    args = getargspec(opt.__init__).args - {"self"}
    return opt(**{k: v for k, v in kwargs.items() if k in args})


def kInput(ref, **kwargs):
    """ produce a keras.Input with shape & dtype according to ref """
    kwargs.setdefault("shape", ref.shape[1:])
    kwargs.setdefault("dtype", ref.dtype)
    return tf.keras.Input(**kwargs)


def keras_register_custom_object(obj):
    """ decorator for globally registering a custom object with keras """
    tf.keras.utils.get_custom_objects()[obj.__name__] = obj
    return obj


class KFeed(object):
    def __init__(self, x, y, w=None, train="train", valid="valid", balance=lambda x: x):
        pin(locals())

    @property
    def xyw(self):
        return tuple(v if isinstance(v, tuple) else (v,) for v in (self.x, self.y, self.w))

    def get(self, src):
        return tuple(tuple(src[k] for k in v) for v in self.xyw)

    @property
    def all(self):
        return set().union(*self.xyw)

    def balance(self, src):
        return src

    def kfeed(self, src, **kwargs):
        src = src.only(*self.all)
        bal = self.balance(src)
        return dict(
            zip(["x", "y", "sample_weight"], self.get(bal[self.train])),
            validation_data=self.get(bal[self.valid]),
            **kwargs,
        )

    def balance_weights(self, weights):
        if isinstance(weights, SKDict):
            sums = weights.map(np.sum)
            if len(sums) > 1:
                ref = np.mean(list(sums.values()))
                weights = weights.__class__({k: weights[k] * (ref / s) for k, s in sums.items()})
        return weights

    def compact(self, src, keep=()):
        ret = src.only(*self.all, *keep)
        val = ret[self.valid]
        assert not isinstance(self.w, tuple)
        if isinstance(val[self.w], SKDict):
            val[self.w] = self.balance_weights(val[self.w])
            val = val.fuse(*val[self.w].keys())
        ret = ret.only(self.train)
        ret[self.valid] = val
        return ret

    def gfeed(self, src, batch_size, rng=np.random, auto_steps=np.max, **kwargs):
        """
        Creates a generator for tf.keras' model.fit().
        Requires, that mean weights per process are equal:
            dss["weight"] = dss["weight"].map(lambda x: x / np.mean(x))
        """
        src = self.compact(src)
        val = src[self.valid]
        val.blen  # asserts that all first dimensions have equal length
        return dict(
            dict(
                zip(
                    ("x", "steps_per_epoch"),
                    self.gensteps(src[self.train], batch_size, rng=rng, auto_steps=auto_steps),
                )
            ),
            validation_data=self.get(val),
            workers=0,
            **kwargs,
        )

    def gensteps(self, src, batch_size, rng=np.random, auto_steps=np.max):
        keys = src.mkeys(self.all)
        gen = (
            (
                self.get(DSS.zip(*parts).map(np.concatenate))
                for parts in zip(
                    *[
                        src[k].batch_generator(
                            batch_size // len(keys),
                            rng=np.random.RandomState(rng.randint(1 << 31, size=20)),
                        )
                        for k in keys
                    ]
                )
            )
            if len(keys) > 1
            else map(self.get, src[list(keys)[0]].batch_generator(batch_size, rng=rng))
        )
        gs = float(batch_size // len(keys))
        steps = int(auto_steps([src[k].blen / gs for k in keys])) or None
        return gen, steps

    def generator(self, *args, **kwargs):
        assert "auto_steps" not in kwargs
        return self.gensteps(*args, **kwargs)[0]

    def getShapes(self, src):
        return tuple(tuple(src[k].shape for k in g) for g in self.xyw)

    def getShapesK(self, src):
        strip_first_dim = lambda x: x[1:] if x[0] is None else x
        shapes = self.getShapes(src)
        return tuple(tuple(strip_first_dim(a) for a in b) for b in shapes)

    def mkInputs(self, src, **kwargs):
        return tuple(
            tf.keras.Input(shape=ref.shape[1:], dtype=ref.dtype, **kwargs)
            for ref in (src[x] for x in self.xyw[0])
        )


def Normal(ref, const=None, ignore_zeros=False, name=None, **kwargs):
    """
    Normalizing layer according to ref.
    If given, variables at the indices const will not be normalized.
    """
    if ignore_zeros:
        mean = np.nanmean(np.where(ref == 0, np.ones_like(ref) * np.nan, ref), **kwargs)
        std = np.nanstd(np.where(ref == 0, np.ones_like(ref) * np.nan, ref), **kwargs)
    else:
        mean = ref.mean(**kwargs)
        std = ref.std(**kwargs)
    if const is not None:
        mean[const] = 0
        std[const] = 1
    std = np.where(std == 0, 1, std)
    mul = 1.0 / std
    add = -mean / std
    return tf.keras.layers.Lambda((lambda x: (x * mul) + add), name=name)


def Onehot(index, n, name=None):
    """
    One hot encodes a variable referred to by index.
    n is the number of different variables.
    """

    def to_onehot(x):
        # Concat zeros to eye for indices in x equal to n (larger than those encoded by one)
        eye = tf.concat((tf.eye(n), tf.zeros((1, n))), axis=0)
        return tf.concat(
            (
                x[..., :index],
                tf.gather(eye, tf.cast(x[..., index], tf.int64)),
                x[..., (index + 1) :],
            ),
            axis=-1,
        )

    return tf.keras.layers.Lambda(to_onehot, name=name)


class Moment(tf.keras.metrics.Mean):
    def __init__(self, order, label=False, **kwargs):
        """ Metric calculating the order-th moment """
        assert order == int(order)
        assert label == bool(label)
        kwargs.setdefault("name", "%smom%d" % ("l" if label else "", order))
        super(Moment, self).__init__(**kwargs)
        self.order = order
        self.label = label

    def update_state(self, y_true, y_pred, sample_weight=None):
        y = y_true if self.label else y_pred
        y = tf.keras.backend.cast(y, self._dtype)

        if self.order == 0:
            y = tf.keras.backend.ones_like(y)
        elif self.order == 1:
            pass
        elif self.order == 2:
            y = tf.keras.backend.square(y)
        else:
            y = tf.keras.backend.pow(y, self.order)

        return super(Moment, self).update_state(y, sample_weight=sample_weight)

    def get_config(self):
        return dict(super(Moment, self).get_config(), order=self.order, label=self.label)


# lots of call backs


def _patfilter(pattern, items):
    if isinstance(pattern, (list, tuple)):
        pattern = "|".join(map(fnmatch.translate, pattern))
    return filter(re.compile(pattern).search, items)


class Moment2Std(tf.keras.callbacks.Callback):
    def __init__(self, mom1="mom1", mom2="mom2", std="std"):
        assert mom1 and mom2 and std
        pin(locals())

    def on_x_end(self, x, logs=None):
        if logs is None:
            return
        for key1, mom1 in list(logs.items()):
            if not key1.endswith(self.mom1):
                continue
            prefix = key1[: -len(self.mom1)]
            mom2 = logs.get(prefix + self.mom2, None)
            if mom2 is not None:
                logs[prefix + self.std] = (mom2 - mom1 ** 2) ** 0.5

    on_batch_end = on_x_end
    on_epoch_end = on_x_end


class LogRewrite(tf.keras.callbacks.Callback):
    def __init__(self, *rewrites, **kwargs):
        self.collapse_weighted = kwargs.pop("collapse_weighted", False)
        assert not kwargs
        self.rewrites = list(rewrites)
        if self.collapse_weighted:
            self.rewrites.append(lambda s: s.replace("weighted_", ""))

    def on_x_end(self, x, logs=None):
        if logs is None:
            return
        shadow = []
        for key, val in logs.items():
            short = key
            for rewrite in self.rewrites:
                short = rewrite(short)
            if short == key:
                continue
            if short in logs:
                shadow.append(short)
            logs[short] = val
            del logs[key]
        if shadow:
            warn("%r: shadows the following keys: %s" % (",".join(shadow)))

    on_batch_end = on_x_end
    on_epoch_end = on_x_end


class LogTransforms(tf.keras.callbacks.Callback):
    _prefixesRe = re.compile("^(val_|)(.+)")

    def __init__(self, *funcs, **kwargs):
        self.transforms = []
        for func in funcs:
            self.add_transform(func)
        for name, func in kwargs.items():
            self.add_transform(func, name)

    def add_transform(self, func, name=None):
        if name is None:
            name = func.__name__
        assert re.match(r"[a-zA-Z]\w*$", name)
        argspec = getargspec(func)
        assert argspec.keywords or argspec.args
        self.transforms.append((name, func, True if argspec.keywords else set(argspec.args)))

    def on_x_end(self, x, logs=None):
        if logs is None:
            return
        plogs = {}
        for key, value in logs.items():
            prefix, suffix = self._prefixesRe.match(key).groups()
            plogs.setdefault(prefix, {})[suffix] = value
        for name, func, args in self.transforms:
            for prefix, log in plogs.items():
                if args is True:
                    val = func(**log)
                else:
                    missing = args.difference(log.keys())
                    if missing:
                        msg = "%r: %s=%r needs unavailable log values: %s" % (
                            self,
                            name,
                            func,
                            ", ".join(missing),
                        )
                        if prefix:
                            warn(msg)
                            continue
                        else:
                            raise RuntimeError(msg)
                    val = func(**{key: val for key, val in log.items() if key in args})
                logs[prefix + name] = log[name] = val

    on_batch_end = on_x_end
    on_epoch_end = on_x_end


class LogEMA(tf.keras.callbacks.Callback):
    def __init__(self, keys=[], pattern=[], pairs={}, ema=0, ricu=0, ema_fmt="%s_ema"):
        assert ema < 1 and ricu < 1 and "%s" in ema_fmt
        if ema < 0:
            ema = 1 - 10 ** ema
            assert ema < 1
        if ricu < 0:
            ricu = -1.0 / ricu
        keys = {k: None for k in keys}
        keys.update(pairs)
        pin(locals(), pairs)
        self.reset()

    def reset(self):
        self.ema_val = {}
        self.run_num = defaultdict(int)

    def on_epoch_end(self, x, logs=None):
        assert logs
        keys = {key: None for key in _patfilter(self.pattern, logs.keys())}
        keys.update(self.keys)
        for key, out in keys.items():
            if not out:
                out = self.ema_fmt % key
            assert key in logs
            assert out not in keys
            assert out not in logs
            val = logs[key]
            if key in self.ema_val:
                delta = self.ema_val[key] - val
                dsign = int(np.sign(delta))
                self.run_num[key] += dsign
                if dsign * np.sign(self.run_num[key]) < -self.ricu < 0:
                    self.run_num[key] *= self.ricu ** 0.5
                val += delta * self.ema ** (1 + (self.ricu * self.run_num[key]) ** 4)
            logs[out] = self.ema_val[key] = val


class CustomValidation(tf.keras.callbacks.Callback):
    def __init__(self, **kwargs):
        self.kwargs = kwargs

    def on_epoch_end(self, x, logs=None):
        if logs is None:
            return
        res = self.model.evaluate(**self.kwargs)
        if not isinstance(res, list):
            res = [res]
        logs.update(make_logs(self.model, logs, res, mode=ModeKeys.TEST, prefix="val_"))


class TFSummaryCallback(tf.keras.callbacks.Callback):
    def __init__(self, logdir=None, **kwargs):
        self.writer = tf.summary.create_file_writer(logdir)


class PlotMulticlass(TFSummaryCallback):
    def __init__(
        self,
        x,
        y,
        sample_weight=None,
        class_names=["signal", "background"],
        to_file=False,
        columns=None,
        plot_inputs=False,
        signalvsbkg=False,
        plot_importance=False,
        plot_activations=False,
        tag="",
        **kwargs,
    ):
        super().__init__(**kwargs)
        self.x = x
        self.truth = y[0]
        self.sample_weight = sample_weight[0]
        self.class_names = class_names
        self.plot_inputs = plot_inputs
        self.columns = columns
        self.to_file = to_file
        self.signalvsbkg = signalvsbkg
        self.tag = tag
        self.plot_importance = plot_importance
        self.plot_activations = plot_activations

    def on_train_begin(self, logs=None):
        self.make_input_plots()

    def on_epoch_end(self, epoch, logs=None):
        self.make_eval_plots(epoch)

    def make_input_plots(self):
        if self.plot_inputs:
            imgs = {}
            if self.columns:
                inps = self.x
                if not isinstance(inps, (list, tuple)):
                    inps = [inps]

                for part, inp in zip(self.columns.keys(), inps):
                    imgs[f"inp_xmerged_{part}"] = figure_to_image(
                        figure_multihist(inp, columns=self.columns[part])
                    )
            if self.sample_weight is not None:
                for part, inp in zip(self.columns.keys(), inps):
                    imgs[f"inp_x_{part}"] = figure_to_image(
                        figure_inputs(
                            inp,
                            self.truth,
                            sample_weight=self.sample_weight,
                            columns=self.columns[part],
                            class_names=self.class_names,
                            signalvsbkg=self.signalvsbkg,
                        )
                    )

                imgs["inp_weights"] = figure_to_image(
                    figure_weights(self.sample_weight, self.truth, class_names=self.class_names)
                )
            imgs["inp_y"] = figure_to_image(figure_y(self.truth, class_names=self.class_names))
            imgs["inp_yrelative"] = figure_to_image(
                figure_y(self.truth, class_names=self.class_names, relative=True)
            )
            for name, img in imgs.items():
                with self.writer.as_default():
                    tf.summary.image(f"{name}{self.tag}", img, step=0)

    def clear_figure(self, fig):
        fig.clf()
        del fig
        plt.close("all")
        gc.collect()

    def make_eval_plots(self, epoch):
        prediction = self.model.predict(self.x, batch_size=4096)
        truth = self.truth
        imgs = {}

        fig = figure_roc_curve(
            truth,
            prediction,
            class_names=self.class_names,
            sample_weight=self.sample_weight,
        )
        imgs["roc_curve"] = figure_to_image(fig)
        self.clear_figure(fig)

        fig = figure_confusion_matrix(
            truth,
            prediction,
            class_names=self.class_names,
            sample_weight=self.sample_weight,
            normalize="true",
        )
        imgs["confusion_matrix_true"] = figure_to_image(fig)
        self.clear_figure(fig)

        fig = figure_confusion_matrix(
            truth,
            prediction,
            class_names=self.class_names,
            sample_weight=self.sample_weight,
            normalize="pred",
        )
        imgs["confusion_matrix_pred"] = figure_to_image(fig)
        self.clear_figure(fig)

        if self.plot_activations:
            fig = figure_node_activations(
                prediction,
                truth,
                class_names=self.class_names,
                disjoint=True,
                sample_weight=self.sample_weight,
            )
            imgs["node_activation_disjoint"] = figure_to_image(fig)
            self.clear_figure(fig)

        if self.plot_importance:
            importance = feature_importance(
                self.model,
                x=[feat[:5000] for feat in self.x],
                y=self.truth[:5000],
                sample_weight=self.sample_weight[:5000],
                method="grad",
                columns=[c for col in self.columns.values() for c in col],
            )
            fig = figure_dict(importance)
            imgs["importance"] = figure_to_image(fig)
            self.clear_figure(fig)

        for name, img in imgs.items():
            with self.writer.as_default():
                tf.summary.image(f"{name}{self.tag}", img, step=epoch)
        del imgs


class PlotMulticlassEval(PlotMulticlass):
    def on_test_end(self, logs=None):
        self.make_input_plots()
        self.make_eval_plots(0)


class ModelLH(tf.keras.Model):
    def __init__(self, *args, **kwargs):
        self.loss_hook = kwargs.pop("loss_hook", None)
        super(ModelLH, self).__init__(*args, **kwargs)

    def _update_sample_weight_modes(self, sample_weights=None):
        if not self._is_compiled:
            return
        if sample_weights and any([s is not None for s in sample_weights]):
            pass
            # don't default sample_weight_mode to "samplewise", it prevents proper function caching
            # for endpoint in self._training_endpoints:
            #     endpoint.sample_weight_mode = (
            #         endpoint.sample_weight_mode or 'samplewise')
        else:
            for endpoint in self._training_endpoints:
                endpoint.sample_weight_mode = None

    def _prepare_total_loss(self, *args, **kwargs):
        orig = [
            (ep, ep.__dict__.copy(), ep.training_target.__dict__.copy())
            for ep in self._training_endpoints
        ]

        self.loss_hook(self._training_endpoints.copy())
        ret = super(ModelLH, self)._prepare_total_loss(*args, **kwargs)

        for ep, ed, td in orig:
            ep.__dict__.update(ed)
            ep.training_target.__dict__.update(td)

        return ret


class GCCallback(tf.keras.callbacks.Callback):
    def on_epoch_end(self, epoch, logs=None):
        gc.collect()


class TensorBoard(tf.keras.callbacks.TensorBoard):
    def __init__(self, *args, **kwargs):
        self.writer = kwargs.pop("writer")
        super(TensorBoard, self).__init__(*args, **kwargs)

    def _init_writer(self, model=None):
        pass

    def on_train_end(self, logs=None):
        pass


class CheckpointModel(tf.keras.callbacks.Callback):
    """Gets a dict of targets (checkpoints), if current epoch is in dict, save model to target."""

    def __init__(self, checkpoints=None):
        self.targets = checkpoints.targets

    def on_epoch_end(self, epoch, logs=None):
        if epoch in self.targets:
            target = self.targets[epoch]
            self.model.save(target.path)


class BestTracker(tf.keras.callbacks.Callback):
    def __init__(
        self,
        monitor="val_loss",
        mode="auto",
        min_delta=0,
        min_delta_rel=0,
        baseline=None,
    ):
        pin(locals())
        self.reset()

    @property
    def mode_multiplier(self):
        assert self.mode in ("auto", "min", "max")
        if self.mode == "max" or (self.mode == "auto" and "acc" in self.monitor):
            return -1
        else:
            return 1

    def reset(self):
        if self.baseline is not None:
            self.best = self.baseline
        else:
            self.best = self.mode_multiplier * np.inf

    def update_best(self, logs):
        current = self.get_monitor_value(logs)
        relative = delta = (current - self.best) * self.mode_multiplier
        if np.isfinite(self.best):
            relative /= abs(self.best) or 1.0
        update = delta < -self.min_delta and relative < -self.min_delta_rel
        if update:
            self.best = current
        return update

    def get_monitor_value(self, logs):
        assert logs
        assert self.monitor in logs
        return logs[self.monitor]


class PatientTracker(BestTracker):
    def __init__(
        self,
        patience=10,
        cooldown=0,
        cooldown0=0,
        override=None,
        jank=np.nan,
        jank_last=-np.inf,
        **kwargs,
    ):
        if override is None:
            del override
        cooldown_counter = cooldown if cooldown0 is True else cooldown0
        pin(locals(), kwargs)
        super(PatientTracker, self).__init__(**kwargs)

    def reset(self):
        super(PatientTracker, self).reset()
        self.wait = 0

    def override(self, logs):
        return None

    def patient_step(self, epoch, logs):
        # update jank
        logs["jank"] = min(logs.get("jank", np.inf), epoch - self.jank_last)
        # test override
        override = self.override(logs)
        if override is not None:
            return override
        # check jank
        if logs["jank"] <= self.jank:
            return
        if 0 < self.cooldown_counter:
            self.cooldown_counter -= 1
        if self.update_best(logs):
            self.wait = 0
            return "best"
        elif self.cooldown_counter <= 0:
            self.wait += 1
            if self.patience <= self.wait:
                self.cooldown_counter = self.cooldown
                return "good"


class ScaleOnPlateau(PatientTracker):
    def __init__(
        self,
        target,
        factor,
        min=None,
        max=None,
        verbose=0,
        log_key=None,
        linearly=False,
        **kwargs,
    ):
        pin(locals(), kwargs)
        super(ScaleOnPlateau, self).__init__(**kwargs)

    def on_train_begin(self, logs=None):
        self.reset()

    @property
    def value(self):
        return tf.keras.backend.get_value(self.target)

    @value.setter
    def value(self, new):
        return tf.keras.backend.set_value(self.target, new)

    def on_epoch_end(self, epoch, logs=None):
        logs = logs or {}
        cur = self.value
        if self.log_key is not None:
            logs.setdefault(self.log_key, cur)
        if self.patient_step(epoch, logs) == "good":
            if self.linearly:
                new = cur + self.factor
            else:
                new = cur * self.factor
            if self.min is not None:
                new = max(new, self.min)
            if self.max is not None:
                new = min(new, self.max)
            if cur != new:
                self.value = new
                self.reset()
                if np.isfinite(self.jank):
                    self.jank_last = epoch
                if self.verbose > 0:
                    print(
                        "\nEpoch %05d: %s scaling %s to %s."
                        % (
                            epoch + 1,
                            self.__class__.__name__,
                            self.log_key or self.target,
                            new,
                        )
                    )


class ReduceLROnPlateau(ScaleOnPlateau):
    def __init__(self, min_lr=0, **kwargs):
        super(ReduceLROnPlateau, self).__init__(min=min_lr, target=None, log_key="lr", **kwargs)

    @property
    def target(self):
        return self.model.optimizer.lr

    @target.setter
    def target(self, target):
        assert target is None


class EarlyStopping(PatientTracker):
    def __init__(self, restore_best_weights=False, verbose=0, do_stop=None, **kwargs):
        pin(locals(), kwargs)
        super(EarlyStopping, self).__init__(**kwargs)

    def reset(self):
        super(EarlyStopping, self).reset()
        self.best_weights = None
        self.best_epoch = None

    def on_train_begin(self, logs=None):
        self.reset()

    def on_epoch_end(self, epoch, logs=None):
        action = self.patient_step(epoch, logs)
        if action == "best":
            if self.restore_best_weights:
                self.best_weights = self.model.get_weights()
                self.best_epoch = epoch
        elif action == "good":
            self.stopped_epoch = epoch
            self.model.stop_training = True
            if self.restore_best_weights and self.best_weights is not None:
                self.model.set_weights(self.best_weights)
                if self.verbose > 0:
                    print("Restoring model weights from the end of the best epoch.")

    def on_train_end(self, logs=None):
        hist = self.model.history
        if self.restore_best_weights and self.best_epoch in hist.epoch:
            idx = hist.epoch.index(self.best_epoch)
        else:
            idx = -1
        hist.final_logs = {k: v[idx] for k, v in hist.history.items()}


class TQES(EarlyStopping):
    def __init__(self, log_pattern=None, prog_batch="steps", **kwargs):
        assert prog_batch in (None, "steps", "samples")
        pin(locals(), kwargs)
        super(TQES, self).__init__(**kwargs)

    @property
    def use_steps(self):
        return self.prog_batch == "steps"

    def on_train_begin(self, logs=None):
        super(TQES, self).on_train_begin(logs)
        self.tqE = tqdm(desc=getattr(self.model, "name", None), total=self.params["epochs"])

    def on_epoch_begin(self, epoch, logs=None):
        if self.prog_batch:
            self.tqB = tqdm(
                unit=("batch" if self.use_steps else "sample"),
                total=self.params["steps" if self.use_steps else "samples"],
            )

    def on_batch_end(self, batch, logs=None):
        if self.tqB:
            logs = logs or {}
            self.tqB.set_postfix(self.make_postfix(logs), refresh=False)
            self.tqB.update(
                logs.get("num_steps", 1) * (1 if self.use_steps else logs.get("size", 0))
            )

    def on_epoch_end(self, epoch, logs=None):
        super(TQES, self).on_epoch_end(epoch, logs)
        last = self.get_monitor_value(logs)
        if self.tqB:
            self.tqB.close()
            self.tqB = None
        self.tqE.set_postfix(
            self.make_postfix(
                logs,
                [
                    ("best", self.best),
                    ("conf", self.wait),
                    ("rdlb", ((last or 0.0) - self.best) / self.best),
                ],
            ),
            refresh=False,
        )
        self.tqE.update(epoch - self.tqE.n)

    def make_postfix(self, logs, extra=[]):
        if self.log_pattern is None:
            keys = self.params["metrics"]
        else:
            keys = _patfilter(self.log_pattern, logs.keys())
        return OrderedDict(
            [(key, logs[key]) for key in sorted(keys) if key in logs] + list(filter(None, extra))
        )

    def on_train_end(self, logs=None):
        super(TQES, self).on_train_end(logs)
        self.tqE.close()


class AUCOneVsAll(tf.keras.metrics.AUC):
    def __init__(self, one=0, *args, **kwargs):
        self.one = one
        super().__init__(*args, **kwargs)

    def update_state(self, y_true, y_pred, sample_weight=None):
        from tensorflow.python.framework import ops
        from tensorflow.python.framework import tensor_shape
        from tensorflow.python.keras.utils import metrics_utils
        from tensorflow.python.ops import check_ops

        deps = []
        if not self._built:
            self._build(tensor_shape.TensorShape(y_pred.shape))

        if self.multi_label or (self.label_weights is not None):
            # y_true should have shape (number of examples, number of labels).
            shapes = [(y_true, ("N", "L"))]
            if self.multi_label:
                # TP, TN, FP, and FN should all have shape
                # (number of thresholds, number of labels).
                shapes.extend(
                    [
                        (self.true_positives, ("T", "L")),
                        (self.true_negatives, ("T", "L")),
                        (self.false_positives, ("T", "L")),
                        (self.false_negatives, ("T", "L")),
                    ]
                )
            if self.label_weights is not None:
                # label_weights should be of length equal to the number of labels.
                shapes.append((self.label_weights, ("L",)))
            deps = [check_ops.assert_shapes(shapes, message="Number of labels is not consistent.")]

        # Only forward label_weights to update_confusion_matrix_variables when
        # multi_label is False. Otherwise the averaging of individual label AUCs is
        # handled in AUC.result
        label_weights = None if self.multi_label else self.label_weights
        with ops.control_dependencies(deps):
            return metrics_utils.update_confusion_matrix_variables(
                {
                    metrics_utils.ConfusionMatrix.TRUE_POSITIVES: self.true_positives,
                    metrics_utils.ConfusionMatrix.TRUE_NEGATIVES: self.true_negatives,
                    metrics_utils.ConfusionMatrix.FALSE_POSITIVES: self.false_positives,
                    metrics_utils.ConfusionMatrix.FALSE_NEGATIVES: self.false_negatives,
                },
                y_true[:, self.one],
                y_pred[:, self.one],
                self.thresholds,
                sample_weight=sample_weight,
                multi_label=self.multi_label,
                label_weights=label_weights,
            )


# Layer Definitions


class WhereEquals(tf.keras.layers.Layer):
    def __init__(self, value=0):
        super(WhereEquals, self).__init__()
        self.value = value

    def call(self, inp):
        return tf.where(inp[:, 0] == self.value)


class DenseLayer(tf.keras.layers.Layer):
    """
    The DenseLayer object is an extended implementation of the tf.keras.layers.Dense.
    It features:
        * l2 regu
        * the weights (the real layer)
        * batch norm
        * activation function
        * dynamically chosen dropout

    Parameters
    ----------
    nodes : int
        The number of nodes.
    activation : str or one of tf.keras.activations
        The used activation function.
    dropout : float
        The used dropout ration.
        If "selu" is used as activation function, dropout becomes AlphaDropout.
    l2 : float
        The used factor of l2 regu.
    batch_norm : bool
        Wether to use dropout or not.
        If batch_norm is used, dropout is forced off.

    """

    def __init__(self, nodes=0, activation=None, dropout=0.0, l2=0, batch_norm=False):
        super().__init__(name="DenseLayer")
        self.nodes = nodes
        self.activation = activation
        self.dropout = dropout
        self.l2 = l2
        self.batch_norm = batch_norm

    def build(self, input_shape):
        parts = []

        l2 = tf.keras.regularizers.l2(self.l2)
        weights = tf.keras.layers.Dense(self.nodes, kernel_regularizer=l2)
        parts.append(weights)

        if self.batch_norm:
            self.dropout = 0.0
            bn = tf.keras.layers.BatchNormalization()
            parts.append(bn)

        act = tf.keras.layers.Activation(self.activation)
        parts.append(act)

        if self.activation == "selu":
            dropout = tf.keras.layers.AlphaDropout(self.dropout)
        else:
            dropout = tf.keras.layers.Dropout(self.dropout)

        parts.append(dropout)
        self.parts = parts

    def call(self, input_tensor, training=False):
        x = input_tensor
        for part in self.parts:
            x = part(x, training=training)
        return x

    def get_config(self):
        return {
            "nodes": self.nodes,
            "activation": self.activation,
            "dropout": self.dropout,
            "l2": self.l2,
            "batch_norm": self.batch_norm,
        }


class ResNetBlock(tf.keras.layers.Layer):
    """
    The ResNetBlock object is an implementation of one residual DNN block.

    Parameters
    ----------
    jump : int
        The number layers to bypass.
    kwargs :
        Arguments for DenseLayer.

    """

    def __init__(self, config, jump=2, **kwargs):
        super().__init__(name="ResNetBlock")
        self.jump = jump
        layers = []
        for i in range(self.jump - 1):
            layers.append(DenseLayer(**kwargs))

        activation = kwargs.pop("activation")
        layers.append(DenseLayer(**kwargs))
        self.layers = layers
        self.out_activation = tf.keras.layers.Activation(activation)

    def call(self, input_tensor, training=False):
        x = input_tensor
        for layer in self.layers:
            x = layer(x, training=training)
        x += input_tensor
        x = self.out_activation(x)
        return x

    def get_config(self):
        return {"jump": self.jump}


class FullyConnected(tf.keras.layers.Layer):
    """
    The FullyConnected object is an implementation of a fully connected DNN.

    Parameters
    ----------
    layers : int
        The number of layers.
    kwargs :
        Arguments for DenseLayer.

    """

    def __init__(self, layers=0, sub_kwargs=None, **kwargs):
        super().__init__(name="FullyConnected")
        self.layers = layers
        self.sub_kwargs = kwargs if sub_kwargs is None else sub_kwargs

    def build(self, input_shape):
        network_layers = []
        for layer in range(self.layers):
            network_layers.append(DenseLayer(**self.sub_kwargs))
        self.network_layers = network_layers

    def call(self, input_tensor, training=False):
        x = input_tensor
        for layer in self.network_layers:
            x = layer(x, training=training)
        return x

    def get_config(self):
        return {"layers": self.layers, "sub_kwargs": self.sub_kwargs}


class ResNet(tf.keras.layers.Layer):
    """
    The ResNet object is an implementation of a Residual Neural Network.

    Parameters
    ----------
    layers : int
        The number of residual blocks.
    kwargs :
        Arguments for ResNetBlock.

    """

    def __init__(self, layers=1, **kwargs):
        super().__init__(name="ResNet")
        self.layers = layers
        self.kwargs = kwargs

    def build(self, input_shape):
        _layers = []
        for i in range(self.layers):
            _layers.append(ResNetBlock(**self.kwargs))
        self._layers = _layers

    def call(self, input_tensor, training=False):
        x = input_tensor
        for layer in self._layers:
            x = layer(x, training=training)
        return x

    def get_config(self):
        return {"layers": self.layers}


class SplitHighLow(tf.keras.layers.Layer):
    def call(self, inputs):
        return inputs[:, :, :4], inputs[:, :, 4:]


class LL(tf.keras.layers.Layer):
    def call(self, inputs):
        return inputs[:, :, :4]


class LBNLayer(tf.keras.layers.Layer):
    """
    Custom implementation of the LBNLayer with automatic cropping to
    low-level variables before and batchnorm after the LBN layer.

    Parameters
    ----------
    args :
        args for original LBN layer
    kwargs :
        kwargs for original LBN layer

    """

    def __init__(self, *args, **kwargs):
        super().__init__(name="LBNLayer")
        _kwargs = kwargs.copy()
        if "features" not in _kwargs:
            _kwargs["features"] = self.all_features
        self._args = args
        self._kwargs = _kwargs

    def build(self, input_shape):
        self.concat = tf.keras.layers.Concatenate(axis=-2)
        lbn_inp_shape = (sum(_[-2] for _ in input_shape), 4)
        self.lbn_layer = OriginalLBNLayer(lbn_inp_shape, *self._args, **self._kwargs)
        self.batch_norm = tf.keras.layers.BatchNormalization()

    @property
    def all_features(self):
        return [
            "E",
            "px",
            "py",
            "pz",
            "pt",
            "p",
            "m",
            "phi",
            "eta",
            "beta",
            "gamma",
            "pair_cos",
            "pair_dr",
            "pair_ds",
            "pair_dy",
        ]

    def call(self, input_tensors, training=False):
        ll = [SplitHighLow()(tensor)[0] for tensor in input_tensors]
        ll = self.concat(ll)
        feats = self.lbn_layer(ll)
        feats = self.batch_norm(feats)
        return feats


def feature_importance_grad(model, x=None, **kwargs):
    inp = [tf.Variable(v) for v in x]
    with tf.GradientTape() as tape:
        pred = model(inp, training=False)
        ix = np.argsort(pred, axis=-1)[:, -1]
        decision = tf.gather(pred, ix, batch_dims=1)

    gradients = tape.gradient(decision, inp)  # gradients for decision nodes
    normed_gradients = [_g * _x for (_g, _x) in zip(gradients, x)]  # normed to input values

    mean_gradients = np.concatenate(
        [np.abs(g.numpy()).mean(axis=0).flatten() for g in normed_gradients]
    )
    return mean_gradients / mean_gradients.max()


def feature_importance_perm(model, x=None, **kwargs):
    inp_list = list(x)
    feat = 1  # acc
    ref = model.evaluate(x=x, **kwargs)[feat]
    accs = []
    for index, tensor in enumerate(inp_list):
        s = tensor.shape
        for i in range(np.prod(s[1:])):
            arr = tensor.reshape((-1, np.prod(s[1:])))

            slice_before = arr[:, :i]
            slice_shuffled = np.random.permutation(arr[:, i : i + 1])
            slice_after = arr[:, i + 1 :]
            arr_shuffled = np.concatenate([slice_before, slice_shuffled, slice_after], axis=-1)

            arr_shuffled_reshaped = arr_shuffled.reshape(s)

            valid = inp_list.copy()
            valid[index] = arr_shuffled_reshaped
            accs.append(model.evaluate(x=valid, **kwargs)[feat])
    return ref / np.array(accs)


def feature_importance(*args, method="grad", columns=[], **kwargs):
    if method == "grad":
        importance = feature_importance_grad(*args, **kwargs)
    elif method == "perm":
        importance = feature_importance_perm(*args, **kwargs)
    else:
        raise NotImplementedError("Feature importance method not implemented")
    return {
        k: v
        for k, v in sorted(
            dict(zip(columns, importance.astype(float))).items(),
            key=lambda item: item[1],
        )
    }


@tf.function
def grouped_cross_entropy_t(
    labels,
    predictions,
    sample_weight=None,
    group_ids=None,
    focal_gamma=None,
    class_weight=None,
    epsilon=1e-7,
):
    assert group_ids is not None
    # get true-negative component
    predictions = tf.clip_by_value(predictions, epsilon, 1 - epsilon)
    tn = labels * tf.math.log(predictions)
    # focal loss?
    if focal_gamma is not None:
        tn *= (1 - predictions) ** focal_gamma
    # convert into loss
    losses = -tn
    # apply class weights
    if class_weight is not None:
        losses *= class_weight
    # apply sample weights
    if sample_weight is not None:
        losses *= sample_weight[:, tf.newaxis]
    # create grouped labels and predictions
    labels_grouped = []
    for _, ids in group_ids:
        labels_grouped.append(
            tf.reduce_sum(tf.gather(labels, ids, axis=-1), axis=-1, keepdims=True)
        )
    labels_grouped = (
        tf.concat(labels_grouped, axis=-1) if len(labels_grouped) > 1 else labels_grouped[0]
    )
    predictions_grouped = []
    for _, ids in group_ids:
        predictions_grouped.append(
            tf.reduce_sum(tf.gather(predictions, ids, axis=-1), axis=-1, keepdims=True)
        )
    predictions_grouped = (
        tf.concat(predictions_grouped, axis=-1)
        if len(predictions_grouped) > 1
        else predictions_grouped[0]
    )

    predictions_grouped = tf.clip_by_value(predictions_grouped, epsilon, 1 - epsilon)
    # grouped true-negative component
    tn_grouped = labels_grouped * tf.math.log(predictions_grouped)
    # focal loss?
    if focal_gamma is not None:
        tn_grouped *= (1 - predictions_grouped) ** focal_gamma
    # convert into loss and apply group weights
    group_weights = tf.constant([w for w, _ in group_ids], tf.float32)
    losses_grouped = -tn_grouped * group_weights
    # apply sample weights
    if sample_weight is not None:
        losses_grouped *= sample_weight[:, tf.newaxis]
    # combine losses
    loss = tf.reduce_mean(
        0.5 * (tf.reduce_sum(losses, axis=-1) + tf.reduce_sum(losses_grouped, axis=-1))
    )
    return loss


# Custom Loss Functions


class GroupedXEnt(tf.keras.losses.Loss):
    def __init__(
        self,
        group_ids=None,
        focal_gamma=None,
        class_weight=None,
        epsilon=1e-7,
        *args,
        **kwargs,
    ):
        super(GroupedXEnt, self).__init__(*args, **kwargs)
        self.group_ids = group_ids
        self.focal_gamma = focal_gamma
        self.class_weight = class_weight
        self.epsilon = epsilon

    def call(self, y_true, y_pred, sample_weight=None):
        return grouped_cross_entropy_t(
            y_true,
            y_pred,
            sample_weight=sample_weight,
            group_ids=self.group_ids,
            focal_gamma=self.focal_gamma,
            class_weight=self.class_weight,
            epsilon=self.epsilon,
        )


def write_summary(model, target):
    """write tf keras model summary to law target"""
    summary = []
    model.summary(print_fn=lambda x: summary.append(x))
    summary = "\n".join(summary)
    target.dump(summary)