From edde9250b865343f69a836aaffdf73e6c936e5ad Mon Sep 17 00:00:00 2001
From: Tim Bauerle <tim.bauerle@rwth-aachen.de>
Date: Thu, 27 Jun 2024 15:36:58 +0200
Subject: [PATCH] Added sharing of the server optimizer state for SwarmSL-based
 training

---
 edml/controllers/parallel_split_controller.py |   8 +-
 edml/controllers/swarm_controller.py          |  32 ++++-
 edml/core/device.py                           |  60 ++++++---
 edml/core/server.py                           |  30 ++++-
 edml/generated/connection_pb2.py              | 114 +++++++++---------
 edml/generated/connection_pb2.pyi             |  24 ++--
 edml/proto/connection.proto                   |   5 +
 .../controllers/swarm_controller_test.py      |  17 ++-
 edml/tests/core/device_test.py                |  17 ++-
 9 files changed, 207 insertions(+), 100 deletions(-)

diff --git a/edml/controllers/parallel_split_controller.py b/edml/controllers/parallel_split_controller.py
index a251898..d250d21 100644
--- a/edml/controllers/parallel_split_controller.py
+++ b/edml/controllers/parallel_split_controller.py
@@ -17,6 +17,7 @@ class ParallelSplitController(BaseController):
         client_weights = None
         server_weights = None
         server_device_id = self.cfg.topology.devices[0].device_id
+        optimizer_state = None
 
         for i in range(self.cfg.experiment.max_rounds):
             print(f"=================================Round {i}")
@@ -43,7 +44,10 @@ class ParallelSplitController(BaseController):
 
             # Start parallel training of all client devices.
             training_response = self.request_dispatcher.train_parallel_on_server(
-                server_device_id=server_device_id, epochs=1, round_no=i
+                server_device_id=server_device_id,
+                epochs=1,
+                round_no=i,
+                optimizer_state=optimizer_state,
             )
 
             self._refresh_active_devices()
@@ -62,7 +66,7 @@ class ParallelSplitController(BaseController):
             if training_response is False:  # server device unavailable
                 break
             else:
-                cw, server_weights, metrics, _ = training_response
+                cw, server_weights, metrics, _, optimizer_state = training_response
 
                 self._aggregate_and_log_metrics(metrics, i)
 
diff --git a/edml/controllers/swarm_controller.py b/edml/controllers/swarm_controller.py
index feb9990..f58a378 100644
--- a/edml/controllers/swarm_controller.py
+++ b/edml/controllers/swarm_controller.py
@@ -1,3 +1,5 @@
+from typing import Any
+
 from edml.controllers.base_controller import BaseController
 from edml.controllers.scheduler.base import NextServerScheduler
 from edml.controllers.scheduler.max_battery import MaxBatteryNextServerScheduler
@@ -31,6 +33,7 @@ class SwarmController(BaseController):
         server_weights = None
         server_device_id = None
         diagnostic_metric_container = None
+        optimizer_state = None
         for i in range(self.cfg.experiment.max_rounds):
             print(f"Round {i}")
             self._update_devices_battery_status()
@@ -46,10 +49,18 @@ class SwarmController(BaseController):
                 print("No active client devices left.")
                 break
 
-            client_weights, server_weights, metrics, diagnostic_metric_container = (
-                self._swarm_train_round(
-                    client_weights, server_weights, server_device_id, round_no=i
-                )
+            (
+                client_weights,
+                server_weights,
+                metrics,
+                diagnostic_metric_container,
+                optimizer_state,
+            ) = self._swarm_train_round(
+                client_weights,
+                server_weights,
+                server_device_id,
+                round_no=i,
+                optimizer_state=optimizer_state,
             )
 
             self._refresh_active_devices()
@@ -75,7 +86,12 @@ class SwarmController(BaseController):
                 self._save_weights(client_weights, server_weights, i)
 
     def _swarm_train_round(
-        self, client_weights, server_weights, server_device_id, round_no: int = -1
+        self,
+        client_weights,
+        server_weights,
+        server_device_id,
+        round_no: int = -1,
+        optimizer_state: dict[str, Any] = None,
     ):
         self._refresh_active_devices()
         # set latest client weights on first device to train on
@@ -88,7 +104,10 @@ class SwarmController(BaseController):
         )
 
         training_response = self.request_dispatcher.train_global_on(
-            server_device_id, epochs=1, round_no=round_no
+            server_device_id,
+            epochs=1,
+            round_no=round_no,
+            optimizer_state=optimizer_state,
         )
 
         if training_response is not False:  # server device unavailable
@@ -99,6 +118,7 @@ class SwarmController(BaseController):
                 server_weights,
                 None,
                 None,
+                optimizer_state,
             )  # return most recent weights and no metrics
 
     def _select_server_device(
diff --git a/edml/core/device.py b/edml/core/device.py
index 939ea0b..ee12f48 100644
--- a/edml/core/device.py
+++ b/edml/core/device.py
@@ -216,9 +216,11 @@ class Device(ABC):
 class NetworkDevice(Device):
     @update_battery
     @log_execution_time("logger", "train_parallel_split_learning")
-    def train_parallel_split_learning(self, clients: list[str], round_no: int):
+    def train_parallel_split_learning(
+        self, clients: list[str], round_no: int, optimizer_state: dict[str, Any] = None
+    ):
         return self.server.train_parallel_split_learning(
-            clients=clients, round_no=round_no
+            clients=clients, round_no=round_no, optimizer_state=optimizer_state
         )
 
     @update_battery
@@ -268,10 +270,15 @@ class NetworkDevice(Device):
     @update_battery
     @log_execution_time("logger", "train_global_time")
     def train_global(
-        self, epochs: int, round_no: int = -1
-    ) -> Tuple[Any, Any, ModelMetricResultContainer, DiagnosticMetricResultContainer]:
+        self, epochs: int, round_no: int = -1, optimizer_state: dict[str, Any] = None
+    ) -> Tuple[
+        Any, Any, ModelMetricResultContainer, DiagnosticMetricResultContainer, Any
+    ]:
         return self.server.train(
-            devices=self.__get_device_ids__(), epochs=epochs, round_no=round_no
+            devices=self.__get_device_ids__(),
+            epochs=epochs,
+            round_no=round_no,
+            optimizer_state=optimizer_state,
         )
 
     def __get_device_ids__(self) -> List[str]:
@@ -389,8 +396,8 @@ class NetworkDevice(Device):
         Any, Any, int, ModelMetricResultContainer, DiagnosticMetricResultContainer
     ]:
         """Returns client and server weights, the number of samples used for training and metrics"""
-        client_weights, server_weights, metrics, diagnostic_metrics = self.server.train(
-            devices=[self.device_id], epochs=1, round_no=round_no
+        client_weights, server_weights, metrics, diagnostic_metrics, _ = (
+            self.server.train(devices=[self.device_id], epochs=1, round_no=round_no)
         )
         num_samples = self.client.get_num_samples()
         return client_weights, server_weights, num_samples, metrics, diagnostic_metrics
@@ -406,15 +413,17 @@ class RPCDeviceServicer(DeviceServicer):
 
     def TrainGlobal(self, request, context):
         print(f"Called TrainGlobal on device {self.device.device_id}")
-        client_weights, server_weights, metrics, diagnostic_metrics = (
+        client_weights, server_weights, metrics, diagnostic_metrics, optimizer_state = (
             self.device.train_global(request.epochs, request.round_no)
         )
-        return connection_pb2.TrainGlobalResponse(
+        response = connection_pb2.TrainGlobalResponse(
             client_weights=Weights(weights=state_dict_to_proto(client_weights)),
             server_weights=Weights(weights=state_dict_to_proto(server_weights)),
             metrics=metrics_to_proto(metrics),
             diagnostic_metrics=metrics_to_proto(diagnostic_metrics),
+            optimizer_state=state_dict_to_proto(optimizer_state),
         )
+        return response
 
     def SetWeights(self, request, context):
         print(f"Called SetWeights on device {self.device.device_id}")
@@ -520,17 +529,20 @@ class RPCDeviceServicer(DeviceServicer):
         clients = self.device.__get_device_ids__()
         round_no = request.round_no
 
-        cw, sw, model_metrics, diagnostic_metrics = (
+        cw, sw, model_metrics, diagnostic_metrics, optimizer_state = (
             self.device.train_parallel_split_learning(
                 clients=clients, round_no=round_no
             )
         )
-        return connection_pb2.TrainGlobalParallelSplitLearningResponse(
+        response = connection_pb2.TrainGlobalParallelSplitLearningResponse(
             client_weights=Weights(weights=state_dict_to_proto(cw)),
             server_weights=Weights(weights=state_dict_to_proto(sw)),
             metrics=metrics_to_proto(model_metrics),
             diagnostic_metrics=metrics_to_proto(diagnostic_metrics),
         )
+        if optimizer_state is not None:
+            response.optimizer_state = state_dict_to_proto(optimizer_state)
+        return response
 
     def TrainSingleBatchOnClient(self, request, context):
         batch_index = request.batch_index
@@ -588,14 +600,19 @@ class DeviceRequestDispatcher:
         return None
 
     def train_parallel_on_server(
-        self, server_device_id: str, epochs: int, round_no: int
+        self,
+        server_device_id: str,
+        epochs: int,
+        round_no: int,
+        optimizer_state: dict[str, Any] = None,
     ):
         try:
             response: TrainGlobalParallelSplitLearningResponse = self._get_connection(
                 server_device_id
             ).TrainGlobalParallelSplitLearning(
                 connection_pb2.TrainGlobalParallelSplitLearningRequest(
-                    round_no=round_no
+                    round_no=round_no,
+                    optimizer_state=state_dict_to_proto(optimizer_state),
                 )
             )
             return (
@@ -603,6 +620,7 @@ class DeviceRequestDispatcher:
                 proto_to_weights(response.server_weights),
                 proto_to_metrics(response.metrics),
                 self._add_byte_size_to_diagnostic_metrics(response, self.device_id),
+                proto_to_state_dict(response.optimizer_state),
             )
         except grpc.RpcError:
             self._handle_rpc_error(server_device_id)
@@ -685,24 +703,36 @@ class DeviceRequestDispatcher:
             diagnostic_metrics.merge(_proto_size_per_field(request, device_id))
         return diagnostic_metrics
 
-    def train_global_on(self, device_id: str, epochs: int, round_no: int = -1) -> Union[
+    def train_global_on(
+        self,
+        device_id: str,
+        epochs: int,
+        round_no: int = -1,
+        optimizer_state: dict[str, Any] = None,
+    ) -> Union[
         Tuple[
             Dict[str, Any],
             Dict[str, Any],
             ModelMetricResultContainer,
             DiagnosticMetricResultContainer,
+            Dict[str, Any],
         ],
         bool,
     ]:
         try:
             response: TrainGlobalResponse = self._get_connection(device_id).TrainGlobal(
-                connection_pb2.TrainGlobalRequest(epochs=epochs, round_no=round_no)
+                connection_pb2.TrainGlobalRequest(
+                    epochs=epochs,
+                    round_no=round_no,
+                    optimizer_state=state_dict_to_proto(optimizer_state),
+                )
             )
             return (
                 proto_to_weights(response.client_weights),
                 proto_to_weights(response.server_weights),
                 proto_to_metrics(response.metrics),
                 self._add_byte_size_to_diagnostic_metrics(response, self.device_id),
+                proto_to_state_dict(response.optimizer_state),
             )
         except grpc.RpcError:
             self._handle_rpc_error(device_id)
diff --git a/edml/core/server.py b/edml/core/server.py
index 3273722..02033d4 100644
--- a/edml/core/server.py
+++ b/edml/core/server.py
@@ -69,18 +69,27 @@ class DeviceServer:
 
     @check_device_set()
     def train(
-        self, devices: List[str], epochs: int = 1, round_no: int = -1
-    ) -> Tuple[Any, Any, ModelMetricResultContainer, DiagnosticMetricResultContainer]:
+        self,
+        devices: List[str],
+        epochs: int = 1,
+        round_no: int = -1,
+        optimizer_state: dict[str, Any] = None,
+    ) -> Tuple[
+        Any, Any, ModelMetricResultContainer, DiagnosticMetricResultContainer, Any
+    ]:
         """Train the model on the given devices for the given number of epochs.
         Shares the weights among clients and saves the final weights to the configured paths.
         Args:
             devices: The devices to train on
             epochs: Optionally, the number of epochs to train.
             round_no: Optionally, the current global epoch number if a learning rate scheduler is used.
+            optimizer_state: Optionally, the optimizer_state to proceed from
         """
         client_weights = None
         metrics = ModelMetricResultContainer()
         diagnostic_metric_container = DiagnosticMetricResultContainer()
+        if optimizer_state is not None:
+            self._optimizer.load_state_dict(optimizer_state)
         for epoch in range(epochs):
             for device_id in devices:
                 print(
@@ -117,7 +126,13 @@ class DeviceServer:
                     self._lr_scheduler.step(round_no + epoch)
                 else:
                     self._lr_scheduler.step()
-        return client_weights, self.get_weights(), metrics, diagnostic_metric_container
+        return (
+            client_weights,
+            self.get_weights(),
+            metrics,
+            diagnostic_metric_container,
+            self._optimizer.state_dict(),
+        )
 
     @simulate_latency_decorator(latency_factor_attr="latency_factor")
     def train_batch(self, smashed_data, labels) -> Variable:
@@ -197,7 +212,9 @@ class DeviceServer:
         self._metrics.metrics_on_batch(pred.cpu(), labels.cpu().int())
 
     @simulate_latency_decorator(latency_factor_attr="latency_factor")
-    def train_parallel_split_learning(self, clients: List[str], round_no: int):
+    def train_parallel_split_learning(
+        self, clients: List[str], round_no: int, optimizer_state: dict[str, Any] = None
+    ):
         def client_training_job(client_id: str, batch_index: int) -> SLTrainBatchResult:
             return self.node_device.train_batch_on_client_only_on(
                 device_id=client_id, batch_index=batch_index
@@ -208,6 +225,9 @@ class DeviceServer:
                 client_id=client_id, gradients=gradients
             )
 
+        if optimizer_state is not None:
+            self._optimizer.load_state_dict(optimizer_state)
+
         num_threads = len(clients)
         executor = create_executor_with_threads(num_threads)
 
@@ -285,11 +305,13 @@ class DeviceServer:
             model_metrics.add_results(train_metrics)
             model_metrics.add_results(val_metrics)
 
+        optimizer_state = self._optimizer.state_dict()
         return (
             self.node_device.client.get_weights(),
             self.get_weights(),
             model_metrics,
             diagnostic_metrics,
+            optimizer_state,
         )
 
 
diff --git a/edml/generated/connection_pb2.py b/edml/generated/connection_pb2.py
index b011e8f..73fe17b 100644
--- a/edml/generated/connection_pb2.py
+++ b/edml/generated/connection_pb2.py
@@ -14,7 +14,7 @@ _sym_db = _symbol_database.Default()
 import datastructures_pb2 as datastructures__pb2
 
 
-DESCRIPTOR = _descriptor_pool.Default().AddSerializedFile(b'\n\x10\x63onnection.proto\x1a\x14\x64\x61tastructures.proto\"5\n\x14UpdateWeightsRequest\x12\x1d\n\tgradients\x18\x01 \x01(\x0b\x32\n.Gradients\";\n\x1aSingleBatchBackwardRequest\x12\x1d\n\tgradients\x18\x01 \x01(\x0b\x32\n.Gradients\"8\n\x1bSingleBatchBackwardResponse\x12\x19\n\x07metrics\x18\x01 \x01(\x0b\x32\x08.Metrics\"1\n\x1aSingleBatchTrainingRequest\x12\x13\n\x0b\x62\x61tch_index\x18\x01 \x01(\x05\"Z\n\x1bSingleBatchTrainingResponse\x12\"\n\x0csmashed_data\x18\x01 \x01(\x0b\x32\x0c.Activations\x12\x17\n\x06labels\x18\x02 \x01(\x0b\x32\x07.Labels\"M\n\'TrainGlobalParallelSplitLearningRequest\x12\x15\n\x08round_no\x18\x01 \x01(\x05H\x00\x88\x01\x01\x42\x0b\n\t_round_no\"\xcb\x01\n(TrainGlobalParallelSplitLearningResponse\x12 \n\x0e\x63lient_weights\x18\x01 \x01(\x0b\x32\x08.Weights\x12 \n\x0eserver_weights\x18\x02 \x01(\x0b\x32\x08.Weights\x12\x19\n\x07metrics\x18\x03 \x01(\x0b\x32\x08.Metrics\x12)\n\x12\x64iagnostic_metrics\x18\x04 \x01(\x0b\x32\x08.MetricsH\x00\x88\x01\x01\x42\x15\n\x13_diagnostic_metrics\"H\n\x12TrainGlobalRequest\x12\x0e\n\x06\x65pochs\x18\x01 \x01(\x05\x12\x15\n\x08round_no\x18\x02 \x01(\x05H\x00\x88\x01\x01\x42\x0b\n\t_round_no\"\xb6\x01\n\x13TrainGlobalResponse\x12 \n\x0e\x63lient_weights\x18\x01 \x01(\x0b\x32\x08.Weights\x12 \n\x0eserver_weights\x18\x02 \x01(\x0b\x32\x08.Weights\x12\x19\n\x07metrics\x18\x03 \x01(\x0b\x32\x08.Metrics\x12)\n\x12\x64iagnostic_metrics\x18\x04 \x01(\x0b\x32\x08.MetricsH\x00\x88\x01\x01\x42\x15\n\x13_diagnostic_metrics\"A\n\x11SetWeightsRequest\x12\x19\n\x07weights\x18\x01 \x01(\x0b\x32\x08.Weights\x12\x11\n\ton_client\x18\x02 \x01(\x08\"V\n\x12SetWeightsResponse\x12)\n\x12\x64iagnostic_metrics\x18\x01 \x01(\x0b\x32\x08.MetricsH\x00\x88\x01\x01\x42\x15\n\x13_diagnostic_metrics\"T\n\x11TrainEpochRequest\x12\x1b\n\x06server\x18\x01 \x01(\x0b\x32\x0b.DeviceInfo\x12\x15\n\x08round_no\x18\x02 \x01(\x05H\x00\x88\x01\x01\x42\x0b\n\t_round_no\"q\n\x12TrainEpochResponse\x12\x19\n\x07weights\x18\x01 \x01(\x0b\x32\x08.Weights\x12)\n\x12\x64iagnostic_metrics\x18\x02 \x01(\x0b\x32\x08.MetricsH\x00\x88\x01\x01\x42\x15\n\x13_diagnostic_metrics\"P\n\x11TrainBatchRequest\x12\"\n\x0csmashed_data\x18\x01 \x01(\x0b\x32\x0c.Activations\x12\x17\n\x06labels\x18\x02 \x01(\x0b\x32\x07.Labels\"u\n\x12TrainBatchResponse\x12\x1d\n\tgradients\x18\x01 \x01(\x0b\x32\n.Gradients\x12)\n\x12\x64iagnostic_metrics\x18\x02 \x01(\x0b\x32\x08.MetricsH\x00\x88\x01\x01\x42\x15\n\x13_diagnostic_metrics\":\n\x11\x45valGlobalRequest\x12\x12\n\nvalidation\x18\x01 \x01(\x08\x12\x11\n\tfederated\x18\x02 \x01(\x08\"q\n\x12\x45valGlobalResponse\x12\x19\n\x07metrics\x18\x01 \x01(\x0b\x32\x08.Metrics\x12)\n\x12\x64iagnostic_metrics\x18\x02 \x01(\x0b\x32\x08.MetricsH\x00\x88\x01\x01\x42\x15\n\x13_diagnostic_metrics\">\n\x0b\x45valRequest\x12\x1b\n\x06server\x18\x01 \x01(\x0b\x32\x0b.DeviceInfo\x12\x12\n\nvalidation\x18\x02 \x01(\x08\"P\n\x0c\x45valResponse\x12)\n\x12\x64iagnostic_metrics\x18\x01 \x01(\x0b\x32\x08.MetricsH\x00\x88\x01\x01\x42\x15\n\x13_diagnostic_metrics\"O\n\x10\x45valBatchRequest\x12\"\n\x0csmashed_data\x18\x01 \x01(\x0b\x32\x0c.Activations\x12\x17\n\x06labels\x18\x02 \x01(\x0b\x32\x07.Labels\"p\n\x11\x45valBatchResponse\x12\x19\n\x07metrics\x18\x01 \x01(\x0b\x32\x08.Metrics\x12)\n\x12\x64iagnostic_metrics\x18\x02 \x01(\x0b\x32\x08.MetricsH\x00\x88\x01\x01\x42\x15\n\x13_diagnostic_metrics\";\n\x15\x46ullModelTrainRequest\x12\x15\n\x08round_no\x18\x01 \x01(\x05H\x00\x88\x01\x01\x42\x0b\n\t_round_no\"\xce\x01\n\x16\x46ullModelTrainResponse\x12 \n\x0e\x63lient_weights\x18\x01 \x01(\x0b\x32\x08.Weights\x12 \n\x0eserver_weights\x18\x02 \x01(\x0b\x32\x08.Weights\x12\x13\n\x0bnum_samples\x18\x03 \x01(\x05\x12\x19\n\x07metrics\x18\x04 \x01(\x0b\x32\x08.Metrics\x12)\n\x12\x64iagnostic_metrics\x18\x05 \x01(\x0b\x32\x08.MetricsH\x00\x88\x01\x01\x42\x15\n\x13_diagnostic_metrics\"\x18\n\x16StartExperimentRequest\"[\n\x17StartExperimentResponse\x12)\n\x12\x64iagnostic_metrics\x18\x01 \x01(\x0b\x32\x08.MetricsH\x00\x88\x01\x01\x42\x15\n\x13_diagnostic_metrics\"\x16\n\x14\x45ndExperimentRequest\"Y\n\x15\x45ndExperimentResponse\x12)\n\x12\x64iagnostic_metrics\x18\x01 \x01(\x0b\x32\x08.MetricsH\x00\x88\x01\x01\x42\x15\n\x13_diagnostic_metrics\"\x16\n\x14\x42\x61tteryStatusRequest\"y\n\x15\x42\x61tteryStatusResponse\x12\x1e\n\x06status\x18\x01 \x01(\x0b\x32\x0e.BatteryStatus\x12)\n\x12\x64iagnostic_metrics\x18\x02 \x01(\x0b\x32\x08.MetricsH\x00\x88\x01\x01\x42\x15\n\x13_diagnostic_metrics\"\x19\n\x17\x44\x61tasetModelInfoRequest\"\xc7\x01\n\x18\x44\x61tasetModelInfoResponse\x12\x15\n\rtrain_samples\x18\x01 \x01(\x05\x12\x1a\n\x12validation_samples\x18\x02 \x01(\x05\x12\x1a\n\x12\x63lient_model_flops\x18\x03 \x01(\x05\x12\x1a\n\x12server_model_flops\x18\x04 \x01(\x05\x12)\n\x12\x64iagnostic_metrics\x18\x05 \x01(\x0b\x32\x08.MetricsH\x00\x88\x01\x01\x42\x15\n\x13_diagnostic_metrics2\xb1\x08\n\x06\x44\x65vice\x12:\n\x0bTrainGlobal\x12\x13.TrainGlobalRequest\x1a\x14.TrainGlobalResponse\"\x00\x12\x37\n\nSetWeights\x12\x12.SetWeightsRequest\x1a\x13.SetWeightsResponse\"\x00\x12\x37\n\nTrainEpoch\x12\x12.TrainEpochRequest\x1a\x13.TrainEpochResponse\"\x00\x12\x37\n\nTrainBatch\x12\x12.TrainBatchRequest\x1a\x13.TrainBatchResponse\"\x00\x12;\n\x0e\x45valuateGlobal\x12\x12.EvalGlobalRequest\x1a\x13.EvalGlobalResponse\"\x00\x12)\n\x08\x45valuate\x12\x0c.EvalRequest\x1a\r.EvalResponse\"\x00\x12\x38\n\rEvaluateBatch\x12\x11.EvalBatchRequest\x1a\x12.EvalBatchResponse\"\x00\x12\x46\n\x11\x46ullModelTraining\x12\x16.FullModelTrainRequest\x1a\x17.FullModelTrainResponse\"\x00\x12\x46\n\x0fStartExperiment\x12\x17.StartExperimentRequest\x1a\x18.StartExperimentResponse\"\x00\x12@\n\rEndExperiment\x12\x15.EndExperimentRequest\x1a\x16.EndExperimentResponse\"\x00\x12\x43\n\x10GetBatteryStatus\x12\x15.BatteryStatusRequest\x1a\x16.BatteryStatusResponse\"\x00\x12L\n\x13GetDatasetModelInfo\x12\x18.DatasetModelInfoRequest\x1a\x19.DatasetModelInfoResponse\"\x00\x12y\n TrainGlobalParallelSplitLearning\x12(.TrainGlobalParallelSplitLearningRequest\x1a).TrainGlobalParallelSplitLearningResponse\"\x00\x12W\n\x18TrainSingleBatchOnClient\x12\x1b.SingleBatchTrainingRequest\x1a\x1c.SingleBatchTrainingResponse\"\x00\x12\x65\n&BackwardPropagationSingleBatchOnClient\x12\x1b.SingleBatchBackwardRequest\x1a\x1c.SingleBatchBackwardResponse\"\x00\x62\x06proto3')
+DESCRIPTOR = _descriptor_pool.Default().AddSerializedFile(b'\n\x10\x63onnection.proto\x1a\x14\x64\x61tastructures.proto\"5\n\x14UpdateWeightsRequest\x12\x1d\n\tgradients\x18\x01 \x01(\x0b\x32\n.Gradients\";\n\x1aSingleBatchBackwardRequest\x12\x1d\n\tgradients\x18\x01 \x01(\x0b\x32\n.Gradients\"8\n\x1bSingleBatchBackwardResponse\x12\x19\n\x07metrics\x18\x01 \x01(\x0b\x32\x08.Metrics\"1\n\x1aSingleBatchTrainingRequest\x12\x13\n\x0b\x62\x61tch_index\x18\x01 \x01(\x05\"\x80\x01\n\x1bSingleBatchTrainingResponse\x12\'\n\x0csmashed_data\x18\x01 \x01(\x0b\x32\x0c.ActivationsH\x00\x88\x01\x01\x12\x1c\n\x06labels\x18\x02 \x01(\x0b\x32\x07.LabelsH\x01\x88\x01\x01\x42\x0f\n\r_smashed_dataB\t\n\x07_labels\"\x8b\x01\n\'TrainGlobalParallelSplitLearningRequest\x12\x15\n\x08round_no\x18\x01 \x01(\x05H\x00\x88\x01\x01\x12(\n\x0foptimizer_state\x18\x02 \x01(\x0b\x32\n.StateDictH\x01\x88\x01\x01\x42\x0b\n\t_round_noB\x12\n\x10_optimizer_state\"\x89\x02\n(TrainGlobalParallelSplitLearningResponse\x12 \n\x0e\x63lient_weights\x18\x01 \x01(\x0b\x32\x08.Weights\x12 \n\x0eserver_weights\x18\x02 \x01(\x0b\x32\x08.Weights\x12\x19\n\x07metrics\x18\x03 \x01(\x0b\x32\x08.Metrics\x12)\n\x12\x64iagnostic_metrics\x18\x04 \x01(\x0b\x32\x08.MetricsH\x00\x88\x01\x01\x12(\n\x0foptimizer_state\x18\x05 \x01(\x0b\x32\n.StateDictH\x01\x88\x01\x01\x42\x15\n\x13_diagnostic_metricsB\x12\n\x10_optimizer_state\"\x86\x01\n\x12TrainGlobalRequest\x12\x0e\n\x06\x65pochs\x18\x01 \x01(\x05\x12\x15\n\x08round_no\x18\x02 \x01(\x05H\x00\x88\x01\x01\x12(\n\x0foptimizer_state\x18\x03 \x01(\x0b\x32\n.StateDictH\x01\x88\x01\x01\x42\x0b\n\t_round_noB\x12\n\x10_optimizer_state\"\xf4\x01\n\x13TrainGlobalResponse\x12 \n\x0e\x63lient_weights\x18\x01 \x01(\x0b\x32\x08.Weights\x12 \n\x0eserver_weights\x18\x02 \x01(\x0b\x32\x08.Weights\x12\x19\n\x07metrics\x18\x03 \x01(\x0b\x32\x08.Metrics\x12)\n\x12\x64iagnostic_metrics\x18\x04 \x01(\x0b\x32\x08.MetricsH\x00\x88\x01\x01\x12(\n\x0foptimizer_state\x18\x05 \x01(\x0b\x32\n.StateDictH\x01\x88\x01\x01\x42\x15\n\x13_diagnostic_metricsB\x12\n\x10_optimizer_state\"A\n\x11SetWeightsRequest\x12\x19\n\x07weights\x18\x01 \x01(\x0b\x32\x08.Weights\x12\x11\n\ton_client\x18\x02 \x01(\x08\"V\n\x12SetWeightsResponse\x12)\n\x12\x64iagnostic_metrics\x18\x01 \x01(\x0b\x32\x08.MetricsH\x00\x88\x01\x01\x42\x15\n\x13_diagnostic_metrics\"T\n\x11TrainEpochRequest\x12\x1b\n\x06server\x18\x01 \x01(\x0b\x32\x0b.DeviceInfo\x12\x15\n\x08round_no\x18\x02 \x01(\x05H\x00\x88\x01\x01\x42\x0b\n\t_round_no\"q\n\x12TrainEpochResponse\x12\x19\n\x07weights\x18\x01 \x01(\x0b\x32\x08.Weights\x12)\n\x12\x64iagnostic_metrics\x18\x02 \x01(\x0b\x32\x08.MetricsH\x00\x88\x01\x01\x42\x15\n\x13_diagnostic_metrics\"P\n\x11TrainBatchRequest\x12\"\n\x0csmashed_data\x18\x01 \x01(\x0b\x32\x0c.Activations\x12\x17\n\x06labels\x18\x02 \x01(\x0b\x32\x07.Labels\"u\n\x12TrainBatchResponse\x12\x1d\n\tgradients\x18\x01 \x01(\x0b\x32\n.Gradients\x12)\n\x12\x64iagnostic_metrics\x18\x02 \x01(\x0b\x32\x08.MetricsH\x00\x88\x01\x01\x42\x15\n\x13_diagnostic_metrics\":\n\x11\x45valGlobalRequest\x12\x12\n\nvalidation\x18\x01 \x01(\x08\x12\x11\n\tfederated\x18\x02 \x01(\x08\"q\n\x12\x45valGlobalResponse\x12\x19\n\x07metrics\x18\x01 \x01(\x0b\x32\x08.Metrics\x12)\n\x12\x64iagnostic_metrics\x18\x02 \x01(\x0b\x32\x08.MetricsH\x00\x88\x01\x01\x42\x15\n\x13_diagnostic_metrics\">\n\x0b\x45valRequest\x12\x1b\n\x06server\x18\x01 \x01(\x0b\x32\x0b.DeviceInfo\x12\x12\n\nvalidation\x18\x02 \x01(\x08\"P\n\x0c\x45valResponse\x12)\n\x12\x64iagnostic_metrics\x18\x01 \x01(\x0b\x32\x08.MetricsH\x00\x88\x01\x01\x42\x15\n\x13_diagnostic_metrics\"O\n\x10\x45valBatchRequest\x12\"\n\x0csmashed_data\x18\x01 \x01(\x0b\x32\x0c.Activations\x12\x17\n\x06labels\x18\x02 \x01(\x0b\x32\x07.Labels\"p\n\x11\x45valBatchResponse\x12\x19\n\x07metrics\x18\x01 \x01(\x0b\x32\x08.Metrics\x12)\n\x12\x64iagnostic_metrics\x18\x02 \x01(\x0b\x32\x08.MetricsH\x00\x88\x01\x01\x42\x15\n\x13_diagnostic_metrics\";\n\x15\x46ullModelTrainRequest\x12\x15\n\x08round_no\x18\x01 \x01(\x05H\x00\x88\x01\x01\x42\x0b\n\t_round_no\"\xce\x01\n\x16\x46ullModelTrainResponse\x12 \n\x0e\x63lient_weights\x18\x01 \x01(\x0b\x32\x08.Weights\x12 \n\x0eserver_weights\x18\x02 \x01(\x0b\x32\x08.Weights\x12\x13\n\x0bnum_samples\x18\x03 \x01(\x05\x12\x19\n\x07metrics\x18\x04 \x01(\x0b\x32\x08.Metrics\x12)\n\x12\x64iagnostic_metrics\x18\x05 \x01(\x0b\x32\x08.MetricsH\x00\x88\x01\x01\x42\x15\n\x13_diagnostic_metrics\"\x18\n\x16StartExperimentRequest\"[\n\x17StartExperimentResponse\x12)\n\x12\x64iagnostic_metrics\x18\x01 \x01(\x0b\x32\x08.MetricsH\x00\x88\x01\x01\x42\x15\n\x13_diagnostic_metrics\"\x16\n\x14\x45ndExperimentRequest\"Y\n\x15\x45ndExperimentResponse\x12)\n\x12\x64iagnostic_metrics\x18\x01 \x01(\x0b\x32\x08.MetricsH\x00\x88\x01\x01\x42\x15\n\x13_diagnostic_metrics\"\x16\n\x14\x42\x61tteryStatusRequest\"y\n\x15\x42\x61tteryStatusResponse\x12\x1e\n\x06status\x18\x01 \x01(\x0b\x32\x0e.BatteryStatus\x12)\n\x12\x64iagnostic_metrics\x18\x02 \x01(\x0b\x32\x08.MetricsH\x00\x88\x01\x01\x42\x15\n\x13_diagnostic_metrics\"\x19\n\x17\x44\x61tasetModelInfoRequest\"\xc7\x01\n\x18\x44\x61tasetModelInfoResponse\x12\x15\n\rtrain_samples\x18\x01 \x01(\x05\x12\x1a\n\x12validation_samples\x18\x02 \x01(\x05\x12\x1a\n\x12\x63lient_model_flops\x18\x03 \x01(\x05\x12\x1a\n\x12server_model_flops\x18\x04 \x01(\x05\x12)\n\x12\x64iagnostic_metrics\x18\x05 \x01(\x0b\x32\x08.MetricsH\x00\x88\x01\x01\x42\x15\n\x13_diagnostic_metrics2\xb1\x08\n\x06\x44\x65vice\x12:\n\x0bTrainGlobal\x12\x13.TrainGlobalRequest\x1a\x14.TrainGlobalResponse\"\x00\x12\x37\n\nSetWeights\x12\x12.SetWeightsRequest\x1a\x13.SetWeightsResponse\"\x00\x12\x37\n\nTrainEpoch\x12\x12.TrainEpochRequest\x1a\x13.TrainEpochResponse\"\x00\x12\x37\n\nTrainBatch\x12\x12.TrainBatchRequest\x1a\x13.TrainBatchResponse\"\x00\x12;\n\x0e\x45valuateGlobal\x12\x12.EvalGlobalRequest\x1a\x13.EvalGlobalResponse\"\x00\x12)\n\x08\x45valuate\x12\x0c.EvalRequest\x1a\r.EvalResponse\"\x00\x12\x38\n\rEvaluateBatch\x12\x11.EvalBatchRequest\x1a\x12.EvalBatchResponse\"\x00\x12\x46\n\x11\x46ullModelTraining\x12\x16.FullModelTrainRequest\x1a\x17.FullModelTrainResponse\"\x00\x12\x46\n\x0fStartExperiment\x12\x17.StartExperimentRequest\x1a\x18.StartExperimentResponse\"\x00\x12@\n\rEndExperiment\x12\x15.EndExperimentRequest\x1a\x16.EndExperimentResponse\"\x00\x12\x43\n\x10GetBatteryStatus\x12\x15.BatteryStatusRequest\x1a\x16.BatteryStatusResponse\"\x00\x12L\n\x13GetDatasetModelInfo\x12\x18.DatasetModelInfoRequest\x1a\x19.DatasetModelInfoResponse\"\x00\x12y\n TrainGlobalParallelSplitLearning\x12(.TrainGlobalParallelSplitLearningRequest\x1a).TrainGlobalParallelSplitLearningResponse\"\x00\x12W\n\x18TrainSingleBatchOnClient\x12\x1b.SingleBatchTrainingRequest\x1a\x1c.SingleBatchTrainingResponse\"\x00\x12\x65\n&BackwardPropagationSingleBatchOnClient\x12\x1b.SingleBatchBackwardRequest\x1a\x1c.SingleBatchBackwardResponse\"\x00\x62\x06proto3')
 
 _globals = globals()
 _builder.BuildMessageAndEnumDescriptors(DESCRIPTOR, _globals)
@@ -30,60 +30,60 @@ if _descriptor._USE_C_DESCRIPTORS == False:
   _globals['_SINGLEBATCHBACKWARDRESPONSE']._serialized_end=214
   _globals['_SINGLEBATCHTRAININGREQUEST']._serialized_start=216
   _globals['_SINGLEBATCHTRAININGREQUEST']._serialized_end=265
-  _globals['_SINGLEBATCHTRAININGRESPONSE']._serialized_start=267
-  _globals['_SINGLEBATCHTRAININGRESPONSE']._serialized_end=357
-  _globals['_TRAINGLOBALPARALLELSPLITLEARNINGREQUEST']._serialized_start=359
-  _globals['_TRAINGLOBALPARALLELSPLITLEARNINGREQUEST']._serialized_end=436
-  _globals['_TRAINGLOBALPARALLELSPLITLEARNINGRESPONSE']._serialized_start=439
-  _globals['_TRAINGLOBALPARALLELSPLITLEARNINGRESPONSE']._serialized_end=642
-  _globals['_TRAINGLOBALREQUEST']._serialized_start=644
-  _globals['_TRAINGLOBALREQUEST']._serialized_end=716
-  _globals['_TRAINGLOBALRESPONSE']._serialized_start=719
-  _globals['_TRAINGLOBALRESPONSE']._serialized_end=901
-  _globals['_SETWEIGHTSREQUEST']._serialized_start=903
-  _globals['_SETWEIGHTSREQUEST']._serialized_end=968
-  _globals['_SETWEIGHTSRESPONSE']._serialized_start=970
-  _globals['_SETWEIGHTSRESPONSE']._serialized_end=1056
-  _globals['_TRAINEPOCHREQUEST']._serialized_start=1058
-  _globals['_TRAINEPOCHREQUEST']._serialized_end=1142
-  _globals['_TRAINEPOCHRESPONSE']._serialized_start=1144
-  _globals['_TRAINEPOCHRESPONSE']._serialized_end=1257
-  _globals['_TRAINBATCHREQUEST']._serialized_start=1259
-  _globals['_TRAINBATCHREQUEST']._serialized_end=1339
-  _globals['_TRAINBATCHRESPONSE']._serialized_start=1341
-  _globals['_TRAINBATCHRESPONSE']._serialized_end=1458
-  _globals['_EVALGLOBALREQUEST']._serialized_start=1460
-  _globals['_EVALGLOBALREQUEST']._serialized_end=1518
-  _globals['_EVALGLOBALRESPONSE']._serialized_start=1520
-  _globals['_EVALGLOBALRESPONSE']._serialized_end=1633
-  _globals['_EVALREQUEST']._serialized_start=1635
-  _globals['_EVALREQUEST']._serialized_end=1697
-  _globals['_EVALRESPONSE']._serialized_start=1699
-  _globals['_EVALRESPONSE']._serialized_end=1779
-  _globals['_EVALBATCHREQUEST']._serialized_start=1781
-  _globals['_EVALBATCHREQUEST']._serialized_end=1860
-  _globals['_EVALBATCHRESPONSE']._serialized_start=1862
-  _globals['_EVALBATCHRESPONSE']._serialized_end=1974
-  _globals['_FULLMODELTRAINREQUEST']._serialized_start=1976
-  _globals['_FULLMODELTRAINREQUEST']._serialized_end=2035
-  _globals['_FULLMODELTRAINRESPONSE']._serialized_start=2038
-  _globals['_FULLMODELTRAINRESPONSE']._serialized_end=2244
-  _globals['_STARTEXPERIMENTREQUEST']._serialized_start=2246
-  _globals['_STARTEXPERIMENTREQUEST']._serialized_end=2270
-  _globals['_STARTEXPERIMENTRESPONSE']._serialized_start=2272
-  _globals['_STARTEXPERIMENTRESPONSE']._serialized_end=2363
-  _globals['_ENDEXPERIMENTREQUEST']._serialized_start=2365
-  _globals['_ENDEXPERIMENTREQUEST']._serialized_end=2387
-  _globals['_ENDEXPERIMENTRESPONSE']._serialized_start=2389
-  _globals['_ENDEXPERIMENTRESPONSE']._serialized_end=2478
-  _globals['_BATTERYSTATUSREQUEST']._serialized_start=2480
-  _globals['_BATTERYSTATUSREQUEST']._serialized_end=2502
-  _globals['_BATTERYSTATUSRESPONSE']._serialized_start=2504
-  _globals['_BATTERYSTATUSRESPONSE']._serialized_end=2625
-  _globals['_DATASETMODELINFOREQUEST']._serialized_start=2627
-  _globals['_DATASETMODELINFOREQUEST']._serialized_end=2652
-  _globals['_DATASETMODELINFORESPONSE']._serialized_start=2655
-  _globals['_DATASETMODELINFORESPONSE']._serialized_end=2854
-  _globals['_DEVICE']._serialized_start=2857
-  _globals['_DEVICE']._serialized_end=3930
+  _globals['_SINGLEBATCHTRAININGRESPONSE']._serialized_start=268
+  _globals['_SINGLEBATCHTRAININGRESPONSE']._serialized_end=396
+  _globals['_TRAINGLOBALPARALLELSPLITLEARNINGREQUEST']._serialized_start=399
+  _globals['_TRAINGLOBALPARALLELSPLITLEARNINGREQUEST']._serialized_end=538
+  _globals['_TRAINGLOBALPARALLELSPLITLEARNINGRESPONSE']._serialized_start=541
+  _globals['_TRAINGLOBALPARALLELSPLITLEARNINGRESPONSE']._serialized_end=806
+  _globals['_TRAINGLOBALREQUEST']._serialized_start=809
+  _globals['_TRAINGLOBALREQUEST']._serialized_end=943
+  _globals['_TRAINGLOBALRESPONSE']._serialized_start=946
+  _globals['_TRAINGLOBALRESPONSE']._serialized_end=1190
+  _globals['_SETWEIGHTSREQUEST']._serialized_start=1192
+  _globals['_SETWEIGHTSREQUEST']._serialized_end=1257
+  _globals['_SETWEIGHTSRESPONSE']._serialized_start=1259
+  _globals['_SETWEIGHTSRESPONSE']._serialized_end=1345
+  _globals['_TRAINEPOCHREQUEST']._serialized_start=1347
+  _globals['_TRAINEPOCHREQUEST']._serialized_end=1431
+  _globals['_TRAINEPOCHRESPONSE']._serialized_start=1433
+  _globals['_TRAINEPOCHRESPONSE']._serialized_end=1546
+  _globals['_TRAINBATCHREQUEST']._serialized_start=1548
+  _globals['_TRAINBATCHREQUEST']._serialized_end=1628
+  _globals['_TRAINBATCHRESPONSE']._serialized_start=1630
+  _globals['_TRAINBATCHRESPONSE']._serialized_end=1747
+  _globals['_EVALGLOBALREQUEST']._serialized_start=1749
+  _globals['_EVALGLOBALREQUEST']._serialized_end=1807
+  _globals['_EVALGLOBALRESPONSE']._serialized_start=1809
+  _globals['_EVALGLOBALRESPONSE']._serialized_end=1922
+  _globals['_EVALREQUEST']._serialized_start=1924
+  _globals['_EVALREQUEST']._serialized_end=1986
+  _globals['_EVALRESPONSE']._serialized_start=1988
+  _globals['_EVALRESPONSE']._serialized_end=2068
+  _globals['_EVALBATCHREQUEST']._serialized_start=2070
+  _globals['_EVALBATCHREQUEST']._serialized_end=2149
+  _globals['_EVALBATCHRESPONSE']._serialized_start=2151
+  _globals['_EVALBATCHRESPONSE']._serialized_end=2263
+  _globals['_FULLMODELTRAINREQUEST']._serialized_start=2265
+  _globals['_FULLMODELTRAINREQUEST']._serialized_end=2324
+  _globals['_FULLMODELTRAINRESPONSE']._serialized_start=2327
+  _globals['_FULLMODELTRAINRESPONSE']._serialized_end=2533
+  _globals['_STARTEXPERIMENTREQUEST']._serialized_start=2535
+  _globals['_STARTEXPERIMENTREQUEST']._serialized_end=2559
+  _globals['_STARTEXPERIMENTRESPONSE']._serialized_start=2561
+  _globals['_STARTEXPERIMENTRESPONSE']._serialized_end=2652
+  _globals['_ENDEXPERIMENTREQUEST']._serialized_start=2654
+  _globals['_ENDEXPERIMENTREQUEST']._serialized_end=2676
+  _globals['_ENDEXPERIMENTRESPONSE']._serialized_start=2678
+  _globals['_ENDEXPERIMENTRESPONSE']._serialized_end=2767
+  _globals['_BATTERYSTATUSREQUEST']._serialized_start=2769
+  _globals['_BATTERYSTATUSREQUEST']._serialized_end=2791
+  _globals['_BATTERYSTATUSRESPONSE']._serialized_start=2793
+  _globals['_BATTERYSTATUSRESPONSE']._serialized_end=2914
+  _globals['_DATASETMODELINFOREQUEST']._serialized_start=2916
+  _globals['_DATASETMODELINFOREQUEST']._serialized_end=2941
+  _globals['_DATASETMODELINFORESPONSE']._serialized_start=2944
+  _globals['_DATASETMODELINFORESPONSE']._serialized_end=3143
+  _globals['_DEVICE']._serialized_start=3146
+  _globals['_DEVICE']._serialized_end=4219
 # @@protoc_insertion_point(module_scope)
diff --git a/edml/generated/connection_pb2.pyi b/edml/generated/connection_pb2.pyi
index ee60f70..9cd3049 100644
--- a/edml/generated/connection_pb2.pyi
+++ b/edml/generated/connection_pb2.pyi
@@ -38,42 +38,50 @@ class SingleBatchTrainingResponse(_message.Message):
     def __init__(self, smashed_data: _Optional[_Union[_datastructures_pb2.Activations, _Mapping]] = ..., labels: _Optional[_Union[_datastructures_pb2.Labels, _Mapping]] = ...) -> None: ...
 
 class TrainGlobalParallelSplitLearningRequest(_message.Message):
-    __slots__ = ["round_no"]
+    __slots__ = ["round_no", "optimizer_state"]
     ROUND_NO_FIELD_NUMBER: _ClassVar[int]
+    OPTIMIZER_STATE_FIELD_NUMBER: _ClassVar[int]
     round_no: int
-    def __init__(self, round_no: _Optional[int] = ...) -> None: ...
+    optimizer_state: _datastructures_pb2.StateDict
+    def __init__(self, round_no: _Optional[int] = ..., optimizer_state: _Optional[_Union[_datastructures_pb2.StateDict, _Mapping]] = ...) -> None: ...
 
 class TrainGlobalParallelSplitLearningResponse(_message.Message):
-    __slots__ = ["client_weights", "server_weights", "metrics", "diagnostic_metrics"]
+    __slots__ = ["client_weights", "server_weights", "metrics", "diagnostic_metrics", "optimizer_state"]
     CLIENT_WEIGHTS_FIELD_NUMBER: _ClassVar[int]
     SERVER_WEIGHTS_FIELD_NUMBER: _ClassVar[int]
     METRICS_FIELD_NUMBER: _ClassVar[int]
     DIAGNOSTIC_METRICS_FIELD_NUMBER: _ClassVar[int]
+    OPTIMIZER_STATE_FIELD_NUMBER: _ClassVar[int]
     client_weights: _datastructures_pb2.Weights
     server_weights: _datastructures_pb2.Weights
     metrics: _datastructures_pb2.Metrics
     diagnostic_metrics: _datastructures_pb2.Metrics
-    def __init__(self, client_weights: _Optional[_Union[_datastructures_pb2.Weights, _Mapping]] = ..., server_weights: _Optional[_Union[_datastructures_pb2.Weights, _Mapping]] = ..., metrics: _Optional[_Union[_datastructures_pb2.Metrics, _Mapping]] = ..., diagnostic_metrics: _Optional[_Union[_datastructures_pb2.Metrics, _Mapping]] = ...) -> None: ...
+    optimizer_state: _datastructures_pb2.StateDict
+    def __init__(self, client_weights: _Optional[_Union[_datastructures_pb2.Weights, _Mapping]] = ..., server_weights: _Optional[_Union[_datastructures_pb2.Weights, _Mapping]] = ..., metrics: _Optional[_Union[_datastructures_pb2.Metrics, _Mapping]] = ..., diagnostic_metrics: _Optional[_Union[_datastructures_pb2.Metrics, _Mapping]] = ..., optimizer_state: _Optional[_Union[_datastructures_pb2.StateDict, _Mapping]] = ...) -> None: ...
 
 class TrainGlobalRequest(_message.Message):
-    __slots__ = ["epochs", "round_no"]
+    __slots__ = ["epochs", "round_no", "optimizer_state"]
     EPOCHS_FIELD_NUMBER: _ClassVar[int]
     ROUND_NO_FIELD_NUMBER: _ClassVar[int]
+    OPTIMIZER_STATE_FIELD_NUMBER: _ClassVar[int]
     epochs: int
     round_no: int
-    def __init__(self, epochs: _Optional[int] = ..., round_no: _Optional[int] = ...) -> None: ...
+    optimizer_state: _datastructures_pb2.StateDict
+    def __init__(self, epochs: _Optional[int] = ..., round_no: _Optional[int] = ..., optimizer_state: _Optional[_Union[_datastructures_pb2.StateDict, _Mapping]] = ...) -> None: ...
 
 class TrainGlobalResponse(_message.Message):
-    __slots__ = ["client_weights", "server_weights", "metrics", "diagnostic_metrics"]
+    __slots__ = ["client_weights", "server_weights", "metrics", "diagnostic_metrics", "optimizer_state"]
     CLIENT_WEIGHTS_FIELD_NUMBER: _ClassVar[int]
     SERVER_WEIGHTS_FIELD_NUMBER: _ClassVar[int]
     METRICS_FIELD_NUMBER: _ClassVar[int]
     DIAGNOSTIC_METRICS_FIELD_NUMBER: _ClassVar[int]
+    OPTIMIZER_STATE_FIELD_NUMBER: _ClassVar[int]
     client_weights: _datastructures_pb2.Weights
     server_weights: _datastructures_pb2.Weights
     metrics: _datastructures_pb2.Metrics
     diagnostic_metrics: _datastructures_pb2.Metrics
-    def __init__(self, client_weights: _Optional[_Union[_datastructures_pb2.Weights, _Mapping]] = ..., server_weights: _Optional[_Union[_datastructures_pb2.Weights, _Mapping]] = ..., metrics: _Optional[_Union[_datastructures_pb2.Metrics, _Mapping]] = ..., diagnostic_metrics: _Optional[_Union[_datastructures_pb2.Metrics, _Mapping]] = ...) -> None: ...
+    optimizer_state: _datastructures_pb2.StateDict
+    def __init__(self, client_weights: _Optional[_Union[_datastructures_pb2.Weights, _Mapping]] = ..., server_weights: _Optional[_Union[_datastructures_pb2.Weights, _Mapping]] = ..., metrics: _Optional[_Union[_datastructures_pb2.Metrics, _Mapping]] = ..., diagnostic_metrics: _Optional[_Union[_datastructures_pb2.Metrics, _Mapping]] = ..., optimizer_state: _Optional[_Union[_datastructures_pb2.StateDict, _Mapping]] = ...) -> None: ...
 
 class SetWeightsRequest(_message.Message):
     __slots__ = ["weights", "on_client"]
diff --git a/edml/proto/connection.proto b/edml/proto/connection.proto
index 1875097..51aa5fa 100644
--- a/edml/proto/connection.proto
+++ b/edml/proto/connection.proto
@@ -44,6 +44,7 @@ message SingleBatchTrainingResponse {
 
 message TrainGlobalParallelSplitLearningRequest {
   optional int32 round_no = 1;
+  optional StateDict optimizer_state = 2;
 }
 
 message TrainGlobalParallelSplitLearningResponse {
@@ -51,11 +52,14 @@ message TrainGlobalParallelSplitLearningResponse {
   Weights server_weights = 2;
   Metrics metrics = 3;
   optional Metrics diagnostic_metrics = 4;
+  optional StateDict optimizer_state = 5;
 }
 
 message TrainGlobalRequest {
   int32 epochs = 1;
   optional int32 round_no = 2;
+  optional StateDict optimizer_state = 3;
+
 }
 
 message TrainGlobalResponse {
@@ -63,6 +67,7 @@ message TrainGlobalResponse {
   Weights server_weights = 2;
   Metrics metrics = 3;
   optional Metrics diagnostic_metrics = 4;
+  optional StateDict optimizer_state = 5;
 }
 
 message SetWeightsRequest {
diff --git a/edml/tests/controllers/swarm_controller_test.py b/edml/tests/controllers/swarm_controller_test.py
index dd598ea..5a4ab8c 100644
--- a/edml/tests/controllers/swarm_controller_test.py
+++ b/edml/tests/controllers/swarm_controller_test.py
@@ -31,28 +31,32 @@ class SwarmControllerTest(unittest.TestCase):
             {"weights": 43},
             ModelMetricResultContainer(),
             DiagnosticMetricResultContainer(),
+            {"optimizer_state": 42},
         )
 
-        client_weights, server_weights, metrics, diagnostic_metrics = (
+        client_weights, server_weights, metrics, diagnostic_metrics, optimizer_state = (
             self.swarm_controller._swarm_train_round(None, None, "d1", 0)
         )
 
         self.assertEqual(client_weights, {"weights": 42})
         self.assertEqual(server_weights, {"weights": 43})
         self.assertEqual(metrics, ModelMetricResultContainer())
-        self.assertEqual(metrics, DiagnosticMetricResultContainer())
+        self.assertEqual(diagnostic_metrics, DiagnosticMetricResultContainer())
+        self.assertEqual(optimizer_state, {"optimizer_state": 42})
         self.mock.set_weights_on.assert_has_calls(
             [
                 call(device_id="d0", state_dict=None, on_client=True),
                 call(device_id="d1", state_dict=None, on_client=False),
             ]
         )
-        self.mock.train_global_on.assert_called_once_with("d1", epochs=1, round_no=0)
+        self.mock.train_global_on.assert_called_once_with(
+            "d1", epochs=1, round_no=0, optimizer_state=None
+        )
 
     def test_split_train_round_with_inactive_server_device(self):
         self.mock.train_global_on.return_value = False
 
-        client_weights, server_weights, metrics, diagnostic_metrics = (
+        client_weights, server_weights, metrics, diagnostic_metrics, optimizer_state = (
             self.swarm_controller._swarm_train_round(None, None, "d1", 0)
         )
 
@@ -60,13 +64,16 @@ class SwarmControllerTest(unittest.TestCase):
         self.assertEqual(server_weights, None)
         self.assertEqual(metrics, None)
         self.assertEqual(diagnostic_metrics, None)
+        self.assertEqual(optimizer_state, None)
         self.mock.set_weights_on.assert_has_calls(
             [
                 call(device_id="d0", state_dict=None, on_client=True),
                 call(device_id="d1", state_dict=None, on_client=False),
             ]
         )
-        self.mock.train_global_on.assert_called_once_with("d1", epochs=1, round_no=0)
+        self.mock.train_global_on.assert_called_once_with(
+            "d1", epochs=1, round_no=0, optimizer_state=None
+        )
 
 
 class ServerDeviceSelectionTest(unittest.TestCase):
diff --git a/edml/tests/core/device_test.py b/edml/tests/core/device_test.py
index d91f367..54b0518 100644
--- a/edml/tests/core/device_test.py
+++ b/edml/tests/core/device_test.py
@@ -129,6 +129,7 @@ class RPCDeviceServicerTest(unittest.TestCase):
             {"weights": Tensor([43])},
             self.metrics,
             self.diagnostic_metrics,
+            {"optimizer_state": 44},
         )
         request = connection_pb2.TrainGlobalRequest(epochs=42)
 
@@ -142,6 +143,9 @@ class RPCDeviceServicerTest(unittest.TestCase):
             ({"weights": Tensor([42])}, {"weights": Tensor([43])}),
         )
         self.assertEqual(proto_to_metrics(response.metrics), self.metrics)
+        self.assertEqual(
+            proto_to_state_dict(response.optimizer_state), {"optimizer_state": 44}
+        )
         self.assertEqual(code, StatusCode.OK)
         self.mock_device.train_global.assert_called_once_with(42, 0)
         self.assertEqual(
@@ -493,18 +497,23 @@ class RequestDispatcherTest(unittest.TestCase):
             server_weights=weights_to_proto(self.weights),
             metrics=metrics_to_proto(self.metrics),
             diagnostic_metrics=metrics_to_proto(self.diagnostic_metrics),
+            optimizer_state=state_dict_to_proto({"optimizer_state": 42}),
         )
 
-        client_weights, server_weights, metrics, diagnostic_metrics = (
+        client_weights, server_weights, metrics, diagnostic_metrics, optimizer_state = (
             self.dispatcher.train_global_on("1", 42, 43)
         )
 
         self.assertEqual(client_weights, self.weights)
         self.assertEqual(server_weights, self.weights)
         self.assertEqual(metrics, self.metrics)
+        self.assertEqual(optimizer_state, {"optimizer_state": 42})
+
         self._assert_field_size_added_to_diagnostic_metrics(diagnostic_metrics)
         self.mock_stub.TrainGlobal.assert_called_once_with(
-            connection_pb2.TrainGlobalRequest(epochs=42, round_no=43)
+            connection_pb2.TrainGlobalRequest(
+                epochs=42, round_no=43, optimizer_state=state_dict_to_proto(None)
+            )
         )
 
     def test_train_global_on_with_error(self):
@@ -514,7 +523,9 @@ class RequestDispatcherTest(unittest.TestCase):
 
         self.assertEqual(response, False)
         self.mock_stub.TrainGlobal.assert_called_once_with(
-            connection_pb2.TrainGlobalRequest(epochs=42, round_no=43)
+            connection_pb2.TrainGlobalRequest(
+                epochs=42, round_no=43, optimizer_state=state_dict_to_proto(None)
+            )
         )
 
     def test_set_weights_on_without_error(self):
-- 
GitLab