diff --git a/keras.py b/keras.py index e80d991f9aa9f3eec01c810e662a4263d0bf7f39..2c2fe84c8e20ff8dca8f1b014c3f34d1d30ec772 100644 --- a/keras.py +++ b/keras.py @@ -335,14 +335,24 @@ class CustomValidation(tf.keras.callbacks.Callback): logs.update(make_logs(self.model, logs, res, mode=ModeKeys.TEST, prefix="val_")) -class ModelWH(tf.keras.Model): +class ModelLH(tf.keras.Model): def __init__(self, *args, **kwargs): - self.sample_weight_hook = kwargs.pop("sample_weight_hook", None) - super(ModelWH, self).__init__(*args, **kwargs) + self.loss_hook = kwargs.pop("loss_hook", None) + super(ModelLH, self).__init__(*args, **kwargs) + + def _prepare_total_loss(self, *args, **kwargs): + orig = [ + (ep, ep.__dict__.copy(), ep.training_target.__dict__.copy()) + for ep in self._training_endpoints + ] + + self.loss_hook(self._training_endpoints.copy()) + ret = super(ModelLH, self)._prepare_total_loss(*args, **kwargs) + + for ep, ed, td in orig: + ep.__dict__.update(ed) + ep.training_target.__dict__.update(td) - def _set_sample_weight_attributes(self, *args, **kwargs): - ret = super(ModelWH, self)._set_sample_weight_attributes(*args, **kwargs) - self.sample_weight_hook(self) return ret