Skip to content
Snippets Groups Projects
Commit 3c4a65e0 authored by Dennis Noll's avatar Dennis Noll
Browse files

[keras] gfeed: now uses list of validation callbacks and not only one

parent af94627b
No related branches found
No related tags found
No related merge requests found
......@@ -90,7 +90,7 @@ class KFeed(object):
**kwargs,
)
def gfeed(self, src, batch_size, rng=np.random, auto_steps=np.max, validation=None, **kwargs):
def gfeed(self, src, batch_size, rng=np.random, auto_steps=np.max, val_callbacks=[], **kwargs):
"""
Creates a generator for tf.keras' model.fit().
Requires, that mean weights per process are equal:
......@@ -113,14 +113,14 @@ class KFeed(object):
workers=0,
**kwargs,
)
if validation is not None:
validation_copy = validation.copy()
for val_callback in val_callbacks:
val_callback_copy = val_callback.copy()
ret.setdefault("callbacks", []).insert(
0,
validation_copy.pop("cls", CustomValidation)(
val_callback_copy.pop("cls", CustomValidation)(
**dict(
zip(("x", "y", "sample_weight"), ret.pop("validation_data")),
**validation_copy,
zip(("x", "y", "sample_weight"), ret.get("validation_data")),
**val_callback_copy,
)
),
)
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment