From edde9250b865343f69a836aaffdf73e6c936e5ad Mon Sep 17 00:00:00 2001 From: Tim Bauerle <tim.bauerle@rwth-aachen.de> Date: Thu, 27 Jun 2024 15:36:58 +0200 Subject: [PATCH] Added sharing of the server optimizer state for SwarmSL-based training --- edml/controllers/parallel_split_controller.py | 8 +- edml/controllers/swarm_controller.py | 32 ++++- edml/core/device.py | 60 ++++++--- edml/core/server.py | 30 ++++- edml/generated/connection_pb2.py | 114 +++++++++--------- edml/generated/connection_pb2.pyi | 24 ++-- edml/proto/connection.proto | 5 + .../controllers/swarm_controller_test.py | 17 ++- edml/tests/core/device_test.py | 17 ++- 9 files changed, 207 insertions(+), 100 deletions(-) diff --git a/edml/controllers/parallel_split_controller.py b/edml/controllers/parallel_split_controller.py index a251898..d250d21 100644 --- a/edml/controllers/parallel_split_controller.py +++ b/edml/controllers/parallel_split_controller.py @@ -17,6 +17,7 @@ class ParallelSplitController(BaseController): client_weights = None server_weights = None server_device_id = self.cfg.topology.devices[0].device_id + optimizer_state = None for i in range(self.cfg.experiment.max_rounds): print(f"=================================Round {i}") @@ -43,7 +44,10 @@ class ParallelSplitController(BaseController): # Start parallel training of all client devices. training_response = self.request_dispatcher.train_parallel_on_server( - server_device_id=server_device_id, epochs=1, round_no=i + server_device_id=server_device_id, + epochs=1, + round_no=i, + optimizer_state=optimizer_state, ) self._refresh_active_devices() @@ -62,7 +66,7 @@ class ParallelSplitController(BaseController): if training_response is False: # server device unavailable break else: - cw, server_weights, metrics, _ = training_response + cw, server_weights, metrics, _, optimizer_state = training_response self._aggregate_and_log_metrics(metrics, i) diff --git a/edml/controllers/swarm_controller.py b/edml/controllers/swarm_controller.py index feb9990..f58a378 100644 --- a/edml/controllers/swarm_controller.py +++ b/edml/controllers/swarm_controller.py @@ -1,3 +1,5 @@ +from typing import Any + from edml.controllers.base_controller import BaseController from edml.controllers.scheduler.base import NextServerScheduler from edml.controllers.scheduler.max_battery import MaxBatteryNextServerScheduler @@ -31,6 +33,7 @@ class SwarmController(BaseController): server_weights = None server_device_id = None diagnostic_metric_container = None + optimizer_state = None for i in range(self.cfg.experiment.max_rounds): print(f"Round {i}") self._update_devices_battery_status() @@ -46,10 +49,18 @@ class SwarmController(BaseController): print("No active client devices left.") break - client_weights, server_weights, metrics, diagnostic_metric_container = ( - self._swarm_train_round( - client_weights, server_weights, server_device_id, round_no=i - ) + ( + client_weights, + server_weights, + metrics, + diagnostic_metric_container, + optimizer_state, + ) = self._swarm_train_round( + client_weights, + server_weights, + server_device_id, + round_no=i, + optimizer_state=optimizer_state, ) self._refresh_active_devices() @@ -75,7 +86,12 @@ class SwarmController(BaseController): self._save_weights(client_weights, server_weights, i) def _swarm_train_round( - self, client_weights, server_weights, server_device_id, round_no: int = -1 + self, + client_weights, + server_weights, + server_device_id, + round_no: int = -1, + optimizer_state: dict[str, Any] = None, ): self._refresh_active_devices() # set latest client weights on first device to train on @@ -88,7 +104,10 @@ class SwarmController(BaseController): ) training_response = self.request_dispatcher.train_global_on( - server_device_id, epochs=1, round_no=round_no + server_device_id, + epochs=1, + round_no=round_no, + optimizer_state=optimizer_state, ) if training_response is not False: # server device unavailable @@ -99,6 +118,7 @@ class SwarmController(BaseController): server_weights, None, None, + optimizer_state, ) # return most recent weights and no metrics def _select_server_device( diff --git a/edml/core/device.py b/edml/core/device.py index 939ea0b..ee12f48 100644 --- a/edml/core/device.py +++ b/edml/core/device.py @@ -216,9 +216,11 @@ class Device(ABC): class NetworkDevice(Device): @update_battery @log_execution_time("logger", "train_parallel_split_learning") - def train_parallel_split_learning(self, clients: list[str], round_no: int): + def train_parallel_split_learning( + self, clients: list[str], round_no: int, optimizer_state: dict[str, Any] = None + ): return self.server.train_parallel_split_learning( - clients=clients, round_no=round_no + clients=clients, round_no=round_no, optimizer_state=optimizer_state ) @update_battery @@ -268,10 +270,15 @@ class NetworkDevice(Device): @update_battery @log_execution_time("logger", "train_global_time") def train_global( - self, epochs: int, round_no: int = -1 - ) -> Tuple[Any, Any, ModelMetricResultContainer, DiagnosticMetricResultContainer]: + self, epochs: int, round_no: int = -1, optimizer_state: dict[str, Any] = None + ) -> Tuple[ + Any, Any, ModelMetricResultContainer, DiagnosticMetricResultContainer, Any + ]: return self.server.train( - devices=self.__get_device_ids__(), epochs=epochs, round_no=round_no + devices=self.__get_device_ids__(), + epochs=epochs, + round_no=round_no, + optimizer_state=optimizer_state, ) def __get_device_ids__(self) -> List[str]: @@ -389,8 +396,8 @@ class NetworkDevice(Device): Any, Any, int, ModelMetricResultContainer, DiagnosticMetricResultContainer ]: """Returns client and server weights, the number of samples used for training and metrics""" - client_weights, server_weights, metrics, diagnostic_metrics = self.server.train( - devices=[self.device_id], epochs=1, round_no=round_no + client_weights, server_weights, metrics, diagnostic_metrics, _ = ( + self.server.train(devices=[self.device_id], epochs=1, round_no=round_no) ) num_samples = self.client.get_num_samples() return client_weights, server_weights, num_samples, metrics, diagnostic_metrics @@ -406,15 +413,17 @@ class RPCDeviceServicer(DeviceServicer): def TrainGlobal(self, request, context): print(f"Called TrainGlobal on device {self.device.device_id}") - client_weights, server_weights, metrics, diagnostic_metrics = ( + client_weights, server_weights, metrics, diagnostic_metrics, optimizer_state = ( self.device.train_global(request.epochs, request.round_no) ) - return connection_pb2.TrainGlobalResponse( + response = connection_pb2.TrainGlobalResponse( client_weights=Weights(weights=state_dict_to_proto(client_weights)), server_weights=Weights(weights=state_dict_to_proto(server_weights)), metrics=metrics_to_proto(metrics), diagnostic_metrics=metrics_to_proto(diagnostic_metrics), + optimizer_state=state_dict_to_proto(optimizer_state), ) + return response def SetWeights(self, request, context): print(f"Called SetWeights on device {self.device.device_id}") @@ -520,17 +529,20 @@ class RPCDeviceServicer(DeviceServicer): clients = self.device.__get_device_ids__() round_no = request.round_no - cw, sw, model_metrics, diagnostic_metrics = ( + cw, sw, model_metrics, diagnostic_metrics, optimizer_state = ( self.device.train_parallel_split_learning( clients=clients, round_no=round_no ) ) - return connection_pb2.TrainGlobalParallelSplitLearningResponse( + response = connection_pb2.TrainGlobalParallelSplitLearningResponse( client_weights=Weights(weights=state_dict_to_proto(cw)), server_weights=Weights(weights=state_dict_to_proto(sw)), metrics=metrics_to_proto(model_metrics), diagnostic_metrics=metrics_to_proto(diagnostic_metrics), ) + if optimizer_state is not None: + response.optimizer_state = state_dict_to_proto(optimizer_state) + return response def TrainSingleBatchOnClient(self, request, context): batch_index = request.batch_index @@ -588,14 +600,19 @@ class DeviceRequestDispatcher: return None def train_parallel_on_server( - self, server_device_id: str, epochs: int, round_no: int + self, + server_device_id: str, + epochs: int, + round_no: int, + optimizer_state: dict[str, Any] = None, ): try: response: TrainGlobalParallelSplitLearningResponse = self._get_connection( server_device_id ).TrainGlobalParallelSplitLearning( connection_pb2.TrainGlobalParallelSplitLearningRequest( - round_no=round_no + round_no=round_no, + optimizer_state=state_dict_to_proto(optimizer_state), ) ) return ( @@ -603,6 +620,7 @@ class DeviceRequestDispatcher: proto_to_weights(response.server_weights), proto_to_metrics(response.metrics), self._add_byte_size_to_diagnostic_metrics(response, self.device_id), + proto_to_state_dict(response.optimizer_state), ) except grpc.RpcError: self._handle_rpc_error(server_device_id) @@ -685,24 +703,36 @@ class DeviceRequestDispatcher: diagnostic_metrics.merge(_proto_size_per_field(request, device_id)) return diagnostic_metrics - def train_global_on(self, device_id: str, epochs: int, round_no: int = -1) -> Union[ + def train_global_on( + self, + device_id: str, + epochs: int, + round_no: int = -1, + optimizer_state: dict[str, Any] = None, + ) -> Union[ Tuple[ Dict[str, Any], Dict[str, Any], ModelMetricResultContainer, DiagnosticMetricResultContainer, + Dict[str, Any], ], bool, ]: try: response: TrainGlobalResponse = self._get_connection(device_id).TrainGlobal( - connection_pb2.TrainGlobalRequest(epochs=epochs, round_no=round_no) + connection_pb2.TrainGlobalRequest( + epochs=epochs, + round_no=round_no, + optimizer_state=state_dict_to_proto(optimizer_state), + ) ) return ( proto_to_weights(response.client_weights), proto_to_weights(response.server_weights), proto_to_metrics(response.metrics), self._add_byte_size_to_diagnostic_metrics(response, self.device_id), + proto_to_state_dict(response.optimizer_state), ) except grpc.RpcError: self._handle_rpc_error(device_id) diff --git a/edml/core/server.py b/edml/core/server.py index 3273722..02033d4 100644 --- a/edml/core/server.py +++ b/edml/core/server.py @@ -69,18 +69,27 @@ class DeviceServer: @check_device_set() def train( - self, devices: List[str], epochs: int = 1, round_no: int = -1 - ) -> Tuple[Any, Any, ModelMetricResultContainer, DiagnosticMetricResultContainer]: + self, + devices: List[str], + epochs: int = 1, + round_no: int = -1, + optimizer_state: dict[str, Any] = None, + ) -> Tuple[ + Any, Any, ModelMetricResultContainer, DiagnosticMetricResultContainer, Any + ]: """Train the model on the given devices for the given number of epochs. Shares the weights among clients and saves the final weights to the configured paths. Args: devices: The devices to train on epochs: Optionally, the number of epochs to train. round_no: Optionally, the current global epoch number if a learning rate scheduler is used. + optimizer_state: Optionally, the optimizer_state to proceed from """ client_weights = None metrics = ModelMetricResultContainer() diagnostic_metric_container = DiagnosticMetricResultContainer() + if optimizer_state is not None: + self._optimizer.load_state_dict(optimizer_state) for epoch in range(epochs): for device_id in devices: print( @@ -117,7 +126,13 @@ class DeviceServer: self._lr_scheduler.step(round_no + epoch) else: self._lr_scheduler.step() - return client_weights, self.get_weights(), metrics, diagnostic_metric_container + return ( + client_weights, + self.get_weights(), + metrics, + diagnostic_metric_container, + self._optimizer.state_dict(), + ) @simulate_latency_decorator(latency_factor_attr="latency_factor") def train_batch(self, smashed_data, labels) -> Variable: @@ -197,7 +212,9 @@ class DeviceServer: self._metrics.metrics_on_batch(pred.cpu(), labels.cpu().int()) @simulate_latency_decorator(latency_factor_attr="latency_factor") - def train_parallel_split_learning(self, clients: List[str], round_no: int): + def train_parallel_split_learning( + self, clients: List[str], round_no: int, optimizer_state: dict[str, Any] = None + ): def client_training_job(client_id: str, batch_index: int) -> SLTrainBatchResult: return self.node_device.train_batch_on_client_only_on( device_id=client_id, batch_index=batch_index @@ -208,6 +225,9 @@ class DeviceServer: client_id=client_id, gradients=gradients ) + if optimizer_state is not None: + self._optimizer.load_state_dict(optimizer_state) + num_threads = len(clients) executor = create_executor_with_threads(num_threads) @@ -285,11 +305,13 @@ class DeviceServer: model_metrics.add_results(train_metrics) model_metrics.add_results(val_metrics) + optimizer_state = self._optimizer.state_dict() return ( self.node_device.client.get_weights(), self.get_weights(), model_metrics, diagnostic_metrics, + optimizer_state, ) diff --git a/edml/generated/connection_pb2.py b/edml/generated/connection_pb2.py index b011e8f..73fe17b 100644 --- a/edml/generated/connection_pb2.py +++ b/edml/generated/connection_pb2.py @@ -14,7 +14,7 @@ _sym_db = _symbol_database.Default() import datastructures_pb2 as datastructures__pb2 -DESCRIPTOR = _descriptor_pool.Default().AddSerializedFile(b'\n\x10\x63onnection.proto\x1a\x14\x64\x61tastructures.proto\"5\n\x14UpdateWeightsRequest\x12\x1d\n\tgradients\x18\x01 \x01(\x0b\x32\n.Gradients\";\n\x1aSingleBatchBackwardRequest\x12\x1d\n\tgradients\x18\x01 \x01(\x0b\x32\n.Gradients\"8\n\x1bSingleBatchBackwardResponse\x12\x19\n\x07metrics\x18\x01 \x01(\x0b\x32\x08.Metrics\"1\n\x1aSingleBatchTrainingRequest\x12\x13\n\x0b\x62\x61tch_index\x18\x01 \x01(\x05\"Z\n\x1bSingleBatchTrainingResponse\x12\"\n\x0csmashed_data\x18\x01 \x01(\x0b\x32\x0c.Activations\x12\x17\n\x06labels\x18\x02 \x01(\x0b\x32\x07.Labels\"M\n\'TrainGlobalParallelSplitLearningRequest\x12\x15\n\x08round_no\x18\x01 \x01(\x05H\x00\x88\x01\x01\x42\x0b\n\t_round_no\"\xcb\x01\n(TrainGlobalParallelSplitLearningResponse\x12 \n\x0e\x63lient_weights\x18\x01 \x01(\x0b\x32\x08.Weights\x12 \n\x0eserver_weights\x18\x02 \x01(\x0b\x32\x08.Weights\x12\x19\n\x07metrics\x18\x03 \x01(\x0b\x32\x08.Metrics\x12)\n\x12\x64iagnostic_metrics\x18\x04 \x01(\x0b\x32\x08.MetricsH\x00\x88\x01\x01\x42\x15\n\x13_diagnostic_metrics\"H\n\x12TrainGlobalRequest\x12\x0e\n\x06\x65pochs\x18\x01 \x01(\x05\x12\x15\n\x08round_no\x18\x02 \x01(\x05H\x00\x88\x01\x01\x42\x0b\n\t_round_no\"\xb6\x01\n\x13TrainGlobalResponse\x12 \n\x0e\x63lient_weights\x18\x01 \x01(\x0b\x32\x08.Weights\x12 \n\x0eserver_weights\x18\x02 \x01(\x0b\x32\x08.Weights\x12\x19\n\x07metrics\x18\x03 \x01(\x0b\x32\x08.Metrics\x12)\n\x12\x64iagnostic_metrics\x18\x04 \x01(\x0b\x32\x08.MetricsH\x00\x88\x01\x01\x42\x15\n\x13_diagnostic_metrics\"A\n\x11SetWeightsRequest\x12\x19\n\x07weights\x18\x01 \x01(\x0b\x32\x08.Weights\x12\x11\n\ton_client\x18\x02 \x01(\x08\"V\n\x12SetWeightsResponse\x12)\n\x12\x64iagnostic_metrics\x18\x01 \x01(\x0b\x32\x08.MetricsH\x00\x88\x01\x01\x42\x15\n\x13_diagnostic_metrics\"T\n\x11TrainEpochRequest\x12\x1b\n\x06server\x18\x01 \x01(\x0b\x32\x0b.DeviceInfo\x12\x15\n\x08round_no\x18\x02 \x01(\x05H\x00\x88\x01\x01\x42\x0b\n\t_round_no\"q\n\x12TrainEpochResponse\x12\x19\n\x07weights\x18\x01 \x01(\x0b\x32\x08.Weights\x12)\n\x12\x64iagnostic_metrics\x18\x02 \x01(\x0b\x32\x08.MetricsH\x00\x88\x01\x01\x42\x15\n\x13_diagnostic_metrics\"P\n\x11TrainBatchRequest\x12\"\n\x0csmashed_data\x18\x01 \x01(\x0b\x32\x0c.Activations\x12\x17\n\x06labels\x18\x02 \x01(\x0b\x32\x07.Labels\"u\n\x12TrainBatchResponse\x12\x1d\n\tgradients\x18\x01 \x01(\x0b\x32\n.Gradients\x12)\n\x12\x64iagnostic_metrics\x18\x02 \x01(\x0b\x32\x08.MetricsH\x00\x88\x01\x01\x42\x15\n\x13_diagnostic_metrics\":\n\x11\x45valGlobalRequest\x12\x12\n\nvalidation\x18\x01 \x01(\x08\x12\x11\n\tfederated\x18\x02 \x01(\x08\"q\n\x12\x45valGlobalResponse\x12\x19\n\x07metrics\x18\x01 \x01(\x0b\x32\x08.Metrics\x12)\n\x12\x64iagnostic_metrics\x18\x02 \x01(\x0b\x32\x08.MetricsH\x00\x88\x01\x01\x42\x15\n\x13_diagnostic_metrics\">\n\x0b\x45valRequest\x12\x1b\n\x06server\x18\x01 \x01(\x0b\x32\x0b.DeviceInfo\x12\x12\n\nvalidation\x18\x02 \x01(\x08\"P\n\x0c\x45valResponse\x12)\n\x12\x64iagnostic_metrics\x18\x01 \x01(\x0b\x32\x08.MetricsH\x00\x88\x01\x01\x42\x15\n\x13_diagnostic_metrics\"O\n\x10\x45valBatchRequest\x12\"\n\x0csmashed_data\x18\x01 \x01(\x0b\x32\x0c.Activations\x12\x17\n\x06labels\x18\x02 \x01(\x0b\x32\x07.Labels\"p\n\x11\x45valBatchResponse\x12\x19\n\x07metrics\x18\x01 \x01(\x0b\x32\x08.Metrics\x12)\n\x12\x64iagnostic_metrics\x18\x02 \x01(\x0b\x32\x08.MetricsH\x00\x88\x01\x01\x42\x15\n\x13_diagnostic_metrics\";\n\x15\x46ullModelTrainRequest\x12\x15\n\x08round_no\x18\x01 \x01(\x05H\x00\x88\x01\x01\x42\x0b\n\t_round_no\"\xce\x01\n\x16\x46ullModelTrainResponse\x12 \n\x0e\x63lient_weights\x18\x01 \x01(\x0b\x32\x08.Weights\x12 \n\x0eserver_weights\x18\x02 \x01(\x0b\x32\x08.Weights\x12\x13\n\x0bnum_samples\x18\x03 \x01(\x05\x12\x19\n\x07metrics\x18\x04 \x01(\x0b\x32\x08.Metrics\x12)\n\x12\x64iagnostic_metrics\x18\x05 \x01(\x0b\x32\x08.MetricsH\x00\x88\x01\x01\x42\x15\n\x13_diagnostic_metrics\"\x18\n\x16StartExperimentRequest\"[\n\x17StartExperimentResponse\x12)\n\x12\x64iagnostic_metrics\x18\x01 \x01(\x0b\x32\x08.MetricsH\x00\x88\x01\x01\x42\x15\n\x13_diagnostic_metrics\"\x16\n\x14\x45ndExperimentRequest\"Y\n\x15\x45ndExperimentResponse\x12)\n\x12\x64iagnostic_metrics\x18\x01 \x01(\x0b\x32\x08.MetricsH\x00\x88\x01\x01\x42\x15\n\x13_diagnostic_metrics\"\x16\n\x14\x42\x61tteryStatusRequest\"y\n\x15\x42\x61tteryStatusResponse\x12\x1e\n\x06status\x18\x01 \x01(\x0b\x32\x0e.BatteryStatus\x12)\n\x12\x64iagnostic_metrics\x18\x02 \x01(\x0b\x32\x08.MetricsH\x00\x88\x01\x01\x42\x15\n\x13_diagnostic_metrics\"\x19\n\x17\x44\x61tasetModelInfoRequest\"\xc7\x01\n\x18\x44\x61tasetModelInfoResponse\x12\x15\n\rtrain_samples\x18\x01 \x01(\x05\x12\x1a\n\x12validation_samples\x18\x02 \x01(\x05\x12\x1a\n\x12\x63lient_model_flops\x18\x03 \x01(\x05\x12\x1a\n\x12server_model_flops\x18\x04 \x01(\x05\x12)\n\x12\x64iagnostic_metrics\x18\x05 \x01(\x0b\x32\x08.MetricsH\x00\x88\x01\x01\x42\x15\n\x13_diagnostic_metrics2\xb1\x08\n\x06\x44\x65vice\x12:\n\x0bTrainGlobal\x12\x13.TrainGlobalRequest\x1a\x14.TrainGlobalResponse\"\x00\x12\x37\n\nSetWeights\x12\x12.SetWeightsRequest\x1a\x13.SetWeightsResponse\"\x00\x12\x37\n\nTrainEpoch\x12\x12.TrainEpochRequest\x1a\x13.TrainEpochResponse\"\x00\x12\x37\n\nTrainBatch\x12\x12.TrainBatchRequest\x1a\x13.TrainBatchResponse\"\x00\x12;\n\x0e\x45valuateGlobal\x12\x12.EvalGlobalRequest\x1a\x13.EvalGlobalResponse\"\x00\x12)\n\x08\x45valuate\x12\x0c.EvalRequest\x1a\r.EvalResponse\"\x00\x12\x38\n\rEvaluateBatch\x12\x11.EvalBatchRequest\x1a\x12.EvalBatchResponse\"\x00\x12\x46\n\x11\x46ullModelTraining\x12\x16.FullModelTrainRequest\x1a\x17.FullModelTrainResponse\"\x00\x12\x46\n\x0fStartExperiment\x12\x17.StartExperimentRequest\x1a\x18.StartExperimentResponse\"\x00\x12@\n\rEndExperiment\x12\x15.EndExperimentRequest\x1a\x16.EndExperimentResponse\"\x00\x12\x43\n\x10GetBatteryStatus\x12\x15.BatteryStatusRequest\x1a\x16.BatteryStatusResponse\"\x00\x12L\n\x13GetDatasetModelInfo\x12\x18.DatasetModelInfoRequest\x1a\x19.DatasetModelInfoResponse\"\x00\x12y\n TrainGlobalParallelSplitLearning\x12(.TrainGlobalParallelSplitLearningRequest\x1a).TrainGlobalParallelSplitLearningResponse\"\x00\x12W\n\x18TrainSingleBatchOnClient\x12\x1b.SingleBatchTrainingRequest\x1a\x1c.SingleBatchTrainingResponse\"\x00\x12\x65\n&BackwardPropagationSingleBatchOnClient\x12\x1b.SingleBatchBackwardRequest\x1a\x1c.SingleBatchBackwardResponse\"\x00\x62\x06proto3') +DESCRIPTOR = _descriptor_pool.Default().AddSerializedFile(b'\n\x10\x63onnection.proto\x1a\x14\x64\x61tastructures.proto\"5\n\x14UpdateWeightsRequest\x12\x1d\n\tgradients\x18\x01 \x01(\x0b\x32\n.Gradients\";\n\x1aSingleBatchBackwardRequest\x12\x1d\n\tgradients\x18\x01 \x01(\x0b\x32\n.Gradients\"8\n\x1bSingleBatchBackwardResponse\x12\x19\n\x07metrics\x18\x01 \x01(\x0b\x32\x08.Metrics\"1\n\x1aSingleBatchTrainingRequest\x12\x13\n\x0b\x62\x61tch_index\x18\x01 \x01(\x05\"\x80\x01\n\x1bSingleBatchTrainingResponse\x12\'\n\x0csmashed_data\x18\x01 \x01(\x0b\x32\x0c.ActivationsH\x00\x88\x01\x01\x12\x1c\n\x06labels\x18\x02 \x01(\x0b\x32\x07.LabelsH\x01\x88\x01\x01\x42\x0f\n\r_smashed_dataB\t\n\x07_labels\"\x8b\x01\n\'TrainGlobalParallelSplitLearningRequest\x12\x15\n\x08round_no\x18\x01 \x01(\x05H\x00\x88\x01\x01\x12(\n\x0foptimizer_state\x18\x02 \x01(\x0b\x32\n.StateDictH\x01\x88\x01\x01\x42\x0b\n\t_round_noB\x12\n\x10_optimizer_state\"\x89\x02\n(TrainGlobalParallelSplitLearningResponse\x12 \n\x0e\x63lient_weights\x18\x01 \x01(\x0b\x32\x08.Weights\x12 \n\x0eserver_weights\x18\x02 \x01(\x0b\x32\x08.Weights\x12\x19\n\x07metrics\x18\x03 \x01(\x0b\x32\x08.Metrics\x12)\n\x12\x64iagnostic_metrics\x18\x04 \x01(\x0b\x32\x08.MetricsH\x00\x88\x01\x01\x12(\n\x0foptimizer_state\x18\x05 \x01(\x0b\x32\n.StateDictH\x01\x88\x01\x01\x42\x15\n\x13_diagnostic_metricsB\x12\n\x10_optimizer_state\"\x86\x01\n\x12TrainGlobalRequest\x12\x0e\n\x06\x65pochs\x18\x01 \x01(\x05\x12\x15\n\x08round_no\x18\x02 \x01(\x05H\x00\x88\x01\x01\x12(\n\x0foptimizer_state\x18\x03 \x01(\x0b\x32\n.StateDictH\x01\x88\x01\x01\x42\x0b\n\t_round_noB\x12\n\x10_optimizer_state\"\xf4\x01\n\x13TrainGlobalResponse\x12 \n\x0e\x63lient_weights\x18\x01 \x01(\x0b\x32\x08.Weights\x12 \n\x0eserver_weights\x18\x02 \x01(\x0b\x32\x08.Weights\x12\x19\n\x07metrics\x18\x03 \x01(\x0b\x32\x08.Metrics\x12)\n\x12\x64iagnostic_metrics\x18\x04 \x01(\x0b\x32\x08.MetricsH\x00\x88\x01\x01\x12(\n\x0foptimizer_state\x18\x05 \x01(\x0b\x32\n.StateDictH\x01\x88\x01\x01\x42\x15\n\x13_diagnostic_metricsB\x12\n\x10_optimizer_state\"A\n\x11SetWeightsRequest\x12\x19\n\x07weights\x18\x01 \x01(\x0b\x32\x08.Weights\x12\x11\n\ton_client\x18\x02 \x01(\x08\"V\n\x12SetWeightsResponse\x12)\n\x12\x64iagnostic_metrics\x18\x01 \x01(\x0b\x32\x08.MetricsH\x00\x88\x01\x01\x42\x15\n\x13_diagnostic_metrics\"T\n\x11TrainEpochRequest\x12\x1b\n\x06server\x18\x01 \x01(\x0b\x32\x0b.DeviceInfo\x12\x15\n\x08round_no\x18\x02 \x01(\x05H\x00\x88\x01\x01\x42\x0b\n\t_round_no\"q\n\x12TrainEpochResponse\x12\x19\n\x07weights\x18\x01 \x01(\x0b\x32\x08.Weights\x12)\n\x12\x64iagnostic_metrics\x18\x02 \x01(\x0b\x32\x08.MetricsH\x00\x88\x01\x01\x42\x15\n\x13_diagnostic_metrics\"P\n\x11TrainBatchRequest\x12\"\n\x0csmashed_data\x18\x01 \x01(\x0b\x32\x0c.Activations\x12\x17\n\x06labels\x18\x02 \x01(\x0b\x32\x07.Labels\"u\n\x12TrainBatchResponse\x12\x1d\n\tgradients\x18\x01 \x01(\x0b\x32\n.Gradients\x12)\n\x12\x64iagnostic_metrics\x18\x02 \x01(\x0b\x32\x08.MetricsH\x00\x88\x01\x01\x42\x15\n\x13_diagnostic_metrics\":\n\x11\x45valGlobalRequest\x12\x12\n\nvalidation\x18\x01 \x01(\x08\x12\x11\n\tfederated\x18\x02 \x01(\x08\"q\n\x12\x45valGlobalResponse\x12\x19\n\x07metrics\x18\x01 \x01(\x0b\x32\x08.Metrics\x12)\n\x12\x64iagnostic_metrics\x18\x02 \x01(\x0b\x32\x08.MetricsH\x00\x88\x01\x01\x42\x15\n\x13_diagnostic_metrics\">\n\x0b\x45valRequest\x12\x1b\n\x06server\x18\x01 \x01(\x0b\x32\x0b.DeviceInfo\x12\x12\n\nvalidation\x18\x02 \x01(\x08\"P\n\x0c\x45valResponse\x12)\n\x12\x64iagnostic_metrics\x18\x01 \x01(\x0b\x32\x08.MetricsH\x00\x88\x01\x01\x42\x15\n\x13_diagnostic_metrics\"O\n\x10\x45valBatchRequest\x12\"\n\x0csmashed_data\x18\x01 \x01(\x0b\x32\x0c.Activations\x12\x17\n\x06labels\x18\x02 \x01(\x0b\x32\x07.Labels\"p\n\x11\x45valBatchResponse\x12\x19\n\x07metrics\x18\x01 \x01(\x0b\x32\x08.Metrics\x12)\n\x12\x64iagnostic_metrics\x18\x02 \x01(\x0b\x32\x08.MetricsH\x00\x88\x01\x01\x42\x15\n\x13_diagnostic_metrics\";\n\x15\x46ullModelTrainRequest\x12\x15\n\x08round_no\x18\x01 \x01(\x05H\x00\x88\x01\x01\x42\x0b\n\t_round_no\"\xce\x01\n\x16\x46ullModelTrainResponse\x12 \n\x0e\x63lient_weights\x18\x01 \x01(\x0b\x32\x08.Weights\x12 \n\x0eserver_weights\x18\x02 \x01(\x0b\x32\x08.Weights\x12\x13\n\x0bnum_samples\x18\x03 \x01(\x05\x12\x19\n\x07metrics\x18\x04 \x01(\x0b\x32\x08.Metrics\x12)\n\x12\x64iagnostic_metrics\x18\x05 \x01(\x0b\x32\x08.MetricsH\x00\x88\x01\x01\x42\x15\n\x13_diagnostic_metrics\"\x18\n\x16StartExperimentRequest\"[\n\x17StartExperimentResponse\x12)\n\x12\x64iagnostic_metrics\x18\x01 \x01(\x0b\x32\x08.MetricsH\x00\x88\x01\x01\x42\x15\n\x13_diagnostic_metrics\"\x16\n\x14\x45ndExperimentRequest\"Y\n\x15\x45ndExperimentResponse\x12)\n\x12\x64iagnostic_metrics\x18\x01 \x01(\x0b\x32\x08.MetricsH\x00\x88\x01\x01\x42\x15\n\x13_diagnostic_metrics\"\x16\n\x14\x42\x61tteryStatusRequest\"y\n\x15\x42\x61tteryStatusResponse\x12\x1e\n\x06status\x18\x01 \x01(\x0b\x32\x0e.BatteryStatus\x12)\n\x12\x64iagnostic_metrics\x18\x02 \x01(\x0b\x32\x08.MetricsH\x00\x88\x01\x01\x42\x15\n\x13_diagnostic_metrics\"\x19\n\x17\x44\x61tasetModelInfoRequest\"\xc7\x01\n\x18\x44\x61tasetModelInfoResponse\x12\x15\n\rtrain_samples\x18\x01 \x01(\x05\x12\x1a\n\x12validation_samples\x18\x02 \x01(\x05\x12\x1a\n\x12\x63lient_model_flops\x18\x03 \x01(\x05\x12\x1a\n\x12server_model_flops\x18\x04 \x01(\x05\x12)\n\x12\x64iagnostic_metrics\x18\x05 \x01(\x0b\x32\x08.MetricsH\x00\x88\x01\x01\x42\x15\n\x13_diagnostic_metrics2\xb1\x08\n\x06\x44\x65vice\x12:\n\x0bTrainGlobal\x12\x13.TrainGlobalRequest\x1a\x14.TrainGlobalResponse\"\x00\x12\x37\n\nSetWeights\x12\x12.SetWeightsRequest\x1a\x13.SetWeightsResponse\"\x00\x12\x37\n\nTrainEpoch\x12\x12.TrainEpochRequest\x1a\x13.TrainEpochResponse\"\x00\x12\x37\n\nTrainBatch\x12\x12.TrainBatchRequest\x1a\x13.TrainBatchResponse\"\x00\x12;\n\x0e\x45valuateGlobal\x12\x12.EvalGlobalRequest\x1a\x13.EvalGlobalResponse\"\x00\x12)\n\x08\x45valuate\x12\x0c.EvalRequest\x1a\r.EvalResponse\"\x00\x12\x38\n\rEvaluateBatch\x12\x11.EvalBatchRequest\x1a\x12.EvalBatchResponse\"\x00\x12\x46\n\x11\x46ullModelTraining\x12\x16.FullModelTrainRequest\x1a\x17.FullModelTrainResponse\"\x00\x12\x46\n\x0fStartExperiment\x12\x17.StartExperimentRequest\x1a\x18.StartExperimentResponse\"\x00\x12@\n\rEndExperiment\x12\x15.EndExperimentRequest\x1a\x16.EndExperimentResponse\"\x00\x12\x43\n\x10GetBatteryStatus\x12\x15.BatteryStatusRequest\x1a\x16.BatteryStatusResponse\"\x00\x12L\n\x13GetDatasetModelInfo\x12\x18.DatasetModelInfoRequest\x1a\x19.DatasetModelInfoResponse\"\x00\x12y\n TrainGlobalParallelSplitLearning\x12(.TrainGlobalParallelSplitLearningRequest\x1a).TrainGlobalParallelSplitLearningResponse\"\x00\x12W\n\x18TrainSingleBatchOnClient\x12\x1b.SingleBatchTrainingRequest\x1a\x1c.SingleBatchTrainingResponse\"\x00\x12\x65\n&BackwardPropagationSingleBatchOnClient\x12\x1b.SingleBatchBackwardRequest\x1a\x1c.SingleBatchBackwardResponse\"\x00\x62\x06proto3') _globals = globals() _builder.BuildMessageAndEnumDescriptors(DESCRIPTOR, _globals) @@ -30,60 +30,60 @@ if _descriptor._USE_C_DESCRIPTORS == False: _globals['_SINGLEBATCHBACKWARDRESPONSE']._serialized_end=214 _globals['_SINGLEBATCHTRAININGREQUEST']._serialized_start=216 _globals['_SINGLEBATCHTRAININGREQUEST']._serialized_end=265 - _globals['_SINGLEBATCHTRAININGRESPONSE']._serialized_start=267 - _globals['_SINGLEBATCHTRAININGRESPONSE']._serialized_end=357 - _globals['_TRAINGLOBALPARALLELSPLITLEARNINGREQUEST']._serialized_start=359 - _globals['_TRAINGLOBALPARALLELSPLITLEARNINGREQUEST']._serialized_end=436 - _globals['_TRAINGLOBALPARALLELSPLITLEARNINGRESPONSE']._serialized_start=439 - _globals['_TRAINGLOBALPARALLELSPLITLEARNINGRESPONSE']._serialized_end=642 - _globals['_TRAINGLOBALREQUEST']._serialized_start=644 - _globals['_TRAINGLOBALREQUEST']._serialized_end=716 - _globals['_TRAINGLOBALRESPONSE']._serialized_start=719 - _globals['_TRAINGLOBALRESPONSE']._serialized_end=901 - _globals['_SETWEIGHTSREQUEST']._serialized_start=903 - _globals['_SETWEIGHTSREQUEST']._serialized_end=968 - _globals['_SETWEIGHTSRESPONSE']._serialized_start=970 - _globals['_SETWEIGHTSRESPONSE']._serialized_end=1056 - _globals['_TRAINEPOCHREQUEST']._serialized_start=1058 - _globals['_TRAINEPOCHREQUEST']._serialized_end=1142 - _globals['_TRAINEPOCHRESPONSE']._serialized_start=1144 - _globals['_TRAINEPOCHRESPONSE']._serialized_end=1257 - _globals['_TRAINBATCHREQUEST']._serialized_start=1259 - _globals['_TRAINBATCHREQUEST']._serialized_end=1339 - _globals['_TRAINBATCHRESPONSE']._serialized_start=1341 - _globals['_TRAINBATCHRESPONSE']._serialized_end=1458 - _globals['_EVALGLOBALREQUEST']._serialized_start=1460 - _globals['_EVALGLOBALREQUEST']._serialized_end=1518 - _globals['_EVALGLOBALRESPONSE']._serialized_start=1520 - _globals['_EVALGLOBALRESPONSE']._serialized_end=1633 - _globals['_EVALREQUEST']._serialized_start=1635 - _globals['_EVALREQUEST']._serialized_end=1697 - _globals['_EVALRESPONSE']._serialized_start=1699 - _globals['_EVALRESPONSE']._serialized_end=1779 - _globals['_EVALBATCHREQUEST']._serialized_start=1781 - _globals['_EVALBATCHREQUEST']._serialized_end=1860 - _globals['_EVALBATCHRESPONSE']._serialized_start=1862 - _globals['_EVALBATCHRESPONSE']._serialized_end=1974 - _globals['_FULLMODELTRAINREQUEST']._serialized_start=1976 - _globals['_FULLMODELTRAINREQUEST']._serialized_end=2035 - _globals['_FULLMODELTRAINRESPONSE']._serialized_start=2038 - _globals['_FULLMODELTRAINRESPONSE']._serialized_end=2244 - _globals['_STARTEXPERIMENTREQUEST']._serialized_start=2246 - _globals['_STARTEXPERIMENTREQUEST']._serialized_end=2270 - _globals['_STARTEXPERIMENTRESPONSE']._serialized_start=2272 - _globals['_STARTEXPERIMENTRESPONSE']._serialized_end=2363 - _globals['_ENDEXPERIMENTREQUEST']._serialized_start=2365 - _globals['_ENDEXPERIMENTREQUEST']._serialized_end=2387 - _globals['_ENDEXPERIMENTRESPONSE']._serialized_start=2389 - _globals['_ENDEXPERIMENTRESPONSE']._serialized_end=2478 - _globals['_BATTERYSTATUSREQUEST']._serialized_start=2480 - _globals['_BATTERYSTATUSREQUEST']._serialized_end=2502 - _globals['_BATTERYSTATUSRESPONSE']._serialized_start=2504 - _globals['_BATTERYSTATUSRESPONSE']._serialized_end=2625 - _globals['_DATASETMODELINFOREQUEST']._serialized_start=2627 - _globals['_DATASETMODELINFOREQUEST']._serialized_end=2652 - _globals['_DATASETMODELINFORESPONSE']._serialized_start=2655 - _globals['_DATASETMODELINFORESPONSE']._serialized_end=2854 - _globals['_DEVICE']._serialized_start=2857 - _globals['_DEVICE']._serialized_end=3930 + _globals['_SINGLEBATCHTRAININGRESPONSE']._serialized_start=268 + _globals['_SINGLEBATCHTRAININGRESPONSE']._serialized_end=396 + _globals['_TRAINGLOBALPARALLELSPLITLEARNINGREQUEST']._serialized_start=399 + _globals['_TRAINGLOBALPARALLELSPLITLEARNINGREQUEST']._serialized_end=538 + _globals['_TRAINGLOBALPARALLELSPLITLEARNINGRESPONSE']._serialized_start=541 + _globals['_TRAINGLOBALPARALLELSPLITLEARNINGRESPONSE']._serialized_end=806 + _globals['_TRAINGLOBALREQUEST']._serialized_start=809 + _globals['_TRAINGLOBALREQUEST']._serialized_end=943 + _globals['_TRAINGLOBALRESPONSE']._serialized_start=946 + _globals['_TRAINGLOBALRESPONSE']._serialized_end=1190 + _globals['_SETWEIGHTSREQUEST']._serialized_start=1192 + _globals['_SETWEIGHTSREQUEST']._serialized_end=1257 + _globals['_SETWEIGHTSRESPONSE']._serialized_start=1259 + _globals['_SETWEIGHTSRESPONSE']._serialized_end=1345 + _globals['_TRAINEPOCHREQUEST']._serialized_start=1347 + _globals['_TRAINEPOCHREQUEST']._serialized_end=1431 + _globals['_TRAINEPOCHRESPONSE']._serialized_start=1433 + _globals['_TRAINEPOCHRESPONSE']._serialized_end=1546 + _globals['_TRAINBATCHREQUEST']._serialized_start=1548 + _globals['_TRAINBATCHREQUEST']._serialized_end=1628 + _globals['_TRAINBATCHRESPONSE']._serialized_start=1630 + _globals['_TRAINBATCHRESPONSE']._serialized_end=1747 + _globals['_EVALGLOBALREQUEST']._serialized_start=1749 + _globals['_EVALGLOBALREQUEST']._serialized_end=1807 + _globals['_EVALGLOBALRESPONSE']._serialized_start=1809 + _globals['_EVALGLOBALRESPONSE']._serialized_end=1922 + _globals['_EVALREQUEST']._serialized_start=1924 + _globals['_EVALREQUEST']._serialized_end=1986 + _globals['_EVALRESPONSE']._serialized_start=1988 + _globals['_EVALRESPONSE']._serialized_end=2068 + _globals['_EVALBATCHREQUEST']._serialized_start=2070 + _globals['_EVALBATCHREQUEST']._serialized_end=2149 + _globals['_EVALBATCHRESPONSE']._serialized_start=2151 + _globals['_EVALBATCHRESPONSE']._serialized_end=2263 + _globals['_FULLMODELTRAINREQUEST']._serialized_start=2265 + _globals['_FULLMODELTRAINREQUEST']._serialized_end=2324 + _globals['_FULLMODELTRAINRESPONSE']._serialized_start=2327 + _globals['_FULLMODELTRAINRESPONSE']._serialized_end=2533 + _globals['_STARTEXPERIMENTREQUEST']._serialized_start=2535 + _globals['_STARTEXPERIMENTREQUEST']._serialized_end=2559 + _globals['_STARTEXPERIMENTRESPONSE']._serialized_start=2561 + _globals['_STARTEXPERIMENTRESPONSE']._serialized_end=2652 + _globals['_ENDEXPERIMENTREQUEST']._serialized_start=2654 + _globals['_ENDEXPERIMENTREQUEST']._serialized_end=2676 + _globals['_ENDEXPERIMENTRESPONSE']._serialized_start=2678 + _globals['_ENDEXPERIMENTRESPONSE']._serialized_end=2767 + _globals['_BATTERYSTATUSREQUEST']._serialized_start=2769 + _globals['_BATTERYSTATUSREQUEST']._serialized_end=2791 + _globals['_BATTERYSTATUSRESPONSE']._serialized_start=2793 + _globals['_BATTERYSTATUSRESPONSE']._serialized_end=2914 + _globals['_DATASETMODELINFOREQUEST']._serialized_start=2916 + _globals['_DATASETMODELINFOREQUEST']._serialized_end=2941 + _globals['_DATASETMODELINFORESPONSE']._serialized_start=2944 + _globals['_DATASETMODELINFORESPONSE']._serialized_end=3143 + _globals['_DEVICE']._serialized_start=3146 + _globals['_DEVICE']._serialized_end=4219 # @@protoc_insertion_point(module_scope) diff --git a/edml/generated/connection_pb2.pyi b/edml/generated/connection_pb2.pyi index ee60f70..9cd3049 100644 --- a/edml/generated/connection_pb2.pyi +++ b/edml/generated/connection_pb2.pyi @@ -38,42 +38,50 @@ class SingleBatchTrainingResponse(_message.Message): def __init__(self, smashed_data: _Optional[_Union[_datastructures_pb2.Activations, _Mapping]] = ..., labels: _Optional[_Union[_datastructures_pb2.Labels, _Mapping]] = ...) -> None: ... class TrainGlobalParallelSplitLearningRequest(_message.Message): - __slots__ = ["round_no"] + __slots__ = ["round_no", "optimizer_state"] ROUND_NO_FIELD_NUMBER: _ClassVar[int] + OPTIMIZER_STATE_FIELD_NUMBER: _ClassVar[int] round_no: int - def __init__(self, round_no: _Optional[int] = ...) -> None: ... + optimizer_state: _datastructures_pb2.StateDict + def __init__(self, round_no: _Optional[int] = ..., optimizer_state: _Optional[_Union[_datastructures_pb2.StateDict, _Mapping]] = ...) -> None: ... class TrainGlobalParallelSplitLearningResponse(_message.Message): - __slots__ = ["client_weights", "server_weights", "metrics", "diagnostic_metrics"] + __slots__ = ["client_weights", "server_weights", "metrics", "diagnostic_metrics", "optimizer_state"] CLIENT_WEIGHTS_FIELD_NUMBER: _ClassVar[int] SERVER_WEIGHTS_FIELD_NUMBER: _ClassVar[int] METRICS_FIELD_NUMBER: _ClassVar[int] DIAGNOSTIC_METRICS_FIELD_NUMBER: _ClassVar[int] + OPTIMIZER_STATE_FIELD_NUMBER: _ClassVar[int] client_weights: _datastructures_pb2.Weights server_weights: _datastructures_pb2.Weights metrics: _datastructures_pb2.Metrics diagnostic_metrics: _datastructures_pb2.Metrics - def __init__(self, client_weights: _Optional[_Union[_datastructures_pb2.Weights, _Mapping]] = ..., server_weights: _Optional[_Union[_datastructures_pb2.Weights, _Mapping]] = ..., metrics: _Optional[_Union[_datastructures_pb2.Metrics, _Mapping]] = ..., diagnostic_metrics: _Optional[_Union[_datastructures_pb2.Metrics, _Mapping]] = ...) -> None: ... + optimizer_state: _datastructures_pb2.StateDict + def __init__(self, client_weights: _Optional[_Union[_datastructures_pb2.Weights, _Mapping]] = ..., server_weights: _Optional[_Union[_datastructures_pb2.Weights, _Mapping]] = ..., metrics: _Optional[_Union[_datastructures_pb2.Metrics, _Mapping]] = ..., diagnostic_metrics: _Optional[_Union[_datastructures_pb2.Metrics, _Mapping]] = ..., optimizer_state: _Optional[_Union[_datastructures_pb2.StateDict, _Mapping]] = ...) -> None: ... class TrainGlobalRequest(_message.Message): - __slots__ = ["epochs", "round_no"] + __slots__ = ["epochs", "round_no", "optimizer_state"] EPOCHS_FIELD_NUMBER: _ClassVar[int] ROUND_NO_FIELD_NUMBER: _ClassVar[int] + OPTIMIZER_STATE_FIELD_NUMBER: _ClassVar[int] epochs: int round_no: int - def __init__(self, epochs: _Optional[int] = ..., round_no: _Optional[int] = ...) -> None: ... + optimizer_state: _datastructures_pb2.StateDict + def __init__(self, epochs: _Optional[int] = ..., round_no: _Optional[int] = ..., optimizer_state: _Optional[_Union[_datastructures_pb2.StateDict, _Mapping]] = ...) -> None: ... class TrainGlobalResponse(_message.Message): - __slots__ = ["client_weights", "server_weights", "metrics", "diagnostic_metrics"] + __slots__ = ["client_weights", "server_weights", "metrics", "diagnostic_metrics", "optimizer_state"] CLIENT_WEIGHTS_FIELD_NUMBER: _ClassVar[int] SERVER_WEIGHTS_FIELD_NUMBER: _ClassVar[int] METRICS_FIELD_NUMBER: _ClassVar[int] DIAGNOSTIC_METRICS_FIELD_NUMBER: _ClassVar[int] + OPTIMIZER_STATE_FIELD_NUMBER: _ClassVar[int] client_weights: _datastructures_pb2.Weights server_weights: _datastructures_pb2.Weights metrics: _datastructures_pb2.Metrics diagnostic_metrics: _datastructures_pb2.Metrics - def __init__(self, client_weights: _Optional[_Union[_datastructures_pb2.Weights, _Mapping]] = ..., server_weights: _Optional[_Union[_datastructures_pb2.Weights, _Mapping]] = ..., metrics: _Optional[_Union[_datastructures_pb2.Metrics, _Mapping]] = ..., diagnostic_metrics: _Optional[_Union[_datastructures_pb2.Metrics, _Mapping]] = ...) -> None: ... + optimizer_state: _datastructures_pb2.StateDict + def __init__(self, client_weights: _Optional[_Union[_datastructures_pb2.Weights, _Mapping]] = ..., server_weights: _Optional[_Union[_datastructures_pb2.Weights, _Mapping]] = ..., metrics: _Optional[_Union[_datastructures_pb2.Metrics, _Mapping]] = ..., diagnostic_metrics: _Optional[_Union[_datastructures_pb2.Metrics, _Mapping]] = ..., optimizer_state: _Optional[_Union[_datastructures_pb2.StateDict, _Mapping]] = ...) -> None: ... class SetWeightsRequest(_message.Message): __slots__ = ["weights", "on_client"] diff --git a/edml/proto/connection.proto b/edml/proto/connection.proto index 1875097..51aa5fa 100644 --- a/edml/proto/connection.proto +++ b/edml/proto/connection.proto @@ -44,6 +44,7 @@ message SingleBatchTrainingResponse { message TrainGlobalParallelSplitLearningRequest { optional int32 round_no = 1; + optional StateDict optimizer_state = 2; } message TrainGlobalParallelSplitLearningResponse { @@ -51,11 +52,14 @@ message TrainGlobalParallelSplitLearningResponse { Weights server_weights = 2; Metrics metrics = 3; optional Metrics diagnostic_metrics = 4; + optional StateDict optimizer_state = 5; } message TrainGlobalRequest { int32 epochs = 1; optional int32 round_no = 2; + optional StateDict optimizer_state = 3; + } message TrainGlobalResponse { @@ -63,6 +67,7 @@ message TrainGlobalResponse { Weights server_weights = 2; Metrics metrics = 3; optional Metrics diagnostic_metrics = 4; + optional StateDict optimizer_state = 5; } message SetWeightsRequest { diff --git a/edml/tests/controllers/swarm_controller_test.py b/edml/tests/controllers/swarm_controller_test.py index dd598ea..5a4ab8c 100644 --- a/edml/tests/controllers/swarm_controller_test.py +++ b/edml/tests/controllers/swarm_controller_test.py @@ -31,28 +31,32 @@ class SwarmControllerTest(unittest.TestCase): {"weights": 43}, ModelMetricResultContainer(), DiagnosticMetricResultContainer(), + {"optimizer_state": 42}, ) - client_weights, server_weights, metrics, diagnostic_metrics = ( + client_weights, server_weights, metrics, diagnostic_metrics, optimizer_state = ( self.swarm_controller._swarm_train_round(None, None, "d1", 0) ) self.assertEqual(client_weights, {"weights": 42}) self.assertEqual(server_weights, {"weights": 43}) self.assertEqual(metrics, ModelMetricResultContainer()) - self.assertEqual(metrics, DiagnosticMetricResultContainer()) + self.assertEqual(diagnostic_metrics, DiagnosticMetricResultContainer()) + self.assertEqual(optimizer_state, {"optimizer_state": 42}) self.mock.set_weights_on.assert_has_calls( [ call(device_id="d0", state_dict=None, on_client=True), call(device_id="d1", state_dict=None, on_client=False), ] ) - self.mock.train_global_on.assert_called_once_with("d1", epochs=1, round_no=0) + self.mock.train_global_on.assert_called_once_with( + "d1", epochs=1, round_no=0, optimizer_state=None + ) def test_split_train_round_with_inactive_server_device(self): self.mock.train_global_on.return_value = False - client_weights, server_weights, metrics, diagnostic_metrics = ( + client_weights, server_weights, metrics, diagnostic_metrics, optimizer_state = ( self.swarm_controller._swarm_train_round(None, None, "d1", 0) ) @@ -60,13 +64,16 @@ class SwarmControllerTest(unittest.TestCase): self.assertEqual(server_weights, None) self.assertEqual(metrics, None) self.assertEqual(diagnostic_metrics, None) + self.assertEqual(optimizer_state, None) self.mock.set_weights_on.assert_has_calls( [ call(device_id="d0", state_dict=None, on_client=True), call(device_id="d1", state_dict=None, on_client=False), ] ) - self.mock.train_global_on.assert_called_once_with("d1", epochs=1, round_no=0) + self.mock.train_global_on.assert_called_once_with( + "d1", epochs=1, round_no=0, optimizer_state=None + ) class ServerDeviceSelectionTest(unittest.TestCase): diff --git a/edml/tests/core/device_test.py b/edml/tests/core/device_test.py index d91f367..54b0518 100644 --- a/edml/tests/core/device_test.py +++ b/edml/tests/core/device_test.py @@ -129,6 +129,7 @@ class RPCDeviceServicerTest(unittest.TestCase): {"weights": Tensor([43])}, self.metrics, self.diagnostic_metrics, + {"optimizer_state": 44}, ) request = connection_pb2.TrainGlobalRequest(epochs=42) @@ -142,6 +143,9 @@ class RPCDeviceServicerTest(unittest.TestCase): ({"weights": Tensor([42])}, {"weights": Tensor([43])}), ) self.assertEqual(proto_to_metrics(response.metrics), self.metrics) + self.assertEqual( + proto_to_state_dict(response.optimizer_state), {"optimizer_state": 44} + ) self.assertEqual(code, StatusCode.OK) self.mock_device.train_global.assert_called_once_with(42, 0) self.assertEqual( @@ -493,18 +497,23 @@ class RequestDispatcherTest(unittest.TestCase): server_weights=weights_to_proto(self.weights), metrics=metrics_to_proto(self.metrics), diagnostic_metrics=metrics_to_proto(self.diagnostic_metrics), + optimizer_state=state_dict_to_proto({"optimizer_state": 42}), ) - client_weights, server_weights, metrics, diagnostic_metrics = ( + client_weights, server_weights, metrics, diagnostic_metrics, optimizer_state = ( self.dispatcher.train_global_on("1", 42, 43) ) self.assertEqual(client_weights, self.weights) self.assertEqual(server_weights, self.weights) self.assertEqual(metrics, self.metrics) + self.assertEqual(optimizer_state, {"optimizer_state": 42}) + self._assert_field_size_added_to_diagnostic_metrics(diagnostic_metrics) self.mock_stub.TrainGlobal.assert_called_once_with( - connection_pb2.TrainGlobalRequest(epochs=42, round_no=43) + connection_pb2.TrainGlobalRequest( + epochs=42, round_no=43, optimizer_state=state_dict_to_proto(None) + ) ) def test_train_global_on_with_error(self): @@ -514,7 +523,9 @@ class RequestDispatcherTest(unittest.TestCase): self.assertEqual(response, False) self.mock_stub.TrainGlobal.assert_called_once_with( - connection_pb2.TrainGlobalRequest(epochs=42, round_no=43) + connection_pb2.TrainGlobalRequest( + epochs=42, round_no=43, optimizer_state=state_dict_to_proto(None) + ) ) def test_set_weights_on_without_error(self): -- GitLab