diff --git a/edml/core/client.py b/edml/core/client.py
index 9211cda8617085bc3189b12904d30d24f51090ba..5e99d875eb155ae165b3dbe750c406f09a4e0e60 100644
--- a/edml/core/client.py
+++ b/edml/core/client.py
@@ -199,6 +199,7 @@ class DeviceClient:
         self.node_device.battery.update_flops(
             self._model_flops * len(batch_data) * 2
         )  # 2x for backward pass
+        smashed_data.to(self._device)
         smashed_data.backward(gradients)
         self._optimizer.step()
 
@@ -279,6 +280,7 @@ class DeviceClient:
                 self.node_device.battery.update_flops(
                     self._model_flops * len(batch_data) * 2
                 )  # 2x for backward pass
+                smashed_data.to(self._device)
                 smashed_data.backward(server_grad)
                 self._optimizer.step()