From 0e16a108bdc116f830a8c2ae1f42f9d33571b0b2 Mon Sep 17 00:00:00 2001 From: Tim Bauerle <tim.bauerle@rwth-aachen.de> Date: Fri, 28 Jun 2024 10:34:11 +0200 Subject: [PATCH] Fixed FL global optimizer state --- edml/core/device.py | 2 +- edml/tests/controllers/swarm_controller_test.py | 4 ++-- 2 files changed, 3 insertions(+), 3 deletions(-) diff --git a/edml/core/device.py b/edml/core/device.py index 861625a..d3f8a82 100644 --- a/edml/core/device.py +++ b/edml/core/device.py @@ -403,7 +403,7 @@ class NetworkDevice(Device): Any, Any, int, ModelMetricResultContainer, DiagnosticMetricResultContainer ]: """Returns client and server weights, the number of samples used for training and metrics""" - client_weights, server_weights, metrics, diagnostic_metrics, _ = ( + client_weights, server_weights, metrics, _, diagnostic_metrics = ( self.server.train(devices=[self.device_id], epochs=1, round_no=round_no) ) num_samples = self.client.get_num_samples() diff --git a/edml/tests/controllers/swarm_controller_test.py b/edml/tests/controllers/swarm_controller_test.py index 5a4ab8c..06be96d 100644 --- a/edml/tests/controllers/swarm_controller_test.py +++ b/edml/tests/controllers/swarm_controller_test.py @@ -34,7 +34,7 @@ class SwarmControllerTest(unittest.TestCase): {"optimizer_state": 42}, ) - client_weights, server_weights, metrics, diagnostic_metrics, optimizer_state = ( + client_weights, server_weights, metrics, optimizer_state, diagnostic_metrics = ( self.swarm_controller._swarm_train_round(None, None, "d1", 0) ) @@ -56,7 +56,7 @@ class SwarmControllerTest(unittest.TestCase): def test_split_train_round_with_inactive_server_device(self): self.mock.train_global_on.return_value = False - client_weights, server_weights, metrics, diagnostic_metrics, optimizer_state = ( + client_weights, server_weights, metrics, optimizer_state, diagnostic_metrics = ( self.swarm_controller._swarm_train_round(None, None, "d1", 0) ) -- GitLab