Skip to content
Snippets Groups Projects
Commit b8bf39b7 authored by Tim Tobias Bauerle's avatar Tim Tobias Bauerle
Browse files

Fixed optimizer state for sl and psl

parent 0e16a108
Branches
No related tags found
1 merge request!18Merge in main
......@@ -76,7 +76,7 @@ class ParallelSplitController(BaseController):
print(f"Training response was false.")
break
else:
cw, server_weights, metrics, _, optimizer_state = training_response
cw, server_weights, metrics, optimizer_state, _ = training_response
self._aggregate_and_log_metrics(metrics, i)
......
......@@ -34,8 +34,8 @@ class SplitController(BaseController):
if training_response is False: # server device unavailable
break
else:
client_weights, server_weights, metrics, _ = (
training_response # no need for optimizer state
client_weights, server_weights, metrics, _, _ = (
training_response # no need for optimizer state and diagnostic metrics
)
self._aggregate_and_log_metrics(metrics, i)
......
......@@ -635,8 +635,8 @@ class DeviceRequestDispatcher:
proto_to_weights(response.client_weights),
proto_to_weights(response.server_weights),
proto_to_metrics(response.metrics),
self._add_byte_size_to_diagnostic_metrics(response, self.device_id),
proto_to_state_dict(response.optimizer_state),
self._add_byte_size_to_diagnostic_metrics(response, self.device_id),
)
except grpc.RpcError:
self._handle_rpc_error(server_device_id)
......@@ -730,8 +730,8 @@ class DeviceRequestDispatcher:
Dict[str, Any],
Dict[str, Any],
ModelMetricResultContainer,
DiagnosticMetricResultContainer,
Dict[str, Any],
DiagnosticMetricResultContainer,
],
bool,
]:
......
......@@ -30,8 +30,8 @@ class SwarmControllerTest(unittest.TestCase):
{"weights": 42},
{"weights": 43},
ModelMetricResultContainer(),
DiagnosticMetricResultContainer(),
{"optimizer_state": 42},
DiagnosticMetricResultContainer(),
)
client_weights, server_weights, metrics, optimizer_state, diagnostic_metrics = (
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment