diff --git a/keras.py b/keras.py
index c211feb25cb1ef26fe27280de5bd30e089d3cce4..b59a31365b8bb2675496ac872d13531057785b27 100644
--- a/keras.py
+++ b/keras.py
@@ -56,12 +56,18 @@ class KFeed(object):
         pin(locals())
 
     @property
-    def get(self):
-        return itemgetter(self.x, self.y, self.w)
+    def xyw(self):
+        return tuple(
+            v if isinstance(v, tuple) else (v,)
+            for v in (self.x, self.y, self.w)
+        )
+
+    def get(self, src):
+        return tuple(itemgetter(*v)(src) for v in self.xyw)
 
     @property
     def all(self):
-        return {self.x, self.y, self.w}
+        return set().union(*self.xyz)
 
     def balance(self, src):
         return src
@@ -88,6 +94,7 @@ class KFeed(object):
     def gfeed(self, src, batch_size, rng=np.random, auto_steps=np.max, validation=None, **kwargs):
         src = src.only(*self.all)
         val = src[self.valid]
+        assert not isinstance(self.w, tuple)
         val[self.w] = self.balance_weights(val[self.w])
         val = val.fuse(*val[self.w].keys())
         val.blen
@@ -339,25 +346,26 @@ class ModelLH(tf.keras.Model):
     def __init__(self, *args, **kwargs):
         self.loss_hook = kwargs.pop("loss_hook", None)
         super(ModelLH, self).__init__(*args, **kwargs)
-    
+
     def _update_sample_weight_modes(self, sample_weights=None):
         if not self._is_compiled:
             return
         if sample_weights and any([s is not None for s in sample_weights]):
             pass
-#            for endpoint in self._training_endpoints:
-#                endpoint.sample_weight_mode = (
-#                    endpoint.sample_weight_mode or 'samplewise')
+            # don't default sample_weight_mode to "samplewise", it prevents proper function caching
+            # for endpoint in self._training_endpoints:
+            #     endpoint.sample_weight_mode = (
+            #         endpoint.sample_weight_mode or 'samplewise')
         else:
             for endpoint in self._training_endpoints:
                 endpoint.sample_weight_mode = None
-    
+
     def _prepare_total_loss(self, *args, **kwargs):
         orig = [
             (ep, ep.__dict__.copy(), ep.training_target.__dict__.copy())
             for ep in self._training_endpoints
         ]
-        
+
         self.loss_hook(self._training_endpoints.copy())
         ret = super(ModelLH, self)._prepare_total_loss(*args, **kwargs)