diff --git a/edml/controllers/fed_controller.py b/edml/controllers/fed_controller.py index cdf6b6e86d2df4d59f5ea025288fb5e28d3d4efc..4ee55bfd514923864f9b0e5abd3d22d7f8277c95 100644 --- a/edml/controllers/fed_controller.py +++ b/edml/controllers/fed_controller.py @@ -5,6 +5,7 @@ from typing import Dict, List from overrides import override from edml.controllers.base_controller import BaseController +from edml.helpers.decorators import Timer from edml.helpers.metrics import ModelMetricResultContainer @@ -43,38 +44,38 @@ class FedController(BaseController): def _fed_train_round(self, round_no: int = -1): """Returns new client and server weights.""" - client_weights_lock = threading.Lock() - server_weights_lock = threading.Lock() - samples_count_lock = threading.Lock() - metrics_lock = threading.Lock() client_weights = [] server_weights = [] samples_count = [] metrics_container = ModelMetricResultContainer() - with concurrent.futures.ThreadPoolExecutor( - max_workers=max(len(self.active_devices), 1) - ) as executor: # avoid exception when setting 0 workers - futures = [ - executor.submit( - self.request_dispatcher.federated_train_on, device_id, round_no - ) - for device_id in self.active_devices - ] - for future in concurrent.futures.as_completed(futures): - response = future.result() - if response is not False: - new_client_weights, new_server_weights, num_samples, metrics, _ = ( - response # skip diagnostic metrics + parallel_times = [] + with Timer() as elapsed_time: + for device_id in self.active_devices: + with Timer() as individual_time: + response = self.request_dispatcher.federated_train_on( + device_id, round_no ) - with client_weights_lock: + if response is not False: + ( + new_client_weights, + new_server_weights, + num_samples, + metrics, + _, + ) = response # skip diagnostic metrics client_weights.append(new_client_weights) - with server_weights_lock: server_weights.append(new_server_weights) - with samples_count_lock: samples_count.append(num_samples) - with metrics_lock: metrics_container.merge(metrics) - + parallel_times.append(individual_time.execution_time) + self.logger.log( + { + "parallel_fed_time": { + "elapsed_time": elapsed_time.execution_time, + "parallel_time": max(parallel_times), + } + } + ) print(f"samples count {samples_count}") return ( diff --git a/edml/core/device.py b/edml/core/device.py index 94bb18ac35241fc86eec778b8ef2d9145a8ed51f..6770e4271e8dc030af1abc7f8ac010a57242f5bb 100644 --- a/edml/core/device.py +++ b/edml/core/device.py @@ -225,6 +225,7 @@ class Device(ABC): class NetworkDevice(Device): @update_battery + @log_execution_time("logger", "finalize_gradients") def set_gradient_and_finalize_training_on_client_only_on( self, client_id: str, gradients: Any ): @@ -404,6 +405,7 @@ class NetworkDevice(Device): @add_time_to_diagnostic_metrics("evaluate_batch") @update_battery + @log_execution_time("logger", "evaluate_batch_time") def evaluate_batch(self, smashed_data, labels): result = self.server.evaluate_batch(smashed_data, labels) self._log_current_battery_capacity() diff --git a/edml/core/server.py b/edml/core/server.py index 4b9ef80fd876d618075a98a14d61e03d83369b66..629777e3a8969d4a265f11d5aa0f9d975b65ad7a 100644 --- a/edml/core/server.py +++ b/edml/core/server.py @@ -10,7 +10,7 @@ from torch import nn from torch.autograd import Variable from edml.helpers.config_helpers import get_torch_device_id -from edml.helpers.decorators import check_device_set, simulate_latency_decorator +from edml.helpers.decorators import check_device_set, simulate_latency_decorator, Timer from edml.helpers.executor import create_executor_with_threads from edml.helpers.flops import estimate_model_flops from edml.helpers.load_optimizer import get_optimizer_and_scheduler @@ -262,15 +262,24 @@ class DeviceServer: print(f":: BATCHES :: {num_batches}") for batch_index in range(num_batches): client_forward_pass_responses = [] - futures = [ - executor.submit(client_training_job, client_id, batch_index) - for client_id in clients - ] - for future in concurrent.futures.as_completed(futures): - (client_id, result) = future.result() - if result is not None and result is not False: - client_forward_pass_responses.append((client_id, result)) - + parallel_times = [] + with Timer() as elapsed_time: + for client_id in clients: + with Timer() as individual_time: + (client_id, result) = client_training_job( + client_id, batch_index + ) + if result is not None and result is not False: + client_forward_pass_responses.append((client_id, result)) + parallel_times.append(individual_time.execution_time) + self.node_device.log( + { + "parallel_client_train_time": { + "elapsed_time": elapsed_time.execution_time, + "parallel_time": max(parallel_times), + } + } + ) # We want to split up the responses into a list of client IDs and batches again. client_ids = [b[0] for b in client_forward_pass_responses] client_batches = [b[1] for b in client_forward_pass_responses] @@ -313,48 +322,74 @@ class DeviceServer: t.clone().detach() for t in torch.chunk(server_gradients, num_client_gradients) ] - futures = [ - executor.submit( - client_backpropagation_job, client_id, client_gradients[idx] - ) - for (idx, client_id) in enumerate(client_ids) - ] client_backpropagation_gradients = [] - for future in concurrent.futures.as_completed(futures): - _, grads = future.result() - if grads is not None and grads is not False: - client_backpropagation_gradients.append(grads) + parallel_times = [] + with Timer() as elapsed_time: + for idx, client_id in enumerate(client_ids): + with Timer() as individual_time: + _, grads = client_backpropagation_job( + client_id, client_gradients[idx] + ) + if grads is not None and grads is not False: + client_backpropagation_gradients.append(grads) + parallel_times.append(individual_time.execution_time) + self.node_device.log( + { + "parallel_client_backprop_time": { + "elapsed_time": elapsed_time.execution_time, + "parallel_time": max(parallel_times), + } + } + ) # We want to average the client's backpropagation gradients and send them over again to finalize the # current training step. averaged_gradient = _calculate_gradient_mean( client_backpropagation_gradients, self._device ) - futures = [ - executor.submit( - client_set_gradient_and_finalize_training_job, - client_id, - averaged_gradient, - ) - for client_id in clients - ] - for future in concurrent.futures.as_completed(futures): - future.result() - - # Now we have to determine the model metrics for each client. - for client_id in clients: - train_metrics = self.finalize_metrics(str(client_id), "train") - - evaluation_diagnostics_metrics = self.node_device.evaluate_on( - device_id=client_id, - server_device=self.node_device.device_id, - val=True, + parallel_times = [] + with Timer() as elapsed_time: + for client_id in clients: + with Timer() as individual_time: + client_set_gradient_and_finalize_training_job( + client_id, averaged_gradient + ) + parallel_times.append(individual_time.execution_time) + self.node_device.log( + { + "parallel_client_model_update_time": { + "elapsed_time": elapsed_time.execution_time, + "parallel_time": max(parallel_times), + } + } ) - # if evaluation_diagnostics_metrics: - # diagnostic_metrics.merge(evaluation_diagnostics_metrics) - val_metrics = self.finalize_metrics(str(client_id), "val") - model_metrics.add_results(train_metrics) - model_metrics.add_results(val_metrics) + # Now we have to determine the model metrics for each client. + parallel_times = [] + with Timer() as elapsed_time: + for client_id in clients: + with Timer() as individual_time: + train_metrics = self.finalize_metrics(str(client_id), "train") + + evaluation_diagnostics_metrics = self.node_device.evaluate_on( + device_id=client_id, + server_device=self.node_device.device_id, + val=True, + ) + # if evaluation_diagnostics_metrics: + # diagnostic_metrics.merge(evaluation_diagnostics_metrics) + val_metrics = self.finalize_metrics(str(client_id), "val") + + model_metrics.add_results(train_metrics) + model_metrics.add_results(val_metrics) + parallel_times.append(individual_time.execution_time) + self.node_device.log( + { + "parallel_client_eval_time": { + "elapsed_time": elapsed_time.execution_time, + "parallel_time": max(parallel_times), + } + } + ) optimizer_state = self._optimizer.state_dict() # delete references and free GPU memory manually diff --git a/edml/helpers/decorators.py b/edml/helpers/decorators.py index 5623648311abd9e63c51281af657649ea950459e..6f4a8da8912f989163ddca9f829b428800db3749 100644 --- a/edml/helpers/decorators.py +++ b/edml/helpers/decorators.py @@ -165,6 +165,22 @@ class LatencySimulator: time.sleep(self.execution_time * self.latency_factor) +class Timer: + """ + Context Manager to measure execution time. + + Notes: + Access execution time via Timer.execution_time. + """ + + def __enter__(self): + self.start_time = time.perf_counter() + return self + + def __exit__(self, exc_type, exc_val, exc_tb): + self.execution_time = time.perf_counter() - self.start_time + + def add_time_to_diagnostic_metrics(method_name: str): """ A decorator factory that measures the execution time of the wrapped method. It then creates a diagnostic metric