diff --git a/edml/controllers/parallel_split_controller.py b/edml/controllers/parallel_split_controller.py index 895de9b52407d12ddb71b712f24a02a7843b984c..82bf519a06dc456d79ba25368914c376236275c9 100644 --- a/edml/controllers/parallel_split_controller.py +++ b/edml/controllers/parallel_split_controller.py @@ -76,7 +76,7 @@ class ParallelSplitController(BaseController): print(f"Training response was false.") break else: - cw, server_weights, metrics, _, optimizer_state = training_response + cw, server_weights, metrics, optimizer_state, _ = training_response self._aggregate_and_log_metrics(metrics, i) diff --git a/edml/controllers/split_controller.py b/edml/controllers/split_controller.py index 0f4fcc97e240ffb214844b4493967356c210caaf..309f6d95f586417261c01b185ef0f033e4ac8a46 100644 --- a/edml/controllers/split_controller.py +++ b/edml/controllers/split_controller.py @@ -34,8 +34,8 @@ class SplitController(BaseController): if training_response is False: # server device unavailable break else: - client_weights, server_weights, metrics, _ = ( - training_response # no need for optimizer state + client_weights, server_weights, metrics, _, _ = ( + training_response # no need for optimizer state and diagnostic metrics ) self._aggregate_and_log_metrics(metrics, i) diff --git a/edml/core/device.py b/edml/core/device.py index d3f8a82aa4d6340bec9bf7a54f4f3349e8fd7f16..7fc7d8fb1638792077bf04dd2e8d7c5257c5cc5c 100644 --- a/edml/core/device.py +++ b/edml/core/device.py @@ -635,8 +635,8 @@ class DeviceRequestDispatcher: proto_to_weights(response.client_weights), proto_to_weights(response.server_weights), proto_to_metrics(response.metrics), - self._add_byte_size_to_diagnostic_metrics(response, self.device_id), proto_to_state_dict(response.optimizer_state), + self._add_byte_size_to_diagnostic_metrics(response, self.device_id), ) except grpc.RpcError: self._handle_rpc_error(server_device_id) @@ -730,8 +730,8 @@ class DeviceRequestDispatcher: Dict[str, Any], Dict[str, Any], ModelMetricResultContainer, - DiagnosticMetricResultContainer, Dict[str, Any], + DiagnosticMetricResultContainer, ], bool, ]: diff --git a/edml/tests/controllers/swarm_controller_test.py b/edml/tests/controllers/swarm_controller_test.py index 06be96dfed83e48fb0b4bdc59fa5b79c73e2fd54..56e9a5ee46f6c7f4b36cf2fcfcae1030fc0278cf 100644 --- a/edml/tests/controllers/swarm_controller_test.py +++ b/edml/tests/controllers/swarm_controller_test.py @@ -30,8 +30,8 @@ class SwarmControllerTest(unittest.TestCase): {"weights": 42}, {"weights": 43}, ModelMetricResultContainer(), - DiagnosticMetricResultContainer(), {"optimizer_state": 42}, + DiagnosticMetricResultContainer(), ) client_weights, server_weights, metrics, optimizer_state, diagnostic_metrics = (