diff --git a/keras.py b/keras.py
index 8668f2e56344d9feda67b273408298d2ebaa98ed..e80d991f9aa9f3eec01c810e662a4263d0bf7f39 100644
--- a/keras.py
+++ b/keras.py
@@ -7,8 +7,9 @@ from tqdm import tqdm
 from inspect import getargspec
 from warnings import warn
 import re
-#from tensorflow.python.keras.engine.training_utils import make_logs
+from tensorflow.python.keras.callbacks import make_logs
 from tensorflow.python.keras.backend import track_variable
+from tensorflow.python.keras.utils.mode_keys import ModeKeys
 from operator import itemgetter
 
 from .evil import pin
@@ -68,9 +69,9 @@ class KFeed(object):
     def balance_weights(self, weights):
         if isinstance(weights, SKDict):
             sums = weights.map(np.sum)
-            ref = np.mean(sums.values())
+            ref = np.mean(list(sums.values()))
             weights = weights.__class__({
-                k: weights[k] * (ref / s)
+                k: weights[k] * (ref / len(weights[k]))
                 for k, s in sums.items()
             })
         return weights
@@ -92,7 +93,7 @@ class KFeed(object):
         val.blen
         ret = dict(
             dict(zip(
-                ("generator", "steps_per_epoch"),
+                ("x", "steps_per_epoch"),
                 self.gensteps(src[self.train], batch_size, rng=rng, auto_steps=auto_steps)
             )),
             validation_data=self.get(val),
@@ -112,7 +113,7 @@ class KFeed(object):
         keys = src.mkeys(self.all)
         gen = (
             self.get(DSS.zip(*parts).map(np.concatenate))
-            for parts in izip(*[
+            for parts in zip(*[
                 src[k].batch_generator(
                     batch_size // len(keys),
                     rng=np.random.RandomState(rng.randint(1 << 31, size=20))
@@ -331,7 +332,7 @@ class CustomValidation(tf.keras.callbacks.Callback):
         res = self.model.evaluate(**self.kwargs)
         if not isinstance(res, list):
             res = [res]
-        logs.update(make_logs(self.model, res, "train", prefix="val_"))
+        logs.update(make_logs(self.model, logs, res, mode=ModeKeys.TEST, prefix="val_"))
 
 
 class ModelWH(tf.keras.Model):
@@ -573,7 +574,7 @@ class TQES(EarlyStopping):
             (key, logs[key])
             for key in sorted(keys)
             if key in logs
-        ] + filter(None, extra))
+        ] + list(filter(None, extra)))
 
     def on_train_end(self, logs=None):
         super(TQES, self).on_train_end(logs)