From c889925b15d4df07bcba5d5182f790841408d902 Mon Sep 17 00:00:00 2001
From: Tim Bauerle <tim.bauerle@rwth-aachen.de>
Date: Wed, 3 Jul 2024 16:00:41 +0200
Subject: [PATCH] Added estimate for transmission latency to smart schedule
 computation

---
 edml/controllers/scheduler/smart.py           | 10 ++++--
 edml/controllers/strategy_optimization.py     | 34 ++++++++++++++++++-
 edml/helpers/metrics.py                       |  3 ++
 edml/tests/controllers/optimization_test.py   | 17 +++++++++-
 .../tests/controllers/scheduler/smart_test.py |  7 ++--
 edml/tests/helpers/metrics_test.py            |  4 +++
 6 files changed, 69 insertions(+), 6 deletions(-)

diff --git a/edml/controllers/scheduler/smart.py b/edml/controllers/scheduler/smart.py
index d2de92b..4894033 100644
--- a/edml/controllers/scheduler/smart.py
+++ b/edml/controllers/scheduler/smart.py
@@ -38,6 +38,7 @@ class SmartNextServerScheduler(NextServerScheduler):
     def _next_server(
         self,
         active_devices: Sequence[str],
+        last_server_device_id=None,
         diagnostic_metric_container: Optional[DiagnosticMetricResultContainer] = None,
         **kwargs,
     ) -> str:
@@ -46,7 +47,7 @@ class SmartNextServerScheduler(NextServerScheduler):
         else:
             if self.selection_schedule is None or len(self.selection_schedule) == 0:
                 self.selection_schedule = self._get_selection_schedule(
-                    diagnostic_metric_container
+                    diagnostic_metric_container, last_server_device_id
                 )
                 print(f"server device schedule: {self.selection_schedule}")
             try:
@@ -58,7 +59,9 @@ class SmartNextServerScheduler(NextServerScheduler):
                     kwargs=kwargs,
                 )
 
-    def _get_selection_schedule(self, diagnostic_metric_container):
+    def _get_selection_schedule(
+        self, diagnostic_metric_container, last_server_device_id=None
+    ):
         device_params_list = []
         device_battery_levels = self._update_batteries_cb()
         # get num samples and flops per device
@@ -101,6 +104,9 @@ class SmartNextServerScheduler(NextServerScheduler):
         global_params.optimizer_state_size = optimization_metrics[
             "optimizer_state_size"
         ]
+        if "train_global_time" in optimization_metrics.keys():
+            global_params.train_global_time = optimization_metrics["train_global_time"]
+        global_params.last_server_device_id = last_server_device_id
         print(f"global params: {vars(global_params)}")
         print(f"device params: {[vars(device) for device in device_params_list]}")
         server_choice_optimizer = ServerChoiceOptimizer(
diff --git a/edml/controllers/strategy_optimization.py b/edml/controllers/strategy_optimization.py
index 022f204..eddbf57 100644
--- a/edml/controllers/strategy_optimization.py
+++ b/edml/controllers/strategy_optimization.py
@@ -42,6 +42,8 @@ class GlobalParams:
         client_norm_bw_time=None,
         server_norm_fw_time=None,
         server_norm_bw_time=None,
+        train_global_time=None,
+        last_server_device_id=None,
     ):
         self.cost_per_sec = cost_per_sec
         self.cost_per_byte_sent = cost_per_byte_sent
@@ -54,6 +56,8 @@ class GlobalParams:
         self.client_weights_size = client_weights_size
         self.server_weights_size = server_weights_size
         self.optimizer_state_size = optimizer_state_size
+        self.train_global_time = train_global_time
+        self.last_server_device_id = last_server_device_id
         # metrics per sample
         self.label_size = label_size
         self.gradient_size = gradient_size
@@ -114,7 +118,20 @@ class ServerChoiceOptimizer:
                 return device
         return None
 
-    def round_runtime_with_server(self, server_device_id):
+    def _transmission_latency(self):
+        if (
+            self.global_params.train_global_time is not None
+            and self.global_params.last_server_device_id is not None
+        ):
+            return (
+                self.global_params.train_global_time
+                - self._round_runtime_with_server_no_latency(
+                    self.global_params.last_server_device_id
+                )
+            )
+        return 0  # latency not known
+
+    def _round_runtime_with_server_no_latency(self, server_device_id):
         """
         Computes the runtime of a round with the given server device.
         Params:
@@ -122,6 +139,7 @@ class ServerChoiceOptimizer:
         Returns:
             the runtime of a round with the given server device
         Notes:
+            Does not consider any transmission latency and may underestimate the runtime
         """
         total_time = 0
         for device in self.device_params_list:
@@ -144,6 +162,20 @@ class ServerChoiceOptimizer:
             ) * device.comp_latency_factor
         return total_time
 
+    def round_runtime_with_server(self, server_device_id):
+        """
+        Computes the runtime of a round with the given server device.
+        Params:
+            server_device_id: the id of the server device
+        Returns:
+            the runtime of a round with the given server device
+        Notes:
+        """
+        return (
+            self._round_runtime_with_server_no_latency(server_device_id)
+            + self._transmission_latency()
+        )
+
     def num_flops_per_round_on_device(self, device_id, server_device_id):
         device = self._get_device_params(device_id)
         total_flops = 0
diff --git a/edml/helpers/metrics.py b/edml/helpers/metrics.py
index 7dc57e7..d0ba068 100644
--- a/edml/helpers/metrics.py
+++ b/edml/helpers/metrics.py
@@ -436,6 +436,9 @@ def compute_metrics_for_optimization(
     )  # all values should be equal anyway
 
     # time
+    result["train_global_time"] = max(
+        [metric.value for metric in raw_metrics[("comp_time", "train_global")]]
+    )  # should be only one value anyway
     avg_server_model_train_time_per_sample = (
         sum([metric.value for metric in raw_metrics[("comp_time", "train_batch")]])
         / len(raw_metrics[("comp_time", "train_batch")])
diff --git a/edml/tests/controllers/optimization_test.py b/edml/tests/controllers/optimization_test.py
index d488cd0..2e8cc39 100644
--- a/edml/tests/controllers/optimization_test.py
+++ b/edml/tests/controllers/optimization_test.py
@@ -77,11 +77,26 @@ class StrategyOptimizationTest(unittest.TestCase):
         self.assertEqual(self.optimizer._get_device_params("d1").device_id, "d1")
         self.assertEqual(self.optimizer._get_device_params("d2").device_id, "d2")
 
-    def test_round_runtime_with_server(self):
+    def test_transmission_latency_no_values(self):
+        self.assertEqual(self.optimizer._transmission_latency(), 0)
+
+    def test_transmission_latency_with_values(self):
+        self.global_params.train_global_time = 100
+        self.global_params.last_server_device_id = "d0"
+        self.assertEqual(self.optimizer._transmission_latency(), 6.5)
+
+    def test_round_runtime_with_server_no_latency(self):
         self.assertEqual(self.optimizer.round_runtime_with_server("d0"), 93.5)
         self.assertEqual(self.optimizer.round_runtime_with_server("d1"), 153.5)
         self.assertEqual(self.optimizer.round_runtime_with_server("d2"), 63.5)
 
+    def test_round_runtime_with_server_with_latency(self):
+        self.global_params.train_global_time = 100
+        self.global_params.last_server_device_id = "d0"
+        self.assertEqual(self.optimizer.round_runtime_with_server("d0"), 100)
+        self.assertEqual(self.optimizer.round_runtime_with_server("d1"), 160)
+        self.assertEqual(self.optimizer.round_runtime_with_server("d2"), 70)
+
     def test_num_flops_per_round_on_device(self):
         self.assertEqual(self.optimizer.num_flops_per_round_on_device("d0", "d0"), 670)
         self.assertEqual(self.optimizer.num_flops_per_round_on_device("d1", "d0"), 100)
diff --git a/edml/tests/controllers/scheduler/smart_test.py b/edml/tests/controllers/scheduler/smart_test.py
index 4e0178b..6eb7e75 100644
--- a/edml/tests/controllers/scheduler/smart_test.py
+++ b/edml/tests/controllers/scheduler/smart_test.py
@@ -44,7 +44,9 @@ class SmartServerDeviceSelectionTest(unittest.TestCase):
 
     def test_select_server_device_smart_first_round(self):
         # should select according to max battery
-        self.scheduler.next_server([""], diagnostic_metric_container=None)
+        self.scheduler.next_server(
+            [""], diagnostic_metric_container=None, last_server_device_id=None
+        )
         self.scheduler.fallback_scheduler.next_server.assert_called_once()
 
     def test_select_server_device_smart_second_round(self):
@@ -56,6 +58,7 @@ class SmartServerDeviceSelectionTest(unittest.TestCase):
             server_device = self.scheduler.next_server(
                 ["d0", "d1"],
                 diagnostic_metric_container=self.diagnostic_metric_container,
+                last_server_device_id=None,
             )
             self.assertEqual(server_device, "d0")
 
@@ -65,6 +68,6 @@ class SmartServerDeviceSelectionTest(unittest.TestCase):
             return_value=self.metrics,
         ):
             schedule = self.scheduler._get_selection_schedule(
-                self.diagnostic_metric_container
+                self.diagnostic_metric_container,
             )
             self.assertEqual(schedule, ["d0", "d1", "d1", "d1"])
diff --git a/edml/tests/helpers/metrics_test.py b/edml/tests/helpers/metrics_test.py
index 5810d4e..1d13f2e 100644
--- a/edml/tests/helpers/metrics_test.py
+++ b/edml/tests/helpers/metrics_test.py
@@ -228,6 +228,9 @@ class MetricsForOptimizationTest(unittest.TestCase):
             DiagnosticMetricResult(
                 device_id="d0", method="train_batch", name="comp_time", value=0.6
             ),
+            DiagnosticMetricResult(
+                device_id="d0", method="train_global", name="comp_time", value=10
+            ),
             # train epoch time = train batch time + client model train time
             DiagnosticMetricResult(
                 device_id="d0",
@@ -304,6 +307,7 @@ class MetricsForOptimizationTest(unittest.TestCase):
         self.assertEqual(normalized_metrics["client_weight_size"], 500)
         self.assertEqual(normalized_metrics["server_weight_size"], 500)
         self.assertEqual(normalized_metrics["optimizer_state_size"], 500)
+        self.assertEqual(normalized_metrics["train_global_time"], 10)
         self.assertAlmostEqual(normalized_metrics["client_norm_fw_time"], 0.1)
         self.assertAlmostEqual(normalized_metrics["client_norm_bw_time"], 0.2)
         self.assertAlmostEqual(normalized_metrics["server_norm_fw_time"], 0.1)
-- 
GitLab