diff --git a/config/controller/parallel_swarm2.yaml b/config/controller/parallel_swarm2.yaml new file mode 100644 index 0000000000000000000000000000000000000000..fcf824f56607b1f672a2f6d5d5cef8e457d5480a --- /dev/null +++ b/config/controller/parallel_swarm2.yaml @@ -0,0 +1,5 @@ +name: psl +_target_: edml.controllers.parallel_split_controller.ParallelSplitController +_partial_: true +scheduler: + _target_: edml.controllers.scheduler.sequential.SequentialNextServerScheduler diff --git a/config/controller/parallel_swarm_ash_1.65.yaml b/config/controller/parallel_swarm_ash_1.65.yaml new file mode 100644 index 0000000000000000000000000000000000000000..9f5616323ef4deb46eefcf2bf4ba091b406437da --- /dev/null +++ b/config/controller/parallel_swarm_ash_1.65.yaml @@ -0,0 +1,7 @@ +_target_: edml.controllers.parallel_split_controller.ParallelSplitController +_partial_: true +scheduler: + _target_: edml.controllers.scheduler.sequential.SequentialNextServerScheduler +adaptive_threshold_fn: + _target_: edml.controllers.adaptive_threshold.static.StaticAdaptiveThresholdFn + threshold: 1.65 diff --git a/config/default.yaml b/config/default.yaml index 4e6cf1abac94a4364b165312cbd7926837d89756..1b82a0f94f915f7c73ed6065c3959532d3bc08d3 100644 --- a/config/default.yaml +++ b/config/default.yaml @@ -4,12 +4,12 @@ defaults: - dataset: mnist - battery: flops_and_communication - loss_fn: !!null + - experiment: default_experiment - model_provider: mnist - optimizer: !!null - scheduler: !!null - seed: default - topology: equal_batteries - - experiment: default_experiment - grpc: default - wandb: default - _self_ @@ -17,6 +17,12 @@ defaults: own_device_id: "d0" num_devices: ${len:${topology.devices}} +# define config attributes for the group: +group_by: + - controller: [ name, scheduler: name, adaptive_threshold_fn: name ] +# group attribute determined by resolver with the given attributes +group: ${group_name:${group_by}} + # This registers the framework-provided configuration files with hydra. hydra: searchpath: diff --git a/config/experiment/baseline.yaml b/config/experiment/baseline.yaml index cb022990d65e2a2305011e7ef99968d366912953..51dd14aed7bb47f4a5cbfc08b02b81cac5e85242 100644 --- a/config/experiment/baseline.yaml +++ b/config/experiment/baseline.yaml @@ -19,8 +19,8 @@ early_stopping_metric: accuracy # Dataset partitioning. partition: True -fractions: [ 0.1, 0.1, 0.1, 0.1, 0.1 ] # set to !!null if dataset should not be partitioned or partitioned equally -latency: [ 0.0, 1.0, 0.0, 0.0, 0.0 ] # set to !!null for no latency +fractions: !!null # set to !!null if dataset should not be partitioned or partitioned equally +latency: !!null # set to !!null for no latency # Debug. load_single_batch_for_debugging: False diff --git a/config/experiment/cifar100-effectiveness-adaptive-threshold-mechanism-none.yaml b/config/experiment/cifar100-effectiveness-adaptive-threshold-mechanism-none.yaml new file mode 100644 index 0000000000000000000000000000000000000000..89d487cef2d286fb0139cf965c44cbb4551b05bd --- /dev/null +++ b/config/experiment/cifar100-effectiveness-adaptive-threshold-mechanism-none.yaml @@ -0,0 +1,26 @@ +# Base properties for the experiment. +project: inda-ml-comparisons +name: cifar100-effectiveness-adaptive-threshold-mechanism-none +job: train + +# Training parameters. +batch_size: 64 +max_epochs: 1 +max_rounds: 200 +metrics: [ accuracy ] + +# Checkpoint saving and early stopping. +save_weights: True +server_model_save_path: "edml/models/weights/" +client_model_save_path: "edml/models/weights/" +early_stopping: True +early_stopping_patience: 200 +early_stopping_metric: accuracy + +# Dataset partitioning. +partition: True +fractions: !!null +latency: !!null + +# Debug. +load_single_batch_for_debugging: False diff --git a/config/experiment/cifar100-effectiveness-adaptive-threshold-mechanism.yaml b/config/experiment/cifar100-effectiveness-adaptive-threshold-mechanism.yaml new file mode 100644 index 0000000000000000000000000000000000000000..ad7614dcf4a01d6c75776bef1f136bdea03c1010 --- /dev/null +++ b/config/experiment/cifar100-effectiveness-adaptive-threshold-mechanism.yaml @@ -0,0 +1,26 @@ +# Base properties for the experiment. +project: inda-ml-comparisons +name: cifar100-effectiveness-adaptive-threshold-mechanism +job: train + +# Training parameters. +batch_size: 64 +max_epochs: 1 +max_rounds: 200 +metrics: [ accuracy ] + +# Checkpoint saving and early stopping. +save_weights: True +server_model_save_path: "edml/models/weights/" +client_model_save_path: "edml/models/weights/" +early_stopping: True +early_stopping_patience: 200 +early_stopping_metric: accuracy + +# Dataset partitioning. +partition: True +fractions: !!null +latency: !!null + +# Debug. +load_single_batch_for_debugging: False diff --git a/config/optimizer/sdg_with_momentum.yaml b/config/optimizer/sdg_with_momentum.yaml new file mode 100644 index 0000000000000000000000000000000000000000..a32f2ff6ddbcb1777c847d651c39386711cb5252 --- /dev/null +++ b/config/optimizer/sdg_with_momentum.yaml @@ -0,0 +1,4 @@ +_target_: torch.optim.SGD +lr: 0.1 +momentum: 0.9 +weight_decay: 0.0001 diff --git a/config/sweep/cifar100/cifar100-effectiveness-adaptive-threshold-mechanism.yaml b/config/sweep/cifar100/cifar100-effectiveness-adaptive-threshold-mechanism.yaml new file mode 100644 index 0000000000000000000000000000000000000000..e8f367ee18ad6235e9eaf81fc5c57393cca363f4 --- /dev/null +++ b/config/sweep/cifar100/cifar100-effectiveness-adaptive-threshold-mechanism.yaml @@ -0,0 +1,19 @@ +# @package _global_ +defaults: + - override /battery: unlimited + - override /dataset: cifar100 + - override /experiment: cifar100-effectiveness-adaptive-threshold-mechanism + - override /loss_fn: cross_entropy + - override /model_provider: resnet20 + - override /optimizer: sdg_with_momentum + - override /scheduler: multistep + - override /topology: equal_batteries + - _self_ + +hydra: + mode: MULTIRUN + sweeper: + params: + +controller: parallel_swarm #parallel_swarm_ash_1.65 + # +controller/scheduler: max_battery + # controller.adaptive_learning_threshold: 1.65 diff --git a/config/sweep/mnist/all.yaml b/config/sweep/mnist/all.yaml index c3546725359f4eb3e08ecf3549b5f7c072b64f8f..8137a13cc92426ad4c5eefb7b5ad1b04205f54e2 100644 --- a/config/sweep/mnist/all.yaml +++ b/config/sweep/mnist/all.yaml @@ -11,5 +11,5 @@ hydra: mode: MULTIRUN sweeper: params: - +controller: fed,swarm,parallel_swarm - +controller/scheduler: max_battery,sequential,rand + +controller: swarm,parallel_swarm + controller/scheduler: max_battery,sequential,rand diff --git a/edml/config/controller/adaptive_threshold_fn/dynamic.yaml b/edml/config/controller/adaptive_threshold_fn/dynamic.yaml new file mode 100644 index 0000000000000000000000000000000000000000..4d81a4797e2a99b0697029f37d8e94c27f9fa707 --- /dev/null +++ b/edml/config/controller/adaptive_threshold_fn/dynamic.yaml @@ -0,0 +1,5 @@ +_target_: edml.controllers.adaptive_threshold_mechanism.dynamic.LogarithmicDecayAdaptiveThresholdFn +name: log_decay_at +starting_value: 4 +approach_value: 1 +decay_rate: 0.05 diff --git a/edml/config/controller/adaptive_threshold_fn/static.yaml b/edml/config/controller/adaptive_threshold_fn/static.yaml new file mode 100644 index 0000000000000000000000000000000000000000..6ec21ff8dffffc281964d798c13b455ea9e43b15 --- /dev/null +++ b/edml/config/controller/adaptive_threshold_fn/static.yaml @@ -0,0 +1,3 @@ +name: static_at +_target_: edml.controllers.adaptive_threshold_mechanism.static.StaticAdaptiveThresholdFn +threshold: 1.65 diff --git a/edml/config/controller/fed.yaml b/edml/config/controller/fed.yaml index 97562c045e264f7e58f12a9d0304d881e4830a4f..f0a0886c6aff43a87b8ca5d0616f44393a31ced3 100644 --- a/edml/config/controller/fed.yaml +++ b/edml/config/controller/fed.yaml @@ -1,2 +1,3 @@ +name: fed _target_: edml.controllers.fed_controller.FedController _partial_: true diff --git a/edml/config/controller/parallel_swarm.yaml b/edml/config/controller/parallel_swarm.yaml index d835de6f24284633ef29f49e902f2ef4903006ef..be9d2a6b43df282facea48954ff6fa6ca7907fdc 100644 --- a/edml/config/controller/parallel_swarm.yaml +++ b/edml/config/controller/parallel_swarm.yaml @@ -1,4 +1,6 @@ +name: psl _target_: edml.controllers.parallel_split_controller.ParallelSplitController _partial_: true -scheduler: - _target_: edml.controllers.scheduler.sequential.SequentialNextServerScheduler +defaults: + - scheduler: sequential + - adaptive_threshold_fn: !!null diff --git a/edml/config/controller/scheduler/max_battery.yaml b/edml/config/controller/scheduler/max_battery.yaml index 045603ca935f06679fe839119963ca74497c2fc2..d0a258cb82414f2adec4685c1e02cf559dd120d3 100644 --- a/edml/config/controller/scheduler/max_battery.yaml +++ b/edml/config/controller/scheduler/max_battery.yaml @@ -1 +1,2 @@ +name: max_battery _target_: edml.controllers.scheduler.max_battery.MaxBatteryNextServerScheduler diff --git a/edml/config/controller/scheduler/rand.yaml b/edml/config/controller/scheduler/rand.yaml index 1e148289377f65917fa8550d0fe0efead49613cf..1db64320ca7c09694b7108c176d8569e7fa88c17 100644 --- a/edml/config/controller/scheduler/rand.yaml +++ b/edml/config/controller/scheduler/rand.yaml @@ -1 +1,2 @@ +name: rand _target_: edml.controllers.scheduler.random.RandomNextServerScheduler diff --git a/edml/config/controller/scheduler/sequential.yaml b/edml/config/controller/scheduler/sequential.yaml index fffebf9175de6c3aea5fe8121f2687e6af4bf404..0f920dc24f8766f5a35b4ba2e4962f59fd041307 100644 --- a/edml/config/controller/scheduler/sequential.yaml +++ b/edml/config/controller/scheduler/sequential.yaml @@ -1 +1,2 @@ +name: sequential _target_: edml.controllers.scheduler.sequential.SequentialNextServerScheduler diff --git a/edml/config/controller/swarm.yaml b/edml/config/controller/swarm.yaml index 328b5085f3c96220a477a22007fc59cda943cd8a..f31a60a20c1e4cff7f8eec4bd9b1ff4958753ab5 100644 --- a/edml/config/controller/swarm.yaml +++ b/edml/config/controller/swarm.yaml @@ -1,4 +1,5 @@ +name: swarm _target_: edml.controllers.swarm_controller.SwarmController _partial_: true -scheduler: - _target_: edml.controllers.scheduler.sequential.SequentialNextServerScheduler +defaults: + - scheduler: sequential diff --git a/edml/config/model_provider/resnet20.yaml b/edml/config/model_provider/resnet20.yaml new file mode 100644 index 0000000000000000000000000000000000000000..144ba60783f58c22e791f84c67725e779502a648 --- /dev/null +++ b/edml/config/model_provider/resnet20.yaml @@ -0,0 +1,9 @@ +_target_: edml.models.provider.cut_layer.CutLayerModelProvider +model: + _target_: edml.models.resnet_models.ResNet + block: + _target_: hydra.utils.get_class + path: edml.models.resnet_models.BasicBlock + num_blocks: [ 3, 3, 3 ] + num_classes: 100 +cut_layer: 4 diff --git a/edml/controllers/adaptive_threshold_mechanism/__init__.py b/edml/controllers/adaptive_threshold_mechanism/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..6e85555cd7127d35fa7be4ffb2d0e1ce9c1517fd --- /dev/null +++ b/edml/controllers/adaptive_threshold_mechanism/__init__.py @@ -0,0 +1,9 @@ +from abc import ABC, abstractmethod + + +class AdaptiveThresholdFn(ABC): + """A function that returns the adaptive threshold value based on the current round.""" + + @abstractmethod + def invoke(self, round_no: int) -> float: + """Return the adaptive threshold value for the given round number.""" diff --git a/edml/controllers/adaptive_threshold_mechanism/dynamic.py b/edml/controllers/adaptive_threshold_mechanism/dynamic.py new file mode 100644 index 0000000000000000000000000000000000000000..b73e3e1cfdb82a9183d7ad7a4afdecb6dfe2c91d --- /dev/null +++ b/edml/controllers/adaptive_threshold_mechanism/dynamic.py @@ -0,0 +1,19 @@ +import numpy as np + +from edml.controllers.adaptive_threshold_mechanism import AdaptiveThresholdFn + + +class LogarithmicDecayAdaptiveThresholdFn(AdaptiveThresholdFn): + + def __init__( + self, starting_value: float, approach_value: float, decay_rate: float = 1.0 + ): + super().__init__() + self._start = starting_value + self._end = approach_value + self._decay_rate = decay_rate + + def invoke(self, round_no: int): + return self._end + (self._start - self._end) * np.exp( + -self._decay_rate * round_no + ) diff --git a/edml/controllers/adaptive_threshold_mechanism/static.py b/edml/controllers/adaptive_threshold_mechanism/static.py new file mode 100644 index 0000000000000000000000000000000000000000..cb6e78e12d9fc7238b231f591faccf5b8c6b1562 --- /dev/null +++ b/edml/controllers/adaptive_threshold_mechanism/static.py @@ -0,0 +1,10 @@ +from edml.controllers.adaptive_threshold_mechanism import AdaptiveThresholdFn + + +class StaticAdaptiveThresholdFn(AdaptiveThresholdFn): + def __init__(self, threshold: float): + super().__init__() + self._threshold = threshold + + def invoke(self, round_no: int) -> float: + return self._threshold diff --git a/edml/controllers/base_controller.py b/edml/controllers/base_controller.py index 4967be56448e96a3ecf24e02433e91f3ceadc04b..64e763c96197971dfd1f8243405d9107fc7ce577 100644 --- a/edml/controllers/base_controller.py +++ b/edml/controllers/base_controller.py @@ -200,4 +200,4 @@ class BaseController(abc.ABC): def __model_prefix__(self): """Returns the model prefix for the current experiment.""" - return f"{self.cfg.experiment.project}_{self.cfg.experiment.name}" + return f"{self.cfg.experiment.project}_{self.cfg.group}" diff --git a/edml/controllers/parallel_split_controller.py b/edml/controllers/parallel_split_controller.py index d250d21f4eccd914b2514121c8113f55b72dada8..895de9b52407d12ddb71b712f24a02a7843b984c 100644 --- a/edml/controllers/parallel_split_controller.py +++ b/edml/controllers/parallel_split_controller.py @@ -1,3 +1,7 @@ +from edml.controllers.adaptive_threshold_mechanism import AdaptiveThresholdFn +from edml.controllers.adaptive_threshold_mechanism.static import ( + StaticAdaptiveThresholdFn, +) from edml.controllers.base_controller import BaseController from edml.controllers.scheduler.base import NextServerScheduler from edml.helpers.config_helpers import get_device_index_by_id @@ -8,10 +12,12 @@ class ParallelSplitController(BaseController): self, cfg, scheduler: NextServerScheduler, + adaptive_threshold_fn: AdaptiveThresholdFn = StaticAdaptiveThresholdFn(0.0), ): super().__init__(cfg) scheduler.initialize(self) self._next_server_scheduler = scheduler + self._adaptive_threshold_fn = adaptive_threshold_fn def _train(self): client_weights = None @@ -43,10 +49,13 @@ class ParallelSplitController(BaseController): ) # Start parallel training of all client devices. + adaptive_threshold = self._adaptive_threshold_fn.invoke(i) + self.logger.log({"adaptive-threshold": adaptive_threshold}) training_response = self.request_dispatcher.train_parallel_on_server( server_device_id=server_device_id, epochs=1, round_no=i, + adaptive_learning_threshold=adaptive_threshold, optimizer_state=optimizer_state, ) @@ -64,6 +73,7 @@ class ParallelSplitController(BaseController): ) # log the server device index for convenience if training_response is False: # server device unavailable + print(f"Training response was false.") break else: cw, server_weights, metrics, _, optimizer_state = training_response @@ -72,6 +82,7 @@ class ParallelSplitController(BaseController): early_stop = self.early_stopping(metrics, i) if early_stop: + print(f"Early stopping triggered.") break self._save_weights( diff --git a/edml/core/client.py b/edml/core/client.py index c89c1f985c7407b47fec574d9d79a1d66e7914df..bcedd28435f91f237ff31cf3fc68b1aa1ddb2c07 100644 --- a/edml/core/client.py +++ b/edml/core/client.py @@ -279,7 +279,7 @@ class DeviceClient: train_batch_response is False or train_batch_response is None ): # server device unavailable break - server_grad, diagnostic_metrics = train_batch_response + server_grad, _server_loss, diagnostic_metrics = train_batch_response diagnostic_metric_container.merge(diagnostic_metrics) self.node_device.battery.update_flops( self._model_flops * len(batch_data) * 2 diff --git a/edml/core/device.py b/edml/core/device.py index ee12f4802a27717764b13e3da365566c37df6cbd..7fba252498c7ccdbfd7f11c2dcc9175ee8884dae 100644 --- a/edml/core/device.py +++ b/edml/core/device.py @@ -217,10 +217,17 @@ class NetworkDevice(Device): @update_battery @log_execution_time("logger", "train_parallel_split_learning") def train_parallel_split_learning( - self, clients: list[str], round_no: int, optimizer_state: dict[str, Any] = None + self, + clients: list[str], + round_no: int, + adaptive_learning_threshold: Optional[float] = None, + optimizer_state: dict[str, Any] = None ): return self.server.train_parallel_split_learning( - clients=clients, round_no=round_no, optimizer_state=optimizer_state + clients=clients, + round_no=round_no, + adaptive_learning_threshold=adaptive_learning_threshold, + optimizer_state=optimizer_state ) @update_battery @@ -446,10 +453,13 @@ class RPCDeviceServicer(DeviceServicer): def TrainBatch(self, request, context): activations = proto_to_tensor(request.smashed_data.activations) labels = proto_to_tensor(request.labels.labels) - gradients, diagnostic_metrics = self.device.train_batch(activations, labels) + gradients, loss, diagnostic_metrics = self.device.train_batch( + activations, labels + ) proto_gradients = Gradients(gradients=tensor_to_proto(gradients)) return connection_pb2.TrainBatchResponse( gradients=proto_gradients, + loss=loss, diagnostic_metrics=metrics_to_proto(diagnostic_metrics), ) @@ -528,10 +538,13 @@ class RPCDeviceServicer(DeviceServicer): print(f"Starting parallel split learning") clients = self.device.__get_device_ids__() round_no = request.round_no + adaptive_learning_threshold = request.adaptive_learning_threshold cw, sw, model_metrics, diagnostic_metrics, optimizer_state = ( self.device.train_parallel_split_learning( - clients=clients, round_no=round_no + clients=clients, + round_no=round_no, + adaptive_learning_threshold=adaptive_learning_threshold, ) ) response = connection_pb2.TrainGlobalParallelSplitLearningResponse( @@ -604,14 +617,18 @@ class DeviceRequestDispatcher: server_device_id: str, epochs: int, round_no: int, + adaptive_learning_threshold: Optional[float] = None, optimizer_state: dict[str, Any] = None, ): + print(f"><><><> {adaptive_learning_threshold}") + try: response: TrainGlobalParallelSplitLearningResponse = self._get_connection( server_device_id ).TrainGlobalParallelSplitLearning( connection_pb2.TrainGlobalParallelSplitLearningRequest( round_no=round_no, + adaptive_learning_threshold=adaptive_learning_threshold, optimizer_state=state_dict_to_proto(optimizer_state), ) ) @@ -787,10 +804,12 @@ class DeviceRequestDispatcher: response: TrainBatchResponse = self._get_connection(device_id).TrainBatch( request ) - return proto_to_tensor( - response.gradients.gradients - ), self._add_byte_size_to_diagnostic_metrics( - response, self.device_id, request + return ( + proto_to_tensor(response.gradients.gradients), + response.loss, + self._add_byte_size_to_diagnostic_metrics( + response, self.device_id, request + ), ) except grpc.RpcError: self._handle_rpc_error(device_id) diff --git a/edml/core/server.py b/edml/core/server.py index 02033d439dc6f24d25dabefcf705d7a62c55c248..7654a540f1854f3110d5897b34eecd5238476085 100644 --- a/edml/core/server.py +++ b/edml/core/server.py @@ -6,6 +6,7 @@ from typing import List, Optional, Tuple, Any, TYPE_CHECKING import torch from omegaconf import DictConfig +from colorama import init, Fore from torch import nn from torch.autograd import Variable @@ -135,7 +136,7 @@ class DeviceServer: ) @simulate_latency_decorator(latency_factor_attr="latency_factor") - def train_batch(self, smashed_data, labels) -> Variable: + def train_batch(self, smashed_data, labels) -> Tuple[Variable, float]: """Train the model on the given batch of data and labels. Returns the gradients of the model's parameters.""" smashed_data, labels = smashed_data.to(self._device), labels.to(self._device) @@ -158,7 +159,7 @@ class DeviceServer: self.node_device.log({"loss": loss_train.item()}) self._metrics.metrics_on_batch(output_train.cpu(), labels.cpu().int()) - return smashed_data.grad + return smashed_data.grad, loss_train.item() def _set_model_flops(self, smashed_data): """Helper to determine the model flops when smashed data are available for the first time.""" @@ -213,7 +214,11 @@ class DeviceServer: @simulate_latency_decorator(latency_factor_attr="latency_factor") def train_parallel_split_learning( - self, clients: List[str], round_no: int, optimizer_state: dict[str, Any] = None + self, + clients: List[str], + round_no: int, + adaptive_learning_threshold: Optional[float] = None, + optimizer_state: dict[str, Any] = None ): def client_training_job(client_id: str, batch_index: int) -> SLTrainBatchResult: return self.node_device.train_batch_on_client_only_on( @@ -260,9 +265,25 @@ class DeviceServer: # Train the part on the server. Then send the gradients to each client, continuing the calculation. We need # to split the gradients back into batch-sized tensors to average them before sending them to the client. - server_gradients, server_metrics = self.node_device.train_batch( - server_batch, server_labels + server_gradients, server_loss, server_metrics = ( + self.node_device.train_batch(server_batch, server_labels) ) # DiagnosticMetricResultContainer + + # We check if the server should activate the adaptive learning threshold. And if true, we make sure to only + # do the client propagation once the current loss value is larger than the threshold. + print( + f"\n{Fore.GREEN}{adaptive_learning_threshold} <-> {server_loss}\n{Fore.RESET}" + ) + if ( + adaptive_learning_threshold + and server_loss < adaptive_learning_threshold + ): + print( + f"\n{Fore.RED}ADAPTIVE TRESHOLD REACHED, NEXT BATCH\n{Fore.RESET}" + ) + self.node_device.log({"adaptive_learning_threshold_applied": True}) + continue + num_client_gradients = len(batches) print( f"::: tensor shape: {server_gradients.shape} -> {server_gradients.size(0)} with metrics: {server_metrics is not None}" @@ -306,6 +327,12 @@ class DeviceServer: model_metrics.add_results(val_metrics) optimizer_state = self._optimizer.state_dict() + if self._lr_scheduler is not None: + if round_no != -1: + self._lr_scheduler.step(round_no + 1) # epoch=1 + else: + self._lr_scheduler.step() + return ( self.node_device.client.get_weights(), self.get_weights(), @@ -316,8 +343,10 @@ class DeviceServer: def _concat_smashed_data(data: List[Any]) -> Any: + """Creates a single batch tensor from a list of tensors.""" return torch.cat(data, dim=0) def _empty_batches(batches): + """Checks if all the list entries are `None`.""" return batches.count(None) == len(batches) diff --git a/edml/generated/connection_pb2.pyi b/edml/generated/connection_pb2.pyi index 9cd30493daff87b1f94286dd5f117dfd863ae351..0baa1e9eccfea89fb971a4cebdbb7ed74f6be8a8 100644 --- a/edml/generated/connection_pb2.pyi +++ b/edml/generated/connection_pb2.pyi @@ -122,12 +122,14 @@ class TrainBatchRequest(_message.Message): def __init__(self, smashed_data: _Optional[_Union[_datastructures_pb2.Activations, _Mapping]] = ..., labels: _Optional[_Union[_datastructures_pb2.Labels, _Mapping]] = ...) -> None: ... class TrainBatchResponse(_message.Message): - __slots__ = ["gradients", "diagnostic_metrics"] + __slots__ = ["gradients", "diagnostic_metrics", "loss"] GRADIENTS_FIELD_NUMBER: _ClassVar[int] DIAGNOSTIC_METRICS_FIELD_NUMBER: _ClassVar[int] + LOSS_FIELD_NUMBER: _ClassVar[int] gradients: _datastructures_pb2.Gradients diagnostic_metrics: _datastructures_pb2.Metrics - def __init__(self, gradients: _Optional[_Union[_datastructures_pb2.Gradients, _Mapping]] = ..., diagnostic_metrics: _Optional[_Union[_datastructures_pb2.Metrics, _Mapping]] = ...) -> None: ... + loss: float + def __init__(self, gradients: _Optional[_Union[_datastructures_pb2.Gradients, _Mapping]] = ..., diagnostic_metrics: _Optional[_Union[_datastructures_pb2.Metrics, _Mapping]] = ..., loss: _Optional[float] = ...) -> None: ... class EvalGlobalRequest(_message.Message): __slots__ = ["validation", "federated"] diff --git a/edml/helpers/config_helpers.py b/edml/helpers/config_helpers.py index 63c7b6873c63553ced6ec637791070cc1b61258c..9902777106236e52128e2490817c1c7305bbf5e6 100644 --- a/edml/helpers/config_helpers.py +++ b/edml/helpers/config_helpers.py @@ -1,8 +1,9 @@ +from copy import deepcopy from inspect import signature import torch from hydra.utils import get_class, instantiate -from omegaconf import OmegaConf, DictConfig +from omegaconf import OmegaConf, DictConfig, ListConfig from omegaconf.errors import ConfigAttributeError from edml.controllers.base_controller import BaseController @@ -67,6 +68,57 @@ def get_device_index_by_id(cfg: DictConfig, device_id: str) -> int: ) +def _group_resolver(cfg: DictConfig, group_by: DictConfig): + """ + Resolver for the group_by attribute in the config. This attribute specifies which values to include in the group name. + Therefore, the values of group_by are parsed to retrieve the key paths to the desired values. + E.g. grouping by controller and scheduler name: + group_by: + - controller: [ name, scheduler: name] + yields the paths: ["controller", "name"] and ["controller", "scheduler", "name"] which are read from the cfg then. + + Args: + cfg (DictConfig): The full dict config. + group_by (DictConfig): The part of the config specifying which attributes to use for experiment grouping. + + Returns: + The name of the config group with underscores in between each value. + """ + + def __recurse__(group_by, attr_path: list): + """Retrieves the key paths to the desired attributes.""" + attr_paths = [] + if isinstance(group_by, DictConfig): + for k, v in group_by.items(): + if isinstance(v, str): + attr_paths.append(attr_path + [k] + [v]) + else: + attr_paths.extend(__recurse__(group_by[k], attr_path + [k])) + elif isinstance(group_by, ListConfig): + for idx, item in enumerate(group_by): + if isinstance(item, str): + attr_paths.append(attr_path + [item]) + else: + attr_paths.extend(__recurse__(group_by[idx], attr_path)) + + return attr_paths + + attr_paths = __recurse__(group_by, []) + # resolve each attribute + values = [] + for path in attr_paths: + value = cfg + for key in path: + if isinstance( + value, str + ): # if previous key was not found, value is the empty string + break + value = value.get(key, "") + values.append(value) + # concatenate and return + return "_".join(values) + + def preprocess_config(cfg: DictConfig): """ Configures `OmegaConf` and registers custom resolvers. Additionally, normalizes the configuration file for command @@ -75,8 +127,12 @@ def preprocess_config(cfg: DictConfig): - If `own_device_id` is an integer, the value is treated as an index into the list of available devices; it is treated as the i-th device inside the configured topology. This functions then looks up the device_id by index and sets `own_device_id`. + - resolves the group_name attribute specifying the composition of the experiment group name. """ OmegaConf.register_new_resolver("len", lambda x: len(x), replace=True) + OmegaConf.register_new_resolver( + "group_name", lambda group_by: _group_resolver(cfg, group_by), replace=True + ) OmegaConf.resolve(cfg) # In case someone specified an integer instead of a proper device_id (str), we look up the proper device by indexing @@ -85,6 +141,49 @@ def preprocess_config(cfg: DictConfig): cfg.own_device_id = get_device_id_by_index(cfg, cfg.own_device_id) +def __drop_irrelevant_keys__(cfg: DictConfig) -> DictConfig: + """ + Removes keys from config not needed to instantiate the specified _target_ class. + Assumes that cfg has a key _target_. Hydra keys _recursive_ and _partial_ are not removed. + + Args: + cfg: The controller configuration. + + Returns: + A DictConfig without unnecessary keys. + """ + controller_class = get_class(cfg._target_) + controller_signature = signature(controller_class.__init__) + controller_args = controller_signature.parameters.keys() + + # These are special hydra keywords that we do not want to filter out. + special_keys = ["_target_", "_recursive_", "_partial_"] + cfg = {k: v for k, v in cfg.items() if k in controller_args or k in special_keys} + return cfg + + +def drop_irrelevant_keys_recursively(cfg: DictConfig) -> DictConfig: + """ + Removes parameters that are not necessary to instantiate the specified classes. + This is done for the controller class as well as for scheduler and adaptive threshold if present. + This is needed because hydra's instantiation mechanism expects that all given parameters are actually needed. + + Args: + cfg: The controller configuration. + + Returns: + A DictConfig that contains only the parameters actually needed to instantiate the specified classes. + """ + cfg.controller = __drop_irrelevant_keys__(cfg.controller) + if cfg.controller.get("scheduler", False): + cfg.controller.scheduler = __drop_irrelevant_keys__(cfg.controller.scheduler) + if cfg.controller.get("adaptive_threshold_fn", False): + cfg.controller.adaptive_threshold_fn = __drop_irrelevant_keys__( + cfg.controller.adaptive_threshold_fn + ) + return cfg + + def instantiate_controller(cfg: DictConfig) -> BaseController: """ Instantiates a controller based on the configuration. This method filters out extra parameters defined through hydra @@ -97,27 +196,17 @@ def instantiate_controller(cfg: DictConfig) -> BaseController: Returns: An instance of `BaseController`. """ - + original_cfg = deepcopy(cfg) # Filter out any arguments not present in the controller constructor. This is a hack required to make multirun work. # We want to be able to use different scheduling strategies combined with different controllers. But hydra's # `instantiate` method is strict and fails if it receives any extra arguments. - controller_class = get_class(cfg.controller._target_) - controller_signature = signature(controller_class.__init__) - controller_args = controller_signature.parameters.keys() - - # These are special hydra keywords that we do not want to filter out. - special_keys = ["_target_", "_recursive_", "_partial_"] - cfg.controller = { - k: v - for k, v in cfg.controller.items() - if k in controller_args or k in special_keys - } + cfg = drop_irrelevant_keys_recursively(cfg) # Update the device ID and set it to controller. cfg.own_device_id = "controller" # Instantiate the controller. - controller: BaseController = instantiate(cfg.controller)(cfg=cfg) + controller: BaseController = instantiate(cfg.controller)(cfg=original_cfg) return controller diff --git a/edml/helpers/logging.py b/edml/helpers/logging.py index 44b12e59fae3ac15ca422d200f1a5c06ae2924ab..e6af989d8038f06ef197bfb0c860a81200c420cb 100644 --- a/edml/helpers/logging.py +++ b/edml/helpers/logging.py @@ -66,7 +66,7 @@ class WandbLogger(SimpleLogger): entity=self.cfg.wandb.entity, project=self.cfg.experiment.project, # project = set of experiments job_type=self.cfg.experiment.job, # train or test - group=self.cfg.experiment.name, # group runs by experiment name + group=self.cfg.group, name=self.cfg.own_device_id, # name runs by device id config=dict(self.cfg), ) diff --git a/edml/proto/connection.proto b/edml/proto/connection.proto index 51aa5fa7c2e3e5845894fa7395c8dac02a206474..6ed477dc344b6f5d4b4b64629daabeb9c8caecd4 100644 --- a/edml/proto/connection.proto +++ b/edml/proto/connection.proto @@ -44,7 +44,8 @@ message SingleBatchTrainingResponse { message TrainGlobalParallelSplitLearningRequest { optional int32 round_no = 1; - optional StateDict optimizer_state = 2; + optional double adaptive_learning_threshold = 2; + optional StateDict optimizer_state = 3; } message TrainGlobalParallelSplitLearningResponse { @@ -97,6 +98,7 @@ message TrainBatchRequest { message TrainBatchResponse { Gradients gradients = 1; optional Metrics diagnostic_metrics = 2; + optional double loss = 3; } message EvalGlobalRequest { diff --git a/edml/tests/core/device_test.py b/edml/tests/core/device_test.py index 54b0518fb0cac4deaf31fb17875df5f59b5f7a26..6080e7f406b77bc68ec503cf2a655d657fdd4677 100644 --- a/edml/tests/core/device_test.py +++ b/edml/tests/core/device_test.py @@ -188,6 +188,7 @@ class RPCDeviceServicerTest(unittest.TestCase): def test_train_batch(self): self.mock_device.train_batch.return_value = ( Tensor([42]), + 42.0, self.diagnostic_metrics, ) request = connection_pb2.TrainBatchRequest( @@ -197,9 +198,9 @@ class RPCDeviceServicerTest(unittest.TestCase): response, metadata, code, details = self.make_call("TrainBatch", request) - self.assertEqual(proto_to_tensor(response.gradients.gradients), Tensor([42])) self.assertEqual(code, StatusCode.OK) self.mock_device.train_batch.assert_called_once_with(Tensor([1.0]), Tensor([1])) + self.assertEqual(proto_to_tensor(response.gradients.gradients), Tensor([42])) self.assertEqual( proto_to_metrics(response.diagnostic_metrics), self.diagnostic_metrics ) @@ -585,13 +586,15 @@ class RequestDispatcherTest(unittest.TestCase): self.mock_stub.TrainBatch.return_value = connection_pb2.TrainBatchResponse( gradients=gradients_to_proto(self.gradients), diagnostic_metrics=metrics_to_proto(self.diagnostic_metrics), + loss=42.0, ) - gradients, diagnostic_metrics = self.dispatcher.train_batch_on( + gradients, loss, diagnostic_metrics = self.dispatcher.train_batch_on( "1", self.activations, self.labels ) self.assertEqual(gradients, self.gradients) + self.assertEqual(loss, 42.0) self._assert_field_size_added_to_diagnostic_metrics(diagnostic_metrics) self.mock_stub.TrainBatch.assert_called_once_with( connection_pb2.TrainBatchRequest( diff --git a/edml/tests/helpers/config_helpers_test.py b/edml/tests/helpers/config_helpers_test.py index 41f18361a36eb62464202fd2b91c5bc6bb2043a5..9279aeb714b84efdb27fd9895fca65393fa0b650 100644 --- a/edml/tests/helpers/config_helpers_test.py +++ b/edml/tests/helpers/config_helpers_test.py @@ -1,13 +1,17 @@ +import os import unittest from unittest.mock import patch -from omegaconf import DictConfig +from omegaconf import DictConfig, OmegaConf +from edml.controllers.parallel_split_controller import ParallelSplitController from edml.helpers.config_helpers import ( get_device_id_by_index, get_device_address_by_id, preprocess_config, get_device_index_by_id, + instantiate_controller, + __drop_irrelevant_keys__, get_torch_device_id, ) @@ -25,6 +29,18 @@ class ConfigHelpersTest(unittest.TestCase): ] }, "num_devices": "${len:${topology.devices}}", + "controller": { + "name": "swarm", + "scheduler": {"name": "max_battery"}, + "adaptive_threshold_fn": {"name": "static"}, + }, + "group_by": { + "controller": [ + "name", + {"scheduler": "name", "adaptive_threshold_fn": "name"}, + ], + }, + "group": "${group_name:${group_by}}", } ) @@ -57,6 +73,45 @@ class ConfigHelpersTest(unittest.TestCase): with patch("torch.cuda.is_available", return_value=False): self.assertEqual(get_torch_device_id(self.cfg), "cpu") + def test_preprocess_config_group_name(self): + preprocess_config(self.cfg) + self.assertEqual(self.cfg.group, "swarm_max_battery_static") + + +class ControllerInstantiationTest(unittest.TestCase): + def setUp(self) -> None: + self.cfg = OmegaConf.create({"some_key": "some_value"}) + self.cfg.controller = OmegaConf.load( + os.path.join( + os.path.dirname(__file__), + "../../../edml/config/controller/parallel_swarm.yaml", + ) + ) + self.cfg.controller.scheduler = OmegaConf.load( + os.path.join( + os.path.dirname(__file__), + "../../../edml/config/controller/scheduler/max_battery.yaml", + ) + ) + + def test_parallel_split_controller_with_max_battery_instantiation(self): + with patch( + "edml.controllers.base_controller.BaseController.__init__" + ): # Avoid initializing the base_controller for brevity + with patch( + "edml.controllers.base_controller.BaseController._get_device_ids" + ) as _get_device_ids: # needed by scheduler + _get_device_ids.return_value = ["d0"] + controller = instantiate_controller(self.cfg) + self.assertIsInstance(controller, ParallelSplitController) + + def test_drop_irrelevant_keys(self): + self.cfg.controller["irrelevant_key"] = "some value" + reduced_cfg = __drop_irrelevant_keys__(self.cfg.controller) + self.assertListEqual( + list(reduced_cfg.keys()), ["_target_", "_partial_", "scheduler"] + ) + class GetTorchDeviceIdTest(unittest.TestCase):