Skip to content
Snippets Groups Projects
Commit 4d2fb3a8 authored by Dennis Noll's avatar Dennis Noll
Browse files

[keras] gfeed: fixes weighting + adds possibility for cust. val callback

parent 677c6253
No related branches found
No related tags found
No related merge requests found
......@@ -69,9 +69,7 @@ class KFeed(object):
if isinstance(weights, SKDict):
sums = weights.map(np.sum)
ref = np.mean(list(sums.values()))
weights = weights.__class__(
{k: weights[k] * (ref / len(weights[k])) for k, s in sums.items()}
)
weights = weights.__class__({k: weights[k] * (ref / s) for k, s in sums.items()})
return weights
def kfeed(self, src, **kwargs):
......@@ -84,6 +82,11 @@ class KFeed(object):
)
def gfeed(self, src, batch_size, rng=np.random, auto_steps=np.max, validation=None, **kwargs):
"""
Creates a generator for tf.keras' model.fit().
Requires, that mean weights per process are equal:
dss["weight"] = dss["weight"].map(lambda x: x / np.mean(x))
"""
src = src.only(*self.all)
val = src[self.valid]
assert not isinstance(self.w, tuple)
......@@ -102,11 +105,13 @@ class KFeed(object):
**kwargs
)
if validation is not None:
validation_copy = validation.copy()
ret.setdefault("callbacks", []).insert(
0,
CustomValidation(
validation_copy.pop("cls", CustomValidation)(
**dict(
zip(("x", "y", "sample_weight"), ret.pop("validation_data")), **validation
zip(("x", "y", "sample_weight"), ret.pop("validation_data")),
**validation_copy
)
),
)
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment