#from itertools import izip from collections import OrderedDict, defaultdict import fnmatch import numpy as np import tensorflow as tf from tqdm import tqdm from inspect import getargspec from warnings import warn import re from tensorflow.python.keras.callbacks import make_logs from tensorflow.python.keras.backend import track_variable from tensorflow.python.keras.utils.mode_keys import ModeKeys from operator import itemgetter from .evil import pin from .data import SKDict, DSS # various helper functions def kVar(*args, **kwargs): """ produce a keras-tracked tf.Variable from all given parameters """ var = tf.Variable(*args, **kwargs) track_variable(var) return var def kOpt(opt, **kwargs): """ instanciate a keras Optimizer with all applicable **kwargs """ if not callable(opt): opt = getattr(tf.keras.optimizers, opt) assert issubclass(opt, tf.keras.optimizers.Optimizer) args = getargspec(opt.__init__).args - {"self"} return opt(**{ k: v for k, v in kwargs.items() if k in args }) def kInput(ref, **kwargs): """ produce a keras.Input with shape & dtype according to ref """ kwargs.setdefault("shape", ref.shape[1:]) kwargs.setdefault("dtype", ref.dtype) return tf.keras.Input(**kwargs) def keras_register_custom_object(obj): """ decorator for globally registering a custom object with keras """ tf.keras.utils.get_custom_objects()[obj.__name__] = obj return obj class KFeed(object): def __init__(self, x, y, w=None, train="train", valid="valid", balance=lambda x: x): pin(locals()) @property def xyw(self): return tuple( v if isinstance(v, tuple) else (v,) for v in (self.x, self.y, self.w) ) def get(self, src): return tuple(itemgetter(*v)(src) for v in self.xyw) @property def all(self): return set().union(*self.xyw) def balance(self, src): return src def balance_weights(self, weights): if isinstance(weights, SKDict): sums = weights.map(np.sum) ref = np.mean(list(sums.values())) weights = weights.__class__({ k: weights[k] * (ref / len(weights[k])) for k, s in sums.items() }) return weights def kfeed(self, src, **kwargs): src = src.only(*self.all) bal = self.balance(src) return dict( zip(["x", "y", "sample_weight"], self.get(bal[self.train])), validation_data=self.get(bal[self.valid]), **kwargs ) def gfeed(self, src, batch_size, rng=np.random, auto_steps=np.max, validation=None, **kwargs): src = src.only(*self.all) val = src[self.valid] assert not isinstance(self.w, tuple) val[self.w] = self.balance_weights(val[self.w]) val = val.fuse(*val[self.w].keys()) val.blen ret = dict( dict(zip( ("x", "steps_per_epoch"), self.gensteps(src[self.train], batch_size, rng=rng, auto_steps=auto_steps) )), validation_data=self.get(val), workers=0, **kwargs ) if validation is not None: ret.setdefault("callbacks", []).insert(0, CustomValidation( **dict(zip( ("x", "y", "sample_weight"), ret.pop("validation_data") ), **validation) )) return ret def gensteps(self, src, batch_size, rng=np.random, auto_steps=np.max): keys = src.mkeys(self.all) gen = ( self.get(DSS.zip(*parts).map(np.concatenate)) for parts in zip(*[ src[k].batch_generator( batch_size // len(keys), rng=np.random.RandomState(rng.randint(1 << 31, size=20)) ) for k in keys ]) ) if len(keys) > 1 else src[keys[0]].batch_generator(batch_size, rng=rng) gs = float(batch_size // len(keys)) steps = int(auto_steps([src[k].blen / gs for k in keys])) or None return gen, steps def generator(self, *args, **kwargs): assert "auto_steps" not in kwargs return self.gensteps(*args, **kwargs)[0] def getShapes(self, src): return tuple( tuple(src[k].shape for k in g) for g in self.xyw ) def mkInputs(self, src, **kwargs): return tuple( tf.keras.Input(shape=ref.shape[1:], dtype=ref.dtype, **kwargs) for ref in ( src[x] for x in self.xyw[0] ) ) def Normal(ref, indices=None, name=None, **kwargs): """ Normalizing layer according to ref. If given, only the variables corresponding to indices will be normalized. """ mean = ref.mean(**kwargs) std = ref.std(**kwargs) if indices: indices = np.array(indices) replace = np.isin(np.arange(len(mean)), indices, invert=True) mean[replace] = 0 std[replace] = 1 mul = 1.0 / std add = -mean / std return tf.keras.layers.Lambda((lambda x: (x * mul) + add), name=name) def Onehot(index, n, name=None): """ One hot encodes a variable referred to by index. n is the number of different variables. """ def to_onehot(x): #Concat zeros to eye for indices in x equal to n (larger than those encoded by one) eye = tf.concat((tf.eye(n), tf.zeros((1, n))), axis=0) return tf.concat(( x[...,:index], tf.gather(eye, tf.cast(x[...,index], tf.int64)), x[...,(index+1):], ), axis=-1) return tf.keras.layers.Lambda(to_onehot, name=name) class Moment(tf.keras.metrics.Mean): def __init__(self, order, label=False, **kwargs): """ Metric calculating the order-th moment """ assert order == int(order) assert label == bool(label) kwargs.setdefault("name", "%smom%d" % ("l" if label else "", order)) super(Moment, self).__init__(**kwargs) self.order = order self.label = label def update_state(self, y_true, y_pred, sample_weight=None): y = y_true if self.label else y_pred y = tf.keras.backend.cast(y, self._dtype) if self.order == 0: y = tf.keras.backend.ones_like(y) elif self.order == 1: pass elif self.order == 2: y = tf.keras.backend.square(y) else: y = tf.keras.backend.pow(y, self.order) return super(Moment, self).update_state(y, sample_weight=sample_weight) def get_config(self): return dict( super(Moment, self).get_config(), order=self.order, label=self.label, ) # lots of call backs def _patfilter(pattern, items): if isinstance(pattern, (list, tuple)): pattern = "|".join(map(fnmatch.translate, pattern)) return filter(re.compile(pattern).search, items) class Moment2Std(tf.keras.callbacks.Callback): def __init__(self, mom1="mom1", mom2="mom2", std="std"): assert mom1 and mom2 and std pin(locals()) def on_x_end(self, x, logs=None): if logs is None: return for key1, mom1 in list(logs.items()): if not key1.endswith(self.mom1): continue prefix = key1[:-len(self.mom1)] mom2 = logs.get(prefix + self.mom2, None) if mom2 is not None: logs[prefix + self.std] = (mom2 - mom1**2)**0.5 on_batch_end = on_x_end on_epoch_end = on_x_end class LogRewrite(tf.keras.callbacks.Callback): def __init__(self, *rewrites, **kwargs): self.collapse_weighted = kwargs.pop("collapse_weighted", False) assert not kwargs self.rewrites = list(rewrites) if self.collapse_weighted: self.rewrites.append(lambda s: s.replace("weighted_", "")) def on_x_end(self, x, logs=None): if logs is None: return shadow = [] for key, val in logs.items(): short = key for rewrite in self.rewrites: short = rewrite(short) if short == key: continue if short in logs: shadow.append(short) logs[short] = val del logs[key] if shadow: warn("%r: shadows the following keys: %s" % (",".join(shadow))) on_batch_end = on_x_end on_epoch_end = on_x_end class LogTransforms(tf.keras.callbacks.Callback): _prefixesRe = re.compile("^(val_|)(.+)") def __init__(self, *funcs, **kwargs): self.transforms = [] for func in funcs: self.add_transform(func) for name, func in kwargs.items(): self.add_transform(func, name) def add_transform(self, func, name=None): if name is None: name = func.__name__ assert re.match(r"[a-zA-Z]\w*$", name) argspec = getargspec(func) assert argspec.keywords or argspec.args self.transforms.append((name, func, True if argspec.keywords else set(argspec.args))) def on_x_end(self, x, logs=None): if logs is None: return plogs = {} for key, value in logs.items(): prefix, suffix = self._prefixesRe.match(key).groups() plogs.setdefault(prefix, {})[suffix] = value for name, func, args in self.transforms: for prefix, log in plogs.items(): if args is True: val = func(**log) else: missing = args.difference(log.keys()) if missing: msg = "%r: %s=%r needs unavailable log values: %s" % ( self, name, func, ", ".join(missing) ) if prefix: warn(msg) continue else: raise RuntimeError(msg) val = func(**{ key: val for key, val in log.items() if key in args }) logs[prefix + name] = log[name] = val on_batch_end = on_x_end on_epoch_end = on_x_end class LogEMA(tf.keras.callbacks.Callback): def __init__(self, keys=[], pattern=[], pairs={}, ema=0, ricu=0, ema_fmt="%s_ema"): assert ema < 1 and ricu < 1 and "%s" in ema_fmt if ema < 0: ema = 1 - 10**ema assert ema < 1 if ricu < 0: ricu = -1. / ricu keys = {k: None for k in keys} keys.update(pairs) pin(locals(), pairs) self.reset() def reset(self): self.ema_val = {} self.run_num = defaultdict(int) def on_epoch_end(self, x, logs=None): assert logs keys = { key: None for key in _patfilter(self.pattern, logs.keys()) } keys.update(self.keys) for key, out in keys.items(): if not out: out = self.ema_fmt % key assert key in logs assert out not in keys assert out not in logs val = logs[key] if key in self.ema_val: delta = self.ema_val[key] - val dsign = int(np.sign(delta)) self.run_num[key] += dsign if dsign * np.sign(self.run_num[key]) < -self.ricu < 0: self.run_num[key] *= self.ricu**.5 val += delta * self.ema**(1 + (self.ricu * self.run_num[key])**4) logs[out] = self.ema_val[key] = val class CustomValidation(tf.keras.callbacks.Callback): def __init__(self, **kwargs): self.kwargs = kwargs def on_epoch_end(self, x, logs=None): if logs is None: return res = self.model.evaluate(**self.kwargs) if not isinstance(res, list): res = [res] logs.update(make_logs(self.model, logs, res, mode=ModeKeys.TEST, prefix="val_")) class ModelLH(tf.keras.Model): def __init__(self, *args, **kwargs): self.loss_hook = kwargs.pop("loss_hook", None) super(ModelLH, self).__init__(*args, **kwargs) def _update_sample_weight_modes(self, sample_weights=None): if not self._is_compiled: return if sample_weights and any([s is not None for s in sample_weights]): pass # don't default sample_weight_mode to "samplewise", it prevents proper function caching # for endpoint in self._training_endpoints: # endpoint.sample_weight_mode = ( # endpoint.sample_weight_mode or 'samplewise') else: for endpoint in self._training_endpoints: endpoint.sample_weight_mode = None def _prepare_total_loss(self, *args, **kwargs): orig = [ (ep, ep.__dict__.copy(), ep.training_target.__dict__.copy()) for ep in self._training_endpoints ] self.loss_hook(self._training_endpoints.copy()) ret = super(ModelLH, self)._prepare_total_loss(*args, **kwargs) for ep, ed, td in orig: ep.__dict__.update(ed) ep.training_target.__dict__.update(td) return ret class TensorBoard(tf.keras.callbacks.TensorBoard): def __init__(self, *args, **kwargs): self.writer = kwargs.pop("writer") super(TensorBoard, self).__init__(*args, **kwargs) def _init_writer(self, model=None): pass def on_train_end(self, logs=None): pass class BestTracker(tf.keras.callbacks.Callback): def __init__(self, monitor="val_loss", mode="auto", min_delta=0, min_delta_rel=0, baseline=None): pin(locals()) self.reset() @property def mode_multiplier(self): assert self.mode in ("auto", "min", "max") if self.mode == "max" or (self.mode == "auto" and "acc" in self.monitor): return -1 else: return 1 def reset(self): if self.baseline is not None: self.best = self.baseline else: self.best = self.mode_multiplier * np.inf def update_best(self, logs): current = self.get_monitor_value(logs) relative = delta = (current - self.best) * self.mode_multiplier if np.isfinite(self.best): relative /= abs(self.best) or 1. update = delta < -self.min_delta and relative < -self.min_delta_rel if update: self.best = current return update def get_monitor_value(self, logs): assert logs assert self.monitor in logs return logs[self.monitor] class PatientTracker(BestTracker): def __init__(self, patience=10, cooldown=0, cooldown0=0, override=None, jank=np.nan, jank_last=-np.inf, **kwargs): if override is None: del override cooldown_counter = cooldown if cooldown0 is True else cooldown0 pin(locals(), kwargs) super(PatientTracker, self).__init__(**kwargs) def reset(self): super(PatientTracker, self).reset() self.wait = 0 def override(self, logs): return None def patient_step(self, epoch, logs): # update jank logs["jank"] = min( logs.get("jank", np.inf), epoch - self.jank_last ) # test override override = self.override(logs) if override is not None: return override # check jank if logs["jank"] <= self.jank: return if 0 < self.cooldown_counter: self.cooldown_counter -= 1 if self.update_best(logs): self.wait = 0 return "best" elif self.cooldown_counter <= 0: self.wait += 1 if self.patience <= self.wait: self.cooldown_counter = self.cooldown return "good" class ScaleOnPlateau(PatientTracker): def __init__(self, target, factor, min=None, max=None, verbose=0, log_key=None, **kwargs): pin(locals(), kwargs) super(ScaleOnPlateau, self).__init__(**kwargs) def on_train_begin(self, logs=None): self.reset() @property def value(self): return tf.keras.backend.get_value(self.target) @value.setter def value(self, new): return tf.keras.backend.set_value(self.target, new) def on_epoch_end(self, epoch, logs=None): logs = logs or {} cur = self.value if self.log_key is not None: logs.setdefault(self.log_key, cur) if self.patient_step(epoch, logs) == "good": new = cur * self.factor if self.min is not None: new = max(new, self.min) if self.max is not None: new = min(new, self.max) if cur != new: self.value = new self.reset() if np.isfinite(self.jank): self.jank_last = epoch if self.verbose > 0: print("\nEpoch %05d: %s scaling %s to %s." % ( epoch + 1, self.__class__.__name__, self.log_key or self.target, new )) class ReduceLROnPlateau(ScaleOnPlateau): def __init__(self, min_lr=0, **kwargs): super(ReduceLROnPlateau, self).__init__(min=min_lr, target=None, log_key="lr", **kwargs) @property def target(self): return self.model.optimizer.lr @target.setter def target(self, target): assert target is None class EarlyStopping(PatientTracker): def __init__(self, restore_best_weights=False, verbose=0, do_stop=None, **kwargs): pin(locals(), kwargs) super(EarlyStopping, self).__init__(**kwargs) def reset(self): super(EarlyStopping, self).reset() self.best_weights = None self.best_epoch = None def on_train_begin(self, logs=None): self.reset() def on_epoch_end(self, epoch, logs=None): action = self.patient_step(epoch, logs) if action == "best": if self.restore_best_weights: self.best_weights = self.model.get_weights() self.best_epoch = epoch elif action == "good": self.stopped_epoch = epoch self.model.stop_training = True if self.restore_best_weights and self.best_weights is not None: self.model.set_weights(self.best_weights) if self.verbose > 0: print('Restoring model weights from the end of the best epoch.') def on_train_end(self, logs=None): hist = self.model.history if self.restore_best_weights and self.best_epoch in hist.epoch: idx = hist.epoch.index(self.best_epoch) else: idx = -1 hist.final_logs = { k: v[idx] for k, v in hist.history.items() } class TQES(EarlyStopping): def __init__(self, log_pattern=None, prog_batch="steps", **kwargs): assert prog_batch in (None, "steps", "samples") pin(locals(), kwargs) super(TQES, self).__init__(**kwargs) @property def use_steps(self): return self.prog_batch == "steps" def on_train_begin(self, logs=None): super(TQES, self).on_train_begin(logs) self.tqE = tqdm(desc=getattr(self.model, "name", None), total=self.params["epochs"]) def on_epoch_begin(self, epoch, logs=None): if self.prog_batch: self.tqB = tqdm( unit=("batch" if self.use_steps else "sample"), total=self.params["steps" if self.use_steps else "samples"] ) def on_batch_end(self, batch, logs=None): if self.tqB: logs = logs or {} self.tqB.set_postfix(self.make_postfix(logs), refresh=False) self.tqB.update(logs.get("num_steps", 1) * ( 1 if self.use_steps else logs.get("size", 0) )) def on_epoch_end(self, epoch, logs=None): super(TQES, self).on_epoch_end(epoch, logs) last = self.get_monitor_value(logs) if self.tqB: self.tqB.close() self.tqB = None self.tqE.set_postfix(self.make_postfix(logs, [ ("best", self.best), ("conf", self.wait), ("rdlb", ((last or 0.) - self.best) / self.best) ]), refresh=False) self.tqE.update(epoch - self.tqE.n) def make_postfix(self, logs, extra=[]): if self.log_pattern is None: keys = self.params["metrics"] else: keys = _patfilter(self.log_pattern, logs.keys()) return OrderedDict([ (key, logs[key]) for key in sorted(keys) if key in logs ] + list(filter(None, extra))) def on_train_end(self, logs=None): super(TQES, self).on_train_end(logs) self.tqE.close()