diff --git a/edml/core/client.py b/edml/core/client.py
index 4791c219bb99705d134c4ebc6e023910e0485d78..d7ff6723297bc45a9f2b5fd3e3bb002bb3de0111 100644
--- a/edml/core/client.py
+++ b/edml/core/client.py
@@ -223,7 +223,7 @@ class DeviceClient:
         gradients = []
         for param in self._model.parameters():
             if param is not None:
-                gradients.append(param)
+                gradients.append(param.grad)
             else:
                 gradients.append(torch.zeros_like(param))
 
diff --git a/edml/core/device.py b/edml/core/device.py
index ab03a450e3104aa35e3dc49c6a178fdc9e10faa1..2ba6833226b537118f8c02399a0c6c4a4967e47f 100644
--- a/edml/core/device.py
+++ b/edml/core/device.py
@@ -597,9 +597,10 @@ class RPCDeviceServicer(DeviceServicer):
     ):
         gradients = proto_to_tensor(request.gradients.gradients)
 
-        metrics = self.device.client.backward_single_batch(gradients=gradients)
+        metrics, gradients = self.device.client.backward_single_batch(gradients=gradients)
         return connection_pb2.SingleBatchBackwardResponse(
-            metrics=metrics_to_proto(metrics)
+            metrics=metrics_to_proto(metrics),
+            gradients=Gradients(gradients=tensor_to_proto(gradients)),
         )
     
     def SetGradientsAndFinalizeTrainingStep(
@@ -607,7 +608,7 @@ class RPCDeviceServicer(DeviceServicer):
     ):
         gradients = proto_to_tensor(request.gradients.gradients)
         self.device.client.set_gradient_and_finalize_training(gradients=gradients)
-        return connection_pb2.Empty()
+        return Empty()
 
 
 class DeviceRequestDispatcher:
@@ -1011,7 +1012,7 @@ class DeviceRequestDispatcher:
                 )
             )
             return (
-                proto_to_metrics(response.metrics),
+                None,
                 proto_to_tensor(response.gradients.gradients),
             )
         except grpc.RpcError:
diff --git a/edml/core/server.py b/edml/core/server.py
index 185a616c2bbcf0c41bca54d2e29391cd5ae53338..b78793678f6f77bc84354c93f9d32064eee657c1 100644
--- a/edml/core/server.py
+++ b/edml/core/server.py
@@ -383,10 +383,15 @@ class DeviceServer:
         )
 
 
-def _calculate_gradient_mean(gradients: List[Variable]) -> Variable:
+def _calculate_gradient_mean(gradients: List[Variable], device: str = "cpu") -> Variable:
     num_devices = len(gradients)
     weights = [1] * num_devices
 
+    # We need to move all tensors to the same device to do calculations.
+    for i, client_gradients in enumerate(gradients):
+        for j, grad in enumerate(client_gradients):
+            gradients[i][j] = grad.to(device)
+
     return [
         sum(gradients[i][j] * weights[i] for i in range(num_devices))
         for j in range(len(gradients[0]))