diff --git a/keras.py b/keras.py
index 26056a41330f41d11f5b5b7e04ee2981c9f7e6ab..70034009dd06b85403ea1d3de02f5326d6c701f7 100644
--- a/keras.py
+++ b/keras.py
@@ -69,9 +69,7 @@ class KFeed(object):
         if isinstance(weights, SKDict):
             sums = weights.map(np.sum)
             ref = np.mean(list(sums.values()))
-            weights = weights.__class__(
-                {k: weights[k] * (ref / len(weights[k])) for k, s in sums.items()}
-            )
+            weights = weights.__class__({k: weights[k] * (ref / s) for k, s in sums.items()})
         return weights
 
     def kfeed(self, src, **kwargs):
@@ -84,6 +82,11 @@ class KFeed(object):
         )
 
     def gfeed(self, src, batch_size, rng=np.random, auto_steps=np.max, validation=None, **kwargs):
+        """
+        Creates a generator for tf.keras' model.fit().
+        Requires, that mean weights per process are equal:
+            dss["weight"] = dss["weight"].map(lambda x: x / np.mean(x))
+        """
         src = src.only(*self.all)
         val = src[self.valid]
         assert not isinstance(self.w, tuple)
@@ -102,11 +105,13 @@ class KFeed(object):
             **kwargs
         )
         if validation is not None:
+            validation_copy = validation.copy()
             ret.setdefault("callbacks", []).insert(
                 0,
-                CustomValidation(
+                validation_copy.pop("cls", CustomValidation)(
                     **dict(
-                        zip(("x", "y", "sample_weight"), ret.pop("validation_data")), **validation
+                        zip(("x", "y", "sample_weight"), ret.pop("validation_data")),
+                        **validation_copy
                     )
                 ),
             )