diff --git a/keras.py b/keras.py index 26056a41330f41d11f5b5b7e04ee2981c9f7e6ab..70034009dd06b85403ea1d3de02f5326d6c701f7 100644 --- a/keras.py +++ b/keras.py @@ -69,9 +69,7 @@ class KFeed(object): if isinstance(weights, SKDict): sums = weights.map(np.sum) ref = np.mean(list(sums.values())) - weights = weights.__class__( - {k: weights[k] * (ref / len(weights[k])) for k, s in sums.items()} - ) + weights = weights.__class__({k: weights[k] * (ref / s) for k, s in sums.items()}) return weights def kfeed(self, src, **kwargs): @@ -84,6 +82,11 @@ class KFeed(object): ) def gfeed(self, src, batch_size, rng=np.random, auto_steps=np.max, validation=None, **kwargs): + """ + Creates a generator for tf.keras' model.fit(). + Requires, that mean weights per process are equal: + dss["weight"] = dss["weight"].map(lambda x: x / np.mean(x)) + """ src = src.only(*self.all) val = src[self.valid] assert not isinstance(self.w, tuple) @@ -102,11 +105,13 @@ class KFeed(object): **kwargs ) if validation is not None: + validation_copy = validation.copy() ret.setdefault("callbacks", []).insert( 0, - CustomValidation( + validation_copy.pop("cls", CustomValidation)( **dict( - zip(("x", "y", "sample_weight"), ret.pop("validation_data")), **validation + zip(("x", "y", "sample_weight"), ret.pop("validation_data")), + **validation_copy ) ), )