From b8bf39b7c8e6b1ba0dcb8b8ff0be0f8b0b7e00f4 Mon Sep 17 00:00:00 2001
From: Tim Bauerle <tim.bauerle@rwth-aachen.de>
Date: Fri, 28 Jun 2024 14:17:00 +0200
Subject: [PATCH] Fixed optimizer state for sl and psl

---
 edml/controllers/parallel_split_controller.py   | 2 +-
 edml/controllers/split_controller.py            | 4 ++--
 edml/core/device.py                             | 4 ++--
 edml/tests/controllers/swarm_controller_test.py | 2 +-
 4 files changed, 6 insertions(+), 6 deletions(-)

diff --git a/edml/controllers/parallel_split_controller.py b/edml/controllers/parallel_split_controller.py
index 895de9b..82bf519 100644
--- a/edml/controllers/parallel_split_controller.py
+++ b/edml/controllers/parallel_split_controller.py
@@ -76,7 +76,7 @@ class ParallelSplitController(BaseController):
                 print(f"Training response was false.")
                 break
             else:
-                cw, server_weights, metrics, _, optimizer_state = training_response
+                cw, server_weights, metrics, optimizer_state, _ = training_response
 
                 self._aggregate_and_log_metrics(metrics, i)
 
diff --git a/edml/controllers/split_controller.py b/edml/controllers/split_controller.py
index 0f4fcc9..309f6d9 100644
--- a/edml/controllers/split_controller.py
+++ b/edml/controllers/split_controller.py
@@ -34,8 +34,8 @@ class SplitController(BaseController):
             if training_response is False:  # server device unavailable
                 break
             else:
-                client_weights, server_weights, metrics, _ = (
-                    training_response  # no need for optimizer state
+                client_weights, server_weights, metrics, _, _ = (
+                    training_response  # no need for optimizer state and diagnostic metrics
                 )
 
                 self._aggregate_and_log_metrics(metrics, i)
diff --git a/edml/core/device.py b/edml/core/device.py
index d3f8a82..7fc7d8f 100644
--- a/edml/core/device.py
+++ b/edml/core/device.py
@@ -635,8 +635,8 @@ class DeviceRequestDispatcher:
                 proto_to_weights(response.client_weights),
                 proto_to_weights(response.server_weights),
                 proto_to_metrics(response.metrics),
-                self._add_byte_size_to_diagnostic_metrics(response, self.device_id),
                 proto_to_state_dict(response.optimizer_state),
+                self._add_byte_size_to_diagnostic_metrics(response, self.device_id),
             )
         except grpc.RpcError:
             self._handle_rpc_error(server_device_id)
@@ -730,8 +730,8 @@ class DeviceRequestDispatcher:
             Dict[str, Any],
             Dict[str, Any],
             ModelMetricResultContainer,
-            DiagnosticMetricResultContainer,
             Dict[str, Any],
+            DiagnosticMetricResultContainer,
         ],
         bool,
     ]:
diff --git a/edml/tests/controllers/swarm_controller_test.py b/edml/tests/controllers/swarm_controller_test.py
index 06be96d..56e9a5e 100644
--- a/edml/tests/controllers/swarm_controller_test.py
+++ b/edml/tests/controllers/swarm_controller_test.py
@@ -30,8 +30,8 @@ class SwarmControllerTest(unittest.TestCase):
             {"weights": 42},
             {"weights": 43},
             ModelMetricResultContainer(),
-            DiagnosticMetricResultContainer(),
             {"optimizer_state": 42},
+            DiagnosticMetricResultContainer(),
         )
 
         client_weights, server_weights, metrics, optimizer_state, diagnostic_metrics = (
-- 
GitLab