diff --git a/config/default.yaml b/config/default.yaml index dac99d0ff6f6b963c60d9cc8e16b0c82df22d217..cf51f4118495971a2b319f335d6844cc6cb1351f 100644 --- a/config/default.yaml +++ b/config/default.yaml @@ -14,8 +14,9 @@ defaults: - wandb: default - _self_ -# If true, controllers will run devices in parallel. If false, they will run sequentially and their runtime is corrected +# If False, controllers will run devices in parallel. If True, they will run sequentially and their runtime is corrected # to account for the parallelism in post-processing. +# Important Note: In a limited energy setting, the runtime will not be accounted for correctly (i.e. wall time) if parallelism is only simulated simulate_parallelism: False own_device_id: "d0" num_devices: ${len:${topology.devices}} diff --git a/edml/config/controller/swarm.yaml b/edml/config/controller/swarm.yaml index f31a60a20c1e4cff7f8eec4bd9b1ff4958753ab5..39bb393eef5bf2b73a6aa5c755c9f41ad765809e 100644 --- a/edml/config/controller/swarm.yaml +++ b/edml/config/controller/swarm.yaml @@ -3,3 +3,4 @@ _target_: edml.controllers.swarm_controller.SwarmController _partial_: true defaults: - scheduler: sequential + - adaptive_threshold_fn: !!null diff --git a/edml/controllers/parallel_split_controller.py b/edml/controllers/parallel_split_controller.py index 82bf519a06dc456d79ba25368914c376236275c9..05e5e49685086fcc47b7af390a64c825118abbe0 100644 --- a/edml/controllers/parallel_split_controller.py +++ b/edml/controllers/parallel_split_controller.py @@ -49,13 +49,13 @@ class ParallelSplitController(BaseController): ) # Start parallel training of all client devices. - adaptive_threshold = self._adaptive_threshold_fn.invoke(i) - self.logger.log({"adaptive-threshold": adaptive_threshold}) + adaptive_threshold_value = self._adaptive_threshold_fn.invoke(i) + self.logger.log({"adaptive-threshold": adaptive_threshold_value}) training_response = self.request_dispatcher.train_parallel_on_server( server_device_id=server_device_id, epochs=1, round_no=i, - adaptive_learning_threshold=adaptive_threshold, + adaptive_threshold_value=adaptive_threshold_value, optimizer_state=optimizer_state, ) diff --git a/edml/controllers/scheduler/smart.py b/edml/controllers/scheduler/smart.py index 48940331247f647a855d72bf3e8a54b44f13ba78..eddf1e98f9b18b0589ff0f61782e54a98ccdd130 100644 --- a/edml/controllers/scheduler/smart.py +++ b/edml/controllers/scheduler/smart.py @@ -90,8 +90,10 @@ class SmartNextServerScheduler(NextServerScheduler): device_params_list.append(device_params) global_params = GlobalParams() global_params.fill_values_from_config(self.cfg) - global_params.client_model_flops = model_flops["client"] - global_params.server_model_flops = model_flops["server"] + global_params.client_fw_flops = model_flops["client"]["FW"] + global_params.server_fw_flops = model_flops["server"]["FW"] + global_params.client_bw_flops = model_flops["client"]["BW"] + global_params.server_bw_flops = model_flops["server"]["BW"] global_params.client_norm_fw_time = optimization_metrics["client_norm_fw_time"] global_params.client_norm_bw_time = optimization_metrics["client_norm_bw_time"] global_params.server_norm_fw_time = optimization_metrics["server_norm_fw_time"] diff --git a/edml/controllers/strategy_optimization.py b/edml/controllers/strategy_optimization.py index cd7bae8ac0db1c8eb4bcf818713ed8b2f1ed0e47..7c11f73652cf1aecd792a26e603a099c488b9943 100644 --- a/edml/controllers/strategy_optimization.py +++ b/edml/controllers/strategy_optimization.py @@ -29,8 +29,10 @@ class GlobalParams: cost_per_byte_sent=None, cost_per_byte_received=None, cost_per_flop=None, - client_model_flops=None, - server_model_flops=None, + client_fw_flops=None, # mandatory forward FLOPs + server_fw_flops=None, # mandatory forward FLOPs + client_bw_flops=None, # optional backward FLOPs if bw FLOPs != 2 * fw FLOPs (e.g. for AE) + server_bw_flops=None, # optional backward FLOPs if bw FLOPs != 2 * fw FLOPs (e.g. for AE) smashed_data_size=None, label_size=None, gradient_size=None, @@ -49,8 +51,10 @@ class GlobalParams: self.cost_per_byte_sent = cost_per_byte_sent self.cost_per_byte_received = cost_per_byte_received self.cost_per_flop = cost_per_flop - self.client_model_flops = client_model_flops - self.server_model_flops = server_model_flops + self.client_fw_flops = client_fw_flops + self.server_fw_flops = server_fw_flops + self.client_bw_flops = client_bw_flops # optional backward FLOPs if bw FLOPs != 2 * fw FLOPs (e.g. for AE) + self.server_bw_flops = server_bw_flops # optional backward FLOPs if bw FLOPs != 2 * fw FLOPs (e.g. for AE) self.smashed_data_size = smashed_data_size self.batch_size = batch_size self.client_weights_size = client_weights_size @@ -79,8 +83,8 @@ class GlobalParams: and self.cost_per_byte_sent is not None and self.cost_per_byte_received is not None and self.cost_per_flop is not None - and self.client_model_flops is not None - and self.server_model_flops is not None + and self.client_fw_flops is not None + and self.server_fw_flops is not None and self.optimizer_state_size is not None and self.smashed_data_size is not None and self.label_size is not None @@ -181,22 +185,24 @@ class ServerChoiceOptimizer: def num_flops_per_round_on_device(self, device_id, server_device_id): device = self._get_device_params(device_id) total_flops = 0 + # set bw FLOPs to 2 * fw FLOPs if not provided + if self.global_params.client_bw_flops and self.global_params.server_bw_flops: + client_bw_flops = self.global_params.client_bw_flops + server_bw_flops = self.global_params.server_bw_flops + else: + client_bw_flops = ( + 2 * self.global_params.client_fw_flops + ) # by default bw FLOPs = 2*fw FLOPs + server_bw_flops = 2 * self.global_params.server_fw_flops + client_fw_flops = self.global_params.client_fw_flops + server_fw_flops = self.global_params.server_fw_flops if device_id == server_device_id: total_flops += ( - self.global_params.server_model_flops - * self._total_train_dataset_size() - * 3 - ) # fw + 2bw - total_flops += ( - self.global_params.server_model_flops - * self._total_validation_dataset_size() - ) # fw - total_flops += ( - self.global_params.client_model_flops * device.train_samples * 3 - ) # fw + 2bw - total_flops += ( - self.global_params.client_model_flops * device.validation_samples - ) # fw + server_fw_flops + server_bw_flops + ) * self._total_train_dataset_size() + total_flops += server_fw_flops * self._total_validation_dataset_size() + total_flops += (client_fw_flops + client_bw_flops) * device.train_samples + total_flops += client_fw_flops * device.validation_samples return total_flops def num_bytes_sent_per_round_on_device(self, device_id, server_device_id): @@ -476,10 +482,10 @@ class EnergySimulator: def _fl_flops_on_device(self, device_id): device = self._get_device_params(device_id) total_flops = 0 - total_flops += self.global_params.client_model_flops * device.train_samples * 3 - total_flops += self.global_params.client_model_flops * device.validation_samples - total_flops += self.global_params.server_model_flops * device.train_samples * 3 - total_flops += self.global_params.server_model_flops * device.validation_samples + total_flops += self.global_params.client_fw_flops * device.train_samples * 3 + total_flops += self.global_params.client_fw_flops * device.validation_samples + total_flops += self.global_params.server_fw_flops * device.train_samples * 3 + total_flops += self.global_params.server_fw_flops * device.validation_samples return total_flops def _fl_data_sent_per_device(self): diff --git a/edml/controllers/swarm_controller.py b/edml/controllers/swarm_controller.py index 4463e2cc8ffd3a89471e8728ba0a117e163e8fea..716b73ec20546a83c7b63e34dd9287a113b4a8e4 100644 --- a/edml/controllers/swarm_controller.py +++ b/edml/controllers/swarm_controller.py @@ -1,5 +1,9 @@ from typing import Any +from edml.controllers.adaptive_threshold_mechanism import AdaptiveThresholdFn +from edml.controllers.adaptive_threshold_mechanism.static import ( + StaticAdaptiveThresholdFn, +) from edml.controllers.base_controller import BaseController from edml.controllers.scheduler.base import NextServerScheduler from edml.helpers.config_helpers import get_device_index_by_id @@ -7,10 +11,16 @@ from edml.helpers.config_helpers import get_device_index_by_id class SwarmController(BaseController): - def __init__(self, cfg, scheduler: NextServerScheduler): + def __init__( + self, + cfg, + scheduler: NextServerScheduler, + adaptive_threshold_fn: AdaptiveThresholdFn = StaticAdaptiveThresholdFn(0.0), + ): super().__init__(cfg) scheduler.initialize(self) self._next_server_scheduler = scheduler + self._adaptive_threshold_fn = adaptive_threshold_fn def _train(self): client_weights = None @@ -87,10 +97,13 @@ class SwarmController(BaseController): device_id=server_device_id, state_dict=server_weights, on_client=False ) + adaptive_threshold = self._adaptive_threshold_fn.invoke(round_no) + self.logger.log({"adaptive-threshold": adaptive_threshold}) training_response = self.request_dispatcher.train_global_on( server_device_id, epochs=1, round_no=round_no, + adaptive_threshold_value=adaptive_threshold, optimizer_state=optimizer_state, ) @@ -121,17 +134,32 @@ class SwarmController(BaseController): """Returns the dataset sizes and model flops of active devices only.""" dataset_sizes = {} model_flops = {} - client_flop_list = [] - server_flop_list = [] + client_fw_flop_list = [] + server_fw_flop_list = [] + client_bw_flop_list = [] + server_bw_flop_list = [] for device_id in self.active_devices: - train_samples, val_samples, client_flops, server_flops = ( - self.request_dispatcher.get_dataset_model_info_on(device_id) - ) + ( + train_samples, + val_samples, + client_fw_flops, + server_fw_flops, + client_bw_flops, + server_bw_flops, + ) = self.request_dispatcher.get_dataset_model_info_on(device_id) dataset_sizes[device_id] = (train_samples, val_samples) - client_flop_list.append(client_flops) - server_flop_list.append(server_flops) + client_fw_flop_list.append(client_fw_flops) + server_fw_flop_list.append(server_fw_flops) + client_bw_flop_list.append(client_bw_flops) + server_bw_flop_list.append(server_bw_flops) # avoid that flops are taken from a device that wasn't used for training and thus has no flops # apart from that, FLOPs should be the same everywhere - model_flops["client"] = max(client_flop_list) - model_flops["server"] = max(server_flop_list) + model_flops["client"] = { + "FW": max(client_fw_flop_list), + "BW": max(client_bw_flop_list), + } + model_flops["server"] = { + "FW": max(server_fw_flop_list), + "BW": max(server_bw_flop_list), + } return dataset_sizes, model_flops diff --git a/edml/core/client.py b/edml/core/client.py index d8096c7ca5bc479a2aec458f2daf1cc442db4986..03437ce4a4ade8472a889586fb3f82981a6278c8 100644 --- a/edml/core/client.py +++ b/edml/core/client.py @@ -1,6 +1,5 @@ from __future__ import annotations -import itertools import time from typing import Optional, Tuple, TYPE_CHECKING, Any @@ -19,6 +18,7 @@ from edml.helpers.flops import estimate_model_flops from edml.helpers.load_optimizer import get_optimizer_and_scheduler from edml.helpers.metrics import DiagnosticMetricResultContainer, DiagnosticMetricResult from edml.helpers.types import StateDict +from edml.models.provider.base import has_autoencoder if TYPE_CHECKING: from edml.core.device import Device @@ -69,13 +69,13 @@ class DeviceClient: self._device = torch.device(get_torch_device_id(cfg)) self._model = model.to(self._device) self._optimizer, self._lr_scheduler = get_optimizer_and_scheduler( - cfg, self._model.parameters() + cfg, self._model.get_optimizer_params() ) # get first sample from train data to estimate model flops sample = self._train_data.dataset.__getitem__(0)[0] if not isinstance(sample, torch.Tensor): sample = torch.tensor(data=sample) - self._model_flops = estimate_model_flops( + self._model_flops: dict[str, int] = estimate_model_flops( self._model, sample.to(self._device).unsqueeze(0) ) self._cfg = cfg @@ -93,6 +93,7 @@ class DeviceClient: If a latency factor is specified, this function sleeps for said amount before returning. """ self.node_device = node_device + self.node_device.logger.log({"client_model_flops": self._model_flops}) @simulate_latency_decorator(latency_factor_attr="latency_factor") def set_weights(self, state_dict: StateDict): @@ -168,10 +169,10 @@ class DeviceClient: # Safety check to ensure that we train same-sized batches only. batch_data, batch_labels = next(self._batchable_data_loader) - # Updates the battery capacity by simulating the required energy consumption for conducting the training step. - self.node_device.battery.update_flops(self._model_flops * len(batch_data)) + # Updates the battery capacity by simulating the required energy consumption for conducting the forward pass. + self.node_device.battery.update_flops(self._model_flops["FW"] * len(batch_data)) - # We train the model using the single batch and return the activations and labels. These get send over to the + # We train the model using the single batch and return the activations and labels. These get sent over to the # server to be then further processed with LatencySimulator(latency_factor=self.latency_factor): @@ -204,11 +205,12 @@ class DeviceClient: start_time_2 = time.time() - self.node_device.battery.update_flops( - self._model_flops * len(batch_data) * 2 - ) # 2x for backward pass + self.node_device.battery.update_flops(self._model_flops["BW"] * len(batch_data)) gradients = gradients.to(self._device) - smashed_data.backward(gradients) + if has_autoencoder(self._model): + self._model.trainable_layers_output.backward(gradients) + else: + smashed_data.backward(gradients) # self._optimizer.step() # We need to store a reference to the smashed_data to make it possible to finalize the training step. @@ -225,7 +227,7 @@ class DeviceClient: metrics_container = DiagnosticMetricResultContainer([metric]) gradients = [] - for param in self._model.parameters(): + for param in self._model.get_optimizer_params(): if param.grad is not None: gradients.append(param.grad) else: @@ -276,7 +278,9 @@ class DeviceClient: self._model.train() diagnostic_metric_container = DiagnosticMetricResultContainer() for idx, (batch_data, batch_labels) in enumerate(self._train_data): - self.node_device.battery.update_flops(self._model_flops * len(batch_data)) + self.node_device.battery.update_flops( + self._model_flops["FW"] * len(batch_data) + ) with LatencySimulator(latency_factor=self.latency_factor): batch_data = batch_data.to(self._device) @@ -299,12 +303,16 @@ class DeviceClient: break server_grad, _server_loss, diagnostic_metrics = train_batch_response diagnostic_metric_container.merge(diagnostic_metrics) - self.node_device.battery.update_flops( - self._model_flops * len(batch_data) * 2 - ) # 2x for backward pass - server_grad = server_grad.to(self._device) - smashed_data.backward(server_grad) - self._optimizer.step() + if server_grad is not None: # otherwise threshold was applied + self.node_device.battery.update_flops( + self._model_flops["BW"] * len(batch_data) + ) + server_grad = server_grad.to(self._device) + if has_autoencoder(self._model): + self._model.trainable_layers_output.backward(server_grad) + else: + smashed_data.backward(server_grad) + self._optimizer.step() client_train_time = ( time.time() - client_train_start_time - sum(server_train_batch_times) @@ -355,7 +363,7 @@ class DeviceClient: for b, (batch_data, batch_labels) in enumerate(dataloader): with LatencySimulator(latency_factor=self.latency_factor): self.node_device.battery.update_flops( - self._model_flops * len(batch_data) + self._model_flops["FW"] * len(batch_data) ) batch_data = batch_data.to(self._device) batch_labels = batch_labels.to(self._device) @@ -382,7 +390,7 @@ class DeviceClient: return diagnostic_metric_results def set_gradient_and_finalize_training(self, gradients: Any): - for param, grad in zip(self._model.parameters(), gradients): + for param, grad in zip(self._model.get_optimizer_params(), gradients): param.grad = grad.to(self._device) self._optimizer.step() diff --git a/edml/core/device.py b/edml/core/device.py index 6770e4271e8dc030af1abc7f8ac010a57242f5bb..aea1ae95de2735dc4c984a13d49dadd632d66432 100644 --- a/edml/core/device.py +++ b/edml/core/device.py @@ -242,13 +242,13 @@ class NetworkDevice(Device): self, clients: list[str], round_no: int, - adaptive_learning_threshold: Optional[float] = None, + adaptive_threshold_value: Optional[float] = None, optimizer_state: dict[str, Any] = None, ): return self.server.train_parallel_split_learning( clients=clients, round_no=round_no, - adaptive_learning_threshold=adaptive_learning_threshold, + adaptive_threshold_value=adaptive_threshold_value, optimizer_state=optimizer_state, ) @@ -303,7 +303,11 @@ class NetworkDevice(Device): @update_battery @log_execution_time("logger", "train_global_time") def train_global( - self, epochs: int, round_no: int = -1, optimizer_state: dict[str, Any] = None + self, + epochs: int, + round_no: int = -1, + adaptive_threshold_value: Optional[float] = None, + optimizer_state: dict[str, Any] = None, ) -> Tuple[ Any, Any, ModelMetricResultContainer, Any, DiagnosticMetricResultContainer ]: @@ -311,6 +315,7 @@ class NetworkDevice(Device): devices=self.__get_device_ids__(), epochs=epochs, round_no=round_no, + adaptive_threshold_value=adaptive_threshold_value, optimizer_state=optimizer_state, ) @@ -448,7 +453,12 @@ class RPCDeviceServicer(DeviceServicer): def TrainGlobal(self, request, context): print(f"Called TrainGlobal on device {self.device.device_id}") client_weights, server_weights, metrics, optimizer_state, diagnostic_metrics = ( - self.device.train_global(request.epochs, request.round_no) + self.device.train_global( + request.epochs, + request.round_no, + request.adaptive_threshold_value, + proto_to_state_dict(request.optimizer_state), + ) ) response = connection_pb2.TrainGlobalResponse( client_weights=Weights(weights=state_dict_to_proto(client_weights)), @@ -557,21 +567,25 @@ class RPCDeviceServicer(DeviceServicer): return connection_pb2.DatasetModelInfoResponse( train_samples=len(self.device.client._train_data.dataset), validation_samples=len(self.device.client._val_data.dataset), - client_model_flops=int(self.device.client._model_flops), - server_model_flops=int(self.device.server._model_flops), + client_fw_flops=int(self.device.client._model_flops["FW"]), + server_fw_flops=int(self.device.server._model_flops["FW"]), + client_bw_flops=int(self.device.client._model_flops["BW"]), + server_bw_flops=int(self.device.server._model_flops["BW"]), ) def TrainGlobalParallelSplitLearning(self, request, context): print(f"Starting parallel split learning") clients = self.device.__get_device_ids__() round_no = request.round_no - adaptive_learning_threshold = request.adaptive_learning_threshold + adaptive_threshold_value = request.adaptive_threshold_value + optimizer_state = proto_to_state_dict(request.optimizer_state) cw, sw, model_metrics, optimizer_state, diagnostic_metrics = ( self.device.train_parallel_split_learning( clients=clients, round_no=round_no, - adaptive_learning_threshold=adaptive_learning_threshold, + adaptive_threshold_value=adaptive_threshold_value, + optimizer_state=optimizer_state, ) ) response = connection_pb2.TrainGlobalParallelSplitLearningResponse( @@ -651,10 +665,10 @@ class DeviceRequestDispatcher: server_device_id: str, epochs: int, round_no: int, - adaptive_learning_threshold: Optional[float] = None, + adaptive_threshold_value: Optional[float] = None, optimizer_state: dict[str, Any] = None, ): - print(f"><><><> {adaptive_learning_threshold}") + print(f"><><><> {adaptive_threshold_value}") try: response: TrainGlobalParallelSplitLearningResponse = self._get_connection( @@ -662,7 +676,7 @@ class DeviceRequestDispatcher: ).TrainGlobalParallelSplitLearning( connection_pb2.TrainGlobalParallelSplitLearningRequest( round_no=round_no, - adaptive_learning_threshold=adaptive_learning_threshold, + adaptive_threshold_value=adaptive_threshold_value, optimizer_state=state_dict_to_proto(optimizer_state), ) ) @@ -759,6 +773,7 @@ class DeviceRequestDispatcher: device_id: str, epochs: int, round_no: int = -1, + adaptive_threshold_value: Optional[float] = None, optimizer_state: dict[str, Any] = None, ) -> Union[ Tuple[ @@ -775,6 +790,7 @@ class DeviceRequestDispatcher: connection_pb2.TrainGlobalRequest( epochs=epochs, round_no=round_no, + adaptive_threshold_value=adaptive_threshold_value, optimizer_state=state_dict_to_proto(optimizer_state), ) ) @@ -995,7 +1011,7 @@ class DeviceRequestDispatcher: def get_dataset_model_info_on( self, device_id: str - ) -> Union[Tuple[int, int, float, float], bool]: + ) -> Union[Tuple[int, int, int, int, int, int], bool]: try: response: DatasetModelInfoResponse = self._get_connection( device_id @@ -1003,8 +1019,10 @@ class DeviceRequestDispatcher: return ( response.train_samples, response.validation_samples, - response.client_model_flops, - response.server_model_flops, + response.client_fw_flops, + response.server_fw_flops, + response.client_bw_flops, + response.server_bw_flops, ) except grpc.RpcError: self._handle_rpc_error(device_id) diff --git a/edml/core/server.py b/edml/core/server.py index 4279f40301205c361a8648fe907d9084bd159ac5..f93f7fb3025095a93d72f08b26151db751032fc0 100644 --- a/edml/core/server.py +++ b/edml/core/server.py @@ -5,7 +5,6 @@ from typing import List, Optional, Tuple, Any, TYPE_CHECKING import torch from omegaconf import DictConfig -from colorama import Fore from torch import nn from torch.autograd import Variable @@ -20,6 +19,7 @@ from edml.helpers.metrics import ( DiagnosticMetricResultContainer, ) from edml.helpers.types import StateDict, LossFn +from edml.models.provider.base import has_autoencoder if TYPE_CHECKING: from edml.core.device import Device @@ -40,9 +40,9 @@ class DeviceServer: self._device = torch.device(get_torch_device_id(cfg)) self._model = model.to(self._device) self._optimizer, self._lr_scheduler = get_optimizer_and_scheduler( - cfg, self._model.parameters() + cfg, self._model.get_optimizer_params() ) - self._model_flops = 0 # determine later + self._model_flops: dict[str, int] = {"FW": 0, "BW": 0} # determine later self._metrics = create_metrics( cfg.experiment.metrics, cfg.dataset.num_classes, cfg.dataset.average_setting ) @@ -50,6 +50,7 @@ class DeviceServer: self._cfg = cfg self.node_device: Optional[Device] = None self.latency_factor = latency_factor + self.adaptive_threshold_value = None def set_device(self, node_device: Device): """Sets the device reference for the server.""" @@ -72,6 +73,7 @@ class DeviceServer: devices: List[str], epochs: int = 1, round_no: int = -1, + adaptive_threshold_value: Optional[float] = None, optimizer_state: dict[str, Any] = None, ) -> Tuple[ Any, Any, ModelMetricResultContainer, Any, DiagnosticMetricResultContainer @@ -82,6 +84,7 @@ class DeviceServer: devices: The devices to train on epochs: Optionally, the number of epochs to train. round_no: Optionally, the current global epoch number if a learning rate scheduler is used. + adaptive_threshold_value: Optionally, the loss threshold to not send the gradients to the client optimizer_state: Optionally, the optimizer_state to proceed from """ client_weights = None @@ -89,6 +92,8 @@ class DeviceServer: diagnostic_metric_container = DiagnosticMetricResultContainer() if optimizer_state is not None: self._optimizer.load_state_dict(optimizer_state) + if adaptive_threshold_value is not None: + self.adaptive_threshold_value = adaptive_threshold_value for epoch in range(epochs): if self._lr_scheduler is not None: if round_no != -1: @@ -134,22 +139,26 @@ class DeviceServer: ) @simulate_latency_decorator(latency_factor_attr="latency_factor") - def train_batch(self, smashed_data, labels) -> Tuple[Variable, float]: + def train_batch(self, smashed_data, labels) -> Tuple[Optional[Variable], float]: """Train the model on the given batch of data and labels. Returns the gradients of the model's parameters.""" smashed_data, labels = smashed_data.to(self._device), labels.to(self._device) - self._set_model_flops(smashed_data) + self._set_model_flops(smashed_data[0]) self._optimizer.zero_grad() - self.node_device.battery.update_flops(self._model_flops * len(smashed_data)) + self.node_device.battery.update_flops( + self._model_flops["FW"] * len(smashed_data) + ) smashed_data = Variable(smashed_data, requires_grad=True) output_train = self._model(smashed_data) loss_train = self._loss_fn(output_train, labels) - self.node_device.battery.update_flops(self._model_flops * len(smashed_data) * 2) + self.node_device.battery.update_flops( + self._model_flops["BW"] * len(smashed_data) + ) loss_train.backward() self._optimizer.step() @@ -157,14 +166,27 @@ class DeviceServer: self.node_device.log({"loss": loss_train.item()}) self._metrics.metrics_on_batch(output_train.cpu(), labels.cpu().int()) - return smashed_data.grad, loss_train.item() + if has_autoencoder(self._model): + gradients = self._model.trainable_layers_input.grad + else: + gradients = smashed_data.grad + if ( + self.adaptive_threshold_value + and loss_train.item() < self.adaptive_threshold_value + ): + self.node_device.log( + {"adaptive_learning_threshold_applied": gradients.size(0)} + ) + return None, loss_train.item() + return gradients, loss_train.item() - def _set_model_flops(self, smashed_data): + def _set_model_flops(self, sample): """Helper to determine the model flops when smashed data are available for the first time.""" - if self._model_flops == 0: - self._model_flops = estimate_model_flops(self._model, smashed_data) / len( - smashed_data + if self._model_flops["FW"] == 0 or self._model_flops["BW"] == 0: + self._model_flops = estimate_model_flops( + self._model, sample.to(self._device).unsqueeze(0) ) + self.node_device.logger.log({"server_model_flops": self._model_flops}) @simulate_latency_decorator(latency_factor_attr="latency_factor") def finalize_metrics(self, device_id: str, phase: str): @@ -205,8 +227,10 @@ class DeviceServer: """Evaluates the model on the given batch of data and labels""" with torch.no_grad(): smashed_data = smashed_data.to(self._device) - self._set_model_flops(smashed_data) - self.node_device.battery.update_flops(self._model_flops * len(smashed_data)) + self._set_model_flops(smashed_data[0]) + self.node_device.battery.update_flops( + self._model_flops["FW"] * len(smashed_data) + ) pred = self._model(smashed_data) self._metrics.metrics_on_batch(pred.cpu(), labels.cpu().int()) @@ -215,7 +239,7 @@ class DeviceServer: self, clients: List[str], round_no: int, - adaptive_learning_threshold: Optional[float] = None, + adaptive_threshold_value: Optional[float] = None, optimizer_state: dict[str, Any] = None, ): def client_training_job(client_id: str, batch_index: int): @@ -249,7 +273,8 @@ class DeviceServer: self._lr_scheduler.step(round_no + 1) # epoch=1 else: self._lr_scheduler.step() - + if adaptive_threshold_value is not None: + self.adaptive_threshold_value = adaptive_threshold_value num_threads = len(clients) executor = create_executor_with_threads(num_threads) @@ -299,25 +324,9 @@ class DeviceServer: server_gradients, server_loss, server_metrics = ( self.node_device.train_batch(server_batch, server_labels) ) # DiagnosticMetricResultContainer - # We check if the server should activate the adaptive learning threshold. And if true, we make sure to only - # do the client propagation once the current loss value is larger than the threshold. - print( - f"\n{Fore.GREEN}{adaptive_learning_threshold} <-> {server_loss}\n{Fore.RESET}" - ) if ( - adaptive_learning_threshold - and server_loss < adaptive_learning_threshold - ): - print( - f"\n{Fore.RED}ADAPTIVE TRESHOLD REACHED, NEXT BATCH\n{Fore.RESET}" - ) - self.node_device.log( - { - "adaptive_learning_threshold_applied": server_gradients.size( - 0 - ) - } - ) + server_gradients is None + ): # loss threshold was reached, skip client backprop continue num_client_gradients = len(client_forward_pass_responses) @@ -397,25 +406,9 @@ class DeviceServer: server_gradients, server_loss, server_metrics = ( self.node_device.train_batch(server_batch, server_labels) ) # DiagnosticMetricResultContainer - # We check if the server should activate the adaptive learning threshold. And if true, we make sure to only - # do the client propagation once the current loss value is larger than the threshold. - print( - f"\n{Fore.GREEN}{adaptive_learning_threshold} <-> {server_loss}\n{Fore.RESET}" - ) if ( - adaptive_learning_threshold - and server_loss < adaptive_learning_threshold - ): - print( - f"\n{Fore.RED}ADAPTIVE TRESHOLD REACHED, NEXT BATCH\n{Fore.RESET}" - ) - self.node_device.log( - { - "adaptive_learning_threshold_applied": server_gradients.size( - 0 - ) - } - ) + server_gradients is None + ): # loss threshold was reached, skip client backprop continue num_client_gradients = len(client_forward_pass_responses) diff --git a/edml/generated/connection_pb2.py b/edml/generated/connection_pb2.py index f3c72442f44b99d587925a030002e1af92924ac8..3c9a0d7b48b4739017066ae0dfedaa636dd64a62 100644 --- a/edml/generated/connection_pb2.py +++ b/edml/generated/connection_pb2.py @@ -14,7 +14,7 @@ _sym_db = _symbol_database.Default() import datastructures_pb2 as datastructures__pb2 -DESCRIPTOR = _descriptor_pool.Default().AddSerializedFile(b'\n\x10\x63onnection.proto\x1a\x14\x64\x61tastructures.proto\"4\n\x13SetGradientsRequest\x12\x1d\n\tgradients\x18\x01 \x01(\x0b\x32\n.Gradients\"5\n\x14UpdateWeightsRequest\x12\x1d\n\tgradients\x18\x01 \x01(\x0b\x32\n.Gradients\";\n\x1aSingleBatchBackwardRequest\x12\x1d\n\tgradients\x18\x01 \x01(\x0b\x32\n.Gradients\"j\n\x1bSingleBatchBackwardResponse\x12\x19\n\x07metrics\x18\x01 \x01(\x0b\x32\x08.Metrics\x12\"\n\tgradients\x18\x02 \x01(\x0b\x32\n.GradientsH\x00\x88\x01\x01\x42\x0c\n\n_gradients\"C\n\x1aSingleBatchTrainingRequest\x12\x13\n\x0b\x62\x61tch_index\x18\x01 \x01(\x05\x12\x10\n\x08round_no\x18\x02 \x01(\x05\"\x80\x01\n\x1bSingleBatchTrainingResponse\x12\'\n\x0csmashed_data\x18\x01 \x01(\x0b\x32\x0c.ActivationsH\x00\x88\x01\x01\x12\x1c\n\x06labels\x18\x02 \x01(\x0b\x32\x07.LabelsH\x01\x88\x01\x01\x42\x0f\n\r_smashed_dataB\t\n\x07_labels\"\xd5\x01\n\'TrainGlobalParallelSplitLearningRequest\x12\x15\n\x08round_no\x18\x01 \x01(\x05H\x00\x88\x01\x01\x12(\n\x1b\x61\x64\x61ptive_learning_threshold\x18\x02 \x01(\x01H\x01\x88\x01\x01\x12(\n\x0foptimizer_state\x18\x03 \x01(\x0b\x32\n.StateDictH\x02\x88\x01\x01\x42\x0b\n\t_round_noB\x1e\n\x1c_adaptive_learning_thresholdB\x12\n\x10_optimizer_state\"\x89\x02\n(TrainGlobalParallelSplitLearningResponse\x12 \n\x0e\x63lient_weights\x18\x01 \x01(\x0b\x32\x08.Weights\x12 \n\x0eserver_weights\x18\x02 \x01(\x0b\x32\x08.Weights\x12\x19\n\x07metrics\x18\x03 \x01(\x0b\x32\x08.Metrics\x12(\n\x0foptimizer_state\x18\x04 \x01(\x0b\x32\n.StateDictH\x00\x88\x01\x01\x12)\n\x12\x64iagnostic_metrics\x18\x05 \x01(\x0b\x32\x08.MetricsH\x01\x88\x01\x01\x42\x12\n\x10_optimizer_stateB\x15\n\x13_diagnostic_metrics\"\x86\x01\n\x12TrainGlobalRequest\x12\x0e\n\x06\x65pochs\x18\x01 \x01(\x05\x12\x15\n\x08round_no\x18\x02 \x01(\x05H\x00\x88\x01\x01\x12(\n\x0foptimizer_state\x18\x03 \x01(\x0b\x32\n.StateDictH\x01\x88\x01\x01\x42\x0b\n\t_round_noB\x12\n\x10_optimizer_state\"\xf4\x01\n\x13TrainGlobalResponse\x12 \n\x0e\x63lient_weights\x18\x01 \x01(\x0b\x32\x08.Weights\x12 \n\x0eserver_weights\x18\x02 \x01(\x0b\x32\x08.Weights\x12\x19\n\x07metrics\x18\x03 \x01(\x0b\x32\x08.Metrics\x12(\n\x0foptimizer_state\x18\x04 \x01(\x0b\x32\n.StateDictH\x00\x88\x01\x01\x12)\n\x12\x64iagnostic_metrics\x18\x05 \x01(\x0b\x32\x08.MetricsH\x01\x88\x01\x01\x42\x12\n\x10_optimizer_stateB\x15\n\x13_diagnostic_metrics\"A\n\x11SetWeightsRequest\x12\x19\n\x07weights\x18\x01 \x01(\x0b\x32\x08.Weights\x12\x11\n\ton_client\x18\x02 \x01(\x08\"V\n\x12SetWeightsResponse\x12)\n\x12\x64iagnostic_metrics\x18\x01 \x01(\x0b\x32\x08.MetricsH\x00\x88\x01\x01\x42\x15\n\x13_diagnostic_metrics\"T\n\x11TrainEpochRequest\x12\x1b\n\x06server\x18\x01 \x01(\x0b\x32\x0b.DeviceInfo\x12\x15\n\x08round_no\x18\x02 \x01(\x05H\x00\x88\x01\x01\x42\x0b\n\t_round_no\"q\n\x12TrainEpochResponse\x12\x19\n\x07weights\x18\x01 \x01(\x0b\x32\x08.Weights\x12)\n\x12\x64iagnostic_metrics\x18\x02 \x01(\x0b\x32\x08.MetricsH\x00\x88\x01\x01\x42\x15\n\x13_diagnostic_metrics\"P\n\x11TrainBatchRequest\x12\"\n\x0csmashed_data\x18\x01 \x01(\x0b\x32\x0c.Activations\x12\x17\n\x06labels\x18\x02 \x01(\x0b\x32\x07.Labels\"\x91\x01\n\x12TrainBatchResponse\x12\x1d\n\tgradients\x18\x01 \x01(\x0b\x32\n.Gradients\x12)\n\x12\x64iagnostic_metrics\x18\x02 \x01(\x0b\x32\x08.MetricsH\x00\x88\x01\x01\x12\x11\n\x04loss\x18\x03 \x01(\x01H\x01\x88\x01\x01\x42\x15\n\x13_diagnostic_metricsB\x07\n\x05_loss\":\n\x11\x45valGlobalRequest\x12\x12\n\nvalidation\x18\x01 \x01(\x08\x12\x11\n\tfederated\x18\x02 \x01(\x08\"q\n\x12\x45valGlobalResponse\x12\x19\n\x07metrics\x18\x01 \x01(\x0b\x32\x08.Metrics\x12)\n\x12\x64iagnostic_metrics\x18\x02 \x01(\x0b\x32\x08.MetricsH\x00\x88\x01\x01\x42\x15\n\x13_diagnostic_metrics\">\n\x0b\x45valRequest\x12\x1b\n\x06server\x18\x01 \x01(\x0b\x32\x0b.DeviceInfo\x12\x12\n\nvalidation\x18\x02 \x01(\x08\"P\n\x0c\x45valResponse\x12)\n\x12\x64iagnostic_metrics\x18\x01 \x01(\x0b\x32\x08.MetricsH\x00\x88\x01\x01\x42\x15\n\x13_diagnostic_metrics\"O\n\x10\x45valBatchRequest\x12\"\n\x0csmashed_data\x18\x01 \x01(\x0b\x32\x0c.Activations\x12\x17\n\x06labels\x18\x02 \x01(\x0b\x32\x07.Labels\"p\n\x11\x45valBatchResponse\x12\x19\n\x07metrics\x18\x01 \x01(\x0b\x32\x08.Metrics\x12)\n\x12\x64iagnostic_metrics\x18\x02 \x01(\x0b\x32\x08.MetricsH\x00\x88\x01\x01\x42\x15\n\x13_diagnostic_metrics\";\n\x15\x46ullModelTrainRequest\x12\x15\n\x08round_no\x18\x01 \x01(\x05H\x00\x88\x01\x01\x42\x0b\n\t_round_no\"\xce\x01\n\x16\x46ullModelTrainResponse\x12 \n\x0e\x63lient_weights\x18\x01 \x01(\x0b\x32\x08.Weights\x12 \n\x0eserver_weights\x18\x02 \x01(\x0b\x32\x08.Weights\x12\x13\n\x0bnum_samples\x18\x03 \x01(\x05\x12\x19\n\x07metrics\x18\x04 \x01(\x0b\x32\x08.Metrics\x12)\n\x12\x64iagnostic_metrics\x18\x05 \x01(\x0b\x32\x08.MetricsH\x00\x88\x01\x01\x42\x15\n\x13_diagnostic_metrics\"\x18\n\x16StartExperimentRequest\"[\n\x17StartExperimentResponse\x12)\n\x12\x64iagnostic_metrics\x18\x01 \x01(\x0b\x32\x08.MetricsH\x00\x88\x01\x01\x42\x15\n\x13_diagnostic_metrics\"\x16\n\x14\x45ndExperimentRequest\"Y\n\x15\x45ndExperimentResponse\x12)\n\x12\x64iagnostic_metrics\x18\x01 \x01(\x0b\x32\x08.MetricsH\x00\x88\x01\x01\x42\x15\n\x13_diagnostic_metrics\"\x16\n\x14\x42\x61tteryStatusRequest\"y\n\x15\x42\x61tteryStatusResponse\x12\x1e\n\x06status\x18\x01 \x01(\x0b\x32\x0e.BatteryStatus\x12)\n\x12\x64iagnostic_metrics\x18\x02 \x01(\x0b\x32\x08.MetricsH\x00\x88\x01\x01\x42\x15\n\x13_diagnostic_metrics\"\x19\n\x17\x44\x61tasetModelInfoRequest\"\xc7\x01\n\x18\x44\x61tasetModelInfoResponse\x12\x15\n\rtrain_samples\x18\x01 \x01(\x05\x12\x1a\n\x12validation_samples\x18\x02 \x01(\x05\x12\x1a\n\x12\x63lient_model_flops\x18\x03 \x01(\x05\x12\x1a\n\x12server_model_flops\x18\x04 \x01(\x05\x12)\n\x12\x64iagnostic_metrics\x18\x05 \x01(\x0b\x32\x08.MetricsH\x00\x88\x01\x01\x42\x15\n\x13_diagnostic_metrics2\xf8\x08\n\x06\x44\x65vice\x12:\n\x0bTrainGlobal\x12\x13.TrainGlobalRequest\x1a\x14.TrainGlobalResponse\"\x00\x12\x37\n\nSetWeights\x12\x12.SetWeightsRequest\x1a\x13.SetWeightsResponse\"\x00\x12\x37\n\nTrainEpoch\x12\x12.TrainEpochRequest\x1a\x13.TrainEpochResponse\"\x00\x12\x37\n\nTrainBatch\x12\x12.TrainBatchRequest\x1a\x13.TrainBatchResponse\"\x00\x12;\n\x0e\x45valuateGlobal\x12\x12.EvalGlobalRequest\x1a\x13.EvalGlobalResponse\"\x00\x12)\n\x08\x45valuate\x12\x0c.EvalRequest\x1a\r.EvalResponse\"\x00\x12\x38\n\rEvaluateBatch\x12\x11.EvalBatchRequest\x1a\x12.EvalBatchResponse\"\x00\x12\x46\n\x11\x46ullModelTraining\x12\x16.FullModelTrainRequest\x1a\x17.FullModelTrainResponse\"\x00\x12\x46\n\x0fStartExperiment\x12\x17.StartExperimentRequest\x1a\x18.StartExperimentResponse\"\x00\x12@\n\rEndExperiment\x12\x15.EndExperimentRequest\x1a\x16.EndExperimentResponse\"\x00\x12\x43\n\x10GetBatteryStatus\x12\x15.BatteryStatusRequest\x1a\x16.BatteryStatusResponse\"\x00\x12L\n\x13GetDatasetModelInfo\x12\x18.DatasetModelInfoRequest\x1a\x19.DatasetModelInfoResponse\"\x00\x12y\n TrainGlobalParallelSplitLearning\x12(.TrainGlobalParallelSplitLearningRequest\x1a).TrainGlobalParallelSplitLearningResponse\"\x00\x12W\n\x18TrainSingleBatchOnClient\x12\x1b.SingleBatchTrainingRequest\x1a\x1c.SingleBatchTrainingResponse\"\x00\x12\x65\n&BackwardPropagationSingleBatchOnClient\x12\x1b.SingleBatchBackwardRequest\x1a\x1c.SingleBatchBackwardResponse\"\x00\x12\x45\n#SetGradientsAndFinalizeTrainingStep\x12\x14.SetGradientsRequest\x1a\x06.Empty\"\x00\x62\x06proto3') +DESCRIPTOR = _descriptor_pool.Default().AddSerializedFile(b'\n\x10\x63onnection.proto\x1a\x14\x64\x61tastructures.proto\"4\n\x13SetGradientsRequest\x12\x1d\n\tgradients\x18\x01 \x01(\x0b\x32\n.Gradients\"5\n\x14UpdateWeightsRequest\x12\x1d\n\tgradients\x18\x01 \x01(\x0b\x32\n.Gradients\";\n\x1aSingleBatchBackwardRequest\x12\x1d\n\tgradients\x18\x01 \x01(\x0b\x32\n.Gradients\"j\n\x1bSingleBatchBackwardResponse\x12\x19\n\x07metrics\x18\x01 \x01(\x0b\x32\x08.Metrics\x12\"\n\tgradients\x18\x02 \x01(\x0b\x32\n.GradientsH\x00\x88\x01\x01\x42\x0c\n\n_gradients\"C\n\x1aSingleBatchTrainingRequest\x12\x13\n\x0b\x62\x61tch_index\x18\x01 \x01(\x05\x12\x10\n\x08round_no\x18\x02 \x01(\x05\"\x80\x01\n\x1bSingleBatchTrainingResponse\x12\'\n\x0csmashed_data\x18\x01 \x01(\x0b\x32\x0c.ActivationsH\x00\x88\x01\x01\x12\x1c\n\x06labels\x18\x02 \x01(\x0b\x32\x07.LabelsH\x01\x88\x01\x01\x42\x0f\n\r_smashed_dataB\t\n\x07_labels\"\xcf\x01\n\'TrainGlobalParallelSplitLearningRequest\x12\x15\n\x08round_no\x18\x01 \x01(\x05H\x00\x88\x01\x01\x12%\n\x18\x61\x64\x61ptive_threshold_value\x18\x02 \x01(\x01H\x01\x88\x01\x01\x12(\n\x0foptimizer_state\x18\x03 \x01(\x0b\x32\n.StateDictH\x02\x88\x01\x01\x42\x0b\n\t_round_noB\x1b\n\x19_adaptive_threshold_valueB\x12\n\x10_optimizer_state\"\x89\x02\n(TrainGlobalParallelSplitLearningResponse\x12 \n\x0e\x63lient_weights\x18\x01 \x01(\x0b\x32\x08.Weights\x12 \n\x0eserver_weights\x18\x02 \x01(\x0b\x32\x08.Weights\x12\x19\n\x07metrics\x18\x03 \x01(\x0b\x32\x08.Metrics\x12(\n\x0foptimizer_state\x18\x04 \x01(\x0b\x32\n.StateDictH\x00\x88\x01\x01\x12)\n\x12\x64iagnostic_metrics\x18\x05 \x01(\x0b\x32\x08.MetricsH\x01\x88\x01\x01\x42\x12\n\x10_optimizer_stateB\x15\n\x13_diagnostic_metrics\"\xca\x01\n\x12TrainGlobalRequest\x12\x0e\n\x06\x65pochs\x18\x01 \x01(\x05\x12\x15\n\x08round_no\x18\x02 \x01(\x05H\x00\x88\x01\x01\x12%\n\x18\x61\x64\x61ptive_threshold_value\x18\x03 \x01(\x01H\x01\x88\x01\x01\x12(\n\x0foptimizer_state\x18\x04 \x01(\x0b\x32\n.StateDictH\x02\x88\x01\x01\x42\x0b\n\t_round_noB\x1b\n\x19_adaptive_threshold_valueB\x12\n\x10_optimizer_state\"\xf4\x01\n\x13TrainGlobalResponse\x12 \n\x0e\x63lient_weights\x18\x01 \x01(\x0b\x32\x08.Weights\x12 \n\x0eserver_weights\x18\x02 \x01(\x0b\x32\x08.Weights\x12\x19\n\x07metrics\x18\x03 \x01(\x0b\x32\x08.Metrics\x12(\n\x0foptimizer_state\x18\x04 \x01(\x0b\x32\n.StateDictH\x00\x88\x01\x01\x12)\n\x12\x64iagnostic_metrics\x18\x05 \x01(\x0b\x32\x08.MetricsH\x01\x88\x01\x01\x42\x12\n\x10_optimizer_stateB\x15\n\x13_diagnostic_metrics\"A\n\x11SetWeightsRequest\x12\x19\n\x07weights\x18\x01 \x01(\x0b\x32\x08.Weights\x12\x11\n\ton_client\x18\x02 \x01(\x08\"V\n\x12SetWeightsResponse\x12)\n\x12\x64iagnostic_metrics\x18\x01 \x01(\x0b\x32\x08.MetricsH\x00\x88\x01\x01\x42\x15\n\x13_diagnostic_metrics\"T\n\x11TrainEpochRequest\x12\x1b\n\x06server\x18\x01 \x01(\x0b\x32\x0b.DeviceInfo\x12\x15\n\x08round_no\x18\x02 \x01(\x05H\x00\x88\x01\x01\x42\x0b\n\t_round_no\"q\n\x12TrainEpochResponse\x12\x19\n\x07weights\x18\x01 \x01(\x0b\x32\x08.Weights\x12)\n\x12\x64iagnostic_metrics\x18\x02 \x01(\x0b\x32\x08.MetricsH\x00\x88\x01\x01\x42\x15\n\x13_diagnostic_metrics\"P\n\x11TrainBatchRequest\x12\"\n\x0csmashed_data\x18\x01 \x01(\x0b\x32\x0c.Activations\x12\x17\n\x06labels\x18\x02 \x01(\x0b\x32\x07.Labels\"\x91\x01\n\x12TrainBatchResponse\x12\x1d\n\tgradients\x18\x01 \x01(\x0b\x32\n.Gradients\x12)\n\x12\x64iagnostic_metrics\x18\x02 \x01(\x0b\x32\x08.MetricsH\x00\x88\x01\x01\x12\x11\n\x04loss\x18\x03 \x01(\x01H\x01\x88\x01\x01\x42\x15\n\x13_diagnostic_metricsB\x07\n\x05_loss\":\n\x11\x45valGlobalRequest\x12\x12\n\nvalidation\x18\x01 \x01(\x08\x12\x11\n\tfederated\x18\x02 \x01(\x08\"q\n\x12\x45valGlobalResponse\x12\x19\n\x07metrics\x18\x01 \x01(\x0b\x32\x08.Metrics\x12)\n\x12\x64iagnostic_metrics\x18\x02 \x01(\x0b\x32\x08.MetricsH\x00\x88\x01\x01\x42\x15\n\x13_diagnostic_metrics\">\n\x0b\x45valRequest\x12\x1b\n\x06server\x18\x01 \x01(\x0b\x32\x0b.DeviceInfo\x12\x12\n\nvalidation\x18\x02 \x01(\x08\"P\n\x0c\x45valResponse\x12)\n\x12\x64iagnostic_metrics\x18\x01 \x01(\x0b\x32\x08.MetricsH\x00\x88\x01\x01\x42\x15\n\x13_diagnostic_metrics\"O\n\x10\x45valBatchRequest\x12\"\n\x0csmashed_data\x18\x01 \x01(\x0b\x32\x0c.Activations\x12\x17\n\x06labels\x18\x02 \x01(\x0b\x32\x07.Labels\"p\n\x11\x45valBatchResponse\x12\x19\n\x07metrics\x18\x01 \x01(\x0b\x32\x08.Metrics\x12)\n\x12\x64iagnostic_metrics\x18\x02 \x01(\x0b\x32\x08.MetricsH\x00\x88\x01\x01\x42\x15\n\x13_diagnostic_metrics\";\n\x15\x46ullModelTrainRequest\x12\x15\n\x08round_no\x18\x01 \x01(\x05H\x00\x88\x01\x01\x42\x0b\n\t_round_no\"\xce\x01\n\x16\x46ullModelTrainResponse\x12 \n\x0e\x63lient_weights\x18\x01 \x01(\x0b\x32\x08.Weights\x12 \n\x0eserver_weights\x18\x02 \x01(\x0b\x32\x08.Weights\x12\x13\n\x0bnum_samples\x18\x03 \x01(\x05\x12\x19\n\x07metrics\x18\x04 \x01(\x0b\x32\x08.Metrics\x12)\n\x12\x64iagnostic_metrics\x18\x05 \x01(\x0b\x32\x08.MetricsH\x00\x88\x01\x01\x42\x15\n\x13_diagnostic_metrics\"\x18\n\x16StartExperimentRequest\"[\n\x17StartExperimentResponse\x12)\n\x12\x64iagnostic_metrics\x18\x01 \x01(\x0b\x32\x08.MetricsH\x00\x88\x01\x01\x42\x15\n\x13_diagnostic_metrics\"\x16\n\x14\x45ndExperimentRequest\"Y\n\x15\x45ndExperimentResponse\x12)\n\x12\x64iagnostic_metrics\x18\x01 \x01(\x0b\x32\x08.MetricsH\x00\x88\x01\x01\x42\x15\n\x13_diagnostic_metrics\"\x16\n\x14\x42\x61tteryStatusRequest\"y\n\x15\x42\x61tteryStatusResponse\x12\x1e\n\x06status\x18\x01 \x01(\x0b\x32\x0e.BatteryStatus\x12)\n\x12\x64iagnostic_metrics\x18\x02 \x01(\x0b\x32\x08.MetricsH\x00\x88\x01\x01\x42\x15\n\x13_diagnostic_metrics\"\x19\n\x17\x44\x61tasetModelInfoRequest\"\xa5\x02\n\x18\x44\x61tasetModelInfoResponse\x12\x15\n\rtrain_samples\x18\x01 \x01(\x05\x12\x1a\n\x12validation_samples\x18\x02 \x01(\x05\x12\x17\n\x0f\x63lient_fw_flops\x18\x03 \x01(\x05\x12\x17\n\x0fserver_fw_flops\x18\x04 \x01(\x05\x12\x1c\n\x0f\x63lient_bw_flops\x18\x05 \x01(\x05H\x00\x88\x01\x01\x12\x1c\n\x0fserver_bw_flops\x18\x06 \x01(\x05H\x01\x88\x01\x01\x12)\n\x12\x64iagnostic_metrics\x18\x07 \x01(\x0b\x32\x08.MetricsH\x02\x88\x01\x01\x42\x12\n\x10_client_bw_flopsB\x12\n\x10_server_bw_flopsB\x15\n\x13_diagnostic_metrics2\xf8\x08\n\x06\x44\x65vice\x12:\n\x0bTrainGlobal\x12\x13.TrainGlobalRequest\x1a\x14.TrainGlobalResponse\"\x00\x12\x37\n\nSetWeights\x12\x12.SetWeightsRequest\x1a\x13.SetWeightsResponse\"\x00\x12\x37\n\nTrainEpoch\x12\x12.TrainEpochRequest\x1a\x13.TrainEpochResponse\"\x00\x12\x37\n\nTrainBatch\x12\x12.TrainBatchRequest\x1a\x13.TrainBatchResponse\"\x00\x12;\n\x0e\x45valuateGlobal\x12\x12.EvalGlobalRequest\x1a\x13.EvalGlobalResponse\"\x00\x12)\n\x08\x45valuate\x12\x0c.EvalRequest\x1a\r.EvalResponse\"\x00\x12\x38\n\rEvaluateBatch\x12\x11.EvalBatchRequest\x1a\x12.EvalBatchResponse\"\x00\x12\x46\n\x11\x46ullModelTraining\x12\x16.FullModelTrainRequest\x1a\x17.FullModelTrainResponse\"\x00\x12\x46\n\x0fStartExperiment\x12\x17.StartExperimentRequest\x1a\x18.StartExperimentResponse\"\x00\x12@\n\rEndExperiment\x12\x15.EndExperimentRequest\x1a\x16.EndExperimentResponse\"\x00\x12\x43\n\x10GetBatteryStatus\x12\x15.BatteryStatusRequest\x1a\x16.BatteryStatusResponse\"\x00\x12L\n\x13GetDatasetModelInfo\x12\x18.DatasetModelInfoRequest\x1a\x19.DatasetModelInfoResponse\"\x00\x12y\n TrainGlobalParallelSplitLearning\x12(.TrainGlobalParallelSplitLearningRequest\x1a).TrainGlobalParallelSplitLearningResponse\"\x00\x12W\n\x18TrainSingleBatchOnClient\x12\x1b.SingleBatchTrainingRequest\x1a\x1c.SingleBatchTrainingResponse\"\x00\x12\x65\n&BackwardPropagationSingleBatchOnClient\x12\x1b.SingleBatchBackwardRequest\x1a\x1c.SingleBatchBackwardResponse\"\x00\x12\x45\n#SetGradientsAndFinalizeTrainingStep\x12\x14.SetGradientsRequest\x1a\x06.Empty\"\x00\x62\x06proto3') _globals = globals() _builder.BuildMessageAndEnumDescriptors(DESCRIPTOR, _globals) @@ -35,57 +35,57 @@ if _descriptor._USE_C_DESCRIPTORS == False: _globals['_SINGLEBATCHTRAININGRESPONSE']._serialized_start=390 _globals['_SINGLEBATCHTRAININGRESPONSE']._serialized_end=518 _globals['_TRAINGLOBALPARALLELSPLITLEARNINGREQUEST']._serialized_start=521 - _globals['_TRAINGLOBALPARALLELSPLITLEARNINGREQUEST']._serialized_end=734 - _globals['_TRAINGLOBALPARALLELSPLITLEARNINGRESPONSE']._serialized_start=737 - _globals['_TRAINGLOBALPARALLELSPLITLEARNINGRESPONSE']._serialized_end=1002 - _globals['_TRAINGLOBALREQUEST']._serialized_start=1005 - _globals['_TRAINGLOBALREQUEST']._serialized_end=1139 - _globals['_TRAINGLOBALRESPONSE']._serialized_start=1142 - _globals['_TRAINGLOBALRESPONSE']._serialized_end=1386 - _globals['_SETWEIGHTSREQUEST']._serialized_start=1388 - _globals['_SETWEIGHTSREQUEST']._serialized_end=1453 - _globals['_SETWEIGHTSRESPONSE']._serialized_start=1455 - _globals['_SETWEIGHTSRESPONSE']._serialized_end=1541 - _globals['_TRAINEPOCHREQUEST']._serialized_start=1543 - _globals['_TRAINEPOCHREQUEST']._serialized_end=1627 - _globals['_TRAINEPOCHRESPONSE']._serialized_start=1629 - _globals['_TRAINEPOCHRESPONSE']._serialized_end=1742 - _globals['_TRAINBATCHREQUEST']._serialized_start=1744 - _globals['_TRAINBATCHREQUEST']._serialized_end=1824 - _globals['_TRAINBATCHRESPONSE']._serialized_start=1827 - _globals['_TRAINBATCHRESPONSE']._serialized_end=1972 - _globals['_EVALGLOBALREQUEST']._serialized_start=1974 - _globals['_EVALGLOBALREQUEST']._serialized_end=2032 - _globals['_EVALGLOBALRESPONSE']._serialized_start=2034 - _globals['_EVALGLOBALRESPONSE']._serialized_end=2147 - _globals['_EVALREQUEST']._serialized_start=2149 - _globals['_EVALREQUEST']._serialized_end=2211 - _globals['_EVALRESPONSE']._serialized_start=2213 - _globals['_EVALRESPONSE']._serialized_end=2293 - _globals['_EVALBATCHREQUEST']._serialized_start=2295 - _globals['_EVALBATCHREQUEST']._serialized_end=2374 - _globals['_EVALBATCHRESPONSE']._serialized_start=2376 - _globals['_EVALBATCHRESPONSE']._serialized_end=2488 - _globals['_FULLMODELTRAINREQUEST']._serialized_start=2490 - _globals['_FULLMODELTRAINREQUEST']._serialized_end=2549 - _globals['_FULLMODELTRAINRESPONSE']._serialized_start=2552 - _globals['_FULLMODELTRAINRESPONSE']._serialized_end=2758 - _globals['_STARTEXPERIMENTREQUEST']._serialized_start=2760 - _globals['_STARTEXPERIMENTREQUEST']._serialized_end=2784 - _globals['_STARTEXPERIMENTRESPONSE']._serialized_start=2786 - _globals['_STARTEXPERIMENTRESPONSE']._serialized_end=2877 - _globals['_ENDEXPERIMENTREQUEST']._serialized_start=2879 - _globals['_ENDEXPERIMENTREQUEST']._serialized_end=2901 - _globals['_ENDEXPERIMENTRESPONSE']._serialized_start=2903 - _globals['_ENDEXPERIMENTRESPONSE']._serialized_end=2992 - _globals['_BATTERYSTATUSREQUEST']._serialized_start=2994 - _globals['_BATTERYSTATUSREQUEST']._serialized_end=3016 - _globals['_BATTERYSTATUSRESPONSE']._serialized_start=3018 - _globals['_BATTERYSTATUSRESPONSE']._serialized_end=3139 - _globals['_DATASETMODELINFOREQUEST']._serialized_start=3141 - _globals['_DATASETMODELINFOREQUEST']._serialized_end=3166 - _globals['_DATASETMODELINFORESPONSE']._serialized_start=3169 - _globals['_DATASETMODELINFORESPONSE']._serialized_end=3368 - _globals['_DEVICE']._serialized_start=3371 - _globals['_DEVICE']._serialized_end=4515 + _globals['_TRAINGLOBALPARALLELSPLITLEARNINGREQUEST']._serialized_end=728 + _globals['_TRAINGLOBALPARALLELSPLITLEARNINGRESPONSE']._serialized_start=731 + _globals['_TRAINGLOBALPARALLELSPLITLEARNINGRESPONSE']._serialized_end=996 + _globals['_TRAINGLOBALREQUEST']._serialized_start=999 + _globals['_TRAINGLOBALREQUEST']._serialized_end=1201 + _globals['_TRAINGLOBALRESPONSE']._serialized_start=1204 + _globals['_TRAINGLOBALRESPONSE']._serialized_end=1448 + _globals['_SETWEIGHTSREQUEST']._serialized_start=1450 + _globals['_SETWEIGHTSREQUEST']._serialized_end=1515 + _globals['_SETWEIGHTSRESPONSE']._serialized_start=1517 + _globals['_SETWEIGHTSRESPONSE']._serialized_end=1603 + _globals['_TRAINEPOCHREQUEST']._serialized_start=1605 + _globals['_TRAINEPOCHREQUEST']._serialized_end=1689 + _globals['_TRAINEPOCHRESPONSE']._serialized_start=1691 + _globals['_TRAINEPOCHRESPONSE']._serialized_end=1804 + _globals['_TRAINBATCHREQUEST']._serialized_start=1806 + _globals['_TRAINBATCHREQUEST']._serialized_end=1886 + _globals['_TRAINBATCHRESPONSE']._serialized_start=1889 + _globals['_TRAINBATCHRESPONSE']._serialized_end=2034 + _globals['_EVALGLOBALREQUEST']._serialized_start=2036 + _globals['_EVALGLOBALREQUEST']._serialized_end=2094 + _globals['_EVALGLOBALRESPONSE']._serialized_start=2096 + _globals['_EVALGLOBALRESPONSE']._serialized_end=2209 + _globals['_EVALREQUEST']._serialized_start=2211 + _globals['_EVALREQUEST']._serialized_end=2273 + _globals['_EVALRESPONSE']._serialized_start=2275 + _globals['_EVALRESPONSE']._serialized_end=2355 + _globals['_EVALBATCHREQUEST']._serialized_start=2357 + _globals['_EVALBATCHREQUEST']._serialized_end=2436 + _globals['_EVALBATCHRESPONSE']._serialized_start=2438 + _globals['_EVALBATCHRESPONSE']._serialized_end=2550 + _globals['_FULLMODELTRAINREQUEST']._serialized_start=2552 + _globals['_FULLMODELTRAINREQUEST']._serialized_end=2611 + _globals['_FULLMODELTRAINRESPONSE']._serialized_start=2614 + _globals['_FULLMODELTRAINRESPONSE']._serialized_end=2820 + _globals['_STARTEXPERIMENTREQUEST']._serialized_start=2822 + _globals['_STARTEXPERIMENTREQUEST']._serialized_end=2846 + _globals['_STARTEXPERIMENTRESPONSE']._serialized_start=2848 + _globals['_STARTEXPERIMENTRESPONSE']._serialized_end=2939 + _globals['_ENDEXPERIMENTREQUEST']._serialized_start=2941 + _globals['_ENDEXPERIMENTREQUEST']._serialized_end=2963 + _globals['_ENDEXPERIMENTRESPONSE']._serialized_start=2965 + _globals['_ENDEXPERIMENTRESPONSE']._serialized_end=3054 + _globals['_BATTERYSTATUSREQUEST']._serialized_start=3056 + _globals['_BATTERYSTATUSREQUEST']._serialized_end=3078 + _globals['_BATTERYSTATUSRESPONSE']._serialized_start=3080 + _globals['_BATTERYSTATUSRESPONSE']._serialized_end=3201 + _globals['_DATASETMODELINFOREQUEST']._serialized_start=3203 + _globals['_DATASETMODELINFOREQUEST']._serialized_end=3228 + _globals['_DATASETMODELINFORESPONSE']._serialized_start=3231 + _globals['_DATASETMODELINFORESPONSE']._serialized_end=3524 + _globals['_DEVICE']._serialized_start=3527 + _globals['_DEVICE']._serialized_end=4671 # @@protoc_insertion_point(module_scope) diff --git a/edml/generated/connection_pb2.pyi b/edml/generated/connection_pb2.pyi index 89353343aa1c7e39072bd0ea03c891bd2df7b4df..bc0c09189004803b0d556d0c7eaeeecaf82922e0 100644 --- a/edml/generated/connection_pb2.pyi +++ b/edml/generated/connection_pb2.pyi @@ -48,14 +48,14 @@ class SingleBatchTrainingResponse(_message.Message): def __init__(self, smashed_data: _Optional[_Union[_datastructures_pb2.Activations, _Mapping]] = ..., labels: _Optional[_Union[_datastructures_pb2.Labels, _Mapping]] = ...) -> None: ... class TrainGlobalParallelSplitLearningRequest(_message.Message): - __slots__ = ["round_no", "adaptive_learning_threshold", "optimizer_state"] + __slots__ = ["round_no", "adaptive_threshold_value", "optimizer_state"] ROUND_NO_FIELD_NUMBER: _ClassVar[int] - ADAPTIVE_LEARNING_THRESHOLD_FIELD_NUMBER: _ClassVar[int] + ADAPTIVE_THRESHOLD_VALUE_FIELD_NUMBER: _ClassVar[int] OPTIMIZER_STATE_FIELD_NUMBER: _ClassVar[int] round_no: int - adaptive_learning_threshold: float + adaptive_threshold_value: float optimizer_state: _datastructures_pb2.StateDict - def __init__(self, round_no: _Optional[int] = ..., adaptive_learning_threshold: _Optional[float] = ..., optimizer_state: _Optional[_Union[_datastructures_pb2.StateDict, _Mapping]] = ...) -> None: ... + def __init__(self, round_no: _Optional[int] = ..., adaptive_threshold_value: _Optional[float] = ..., optimizer_state: _Optional[_Union[_datastructures_pb2.StateDict, _Mapping]] = ...) -> None: ... class TrainGlobalParallelSplitLearningResponse(_message.Message): __slots__ = ["client_weights", "server_weights", "metrics", "optimizer_state", "diagnostic_metrics"] @@ -72,14 +72,16 @@ class TrainGlobalParallelSplitLearningResponse(_message.Message): def __init__(self, client_weights: _Optional[_Union[_datastructures_pb2.Weights, _Mapping]] = ..., server_weights: _Optional[_Union[_datastructures_pb2.Weights, _Mapping]] = ..., metrics: _Optional[_Union[_datastructures_pb2.Metrics, _Mapping]] = ..., optimizer_state: _Optional[_Union[_datastructures_pb2.StateDict, _Mapping]] = ..., diagnostic_metrics: _Optional[_Union[_datastructures_pb2.Metrics, _Mapping]] = ...) -> None: ... class TrainGlobalRequest(_message.Message): - __slots__ = ["epochs", "round_no", "optimizer_state"] + __slots__ = ["epochs", "round_no", "adaptive_threshold_value", "optimizer_state"] EPOCHS_FIELD_NUMBER: _ClassVar[int] ROUND_NO_FIELD_NUMBER: _ClassVar[int] + ADAPTIVE_THRESHOLD_VALUE_FIELD_NUMBER: _ClassVar[int] OPTIMIZER_STATE_FIELD_NUMBER: _ClassVar[int] epochs: int round_no: int + adaptive_threshold_value: float optimizer_state: _datastructures_pb2.StateDict - def __init__(self, epochs: _Optional[int] = ..., round_no: _Optional[int] = ..., optimizer_state: _Optional[_Union[_datastructures_pb2.StateDict, _Mapping]] = ...) -> None: ... + def __init__(self, epochs: _Optional[int] = ..., round_no: _Optional[int] = ..., adaptive_threshold_value: _Optional[float] = ..., optimizer_state: _Optional[_Union[_datastructures_pb2.StateDict, _Mapping]] = ...) -> None: ... class TrainGlobalResponse(_message.Message): __slots__ = ["client_weights", "server_weights", "metrics", "optimizer_state", "diagnostic_metrics"] @@ -246,15 +248,19 @@ class DatasetModelInfoRequest(_message.Message): def __init__(self) -> None: ... class DatasetModelInfoResponse(_message.Message): - __slots__ = ["train_samples", "validation_samples", "client_model_flops", "server_model_flops", "diagnostic_metrics"] + __slots__ = ["train_samples", "validation_samples", "client_fw_flops", "server_fw_flops", "client_bw_flops", "server_bw_flops", "diagnostic_metrics"] TRAIN_SAMPLES_FIELD_NUMBER: _ClassVar[int] VALIDATION_SAMPLES_FIELD_NUMBER: _ClassVar[int] - CLIENT_MODEL_FLOPS_FIELD_NUMBER: _ClassVar[int] - SERVER_MODEL_FLOPS_FIELD_NUMBER: _ClassVar[int] + CLIENT_FW_FLOPS_FIELD_NUMBER: _ClassVar[int] + SERVER_FW_FLOPS_FIELD_NUMBER: _ClassVar[int] + CLIENT_BW_FLOPS_FIELD_NUMBER: _ClassVar[int] + SERVER_BW_FLOPS_FIELD_NUMBER: _ClassVar[int] DIAGNOSTIC_METRICS_FIELD_NUMBER: _ClassVar[int] train_samples: int validation_samples: int - client_model_flops: int - server_model_flops: int + client_fw_flops: int + server_fw_flops: int + client_bw_flops: int + server_bw_flops: int diagnostic_metrics: _datastructures_pb2.Metrics - def __init__(self, train_samples: _Optional[int] = ..., validation_samples: _Optional[int] = ..., client_model_flops: _Optional[int] = ..., server_model_flops: _Optional[int] = ..., diagnostic_metrics: _Optional[_Union[_datastructures_pb2.Metrics, _Mapping]] = ...) -> None: ... + def __init__(self, train_samples: _Optional[int] = ..., validation_samples: _Optional[int] = ..., client_fw_flops: _Optional[int] = ..., server_fw_flops: _Optional[int] = ..., client_bw_flops: _Optional[int] = ..., server_bw_flops: _Optional[int] = ..., diagnostic_metrics: _Optional[_Union[_datastructures_pb2.Metrics, _Mapping]] = ...) -> None: ... diff --git a/edml/helpers/flops.py b/edml/helpers/flops.py index 4388c9754e4c42e33a958c309f8a73a2c21f4265..038f27fd31391a8feac144da37db988c1aa93c69 100644 --- a/edml/helpers/flops.py +++ b/edml/helpers/flops.py @@ -3,19 +3,43 @@ from typing import Union, Tuple from fvcore.nn import FlopCountAnalysis from torch import Tensor, nn +from edml.models.autoencoder import ClientWithAutoencoder, ServerWithAutoencoder +from edml.models.provider.base import has_autoencoder + def estimate_model_flops( model: nn.Module, sample: Union[Tensor, Tuple[Tensor, ...]] -) -> int: +) -> dict[str, int]: """ - Estimates the FLOPs of one forward pass of the model using the sample data provided. - + Estimates the FLOPs of one forward pass and one backward pass of the model using the sample data provided. + If forward and backward pass affect the same parameters, the number of FLOPs of one backward pass is assumed to be + two times the number of FLOPs for the backward pass. + In case of non-trainable parameters as an autoencoder in between, the number of backward FLOPs is assumed to be + twice the number of the forward FLOPs of the trainable part, while the number of actual forward FLOPs is estimated + using the full model. Args: model (nn.Module): the neural network model to calculate the FLOPs for. - sample: The data used to calculate the FLOPs. + sample: The data used to calculate the FLOPs. The first dimension should be the batch dimension. Returns: - int: the number of FLOPs. - + dict(str, int): {"FW": #ForwardFLOPs, "BW": #BackwardFLOPs} the number of estimated forward and backward FLOPs. """ - return FlopCountAnalysis(model, sample).total() + fw_flops = FlopCountAnalysis(model, sample).total() + if has_autoencoder(model): + bw_flops = 0 + if isinstance( + model, ClientWithAutoencoder + ): # run FW pass until Encoder and multiply by 2 + bw_flops = FlopCountAnalysis(model.model, sample).total() * 2 + elif isinstance( + model, ServerWithAutoencoder + ): # run FW pass from decoder output and multiply by 2 + bw_flops = ( + FlopCountAnalysis(model.model, model.autoencoder(sample)).total() * 2 + ) + else: + raise NotImplementedError() + else: + bw_flops = 2 * fw_flops + + return {"FW": fw_flops, "BW": bw_flops} diff --git a/edml/models/autoencoder.py b/edml/models/autoencoder.py index ed16f3000d694c53856e901a2634776fcf488433..7814b7e465e6ca8bb218fd4e4798760917999a9e 100644 --- a/edml/models/autoencoder.py +++ b/edml/models/autoencoder.py @@ -1,5 +1,5 @@ -import torch from torch import nn +from torch.autograd import Variable class ClientWithAutoencoder(nn.Module): @@ -7,10 +7,13 @@ class ClientWithAutoencoder(nn.Module): super().__init__() self.model = model self.autoencoder = autoencoder.requires_grad_(False) + self.trainable_layers_output = ( + None # needed to continue the backprop skipping the AE + ) def forward(self, x): - x = self.model(x) - return self.autoencoder(x) + self.trainable_layers_output = self.model(x) + return self.autoencoder(self.trainable_layers_output) class ServerWithAutoencoder(nn.Module): @@ -18,7 +21,13 @@ class ServerWithAutoencoder(nn.Module): super().__init__() self.model = model self.autoencoder = autoencoder.requires_grad_(False) + self.trainable_layers_input = ( + None # needed to stop backprop before AE and send the gradients to client + ) def forward(self, x): - x = self.autoencoder(x) - return self.model(x) + self.trainable_layers_input = self.autoencoder(x) + self.trainable_layers_input = Variable( + self.trainable_layers_input, requires_grad=True + ) + return self.model(self.trainable_layers_input) diff --git a/edml/models/provider/base.py b/edml/models/provider/base.py index e60736bae46fb8f3a2239c9ff016ddb32ce0578e..9fa81abd5a7c58e880f163d2d88eb070c7271e47 100644 --- a/edml/models/provider/base.py +++ b/edml/models/provider/base.py @@ -1,10 +1,33 @@ from torch import nn +def has_autoencoder(model: nn.Module): + if hasattr(model, "model") and hasattr(model, "autoencoder"): + return True + return False + + +def add_optimizer_params_function(model: nn.Module): + # exclude AE params + if has_autoencoder(model): + + def get_optimizer_params(model: nn.Module): + return model.model.parameters + + model.get_optimizer_params = get_optimizer_params(model) + else: + + def get_optimizer_params(model: nn.Module): + return model.parameters + + model.get_optimizer_params = get_optimizer_params(model) + return model + + class ModelProvider: def __init__(self, client: nn.Module, server: nn.Module): - self._client = client - self._server = server + self._client = add_optimizer_params_function(client) + self._server = add_optimizer_params_function(server) @property def models(self) -> tuple[nn.Module, nn.Module]: diff --git a/edml/proto/connection.proto b/edml/proto/connection.proto index 5755518d01796bdff68957655d4d339dcf02ab1d..f4c7952a62b9db7225760a9b7596a03e4b9f09f4 100644 --- a/edml/proto/connection.proto +++ b/edml/proto/connection.proto @@ -51,7 +51,7 @@ message SingleBatchTrainingResponse { message TrainGlobalParallelSplitLearningRequest { optional int32 round_no = 1; - optional double adaptive_learning_threshold = 2; + optional double adaptive_threshold_value = 2; optional StateDict optimizer_state = 3; } @@ -66,7 +66,8 @@ message TrainGlobalParallelSplitLearningResponse { message TrainGlobalRequest { int32 epochs = 1; optional int32 round_no = 2; - optional StateDict optimizer_state = 3; + optional double adaptive_threshold_value = 3; + optional StateDict optimizer_state = 4; } @@ -173,7 +174,9 @@ message DatasetModelInfoRequest {} message DatasetModelInfoResponse { int32 train_samples = 1; int32 validation_samples = 2; - int32 client_model_flops = 3; - int32 server_model_flops = 4; - optional Metrics diagnostic_metrics = 5; + int32 client_fw_flops = 3; + int32 server_fw_flops = 4; + optional int32 client_bw_flops = 5; + optional int32 server_bw_flops = 6; + optional Metrics diagnostic_metrics = 7; } diff --git a/edml/tests/controllers/fed_controller_test.py b/edml/tests/controllers/fed_controller_test.py index a44090eefc7b1fb3b791cc749c516835e6d8cc67..45c4ef900c74a02035674ad7cf96be12bd95a5af 100644 --- a/edml/tests/controllers/fed_controller_test.py +++ b/edml/tests/controllers/fed_controller_test.py @@ -149,6 +149,7 @@ class FedTrainRoundThreadingTest(unittest.TestCase): "num_devices": 0, "wandb": {"enabled": False}, "experiment": {"early_stopping": False}, + "simulate_parallelism": False, } ) ) diff --git a/edml/tests/controllers/optimization_test.py b/edml/tests/controllers/optimization_test.py index 2e8cc390c7a18e0d2b5c8e43822930973f32f953..ab89b184b992891dbe7793e31c889ce4ea912212 100644 --- a/edml/tests/controllers/optimization_test.py +++ b/edml/tests/controllers/optimization_test.py @@ -42,8 +42,8 @@ class StrategyOptimizationTest(unittest.TestCase): cost_per_byte_sent=1, cost_per_byte_received=1, cost_per_flop=1, - client_model_flops=10, - server_model_flops=20, + client_fw_flops=10, + server_fw_flops=20, smashed_data_size=50, label_size=5, gradient_size=50, @@ -168,8 +168,8 @@ class EnergySimulatorTest(unittest.TestCase): cost_per_byte_sent=1, cost_per_byte_received=1, cost_per_flop=1, - client_model_flops=10, - server_model_flops=20, + client_fw_flops=10, + server_fw_flops=20, smashed_data_size=50, label_size=5, gradient_size=50, @@ -284,8 +284,8 @@ class TestWithRealData(unittest.TestCase): cost_per_byte_sent=self.cost_per_mbyte_sent / 1000000, cost_per_byte_received=self.cost_per_mbyte_received / 1000000, cost_per_flop=self.cost_per_mflop / 1000000, - client_model_flops=5405760, - server_model_flops=11215800, + client_fw_flops=5405760, + server_fw_flops=11215800, smashed_data_size=36871, label_size=14, gradient_size=36871, @@ -450,3 +450,85 @@ class TestWithRealData(unittest.TestCase): ) print(solution, sum(solution.values()), num_rounds, remaining_batteries) self.assertGreater(sum(solution.values()), num_rounds) + + +class TestBug(unittest.TestCase): + """Case studies for estimating the number of rounds of each strategy with given energy constratins.""" + + def setUp(self): + self.global_params = GlobalParams( + cost_per_sec=0.02, + cost_per_byte_sent=2e-10, + cost_per_byte_received=2e-10, + cost_per_flop=4.9999999999999995e-14, + client_fw_flops=88408064, + server_fw_flops=169728256, + smashed_data_size=65542.875, + batch_size=64, + client_weights_size=416469, + server_weights_size=6769181, + optimizer_state_size=6664371, + train_global_time=204.80956506729126, + last_server_device_id="d0", + label_size=14.359375, + gradient_size=65542.875, + client_norm_fw_time=0.0002889558474222819, + client_norm_bw_time=0.00033905270364549425, + server_norm_fw_time=0.0035084471996721703, + server_norm_bw_time=0.005231082973148723, + ) + self.device_params_list = [ + DeviceParams( + device_id="d0", + initial_battery=400.0, + current_battery=393.09099611222206, + train_samples=4500, + validation_samples=500, + comp_latency_factor=3.322075197090219, + ), + DeviceParams( + device_id="d1", + initial_battery=400.0, + current_battery=398.2543529099223, + train_samples=4500, + validation_samples=500, + comp_latency_factor=3.5023548154181494, + ), + DeviceParams( + device_id="d2", + initial_battery=300.0, + current_battery=297.6957505401916, + train_samples=4500, + validation_samples=500, + comp_latency_factor=3.2899405054194144, + ), + DeviceParams( + device_id="d3", + initial_battery=200.0, + current_battery=197.12318101733192, + train_samples=4500, + validation_samples=500, + comp_latency_factor=3.4622951339692114, + ), + DeviceParams( + device_id="d4", + initial_battery=200.0, + current_battery=194.35766488685508, + train_samples=27000, + validation_samples=3000, + comp_latency_factor=1.0, + ), + ] + + self.optimizer = ServerChoiceOptimizer( + self.device_params_list, self.global_params + ) + self.simulator = EnergySimulator(self.device_params_list, self.global_params) + + def test_bug(self): + solution, status = self.optimizer.optimize() + num_rounds, schedule, remaining_batteries = ( + self.simulator.simulate_smart_selection() + ) + print(solution, sum(solution.values()), num_rounds, remaining_batteries) + self.assertEqual(sum(solution.values()), num_rounds) diff --git a/edml/tests/controllers/sample_config.yaml b/edml/tests/controllers/sample_config.yaml index 505bdae73da5d1fd3901b53d94ae2316cb6ae560..ad8e9175cb16b04de23351fa6794a91b2e41ed92 100644 --- a/edml/tests/controllers/sample_config.yaml +++ b/edml/tests/controllers/sample_config.yaml @@ -24,3 +24,4 @@ battery: deduction_per_mbyte_sent: 1 deduction_per_mbyte_received: 1 deduction_per_mflop: 1 +simulate_parallelism: False diff --git a/edml/tests/controllers/scheduler/smart_test.py b/edml/tests/controllers/scheduler/smart_test.py index 6eb7e758073a82ff965eaf414fc37946164d2541..a711f7d78deb4795859c09edb67b2b10431ae087 100644 --- a/edml/tests/controllers/scheduler/smart_test.py +++ b/edml/tests/controllers/scheduler/smart_test.py @@ -39,7 +39,10 @@ class SmartServerDeviceSelectionTest(unittest.TestCase): "d0": (2, 1), "d1": (4, 1), }, - {"client": 1000000, "server": 1000000}, + { + "client": {"FW": 1000000, "BW": 2000000}, + "server": {"FW": 1000000, "BW": 2000000}, + }, ] def test_select_server_device_smart_first_round(self): diff --git a/edml/tests/controllers/swarm_controller_test.py b/edml/tests/controllers/swarm_controller_test.py index 6c025f4b78a2a648b6055751cf5875b4b7224494..89ad8aa1fe830bc397ff2831cd037df49ba3b4b7 100644 --- a/edml/tests/controllers/swarm_controller_test.py +++ b/edml/tests/controllers/swarm_controller_test.py @@ -37,7 +37,9 @@ class SwarmControllerTest(unittest.TestCase): ) client_weights, server_weights, metrics, optimizer_state, diagnostic_metrics = ( - self.swarm_controller._swarm_train_round(None, None, "d1", 0) + self.swarm_controller._swarm_train_round( + None, None, "d1", 0, optimizer_state={"optimizer_state": 43} + ) ) self.assertEqual(client_weights, {"weights": 42}) @@ -52,7 +54,11 @@ class SwarmControllerTest(unittest.TestCase): ] ) self.mock.train_global_on.assert_called_once_with( - "d1", epochs=1, round_no=0, optimizer_state=None + "d1", + epochs=1, + round_no=0, + adaptive_threshold_value=0.0, + optimizer_state={"optimizer_state": 43}, ) def test_split_train_round_with_inactive_server_device(self): @@ -74,7 +80,11 @@ class SwarmControllerTest(unittest.TestCase): ] ) self.mock.train_global_on.assert_called_once_with( - "d1", epochs=1, round_no=0, optimizer_state=None + "d1", + epochs=1, + round_no=0, + adaptive_threshold_value=0.0, + optimizer_state=None, ) diff --git a/edml/tests/core/device_test.py b/edml/tests/core/device_test.py index 3827df9c727254bb13482e22bef6899bb2ae2c47..85b0e1f7f7204efa6a66b3c95f89a5c392098855 100644 --- a/edml/tests/core/device_test.py +++ b/edml/tests/core/device_test.py @@ -131,7 +131,12 @@ class RPCDeviceServicerTest(unittest.TestCase): {"optimizer_state": 44}, self.diagnostic_metrics, ) - request = connection_pb2.TrainGlobalRequest(epochs=42) + request = connection_pb2.TrainGlobalRequest( + epochs=42, + round_no=1, + adaptive_threshold_value=3, + optimizer_state=state_dict_to_proto({"optimizer_state": 42}), + ) response, metadata, code, details = self.make_call("TrainGlobal", request) @@ -147,7 +152,9 @@ class RPCDeviceServicerTest(unittest.TestCase): proto_to_state_dict(response.optimizer_state), {"optimizer_state": 44} ) self.assertEqual(code, StatusCode.OK) - self.mock_device.train_global.assert_called_once_with(42, 0) + self.mock_device.train_global.assert_called_once_with( + 42, 1, 3, {"optimizer_state": 42} + ) self.assertEqual( proto_to_metrics(response.diagnostic_metrics), self.diagnostic_metrics ) @@ -315,8 +322,8 @@ class RPCDeviceServicerTest(unittest.TestCase): def test_dataset_model_info(self): self.mock_device.client._train_data.dataset = [1] self.mock_device.client._val_data.dataset = [2] - self.mock_device.client._model_flops = 3 - self.mock_device.server._model_flops = 4 + self.mock_device.client._model_flops = {"FW": 3, "BW": 6} + self.mock_device.server._model_flops = {"FW": 4, "BW": 8} request = connection_pb2.DatasetModelInfoRequest() response, metadata, code, details = self.make_call( @@ -326,8 +333,10 @@ class RPCDeviceServicerTest(unittest.TestCase): self.assertEqual(code, StatusCode.OK) self.assertEqual(response.train_samples, 1) self.assertEqual(response.validation_samples, 1) - self.assertEqual(response.client_model_flops, 3) - self.assertEqual(response.server_model_flops, 4) + self.assertEqual(response.client_fw_flops, 3) + self.assertEqual(response.server_fw_flops, 4) + self.assertEqual(response.client_bw_flops, 6) + self.assertEqual(response.server_bw_flops, 8) class RPCDeviceServicerBatteryEmptyTest(unittest.TestCase): @@ -355,7 +364,9 @@ class RPCDeviceServicerBatteryEmptyTest(unittest.TestCase): self.mock_device.train_global.side_effect = BatteryEmptyException( "Battery empty" ) - request = connection_pb2.TrainGlobalRequest() + request = connection_pb2.TrainGlobalRequest( + optimizer_state=state_dict_to_proto(None) + ) self._test_device_out_of_battery_lets_rpc_fail(request, "TrainGlobal") def test_stop_at_set_weights(self): @@ -502,7 +513,7 @@ class RequestDispatcherTest(unittest.TestCase): ) client_weights, server_weights, metrics, optimizer_state, diagnostic_metrics = ( - self.dispatcher.train_global_on("1", 42, 43) + self.dispatcher.train_global_on("1", 42, 43, 3, {"optimizer_state": 44}) ) self.assertEqual(client_weights, self.weights) @@ -513,19 +524,31 @@ class RequestDispatcherTest(unittest.TestCase): self._assert_field_size_added_to_diagnostic_metrics(diagnostic_metrics) self.mock_stub.TrainGlobal.assert_called_once_with( connection_pb2.TrainGlobalRequest( - epochs=42, round_no=43, optimizer_state=state_dict_to_proto(None) + epochs=42, + round_no=43, + adaptive_threshold_value=3, + optimizer_state=state_dict_to_proto({"optimizer_state": 44}), ) ) def test_train_global_on_with_error(self): self.mock_stub.TrainGlobal.side_effect = grpc.RpcError() - response = self.dispatcher.train_global_on("1", 42, round_no=43) + response = self.dispatcher.train_global_on( + "1", + 42, + round_no=43, + adaptive_threshold_value=3, + optimizer_state={"optimizer_state": 44}, + ) self.assertEqual(response, False) self.mock_stub.TrainGlobal.assert_called_once_with( connection_pb2.TrainGlobalRequest( - epochs=42, round_no=43, optimizer_state=state_dict_to_proto(None) + epochs=42, + round_no=43, + adaptive_threshold_value=3, + optimizer_state=state_dict_to_proto({"optimizer_state": 44}), ) ) @@ -809,14 +832,16 @@ class RequestDispatcherTest(unittest.TestCase): connection_pb2.DatasetModelInfoResponse( train_samples=42, validation_samples=21, - client_model_flops=42, - server_model_flops=21, + client_fw_flops=1, + server_fw_flops=2, + client_bw_flops=3, + server_bw_flops=4, ) ) response = self.dispatcher.get_dataset_model_info_on("1") - self.assertEqual(response, (42, 21, 42, 21)) + self.assertEqual(response, (42, 21, 1, 2, 3, 4)) self.mock_stub.GetDatasetModelInfo.assert_called_once_with( connection_pb2.DatasetModelInfoRequest() ) diff --git a/edml/tests/core/server_test.py b/edml/tests/core/server_test.py index 1331f8aa7bad195eb02044bcf25d381fe2ee2a13..c855d82bdb82b26356b91c04960c4c5f8dafd164 100644 --- a/edml/tests/core/server_test.py +++ b/edml/tests/core/server_test.py @@ -13,6 +13,7 @@ from edml.core.battery import Battery from edml.core.client import DeviceClient from edml.core.device import Device from edml.core.server import DeviceServer +from edml.helpers.logging import SimpleLogger class ClientModel(nn.Module): @@ -76,6 +77,7 @@ class PSLTest(unittest.TestCase): ] }, "own_device_id": "d0", + "simulate_parallelism": False, } ) # init models with fixed weights for repeatability @@ -90,10 +92,19 @@ class PSLTest(unittest.TestCase): ) server_model = ServerModel() server_model.load_state_dict(server_state_dict) + server_model.get_optimizer_params = ( + server_model.parameters + ) # using a model provider, this is created automatically client_model1 = ClientModel() client_model1.load_state_dict(client_state_dict) + client_model1.get_optimizer_params = ( + server_model.parameters + ) # using a model provider, this is created automatically client_model2 = ClientModel() client_model2.load_state_dict(client_state_dict) + client_model2.get_optimizer_params = ( + server_model.parameters + ) # using a model provider, this is created automatically self.server = DeviceServer( model=server_model, loss_fn=torch.nn.L1Loss(), cfg=cfg ) @@ -157,6 +168,7 @@ class PSLTest(unittest.TestCase): node_device = Mock(Device) node_device.battery = Mock(Battery) + node_device.logger = Mock(SimpleLogger) node_device.train_batch_on_client_only_on.side_effect = get_client_side_effect( "train_single_batch" ) diff --git a/edml/tests/core/start_device_test.py b/edml/tests/core/start_device_test.py index abb5d9d61d68c791657d168f9e266abec366a74c..dd1f40955f492fda856fbb05194b3d22aa1e165f 100644 --- a/edml/tests/core/start_device_test.py +++ b/edml/tests/core/start_device_test.py @@ -4,11 +4,51 @@ from copy import deepcopy import torch from omegaconf import OmegaConf -from torch.autograd import Variable -from edml.core.start_device import _get_models from edml.helpers.model_splitting import Part from edml.models.autoencoder import ClientWithAutoencoder, ServerWithAutoencoder +from edml.tests.models.model_loading_helpers import ( + _get_model_from_model_provider_config, +) + + +def plot_grad_flow(named_parameters): + from matplotlib import pyplot as plt + import numpy as np + from matplotlib.lines import Line2D + + """Plots the gradients flowing through different layers in the net during training. + Can be used for checking for possible gradient vanishing / exploding problems. + + Usage: Plug this function in Trainer class after loss.backwards() as + "plot_grad_flow(self.model.named_parameters())" to visualize the gradient flow""" + ave_grads = [] + max_grads = [] + layers = [] + for n, p in named_parameters: + if (p.requires_grad) and ("bias" not in n): + layers.append(n) + ave_grads.append(p.grad.abs().mean()) + max_grads.append(p.grad.abs().max()) + plt.bar(np.arange(len(max_grads)), max_grads, alpha=0.1, lw=1, color="c") + plt.bar(np.arange(len(max_grads)), ave_grads, alpha=0.1, lw=1, color="b") + plt.hlines(0, 0, len(ave_grads) + 1, lw=2, color="k") + plt.xticks(range(0, len(ave_grads), 1), layers, rotation="vertical") + plt.xlim(left=0, right=len(ave_grads)) + plt.ylim(bottom=-0.001) # , top=0.02) # zoom in on the lower gradient regions + plt.xlabel("Layers") + plt.ylabel("average gradient") + plt.title("Gradient flow") + plt.grid(True) + plt.legend( + [ + Line2D([0], [0], color="c", lw=4), + Line2D([0], [0], color="b", lw=4), + Line2D([0], [0], color="k", lw=4), + ], + ["max-gradient", "mean-gradient", "zero-gradient"], + ) + plt.show() class GetModelsTest(unittest.TestCase): @@ -22,17 +62,8 @@ class GetModelsTest(unittest.TestCase): ) ) - def _get_model_from_model_provider_config(self, config_name): - self.cfg.model_provider = OmegaConf.load( - os.path.join( - os.path.dirname(__file__), - f"../../config/model_provider/{config_name}.yaml", - ) - ) - return _get_models(self.cfg) - def test_load_resnet20(self): - client, server = self._get_model_from_model_provider_config("resnet20") + client, server = _get_model_from_model_provider_config(self.cfg, "resnet20") self.assertIsInstance(client, Part) self.assertIsInstance(server, Part) self.assertEqual(len(client.layers), 4) @@ -40,29 +71,18 @@ class GetModelsTest(unittest.TestCase): self.assertEqual(server(client(torch.zeros(1, 3, 32, 32))).shape, (1, 100)) def test_load_resnet20_with_ae(self): - client, server = self._get_model_from_model_provider_config( - "resnet20-with-autoencoder" + client, server = _get_model_from_model_provider_config( + self.cfg, "resnet20-with-autoencoder" ) self.assertIsInstance(client, ClientWithAutoencoder) self.assertIsInstance(server, ServerWithAutoencoder) self.assertEqual(len(client.model.layers), 4) self.assertEqual(len(server.model.layers), 5) self.assertEqual(server(client(torch.zeros(1, 3, 32, 32))).shape, (1, 100)) - optimizer = torch.optim.Adam(server.parameters()) - smashed_data = client(torch.zeros(1, 3, 32, 32)) - server_smashed_data = Variable(smashed_data, requires_grad=True) - output_train = server(server_smashed_data) - loss_train = torch.nn.functional.cross_entropy( - output_train, torch.zeros((1, 100)) - ) - loss_train.backward() - optimizer.step() - smashed_data.backward(server_smashed_data.grad) - optimizer.step() def test_training_resnet20_with_ae_as_non_trainable_layers(self): - client_encoder, server_decoder = self._get_model_from_model_provider_config( - "resnet20-with-autoencoder" + client_encoder, server_decoder = _get_model_from_model_provider_config( + self.cfg, "resnet20-with-autoencoder" ) client_params = deepcopy(str(client_encoder.model.state_dict())) encoder_params = deepcopy(str(client_encoder.autoencoder.state_dict())) @@ -70,17 +90,18 @@ class GetModelsTest(unittest.TestCase): decoder_params = deepcopy(str(server_decoder.autoencoder.state_dict())) # Training loop - client_optimizer = torch.optim.Adam(client_encoder.parameters()) - server_optimizer = torch.optim.Adam(server_decoder.parameters()) - smashed_data = client_encoder(torch.zeros(1, 3, 32, 32)) - server_smashed_data = Variable(smashed_data, requires_grad=True) - output_train = server_decoder(server_smashed_data) + client_optimizer = torch.optim.Adam(client_encoder.get_optimizer_params()) + server_optimizer = torch.optim.Adam(server_decoder.get_optimizer_params()) + smashed_data = client_encoder(torch.zeros(7, 3, 32, 32)) + output_train = server_decoder(smashed_data) loss_train = torch.nn.functional.cross_entropy( - output_train, torch.rand((1, 100)) + output_train, torch.rand((7, 100)) ) loss_train.backward() server_optimizer.step() - smashed_data.backward(server_smashed_data.grad) + client_encoder.trainable_layers_output.backward( + server_decoder.trainable_layers_input.grad + ) client_optimizer.step() # check that AE hasn't changed, but client and server have @@ -90,7 +111,7 @@ class GetModelsTest(unittest.TestCase): self.assertNotEqual(server_params, str(server_decoder.model.state_dict())) def test_load_resnet110(self): - client, server = self._get_model_from_model_provider_config("resnet110") + client, server = _get_model_from_model_provider_config(self.cfg, "resnet110") self.assertIsInstance(client, Part) self.assertIsInstance(server, Part) self.assertEqual(len(client.layers), 4) @@ -98,8 +119,8 @@ class GetModelsTest(unittest.TestCase): self.assertEqual(server(client(torch.zeros(1, 3, 32, 32))).shape, (1, 100)) def test_load_resnet110_with_ae(self): - client, server = self._get_model_from_model_provider_config( - "resnet110-with-autoencoder" + client, server = _get_model_from_model_provider_config( + self.cfg, "resnet110-with-autoencoder" ) self.assertIsInstance(client, ClientWithAutoencoder) self.assertIsInstance(server, ServerWithAutoencoder) diff --git a/edml/tests/helpers/flops_test.py b/edml/tests/helpers/flops_test.py index 4b7f0d66bbfe6c74ba37e517969bb596842277b4..cb966d4864c1b11b16f75e1f1ed94a504db98a1f 100644 --- a/edml/tests/helpers/flops_test.py +++ b/edml/tests/helpers/flops_test.py @@ -1,10 +1,15 @@ +import os import unittest import torch import torch.nn as nn +from omegaconf import OmegaConf from edml.helpers.flops import estimate_model_flops from edml.models.mnist_models import ClientNet, ServerNet +from edml.tests.models.model_loading_helpers import ( + _get_model_from_model_provider_config, +) class FullTestModel(nn.Module): @@ -55,9 +60,15 @@ class FlopsTest(unittest.TestCase): client_flops = estimate_model_flops(client_model, inputs) server_flops = estimate_model_flops(server_model, server_inputs) - self.assertEqual(client_flops, 30000) - self.assertEqual(server_flops, 101000) - self.assertEqual(full_flops, client_flops + server_flops) + self.assertEqual(client_flops, {"BW": 60000, "FW": 30000}) + self.assertEqual(server_flops, {"BW": 202000, "FW": 101000}) + self.assertEqual( + full_flops, + { + "BW": client_flops["BW"] + server_flops["BW"], + "FW": client_flops["FW"] + server_flops["FW"], + }, + ) def test_mnist_split_count(self): client_model = ClientNet() @@ -69,5 +80,34 @@ class FlopsTest(unittest.TestCase): client_flops = estimate_model_flops(client_model, inputs) server_flops = estimate_model_flops(server_model, server_inputs) - self.assertEqual(client_flops, 5405760) - self.assertEqual(server_flops, 11215800) + self.assertEqual(client_flops, {"BW": 10811520, "FW": 5405760}) + self.assertEqual(server_flops, {"BW": 22431600, "FW": 11215800}) + + def test_autoencoder_does_not_affect_backward_flops(self): + os.chdir(os.path.join(os.path.dirname(__file__), "../../../")) + client, server = _get_model_from_model_provider_config( + OmegaConf.create({}), "resnet20" + ) + client_with_encoder, server_with_decoder = ( + _get_model_from_model_provider_config( + OmegaConf.create({}), "resnet20-with-autoencoder" + ) + ) + input = torch.randn(1, 3, 32, 32) + + ae_smashed_data = client_with_encoder(input) + smashed_data = client(input) + + client_flops = estimate_model_flops(client, input) + server_flops = estimate_model_flops(server, smashed_data) + + client_with_encoder_flops = estimate_model_flops(client_with_encoder, input) + server_with_decoder_flops = estimate_model_flops( + server_with_decoder, ae_smashed_data + ) + + self.assertEqual(client_flops["BW"], client_with_encoder_flops["BW"]) + self.assertEqual(server_flops["BW"], server_with_decoder_flops["BW"]) + + self.assertGreater(client_with_encoder_flops["FW"], client_flops["FW"]) + self.assertGreater(server_with_decoder_flops["FW"], server_flops["FW"]) diff --git a/edml/tests/models/model_loading_helpers.py b/edml/tests/models/model_loading_helpers.py new file mode 100644 index 0000000000000000000000000000000000000000..18b88a0618c7c3ac490400ee500a019ec19578ae --- /dev/null +++ b/edml/tests/models/model_loading_helpers.py @@ -0,0 +1,15 @@ +import os + +from omegaconf import OmegaConf + +from edml.core.start_device import _get_models + + +def _get_model_from_model_provider_config(cfg, config_name): + cfg.model_provider = OmegaConf.load( + os.path.join( + os.path.dirname(__file__), + f"../../config/model_provider/{config_name}.yaml", + ) + ) + return _get_models(cfg) diff --git a/edml/tests/models/model_provider_test.py b/edml/tests/models/model_provider_test.py new file mode 100644 index 0000000000000000000000000000000000000000..bf872d0c07d0e40188208978e258fd1298c74eaa --- /dev/null +++ b/edml/tests/models/model_provider_test.py @@ -0,0 +1,43 @@ +import os +import unittest + +from omegaconf import OmegaConf + +from edml.models.provider.base import has_autoencoder +from edml.tests.models.model_loading_helpers import ( + _get_model_from_model_provider_config, +) + + +class ModelProviderTest(unittest.TestCase): + def setUp(self): + os.chdir(os.path.join(os.path.dirname(__file__), "../../../")) + self.cfg = OmegaConf.create({"some_key": "some_value"}) + + def test_has_autoencoder_for_model_without_autoencoder(self): + client, server = _get_model_from_model_provider_config(self.cfg, "resnet20") + self.assertFalse(has_autoencoder(client)) + self.assertFalse(has_autoencoder(server)) + + def test_has_autoencoder_for_model_with_autoencoder(self): + client, server = _get_model_from_model_provider_config( + self.cfg, "resnet20-with-autoencoder" + ) + self.assertTrue(has_autoencoder(client)) + self.assertTrue(has_autoencoder(server)) + + def test_get_optimizer_params_for_model_without_autoencoder(self): + client, server = _get_model_from_model_provider_config(self.cfg, "resnet20") + self.assertEqual(list(client.get_optimizer_params()), list(client.parameters())) + self.assertEqual(list(server.get_optimizer_params()), list(server.parameters())) + + def test_get_optimizer_params_for_model_with_autoencoder(self): + client, server = _get_model_from_model_provider_config( + self.cfg, "resnet20-with-autoencoder" + ) + self.assertEqual( + list(client.get_optimizer_params()), list(client.model.parameters()) + ) + self.assertEqual( + list(server.get_optimizer_params()), list(server.model.parameters()) + ) diff --git a/results/result_generation.ipynb b/results/result_generation.ipynb index 5c3ed118508856d66ca7b70f0027b99a2cb8052f..fe4e16d604eb3d3ffcc2f1bc3f8c731019b6b797 100644 --- a/results/result_generation.ipynb +++ b/results/result_generation.ipynb @@ -13,31 +13,56 @@ ] }, { - "cell_type": "markdown", + "cell_type": "code", + "execution_count": null, + "outputs": [], "source": [ - "# Downloading the dataframes (takes time)" + "df_base_dir = \"./dataframes\"\n", + "projects_with_model = [\n", + " (\"5_devices_unlimited_new\", \"resnet110\"),\n", + " (\"50_devices_unlimited_new\", \"resnet110\"),\n", + " (\"controller_comparison\", \"resnet110\"),\n", + " (\"controller_comparison_het_bat\", \"resnet110\"),\n", + " (\"controller_comparison_homogeneous\", \"resnet110\")\n", + "]" ], "metadata": { "collapsed": false }, - "id": "3e72d0fe76a55e56" + "id": "5b81c8c9ba4b483d" }, { "cell_type": "code", "execution_count": null, "outputs": [], "source": [ - "df_base_dir = \"./dataframes\"\n", - "projects_with_model = [\n", - " (\"5_devices_unlimited_new\", \"resnet110\"),\n", - " (\"50_devices_unlimited_new\", \"resnet110\"),\n", - " (\"controller_comparison\", \"resnet110\")\n", - "]" + "strategy_autoencoder = {\n", + " \"psl_sequential__\": False,\n", + " \"fed___\": False,\n", + " \"split___\": False,\n", + " \"swarm_sequential__\": False,\n", + " \"swarm_max_battery__\": False,\n", + " \"swarm_smart__\": False,\n", + " \"psl_sequential_static_at_resnet_decoderpth\": True,\n", + " \"psl_sequential__resnet_decoderpth\": True,\n", + " \"psl_sequential_static_at_\": False,\n", + "}\n", + "strategies = list(strategy_autoencoder.keys())" ], "metadata": { "collapsed": false }, - "id": "5b81c8c9ba4b483d" + "id": "26e70757d4650fc1" + }, + { + "cell_type": "markdown", + "source": [ + "# Downloading the dataframes (takes time)" + ], + "metadata": { + "collapsed": false + }, + "id": "3e72d0fe76a55e56" }, { "cell_type": "code", @@ -45,29 +70,7 @@ "outputs": [], "source": [ "for project_name, _ in projects_with_model:\n", - " save_dataframes(project_name=project_name, strategies=[\n", - " #\"swarm_seq\",\n", - " #\"fed\",\n", - " #\"swarm_max\",\n", - " #\"swarm_rand\",\n", - " #\"swarm_smart\",\n", - " #\"split\",\n", - " #\"psl_rand_\",\n", - " #\"psl_sequential_\",\n", - " #\"psl_max_batteries_\",\n", - " #\"swarm_rand_\",\n", - " #\"swarm_sequential_\",\n", - " #\"swarm_max_batteries_\",\n", - " \"psl_sequential__\",\n", - " \"fed___\",\n", - " \"split___\",\n", - " \"swarm_sequential__\",\n", - " \"swarm_max_battery__\",\n", - " \"swarm_smart__\",\n", - " \"psl_sequential_static_at_resnet_decoderpth\",\n", - " \"psl_sequential__resnet_decoderpth\",\n", - " \"psl_sequential_static_at_\",\n", - " ])" + " save_dataframes(project_name=project_name, strategies=strategies)" ], "metadata": { "collapsed": false @@ -90,27 +93,23 @@ "outputs": [], "source": [ "# Required for total number of FLOPs computation\n", + "# plain = forward FLOPs without AE = 1/2 backward FLOPs with or without AE (number of backward FLOPs equal with and without AE as AE is skipped during BP)\n", + "# ae = forward FLOPs with AE\n", "model_flops = {\n", - " \"resnet20\": 41498880,\n", - " \"resnet20_ae\": 45758720,\n", - " \"resnet110\": 258136320,\n", - " \"resnet110_ae\": 262396160,\n", - " \"tcn\": 27240000,\n", - " \"simple_conv\": 16621560\n", + " \"resnet20\": {\"plain\": 41498880, \"ae\": 45758720},\n", + " \"resnet110\": {\"plain\": 258136320, \"ae\": 262396160},\n", + " \"tcn\": {\"plain\": 27240000, \"ae\": 27240000}, # no AE implemented yet\n", + " \"simple_conv\": {\"plain\": 16621560, \"ae\": 16621560}, # no AE implemented yet\n", "}\n", "\n", "client_model_flops = {\n", - " \"resnet20\": 15171584,\n", - " \"resnet20_ae\": 19005440,\n", - " \"resnet110\": 88408064,\n", - " \"resnet110_ae\": 92241920,\n", + " \"resnet20\": {\"plain\": 15171584, \"ae\": 19005440},\n", + " \"resnet110\": {\"plain\": 88408064, \"ae\": 92241920},\n", "}\n", "\n", "server_model_flops = {\n", - " \"resnet20\": 26327296,\n", - " \"resnet20_ae\": 26753280,\n", - " \"resnet110\": 169728256,\n", - " \"resnet110_ae\": 170154240,\n", + " \"resnet20\": {\"plain\": 26327296, \"ae\": 26753280},\n", + " \"resnet110\": {\"plain\": 169728256, \"ae\": 170154240},\n", "}\n", "experiment_batch_size = 64" ], @@ -144,7 +143,7 @@ " dataframes = load_dataframes(proj_name, df_base_dir)\n", " print(\" generating metrics\")\n", " generate_metric_files(dataframes, proj_name, model_flops[model_name], client_model_flops[model_name],\n", - " # TODO distinguish AE\n", + " strategy_autoencoder_mapping=strategy_autoencoder,\n", " base_path=metrics_base_path, batch_size=experiment_batch_size)\n", " print(\" generating plots\")\n", " generate_plots(dataframes, proj_name)" diff --git a/results/result_generation.py b/results/result_generation.py index caa42642976785d82f0313733b324c9cd63b1166..f22bca3d8558982250e4fa8d3376ee41b70d7539 100644 --- a/results/result_generation.py +++ b/results/result_generation.py @@ -39,6 +39,8 @@ LABEL_MAPPING = { "device": "Device", } +DPI = 300 + def scale_parallel_time(run_df, scale_factor=1.0): """ @@ -230,7 +232,13 @@ def load_dataframes(project_name, base_dir="./dataframes"): return history_groups -def get_total_flops(groups, total_model_flops, client_model_flops, batch_size=64): +def get_total_flops( + groups, + total_model_flops, + client_model_flops, + strategy_autoencoder_mapping, + batch_size=64, +): """ Returns the total number of FLOPs for each group. Args: @@ -242,6 +250,13 @@ def get_total_flops(groups, total_model_flops, client_model_flops, batch_size=64 """ flops_per_group = {"strategy": [], "flops": []} for (strategy, job), group in groups.items(): + # determine model FLOPs depending on whether an autoencoder was used + if strategy_autoencoder_mapping[strategy]: # AE + total_model_fw_flops = total_model_flops["ae"] + else: # no AE + total_model_fw_flops = total_model_flops["plain"] + total_model_bw_flops = 2 * total_model_flops["plain"] + client_model_bw_flops = 2 * client_model_flops["plain"] if job == "train": flops = 0 num_runs = 1 # avoid division by 0 @@ -254,12 +269,12 @@ def get_total_flops(groups, total_model_flops, client_model_flops, batch_size=64 for run_df in runs: for col_name in run_df.columns: if col_name == "train_accuracy.num_samples": - flops += ( - run_df[col_name].sum() * total_model_flops * 3 + flops += run_df[col_name].sum() * ( + total_model_fw_flops + total_model_bw_flops ) # 1x forward + 2x backward if col_name == "val_accuracy.num_samples": flops += ( - run_df[col_name].sum() * total_model_flops + run_df[col_name].sum() * total_model_fw_flops ) # 1x forward if col_name == "adaptive_learning_threshold_applied": # deduce client model flops twice as client backprop is avoided @@ -285,13 +300,12 @@ def get_total_flops(groups, total_model_flops, client_model_flops, batch_size=64 ) flops -= ( len(run_df[col_name].dropna()) - * client_model_flops - * 2 + * client_model_bw_flops * avg_samples_per_batch ) else: # numbers of samples skipped are logged -> sum up flops -= ( - run_df[col_name].sum() * client_model_flops * 2 + run_df[col_name].sum() * client_model_bw_flops ) flops = flops / num_runs flops_per_group["strategy"].append(STRATEGY_MAPPING[strategy]) @@ -401,7 +415,7 @@ def plot_remaining_devices(devices_per_round, save_path=None): save_path: (str) the path to save the plot to """ - plt.figure() + plt.figure(dpi=DPI) num_rounds = [0] max_devices = [] for (strategy, job), (rounds, num_devices) in devices_per_round.items(): @@ -492,7 +506,7 @@ def plot_accuracies(accuracies_per_round, save_path=None, phase="train"): accuracies_per_round: (dict) the accuracy (list(float)) per round (list(int)) for each group save_path: (str) the path to save the plot to """ - plt.figure() + plt.figure(dpi=DPI) num_rounds = [0] for (strategy, job), (rounds, accs) in accuracies_per_round.items(): plt.plot(rounds, accs, label=f"{STRATEGY_MAPPING[strategy]}") @@ -517,7 +531,7 @@ def plot_accuracies_over_time(accuracies_per_time, save_path=None, phase="train" accuracies_per_time: (dict) the accuracy (list(float)) per time (list(float)) for each group save_path: (str) the path to save the plot to """ - plt.figure() + plt.figure(dpi=DPI) for (strategy, job), (time, accs) in accuracies_per_time.items(): plt.plot(time, accs, label=f"{STRATEGY_MAPPING[strategy]}") plt.xlabel(LABEL_MAPPING["runtime"]) @@ -688,7 +702,7 @@ def plot_batteries_over_time( aggregated: (bool) whether the battery is aggregated or not """ if aggregated: - plt.figure() + plt.figure(dpi=DPI) plt.rcParams.update({"font.size": 13}) for (strategy, job), series in batteries_over_time.items(): runtime = max_runtimes[(strategy, job)] @@ -709,7 +723,7 @@ def plot_batteries_over_time( plt.close() else: for (strategy, job), series_dict in batteries_over_time.items(): - plt.figure() + plt.figure(dpi=DPI) plt.rcParams.update({"font.size": 13}) for device_id, series in series_dict.items(): runtime = max_runtimes[(strategy, job)] @@ -739,7 +753,7 @@ def plot_batteries_over_epoch(batteries_over_epoch, save_path=None, aggregated=T aggregated: (bool) whether the battery is aggregated or not """ if aggregated: - plt.figure() + plt.figure(dpi=DPI) plt.rcParams.update({"font.size": 13}) num_rounds = [0] for (strategy, job), series in batteries_over_epoch.items(): @@ -760,7 +774,7 @@ def plot_batteries_over_epoch(batteries_over_epoch, save_path=None, aggregated=T plt.close() else: for (strategy, job), series_dict in batteries_over_epoch.items(): - plt.figure() + plt.figure(dpi=DPI) plt.rcParams.update({"font.size": 13}) num_rounds = [0] for device_id, series in series_dict.items(): @@ -899,7 +913,7 @@ def plot_batteries_over_time_with_activity( 0, train_time_end(server_train_times, client_train_times) * 1.05, ) # set end 5% after last activity timestamp - plt.figure() + plt.figure(dpi=DPI) plt.rcParams.update({"font.size": 13}) # battery_plot plt.subplot(2, 1, 1) @@ -976,7 +990,7 @@ def plot_batteries_over_epoch_with_activity_at_time_scale( ], [str(i) for i in range(0, len(start_times), max(1, len(start_times) // 8))], ) - plt.figure() + plt.figure(dpi=DPI) plt.rcParams.update({"font.size": 13}) # battery_plot plt.subplot(2, 1, 1) @@ -1048,7 +1062,7 @@ def plot_batteries_over_epoch_with_activity_at_epoch_scale( ).sort_values("start") # battery plot - plt.figure() + plt.figure(dpi=DPI) plt.rcParams.update({"font.size": 13}) plt.subplot(2, 1, 1) num_rounds = [] @@ -1242,6 +1256,7 @@ def generate_metric_files( project_name, total_model_flops, client_model_flops, + strategy_autoencoder_mapping, base_path="./metrics", batch_size=64, ): @@ -1265,7 +1280,11 @@ def generate_metric_files( ).set_index("strategy") total_flops = pd.DataFrame.from_dict( get_total_flops( - history_groups, total_model_flops, client_model_flops, batch_size + history_groups, + total_model_flops, + client_model_flops, + strategy_autoencoder_mapping, + batch_size, ) ).set_index("strategy") df = pd.concat([test_acc, comm_overhead, total_flops], axis=1)