-
Dennis Noll authoredDennis Noll authored
Code owners
Assign users and groups as approvers for specific file changes. Learn more.
keras.py 46.13 KiB
# from itertools import izip
import gc
from collections import OrderedDict, defaultdict
import fnmatch
import math
import numpy as np
import tensorflow as tf
from tqdm import tqdm
from inspect import getargspec
from warnings import warn
import re
from tensorflow.python.keras.callbacks import make_logs
from tensorflow.python.keras.backend import track_variable
from tensorflow.python.keras.utils.mode_keys import ModeKeys
from operator import itemgetter
from lbn import LBNLayer as OriginalLBNLayer
from matplotlib import pyplot as plt
from .evil import pin
from .data import SKDict, DSS
from .plotting import (
figure_to_image,
figure_confusion_matrix,
figure_activations,
figure_node_activations,
figure_roc_curve,
figure_multihist,
figure_y,
figure_weights,
figure_inputs,
figure_dict,
)
# various helper functions
def kVar(*args, **kwargs):
""" produce a keras-tracked tf.Variable from all given parameters """
var = tf.Variable(*args, **kwargs)
track_variable(var)
return var
def kOpt(opt, **kwargs):
""" instanciate a keras Optimizer with all applicable **kwargs """
if not callable(opt):
opt = getattr(tf.keras.optimizers, opt)
assert issubclass(opt, tf.keras.optimizers.Optimizer)
args = getargspec(opt.__init__).args - {"self"}
return opt(**{k: v for k, v in kwargs.items() if k in args})
def kInput(ref, **kwargs):
""" produce a keras.Input with shape & dtype according to ref """
kwargs.setdefault("shape", ref.shape[1:])
kwargs.setdefault("dtype", ref.dtype)
return tf.keras.Input(**kwargs)
def keras_register_custom_object(obj):
""" decorator for globally registering a custom object with keras """
tf.keras.utils.get_custom_objects()[obj.__name__] = obj
return obj
class KFeed(object):
def __init__(self, x, y, w=None, train="train", valid="valid", balance=lambda x: x):
pin(locals())
@property
def xyw(self):
return tuple(v if isinstance(v, tuple) else (v,) for v in (self.x, self.y, self.w))
def get(self, src):
return tuple(tuple(src[k] for k in v) for v in self.xyw)
@property
def all(self):
return set().union(*self.xyw)
def balance(self, src):
return src
def kfeed(self, src, **kwargs):
src = src.only(*self.all)
bal = self.balance(src)
return dict(
zip(["x", "y", "sample_weight"], self.get(bal[self.train])),
validation_data=self.get(bal[self.valid]),
**kwargs,
)
def balance_weights(self, weights):
if isinstance(weights, SKDict):
sums = weights.map(np.sum)
if len(sums) > 1:
ref = np.mean(list(sums.values()))
weights = weights.__class__({k: weights[k] * (ref / s) for k, s in sums.items()})
return weights
def compact(self, src, keep=()):
ret = src.only(*self.all, *keep)
val = ret[self.valid]
assert not isinstance(self.w, tuple)
if isinstance(val[self.w], SKDict):
val[self.w] = self.balance_weights(val[self.w])
val = val.fuse(*val[self.w].keys())
ret = ret.only(self.train)
ret[self.valid] = val
return ret
def gfeed(self, src, batch_size, rng=np.random, auto_steps=np.max, **kwargs):
"""
Creates a generator for tf.keras' model.fit().
Requires, that mean weights per process are equal:
dss["weight"] = dss["weight"].map(lambda x: x / np.mean(x))
"""
src = self.compact(src)
val = src[self.valid]
val.blen # asserts that all first dimensions have equal length
return dict(
dict(
zip(
("x", "steps_per_epoch"),
self.gensteps(src[self.train], batch_size, rng=rng, auto_steps=auto_steps),
)
),
validation_data=self.get(val),
workers=0,
**kwargs,
)
def gensteps(self, src, batch_size, rng=np.random, auto_steps=np.max):
keys = src.mkeys(self.all)
gen = (
(
self.get(DSS.zip(*parts).map(np.concatenate))
for parts in zip(
*[
src[k].batch_generator(
batch_size // len(keys),
rng=np.random.RandomState(rng.randint(1 << 31, size=20)),
)
for k in keys
]
)
)
if len(keys) > 1
else map(self.get, src[list(keys)[0]].batch_generator(batch_size, rng=rng))
)
gs = float(batch_size // len(keys))
steps = int(auto_steps([src[k].blen / gs for k in keys])) or None
return gen, steps
def generator(self, *args, **kwargs):
assert "auto_steps" not in kwargs
return self.gensteps(*args, **kwargs)[0]
def getShapes(self, src):
return tuple(tuple(src[k].shape for k in g) for g in self.xyw)
def getShapesK(self, src):
strip_first_dim = lambda x: x[1:] if x[0] is None else x
shapes = self.getShapes(src)
return tuple(tuple(strip_first_dim(a) for a in b) for b in shapes)
def mkInputs(self, src, **kwargs):
return tuple(
tf.keras.Input(shape=ref.shape[1:], dtype=ref.dtype, **kwargs)
for ref in (src[x] for x in self.xyw[0])
)
def Normal(ref, const=None, ignore_zeros=False, name=None, **kwargs):
"""
Normalizing layer according to ref.
If given, variables at the indices const will not be normalized.
"""
if ignore_zeros:
mean = np.nanmean(np.where(ref == 0, np.ones_like(ref) * np.nan, ref), **kwargs)
std = np.nanstd(np.where(ref == 0, np.ones_like(ref) * np.nan, ref), **kwargs)
else:
mean = ref.mean(**kwargs)
std = ref.std(**kwargs)
if const is not None:
mean[const] = 0
std[const] = 1
std = np.where(std == 0, 1, std)
mul = 1.0 / std
add = -mean / std
return tf.keras.layers.Lambda((lambda x: (x * mul) + add), name=name)
def Onehot(index, n, name=None):
"""
One hot encodes a variable referred to by index.
n is the number of different variables.
"""
def to_onehot(x):
# Concat zeros to eye for indices in x equal to n (larger than those encoded by one)
eye = tf.concat((tf.eye(n), tf.zeros((1, n))), axis=0)
return tf.concat(
(
x[..., :index],
tf.gather(eye, tf.cast(x[..., index], tf.int64)),
x[..., (index + 1) :],
),
axis=-1,
)
return tf.keras.layers.Lambda(to_onehot, name=name)
class Moment(tf.keras.metrics.Mean):
def __init__(self, order, label=False, **kwargs):
""" Metric calculating the order-th moment """
assert order == int(order)
assert label == bool(label)
kwargs.setdefault("name", "%smom%d" % ("l" if label else "", order))
super(Moment, self).__init__(**kwargs)
self.order = order
self.label = label
def update_state(self, y_true, y_pred, sample_weight=None):
y = y_true if self.label else y_pred
y = tf.keras.backend.cast(y, self._dtype)
if self.order == 0:
y = tf.keras.backend.ones_like(y)
elif self.order == 1:
pass
elif self.order == 2:
y = tf.keras.backend.square(y)
else:
y = tf.keras.backend.pow(y, self.order)
return super(Moment, self).update_state(y, sample_weight=sample_weight)
def get_config(self):
return dict(super(Moment, self).get_config(), order=self.order, label=self.label)
# lots of call backs
def _patfilter(pattern, items):
if isinstance(pattern, (list, tuple)):
pattern = "|".join(map(fnmatch.translate, pattern))
return filter(re.compile(pattern).search, items)
class Moment2Std(tf.keras.callbacks.Callback):
def __init__(self, mom1="mom1", mom2="mom2", std="std"):
assert mom1 and mom2 and std
pin(locals())
def on_x_end(self, x, logs=None):
if logs is None:
return
for key1, mom1 in list(logs.items()):
if not key1.endswith(self.mom1):
continue
prefix = key1[: -len(self.mom1)]
mom2 = logs.get(prefix + self.mom2, None)
if mom2 is not None:
logs[prefix + self.std] = (mom2 - mom1 ** 2) ** 0.5
on_batch_end = on_x_end
on_epoch_end = on_x_end
class LogRewrite(tf.keras.callbacks.Callback):
def __init__(self, *rewrites, **kwargs):
self.collapse_weighted = kwargs.pop("collapse_weighted", False)
assert not kwargs
self.rewrites = list(rewrites)
if self.collapse_weighted:
self.rewrites.append(lambda s: s.replace("weighted_", ""))
def on_x_end(self, x, logs=None):
if logs is None:
return
shadow = []
for key, val in logs.items():
short = key
for rewrite in self.rewrites:
short = rewrite(short)
if short == key:
continue
if short in logs:
shadow.append(short)
logs[short] = val
del logs[key]
if shadow:
warn("%r: shadows the following keys: %s" % (",".join(shadow)))
on_batch_end = on_x_end
on_epoch_end = on_x_end
class LogTransforms(tf.keras.callbacks.Callback):
_prefixesRe = re.compile("^(val_|)(.+)")
def __init__(self, *funcs, **kwargs):
self.transforms = []
for func in funcs:
self.add_transform(func)
for name, func in kwargs.items():
self.add_transform(func, name)
def add_transform(self, func, name=None):
if name is None:
name = func.__name__
assert re.match(r"[a-zA-Z]\w*$", name)
argspec = getargspec(func)
assert argspec.keywords or argspec.args
self.transforms.append((name, func, True if argspec.keywords else set(argspec.args)))
def on_x_end(self, x, logs=None):
if logs is None:
return
plogs = {}
for key, value in logs.items():
prefix, suffix = self._prefixesRe.match(key).groups()
plogs.setdefault(prefix, {})[suffix] = value
for name, func, args in self.transforms:
for prefix, log in plogs.items():
if args is True:
val = func(**log)
else:
missing = args.difference(log.keys())
if missing:
msg = "%r: %s=%r needs unavailable log values: %s" % (
self,
name,
func,
", ".join(missing),
)
if prefix:
warn(msg)
continue
else:
raise RuntimeError(msg)
val = func(**{key: val for key, val in log.items() if key in args})
logs[prefix + name] = log[name] = val
on_batch_end = on_x_end
on_epoch_end = on_x_end
class LogEMA(tf.keras.callbacks.Callback):
def __init__(self, keys=[], pattern=[], pairs={}, ema=0, ricu=0, ema_fmt="%s_ema"):
assert ema < 1 and ricu < 1 and "%s" in ema_fmt
if ema < 0:
ema = 1 - 10 ** ema
assert ema < 1
if ricu < 0:
ricu = -1.0 / ricu
keys = {k: None for k in keys}
keys.update(pairs)
pin(locals(), pairs)
self.reset()
def reset(self):
self.ema_val = {}
self.run_num = defaultdict(int)
def on_epoch_end(self, x, logs=None):
assert logs
keys = {key: None for key in _patfilter(self.pattern, logs.keys())}
keys.update(self.keys)
for key, out in keys.items():
if not out:
out = self.ema_fmt % key
assert key in logs
assert out not in keys
assert out not in logs
val = logs[key]
if key in self.ema_val:
delta = self.ema_val[key] - val
dsign = int(np.sign(delta))
self.run_num[key] += dsign
if dsign * np.sign(self.run_num[key]) < -self.ricu < 0:
self.run_num[key] *= self.ricu ** 0.5
val += delta * self.ema ** (1 + (self.ricu * self.run_num[key]) ** 4)
logs[out] = self.ema_val[key] = val
class CustomValidation(tf.keras.callbacks.Callback):
def __init__(self, **kwargs):
self.kwargs = kwargs
def on_epoch_end(self, x, logs=None):
if logs is None:
return
res = self.model.evaluate(**self.kwargs)
if not isinstance(res, list):
res = [res]
logs.update(make_logs(self.model, logs, res, mode=ModeKeys.TEST, prefix="val_"))
class TFSummaryCallback(tf.keras.callbacks.Callback):
def __init__(self, logdir=None, **kwargs):
self.writer = tf.summary.create_file_writer(logdir)
class PlotMulticlass(TFSummaryCallback):
def __init__(
self,
x,
y,
sample_weight=None,
class_names=["signal", "background"],
to_file=False,
columns=None,
plot_inputs=False,
signalvsbkg=False,
plot_importance=False,
plot_activations=False,
tag="",
**kwargs,
):
super().__init__(**kwargs)
self.x = x
self.truth = y[0]
self.sample_weight = sample_weight[0]
self.class_names = class_names
self.plot_inputs = plot_inputs
self.columns = columns
self.to_file = to_file
self.signalvsbkg = signalvsbkg
self.tag = tag
self.plot_importance = plot_importance
self.plot_activations = plot_activations
def on_train_begin(self, logs=None):
self.make_input_plots()
def on_train_end(self, logs=None):
self.make_eval_plots()
def make_input_plots(self):
if self.plot_inputs:
imgs = {}
if self.columns:
inps = self.x
if not isinstance(inps, (list, tuple)):
inps = [inps]
for part, inp in zip(self.columns.keys(), inps):
imgs[f"inp_xmerged_{part}"] = figure_to_image(
figure_multihist(inp, columns=self.columns[part])
)
if self.sample_weight is not None:
for part, inp in zip(self.columns.keys(), inps):
imgs[f"inp_x_{part}"] = figure_to_image(
figure_inputs(
inp,
self.truth,
sample_weight=self.sample_weight,
columns=self.columns[part],
class_names=self.class_names,
signalvsbkg=self.signalvsbkg,
)
)
imgs["inp_weights"] = figure_to_image(
figure_weights(self.sample_weight, self.truth, class_names=self.class_names)
)
imgs["inp_y"] = figure_to_image(figure_y(self.truth, class_names=self.class_names))
imgs["inp_yrelative"] = figure_to_image(
figure_y(self.truth, class_names=self.class_names, relative=True)
)
for name, img in imgs.items():
with self.writer.as_default():
tf.summary.image(f"{name}{self.tag}", img, step=0)
def clear_figure(self, fig):
fig.clf()
del fig
plt.close("all")
gc.collect()
def make_eval_plots(self, epoch=0):
prediction = self.model.predict(self.x, batch_size=4096)
truth = self.truth
imgs = {}
fig = figure_roc_curve(
truth,
prediction,
class_names=self.class_names,
sample_weight=self.sample_weight,
)
imgs["roc_curve"] = figure_to_image(fig)
self.clear_figure(fig)
fig = figure_confusion_matrix(
truth,
prediction,
class_names=self.class_names,
sample_weight=self.sample_weight,
normalize="true",
)
imgs["confusion_matrix_true"] = figure_to_image(fig)
self.clear_figure(fig)
fig = figure_confusion_matrix(
truth,
prediction,
class_names=self.class_names,
sample_weight=self.sample_weight,
normalize="pred",
)
imgs["confusion_matrix_pred"] = figure_to_image(fig)
self.clear_figure(fig)
if self.plot_activations:
fig = figure_node_activations(
prediction,
truth,
class_names=self.class_names,
disjoint=True,
sample_weight=self.sample_weight,
)
imgs["node_activation_disjoint"] = figure_to_image(fig)
self.clear_figure(fig)
if self.plot_importance:
importance = feature_importance(
self.model,
x=[feat[:20000] for feat in self.x],
y=self.truth[:20000],
sample_weight=self.sample_weight[:20000],
method="grad",
columns=[c for col in self.columns.values() for c in col],
)
fig = figure_dict(importance)
imgs["importance"] = figure_to_image(fig)
self.clear_figure(fig)
for name, img in imgs.items():
with self.writer.as_default():
tf.summary.image(f"{name}{self.tag}", img, step=epoch)
del imgs
class PlotMulticlassEval(PlotMulticlass):
def on_test_end(self, logs=None):
self.make_input_plots()
self.make_eval_plots(0)
class ModelLH(tf.keras.Model):
def __init__(self, *args, **kwargs):
self.loss_hook = kwargs.pop("loss_hook", None)
super(ModelLH, self).__init__(*args, **kwargs)
def _update_sample_weight_modes(self, sample_weights=None):
if not self._is_compiled:
return
if sample_weights and any([s is not None for s in sample_weights]):
pass
# don't default sample_weight_mode to "samplewise", it prevents proper function caching
# for endpoint in self._training_endpoints:
# endpoint.sample_weight_mode = (
# endpoint.sample_weight_mode or 'samplewise')
else:
for endpoint in self._training_endpoints:
endpoint.sample_weight_mode = None
def _prepare_total_loss(self, *args, **kwargs):
orig = [
(ep, ep.__dict__.copy(), ep.training_target.__dict__.copy())
for ep in self._training_endpoints
]
self.loss_hook(self._training_endpoints.copy())
ret = super(ModelLH, self)._prepare_total_loss(*args, **kwargs)
for ep, ed, td in orig:
ep.__dict__.update(ed)
ep.training_target.__dict__.update(td)
return ret
class GCCallback(tf.keras.callbacks.Callback):
def on_epoch_end(self, epoch, logs=None):
gc.collect()
class TensorBoard(tf.keras.callbacks.TensorBoard):
def __init__(self, *args, **kwargs):
self.writer = kwargs.pop("writer")
super(TensorBoard, self).__init__(*args, **kwargs)
def _init_writer(self, model=None):
pass
def on_train_end(self, logs=None):
pass
class CheckpointModel(tf.keras.callbacks.Callback):
"""Gets a dict of targets (checkpoints), if current epoch is in dict, save model to target."""
def __init__(self, checkpoints=None):
self.targets = checkpoints.targets
def on_epoch_end(self, epoch, logs=None):
if epoch in self.targets:
target = self.targets[epoch]
self.model.save(target.path)
class BestTracker(tf.keras.callbacks.Callback):
def __init__(
self,
monitor="val_loss",
mode="auto",
min_delta=0,
min_delta_rel=0,
baseline=None,
):
pin(locals())
self.reset()
@property
def mode_multiplier(self):
assert self.mode in ("auto", "min", "max")
if self.mode == "max" or (self.mode == "auto" and "acc" in self.monitor):
return -1
else:
return 1
def reset(self):
if self.baseline is not None:
self.best = self.baseline
else:
self.best = self.mode_multiplier * np.inf
def update_best(self, logs):
current = self.get_monitor_value(logs)
relative = delta = (current - self.best) * self.mode_multiplier
if np.isfinite(self.best):
relative /= abs(self.best) or 1.0
update = delta < -self.min_delta and relative < -self.min_delta_rel
if update:
self.best = current
return update
def get_monitor_value(self, logs):
assert logs
assert self.monitor in logs
return logs[self.monitor]
class PatientTracker(BestTracker):
def __init__(
self,
patience=10,
cooldown=0,
cooldown0=0,
override=None,
jank=np.nan,
jank_last=-np.inf,
**kwargs,
):
if override is None:
del override
cooldown_counter = cooldown if cooldown0 is True else cooldown0
pin(locals(), kwargs)
super(PatientTracker, self).__init__(**kwargs)
def reset(self):
super(PatientTracker, self).reset()
self.wait = 0
def override(self, logs):
return None
def patient_step(self, epoch, logs):
# update jank
logs["jank"] = min(logs.get("jank", np.inf), epoch - self.jank_last)
# test override
override = self.override(logs)
if override is not None:
return override
# check jank
if logs["jank"] <= self.jank:
return
if 0 < self.cooldown_counter:
self.cooldown_counter -= 1
if self.update_best(logs):
self.wait = 0
return "best"
elif self.cooldown_counter <= 0:
self.wait += 1
if self.patience <= self.wait:
self.cooldown_counter = self.cooldown
return "good"
class ScaleOnPlateau(PatientTracker):
def __init__(
self,
target,
factor,
min=None,
max=None,
verbose=0,
log_key=None,
linearly=False,
**kwargs,
):
pin(locals(), kwargs)
super(ScaleOnPlateau, self).__init__(**kwargs)
def on_train_begin(self, logs=None):
self.reset()
@property
def value(self):
return tf.keras.backend.get_value(self.target)
@value.setter
def value(self, new):
return tf.keras.backend.set_value(self.target, new)
def on_epoch_end(self, epoch, logs=None):
logs = logs or {}
cur = self.value
if self.log_key is not None:
logs.setdefault(self.log_key, cur)
if self.patient_step(epoch, logs) == "good":
if self.linearly:
new = cur + self.factor
else:
new = cur * self.factor
if self.min is not None:
new = max(new, self.min)
if self.max is not None:
new = min(new, self.max)
if cur != new:
self.value = new
self.reset()
if np.isfinite(self.jank):
self.jank_last = epoch
if self.verbose > 0:
print(
"\nEpoch %05d: %s scaling %s to %s."
% (
epoch + 1,
self.__class__.__name__,
self.log_key or self.target,
new,
)
)
class ReduceLROnPlateau(ScaleOnPlateau):
def __init__(self, min_lr=0, **kwargs):
super(ReduceLROnPlateau, self).__init__(min=min_lr, target=None, log_key="lr", **kwargs)
@property
def target(self):
return self.model.optimizer.lr
@target.setter
def target(self, target):
assert target is None
class EarlyStopping(PatientTracker):
def __init__(self, restore_best_weights=False, verbose=0, do_stop=None, **kwargs):
pin(locals(), kwargs)
super(EarlyStopping, self).__init__(**kwargs)
def reset(self):
super(EarlyStopping, self).reset()
self.best_weights = None
self.best_epoch = None
def on_train_begin(self, logs=None):
self.reset()
def on_epoch_end(self, epoch, logs=None):
action = self.patient_step(epoch, logs)
if action == "best":
if self.restore_best_weights:
self.best_weights = self.model.get_weights()
self.best_epoch = epoch
elif action == "good":
self.stopped_epoch = epoch
self.model.stop_training = True
if self.restore_best_weights and self.best_weights is not None:
self.model.set_weights(self.best_weights)
if self.verbose > 0:
print("Restoring model weights from the end of the best epoch.")
def on_train_end(self, logs=None):
hist = self.model.history
if self.restore_best_weights and self.best_epoch in hist.epoch:
idx = hist.epoch.index(self.best_epoch)
else:
idx = -1
hist.final_logs = {k: v[idx] for k, v in hist.history.items()}
class TQES(EarlyStopping):
def __init__(self, log_pattern=None, prog_batch="steps", **kwargs):
assert prog_batch in (None, "steps", "samples")
pin(locals(), kwargs)
super(TQES, self).__init__(**kwargs)
@property
def use_steps(self):
return self.prog_batch == "steps"
def on_train_begin(self, logs=None):
super(TQES, self).on_train_begin(logs)
self.tqE = tqdm(desc=getattr(self.model, "name", None), total=self.params["epochs"])
def on_epoch_begin(self, epoch, logs=None):
if self.prog_batch:
self.tqB = tqdm(
unit=("batch" if self.use_steps else "sample"),
total=self.params["steps" if self.use_steps else "samples"],
)
def on_batch_end(self, batch, logs=None):
if self.tqB:
logs = logs or {}
self.tqB.set_postfix(self.make_postfix(logs), refresh=False)
self.tqB.update(
logs.get("num_steps", 1) * (1 if self.use_steps else logs.get("size", 0))
)
def on_epoch_end(self, epoch, logs=None):
super(TQES, self).on_epoch_end(epoch, logs)
last = self.get_monitor_value(logs)
if self.tqB:
self.tqB.close()
self.tqB = None
self.tqE.set_postfix(
self.make_postfix(
logs,
[
("best", self.best),
("conf", self.wait),
("rdlb", ((last or 0.0) - self.best) / self.best),
],
),
refresh=False,
)
self.tqE.update(epoch - self.tqE.n)
def make_postfix(self, logs, extra=[]):
if self.log_pattern is None:
keys = self.params["metrics"]
else:
keys = _patfilter(self.log_pattern, logs.keys())
return OrderedDict(
[(key, logs[key]) for key in sorted(keys) if key in logs] + list(filter(None, extra))
)
def on_train_end(self, logs=None):
super(TQES, self).on_train_end(logs)
self.tqE.close()
class AUCOneVsAll(tf.keras.metrics.AUC):
def __init__(self, one=0, *args, **kwargs):
self.one = one
super().__init__(*args, **kwargs)
def update_state(self, y_true, y_pred, sample_weight=None):
from tensorflow.python.framework import ops
from tensorflow.python.framework import tensor_shape
from tensorflow.python.keras.utils import metrics_utils
from tensorflow.python.ops import check_ops
deps = []
if not self._built:
self._build(tensor_shape.TensorShape(y_pred.shape))
if self.multi_label or (self.label_weights is not None):
# y_true should have shape (number of examples, number of labels).
shapes = [(y_true, ("N", "L"))]
if self.multi_label:
# TP, TN, FP, and FN should all have shape
# (number of thresholds, number of labels).
shapes.extend(
[
(self.true_positives, ("T", "L")),
(self.true_negatives, ("T", "L")),
(self.false_positives, ("T", "L")),
(self.false_negatives, ("T", "L")),
]
)
if self.label_weights is not None:
# label_weights should be of length equal to the number of labels.
shapes.append((self.label_weights, ("L",)))
deps = [check_ops.assert_shapes(shapes, message="Number of labels is not consistent.")]
# Only forward label_weights to update_confusion_matrix_variables when
# multi_label is False. Otherwise the averaging of individual label AUCs is
# handled in AUC.result
label_weights = None if self.multi_label else self.label_weights
with ops.control_dependencies(deps):
return metrics_utils.update_confusion_matrix_variables(
{
metrics_utils.ConfusionMatrix.TRUE_POSITIVES: self.true_positives,
metrics_utils.ConfusionMatrix.TRUE_NEGATIVES: self.true_negatives,
metrics_utils.ConfusionMatrix.FALSE_POSITIVES: self.false_positives,
metrics_utils.ConfusionMatrix.FALSE_NEGATIVES: self.false_negatives,
},
y_true[:, self.one],
y_pred[:, self.one],
self.thresholds,
sample_weight=sample_weight,
multi_label=self.multi_label,
label_weights=label_weights,
)
# Layer Definitions
class WhereEquals(tf.keras.layers.Layer):
def __init__(self, value=0):
super(WhereEquals, self).__init__()
self.value = value
def call(self, inp):
return tf.where(inp[:, 0] == self.value)
class BaseLayer(tf.keras.layers.Layer):
"""
The BaseLayer object is a feature collection of different features used in a DNN layer
It features:
* l2 regu
* the weights (the real layer)
* batch norm
* activation function
* dynamically chosen dropout
Parameters
----------
nodes : int
The number of nodes.
activation : str or one of tf.keras.activations
The used activation function.
dropout : float
The used dropout ration.
If "selu" is used as activation function, dropout becomes AlphaDropout.
l2 : float
The used factor of l2 regu.
batch_norm : bool
Wether to use dropout or not.
If selu activation is used, batch_norm is forced off.
"""
def __init__(self, nodes=0, activation=None, dropout=0.0, l2=0, batch_norm=False, **kwargs):
super().__init__(name="BaseLayer")
self.nodes = nodes
self.activation = activation
self.dropout = dropout
self.l2 = l2
self.batch_norm = batch_norm
def build(self, input_shape):
parts = {}
if self.activation == "selu":
kernel_initializer = "lecun_normal"
else:
kernel_initializer = "glorot_uniform"
l2 = tf.keras.regularizers.l2(self.l2)
weights = tf.keras.layers.Dense(
self.nodes,
kernel_regularizer=l2,
kernel_initializer=kernel_initializer,
)
parts["weights"] = weights
if self.batch_norm and not self.activation == "selu":
bn = tf.keras.layers.BatchNormalization()
parts["bn"] = bn
act = tf.keras.layers.Activation(self.activation)
parts["act"] = act
if self.activation == "selu":
dropout = tf.keras.layers.AlphaDropout(self.dropout)
else:
dropout = tf.keras.layers.Dropout(self.dropout)
parts["dropout"] = dropout
self.parts = parts
@property
def part_order(self):
raise NotImplementedError
def call(self, input_tensor, training=False):
x = input_tensor
for part in self.part_order:
if part in self.parts:
x = self.parts[part](x, training=training)
return x
def get_config(self):
return {
"nodes": self.nodes,
"activation": self.activation,
"dropout": self.dropout,
"l2": self.l2,
"batch_norm": self.batch_norm,
}
class DenseLayer(BaseLayer):
part_order = ["weights", "bn", "act", "dropout"]
def __init__(self, **kwargs):
super().__init__(name="DenseLayer", **kwargs)
class BNActWeightLayer(BaseLayer):
part_order = ["bn", "act", "weights", "dropout"]
def __init__(self, **kwargs):
super().__init__(name="BNActWeightLayer", **kwargs)
class ResNetBlock(tf.keras.layers.Layer):
"""
The ResNetBlock object is an implementation of one residual DNN block.
Parameters
----------
jump : int
The number layers to bypass by residual connection.
kwargs :
Arguments for DenseLayer.
"""
def __init__(self, jump=2, sub_kwargs=None, **kwargs):
super().__init__(name="ResNetBlock")
self.jump = jump
self.sub_kwargs = kwargs if sub_kwargs is None else sub_kwargs
def build(self, input_shape):
layers = []
for i in range(self.jump - 1):
layers.append(DenseLayer(**self.sub_kwargs))
kwargs = dict(self.sub_kwargs)
activation = kwargs.pop("activation")
kwargs.pop("nodes")
layers.append(DenseLayer(nodes=input_shape[-1], **kwargs))
self.layers = layers
self.out_activation = tf.keras.layers.Activation(activation)
def call(self, input_tensor, training=False):
x = input_tensor
for layer in self.layers:
x = layer(x, training=training)
x += input_tensor
x = self.out_activation(x)
return x
def get_config(self):
return {"jump": self.jump, "sub_kwargs": self.sub_kwargs}
class DenseNetBlock(tf.keras.layers.Layer):
"""
The DenseNetBlock object is an implementation of one DenseNet DNN block.
Parameters
----------
block_size : int
The size of a DenseNet Block, includes the transition layer.
kwargs :
Arguments for DenseLayer.
"""
def __init__(self, block_size=2, sub_kwargs=None, **kwargs):
super().__init__(name="DenseNetBlock")
self.block_size = block_size
self.sub_kwargs = kwargs if sub_kwargs is None else sub_kwargs
def build(self, input_shape):
layers = []
for i in range(self.block_size - 1):
layers.append(DenseLayer(**self.sub_kwargs))
self.layers = layers
self.transition = DenseLayer(**self.sub_kwargs)
def call(self, input_tensor, training=False):
x = input_tensor
for layer in self.layers:
y = layer(x, training=training)
x = tf.keras.layers.concatenate([x, y], axis=-1)
x = self.transition(x)
return x
def get_config(self):
return {"block_size": self.block_size, "sub_kwargs": self.sub_kwargs}
class LinearNetwork(tf.keras.layers.Layer):
@property
def name(self):
raise NotImplementedError
@property
def substructure(self):
raise NotImplementedError
def __init__(self, layers=0, sub_kwargs=None, **kwargs):
super().__init__(name=self.name)
self.layers = layers
self.sub_kwargs = kwargs if sub_kwargs is None else sub_kwargs
def build(self, input_shape):
network_layers = []
for layer in range(self.layers):
network_layers.append(self.substructure(**self.sub_kwargs))
self.network_layers = network_layers
def call(self, input_tensor, training=False):
x = input_tensor
for layer in self.network_layers:
x = layer(x, training=training)
return x
def get_config(self):
return {"layers": self.layers, "sub_kwargs": self.sub_kwargs}
class FullyConnected(LinearNetwork):
"""
The FullyConnected object is an implementation of a fully connected DNN.
Parameters
----------
layers : int
The number of layers.
kwargs :
Arguments for DenseLayer.
"""
name = "FullyConnected"
substructure = DenseLayer
class ResNet(LinearNetwork):
"""
The ResNet object is an implementation of a Residual Neural Network.
Parameters
----------
layers : int
The number of residual blocks.
kwargs :
Arguments for ResNetBlock.
"""
name = "ResNet"
substructure = ResNetBlock
class DenseNet(LinearNetwork):
"""
The DenseNet object is an implementation of a DenseNet Neural Network.
Parameters
----------
layers : int
The number of densely connected (DenseNet) blocks.
kwargs :
Arguments for DenseNetBlock.
"""
name = "DenseNet"
substructure = DenseNetBlock
class Xception1D(tf.keras.layers.Layer):
"""
The Xception1D object is an implementation of a Xception Neural Network.
Because Xception was trained on 2D RGB data, each feature vector for the Xception1D
is automatically scaled up to (71, 71, 3).
Parameters
----------
kwargs :
Arguments for DenseLayers.
"""
def __init__(self, sub_kwargs=None, **kwargs):
super().__init__(name="DenseNet")
self.sub_kwargs = kwargs if sub_kwargs is None else sub_kwargs
def build(self, input_shape):
lenght = input_shape[-1]
assert lenght < 71 # until now only includes small feature spaces
pad = math.ceil((71 - lenght) / 2)
padded_lenght = 71 + 2 * math.ceil((71 - lenght) / 2)
self.up_sample = tf.keras.layers.UpSampling2D(size=(1, lenght))
self.zero_pad = (
tf.keras.layers.ZeroPadding2D(pad)
if lenght < 71
else tf.keras.layers.Lambda(lambda x: x)
)
self.first_network = tf.keras.applications.Xception(
include_top=False, input_shape=(71 + padded_lenght, padded_lenght, 3)
)
self.first_network.trainable = False
self.flatten = tf.keras.layers.Flatten()
self.second_network = FullyConnected(**self.sub_kwargs)
def call(self, input_tensor, training=False):
x = input_tensor
x = tf.keras.backend.expand_dims(x)
x = tf.keras.backend.expand_dims(x)
x = self.up_sample(x)
x = tf.keras.backend.repeat_elements(x, 3, axis=-1)
x = self.zero_pad(x)
x = self.first_network(x)
x = self.flatten(x)
x = self.second_network(x)
return x
def get_config(self):
return {"sub_kwargs": self.sub_kwargs}
class SplitHighLow(tf.keras.layers.Layer):
def call(self, inputs):
return inputs[:, :, :4], inputs[:, :, 4:]
class LL(tf.keras.layers.Layer):
def call(self, inputs):
return inputs[:, :, :4]
class LBNLayer(tf.keras.layers.Layer):
"""
Custom implementation of the LBNLayer with automatic cropping to
low-level variables before and batchnorm after the LBN layer.
Parameters
----------
args :
args for original LBN layer
kwargs :
kwargs for original LBN layer
"""
def __init__(self, *args, **kwargs):
super().__init__(name="LBNLayer")
_kwargs = kwargs.copy()
if "features" not in _kwargs:
_kwargs["features"] = self.all_features
self._args = args
self._kwargs = _kwargs
def build(self, input_shape):
self.concat = tf.keras.layers.Concatenate(axis=-2)
lbn_inp_shape = (sum(_[-2] for _ in input_shape), 4)
self.lbn_layer = OriginalLBNLayer(lbn_inp_shape, *self._args, **self._kwargs)
self.batch_norm = tf.keras.layers.BatchNormalization()
@property
def all_features(self):
return [
"E",
"px",
"py",
"pz",
"pt",
"p",
"m",
"phi",
"eta",
"beta",
"gamma",
"pair_cos",
"pair_dr",
"pair_ds",
"pair_dy",
]
def call(self, input_tensors, training=False):
ll = [SplitHighLow()(tensor)[0] for tensor in input_tensors]
ll = self.concat(ll)
feats = self.lbn_layer(ll)
feats = self.batch_norm(feats)
return feats
def feature_importance_grad(model, x=None, **kwargs):
inp = [tf.constant(v) for v in x]
with tf.GradientTape() as tape:
tape.watch(inp)
pred = model(inp, training=False)
ix = np.argsort(pred, axis=-1)[:, -1]
decision = tf.gather(pred, ix, batch_dims=1)
gradients = tape.gradient(decision, inp) # gradients for decision nodes
gradients = [grad if grad is not None else 0 for grad in gradients] # categorical tensors
gradients = [
np.array(_g) * _x.std(axis=0) for (_g, _x) in zip(gradients, x)
] # norm to value ranges
if "sample_weight" in kwargs:
gradients = [(_g.T * kwargs["sample_weight"]).T for _g in gradients] # apply ev weight
mean_gradients = np.concatenate([np.abs(g).mean(axis=0).flatten() for g in gradients])
return mean_gradients / mean_gradients.max()
def feature_importance_perm(model, x=None, **kwargs):
inp_list = list(x)
feat = 0 # loss
ref = model.evaluate(x=x, **kwargs)[feat]
vals = []
for index, tensor in enumerate(inp_list):
s = tensor.shape
for i in range(np.prod(s[1:])):
arr = tensor.reshape((-1, np.prod(s[1:])))
slice_before = arr[:, :i]
slice_shuffled = np.random.permutation(arr[:, i : i + 1])
slice_after = arr[:, i + 1 :]
arr_shuffled = np.concatenate([slice_before, slice_shuffled, slice_after], axis=-1)
arr_shuffled_reshaped = arr_shuffled.reshape(s)
valid = inp_list.copy()
valid[index] = arr_shuffled_reshaped
vals.append(model.evaluate(x=valid, **kwargs)[feat])
return np.array(vals) / ref
def feature_importance(*args, method="grad", columns=[], **kwargs):
if method == "grad":
importance = feature_importance_grad(*args, **kwargs)
elif method == "perm":
importance = feature_importance_perm(*args, **kwargs)
else:
raise NotImplementedError("Feature importance method not implemented")
return {
k: v
for k, v in sorted(
dict(zip(columns, importance.astype(float))).items(),
key=lambda item: item[1],
)
}
# Custom Loss Functions
@tf.function
def cross_entropy_t(
labels,
predictions,
sample_weight=None,
focal_gamma=None,
class_weight=None,
epsilon=1e-7,
):
# get true-negative component
predictions = tf.clip_by_value(predictions, epsilon, 1 - epsilon)
tn = labels * tf.math.log(predictions)
# focal loss?
if focal_gamma is not None:
tn *= (1 - predictions) ** focal_gamma
# convert into loss
loss = -tn
# apply class weights
if class_weight is not None:
loss *= class_weight
# apply sample weights
if sample_weight is not None:
loss *= sample_weight[:, tf.newaxis]
return tf.reduce_mean(tf.reduce_sum(loss, axis=-1))
@tf.function
def grouped_cross_entropy_t(
labels,
predictions,
sample_weight=None,
group_ids=None,
focal_gamma=None,
class_weight=None,
epsilon=1e-7,
std_xent=True,
):
loss_terms = []
if std_xent:
loss_terms.append(
cross_entropy_t(
labels,
predictions,
sample_weight=sample_weight,
focal_gamma=focal_gamma,
class_weight=class_weight,
epsilon=epsilon,
)
)
if group_ids:
# create grouped labels and predictions
labels_grouped, predictions_grouped = [], []
for _, ids in group_ids:
labels_grouped.append(
tf.reduce_sum(tf.gather(labels, ids, axis=-1), axis=-1, keepdims=True)
)
predictions_grouped.append(
tf.reduce_sum(tf.gather(predictions, ids, axis=-1), axis=-1, keepdims=True)
)
labels_grouped = (
tf.concat(labels_grouped, axis=-1) if len(labels_grouped) > 1 else labels_grouped[0]
)
predictions_grouped = (
tf.concat(predictions_grouped, axis=-1)
if len(predictions_grouped) > 1
else predictions_grouped[0]
)
loss_terms.append(
cross_entropy_t(
labels_grouped,
predictions_grouped,
sample_weight=sample_weight,
focal_gamma=focal_gamma,
class_weight=tf.constant([w for w, _ in group_ids], tf.float32),
epsilon=epsilon,
)
)
return sum(loss_terms) / len(loss_terms)
class GroupedXEnt(tf.keras.losses.Loss):
def __init__(
self,
group_ids=None,
focal_gamma=None,
class_weight=None,
epsilon=1e-7,
std_xent=True,
*args,
**kwargs,
):
super(GroupedXEnt, self).__init__(*args, **kwargs)
self.group_ids = group_ids
self.focal_gamma = focal_gamma
self.class_weight = class_weight
self.epsilon = epsilon
self.std_xent = std_xent
def call(self, y_true, y_pred, sample_weight=None):
return grouped_cross_entropy_t(
y_true,
y_pred,
sample_weight=sample_weight,
group_ids=self.group_ids,
focal_gamma=self.focal_gamma,
class_weight=self.class_weight,
epsilon=self.epsilon,
std_xent=self.std_xent,
)
def write_summary(model, target):
"""write tf keras model summary to law target"""
summary = []
model.summary(print_fn=lambda x: summary.append(x))
summary = "\n".join(summary)
target.dump(summary)