diff --git a/keras.py b/keras.py
index e80d991f9aa9f3eec01c810e662a4263d0bf7f39..2c2fe84c8e20ff8dca8f1b014c3f34d1d30ec772 100644
--- a/keras.py
+++ b/keras.py
@@ -335,14 +335,24 @@ class CustomValidation(tf.keras.callbacks.Callback):
         logs.update(make_logs(self.model, logs, res, mode=ModeKeys.TEST, prefix="val_"))
 
 
-class ModelWH(tf.keras.Model):
+class ModelLH(tf.keras.Model):
     def __init__(self, *args, **kwargs):
-        self.sample_weight_hook = kwargs.pop("sample_weight_hook", None)
-        super(ModelWH, self).__init__(*args, **kwargs)
+        self.loss_hook = kwargs.pop("loss_hook", None)
+        super(ModelLH, self).__init__(*args, **kwargs)
+
+    def _prepare_total_loss(self, *args, **kwargs):
+        orig = [
+            (ep, ep.__dict__.copy(), ep.training_target.__dict__.copy())
+            for ep in self._training_endpoints
+        ]
+
+        self.loss_hook(self._training_endpoints.copy())
+        ret = super(ModelLH, self)._prepare_total_loss(*args, **kwargs)
+
+        for ep, ed, td in orig:
+            ep.__dict__.update(ed)
+            ep.training_target.__dict__.update(td)
 
-    def _set_sample_weight_attributes(self, *args, **kwargs):
-        ret = super(ModelWH, self)._set_sample_weight_attributes(*args, **kwargs)
-        self.sample_weight_hook(self)
         return ret