diff --git a/edml/core/client.py b/edml/core/client.py
index 9802a84b8227421d0dd7c5a1acf17d7968d81d1f..4791c219bb99705d134c4ebc6e023910e0485d78 100644
--- a/edml/core/client.py
+++ b/edml/core/client.py
@@ -205,7 +205,6 @@ class DeviceClient:
         )  # 2x for backward pass
         gradients = gradients.to(self._device)
         smashed_data.backward(gradients)
-        print(smashed_data.grad)
         # self._optimizer.step()
 
         # We need to store a reference to the smashed_data to make it possible to finalize the training step.
@@ -221,7 +220,14 @@ class DeviceClient:
         )
         metrics_container = DiagnosticMetricResultContainer([metric])
 
-        return metrics_container, smashed_data.grad
+        gradients = []
+        for param in self._model.parameters():
+            if param is not None:
+                gradients.append(param)
+            else:
+                gradients.append(torch.zeros_like(param))
+
+        return metrics_container, gradients
 
     def get_approximated_num_batches(self) -> int:
         return len(self._train_data)
@@ -372,8 +378,9 @@ class DeviceClient:
         )
         return diagnostic_metric_results
 
-    def set_gradient_and_finalize_training(self, gradients: Any):
-        smashed_data = self._psl_cache["smashed_data"]
-        smashed_data.grad = gradients
+    def set_gradient_and_finalize_training(self, gradients: Any):        
+        for param, grad in zip(self._model.parameters(), gradients):
+            param.grad = grad.to(self._device)
+
         self._optimizer.step()
         self._psl_cache = None
diff --git a/edml/core/device.py b/edml/core/device.py
index 694775c1825a7cb51f120a4c743a2498057d66a1..ab03a450e3104aa35e3dc49c6a178fdc9e10faa1 100644
--- a/edml/core/device.py
+++ b/edml/core/device.py
@@ -15,6 +15,7 @@ from edml.core.client import DeviceClient
 from edml.core.server import DeviceServer
 from edml.generated import connection_pb2
 from edml.generated.connection_pb2 import (
+    SetGradientsRequest,
     SetWeightsRequest,
     TrainBatchRequest,
     TrainGlobalResponse,
@@ -600,6 +601,13 @@ class RPCDeviceServicer(DeviceServicer):
         return connection_pb2.SingleBatchBackwardResponse(
             metrics=metrics_to_proto(metrics)
         )
+    
+    def SetGradientsAndFinalizeTrainingStep(
+            self, request: SetGradientsRequest, context
+    ):
+        gradients = proto_to_tensor(request.gradients.gradients)
+        self.device.client.set_gradient_and_finalize_training(gradients=gradients)
+        return connection_pb2.Empty()
 
 
 class DeviceRequestDispatcher:
@@ -1003,8 +1011,8 @@ class DeviceRequestDispatcher:
                 )
             )
             return (
-                response.metrics,
-                response.gradients,
+                proto_to_metrics(response.metrics),
+                proto_to_tensor(response.gradients.gradients),
             )
         except grpc.RpcError:
             self._handle_rpc_error(device_id)
diff --git a/edml/core/server.py b/edml/core/server.py
index bc6ac5b2ede8ecbe86e40129b27d686cc1a3a33e..185a616c2bbcf0c41bca54d2e29391cd5ae53338 100644
--- a/edml/core/server.py
+++ b/edml/core/server.py
@@ -384,8 +384,13 @@ class DeviceServer:
 
 
 def _calculate_gradient_mean(gradients: List[Variable]) -> Variable:
-    """Calculates the mean of a list of gradients."""
-    return torch.mean(torch.stack(gradients), dim=0)
+    num_devices = len(gradients)
+    weights = [1] * num_devices
+
+    return [
+        sum(gradients[i][j] * weights[i] for i in range(num_devices))
+        for j in range(len(gradients[0]))
+    ]
 
 
 def _concat_smashed_data(data: List[Any]) -> Any: