diff --git a/edml/core/device.py b/edml/core/device.py index 861625a482ac97ca36ed468577e8c67f81e325db..d3f8a82aa4d6340bec9bf7a54f4f3349e8fd7f16 100644 --- a/edml/core/device.py +++ b/edml/core/device.py @@ -403,7 +403,7 @@ class NetworkDevice(Device): Any, Any, int, ModelMetricResultContainer, DiagnosticMetricResultContainer ]: """Returns client and server weights, the number of samples used for training and metrics""" - client_weights, server_weights, metrics, diagnostic_metrics, _ = ( + client_weights, server_weights, metrics, _, diagnostic_metrics = ( self.server.train(devices=[self.device_id], epochs=1, round_no=round_no) ) num_samples = self.client.get_num_samples() diff --git a/edml/tests/controllers/swarm_controller_test.py b/edml/tests/controllers/swarm_controller_test.py index 5a4ab8c061014ab820bc791da6f417b8d478f76b..06be96dfed83e48fb0b4bdc59fa5b79c73e2fd54 100644 --- a/edml/tests/controllers/swarm_controller_test.py +++ b/edml/tests/controllers/swarm_controller_test.py @@ -34,7 +34,7 @@ class SwarmControllerTest(unittest.TestCase): {"optimizer_state": 42}, ) - client_weights, server_weights, metrics, diagnostic_metrics, optimizer_state = ( + client_weights, server_weights, metrics, optimizer_state, diagnostic_metrics = ( self.swarm_controller._swarm_train_round(None, None, "d1", 0) ) @@ -56,7 +56,7 @@ class SwarmControllerTest(unittest.TestCase): def test_split_train_round_with_inactive_server_device(self): self.mock.train_global_on.return_value = False - client_weights, server_weights, metrics, diagnostic_metrics, optimizer_state = ( + client_weights, server_weights, metrics, optimizer_state, diagnostic_metrics = ( self.swarm_controller._swarm_train_round(None, None, "d1", 0) )