Skip to content
Snippets Groups Projects
Commit 0e16a108 authored by Tim Tobias Bauerle's avatar Tim Tobias Bauerle
Browse files

Fixed FL global optimizer state

parent 5c9a4a9c
Branches
No related tags found
1 merge request!18Merge in main
......@@ -403,7 +403,7 @@ class NetworkDevice(Device):
Any, Any, int, ModelMetricResultContainer, DiagnosticMetricResultContainer
]:
"""Returns client and server weights, the number of samples used for training and metrics"""
client_weights, server_weights, metrics, diagnostic_metrics, _ = (
client_weights, server_weights, metrics, _, diagnostic_metrics = (
self.server.train(devices=[self.device_id], epochs=1, round_no=round_no)
)
num_samples = self.client.get_num_samples()
......
......@@ -34,7 +34,7 @@ class SwarmControllerTest(unittest.TestCase):
{"optimizer_state": 42},
)
client_weights, server_weights, metrics, diagnostic_metrics, optimizer_state = (
client_weights, server_weights, metrics, optimizer_state, diagnostic_metrics = (
self.swarm_controller._swarm_train_round(None, None, "d1", 0)
)
......@@ -56,7 +56,7 @@ class SwarmControllerTest(unittest.TestCase):
def test_split_train_round_with_inactive_server_device(self):
self.mock.train_global_on.return_value = False
client_weights, server_weights, metrics, diagnostic_metrics, optimizer_state = (
client_weights, server_weights, metrics, optimizer_state, diagnostic_metrics = (
self.swarm_controller._swarm_train_round(None, None, "d1", 0)
)
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment