diff --git a/edml/controllers/split_controller.py b/edml/controllers/split_controller.py
index d427a2121c97999009a5c171bd9312dab8bb7f35..0f4fcc97e240ffb214844b4493967356c210caaf 100644
--- a/edml/controllers/split_controller.py
+++ b/edml/controllers/split_controller.py
@@ -34,7 +34,9 @@ class SplitController(BaseController):
             if training_response is False:  # server device unavailable
                 break
             else:
-                client_weights, server_weights, metrics, _ = training_response
+                client_weights, server_weights, metrics, _ = (
+                    training_response  # no need for optimizer state
+                )
 
                 self._aggregate_and_log_metrics(metrics, i)
 
diff --git a/edml/controllers/swarm_controller.py b/edml/controllers/swarm_controller.py
index f58a37826765a3df692ec23488304e508347b777..595edafadb47382f158866a654053eb4ea54e150 100644
--- a/edml/controllers/swarm_controller.py
+++ b/edml/controllers/swarm_controller.py
@@ -53,8 +53,8 @@ class SwarmController(BaseController):
                 client_weights,
                 server_weights,
                 metrics,
-                diagnostic_metric_container,
                 optimizer_state,
+                diagnostic_metric_container,
             ) = self._swarm_train_round(
                 client_weights,
                 server_weights,
@@ -117,8 +117,8 @@ class SwarmController(BaseController):
                 client_weights,
                 server_weights,
                 None,
-                None,
                 optimizer_state,
+                None,
             )  # return most recent weights and no metrics
 
     def _select_server_device(
diff --git a/edml/core/device.py b/edml/core/device.py
index 7fba252498c7ccdbfd7f11c2dcc9175ee8884dae..861625a482ac97ca36ed468577e8c67f81e325db 100644
--- a/edml/core/device.py
+++ b/edml/core/device.py
@@ -221,13 +221,13 @@ class NetworkDevice(Device):
         clients: list[str],
         round_no: int,
         adaptive_learning_threshold: Optional[float] = None,
-        optimizer_state: dict[str, Any] = None
+        optimizer_state: dict[str, Any] = None,
     ):
         return self.server.train_parallel_split_learning(
             clients=clients,
             round_no=round_no,
             adaptive_learning_threshold=adaptive_learning_threshold,
-            optimizer_state=optimizer_state
+            optimizer_state=optimizer_state,
         )
 
     @update_battery
@@ -279,7 +279,7 @@ class NetworkDevice(Device):
     def train_global(
         self, epochs: int, round_no: int = -1, optimizer_state: dict[str, Any] = None
     ) -> Tuple[
-        Any, Any, ModelMetricResultContainer, DiagnosticMetricResultContainer, Any
+        Any, Any, ModelMetricResultContainer, Any, DiagnosticMetricResultContainer
     ]:
         return self.server.train(
             devices=self.__get_device_ids__(),
@@ -420,15 +420,15 @@ class RPCDeviceServicer(DeviceServicer):
 
     def TrainGlobal(self, request, context):
         print(f"Called TrainGlobal on device {self.device.device_id}")
-        client_weights, server_weights, metrics, diagnostic_metrics, optimizer_state = (
+        client_weights, server_weights, metrics, optimizer_state, diagnostic_metrics = (
             self.device.train_global(request.epochs, request.round_no)
         )
         response = connection_pb2.TrainGlobalResponse(
             client_weights=Weights(weights=state_dict_to_proto(client_weights)),
             server_weights=Weights(weights=state_dict_to_proto(server_weights)),
             metrics=metrics_to_proto(metrics),
-            diagnostic_metrics=metrics_to_proto(diagnostic_metrics),
             optimizer_state=state_dict_to_proto(optimizer_state),
+            diagnostic_metrics=metrics_to_proto(diagnostic_metrics),
         )
         return response
 
@@ -540,7 +540,7 @@ class RPCDeviceServicer(DeviceServicer):
         round_no = request.round_no
         adaptive_learning_threshold = request.adaptive_learning_threshold
 
-        cw, sw, model_metrics, diagnostic_metrics, optimizer_state = (
+        cw, sw, model_metrics, optimizer_state, diagnostic_metrics = (
             self.device.train_parallel_split_learning(
                 clients=clients,
                 round_no=round_no,
@@ -551,10 +551,9 @@ class RPCDeviceServicer(DeviceServicer):
             client_weights=Weights(weights=state_dict_to_proto(cw)),
             server_weights=Weights(weights=state_dict_to_proto(sw)),
             metrics=metrics_to_proto(model_metrics),
+            optimizer_state=state_dict_to_proto(optimizer_state),
             diagnostic_metrics=metrics_to_proto(diagnostic_metrics),
         )
-        if optimizer_state is not None:
-            response.optimizer_state = state_dict_to_proto(optimizer_state)
         return response
 
     def TrainSingleBatchOnClient(self, request, context):
@@ -748,8 +747,8 @@ class DeviceRequestDispatcher:
                 proto_to_weights(response.client_weights),
                 proto_to_weights(response.server_weights),
                 proto_to_metrics(response.metrics),
-                self._add_byte_size_to_diagnostic_metrics(response, self.device_id),
                 proto_to_state_dict(response.optimizer_state),
+                self._add_byte_size_to_diagnostic_metrics(response, self.device_id),
             )
         except grpc.RpcError:
             self._handle_rpc_error(device_id)
diff --git a/edml/core/server.py b/edml/core/server.py
index 7654a540f1854f3110d5897b34eecd5238476085..61cedcde19389a446e85da31299eacd5711fe0cf 100644
--- a/edml/core/server.py
+++ b/edml/core/server.py
@@ -76,7 +76,7 @@ class DeviceServer:
         round_no: int = -1,
         optimizer_state: dict[str, Any] = None,
     ) -> Tuple[
-        Any, Any, ModelMetricResultContainer, DiagnosticMetricResultContainer, Any
+        Any, Any, ModelMetricResultContainer, Any, DiagnosticMetricResultContainer
     ]:
         """Train the model on the given devices for the given number of epochs.
         Shares the weights among clients and saves the final weights to the configured paths.
@@ -131,8 +131,8 @@ class DeviceServer:
             client_weights,
             self.get_weights(),
             metrics,
-            diagnostic_metric_container,
             self._optimizer.state_dict(),
+            diagnostic_metric_container,
         )
 
     @simulate_latency_decorator(latency_factor_attr="latency_factor")
@@ -218,7 +218,7 @@ class DeviceServer:
         clients: List[str],
         round_no: int,
         adaptive_learning_threshold: Optional[float] = None,
-        optimizer_state: dict[str, Any] = None
+        optimizer_state: dict[str, Any] = None,
     ):
         def client_training_job(client_id: str, batch_index: int) -> SLTrainBatchResult:
             return self.node_device.train_batch_on_client_only_on(
@@ -337,8 +337,8 @@ class DeviceServer:
             self.node_device.client.get_weights(),
             self.get_weights(),
             model_metrics,
-            diagnostic_metrics,
             optimizer_state,
+            diagnostic_metrics,
         )
 
 
diff --git a/edml/generated/connection_pb2.py b/edml/generated/connection_pb2.py
index 13d7ec10fd8a955bc9e7617957704db80d32f207..d8899e228ac061499659d0d55c7575241674ad32 100644
--- a/edml/generated/connection_pb2.py
+++ b/edml/generated/connection_pb2.py
@@ -14,7 +14,7 @@ _sym_db = _symbol_database.Default()
 import datastructures_pb2 as datastructures__pb2
 
 
-DESCRIPTOR = _descriptor_pool.Default().AddSerializedFile(b'\n\x10\x63onnection.proto\x1a\x14\x64\x61tastructures.proto\"5\n\x14UpdateWeightsRequest\x12\x1d\n\tgradients\x18\x01 \x01(\x0b\x32\n.Gradients\";\n\x1aSingleBatchBackwardRequest\x12\x1d\n\tgradients\x18\x01 \x01(\x0b\x32\n.Gradients\"8\n\x1bSingleBatchBackwardResponse\x12\x19\n\x07metrics\x18\x01 \x01(\x0b\x32\x08.Metrics\"1\n\x1aSingleBatchTrainingRequest\x12\x13\n\x0b\x62\x61tch_index\x18\x01 \x01(\x05\"\x80\x01\n\x1bSingleBatchTrainingResponse\x12\'\n\x0csmashed_data\x18\x01 \x01(\x0b\x32\x0c.ActivationsH\x00\x88\x01\x01\x12\x1c\n\x06labels\x18\x02 \x01(\x0b\x32\x07.LabelsH\x01\x88\x01\x01\x42\x0f\n\r_smashed_dataB\t\n\x07_labels\"\xd5\x01\n\'TrainGlobalParallelSplitLearningRequest\x12\x15\n\x08round_no\x18\x01 \x01(\x05H\x00\x88\x01\x01\x12(\n\x1b\x61\x64\x61ptive_learning_threshold\x18\x02 \x01(\x01H\x01\x88\x01\x01\x12(\n\x0foptimizer_state\x18\x03 \x01(\x0b\x32\n.StateDictH\x02\x88\x01\x01\x42\x0b\n\t_round_noB\x1e\n\x1c_adaptive_learning_thresholdB\x12\n\x10_optimizer_state\"\x89\x02\n(TrainGlobalParallelSplitLearningResponse\x12 \n\x0e\x63lient_weights\x18\x01 \x01(\x0b\x32\x08.Weights\x12 \n\x0eserver_weights\x18\x02 \x01(\x0b\x32\x08.Weights\x12\x19\n\x07metrics\x18\x03 \x01(\x0b\x32\x08.Metrics\x12)\n\x12\x64iagnostic_metrics\x18\x04 \x01(\x0b\x32\x08.MetricsH\x00\x88\x01\x01\x12(\n\x0foptimizer_state\x18\x05 \x01(\x0b\x32\n.StateDictH\x01\x88\x01\x01\x42\x15\n\x13_diagnostic_metricsB\x12\n\x10_optimizer_state\"\x86\x01\n\x12TrainGlobalRequest\x12\x0e\n\x06\x65pochs\x18\x01 \x01(\x05\x12\x15\n\x08round_no\x18\x02 \x01(\x05H\x00\x88\x01\x01\x12(\n\x0foptimizer_state\x18\x03 \x01(\x0b\x32\n.StateDictH\x01\x88\x01\x01\x42\x0b\n\t_round_noB\x12\n\x10_optimizer_state\"\xf4\x01\n\x13TrainGlobalResponse\x12 \n\x0e\x63lient_weights\x18\x01 \x01(\x0b\x32\x08.Weights\x12 \n\x0eserver_weights\x18\x02 \x01(\x0b\x32\x08.Weights\x12\x19\n\x07metrics\x18\x03 \x01(\x0b\x32\x08.Metrics\x12)\n\x12\x64iagnostic_metrics\x18\x04 \x01(\x0b\x32\x08.MetricsH\x00\x88\x01\x01\x12(\n\x0foptimizer_state\x18\x05 \x01(\x0b\x32\n.StateDictH\x01\x88\x01\x01\x42\x15\n\x13_diagnostic_metricsB\x12\n\x10_optimizer_state\"A\n\x11SetWeightsRequest\x12\x19\n\x07weights\x18\x01 \x01(\x0b\x32\x08.Weights\x12\x11\n\ton_client\x18\x02 \x01(\x08\"V\n\x12SetWeightsResponse\x12)\n\x12\x64iagnostic_metrics\x18\x01 \x01(\x0b\x32\x08.MetricsH\x00\x88\x01\x01\x42\x15\n\x13_diagnostic_metrics\"T\n\x11TrainEpochRequest\x12\x1b\n\x06server\x18\x01 \x01(\x0b\x32\x0b.DeviceInfo\x12\x15\n\x08round_no\x18\x02 \x01(\x05H\x00\x88\x01\x01\x42\x0b\n\t_round_no\"q\n\x12TrainEpochResponse\x12\x19\n\x07weights\x18\x01 \x01(\x0b\x32\x08.Weights\x12)\n\x12\x64iagnostic_metrics\x18\x02 \x01(\x0b\x32\x08.MetricsH\x00\x88\x01\x01\x42\x15\n\x13_diagnostic_metrics\"P\n\x11TrainBatchRequest\x12\"\n\x0csmashed_data\x18\x01 \x01(\x0b\x32\x0c.Activations\x12\x17\n\x06labels\x18\x02 \x01(\x0b\x32\x07.Labels\"\x91\x01\n\x12TrainBatchResponse\x12\x1d\n\tgradients\x18\x01 \x01(\x0b\x32\n.Gradients\x12)\n\x12\x64iagnostic_metrics\x18\x02 \x01(\x0b\x32\x08.MetricsH\x00\x88\x01\x01\x12\x11\n\x04loss\x18\x03 \x01(\x01H\x01\x88\x01\x01\x42\x15\n\x13_diagnostic_metricsB\x07\n\x05_loss\":\n\x11\x45valGlobalRequest\x12\x12\n\nvalidation\x18\x01 \x01(\x08\x12\x11\n\tfederated\x18\x02 \x01(\x08\"q\n\x12\x45valGlobalResponse\x12\x19\n\x07metrics\x18\x01 \x01(\x0b\x32\x08.Metrics\x12)\n\x12\x64iagnostic_metrics\x18\x02 \x01(\x0b\x32\x08.MetricsH\x00\x88\x01\x01\x42\x15\n\x13_diagnostic_metrics\">\n\x0b\x45valRequest\x12\x1b\n\x06server\x18\x01 \x01(\x0b\x32\x0b.DeviceInfo\x12\x12\n\nvalidation\x18\x02 \x01(\x08\"P\n\x0c\x45valResponse\x12)\n\x12\x64iagnostic_metrics\x18\x01 \x01(\x0b\x32\x08.MetricsH\x00\x88\x01\x01\x42\x15\n\x13_diagnostic_metrics\"O\n\x10\x45valBatchRequest\x12\"\n\x0csmashed_data\x18\x01 \x01(\x0b\x32\x0c.Activations\x12\x17\n\x06labels\x18\x02 \x01(\x0b\x32\x07.Labels\"p\n\x11\x45valBatchResponse\x12\x19\n\x07metrics\x18\x01 \x01(\x0b\x32\x08.Metrics\x12)\n\x12\x64iagnostic_metrics\x18\x02 \x01(\x0b\x32\x08.MetricsH\x00\x88\x01\x01\x42\x15\n\x13_diagnostic_metrics\";\n\x15\x46ullModelTrainRequest\x12\x15\n\x08round_no\x18\x01 \x01(\x05H\x00\x88\x01\x01\x42\x0b\n\t_round_no\"\xce\x01\n\x16\x46ullModelTrainResponse\x12 \n\x0e\x63lient_weights\x18\x01 \x01(\x0b\x32\x08.Weights\x12 \n\x0eserver_weights\x18\x02 \x01(\x0b\x32\x08.Weights\x12\x13\n\x0bnum_samples\x18\x03 \x01(\x05\x12\x19\n\x07metrics\x18\x04 \x01(\x0b\x32\x08.Metrics\x12)\n\x12\x64iagnostic_metrics\x18\x05 \x01(\x0b\x32\x08.MetricsH\x00\x88\x01\x01\x42\x15\n\x13_diagnostic_metrics\"\x18\n\x16StartExperimentRequest\"[\n\x17StartExperimentResponse\x12)\n\x12\x64iagnostic_metrics\x18\x01 \x01(\x0b\x32\x08.MetricsH\x00\x88\x01\x01\x42\x15\n\x13_diagnostic_metrics\"\x16\n\x14\x45ndExperimentRequest\"Y\n\x15\x45ndExperimentResponse\x12)\n\x12\x64iagnostic_metrics\x18\x01 \x01(\x0b\x32\x08.MetricsH\x00\x88\x01\x01\x42\x15\n\x13_diagnostic_metrics\"\x16\n\x14\x42\x61tteryStatusRequest\"y\n\x15\x42\x61tteryStatusResponse\x12\x1e\n\x06status\x18\x01 \x01(\x0b\x32\x0e.BatteryStatus\x12)\n\x12\x64iagnostic_metrics\x18\x02 \x01(\x0b\x32\x08.MetricsH\x00\x88\x01\x01\x42\x15\n\x13_diagnostic_metrics\"\x19\n\x17\x44\x61tasetModelInfoRequest\"\xc7\x01\n\x18\x44\x61tasetModelInfoResponse\x12\x15\n\rtrain_samples\x18\x01 \x01(\x05\x12\x1a\n\x12validation_samples\x18\x02 \x01(\x05\x12\x1a\n\x12\x63lient_model_flops\x18\x03 \x01(\x05\x12\x1a\n\x12server_model_flops\x18\x04 \x01(\x05\x12)\n\x12\x64iagnostic_metrics\x18\x05 \x01(\x0b\x32\x08.MetricsH\x00\x88\x01\x01\x42\x15\n\x13_diagnostic_metrics2\xb1\x08\n\x06\x44\x65vice\x12:\n\x0bTrainGlobal\x12\x13.TrainGlobalRequest\x1a\x14.TrainGlobalResponse\"\x00\x12\x37\n\nSetWeights\x12\x12.SetWeightsRequest\x1a\x13.SetWeightsResponse\"\x00\x12\x37\n\nTrainEpoch\x12\x12.TrainEpochRequest\x1a\x13.TrainEpochResponse\"\x00\x12\x37\n\nTrainBatch\x12\x12.TrainBatchRequest\x1a\x13.TrainBatchResponse\"\x00\x12;\n\x0e\x45valuateGlobal\x12\x12.EvalGlobalRequest\x1a\x13.EvalGlobalResponse\"\x00\x12)\n\x08\x45valuate\x12\x0c.EvalRequest\x1a\r.EvalResponse\"\x00\x12\x38\n\rEvaluateBatch\x12\x11.EvalBatchRequest\x1a\x12.EvalBatchResponse\"\x00\x12\x46\n\x11\x46ullModelTraining\x12\x16.FullModelTrainRequest\x1a\x17.FullModelTrainResponse\"\x00\x12\x46\n\x0fStartExperiment\x12\x17.StartExperimentRequest\x1a\x18.StartExperimentResponse\"\x00\x12@\n\rEndExperiment\x12\x15.EndExperimentRequest\x1a\x16.EndExperimentResponse\"\x00\x12\x43\n\x10GetBatteryStatus\x12\x15.BatteryStatusRequest\x1a\x16.BatteryStatusResponse\"\x00\x12L\n\x13GetDatasetModelInfo\x12\x18.DatasetModelInfoRequest\x1a\x19.DatasetModelInfoResponse\"\x00\x12y\n TrainGlobalParallelSplitLearning\x12(.TrainGlobalParallelSplitLearningRequest\x1a).TrainGlobalParallelSplitLearningResponse\"\x00\x12W\n\x18TrainSingleBatchOnClient\x12\x1b.SingleBatchTrainingRequest\x1a\x1c.SingleBatchTrainingResponse\"\x00\x12\x65\n&BackwardPropagationSingleBatchOnClient\x12\x1b.SingleBatchBackwardRequest\x1a\x1c.SingleBatchBackwardResponse\"\x00\x62\x06proto3')
+DESCRIPTOR = _descriptor_pool.Default().AddSerializedFile(b'\n\x10\x63onnection.proto\x1a\x14\x64\x61tastructures.proto\"5\n\x14UpdateWeightsRequest\x12\x1d\n\tgradients\x18\x01 \x01(\x0b\x32\n.Gradients\";\n\x1aSingleBatchBackwardRequest\x12\x1d\n\tgradients\x18\x01 \x01(\x0b\x32\n.Gradients\"8\n\x1bSingleBatchBackwardResponse\x12\x19\n\x07metrics\x18\x01 \x01(\x0b\x32\x08.Metrics\"1\n\x1aSingleBatchTrainingRequest\x12\x13\n\x0b\x62\x61tch_index\x18\x01 \x01(\x05\"\x80\x01\n\x1bSingleBatchTrainingResponse\x12\'\n\x0csmashed_data\x18\x01 \x01(\x0b\x32\x0c.ActivationsH\x00\x88\x01\x01\x12\x1c\n\x06labels\x18\x02 \x01(\x0b\x32\x07.LabelsH\x01\x88\x01\x01\x42\x0f\n\r_smashed_dataB\t\n\x07_labels\"\xd5\x01\n\'TrainGlobalParallelSplitLearningRequest\x12\x15\n\x08round_no\x18\x01 \x01(\x05H\x00\x88\x01\x01\x12(\n\x1b\x61\x64\x61ptive_learning_threshold\x18\x02 \x01(\x01H\x01\x88\x01\x01\x12(\n\x0foptimizer_state\x18\x03 \x01(\x0b\x32\n.StateDictH\x02\x88\x01\x01\x42\x0b\n\t_round_noB\x1e\n\x1c_adaptive_learning_thresholdB\x12\n\x10_optimizer_state\"\x89\x02\n(TrainGlobalParallelSplitLearningResponse\x12 \n\x0e\x63lient_weights\x18\x01 \x01(\x0b\x32\x08.Weights\x12 \n\x0eserver_weights\x18\x02 \x01(\x0b\x32\x08.Weights\x12\x19\n\x07metrics\x18\x03 \x01(\x0b\x32\x08.Metrics\x12(\n\x0foptimizer_state\x18\x04 \x01(\x0b\x32\n.StateDictH\x00\x88\x01\x01\x12)\n\x12\x64iagnostic_metrics\x18\x05 \x01(\x0b\x32\x08.MetricsH\x01\x88\x01\x01\x42\x12\n\x10_optimizer_stateB\x15\n\x13_diagnostic_metrics\"\x86\x01\n\x12TrainGlobalRequest\x12\x0e\n\x06\x65pochs\x18\x01 \x01(\x05\x12\x15\n\x08round_no\x18\x02 \x01(\x05H\x00\x88\x01\x01\x12(\n\x0foptimizer_state\x18\x03 \x01(\x0b\x32\n.StateDictH\x01\x88\x01\x01\x42\x0b\n\t_round_noB\x12\n\x10_optimizer_state\"\xf4\x01\n\x13TrainGlobalResponse\x12 \n\x0e\x63lient_weights\x18\x01 \x01(\x0b\x32\x08.Weights\x12 \n\x0eserver_weights\x18\x02 \x01(\x0b\x32\x08.Weights\x12\x19\n\x07metrics\x18\x03 \x01(\x0b\x32\x08.Metrics\x12(\n\x0foptimizer_state\x18\x04 \x01(\x0b\x32\n.StateDictH\x00\x88\x01\x01\x12)\n\x12\x64iagnostic_metrics\x18\x05 \x01(\x0b\x32\x08.MetricsH\x01\x88\x01\x01\x42\x12\n\x10_optimizer_stateB\x15\n\x13_diagnostic_metrics\"A\n\x11SetWeightsRequest\x12\x19\n\x07weights\x18\x01 \x01(\x0b\x32\x08.Weights\x12\x11\n\ton_client\x18\x02 \x01(\x08\"V\n\x12SetWeightsResponse\x12)\n\x12\x64iagnostic_metrics\x18\x01 \x01(\x0b\x32\x08.MetricsH\x00\x88\x01\x01\x42\x15\n\x13_diagnostic_metrics\"T\n\x11TrainEpochRequest\x12\x1b\n\x06server\x18\x01 \x01(\x0b\x32\x0b.DeviceInfo\x12\x15\n\x08round_no\x18\x02 \x01(\x05H\x00\x88\x01\x01\x42\x0b\n\t_round_no\"q\n\x12TrainEpochResponse\x12\x19\n\x07weights\x18\x01 \x01(\x0b\x32\x08.Weights\x12)\n\x12\x64iagnostic_metrics\x18\x02 \x01(\x0b\x32\x08.MetricsH\x00\x88\x01\x01\x42\x15\n\x13_diagnostic_metrics\"P\n\x11TrainBatchRequest\x12\"\n\x0csmashed_data\x18\x01 \x01(\x0b\x32\x0c.Activations\x12\x17\n\x06labels\x18\x02 \x01(\x0b\x32\x07.Labels\"\x91\x01\n\x12TrainBatchResponse\x12\x1d\n\tgradients\x18\x01 \x01(\x0b\x32\n.Gradients\x12)\n\x12\x64iagnostic_metrics\x18\x02 \x01(\x0b\x32\x08.MetricsH\x00\x88\x01\x01\x12\x11\n\x04loss\x18\x03 \x01(\x01H\x01\x88\x01\x01\x42\x15\n\x13_diagnostic_metricsB\x07\n\x05_loss\":\n\x11\x45valGlobalRequest\x12\x12\n\nvalidation\x18\x01 \x01(\x08\x12\x11\n\tfederated\x18\x02 \x01(\x08\"q\n\x12\x45valGlobalResponse\x12\x19\n\x07metrics\x18\x01 \x01(\x0b\x32\x08.Metrics\x12)\n\x12\x64iagnostic_metrics\x18\x02 \x01(\x0b\x32\x08.MetricsH\x00\x88\x01\x01\x42\x15\n\x13_diagnostic_metrics\">\n\x0b\x45valRequest\x12\x1b\n\x06server\x18\x01 \x01(\x0b\x32\x0b.DeviceInfo\x12\x12\n\nvalidation\x18\x02 \x01(\x08\"P\n\x0c\x45valResponse\x12)\n\x12\x64iagnostic_metrics\x18\x01 \x01(\x0b\x32\x08.MetricsH\x00\x88\x01\x01\x42\x15\n\x13_diagnostic_metrics\"O\n\x10\x45valBatchRequest\x12\"\n\x0csmashed_data\x18\x01 \x01(\x0b\x32\x0c.Activations\x12\x17\n\x06labels\x18\x02 \x01(\x0b\x32\x07.Labels\"p\n\x11\x45valBatchResponse\x12\x19\n\x07metrics\x18\x01 \x01(\x0b\x32\x08.Metrics\x12)\n\x12\x64iagnostic_metrics\x18\x02 \x01(\x0b\x32\x08.MetricsH\x00\x88\x01\x01\x42\x15\n\x13_diagnostic_metrics\";\n\x15\x46ullModelTrainRequest\x12\x15\n\x08round_no\x18\x01 \x01(\x05H\x00\x88\x01\x01\x42\x0b\n\t_round_no\"\xce\x01\n\x16\x46ullModelTrainResponse\x12 \n\x0e\x63lient_weights\x18\x01 \x01(\x0b\x32\x08.Weights\x12 \n\x0eserver_weights\x18\x02 \x01(\x0b\x32\x08.Weights\x12\x13\n\x0bnum_samples\x18\x03 \x01(\x05\x12\x19\n\x07metrics\x18\x04 \x01(\x0b\x32\x08.Metrics\x12)\n\x12\x64iagnostic_metrics\x18\x05 \x01(\x0b\x32\x08.MetricsH\x00\x88\x01\x01\x42\x15\n\x13_diagnostic_metrics\"\x18\n\x16StartExperimentRequest\"[\n\x17StartExperimentResponse\x12)\n\x12\x64iagnostic_metrics\x18\x01 \x01(\x0b\x32\x08.MetricsH\x00\x88\x01\x01\x42\x15\n\x13_diagnostic_metrics\"\x16\n\x14\x45ndExperimentRequest\"Y\n\x15\x45ndExperimentResponse\x12)\n\x12\x64iagnostic_metrics\x18\x01 \x01(\x0b\x32\x08.MetricsH\x00\x88\x01\x01\x42\x15\n\x13_diagnostic_metrics\"\x16\n\x14\x42\x61tteryStatusRequest\"y\n\x15\x42\x61tteryStatusResponse\x12\x1e\n\x06status\x18\x01 \x01(\x0b\x32\x0e.BatteryStatus\x12)\n\x12\x64iagnostic_metrics\x18\x02 \x01(\x0b\x32\x08.MetricsH\x00\x88\x01\x01\x42\x15\n\x13_diagnostic_metrics\"\x19\n\x17\x44\x61tasetModelInfoRequest\"\xc7\x01\n\x18\x44\x61tasetModelInfoResponse\x12\x15\n\rtrain_samples\x18\x01 \x01(\x05\x12\x1a\n\x12validation_samples\x18\x02 \x01(\x05\x12\x1a\n\x12\x63lient_model_flops\x18\x03 \x01(\x05\x12\x1a\n\x12server_model_flops\x18\x04 \x01(\x05\x12)\n\x12\x64iagnostic_metrics\x18\x05 \x01(\x0b\x32\x08.MetricsH\x00\x88\x01\x01\x42\x15\n\x13_diagnostic_metrics2\xb1\x08\n\x06\x44\x65vice\x12:\n\x0bTrainGlobal\x12\x13.TrainGlobalRequest\x1a\x14.TrainGlobalResponse\"\x00\x12\x37\n\nSetWeights\x12\x12.SetWeightsRequest\x1a\x13.SetWeightsResponse\"\x00\x12\x37\n\nTrainEpoch\x12\x12.TrainEpochRequest\x1a\x13.TrainEpochResponse\"\x00\x12\x37\n\nTrainBatch\x12\x12.TrainBatchRequest\x1a\x13.TrainBatchResponse\"\x00\x12;\n\x0e\x45valuateGlobal\x12\x12.EvalGlobalRequest\x1a\x13.EvalGlobalResponse\"\x00\x12)\n\x08\x45valuate\x12\x0c.EvalRequest\x1a\r.EvalResponse\"\x00\x12\x38\n\rEvaluateBatch\x12\x11.EvalBatchRequest\x1a\x12.EvalBatchResponse\"\x00\x12\x46\n\x11\x46ullModelTraining\x12\x16.FullModelTrainRequest\x1a\x17.FullModelTrainResponse\"\x00\x12\x46\n\x0fStartExperiment\x12\x17.StartExperimentRequest\x1a\x18.StartExperimentResponse\"\x00\x12@\n\rEndExperiment\x12\x15.EndExperimentRequest\x1a\x16.EndExperimentResponse\"\x00\x12\x43\n\x10GetBatteryStatus\x12\x15.BatteryStatusRequest\x1a\x16.BatteryStatusResponse\"\x00\x12L\n\x13GetDatasetModelInfo\x12\x18.DatasetModelInfoRequest\x1a\x19.DatasetModelInfoResponse\"\x00\x12y\n TrainGlobalParallelSplitLearning\x12(.TrainGlobalParallelSplitLearningRequest\x1a).TrainGlobalParallelSplitLearningResponse\"\x00\x12W\n\x18TrainSingleBatchOnClient\x12\x1b.SingleBatchTrainingRequest\x1a\x1c.SingleBatchTrainingResponse\"\x00\x12\x65\n&BackwardPropagationSingleBatchOnClient\x12\x1b.SingleBatchBackwardRequest\x1a\x1c.SingleBatchBackwardResponse\"\x00\x62\x06proto3')
 
 _globals = globals()
 _builder.BuildMessageAndEnumDescriptors(DESCRIPTOR, _globals)
diff --git a/edml/generated/connection_pb2.pyi b/edml/generated/connection_pb2.pyi
index 4e17f0439ab8618be609b7f5f43feb54d20e86a1..730bf6e320c91c3d08548a1cbac1607352c85cad 100644
--- a/edml/generated/connection_pb2.pyi
+++ b/edml/generated/connection_pb2.pyi
@@ -48,18 +48,18 @@ class TrainGlobalParallelSplitLearningRequest(_message.Message):
     def __init__(self, round_no: _Optional[int] = ..., adaptive_learning_threshold: _Optional[float] = ..., optimizer_state: _Optional[_Union[_datastructures_pb2.StateDict, _Mapping]] = ...) -> None: ...
 
 class TrainGlobalParallelSplitLearningResponse(_message.Message):
-    __slots__ = ["client_weights", "server_weights", "metrics", "diagnostic_metrics", "optimizer_state"]
+    __slots__ = ["client_weights", "server_weights", "metrics", "optimizer_state", "diagnostic_metrics"]
     CLIENT_WEIGHTS_FIELD_NUMBER: _ClassVar[int]
     SERVER_WEIGHTS_FIELD_NUMBER: _ClassVar[int]
     METRICS_FIELD_NUMBER: _ClassVar[int]
-    DIAGNOSTIC_METRICS_FIELD_NUMBER: _ClassVar[int]
     OPTIMIZER_STATE_FIELD_NUMBER: _ClassVar[int]
+    DIAGNOSTIC_METRICS_FIELD_NUMBER: _ClassVar[int]
     client_weights: _datastructures_pb2.Weights
     server_weights: _datastructures_pb2.Weights
     metrics: _datastructures_pb2.Metrics
-    diagnostic_metrics: _datastructures_pb2.Metrics
     optimizer_state: _datastructures_pb2.StateDict
-    def __init__(self, client_weights: _Optional[_Union[_datastructures_pb2.Weights, _Mapping]] = ..., server_weights: _Optional[_Union[_datastructures_pb2.Weights, _Mapping]] = ..., metrics: _Optional[_Union[_datastructures_pb2.Metrics, _Mapping]] = ..., diagnostic_metrics: _Optional[_Union[_datastructures_pb2.Metrics, _Mapping]] = ..., optimizer_state: _Optional[_Union[_datastructures_pb2.StateDict, _Mapping]] = ...) -> None: ...
+    diagnostic_metrics: _datastructures_pb2.Metrics
+    def __init__(self, client_weights: _Optional[_Union[_datastructures_pb2.Weights, _Mapping]] = ..., server_weights: _Optional[_Union[_datastructures_pb2.Weights, _Mapping]] = ..., metrics: _Optional[_Union[_datastructures_pb2.Metrics, _Mapping]] = ..., optimizer_state: _Optional[_Union[_datastructures_pb2.StateDict, _Mapping]] = ..., diagnostic_metrics: _Optional[_Union[_datastructures_pb2.Metrics, _Mapping]] = ...) -> None: ...
 
 class TrainGlobalRequest(_message.Message):
     __slots__ = ["epochs", "round_no", "optimizer_state"]
@@ -72,18 +72,18 @@ class TrainGlobalRequest(_message.Message):
     def __init__(self, epochs: _Optional[int] = ..., round_no: _Optional[int] = ..., optimizer_state: _Optional[_Union[_datastructures_pb2.StateDict, _Mapping]] = ...) -> None: ...
 
 class TrainGlobalResponse(_message.Message):
-    __slots__ = ["client_weights", "server_weights", "metrics", "diagnostic_metrics", "optimizer_state"]
+    __slots__ = ["client_weights", "server_weights", "metrics", "optimizer_state", "diagnostic_metrics"]
     CLIENT_WEIGHTS_FIELD_NUMBER: _ClassVar[int]
     SERVER_WEIGHTS_FIELD_NUMBER: _ClassVar[int]
     METRICS_FIELD_NUMBER: _ClassVar[int]
-    DIAGNOSTIC_METRICS_FIELD_NUMBER: _ClassVar[int]
     OPTIMIZER_STATE_FIELD_NUMBER: _ClassVar[int]
+    DIAGNOSTIC_METRICS_FIELD_NUMBER: _ClassVar[int]
     client_weights: _datastructures_pb2.Weights
     server_weights: _datastructures_pb2.Weights
     metrics: _datastructures_pb2.Metrics
-    diagnostic_metrics: _datastructures_pb2.Metrics
     optimizer_state: _datastructures_pb2.StateDict
-    def __init__(self, client_weights: _Optional[_Union[_datastructures_pb2.Weights, _Mapping]] = ..., server_weights: _Optional[_Union[_datastructures_pb2.Weights, _Mapping]] = ..., metrics: _Optional[_Union[_datastructures_pb2.Metrics, _Mapping]] = ..., diagnostic_metrics: _Optional[_Union[_datastructures_pb2.Metrics, _Mapping]] = ..., optimizer_state: _Optional[_Union[_datastructures_pb2.StateDict, _Mapping]] = ...) -> None: ...
+    diagnostic_metrics: _datastructures_pb2.Metrics
+    def __init__(self, client_weights: _Optional[_Union[_datastructures_pb2.Weights, _Mapping]] = ..., server_weights: _Optional[_Union[_datastructures_pb2.Weights, _Mapping]] = ..., metrics: _Optional[_Union[_datastructures_pb2.Metrics, _Mapping]] = ..., optimizer_state: _Optional[_Union[_datastructures_pb2.StateDict, _Mapping]] = ..., diagnostic_metrics: _Optional[_Union[_datastructures_pb2.Metrics, _Mapping]] = ...) -> None: ...
 
 class SetWeightsRequest(_message.Message):
     __slots__ = ["weights", "on_client"]
diff --git a/edml/proto/connection.proto b/edml/proto/connection.proto
index 6ed477dc344b6f5d4b4b64629daabeb9c8caecd4..d03d2d1ec39b7f884673de4baf6b46a0643d0cdd 100644
--- a/edml/proto/connection.proto
+++ b/edml/proto/connection.proto
@@ -52,8 +52,8 @@ message TrainGlobalParallelSplitLearningResponse {
   Weights client_weights = 1;
   Weights server_weights = 2;
   Metrics metrics = 3;
-  optional Metrics diagnostic_metrics = 4;
-  optional StateDict optimizer_state = 5;
+  optional StateDict optimizer_state = 4;
+  optional Metrics diagnostic_metrics = 5;
 }
 
 message TrainGlobalRequest {
@@ -67,8 +67,8 @@ message TrainGlobalResponse {
   Weights client_weights = 1;
   Weights server_weights = 2;
   Metrics metrics = 3;
-  optional Metrics diagnostic_metrics = 4;
-  optional StateDict optimizer_state = 5;
+  optional StateDict optimizer_state = 4;
+  optional Metrics diagnostic_metrics = 5;
 }
 
 message SetWeightsRequest {
diff --git a/edml/tests/core/device_test.py b/edml/tests/core/device_test.py
index 6080e7f406b77bc68ec503cf2a655d657fdd4677..3827df9c727254bb13482e22bef6899bb2ae2c47 100644
--- a/edml/tests/core/device_test.py
+++ b/edml/tests/core/device_test.py
@@ -128,8 +128,8 @@ class RPCDeviceServicerTest(unittest.TestCase):
             {"weights": Tensor([42])},
             {"weights": Tensor([43])},
             self.metrics,
-            self.diagnostic_metrics,
             {"optimizer_state": 44},
+            self.diagnostic_metrics,
         )
         request = connection_pb2.TrainGlobalRequest(epochs=42)
 
@@ -501,7 +501,7 @@ class RequestDispatcherTest(unittest.TestCase):
             optimizer_state=state_dict_to_proto({"optimizer_state": 42}),
         )
 
-        client_weights, server_weights, metrics, diagnostic_metrics, optimizer_state = (
+        client_weights, server_weights, metrics, optimizer_state, diagnostic_metrics = (
             self.dispatcher.train_global_on("1", 42, 43)
         )