diff --git a/config/default.yaml b/config/default.yaml index 6ff71a77b1310a81a84df1bd7bfd040813ee1f9d..dac99d0ff6f6b963c60d9cc8e16b0c82df22d217 100644 --- a/config/default.yaml +++ b/config/default.yaml @@ -14,6 +14,9 @@ defaults: - wandb: default - _self_ +# If true, controllers will run devices in parallel. If false, they will run sequentially and their runtime is corrected +# to account for the parallelism in post-processing. +simulate_parallelism: False own_device_id: "d0" num_devices: ${len:${topology.devices}} diff --git a/edml/config/battery/resnet110_cifar100_cost.yaml b/edml/config/battery/resnet110_cifar100_cost.yaml index d1bcdb98d5d0d98d49eedff71d3cf12b9e1bf9d9..424c8d81b8db9ee8898edb91911d5d9b1f6c26cb 100644 --- a/edml/config/battery/resnet110_cifar100_cost.yaml +++ b/edml/config/battery/resnet110_cifar100_cost.yaml @@ -1,4 +1,4 @@ -deduction_per_second: 0.005 -deduction_per_mflop: 0.00000005 -deduction_per_mbyte_received: 0.0002 -deduction_per_mbyte_sent: 0.0002 +deduction_per_second: 0.007 +deduction_per_mflop: 0.00000001 +deduction_per_mbyte_received: 0.00005 +deduction_per_mbyte_sent: 0.00005 diff --git a/edml/config/topology/resnet110_cifar100_batteries.yaml b/edml/config/topology/resnet110_cifar100_batteries.yaml index 3a81e0087447bd83066bf31840321d1cc2b4413d..61838e0fb17a5e720154b4acb0b10496decf48c9 100644 --- a/edml/config/topology/resnet110_cifar100_batteries.yaml +++ b/edml/config/topology/resnet110_cifar100_batteries.yaml @@ -2,26 +2,26 @@ devices: [ { device_id: "d0", address: "localhost:50051", - battery_capacity: 750, + battery_capacity: 300, }, { device_id: "d1", address: "localhost:50052", - battery_capacity: 750 + battery_capacity: 300 }, { device_id: "d2", address: "localhost:50053", - battery_capacity: 600 + battery_capacity: 300 }, { device_id: "d3", address: "localhost:50054", - battery_capacity: 600 + battery_capacity: 300 }, { device_id: "d4", address: "localhost:50055", - battery_capacity: 600 + battery_capacity: 300 } ] diff --git a/edml/controllers/fed_controller.py b/edml/controllers/fed_controller.py index cdf6b6e86d2df4d59f5ea025288fb5e28d3d4efc..250e02c0307e002c2d84c50296a752ec651a8d74 100644 --- a/edml/controllers/fed_controller.py +++ b/edml/controllers/fed_controller.py @@ -1,10 +1,13 @@ import concurrent.futures +import contextlib +import functools import threading from typing import Dict, List from overrides import override from edml.controllers.base_controller import BaseController +from edml.helpers.decorators import Timer from edml.helpers.metrics import ModelMetricResultContainer @@ -35,6 +38,14 @@ def fed_average(model_weights: list[Dict], weighting_scheme: List[float] = None) return None +def simulate_parallelism(f): + @functools.wraps(f) + def inner(self): + pass + + return inner + + class FedController(BaseController): """Controller for federated learning.""" @@ -43,40 +54,66 @@ class FedController(BaseController): def _fed_train_round(self, round_no: int = -1): """Returns new client and server weights.""" - client_weights_lock = threading.Lock() - server_weights_lock = threading.Lock() - samples_count_lock = threading.Lock() - metrics_lock = threading.Lock() client_weights = [] server_weights = [] samples_count = [] metrics_container = ModelMetricResultContainer() - with concurrent.futures.ThreadPoolExecutor( - max_workers=max(len(self.active_devices), 1) - ) as executor: # avoid exception when setting 0 workers - futures = [ - executor.submit( - self.request_dispatcher.federated_train_on, device_id, round_no - ) - for device_id in self.active_devices - ] - for future in concurrent.futures.as_completed(futures): - response = future.result() - if response is not False: - new_client_weights, new_server_weights, num_samples, metrics, _ = ( - response # skip diagnostic metrics + + if self.cfg.simulate_parallelism: + parallel_times = [] + with Timer() as elapsed_time: + for device_id in self.active_devices: + with Timer() as individual_time: + response = self.request_dispatcher.federated_train_on( + device_id, round_no + ) + if response is not False: + ( + new_client_weights, + new_server_weights, + num_samples, + metrics, + _, + ) = response # skip diagnostic metrics + client_weights.append(new_client_weights) + server_weights.append(new_server_weights) + samples_count.append(num_samples) + metrics_container.merge(metrics) + parallel_times.append(individual_time.execution_time) + self.logger.log( + { + "parallel_fed_time": { + "elapsed_time": elapsed_time.execution_time, + "parallel_time": max(parallel_times), + } + } + ) + else: + with concurrent.futures.ThreadPoolExecutor( + max_workers=max(len(self.active_devices), 1) + ) as executor: # avoid exception when setting 0 workers + futures = [ + executor.submit( + self.request_dispatcher.federated_train_on, device_id, round_no ) - with client_weights_lock: + for device_id in self.active_devices + ] + for future in concurrent.futures.as_completed(futures): + response = future.result() + if response is not False: + ( + new_client_weights, + new_server_weights, + num_samples, + metrics, + _, + ) = response # skip diagnostic metrics client_weights.append(new_client_weights) - with server_weights_lock: server_weights.append(new_server_weights) - with samples_count_lock: samples_count.append(num_samples) - with metrics_lock: metrics_container.merge(metrics) print(f"samples count {samples_count}") - return ( fed_average(model_weights=client_weights, weighting_scheme=samples_count), fed_average(model_weights=server_weights, weighting_scheme=samples_count), diff --git a/edml/core/device.py b/edml/core/device.py index 94bb18ac35241fc86eec778b8ef2d9145a8ed51f..6770e4271e8dc030af1abc7f8ac010a57242f5bb 100644 --- a/edml/core/device.py +++ b/edml/core/device.py @@ -225,6 +225,7 @@ class Device(ABC): class NetworkDevice(Device): @update_battery + @log_execution_time("logger", "finalize_gradients") def set_gradient_and_finalize_training_on_client_only_on( self, client_id: str, gradients: Any ): @@ -404,6 +405,7 @@ class NetworkDevice(Device): @add_time_to_diagnostic_metrics("evaluate_batch") @update_battery + @log_execution_time("logger", "evaluate_batch_time") def evaluate_batch(self, smashed_data, labels): result = self.server.evaluate_batch(smashed_data, labels) self._log_current_battery_capacity() diff --git a/edml/core/server.py b/edml/core/server.py index 4b9ef80fd876d618075a98a14d61e03d83369b66..4279f40301205c361a8648fe907d9084bd159ac5 100644 --- a/edml/core/server.py +++ b/edml/core/server.py @@ -10,7 +10,7 @@ from torch import nn from torch.autograd import Variable from edml.helpers.config_helpers import get_torch_device_id -from edml.helpers.decorators import check_device_set, simulate_latency_decorator +from edml.helpers.decorators import check_device_set, simulate_latency_decorator, Timer from edml.helpers.executor import create_executor_with_threads from edml.helpers.flops import estimate_model_flops from edml.helpers.load_optimizer import get_optimizer_and_scheduler @@ -260,101 +260,245 @@ class DeviceServer: # We iterate over each batch, initializing all client training at once and processing the results afterward. num_batches = self.node_device.client.get_approximated_num_batches() print(f":: BATCHES :: {num_batches}") - for batch_index in range(num_batches): - client_forward_pass_responses = [] - futures = [ - executor.submit(client_training_job, client_id, batch_index) - for client_id in clients - ] - for future in concurrent.futures.as_completed(futures): - (client_id, result) = future.result() - if result is not None and result is not False: - client_forward_pass_responses.append((client_id, result)) - - # We want to split up the responses into a list of client IDs and batches again. - client_ids = [b[0] for b in client_forward_pass_responses] - client_batches = [b[1] for b in client_forward_pass_responses] - - server_batch = _concat_smashed_data( - [b[0].to(self._device) for b in client_batches] - ) - server_labels = _concat_smashed_data( - [b[1].to(self._device) for b in client_batches] - ) - # Train the part on the server. Then send the gradients to each client, continuing the calculation. We need - # to split the gradients back into batch-sized tensors to average them before sending them to the client. - server_gradients, server_loss, server_metrics = ( - self.node_device.train_batch(server_batch, server_labels) - ) # DiagnosticMetricResultContainer - # We check if the server should activate the adaptive learning threshold. And if true, we make sure to only - # do the client propagation once the current loss value is larger than the threshold. - print( - f"\n{Fore.GREEN}{adaptive_learning_threshold} <-> {server_loss}\n{Fore.RESET}" - ) - if ( - adaptive_learning_threshold - and server_loss < adaptive_learning_threshold - ): + + if self._cfg.simulate_parallelism: + for batch_index in range(num_batches): + client_forward_pass_responses = [] + parallel_times = [] + with Timer() as elapsed_time: + for client_id in clients: + with Timer() as individual_time: + (client_id, result) = client_training_job( + client_id, batch_index + ) + if result is not None and result is not False: + client_forward_pass_responses.append( + (client_id, result) + ) + parallel_times.append(individual_time.execution_time) + self.node_device.log( + { + "parallel_client_train_time": { + "elapsed_time": elapsed_time.execution_time, + "parallel_time": max(parallel_times), + } + } + ) + # We want to split up the responses into a list of client IDs and batches again. + client_ids = [b[0] for b in client_forward_pass_responses] + client_batches = [b[1] for b in client_forward_pass_responses] + + server_batch = _concat_smashed_data( + [b[0].to(self._device) for b in client_batches] + ) + server_labels = _concat_smashed_data( + [b[1].to(self._device) for b in client_batches] + ) + # Train the part on the server. Then send the gradients to each client, continuing the calculation. We need + # to split the gradients back into batch-sized tensors to average them before sending them to the client. + server_gradients, server_loss, server_metrics = ( + self.node_device.train_batch(server_batch, server_labels) + ) # DiagnosticMetricResultContainer + # We check if the server should activate the adaptive learning threshold. And if true, we make sure to only + # do the client propagation once the current loss value is larger than the threshold. print( - f"\n{Fore.RED}ADAPTIVE TRESHOLD REACHED, NEXT BATCH\n{Fore.RESET}" + f"\n{Fore.GREEN}{adaptive_learning_threshold} <-> {server_loss}\n{Fore.RESET}" + ) + if ( + adaptive_learning_threshold + and server_loss < adaptive_learning_threshold + ): + print( + f"\n{Fore.RED}ADAPTIVE TRESHOLD REACHED, NEXT BATCH\n{Fore.RESET}" + ) + self.node_device.log( + { + "adaptive_learning_threshold_applied": server_gradients.size( + 0 + ) + } + ) + continue + + num_client_gradients = len(client_forward_pass_responses) + print( + f"::: tensor shape: {server_gradients.shape} -> {server_gradients.size(0)} with metrics: {server_metrics is not None}" + ) + # clone single client gradients so that client_gradients is not a list of views of server_gradients + # if we just use torch.chunk, each client will receive the whole server_gradients + client_gradients = [ + t.clone().detach() + for t in torch.chunk(server_gradients, num_client_gradients) + ] + client_backpropagation_gradients = [] + parallel_times = [] + with Timer() as elapsed_time: + for idx, client_id in enumerate(client_ids): + with Timer() as individual_time: + _, grads = client_backpropagation_job( + client_id, client_gradients[idx] + ) + if grads is not None and grads is not False: + client_backpropagation_gradients.append(grads) + parallel_times.append(individual_time.execution_time) + self.node_device.log( + { + "parallel_client_backprop_time": { + "elapsed_time": elapsed_time.execution_time, + "parallel_time": max(parallel_times), + } + } + ) + # We want to average the client's backpropagation gradients and send them over again to finalize the + # current training step. + averaged_gradient = _calculate_gradient_mean( + client_backpropagation_gradients, self._device ) + parallel_times = [] + with Timer() as elapsed_time: + for client_id in clients: + with Timer() as individual_time: + client_set_gradient_and_finalize_training_job( + client_id, averaged_gradient + ) + parallel_times.append(individual_time.execution_time) self.node_device.log( - {"adaptive_learning_threshold_applied": server_gradients.size(0)} + { + "parallel_client_model_update_time": { + "elapsed_time": elapsed_time.execution_time, + "parallel_time": max(parallel_times), + } + } + ) + else: + for batch_index in range(num_batches): + client_forward_pass_responses = [] + futures = [ + executor.submit(client_training_job, client_id, batch_index) + for client_id in clients + ] + for future in concurrent.futures.as_completed(futures): + (client_id, result) = future.result() + if result is not None and result is not False: + client_forward_pass_responses.append((client_id, result)) + + # We want to split up the responses into a list of client IDs and batches again. + client_ids = [b[0] for b in client_forward_pass_responses] + client_batches = [b[1] for b in client_forward_pass_responses] + + server_batch = _concat_smashed_data( + [b[0].to(self._device) for b in client_batches] ) - continue + server_labels = _concat_smashed_data( + [b[1].to(self._device) for b in client_batches] + ) + # Train the part on the server. Then send the gradients to each client, continuing the calculation. We need + # to split the gradients back into batch-sized tensors to average them before sending them to the client. + server_gradients, server_loss, server_metrics = ( + self.node_device.train_batch(server_batch, server_labels) + ) # DiagnosticMetricResultContainer + # We check if the server should activate the adaptive learning threshold. And if true, we make sure to only + # do the client propagation once the current loss value is larger than the threshold. + print( + f"\n{Fore.GREEN}{adaptive_learning_threshold} <-> {server_loss}\n{Fore.RESET}" + ) + if ( + adaptive_learning_threshold + and server_loss < adaptive_learning_threshold + ): + print( + f"\n{Fore.RED}ADAPTIVE TRESHOLD REACHED, NEXT BATCH\n{Fore.RESET}" + ) + self.node_device.log( + { + "adaptive_learning_threshold_applied": server_gradients.size( + 0 + ) + } + ) + continue - num_client_gradients = len(client_forward_pass_responses) - print( - f"::: tensor shape: {server_gradients.shape} -> {server_gradients.size(0)} with metrics: {server_metrics is not None}" - ) - # clone single client gradients so that client_gradients is not a list of views of server_gradients - # if we just use torch.chunk, each client will receive the whole server_gradients - client_gradients = [ - t.clone().detach() - for t in torch.chunk(server_gradients, num_client_gradients) - ] - futures = [ - executor.submit( - client_backpropagation_job, client_id, client_gradients[idx] + num_client_gradients = len(client_forward_pass_responses) + print( + f"::: tensor shape: {server_gradients.shape} -> {server_gradients.size(0)} with metrics: {server_metrics is not None}" ) - for (idx, client_id) in enumerate(client_ids) - ] - client_backpropagation_gradients = [] - for future in concurrent.futures.as_completed(futures): - _, grads = future.result() - if grads is not None and grads is not False: - client_backpropagation_gradients.append(grads) - # We want to average the client's backpropagation gradients and send them over again to finalize the - # current training step. - averaged_gradient = _calculate_gradient_mean( - client_backpropagation_gradients, self._device - ) - futures = [ - executor.submit( - client_set_gradient_and_finalize_training_job, - client_id, - averaged_gradient, + # clone single client gradients so that client_gradients is not a list of views of server_gradients + # if we just use torch.chunk, each client will receive the whole server_gradients + client_gradients = [ + t.clone().detach() + for t in torch.chunk(server_gradients, num_client_gradients) + ] + futures = [ + executor.submit( + client_backpropagation_job, client_id, client_gradients[idx] + ) + for (idx, client_id) in enumerate(client_ids) + ] + client_backpropagation_gradients = [] + for future in concurrent.futures.as_completed(futures): + _, grads = future.result() + if grads is not None and grads is not False: + client_backpropagation_gradients.append(grads) + # We want to average the client's backpropagation gradients and send them over again to finalize the + # current training step. + averaged_gradient = _calculate_gradient_mean( + client_backpropagation_gradients, self._device ) - for client_id in clients - ] - for future in concurrent.futures.as_completed(futures): - future.result() + futures = [ + executor.submit( + client_set_gradient_and_finalize_training_job, + client_id, + averaged_gradient, + ) + for client_id in clients + ] + for future in concurrent.futures.as_completed(futures): + future.result() # Now we have to determine the model metrics for each client. - for client_id in clients: - train_metrics = self.finalize_metrics(str(client_id), "train") - evaluation_diagnostics_metrics = self.node_device.evaluate_on( - device_id=client_id, - server_device=self.node_device.device_id, - val=True, + if self._cfg.simulate_parallelism: + parallel_times = [] + with Timer() as elapsed_time: + for client_id in clients: + with Timer() as individual_time: + train_metrics = self.finalize_metrics(str(client_id), "train") + + evaluation_diagnostics_metrics = self.node_device.evaluate_on( + device_id=client_id, + server_device=self.node_device.device_id, + val=True, + ) + # if evaluation_diagnostics_metrics: + # diagnostic_metrics.merge(evaluation_diagnostics_metrics) + val_metrics = self.finalize_metrics(str(client_id), "val") + + model_metrics.add_results(train_metrics) + model_metrics.add_results(val_metrics) + parallel_times.append(individual_time.execution_time) + self.node_device.log( + { + "parallel_client_eval_time": { + "elapsed_time": elapsed_time.execution_time, + "parallel_time": max(parallel_times), + } + } ) - # if evaluation_diagnostics_metrics: - # diagnostic_metrics.merge(evaluation_diagnostics_metrics) - val_metrics = self.finalize_metrics(str(client_id), "val") + else: + for client_id in clients: + train_metrics = self.finalize_metrics(str(client_id), "train") + + evaluation_diagnostics_metrics = self.node_device.evaluate_on( + device_id=client_id, + server_device=self.node_device.device_id, + val=True, + ) + # if evaluation_diagnostics_metrics: + # diagnostic_metrics.merge(evaluation_diagnostics_metrics) + val_metrics = self.finalize_metrics(str(client_id), "val") - model_metrics.add_results(train_metrics) - model_metrics.add_results(val_metrics) + model_metrics.add_results(train_metrics) + model_metrics.add_results(val_metrics) optimizer_state = self._optimizer.state_dict() # delete references and free GPU memory manually diff --git a/edml/helpers/decorators.py b/edml/helpers/decorators.py index 5623648311abd9e63c51281af657649ea950459e..6f4a8da8912f989163ddca9f829b428800db3749 100644 --- a/edml/helpers/decorators.py +++ b/edml/helpers/decorators.py @@ -165,6 +165,22 @@ class LatencySimulator: time.sleep(self.execution_time * self.latency_factor) +class Timer: + """ + Context Manager to measure execution time. + + Notes: + Access execution time via Timer.execution_time. + """ + + def __enter__(self): + self.start_time = time.perf_counter() + return self + + def __exit__(self, exc_type, exc_val, exc_tb): + self.execution_time = time.perf_counter() - self.start_time + + def add_time_to_diagnostic_metrics(method_name: str): """ A decorator factory that measures the execution time of the wrapped method. It then creates a diagnostic metric diff --git a/results/result_generation.ipynb b/results/result_generation.ipynb index 85b37ddb464b85b1e634dcbcde25c0a221baa9fb..5c3ed118508856d66ca7b70f0027b99a2cb8052f 100644 --- a/results/result_generation.ipynb +++ b/results/result_generation.ipynb @@ -29,30 +29,15 @@ "source": [ "df_base_dir = \"./dataframes\"\n", "projects_with_model = [\n", - " (\"greedy_vs_smart_ecg-non-iid_RESULT\", \"tcn\"),\n", - " (\"greedy_vs_smart_cifar100_resnet20_RESULT\", \"resnet20\"),\n", - " (\"greedy_vs_smart_ecg-iid_RESULT\", \"tcn\"),\n", - " (\"greedy_vs_smart_PTBXL_equal_devices_RESULT\", \"tcn\"),\n", - " (\"greedy_vs_smart_PTBXL_unequal_processors_RESULT\", \"tcn\"),\n", - " (\"greedy_vs_smart_MNIST_unequal_processors_RESULT\", \"simple_conv\"),\n", - " (\"greedy_vs_smart_MNIST_unequal_batteries_unequal_partition_RESULT\", \"simple_conv\"),\n", - " (\"greedy_vs_smart_MNIST_equal_devices_RESULT\", \"simple_conv\"),\n", - " (\"greedy_vs_smart_MNIST_unequal_batteries_RESULT\", \"simple_conv\"),\n", - " (\"fed_vs_split_MNIST_limited_batteries_RESULT\", \"simple_conv\"),\n", - " (\"fed_vs_split_MNIST_unlimited_batteries_RESULT\", \"simple_conv\"),\n", - " (\"fed_vs_split_PTBXL_limited_batteries_RESULT\", \"tcn\"),\n", - " (\"fed_vs_split_PTBXL_unlimited_batteries_RESULT\", \"tcn\"),\n", - " (\"fed_vs_split_cifar100_unlimited_batteries_RESULT\", \"resnet20\"),\n", - " (\"fed_vs_split_CIFAR100_limited_batteries_RESULT\", \"resnet20\"),\n", - " (\"fed_vs_split_50_devices_RESULT\", \"resnet110\"),\n", - " (\"greedy_vs_smart_CIFAR100_equal_devices_RESULT\", \"resnet20\"),\n", - " (\"greedy_vs_smart_CIFAR100_unequal_processors_RESULT\", \"resnet20\"),\n", + " (\"5_devices_unlimited_new\", \"resnet110\"),\n", + " (\"50_devices_unlimited_new\", \"resnet110\"),\n", + " (\"controller_comparison\", \"resnet110\")\n", "]" ], "metadata": { "collapsed": false }, - "id": "6695251b9af7ea4b" + "id": "5b81c8c9ba4b483d" }, { "cell_type": "code", @@ -61,18 +46,33 @@ "source": [ "for project_name, _ in projects_with_model:\n", " save_dataframes(project_name=project_name, strategies=[\n", - " \"swarm_seq\",\n", - " \"fed\",\n", - " \"swarm_max\",\n", - " \"swarm_rand\",\n", - " \"swarm_smart\",\n", - " \"split\"\n", + " #\"swarm_seq\",\n", + " #\"fed\",\n", + " #\"swarm_max\",\n", + " #\"swarm_rand\",\n", + " #\"swarm_smart\",\n", + " #\"split\",\n", + " #\"psl_rand_\",\n", + " #\"psl_sequential_\",\n", + " #\"psl_max_batteries_\",\n", + " #\"swarm_rand_\",\n", + " #\"swarm_sequential_\",\n", + " #\"swarm_max_batteries_\",\n", + " \"psl_sequential__\",\n", + " \"fed___\",\n", + " \"split___\",\n", + " \"swarm_sequential__\",\n", + " \"swarm_max_battery__\",\n", + " \"swarm_smart__\",\n", + " \"psl_sequential_static_at_resnet_decoderpth\",\n", + " \"psl_sequential__resnet_decoderpth\",\n", + " \"psl_sequential_static_at_\",\n", " ])" ], "metadata": { "collapsed": false }, - "id": "b07913828b33ffcc" + "id": "118f1ed9e7537718" }, { "cell_type": "markdown", @@ -82,7 +82,7 @@ "metadata": { "collapsed": false }, - "id": "d8269abd823cdcc7" + "id": "bbc47124f3c80f1c" }, { "cell_type": "code", @@ -92,28 +92,45 @@ "# Required for total number of FLOPs computation\n", "model_flops = {\n", " \"resnet20\": 41498880,\n", + " \"resnet20_ae\": 45758720,\n", " \"resnet110\": 258136320,\n", + " \"resnet110_ae\": 262396160,\n", " \"tcn\": 27240000,\n", " \"simple_conv\": 16621560\n", - "}" + "}\n", + "\n", + "client_model_flops = {\n", + " \"resnet20\": 15171584,\n", + " \"resnet20_ae\": 19005440,\n", + " \"resnet110\": 88408064,\n", + " \"resnet110_ae\": 92241920,\n", + "}\n", + "\n", + "server_model_flops = {\n", + " \"resnet20\": 26327296,\n", + " \"resnet20_ae\": 26753280,\n", + " \"resnet110\": 169728256,\n", + " \"resnet110_ae\": 170154240,\n", + "}\n", + "experiment_batch_size = 64" ], "metadata": { "collapsed": false }, - "id": "828bcb4737b21c6d" + "id": "6e4e3bd0198fe7e7" }, { "cell_type": "code", "execution_count": null, "outputs": [], "source": [ - "plots_base_path=\"./plots\"\n", - "metrics_base_path=\"./metrics\"" + "plots_base_path = \"./plots\"\n", + "metrics_base_path = \"./metrics\"" ], "metadata": { "collapsed": false }, - "id": "b13f9e0e98b7ac5b" + "id": "ede70693af668f55" }, { "cell_type": "code", @@ -126,14 +143,16 @@ " print(\" loading data from disk\")\n", " dataframes = load_dataframes(proj_name, df_base_dir)\n", " print(\" generating metrics\")\n", - " generate_metric_files(dataframes, proj_name, model_flops[model_name])\n", + " generate_metric_files(dataframes, proj_name, model_flops[model_name], client_model_flops[model_name],\n", + " # TODO distinguish AE\n", + " base_path=metrics_base_path, batch_size=experiment_batch_size)\n", " print(\" generating plots\")\n", " generate_plots(dataframes, proj_name)" ], "metadata": { "collapsed": false }, - "id": "c4b1ed2d809c54e2" + "id": "1c72379feadc98cb" }, { "cell_type": "code", @@ -145,7 +164,17 @@ "metadata": { "collapsed": false }, - "id": "378bf3365dd9fde2" + "id": "7927831aecd5a02" + }, + { + "cell_type": "code", + "execution_count": null, + "outputs": [], + "source": [], + "metadata": { + "collapsed": false + }, + "id": "fc698fc664867532" } ], "metadata": { diff --git a/results/result_generation.py b/results/result_generation.py index b95183ebae7218c3c6ae7453b6bc103d69ba163e..caa42642976785d82f0313733b324c9cd63b1166 100644 --- a/results/result_generation.py +++ b/results/result_generation.py @@ -5,6 +5,7 @@ import matplotlib.pyplot as plt import numpy as np import pandas as pd import wandb +import math # For plotting STRATEGY_MAPPING = { @@ -14,6 +15,16 @@ STRATEGY_MAPPING = { "swarm_rand": "Swarm SL (Rand)", "swarm_max": "Swarm SL (Greedy)", "fed": "Vanilla FL", + "psl_sequential__": "PSSL (Seq)", + "fed___": "Vanilla FL", + "swarm_sequential__": "Swarm SL (Seq)", + "swarm_smart__": "Swarm SL (Smart)", + "swarm_rand__": "Swarm SL (Rand)", + "swarm_max_battery__": "Swarm SL (Greedy)", + "split___": "Vanilla SL", + "psl_sequential_static_at_resnet_decoderpth": "PSSL (Seq) AE Static", + "psl_sequential__resnet_decoderpth": "PSSL (Seq) AE", + "psl_sequential_static_at_": "PSSL (Seq) Static", } LABEL_MAPPING = { @@ -29,6 +40,75 @@ LABEL_MAPPING = { } +def scale_parallel_time(run_df, scale_factor=1.0): + """ + Scales the time by the provided scale_factor. + Args: + run_df: The dataframe of the project + scale_factor: (float) the factor to shorten time e.g. 2 halves the total time + Returns: + run_df: the dataframe with scaled timestamps + """ + if scale_factor == 1: + return run_df + if "_timestamp" in run_df.columns: + start_time = run_df["_timestamp"].min() + for col in run_df.columns: + if col.endswith(".start") or col.endswith(".end") or col == "_timestamp": + run_df[col] = (run_df[col] - start_time) / scale_factor + start_time + if col.endswith(".duration") or col == "_runtime": + run_df[col] = run_df[col] / scale_factor + return run_df + + +def get_scale_factors(group): + """ + Determines the scale factor to account for parallelism. + For each set of runs (i.e. one run of controller, d0, d1, ...), the time overhead introduced by running a + parallel operation sequentially is determined and the resulting factor to scale the runs down as well. + If no parallel operations were simulated, no time is deduced and the scale factor will equal 1. + Args: + group: the group of runs + Returns: + A list of factors to scale down each set of runs. + """ + columns_to_count = [ + "parallel_client_train_time", + "parallel_client_backprop_time", + "parallel_client_model_update_time", + "parallel_fed_time", + ] + scale_factors = [] + num_runs = len(next(iter(group.values()))) + max_runtime = [0] * num_runs + elapsed_time = [0] * num_runs + parallel_time = [0] * num_runs + for name, runs in group.items(): + for i, run_df in enumerate(runs): + if "_runtime" in run_df.columns: # assure that run_df is not empty + if run_df["_runtime"].max() > max_runtime[i]: + max_runtime[i] = run_df["_runtime"].max() + for col_name in columns_to_count: + if f"{col_name}.parallel_time" in run_df.columns: + elapsed_time[i] += run_df[f"{col_name}.elapsed_time"].sum() + parallel_time[i] += run_df[f"{col_name}.parallel_time"].sum() + if "parallel_client_eval_time.parallel_time" in run_df.columns: + if "evaluate_batch_time.duration" in run_df.columns: + elapsed_time[i] += run_df[ + "parallel_client_eval_time.elapsed_time" + ].sum() + parallel_time[i] += ( + run_df["parallel_client_eval_time.parallel_time"].sum() + - run_df["evaluate_batch_time.duration"].sum() + ) # evaluate batch time measured at server -> sequential either way + for i, max_rt in enumerate(max_runtime): + if max_rt > 0: + scale_factors.append(max_rt / (max_rt - elapsed_time[i] + parallel_time[i])) + else: + scale_factors.append(1.0) + return scale_factors + + def save_dataframes(project_name, strategies, base_dir="./dataframes"): """ Fetches the dataframes from wandb and saves them to the base_dir. @@ -81,13 +161,22 @@ def save_dataframes(project_name, strategies, base_dir="./dataframes"): history_groups = {} for (strategy, job), group in run_groups.items(): print(f" {strategy} {job}") - history = defaultdict(list) + unscaled_runs = defaultdict(list) for name, runs in group.items(): print(f" {name}") for run in runs: history_df = pd.DataFrame(run.scan_history()) - history[name].append(history_df) - history_groups[(strategy, job)] = history + unscaled_runs[name].append(history_df) + # rescale if parallelism was only simulated + if job == "train" and len(unscaled_runs) > 0: + scale_factors = get_scale_factors(unscaled_runs) + scaled_runs = defaultdict(list) + for name, runs in unscaled_runs.items(): + for i, run in enumerate(runs): + scaled_runs[name].append(scale_parallel_time(run, scale_factors[i])) + history_groups[(strategy, job)] = scaled_runs + else: + history_groups[(strategy, job)] = unscaled_runs # save dataframe print("saving data") for (strategy, job), group in history_groups.items(): @@ -105,7 +194,7 @@ def save_dataframes(project_name, strategies, base_dir="./dataframes"): def load_dataframes(project_name, base_dir="./dataframes"): """ - Loades saved dataframes from the given project. + Loads saved dataframes from the given project. Args: project_name: (str) the name of the project folder base_dir: (str) the base directory to fetch the dataframes from @@ -130,7 +219,7 @@ def load_dataframes(project_name, base_dir="./dataframes"): project_dir, strategy, job, device_id = path.split(os.sep) # Load dataframe from csv - df = pd.read_csv(os.path.join(root, file)) + df = pd.read_csv(os.path.join(root, file), low_memory=False) # Add dataframe to dictionary if (strategy, job) not in history_groups: @@ -141,12 +230,13 @@ def load_dataframes(project_name, base_dir="./dataframes"): return history_groups -def get_total_flops(groups, total_model_flops): +def get_total_flops(groups, total_model_flops, client_model_flops, batch_size=64): """ Returns the total number of FLOPs for each group. Args: groups: The runs of one project, according to the structure of the wandb project total_model_flops: (int) the total number of FLOPs of the model + client_model_flops: (int) the total number of FLOPs of the client model Returns: flops_per_group: (dict) the total number of FLOPs for each group """ @@ -155,6 +245,7 @@ def get_total_flops(groups, total_model_flops): if job == "train": flops = 0 num_runs = 1 # avoid division by 0 + num_clients = len(group.items()) - 1 # minus controller for name, runs in group.items(): if ( name != "controller" @@ -170,6 +261,38 @@ def get_total_flops(groups, total_model_flops): flops += ( run_df[col_name].sum() * total_model_flops ) # 1x forward + if col_name == "adaptive_learning_threshold_applied": + # deduce client model flops twice as client backprop is avoided + if ( + run_df[col_name].dtype == "object" + ): # if boolean values were logged + # assumptions: compute avg number of samples per batch + avg_samples_per_epoch = sum( + run_df["train_accuracy.num_samples"].dropna() + ) / len( + run_df["train_accuracy.num_samples"].dropna() + ) + avg_num_batches = ( + math.ceil( + avg_samples_per_epoch + / num_clients + / batch_size + ) + * num_clients + ) + avg_samples_per_batch = ( + avg_samples_per_epoch / avg_num_batches + ) + flops -= ( + len(run_df[col_name].dropna()) + * client_model_flops + * 2 + * avg_samples_per_batch + ) + else: # numbers of samples skipped are logged -> sum up + flops -= ( + run_df[col_name].sum() * client_model_flops * 2 + ) flops = flops / num_runs flops_per_group["strategy"].append(STRATEGY_MAPPING[strategy]) flops_per_group["flops"].append(round(flops / 1000000000, 3)) # in GFLOPs @@ -341,6 +464,27 @@ def accuracy_over_epoch(history_groups, phase="train"): return results +def accuracy_over_time(history_groups, phase="train"): + """ + Returns the accuracy over time for each group. No averaging implemented yet if there are multiple runs per group! + Args: + history_groups: The runs of one project, according to the structure of the wandb project + phase: (str) the phase to get the accuracy for, either "train" or "val" + Returns: + results: (dict) the accuracy (list(float)) per round (list(int)) for each group + """ + results = {} + for (strategy, job), group in history_groups.items(): + if job == "train": + run_df = group["controller"][0] # no averaging + time_acc = run_df[[f"{phase}_accuracy.value", "_runtime"]].dropna() + results[(strategy, job)] = ( + time_acc["_runtime"], + time_acc[f"{phase}_accuracy.value"], + ) + return results + + def plot_accuracies(accuracies_per_round, save_path=None, phase="train"): """ Plots the accuracy over the epoch for each group. @@ -366,6 +510,28 @@ def plot_accuracies(accuracies_per_round, save_path=None, phase="train"): plt.close() +def plot_accuracies_over_time(accuracies_per_time, save_path=None, phase="train"): + """ + Plots the accuracy over the time for each group. + Args: + accuracies_per_time: (dict) the accuracy (list(float)) per time (list(float)) for each group + save_path: (str) the path to save the plot to + """ + plt.figure() + for (strategy, job), (time, accs) in accuracies_per_time.items(): + plt.plot(time, accs, label=f"{STRATEGY_MAPPING[strategy]}") + plt.xlabel(LABEL_MAPPING["runtime"]) + plt.ylabel(LABEL_MAPPING[f"{phase} accuracy"]) + plt.legend() + plt.tight_layout() + if save_path is None: + plt.show() + else: + plt.savefig(f"{save_path}.pdf", format="pdf") + plt.savefig(f"{save_path}.png", format="png") + plt.close() + + def battery_over_time(history_groups, num_intervals=1000): """ Returns the average battery over time for each group. @@ -1028,20 +1194,6 @@ def generate_plots(history_groups, project_name, base_path="./plots"): aggregated=False, ) - train_times = get_train_times(history_groups) - plot_batteries_over_time_with_activity( - batteries_over_time, - max_runtimes, - train_times, - save_path=f"{project_path}/activity_over_time", - ) - plot_batteries_over_epoch_with_activity_at_time_scale( - batteries_over_time, - max_runtimes, - train_times, - save_path=f"{project_path}/activity_over_time_with_epoch", - ) - # batteries over epoch batteries_over_epoch = aggregated_battery_over_epoch( history_groups, num_intervals=1000 @@ -1058,13 +1210,6 @@ def generate_plots(history_groups, project_name, base_path="./plots"): save_path=f"{project_path}/batteries_over_epoch", aggregated=False, ) - training_times = get_train_times(history_groups) - plot_batteries_over_epoch_with_activity_at_epoch_scale( - batteries_over_epoch, - training_times=training_times, - save_path=f"{project_path}/activity_over_epoch", - ) - # remaining devices remaining_devices = remaining_devices_per_round(history_groups) plot_remaining_devices( @@ -1080,16 +1225,32 @@ def generate_plots(history_groups, project_name, base_path="./plots"): val_accs = accuracy_over_epoch(history_groups, "val") plot_accuracies(val_accs, save_path=f"{project_path}/val_accuracy", phase="val") + time_train_accs = accuracy_over_time(history_groups, "train") + plot_accuracies_over_time( + time_train_accs, + save_path=f"{project_path}/train_accuracy_over_time", + phase="train", + ) + time_val_accs = accuracy_over_time(history_groups, "val") + plot_accuracies_over_time( + time_val_accs, save_path=f"{project_path}/val_accuracy_over_time", phase="val" + ) + def generate_metric_files( - history_groups, project_name, model_flops, base_path="./metrics" + history_groups, + project_name, + total_model_flops, + client_model_flops, + base_path="./metrics", + batch_size=64, ): """ Generates metric file for the given history groups and saves them to the project_name folder. Args: history_groups: The runs of one project, according to the structure of the wandb project project_name: (str) the name of the project - model_flops: (int) the total number of FLOPs of the model + total_model_flops: (int) the total number of FLOPs of the model base_path: (str) the base path to save the metrics to """ project_path = f"{base_path}/{project_name}" @@ -1103,7 +1264,9 @@ def generate_metric_files( get_communication_overhead(history_groups) ).set_index("strategy") total_flops = pd.DataFrame.from_dict( - get_total_flops(history_groups, model_flops) + get_total_flops( + history_groups, total_model_flops, client_model_flops, batch_size + ) ).set_index("strategy") df = pd.concat([test_acc, comm_overhead, total_flops], axis=1) df.to_csv(f"{project_path}/metrics.csv")