diff --git a/edml/core/client.py b/edml/core/client.py index ace69aca8f663b895b5dd8a84393e1bf8c563591..d8096c7ca5bc479a2aec458f2daf1cc442db4986 100644 --- a/edml/core/client.py +++ b/edml/core/client.py @@ -18,7 +18,7 @@ from edml.helpers.decorators import ( from edml.helpers.flops import estimate_model_flops from edml.helpers.load_optimizer import get_optimizer_and_scheduler from edml.helpers.metrics import DiagnosticMetricResultContainer, DiagnosticMetricResult -from edml.helpers.types import StateDict, SLTrainBatchResult +from edml.helpers.types import StateDict if TYPE_CHECKING: from edml.core.device import Device @@ -136,12 +136,18 @@ class DeviceClient: @check_device_set() def train_single_batch( - self, batch_index: int + self, batch_index: int, round_no: int = -1 ) -> Optional[torch.Tensor, torch.Tensor]: torch.cuda.set_device(self._device) # We have to re-initialize the data loader in the case that we do another epoch. if batch_index == 0: self._batchable_data_loader = iter(self._train_data) + # update lr scheduler in the beginning of each round + if self._lr_scheduler is not None: + if round_no != -1: + self._lr_scheduler.step(round_no) + else: + self._lr_scheduler.step() # Used to measure training time. The problem we have with parallel split learning is that forward- and backward- # passes are orchestrated by the current server. diff --git a/edml/core/device.py b/edml/core/device.py index 1348bea7e4a4fadb44a6bd304bfc466f44bbba41..94bb18ac35241fc86eec778b8ef2d9145a8ed51f 100644 --- a/edml/core/device.py +++ b/edml/core/device.py @@ -207,7 +207,9 @@ class Device(ABC): """Evaluates a batch on the server of the device with the given id""" @abstractmethod - def train_batch_on_client_only_on(self, device_id: str, batch_index: int): + def train_batch_on_client_only_on( + self, device_id: str, batch_index: int, round_no: int + ): """""" @abstractmethod @@ -265,17 +267,23 @@ class NetworkDevice(Device): @update_battery @log_execution_time("logger", "client_only_batch_train") - def train_batch_on_client_only(self, batch_index: int): - smashed_data, labels = self.client.train_single_batch(batch_index=batch_index) + def train_batch_on_client_only(self, batch_index: int, round_no: int): + smashed_data, labels = self.client.train_single_batch( + batch_index=batch_index, round_no=round_no + ) return smashed_data, labels @update_battery - def train_batch_on_client_only_on(self, device_id: str, batch_index: int): + def train_batch_on_client_only_on( + self, device_id: str, batch_index: int, round_no: int + ): if self.device_id == device_id: - return self.train_batch_on_client_only(batch_index=batch_index) + return self.train_batch_on_client_only( + batch_index=batch_index, round_no=round_no + ) else: return self.request_dispatcher.train_batch_on_client_only( - device_id=device_id, batch_index=batch_index + device_id=device_id, batch_index=batch_index, round_no=round_no ) def __init__( @@ -575,8 +583,11 @@ class RPCDeviceServicer(DeviceServicer): def TrainSingleBatchOnClient(self, request, context): batch_index = request.batch_index + round_no = request.round_no - smashed_data, labels = self.device.client.train_single_batch(batch_index) + smashed_data, labels = self.device.client.train_single_batch( + batch_index, round_no=round_no + ) smashed_data = Activations(activations=tensor_to_proto(smashed_data)) labels = Labels(labels=tensor_to_proto(labels)) @@ -955,13 +966,15 @@ class DeviceRequestDispatcher: return False def train_batch_on_client_only( - self, device_id: str, batch_index: int + self, device_id: str, batch_index: int, round_no: int ) -> Tuple[Tensor, Tensor] | None: try: response: SingleBatchTrainingResponse = self._get_connection( device_id ).TrainSingleBatchOnClient( - connection_pb2.SingleBatchTrainingRequest(batch_index=batch_index) + connection_pb2.SingleBatchTrainingRequest( + batch_index=batch_index, round_no=round_no + ) ) # The response can only be None if the last batch was smaller than the configured batch size. diff --git a/edml/core/server.py b/edml/core/server.py index c0e60fb928ddbdaa2022ff08c5fb6192aa36cf10..7503389f0c080b2eddb33bf9b932bbbdcf76d758 100644 --- a/edml/core/server.py +++ b/edml/core/server.py @@ -220,7 +220,10 @@ class DeviceServer: ): def client_training_job(client_id: str, batch_index: int): result = self.node_device.train_batch_on_client_only_on( - device_id=client_id, batch_index=batch_index + device_id=client_id, + batch_index=batch_index, + round_no=round_no, + # round_no is taken from outer method arg ) return (client_id, result) diff --git a/edml/generated/connection_pb2.py b/edml/generated/connection_pb2.py index ce271bbab4fa5d30e0e8c5a0a8b309d96eaccf07..f3c72442f44b99d587925a030002e1af92924ac8 100644 --- a/edml/generated/connection_pb2.py +++ b/edml/generated/connection_pb2.py @@ -14,7 +14,7 @@ _sym_db = _symbol_database.Default() import datastructures_pb2 as datastructures__pb2 -DESCRIPTOR = _descriptor_pool.Default().AddSerializedFile(b'\n\x10\x63onnection.proto\x1a\x14\x64\x61tastructures.proto\"4\n\x13SetGradientsRequest\x12\x1d\n\tgradients\x18\x01 \x01(\x0b\x32\n.Gradients\"5\n\x14UpdateWeightsRequest\x12\x1d\n\tgradients\x18\x01 \x01(\x0b\x32\n.Gradients\";\n\x1aSingleBatchBackwardRequest\x12\x1d\n\tgradients\x18\x01 \x01(\x0b\x32\n.Gradients\"j\n\x1bSingleBatchBackwardResponse\x12\x19\n\x07metrics\x18\x01 \x01(\x0b\x32\x08.Metrics\x12\"\n\tgradients\x18\x02 \x01(\x0b\x32\n.GradientsH\x00\x88\x01\x01\x42\x0c\n\n_gradients\"1\n\x1aSingleBatchTrainingRequest\x12\x13\n\x0b\x62\x61tch_index\x18\x01 \x01(\x05\"\x80\x01\n\x1bSingleBatchTrainingResponse\x12\'\n\x0csmashed_data\x18\x01 \x01(\x0b\x32\x0c.ActivationsH\x00\x88\x01\x01\x12\x1c\n\x06labels\x18\x02 \x01(\x0b\x32\x07.LabelsH\x01\x88\x01\x01\x42\x0f\n\r_smashed_dataB\t\n\x07_labels\"\xd5\x01\n\'TrainGlobalParallelSplitLearningRequest\x12\x15\n\x08round_no\x18\x01 \x01(\x05H\x00\x88\x01\x01\x12(\n\x1b\x61\x64\x61ptive_learning_threshold\x18\x02 \x01(\x01H\x01\x88\x01\x01\x12(\n\x0foptimizer_state\x18\x03 \x01(\x0b\x32\n.StateDictH\x02\x88\x01\x01\x42\x0b\n\t_round_noB\x1e\n\x1c_adaptive_learning_thresholdB\x12\n\x10_optimizer_state\"\x89\x02\n(TrainGlobalParallelSplitLearningResponse\x12 \n\x0e\x63lient_weights\x18\x01 \x01(\x0b\x32\x08.Weights\x12 \n\x0eserver_weights\x18\x02 \x01(\x0b\x32\x08.Weights\x12\x19\n\x07metrics\x18\x03 \x01(\x0b\x32\x08.Metrics\x12(\n\x0foptimizer_state\x18\x04 \x01(\x0b\x32\n.StateDictH\x00\x88\x01\x01\x12)\n\x12\x64iagnostic_metrics\x18\x05 \x01(\x0b\x32\x08.MetricsH\x01\x88\x01\x01\x42\x12\n\x10_optimizer_stateB\x15\n\x13_diagnostic_metrics\"\x86\x01\n\x12TrainGlobalRequest\x12\x0e\n\x06\x65pochs\x18\x01 \x01(\x05\x12\x15\n\x08round_no\x18\x02 \x01(\x05H\x00\x88\x01\x01\x12(\n\x0foptimizer_state\x18\x03 \x01(\x0b\x32\n.StateDictH\x01\x88\x01\x01\x42\x0b\n\t_round_noB\x12\n\x10_optimizer_state\"\xf4\x01\n\x13TrainGlobalResponse\x12 \n\x0e\x63lient_weights\x18\x01 \x01(\x0b\x32\x08.Weights\x12 \n\x0eserver_weights\x18\x02 \x01(\x0b\x32\x08.Weights\x12\x19\n\x07metrics\x18\x03 \x01(\x0b\x32\x08.Metrics\x12(\n\x0foptimizer_state\x18\x04 \x01(\x0b\x32\n.StateDictH\x00\x88\x01\x01\x12)\n\x12\x64iagnostic_metrics\x18\x05 \x01(\x0b\x32\x08.MetricsH\x01\x88\x01\x01\x42\x12\n\x10_optimizer_stateB\x15\n\x13_diagnostic_metrics\"A\n\x11SetWeightsRequest\x12\x19\n\x07weights\x18\x01 \x01(\x0b\x32\x08.Weights\x12\x11\n\ton_client\x18\x02 \x01(\x08\"V\n\x12SetWeightsResponse\x12)\n\x12\x64iagnostic_metrics\x18\x01 \x01(\x0b\x32\x08.MetricsH\x00\x88\x01\x01\x42\x15\n\x13_diagnostic_metrics\"T\n\x11TrainEpochRequest\x12\x1b\n\x06server\x18\x01 \x01(\x0b\x32\x0b.DeviceInfo\x12\x15\n\x08round_no\x18\x02 \x01(\x05H\x00\x88\x01\x01\x42\x0b\n\t_round_no\"q\n\x12TrainEpochResponse\x12\x19\n\x07weights\x18\x01 \x01(\x0b\x32\x08.Weights\x12)\n\x12\x64iagnostic_metrics\x18\x02 \x01(\x0b\x32\x08.MetricsH\x00\x88\x01\x01\x42\x15\n\x13_diagnostic_metrics\"P\n\x11TrainBatchRequest\x12\"\n\x0csmashed_data\x18\x01 \x01(\x0b\x32\x0c.Activations\x12\x17\n\x06labels\x18\x02 \x01(\x0b\x32\x07.Labels\"\x91\x01\n\x12TrainBatchResponse\x12\x1d\n\tgradients\x18\x01 \x01(\x0b\x32\n.Gradients\x12)\n\x12\x64iagnostic_metrics\x18\x02 \x01(\x0b\x32\x08.MetricsH\x00\x88\x01\x01\x12\x11\n\x04loss\x18\x03 \x01(\x01H\x01\x88\x01\x01\x42\x15\n\x13_diagnostic_metricsB\x07\n\x05_loss\":\n\x11\x45valGlobalRequest\x12\x12\n\nvalidation\x18\x01 \x01(\x08\x12\x11\n\tfederated\x18\x02 \x01(\x08\"q\n\x12\x45valGlobalResponse\x12\x19\n\x07metrics\x18\x01 \x01(\x0b\x32\x08.Metrics\x12)\n\x12\x64iagnostic_metrics\x18\x02 \x01(\x0b\x32\x08.MetricsH\x00\x88\x01\x01\x42\x15\n\x13_diagnostic_metrics\">\n\x0b\x45valRequest\x12\x1b\n\x06server\x18\x01 \x01(\x0b\x32\x0b.DeviceInfo\x12\x12\n\nvalidation\x18\x02 \x01(\x08\"P\n\x0c\x45valResponse\x12)\n\x12\x64iagnostic_metrics\x18\x01 \x01(\x0b\x32\x08.MetricsH\x00\x88\x01\x01\x42\x15\n\x13_diagnostic_metrics\"O\n\x10\x45valBatchRequest\x12\"\n\x0csmashed_data\x18\x01 \x01(\x0b\x32\x0c.Activations\x12\x17\n\x06labels\x18\x02 \x01(\x0b\x32\x07.Labels\"p\n\x11\x45valBatchResponse\x12\x19\n\x07metrics\x18\x01 \x01(\x0b\x32\x08.Metrics\x12)\n\x12\x64iagnostic_metrics\x18\x02 \x01(\x0b\x32\x08.MetricsH\x00\x88\x01\x01\x42\x15\n\x13_diagnostic_metrics\";\n\x15\x46ullModelTrainRequest\x12\x15\n\x08round_no\x18\x01 \x01(\x05H\x00\x88\x01\x01\x42\x0b\n\t_round_no\"\xce\x01\n\x16\x46ullModelTrainResponse\x12 \n\x0e\x63lient_weights\x18\x01 \x01(\x0b\x32\x08.Weights\x12 \n\x0eserver_weights\x18\x02 \x01(\x0b\x32\x08.Weights\x12\x13\n\x0bnum_samples\x18\x03 \x01(\x05\x12\x19\n\x07metrics\x18\x04 \x01(\x0b\x32\x08.Metrics\x12)\n\x12\x64iagnostic_metrics\x18\x05 \x01(\x0b\x32\x08.MetricsH\x00\x88\x01\x01\x42\x15\n\x13_diagnostic_metrics\"\x18\n\x16StartExperimentRequest\"[\n\x17StartExperimentResponse\x12)\n\x12\x64iagnostic_metrics\x18\x01 \x01(\x0b\x32\x08.MetricsH\x00\x88\x01\x01\x42\x15\n\x13_diagnostic_metrics\"\x16\n\x14\x45ndExperimentRequest\"Y\n\x15\x45ndExperimentResponse\x12)\n\x12\x64iagnostic_metrics\x18\x01 \x01(\x0b\x32\x08.MetricsH\x00\x88\x01\x01\x42\x15\n\x13_diagnostic_metrics\"\x16\n\x14\x42\x61tteryStatusRequest\"y\n\x15\x42\x61tteryStatusResponse\x12\x1e\n\x06status\x18\x01 \x01(\x0b\x32\x0e.BatteryStatus\x12)\n\x12\x64iagnostic_metrics\x18\x02 \x01(\x0b\x32\x08.MetricsH\x00\x88\x01\x01\x42\x15\n\x13_diagnostic_metrics\"\x19\n\x17\x44\x61tasetModelInfoRequest\"\xc7\x01\n\x18\x44\x61tasetModelInfoResponse\x12\x15\n\rtrain_samples\x18\x01 \x01(\x05\x12\x1a\n\x12validation_samples\x18\x02 \x01(\x05\x12\x1a\n\x12\x63lient_model_flops\x18\x03 \x01(\x05\x12\x1a\n\x12server_model_flops\x18\x04 \x01(\x05\x12)\n\x12\x64iagnostic_metrics\x18\x05 \x01(\x0b\x32\x08.MetricsH\x00\x88\x01\x01\x42\x15\n\x13_diagnostic_metrics2\xf8\x08\n\x06\x44\x65vice\x12:\n\x0bTrainGlobal\x12\x13.TrainGlobalRequest\x1a\x14.TrainGlobalResponse\"\x00\x12\x37\n\nSetWeights\x12\x12.SetWeightsRequest\x1a\x13.SetWeightsResponse\"\x00\x12\x37\n\nTrainEpoch\x12\x12.TrainEpochRequest\x1a\x13.TrainEpochResponse\"\x00\x12\x37\n\nTrainBatch\x12\x12.TrainBatchRequest\x1a\x13.TrainBatchResponse\"\x00\x12;\n\x0e\x45valuateGlobal\x12\x12.EvalGlobalRequest\x1a\x13.EvalGlobalResponse\"\x00\x12)\n\x08\x45valuate\x12\x0c.EvalRequest\x1a\r.EvalResponse\"\x00\x12\x38\n\rEvaluateBatch\x12\x11.EvalBatchRequest\x1a\x12.EvalBatchResponse\"\x00\x12\x46\n\x11\x46ullModelTraining\x12\x16.FullModelTrainRequest\x1a\x17.FullModelTrainResponse\"\x00\x12\x46\n\x0fStartExperiment\x12\x17.StartExperimentRequest\x1a\x18.StartExperimentResponse\"\x00\x12@\n\rEndExperiment\x12\x15.EndExperimentRequest\x1a\x16.EndExperimentResponse\"\x00\x12\x43\n\x10GetBatteryStatus\x12\x15.BatteryStatusRequest\x1a\x16.BatteryStatusResponse\"\x00\x12L\n\x13GetDatasetModelInfo\x12\x18.DatasetModelInfoRequest\x1a\x19.DatasetModelInfoResponse\"\x00\x12y\n TrainGlobalParallelSplitLearning\x12(.TrainGlobalParallelSplitLearningRequest\x1a).TrainGlobalParallelSplitLearningResponse\"\x00\x12W\n\x18TrainSingleBatchOnClient\x12\x1b.SingleBatchTrainingRequest\x1a\x1c.SingleBatchTrainingResponse\"\x00\x12\x65\n&BackwardPropagationSingleBatchOnClient\x12\x1b.SingleBatchBackwardRequest\x1a\x1c.SingleBatchBackwardResponse\"\x00\x12\x45\n#SetGradientsAndFinalizeTrainingStep\x12\x14.SetGradientsRequest\x1a\x06.Empty\"\x00\x62\x06proto3') +DESCRIPTOR = _descriptor_pool.Default().AddSerializedFile(b'\n\x10\x63onnection.proto\x1a\x14\x64\x61tastructures.proto\"4\n\x13SetGradientsRequest\x12\x1d\n\tgradients\x18\x01 \x01(\x0b\x32\n.Gradients\"5\n\x14UpdateWeightsRequest\x12\x1d\n\tgradients\x18\x01 \x01(\x0b\x32\n.Gradients\";\n\x1aSingleBatchBackwardRequest\x12\x1d\n\tgradients\x18\x01 \x01(\x0b\x32\n.Gradients\"j\n\x1bSingleBatchBackwardResponse\x12\x19\n\x07metrics\x18\x01 \x01(\x0b\x32\x08.Metrics\x12\"\n\tgradients\x18\x02 \x01(\x0b\x32\n.GradientsH\x00\x88\x01\x01\x42\x0c\n\n_gradients\"C\n\x1aSingleBatchTrainingRequest\x12\x13\n\x0b\x62\x61tch_index\x18\x01 \x01(\x05\x12\x10\n\x08round_no\x18\x02 \x01(\x05\"\x80\x01\n\x1bSingleBatchTrainingResponse\x12\'\n\x0csmashed_data\x18\x01 \x01(\x0b\x32\x0c.ActivationsH\x00\x88\x01\x01\x12\x1c\n\x06labels\x18\x02 \x01(\x0b\x32\x07.LabelsH\x01\x88\x01\x01\x42\x0f\n\r_smashed_dataB\t\n\x07_labels\"\xd5\x01\n\'TrainGlobalParallelSplitLearningRequest\x12\x15\n\x08round_no\x18\x01 \x01(\x05H\x00\x88\x01\x01\x12(\n\x1b\x61\x64\x61ptive_learning_threshold\x18\x02 \x01(\x01H\x01\x88\x01\x01\x12(\n\x0foptimizer_state\x18\x03 \x01(\x0b\x32\n.StateDictH\x02\x88\x01\x01\x42\x0b\n\t_round_noB\x1e\n\x1c_adaptive_learning_thresholdB\x12\n\x10_optimizer_state\"\x89\x02\n(TrainGlobalParallelSplitLearningResponse\x12 \n\x0e\x63lient_weights\x18\x01 \x01(\x0b\x32\x08.Weights\x12 \n\x0eserver_weights\x18\x02 \x01(\x0b\x32\x08.Weights\x12\x19\n\x07metrics\x18\x03 \x01(\x0b\x32\x08.Metrics\x12(\n\x0foptimizer_state\x18\x04 \x01(\x0b\x32\n.StateDictH\x00\x88\x01\x01\x12)\n\x12\x64iagnostic_metrics\x18\x05 \x01(\x0b\x32\x08.MetricsH\x01\x88\x01\x01\x42\x12\n\x10_optimizer_stateB\x15\n\x13_diagnostic_metrics\"\x86\x01\n\x12TrainGlobalRequest\x12\x0e\n\x06\x65pochs\x18\x01 \x01(\x05\x12\x15\n\x08round_no\x18\x02 \x01(\x05H\x00\x88\x01\x01\x12(\n\x0foptimizer_state\x18\x03 \x01(\x0b\x32\n.StateDictH\x01\x88\x01\x01\x42\x0b\n\t_round_noB\x12\n\x10_optimizer_state\"\xf4\x01\n\x13TrainGlobalResponse\x12 \n\x0e\x63lient_weights\x18\x01 \x01(\x0b\x32\x08.Weights\x12 \n\x0eserver_weights\x18\x02 \x01(\x0b\x32\x08.Weights\x12\x19\n\x07metrics\x18\x03 \x01(\x0b\x32\x08.Metrics\x12(\n\x0foptimizer_state\x18\x04 \x01(\x0b\x32\n.StateDictH\x00\x88\x01\x01\x12)\n\x12\x64iagnostic_metrics\x18\x05 \x01(\x0b\x32\x08.MetricsH\x01\x88\x01\x01\x42\x12\n\x10_optimizer_stateB\x15\n\x13_diagnostic_metrics\"A\n\x11SetWeightsRequest\x12\x19\n\x07weights\x18\x01 \x01(\x0b\x32\x08.Weights\x12\x11\n\ton_client\x18\x02 \x01(\x08\"V\n\x12SetWeightsResponse\x12)\n\x12\x64iagnostic_metrics\x18\x01 \x01(\x0b\x32\x08.MetricsH\x00\x88\x01\x01\x42\x15\n\x13_diagnostic_metrics\"T\n\x11TrainEpochRequest\x12\x1b\n\x06server\x18\x01 \x01(\x0b\x32\x0b.DeviceInfo\x12\x15\n\x08round_no\x18\x02 \x01(\x05H\x00\x88\x01\x01\x42\x0b\n\t_round_no\"q\n\x12TrainEpochResponse\x12\x19\n\x07weights\x18\x01 \x01(\x0b\x32\x08.Weights\x12)\n\x12\x64iagnostic_metrics\x18\x02 \x01(\x0b\x32\x08.MetricsH\x00\x88\x01\x01\x42\x15\n\x13_diagnostic_metrics\"P\n\x11TrainBatchRequest\x12\"\n\x0csmashed_data\x18\x01 \x01(\x0b\x32\x0c.Activations\x12\x17\n\x06labels\x18\x02 \x01(\x0b\x32\x07.Labels\"\x91\x01\n\x12TrainBatchResponse\x12\x1d\n\tgradients\x18\x01 \x01(\x0b\x32\n.Gradients\x12)\n\x12\x64iagnostic_metrics\x18\x02 \x01(\x0b\x32\x08.MetricsH\x00\x88\x01\x01\x12\x11\n\x04loss\x18\x03 \x01(\x01H\x01\x88\x01\x01\x42\x15\n\x13_diagnostic_metricsB\x07\n\x05_loss\":\n\x11\x45valGlobalRequest\x12\x12\n\nvalidation\x18\x01 \x01(\x08\x12\x11\n\tfederated\x18\x02 \x01(\x08\"q\n\x12\x45valGlobalResponse\x12\x19\n\x07metrics\x18\x01 \x01(\x0b\x32\x08.Metrics\x12)\n\x12\x64iagnostic_metrics\x18\x02 \x01(\x0b\x32\x08.MetricsH\x00\x88\x01\x01\x42\x15\n\x13_diagnostic_metrics\">\n\x0b\x45valRequest\x12\x1b\n\x06server\x18\x01 \x01(\x0b\x32\x0b.DeviceInfo\x12\x12\n\nvalidation\x18\x02 \x01(\x08\"P\n\x0c\x45valResponse\x12)\n\x12\x64iagnostic_metrics\x18\x01 \x01(\x0b\x32\x08.MetricsH\x00\x88\x01\x01\x42\x15\n\x13_diagnostic_metrics\"O\n\x10\x45valBatchRequest\x12\"\n\x0csmashed_data\x18\x01 \x01(\x0b\x32\x0c.Activations\x12\x17\n\x06labels\x18\x02 \x01(\x0b\x32\x07.Labels\"p\n\x11\x45valBatchResponse\x12\x19\n\x07metrics\x18\x01 \x01(\x0b\x32\x08.Metrics\x12)\n\x12\x64iagnostic_metrics\x18\x02 \x01(\x0b\x32\x08.MetricsH\x00\x88\x01\x01\x42\x15\n\x13_diagnostic_metrics\";\n\x15\x46ullModelTrainRequest\x12\x15\n\x08round_no\x18\x01 \x01(\x05H\x00\x88\x01\x01\x42\x0b\n\t_round_no\"\xce\x01\n\x16\x46ullModelTrainResponse\x12 \n\x0e\x63lient_weights\x18\x01 \x01(\x0b\x32\x08.Weights\x12 \n\x0eserver_weights\x18\x02 \x01(\x0b\x32\x08.Weights\x12\x13\n\x0bnum_samples\x18\x03 \x01(\x05\x12\x19\n\x07metrics\x18\x04 \x01(\x0b\x32\x08.Metrics\x12)\n\x12\x64iagnostic_metrics\x18\x05 \x01(\x0b\x32\x08.MetricsH\x00\x88\x01\x01\x42\x15\n\x13_diagnostic_metrics\"\x18\n\x16StartExperimentRequest\"[\n\x17StartExperimentResponse\x12)\n\x12\x64iagnostic_metrics\x18\x01 \x01(\x0b\x32\x08.MetricsH\x00\x88\x01\x01\x42\x15\n\x13_diagnostic_metrics\"\x16\n\x14\x45ndExperimentRequest\"Y\n\x15\x45ndExperimentResponse\x12)\n\x12\x64iagnostic_metrics\x18\x01 \x01(\x0b\x32\x08.MetricsH\x00\x88\x01\x01\x42\x15\n\x13_diagnostic_metrics\"\x16\n\x14\x42\x61tteryStatusRequest\"y\n\x15\x42\x61tteryStatusResponse\x12\x1e\n\x06status\x18\x01 \x01(\x0b\x32\x0e.BatteryStatus\x12)\n\x12\x64iagnostic_metrics\x18\x02 \x01(\x0b\x32\x08.MetricsH\x00\x88\x01\x01\x42\x15\n\x13_diagnostic_metrics\"\x19\n\x17\x44\x61tasetModelInfoRequest\"\xc7\x01\n\x18\x44\x61tasetModelInfoResponse\x12\x15\n\rtrain_samples\x18\x01 \x01(\x05\x12\x1a\n\x12validation_samples\x18\x02 \x01(\x05\x12\x1a\n\x12\x63lient_model_flops\x18\x03 \x01(\x05\x12\x1a\n\x12server_model_flops\x18\x04 \x01(\x05\x12)\n\x12\x64iagnostic_metrics\x18\x05 \x01(\x0b\x32\x08.MetricsH\x00\x88\x01\x01\x42\x15\n\x13_diagnostic_metrics2\xf8\x08\n\x06\x44\x65vice\x12:\n\x0bTrainGlobal\x12\x13.TrainGlobalRequest\x1a\x14.TrainGlobalResponse\"\x00\x12\x37\n\nSetWeights\x12\x12.SetWeightsRequest\x1a\x13.SetWeightsResponse\"\x00\x12\x37\n\nTrainEpoch\x12\x12.TrainEpochRequest\x1a\x13.TrainEpochResponse\"\x00\x12\x37\n\nTrainBatch\x12\x12.TrainBatchRequest\x1a\x13.TrainBatchResponse\"\x00\x12;\n\x0e\x45valuateGlobal\x12\x12.EvalGlobalRequest\x1a\x13.EvalGlobalResponse\"\x00\x12)\n\x08\x45valuate\x12\x0c.EvalRequest\x1a\r.EvalResponse\"\x00\x12\x38\n\rEvaluateBatch\x12\x11.EvalBatchRequest\x1a\x12.EvalBatchResponse\"\x00\x12\x46\n\x11\x46ullModelTraining\x12\x16.FullModelTrainRequest\x1a\x17.FullModelTrainResponse\"\x00\x12\x46\n\x0fStartExperiment\x12\x17.StartExperimentRequest\x1a\x18.StartExperimentResponse\"\x00\x12@\n\rEndExperiment\x12\x15.EndExperimentRequest\x1a\x16.EndExperimentResponse\"\x00\x12\x43\n\x10GetBatteryStatus\x12\x15.BatteryStatusRequest\x1a\x16.BatteryStatusResponse\"\x00\x12L\n\x13GetDatasetModelInfo\x12\x18.DatasetModelInfoRequest\x1a\x19.DatasetModelInfoResponse\"\x00\x12y\n TrainGlobalParallelSplitLearning\x12(.TrainGlobalParallelSplitLearningRequest\x1a).TrainGlobalParallelSplitLearningResponse\"\x00\x12W\n\x18TrainSingleBatchOnClient\x12\x1b.SingleBatchTrainingRequest\x1a\x1c.SingleBatchTrainingResponse\"\x00\x12\x65\n&BackwardPropagationSingleBatchOnClient\x12\x1b.SingleBatchBackwardRequest\x1a\x1c.SingleBatchBackwardResponse\"\x00\x12\x45\n#SetGradientsAndFinalizeTrainingStep\x12\x14.SetGradientsRequest\x1a\x06.Empty\"\x00\x62\x06proto3') _globals = globals() _builder.BuildMessageAndEnumDescriptors(DESCRIPTOR, _globals) @@ -31,61 +31,61 @@ if _descriptor._USE_C_DESCRIPTORS == False: _globals['_SINGLEBATCHBACKWARDRESPONSE']._serialized_start=212 _globals['_SINGLEBATCHBACKWARDRESPONSE']._serialized_end=318 _globals['_SINGLEBATCHTRAININGREQUEST']._serialized_start=320 - _globals['_SINGLEBATCHTRAININGREQUEST']._serialized_end=369 - _globals['_SINGLEBATCHTRAININGRESPONSE']._serialized_start=372 - _globals['_SINGLEBATCHTRAININGRESPONSE']._serialized_end=500 - _globals['_TRAINGLOBALPARALLELSPLITLEARNINGREQUEST']._serialized_start=503 - _globals['_TRAINGLOBALPARALLELSPLITLEARNINGREQUEST']._serialized_end=716 - _globals['_TRAINGLOBALPARALLELSPLITLEARNINGRESPONSE']._serialized_start=719 - _globals['_TRAINGLOBALPARALLELSPLITLEARNINGRESPONSE']._serialized_end=984 - _globals['_TRAINGLOBALREQUEST']._serialized_start=987 - _globals['_TRAINGLOBALREQUEST']._serialized_end=1121 - _globals['_TRAINGLOBALRESPONSE']._serialized_start=1124 - _globals['_TRAINGLOBALRESPONSE']._serialized_end=1368 - _globals['_SETWEIGHTSREQUEST']._serialized_start=1370 - _globals['_SETWEIGHTSREQUEST']._serialized_end=1435 - _globals['_SETWEIGHTSRESPONSE']._serialized_start=1437 - _globals['_SETWEIGHTSRESPONSE']._serialized_end=1523 - _globals['_TRAINEPOCHREQUEST']._serialized_start=1525 - _globals['_TRAINEPOCHREQUEST']._serialized_end=1609 - _globals['_TRAINEPOCHRESPONSE']._serialized_start=1611 - _globals['_TRAINEPOCHRESPONSE']._serialized_end=1724 - _globals['_TRAINBATCHREQUEST']._serialized_start=1726 - _globals['_TRAINBATCHREQUEST']._serialized_end=1806 - _globals['_TRAINBATCHRESPONSE']._serialized_start=1809 - _globals['_TRAINBATCHRESPONSE']._serialized_end=1954 - _globals['_EVALGLOBALREQUEST']._serialized_start=1956 - _globals['_EVALGLOBALREQUEST']._serialized_end=2014 - _globals['_EVALGLOBALRESPONSE']._serialized_start=2016 - _globals['_EVALGLOBALRESPONSE']._serialized_end=2129 - _globals['_EVALREQUEST']._serialized_start=2131 - _globals['_EVALREQUEST']._serialized_end=2193 - _globals['_EVALRESPONSE']._serialized_start=2195 - _globals['_EVALRESPONSE']._serialized_end=2275 - _globals['_EVALBATCHREQUEST']._serialized_start=2277 - _globals['_EVALBATCHREQUEST']._serialized_end=2356 - _globals['_EVALBATCHRESPONSE']._serialized_start=2358 - _globals['_EVALBATCHRESPONSE']._serialized_end=2470 - _globals['_FULLMODELTRAINREQUEST']._serialized_start=2472 - _globals['_FULLMODELTRAINREQUEST']._serialized_end=2531 - _globals['_FULLMODELTRAINRESPONSE']._serialized_start=2534 - _globals['_FULLMODELTRAINRESPONSE']._serialized_end=2740 - _globals['_STARTEXPERIMENTREQUEST']._serialized_start=2742 - _globals['_STARTEXPERIMENTREQUEST']._serialized_end=2766 - _globals['_STARTEXPERIMENTRESPONSE']._serialized_start=2768 - _globals['_STARTEXPERIMENTRESPONSE']._serialized_end=2859 - _globals['_ENDEXPERIMENTREQUEST']._serialized_start=2861 - _globals['_ENDEXPERIMENTREQUEST']._serialized_end=2883 - _globals['_ENDEXPERIMENTRESPONSE']._serialized_start=2885 - _globals['_ENDEXPERIMENTRESPONSE']._serialized_end=2974 - _globals['_BATTERYSTATUSREQUEST']._serialized_start=2976 - _globals['_BATTERYSTATUSREQUEST']._serialized_end=2998 - _globals['_BATTERYSTATUSRESPONSE']._serialized_start=3000 - _globals['_BATTERYSTATUSRESPONSE']._serialized_end=3121 - _globals['_DATASETMODELINFOREQUEST']._serialized_start=3123 - _globals['_DATASETMODELINFOREQUEST']._serialized_end=3148 - _globals['_DATASETMODELINFORESPONSE']._serialized_start=3151 - _globals['_DATASETMODELINFORESPONSE']._serialized_end=3350 - _globals['_DEVICE']._serialized_start=3353 - _globals['_DEVICE']._serialized_end=4497 + _globals['_SINGLEBATCHTRAININGREQUEST']._serialized_end=387 + _globals['_SINGLEBATCHTRAININGRESPONSE']._serialized_start=390 + _globals['_SINGLEBATCHTRAININGRESPONSE']._serialized_end=518 + _globals['_TRAINGLOBALPARALLELSPLITLEARNINGREQUEST']._serialized_start=521 + _globals['_TRAINGLOBALPARALLELSPLITLEARNINGREQUEST']._serialized_end=734 + _globals['_TRAINGLOBALPARALLELSPLITLEARNINGRESPONSE']._serialized_start=737 + _globals['_TRAINGLOBALPARALLELSPLITLEARNINGRESPONSE']._serialized_end=1002 + _globals['_TRAINGLOBALREQUEST']._serialized_start=1005 + _globals['_TRAINGLOBALREQUEST']._serialized_end=1139 + _globals['_TRAINGLOBALRESPONSE']._serialized_start=1142 + _globals['_TRAINGLOBALRESPONSE']._serialized_end=1386 + _globals['_SETWEIGHTSREQUEST']._serialized_start=1388 + _globals['_SETWEIGHTSREQUEST']._serialized_end=1453 + _globals['_SETWEIGHTSRESPONSE']._serialized_start=1455 + _globals['_SETWEIGHTSRESPONSE']._serialized_end=1541 + _globals['_TRAINEPOCHREQUEST']._serialized_start=1543 + _globals['_TRAINEPOCHREQUEST']._serialized_end=1627 + _globals['_TRAINEPOCHRESPONSE']._serialized_start=1629 + _globals['_TRAINEPOCHRESPONSE']._serialized_end=1742 + _globals['_TRAINBATCHREQUEST']._serialized_start=1744 + _globals['_TRAINBATCHREQUEST']._serialized_end=1824 + _globals['_TRAINBATCHRESPONSE']._serialized_start=1827 + _globals['_TRAINBATCHRESPONSE']._serialized_end=1972 + _globals['_EVALGLOBALREQUEST']._serialized_start=1974 + _globals['_EVALGLOBALREQUEST']._serialized_end=2032 + _globals['_EVALGLOBALRESPONSE']._serialized_start=2034 + _globals['_EVALGLOBALRESPONSE']._serialized_end=2147 + _globals['_EVALREQUEST']._serialized_start=2149 + _globals['_EVALREQUEST']._serialized_end=2211 + _globals['_EVALRESPONSE']._serialized_start=2213 + _globals['_EVALRESPONSE']._serialized_end=2293 + _globals['_EVALBATCHREQUEST']._serialized_start=2295 + _globals['_EVALBATCHREQUEST']._serialized_end=2374 + _globals['_EVALBATCHRESPONSE']._serialized_start=2376 + _globals['_EVALBATCHRESPONSE']._serialized_end=2488 + _globals['_FULLMODELTRAINREQUEST']._serialized_start=2490 + _globals['_FULLMODELTRAINREQUEST']._serialized_end=2549 + _globals['_FULLMODELTRAINRESPONSE']._serialized_start=2552 + _globals['_FULLMODELTRAINRESPONSE']._serialized_end=2758 + _globals['_STARTEXPERIMENTREQUEST']._serialized_start=2760 + _globals['_STARTEXPERIMENTREQUEST']._serialized_end=2784 + _globals['_STARTEXPERIMENTRESPONSE']._serialized_start=2786 + _globals['_STARTEXPERIMENTRESPONSE']._serialized_end=2877 + _globals['_ENDEXPERIMENTREQUEST']._serialized_start=2879 + _globals['_ENDEXPERIMENTREQUEST']._serialized_end=2901 + _globals['_ENDEXPERIMENTRESPONSE']._serialized_start=2903 + _globals['_ENDEXPERIMENTRESPONSE']._serialized_end=2992 + _globals['_BATTERYSTATUSREQUEST']._serialized_start=2994 + _globals['_BATTERYSTATUSREQUEST']._serialized_end=3016 + _globals['_BATTERYSTATUSRESPONSE']._serialized_start=3018 + _globals['_BATTERYSTATUSRESPONSE']._serialized_end=3139 + _globals['_DATASETMODELINFOREQUEST']._serialized_start=3141 + _globals['_DATASETMODELINFOREQUEST']._serialized_end=3166 + _globals['_DATASETMODELINFORESPONSE']._serialized_start=3169 + _globals['_DATASETMODELINFORESPONSE']._serialized_end=3368 + _globals['_DEVICE']._serialized_start=3371 + _globals['_DEVICE']._serialized_end=4515 # @@protoc_insertion_point(module_scope) diff --git a/edml/generated/connection_pb2.pyi b/edml/generated/connection_pb2.pyi index a9735505e808ee35b156158e972ab6b206e90b0a..89353343aa1c7e39072bd0ea03c891bd2df7b4df 100644 --- a/edml/generated/connection_pb2.pyi +++ b/edml/generated/connection_pb2.pyi @@ -32,10 +32,12 @@ class SingleBatchBackwardResponse(_message.Message): def __init__(self, metrics: _Optional[_Union[_datastructures_pb2.Metrics, _Mapping]] = ..., gradients: _Optional[_Union[_datastructures_pb2.Gradients, _Mapping]] = ...) -> None: ... class SingleBatchTrainingRequest(_message.Message): - __slots__ = ["batch_index"] + __slots__ = ["batch_index", "round_no"] BATCH_INDEX_FIELD_NUMBER: _ClassVar[int] + ROUND_NO_FIELD_NUMBER: _ClassVar[int] batch_index: int - def __init__(self, batch_index: _Optional[int] = ...) -> None: ... + round_no: int + def __init__(self, batch_index: _Optional[int] = ..., round_no: _Optional[int] = ...) -> None: ... class SingleBatchTrainingResponse(_message.Message): __slots__ = ["smashed_data", "labels"] diff --git a/edml/proto/connection.proto b/edml/proto/connection.proto index 2f12881a1cf7e3ee6ae2e076b2f9005141cfece2..5755518d01796bdff68957655d4d339dcf02ab1d 100644 --- a/edml/proto/connection.proto +++ b/edml/proto/connection.proto @@ -41,6 +41,7 @@ message SingleBatchBackwardResponse { message SingleBatchTrainingRequest { int32 batch_index = 1; + int32 round_no = 2; } message SingleBatchTrainingResponse {