From b8bf39b7c8e6b1ba0dcb8b8ff0be0f8b0b7e00f4 Mon Sep 17 00:00:00 2001 From: Tim Bauerle <tim.bauerle@rwth-aachen.de> Date: Fri, 28 Jun 2024 14:17:00 +0200 Subject: [PATCH] Fixed optimizer state for sl and psl --- edml/controllers/parallel_split_controller.py | 2 +- edml/controllers/split_controller.py | 4 ++-- edml/core/device.py | 4 ++-- edml/tests/controllers/swarm_controller_test.py | 2 +- 4 files changed, 6 insertions(+), 6 deletions(-) diff --git a/edml/controllers/parallel_split_controller.py b/edml/controllers/parallel_split_controller.py index 895de9b..82bf519 100644 --- a/edml/controllers/parallel_split_controller.py +++ b/edml/controllers/parallel_split_controller.py @@ -76,7 +76,7 @@ class ParallelSplitController(BaseController): print(f"Training response was false.") break else: - cw, server_weights, metrics, _, optimizer_state = training_response + cw, server_weights, metrics, optimizer_state, _ = training_response self._aggregate_and_log_metrics(metrics, i) diff --git a/edml/controllers/split_controller.py b/edml/controllers/split_controller.py index 0f4fcc9..309f6d9 100644 --- a/edml/controllers/split_controller.py +++ b/edml/controllers/split_controller.py @@ -34,8 +34,8 @@ class SplitController(BaseController): if training_response is False: # server device unavailable break else: - client_weights, server_weights, metrics, _ = ( - training_response # no need for optimizer state + client_weights, server_weights, metrics, _, _ = ( + training_response # no need for optimizer state and diagnostic metrics ) self._aggregate_and_log_metrics(metrics, i) diff --git a/edml/core/device.py b/edml/core/device.py index d3f8a82..7fc7d8f 100644 --- a/edml/core/device.py +++ b/edml/core/device.py @@ -635,8 +635,8 @@ class DeviceRequestDispatcher: proto_to_weights(response.client_weights), proto_to_weights(response.server_weights), proto_to_metrics(response.metrics), - self._add_byte_size_to_diagnostic_metrics(response, self.device_id), proto_to_state_dict(response.optimizer_state), + self._add_byte_size_to_diagnostic_metrics(response, self.device_id), ) except grpc.RpcError: self._handle_rpc_error(server_device_id) @@ -730,8 +730,8 @@ class DeviceRequestDispatcher: Dict[str, Any], Dict[str, Any], ModelMetricResultContainer, - DiagnosticMetricResultContainer, Dict[str, Any], + DiagnosticMetricResultContainer, ], bool, ]: diff --git a/edml/tests/controllers/swarm_controller_test.py b/edml/tests/controllers/swarm_controller_test.py index 06be96d..56e9a5e 100644 --- a/edml/tests/controllers/swarm_controller_test.py +++ b/edml/tests/controllers/swarm_controller_test.py @@ -30,8 +30,8 @@ class SwarmControllerTest(unittest.TestCase): {"weights": 42}, {"weights": 43}, ModelMetricResultContainer(), - DiagnosticMetricResultContainer(), {"optimizer_state": 42}, + DiagnosticMetricResultContainer(), ) client_weights, server_weights, metrics, optimizer_state, diagnostic_metrics = ( -- GitLab