# from itertools import izip
import itertools
from functools import cached_property
import gc
from collections import OrderedDict, defaultdict
import fnmatch
import math
import datetime

import numpy as np
import tensorflow as tf
from tqdm import tqdm
from inspect import getargspec
from warnings import warn
import re
from tensorflow.python.keras.callbacks import make_logs
from tensorflow.python.keras.backend import track_variable
from tensorflow.python.keras.utils.mode_keys import ModeKeys
from operator import itemgetter
from lbn import LBNLayer as OriginalLBNLayer
from matplotlib import pyplot as plt

from .evil import pin
from .data import SKDict, DSS
from .plotting import (
    figure_to_image,
    figure_confusion_matrix,
    figure_activations,
    figure_node_activations,
    figure_roc_curve,
    figure_y,
    figure_weights,
    figure_correlation,
    figure_inputs,
    figure_dict,
    figure_lbn_weights,
)
from .numpy import mean_excl, std_excl

# various helper functions


def kVar(*args, **kwargs):
    """produce a keras-tracked tf.Variable from all given parameters"""
    var = tf.Variable(*args, **kwargs)
    track_variable(var)
    return var


def kOpt(opt, **kwargs):
    """instanciate a keras Optimizer with all applicable **kwargs"""
    if not callable(opt):
        opt = getattr(tf.keras.optimizers, opt)
    assert issubclass(opt, tf.keras.optimizers.Optimizer)
    args = getargspec(opt.__init__).args - {"self"}
    return opt(**{k: v for k, v in kwargs.items() if k in args})


def kInput(ref, **kwargs):
    """produce a keras.Input with shape & dtype according to ref"""
    kwargs.setdefault("shape", ref.shape[1:])
    kwargs.setdefault("dtype", ref.dtype)
    return tf.keras.Input(**kwargs)


def keras_register_custom_object(obj):
    """decorator for globally registering a custom object with keras"""
    tf.keras.utils.get_custom_objects()[obj.__name__] = obj
    return obj


class KFeed(object):
    def __init__(self, x, y, w=None, train="train", valid="valid", balance=lambda x: x):
        pin(locals())

    @property
    def xyw(self):
        return tuple(v if isinstance(v, tuple) else (v,) for v in (self.x, self.y, self.w))

    def get(self, src):
        return tuple(tuple(src[k] for k in v) for v in self.xyw)

    @property
    def all(self):
        return set().union(*self.xyw)

    def balance(self, src):
        return src

    def kfeed(self, src, **kwargs):
        src = src.only(*self.all)
        bal = self.balance(src)
        return dict(
            zip(["x", "y", "sample_weight"], self.get(bal[self.train])),
            validation_data=self.get(bal[self.valid]),
            **kwargs,
        )

    def balance_weights(self, weights):
        if isinstance(weights, SKDict):
            sums = weights.map(np.sum)
            if len(sums) > 1:
                ref = np.mean(list(sums.values()))
                weights = weights.__class__({k: weights[k] * (ref / s) for k, s in sums.items()})
        return weights

    def compact(self, src, keep=()):
        ret = src.only(*self.all, *keep)
        val = ret[self.valid]
        assert not isinstance(self.w, tuple)
        if isinstance(val[self.w], SKDict):
            val[self.w] = self.balance_weights(val[self.w])
            val = val.fuse(*val[self.w].keys())
        ret = ret.only(self.train)
        ret[self.valid] = val
        return ret

    def gfeed(self, src, batch_size, rng=np.random, auto_steps=np.max, **kwargs):
        """
        Creates a generator for tf.keras' model.fit().
        Requires, that mean weights per process are equal:
            dss["weight"] = dss["weight"].map(lambda x: x / np.mean(x))
        """
        src = self.compact(src)
        val = src[self.valid]
        val.blen  # asserts that all first dimensions have equal length
        return dict(
            dict(
                zip(
                    ("x", "steps_per_epoch"),
                    self.gensteps(src[self.train], batch_size, rng=rng, auto_steps=auto_steps),
                )
            ),
            validation_data=self.get(val),
            workers=0,
            **kwargs,
        )

    def gensteps(self, src, batch_size, rng=np.random, auto_steps=np.max):
        keys = src.mkeys(self.all)
        gen = (
            (
                self.get(DSS.zip(*parts).map(np.concatenate))
                for parts in zip(
                    *[
                        src[k].batch_generator(
                            batch_size // len(keys),
                            rng=np.random.RandomState(rng.randint(1 << 31, size=20)),
                        )
                        for k in keys
                    ]
                )
            )
            if len(keys) > 1
            else map(self.get, src[list(keys)[0]].batch_generator(batch_size, rng=rng))
        )
        gs = float(batch_size // len(keys))
        steps = int(auto_steps([src[k].blen / gs for k in keys])) or None
        return gen, steps

    def generator(self, *args, **kwargs):
        assert "auto_steps" not in kwargs
        return self.gensteps(*args, **kwargs)[0]

    def getShapes(self, src):
        return tuple(tuple(src[k].shape for k in g) for g in self.xyw)

    def getShapesK(self, src):
        strip_first_dim = lambda x: x[1:] if x[0] is None else x
        shapes = self.getShapes(src)
        return tuple(tuple(strip_first_dim(a) for a in b) for b in shapes)

    def mkInputs(self, src, **kwargs):
        return tuple(
            tf.keras.Input(shape=ref.shape[1:], dtype=ref.dtype, **kwargs)
            for ref in (src[x] for x in self.xyw[0])
        )


def Normal(ref, const=None, ignore_zeros=False, name=None, **kwargs):
    """
    Normalizing layer according to ref.
    If given, variables at the indices const will not be normalized.
    """
    if ignore_zeros:
        mean = mean_excl(ref, value=0, **kwargs)
        mean = np.nan_to_num(mean)
        std = std_excl(ref, value=0, **kwargs)
        std = np.nan_to_num(std)
    else:
        mean = ref.mean(**kwargs)
        std = ref.std(**kwargs)
    if const is not None:
        mean[const] = 0
        std[const] = 1

    std = np.where(std == 0, 1, std)
    mul = 1.0 / std
    add = -mean / std
    return tf.keras.layers.Lambda((lambda x: (x * mul) + add), name=name)


def Onehot(index, n, name=None):
    """
    One hot encodes a variable referred to by index.
    n is the number of different variables.
    """

    def to_onehot(x):
        # Concat zeros to eye for indices in x equal to n (larger than those encoded by one)
        eye = tf.concat((tf.eye(n), tf.zeros((1, n))), axis=0)
        return tf.concat(
            (
                x[..., :index],
                tf.gather(eye, tf.cast(x[..., index], tf.int64)),
                x[..., (index + 1) :],
            ),
            axis=-1,
        )

    return tf.keras.layers.Lambda(to_onehot, name=name)


class Moment(tf.keras.metrics.Mean):
    def __init__(self, order, label=False, **kwargs):
        """Metric calculating the order-th moment"""
        assert order == int(order)
        assert label == bool(label)
        kwargs.setdefault("name", "%smom%d" % ("l" if label else "", order))
        super(Moment, self).__init__(**kwargs)
        self.order = order
        self.label = label

    def update_state(self, y_true, y_pred, sample_weight=None):
        y = y_true if self.label else y_pred
        y = tf.keras.backend.cast(y, self._dtype)

        if self.order == 0:
            y = tf.keras.backend.ones_like(y)
        elif self.order == 1:
            pass
        elif self.order == 2:
            y = tf.keras.backend.square(y)
        else:
            y = tf.keras.backend.pow(y, self.order)

        return super(Moment, self).update_state(y, sample_weight=sample_weight)

    def get_config(self):
        return dict(super(Moment, self).get_config(), order=self.order, label=self.label)


# lots of call backs


def _patfilter(pattern, items):
    if isinstance(pattern, (list, tuple)):
        pattern = "|".join(map(fnmatch.translate, pattern))
    return filter(re.compile(pattern).search, items)


try:
    import nvidia_smi
except ImportError:
    pass
else:

    class GPUStats(tf.keras.callbacks.Callback):
        """
        conda: conda install -c fastai nvidia-ml-py3
        pip  : pip install nvidia-ml-py3
        """

        def __init__(self, idx=0):
            nvidia_smi.nvmlInit()
            self.handle = nvidia_smi.nvmlDeviceGetHandleByIndex(idx)

        def on_x_end(self, x, logs=None):
            self.mem = nvidia_smi.nvmlDeviceGetMemoryInfo(self.handle)
            self.res = nvidia_smi.nvmlDeviceGetUtilizationRates(self.handle)
            logs["GPU-Usage [%]"] = self.res.gpu
            logs["GPU-vRAM [%]"] = 100 * self.mem.used / self.mem.total
            logs["GPU-vRAM [MiB]"] = self.mem.used / (1024 ** 2)

        on_batch_end = on_x_end


class Moment2Std(tf.keras.callbacks.Callback):
    def __init__(self, mom1="mom1", mom2="mom2", std="std"):
        assert mom1 and mom2 and std
        pin(locals())

    def on_x_end(self, x, logs=None):
        if logs is None:
            return
        for key1, mom1 in list(logs.items()):
            if not key1.endswith(self.mom1):
                continue
            prefix = key1[: -len(self.mom1)]
            mom2 = logs.get(prefix + self.mom2, None)
            if mom2 is not None:
                logs[prefix + self.std] = (mom2 - mom1 ** 2) ** 0.5

    on_batch_end = on_x_end
    on_epoch_end = on_x_end


class LogRewrite(tf.keras.callbacks.Callback):
    def __init__(self, *rewrites, **kwargs):
        self.collapse_weighted = kwargs.pop("collapse_weighted", False)
        assert not kwargs
        self.rewrites = list(rewrites)
        if self.collapse_weighted:
            self.rewrites.append(lambda s: s.replace("weighted_", ""))

    def on_x_end(self, x, logs=None):
        if logs is None:
            return
        shadow = []
        for key, val in logs.items():
            short = key
            for rewrite in self.rewrites:
                short = rewrite(short)
            if short == key:
                continue
            if short in logs:
                shadow.append(short)
            logs[short] = val
            del logs[key]
        if shadow:
            warn("%r: shadows the following keys: %s" % (",".join(shadow)))

    on_batch_end = on_x_end
    on_epoch_end = on_x_end


class LogTransforms(tf.keras.callbacks.Callback):
    _prefixesRe = re.compile("^(val_|)(.+)")

    def __init__(self, *funcs, **kwargs):
        self.transforms = []
        for func in funcs:
            self.add_transform(func)
        for name, func in kwargs.items():
            self.add_transform(func, name)

    def add_transform(self, func, name=None):
        if name is None:
            name = func.__name__
        assert re.match(r"[a-zA-Z]\w*$", name)
        argspec = getargspec(func)
        assert argspec.keywords or argspec.args
        self.transforms.append((name, func, True if argspec.keywords else set(argspec.args)))

    def on_x_end(self, x, logs=None):
        if logs is None:
            return
        plogs = {}
        for key, value in logs.items():
            prefix, suffix = self._prefixesRe.match(key).groups()
            plogs.setdefault(prefix, {})[suffix] = value
        for name, func, args in self.transforms:
            for prefix, log in plogs.items():
                if args is True:
                    val = func(**log)
                else:
                    missing = args.difference(log.keys())
                    if missing:
                        msg = "%r: %s=%r needs unavailable log values: %s" % (
                            self,
                            name,
                            func,
                            ", ".join(missing),
                        )
                        if prefix:
                            warn(msg)
                            continue
                        else:
                            raise RuntimeError(msg)
                    val = func(**{key: val for key, val in log.items() if key in args})
                logs[prefix + name] = log[name] = val

    on_batch_end = on_x_end
    on_epoch_end = on_x_end


class LogEMA(tf.keras.callbacks.Callback):
    def __init__(self, keys=[], pattern=[], pairs={}, ema=0, ricu=0, ema_fmt="%s_ema"):
        assert ema < 1 and ricu < 1 and "%s" in ema_fmt
        if ema < 0:
            ema = 1 - 10 ** ema
            assert ema < 1
        if ricu < 0:
            ricu = -1.0 / ricu
        keys = {k: None for k in keys}
        keys.update(pairs)
        pin(locals(), pairs)
        self.reset()

    def reset(self):
        self.ema_val = {}
        self.run_num = defaultdict(int)

    def on_epoch_end(self, x, logs=None):
        assert logs
        keys = {key: None for key in _patfilter(self.pattern, logs.keys())}
        keys.update(self.keys)
        for key, out in keys.items():
            if not out:
                out = self.ema_fmt % key
            assert key in logs
            assert out not in keys
            assert out not in logs
            val = logs[key]
            if key in self.ema_val:
                delta = self.ema_val[key] - val
                dsign = int(np.sign(delta))
                self.run_num[key] += dsign
                if dsign * np.sign(self.run_num[key]) < -self.ricu < 0:
                    self.run_num[key] *= self.ricu ** 0.5
                val += delta * self.ema ** (1 + (self.ricu * self.run_num[key]) ** 4)
            logs[out] = self.ema_val[key] = val


class CustomValidation(tf.keras.callbacks.Callback):
    def __init__(self, **kwargs):
        self.kwargs = kwargs

    def on_epoch_end(self, x, logs=None):
        if logs is None:
            return
        res = self.model.evaluate(**self.kwargs)
        if not isinstance(res, list):
            res = [res]
        logs.update(make_logs(self.model, logs, res, mode=ModeKeys.TEST, prefix="val_"))


class TFSummaryCallback(tf.keras.callbacks.Callback):
    def __init__(self, logdir=None, **kwargs):
        self.writer = tf.summary.create_file_writer(logdir)


class PlottingCallback:
    def __init__(self, path=None, **kwargs):
        pin(locals())
        super().__init__(**kwargs)

    def draw(self, figure, dict, key):
        if self.path is not None:
            plt.savefig(f"{self.path}/{key}.pdf")
        image = figure_to_image(figure)
        dict.update({key: image})


class PlotMulticlass(PlottingCallback, TFSummaryCallback):
    def __init__(
        self,
        x,
        y,
        sample_weight=(None,),
        class_names=["signal", "background"],
        to_file=False,
        columns=None,
        plot_inputs=False,
        signalvsbkg=False,
        plot_importance=False,
        plot_activations=False,
        tag="",
        **kwargs,
    ):
        pin(locals())
        self.y = y[0]
        self.sample_weight_flat = sample_weight[0]
        self.comet_experiment = kwargs.get("comet_experiment", None)
        self.freq = kwargs.get("freq", None)
        self.file_extension = kwargs.get("file_extension", "pdf")
        self.verbose = int(kwargs.get("verbose", 1))
        super().__init__(**kwargs)

    def draw(self, figure, dict, key, epoch=None, name=None):
        if self.path is not None:
            image_path = f"{self.path}/{key}.{self.file_extension}"
            if self.verbose:
                print("PlotMulticlass-callback: Plotting {}".format(key))
            plt.savefig(image_path)
            if self.comet_experiment is not None:
                self.comet_experiment.log_image(
                    image_path, name=name, image_format=self.file_extension, step=epoch
                )
        image = figure_to_image(figure)
        dict.update({key: image})

    def on_train_begin(self, logs=None):
        if self.plot_inputs:
            self.make_input_plots()

    def on_epoch_end(self, epoch, logs=None):
        if isinstance(self.freq, (float, int)):
            if epoch % int(self.freq) == 0 and epoch > 0:
                self.make_eval_plots(epoch=epoch, name=str(epoch))

    def on_train_end(self, logs=None):
        self.make_eval_plots()

    def make_input_plots(self):
        inps = self.x
        if not isinstance(inps, (list, tuple)):
            inps = [inps]
        imgs = {}
        if self.sample_weight is not None:
            for part, inp in zip(self.columns.keys(), inps):
                self.draw(
                    figure_inputs(
                        inp,
                        self.y,
                        sample_weight=self.sample_weight_flat,
                        columns=self.columns[part],
                        class_names=self.class_names,
                        signalvsbkg=self.signalvsbkg,
                    ),
                    imgs,
                    f"inp_x_{part}",
                    name="Inputs",
                )

            self.draw(
                figure_weights(self.sample_weight_flat, self.y, class_names=self.class_names),
                imgs,
                "inp_weights",
            )
            self.draw(
                figure_correlation(
                    np.concatenate(
                        [inp.reshape(-1, np.prod(inp.shape[1:])) for inp in inps], axis=1
                    ),
                    label=list(itertools.chain(*self.columns.values())),
                ),
                imgs,
                "inp_correlation",
            )
        self.draw(
            figure_y(self.y, class_names=self.class_names),
            imgs,
            "inp_y",
        )

        for name, img in imgs.items():
            with self.writer.as_default():
                tf.summary.image(f"{name}{self.tag}", img, step=0)

    def clear_figure(self, fig):
        fig.clf()
        del fig
        plt.close("all")
        gc.collect()

    @cached_property
    def prediction(self):
        return self.model.predict(self.x, batch_size=4096)

    def make_eval_plots(self, epoch=0, name=""):
        imgs = {}

        fig = figure_roc_curve(
            self.y,
            self.prediction,
            class_names=self.class_names,
            sample_weight=self.sample_weight_flat,
        )
        self.draw(fig, imgs, "roc_curve{}".format(name), epoch=epoch, name="ROC")
        self.clear_figure(fig)

        fig = figure_confusion_matrix(
            self.y,
            self.prediction,
            class_names=self.class_names,
            sample_weight=self.sample_weight_flat,
            normalize="true",
        )
        self.draw(
            fig,
            imgs,
            "confusion_matrix_true{}".format(name),
            epoch=epoch,
            name="Confusion-Matrix-True",
        )
        self.clear_figure(fig)

        fig = figure_confusion_matrix(
            self.y,
            self.prediction,
            class_names=self.class_names,
            sample_weight=self.sample_weight_flat,
            normalize="pred",
        )
        self.draw(
            fig,
            imgs,
            "confusion_matrix_pred{}".format(name),
            epoch=epoch,
            name="Confusion-Matrix-Pred",
        )
        self.clear_figure(fig)

        if self.plot_activations:
            fig = figure_node_activations(
                self.prediction,
                self.y,
                class_names=self.class_names,
                disjoint=True,
                sample_weight=self.sample_weight_flat,
            )
            self.draw(
                fig,
                imgs,
                "node_activation_disjoint{}".format(name),
                epoch=epoch,
                name="Node-Activation",
            )
            self.clear_figure(fig)

        if self.plot_importance:
            for method in ["perm"]:
                importance = feature_importance(
                    self.model,
                    x=self.x,
                    y=self.y,
                    sample_weight=self.sample_weight,
                    method=method,
                    columns=[c for col in self.columns.values() for c in col],
                    batch_size=4096,
                )
                kinematics = ["px", "py", "pz", "x", "y", "z", "energy"]
                fig = figure_dict(
                    {
                        key: importance[key]
                        for key in importance.keys()
                        if any(key.endswith(f"_{p}") for p in kinematics)
                    },
                    xlabel="Importance [AU]",
                )
                self.draw(
                    fig,
                    imgs,
                    f"importance_{method}_kin{name}",
                    epoch=epoch,
                    name="Kinematic Importance",
                )
                self.clear_figure(fig)
                fig = figure_dict(
                    {
                        key: importance[key]
                        for key in importance.keys()
                        if not any(key.endswith(p) for p in kinematics)
                    },
                    xlabel="Importance [AU]",
                )
                self.draw(
                    fig,
                    imgs,
                    f"importance_{method}_nonkin{name}",
                    epoch=epoch,
                    name="Non-Kinematic Importance",
                )
                self.clear_figure(fig)

        for name, img in imgs.items():
            with self.writer.as_default():
                tf.summary.image(f"{name}{self.tag}", img, step=epoch)
        del imgs


class PlotMulticlassEval(PlotMulticlass):
    def on_test_end(self, logs=None):
        self.make_input_plots()
        self.make_eval_plots(0)


class PlotAssignment(PlotMulticlass):
    def __init__(
        self,
        x,
        y,
        **kwargs,
    ):
        super().__init__(x, y, **kwargs)
        n_objects = self.y.shape[1]
        n_flavours = 3
        self.y = self.y.reshape(-1, n_flavours)
        self.sample_weight_flat = (self.y * np.array([1 / 0.81, 1 / 0.15, 1 / 0.037])).sum(axis=-1)

    @cached_property
    def prediction(self):
        return self.model.predict(self.x, batch_size=4096).reshape(-1, 3)


class PlotLBN(PlottingCallback, TFSummaryCallback):
    def __init__(self, lbn_layer=None, inp_particle_names=None, *args, **kwargs):
        self.lbn = lbn_layer
        self.inp_particle_names = inp_particle_names
        super().__init__(*args, **kwargs)

    def on_train_end(self, logs=None):
        self.make_plots()

    def make_plots(self, epoch=0):
        imgs = {}
        pkwargs = {"inp_particle_names": self.inp_particle_names}

        fig = figure_lbn_weights(
            self.lbn.weights[0].numpy(), name="particles", cmap="OrRd", **pkwargs
        )
        self.draw(fig, imgs, "lbn_particles")

        fig = figure_lbn_weights(
            self.lbn.weights[1].numpy(), name="restframes", cmap="YlGn", **pkwargs
        )
        self.draw(fig, imgs, "lbn_restframes")

        pkwargs["norm"] = True
        fig = figure_lbn_weights(
            self.lbn.weights[0].numpy(), name="particles", cmap="OrRd", **pkwargs
        )
        self.draw(fig, imgs, "lbn_particles_normed")

        fig = figure_lbn_weights(
            self.lbn.weights[1].numpy(), name="restframes", cmap="YlGn", **pkwargs
        )
        self.draw(fig, imgs, "lbn_restframes_normed")

        for name, img in imgs.items():
            with self.writer.as_default():
                tf.summary.image(name, img, step=epoch)


class PlotLBNEval(PlotLBN):
    def on_test_end(self, logs=None):
        self.make_plots()


class ModelLH(tf.keras.Model):
    def __init__(self, *args, **kwargs):
        self.loss_hook = kwargs.pop("loss_hook", None)
        super(ModelLH, self).__init__(*args, **kwargs)

    def _update_sample_weight_modes(self, sample_weights=None):
        if not self._is_compiled:
            return
        if sample_weights and any([s is not None for s in sample_weights]):
            pass
            # don't default sample_weight_mode to "samplewise", it prevents proper function caching
            # for endpoint in self._training_endpoints:
            #     endpoint.sample_weight_mode = (
            #         endpoint.sample_weight_mode or 'samplewise')
        else:
            for endpoint in self._training_endpoints:
                endpoint.sample_weight_mode = None

    def _prepare_total_loss(self, *args, **kwargs):
        orig = [
            (ep, ep.__dict__.copy(), ep.training_target.__dict__.copy())
            for ep in self._training_endpoints
        ]

        self.loss_hook(self._training_endpoints.copy())
        ret = super(ModelLH, self)._prepare_total_loss(*args, **kwargs)

        for ep, ed, td in orig:
            ep.__dict__.update(ed)
            ep.training_target.__dict__.update(td)

        return ret


class GCCallback(tf.keras.callbacks.Callback):
    def on_epoch_end(self, epoch, logs=None):
        gc.collect()


class TensorBoard(tf.keras.callbacks.TensorBoard):
    def __init__(self, *args, **kwargs):
        self.writer = kwargs.pop("writer")
        super(TensorBoard, self).__init__(*args, **kwargs)

    def _init_writer(self, model=None):
        pass

    def on_train_end(self, logs=None):
        pass


class CheckpointModel(tf.keras.callbacks.Callback):
    """Gets a dict of targets (checkpoints), if current epoch is in dict, save model to target."""

    def __init__(self, checkpoints=None):
        self.targets = checkpoints.targets

    def on_epoch_end(self, epoch, logs=None):
        if epoch in self.targets:
            target = self.targets[epoch]
            self.model.save(target.path)


class BestTracker(tf.keras.callbacks.Callback):
    def __init__(
        self,
        monitor="val_loss",
        mode="auto",
        min_delta=0,
        min_delta_rel=0,
        baseline=None,
    ):
        pin(locals())
        self.reset()

    @property
    def mode_multiplier(self):
        assert self.mode in ("auto", "min", "max")
        if self.mode == "max" or (self.mode == "auto" and "acc" in self.monitor):
            return -1
        else:
            return 1

    def reset(self):
        if self.baseline is not None:
            self.best = self.baseline
        else:
            self.best = self.mode_multiplier * np.inf

    def update_best(self, logs):
        current = self.get_monitor_value(logs)
        relative = delta = (current - self.best) * self.mode_multiplier
        if np.isfinite(self.best):
            relative /= abs(self.best) or 1.0
        update = delta < -self.min_delta and relative < -self.min_delta_rel
        if update:
            self.best = current
        return update

    def get_monitor_value(self, logs):
        assert logs
        assert self.monitor in logs
        return logs[self.monitor]


class PatientTracker(BestTracker):
    """
    Jank:
        All PatientTrackers can communicate over the jank variable:
        If one PatientTracker triggers, all other PatientTrackers may not trigger for the number of epochs given by their jank.
        In order to participate in the jank mechanism, the jank of a PatientTracker must be finite (e.g. jank -1).
    """

    def __init__(
        self,
        patience=10,
        cooldown=0,
        cooldown0=0,
        override=None,
        jank=np.nan,
        jank_last=-np.inf,
        **kwargs,
    ):
        if override is None:
            del override
        cooldown_counter = cooldown if cooldown0 is True else cooldown0
        pin(locals(), kwargs)
        super(PatientTracker, self).__init__(**kwargs)

    def reset(self):
        super(PatientTracker, self).reset()
        self.wait = 0

    def override(self, logs):
        return None

    def patient_step(self, epoch, logs):
        # update jank
        logs["jank"] = min(logs.get("jank", np.inf), epoch - self.jank_last)
        # test override
        override = self.override(logs)
        if override is not None:
            return override
        # check jank
        if logs["jank"] <= self.jank:
            return
        if 0 < self.cooldown_counter:
            self.cooldown_counter -= 1
        if self.update_best(logs):
            self.wait = 0
            return "best"
        elif self.cooldown_counter <= 0:
            self.wait += 1
            if self.patience <= self.wait:
                self.cooldown_counter = self.cooldown
                return "good"


class ScaleOnPlateau(PatientTracker):
    def __init__(
        self,
        target,
        factor,
        min=None,
        max=None,
        verbose=0,
        log_key=None,
        linearly=False,
        **kwargs,
    ):
        pin(locals(), kwargs)
        super(ScaleOnPlateau, self).__init__(**kwargs)

    def on_train_begin(self, logs=None):
        self.reset()

    @property
    def value(self):
        return tf.keras.backend.get_value(self.target)

    @value.setter
    def value(self, new):
        return tf.keras.backend.set_value(self.target, new)

    def on_epoch_end(self, epoch, logs=None):
        logs = logs or {}
        cur = self.value
        if self.log_key is not None:
            logs.setdefault(self.log_key, cur)
        if self.patient_step(epoch, logs) == "good":
            if self.linearly:
                new = cur + self.factor
            else:
                new = cur * self.factor
            if self.min is not None:
                new = max(new, self.min)
            if self.max is not None:
                new = min(new, self.max)
            if cur != new:
                self.value = new
                self.reset()
                if np.isfinite(self.jank):
                    self.jank_last = epoch
                if self.verbose > 0:
                    print(
                        "\nEpoch %05d: %s scaling %s to %s."
                        % (
                            epoch + 1,
                            self.__class__.__name__,
                            self.log_key or self.target,
                            new,
                        )
                    )


class ReduceLROnPlateau(ScaleOnPlateau):
    def __init__(self, min_lr=0, **kwargs):
        super(ReduceLROnPlateau, self).__init__(min=min_lr, target=None, log_key="lr", **kwargs)

    @property
    def target(self):
        return self.model.optimizer.lr

    @target.setter
    def target(self, target):
        assert target is None


class EarlyStopping(PatientTracker):
    def __init__(self, runtime=None, restore_best_weights=False, verbose=0, do_stop=None, **kwargs):
        stop_time = datetime.datetime.now() + datetime.timedelta(**runtime) if runtime else None
        pin(locals(), kwargs)
        super(EarlyStopping, self).__init__(**kwargs)

    def reset(self):
        super(EarlyStopping, self).reset()
        self.best_weights = None
        self.best_epoch = None

    def on_train_begin(self, logs=None):
        self.reset()

    def on_epoch_end(self, epoch, logs=None):
        action = self.patient_step(epoch, logs)
        if action == "best":
            if self.restore_best_weights:
                self.best_weights = self.model.get_weights()
                self.best_epoch = epoch
        elif action == "good" or (
            (self.stop_time is not None) and (datetime.datetime.now() > self.stop_time)
        ):
            self.stopped_epoch = epoch
            self.model.stop_training = True
            if self.restore_best_weights and self.best_weights is not None:
                self.model.set_weights(self.best_weights)
                if self.verbose > 0:
                    print(
                        "EarlyStopping-Callback: Restoring model weights from the end of the best epoch."
                    )

    def on_train_end(self, logs=None):
        hist = self.model.history
        if self.restore_best_weights and self.best_epoch in hist.epoch:
            idx = hist.epoch.index(self.best_epoch)
        else:
            idx = -1
        hist.final_logs = {k: v[idx] for k, v in hist.history.items()}


class TQES(EarlyStopping):
    def __init__(self, log_pattern=None, prog_batch="steps", **kwargs):
        assert prog_batch in (None, "steps", "samples")
        pin(locals(), kwargs)
        super(TQES, self).__init__(**kwargs)

    @property
    def use_steps(self):
        return self.prog_batch == "steps"

    def on_train_begin(self, logs=None):
        super(TQES, self).on_train_begin(logs)
        self.tqE = tqdm(desc=getattr(self.model, "name", None), total=self.params["epochs"])

    def on_epoch_begin(self, epoch, logs=None):
        if self.prog_batch:
            self.tqB = tqdm(
                unit=("batch" if self.use_steps else "sample"),
                total=self.params["steps" if self.use_steps else "samples"],
            )

    def on_batch_end(self, batch, logs=None):
        if self.tqB:
            logs = logs or {}
            self.tqB.set_postfix(self.make_postfix(logs), refresh=False)
            self.tqB.update(
                logs.get("num_steps", 1) * (1 if self.use_steps else logs.get("size", 0))
            )

    def on_epoch_end(self, epoch, logs=None):
        super(TQES, self).on_epoch_end(epoch, logs)
        last = self.get_monitor_value(logs)
        if self.tqB:
            self.tqB.close()
            self.tqB = None
        self.tqE.set_postfix(
            self.make_postfix(
                logs,
                [
                    ("best", self.best),
                    ("conf", self.wait),
                    ("rdlb", ((last or 0.0) - self.best) / self.best),
                ],
            ),
            refresh=False,
        )
        self.tqE.update(epoch - self.tqE.n)

    def make_postfix(self, logs, extra=[]):
        if self.log_pattern is None:
            keys = self.params["metrics"]
        else:
            keys = _patfilter(self.log_pattern, logs.keys())
        return OrderedDict(
            [(key, logs[key]) for key in sorted(keys) if key in logs] + list(filter(None, extra))
        )

    def on_train_end(self, logs=None):
        super(TQES, self).on_train_end(logs)
        self.tqE.close()


class AUCOneVsAll(tf.keras.metrics.AUC):
    def __init__(self, one=0, *args, **kwargs):
        self.one = one
        super().__init__(*args, **kwargs)

    def update_state(self, y_true, y_pred, sample_weight=None):
        from tensorflow.python.framework import ops
        from tensorflow.python.framework import tensor_shape
        from tensorflow.python.keras.utils import metrics_utils
        from tensorflow.python.ops import check_ops

        deps = []
        if not self._built:
            self._build(tensor_shape.TensorShape(y_pred.shape))

        if self.multi_label or (self.label_weights is not None):
            # y_true should have shape (number of examples, number of labels).
            shapes = [(y_true, ("N", "L"))]
            if self.multi_label:
                # TP, TN, FP, and FN should all have shape
                # (number of thresholds, number of labels).
                shapes.extend(
                    [
                        (self.true_positives, ("T", "L")),
                        (self.true_negatives, ("T", "L")),
                        (self.false_positives, ("T", "L")),
                        (self.false_negatives, ("T", "L")),
                    ]
                )
            if self.label_weights is not None:
                # label_weights should be of length equal to the number of labels.
                shapes.append((self.label_weights, ("L",)))
            deps = [check_ops.assert_shapes(shapes, message="Number of labels is not consistent.")]

        # Only forward label_weights to update_confusion_matrix_variables when
        # multi_label is False. Otherwise the averaging of individual label AUCs is
        # handled in AUC.result
        label_weights = None if self.multi_label else self.label_weights
        with ops.control_dependencies(deps):
            return metrics_utils.update_confusion_matrix_variables(
                {
                    metrics_utils.ConfusionMatrix.TRUE_POSITIVES: self.true_positives,
                    metrics_utils.ConfusionMatrix.TRUE_NEGATIVES: self.true_negatives,
                    metrics_utils.ConfusionMatrix.FALSE_POSITIVES: self.false_positives,
                    metrics_utils.ConfusionMatrix.FALSE_NEGATIVES: self.false_negatives,
                },
                y_true[:, self.one],
                y_pred[:, self.one],
                self.thresholds,
                sample_weight=sample_weight,
                multi_label=self.multi_label,
                label_weights=label_weights,
            )


# Layer Definitions


class WhereEquals(tf.keras.layers.Layer):
    def __init__(self, value=0):
        super(WhereEquals, self).__init__()
        self.value = value

    def call(self, inp):
        return tf.where(inp[:, 0] == self.value)


class BaseLayer(tf.keras.layers.Layer):
    """
    The BaseLayer object is a feature collection of different features used in a DNN layer
    It features:
        * l2 regu
        * the weights (the real layer)
        * batch norm
        * activation function
        * dynamically chosen dropout

    Parameters
    ----------
    nodes : int
        The number of nodes.
    activation : str or one of tf.keras.activations
        The used activation function.
    dropout : float
        The used dropout ration.
        If "selu" is used as activation function, dropout becomes AlphaDropout.
    l2 : float
        The used factor of l2 regu.
    batch_norm : bool
        Wether to use dropout or not.
        If selu activation is used, batch_norm is forced off.
    """

    def __init__(self, nodes=0, activation=None, dropout=0.0, l2=0, batch_norm=False, **kwargs):
        super().__init__(name="BaseLayer")
        self.nodes = nodes
        self.activation = activation
        self.dropout = dropout
        self.l2 = l2
        self.batch_norm = batch_norm

    def build(self, input_shape):
        parts = {}
        if self.activation == "selu":
            kernel_initializer = "lecun_normal"
        else:
            kernel_initializer = "glorot_uniform"

        l2 = tf.keras.regularizers.l2(self.l2)
        weights = tf.keras.layers.Dense(
            self.nodes,
            kernel_regularizer=l2,
            kernel_initializer=kernel_initializer,
        )
        parts["weights"] = weights

        if self.batch_norm and not self.activation == "selu":
            bn = tf.keras.layers.BatchNormalization()
            parts["bn"] = bn

        act = tf.keras.layers.Activation(self.activation)
        parts["act"] = act

        if self.activation == "selu":
            dropout = tf.keras.layers.AlphaDropout(self.dropout)
        else:
            dropout = tf.keras.layers.Dropout(self.dropout)

        parts["dropout"] = dropout
        self.parts = parts

    @property
    def part_order(self):
        raise NotImplementedError

    def call(self, input_tensor, training=False):
        x = input_tensor
        for part in self.part_order:
            if part in self.parts:
                x = self.parts[part](x, training=training)
        return x

    def get_config(self):
        return {
            "nodes": self.nodes,
            "activation": self.activation,
            "dropout": self.dropout,
            "l2": self.l2,
            "batch_norm": self.batch_norm,
        }


class DenseLayer(BaseLayer):
    part_order = ["weights", "bn", "act", "dropout"]

    def __init__(self, **kwargs):
        super().__init__(name="DenseLayer", **kwargs)


class BNActWeightLayer(BaseLayer):
    part_order = ["bn", "act", "weights", "dropout"]

    def __init__(self, **kwargs):
        super().__init__(name="BNActWeightLayer", **kwargs)


class ResNetBlock(tf.keras.layers.Layer):
    """
    The ResNetBlock object is an implementation of one residual DNN block.

    Parameters
    ----------
    jump : int
        The number layers to bypass by residual connection.
    kwargs :
        Arguments for DenseLayer.

    """

    def __init__(self, jump=2, sub_kwargs=None, **kwargs):
        super().__init__(name="ResNetBlock")
        self.jump = jump
        self.sub_kwargs = kwargs if sub_kwargs is None else sub_kwargs

    def build(self, input_shape):
        layers = []
        for i in range(self.jump - 1):
            layers.append(DenseLayer(**self.sub_kwargs))

        kwargs = dict(self.sub_kwargs)
        activation = kwargs.pop("activation")
        kwargs.pop("nodes")
        layers.append(DenseLayer(nodes=input_shape[-1], **kwargs))
        self.layers = layers
        self.out_activation = tf.keras.layers.Activation(activation)

    def call(self, input_tensor, training=False):
        x = input_tensor
        for layer in self.layers:
            x = layer(x, training=training)
        x += input_tensor
        x = self.out_activation(x)
        return x

    def get_config(self):
        return {"jump": self.jump, "sub_kwargs": self.sub_kwargs}


class DenseNetBlock(tf.keras.layers.Layer):
    """
    The DenseNetBlock object is an implementation of one DenseNet DNN block.

    Parameters
    ----------
    block_size : int
        The size of a DenseNet Block, includes the transition layer.
    kwargs :
        Arguments for DenseLayer.

    """

    def __init__(self, block_size=2, sub_kwargs=None, **kwargs):
        super().__init__(name="DenseNetBlock")
        self.block_size = block_size
        self.sub_kwargs = kwargs if sub_kwargs is None else sub_kwargs

    def build(self, input_shape):
        layers = []
        for i in range(self.block_size - 1):
            layers.append(DenseLayer(**self.sub_kwargs))
        self.layers = layers
        self.transition = DenseLayer(**self.sub_kwargs)

    def call(self, input_tensor, training=False):
        x = input_tensor
        for layer in self.layers:
            y = layer(x, training=training)
            x = tf.keras.layers.concatenate([x, y], axis=-1)
        x = self.transition(x)
        return x

    def get_config(self):
        return {"block_size": self.block_size, "sub_kwargs": self.sub_kwargs}


class LinearNetwork(tf.keras.layers.Layer):
    @property
    def name(self):
        raise NotImplementedError

    @property
    def substructure(self):
        raise NotImplementedError

    def __init__(self, layers=0, sub_kwargs=None, **kwargs):
        super().__init__(name=self.name)
        self.layers = layers
        self.sub_kwargs = kwargs if sub_kwargs is None else sub_kwargs

    def build(self, input_shape):
        network_layers = []
        for layer in range(self.layers):
            network_layers.append(self.substructure(**self.sub_kwargs))
        self.network_layers = network_layers

    def call(self, input_tensor, training=False):
        x = input_tensor
        for layer in self.network_layers:
            x = layer(x, training=training)
        return x

    def get_config(self):
        return {"layers": self.layers, "sub_kwargs": self.sub_kwargs}


class FullyConnected(LinearNetwork):
    """
    The FullyConnected object is an implementation of a fully connected DNN.

    Parameters
    ----------
    layers : int
        The number of layers.
    kwargs :
        Arguments for DenseLayer.

    """

    name = "FullyConnected"
    substructure = DenseLayer


class ResNet(LinearNetwork):
    """
    The ResNet object is an implementation of a Residual Neural Network.

    Parameters
    ----------
    layers : int
        The number of residual blocks.
    kwargs :
        Arguments for ResNetBlock.

    """

    name = "ResNet"
    substructure = ResNetBlock


class DenseNet(LinearNetwork):
    """
    The DenseNet object is an implementation of a DenseNet Neural Network.

    Parameters
    ----------
    layers : int
        The number of densely connected (DenseNet) blocks.
    kwargs :
        Arguments for DenseNetBlock.

    """

    name = "DenseNet"
    substructure = DenseNetBlock


class Xception1D(tf.keras.layers.Layer):
    """
    The Xception1D object is an implementation of a Xception Neural Network.
    Because Xception was trained on 2D RGB data, each feature vector for the Xception1D
    is automatically scaled up to (71, 71, 3).

    Parameters
    ----------
    kwargs :
        Arguments for DenseLayers.
    """

    def __init__(self, sub_kwargs=None, **kwargs):
        super().__init__(name="DenseNet")
        self.sub_kwargs = kwargs if sub_kwargs is None else sub_kwargs

    def build(self, input_shape):
        lenght = input_shape[-1]
        assert lenght < 71  # until now only includes small feature spaces
        pad = math.ceil((71 - lenght) / 2)
        padded_lenght = 71 + 2 * math.ceil((71 - lenght) / 2)
        self.up_sample = tf.keras.layers.UpSampling2D(size=(1, lenght))
        self.zero_pad = (
            tf.keras.layers.ZeroPadding2D(pad)
            if lenght < 71
            else tf.keras.layers.Lambda(lambda x: x)
        )
        self.first_network = tf.keras.applications.Xception(
            include_top=False, input_shape=(71 + padded_lenght, padded_lenght, 3)
        )
        self.first_network.trainable = False
        self.flatten = tf.keras.layers.Flatten()
        self.second_network = FullyConnected(**self.sub_kwargs)

    def call(self, input_tensor, training=False):
        x = input_tensor
        x = tf.keras.backend.expand_dims(x)
        x = tf.keras.backend.expand_dims(x)
        x = self.up_sample(x)
        x = tf.keras.backend.repeat_elements(x, 3, axis=-1)
        x = self.zero_pad(x)
        x = self.first_network(x)
        x = self.flatten(x)
        x = self.second_network(x)
        return x

    def get_config(self):
        return {"sub_kwargs": self.sub_kwargs}


class PermutationInvariantDense1D(tf.keras.layers.Layer):
    """
    The PermutationInvariantDense1D layer is a TimeDistributed(Dense) implementation.
    Additionally it imposes (if wanted) a permutation inivariant pooling operation.

    Parameters
    ----------
    name :
        name of the layer
    depth :
        depth of the ResNet layers
    nfeatures:
        number of features of Conv1D layer
    batchnorm:
        enable batchnorm for ResNet layers
    pooling_op:
        permutation invariant pooling operation
    """

    def __init__(self, **kwargs):
        name = kwargs.pop("name", "PermutationInvariantDense1D")
        super().__init__(name=name)
        self.depth = kwargs.pop("depth", 1)
        self.nfeatures = kwargs.pop("nfeatures", 32)
        self.batchnorm = kwargs.pop("batchnorm", False)
        self.pooling_op = kwargs.pop("pooling_op", None)

    def build(self, input_shape):
        opts = dict(activation="elu", kernel_initializer="he_normal")
        self.layer = tf.keras.layers.TimeDistributed(
            tf.keras.layers.Dense(
                self.nfeatures,
                input_shape=input_shape,
                **opts,
            )
        )
        network_layers = []
        for layer in range(self.depth - 1):
            network_layers.append(
                tf.keras.layers.TimeDistributed(tf.keras.layers.Dense(self.nfeatures, **opts))
            )
            if self.batchnorm:
                network_layers.append(tf.keras.layers.BatchNormalization())
        if self.pooling_op:
            # permutation invariant pooling op
            network_layers.append(tf.keras.layers.Lambda(lambda x: self.pooling_op(x, axis=1)))
        self.network_layers = network_layers

    def call(self, input_tensor, training=False):
        x = input_tensor
        for layer in self.network_layers:
            x = layer(x, training=training)
        return x


class SplitHighLow(tf.keras.layers.Layer):
    def call(self, inputs):
        return inputs[:, :, :4], inputs[:, :, 4:]


class LL(tf.keras.layers.Layer):
    def call(self, inputs):
        return inputs[:, :, :4]


class LBNLayer(tf.keras.layers.Layer):
    """
    Custom implementation of the LBNLayer with automatic cropping to
    low-level variables before and batchnorm after the LBN layer.

    Parameters
    ----------
    args :
        args for original LBN layer
    kwargs :
        kwargs for original LBN layer

    """

    def __init__(self, *args, **kwargs):
        super().__init__(name="LBNLayer")
        _kwargs = kwargs.copy()
        if "features" not in _kwargs:
            _kwargs["features"] = self.all_features
        self._args = args
        self._kwargs = _kwargs

    def build(self, input_shape):
        self.concat = tf.keras.layers.Concatenate(axis=-2)
        lbn_inp_shape = (sum(_[-2] for _ in input_shape), 4)
        self.lbn_layer = OriginalLBNLayer(lbn_inp_shape, *self._args, **self._kwargs)
        self.batch_norm = tf.keras.layers.BatchNormalization()

    @property
    def all_features(self):
        return [
            "E",
            "px",
            "py",
            "pz",
            "pt",
            "p",
            "m",
            "phi",
            "eta",
            "beta",
            "gamma",
            "pair_cos",
            "pair_dr",
            "pair_ds",
            "pair_dy",
        ]

    def call(self, input_tensors, training=False):
        ll = [SplitHighLow()(tensor)[0] for tensor in input_tensors]
        ll = self.concat(ll)
        feats = self.lbn_layer(ll)
        feats = self.batch_norm(feats)
        return feats


def feature_importance_grad(model, x=None, filter_types=["int32", "int64"], **kwargs):
    inp = {f"{i}": _x for (i, _x) in enumerate(x)}
    inp["sample_weight"] = kwargs["sample_weight"]
    ds = tf.data.Dataset.from_tensor_slices((inp))
    ds = ds.batch(256)
    grad = 0
    n_batches = tf.data.experimental.cardinality(ds).numpy()

    with tqdm(total=n_batches) as pbar:
        pbar.set_description("Importance grad")
        for _x in ds:
            if "sample_weight" in _x:
                sw = tf.cast(_x.pop("sample_weight"), tf.float32)
            inp = list(_x.values())
            with tf.GradientTape(watch_accessed_variables=False) as tape:
                [tape.watch(i) for i in inp if i.dtype not in filter_types]
                pred = model(inp, training=False)
                ix = tf.keras.layers.Lambda(lambda x: x[:, -1])(tf.argsort(pred, axis=-1))
                decision = tf.gather(pred, ix, batch_dims=1)
            gradients = tape.gradient(
                decision, [i for i in inp if i.dtype not in filter_types]
            )  # gradients for decision nodes
            gradients = [
                grad if grad is not None else tf.constant(0.0, dtype=tf.float32)
                for grad in gradients
            ]  # categorical tensors
            gradients = [tf.transpose(tf.transpose(g) * sw) for g in gradients]
            gradients = [tf.math.reduce_mean(tf.math.abs(g), axis=0) for g in gradients]

            # norm gradients to std
            normed_gradients = []
            for g, i in zip(gradients, inp):
                if i.dtype in filter_types:
                    val = np.nan * tf.math.reduce_std(i, axis=0)
                else:
                    val = g * tf.math.reduce_std(i, axis=0)
                normed_gradients.append(val)

            grad += tf.concat([tf.keras.backend.flatten(g) for g in normed_gradients], axis=-1)
            pbar.update()

    mean_gradients = grad.numpy() / n_batches

    return mean_gradients / mean_gradients.max()


def feature_importance_perm(model, x=None, **kwargs):
    inp_list = list(x)
    feat = 0  # loss
    ref = model.evaluate(x=x, **kwargs)[feat]
    vals = []
    for tensor_index, tensor in enumerate(inp_list):
        for index in np.ndindex(tensor.shape[1:]):
            # copy original values of feature
            original_values = np.copy(tensor[(Ellipsis, *index)])

            # randomly permute indexed feature
            tensor[(Ellipsis, *index)] = np.random.permutation(tensor[(Ellipsis, *index)])

            # write permuteted feature to network input (inp)
            inp = inp_list.copy()
            inp[tensor_index] = tensor
            vals.append(model.evaluate(x=inp, **kwargs)[feat])

            # cleanup and write back original values of feature
            tensor[(Ellipsis, *index)] = original_values

    vals = np.array(vals)
    return (vals - ref) / ref


def feature_importance(*args, method="grad", columns=[], **kwargs):
    if method == "grad":
        importance = feature_importance_grad(*args, **kwargs)
    elif method == "perm":
        importance = feature_importance_perm(*args, **kwargs)
    else:
        raise NotImplementedError("Feature importance method not implemented")
    return {
        k: v
        for k, v in sorted(
            dict(zip(columns, importance.astype(float))).items(),
            key=lambda item: item[1],
        )
        if v is not np.nan
    }


# Custom Loss Functions
@tf.function
def cross_entropy_t(
    labels,
    predictions,
    sample_weight=None,
    focal_gamma=None,
    class_weight=None,
    epsilon=1e-7,
):
    # get true-negative component
    predictions = tf.clip_by_value(predictions, epsilon, 1 - epsilon)
    tn = labels * tf.math.log(predictions)
    # focal loss?
    if focal_gamma is not None:
        tn *= (1 - predictions) ** focal_gamma
    # convert into loss
    loss = -tn
    # apply class weights
    if class_weight is not None:
        loss *= class_weight
    # apply sample weights
    if sample_weight is not None:
        loss *= sample_weight[:, tf.newaxis]

    return tf.reduce_mean(tf.reduce_sum(loss, axis=-1))


@tf.function
def grouped_cross_entropy_t(
    labels,
    predictions,
    sample_weight=None,
    group_ids=None,
    focal_gamma=None,
    class_weight=None,
    epsilon=1e-7,
    std_xent=True,
):
    loss_terms = []
    # ensure same type for labels and predictions
    labels = tf.cast(labels, predictions.dtype)
    if std_xent:
        loss_terms.append(
            cross_entropy_t(
                labels,
                predictions,
                sample_weight=sample_weight,
                focal_gamma=focal_gamma,
                class_weight=class_weight,
                epsilon=epsilon,
            )
        )

    if group_ids:
        # create grouped labels and predictions
        labels_grouped, predictions_grouped = [], []
        for _, ids in group_ids:
            labels_grouped.append(
                tf.reduce_sum(tf.gather(labels, ids, axis=-1), axis=-1, keepdims=True)
            )
            predictions_grouped.append(
                tf.reduce_sum(tf.gather(predictions, ids, axis=-1), axis=-1, keepdims=True)
            )

        labels_grouped = (
            tf.concat(labels_grouped, axis=-1) if len(labels_grouped) > 1 else labels_grouped[0]
        )
        predictions_grouped = (
            tf.concat(predictions_grouped, axis=-1)
            if len(predictions_grouped) > 1
            else predictions_grouped[0]
        )
        loss_terms.append(
            cross_entropy_t(
                labels_grouped,
                predictions_grouped,
                sample_weight=sample_weight,
                focal_gamma=focal_gamma,
                class_weight=tf.constant([w for w, _ in group_ids], tf.float32),
                epsilon=epsilon,
            )
        )
    return sum(loss_terms) / len(loss_terms)


class GroupedXEnt(tf.keras.losses.Loss):
    def __init__(
        self,
        group_ids=None,
        focal_gamma=None,
        class_weight=None,
        epsilon=1e-7,
        std_xent=True,
        *args,
        **kwargs,
    ):
        super(GroupedXEnt, self).__init__(*args, **kwargs)
        self.group_ids = group_ids
        self.focal_gamma = focal_gamma
        self.class_weight = class_weight
        self.epsilon = epsilon
        self.std_xent = std_xent

    def call(self, y_true, y_pred, sample_weight=None):
        return grouped_cross_entropy_t(
            y_true,
            y_pred,
            sample_weight=sample_weight,
            group_ids=self.group_ids,
            focal_gamma=self.focal_gamma,
            class_weight=self.class_weight,
            epsilon=self.epsilon,
            std_xent=self.std_xent,
        )


def write_summary(model, target):
    """write tf keras model summary to law target"""
    summary = []
    model.summary(print_fn=lambda x: summary.append(x))
    summary = "\n".join(summary)
    target.dump(summary)


def find_layer(layer, name="LBNLayer"):
    if layer.name == name:
        return layer
    if hasattr(layer, "layers"):
        sub_layers = layer.layers
        for sub_layer in sub_layers:
            layer = find_layer(sub_layer, name=name)
            if layer is not None:
                return layer