diff --git a/keras.py b/keras.py index c211feb25cb1ef26fe27280de5bd30e089d3cce4..b59a31365b8bb2675496ac872d13531057785b27 100644 --- a/keras.py +++ b/keras.py @@ -56,12 +56,18 @@ class KFeed(object): pin(locals()) @property - def get(self): - return itemgetter(self.x, self.y, self.w) + def xyw(self): + return tuple( + v if isinstance(v, tuple) else (v,) + for v in (self.x, self.y, self.w) + ) + + def get(self, src): + return tuple(itemgetter(*v)(src) for v in self.xyw) @property def all(self): - return {self.x, self.y, self.w} + return set().union(*self.xyz) def balance(self, src): return src @@ -88,6 +94,7 @@ class KFeed(object): def gfeed(self, src, batch_size, rng=np.random, auto_steps=np.max, validation=None, **kwargs): src = src.only(*self.all) val = src[self.valid] + assert not isinstance(self.w, tuple) val[self.w] = self.balance_weights(val[self.w]) val = val.fuse(*val[self.w].keys()) val.blen @@ -339,25 +346,26 @@ class ModelLH(tf.keras.Model): def __init__(self, *args, **kwargs): self.loss_hook = kwargs.pop("loss_hook", None) super(ModelLH, self).__init__(*args, **kwargs) - + def _update_sample_weight_modes(self, sample_weights=None): if not self._is_compiled: return if sample_weights and any([s is not None for s in sample_weights]): pass -# for endpoint in self._training_endpoints: -# endpoint.sample_weight_mode = ( -# endpoint.sample_weight_mode or 'samplewise') + # don't default sample_weight_mode to "samplewise", it prevents proper function caching + # for endpoint in self._training_endpoints: + # endpoint.sample_weight_mode = ( + # endpoint.sample_weight_mode or 'samplewise') else: for endpoint in self._training_endpoints: endpoint.sample_weight_mode = None - + def _prepare_total_loss(self, *args, **kwargs): orig = [ (ep, ep.__dict__.copy(), ep.training_target.__dict__.copy()) for ep in self._training_endpoints ] - + self.loss_hook(self._training_endpoints.copy()) ret = super(ModelLH, self)._prepare_total_loss(*args, **kwargs)