diff --git a/edml/config/controller/swarm.yaml b/edml/config/controller/swarm.yaml
index f31a60a20c1e4cff7f8eec4bd9b1ff4958753ab5..39bb393eef5bf2b73a6aa5c755c9f41ad765809e 100644
--- a/edml/config/controller/swarm.yaml
+++ b/edml/config/controller/swarm.yaml
@@ -3,3 +3,4 @@ _target_: edml.controllers.swarm_controller.SwarmController
 _partial_: true
 defaults:
   - scheduler: sequential
+  - adaptive_threshold_fn: !!null
diff --git a/edml/controllers/swarm_controller.py b/edml/controllers/swarm_controller.py
index 1b61f0aac70ae67a5ff0a7df98077dcc4bba3ec8..19378c4aedfbe5b5b1307751ee4cbffe7519fe66 100644
--- a/edml/controllers/swarm_controller.py
+++ b/edml/controllers/swarm_controller.py
@@ -1,5 +1,9 @@
 from typing import Any
 
+from edml.controllers.adaptive_threshold_mechanism import AdaptiveThresholdFn
+from edml.controllers.adaptive_threshold_mechanism.static import (
+    StaticAdaptiveThresholdFn,
+)
 from edml.controllers.base_controller import BaseController
 from edml.controllers.scheduler.base import NextServerScheduler
 from edml.helpers.config_helpers import get_device_index_by_id
@@ -7,10 +11,16 @@ from edml.helpers.config_helpers import get_device_index_by_id
 
 class SwarmController(BaseController):
 
-    def __init__(self, cfg, scheduler: NextServerScheduler):
+    def __init__(
+        self,
+        cfg,
+        scheduler: NextServerScheduler,
+        adaptive_threshold_fn: AdaptiveThresholdFn = StaticAdaptiveThresholdFn(0.0),
+    ):
         super().__init__(cfg)
         scheduler.initialize(self)
         self._next_server_scheduler = scheduler
+        self._adaptive_threshold_fn = adaptive_threshold_fn
 
     def _train(self):
         client_weights = None
@@ -87,10 +97,13 @@ class SwarmController(BaseController):
             device_id=server_device_id, state_dict=server_weights, on_client=False
         )
 
+        adaptive_threshold = self._adaptive_threshold_fn.invoke(round_no)
+        self.logger.log({"adaptive-threshold": adaptive_threshold})
         training_response = self.request_dispatcher.train_global_on(
             server_device_id,
             epochs=1,
             round_no=round_no,
+            adaptive_learning_threshold=adaptive_threshold,
             optimizer_state=optimizer_state,
         )
 
diff --git a/edml/core/client.py b/edml/core/client.py
index aae9b9a917a6be55c6c2e273021d0b8289458bcc..03437ce4a4ade8472a889586fb3f82981a6278c8 100644
--- a/edml/core/client.py
+++ b/edml/core/client.py
@@ -303,15 +303,16 @@ class DeviceClient:
                     break
                 server_grad, _server_loss, diagnostic_metrics = train_batch_response
                 diagnostic_metric_container.merge(diagnostic_metrics)
-                self.node_device.battery.update_flops(
-                    self._model_flops["BW"] * len(batch_data)
-                )
-                server_grad = server_grad.to(self._device)
-                if has_autoencoder(self._model):
-                    self._model.trainable_layers_output.backward(server_grad)
-                else:
-                    smashed_data.backward(server_grad)
-                self._optimizer.step()
+                if server_grad is not None:  # otherwise threshold was applied
+                    self.node_device.battery.update_flops(
+                        self._model_flops["BW"] * len(batch_data)
+                    )
+                    server_grad = server_grad.to(self._device)
+                    if has_autoencoder(self._model):
+                        self._model.trainable_layers_output.backward(server_grad)
+                    else:
+                        smashed_data.backward(server_grad)
+                    self._optimizer.step()
 
         client_train_time = (
             time.time() - client_train_start_time - sum(server_train_batch_times)
diff --git a/edml/core/device.py b/edml/core/device.py
index 961ba4b729f2f854cd2e71970dc27875aef2ffbd..78c76b5cf875333b118e745230d06dccf227ad67 100644
--- a/edml/core/device.py
+++ b/edml/core/device.py
@@ -303,7 +303,11 @@ class NetworkDevice(Device):
     @update_battery
     @log_execution_time("logger", "train_global_time")
     def train_global(
-        self, epochs: int, round_no: int = -1, optimizer_state: dict[str, Any] = None
+        self,
+        epochs: int,
+        round_no: int = -1,
+        adaptive_learning_threshold: Optional[float] = None,
+        optimizer_state: dict[str, Any] = None,
     ) -> Tuple[
         Any, Any, ModelMetricResultContainer, Any, DiagnosticMetricResultContainer
     ]:
@@ -311,6 +315,7 @@ class NetworkDevice(Device):
             devices=self.__get_device_ids__(),
             epochs=epochs,
             round_no=round_no,
+            adaptive_learning_threshold=adaptive_learning_threshold,
             optimizer_state=optimizer_state,
         )
 
@@ -448,7 +453,12 @@ class RPCDeviceServicer(DeviceServicer):
     def TrainGlobal(self, request, context):
         print(f"Called TrainGlobal on device {self.device.device_id}")
         client_weights, server_weights, metrics, optimizer_state, diagnostic_metrics = (
-            self.device.train_global(request.epochs, request.round_no)
+            self.device.train_global(
+                request.epochs,
+                request.round_no,
+                request.adaptive_learning_threshold,
+                proto_to_state_dict(request.optimizer_state),
+            )
         )
         response = connection_pb2.TrainGlobalResponse(
             client_weights=Weights(weights=state_dict_to_proto(client_weights)),
@@ -568,12 +578,14 @@ class RPCDeviceServicer(DeviceServicer):
         clients = self.device.__get_device_ids__()
         round_no = request.round_no
         adaptive_learning_threshold = request.adaptive_learning_threshold
+        optimizer_state = proto_to_state_dict(request.optimizer_state)
 
         cw, sw, model_metrics, optimizer_state, diagnostic_metrics = (
             self.device.train_parallel_split_learning(
                 clients=clients,
                 round_no=round_no,
                 adaptive_learning_threshold=adaptive_learning_threshold,
+                optimizer_state=optimizer_state,
             )
         )
         response = connection_pb2.TrainGlobalParallelSplitLearningResponse(
@@ -761,6 +773,7 @@ class DeviceRequestDispatcher:
         device_id: str,
         epochs: int,
         round_no: int = -1,
+        adaptive_learning_threshold: Optional[float] = None,
         optimizer_state: dict[str, Any] = None,
     ) -> Union[
         Tuple[
@@ -777,6 +790,7 @@ class DeviceRequestDispatcher:
                 connection_pb2.TrainGlobalRequest(
                     epochs=epochs,
                     round_no=round_no,
+                    adaptive_learning_threshold=adaptive_learning_threshold,
                     optimizer_state=state_dict_to_proto(optimizer_state),
                 )
             )
diff --git a/edml/core/server.py b/edml/core/server.py
index 22b6a599a6d7a067a24f2543bc21470f54ca2bf6..42f26b6481424db18b49886dcb324e91c9195c28 100644
--- a/edml/core/server.py
+++ b/edml/core/server.py
@@ -51,6 +51,7 @@ class DeviceServer:
         self._cfg = cfg
         self.node_device: Optional[Device] = None
         self.latency_factor = latency_factor
+        self.adaptive_learning_threshold = None
 
     def set_device(self, node_device: Device):
         """Sets the device reference for the server."""
@@ -73,6 +74,7 @@ class DeviceServer:
         devices: List[str],
         epochs: int = 1,
         round_no: int = -1,
+        adaptive_learning_threshold: Optional[float] = None,
         optimizer_state: dict[str, Any] = None,
     ) -> Tuple[
         Any, Any, ModelMetricResultContainer, Any, DiagnosticMetricResultContainer
@@ -83,13 +85,19 @@ class DeviceServer:
             devices: The devices to train on
             epochs: Optionally, the number of epochs to train.
             round_no: Optionally, the current global epoch number if a learning rate scheduler is used.
+            adaptive_learning_threshold: Optionally, the loss threshold to not send the gradients to the client
             optimizer_state: Optionally, the optimizer_state to proceed from
         """
         client_weights = None
         metrics = ModelMetricResultContainer()
         diagnostic_metric_container = DiagnosticMetricResultContainer()
         if optimizer_state is not None:
+            print(f"apply optimizer state {optimizer_state}")
             self._optimizer.load_state_dict(optimizer_state)
+        else:
+            print("optimizer state is None")
+        if adaptive_learning_threshold is not None:
+            self.adaptive_learning_threshold = adaptive_learning_threshold
         for epoch in range(epochs):
             if self._lr_scheduler is not None:
                 if round_no != -1:
@@ -98,7 +106,7 @@ class DeviceServer:
                     self._lr_scheduler.step()
             for device_id in devices:
                 print(
-                    f"Train epoch {epoch} on client {device_id} with server {self.node_device.device_id}"
+                    f"Train epoch {epoch} on client {device_id} with server {self.node_device.device_id} and threshold {self.adaptive_learning_threshold}"
                 )
                 if client_weights is not None:
                     self.node_device.set_weights_on(
@@ -135,7 +143,7 @@ class DeviceServer:
         )
 
     @simulate_latency_decorator(latency_factor_attr="latency_factor")
-    def train_batch(self, smashed_data, labels) -> Tuple[Variable, float]:
+    def train_batch(self, smashed_data, labels) -> Tuple[Optional[Variable], float]:
         """Train the model on the given batch of data and labels.
         Returns the gradients of the model's parameters."""
         smashed_data, labels = smashed_data.to(self._device), labels.to(self._device)
@@ -161,12 +169,20 @@ class DeviceServer:
         # Capturing training metrics for the current batch.
         self.node_device.log({"loss": loss_train.item()})
         self._metrics.metrics_on_batch(output_train.cpu(), labels.cpu().int())
+
         if has_autoencoder(self._model):
-            return self._model.trainable_layers_input.grad, loss_train.item()
-        return (
-            smashed_data.grad,
-            loss_train.item(),
-        )  # hier sollten beim AE die gradients vom server model vor dem decoder zurückgegeben werden
+            gradients = self._model.trainable_layers_input.grad
+        else:
+            gradients = smashed_data.grad
+        if (
+            self.adaptive_learning_threshold
+            and loss_train.item() < self.adaptive_learning_threshold
+        ):
+            self.node_device.log(
+                {"adaptive_learning_threshold_applied": gradients.size(0)}
+            )
+            return None, loss_train.item()
+        return gradients, loss_train.item()
 
     def _set_model_flops(self, sample):
         """Helper to determine the model flops when smashed data are available for the first time."""
@@ -261,7 +277,8 @@ class DeviceServer:
                 self._lr_scheduler.step(round_no + 1)  # epoch=1
             else:
                 self._lr_scheduler.step()
-
+        if adaptive_learning_threshold is not None:
+            self.adaptive_learning_threshold = adaptive_learning_threshold
         num_threads = len(clients)
         executor = create_executor_with_threads(num_threads)
 
@@ -311,25 +328,9 @@ class DeviceServer:
                 server_gradients, server_loss, server_metrics = (
                     self.node_device.train_batch(server_batch, server_labels)
                 )  # DiagnosticMetricResultContainer
-                # We check if the server should activate the adaptive learning threshold. And if true, we make sure to only
-                # do the client propagation once the current loss value is larger than the threshold.
-                print(
-                    f"\n{Fore.GREEN}{adaptive_learning_threshold} <-> {server_loss}\n{Fore.RESET}"
-                )
                 if (
-                    adaptive_learning_threshold
-                    and server_loss < adaptive_learning_threshold
-                ):
-                    print(
-                        f"\n{Fore.RED}ADAPTIVE TRESHOLD REACHED, NEXT BATCH\n{Fore.RESET}"
-                    )
-                    self.node_device.log(
-                        {
-                            "adaptive_learning_threshold_applied": server_gradients.size(
-                                0
-                            )
-                        }
-                    )
+                    server_gradients is None
+                ):  # loss threshold was reached, skip client backprop
                     continue
 
                 num_client_gradients = len(client_forward_pass_responses)
@@ -409,25 +410,9 @@ class DeviceServer:
                 server_gradients, server_loss, server_metrics = (
                     self.node_device.train_batch(server_batch, server_labels)
                 )  # DiagnosticMetricResultContainer
-                # We check if the server should activate the adaptive learning threshold. And if true, we make sure to only
-                # do the client propagation once the current loss value is larger than the threshold.
-                print(
-                    f"\n{Fore.GREEN}{adaptive_learning_threshold} <-> {server_loss}\n{Fore.RESET}"
-                )
                 if (
-                    adaptive_learning_threshold
-                    and server_loss < adaptive_learning_threshold
-                ):
-                    print(
-                        f"\n{Fore.RED}ADAPTIVE TRESHOLD REACHED, NEXT BATCH\n{Fore.RESET}"
-                    )
-                    self.node_device.log(
-                        {
-                            "adaptive_learning_threshold_applied": server_gradients.size(
-                                0
-                            )
-                        }
-                    )
+                    server_gradients is None
+                ):  # loss threshold was reached, skip client backprop
                     continue
 
                 num_client_gradients = len(client_forward_pass_responses)
diff --git a/edml/generated/connection_pb2.py b/edml/generated/connection_pb2.py
index 3c25ef7670c9e08c02ca4df71de66afa9e55193b..20237aafe29133f1667ddd8a256fd29c2a196d05 100644
--- a/edml/generated/connection_pb2.py
+++ b/edml/generated/connection_pb2.py
@@ -14,7 +14,7 @@ _sym_db = _symbol_database.Default()
 import datastructures_pb2 as datastructures__pb2
 
 
-DESCRIPTOR = _descriptor_pool.Default().AddSerializedFile(b'\n\x10\x63onnection.proto\x1a\x14\x64\x61tastructures.proto\"4\n\x13SetGradientsRequest\x12\x1d\n\tgradients\x18\x01 \x01(\x0b\x32\n.Gradients\"5\n\x14UpdateWeightsRequest\x12\x1d\n\tgradients\x18\x01 \x01(\x0b\x32\n.Gradients\";\n\x1aSingleBatchBackwardRequest\x12\x1d\n\tgradients\x18\x01 \x01(\x0b\x32\n.Gradients\"j\n\x1bSingleBatchBackwardResponse\x12\x19\n\x07metrics\x18\x01 \x01(\x0b\x32\x08.Metrics\x12\"\n\tgradients\x18\x02 \x01(\x0b\x32\n.GradientsH\x00\x88\x01\x01\x42\x0c\n\n_gradients\"C\n\x1aSingleBatchTrainingRequest\x12\x13\n\x0b\x62\x61tch_index\x18\x01 \x01(\x05\x12\x10\n\x08round_no\x18\x02 \x01(\x05\"\x80\x01\n\x1bSingleBatchTrainingResponse\x12\'\n\x0csmashed_data\x18\x01 \x01(\x0b\x32\x0c.ActivationsH\x00\x88\x01\x01\x12\x1c\n\x06labels\x18\x02 \x01(\x0b\x32\x07.LabelsH\x01\x88\x01\x01\x42\x0f\n\r_smashed_dataB\t\n\x07_labels\"\xd5\x01\n\'TrainGlobalParallelSplitLearningRequest\x12\x15\n\x08round_no\x18\x01 \x01(\x05H\x00\x88\x01\x01\x12(\n\x1b\x61\x64\x61ptive_learning_threshold\x18\x02 \x01(\x01H\x01\x88\x01\x01\x12(\n\x0foptimizer_state\x18\x03 \x01(\x0b\x32\n.StateDictH\x02\x88\x01\x01\x42\x0b\n\t_round_noB\x1e\n\x1c_adaptive_learning_thresholdB\x12\n\x10_optimizer_state\"\x89\x02\n(TrainGlobalParallelSplitLearningResponse\x12 \n\x0e\x63lient_weights\x18\x01 \x01(\x0b\x32\x08.Weights\x12 \n\x0eserver_weights\x18\x02 \x01(\x0b\x32\x08.Weights\x12\x19\n\x07metrics\x18\x03 \x01(\x0b\x32\x08.Metrics\x12(\n\x0foptimizer_state\x18\x04 \x01(\x0b\x32\n.StateDictH\x00\x88\x01\x01\x12)\n\x12\x64iagnostic_metrics\x18\x05 \x01(\x0b\x32\x08.MetricsH\x01\x88\x01\x01\x42\x12\n\x10_optimizer_stateB\x15\n\x13_diagnostic_metrics\"\x86\x01\n\x12TrainGlobalRequest\x12\x0e\n\x06\x65pochs\x18\x01 \x01(\x05\x12\x15\n\x08round_no\x18\x02 \x01(\x05H\x00\x88\x01\x01\x12(\n\x0foptimizer_state\x18\x03 \x01(\x0b\x32\n.StateDictH\x01\x88\x01\x01\x42\x0b\n\t_round_noB\x12\n\x10_optimizer_state\"\xf4\x01\n\x13TrainGlobalResponse\x12 \n\x0e\x63lient_weights\x18\x01 \x01(\x0b\x32\x08.Weights\x12 \n\x0eserver_weights\x18\x02 \x01(\x0b\x32\x08.Weights\x12\x19\n\x07metrics\x18\x03 \x01(\x0b\x32\x08.Metrics\x12(\n\x0foptimizer_state\x18\x04 \x01(\x0b\x32\n.StateDictH\x00\x88\x01\x01\x12)\n\x12\x64iagnostic_metrics\x18\x05 \x01(\x0b\x32\x08.MetricsH\x01\x88\x01\x01\x42\x12\n\x10_optimizer_stateB\x15\n\x13_diagnostic_metrics\"A\n\x11SetWeightsRequest\x12\x19\n\x07weights\x18\x01 \x01(\x0b\x32\x08.Weights\x12\x11\n\ton_client\x18\x02 \x01(\x08\"V\n\x12SetWeightsResponse\x12)\n\x12\x64iagnostic_metrics\x18\x01 \x01(\x0b\x32\x08.MetricsH\x00\x88\x01\x01\x42\x15\n\x13_diagnostic_metrics\"T\n\x11TrainEpochRequest\x12\x1b\n\x06server\x18\x01 \x01(\x0b\x32\x0b.DeviceInfo\x12\x15\n\x08round_no\x18\x02 \x01(\x05H\x00\x88\x01\x01\x42\x0b\n\t_round_no\"q\n\x12TrainEpochResponse\x12\x19\n\x07weights\x18\x01 \x01(\x0b\x32\x08.Weights\x12)\n\x12\x64iagnostic_metrics\x18\x02 \x01(\x0b\x32\x08.MetricsH\x00\x88\x01\x01\x42\x15\n\x13_diagnostic_metrics\"P\n\x11TrainBatchRequest\x12\"\n\x0csmashed_data\x18\x01 \x01(\x0b\x32\x0c.Activations\x12\x17\n\x06labels\x18\x02 \x01(\x0b\x32\x07.Labels\"\x91\x01\n\x12TrainBatchResponse\x12\x1d\n\tgradients\x18\x01 \x01(\x0b\x32\n.Gradients\x12)\n\x12\x64iagnostic_metrics\x18\x02 \x01(\x0b\x32\x08.MetricsH\x00\x88\x01\x01\x12\x11\n\x04loss\x18\x03 \x01(\x01H\x01\x88\x01\x01\x42\x15\n\x13_diagnostic_metricsB\x07\n\x05_loss\":\n\x11\x45valGlobalRequest\x12\x12\n\nvalidation\x18\x01 \x01(\x08\x12\x11\n\tfederated\x18\x02 \x01(\x08\"q\n\x12\x45valGlobalResponse\x12\x19\n\x07metrics\x18\x01 \x01(\x0b\x32\x08.Metrics\x12)\n\x12\x64iagnostic_metrics\x18\x02 \x01(\x0b\x32\x08.MetricsH\x00\x88\x01\x01\x42\x15\n\x13_diagnostic_metrics\">\n\x0b\x45valRequest\x12\x1b\n\x06server\x18\x01 \x01(\x0b\x32\x0b.DeviceInfo\x12\x12\n\nvalidation\x18\x02 \x01(\x08\"P\n\x0c\x45valResponse\x12)\n\x12\x64iagnostic_metrics\x18\x01 \x01(\x0b\x32\x08.MetricsH\x00\x88\x01\x01\x42\x15\n\x13_diagnostic_metrics\"O\n\x10\x45valBatchRequest\x12\"\n\x0csmashed_data\x18\x01 \x01(\x0b\x32\x0c.Activations\x12\x17\n\x06labels\x18\x02 \x01(\x0b\x32\x07.Labels\"p\n\x11\x45valBatchResponse\x12\x19\n\x07metrics\x18\x01 \x01(\x0b\x32\x08.Metrics\x12)\n\x12\x64iagnostic_metrics\x18\x02 \x01(\x0b\x32\x08.MetricsH\x00\x88\x01\x01\x42\x15\n\x13_diagnostic_metrics\";\n\x15\x46ullModelTrainRequest\x12\x15\n\x08round_no\x18\x01 \x01(\x05H\x00\x88\x01\x01\x42\x0b\n\t_round_no\"\xce\x01\n\x16\x46ullModelTrainResponse\x12 \n\x0e\x63lient_weights\x18\x01 \x01(\x0b\x32\x08.Weights\x12 \n\x0eserver_weights\x18\x02 \x01(\x0b\x32\x08.Weights\x12\x13\n\x0bnum_samples\x18\x03 \x01(\x05\x12\x19\n\x07metrics\x18\x04 \x01(\x0b\x32\x08.Metrics\x12)\n\x12\x64iagnostic_metrics\x18\x05 \x01(\x0b\x32\x08.MetricsH\x00\x88\x01\x01\x42\x15\n\x13_diagnostic_metrics\"\x18\n\x16StartExperimentRequest\"[\n\x17StartExperimentResponse\x12)\n\x12\x64iagnostic_metrics\x18\x01 \x01(\x0b\x32\x08.MetricsH\x00\x88\x01\x01\x42\x15\n\x13_diagnostic_metrics\"\x16\n\x14\x45ndExperimentRequest\"Y\n\x15\x45ndExperimentResponse\x12)\n\x12\x64iagnostic_metrics\x18\x01 \x01(\x0b\x32\x08.MetricsH\x00\x88\x01\x01\x42\x15\n\x13_diagnostic_metrics\"\x16\n\x14\x42\x61tteryStatusRequest\"y\n\x15\x42\x61tteryStatusResponse\x12\x1e\n\x06status\x18\x01 \x01(\x0b\x32\x0e.BatteryStatus\x12)\n\x12\x64iagnostic_metrics\x18\x02 \x01(\x0b\x32\x08.MetricsH\x00\x88\x01\x01\x42\x15\n\x13_diagnostic_metrics\"\x19\n\x17\x44\x61tasetModelInfoRequest\"\xa5\x02\n\x18\x44\x61tasetModelInfoResponse\x12\x15\n\rtrain_samples\x18\x01 \x01(\x05\x12\x1a\n\x12validation_samples\x18\x02 \x01(\x05\x12\x17\n\x0f\x63lient_fw_flops\x18\x03 \x01(\x05\x12\x17\n\x0fserver_fw_flops\x18\x04 \x01(\x05\x12\x1c\n\x0f\x63lient_bw_flops\x18\x05 \x01(\x05H\x00\x88\x01\x01\x12\x1c\n\x0fserver_bw_flops\x18\x06 \x01(\x05H\x01\x88\x01\x01\x12)\n\x12\x64iagnostic_metrics\x18\x07 \x01(\x0b\x32\x08.MetricsH\x02\x88\x01\x01\x42\x12\n\x10_client_bw_flopsB\x12\n\x10_server_bw_flopsB\x15\n\x13_diagnostic_metrics2\xf8\x08\n\x06\x44\x65vice\x12:\n\x0bTrainGlobal\x12\x13.TrainGlobalRequest\x1a\x14.TrainGlobalResponse\"\x00\x12\x37\n\nSetWeights\x12\x12.SetWeightsRequest\x1a\x13.SetWeightsResponse\"\x00\x12\x37\n\nTrainEpoch\x12\x12.TrainEpochRequest\x1a\x13.TrainEpochResponse\"\x00\x12\x37\n\nTrainBatch\x12\x12.TrainBatchRequest\x1a\x13.TrainBatchResponse\"\x00\x12;\n\x0e\x45valuateGlobal\x12\x12.EvalGlobalRequest\x1a\x13.EvalGlobalResponse\"\x00\x12)\n\x08\x45valuate\x12\x0c.EvalRequest\x1a\r.EvalResponse\"\x00\x12\x38\n\rEvaluateBatch\x12\x11.EvalBatchRequest\x1a\x12.EvalBatchResponse\"\x00\x12\x46\n\x11\x46ullModelTraining\x12\x16.FullModelTrainRequest\x1a\x17.FullModelTrainResponse\"\x00\x12\x46\n\x0fStartExperiment\x12\x17.StartExperimentRequest\x1a\x18.StartExperimentResponse\"\x00\x12@\n\rEndExperiment\x12\x15.EndExperimentRequest\x1a\x16.EndExperimentResponse\"\x00\x12\x43\n\x10GetBatteryStatus\x12\x15.BatteryStatusRequest\x1a\x16.BatteryStatusResponse\"\x00\x12L\n\x13GetDatasetModelInfo\x12\x18.DatasetModelInfoRequest\x1a\x19.DatasetModelInfoResponse\"\x00\x12y\n TrainGlobalParallelSplitLearning\x12(.TrainGlobalParallelSplitLearningRequest\x1a).TrainGlobalParallelSplitLearningResponse\"\x00\x12W\n\x18TrainSingleBatchOnClient\x12\x1b.SingleBatchTrainingRequest\x1a\x1c.SingleBatchTrainingResponse\"\x00\x12\x65\n&BackwardPropagationSingleBatchOnClient\x12\x1b.SingleBatchBackwardRequest\x1a\x1c.SingleBatchBackwardResponse\"\x00\x12\x45\n#SetGradientsAndFinalizeTrainingStep\x12\x14.SetGradientsRequest\x1a\x06.Empty\"\x00\x62\x06proto3')
+DESCRIPTOR = _descriptor_pool.Default().AddSerializedFile(b'\n\x10\x63onnection.proto\x1a\x14\x64\x61tastructures.proto\"4\n\x13SetGradientsRequest\x12\x1d\n\tgradients\x18\x01 \x01(\x0b\x32\n.Gradients\"5\n\x14UpdateWeightsRequest\x12\x1d\n\tgradients\x18\x01 \x01(\x0b\x32\n.Gradients\";\n\x1aSingleBatchBackwardRequest\x12\x1d\n\tgradients\x18\x01 \x01(\x0b\x32\n.Gradients\"j\n\x1bSingleBatchBackwardResponse\x12\x19\n\x07metrics\x18\x01 \x01(\x0b\x32\x08.Metrics\x12\"\n\tgradients\x18\x02 \x01(\x0b\x32\n.GradientsH\x00\x88\x01\x01\x42\x0c\n\n_gradients\"C\n\x1aSingleBatchTrainingRequest\x12\x13\n\x0b\x62\x61tch_index\x18\x01 \x01(\x05\x12\x10\n\x08round_no\x18\x02 \x01(\x05\"\x80\x01\n\x1bSingleBatchTrainingResponse\x12\'\n\x0csmashed_data\x18\x01 \x01(\x0b\x32\x0c.ActivationsH\x00\x88\x01\x01\x12\x1c\n\x06labels\x18\x02 \x01(\x0b\x32\x07.LabelsH\x01\x88\x01\x01\x42\x0f\n\r_smashed_dataB\t\n\x07_labels\"\xd5\x01\n\'TrainGlobalParallelSplitLearningRequest\x12\x15\n\x08round_no\x18\x01 \x01(\x05H\x00\x88\x01\x01\x12(\n\x1b\x61\x64\x61ptive_learning_threshold\x18\x02 \x01(\x01H\x01\x88\x01\x01\x12(\n\x0foptimizer_state\x18\x03 \x01(\x0b\x32\n.StateDictH\x02\x88\x01\x01\x42\x0b\n\t_round_noB\x1e\n\x1c_adaptive_learning_thresholdB\x12\n\x10_optimizer_state\"\x89\x02\n(TrainGlobalParallelSplitLearningResponse\x12 \n\x0e\x63lient_weights\x18\x01 \x01(\x0b\x32\x08.Weights\x12 \n\x0eserver_weights\x18\x02 \x01(\x0b\x32\x08.Weights\x12\x19\n\x07metrics\x18\x03 \x01(\x0b\x32\x08.Metrics\x12(\n\x0foptimizer_state\x18\x04 \x01(\x0b\x32\n.StateDictH\x00\x88\x01\x01\x12)\n\x12\x64iagnostic_metrics\x18\x05 \x01(\x0b\x32\x08.MetricsH\x01\x88\x01\x01\x42\x12\n\x10_optimizer_stateB\x15\n\x13_diagnostic_metrics\"\xd0\x01\n\x12TrainGlobalRequest\x12\x0e\n\x06\x65pochs\x18\x01 \x01(\x05\x12\x15\n\x08round_no\x18\x02 \x01(\x05H\x00\x88\x01\x01\x12(\n\x1b\x61\x64\x61ptive_learning_threshold\x18\x03 \x01(\x01H\x01\x88\x01\x01\x12(\n\x0foptimizer_state\x18\x04 \x01(\x0b\x32\n.StateDictH\x02\x88\x01\x01\x42\x0b\n\t_round_noB\x1e\n\x1c_adaptive_learning_thresholdB\x12\n\x10_optimizer_state\"\xf4\x01\n\x13TrainGlobalResponse\x12 \n\x0e\x63lient_weights\x18\x01 \x01(\x0b\x32\x08.Weights\x12 \n\x0eserver_weights\x18\x02 \x01(\x0b\x32\x08.Weights\x12\x19\n\x07metrics\x18\x03 \x01(\x0b\x32\x08.Metrics\x12(\n\x0foptimizer_state\x18\x04 \x01(\x0b\x32\n.StateDictH\x00\x88\x01\x01\x12)\n\x12\x64iagnostic_metrics\x18\x05 \x01(\x0b\x32\x08.MetricsH\x01\x88\x01\x01\x42\x12\n\x10_optimizer_stateB\x15\n\x13_diagnostic_metrics\"A\n\x11SetWeightsRequest\x12\x19\n\x07weights\x18\x01 \x01(\x0b\x32\x08.Weights\x12\x11\n\ton_client\x18\x02 \x01(\x08\"V\n\x12SetWeightsResponse\x12)\n\x12\x64iagnostic_metrics\x18\x01 \x01(\x0b\x32\x08.MetricsH\x00\x88\x01\x01\x42\x15\n\x13_diagnostic_metrics\"T\n\x11TrainEpochRequest\x12\x1b\n\x06server\x18\x01 \x01(\x0b\x32\x0b.DeviceInfo\x12\x15\n\x08round_no\x18\x02 \x01(\x05H\x00\x88\x01\x01\x42\x0b\n\t_round_no\"q\n\x12TrainEpochResponse\x12\x19\n\x07weights\x18\x01 \x01(\x0b\x32\x08.Weights\x12)\n\x12\x64iagnostic_metrics\x18\x02 \x01(\x0b\x32\x08.MetricsH\x00\x88\x01\x01\x42\x15\n\x13_diagnostic_metrics\"P\n\x11TrainBatchRequest\x12\"\n\x0csmashed_data\x18\x01 \x01(\x0b\x32\x0c.Activations\x12\x17\n\x06labels\x18\x02 \x01(\x0b\x32\x07.Labels\"\x91\x01\n\x12TrainBatchResponse\x12\x1d\n\tgradients\x18\x01 \x01(\x0b\x32\n.Gradients\x12)\n\x12\x64iagnostic_metrics\x18\x02 \x01(\x0b\x32\x08.MetricsH\x00\x88\x01\x01\x12\x11\n\x04loss\x18\x03 \x01(\x01H\x01\x88\x01\x01\x42\x15\n\x13_diagnostic_metricsB\x07\n\x05_loss\":\n\x11\x45valGlobalRequest\x12\x12\n\nvalidation\x18\x01 \x01(\x08\x12\x11\n\tfederated\x18\x02 \x01(\x08\"q\n\x12\x45valGlobalResponse\x12\x19\n\x07metrics\x18\x01 \x01(\x0b\x32\x08.Metrics\x12)\n\x12\x64iagnostic_metrics\x18\x02 \x01(\x0b\x32\x08.MetricsH\x00\x88\x01\x01\x42\x15\n\x13_diagnostic_metrics\">\n\x0b\x45valRequest\x12\x1b\n\x06server\x18\x01 \x01(\x0b\x32\x0b.DeviceInfo\x12\x12\n\nvalidation\x18\x02 \x01(\x08\"P\n\x0c\x45valResponse\x12)\n\x12\x64iagnostic_metrics\x18\x01 \x01(\x0b\x32\x08.MetricsH\x00\x88\x01\x01\x42\x15\n\x13_diagnostic_metrics\"O\n\x10\x45valBatchRequest\x12\"\n\x0csmashed_data\x18\x01 \x01(\x0b\x32\x0c.Activations\x12\x17\n\x06labels\x18\x02 \x01(\x0b\x32\x07.Labels\"p\n\x11\x45valBatchResponse\x12\x19\n\x07metrics\x18\x01 \x01(\x0b\x32\x08.Metrics\x12)\n\x12\x64iagnostic_metrics\x18\x02 \x01(\x0b\x32\x08.MetricsH\x00\x88\x01\x01\x42\x15\n\x13_diagnostic_metrics\";\n\x15\x46ullModelTrainRequest\x12\x15\n\x08round_no\x18\x01 \x01(\x05H\x00\x88\x01\x01\x42\x0b\n\t_round_no\"\xce\x01\n\x16\x46ullModelTrainResponse\x12 \n\x0e\x63lient_weights\x18\x01 \x01(\x0b\x32\x08.Weights\x12 \n\x0eserver_weights\x18\x02 \x01(\x0b\x32\x08.Weights\x12\x13\n\x0bnum_samples\x18\x03 \x01(\x05\x12\x19\n\x07metrics\x18\x04 \x01(\x0b\x32\x08.Metrics\x12)\n\x12\x64iagnostic_metrics\x18\x05 \x01(\x0b\x32\x08.MetricsH\x00\x88\x01\x01\x42\x15\n\x13_diagnostic_metrics\"\x18\n\x16StartExperimentRequest\"[\n\x17StartExperimentResponse\x12)\n\x12\x64iagnostic_metrics\x18\x01 \x01(\x0b\x32\x08.MetricsH\x00\x88\x01\x01\x42\x15\n\x13_diagnostic_metrics\"\x16\n\x14\x45ndExperimentRequest\"Y\n\x15\x45ndExperimentResponse\x12)\n\x12\x64iagnostic_metrics\x18\x01 \x01(\x0b\x32\x08.MetricsH\x00\x88\x01\x01\x42\x15\n\x13_diagnostic_metrics\"\x16\n\x14\x42\x61tteryStatusRequest\"y\n\x15\x42\x61tteryStatusResponse\x12\x1e\n\x06status\x18\x01 \x01(\x0b\x32\x0e.BatteryStatus\x12)\n\x12\x64iagnostic_metrics\x18\x02 \x01(\x0b\x32\x08.MetricsH\x00\x88\x01\x01\x42\x15\n\x13_diagnostic_metrics\"\x19\n\x17\x44\x61tasetModelInfoRequest\"\xa5\x02\n\x18\x44\x61tasetModelInfoResponse\x12\x15\n\rtrain_samples\x18\x01 \x01(\x05\x12\x1a\n\x12validation_samples\x18\x02 \x01(\x05\x12\x17\n\x0f\x63lient_fw_flops\x18\x03 \x01(\x05\x12\x17\n\x0fserver_fw_flops\x18\x04 \x01(\x05\x12\x1c\n\x0f\x63lient_bw_flops\x18\x05 \x01(\x05H\x00\x88\x01\x01\x12\x1c\n\x0fserver_bw_flops\x18\x06 \x01(\x05H\x01\x88\x01\x01\x12)\n\x12\x64iagnostic_metrics\x18\x07 \x01(\x0b\x32\x08.MetricsH\x02\x88\x01\x01\x42\x12\n\x10_client_bw_flopsB\x12\n\x10_server_bw_flopsB\x15\n\x13_diagnostic_metrics2\xf8\x08\n\x06\x44\x65vice\x12:\n\x0bTrainGlobal\x12\x13.TrainGlobalRequest\x1a\x14.TrainGlobalResponse\"\x00\x12\x37\n\nSetWeights\x12\x12.SetWeightsRequest\x1a\x13.SetWeightsResponse\"\x00\x12\x37\n\nTrainEpoch\x12\x12.TrainEpochRequest\x1a\x13.TrainEpochResponse\"\x00\x12\x37\n\nTrainBatch\x12\x12.TrainBatchRequest\x1a\x13.TrainBatchResponse\"\x00\x12;\n\x0e\x45valuateGlobal\x12\x12.EvalGlobalRequest\x1a\x13.EvalGlobalResponse\"\x00\x12)\n\x08\x45valuate\x12\x0c.EvalRequest\x1a\r.EvalResponse\"\x00\x12\x38\n\rEvaluateBatch\x12\x11.EvalBatchRequest\x1a\x12.EvalBatchResponse\"\x00\x12\x46\n\x11\x46ullModelTraining\x12\x16.FullModelTrainRequest\x1a\x17.FullModelTrainResponse\"\x00\x12\x46\n\x0fStartExperiment\x12\x17.StartExperimentRequest\x1a\x18.StartExperimentResponse\"\x00\x12@\n\rEndExperiment\x12\x15.EndExperimentRequest\x1a\x16.EndExperimentResponse\"\x00\x12\x43\n\x10GetBatteryStatus\x12\x15.BatteryStatusRequest\x1a\x16.BatteryStatusResponse\"\x00\x12L\n\x13GetDatasetModelInfo\x12\x18.DatasetModelInfoRequest\x1a\x19.DatasetModelInfoResponse\"\x00\x12y\n TrainGlobalParallelSplitLearning\x12(.TrainGlobalParallelSplitLearningRequest\x1a).TrainGlobalParallelSplitLearningResponse\"\x00\x12W\n\x18TrainSingleBatchOnClient\x12\x1b.SingleBatchTrainingRequest\x1a\x1c.SingleBatchTrainingResponse\"\x00\x12\x65\n&BackwardPropagationSingleBatchOnClient\x12\x1b.SingleBatchBackwardRequest\x1a\x1c.SingleBatchBackwardResponse\"\x00\x12\x45\n#SetGradientsAndFinalizeTrainingStep\x12\x14.SetGradientsRequest\x1a\x06.Empty\"\x00\x62\x06proto3')
 
 _globals = globals()
 _builder.BuildMessageAndEnumDescriptors(DESCRIPTOR, _globals)
@@ -39,53 +39,53 @@ if _descriptor._USE_C_DESCRIPTORS == False:
   _globals['_TRAINGLOBALPARALLELSPLITLEARNINGRESPONSE']._serialized_start=737
   _globals['_TRAINGLOBALPARALLELSPLITLEARNINGRESPONSE']._serialized_end=1002
   _globals['_TRAINGLOBALREQUEST']._serialized_start=1005
-  _globals['_TRAINGLOBALREQUEST']._serialized_end=1139
-  _globals['_TRAINGLOBALRESPONSE']._serialized_start=1142
-  _globals['_TRAINGLOBALRESPONSE']._serialized_end=1386
-  _globals['_SETWEIGHTSREQUEST']._serialized_start=1388
-  _globals['_SETWEIGHTSREQUEST']._serialized_end=1453
-  _globals['_SETWEIGHTSRESPONSE']._serialized_start=1455
-  _globals['_SETWEIGHTSRESPONSE']._serialized_end=1541
-  _globals['_TRAINEPOCHREQUEST']._serialized_start=1543
-  _globals['_TRAINEPOCHREQUEST']._serialized_end=1627
-  _globals['_TRAINEPOCHRESPONSE']._serialized_start=1629
-  _globals['_TRAINEPOCHRESPONSE']._serialized_end=1742
-  _globals['_TRAINBATCHREQUEST']._serialized_start=1744
-  _globals['_TRAINBATCHREQUEST']._serialized_end=1824
-  _globals['_TRAINBATCHRESPONSE']._serialized_start=1827
-  _globals['_TRAINBATCHRESPONSE']._serialized_end=1972
-  _globals['_EVALGLOBALREQUEST']._serialized_start=1974
-  _globals['_EVALGLOBALREQUEST']._serialized_end=2032
-  _globals['_EVALGLOBALRESPONSE']._serialized_start=2034
-  _globals['_EVALGLOBALRESPONSE']._serialized_end=2147
-  _globals['_EVALREQUEST']._serialized_start=2149
-  _globals['_EVALREQUEST']._serialized_end=2211
-  _globals['_EVALRESPONSE']._serialized_start=2213
-  _globals['_EVALRESPONSE']._serialized_end=2293
-  _globals['_EVALBATCHREQUEST']._serialized_start=2295
-  _globals['_EVALBATCHREQUEST']._serialized_end=2374
-  _globals['_EVALBATCHRESPONSE']._serialized_start=2376
-  _globals['_EVALBATCHRESPONSE']._serialized_end=2488
-  _globals['_FULLMODELTRAINREQUEST']._serialized_start=2490
-  _globals['_FULLMODELTRAINREQUEST']._serialized_end=2549
-  _globals['_FULLMODELTRAINRESPONSE']._serialized_start=2552
-  _globals['_FULLMODELTRAINRESPONSE']._serialized_end=2758
-  _globals['_STARTEXPERIMENTREQUEST']._serialized_start=2760
-  _globals['_STARTEXPERIMENTREQUEST']._serialized_end=2784
-  _globals['_STARTEXPERIMENTRESPONSE']._serialized_start=2786
-  _globals['_STARTEXPERIMENTRESPONSE']._serialized_end=2877
-  _globals['_ENDEXPERIMENTREQUEST']._serialized_start=2879
-  _globals['_ENDEXPERIMENTREQUEST']._serialized_end=2901
-  _globals['_ENDEXPERIMENTRESPONSE']._serialized_start=2903
-  _globals['_ENDEXPERIMENTRESPONSE']._serialized_end=2992
-  _globals['_BATTERYSTATUSREQUEST']._serialized_start=2994
-  _globals['_BATTERYSTATUSREQUEST']._serialized_end=3016
-  _globals['_BATTERYSTATUSRESPONSE']._serialized_start=3018
-  _globals['_BATTERYSTATUSRESPONSE']._serialized_end=3139
-  _globals['_DATASETMODELINFOREQUEST']._serialized_start=3141
-  _globals['_DATASETMODELINFOREQUEST']._serialized_end=3166
-  _globals['_DATASETMODELINFORESPONSE']._serialized_start=3169
-  _globals['_DATASETMODELINFORESPONSE']._serialized_end=3462
-  _globals['_DEVICE']._serialized_start=3465
-  _globals['_DEVICE']._serialized_end=4609
+  _globals['_TRAINGLOBALREQUEST']._serialized_end=1213
+  _globals['_TRAINGLOBALRESPONSE']._serialized_start=1216
+  _globals['_TRAINGLOBALRESPONSE']._serialized_end=1460
+  _globals['_SETWEIGHTSREQUEST']._serialized_start=1462
+  _globals['_SETWEIGHTSREQUEST']._serialized_end=1527
+  _globals['_SETWEIGHTSRESPONSE']._serialized_start=1529
+  _globals['_SETWEIGHTSRESPONSE']._serialized_end=1615
+  _globals['_TRAINEPOCHREQUEST']._serialized_start=1617
+  _globals['_TRAINEPOCHREQUEST']._serialized_end=1701
+  _globals['_TRAINEPOCHRESPONSE']._serialized_start=1703
+  _globals['_TRAINEPOCHRESPONSE']._serialized_end=1816
+  _globals['_TRAINBATCHREQUEST']._serialized_start=1818
+  _globals['_TRAINBATCHREQUEST']._serialized_end=1898
+  _globals['_TRAINBATCHRESPONSE']._serialized_start=1901
+  _globals['_TRAINBATCHRESPONSE']._serialized_end=2046
+  _globals['_EVALGLOBALREQUEST']._serialized_start=2048
+  _globals['_EVALGLOBALREQUEST']._serialized_end=2106
+  _globals['_EVALGLOBALRESPONSE']._serialized_start=2108
+  _globals['_EVALGLOBALRESPONSE']._serialized_end=2221
+  _globals['_EVALREQUEST']._serialized_start=2223
+  _globals['_EVALREQUEST']._serialized_end=2285
+  _globals['_EVALRESPONSE']._serialized_start=2287
+  _globals['_EVALRESPONSE']._serialized_end=2367
+  _globals['_EVALBATCHREQUEST']._serialized_start=2369
+  _globals['_EVALBATCHREQUEST']._serialized_end=2448
+  _globals['_EVALBATCHRESPONSE']._serialized_start=2450
+  _globals['_EVALBATCHRESPONSE']._serialized_end=2562
+  _globals['_FULLMODELTRAINREQUEST']._serialized_start=2564
+  _globals['_FULLMODELTRAINREQUEST']._serialized_end=2623
+  _globals['_FULLMODELTRAINRESPONSE']._serialized_start=2626
+  _globals['_FULLMODELTRAINRESPONSE']._serialized_end=2832
+  _globals['_STARTEXPERIMENTREQUEST']._serialized_start=2834
+  _globals['_STARTEXPERIMENTREQUEST']._serialized_end=2858
+  _globals['_STARTEXPERIMENTRESPONSE']._serialized_start=2860
+  _globals['_STARTEXPERIMENTRESPONSE']._serialized_end=2951
+  _globals['_ENDEXPERIMENTREQUEST']._serialized_start=2953
+  _globals['_ENDEXPERIMENTREQUEST']._serialized_end=2975
+  _globals['_ENDEXPERIMENTRESPONSE']._serialized_start=2977
+  _globals['_ENDEXPERIMENTRESPONSE']._serialized_end=3066
+  _globals['_BATTERYSTATUSREQUEST']._serialized_start=3068
+  _globals['_BATTERYSTATUSREQUEST']._serialized_end=3090
+  _globals['_BATTERYSTATUSRESPONSE']._serialized_start=3092
+  _globals['_BATTERYSTATUSRESPONSE']._serialized_end=3213
+  _globals['_DATASETMODELINFOREQUEST']._serialized_start=3215
+  _globals['_DATASETMODELINFOREQUEST']._serialized_end=3240
+  _globals['_DATASETMODELINFORESPONSE']._serialized_start=3243
+  _globals['_DATASETMODELINFORESPONSE']._serialized_end=3536
+  _globals['_DEVICE']._serialized_start=3539
+  _globals['_DEVICE']._serialized_end=4683
 # @@protoc_insertion_point(module_scope)
diff --git a/edml/generated/connection_pb2.pyi b/edml/generated/connection_pb2.pyi
index 11b713d56cdbe69e9f9b68a52bd9aa51693f49b3..cbbaadd8f74a800e77028145de74c967ddec3956 100644
--- a/edml/generated/connection_pb2.pyi
+++ b/edml/generated/connection_pb2.pyi
@@ -72,14 +72,16 @@ class TrainGlobalParallelSplitLearningResponse(_message.Message):
     def __init__(self, client_weights: _Optional[_Union[_datastructures_pb2.Weights, _Mapping]] = ..., server_weights: _Optional[_Union[_datastructures_pb2.Weights, _Mapping]] = ..., metrics: _Optional[_Union[_datastructures_pb2.Metrics, _Mapping]] = ..., optimizer_state: _Optional[_Union[_datastructures_pb2.StateDict, _Mapping]] = ..., diagnostic_metrics: _Optional[_Union[_datastructures_pb2.Metrics, _Mapping]] = ...) -> None: ...
 
 class TrainGlobalRequest(_message.Message):
-    __slots__ = ["epochs", "round_no", "optimizer_state"]
+    __slots__ = ["epochs", "round_no", "adaptive_learning_threshold", "optimizer_state"]
     EPOCHS_FIELD_NUMBER: _ClassVar[int]
     ROUND_NO_FIELD_NUMBER: _ClassVar[int]
+    ADAPTIVE_LEARNING_THRESHOLD_FIELD_NUMBER: _ClassVar[int]
     OPTIMIZER_STATE_FIELD_NUMBER: _ClassVar[int]
     epochs: int
     round_no: int
+    adaptive_learning_threshold: float
     optimizer_state: _datastructures_pb2.StateDict
-    def __init__(self, epochs: _Optional[int] = ..., round_no: _Optional[int] = ..., optimizer_state: _Optional[_Union[_datastructures_pb2.StateDict, _Mapping]] = ...) -> None: ...
+    def __init__(self, epochs: _Optional[int] = ..., round_no: _Optional[int] = ..., adaptive_learning_threshold: _Optional[float] = ..., optimizer_state: _Optional[_Union[_datastructures_pb2.StateDict, _Mapping]] = ...) -> None: ...
 
 class TrainGlobalResponse(_message.Message):
     __slots__ = ["client_weights", "server_weights", "metrics", "optimizer_state", "diagnostic_metrics"]
diff --git a/edml/proto/connection.proto b/edml/proto/connection.proto
index 6fecb49b933ebdfed642b247322b2a099fb5901e..b0441d99a3397af8b85fc2f4eb99190accb05242 100644
--- a/edml/proto/connection.proto
+++ b/edml/proto/connection.proto
@@ -66,7 +66,8 @@ message TrainGlobalParallelSplitLearningResponse {
 message TrainGlobalRequest {
   int32 epochs = 1;
   optional int32 round_no = 2;
-  optional StateDict optimizer_state = 3;
+  optional double adaptive_learning_threshold = 3;
+  optional StateDict optimizer_state = 4;
 
 }
 
diff --git a/edml/tests/controllers/swarm_controller_test.py b/edml/tests/controllers/swarm_controller_test.py
index 6c025f4b78a2a648b6055751cf5875b4b7224494..4cde758ebf1199882293e115fb479d25a80fecf6 100644
--- a/edml/tests/controllers/swarm_controller_test.py
+++ b/edml/tests/controllers/swarm_controller_test.py
@@ -37,7 +37,9 @@ class SwarmControllerTest(unittest.TestCase):
         )
 
         client_weights, server_weights, metrics, optimizer_state, diagnostic_metrics = (
-            self.swarm_controller._swarm_train_round(None, None, "d1", 0)
+            self.swarm_controller._swarm_train_round(
+                None, None, "d1", 0, optimizer_state={"optimizer_state": 43}
+            )
         )
 
         self.assertEqual(client_weights, {"weights": 42})
@@ -52,7 +54,11 @@ class SwarmControllerTest(unittest.TestCase):
             ]
         )
         self.mock.train_global_on.assert_called_once_with(
-            "d1", epochs=1, round_no=0, optimizer_state=None
+            "d1",
+            epochs=1,
+            round_no=0,
+            adaptive_learning_threshold=0.0,
+            optimizer_state={"optimizer_state": 43},
         )
 
     def test_split_train_round_with_inactive_server_device(self):
@@ -74,7 +80,11 @@ class SwarmControllerTest(unittest.TestCase):
             ]
         )
         self.mock.train_global_on.assert_called_once_with(
-            "d1", epochs=1, round_no=0, optimizer_state=None
+            "d1",
+            epochs=1,
+            round_no=0,
+            adaptive_learning_threshold=0.0,
+            optimizer_state=None,
         )
 
 
diff --git a/edml/tests/core/device_test.py b/edml/tests/core/device_test.py
index a0d2be2eaf7c4f2f1a8e27b3f1182e842658c87d..31371cebea53098c3344cc64b2d2f25614836cb7 100644
--- a/edml/tests/core/device_test.py
+++ b/edml/tests/core/device_test.py
@@ -131,7 +131,12 @@ class RPCDeviceServicerTest(unittest.TestCase):
             {"optimizer_state": 44},
             self.diagnostic_metrics,
         )
-        request = connection_pb2.TrainGlobalRequest(epochs=42)
+        request = connection_pb2.TrainGlobalRequest(
+            epochs=42,
+            round_no=1,
+            adaptive_learning_threshold=3,
+            optimizer_state=state_dict_to_proto({"optimizer_state": 42}),
+        )
 
         response, metadata, code, details = self.make_call("TrainGlobal", request)
 
@@ -147,7 +152,9 @@ class RPCDeviceServicerTest(unittest.TestCase):
             proto_to_state_dict(response.optimizer_state), {"optimizer_state": 44}
         )
         self.assertEqual(code, StatusCode.OK)
-        self.mock_device.train_global.assert_called_once_with(42, 0)
+        self.mock_device.train_global.assert_called_once_with(
+            42, 1, 3, {"optimizer_state": 42}
+        )
         self.assertEqual(
             proto_to_metrics(response.diagnostic_metrics), self.diagnostic_metrics
         )
@@ -357,7 +364,9 @@ class RPCDeviceServicerBatteryEmptyTest(unittest.TestCase):
         self.mock_device.train_global.side_effect = BatteryEmptyException(
             "Battery empty"
         )
-        request = connection_pb2.TrainGlobalRequest()
+        request = connection_pb2.TrainGlobalRequest(
+            optimizer_state=state_dict_to_proto(None)
+        )
         self._test_device_out_of_battery_lets_rpc_fail(request, "TrainGlobal")
 
     def test_stop_at_set_weights(self):
@@ -504,7 +513,7 @@ class RequestDispatcherTest(unittest.TestCase):
         )
 
         client_weights, server_weights, metrics, optimizer_state, diagnostic_metrics = (
-            self.dispatcher.train_global_on("1", 42, 43)
+            self.dispatcher.train_global_on("1", 42, 43, 3, {"optimizer_state": 44})
         )
 
         self.assertEqual(client_weights, self.weights)
@@ -515,19 +524,31 @@ class RequestDispatcherTest(unittest.TestCase):
         self._assert_field_size_added_to_diagnostic_metrics(diagnostic_metrics)
         self.mock_stub.TrainGlobal.assert_called_once_with(
             connection_pb2.TrainGlobalRequest(
-                epochs=42, round_no=43, optimizer_state=state_dict_to_proto(None)
+                epochs=42,
+                round_no=43,
+                adaptive_learning_threshold=3,
+                optimizer_state=state_dict_to_proto({"optimizer_state": 44}),
             )
         )
 
     def test_train_global_on_with_error(self):
         self.mock_stub.TrainGlobal.side_effect = grpc.RpcError()
 
-        response = self.dispatcher.train_global_on("1", 42, round_no=43)
+        response = self.dispatcher.train_global_on(
+            "1",
+            42,
+            round_no=43,
+            adaptive_learning_threshold=3,
+            optimizer_state={"optimizer_state": 44},
+        )
 
         self.assertEqual(response, False)
         self.mock_stub.TrainGlobal.assert_called_once_with(
             connection_pb2.TrainGlobalRequest(
-                epochs=42, round_no=43, optimizer_state=state_dict_to_proto(None)
+                epochs=42,
+                round_no=43,
+                adaptive_learning_threshold=3,
+                optimizer_state=state_dict_to_proto({"optimizer_state": 44}),
             )
         )