diff --git a/edml/controllers/base_controller.py b/edml/controllers/base_controller.py index a2adca85736844b23f40a7c8c5c1a5c4574ef280..76e1a6d61a6f8258b70e12804b52b57d45a9e58a 100644 --- a/edml/controllers/base_controller.py +++ b/edml/controllers/base_controller.py @@ -5,7 +5,7 @@ from typing import Optional import torch from edml.controllers.early_stopping import create_early_stopping_callback -from edml.core.device import DeviceRequestDispatcher +from edml.core.device_request_dispatcher import DeviceRequestDispatcher from edml.core.start_device import _get_models from edml.helpers.logging import SimpleLogger, create_logger from edml.helpers.metrics import ModelMetricResultContainer diff --git a/edml/core/device.py b/edml/core/device.py index aea1ae95de2735dc4c984a13d49dadd632d66432..873a39b08ac386f807c0448d13824e282f4ff2e5 100644 --- a/edml/core/device.py +++ b/edml/core/device.py @@ -2,76 +2,25 @@ from __future__ import annotations import threading from abc import ABC, abstractmethod -from typing import Optional, Dict, Any, List, Union, Tuple, cast +from typing import Optional, Any, List, Tuple, cast -import grpc -from google.protobuf.message import Message from omegaconf import DictConfig -from torch import Tensor from torch.autograd import Variable from edml.core.battery import Battery from edml.core.client import DeviceClient +from edml.core.device_request_dispatcher import DeviceRequestDispatcher from edml.core.server import DeviceServer -from edml.generated import connection_pb2 -from edml.generated.connection_pb2 import ( - SetGradientsRequest, - SetWeightsRequest, - TrainBatchRequest, - TrainGlobalResponse, - TrainEpochResponse, - TrainBatchResponse, - EvalGlobalResponse, - EvalResponse, - EvalBatchResponse, - FullModelTrainResponse, - BatteryStatusResponse, - DatasetModelInfoResponse, - EndExperimentResponse, - StartExperimentResponse, - SingleBatchTrainingResponse, - SingleBatchBackwardRequest, - TrainGlobalParallelSplitLearningResponse, - SingleBatchBackwardResponse, -) -from edml.generated.connection_pb2_grpc import DeviceServicer, DeviceStub -from edml.generated.datastructures_pb2 import ( - Gradients, - Weights, - DeviceInfo, - Activations, - Labels, - BatteryStatus, - Empty, - Metrics, -) from edml.helpers.decorators import ( log_execution_time, update_battery, add_time_to_diagnostic_metrics, ) -from edml.helpers.interceptors import DeviceClientInterceptor from edml.helpers.logging import SimpleLogger from edml.helpers.metrics import ( ModelMetricResultContainer, DiagnosticMetricResultContainer, ) -from edml.helpers.proto_helpers import ( - proto_to_tensor, - tensor_to_proto, - state_dict_to_proto, - proto_to_state_dict, - proto_to_weights, - metrics_to_proto, - proto_to_metrics, - _proto_size_per_field, -) -from edml.helpers.types import ( - HasMetrics, - StateDict, - DeviceBatteryStatus, - DeviceBatteryStatusReport, -) class Device(ABC): @@ -224,6 +173,38 @@ class Device(ABC): class NetworkDevice(Device): + """ + Main implementation of the Device class. Provides actions related to ML tasks which are delegated accordingly. + There are two hooks for each action, e.g. set_weights and set_weights_on. Calling set_weights lets the device + set the provided weights on itself, while set_weights_on delegates the weights to a certain receiver. This way, + the "client" or "server" part of split learning can request certain actions without handling the communication details. + Outgoing communication with other devices is delegated to the DeviceRequestDispatcher and incoming requests are handled + by the RPCDeviceServicer calling appropriate actions on the device. + To set up the communication, set_devices must be called with the other devices and connection information. + Other responsibilities of this class are logging and updating the battery. + + Attributes: + device_id (str): This device's id. + logger (SimpleLogger): The logger instance the device can use. + battery (Battery): The device's battery. Certain function consume energy and drain the battery. + client (DeviceClient): The client part of this device. Initialized later by explicitly calling + py:meth:`set_client`. + server (DeviceServer): The server part of this device. Initialized later by explicitly calling + py:meth:`set_server`. + """ + + def __init__( + self, + device_id: str, + logger: SimpleLogger, + battery: Battery, + stop_event: Optional[threading.Event] = None, + ): + self.devices: List[DictConfig[str, Any]] = [] + self.request_dispatcher = DeviceRequestDispatcher([], device_id=device_id) + self.stop_event = stop_event + super().__init__(device_id, logger, battery) + @update_battery @log_execution_time("logger", "finalize_gradients") def set_gradient_and_finalize_training_on_client_only_on( @@ -287,18 +268,6 @@ class NetworkDevice(Device): device_id=device_id, batch_index=batch_index, round_no=round_no ) - def __init__( - self, - device_id: str, - logger: SimpleLogger, - battery: Battery, - stop_event: Optional[threading.Event] = None, - ): - self.devices: List[DictConfig[str, Any]] = [] - self.request_dispatcher = DeviceRequestDispatcher([], device_id=device_id) - self.stop_event = stop_event - super().__init__(device_id, logger, battery) - @add_time_to_diagnostic_metrics("train_global") @update_battery @log_execution_time("logger", "train_global_time") @@ -444,625 +413,3 @@ class NetworkDevice(Device): def _log_current_battery_capacity(self): """Wrapper for logging the current battery capacity""" self.logger.log({"battery": self.battery.remaining_capacity()}) - - -class RPCDeviceServicer(DeviceServicer): - def __init__(self, device: NetworkDevice): - self.device = device - - def TrainGlobal(self, request, context): - print(f"Called TrainGlobal on device {self.device.device_id}") - client_weights, server_weights, metrics, optimizer_state, diagnostic_metrics = ( - self.device.train_global( - request.epochs, - request.round_no, - request.adaptive_threshold_value, - proto_to_state_dict(request.optimizer_state), - ) - ) - response = connection_pb2.TrainGlobalResponse( - client_weights=Weights(weights=state_dict_to_proto(client_weights)), - server_weights=Weights(weights=state_dict_to_proto(server_weights)), - metrics=metrics_to_proto(metrics), - optimizer_state=state_dict_to_proto(optimizer_state), - diagnostic_metrics=metrics_to_proto(diagnostic_metrics), - ) - return response - - def SetWeights(self, request, context): - print(f"Called SetWeights on device {self.device.device_id}") - weights = proto_to_state_dict(request.weights.weights) - self.device.set_weights(weights, request.on_client) - return connection_pb2.SetWeightsResponse() - - def TrainEpoch(self, request, context): - print(f"Called TrainEpoch on device {self.device.device_id}") - device_info = request.server - device_id = device_info.device_id - round_no = request.round_no - weights, diagnostic_metrics = self.device.train_epoch(device_id, round_no) - proto_weights = state_dict_to_proto(weights) - return connection_pb2.TrainEpochResponse( - weights=Weights(weights=proto_weights), - diagnostic_metrics=metrics_to_proto(diagnostic_metrics), - ) - - def TrainBatch(self, request, context): - activations = proto_to_tensor(request.smashed_data.activations) - labels = proto_to_tensor(request.labels.labels) - gradients, loss, diagnostic_metrics = self.device.train_batch( - activations, labels - ) - proto_gradients = Gradients(gradients=tensor_to_proto(gradients)) - return connection_pb2.TrainBatchResponse( - gradients=proto_gradients, - loss=loss, - diagnostic_metrics=metrics_to_proto(diagnostic_metrics), - ) - - def EvaluateGlobal(self, request, context): - print(f"Called EvaluateGlobal on device {self.device.device_id}") - metrics, diagnostic_metrics = self.device.evaluate_global( - request.validation, request.federated - ) - return connection_pb2.EvalGlobalResponse( - metrics=metrics_to_proto(metrics), - diagnostic_metrics=metrics_to_proto(diagnostic_metrics), - ) - - def Evaluate(self, request, context): - print(f"Called Evaluate on device {self.device.device_id}") - diagnostic_metrics = self.device.evaluate( - request.server.device_id, request.validation - ) - return connection_pb2.EvalResponse( - diagnostic_metrics=metrics_to_proto(diagnostic_metrics) - ) - - def EvaluateBatch(self, request, context): - activations = proto_to_tensor(request.smashed_data.activations) - labels = proto_to_tensor(request.labels.labels) - diagnostic_metrics = self.device.evaluate_batch(activations, labels) - return connection_pb2.EvalBatchResponse( - diagnostic_metrics=metrics_to_proto(diagnostic_metrics) - ) - - def FullModelTraining(self, request, context): - print(f"Called Full Training on device {self.device.device_id}") - client_weights, server_weights, num_samples, metrics, diagnostic_metrics = ( - self.device.federated_train(request.round_no) - ) - return connection_pb2.FullModelTrainResponse( - client_weights=Weights(weights=state_dict_to_proto(client_weights)), - server_weights=Weights(weights=state_dict_to_proto(server_weights)), - num_samples=num_samples, - metrics=metrics_to_proto(metrics), - diagnostic_metrics=metrics_to_proto(diagnostic_metrics), - ) - - def StartExperiment(self, request, context) -> StartExperimentResponse: - print(f"Start Experiment on device {self.device.device_id}") - self.device.start_experiment() - return connection_pb2.StartExperimentResponse() - - def EndExperiment(self, request, context) -> EndExperimentResponse: - print(f"End Experiment on device {self.device.device_id}") - print(f"Remaining battery capacity {self.device.battery.remaining_capacity()}") - self.device.end_experiment() - return connection_pb2.EndExperimentResponse() - - def GetBatteryStatus(self, request, context): - print(f"Get Battery Status on device {self.device.device_id}") - initial_capacity, remaining_capacity = self.device.get_battery_status() - - return connection_pb2.BatteryStatusResponse( - status=BatteryStatus( - initial_battery_level=initial_capacity, - current_battery_level=remaining_capacity, - ) - ) - - def GetDatasetModelInfo(self, request, context): - print(f"Get Dataset and Model Info on device {self.device.device_id}") - return connection_pb2.DatasetModelInfoResponse( - train_samples=len(self.device.client._train_data.dataset), - validation_samples=len(self.device.client._val_data.dataset), - client_fw_flops=int(self.device.client._model_flops["FW"]), - server_fw_flops=int(self.device.server._model_flops["FW"]), - client_bw_flops=int(self.device.client._model_flops["BW"]), - server_bw_flops=int(self.device.server._model_flops["BW"]), - ) - - def TrainGlobalParallelSplitLearning(self, request, context): - print(f"Starting parallel split learning") - clients = self.device.__get_device_ids__() - round_no = request.round_no - adaptive_threshold_value = request.adaptive_threshold_value - optimizer_state = proto_to_state_dict(request.optimizer_state) - - cw, sw, model_metrics, optimizer_state, diagnostic_metrics = ( - self.device.train_parallel_split_learning( - clients=clients, - round_no=round_no, - adaptive_threshold_value=adaptive_threshold_value, - optimizer_state=optimizer_state, - ) - ) - response = connection_pb2.TrainGlobalParallelSplitLearningResponse( - client_weights=Weights(weights=state_dict_to_proto(cw)), - server_weights=Weights(weights=state_dict_to_proto(sw)), - metrics=metrics_to_proto(model_metrics), - optimizer_state=state_dict_to_proto(optimizer_state), - diagnostic_metrics=metrics_to_proto(diagnostic_metrics), - ) - return response - - def TrainSingleBatchOnClient(self, request, context): - batch_index = request.batch_index - round_no = request.round_no - - smashed_data, labels = self.device.client.train_single_batch( - batch_index, round_no=round_no - ) - - smashed_data = Activations(activations=tensor_to_proto(smashed_data)) - labels = Labels(labels=tensor_to_proto(labels)) - return connection_pb2.SingleBatchTrainingResponse( - smashed_data=smashed_data, - labels=labels, - ) - - def BackwardPropagationSingleBatchOnClient( - self, request: SingleBatchBackwardRequest, context - ): - gradients = proto_to_tensor(request.gradients.gradients) - - metrics, gradients = self.device.client.backward_single_batch( - gradients=gradients - ) - return connection_pb2.SingleBatchBackwardResponse( - metrics=metrics_to_proto(metrics), - gradients=Gradients(gradients=tensor_to_proto(gradients)), - ) - - def SetGradientsAndFinalizeTrainingStep( - self, request: SetGradientsRequest, context - ): - gradients = proto_to_tensor(request.gradients.gradients) - self.device.client.set_gradient_and_finalize_training(gradients=gradients) - return Empty() - - -class DeviceRequestDispatcher: - - def __init__( - self, - devices: List[DictConfig[str, Any]], - logger: Optional[SimpleLogger] = None, - battery: Optional[Battery] = None, - stop_event: Optional[threading.Event] = None, - device_id: Optional[str] = None, - ): - self.devices = devices - self.connections: Dict[str, DeviceStub] = {} - # optional, interceptor only works if all three are set - self.logger = logger - self.battery = battery - self.stop_event = stop_event - - self._establish_connections() - self.connections_lock = threading.Lock() - self.device_id = device_id # used for diagnostic metrics to assign the source device correctly - - def __get_device_address__(self, device_id: str) -> Optional[str]: - for device in self.devices: - if device.device_id == device_id: - return device.address - return None - - def train_parallel_on_server( - self, - server_device_id: str, - epochs: int, - round_no: int, - adaptive_threshold_value: Optional[float] = None, - optimizer_state: dict[str, Any] = None, - ): - print(f"><><><> {adaptive_threshold_value}") - - try: - response: TrainGlobalParallelSplitLearningResponse = self._get_connection( - server_device_id - ).TrainGlobalParallelSplitLearning( - connection_pb2.TrainGlobalParallelSplitLearningRequest( - round_no=round_no, - adaptive_threshold_value=adaptive_threshold_value, - optimizer_state=state_dict_to_proto(optimizer_state), - ) - ) - return ( - proto_to_weights(response.client_weights), - proto_to_weights(response.server_weights), - proto_to_metrics(response.metrics), - proto_to_state_dict(response.optimizer_state), - self._add_byte_size_to_diagnostic_metrics(response, self.device_id), - ) - except grpc.RpcError: - self._handle_rpc_error(server_device_id) - except KeyError: - self._handle_unknown_device_id(server_device_id) - return False - - def _establish_connections(self): - for device in self.devices: - channel = grpc.insecure_channel( - device.address, - options=[ - ("grpc.max_send_message_length", -1), - ("grpc.max_receive_message_length", -1), - ], - ) - if ( - self.logger is not None - and self.battery is not None - and self.stop_event is not None - ): - channel = grpc.intercept_channel( - channel, - DeviceClientInterceptor(self.logger, self.battery, self.stop_event), - ) - stub = DeviceStub(channel) - self.connections[device.device_id] = stub - print(f"active devices {self.connections.keys()}") - - def _get_connection(self, device_id: str) -> DeviceStub: - """Returns the connection to the device with the given id. - If the device_id is not a valid dictionary key, the caller has to handle the KeyError. - """ - with self.connections_lock: - return self.connections[device_id] - - def active_devices(self) -> List[str]: - with self.connections_lock: # lock connections to prevent remove while iterating - return list(self.connections.keys()) - - def _handle_rpc_error(self, device_id: str): - """Hook for handling an RpcError. Happens when the device encounters an empty battery during handling - the request or is not reachable (i.e. most probably shut down). - Here, the connection to the device is removed using a lock to avoid race conditions. - """ - with self.connections_lock: - del self.connections[device_id] - - def _handle_unknown_device_id(self, device_id: str): - """Handler in case the requested device id is not in the connection dictionary. - Could happen if the active devices weren't refreshed properly.""" - if device_id in [device.device_id for device in self.devices]: - print(f"Device {device_id} not active") - else: - print(f"Unknown Device ID {device_id}") - - def _add_byte_size_to_diagnostic_metrics( - self, response: Message, device_id: str, request=None - ): - """ - Adds the byte size of the response to the diagnostic metrics. - - Args: - response: A protobuf message. - device_id: The id of the device the request was sent from. - request: A protobuf message, if the request is of interest. - Returns: - A DiagnosticMetricResultContainer including the added byte size and previous metrics from the response (and possibly the request). - Raises: - None - Notes: - """ - if response.HasField("diagnostic_metrics"): - response: HasMetrics - diagnostic_metrics = proto_to_metrics(response.diagnostic_metrics) - else: - diagnostic_metrics = DiagnosticMetricResultContainer() - diagnostic_metrics.merge(_proto_size_per_field(response, device_id)) - if request is not None: - diagnostic_metrics.merge(_proto_size_per_field(request, device_id)) - return diagnostic_metrics - - def train_global_on( - self, - device_id: str, - epochs: int, - round_no: int = -1, - adaptive_threshold_value: Optional[float] = None, - optimizer_state: dict[str, Any] = None, - ) -> Union[ - Tuple[ - Dict[str, Any], - Dict[str, Any], - ModelMetricResultContainer, - Dict[str, Any], - DiagnosticMetricResultContainer, - ], - bool, - ]: - try: - response: TrainGlobalResponse = self._get_connection(device_id).TrainGlobal( - connection_pb2.TrainGlobalRequest( - epochs=epochs, - round_no=round_no, - adaptive_threshold_value=adaptive_threshold_value, - optimizer_state=state_dict_to_proto(optimizer_state), - ) - ) - return ( - proto_to_weights(response.client_weights), - proto_to_weights(response.server_weights), - proto_to_metrics(response.metrics), - proto_to_state_dict(response.optimizer_state), - self._add_byte_size_to_diagnostic_metrics(response, self.device_id), - ) - except grpc.RpcError: - self._handle_rpc_error(device_id) - except KeyError: - self._handle_unknown_device_id(device_id) - return False - - def set_weights_on( - self, device_id: str, state_dict, on_client: bool, wait_for_ready: bool = False - ): - try: - self._get_connection(device_id).SetWeights( - SetWeightsRequest( - weights=Weights(weights=state_dict_to_proto(state_dict)), - on_client=on_client, - ), - wait_for_ready=wait_for_ready, - ) - return True - except grpc.RpcError: - self._handle_rpc_error(device_id) - except KeyError: - self._handle_unknown_device_id(device_id) - return False - - def train_epoch_on(self, device_id: str, server_device: str, round_no: int = -1): - try: - response: TrainEpochResponse = self._get_connection(device_id).TrainEpoch( - connection_pb2.TrainEpochRequest( - server=DeviceInfo( - device_id=server_device, - address=self.__get_device_address__(server_device), - ), - round_no=round_no, - ) - ) - return proto_to_state_dict( - response.weights.weights - ), self._add_byte_size_to_diagnostic_metrics(response, self.device_id) - except grpc.RpcError: - self._handle_rpc_error(device_id) - except KeyError: - self._handle_unknown_device_id(device_id) - return False - - def train_batch_on(self, device_id: str, smashed_data, labels): - try: - request = TrainBatchRequest( - smashed_data=Activations(activations=tensor_to_proto(smashed_data)), - labels=Labels(labels=tensor_to_proto(labels)), - ) - response: TrainBatchResponse = self._get_connection(device_id).TrainBatch( - request - ) - return ( - proto_to_tensor(response.gradients.gradients), - response.loss, - self._add_byte_size_to_diagnostic_metrics( - response, self.device_id, request - ), - ) - except grpc.RpcError: - self._handle_rpc_error(device_id) - except KeyError: - self._handle_unknown_device_id(device_id) - return False - - def evaluate_global_on(self, device_id: str, val: bool = True, fed: bool = False): - try: - response: EvalGlobalResponse = self._get_connection( - device_id - ).EvaluateGlobal( - connection_pb2.EvalGlobalRequest(validation=val, federated=fed) - ) - return proto_to_metrics( - response.metrics - ), self._add_byte_size_to_diagnostic_metrics(response, self.device_id) - except grpc.RpcError: - self._handle_rpc_error(device_id) - except KeyError: - self._handle_unknown_device_id(device_id) - return False - - def evaluate_on(self, device_id: str, server_device: str, val: bool): - try: - response: EvalResponse = self._get_connection(device_id).Evaluate( - connection_pb2.EvalRequest( - server=DeviceInfo( - device_id=server_device, - address=self.__get_device_address__(server_device), - ), - validation=val, - ) - ) - return self._add_byte_size_to_diagnostic_metrics(response, self.device_id) - except grpc.RpcError: - self._handle_rpc_error(device_id) - except KeyError: - self._handle_unknown_device_id(device_id) - return False - - def evaluate_batch_on(self, device_id: str, smashed_data, labels): - try: - request = connection_pb2.EvalBatchRequest( - smashed_data=Activations(activations=tensor_to_proto(smashed_data)), - labels=Labels(labels=tensor_to_proto(labels)), - ) - response: EvalBatchResponse = self._get_connection(device_id).EvaluateBatch( - request - ) - return self._add_byte_size_to_diagnostic_metrics( - response, self.device_id, request - ) - except grpc.RpcError: - self._handle_rpc_error(device_id) - except KeyError: - self._handle_unknown_device_id(device_id) - return False - - def federated_train_on(self, device_id: str, round_no: int = -1) -> Union[ - Tuple[ - StateDict, - StateDict, - int, - ], - bool, - ]: - try: - response: FullModelTrainResponse = self._get_connection( - device_id - ).FullModelTraining(connection_pb2.FullModelTrainRequest(round_no=round_no)) - return ( - proto_to_weights(response.client_weights), - proto_to_weights(response.server_weights), - response.num_samples, - proto_to_metrics(response.metrics), - self._add_byte_size_to_diagnostic_metrics(response, self.device_id), - ) - except grpc.RpcError: - self._handle_rpc_error(device_id) - except KeyError: - self._handle_unknown_device_id(device_id) - return False - - def start_experiment_on(self, device_id: str, wait_for_ready: bool = False) -> bool: - try: - self._get_connection(device_id).StartExperiment( - connection_pb2.StartExperimentRequest(), wait_for_ready=wait_for_ready - ) - return True - except grpc.RpcError: - self._handle_rpc_error(device_id) - except KeyError: - self._handle_unknown_device_id(device_id) - return False - - def end_experiment_on(self, device_id: str) -> bool: - try: - self._get_connection(device_id).EndExperiment( - connection_pb2.EndExperimentRequest() - ) - return True - except grpc.RpcError: - self._handle_rpc_error(device_id) - except KeyError: - self._handle_unknown_device_id(device_id) - return False - - def get_battery_status_on(self, device_id: str) -> DeviceBatteryStatusReport: - try: - response: BatteryStatusResponse = self._get_connection( - device_id - ).GetBatteryStatus(connection_pb2.BatteryStatusRequest()) - return DeviceBatteryStatus( - current_capacity=response.status.current_battery_level, - initial_capacity=response.status.initial_battery_level, - ) - except grpc.RpcError: - self._handle_rpc_error(device_id) - except KeyError: - self._handle_unknown_device_id(device_id) - return False - - def train_batch_on_client_only( - self, device_id: str, batch_index: int, round_no: int - ) -> Tuple[Tensor, Tensor] | None: - try: - response: SingleBatchTrainingResponse = self._get_connection( - device_id - ).TrainSingleBatchOnClient( - connection_pb2.SingleBatchTrainingRequest( - batch_index=batch_index, round_no=round_no - ) - ) - - # The response can only be None if the last batch was smaller than the configured batch size. - if response.HasField("smashed_data"): - return ( - proto_to_tensor(response.smashed_data.activations), - proto_to_tensor(response.labels.labels), - ) - - return None - except grpc.RpcError: - self._handle_rpc_error(device_id) - except KeyError: - self._handle_unknown_device_id(device_id) - return False - - def get_dataset_model_info_on( - self, device_id: str - ) -> Union[Tuple[int, int, int, int, int, int], bool]: - try: - response: DatasetModelInfoResponse = self._get_connection( - device_id - ).GetDatasetModelInfo(connection_pb2.DatasetModelInfoRequest()) - return ( - response.train_samples, - response.validation_samples, - response.client_fw_flops, - response.server_fw_flops, - response.client_bw_flops, - response.server_bw_flops, - ) - except grpc.RpcError: - self._handle_rpc_error(device_id) - except KeyError: - self._handle_unknown_device_id(device_id) - return False - - def backpropagation_on_client_only(self, device_id, gradients): - try: - response: SingleBatchBackwardResponse = self._get_connection( - device_id - ).BackwardPropagationSingleBatchOnClient( - connection_pb2.SingleBatchBackwardRequest( - gradients=Gradients(gradients=tensor_to_proto(gradients)) - ) - ) - return ( - None, - proto_to_tensor(response.gradients.gradients), - ) - except grpc.RpcError: - self._handle_rpc_error(device_id) - except KeyError: - self._handle_unknown_device_id(device_id) - return False - - def set_gradient_and_finalize_training_on_client_only( - self, client_id: str, gradients: Any - ): - try: - response: Empty = self._get_connection( - client_id - ).SetGradientsAndFinalizeTrainingStep( - connection_pb2.SetGradientsRequest( - gradients=Gradients(gradients=tensor_to_proto(gradients)) - ) - ) - return response - except grpc.RpcError: - self._handle_rpc_error(client_id) - except KeyError: - self._handle_unknown_device_id(client_id) - return False diff --git a/edml/core/device_request_dispatcher.py b/edml/core/device_request_dispatcher.py new file mode 100644 index 0000000000000000000000000000000000000000..1c7802ffb35d9cb4eea543938047d1e91ac48200 --- /dev/null +++ b/edml/core/device_request_dispatcher.py @@ -0,0 +1,468 @@ +from __future__ import annotations + +import threading +from contextlib import contextmanager +from typing import List, Any, Optional, Dict, Union, Tuple + +import grpc +from google.protobuf.message import Message +from omegaconf import DictConfig +from torch import Tensor + +from edml.core.battery import Battery +from edml.generated import connection_pb2 +from edml.generated.connection_pb2 import ( + TrainGlobalParallelSplitLearningResponse, + TrainGlobalResponse, + SetWeightsRequest, + TrainEpochResponse, + TrainBatchRequest, + TrainBatchResponse, + EvalGlobalResponse, + EvalResponse, + EvalBatchResponse, + FullModelTrainResponse, + BatteryStatusResponse, + SingleBatchTrainingResponse, + DatasetModelInfoResponse, + SingleBatchBackwardResponse, +) +from edml.generated.connection_pb2_grpc import DeviceStub +from edml.generated.datastructures_pb2 import ( + Weights, + DeviceInfo, + Activations, + Labels, + Gradients, + Empty, +) +from edml.helpers.interceptors import DeviceClientInterceptor +from edml.helpers.logging import SimpleLogger +from edml.helpers.metrics import ( + DiagnosticMetricResultContainer, + ModelMetricResultContainer, +) +from edml.helpers.proto_helpers import ( + proto_to_metrics, + _proto_size_per_field, + state_dict_to_proto, + proto_to_weights, + proto_to_state_dict, + tensor_to_proto, + proto_to_tensor, +) +from edml.helpers.types import ( + HasMetrics, + StateDict, + DeviceBatteryStatusReport, + DeviceBatteryStatus, +) + + +class DeviceRequestDispatcher: + """ + Implements functionality to make gRPC calls to call actions on other devices. Serializes python objects to protobufs + for sending. Maintains connection information to the other devices. + If the device's battery is empty, it triggers a stop_event to terminate the device's process. + This way, it is assured that a device cannot call other devices while having no (simulated) energy left. + To account for the energy, an interceptor is defined to measure the communication overhead. + + Attributes: + device_id (str): This device's id. + logger (SimpleLogger): The logger instance the device can use. + battery (Battery): The device's battery. Certain function consume energy and drain the battery. Here it is + updated to account for the communication overhead. + stop_event (threading.Event): Event to be set when the device is out of battery + devices (List[DictConfig[str, Any]]): Connection information to the other devices. + """ + + def __init__( + self, + devices: List[DictConfig[str, Any]], + logger: Optional[SimpleLogger] = None, + battery: Optional[Battery] = None, + stop_event: Optional[threading.Event] = None, + device_id: Optional[str] = None, + ): + self.devices = devices + self.connections: Dict[str, DeviceStub] = {} + # optional, interceptor only works if all three are set + self.logger = logger + self.battery = battery + self.stop_event = stop_event + + self._establish_connections() + self.connections_lock = threading.Lock() + self.device_id = device_id # used for diagnostic metrics to assign the source device correctly + + def __get_device_address__(self, device_id: str) -> Optional[str]: + for device in self.devices: + if device.device_id == device_id: + return device.address + return None + + def _establish_connections(self): + for device in self.devices: + channel = grpc.insecure_channel( + device.address, + options=[ + ("grpc.max_send_message_length", -1), + ("grpc.max_receive_message_length", -1), + ], + ) + if ( + self.logger is not None + and self.battery is not None + and self.stop_event is not None + ): + channel = grpc.intercept_channel( + channel, + DeviceClientInterceptor(self.logger, self.battery, self.stop_event), + ) + stub = DeviceStub(channel) + self.connections[device.device_id] = stub + print(f"active devices {self.connections.keys()}") + + def _get_connection(self, device_id: str) -> DeviceStub: + """Returns the connection to the device with the given id. + If the device_id is not a valid dictionary key, the caller has to handle the KeyError. + """ + with self.connections_lock: + return self.connections[device_id] + + def active_devices(self) -> List[str]: + with self.connections_lock: # lock connections to prevent remove while iterating + return list(self.connections.keys()) + + def _handle_rpc_error(self, device_id: str): + """Hook for handling an RpcError. Happens when the device encounters an empty battery during handling + the request or is not reachable (i.e. most probably shut down). + Here, the connection to the device is removed using a lock to avoid race conditions. + """ + with self.connections_lock: + del self.connections[device_id] + + def _handle_unknown_device_id(self, device_id: str): + """Handler in case the requested device id is not in the connection dictionary. + Could happen if the active devices weren't refreshed properly.""" + if device_id in [device.device_id for device in self.devices]: + print(f"Device {device_id} not active") + else: + print(f"Unknown Device ID {device_id}") + + def _add_byte_size_to_diagnostic_metrics( + self, response: Message, device_id: str, request=None + ): + """ + Adds the byte size of the response to the diagnostic metrics. + + Args: + response: A protobuf message. + device_id: The id of the device the request was sent from. + request: A protobuf message, if the request is of interest. + Returns: + A DiagnosticMetricResultContainer including the added byte size and previous metrics from the response (and possibly the request). + Raises: + None + Notes: + """ + if response.HasField("diagnostic_metrics"): + response: HasMetrics + diagnostic_metrics = proto_to_metrics(response.diagnostic_metrics) + else: + diagnostic_metrics = DiagnosticMetricResultContainer() + diagnostic_metrics.merge(_proto_size_per_field(response, device_id)) + if request is not None: + diagnostic_metrics.merge(_proto_size_per_field(request, device_id)) + return diagnostic_metrics + + @contextmanager + def handle_connection(self, device_id): + """ + Context manager that wraps RPCs and handles errors appropriately. + + Args: + device_id: The id of the device the request is sent to. + Returns: + None + Raises: + None + Notes: + Successful RPCs should return within the with handle_connection() block. + To indicate an unsuccessful RPC, the with block should be followed by a return False statement (some linters + may count this as unreachable which is not true). + """ + try: + # this returns the control to the RPC call + yield + # check for errors after RPC call + except grpc.RpcError: + self._handle_rpc_error(device_id) + except KeyError: + self._handle_unknown_device_id(device_id) + + def train_parallel_on_server( + self, + server_device_id: str, + epochs: int, + round_no: int, + adaptive_threshold_value: Optional[float] = None, + optimizer_state: dict[str, Any] = None, + ): + with self.handle_connection(server_device_id): + response: TrainGlobalParallelSplitLearningResponse = self._get_connection( + server_device_id + ).TrainGlobalParallelSplitLearning( + connection_pb2.TrainGlobalParallelSplitLearningRequest( + round_no=round_no, + adaptive_threshold_value=adaptive_threshold_value, + optimizer_state=state_dict_to_proto(optimizer_state), + ) + ) + return ( + proto_to_weights(response.client_weights), + proto_to_weights(response.server_weights), + proto_to_metrics(response.metrics), + proto_to_state_dict(response.optimizer_state), + self._add_byte_size_to_diagnostic_metrics(response, self.device_id), + ) + return False + + def train_global_on( + self, + device_id: str, + epochs: int, + round_no: int = -1, + adaptive_threshold_value: Optional[float] = None, + optimizer_state: dict[str, Any] = None, + ) -> Union[ + Tuple[ + Dict[str, Any], + Dict[str, Any], + ModelMetricResultContainer, + Dict[str, Any], + DiagnosticMetricResultContainer, + ], + bool, + ]: + with self.handle_connection(device_id): + response: TrainGlobalResponse = self._get_connection(device_id).TrainGlobal( + connection_pb2.TrainGlobalRequest( + epochs=epochs, + round_no=round_no, + adaptive_threshold_value=adaptive_threshold_value, + optimizer_state=state_dict_to_proto(optimizer_state), + ) + ) + return ( + proto_to_weights(response.client_weights), + proto_to_weights(response.server_weights), + proto_to_metrics(response.metrics), + proto_to_state_dict(response.optimizer_state), + self._add_byte_size_to_diagnostic_metrics(response, self.device_id), + ) + return False + + def set_weights_on( + self, device_id: str, state_dict, on_client: bool, wait_for_ready: bool = False + ): + with self.handle_connection(device_id): + self._get_connection(device_id).SetWeights( + SetWeightsRequest( + weights=Weights(weights=state_dict_to_proto(state_dict)), + on_client=on_client, + ), + wait_for_ready=wait_for_ready, + ) + return True + return False + + def train_epoch_on(self, device_id: str, server_device: str, round_no: int = -1): + with self.handle_connection(device_id): + response: TrainEpochResponse = self._get_connection(device_id).TrainEpoch( + connection_pb2.TrainEpochRequest( + server=DeviceInfo( + device_id=server_device, + address=self.__get_device_address__(server_device), + ), + round_no=round_no, + ) + ) + return proto_to_state_dict( + response.weights.weights + ), self._add_byte_size_to_diagnostic_metrics(response, self.device_id) + return False + + def train_batch_on(self, device_id: str, smashed_data, labels): + with self.handle_connection(device_id): + request = TrainBatchRequest( + smashed_data=Activations(activations=tensor_to_proto(smashed_data)), + labels=Labels(labels=tensor_to_proto(labels)), + ) + response: TrainBatchResponse = self._get_connection(device_id).TrainBatch( + request + ) + return ( + proto_to_tensor(response.gradients.gradients), + response.loss, + self._add_byte_size_to_diagnostic_metrics( + response, self.device_id, request + ), + ) + return False + + def evaluate_global_on(self, device_id: str, val: bool = True, fed: bool = False): + with self.handle_connection(device_id): + response: EvalGlobalResponse = self._get_connection( + device_id + ).EvaluateGlobal( + connection_pb2.EvalGlobalRequest(validation=val, federated=fed) + ) + return proto_to_metrics( + response.metrics + ), self._add_byte_size_to_diagnostic_metrics(response, self.device_id) + return False + + def evaluate_on(self, device_id: str, server_device: str, val: bool): + with self.handle_connection(device_id): + response: EvalResponse = self._get_connection(device_id).Evaluate( + connection_pb2.EvalRequest( + server=DeviceInfo( + device_id=server_device, + address=self.__get_device_address__(server_device), + ), + validation=val, + ) + ) + return self._add_byte_size_to_diagnostic_metrics(response, self.device_id) + return False + + def evaluate_batch_on(self, device_id: str, smashed_data, labels): + with self.handle_connection(device_id): + request = connection_pb2.EvalBatchRequest( + smashed_data=Activations(activations=tensor_to_proto(smashed_data)), + labels=Labels(labels=tensor_to_proto(labels)), + ) + response: EvalBatchResponse = self._get_connection(device_id).EvaluateBatch( + request + ) + return self._add_byte_size_to_diagnostic_metrics( + response, self.device_id, request + ) + return False + + def federated_train_on(self, device_id: str, round_no: int = -1) -> Union[ + Tuple[ + StateDict, + StateDict, + int, + ], + bool, + ]: + with self.handle_connection(device_id): + response: FullModelTrainResponse = self._get_connection( + device_id + ).FullModelTraining(connection_pb2.FullModelTrainRequest(round_no=round_no)) + return ( + proto_to_weights(response.client_weights), + proto_to_weights(response.server_weights), + response.num_samples, + proto_to_metrics(response.metrics), + self._add_byte_size_to_diagnostic_metrics(response, self.device_id), + ) + return False + + def start_experiment_on(self, device_id: str, wait_for_ready: bool = False) -> bool: + with self.handle_connection(device_id): + self._get_connection(device_id).StartExperiment( + connection_pb2.StartExperimentRequest(), wait_for_ready=wait_for_ready + ) + return True + return False + + def end_experiment_on(self, device_id: str) -> bool: + with self.handle_connection(device_id): + self._get_connection(device_id).EndExperiment( + connection_pb2.EndExperimentRequest() + ) + return True + return False + + def get_battery_status_on(self, device_id: str) -> DeviceBatteryStatusReport: + with self.handle_connection(device_id): + response: BatteryStatusResponse = self._get_connection( + device_id + ).GetBatteryStatus(connection_pb2.BatteryStatusRequest()) + return DeviceBatteryStatus( + current_capacity=response.status.current_battery_level, + initial_capacity=response.status.initial_battery_level, + ) + return False + + def train_batch_on_client_only( + self, device_id: str, batch_index: int, round_no: int + ) -> Tuple[Tensor, Tensor] | None: + with self.handle_connection(device_id): + response: SingleBatchTrainingResponse = self._get_connection( + device_id + ).TrainSingleBatchOnClient( + connection_pb2.SingleBatchTrainingRequest( + batch_index=batch_index, round_no=round_no + ) + ) + + # The response can only be None if the last batch was smaller than the configured batch size. + if response.HasField("smashed_data"): + return ( + proto_to_tensor(response.smashed_data.activations), + proto_to_tensor(response.labels.labels), + ) + + return None + return False + + def get_dataset_model_info_on( + self, device_id: str + ) -> Union[Tuple[int, int, int, int, int, int], bool]: + with self.handle_connection(device_id): + response: DatasetModelInfoResponse = self._get_connection( + device_id + ).GetDatasetModelInfo(connection_pb2.DatasetModelInfoRequest()) + return ( + response.train_samples, + response.validation_samples, + response.client_fw_flops, + response.server_fw_flops, + response.client_bw_flops, + response.server_bw_flops, + ) + return False + + def backpropagation_on_client_only(self, device_id, gradients): + with self.handle_connection(device_id): + response: SingleBatchBackwardResponse = self._get_connection( + device_id + ).BackwardPropagationSingleBatchOnClient( + connection_pb2.SingleBatchBackwardRequest( + gradients=Gradients(gradients=tensor_to_proto(gradients)) + ) + ) + return ( + None, + proto_to_tensor(response.gradients.gradients), + ) + return False + + def set_gradient_and_finalize_training_on_client_only( + self, client_id: str, gradients: Any + ): + with self.handle_connection(client_id): + response: Empty = self._get_connection( + client_id + ).SetGradientsAndFinalizeTrainingStep( + connection_pb2.SetGradientsRequest( + gradients=Gradients(gradients=tensor_to_proto(gradients)) + ) + ) + return response + return False diff --git a/edml/core/rpc_device_servicer.py b/edml/core/rpc_device_servicer.py new file mode 100644 index 0000000000000000000000000000000000000000..cbaddec41b5cd8eec2bead163b3119e548218af6 --- /dev/null +++ b/edml/core/rpc_device_servicer.py @@ -0,0 +1,221 @@ +from __future__ import annotations + +from edml.core.device import NetworkDevice +from edml.generated import connection_pb2 +from edml.generated.connection_pb2 import ( + StartExperimentResponse, + EndExperimentResponse, + SingleBatchBackwardRequest, + SetGradientsRequest, +) +from edml.generated.connection_pb2_grpc import DeviceServicer +from edml.generated.datastructures_pb2 import ( + Weights, + Gradients, + BatteryStatus, + Activations, + Labels, + Empty, +) +from edml.helpers.proto_helpers import ( + proto_to_state_dict, + state_dict_to_proto, + metrics_to_proto, + proto_to_tensor, + tensor_to_proto, +) + + +class RPCDeviceServicer(DeviceServicer): + """ + Implements the handling for incoming gRPC calls according to the protobuf definition. Added to the gRPC server upon + starting a device. Deserializes protobufs to python objects. + + Attributes: + device (NetworkDevice): This device to receive the calls. + """ + + def __init__(self, device: NetworkDevice): + self.device = device + + def TrainGlobal(self, request, context): + print(f"Called TrainGlobal on device {self.device.device_id}") + client_weights, server_weights, metrics, optimizer_state, diagnostic_metrics = ( + self.device.train_global( + request.epochs, + request.round_no, + request.adaptive_threshold_value, + proto_to_state_dict(request.optimizer_state), + ) + ) + response = connection_pb2.TrainGlobalResponse( + client_weights=Weights(weights=state_dict_to_proto(client_weights)), + server_weights=Weights(weights=state_dict_to_proto(server_weights)), + metrics=metrics_to_proto(metrics), + optimizer_state=state_dict_to_proto(optimizer_state), + diagnostic_metrics=metrics_to_proto(diagnostic_metrics), + ) + return response + + def SetWeights(self, request, context): + print(f"Called SetWeights on device {self.device.device_id}") + weights = proto_to_state_dict(request.weights.weights) + self.device.set_weights(weights, request.on_client) + return connection_pb2.SetWeightsResponse() + + def TrainEpoch(self, request, context): + print(f"Called TrainEpoch on device {self.device.device_id}") + device_info = request.server + device_id = device_info.device_id + round_no = request.round_no + weights, diagnostic_metrics = self.device.train_epoch(device_id, round_no) + proto_weights = state_dict_to_proto(weights) + return connection_pb2.TrainEpochResponse( + weights=Weights(weights=proto_weights), + diagnostic_metrics=metrics_to_proto(diagnostic_metrics), + ) + + def TrainBatch(self, request, context): + activations = proto_to_tensor(request.smashed_data.activations) + labels = proto_to_tensor(request.labels.labels) + gradients, loss, diagnostic_metrics = self.device.train_batch( + activations, labels + ) + proto_gradients = Gradients(gradients=tensor_to_proto(gradients)) + return connection_pb2.TrainBatchResponse( + gradients=proto_gradients, + loss=loss, + diagnostic_metrics=metrics_to_proto(diagnostic_metrics), + ) + + def EvaluateGlobal(self, request, context): + print(f"Called EvaluateGlobal on device {self.device.device_id}") + metrics, diagnostic_metrics = self.device.evaluate_global( + request.validation, request.federated + ) + return connection_pb2.EvalGlobalResponse( + metrics=metrics_to_proto(metrics), + diagnostic_metrics=metrics_to_proto(diagnostic_metrics), + ) + + def Evaluate(self, request, context): + print(f"Called Evaluate on device {self.device.device_id}") + diagnostic_metrics = self.device.evaluate( + request.server.device_id, request.validation + ) + return connection_pb2.EvalResponse( + diagnostic_metrics=metrics_to_proto(diagnostic_metrics) + ) + + def EvaluateBatch(self, request, context): + activations = proto_to_tensor(request.smashed_data.activations) + labels = proto_to_tensor(request.labels.labels) + diagnostic_metrics = self.device.evaluate_batch(activations, labels) + return connection_pb2.EvalBatchResponse( + diagnostic_metrics=metrics_to_proto(diagnostic_metrics) + ) + + def FullModelTraining(self, request, context): + print(f"Called Full Training on device {self.device.device_id}") + client_weights, server_weights, num_samples, metrics, diagnostic_metrics = ( + self.device.federated_train(request.round_no) + ) + return connection_pb2.FullModelTrainResponse( + client_weights=Weights(weights=state_dict_to_proto(client_weights)), + server_weights=Weights(weights=state_dict_to_proto(server_weights)), + num_samples=num_samples, + metrics=metrics_to_proto(metrics), + diagnostic_metrics=metrics_to_proto(diagnostic_metrics), + ) + + def StartExperiment(self, request, context) -> StartExperimentResponse: + print(f"Start Experiment on device {self.device.device_id}") + self.device.start_experiment() + return connection_pb2.StartExperimentResponse() + + def EndExperiment(self, request, context) -> EndExperimentResponse: + print(f"End Experiment on device {self.device.device_id}") + print(f"Remaining battery capacity {self.device.battery.remaining_capacity()}") + self.device.end_experiment() + return connection_pb2.EndExperimentResponse() + + def GetBatteryStatus(self, request, context): + print(f"Get Battery Status on device {self.device.device_id}") + initial_capacity, remaining_capacity = self.device.get_battery_status() + + return connection_pb2.BatteryStatusResponse( + status=BatteryStatus( + initial_battery_level=initial_capacity, + current_battery_level=remaining_capacity, + ) + ) + + def GetDatasetModelInfo(self, request, context): + print(f"Get Dataset and Model Info on device {self.device.device_id}") + return connection_pb2.DatasetModelInfoResponse( + train_samples=len(self.device.client._train_data.dataset), + validation_samples=len(self.device.client._val_data.dataset), + client_fw_flops=int(self.device.client._model_flops["FW"]), + server_fw_flops=int(self.device.server._model_flops["FW"]), + client_bw_flops=int(self.device.client._model_flops["BW"]), + server_bw_flops=int(self.device.server._model_flops["BW"]), + ) + + def TrainGlobalParallelSplitLearning(self, request, context): + print(f"Starting parallel split learning") + clients = self.device.__get_device_ids__() + round_no = request.round_no + adaptive_threshold_value = request.adaptive_threshold_value + optimizer_state = proto_to_state_dict(request.optimizer_state) + + cw, sw, model_metrics, optimizer_state, diagnostic_metrics = ( + self.device.train_parallel_split_learning( + clients=clients, + round_no=round_no, + adaptive_threshold_value=adaptive_threshold_value, + optimizer_state=optimizer_state, + ) + ) + response = connection_pb2.TrainGlobalParallelSplitLearningResponse( + client_weights=Weights(weights=state_dict_to_proto(cw)), + server_weights=Weights(weights=state_dict_to_proto(sw)), + metrics=metrics_to_proto(model_metrics), + optimizer_state=state_dict_to_proto(optimizer_state), + diagnostic_metrics=metrics_to_proto(diagnostic_metrics), + ) + return response + + def TrainSingleBatchOnClient(self, request, context): + batch_index = request.batch_index + round_no = request.round_no + + smashed_data, labels = self.device.client.train_single_batch( + batch_index, round_no=round_no + ) + + smashed_data = Activations(activations=tensor_to_proto(smashed_data)) + labels = Labels(labels=tensor_to_proto(labels)) + return connection_pb2.SingleBatchTrainingResponse( + smashed_data=smashed_data, + labels=labels, + ) + + def BackwardPropagationSingleBatchOnClient( + self, request: SingleBatchBackwardRequest, context + ): + gradients = proto_to_tensor(request.gradients.gradients) + + metrics, gradients = self.device.client.backward_single_batch( + gradients=gradients + ) + return connection_pb2.SingleBatchBackwardResponse( + metrics=metrics_to_proto(metrics), + gradients=Gradients(gradients=tensor_to_proto(gradients)), + ) + + def SetGradientsAndFinalizeTrainingStep( + self, request: SetGradientsRequest, context + ): + gradients = proto_to_tensor(request.gradients.gradients) + self.device.client.set_gradient_and_finalize_training(gradients=gradients) + return Empty() diff --git a/edml/core/start_device.py b/edml/core/start_device.py index 49df053b5073a334d185ce21bc30c2188a545716..79222554154103605e88a25f3bd21cebec9db7ef 100644 --- a/edml/core/start_device.py +++ b/edml/core/start_device.py @@ -8,7 +8,8 @@ from torch import nn from edml.core.battery import Battery from edml.core.client import DeviceClient -from edml.core.device import NetworkDevice, RPCDeviceServicer +from edml.core.device import NetworkDevice +from edml.core.rpc_device_servicer import RPCDeviceServicer from edml.core.server import DeviceServer from edml.dataset_utils.mnist.mnist import single_batch_dataloaders from edml.generated import connection_pb2_grpc diff --git a/edml/tests/controllers/base_controller_test.py b/edml/tests/controllers/base_controller_test.py index 5238ae39c2e535ac22443d8727a13b0939821e84..f2b01b4ccfa62382fe7bcbccff97e0a5d40a05df 100644 --- a/edml/tests/controllers/base_controller_test.py +++ b/edml/tests/controllers/base_controller_test.py @@ -2,7 +2,7 @@ import unittest from unittest.mock import Mock from edml.controllers.base_controller import BaseController -from edml.core.device import DeviceRequestDispatcher +from edml.core.device_request_dispatcher import DeviceRequestDispatcher from edml.tests.controllers.test_helper import load_sample_config diff --git a/edml/tests/controllers/fed_controller_test.py b/edml/tests/controllers/fed_controller_test.py index 45c4ef900c74a02035674ad7cf96be12bd95a5af..df07b2b0f8cc3724de47216a1fba97e4293ba524 100644 --- a/edml/tests/controllers/fed_controller_test.py +++ b/edml/tests/controllers/fed_controller_test.py @@ -6,7 +6,7 @@ from omegaconf import DictConfig from torch import Tensor from edml.controllers.fed_controller import FedController, fed_average -from edml.core.device import DeviceRequestDispatcher +from edml.core.device_request_dispatcher import DeviceRequestDispatcher from edml.generated.connection_pb2 import FullModelTrainResponse from edml.helpers.metrics import ( ModelMetricResultContainer, diff --git a/edml/tests/controllers/split_controller_test.py b/edml/tests/controllers/split_controller_test.py index 8531a2d1fe0aa146494d1ffb45317739a1c01dc1..c5533a1e38bb490f00ae714fae87fbbcd5b01aca 100644 --- a/edml/tests/controllers/split_controller_test.py +++ b/edml/tests/controllers/split_controller_test.py @@ -2,7 +2,7 @@ import unittest from unittest.mock import Mock from edml.controllers.split_controller import SplitController -from edml.core.device import DeviceRequestDispatcher +from edml.core.device_request_dispatcher import DeviceRequestDispatcher from edml.helpers.metrics import ( ModelMetricResultContainer, DiagnosticMetricResultContainer, diff --git a/edml/tests/controllers/swarm_controller_test.py b/edml/tests/controllers/swarm_controller_test.py index 89ad8aa1fe830bc397ff2831cd037df49ba3b4b7..9aa29a7d64befd7bf85b016bf809152e43487da7 100644 --- a/edml/tests/controllers/swarm_controller_test.py +++ b/edml/tests/controllers/swarm_controller_test.py @@ -3,12 +3,9 @@ from unittest.mock import Mock, call from omegaconf import ListConfig -from edml.controllers.scheduler.max_battery import MaxBatteryNextServerScheduler -from edml.controllers.scheduler.random import RandomNextServerScheduler from edml.controllers.scheduler.sequential import SequentialNextServerScheduler -from edml.controllers.scheduler.smart import SmartNextServerScheduler from edml.controllers.swarm_controller import SwarmController -from edml.core.device import DeviceRequestDispatcher +from edml.core.device_request_dispatcher import DeviceRequestDispatcher from edml.helpers.metrics import ( ModelMetricResultContainer, DiagnosticMetricResultContainer, diff --git a/edml/tests/controllers/test_controller_test.py b/edml/tests/controllers/test_controller_test.py index 848071a068d6418f6ac44303f83ada770755bbe9..f63fc5d21afad4de9ae1403f4834eb5a0fa56f63 100644 --- a/edml/tests/controllers/test_controller_test.py +++ b/edml/tests/controllers/test_controller_test.py @@ -2,7 +2,7 @@ import unittest from unittest.mock import patch, Mock, call from edml.controllers.test_controller import TestController -from edml.core.device import DeviceRequestDispatcher +from edml.core.device_request_dispatcher import DeviceRequestDispatcher from edml.tests.controllers.test_helper import load_sample_config diff --git a/edml/tests/core/device_request_dispatcher_test.py b/edml/tests/core/device_request_dispatcher_test.py new file mode 100644 index 0000000000000000000000000000000000000000..a6de7ec4be7a87b578a01088835c9f4b3fd13763 --- /dev/null +++ b/edml/tests/core/device_request_dispatcher_test.py @@ -0,0 +1,538 @@ +import concurrent.futures +import math +import threading +import unittest +from inspect import signature +from unittest.mock import Mock, patch + +import grpc +from omegaconf import DictConfig +from torch import Tensor + +from edml.core.device_request_dispatcher import DeviceRequestDispatcher +from edml.generated import connection_pb2 +from edml.generated.connection_pb2 import ( + SetWeightsResponse, + TrainEpochResponse, + EvalResponse, + EvalBatchResponse, +) +from edml.generated.datastructures_pb2 import ( + DeviceInfo, + Activations, + Labels, + BatteryStatus, +) +from edml.helpers.metrics import ( + ModelMetricResultContainer, + ModelMetricResult, + DiagnosticMetricResultContainer, + DiagnosticMetricResult, +) +from edml.helpers.proto_helpers import ( + weights_to_proto, + metrics_to_proto, + state_dict_to_proto, + gradients_to_proto, + tensor_to_proto, +) +from edml.helpers.types import DeviceBatteryStatus +from edml.tests.controllers.test_helper import get_side_effect + + +class RequestDispatcherTest(unittest.TestCase): + + def setUp(self) -> None: + self.dispatcher = DeviceRequestDispatcher( + [] + ) # pass no devices to avoid grpc calls + # mock connections instead + self.dispatcher.connections["1"] = Mock( + spec=[ + "TrainGlobal", + "SetWeights", + "TrainEpoch", + "TrainBatch", + "EvaluateGlobal", + "Evaluate", + "EvaluateBatch", + "FullModelTraining", + "StartExperiment", + "EndExperiment", + "GetBatteryStatus", + "GetDatasetModelInfo", + ] + ) + self.dispatcher.devices = [ + DictConfig({"device_id": "0", "address": "42"}), # inactive device + DictConfig({"device_id": "1", "address": "43"}), + ] + self.mock_stub = self.dispatcher.connections["1"] # for convenience + self.weights = {"weights": Tensor([42])} + self.gradients = Tensor([42]) + self.activations = Tensor([1]) + self.labels = Tensor([1]) + self.metrics = ModelMetricResultContainer( + [ModelMetricResult("d1", "accuracy", "val", 0.42, 42)] + ) + self.diagnostic_metrics = DiagnosticMetricResultContainer( + [DiagnosticMetricResult("d1", "comp_time", "train", 42)] + ) + + def test_train_global_on_without_error(self): + self.mock_stub.TrainGlobal.return_value = connection_pb2.TrainGlobalResponse( + client_weights=weights_to_proto(self.weights), + server_weights=weights_to_proto(self.weights), + metrics=metrics_to_proto(self.metrics), + diagnostic_metrics=metrics_to_proto(self.diagnostic_metrics), + optimizer_state=state_dict_to_proto({"optimizer_state": 42}), + ) + + client_weights, server_weights, metrics, optimizer_state, diagnostic_metrics = ( + self.dispatcher.train_global_on("1", 42, 43, 3, {"optimizer_state": 44}) + ) + + self.assertEqual(client_weights, self.weights) + self.assertEqual(server_weights, self.weights) + self.assertEqual(metrics, self.metrics) + self.assertEqual(optimizer_state, {"optimizer_state": 42}) + + self._assert_field_size_added_to_diagnostic_metrics(diagnostic_metrics) + self.mock_stub.TrainGlobal.assert_called_once_with( + connection_pb2.TrainGlobalRequest( + epochs=42, + round_no=43, + adaptive_threshold_value=3, + optimizer_state=state_dict_to_proto({"optimizer_state": 44}), + ) + ) + + def test_train_global_on_with_error(self): + self.mock_stub.TrainGlobal.side_effect = grpc.RpcError() + + response = self.dispatcher.train_global_on( + "1", + 42, + round_no=43, + adaptive_threshold_value=3, + optimizer_state={"optimizer_state": 44}, + ) + + self.assertEqual(response, False) + self.mock_stub.TrainGlobal.assert_called_once_with( + connection_pb2.TrainGlobalRequest( + epochs=42, + round_no=43, + adaptive_threshold_value=3, + optimizer_state=state_dict_to_proto({"optimizer_state": 44}), + ) + ) + + def test_set_weights_on_without_error(self): + self.mock_stub.SetWeights.return_value = SetWeightsResponse() + + self.dispatcher.set_weights_on("1", self.weights, True) + + self.mock_stub.SetWeights.assert_called_once_with( + connection_pb2.SetWeightsRequest( + weights=weights_to_proto(self.weights), on_client=True + ), + wait_for_ready=False, + ) + + def test_set_weights_on_with_error(self): + self.mock_stub.SetWeights.side_effect = grpc.RpcError() + + response = self.dispatcher.set_weights_on("1", self.weights, True) + + self.mock_stub.SetWeights.assert_called_once_with( + connection_pb2.SetWeightsRequest( + weights=weights_to_proto(self.weights), on_client=True + ), + wait_for_ready=False, + ) + self.assertEqual(response, False) + + def test_train_epoch_on_without_error(self): + self.mock_stub.TrainEpoch.return_value = TrainEpochResponse( + weights=weights_to_proto(self.weights), + diagnostic_metrics=metrics_to_proto(self.diagnostic_metrics), + ) + + weights, diagnostic_metrics = self.dispatcher.train_epoch_on("1", "0", 42) + + self.assertEqual(weights, self.weights) + self._assert_field_size_added_to_diagnostic_metrics(diagnostic_metrics) + self.mock_stub.TrainEpoch.assert_called_once_with( + connection_pb2.TrainEpochRequest( + server=DeviceInfo(device_id="0", address="42"), round_no=42 + ) + ) + + def test_train_epoch_on_with_error(self): + self.mock_stub.TrainEpoch.side_effect = grpc.RpcError() + + response = self.dispatcher.train_epoch_on("1", "0", 42) + + self.assertEqual(response, False) + self.mock_stub.TrainEpoch.assert_called_once_with( + connection_pb2.TrainEpochRequest( + server=DeviceInfo(device_id="0", address="42"), round_no=42 + ) + ) + + def test_train_batch_on_without_error(self): + self.mock_stub.TrainBatch.return_value = connection_pb2.TrainBatchResponse( + gradients=gradients_to_proto(self.gradients), + diagnostic_metrics=metrics_to_proto(self.diagnostic_metrics), + loss=42.0, + ) + + gradients, loss, diagnostic_metrics = self.dispatcher.train_batch_on( + "1", self.activations, self.labels + ) + + self.assertEqual(gradients, self.gradients) + self.assertEqual(loss, 42.0) + self._assert_field_size_added_to_diagnostic_metrics(diagnostic_metrics) + self.mock_stub.TrainBatch.assert_called_once_with( + connection_pb2.TrainBatchRequest( + smashed_data=Activations(activations=tensor_to_proto(self.activations)), + labels=Labels(labels=tensor_to_proto(self.labels)), + ) + ) + + def test_train_batch_on_with_error(self): + self.mock_stub.TrainBatch.side_effect = grpc.RpcError() + + response = self.dispatcher.train_batch_on("1", self.activations, self.labels) + + self.assertEqual(response, False) + self.mock_stub.TrainBatch.assert_called_once_with( + connection_pb2.TrainBatchRequest( + smashed_data=Activations(activations=tensor_to_proto(self.activations)), + labels=Labels(labels=tensor_to_proto(self.labels)), + ) + ) + + def test_evaluate_global_on_without_error(self): + self.mock_stub.EvaluateGlobal.return_value = connection_pb2.EvalGlobalResponse( + metrics=metrics_to_proto(self.metrics), + diagnostic_metrics=metrics_to_proto(self.diagnostic_metrics), + ) + + metrics, diagnostic_metrics = self.dispatcher.evaluate_global_on( + "1", True, True + ) + + self.assertEqual(metrics, self.metrics) + self._assert_field_size_added_to_diagnostic_metrics(diagnostic_metrics) + self.mock_stub.EvaluateGlobal.assert_called_once_with( + connection_pb2.EvalGlobalRequest(validation=True, federated=True) + ) + + def test_evaluate_global_on_with_error(self): + self.mock_stub.EvaluateGlobal.side_effect = grpc.RpcError() + + response = self.dispatcher.evaluate_global_on("1", True, True) + + self.assertEqual(response, False) + self.mock_stub.EvaluateGlobal.assert_called_once_with( + connection_pb2.EvalGlobalRequest(validation=True, federated=True) + ) + + def test_evaluate_on_without_error(self): + self.mock_stub.Evaluate.return_value = EvalResponse( + diagnostic_metrics=metrics_to_proto(self.diagnostic_metrics) + ) + + diagnostic_metrics = self.dispatcher.evaluate_on("1", "0", True) + + self.assertEqual(diagnostic_metrics, self.diagnostic_metrics) + self.mock_stub.Evaluate.assert_called_once_with( + connection_pb2.EvalRequest( + server=DeviceInfo(device_id="0", address="42"), validation=True + ) + ) + + def test_evaluate_on_with_error(self): + self.mock_stub.Evaluate.side_effect = grpc.RpcError() + + response = self.dispatcher.evaluate_on("1", "0", True) + + self.assertEqual(response, False) + self.mock_stub.Evaluate.assert_called_once_with( + connection_pb2.EvalRequest( + server=DeviceInfo(device_id="0", address="42"), validation=True + ) + ) + + def test_evaluate_batch_on_without_error(self): + self.mock_stub.EvaluateBatch.return_value = EvalBatchResponse( + diagnostic_metrics=metrics_to_proto(self.diagnostic_metrics) + ) + + diagnostic_metrics = self.dispatcher.evaluate_batch_on( + "1", self.activations, self.labels + ) + + self._assert_field_size_added_to_diagnostic_metrics( + diagnostic_metrics + ) # metric field present in response, but not used in practice. Thus, a field size is added + self.mock_stub.EvaluateBatch.assert_called_once_with( + connection_pb2.EvalBatchRequest( + smashed_data=Activations(activations=tensor_to_proto(self.activations)), + labels=Labels(labels=tensor_to_proto(self.labels)), + ) + ) + + def test_evaluate_batch_on_with_error(self): + self.mock_stub.EvaluateBatch.side_effect = grpc.RpcError() + + response = self.dispatcher.evaluate_batch_on("1", self.activations, self.labels) + + self.assertEqual(response, False) + self.mock_stub.EvaluateBatch.assert_called_once_with( + connection_pb2.EvalBatchRequest( + smashed_data=Activations(activations=tensor_to_proto(self.activations)), + labels=Labels(labels=tensor_to_proto(self.labels)), + ) + ) + + def test_full_model_training_without_error(self): + self.mock_stub.FullModelTraining.return_value = ( + connection_pb2.FullModelTrainResponse( + client_weights=weights_to_proto(self.weights), + server_weights=weights_to_proto(self.weights), + num_samples=42, + metrics=metrics_to_proto(self.metrics), + diagnostic_metrics=metrics_to_proto(self.diagnostic_metrics), + ) + ) + + client_weights, server_weights, num_samples, metrics, diagnostic_metrics = ( + self.dispatcher.federated_train_on("1", 42) + ) + + self.assertEqual(client_weights, self.weights) + self.assertEqual(server_weights, self.weights) + self.assertEqual(num_samples, 42) + self.assertEqual(metrics, self.metrics) + self._assert_field_size_added_to_diagnostic_metrics(diagnostic_metrics) + self.mock_stub.FullModelTraining.assert_called_once_with( + connection_pb2.FullModelTrainRequest(round_no=42) + ) + + def test_full_model_training_with_error(self): + self.mock_stub.FullModelTraining.side_effect = grpc.RpcError() + + response = self.dispatcher.federated_train_on("1", 42) + + self.assertEqual(response, False) + self.mock_stub.FullModelTraining.assert_called_once_with( + connection_pb2.FullModelTrainRequest(round_no=42) + ) + + def test_start_experiment_on_without_error(self): + self.mock_stub.StartExperiment.return_value = ( + connection_pb2.StartExperimentResponse() + ) + + response = self.dispatcher.start_experiment_on("1", True) + + self.assertEqual(response, True) + self.mock_stub.StartExperiment.assert_called_once_with( + connection_pb2.StartExperimentRequest(), wait_for_ready=True + ) + + def test_start_experiment_on_with_error(self): + self.mock_stub.StartExperiment.side_effect = grpc.RpcError() + + response = self.dispatcher.start_experiment_on("1", True) + + self.assertEqual(response, False) + self.mock_stub.StartExperiment.assert_called_once_with( + connection_pb2.StartExperimentRequest(), wait_for_ready=True + ) + + def test_end_experiment_on_without_error(self): + self.mock_stub.EndExperiment.return_value = ( + connection_pb2.EndExperimentResponse() + ) + + response = self.dispatcher.end_experiment_on("1") + + self.assertEqual(response, True) + self.mock_stub.EndExperiment.assert_called_once_with( + connection_pb2.EndExperimentRequest() + ) + + def test_end_experiment_on_with_error(self): + self.mock_stub.EndExperiment.side_effect = grpc.RpcError() + + response = self.dispatcher.end_experiment_on("1") + + self.assertEqual(response, False) + self.mock_stub.EndExperiment.assert_called_once_with( + connection_pb2.EndExperimentRequest() + ) + + def test_get_battery_status_without_error(self): + self.mock_stub.GetBatteryStatus.return_value = ( + connection_pb2.BatteryStatusResponse( + status=BatteryStatus(initial_battery_level=42, current_battery_level=21) + ) + ) + + response = self.dispatcher.get_battery_status_on("1") + + self.assertEqual( + response, DeviceBatteryStatus(initial_capacity=42, current_capacity=21) + ) + self.mock_stub.GetBatteryStatus.assert_called_once_with( + connection_pb2.BatteryStatusRequest() + ) + + def test_get_battery_status_with_error(self): + self.mock_stub.GetBatteryStatus.side_effect = grpc.RpcError() + + response = self.dispatcher.get_battery_status_on("1") + + self.assertEqual(response, False) + self.mock_stub.GetBatteryStatus.assert_called_once_with( + connection_pb2.BatteryStatusRequest() + ) + + def test_get_dataset_model_info_without_error(self): + self.mock_stub.GetDatasetModelInfo.return_value = ( + connection_pb2.DatasetModelInfoResponse( + train_samples=42, + validation_samples=21, + client_fw_flops=1, + server_fw_flops=2, + client_bw_flops=3, + server_bw_flops=4, + ) + ) + + response = self.dispatcher.get_dataset_model_info_on("1") + + self.assertEqual(response, (42, 21, 1, 2, 3, 4)) + self.mock_stub.GetDatasetModelInfo.assert_called_once_with( + connection_pb2.DatasetModelInfoRequest() + ) + + def test_get_dataset_model_info_with_error(self): + self.mock_stub.GetDatasetModelInfo.side_effect = grpc.RpcError() + + response = self.dispatcher.get_dataset_model_info_on("1") + + self.assertEqual(response, False) + self.mock_stub.GetDatasetModelInfo.assert_called_once_with( + connection_pb2.DatasetModelInfoRequest() + ) + + def test_handle_calls_to_inactive_device(self): + """Test each method of the dispatcher that does RPC calls to handle calls to inactive devices. + One test where the device is known, but inactive and one where the device is unknown. + """ + methods_names = [ + "train_global_on", + "set_weights_on", + "train_epoch_on", + "train_batch_on", + "evaluate_global_on", + "evaluate_on", + "evaluate_batch_on", + "federated_train_on", + "start_experiment_on", + "end_experiment_on", + "get_battery_status_on", + "get_dataset_model_info_on", + ] + for device in [("0", True), ("2", False)]: # 0 is inactive, 2 is unknown + for method_name in methods_names: + with self.subTest(method_name=f"{method_name}_device{device[0]}"): + with patch("builtins.print") as print_patch: + method = getattr(self.dispatcher, method_name) + params = list(signature(method).parameters) + + # make method call with device id and all other params as None + response = method( + device[0], + *[None for _ in params if _ not in ["self", "device_id"]], + ) + + self.assertEqual(response, False) + self.mock_stub.assert_not_called() + if device[1]: + print_patch.assert_called_once_with( + f"Device {device[0]} not active" + ) + else: + print_patch.assert_called_once_with( + f"Unknown Device ID {device[0]}" + ) + + def _assert_field_size_added_to_diagnostic_metrics(self, diagnostic_metrics): + """Not to use when response only includes diagnostic metrics, as these are ignored for the field size""" + self.assertEqual( + diagnostic_metrics.get_as_list()[0], + self.diagnostic_metrics.get_as_list()[0], + ) # check that previous diagnostic metrics still exist + self.assertEqual( + diagnostic_metrics.get_as_list()[1].name, "size" + ) # check that at least one new diagnostic metric for size was added + + +class RequestDispatcherThreadingTest(unittest.TestCase): + + def test_handle_rpc_error_thread_safety(self): + + # run the procedure with 2 to 20 threads + for n in range(2, 21): + with self.subTest(method_name=f"Test with {n} threads"): + print(f"Test with {n} threads") + self.request_dispatcher = DeviceRequestDispatcher([]) + + # mock n connections with half of them returning errors and the other half valid responses + self.request_dispatcher.connections = { + f"d{i}": Mock(spec=["GetBatteryStatus"]) for i in range(n) + } + response = connection_pb2.BatteryStatusResponse( + status=BatteryStatus( + initial_battery_level=42, current_battery_level=21 + ) + ) + for i, mock in enumerate(self.request_dispatcher.connections.values()): + mock.GetBatteryStatus.side_effect = get_side_effect(i, response) + + # make n concurrent calls to the dispatcher + responses = [] + response_lock = threading.Lock() + with concurrent.futures.ThreadPoolExecutor(max_workers=n) as executor: + futures = [ + executor.submit( + self.request_dispatcher.get_battery_status_on, device_id + ) + for device_id in self.request_dispatcher.active_devices() + ] + for future in concurrent.futures.as_completed(futures): + with response_lock: + responses.append(future.result()) + + # check that only error connections were removed + self.assertEqual( + list(self.request_dispatcher.connections.keys()), + [f"d{i}" for i in range(n) if i % 2 == 0], + ) + + response_battery_level = DeviceBatteryStatus( + initial_capacity=42, current_capacity=21 + ) + + # check that there are n responses in total, half of them valid and half of them None + self.assertEqual(responses.count(False), math.floor(n / 2)) + self.assertEqual( + responses.count(response_battery_level), math.ceil(n / 2) + ) diff --git a/edml/tests/core/device_test.py b/edml/tests/core/device_test.py index 85b0e1f7f7204efa6a66b3c95f89a5c392098855..809c69587676f3024fff22af46e17e1347469e4b 100644 --- a/edml/tests/core/device_test.py +++ b/edml/tests/core/device_test.py @@ -1,53 +1,24 @@ -import concurrent.futures -import math -import threading import unittest -from inspect import signature -from unittest.mock import Mock, patch +from unittest.mock import Mock -import grpc -from grpc import StatusCode -from grpc_testing import server_from_dictionary, strict_real_time from omegaconf import DictConfig from torch import Tensor -from edml.core.battery import Battery, BatteryEmptyException +from edml.core.battery import Battery from edml.core.client import DeviceClient -from edml.core.device import RPCDeviceServicer, NetworkDevice, DeviceRequestDispatcher +from edml.core.device import NetworkDevice from edml.core.server import DeviceServer -from edml.generated import connection_pb2, datastructures_pb2 +from edml.generated import connection_pb2 from edml.generated.connection_pb2 import ( - EvalResponse, - EvalBatchResponse, - SetWeightsResponse, TrainEpochResponse, ) from edml.generated.datastructures_pb2 import ( - Activations, - Labels, Weights, DeviceInfo, - BatteryStatus, -) -from edml.helpers.metrics import ( - ModelMetricResult, - ModelMetricResultContainer, - DiagnosticMetricResultContainer, - DiagnosticMetricResult, ) from edml.helpers.proto_helpers import ( - tensor_to_proto, - proto_to_tensor, state_dict_to_proto, - proto_to_state_dict, - proto_to_weights, - weights_to_proto, - gradients_to_proto, - metrics_to_proto, - proto_to_metrics, ) -from edml.helpers.types import DeviceBatteryStatus -from edml.tests.controllers.test_helper import get_side_effect class NetworkDeviceTest(unittest.TestCase): @@ -93,365 +64,6 @@ class NetworkDeviceTest(unittest.TestCase): ) -class RPCDeviceServicerTest(unittest.TestCase): - - def setUp(self) -> None: - # instantiating NetworkDevice to mock Device(ABC)'s properties - # store mock_device reference for later function call assertions - self.mock_device = Mock(spec=NetworkDevice(42, None, Mock(Battery))) - self.mock_device.device_id = 42 - - my_servicer = RPCDeviceServicer(device=self.mock_device) - servicers = {connection_pb2.DESCRIPTOR.services_by_name["Device"]: my_servicer} - self.test_server = server_from_dictionary(servicers, strict_real_time()) - - self.metrics = ModelMetricResultContainer( - [ModelMetricResult("d1", "accuracy", "val", 0.42, 42)] - ) - self.diagnostic_metrics = DiagnosticMetricResultContainer( - [DiagnosticMetricResult("d1", "comp_time", "train", 42)] - ) - - def make_call(self, method_name, request): - method = self.test_server.invoke_unary_unary( - method_descriptor=connection_pb2.DESCRIPTOR.services_by_name[ - "Device" - ].methods_by_name[method_name], - invocation_metadata={}, - request=request, - timeout=None, - ) - return method.termination() - - def test_train_global(self): - self.mock_device.train_global.return_value = ( - {"weights": Tensor([42])}, - {"weights": Tensor([43])}, - self.metrics, - {"optimizer_state": 44}, - self.diagnostic_metrics, - ) - request = connection_pb2.TrainGlobalRequest( - epochs=42, - round_no=1, - adaptive_threshold_value=3, - optimizer_state=state_dict_to_proto({"optimizer_state": 42}), - ) - - response, metadata, code, details = self.make_call("TrainGlobal", request) - - self.assertEqual( - ( - proto_to_weights(response.client_weights), - proto_to_weights(response.server_weights), - ), - ({"weights": Tensor([42])}, {"weights": Tensor([43])}), - ) - self.assertEqual(proto_to_metrics(response.metrics), self.metrics) - self.assertEqual( - proto_to_state_dict(response.optimizer_state), {"optimizer_state": 44} - ) - self.assertEqual(code, StatusCode.OK) - self.mock_device.train_global.assert_called_once_with( - 42, 1, 3, {"optimizer_state": 42} - ) - self.assertEqual( - proto_to_metrics(response.diagnostic_metrics), self.diagnostic_metrics - ) - - def test_set_weights(self): - request = connection_pb2.SetWeightsRequest( - weights=Weights(weights=state_dict_to_proto({"weights": Tensor([42])})), - on_client=True, - ) - - response, metadata, code, details = self.make_call("SetWeights", request) - - self.assertEqual(type(response), type(SetWeightsResponse())) - self.assertEqual(code, StatusCode.OK) - self.mock_device.set_weights.assert_called_once_with( - {"weights": Tensor([42])}, True - ) - - def test_train_epoch(self): - self.mock_device.train_epoch.return_value = { - "weights": Tensor([42]) - }, self.diagnostic_metrics - request = connection_pb2.TrainEpochRequest( - server=datastructures_pb2.DeviceInfo(device_id="42", address="") - ) - - response, metadata, code, details = self.make_call("TrainEpoch", request) - - self.assertEqual( - proto_to_state_dict(response.weights.weights), {"weights": Tensor([42])} - ) - self.assertEqual(code, StatusCode.OK) - self.mock_device.train_epoch.assert_called_once_with("42", 0) - self.assertEqual( - proto_to_metrics(response.diagnostic_metrics), self.diagnostic_metrics - ) - - def test_train_batch(self): - self.mock_device.train_batch.return_value = ( - Tensor([42]), - 42.0, - self.diagnostic_metrics, - ) - request = connection_pb2.TrainBatchRequest( - smashed_data=Activations(activations=tensor_to_proto(Tensor([1.0]))), - labels=Labels(labels=tensor_to_proto(Tensor([1]))), - ) - - response, metadata, code, details = self.make_call("TrainBatch", request) - - self.assertEqual(code, StatusCode.OK) - self.mock_device.train_batch.assert_called_once_with(Tensor([1.0]), Tensor([1])) - self.assertEqual(proto_to_tensor(response.gradients.gradients), Tensor([42])) - self.assertEqual( - proto_to_metrics(response.diagnostic_metrics), self.diagnostic_metrics - ) - - def test_evaluate_global(self): - self.mock_device.evaluate_global.return_value = ( - self.metrics, - self.diagnostic_metrics, - ) - request = connection_pb2.EvalGlobalRequest(validation=False, federated=False) - - response, metadata, code, details = self.make_call("EvaluateGlobal", request) - - self.assertEqual(proto_to_metrics(response.metrics), self.metrics) - self.assertEqual(code, StatusCode.OK) - self.mock_device.evaluate_global.assert_called_once_with(False, False) - self.assertEqual( - proto_to_metrics(response.diagnostic_metrics), self.diagnostic_metrics - ) - - def test_evaluate(self): - self.mock_device.evaluate.return_value = self.diagnostic_metrics - request = connection_pb2.EvalRequest( - server=DeviceInfo(device_id="42", address=""), validation=True - ) - - response, metadata, code, details = self.make_call("Evaluate", request) - - self.assertEqual(type(response), type(EvalResponse())) - self.assertEqual(code, StatusCode.OK) - self.mock_device.evaluate.assert_called_once_with("42", True) - self.assertEqual( - proto_to_metrics(response.diagnostic_metrics), self.diagnostic_metrics - ) - - def test_evaluate_batch(self): - self.mock_device.evaluate_batch.return_value = self.diagnostic_metrics - request = connection_pb2.EvalBatchRequest( - smashed_data=Activations(activations=tensor_to_proto(Tensor([1.0]))), - labels=Labels(labels=tensor_to_proto(Tensor([1]))), - ) - - response, metadata, code, details = self.make_call("EvaluateBatch", request) - - self.assertEqual(type(response), type(EvalBatchResponse())) - self.assertEqual(code, StatusCode.OK) - self.mock_device.evaluate_batch.assert_called_once_with( - Tensor([1.0]), Tensor([1]) - ) - self.assertEqual( - proto_to_metrics(response.diagnostic_metrics), self.diagnostic_metrics - ) - - def test_full_model_training(self): - self.mock_device.federated_train.return_value = ( - {"weights": Tensor([42])}, - {"weights": Tensor([43])}, - 44, - self.metrics, - self.diagnostic_metrics, - ) - request = connection_pb2.FullModelTrainRequest() - - response, metadata, code, details = self.make_call("FullModelTraining", request) - - self.assertEqual( - proto_to_state_dict(response.client_weights.weights), - {"weights": Tensor([42])}, - ) - self.assertEqual( - proto_to_state_dict(response.server_weights.weights), - {"weights": Tensor([43])}, - ) - self.assertEqual(response.num_samples, 44) - self.assertEqual(proto_to_metrics(response.metrics), self.metrics) - self.assertEqual(code, StatusCode.OK) - self.mock_device.federated_train.assert_called_once() - - def test_start_experiment(self): - self.mock_device.start_experiment.return_value = None - request = connection_pb2.StartExperimentRequest() - - response, metadata, code, details = self.make_call("StartExperiment", request) - - self.assertEqual(type(response), type(connection_pb2.StartExperimentResponse())) - self.assertEqual(code, StatusCode.OK) - self.mock_device.start_experiment.assert_called_once() - - def test_end_experiment(self): - self.mock_device.end_experiment.return_value = None - request = connection_pb2.EndExperimentRequest() - - response, metadata, code, details = self.make_call("EndExperiment", request) - - self.assertEqual(type(response), type(connection_pb2.EndExperimentResponse())) - self.assertEqual(code, StatusCode.OK) - self.mock_device.end_experiment.assert_called_once() - - def test_battery_status(self): - self.mock_device.get_battery_status.return_value = (42, 21) - request = connection_pb2.BatteryStatusRequest() - - response, metadata, code, details = self.make_call("GetBatteryStatus", request) - - self.assertEqual( - response.status, - BatteryStatus(initial_battery_level=42, current_battery_level=21), - ) - self.assertEqual(code, StatusCode.OK) - self.mock_device.get_battery_status.assert_called_once() - - def test_dataset_model_info(self): - self.mock_device.client._train_data.dataset = [1] - self.mock_device.client._val_data.dataset = [2] - self.mock_device.client._model_flops = {"FW": 3, "BW": 6} - self.mock_device.server._model_flops = {"FW": 4, "BW": 8} - request = connection_pb2.DatasetModelInfoRequest() - - response, metadata, code, details = self.make_call( - "GetDatasetModelInfo", request - ) - - self.assertEqual(code, StatusCode.OK) - self.assertEqual(response.train_samples, 1) - self.assertEqual(response.validation_samples, 1) - self.assertEqual(response.client_fw_flops, 3) - self.assertEqual(response.server_fw_flops, 4) - self.assertEqual(response.client_bw_flops, 6) - self.assertEqual(response.server_bw_flops, 8) - - -class RPCDeviceServicerBatteryEmptyTest(unittest.TestCase): - - def setUp(self) -> None: - # instantiating NetworkDevice to mock Device(ABC)'s properties - # store mock_device reference for later function call assertions - self.mock_device = Mock(spec=NetworkDevice(42, None, Mock(Battery))) - my_servicer = RPCDeviceServicer(device=self.mock_device) - servicers = {connection_pb2.DESCRIPTOR.services_by_name["Device"]: my_servicer} - self.test_server = server_from_dictionary(servicers, strict_real_time()) - - def make_call(self, method_name, request): - method = self.test_server.invoke_unary_unary( - method_descriptor=connection_pb2.DESCRIPTOR.services_by_name[ - "Device" - ].methods_by_name[method_name], - invocation_metadata={}, - request=request, - timeout=None, - ) - return method.termination() - - def test_stop_at_train_global(self): - self.mock_device.train_global.side_effect = BatteryEmptyException( - "Battery empty" - ) - request = connection_pb2.TrainGlobalRequest( - optimizer_state=state_dict_to_proto(None) - ) - self._test_device_out_of_battery_lets_rpc_fail(request, "TrainGlobal") - - def test_stop_at_set_weights(self): - self.mock_device.set_weights.side_effect = BatteryEmptyException( - "Battery empty" - ) - request = connection_pb2.SetWeightsRequest(weights=weights_to_proto({})) - self._test_device_out_of_battery_lets_rpc_fail(request, "SetWeights") - - def test_stop_at_train_epoch(self): - self.mock_device.train_epoch.side_effect = BatteryEmptyException( - "Battery empty" - ) - request = connection_pb2.TrainEpochRequest(server=None) - self._test_device_out_of_battery_lets_rpc_fail(request, "TrainEpoch") - - def test_stop_at_train_batch(self): - self.mock_device.train_batch.side_effect = BatteryEmptyException( - "Battery empty" - ) - request = connection_pb2.TrainBatchRequest( - smashed_data=Activations(activations=tensor_to_proto(Tensor([1]))), - labels=Labels(labels=tensor_to_proto(Tensor([1]))), - ) - self._test_device_out_of_battery_lets_rpc_fail(request, "TrainBatch") - - def test_stop_at_evaluate_global(self): - self.mock_device.evaluate_global.side_effect = BatteryEmptyException( - "Battery empty" - ) - request = connection_pb2.EvalGlobalRequest() - self._test_device_out_of_battery_lets_rpc_fail(request, "EvaluateGlobal") - - def test_stop_at_evaluate(self): - self.mock_device.evaluate.side_effect = BatteryEmptyException("Battery empty") - request = connection_pb2.EvalRequest(server=None) - self._test_device_out_of_battery_lets_rpc_fail(request, "Evaluate") - - def test_stop_at_evaluate_batch(self): - self.mock_device.evaluate_batch.side_effect = BatteryEmptyException( - "Battery empty" - ) - request = connection_pb2.EvalBatchRequest( - smashed_data=Activations(activations=tensor_to_proto(Tensor([1]))), - labels=Labels(labels=tensor_to_proto(Tensor([1]))), - ) - self._test_device_out_of_battery_lets_rpc_fail(request, "EvaluateBatch") - - def test_stop_at_full_model_training(self): - self.mock_device.federated_train.side_effect = BatteryEmptyException( - "Battery empty" - ) - request = connection_pb2.FullModelTrainRequest() - self._test_device_out_of_battery_lets_rpc_fail(request, "FullModelTraining") - - def test_stop_at_start_experiment(self): - self.mock_device.start_experiment.side_effect = BatteryEmptyException( - "Battery empty" - ) - request = connection_pb2.StartExperimentRequest() - self._test_device_out_of_battery_lets_rpc_fail(request, "StartExperiment") - - def test_stop_at_end_experiment(self): - self.mock_device.end_experiment.side_effect = BatteryEmptyException( - "Battery empty" - ) - request = connection_pb2.EndExperimentRequest() - self._test_device_out_of_battery_lets_rpc_fail(request, "EndExperiment") - - def test_stop_at_get_battery_status(self): - self.mock_device.get_battery_status.side_effect = BatteryEmptyException( - "Battery empty" - ) - request = connection_pb2.BatteryStatusRequest() - self._test_device_out_of_battery_lets_rpc_fail(request, "GetBatteryStatus") - - def _test_device_out_of_battery_lets_rpc_fail(self, request, servicer_method_name): - response, metadata, code, details = self.make_call( - servicer_method_name, request - ) - self.assertIsNone(response) - self.assertEqual(code, StatusCode.UNKNOWN) - self.assertEqual(details, "Exception calling application: Battery empty") - - class BatteryUpdateTest(unittest.TestCase): def setUp(self): @@ -462,501 +74,3 @@ class BatteryUpdateTest(unittest.TestCase): device.set_client(Mock(DeviceClient)) device.train_epoch(None) self.battery.update_time.assert_called() - - -class RequestDispatcherTest(unittest.TestCase): - - def setUp(self) -> None: - self.dispatcher = DeviceRequestDispatcher( - [] - ) # pass no devices to avoid grpc calls - # mock connections instead - self.dispatcher.connections["1"] = Mock( - spec=[ - "TrainGlobal", - "SetWeights", - "TrainEpoch", - "TrainBatch", - "EvaluateGlobal", - "Evaluate", - "EvaluateBatch", - "FullModelTraining", - "StartExperiment", - "EndExperiment", - "GetBatteryStatus", - "GetDatasetModelInfo", - ] - ) - self.dispatcher.devices = [ - DictConfig({"device_id": "0", "address": "42"}), # inactive device - DictConfig({"device_id": "1", "address": "43"}), - ] - self.mock_stub = self.dispatcher.connections["1"] # for convenience - self.weights = {"weights": Tensor([42])} - self.gradients = Tensor([42]) - self.activations = Tensor([1]) - self.labels = Tensor([1]) - self.metrics = ModelMetricResultContainer( - [ModelMetricResult("d1", "accuracy", "val", 0.42, 42)] - ) - self.diagnostic_metrics = DiagnosticMetricResultContainer( - [DiagnosticMetricResult("d1", "comp_time", "train", 42)] - ) - - def test_train_global_on_without_error(self): - self.mock_stub.TrainGlobal.return_value = connection_pb2.TrainGlobalResponse( - client_weights=weights_to_proto(self.weights), - server_weights=weights_to_proto(self.weights), - metrics=metrics_to_proto(self.metrics), - diagnostic_metrics=metrics_to_proto(self.diagnostic_metrics), - optimizer_state=state_dict_to_proto({"optimizer_state": 42}), - ) - - client_weights, server_weights, metrics, optimizer_state, diagnostic_metrics = ( - self.dispatcher.train_global_on("1", 42, 43, 3, {"optimizer_state": 44}) - ) - - self.assertEqual(client_weights, self.weights) - self.assertEqual(server_weights, self.weights) - self.assertEqual(metrics, self.metrics) - self.assertEqual(optimizer_state, {"optimizer_state": 42}) - - self._assert_field_size_added_to_diagnostic_metrics(diagnostic_metrics) - self.mock_stub.TrainGlobal.assert_called_once_with( - connection_pb2.TrainGlobalRequest( - epochs=42, - round_no=43, - adaptive_threshold_value=3, - optimizer_state=state_dict_to_proto({"optimizer_state": 44}), - ) - ) - - def test_train_global_on_with_error(self): - self.mock_stub.TrainGlobal.side_effect = grpc.RpcError() - - response = self.dispatcher.train_global_on( - "1", - 42, - round_no=43, - adaptive_threshold_value=3, - optimizer_state={"optimizer_state": 44}, - ) - - self.assertEqual(response, False) - self.mock_stub.TrainGlobal.assert_called_once_with( - connection_pb2.TrainGlobalRequest( - epochs=42, - round_no=43, - adaptive_threshold_value=3, - optimizer_state=state_dict_to_proto({"optimizer_state": 44}), - ) - ) - - def test_set_weights_on_without_error(self): - self.mock_stub.SetWeights.return_value = SetWeightsResponse() - - self.dispatcher.set_weights_on("1", self.weights, True) - - self.mock_stub.SetWeights.assert_called_once_with( - connection_pb2.SetWeightsRequest( - weights=weights_to_proto(self.weights), on_client=True - ), - wait_for_ready=False, - ) - - def test_set_weights_on_with_error(self): - self.mock_stub.SetWeights.side_effect = grpc.RpcError() - - response = self.dispatcher.set_weights_on("1", self.weights, True) - - self.mock_stub.SetWeights.assert_called_once_with( - connection_pb2.SetWeightsRequest( - weights=weights_to_proto(self.weights), on_client=True - ), - wait_for_ready=False, - ) - self.assertEqual(response, False) - - def test_train_epoch_on_without_error(self): - self.mock_stub.TrainEpoch.return_value = TrainEpochResponse( - weights=weights_to_proto(self.weights), - diagnostic_metrics=metrics_to_proto(self.diagnostic_metrics), - ) - - weights, diagnostic_metrics = self.dispatcher.train_epoch_on("1", "0", 42) - - self.assertEqual(weights, self.weights) - self._assert_field_size_added_to_diagnostic_metrics(diagnostic_metrics) - self.mock_stub.TrainEpoch.assert_called_once_with( - connection_pb2.TrainEpochRequest( - server=DeviceInfo(device_id="0", address="42"), round_no=42 - ) - ) - - def test_train_epoch_on_with_error(self): - self.mock_stub.TrainEpoch.side_effect = grpc.RpcError() - - response = self.dispatcher.train_epoch_on("1", "0", 42) - - self.assertEqual(response, False) - self.mock_stub.TrainEpoch.assert_called_once_with( - connection_pb2.TrainEpochRequest( - server=DeviceInfo(device_id="0", address="42"), round_no=42 - ) - ) - - def test_train_batch_on_without_error(self): - self.mock_stub.TrainBatch.return_value = connection_pb2.TrainBatchResponse( - gradients=gradients_to_proto(self.gradients), - diagnostic_metrics=metrics_to_proto(self.diagnostic_metrics), - loss=42.0, - ) - - gradients, loss, diagnostic_metrics = self.dispatcher.train_batch_on( - "1", self.activations, self.labels - ) - - self.assertEqual(gradients, self.gradients) - self.assertEqual(loss, 42.0) - self._assert_field_size_added_to_diagnostic_metrics(diagnostic_metrics) - self.mock_stub.TrainBatch.assert_called_once_with( - connection_pb2.TrainBatchRequest( - smashed_data=Activations(activations=tensor_to_proto(self.activations)), - labels=Labels(labels=tensor_to_proto(self.labels)), - ) - ) - - def test_train_batch_on_with_error(self): - self.mock_stub.TrainBatch.side_effect = grpc.RpcError() - - response = self.dispatcher.train_batch_on("1", self.activations, self.labels) - - self.assertEqual(response, False) - self.mock_stub.TrainBatch.assert_called_once_with( - connection_pb2.TrainBatchRequest( - smashed_data=Activations(activations=tensor_to_proto(self.activations)), - labels=Labels(labels=tensor_to_proto(self.labels)), - ) - ) - - def test_evaluate_global_on_without_error(self): - self.mock_stub.EvaluateGlobal.return_value = connection_pb2.EvalGlobalResponse( - metrics=metrics_to_proto(self.metrics), - diagnostic_metrics=metrics_to_proto(self.diagnostic_metrics), - ) - - metrics, diagnostic_metrics = self.dispatcher.evaluate_global_on( - "1", True, True - ) - - self.assertEqual(metrics, self.metrics) - self._assert_field_size_added_to_diagnostic_metrics(diagnostic_metrics) - self.mock_stub.EvaluateGlobal.assert_called_once_with( - connection_pb2.EvalGlobalRequest(validation=True, federated=True) - ) - - def test_evaluate_global_on_with_error(self): - self.mock_stub.EvaluateGlobal.side_effect = grpc.RpcError() - - response = self.dispatcher.evaluate_global_on("1", True, True) - - self.assertEqual(response, False) - self.mock_stub.EvaluateGlobal.assert_called_once_with( - connection_pb2.EvalGlobalRequest(validation=True, federated=True) - ) - - def test_evaluate_on_without_error(self): - self.mock_stub.Evaluate.return_value = EvalResponse( - diagnostic_metrics=metrics_to_proto(self.diagnostic_metrics) - ) - - diagnostic_metrics = self.dispatcher.evaluate_on("1", "0", True) - - self.assertEqual(diagnostic_metrics, self.diagnostic_metrics) - self.mock_stub.Evaluate.assert_called_once_with( - connection_pb2.EvalRequest( - server=DeviceInfo(device_id="0", address="42"), validation=True - ) - ) - - def test_evaluate_on_with_error(self): - self.mock_stub.Evaluate.side_effect = grpc.RpcError() - - response = self.dispatcher.evaluate_on("1", "0", True) - - self.assertEqual(response, False) - self.mock_stub.Evaluate.assert_called_once_with( - connection_pb2.EvalRequest( - server=DeviceInfo(device_id="0", address="42"), validation=True - ) - ) - - def test_evaluate_batch_on_without_error(self): - self.mock_stub.EvaluateBatch.return_value = EvalBatchResponse( - diagnostic_metrics=metrics_to_proto(self.diagnostic_metrics) - ) - - diagnostic_metrics = self.dispatcher.evaluate_batch_on( - "1", self.activations, self.labels - ) - - self._assert_field_size_added_to_diagnostic_metrics( - diagnostic_metrics - ) # metric field present in response, but not used in practice. Thus, a field size is added - self.mock_stub.EvaluateBatch.assert_called_once_with( - connection_pb2.EvalBatchRequest( - smashed_data=Activations(activations=tensor_to_proto(self.activations)), - labels=Labels(labels=tensor_to_proto(self.labels)), - ) - ) - - def test_evaluate_batch_on_with_error(self): - self.mock_stub.EvaluateBatch.side_effect = grpc.RpcError() - - response = self.dispatcher.evaluate_batch_on("1", self.activations, self.labels) - - self.assertEqual(response, False) - self.mock_stub.EvaluateBatch.assert_called_once_with( - connection_pb2.EvalBatchRequest( - smashed_data=Activations(activations=tensor_to_proto(self.activations)), - labels=Labels(labels=tensor_to_proto(self.labels)), - ) - ) - - def test_full_model_training_without_error(self): - self.mock_stub.FullModelTraining.return_value = ( - connection_pb2.FullModelTrainResponse( - client_weights=weights_to_proto(self.weights), - server_weights=weights_to_proto(self.weights), - num_samples=42, - metrics=metrics_to_proto(self.metrics), - diagnostic_metrics=metrics_to_proto(self.diagnostic_metrics), - ) - ) - - client_weights, server_weights, num_samples, metrics, diagnostic_metrics = ( - self.dispatcher.federated_train_on("1", 42) - ) - - self.assertEqual(client_weights, self.weights) - self.assertEqual(server_weights, self.weights) - self.assertEqual(num_samples, 42) - self.assertEqual(metrics, self.metrics) - self._assert_field_size_added_to_diagnostic_metrics(diagnostic_metrics) - self.mock_stub.FullModelTraining.assert_called_once_with( - connection_pb2.FullModelTrainRequest(round_no=42) - ) - - def test_full_model_training_with_error(self): - self.mock_stub.FullModelTraining.side_effect = grpc.RpcError() - - response = self.dispatcher.federated_train_on("1", 42) - - self.assertEqual(response, False) - self.mock_stub.FullModelTraining.assert_called_once_with( - connection_pb2.FullModelTrainRequest(round_no=42) - ) - - def test_start_experiment_on_without_error(self): - self.mock_stub.StartExperiment.return_value = ( - connection_pb2.StartExperimentResponse() - ) - - response = self.dispatcher.start_experiment_on("1", True) - - self.assertEqual(response, True) - self.mock_stub.StartExperiment.assert_called_once_with( - connection_pb2.StartExperimentRequest(), wait_for_ready=True - ) - - def test_start_experiment_on_with_error(self): - self.mock_stub.StartExperiment.side_effect = grpc.RpcError() - - response = self.dispatcher.start_experiment_on("1", True) - - self.assertEqual(response, False) - self.mock_stub.StartExperiment.assert_called_once_with( - connection_pb2.StartExperimentRequest(), wait_for_ready=True - ) - - def test_end_experiment_on_without_error(self): - self.mock_stub.EndExperiment.return_value = ( - connection_pb2.EndExperimentResponse() - ) - - response = self.dispatcher.end_experiment_on("1") - - self.assertEqual(response, True) - self.mock_stub.EndExperiment.assert_called_once_with( - connection_pb2.EndExperimentRequest() - ) - - def test_end_experiment_on_with_error(self): - self.mock_stub.EndExperiment.side_effect = grpc.RpcError() - - response = self.dispatcher.end_experiment_on("1") - - self.assertEqual(response, False) - self.mock_stub.EndExperiment.assert_called_once_with( - connection_pb2.EndExperimentRequest() - ) - - def test_get_battery_status_without_error(self): - self.mock_stub.GetBatteryStatus.return_value = ( - connection_pb2.BatteryStatusResponse( - status=BatteryStatus(initial_battery_level=42, current_battery_level=21) - ) - ) - - response = self.dispatcher.get_battery_status_on("1") - - self.assertEqual( - response, DeviceBatteryStatus(initial_capacity=42, current_capacity=21) - ) - self.mock_stub.GetBatteryStatus.assert_called_once_with( - connection_pb2.BatteryStatusRequest() - ) - - def test_get_battery_status_with_error(self): - self.mock_stub.GetBatteryStatus.side_effect = grpc.RpcError() - - response = self.dispatcher.get_battery_status_on("1") - - self.assertEqual(response, False) - self.mock_stub.GetBatteryStatus.assert_called_once_with( - connection_pb2.BatteryStatusRequest() - ) - - def test_get_dataset_model_info_without_error(self): - self.mock_stub.GetDatasetModelInfo.return_value = ( - connection_pb2.DatasetModelInfoResponse( - train_samples=42, - validation_samples=21, - client_fw_flops=1, - server_fw_flops=2, - client_bw_flops=3, - server_bw_flops=4, - ) - ) - - response = self.dispatcher.get_dataset_model_info_on("1") - - self.assertEqual(response, (42, 21, 1, 2, 3, 4)) - self.mock_stub.GetDatasetModelInfo.assert_called_once_with( - connection_pb2.DatasetModelInfoRequest() - ) - - def test_get_dataset_model_info_with_error(self): - self.mock_stub.GetDatasetModelInfo.side_effect = grpc.RpcError() - - response = self.dispatcher.get_dataset_model_info_on("1") - - self.assertEqual(response, False) - self.mock_stub.GetDatasetModelInfo.assert_called_once_with( - connection_pb2.DatasetModelInfoRequest() - ) - - def test_handle_calls_to_inactive_device(self): - """Test each method of the dispatcher that does RPC calls to handle calls to inactive devices. - One test where the device is known, but inactive and one where the device is unknown. - """ - methods_names = [ - "train_global_on", - "set_weights_on", - "train_epoch_on", - "train_batch_on", - "evaluate_global_on", - "evaluate_on", - "evaluate_batch_on", - "federated_train_on", - "start_experiment_on", - "end_experiment_on", - "get_battery_status_on", - "get_dataset_model_info_on", - ] - for device in [("0", True), ("2", False)]: # 0 is inactive, 2 is unknown - for method_name in methods_names: - with self.subTest(method_name=f"{method_name}_device{device[0]}"): - with patch("builtins.print") as print_patch: - method = getattr(self.dispatcher, method_name) - params = list(signature(method).parameters) - - # make method call with device id and all other params as None - response = method( - device[0], - *[None for _ in params if _ not in ["self", "device_id"]], - ) - - self.assertEqual(response, False) - self.mock_stub.assert_not_called() - if device[1]: - print_patch.assert_called_once_with( - f"Device {device[0]} not active" - ) - else: - print_patch.assert_called_once_with( - f"Unknown Device ID {device[0]}" - ) - - def _assert_field_size_added_to_diagnostic_metrics(self, diagnostic_metrics): - """Not to use when response only includes diagnostic metrics, as these are ignored for the field size""" - self.assertEqual( - diagnostic_metrics.get_as_list()[0], - self.diagnostic_metrics.get_as_list()[0], - ) # check that previous diagnostic metrics still exist - self.assertEqual( - diagnostic_metrics.get_as_list()[1].name, "size" - ) # check that at least one new diagnostic metric for size was added - - -class RequestDispatcherThreadingTest(unittest.TestCase): - - def test_handle_rpc_error_thread_safety(self): - - # run the procedure with 2 to 20 threads - for n in range(2, 21): - with self.subTest(method_name=f"Test with {n} threads"): - print(f"Test with {n} threads") - self.request_dispatcher = DeviceRequestDispatcher([]) - - # mock n connections with half of them returning errors and the other half valid responses - self.request_dispatcher.connections = { - f"d{i}": Mock(spec=["GetBatteryStatus"]) for i in range(n) - } - response = connection_pb2.BatteryStatusResponse( - status=BatteryStatus( - initial_battery_level=42, current_battery_level=21 - ) - ) - for i, mock in enumerate(self.request_dispatcher.connections.values()): - mock.GetBatteryStatus.side_effect = get_side_effect(i, response) - - # make n concurrent calls to the dispatcher - responses = [] - response_lock = threading.Lock() - with concurrent.futures.ThreadPoolExecutor(max_workers=n) as executor: - futures = [ - executor.submit( - self.request_dispatcher.get_battery_status_on, device_id - ) - for device_id in self.request_dispatcher.active_devices() - ] - for future in concurrent.futures.as_completed(futures): - with response_lock: - responses.append(future.result()) - - # check that only error connections were removed - self.assertEqual( - list(self.request_dispatcher.connections.keys()), - [f"d{i}" for i in range(n) if i % 2 == 0], - ) - - response_battery_level = DeviceBatteryStatus( - initial_capacity=42, current_capacity=21 - ) - - # check that there are n responses in total, half of them valid and half of them None - self.assertEqual(responses.count(False), math.floor(n / 2)) - self.assertEqual( - responses.count(response_battery_level), math.ceil(n / 2) - ) diff --git a/edml/tests/core/rpc_device_servicer_test.py b/edml/tests/core/rpc_device_servicer_test.py new file mode 100644 index 0000000000000000000000000000000000000000..d392f3e8318e7417b1185405dda793a9fda3247d --- /dev/null +++ b/edml/tests/core/rpc_device_servicer_test.py @@ -0,0 +1,397 @@ +import unittest +from unittest.mock import Mock + +from grpc import StatusCode +from grpc_testing import server_from_dictionary, strict_real_time +from torch import Tensor + +from edml.core.battery import Battery, BatteryEmptyException +from edml.core.device import NetworkDevice +from edml.core.rpc_device_servicer import RPCDeviceServicer +from edml.generated import connection_pb2, datastructures_pb2 +from edml.generated.connection_pb2 import ( + SetWeightsResponse, + EvalResponse, + EvalBatchResponse, +) +from edml.generated.datastructures_pb2 import ( + Weights, + Activations, + Labels, + DeviceInfo, + BatteryStatus, +) +from edml.helpers.metrics import ( + ModelMetricResultContainer, + ModelMetricResult, + DiagnosticMetricResultContainer, + DiagnosticMetricResult, +) +from edml.helpers.proto_helpers import ( + state_dict_to_proto, + proto_to_weights, + proto_to_metrics, + proto_to_state_dict, + tensor_to_proto, + proto_to_tensor, + weights_to_proto, +) + + +class RPCDeviceServicerTest(unittest.TestCase): + + def setUp(self) -> None: + # instantiating NetworkDevice to mock Device(ABC)'s properties + # store mock_device reference for later function call assertions + self.mock_device = Mock(spec=NetworkDevice(42, None, Mock(Battery))) + self.mock_device.device_id = 42 + + my_servicer = RPCDeviceServicer(device=self.mock_device) + servicers = {connection_pb2.DESCRIPTOR.services_by_name["Device"]: my_servicer} + self.test_server = server_from_dictionary(servicers, strict_real_time()) + + self.metrics = ModelMetricResultContainer( + [ModelMetricResult("d1", "accuracy", "val", 0.42, 42)] + ) + self.diagnostic_metrics = DiagnosticMetricResultContainer( + [DiagnosticMetricResult("d1", "comp_time", "train", 42)] + ) + + def make_call(self, method_name, request): + method = self.test_server.invoke_unary_unary( + method_descriptor=connection_pb2.DESCRIPTOR.services_by_name[ + "Device" + ].methods_by_name[method_name], + invocation_metadata={}, + request=request, + timeout=None, + ) + return method.termination() + + def test_train_global(self): + self.mock_device.train_global.return_value = ( + {"weights": Tensor([42])}, + {"weights": Tensor([43])}, + self.metrics, + {"optimizer_state": 44}, + self.diagnostic_metrics, + ) + request = connection_pb2.TrainGlobalRequest( + epochs=42, + round_no=1, + adaptive_threshold_value=3, + optimizer_state=state_dict_to_proto({"optimizer_state": 42}), + ) + + response, metadata, code, details = self.make_call("TrainGlobal", request) + + self.assertEqual( + ( + proto_to_weights(response.client_weights), + proto_to_weights(response.server_weights), + ), + ({"weights": Tensor([42])}, {"weights": Tensor([43])}), + ) + self.assertEqual(proto_to_metrics(response.metrics), self.metrics) + self.assertEqual( + proto_to_state_dict(response.optimizer_state), {"optimizer_state": 44} + ) + self.assertEqual(code, StatusCode.OK) + self.mock_device.train_global.assert_called_once_with( + 42, 1, 3, {"optimizer_state": 42} + ) + self.assertEqual( + proto_to_metrics(response.diagnostic_metrics), self.diagnostic_metrics + ) + + def test_set_weights(self): + request = connection_pb2.SetWeightsRequest( + weights=Weights(weights=state_dict_to_proto({"weights": Tensor([42])})), + on_client=True, + ) + + response, metadata, code, details = self.make_call("SetWeights", request) + + self.assertEqual(type(response), type(SetWeightsResponse())) + self.assertEqual(code, StatusCode.OK) + self.mock_device.set_weights.assert_called_once_with( + {"weights": Tensor([42])}, True + ) + + def test_train_epoch(self): + self.mock_device.train_epoch.return_value = { + "weights": Tensor([42]) + }, self.diagnostic_metrics + request = connection_pb2.TrainEpochRequest( + server=datastructures_pb2.DeviceInfo(device_id="42", address="") + ) + + response, metadata, code, details = self.make_call("TrainEpoch", request) + + self.assertEqual( + proto_to_state_dict(response.weights.weights), {"weights": Tensor([42])} + ) + self.assertEqual(code, StatusCode.OK) + self.mock_device.train_epoch.assert_called_once_with("42", 0) + self.assertEqual( + proto_to_metrics(response.diagnostic_metrics), self.diagnostic_metrics + ) + + def test_train_batch(self): + self.mock_device.train_batch.return_value = ( + Tensor([42]), + 42.0, + self.diagnostic_metrics, + ) + request = connection_pb2.TrainBatchRequest( + smashed_data=Activations(activations=tensor_to_proto(Tensor([1.0]))), + labels=Labels(labels=tensor_to_proto(Tensor([1]))), + ) + + response, metadata, code, details = self.make_call("TrainBatch", request) + + self.assertEqual(code, StatusCode.OK) + self.mock_device.train_batch.assert_called_once_with(Tensor([1.0]), Tensor([1])) + self.assertEqual(proto_to_tensor(response.gradients.gradients), Tensor([42])) + self.assertEqual( + proto_to_metrics(response.diagnostic_metrics), self.diagnostic_metrics + ) + + def test_evaluate_global(self): + self.mock_device.evaluate_global.return_value = ( + self.metrics, + self.diagnostic_metrics, + ) + request = connection_pb2.EvalGlobalRequest(validation=False, federated=False) + + response, metadata, code, details = self.make_call("EvaluateGlobal", request) + + self.assertEqual(proto_to_metrics(response.metrics), self.metrics) + self.assertEqual(code, StatusCode.OK) + self.mock_device.evaluate_global.assert_called_once_with(False, False) + self.assertEqual( + proto_to_metrics(response.diagnostic_metrics), self.diagnostic_metrics + ) + + def test_evaluate(self): + self.mock_device.evaluate.return_value = self.diagnostic_metrics + request = connection_pb2.EvalRequest( + server=DeviceInfo(device_id="42", address=""), validation=True + ) + + response, metadata, code, details = self.make_call("Evaluate", request) + + self.assertEqual(type(response), type(EvalResponse())) + self.assertEqual(code, StatusCode.OK) + self.mock_device.evaluate.assert_called_once_with("42", True) + self.assertEqual( + proto_to_metrics(response.diagnostic_metrics), self.diagnostic_metrics + ) + + def test_evaluate_batch(self): + self.mock_device.evaluate_batch.return_value = self.diagnostic_metrics + request = connection_pb2.EvalBatchRequest( + smashed_data=Activations(activations=tensor_to_proto(Tensor([1.0]))), + labels=Labels(labels=tensor_to_proto(Tensor([1]))), + ) + + response, metadata, code, details = self.make_call("EvaluateBatch", request) + + self.assertEqual(type(response), type(EvalBatchResponse())) + self.assertEqual(code, StatusCode.OK) + self.mock_device.evaluate_batch.assert_called_once_with( + Tensor([1.0]), Tensor([1]) + ) + self.assertEqual( + proto_to_metrics(response.diagnostic_metrics), self.diagnostic_metrics + ) + + def test_full_model_training(self): + self.mock_device.federated_train.return_value = ( + {"weights": Tensor([42])}, + {"weights": Tensor([43])}, + 44, + self.metrics, + self.diagnostic_metrics, + ) + request = connection_pb2.FullModelTrainRequest() + + response, metadata, code, details = self.make_call("FullModelTraining", request) + + self.assertEqual( + proto_to_state_dict(response.client_weights.weights), + {"weights": Tensor([42])}, + ) + self.assertEqual( + proto_to_state_dict(response.server_weights.weights), + {"weights": Tensor([43])}, + ) + self.assertEqual(response.num_samples, 44) + self.assertEqual(proto_to_metrics(response.metrics), self.metrics) + self.assertEqual(code, StatusCode.OK) + self.mock_device.federated_train.assert_called_once() + + def test_start_experiment(self): + self.mock_device.start_experiment.return_value = None + request = connection_pb2.StartExperimentRequest() + + response, metadata, code, details = self.make_call("StartExperiment", request) + + self.assertEqual(type(response), type(connection_pb2.StartExperimentResponse())) + self.assertEqual(code, StatusCode.OK) + self.mock_device.start_experiment.assert_called_once() + + def test_end_experiment(self): + self.mock_device.end_experiment.return_value = None + request = connection_pb2.EndExperimentRequest() + + response, metadata, code, details = self.make_call("EndExperiment", request) + + self.assertEqual(type(response), type(connection_pb2.EndExperimentResponse())) + self.assertEqual(code, StatusCode.OK) + self.mock_device.end_experiment.assert_called_once() + + def test_battery_status(self): + self.mock_device.get_battery_status.return_value = (42, 21) + request = connection_pb2.BatteryStatusRequest() + + response, metadata, code, details = self.make_call("GetBatteryStatus", request) + + self.assertEqual( + response.status, + BatteryStatus(initial_battery_level=42, current_battery_level=21), + ) + self.assertEqual(code, StatusCode.OK) + self.mock_device.get_battery_status.assert_called_once() + + def test_dataset_model_info(self): + self.mock_device.client._train_data.dataset = [1] + self.mock_device.client._val_data.dataset = [2] + self.mock_device.client._model_flops = {"FW": 3, "BW": 6} + self.mock_device.server._model_flops = {"FW": 4, "BW": 8} + request = connection_pb2.DatasetModelInfoRequest() + + response, metadata, code, details = self.make_call( + "GetDatasetModelInfo", request + ) + + self.assertEqual(code, StatusCode.OK) + self.assertEqual(response.train_samples, 1) + self.assertEqual(response.validation_samples, 1) + self.assertEqual(response.client_fw_flops, 3) + self.assertEqual(response.server_fw_flops, 4) + self.assertEqual(response.client_bw_flops, 6) + self.assertEqual(response.server_bw_flops, 8) + + +class RPCDeviceServicerBatteryEmptyTest(unittest.TestCase): + + def setUp(self) -> None: + # instantiating NetworkDevice to mock Device(ABC)'s properties + # store mock_device reference for later function call assertions + self.mock_device = Mock(spec=NetworkDevice(42, None, Mock(Battery))) + my_servicer = RPCDeviceServicer(device=self.mock_device) + servicers = {connection_pb2.DESCRIPTOR.services_by_name["Device"]: my_servicer} + self.test_server = server_from_dictionary(servicers, strict_real_time()) + + def make_call(self, method_name, request): + method = self.test_server.invoke_unary_unary( + method_descriptor=connection_pb2.DESCRIPTOR.services_by_name[ + "Device" + ].methods_by_name[method_name], + invocation_metadata={}, + request=request, + timeout=None, + ) + return method.termination() + + def test_stop_at_train_global(self): + self.mock_device.train_global.side_effect = BatteryEmptyException( + "Battery empty" + ) + request = connection_pb2.TrainGlobalRequest( + optimizer_state=state_dict_to_proto(None) + ) + self._test_device_out_of_battery_lets_rpc_fail(request, "TrainGlobal") + + def test_stop_at_set_weights(self): + self.mock_device.set_weights.side_effect = BatteryEmptyException( + "Battery empty" + ) + request = connection_pb2.SetWeightsRequest(weights=weights_to_proto({})) + self._test_device_out_of_battery_lets_rpc_fail(request, "SetWeights") + + def test_stop_at_train_epoch(self): + self.mock_device.train_epoch.side_effect = BatteryEmptyException( + "Battery empty" + ) + request = connection_pb2.TrainEpochRequest(server=None) + self._test_device_out_of_battery_lets_rpc_fail(request, "TrainEpoch") + + def test_stop_at_train_batch(self): + self.mock_device.train_batch.side_effect = BatteryEmptyException( + "Battery empty" + ) + request = connection_pb2.TrainBatchRequest( + smashed_data=Activations(activations=tensor_to_proto(Tensor([1]))), + labels=Labels(labels=tensor_to_proto(Tensor([1]))), + ) + self._test_device_out_of_battery_lets_rpc_fail(request, "TrainBatch") + + def test_stop_at_evaluate_global(self): + self.mock_device.evaluate_global.side_effect = BatteryEmptyException( + "Battery empty" + ) + request = connection_pb2.EvalGlobalRequest() + self._test_device_out_of_battery_lets_rpc_fail(request, "EvaluateGlobal") + + def test_stop_at_evaluate(self): + self.mock_device.evaluate.side_effect = BatteryEmptyException("Battery empty") + request = connection_pb2.EvalRequest(server=None) + self._test_device_out_of_battery_lets_rpc_fail(request, "Evaluate") + + def test_stop_at_evaluate_batch(self): + self.mock_device.evaluate_batch.side_effect = BatteryEmptyException( + "Battery empty" + ) + request = connection_pb2.EvalBatchRequest( + smashed_data=Activations(activations=tensor_to_proto(Tensor([1]))), + labels=Labels(labels=tensor_to_proto(Tensor([1]))), + ) + self._test_device_out_of_battery_lets_rpc_fail(request, "EvaluateBatch") + + def test_stop_at_full_model_training(self): + self.mock_device.federated_train.side_effect = BatteryEmptyException( + "Battery empty" + ) + request = connection_pb2.FullModelTrainRequest() + self._test_device_out_of_battery_lets_rpc_fail(request, "FullModelTraining") + + def test_stop_at_start_experiment(self): + self.mock_device.start_experiment.side_effect = BatteryEmptyException( + "Battery empty" + ) + request = connection_pb2.StartExperimentRequest() + self._test_device_out_of_battery_lets_rpc_fail(request, "StartExperiment") + + def test_stop_at_end_experiment(self): + self.mock_device.end_experiment.side_effect = BatteryEmptyException( + "Battery empty" + ) + request = connection_pb2.EndExperimentRequest() + self._test_device_out_of_battery_lets_rpc_fail(request, "EndExperiment") + + def test_stop_at_get_battery_status(self): + self.mock_device.get_battery_status.side_effect = BatteryEmptyException( + "Battery empty" + ) + request = connection_pb2.BatteryStatusRequest() + self._test_device_out_of_battery_lets_rpc_fail(request, "GetBatteryStatus") + + def _test_device_out_of_battery_lets_rpc_fail(self, request, servicer_method_name): + response, metadata, code, details = self.make_call( + servicer_method_name, request + ) + self.assertIsNone(response) + self.assertEqual(code, StatusCode.UNKNOWN) + self.assertEqual(details, "Exception calling application: Battery empty") diff --git a/edml/tests/integration/rpc_server_test.py b/edml/tests/integration/rpc_server_test.py index faafd7b9b94300c9509d777453694381dfd1529c..0314d1094ffb392d7702f5d7660d2abda8f6f68c 100644 --- a/edml/tests/integration/rpc_server_test.py +++ b/edml/tests/integration/rpc_server_test.py @@ -7,7 +7,9 @@ import grpc from omegaconf import DictConfig from edml.core.battery import Battery, BatteryEmptyException -from edml.core.device import NetworkDevice, RPCDeviceServicer, DeviceRequestDispatcher +from edml.core.device import NetworkDevice +from edml.core.device_request_dispatcher import DeviceRequestDispatcher +from edml.core.rpc_device_servicer import RPCDeviceServicer from edml.generated import connection_pb2_grpc from edml.helpers.interceptors import DeviceServerInterceptor from edml.helpers.logging import SimpleLogger