diff --git a/keras.py b/keras.py
index 2c2fe84c8e20ff8dca8f1b014c3f34d1d30ec772..c211feb25cb1ef26fe27280de5bd30e089d3cce4 100644
--- a/keras.py
+++ b/keras.py
@@ -339,13 +339,25 @@ class ModelLH(tf.keras.Model):
     def __init__(self, *args, **kwargs):
         self.loss_hook = kwargs.pop("loss_hook", None)
         super(ModelLH, self).__init__(*args, **kwargs)
-
+    
+    def _update_sample_weight_modes(self, sample_weights=None):
+        if not self._is_compiled:
+            return
+        if sample_weights and any([s is not None for s in sample_weights]):
+            pass
+#            for endpoint in self._training_endpoints:
+#                endpoint.sample_weight_mode = (
+#                    endpoint.sample_weight_mode or 'samplewise')
+        else:
+            for endpoint in self._training_endpoints:
+                endpoint.sample_weight_mode = None
+    
     def _prepare_total_loss(self, *args, **kwargs):
         orig = [
             (ep, ep.__dict__.copy(), ep.training_target.__dict__.copy())
             for ep in self._training_endpoints
         ]
-
+        
         self.loss_hook(self._training_endpoints.copy())
         ret = super(ModelLH, self)._prepare_total_loss(*args, **kwargs)