From c889925b15d4df07bcba5d5182f790841408d902 Mon Sep 17 00:00:00 2001 From: Tim Bauerle <tim.bauerle@rwth-aachen.de> Date: Wed, 3 Jul 2024 16:00:41 +0200 Subject: [PATCH] Added estimate for transmission latency to smart schedule computation --- edml/controllers/scheduler/smart.py | 10 ++++-- edml/controllers/strategy_optimization.py | 34 ++++++++++++++++++- edml/helpers/metrics.py | 3 ++ edml/tests/controllers/optimization_test.py | 17 +++++++++- .../tests/controllers/scheduler/smart_test.py | 7 ++-- edml/tests/helpers/metrics_test.py | 4 +++ 6 files changed, 69 insertions(+), 6 deletions(-) diff --git a/edml/controllers/scheduler/smart.py b/edml/controllers/scheduler/smart.py index d2de92b..4894033 100644 --- a/edml/controllers/scheduler/smart.py +++ b/edml/controllers/scheduler/smart.py @@ -38,6 +38,7 @@ class SmartNextServerScheduler(NextServerScheduler): def _next_server( self, active_devices: Sequence[str], + last_server_device_id=None, diagnostic_metric_container: Optional[DiagnosticMetricResultContainer] = None, **kwargs, ) -> str: @@ -46,7 +47,7 @@ class SmartNextServerScheduler(NextServerScheduler): else: if self.selection_schedule is None or len(self.selection_schedule) == 0: self.selection_schedule = self._get_selection_schedule( - diagnostic_metric_container + diagnostic_metric_container, last_server_device_id ) print(f"server device schedule: {self.selection_schedule}") try: @@ -58,7 +59,9 @@ class SmartNextServerScheduler(NextServerScheduler): kwargs=kwargs, ) - def _get_selection_schedule(self, diagnostic_metric_container): + def _get_selection_schedule( + self, diagnostic_metric_container, last_server_device_id=None + ): device_params_list = [] device_battery_levels = self._update_batteries_cb() # get num samples and flops per device @@ -101,6 +104,9 @@ class SmartNextServerScheduler(NextServerScheduler): global_params.optimizer_state_size = optimization_metrics[ "optimizer_state_size" ] + if "train_global_time" in optimization_metrics.keys(): + global_params.train_global_time = optimization_metrics["train_global_time"] + global_params.last_server_device_id = last_server_device_id print(f"global params: {vars(global_params)}") print(f"device params: {[vars(device) for device in device_params_list]}") server_choice_optimizer = ServerChoiceOptimizer( diff --git a/edml/controllers/strategy_optimization.py b/edml/controllers/strategy_optimization.py index 022f204..eddbf57 100644 --- a/edml/controllers/strategy_optimization.py +++ b/edml/controllers/strategy_optimization.py @@ -42,6 +42,8 @@ class GlobalParams: client_norm_bw_time=None, server_norm_fw_time=None, server_norm_bw_time=None, + train_global_time=None, + last_server_device_id=None, ): self.cost_per_sec = cost_per_sec self.cost_per_byte_sent = cost_per_byte_sent @@ -54,6 +56,8 @@ class GlobalParams: self.client_weights_size = client_weights_size self.server_weights_size = server_weights_size self.optimizer_state_size = optimizer_state_size + self.train_global_time = train_global_time + self.last_server_device_id = last_server_device_id # metrics per sample self.label_size = label_size self.gradient_size = gradient_size @@ -114,7 +118,20 @@ class ServerChoiceOptimizer: return device return None - def round_runtime_with_server(self, server_device_id): + def _transmission_latency(self): + if ( + self.global_params.train_global_time is not None + and self.global_params.last_server_device_id is not None + ): + return ( + self.global_params.train_global_time + - self._round_runtime_with_server_no_latency( + self.global_params.last_server_device_id + ) + ) + return 0 # latency not known + + def _round_runtime_with_server_no_latency(self, server_device_id): """ Computes the runtime of a round with the given server device. Params: @@ -122,6 +139,7 @@ class ServerChoiceOptimizer: Returns: the runtime of a round with the given server device Notes: + Does not consider any transmission latency and may underestimate the runtime """ total_time = 0 for device in self.device_params_list: @@ -144,6 +162,20 @@ class ServerChoiceOptimizer: ) * device.comp_latency_factor return total_time + def round_runtime_with_server(self, server_device_id): + """ + Computes the runtime of a round with the given server device. + Params: + server_device_id: the id of the server device + Returns: + the runtime of a round with the given server device + Notes: + """ + return ( + self._round_runtime_with_server_no_latency(server_device_id) + + self._transmission_latency() + ) + def num_flops_per_round_on_device(self, device_id, server_device_id): device = self._get_device_params(device_id) total_flops = 0 diff --git a/edml/helpers/metrics.py b/edml/helpers/metrics.py index 7dc57e7..d0ba068 100644 --- a/edml/helpers/metrics.py +++ b/edml/helpers/metrics.py @@ -436,6 +436,9 @@ def compute_metrics_for_optimization( ) # all values should be equal anyway # time + result["train_global_time"] = max( + [metric.value for metric in raw_metrics[("comp_time", "train_global")]] + ) # should be only one value anyway avg_server_model_train_time_per_sample = ( sum([metric.value for metric in raw_metrics[("comp_time", "train_batch")]]) / len(raw_metrics[("comp_time", "train_batch")]) diff --git a/edml/tests/controllers/optimization_test.py b/edml/tests/controllers/optimization_test.py index d488cd0..2e8cc39 100644 --- a/edml/tests/controllers/optimization_test.py +++ b/edml/tests/controllers/optimization_test.py @@ -77,11 +77,26 @@ class StrategyOptimizationTest(unittest.TestCase): self.assertEqual(self.optimizer._get_device_params("d1").device_id, "d1") self.assertEqual(self.optimizer._get_device_params("d2").device_id, "d2") - def test_round_runtime_with_server(self): + def test_transmission_latency_no_values(self): + self.assertEqual(self.optimizer._transmission_latency(), 0) + + def test_transmission_latency_with_values(self): + self.global_params.train_global_time = 100 + self.global_params.last_server_device_id = "d0" + self.assertEqual(self.optimizer._transmission_latency(), 6.5) + + def test_round_runtime_with_server_no_latency(self): self.assertEqual(self.optimizer.round_runtime_with_server("d0"), 93.5) self.assertEqual(self.optimizer.round_runtime_with_server("d1"), 153.5) self.assertEqual(self.optimizer.round_runtime_with_server("d2"), 63.5) + def test_round_runtime_with_server_with_latency(self): + self.global_params.train_global_time = 100 + self.global_params.last_server_device_id = "d0" + self.assertEqual(self.optimizer.round_runtime_with_server("d0"), 100) + self.assertEqual(self.optimizer.round_runtime_with_server("d1"), 160) + self.assertEqual(self.optimizer.round_runtime_with_server("d2"), 70) + def test_num_flops_per_round_on_device(self): self.assertEqual(self.optimizer.num_flops_per_round_on_device("d0", "d0"), 670) self.assertEqual(self.optimizer.num_flops_per_round_on_device("d1", "d0"), 100) diff --git a/edml/tests/controllers/scheduler/smart_test.py b/edml/tests/controllers/scheduler/smart_test.py index 4e0178b..6eb7e75 100644 --- a/edml/tests/controllers/scheduler/smart_test.py +++ b/edml/tests/controllers/scheduler/smart_test.py @@ -44,7 +44,9 @@ class SmartServerDeviceSelectionTest(unittest.TestCase): def test_select_server_device_smart_first_round(self): # should select according to max battery - self.scheduler.next_server([""], diagnostic_metric_container=None) + self.scheduler.next_server( + [""], diagnostic_metric_container=None, last_server_device_id=None + ) self.scheduler.fallback_scheduler.next_server.assert_called_once() def test_select_server_device_smart_second_round(self): @@ -56,6 +58,7 @@ class SmartServerDeviceSelectionTest(unittest.TestCase): server_device = self.scheduler.next_server( ["d0", "d1"], diagnostic_metric_container=self.diagnostic_metric_container, + last_server_device_id=None, ) self.assertEqual(server_device, "d0") @@ -65,6 +68,6 @@ class SmartServerDeviceSelectionTest(unittest.TestCase): return_value=self.metrics, ): schedule = self.scheduler._get_selection_schedule( - self.diagnostic_metric_container + self.diagnostic_metric_container, ) self.assertEqual(schedule, ["d0", "d1", "d1", "d1"]) diff --git a/edml/tests/helpers/metrics_test.py b/edml/tests/helpers/metrics_test.py index 5810d4e..1d13f2e 100644 --- a/edml/tests/helpers/metrics_test.py +++ b/edml/tests/helpers/metrics_test.py @@ -228,6 +228,9 @@ class MetricsForOptimizationTest(unittest.TestCase): DiagnosticMetricResult( device_id="d0", method="train_batch", name="comp_time", value=0.6 ), + DiagnosticMetricResult( + device_id="d0", method="train_global", name="comp_time", value=10 + ), # train epoch time = train batch time + client model train time DiagnosticMetricResult( device_id="d0", @@ -304,6 +307,7 @@ class MetricsForOptimizationTest(unittest.TestCase): self.assertEqual(normalized_metrics["client_weight_size"], 500) self.assertEqual(normalized_metrics["server_weight_size"], 500) self.assertEqual(normalized_metrics["optimizer_state_size"], 500) + self.assertEqual(normalized_metrics["train_global_time"], 10) self.assertAlmostEqual(normalized_metrics["client_norm_fw_time"], 0.1) self.assertAlmostEqual(normalized_metrics["client_norm_bw_time"], 0.2) self.assertAlmostEqual(normalized_metrics["server_norm_fw_time"], 0.1) -- GitLab