diff --git a/edml/core/client.py b/edml/core/client.py
index 5e99d875eb155ae165b3dbe750c406f09a4e0e60..75fb2bfbb360eeca936ec4aad9ba69b3fb68f459 100644
--- a/edml/core/client.py
+++ b/edml/core/client.py
@@ -199,7 +199,7 @@ class DeviceClient:
         self.node_device.battery.update_flops(
             self._model_flops * len(batch_data) * 2
         )  # 2x for backward pass
-        smashed_data.to(self._device)
+        gradients = gradients.to(self._device)
         smashed_data.backward(gradients)
         self._optimizer.step()
 
@@ -280,7 +280,7 @@ class DeviceClient:
                 self.node_device.battery.update_flops(
                     self._model_flops * len(batch_data) * 2
                 )  # 2x for backward pass
-                smashed_data.to(self._device)
+                server_grad = server_grad.to(self._device)
                 smashed_data.backward(server_grad)
                 self._optimizer.step()
 
diff --git a/edml/core/server.py b/edml/core/server.py
index f8703de677bd88a36abbe0aef0f6cd32581e1499..a671ba98e404c6981b35eb90d3e5bb1c17084304 100644
--- a/edml/core/server.py
+++ b/edml/core/server.py
@@ -205,7 +205,7 @@ class DeviceServer:
     def evaluate_batch(self, smashed_data, labels):
         """Evaluates the model on the given batch of data and labels"""
         with torch.no_grad():
-            smashed_data.to(self._device)
+            smashed_data = smashed_data.to(self._device)
             self._set_model_flops(smashed_data)
             self.node_device.battery.update_flops(self._model_flops * len(smashed_data))
             pred = self._model(smashed_data)
@@ -259,8 +259,12 @@ class DeviceServer:
             print(f"\n\n\nBATCHES: {len(batches)}\n\n\n")
             # batches2 = [b for b in batches if b is not None]
             # print(f"\n\n\nBATCHES FILTERED: {len(batches)}\n\n\n")
-            server_batch = _concat_smashed_data([b[0] for b in batches])
-            server_labels = _concat_smashed_data([b[1] for b in batches])
+            server_batch = _concat_smashed_data(
+                [b[0].to(self._device) for b in batches]
+            )
+            server_labels = _concat_smashed_data(
+                [b[1].to(self._device) for b in batches]
+            )
 
             # Train the part on the server. Then send the gradients to each client, continuing the calculation. We need
             # to split the gradients back into batch-sized tensors to average them before sending them to the client.
diff --git a/edml/tests/controllers/split_controller_test.py b/edml/tests/controllers/split_controller_test.py
index 807749d72d338f3c848837c8d9bd8b359966e08b..8531a2d1fe0aa146494d1ffb45317739a1c01dc1 100644
--- a/edml/tests/controllers/split_controller_test.py
+++ b/edml/tests/controllers/split_controller_test.py
@@ -23,6 +23,7 @@ class SplitControllerTest(unittest.TestCase):
             {"weights": 42},
             {"weights": 43},
             ModelMetricResultContainer(),
+            {"optimizer_state": 44},
             DiagnosticMetricResultContainer(),
         )