diff --git a/keras.py b/keras.py index 8668f2e56344d9feda67b273408298d2ebaa98ed..e80d991f9aa9f3eec01c810e662a4263d0bf7f39 100644 --- a/keras.py +++ b/keras.py @@ -7,8 +7,9 @@ from tqdm import tqdm from inspect import getargspec from warnings import warn import re -#from tensorflow.python.keras.engine.training_utils import make_logs +from tensorflow.python.keras.callbacks import make_logs from tensorflow.python.keras.backend import track_variable +from tensorflow.python.keras.utils.mode_keys import ModeKeys from operator import itemgetter from .evil import pin @@ -68,9 +69,9 @@ class KFeed(object): def balance_weights(self, weights): if isinstance(weights, SKDict): sums = weights.map(np.sum) - ref = np.mean(sums.values()) + ref = np.mean(list(sums.values())) weights = weights.__class__({ - k: weights[k] * (ref / s) + k: weights[k] * (ref / len(weights[k])) for k, s in sums.items() }) return weights @@ -92,7 +93,7 @@ class KFeed(object): val.blen ret = dict( dict(zip( - ("generator", "steps_per_epoch"), + ("x", "steps_per_epoch"), self.gensteps(src[self.train], batch_size, rng=rng, auto_steps=auto_steps) )), validation_data=self.get(val), @@ -112,7 +113,7 @@ class KFeed(object): keys = src.mkeys(self.all) gen = ( self.get(DSS.zip(*parts).map(np.concatenate)) - for parts in izip(*[ + for parts in zip(*[ src[k].batch_generator( batch_size // len(keys), rng=np.random.RandomState(rng.randint(1 << 31, size=20)) @@ -331,7 +332,7 @@ class CustomValidation(tf.keras.callbacks.Callback): res = self.model.evaluate(**self.kwargs) if not isinstance(res, list): res = [res] - logs.update(make_logs(self.model, res, "train", prefix="val_")) + logs.update(make_logs(self.model, logs, res, mode=ModeKeys.TEST, prefix="val_")) class ModelWH(tf.keras.Model): @@ -573,7 +574,7 @@ class TQES(EarlyStopping): (key, logs[key]) for key in sorted(keys) if key in logs - ] + filter(None, extra)) + ] + list(filter(None, extra))) def on_train_end(self, logs=None): super(TQES, self).on_train_end(logs)