diff --git a/keras.py b/keras.py index 2c2fe84c8e20ff8dca8f1b014c3f34d1d30ec772..c211feb25cb1ef26fe27280de5bd30e089d3cce4 100644 --- a/keras.py +++ b/keras.py @@ -339,13 +339,25 @@ class ModelLH(tf.keras.Model): def __init__(self, *args, **kwargs): self.loss_hook = kwargs.pop("loss_hook", None) super(ModelLH, self).__init__(*args, **kwargs) - + + def _update_sample_weight_modes(self, sample_weights=None): + if not self._is_compiled: + return + if sample_weights and any([s is not None for s in sample_weights]): + pass +# for endpoint in self._training_endpoints: +# endpoint.sample_weight_mode = ( +# endpoint.sample_weight_mode or 'samplewise') + else: + for endpoint in self._training_endpoints: + endpoint.sample_weight_mode = None + def _prepare_total_loss(self, *args, **kwargs): orig = [ (ep, ep.__dict__.copy(), ep.training_target.__dict__.copy()) for ep in self._training_endpoints ] - + self.loss_hook(self._training_endpoints.copy()) ret = super(ModelLH, self)._prepare_total_loss(*args, **kwargs)