From 3d1dff560c10aaab07700ee7cb250dcafd775c1e Mon Sep 17 00:00:00 2001
From: Dennis Noll <dennis.noll@rwth-aachen.de>
Date: Tue, 28 Jul 2020 13:38:29 +0200
Subject: [PATCH] [keras] gfeed: removed validation

---
 keras.py | 17 +++--------------
 1 file changed, 3 insertions(+), 14 deletions(-)

diff --git a/keras.py b/keras.py
index 0598cd3..0f39c0e 100644
--- a/keras.py
+++ b/keras.py
@@ -93,7 +93,7 @@ class KFeed(object):
             **kwargs,
         )
 
-    def gfeed(self, src, batch_size, rng=np.random, auto_steps=np.max, val_callbacks=[], **kwargs):
+    def gfeed(self, src, batch_size, rng=np.random, auto_steps=np.max, **kwargs):
         """
         Creates a generator for tf.keras' model.fit().
         Requires, that mean weights per process are equal:
@@ -105,7 +105,7 @@ class KFeed(object):
         val[self.w] = self.balance_weights(val[self.w])
         val = val.fuse(*val[self.w].keys())
         val.blen
-        ret = dict(
+        return dict(
             dict(
                 zip(
                     ("x", "steps_per_epoch"),
@@ -116,18 +116,6 @@ class KFeed(object):
             workers=0,
             **kwargs,
         )
-        for val_callback in val_callbacks:
-            val_callback_copy = val_callback.copy()
-            ret.setdefault("callbacks", []).insert(
-                0,
-                val_callback_copy.pop("cls", CustomValidation)(
-                    **dict(
-                        zip(("x", "y", "sample_weight"), ret.get("validation_data")),
-                        **val_callback_copy,
-                    )
-                ),
-            )
-        return ret
 
     def gensteps(self, src, batch_size, rng=np.random, auto_steps=np.max):
         keys = src.mkeys(self.all)
@@ -415,6 +403,7 @@ class PlotMulticlass(TFSummaryCallback):
         self.truth = kwargs["y"]
         self.sample_weight = kwargs.get("sample_weight", None)
         self.class_names = class_names
+        self.plot_inputs = plot_inputs
         self.columns = columns
         self.to_file = to_file
 
-- 
GitLab