diff --git a/edml/core/client.py b/edml/core/client.py
index 05404ed95f94c91b7392a2139f42f19678172fe4..9802a84b8227421d0dd7c5a1acf17d7968d81d1f 100644
--- a/edml/core/client.py
+++ b/edml/core/client.py
@@ -2,7 +2,7 @@ from __future__ import annotations
 
 import itertools
 import time
-from typing import Optional, Tuple, TYPE_CHECKING
+from typing import Optional, Tuple, TYPE_CHECKING, Any
 
 import torch
 from omegaconf import DictConfig
@@ -187,7 +187,9 @@ class DeviceClient:
             )
 
     @check_device_set()
-    def backward_single_batch(self, gradients) -> DiagnosticMetricResultContainer:
+    def backward_single_batch(
+        self, gradients
+    ) -> Tuple[DiagnosticMetricResultContainer, torch.Tensor]:
         torch.cuda.set_device(self._device)
         batch_data, smashed_data, start_time, end_time = (
             self._psl_cache["batch_data"],
@@ -203,7 +205,11 @@ class DeviceClient:
         )  # 2x for backward pass
         gradients = gradients.to(self._device)
         smashed_data.backward(gradients)
-        self._optimizer.step()
+        print(smashed_data.grad)
+        # self._optimizer.step()
+
+        # We need to store a reference to the smashed_data to make it possible to finalize the training step.
+        self._psl_cache["smashed_data"] = smashed_data
 
         end_time_2 = time.time()
 
@@ -215,8 +221,7 @@ class DeviceClient:
         )
         metrics_container = DiagnosticMetricResultContainer([metric])
 
-        self._psl_cache = None
-        return metrics_container
+        return metrics_container, smashed_data.grad
 
     def get_approximated_num_batches(self) -> int:
         return len(self._train_data)
@@ -366,3 +371,9 @@ class DeviceClient:
             )
         )
         return diagnostic_metric_results
+
+    def set_gradient_and_finalize_training(self, gradients: Any):
+        smashed_data = self._psl_cache["smashed_data"]
+        smashed_data.grad = gradients
+        self._optimizer.step()
+        self._psl_cache = None
diff --git a/edml/core/device.py b/edml/core/device.py
index 7fc7d8fb1638792077bf04dd2e8d7c5257c5cc5c..694775c1825a7cb51f120a4c743a2498057d66a1 100644
--- a/edml/core/device.py
+++ b/edml/core/device.py
@@ -31,6 +31,7 @@ from edml.generated.connection_pb2 import (
     SingleBatchTrainingResponse,
     SingleBatchBackwardRequest,
     TrainGlobalParallelSplitLearningResponse,
+    SingleBatchBackwardResponse,
 )
 from edml.generated.connection_pb2_grpc import DeviceServicer, DeviceStub
 from edml.generated.datastructures_pb2 import (
@@ -212,8 +213,25 @@ class Device(ABC):
     def backpropagation_on_client_only_on(self, client_id: str, gradients: Any):
         """"""
 
+    @abstractmethod
+    def set_gradient_and_finalize_training_on_client_only_on(
+        self, client_id: str, gradients: Any
+    ):
+        """"""
+
 
 class NetworkDevice(Device):
+    @update_battery
+    def set_gradient_and_finalize_training_on_client_only_on(
+        self, client_id: str, gradients: Any
+    ):
+        if client_id == self.device_id:
+            self.client.set_gradient_and_finalize_training(gradients)
+        else:
+            return self.request_dispatcher.set_gradient_and_finalize_training_on_client_only(
+                client_id, gradients
+            )
+
     @update_battery
     @log_execution_time("logger", "train_parallel_split_learning")
     def train_parallel_split_learning(
@@ -977,16 +995,37 @@ class DeviceRequestDispatcher:
 
     def backpropagation_on_client_only(self, device_id, gradients):
         try:
-            response: Empty = self._get_connection(
+            response: SingleBatchBackwardResponse = self._get_connection(
                 device_id
             ).BackwardPropagationSingleBatchOnClient(
                 connection_pb2.SingleBatchBackwardRequest(
                     gradients=Gradients(gradients=tensor_to_proto(gradients))
                 )
             )
-            return response
+            return (
+                response.metrics,
+                response.gradients,
+            )
         except grpc.RpcError:
             self._handle_rpc_error(device_id)
         except KeyError:
             self._handle_unknown_device_id(device_id)
         return False
+
+    def set_gradient_and_finalize_training_on_client_only(
+        self, client_id: str, gradients: Any
+    ):
+        try:
+            response: Empty = self._get_connection(
+                client_id
+            ).SetGradientsAndFinalizeTrainingStep(
+                connection_pb2.SetGradientsRequest(
+                    gradients=Gradients(gradients=tensor_to_proto(gradients))
+                )
+            )
+            return response
+        except grpc.RpcError:
+            self._handle_rpc_error(client_id)
+        except KeyError:
+            self._handle_unknown_device_id(client_id)
+        return False
diff --git a/edml/core/server.py b/edml/core/server.py
index 885b22fb0d2aafe93ed2da750f43428d9350dbda..bd5628a72401fd449f05f3430904b626d9970b9e 100644
--- a/edml/core/server.py
+++ b/edml/core/server.py
@@ -219,16 +219,26 @@ class DeviceServer:
         adaptive_learning_threshold: Optional[float] = None,
         optimizer_state: dict[str, Any] = None,
     ):
-        def client_training_job(client_id: str, batch_index: int) -> SLTrainBatchResult:
-            return self.node_device.train_batch_on_client_only_on(
+        def client_training_job(client_id: str, batch_index: int):
+            result = self.node_device.train_batch_on_client_only_on(
                 device_id=client_id, batch_index=batch_index
             )
+            return (client_id, result)
 
         def client_backpropagation_job(client_id: str, gradients: Any):
             return self.node_device.backpropagation_on_client_only_on(
                 client_id=client_id, gradients=gradients
             )
 
+        def client_set_gradient_and_finalize_training_job(
+            client_id: str, gradients: Any
+        ):
+            return (
+                self.node_device.set_gradient_and_finalize_training_on_client_only_on(
+                    client_id=client_id, gradients=gradients
+                )
+            )
+
         if optimizer_state is not None:
             self._optimizer.load_state_dict(optimizer_state)
 
@@ -243,27 +253,31 @@ class DeviceServer:
         num_batches = self.node_device.client.get_approximated_num_batches()
         print(f"\n\n:: BATCHES :: {num_batches}\n\n")
         for batch_index in range(num_batches):
-            batches = []
+            client_forward_pass_responses = []
 
             futures = [
                 executor.submit(client_training_job, client_id, batch_index)
                 for client_id in clients
             ]
             for future in concurrent.futures.as_completed(futures):
-                batches.append(future.result())
+                client_forward_pass_responses.append(future.result())
+
+            # We want to split up the responses into a list of client IDs and batches again.
+            client_ids = [b[0] for b in client_forward_pass_responses]
+            client_batches = [b[1] for b in client_forward_pass_responses]
 
-            if _empty_batches(batches):
-                # Only last batch anyway.
+            if _empty_batches(client_batches):
+                # Only the last batch anyway.
                 break
 
-            print(f"\n\n\nBATCHES: {len(batches)}\n\n\n")
+            print(f"\n\n\nBATCHES: {len(client_batches)}\n\n\n")
             # batches2 = [b for b in batches if b is not None]
             # print(f"\n\n\nBATCHES FILTERED: {len(batches)}\n\n\n")
             server_batch = _concat_smashed_data(
-                [b[0].to(self._device) for b in batches]
+                [b[0].to(self._device) for b in client_batches]
             )
             server_labels = _concat_smashed_data(
-                [b[1].to(self._device) for b in batches]
+                [b[1].to(self._device) for b in client_batches]
             )
 
             # Train the part on the server. Then send the gradients to each client, continuing the calculation. We need
@@ -287,30 +301,46 @@ class DeviceServer:
                 self.node_device.log({"adaptive_learning_threshold_applied": True})
                 continue
 
-            num_client_gradients = len(batches)
+            num_client_gradients = len(client_forward_pass_responses)
             print(
                 f"::: tensor shape: {server_gradients.shape} -> {server_gradients.size(0)} with metrics: {server_metrics is not None}"
             )
 
             client_gradients = torch.chunk(server_gradients, num_client_gradients)
-            print(f"::: result shape: {client_gradients[1].shape}")
-            concatenated_client_gradients = torch.stack(client_gradients, dim=0)
-            mean_tensor = torch.mean(concatenated_client_gradients, dim=0)
-            print(f"::: -> {mean_tensor.shape}")
+            # print(f"::: result shape: {client_gradients[1].shape}")
+            # concatenated_client_gradients = torch.stack(client_gradients, dim=0)
+            # mean_tensor = torch.mean(concatenated_client_gradients, dim=0)
+            # print(f"::: -> {mean_tensor.shape}")
 
             futures = [
-                executor.submit(client_backpropagation_job, client_id, mean_tensor)
-                for client_id in clients
+                executor.submit(
+                    client_backpropagation_job, client_id, client_gradients[idx]
+                )
+                for (client_id, idx) in enumerate(client_ids)
             ]
-            diagnostic_results = []
+            client_backpropagation_results = []
             for future in concurrent.futures.as_completed(futures):
-                diagnostic_results.append(future.result())
-            # Has to be done outside the loop due to thread-safety.
-            # for diagnostic_result in diagnostic_results:
-            #     diagnostic_metrics.merge(diagnostic_result)
+                client_backpropagation_results.append(future.result())
+
+            client_backpropagation_gradients = [
+                result[1] for result in client_backpropagation_results
+            ]
 
-            # # Resetting batches since we only need them per-batch-iteration.
-            # batches = []
+            # We want to average the client's backpropagation gradients and send them over again to finalize the
+            # current training step.
+            averaged_gradient = _calculate_gradient_mean(
+                client_backpropagation_gradients
+            )
+            futures = [
+                executor.submit(
+                    client_set_gradient_and_finalize_training_job,
+                    client_id,
+                    averaged_gradient,
+                )
+                for client_id in clients
+            ]
+            for future in concurrent.futures.as_completed(futures):
+                future.result()
 
         # Now we have to determine the model metrics for each client.
         for client_id in clients:
@@ -353,6 +383,11 @@ class DeviceServer:
         )
 
 
+def _calculate_gradient_mean(gradients: List[Variable]) -> Variable:
+    """Calculates the mean of a list of gradients."""
+    return torch.mean(torch.stack(gradients), dim=0)
+
+
 def _concat_smashed_data(data: List[Any]) -> Any:
     """Creates a single batch tensor from a list of tensors."""
     return torch.cat(data, dim=0)
diff --git a/edml/generated/connection_pb2.py b/edml/generated/connection_pb2.py
index d8899e228ac061499659d0d55c7575241674ad32..ce271bbab4fa5d30e0e8c5a0a8b309d96eaccf07 100644
--- a/edml/generated/connection_pb2.py
+++ b/edml/generated/connection_pb2.py
@@ -14,7 +14,7 @@ _sym_db = _symbol_database.Default()
 import datastructures_pb2 as datastructures__pb2
 
 
-DESCRIPTOR = _descriptor_pool.Default().AddSerializedFile(b'\n\x10\x63onnection.proto\x1a\x14\x64\x61tastructures.proto\"5\n\x14UpdateWeightsRequest\x12\x1d\n\tgradients\x18\x01 \x01(\x0b\x32\n.Gradients\";\n\x1aSingleBatchBackwardRequest\x12\x1d\n\tgradients\x18\x01 \x01(\x0b\x32\n.Gradients\"8\n\x1bSingleBatchBackwardResponse\x12\x19\n\x07metrics\x18\x01 \x01(\x0b\x32\x08.Metrics\"1\n\x1aSingleBatchTrainingRequest\x12\x13\n\x0b\x62\x61tch_index\x18\x01 \x01(\x05\"\x80\x01\n\x1bSingleBatchTrainingResponse\x12\'\n\x0csmashed_data\x18\x01 \x01(\x0b\x32\x0c.ActivationsH\x00\x88\x01\x01\x12\x1c\n\x06labels\x18\x02 \x01(\x0b\x32\x07.LabelsH\x01\x88\x01\x01\x42\x0f\n\r_smashed_dataB\t\n\x07_labels\"\xd5\x01\n\'TrainGlobalParallelSplitLearningRequest\x12\x15\n\x08round_no\x18\x01 \x01(\x05H\x00\x88\x01\x01\x12(\n\x1b\x61\x64\x61ptive_learning_threshold\x18\x02 \x01(\x01H\x01\x88\x01\x01\x12(\n\x0foptimizer_state\x18\x03 \x01(\x0b\x32\n.StateDictH\x02\x88\x01\x01\x42\x0b\n\t_round_noB\x1e\n\x1c_adaptive_learning_thresholdB\x12\n\x10_optimizer_state\"\x89\x02\n(TrainGlobalParallelSplitLearningResponse\x12 \n\x0e\x63lient_weights\x18\x01 \x01(\x0b\x32\x08.Weights\x12 \n\x0eserver_weights\x18\x02 \x01(\x0b\x32\x08.Weights\x12\x19\n\x07metrics\x18\x03 \x01(\x0b\x32\x08.Metrics\x12(\n\x0foptimizer_state\x18\x04 \x01(\x0b\x32\n.StateDictH\x00\x88\x01\x01\x12)\n\x12\x64iagnostic_metrics\x18\x05 \x01(\x0b\x32\x08.MetricsH\x01\x88\x01\x01\x42\x12\n\x10_optimizer_stateB\x15\n\x13_diagnostic_metrics\"\x86\x01\n\x12TrainGlobalRequest\x12\x0e\n\x06\x65pochs\x18\x01 \x01(\x05\x12\x15\n\x08round_no\x18\x02 \x01(\x05H\x00\x88\x01\x01\x12(\n\x0foptimizer_state\x18\x03 \x01(\x0b\x32\n.StateDictH\x01\x88\x01\x01\x42\x0b\n\t_round_noB\x12\n\x10_optimizer_state\"\xf4\x01\n\x13TrainGlobalResponse\x12 \n\x0e\x63lient_weights\x18\x01 \x01(\x0b\x32\x08.Weights\x12 \n\x0eserver_weights\x18\x02 \x01(\x0b\x32\x08.Weights\x12\x19\n\x07metrics\x18\x03 \x01(\x0b\x32\x08.Metrics\x12(\n\x0foptimizer_state\x18\x04 \x01(\x0b\x32\n.StateDictH\x00\x88\x01\x01\x12)\n\x12\x64iagnostic_metrics\x18\x05 \x01(\x0b\x32\x08.MetricsH\x01\x88\x01\x01\x42\x12\n\x10_optimizer_stateB\x15\n\x13_diagnostic_metrics\"A\n\x11SetWeightsRequest\x12\x19\n\x07weights\x18\x01 \x01(\x0b\x32\x08.Weights\x12\x11\n\ton_client\x18\x02 \x01(\x08\"V\n\x12SetWeightsResponse\x12)\n\x12\x64iagnostic_metrics\x18\x01 \x01(\x0b\x32\x08.MetricsH\x00\x88\x01\x01\x42\x15\n\x13_diagnostic_metrics\"T\n\x11TrainEpochRequest\x12\x1b\n\x06server\x18\x01 \x01(\x0b\x32\x0b.DeviceInfo\x12\x15\n\x08round_no\x18\x02 \x01(\x05H\x00\x88\x01\x01\x42\x0b\n\t_round_no\"q\n\x12TrainEpochResponse\x12\x19\n\x07weights\x18\x01 \x01(\x0b\x32\x08.Weights\x12)\n\x12\x64iagnostic_metrics\x18\x02 \x01(\x0b\x32\x08.MetricsH\x00\x88\x01\x01\x42\x15\n\x13_diagnostic_metrics\"P\n\x11TrainBatchRequest\x12\"\n\x0csmashed_data\x18\x01 \x01(\x0b\x32\x0c.Activations\x12\x17\n\x06labels\x18\x02 \x01(\x0b\x32\x07.Labels\"\x91\x01\n\x12TrainBatchResponse\x12\x1d\n\tgradients\x18\x01 \x01(\x0b\x32\n.Gradients\x12)\n\x12\x64iagnostic_metrics\x18\x02 \x01(\x0b\x32\x08.MetricsH\x00\x88\x01\x01\x12\x11\n\x04loss\x18\x03 \x01(\x01H\x01\x88\x01\x01\x42\x15\n\x13_diagnostic_metricsB\x07\n\x05_loss\":\n\x11\x45valGlobalRequest\x12\x12\n\nvalidation\x18\x01 \x01(\x08\x12\x11\n\tfederated\x18\x02 \x01(\x08\"q\n\x12\x45valGlobalResponse\x12\x19\n\x07metrics\x18\x01 \x01(\x0b\x32\x08.Metrics\x12)\n\x12\x64iagnostic_metrics\x18\x02 \x01(\x0b\x32\x08.MetricsH\x00\x88\x01\x01\x42\x15\n\x13_diagnostic_metrics\">\n\x0b\x45valRequest\x12\x1b\n\x06server\x18\x01 \x01(\x0b\x32\x0b.DeviceInfo\x12\x12\n\nvalidation\x18\x02 \x01(\x08\"P\n\x0c\x45valResponse\x12)\n\x12\x64iagnostic_metrics\x18\x01 \x01(\x0b\x32\x08.MetricsH\x00\x88\x01\x01\x42\x15\n\x13_diagnostic_metrics\"O\n\x10\x45valBatchRequest\x12\"\n\x0csmashed_data\x18\x01 \x01(\x0b\x32\x0c.Activations\x12\x17\n\x06labels\x18\x02 \x01(\x0b\x32\x07.Labels\"p\n\x11\x45valBatchResponse\x12\x19\n\x07metrics\x18\x01 \x01(\x0b\x32\x08.Metrics\x12)\n\x12\x64iagnostic_metrics\x18\x02 \x01(\x0b\x32\x08.MetricsH\x00\x88\x01\x01\x42\x15\n\x13_diagnostic_metrics\";\n\x15\x46ullModelTrainRequest\x12\x15\n\x08round_no\x18\x01 \x01(\x05H\x00\x88\x01\x01\x42\x0b\n\t_round_no\"\xce\x01\n\x16\x46ullModelTrainResponse\x12 \n\x0e\x63lient_weights\x18\x01 \x01(\x0b\x32\x08.Weights\x12 \n\x0eserver_weights\x18\x02 \x01(\x0b\x32\x08.Weights\x12\x13\n\x0bnum_samples\x18\x03 \x01(\x05\x12\x19\n\x07metrics\x18\x04 \x01(\x0b\x32\x08.Metrics\x12)\n\x12\x64iagnostic_metrics\x18\x05 \x01(\x0b\x32\x08.MetricsH\x00\x88\x01\x01\x42\x15\n\x13_diagnostic_metrics\"\x18\n\x16StartExperimentRequest\"[\n\x17StartExperimentResponse\x12)\n\x12\x64iagnostic_metrics\x18\x01 \x01(\x0b\x32\x08.MetricsH\x00\x88\x01\x01\x42\x15\n\x13_diagnostic_metrics\"\x16\n\x14\x45ndExperimentRequest\"Y\n\x15\x45ndExperimentResponse\x12)\n\x12\x64iagnostic_metrics\x18\x01 \x01(\x0b\x32\x08.MetricsH\x00\x88\x01\x01\x42\x15\n\x13_diagnostic_metrics\"\x16\n\x14\x42\x61tteryStatusRequest\"y\n\x15\x42\x61tteryStatusResponse\x12\x1e\n\x06status\x18\x01 \x01(\x0b\x32\x0e.BatteryStatus\x12)\n\x12\x64iagnostic_metrics\x18\x02 \x01(\x0b\x32\x08.MetricsH\x00\x88\x01\x01\x42\x15\n\x13_diagnostic_metrics\"\x19\n\x17\x44\x61tasetModelInfoRequest\"\xc7\x01\n\x18\x44\x61tasetModelInfoResponse\x12\x15\n\rtrain_samples\x18\x01 \x01(\x05\x12\x1a\n\x12validation_samples\x18\x02 \x01(\x05\x12\x1a\n\x12\x63lient_model_flops\x18\x03 \x01(\x05\x12\x1a\n\x12server_model_flops\x18\x04 \x01(\x05\x12)\n\x12\x64iagnostic_metrics\x18\x05 \x01(\x0b\x32\x08.MetricsH\x00\x88\x01\x01\x42\x15\n\x13_diagnostic_metrics2\xb1\x08\n\x06\x44\x65vice\x12:\n\x0bTrainGlobal\x12\x13.TrainGlobalRequest\x1a\x14.TrainGlobalResponse\"\x00\x12\x37\n\nSetWeights\x12\x12.SetWeightsRequest\x1a\x13.SetWeightsResponse\"\x00\x12\x37\n\nTrainEpoch\x12\x12.TrainEpochRequest\x1a\x13.TrainEpochResponse\"\x00\x12\x37\n\nTrainBatch\x12\x12.TrainBatchRequest\x1a\x13.TrainBatchResponse\"\x00\x12;\n\x0e\x45valuateGlobal\x12\x12.EvalGlobalRequest\x1a\x13.EvalGlobalResponse\"\x00\x12)\n\x08\x45valuate\x12\x0c.EvalRequest\x1a\r.EvalResponse\"\x00\x12\x38\n\rEvaluateBatch\x12\x11.EvalBatchRequest\x1a\x12.EvalBatchResponse\"\x00\x12\x46\n\x11\x46ullModelTraining\x12\x16.FullModelTrainRequest\x1a\x17.FullModelTrainResponse\"\x00\x12\x46\n\x0fStartExperiment\x12\x17.StartExperimentRequest\x1a\x18.StartExperimentResponse\"\x00\x12@\n\rEndExperiment\x12\x15.EndExperimentRequest\x1a\x16.EndExperimentResponse\"\x00\x12\x43\n\x10GetBatteryStatus\x12\x15.BatteryStatusRequest\x1a\x16.BatteryStatusResponse\"\x00\x12L\n\x13GetDatasetModelInfo\x12\x18.DatasetModelInfoRequest\x1a\x19.DatasetModelInfoResponse\"\x00\x12y\n TrainGlobalParallelSplitLearning\x12(.TrainGlobalParallelSplitLearningRequest\x1a).TrainGlobalParallelSplitLearningResponse\"\x00\x12W\n\x18TrainSingleBatchOnClient\x12\x1b.SingleBatchTrainingRequest\x1a\x1c.SingleBatchTrainingResponse\"\x00\x12\x65\n&BackwardPropagationSingleBatchOnClient\x12\x1b.SingleBatchBackwardRequest\x1a\x1c.SingleBatchBackwardResponse\"\x00\x62\x06proto3')
+DESCRIPTOR = _descriptor_pool.Default().AddSerializedFile(b'\n\x10\x63onnection.proto\x1a\x14\x64\x61tastructures.proto\"4\n\x13SetGradientsRequest\x12\x1d\n\tgradients\x18\x01 \x01(\x0b\x32\n.Gradients\"5\n\x14UpdateWeightsRequest\x12\x1d\n\tgradients\x18\x01 \x01(\x0b\x32\n.Gradients\";\n\x1aSingleBatchBackwardRequest\x12\x1d\n\tgradients\x18\x01 \x01(\x0b\x32\n.Gradients\"j\n\x1bSingleBatchBackwardResponse\x12\x19\n\x07metrics\x18\x01 \x01(\x0b\x32\x08.Metrics\x12\"\n\tgradients\x18\x02 \x01(\x0b\x32\n.GradientsH\x00\x88\x01\x01\x42\x0c\n\n_gradients\"1\n\x1aSingleBatchTrainingRequest\x12\x13\n\x0b\x62\x61tch_index\x18\x01 \x01(\x05\"\x80\x01\n\x1bSingleBatchTrainingResponse\x12\'\n\x0csmashed_data\x18\x01 \x01(\x0b\x32\x0c.ActivationsH\x00\x88\x01\x01\x12\x1c\n\x06labels\x18\x02 \x01(\x0b\x32\x07.LabelsH\x01\x88\x01\x01\x42\x0f\n\r_smashed_dataB\t\n\x07_labels\"\xd5\x01\n\'TrainGlobalParallelSplitLearningRequest\x12\x15\n\x08round_no\x18\x01 \x01(\x05H\x00\x88\x01\x01\x12(\n\x1b\x61\x64\x61ptive_learning_threshold\x18\x02 \x01(\x01H\x01\x88\x01\x01\x12(\n\x0foptimizer_state\x18\x03 \x01(\x0b\x32\n.StateDictH\x02\x88\x01\x01\x42\x0b\n\t_round_noB\x1e\n\x1c_adaptive_learning_thresholdB\x12\n\x10_optimizer_state\"\x89\x02\n(TrainGlobalParallelSplitLearningResponse\x12 \n\x0e\x63lient_weights\x18\x01 \x01(\x0b\x32\x08.Weights\x12 \n\x0eserver_weights\x18\x02 \x01(\x0b\x32\x08.Weights\x12\x19\n\x07metrics\x18\x03 \x01(\x0b\x32\x08.Metrics\x12(\n\x0foptimizer_state\x18\x04 \x01(\x0b\x32\n.StateDictH\x00\x88\x01\x01\x12)\n\x12\x64iagnostic_metrics\x18\x05 \x01(\x0b\x32\x08.MetricsH\x01\x88\x01\x01\x42\x12\n\x10_optimizer_stateB\x15\n\x13_diagnostic_metrics\"\x86\x01\n\x12TrainGlobalRequest\x12\x0e\n\x06\x65pochs\x18\x01 \x01(\x05\x12\x15\n\x08round_no\x18\x02 \x01(\x05H\x00\x88\x01\x01\x12(\n\x0foptimizer_state\x18\x03 \x01(\x0b\x32\n.StateDictH\x01\x88\x01\x01\x42\x0b\n\t_round_noB\x12\n\x10_optimizer_state\"\xf4\x01\n\x13TrainGlobalResponse\x12 \n\x0e\x63lient_weights\x18\x01 \x01(\x0b\x32\x08.Weights\x12 \n\x0eserver_weights\x18\x02 \x01(\x0b\x32\x08.Weights\x12\x19\n\x07metrics\x18\x03 \x01(\x0b\x32\x08.Metrics\x12(\n\x0foptimizer_state\x18\x04 \x01(\x0b\x32\n.StateDictH\x00\x88\x01\x01\x12)\n\x12\x64iagnostic_metrics\x18\x05 \x01(\x0b\x32\x08.MetricsH\x01\x88\x01\x01\x42\x12\n\x10_optimizer_stateB\x15\n\x13_diagnostic_metrics\"A\n\x11SetWeightsRequest\x12\x19\n\x07weights\x18\x01 \x01(\x0b\x32\x08.Weights\x12\x11\n\ton_client\x18\x02 \x01(\x08\"V\n\x12SetWeightsResponse\x12)\n\x12\x64iagnostic_metrics\x18\x01 \x01(\x0b\x32\x08.MetricsH\x00\x88\x01\x01\x42\x15\n\x13_diagnostic_metrics\"T\n\x11TrainEpochRequest\x12\x1b\n\x06server\x18\x01 \x01(\x0b\x32\x0b.DeviceInfo\x12\x15\n\x08round_no\x18\x02 \x01(\x05H\x00\x88\x01\x01\x42\x0b\n\t_round_no\"q\n\x12TrainEpochResponse\x12\x19\n\x07weights\x18\x01 \x01(\x0b\x32\x08.Weights\x12)\n\x12\x64iagnostic_metrics\x18\x02 \x01(\x0b\x32\x08.MetricsH\x00\x88\x01\x01\x42\x15\n\x13_diagnostic_metrics\"P\n\x11TrainBatchRequest\x12\"\n\x0csmashed_data\x18\x01 \x01(\x0b\x32\x0c.Activations\x12\x17\n\x06labels\x18\x02 \x01(\x0b\x32\x07.Labels\"\x91\x01\n\x12TrainBatchResponse\x12\x1d\n\tgradients\x18\x01 \x01(\x0b\x32\n.Gradients\x12)\n\x12\x64iagnostic_metrics\x18\x02 \x01(\x0b\x32\x08.MetricsH\x00\x88\x01\x01\x12\x11\n\x04loss\x18\x03 \x01(\x01H\x01\x88\x01\x01\x42\x15\n\x13_diagnostic_metricsB\x07\n\x05_loss\":\n\x11\x45valGlobalRequest\x12\x12\n\nvalidation\x18\x01 \x01(\x08\x12\x11\n\tfederated\x18\x02 \x01(\x08\"q\n\x12\x45valGlobalResponse\x12\x19\n\x07metrics\x18\x01 \x01(\x0b\x32\x08.Metrics\x12)\n\x12\x64iagnostic_metrics\x18\x02 \x01(\x0b\x32\x08.MetricsH\x00\x88\x01\x01\x42\x15\n\x13_diagnostic_metrics\">\n\x0b\x45valRequest\x12\x1b\n\x06server\x18\x01 \x01(\x0b\x32\x0b.DeviceInfo\x12\x12\n\nvalidation\x18\x02 \x01(\x08\"P\n\x0c\x45valResponse\x12)\n\x12\x64iagnostic_metrics\x18\x01 \x01(\x0b\x32\x08.MetricsH\x00\x88\x01\x01\x42\x15\n\x13_diagnostic_metrics\"O\n\x10\x45valBatchRequest\x12\"\n\x0csmashed_data\x18\x01 \x01(\x0b\x32\x0c.Activations\x12\x17\n\x06labels\x18\x02 \x01(\x0b\x32\x07.Labels\"p\n\x11\x45valBatchResponse\x12\x19\n\x07metrics\x18\x01 \x01(\x0b\x32\x08.Metrics\x12)\n\x12\x64iagnostic_metrics\x18\x02 \x01(\x0b\x32\x08.MetricsH\x00\x88\x01\x01\x42\x15\n\x13_diagnostic_metrics\";\n\x15\x46ullModelTrainRequest\x12\x15\n\x08round_no\x18\x01 \x01(\x05H\x00\x88\x01\x01\x42\x0b\n\t_round_no\"\xce\x01\n\x16\x46ullModelTrainResponse\x12 \n\x0e\x63lient_weights\x18\x01 \x01(\x0b\x32\x08.Weights\x12 \n\x0eserver_weights\x18\x02 \x01(\x0b\x32\x08.Weights\x12\x13\n\x0bnum_samples\x18\x03 \x01(\x05\x12\x19\n\x07metrics\x18\x04 \x01(\x0b\x32\x08.Metrics\x12)\n\x12\x64iagnostic_metrics\x18\x05 \x01(\x0b\x32\x08.MetricsH\x00\x88\x01\x01\x42\x15\n\x13_diagnostic_metrics\"\x18\n\x16StartExperimentRequest\"[\n\x17StartExperimentResponse\x12)\n\x12\x64iagnostic_metrics\x18\x01 \x01(\x0b\x32\x08.MetricsH\x00\x88\x01\x01\x42\x15\n\x13_diagnostic_metrics\"\x16\n\x14\x45ndExperimentRequest\"Y\n\x15\x45ndExperimentResponse\x12)\n\x12\x64iagnostic_metrics\x18\x01 \x01(\x0b\x32\x08.MetricsH\x00\x88\x01\x01\x42\x15\n\x13_diagnostic_metrics\"\x16\n\x14\x42\x61tteryStatusRequest\"y\n\x15\x42\x61tteryStatusResponse\x12\x1e\n\x06status\x18\x01 \x01(\x0b\x32\x0e.BatteryStatus\x12)\n\x12\x64iagnostic_metrics\x18\x02 \x01(\x0b\x32\x08.MetricsH\x00\x88\x01\x01\x42\x15\n\x13_diagnostic_metrics\"\x19\n\x17\x44\x61tasetModelInfoRequest\"\xc7\x01\n\x18\x44\x61tasetModelInfoResponse\x12\x15\n\rtrain_samples\x18\x01 \x01(\x05\x12\x1a\n\x12validation_samples\x18\x02 \x01(\x05\x12\x1a\n\x12\x63lient_model_flops\x18\x03 \x01(\x05\x12\x1a\n\x12server_model_flops\x18\x04 \x01(\x05\x12)\n\x12\x64iagnostic_metrics\x18\x05 \x01(\x0b\x32\x08.MetricsH\x00\x88\x01\x01\x42\x15\n\x13_diagnostic_metrics2\xf8\x08\n\x06\x44\x65vice\x12:\n\x0bTrainGlobal\x12\x13.TrainGlobalRequest\x1a\x14.TrainGlobalResponse\"\x00\x12\x37\n\nSetWeights\x12\x12.SetWeightsRequest\x1a\x13.SetWeightsResponse\"\x00\x12\x37\n\nTrainEpoch\x12\x12.TrainEpochRequest\x1a\x13.TrainEpochResponse\"\x00\x12\x37\n\nTrainBatch\x12\x12.TrainBatchRequest\x1a\x13.TrainBatchResponse\"\x00\x12;\n\x0e\x45valuateGlobal\x12\x12.EvalGlobalRequest\x1a\x13.EvalGlobalResponse\"\x00\x12)\n\x08\x45valuate\x12\x0c.EvalRequest\x1a\r.EvalResponse\"\x00\x12\x38\n\rEvaluateBatch\x12\x11.EvalBatchRequest\x1a\x12.EvalBatchResponse\"\x00\x12\x46\n\x11\x46ullModelTraining\x12\x16.FullModelTrainRequest\x1a\x17.FullModelTrainResponse\"\x00\x12\x46\n\x0fStartExperiment\x12\x17.StartExperimentRequest\x1a\x18.StartExperimentResponse\"\x00\x12@\n\rEndExperiment\x12\x15.EndExperimentRequest\x1a\x16.EndExperimentResponse\"\x00\x12\x43\n\x10GetBatteryStatus\x12\x15.BatteryStatusRequest\x1a\x16.BatteryStatusResponse\"\x00\x12L\n\x13GetDatasetModelInfo\x12\x18.DatasetModelInfoRequest\x1a\x19.DatasetModelInfoResponse\"\x00\x12y\n TrainGlobalParallelSplitLearning\x12(.TrainGlobalParallelSplitLearningRequest\x1a).TrainGlobalParallelSplitLearningResponse\"\x00\x12W\n\x18TrainSingleBatchOnClient\x12\x1b.SingleBatchTrainingRequest\x1a\x1c.SingleBatchTrainingResponse\"\x00\x12\x65\n&BackwardPropagationSingleBatchOnClient\x12\x1b.SingleBatchBackwardRequest\x1a\x1c.SingleBatchBackwardResponse\"\x00\x12\x45\n#SetGradientsAndFinalizeTrainingStep\x12\x14.SetGradientsRequest\x1a\x06.Empty\"\x00\x62\x06proto3')
 
 _globals = globals()
 _builder.BuildMessageAndEnumDescriptors(DESCRIPTOR, _globals)
@@ -22,68 +22,70 @@ _builder.BuildTopDescriptorsAndMessages(DESCRIPTOR, 'connection_pb2', _globals)
 if _descriptor._USE_C_DESCRIPTORS == False:
 
   DESCRIPTOR._options = None
-  _globals['_UPDATEWEIGHTSREQUEST']._serialized_start=42
-  _globals['_UPDATEWEIGHTSREQUEST']._serialized_end=95
-  _globals['_SINGLEBATCHBACKWARDREQUEST']._serialized_start=97
-  _globals['_SINGLEBATCHBACKWARDREQUEST']._serialized_end=156
-  _globals['_SINGLEBATCHBACKWARDRESPONSE']._serialized_start=158
-  _globals['_SINGLEBATCHBACKWARDRESPONSE']._serialized_end=214
-  _globals['_SINGLEBATCHTRAININGREQUEST']._serialized_start=216
-  _globals['_SINGLEBATCHTRAININGREQUEST']._serialized_end=265
-  _globals['_SINGLEBATCHTRAININGRESPONSE']._serialized_start=268
-  _globals['_SINGLEBATCHTRAININGRESPONSE']._serialized_end=396
-  _globals['_TRAINGLOBALPARALLELSPLITLEARNINGREQUEST']._serialized_start=399
-  _globals['_TRAINGLOBALPARALLELSPLITLEARNINGREQUEST']._serialized_end=612
-  _globals['_TRAINGLOBALPARALLELSPLITLEARNINGRESPONSE']._serialized_start=615
-  _globals['_TRAINGLOBALPARALLELSPLITLEARNINGRESPONSE']._serialized_end=880
-  _globals['_TRAINGLOBALREQUEST']._serialized_start=883
-  _globals['_TRAINGLOBALREQUEST']._serialized_end=1017
-  _globals['_TRAINGLOBALRESPONSE']._serialized_start=1020
-  _globals['_TRAINGLOBALRESPONSE']._serialized_end=1264
-  _globals['_SETWEIGHTSREQUEST']._serialized_start=1266
-  _globals['_SETWEIGHTSREQUEST']._serialized_end=1331
-  _globals['_SETWEIGHTSRESPONSE']._serialized_start=1333
-  _globals['_SETWEIGHTSRESPONSE']._serialized_end=1419
-  _globals['_TRAINEPOCHREQUEST']._serialized_start=1421
-  _globals['_TRAINEPOCHREQUEST']._serialized_end=1505
-  _globals['_TRAINEPOCHRESPONSE']._serialized_start=1507
-  _globals['_TRAINEPOCHRESPONSE']._serialized_end=1620
-  _globals['_TRAINBATCHREQUEST']._serialized_start=1622
-  _globals['_TRAINBATCHREQUEST']._serialized_end=1702
-  _globals['_TRAINBATCHRESPONSE']._serialized_start=1705
-  _globals['_TRAINBATCHRESPONSE']._serialized_end=1850
-  _globals['_EVALGLOBALREQUEST']._serialized_start=1852
-  _globals['_EVALGLOBALREQUEST']._serialized_end=1910
-  _globals['_EVALGLOBALRESPONSE']._serialized_start=1912
-  _globals['_EVALGLOBALRESPONSE']._serialized_end=2025
-  _globals['_EVALREQUEST']._serialized_start=2027
-  _globals['_EVALREQUEST']._serialized_end=2089
-  _globals['_EVALRESPONSE']._serialized_start=2091
-  _globals['_EVALRESPONSE']._serialized_end=2171
-  _globals['_EVALBATCHREQUEST']._serialized_start=2173
-  _globals['_EVALBATCHREQUEST']._serialized_end=2252
-  _globals['_EVALBATCHRESPONSE']._serialized_start=2254
-  _globals['_EVALBATCHRESPONSE']._serialized_end=2366
-  _globals['_FULLMODELTRAINREQUEST']._serialized_start=2368
-  _globals['_FULLMODELTRAINREQUEST']._serialized_end=2427
-  _globals['_FULLMODELTRAINRESPONSE']._serialized_start=2430
-  _globals['_FULLMODELTRAINRESPONSE']._serialized_end=2636
-  _globals['_STARTEXPERIMENTREQUEST']._serialized_start=2638
-  _globals['_STARTEXPERIMENTREQUEST']._serialized_end=2662
-  _globals['_STARTEXPERIMENTRESPONSE']._serialized_start=2664
-  _globals['_STARTEXPERIMENTRESPONSE']._serialized_end=2755
-  _globals['_ENDEXPERIMENTREQUEST']._serialized_start=2757
-  _globals['_ENDEXPERIMENTREQUEST']._serialized_end=2779
-  _globals['_ENDEXPERIMENTRESPONSE']._serialized_start=2781
-  _globals['_ENDEXPERIMENTRESPONSE']._serialized_end=2870
-  _globals['_BATTERYSTATUSREQUEST']._serialized_start=2872
-  _globals['_BATTERYSTATUSREQUEST']._serialized_end=2894
-  _globals['_BATTERYSTATUSRESPONSE']._serialized_start=2896
-  _globals['_BATTERYSTATUSRESPONSE']._serialized_end=3017
-  _globals['_DATASETMODELINFOREQUEST']._serialized_start=3019
-  _globals['_DATASETMODELINFOREQUEST']._serialized_end=3044
-  _globals['_DATASETMODELINFORESPONSE']._serialized_start=3047
-  _globals['_DATASETMODELINFORESPONSE']._serialized_end=3246
-  _globals['_DEVICE']._serialized_start=3249
-  _globals['_DEVICE']._serialized_end=4322
+  _globals['_SETGRADIENTSREQUEST']._serialized_start=42
+  _globals['_SETGRADIENTSREQUEST']._serialized_end=94
+  _globals['_UPDATEWEIGHTSREQUEST']._serialized_start=96
+  _globals['_UPDATEWEIGHTSREQUEST']._serialized_end=149
+  _globals['_SINGLEBATCHBACKWARDREQUEST']._serialized_start=151
+  _globals['_SINGLEBATCHBACKWARDREQUEST']._serialized_end=210
+  _globals['_SINGLEBATCHBACKWARDRESPONSE']._serialized_start=212
+  _globals['_SINGLEBATCHBACKWARDRESPONSE']._serialized_end=318
+  _globals['_SINGLEBATCHTRAININGREQUEST']._serialized_start=320
+  _globals['_SINGLEBATCHTRAININGREQUEST']._serialized_end=369
+  _globals['_SINGLEBATCHTRAININGRESPONSE']._serialized_start=372
+  _globals['_SINGLEBATCHTRAININGRESPONSE']._serialized_end=500
+  _globals['_TRAINGLOBALPARALLELSPLITLEARNINGREQUEST']._serialized_start=503
+  _globals['_TRAINGLOBALPARALLELSPLITLEARNINGREQUEST']._serialized_end=716
+  _globals['_TRAINGLOBALPARALLELSPLITLEARNINGRESPONSE']._serialized_start=719
+  _globals['_TRAINGLOBALPARALLELSPLITLEARNINGRESPONSE']._serialized_end=984
+  _globals['_TRAINGLOBALREQUEST']._serialized_start=987
+  _globals['_TRAINGLOBALREQUEST']._serialized_end=1121
+  _globals['_TRAINGLOBALRESPONSE']._serialized_start=1124
+  _globals['_TRAINGLOBALRESPONSE']._serialized_end=1368
+  _globals['_SETWEIGHTSREQUEST']._serialized_start=1370
+  _globals['_SETWEIGHTSREQUEST']._serialized_end=1435
+  _globals['_SETWEIGHTSRESPONSE']._serialized_start=1437
+  _globals['_SETWEIGHTSRESPONSE']._serialized_end=1523
+  _globals['_TRAINEPOCHREQUEST']._serialized_start=1525
+  _globals['_TRAINEPOCHREQUEST']._serialized_end=1609
+  _globals['_TRAINEPOCHRESPONSE']._serialized_start=1611
+  _globals['_TRAINEPOCHRESPONSE']._serialized_end=1724
+  _globals['_TRAINBATCHREQUEST']._serialized_start=1726
+  _globals['_TRAINBATCHREQUEST']._serialized_end=1806
+  _globals['_TRAINBATCHRESPONSE']._serialized_start=1809
+  _globals['_TRAINBATCHRESPONSE']._serialized_end=1954
+  _globals['_EVALGLOBALREQUEST']._serialized_start=1956
+  _globals['_EVALGLOBALREQUEST']._serialized_end=2014
+  _globals['_EVALGLOBALRESPONSE']._serialized_start=2016
+  _globals['_EVALGLOBALRESPONSE']._serialized_end=2129
+  _globals['_EVALREQUEST']._serialized_start=2131
+  _globals['_EVALREQUEST']._serialized_end=2193
+  _globals['_EVALRESPONSE']._serialized_start=2195
+  _globals['_EVALRESPONSE']._serialized_end=2275
+  _globals['_EVALBATCHREQUEST']._serialized_start=2277
+  _globals['_EVALBATCHREQUEST']._serialized_end=2356
+  _globals['_EVALBATCHRESPONSE']._serialized_start=2358
+  _globals['_EVALBATCHRESPONSE']._serialized_end=2470
+  _globals['_FULLMODELTRAINREQUEST']._serialized_start=2472
+  _globals['_FULLMODELTRAINREQUEST']._serialized_end=2531
+  _globals['_FULLMODELTRAINRESPONSE']._serialized_start=2534
+  _globals['_FULLMODELTRAINRESPONSE']._serialized_end=2740
+  _globals['_STARTEXPERIMENTREQUEST']._serialized_start=2742
+  _globals['_STARTEXPERIMENTREQUEST']._serialized_end=2766
+  _globals['_STARTEXPERIMENTRESPONSE']._serialized_start=2768
+  _globals['_STARTEXPERIMENTRESPONSE']._serialized_end=2859
+  _globals['_ENDEXPERIMENTREQUEST']._serialized_start=2861
+  _globals['_ENDEXPERIMENTREQUEST']._serialized_end=2883
+  _globals['_ENDEXPERIMENTRESPONSE']._serialized_start=2885
+  _globals['_ENDEXPERIMENTRESPONSE']._serialized_end=2974
+  _globals['_BATTERYSTATUSREQUEST']._serialized_start=2976
+  _globals['_BATTERYSTATUSREQUEST']._serialized_end=2998
+  _globals['_BATTERYSTATUSRESPONSE']._serialized_start=3000
+  _globals['_BATTERYSTATUSRESPONSE']._serialized_end=3121
+  _globals['_DATASETMODELINFOREQUEST']._serialized_start=3123
+  _globals['_DATASETMODELINFOREQUEST']._serialized_end=3148
+  _globals['_DATASETMODELINFORESPONSE']._serialized_start=3151
+  _globals['_DATASETMODELINFORESPONSE']._serialized_end=3350
+  _globals['_DEVICE']._serialized_start=3353
+  _globals['_DEVICE']._serialized_end=4497
 # @@protoc_insertion_point(module_scope)
diff --git a/edml/generated/connection_pb2.pyi b/edml/generated/connection_pb2.pyi
index 730bf6e320c91c3d08548a1cbac1607352c85cad..a9735505e808ee35b156158e972ab6b206e90b0a 100644
--- a/edml/generated/connection_pb2.pyi
+++ b/edml/generated/connection_pb2.pyi
@@ -5,6 +5,12 @@ from typing import ClassVar as _ClassVar, Mapping as _Mapping, Optional as _Opti
 
 DESCRIPTOR: _descriptor.FileDescriptor
 
+class SetGradientsRequest(_message.Message):
+    __slots__ = ["gradients"]
+    GRADIENTS_FIELD_NUMBER: _ClassVar[int]
+    gradients: _datastructures_pb2.Gradients
+    def __init__(self, gradients: _Optional[_Union[_datastructures_pb2.Gradients, _Mapping]] = ...) -> None: ...
+
 class UpdateWeightsRequest(_message.Message):
     __slots__ = ["gradients"]
     GRADIENTS_FIELD_NUMBER: _ClassVar[int]
@@ -18,10 +24,12 @@ class SingleBatchBackwardRequest(_message.Message):
     def __init__(self, gradients: _Optional[_Union[_datastructures_pb2.Gradients, _Mapping]] = ...) -> None: ...
 
 class SingleBatchBackwardResponse(_message.Message):
-    __slots__ = ["metrics"]
+    __slots__ = ["metrics", "gradients"]
     METRICS_FIELD_NUMBER: _ClassVar[int]
+    GRADIENTS_FIELD_NUMBER: _ClassVar[int]
     metrics: _datastructures_pb2.Metrics
-    def __init__(self, metrics: _Optional[_Union[_datastructures_pb2.Metrics, _Mapping]] = ...) -> None: ...
+    gradients: _datastructures_pb2.Gradients
+    def __init__(self, metrics: _Optional[_Union[_datastructures_pb2.Metrics, _Mapping]] = ..., gradients: _Optional[_Union[_datastructures_pb2.Gradients, _Mapping]] = ...) -> None: ...
 
 class SingleBatchTrainingRequest(_message.Message):
     __slots__ = ["batch_index"]
diff --git a/edml/generated/connection_pb2_grpc.py b/edml/generated/connection_pb2_grpc.py
index 7d833927b4ebe5c4ae2ab97481bb073510f41a5d..c5b692413f8f3dcb6b06ba9e57d7aa31118ba740 100644
--- a/edml/generated/connection_pb2_grpc.py
+++ b/edml/generated/connection_pb2_grpc.py
@@ -3,6 +3,7 @@
 import grpc
 
 import connection_pb2 as connection__pb2
+import datastructures_pb2 as datastructures__pb2
 
 
 class DeviceStub(object):
@@ -89,6 +90,11 @@ class DeviceStub(object):
                 request_serializer=connection__pb2.SingleBatchBackwardRequest.SerializeToString,
                 response_deserializer=connection__pb2.SingleBatchBackwardResponse.FromString,
                 )
+        self.SetGradientsAndFinalizeTrainingStep = channel.unary_unary(
+                '/Device/SetGradientsAndFinalizeTrainingStep',
+                request_serializer=connection__pb2.SetGradientsRequest.SerializeToString,
+                response_deserializer=datastructures__pb2.Empty.FromString,
+                )
 
 
 class DeviceServicer(object):
@@ -185,6 +191,12 @@ class DeviceServicer(object):
         context.set_details('Method not implemented!')
         raise NotImplementedError('Method not implemented!')
 
+    def SetGradientsAndFinalizeTrainingStep(self, request, context):
+        """Missing associated documentation comment in .proto file."""
+        context.set_code(grpc.StatusCode.UNIMPLEMENTED)
+        context.set_details('Method not implemented!')
+        raise NotImplementedError('Method not implemented!')
+
 
 def add_DeviceServicer_to_server(servicer, server):
     rpc_method_handlers = {
@@ -263,6 +275,11 @@ def add_DeviceServicer_to_server(servicer, server):
                     request_deserializer=connection__pb2.SingleBatchBackwardRequest.FromString,
                     response_serializer=connection__pb2.SingleBatchBackwardResponse.SerializeToString,
             ),
+            'SetGradientsAndFinalizeTrainingStep': grpc.unary_unary_rpc_method_handler(
+                    servicer.SetGradientsAndFinalizeTrainingStep,
+                    request_deserializer=connection__pb2.SetGradientsRequest.FromString,
+                    response_serializer=datastructures__pb2.Empty.SerializeToString,
+            ),
     }
     generic_handler = grpc.method_handlers_generic_handler(
             'Device', rpc_method_handlers)
@@ -527,3 +544,20 @@ class Device(object):
             connection__pb2.SingleBatchBackwardResponse.FromString,
             options, channel_credentials,
             insecure, call_credentials, compression, wait_for_ready, timeout, metadata)
+
+    @staticmethod
+    def SetGradientsAndFinalizeTrainingStep(request,
+            target,
+            options=(),
+            channel_credentials=None,
+            call_credentials=None,
+            insecure=False,
+            compression=None,
+            wait_for_ready=None,
+            timeout=None,
+            metadata=None):
+        return grpc.experimental.unary_unary(request, target, '/Device/SetGradientsAndFinalizeTrainingStep',
+            connection__pb2.SetGradientsRequest.SerializeToString,
+            datastructures__pb2.Empty.FromString,
+            options, channel_credentials,
+            insecure, call_credentials, compression, wait_for_ready, timeout, metadata)
diff --git a/edml/proto/connection.proto b/edml/proto/connection.proto
index d03d2d1ec39b7f884673de4baf6b46a0643d0cdd..2f12881a1cf7e3ee6ae2e076b2f9005141cfece2 100644
--- a/edml/proto/connection.proto
+++ b/edml/proto/connection.proto
@@ -19,6 +19,11 @@ service Device {
   rpc TrainGlobalParallelSplitLearning (TrainGlobalParallelSplitLearningRequest) returns (TrainGlobalParallelSplitLearningResponse) {}
   rpc TrainSingleBatchOnClient (SingleBatchTrainingRequest) returns (SingleBatchTrainingResponse) {}
   rpc BackwardPropagationSingleBatchOnClient(SingleBatchBackwardRequest) returns (SingleBatchBackwardResponse) {}
+  rpc SetGradientsAndFinalizeTrainingStep(SetGradientsRequest) returns (Empty) {}
+}
+
+message SetGradientsRequest {
+  Gradients gradients = 1;
 }
 
 message UpdateWeightsRequest {
@@ -31,6 +36,7 @@ message SingleBatchBackwardRequest {
 
 message SingleBatchBackwardResponse {
   Metrics metrics = 1;
+  optional Gradients gradients = 2;
 }
 
 message SingleBatchTrainingRequest {