From 0e16a108bdc116f830a8c2ae1f42f9d33571b0b2 Mon Sep 17 00:00:00 2001
From: Tim Bauerle <tim.bauerle@rwth-aachen.de>
Date: Fri, 28 Jun 2024 10:34:11 +0200
Subject: [PATCH] Fixed FL global optimizer state

---
 edml/core/device.py                             | 2 +-
 edml/tests/controllers/swarm_controller_test.py | 4 ++--
 2 files changed, 3 insertions(+), 3 deletions(-)

diff --git a/edml/core/device.py b/edml/core/device.py
index 861625a..d3f8a82 100644
--- a/edml/core/device.py
+++ b/edml/core/device.py
@@ -403,7 +403,7 @@ class NetworkDevice(Device):
         Any, Any, int, ModelMetricResultContainer, DiagnosticMetricResultContainer
     ]:
         """Returns client and server weights, the number of samples used for training and metrics"""
-        client_weights, server_weights, metrics, diagnostic_metrics, _ = (
+        client_weights, server_weights, metrics, _, diagnostic_metrics = (
             self.server.train(devices=[self.device_id], epochs=1, round_no=round_no)
         )
         num_samples = self.client.get_num_samples()
diff --git a/edml/tests/controllers/swarm_controller_test.py b/edml/tests/controllers/swarm_controller_test.py
index 5a4ab8c..06be96d 100644
--- a/edml/tests/controllers/swarm_controller_test.py
+++ b/edml/tests/controllers/swarm_controller_test.py
@@ -34,7 +34,7 @@ class SwarmControllerTest(unittest.TestCase):
             {"optimizer_state": 42},
         )
 
-        client_weights, server_weights, metrics, diagnostic_metrics, optimizer_state = (
+        client_weights, server_weights, metrics, optimizer_state, diagnostic_metrics = (
             self.swarm_controller._swarm_train_round(None, None, "d1", 0)
         )
 
@@ -56,7 +56,7 @@ class SwarmControllerTest(unittest.TestCase):
     def test_split_train_round_with_inactive_server_device(self):
         self.mock.train_global_on.return_value = False
 
-        client_weights, server_weights, metrics, diagnostic_metrics, optimizer_state = (
+        client_weights, server_weights, metrics, optimizer_state, diagnostic_metrics = (
             self.swarm_controller._swarm_train_round(None, None, "d1", 0)
         )
 
-- 
GitLab