From 3d1dff560c10aaab07700ee7cb250dcafd775c1e Mon Sep 17 00:00:00 2001 From: Dennis Noll <dennis.noll@rwth-aachen.de> Date: Tue, 28 Jul 2020 13:38:29 +0200 Subject: [PATCH] [keras] gfeed: removed validation --- keras.py | 17 +++-------------- 1 file changed, 3 insertions(+), 14 deletions(-) diff --git a/keras.py b/keras.py index 0598cd3..0f39c0e 100644 --- a/keras.py +++ b/keras.py @@ -93,7 +93,7 @@ class KFeed(object): **kwargs, ) - def gfeed(self, src, batch_size, rng=np.random, auto_steps=np.max, val_callbacks=[], **kwargs): + def gfeed(self, src, batch_size, rng=np.random, auto_steps=np.max, **kwargs): """ Creates a generator for tf.keras' model.fit(). Requires, that mean weights per process are equal: @@ -105,7 +105,7 @@ class KFeed(object): val[self.w] = self.balance_weights(val[self.w]) val = val.fuse(*val[self.w].keys()) val.blen - ret = dict( + return dict( dict( zip( ("x", "steps_per_epoch"), @@ -116,18 +116,6 @@ class KFeed(object): workers=0, **kwargs, ) - for val_callback in val_callbacks: - val_callback_copy = val_callback.copy() - ret.setdefault("callbacks", []).insert( - 0, - val_callback_copy.pop("cls", CustomValidation)( - **dict( - zip(("x", "y", "sample_weight"), ret.get("validation_data")), - **val_callback_copy, - ) - ), - ) - return ret def gensteps(self, src, batch_size, rng=np.random, auto_steps=np.max): keys = src.mkeys(self.all) @@ -415,6 +403,7 @@ class PlotMulticlass(TFSummaryCallback): self.truth = kwargs["y"] self.sample_weight = kwargs.get("sample_weight", None) self.class_names = class_names + self.plot_inputs = plot_inputs self.columns = columns self.to_file = to_file -- GitLab