diff --git a/config/scheduler/multistep.yaml b/config/scheduler/multistep.yaml
index 77af7c04aabbd32634c35cebc58bceb7426dc941..6caf71ace59d32e6ecb0eb61e2ad85ec03bbe965 100644
--- a/config/scheduler/multistep.yaml
+++ b/config/scheduler/multistep.yaml
@@ -1,3 +1,3 @@
 _target_: torch.optim.lr_scheduler.MultiStepLR
-milestones: [ 100, 150 ]
+milestones: [ 101, 151 ]
 gamma: 0.1
diff --git a/edml/config/battery/resnet110_cifar100_cost.yaml b/edml/config/battery/resnet110_cifar100_cost.yaml
new file mode 100644
index 0000000000000000000000000000000000000000..d1bcdb98d5d0d98d49eedff71d3cf12b9e1bf9d9
--- /dev/null
+++ b/edml/config/battery/resnet110_cifar100_cost.yaml
@@ -0,0 +1,4 @@
+deduction_per_second: 0.005
+deduction_per_mflop: 0.00000005
+deduction_per_mbyte_received: 0.0002
+deduction_per_mbyte_sent: 0.0002
diff --git a/edml/config/topology/equal_batteries_10_devices.yaml b/edml/config/topology/equal_batteries_10_devices.yaml
new file mode 100644
index 0000000000000000000000000000000000000000..ea81cd38a17f91a049a4bb82440a87239e00cbbf
--- /dev/null
+++ b/edml/config/topology/equal_batteries_10_devices.yaml
@@ -0,0 +1,62 @@
+devices: [
+    {
+        device_id: "d0",
+        address: "localhost:50051",
+        battery_capacity: 1000000,
+        torch_device: cuda:0
+    },
+    {
+        device_id: "d1",
+        address: "localhost:50052",
+        battery_capacity: 1000000,
+        torch_device: cuda:1
+    },
+    {
+        device_id: "d2",
+        address: "localhost:50053",
+        battery_capacity: 1000000,
+        torch_device: cuda:2
+    },
+    {
+        device_id: "d3",
+        address: "localhost:50054",
+        battery_capacity: 1000000,
+        torch_device: cuda:0
+    },
+    {
+        device_id: "d4",
+        address: "localhost:50055",
+        battery_capacity: 1000000,
+        torch_device: cuda:1
+    },
+    {
+        device_id: "d5",
+        address: "localhost:50056",
+        battery_capacity: 1000000,
+        torch_device: cuda:2
+    },
+    {
+        device_id: "d6",
+        address: "localhost:50057",
+        battery_capacity: 1000000,
+        torch_device: cuda:0
+    },
+    {
+        device_id: "d7",
+        address: "localhost:50058",
+        battery_capacity: 1000000,
+        torch_device: cuda:1
+    },
+    {
+        device_id: "d8",
+        address: "localhost:50059",
+        battery_capacity: 1000000,
+        torch_device: cuda:2
+    },
+    {
+        device_id: "d9",
+        address: "localhost:50060",
+        battery_capacity: 1000000,
+        torch_device: cuda:0
+    }
+]
diff --git a/edml/config/topology/resnet110_cifar100_batteries.yaml b/edml/config/topology/resnet110_cifar100_batteries.yaml
new file mode 100644
index 0000000000000000000000000000000000000000..3a81e0087447bd83066bf31840321d1cc2b4413d
--- /dev/null
+++ b/edml/config/topology/resnet110_cifar100_batteries.yaml
@@ -0,0 +1,27 @@
+devices: [
+  {
+    device_id: "d0",
+    address: "localhost:50051",
+    battery_capacity: 750,
+  },
+  {
+    device_id: "d1",
+    address: "localhost:50052",
+    battery_capacity: 750
+  },
+  {
+    device_id: "d2",
+    address: "localhost:50053",
+    battery_capacity: 600
+  },
+  {
+    device_id: "d3",
+    address: "localhost:50054",
+    battery_capacity: 600
+  },
+  {
+    device_id: "d4",
+    address: "localhost:50055",
+    battery_capacity: 600
+  }
+]
diff --git a/edml/controllers/strategy_optimization.py b/edml/controllers/strategy_optimization.py
index eddbf5742a526285cf6a3793b208900beda9f3ec..cd7bae8ac0db1c8eb4bcf818713ed8b2f1ed0e47 100644
--- a/edml/controllers/strategy_optimization.py
+++ b/edml/controllers/strategy_optimization.py
@@ -123,13 +123,15 @@ class ServerChoiceOptimizer:
             self.global_params.train_global_time is not None
             and self.global_params.last_server_device_id is not None
         ):
-            return (
+            latency = (
                 self.global_params.train_global_time
                 - self._round_runtime_with_server_no_latency(
                     self.global_params.last_server_device_id
                 )
             )
-        return 0  # latency not known
+            if latency > 0:
+                return latency
+        return 0  # latency not known or runtime was overestimated previously
 
     def _round_runtime_with_server_no_latency(self, server_device_id):
         """
@@ -356,20 +358,7 @@ class EnergySimulator:
             device_params_list, global_params
         )
 
-    def simulate_greedy_selection(self):
-        """
-        Simulates the greedy server choice algorithm.
-        Returns:
-            num_rounds: number of rounds until the first device runs out of battery
-            server_selection_schedule: list of server device ids for each round
-            device_batteries: list of battery levels for each device after the last successful round
-        """
-
-        def __get_device_with_max_battery__(device_battery_list):
-            return max(
-                range(len(device_battery_list)), key=device_battery_list.__getitem__
-            )
-
+    def _simulate_selection(self, selection_callback=None):
         def __all_devices_alive__(device_battery_list):
             return all(battery > 0 for battery in device_battery_list)
 
@@ -394,7 +383,9 @@ class EnergySimulator:
         server_selection_schedule = []
         num_rounds = 0
         while all_devices_alive:
-            server_device_idx = __get_device_with_max_battery__(device_batteries)
+            server_device_idx = selection_callback(
+                device_battery_list=device_batteries, num_rounds=num_rounds
+            )
             new_batteries = device_batteries.copy()
             for idx, device in enumerate(self.device_params_list):
                 new_batteries[idx] = new_batteries[idx] - energy[idx][server_device_idx]
@@ -409,6 +400,36 @@ class EnergySimulator:
                 break
         return num_rounds, server_selection_schedule, device_batteries
 
+    def simulate_greedy_selection(self):
+        """
+        Simulates the greedy server choice algorithm.
+        Returns:
+            num_rounds: number of rounds until the first device runs out of battery
+            server_selection_schedule: list of server device ids for each round
+            device_batteries: list of battery levels for each device after the last successful round
+        """
+
+        def __get_device_with_max_battery__(device_battery_list, **kwargs):
+            return max(
+                range(len(device_battery_list)), key=device_battery_list.__getitem__
+            )
+
+        return self._simulate_selection(__get_device_with_max_battery__)
+
+    def simulate_sequential_selection(self):
+        """
+        Simulates the sequential server choice algorithm.
+        Returns:
+            num_rounds: number of rounds until the first device runs out of battery
+            server_selection_schedule: list of server device ids for each round
+            device_batteries: list of battery levels for each device after the last successful round
+        """
+
+        def __sequential_selection__(device_battery_list, num_rounds):
+            return num_rounds % len(device_battery_list)
+
+        return self._simulate_selection(__sequential_selection__)
+
     def simulate_smart_selection(self):
         """
         Simulates the smart server choice algorithm.
@@ -574,47 +595,52 @@ def run_grid_search(
             for partition in partitions:
                 for cost_sec in cost_per_sec:
                     for cost_sent in cost_per_byte_sent:
-                        for cost_received in cost_per_byte_received:
-                            for cost_flop in cost_per_flop:
-                                global_params.cost_per_sec = cost_sec
-                                global_params.cost_per_byte_sent = cost_sent
-                                global_params.cost_per_byte_received = cost_received
-                                global_params.cost_per_flop = cost_flop
-                                for idx, device in enumerate(device_params_list):
-                                    device.current_battery = battery[idx]
-                                    device.comp_latency_factor = latency[idx]
-                                    device.train_samples = (
-                                        partition[idx] * total_train_samples
-                                    )
-                                    device.validation_samples = (
-                                        partition[idx] * total_val_samples
-                                    )
-                                energy_simulator = EnergySimulator(
-                                    device_params_list, global_params
-                                )
-                                num_rounds_smart, _, _ = (
-                                    energy_simulator.simulate_smart_selection()
-                                )
-                                num_rounds_greedy, _, _ = (
-                                    energy_simulator.simulate_greedy_selection()
-                                )
-                                num_rounds_fl, _ = (
-                                    energy_simulator.simulate_federated_learning()
+                        # for cost_received in cost_per_byte_received:
+                        cost_received = cost_sent
+                        for cost_flop in cost_per_flop:
+                            global_params.cost_per_sec = cost_sec
+                            global_params.cost_per_byte_sent = cost_sent
+                            global_params.cost_per_byte_received = cost_received
+                            global_params.cost_per_flop = cost_flop
+                            for idx, device in enumerate(device_params_list):
+                                device.current_battery = battery[idx]
+                                device.comp_latency_factor = latency[idx]
+                                device.train_samples = (
+                                    partition[idx] * total_train_samples
                                 )
-                                results.append(
-                                    {
-                                        "battery": battery,
-                                        "latency": latency,
-                                        "partition": partition,
-                                        "cost_per_sec": cost_sec,
-                                        "cost_per_byte_sent": cost_sent,
-                                        "cost_per_byte_received": cost_received,
-                                        "cost_per_flop": cost_flop,
-                                        "num_rounds_smart": num_rounds_smart,
-                                        "num_rounds_greedy": num_rounds_greedy,
-                                        "num_rounds_fl": num_rounds_fl,
-                                    }
+                                device.validation_samples = (
+                                    partition[idx] * total_val_samples
                                 )
+                            energy_simulator = EnergySimulator(
+                                device_params_list, global_params
+                            )
+                            num_rounds_smart, _, _ = (
+                                energy_simulator.simulate_smart_selection()
+                            )
+                            num_rounds_greedy, _, _ = (
+                                energy_simulator.simulate_greedy_selection()
+                            )
+                            num_rounds_seq, _, _ = (
+                                energy_simulator.simulate_sequential_selection()
+                            )
+                            num_rounds_fl, _ = (
+                                energy_simulator.simulate_federated_learning()
+                            )
+                            results.append(
+                                {
+                                    "battery": battery,
+                                    "latency": latency,
+                                    "partition": partition,
+                                    "cost_per_sec": cost_sec,
+                                    "cost_per_byte_sent": cost_sent,
+                                    "cost_per_byte_received": cost_received,
+                                    "cost_per_flop": cost_flop,
+                                    "num_rounds_smart": num_rounds_smart,
+                                    "num_rounds_seq": num_rounds_seq,
+                                    "num_rounds_greedy": num_rounds_greedy,
+                                    "num_rounds_fl": num_rounds_fl,
+                                }
+                            )
     return results
 
 
diff --git a/edml/controllers/test_controller.py b/edml/controllers/test_controller.py
index ba5e7c187a745d13f8dae9673d3518fc9015b9d2..a3f0ca972c4b98e44e0d05d228db5562239ceb57 100644
--- a/edml/controllers/test_controller.py
+++ b/edml/controllers/test_controller.py
@@ -43,7 +43,7 @@ class TestController(BaseController):
 
     def _get_model_with_highest_postfix_number(self, model_save_path):
         """Returns the highest postfix number in the given directory for the configured model_prefix."""
-        model_prefix = self.__model_prefix__()
+        model_prefix = f"{self.__model_prefix__()}_client_"  # assume server and client weights were saved appropriately
         highest_postfix_number = 0
         for file in os.listdir(model_save_path):
             if file.startswith(model_prefix):
diff --git a/edml/core/client.py b/edml/core/client.py
index b3c3e7491e718cf98c9ccde598e00fd4d2054996..d8096c7ca5bc479a2aec458f2daf1cc442db4986 100644
--- a/edml/core/client.py
+++ b/edml/core/client.py
@@ -18,7 +18,7 @@ from edml.helpers.decorators import (
 from edml.helpers.flops import estimate_model_flops
 from edml.helpers.load_optimizer import get_optimizer_and_scheduler
 from edml.helpers.metrics import DiagnosticMetricResultContainer, DiagnosticMetricResult
-from edml.helpers.types import StateDict, SLTrainBatchResult
+from edml.helpers.types import StateDict
 
 if TYPE_CHECKING:
     from edml.core.device import Device
@@ -136,12 +136,18 @@ class DeviceClient:
 
     @check_device_set()
     def train_single_batch(
-        self, batch_index: int
+        self, batch_index: int, round_no: int = -1
     ) -> Optional[torch.Tensor, torch.Tensor]:
         torch.cuda.set_device(self._device)
         # We have to re-initialize the data loader in the case that we do another epoch.
         if batch_index == 0:
             self._batchable_data_loader = iter(self._train_data)
+            # update lr scheduler in the beginning of each round
+            if self._lr_scheduler is not None:
+                if round_no != -1:
+                    self._lr_scheduler.step(round_no)
+                else:
+                    self._lr_scheduler.step()
 
         # Used to measure training time. The problem we have with parallel split learning is that forward- and backward-
         # passes are orchestrated by the current server.
@@ -258,6 +264,11 @@ class DeviceClient:
             that, this approach does not require to deduce server batch processing time after a "traditional"
             measurement.
         """
+        if self._lr_scheduler is not None:
+            if round_no != -1:
+                self._lr_scheduler.step(round_no)
+            else:
+                self._lr_scheduler.step()
         client_train_start_time = time.time()
         server_train_batch_times = (
             []
@@ -295,12 +306,6 @@ class DeviceClient:
                 smashed_data.backward(server_grad)
                 self._optimizer.step()
 
-        if self._lr_scheduler is not None:
-            if round_no != -1:
-                self._lr_scheduler.step(round_no)
-            else:
-                self._lr_scheduler.step()
-
         client_train_time = (
             time.time() - client_train_start_time - sum(server_train_batch_times)
         )
diff --git a/edml/core/device.py b/edml/core/device.py
index 1348bea7e4a4fadb44a6bd304bfc466f44bbba41..94bb18ac35241fc86eec778b8ef2d9145a8ed51f 100644
--- a/edml/core/device.py
+++ b/edml/core/device.py
@@ -207,7 +207,9 @@ class Device(ABC):
         """Evaluates a batch on the server of the device with the given id"""
 
     @abstractmethod
-    def train_batch_on_client_only_on(self, device_id: str, batch_index: int):
+    def train_batch_on_client_only_on(
+        self, device_id: str, batch_index: int, round_no: int
+    ):
         """"""
 
     @abstractmethod
@@ -265,17 +267,23 @@ class NetworkDevice(Device):
 
     @update_battery
     @log_execution_time("logger", "client_only_batch_train")
-    def train_batch_on_client_only(self, batch_index: int):
-        smashed_data, labels = self.client.train_single_batch(batch_index=batch_index)
+    def train_batch_on_client_only(self, batch_index: int, round_no: int):
+        smashed_data, labels = self.client.train_single_batch(
+            batch_index=batch_index, round_no=round_no
+        )
         return smashed_data, labels
 
     @update_battery
-    def train_batch_on_client_only_on(self, device_id: str, batch_index: int):
+    def train_batch_on_client_only_on(
+        self, device_id: str, batch_index: int, round_no: int
+    ):
         if self.device_id == device_id:
-            return self.train_batch_on_client_only(batch_index=batch_index)
+            return self.train_batch_on_client_only(
+                batch_index=batch_index, round_no=round_no
+            )
         else:
             return self.request_dispatcher.train_batch_on_client_only(
-                device_id=device_id, batch_index=batch_index
+                device_id=device_id, batch_index=batch_index, round_no=round_no
             )
 
     def __init__(
@@ -575,8 +583,11 @@ class RPCDeviceServicer(DeviceServicer):
 
     def TrainSingleBatchOnClient(self, request, context):
         batch_index = request.batch_index
+        round_no = request.round_no
 
-        smashed_data, labels = self.device.client.train_single_batch(batch_index)
+        smashed_data, labels = self.device.client.train_single_batch(
+            batch_index, round_no=round_no
+        )
 
         smashed_data = Activations(activations=tensor_to_proto(smashed_data))
         labels = Labels(labels=tensor_to_proto(labels))
@@ -955,13 +966,15 @@ class DeviceRequestDispatcher:
         return False
 
     def train_batch_on_client_only(
-        self, device_id: str, batch_index: int
+        self, device_id: str, batch_index: int, round_no: int
     ) -> Tuple[Tensor, Tensor] | None:
         try:
             response: SingleBatchTrainingResponse = self._get_connection(
                 device_id
             ).TrainSingleBatchOnClient(
-                connection_pb2.SingleBatchTrainingRequest(batch_index=batch_index)
+                connection_pb2.SingleBatchTrainingRequest(
+                    batch_index=batch_index, round_no=round_no
+                )
             )
 
             # The response can only be None if the last batch was smaller than the configured batch size.
diff --git a/edml/core/server.py b/edml/core/server.py
index 5e2f8235de476c5a998f9c8c040b462ac9d3b198..4b9ef80fd876d618075a98a14d61e03d83369b66 100644
--- a/edml/core/server.py
+++ b/edml/core/server.py
@@ -90,6 +90,11 @@ class DeviceServer:
         if optimizer_state is not None:
             self._optimizer.load_state_dict(optimizer_state)
         for epoch in range(epochs):
+            if self._lr_scheduler is not None:
+                if round_no != -1:
+                    self._lr_scheduler.step(round_no + epoch)
+                else:
+                    self._lr_scheduler.step()
             for device_id in devices:
                 print(
                     f"Train epoch {epoch} on client {device_id} with server {self.node_device.device_id}"
@@ -120,11 +125,6 @@ class DeviceServer:
 
                     metrics.add_results(train_metrics)
                     metrics.add_results(val_metrics)
-            if self._lr_scheduler is not None:
-                if round_no != -1:
-                    self._lr_scheduler.step(round_no + epoch)
-                else:
-                    self._lr_scheduler.step()
         return (
             client_weights,
             self.get_weights(),
@@ -220,7 +220,10 @@ class DeviceServer:
     ):
         def client_training_job(client_id: str, batch_index: int):
             result = self.node_device.train_batch_on_client_only_on(
-                device_id=client_id, batch_index=batch_index
+                device_id=client_id,
+                batch_index=batch_index,
+                round_no=round_no,
+                # round_no is taken from outer method arg
             )
             return (client_id, result)
 
@@ -241,6 +244,12 @@ class DeviceServer:
         if optimizer_state is not None:
             self._optimizer.load_state_dict(optimizer_state)
 
+        if self._lr_scheduler is not None:
+            if round_no != -1:
+                self._lr_scheduler.step(round_no + 1)  # epoch=1
+            else:
+                self._lr_scheduler.step()
+
         num_threads = len(clients)
         executor = create_executor_with_threads(num_threads)
 
@@ -289,7 +298,9 @@ class DeviceServer:
                 print(
                     f"\n{Fore.RED}ADAPTIVE TRESHOLD REACHED, NEXT BATCH\n{Fore.RESET}"
                 )
-                self.node_device.log({"adaptive_learning_threshold_applied": True})
+                self.node_device.log(
+                    {"adaptive_learning_threshold_applied": server_gradients.size(0)}
+                )
                 continue
 
             num_client_gradients = len(client_forward_pass_responses)
@@ -346,11 +357,6 @@ class DeviceServer:
             model_metrics.add_results(val_metrics)
 
         optimizer_state = self._optimizer.state_dict()
-        if self._lr_scheduler is not None:
-            if round_no != -1:
-                self._lr_scheduler.step(round_no + 1)  # epoch=1
-            else:
-                self._lr_scheduler.step()
         # delete references and free GPU memory manually
         server_batch = None
         server_labels = None
diff --git a/edml/generated/connection_pb2.py b/edml/generated/connection_pb2.py
index ce271bbab4fa5d30e0e8c5a0a8b309d96eaccf07..f3c72442f44b99d587925a030002e1af92924ac8 100644
--- a/edml/generated/connection_pb2.py
+++ b/edml/generated/connection_pb2.py
@@ -14,7 +14,7 @@ _sym_db = _symbol_database.Default()
 import datastructures_pb2 as datastructures__pb2
 
 
-DESCRIPTOR = _descriptor_pool.Default().AddSerializedFile(b'\n\x10\x63onnection.proto\x1a\x14\x64\x61tastructures.proto\"4\n\x13SetGradientsRequest\x12\x1d\n\tgradients\x18\x01 \x01(\x0b\x32\n.Gradients\"5\n\x14UpdateWeightsRequest\x12\x1d\n\tgradients\x18\x01 \x01(\x0b\x32\n.Gradients\";\n\x1aSingleBatchBackwardRequest\x12\x1d\n\tgradients\x18\x01 \x01(\x0b\x32\n.Gradients\"j\n\x1bSingleBatchBackwardResponse\x12\x19\n\x07metrics\x18\x01 \x01(\x0b\x32\x08.Metrics\x12\"\n\tgradients\x18\x02 \x01(\x0b\x32\n.GradientsH\x00\x88\x01\x01\x42\x0c\n\n_gradients\"1\n\x1aSingleBatchTrainingRequest\x12\x13\n\x0b\x62\x61tch_index\x18\x01 \x01(\x05\"\x80\x01\n\x1bSingleBatchTrainingResponse\x12\'\n\x0csmashed_data\x18\x01 \x01(\x0b\x32\x0c.ActivationsH\x00\x88\x01\x01\x12\x1c\n\x06labels\x18\x02 \x01(\x0b\x32\x07.LabelsH\x01\x88\x01\x01\x42\x0f\n\r_smashed_dataB\t\n\x07_labels\"\xd5\x01\n\'TrainGlobalParallelSplitLearningRequest\x12\x15\n\x08round_no\x18\x01 \x01(\x05H\x00\x88\x01\x01\x12(\n\x1b\x61\x64\x61ptive_learning_threshold\x18\x02 \x01(\x01H\x01\x88\x01\x01\x12(\n\x0foptimizer_state\x18\x03 \x01(\x0b\x32\n.StateDictH\x02\x88\x01\x01\x42\x0b\n\t_round_noB\x1e\n\x1c_adaptive_learning_thresholdB\x12\n\x10_optimizer_state\"\x89\x02\n(TrainGlobalParallelSplitLearningResponse\x12 \n\x0e\x63lient_weights\x18\x01 \x01(\x0b\x32\x08.Weights\x12 \n\x0eserver_weights\x18\x02 \x01(\x0b\x32\x08.Weights\x12\x19\n\x07metrics\x18\x03 \x01(\x0b\x32\x08.Metrics\x12(\n\x0foptimizer_state\x18\x04 \x01(\x0b\x32\n.StateDictH\x00\x88\x01\x01\x12)\n\x12\x64iagnostic_metrics\x18\x05 \x01(\x0b\x32\x08.MetricsH\x01\x88\x01\x01\x42\x12\n\x10_optimizer_stateB\x15\n\x13_diagnostic_metrics\"\x86\x01\n\x12TrainGlobalRequest\x12\x0e\n\x06\x65pochs\x18\x01 \x01(\x05\x12\x15\n\x08round_no\x18\x02 \x01(\x05H\x00\x88\x01\x01\x12(\n\x0foptimizer_state\x18\x03 \x01(\x0b\x32\n.StateDictH\x01\x88\x01\x01\x42\x0b\n\t_round_noB\x12\n\x10_optimizer_state\"\xf4\x01\n\x13TrainGlobalResponse\x12 \n\x0e\x63lient_weights\x18\x01 \x01(\x0b\x32\x08.Weights\x12 \n\x0eserver_weights\x18\x02 \x01(\x0b\x32\x08.Weights\x12\x19\n\x07metrics\x18\x03 \x01(\x0b\x32\x08.Metrics\x12(\n\x0foptimizer_state\x18\x04 \x01(\x0b\x32\n.StateDictH\x00\x88\x01\x01\x12)\n\x12\x64iagnostic_metrics\x18\x05 \x01(\x0b\x32\x08.MetricsH\x01\x88\x01\x01\x42\x12\n\x10_optimizer_stateB\x15\n\x13_diagnostic_metrics\"A\n\x11SetWeightsRequest\x12\x19\n\x07weights\x18\x01 \x01(\x0b\x32\x08.Weights\x12\x11\n\ton_client\x18\x02 \x01(\x08\"V\n\x12SetWeightsResponse\x12)\n\x12\x64iagnostic_metrics\x18\x01 \x01(\x0b\x32\x08.MetricsH\x00\x88\x01\x01\x42\x15\n\x13_diagnostic_metrics\"T\n\x11TrainEpochRequest\x12\x1b\n\x06server\x18\x01 \x01(\x0b\x32\x0b.DeviceInfo\x12\x15\n\x08round_no\x18\x02 \x01(\x05H\x00\x88\x01\x01\x42\x0b\n\t_round_no\"q\n\x12TrainEpochResponse\x12\x19\n\x07weights\x18\x01 \x01(\x0b\x32\x08.Weights\x12)\n\x12\x64iagnostic_metrics\x18\x02 \x01(\x0b\x32\x08.MetricsH\x00\x88\x01\x01\x42\x15\n\x13_diagnostic_metrics\"P\n\x11TrainBatchRequest\x12\"\n\x0csmashed_data\x18\x01 \x01(\x0b\x32\x0c.Activations\x12\x17\n\x06labels\x18\x02 \x01(\x0b\x32\x07.Labels\"\x91\x01\n\x12TrainBatchResponse\x12\x1d\n\tgradients\x18\x01 \x01(\x0b\x32\n.Gradients\x12)\n\x12\x64iagnostic_metrics\x18\x02 \x01(\x0b\x32\x08.MetricsH\x00\x88\x01\x01\x12\x11\n\x04loss\x18\x03 \x01(\x01H\x01\x88\x01\x01\x42\x15\n\x13_diagnostic_metricsB\x07\n\x05_loss\":\n\x11\x45valGlobalRequest\x12\x12\n\nvalidation\x18\x01 \x01(\x08\x12\x11\n\tfederated\x18\x02 \x01(\x08\"q\n\x12\x45valGlobalResponse\x12\x19\n\x07metrics\x18\x01 \x01(\x0b\x32\x08.Metrics\x12)\n\x12\x64iagnostic_metrics\x18\x02 \x01(\x0b\x32\x08.MetricsH\x00\x88\x01\x01\x42\x15\n\x13_diagnostic_metrics\">\n\x0b\x45valRequest\x12\x1b\n\x06server\x18\x01 \x01(\x0b\x32\x0b.DeviceInfo\x12\x12\n\nvalidation\x18\x02 \x01(\x08\"P\n\x0c\x45valResponse\x12)\n\x12\x64iagnostic_metrics\x18\x01 \x01(\x0b\x32\x08.MetricsH\x00\x88\x01\x01\x42\x15\n\x13_diagnostic_metrics\"O\n\x10\x45valBatchRequest\x12\"\n\x0csmashed_data\x18\x01 \x01(\x0b\x32\x0c.Activations\x12\x17\n\x06labels\x18\x02 \x01(\x0b\x32\x07.Labels\"p\n\x11\x45valBatchResponse\x12\x19\n\x07metrics\x18\x01 \x01(\x0b\x32\x08.Metrics\x12)\n\x12\x64iagnostic_metrics\x18\x02 \x01(\x0b\x32\x08.MetricsH\x00\x88\x01\x01\x42\x15\n\x13_diagnostic_metrics\";\n\x15\x46ullModelTrainRequest\x12\x15\n\x08round_no\x18\x01 \x01(\x05H\x00\x88\x01\x01\x42\x0b\n\t_round_no\"\xce\x01\n\x16\x46ullModelTrainResponse\x12 \n\x0e\x63lient_weights\x18\x01 \x01(\x0b\x32\x08.Weights\x12 \n\x0eserver_weights\x18\x02 \x01(\x0b\x32\x08.Weights\x12\x13\n\x0bnum_samples\x18\x03 \x01(\x05\x12\x19\n\x07metrics\x18\x04 \x01(\x0b\x32\x08.Metrics\x12)\n\x12\x64iagnostic_metrics\x18\x05 \x01(\x0b\x32\x08.MetricsH\x00\x88\x01\x01\x42\x15\n\x13_diagnostic_metrics\"\x18\n\x16StartExperimentRequest\"[\n\x17StartExperimentResponse\x12)\n\x12\x64iagnostic_metrics\x18\x01 \x01(\x0b\x32\x08.MetricsH\x00\x88\x01\x01\x42\x15\n\x13_diagnostic_metrics\"\x16\n\x14\x45ndExperimentRequest\"Y\n\x15\x45ndExperimentResponse\x12)\n\x12\x64iagnostic_metrics\x18\x01 \x01(\x0b\x32\x08.MetricsH\x00\x88\x01\x01\x42\x15\n\x13_diagnostic_metrics\"\x16\n\x14\x42\x61tteryStatusRequest\"y\n\x15\x42\x61tteryStatusResponse\x12\x1e\n\x06status\x18\x01 \x01(\x0b\x32\x0e.BatteryStatus\x12)\n\x12\x64iagnostic_metrics\x18\x02 \x01(\x0b\x32\x08.MetricsH\x00\x88\x01\x01\x42\x15\n\x13_diagnostic_metrics\"\x19\n\x17\x44\x61tasetModelInfoRequest\"\xc7\x01\n\x18\x44\x61tasetModelInfoResponse\x12\x15\n\rtrain_samples\x18\x01 \x01(\x05\x12\x1a\n\x12validation_samples\x18\x02 \x01(\x05\x12\x1a\n\x12\x63lient_model_flops\x18\x03 \x01(\x05\x12\x1a\n\x12server_model_flops\x18\x04 \x01(\x05\x12)\n\x12\x64iagnostic_metrics\x18\x05 \x01(\x0b\x32\x08.MetricsH\x00\x88\x01\x01\x42\x15\n\x13_diagnostic_metrics2\xf8\x08\n\x06\x44\x65vice\x12:\n\x0bTrainGlobal\x12\x13.TrainGlobalRequest\x1a\x14.TrainGlobalResponse\"\x00\x12\x37\n\nSetWeights\x12\x12.SetWeightsRequest\x1a\x13.SetWeightsResponse\"\x00\x12\x37\n\nTrainEpoch\x12\x12.TrainEpochRequest\x1a\x13.TrainEpochResponse\"\x00\x12\x37\n\nTrainBatch\x12\x12.TrainBatchRequest\x1a\x13.TrainBatchResponse\"\x00\x12;\n\x0e\x45valuateGlobal\x12\x12.EvalGlobalRequest\x1a\x13.EvalGlobalResponse\"\x00\x12)\n\x08\x45valuate\x12\x0c.EvalRequest\x1a\r.EvalResponse\"\x00\x12\x38\n\rEvaluateBatch\x12\x11.EvalBatchRequest\x1a\x12.EvalBatchResponse\"\x00\x12\x46\n\x11\x46ullModelTraining\x12\x16.FullModelTrainRequest\x1a\x17.FullModelTrainResponse\"\x00\x12\x46\n\x0fStartExperiment\x12\x17.StartExperimentRequest\x1a\x18.StartExperimentResponse\"\x00\x12@\n\rEndExperiment\x12\x15.EndExperimentRequest\x1a\x16.EndExperimentResponse\"\x00\x12\x43\n\x10GetBatteryStatus\x12\x15.BatteryStatusRequest\x1a\x16.BatteryStatusResponse\"\x00\x12L\n\x13GetDatasetModelInfo\x12\x18.DatasetModelInfoRequest\x1a\x19.DatasetModelInfoResponse\"\x00\x12y\n TrainGlobalParallelSplitLearning\x12(.TrainGlobalParallelSplitLearningRequest\x1a).TrainGlobalParallelSplitLearningResponse\"\x00\x12W\n\x18TrainSingleBatchOnClient\x12\x1b.SingleBatchTrainingRequest\x1a\x1c.SingleBatchTrainingResponse\"\x00\x12\x65\n&BackwardPropagationSingleBatchOnClient\x12\x1b.SingleBatchBackwardRequest\x1a\x1c.SingleBatchBackwardResponse\"\x00\x12\x45\n#SetGradientsAndFinalizeTrainingStep\x12\x14.SetGradientsRequest\x1a\x06.Empty\"\x00\x62\x06proto3')
+DESCRIPTOR = _descriptor_pool.Default().AddSerializedFile(b'\n\x10\x63onnection.proto\x1a\x14\x64\x61tastructures.proto\"4\n\x13SetGradientsRequest\x12\x1d\n\tgradients\x18\x01 \x01(\x0b\x32\n.Gradients\"5\n\x14UpdateWeightsRequest\x12\x1d\n\tgradients\x18\x01 \x01(\x0b\x32\n.Gradients\";\n\x1aSingleBatchBackwardRequest\x12\x1d\n\tgradients\x18\x01 \x01(\x0b\x32\n.Gradients\"j\n\x1bSingleBatchBackwardResponse\x12\x19\n\x07metrics\x18\x01 \x01(\x0b\x32\x08.Metrics\x12\"\n\tgradients\x18\x02 \x01(\x0b\x32\n.GradientsH\x00\x88\x01\x01\x42\x0c\n\n_gradients\"C\n\x1aSingleBatchTrainingRequest\x12\x13\n\x0b\x62\x61tch_index\x18\x01 \x01(\x05\x12\x10\n\x08round_no\x18\x02 \x01(\x05\"\x80\x01\n\x1bSingleBatchTrainingResponse\x12\'\n\x0csmashed_data\x18\x01 \x01(\x0b\x32\x0c.ActivationsH\x00\x88\x01\x01\x12\x1c\n\x06labels\x18\x02 \x01(\x0b\x32\x07.LabelsH\x01\x88\x01\x01\x42\x0f\n\r_smashed_dataB\t\n\x07_labels\"\xd5\x01\n\'TrainGlobalParallelSplitLearningRequest\x12\x15\n\x08round_no\x18\x01 \x01(\x05H\x00\x88\x01\x01\x12(\n\x1b\x61\x64\x61ptive_learning_threshold\x18\x02 \x01(\x01H\x01\x88\x01\x01\x12(\n\x0foptimizer_state\x18\x03 \x01(\x0b\x32\n.StateDictH\x02\x88\x01\x01\x42\x0b\n\t_round_noB\x1e\n\x1c_adaptive_learning_thresholdB\x12\n\x10_optimizer_state\"\x89\x02\n(TrainGlobalParallelSplitLearningResponse\x12 \n\x0e\x63lient_weights\x18\x01 \x01(\x0b\x32\x08.Weights\x12 \n\x0eserver_weights\x18\x02 \x01(\x0b\x32\x08.Weights\x12\x19\n\x07metrics\x18\x03 \x01(\x0b\x32\x08.Metrics\x12(\n\x0foptimizer_state\x18\x04 \x01(\x0b\x32\n.StateDictH\x00\x88\x01\x01\x12)\n\x12\x64iagnostic_metrics\x18\x05 \x01(\x0b\x32\x08.MetricsH\x01\x88\x01\x01\x42\x12\n\x10_optimizer_stateB\x15\n\x13_diagnostic_metrics\"\x86\x01\n\x12TrainGlobalRequest\x12\x0e\n\x06\x65pochs\x18\x01 \x01(\x05\x12\x15\n\x08round_no\x18\x02 \x01(\x05H\x00\x88\x01\x01\x12(\n\x0foptimizer_state\x18\x03 \x01(\x0b\x32\n.StateDictH\x01\x88\x01\x01\x42\x0b\n\t_round_noB\x12\n\x10_optimizer_state\"\xf4\x01\n\x13TrainGlobalResponse\x12 \n\x0e\x63lient_weights\x18\x01 \x01(\x0b\x32\x08.Weights\x12 \n\x0eserver_weights\x18\x02 \x01(\x0b\x32\x08.Weights\x12\x19\n\x07metrics\x18\x03 \x01(\x0b\x32\x08.Metrics\x12(\n\x0foptimizer_state\x18\x04 \x01(\x0b\x32\n.StateDictH\x00\x88\x01\x01\x12)\n\x12\x64iagnostic_metrics\x18\x05 \x01(\x0b\x32\x08.MetricsH\x01\x88\x01\x01\x42\x12\n\x10_optimizer_stateB\x15\n\x13_diagnostic_metrics\"A\n\x11SetWeightsRequest\x12\x19\n\x07weights\x18\x01 \x01(\x0b\x32\x08.Weights\x12\x11\n\ton_client\x18\x02 \x01(\x08\"V\n\x12SetWeightsResponse\x12)\n\x12\x64iagnostic_metrics\x18\x01 \x01(\x0b\x32\x08.MetricsH\x00\x88\x01\x01\x42\x15\n\x13_diagnostic_metrics\"T\n\x11TrainEpochRequest\x12\x1b\n\x06server\x18\x01 \x01(\x0b\x32\x0b.DeviceInfo\x12\x15\n\x08round_no\x18\x02 \x01(\x05H\x00\x88\x01\x01\x42\x0b\n\t_round_no\"q\n\x12TrainEpochResponse\x12\x19\n\x07weights\x18\x01 \x01(\x0b\x32\x08.Weights\x12)\n\x12\x64iagnostic_metrics\x18\x02 \x01(\x0b\x32\x08.MetricsH\x00\x88\x01\x01\x42\x15\n\x13_diagnostic_metrics\"P\n\x11TrainBatchRequest\x12\"\n\x0csmashed_data\x18\x01 \x01(\x0b\x32\x0c.Activations\x12\x17\n\x06labels\x18\x02 \x01(\x0b\x32\x07.Labels\"\x91\x01\n\x12TrainBatchResponse\x12\x1d\n\tgradients\x18\x01 \x01(\x0b\x32\n.Gradients\x12)\n\x12\x64iagnostic_metrics\x18\x02 \x01(\x0b\x32\x08.MetricsH\x00\x88\x01\x01\x12\x11\n\x04loss\x18\x03 \x01(\x01H\x01\x88\x01\x01\x42\x15\n\x13_diagnostic_metricsB\x07\n\x05_loss\":\n\x11\x45valGlobalRequest\x12\x12\n\nvalidation\x18\x01 \x01(\x08\x12\x11\n\tfederated\x18\x02 \x01(\x08\"q\n\x12\x45valGlobalResponse\x12\x19\n\x07metrics\x18\x01 \x01(\x0b\x32\x08.Metrics\x12)\n\x12\x64iagnostic_metrics\x18\x02 \x01(\x0b\x32\x08.MetricsH\x00\x88\x01\x01\x42\x15\n\x13_diagnostic_metrics\">\n\x0b\x45valRequest\x12\x1b\n\x06server\x18\x01 \x01(\x0b\x32\x0b.DeviceInfo\x12\x12\n\nvalidation\x18\x02 \x01(\x08\"P\n\x0c\x45valResponse\x12)\n\x12\x64iagnostic_metrics\x18\x01 \x01(\x0b\x32\x08.MetricsH\x00\x88\x01\x01\x42\x15\n\x13_diagnostic_metrics\"O\n\x10\x45valBatchRequest\x12\"\n\x0csmashed_data\x18\x01 \x01(\x0b\x32\x0c.Activations\x12\x17\n\x06labels\x18\x02 \x01(\x0b\x32\x07.Labels\"p\n\x11\x45valBatchResponse\x12\x19\n\x07metrics\x18\x01 \x01(\x0b\x32\x08.Metrics\x12)\n\x12\x64iagnostic_metrics\x18\x02 \x01(\x0b\x32\x08.MetricsH\x00\x88\x01\x01\x42\x15\n\x13_diagnostic_metrics\";\n\x15\x46ullModelTrainRequest\x12\x15\n\x08round_no\x18\x01 \x01(\x05H\x00\x88\x01\x01\x42\x0b\n\t_round_no\"\xce\x01\n\x16\x46ullModelTrainResponse\x12 \n\x0e\x63lient_weights\x18\x01 \x01(\x0b\x32\x08.Weights\x12 \n\x0eserver_weights\x18\x02 \x01(\x0b\x32\x08.Weights\x12\x13\n\x0bnum_samples\x18\x03 \x01(\x05\x12\x19\n\x07metrics\x18\x04 \x01(\x0b\x32\x08.Metrics\x12)\n\x12\x64iagnostic_metrics\x18\x05 \x01(\x0b\x32\x08.MetricsH\x00\x88\x01\x01\x42\x15\n\x13_diagnostic_metrics\"\x18\n\x16StartExperimentRequest\"[\n\x17StartExperimentResponse\x12)\n\x12\x64iagnostic_metrics\x18\x01 \x01(\x0b\x32\x08.MetricsH\x00\x88\x01\x01\x42\x15\n\x13_diagnostic_metrics\"\x16\n\x14\x45ndExperimentRequest\"Y\n\x15\x45ndExperimentResponse\x12)\n\x12\x64iagnostic_metrics\x18\x01 \x01(\x0b\x32\x08.MetricsH\x00\x88\x01\x01\x42\x15\n\x13_diagnostic_metrics\"\x16\n\x14\x42\x61tteryStatusRequest\"y\n\x15\x42\x61tteryStatusResponse\x12\x1e\n\x06status\x18\x01 \x01(\x0b\x32\x0e.BatteryStatus\x12)\n\x12\x64iagnostic_metrics\x18\x02 \x01(\x0b\x32\x08.MetricsH\x00\x88\x01\x01\x42\x15\n\x13_diagnostic_metrics\"\x19\n\x17\x44\x61tasetModelInfoRequest\"\xc7\x01\n\x18\x44\x61tasetModelInfoResponse\x12\x15\n\rtrain_samples\x18\x01 \x01(\x05\x12\x1a\n\x12validation_samples\x18\x02 \x01(\x05\x12\x1a\n\x12\x63lient_model_flops\x18\x03 \x01(\x05\x12\x1a\n\x12server_model_flops\x18\x04 \x01(\x05\x12)\n\x12\x64iagnostic_metrics\x18\x05 \x01(\x0b\x32\x08.MetricsH\x00\x88\x01\x01\x42\x15\n\x13_diagnostic_metrics2\xf8\x08\n\x06\x44\x65vice\x12:\n\x0bTrainGlobal\x12\x13.TrainGlobalRequest\x1a\x14.TrainGlobalResponse\"\x00\x12\x37\n\nSetWeights\x12\x12.SetWeightsRequest\x1a\x13.SetWeightsResponse\"\x00\x12\x37\n\nTrainEpoch\x12\x12.TrainEpochRequest\x1a\x13.TrainEpochResponse\"\x00\x12\x37\n\nTrainBatch\x12\x12.TrainBatchRequest\x1a\x13.TrainBatchResponse\"\x00\x12;\n\x0e\x45valuateGlobal\x12\x12.EvalGlobalRequest\x1a\x13.EvalGlobalResponse\"\x00\x12)\n\x08\x45valuate\x12\x0c.EvalRequest\x1a\r.EvalResponse\"\x00\x12\x38\n\rEvaluateBatch\x12\x11.EvalBatchRequest\x1a\x12.EvalBatchResponse\"\x00\x12\x46\n\x11\x46ullModelTraining\x12\x16.FullModelTrainRequest\x1a\x17.FullModelTrainResponse\"\x00\x12\x46\n\x0fStartExperiment\x12\x17.StartExperimentRequest\x1a\x18.StartExperimentResponse\"\x00\x12@\n\rEndExperiment\x12\x15.EndExperimentRequest\x1a\x16.EndExperimentResponse\"\x00\x12\x43\n\x10GetBatteryStatus\x12\x15.BatteryStatusRequest\x1a\x16.BatteryStatusResponse\"\x00\x12L\n\x13GetDatasetModelInfo\x12\x18.DatasetModelInfoRequest\x1a\x19.DatasetModelInfoResponse\"\x00\x12y\n TrainGlobalParallelSplitLearning\x12(.TrainGlobalParallelSplitLearningRequest\x1a).TrainGlobalParallelSplitLearningResponse\"\x00\x12W\n\x18TrainSingleBatchOnClient\x12\x1b.SingleBatchTrainingRequest\x1a\x1c.SingleBatchTrainingResponse\"\x00\x12\x65\n&BackwardPropagationSingleBatchOnClient\x12\x1b.SingleBatchBackwardRequest\x1a\x1c.SingleBatchBackwardResponse\"\x00\x12\x45\n#SetGradientsAndFinalizeTrainingStep\x12\x14.SetGradientsRequest\x1a\x06.Empty\"\x00\x62\x06proto3')
 
 _globals = globals()
 _builder.BuildMessageAndEnumDescriptors(DESCRIPTOR, _globals)
@@ -31,61 +31,61 @@ if _descriptor._USE_C_DESCRIPTORS == False:
   _globals['_SINGLEBATCHBACKWARDRESPONSE']._serialized_start=212
   _globals['_SINGLEBATCHBACKWARDRESPONSE']._serialized_end=318
   _globals['_SINGLEBATCHTRAININGREQUEST']._serialized_start=320
-  _globals['_SINGLEBATCHTRAININGREQUEST']._serialized_end=369
-  _globals['_SINGLEBATCHTRAININGRESPONSE']._serialized_start=372
-  _globals['_SINGLEBATCHTRAININGRESPONSE']._serialized_end=500
-  _globals['_TRAINGLOBALPARALLELSPLITLEARNINGREQUEST']._serialized_start=503
-  _globals['_TRAINGLOBALPARALLELSPLITLEARNINGREQUEST']._serialized_end=716
-  _globals['_TRAINGLOBALPARALLELSPLITLEARNINGRESPONSE']._serialized_start=719
-  _globals['_TRAINGLOBALPARALLELSPLITLEARNINGRESPONSE']._serialized_end=984
-  _globals['_TRAINGLOBALREQUEST']._serialized_start=987
-  _globals['_TRAINGLOBALREQUEST']._serialized_end=1121
-  _globals['_TRAINGLOBALRESPONSE']._serialized_start=1124
-  _globals['_TRAINGLOBALRESPONSE']._serialized_end=1368
-  _globals['_SETWEIGHTSREQUEST']._serialized_start=1370
-  _globals['_SETWEIGHTSREQUEST']._serialized_end=1435
-  _globals['_SETWEIGHTSRESPONSE']._serialized_start=1437
-  _globals['_SETWEIGHTSRESPONSE']._serialized_end=1523
-  _globals['_TRAINEPOCHREQUEST']._serialized_start=1525
-  _globals['_TRAINEPOCHREQUEST']._serialized_end=1609
-  _globals['_TRAINEPOCHRESPONSE']._serialized_start=1611
-  _globals['_TRAINEPOCHRESPONSE']._serialized_end=1724
-  _globals['_TRAINBATCHREQUEST']._serialized_start=1726
-  _globals['_TRAINBATCHREQUEST']._serialized_end=1806
-  _globals['_TRAINBATCHRESPONSE']._serialized_start=1809
-  _globals['_TRAINBATCHRESPONSE']._serialized_end=1954
-  _globals['_EVALGLOBALREQUEST']._serialized_start=1956
-  _globals['_EVALGLOBALREQUEST']._serialized_end=2014
-  _globals['_EVALGLOBALRESPONSE']._serialized_start=2016
-  _globals['_EVALGLOBALRESPONSE']._serialized_end=2129
-  _globals['_EVALREQUEST']._serialized_start=2131
-  _globals['_EVALREQUEST']._serialized_end=2193
-  _globals['_EVALRESPONSE']._serialized_start=2195
-  _globals['_EVALRESPONSE']._serialized_end=2275
-  _globals['_EVALBATCHREQUEST']._serialized_start=2277
-  _globals['_EVALBATCHREQUEST']._serialized_end=2356
-  _globals['_EVALBATCHRESPONSE']._serialized_start=2358
-  _globals['_EVALBATCHRESPONSE']._serialized_end=2470
-  _globals['_FULLMODELTRAINREQUEST']._serialized_start=2472
-  _globals['_FULLMODELTRAINREQUEST']._serialized_end=2531
-  _globals['_FULLMODELTRAINRESPONSE']._serialized_start=2534
-  _globals['_FULLMODELTRAINRESPONSE']._serialized_end=2740
-  _globals['_STARTEXPERIMENTREQUEST']._serialized_start=2742
-  _globals['_STARTEXPERIMENTREQUEST']._serialized_end=2766
-  _globals['_STARTEXPERIMENTRESPONSE']._serialized_start=2768
-  _globals['_STARTEXPERIMENTRESPONSE']._serialized_end=2859
-  _globals['_ENDEXPERIMENTREQUEST']._serialized_start=2861
-  _globals['_ENDEXPERIMENTREQUEST']._serialized_end=2883
-  _globals['_ENDEXPERIMENTRESPONSE']._serialized_start=2885
-  _globals['_ENDEXPERIMENTRESPONSE']._serialized_end=2974
-  _globals['_BATTERYSTATUSREQUEST']._serialized_start=2976
-  _globals['_BATTERYSTATUSREQUEST']._serialized_end=2998
-  _globals['_BATTERYSTATUSRESPONSE']._serialized_start=3000
-  _globals['_BATTERYSTATUSRESPONSE']._serialized_end=3121
-  _globals['_DATASETMODELINFOREQUEST']._serialized_start=3123
-  _globals['_DATASETMODELINFOREQUEST']._serialized_end=3148
-  _globals['_DATASETMODELINFORESPONSE']._serialized_start=3151
-  _globals['_DATASETMODELINFORESPONSE']._serialized_end=3350
-  _globals['_DEVICE']._serialized_start=3353
-  _globals['_DEVICE']._serialized_end=4497
+  _globals['_SINGLEBATCHTRAININGREQUEST']._serialized_end=387
+  _globals['_SINGLEBATCHTRAININGRESPONSE']._serialized_start=390
+  _globals['_SINGLEBATCHTRAININGRESPONSE']._serialized_end=518
+  _globals['_TRAINGLOBALPARALLELSPLITLEARNINGREQUEST']._serialized_start=521
+  _globals['_TRAINGLOBALPARALLELSPLITLEARNINGREQUEST']._serialized_end=734
+  _globals['_TRAINGLOBALPARALLELSPLITLEARNINGRESPONSE']._serialized_start=737
+  _globals['_TRAINGLOBALPARALLELSPLITLEARNINGRESPONSE']._serialized_end=1002
+  _globals['_TRAINGLOBALREQUEST']._serialized_start=1005
+  _globals['_TRAINGLOBALREQUEST']._serialized_end=1139
+  _globals['_TRAINGLOBALRESPONSE']._serialized_start=1142
+  _globals['_TRAINGLOBALRESPONSE']._serialized_end=1386
+  _globals['_SETWEIGHTSREQUEST']._serialized_start=1388
+  _globals['_SETWEIGHTSREQUEST']._serialized_end=1453
+  _globals['_SETWEIGHTSRESPONSE']._serialized_start=1455
+  _globals['_SETWEIGHTSRESPONSE']._serialized_end=1541
+  _globals['_TRAINEPOCHREQUEST']._serialized_start=1543
+  _globals['_TRAINEPOCHREQUEST']._serialized_end=1627
+  _globals['_TRAINEPOCHRESPONSE']._serialized_start=1629
+  _globals['_TRAINEPOCHRESPONSE']._serialized_end=1742
+  _globals['_TRAINBATCHREQUEST']._serialized_start=1744
+  _globals['_TRAINBATCHREQUEST']._serialized_end=1824
+  _globals['_TRAINBATCHRESPONSE']._serialized_start=1827
+  _globals['_TRAINBATCHRESPONSE']._serialized_end=1972
+  _globals['_EVALGLOBALREQUEST']._serialized_start=1974
+  _globals['_EVALGLOBALREQUEST']._serialized_end=2032
+  _globals['_EVALGLOBALRESPONSE']._serialized_start=2034
+  _globals['_EVALGLOBALRESPONSE']._serialized_end=2147
+  _globals['_EVALREQUEST']._serialized_start=2149
+  _globals['_EVALREQUEST']._serialized_end=2211
+  _globals['_EVALRESPONSE']._serialized_start=2213
+  _globals['_EVALRESPONSE']._serialized_end=2293
+  _globals['_EVALBATCHREQUEST']._serialized_start=2295
+  _globals['_EVALBATCHREQUEST']._serialized_end=2374
+  _globals['_EVALBATCHRESPONSE']._serialized_start=2376
+  _globals['_EVALBATCHRESPONSE']._serialized_end=2488
+  _globals['_FULLMODELTRAINREQUEST']._serialized_start=2490
+  _globals['_FULLMODELTRAINREQUEST']._serialized_end=2549
+  _globals['_FULLMODELTRAINRESPONSE']._serialized_start=2552
+  _globals['_FULLMODELTRAINRESPONSE']._serialized_end=2758
+  _globals['_STARTEXPERIMENTREQUEST']._serialized_start=2760
+  _globals['_STARTEXPERIMENTREQUEST']._serialized_end=2784
+  _globals['_STARTEXPERIMENTRESPONSE']._serialized_start=2786
+  _globals['_STARTEXPERIMENTRESPONSE']._serialized_end=2877
+  _globals['_ENDEXPERIMENTREQUEST']._serialized_start=2879
+  _globals['_ENDEXPERIMENTREQUEST']._serialized_end=2901
+  _globals['_ENDEXPERIMENTRESPONSE']._serialized_start=2903
+  _globals['_ENDEXPERIMENTRESPONSE']._serialized_end=2992
+  _globals['_BATTERYSTATUSREQUEST']._serialized_start=2994
+  _globals['_BATTERYSTATUSREQUEST']._serialized_end=3016
+  _globals['_BATTERYSTATUSRESPONSE']._serialized_start=3018
+  _globals['_BATTERYSTATUSRESPONSE']._serialized_end=3139
+  _globals['_DATASETMODELINFOREQUEST']._serialized_start=3141
+  _globals['_DATASETMODELINFOREQUEST']._serialized_end=3166
+  _globals['_DATASETMODELINFORESPONSE']._serialized_start=3169
+  _globals['_DATASETMODELINFORESPONSE']._serialized_end=3368
+  _globals['_DEVICE']._serialized_start=3371
+  _globals['_DEVICE']._serialized_end=4515
 # @@protoc_insertion_point(module_scope)
diff --git a/edml/generated/connection_pb2.pyi b/edml/generated/connection_pb2.pyi
index a9735505e808ee35b156158e972ab6b206e90b0a..89353343aa1c7e39072bd0ea03c891bd2df7b4df 100644
--- a/edml/generated/connection_pb2.pyi
+++ b/edml/generated/connection_pb2.pyi
@@ -32,10 +32,12 @@ class SingleBatchBackwardResponse(_message.Message):
     def __init__(self, metrics: _Optional[_Union[_datastructures_pb2.Metrics, _Mapping]] = ..., gradients: _Optional[_Union[_datastructures_pb2.Gradients, _Mapping]] = ...) -> None: ...
 
 class SingleBatchTrainingRequest(_message.Message):
-    __slots__ = ["batch_index"]
+    __slots__ = ["batch_index", "round_no"]
     BATCH_INDEX_FIELD_NUMBER: _ClassVar[int]
+    ROUND_NO_FIELD_NUMBER: _ClassVar[int]
     batch_index: int
-    def __init__(self, batch_index: _Optional[int] = ...) -> None: ...
+    round_no: int
+    def __init__(self, batch_index: _Optional[int] = ..., round_no: _Optional[int] = ...) -> None: ...
 
 class SingleBatchTrainingResponse(_message.Message):
     __slots__ = ["smashed_data", "labels"]
diff --git a/edml/proto/connection.proto b/edml/proto/connection.proto
index 2f12881a1cf7e3ee6ae2e076b2f9005141cfece2..5755518d01796bdff68957655d4d339dcf02ab1d 100644
--- a/edml/proto/connection.proto
+++ b/edml/proto/connection.proto
@@ -41,6 +41,7 @@ message SingleBatchBackwardResponse {
 
 message SingleBatchTrainingRequest {
   int32 batch_index = 1;
+  int32 round_no = 2;
 }
 
 message SingleBatchTrainingResponse {