From d0fc72e8f40e34ee020728dc5e0798b606e1a093 Mon Sep 17 00:00:00 2001
From: Tim Tobias Bauerle <tim.bauerle@rwth-aachen.de>
Date: Wed, 3 Jul 2024 13:29:41 +0200
Subject: [PATCH] Smart scheduler integration

---
 edml/config/controller/scheduler/smart.yaml   |  4 +
 edml/controllers/scheduler/smart.py           | 21 +++--
 edml/controllers/swarm_controller.py          | 24 +-----
 .../tests/controllers/scheduler/smart_test.py |  6 +-
 .../controllers/swarm_controller_test.py      | 78 ++-----------------
 5 files changed, 31 insertions(+), 102 deletions(-)
 create mode 100644 edml/config/controller/scheduler/smart.yaml

diff --git a/edml/config/controller/scheduler/smart.yaml b/edml/config/controller/scheduler/smart.yaml
new file mode 100644
index 0000000..04a05ed
--- /dev/null
+++ b/edml/config/controller/scheduler/smart.yaml
@@ -0,0 +1,4 @@
+name: smart
+_target_: edml.controllers.scheduler.smart.SmartNextServerScheduler
+fallback_scheduler:
+  _target_: edml.controllers.scheduler.max_battery.MaxBatteryNextServerScheduler
diff --git a/edml/controllers/scheduler/smart.py b/edml/controllers/scheduler/smart.py
index 6c0fccc..cf66bd1 100644
--- a/edml/controllers/scheduler/smart.py
+++ b/edml/controllers/scheduler/smart.py
@@ -1,7 +1,6 @@
 from typing import Optional, Sequence
 
 from edml.controllers.scheduler.base import NextServerScheduler
-from edml.controllers.scheduler.max_battery import MaxBatteryNextServerScheduler
 from edml.controllers.strategy_optimization import (
     GlobalParams,
     ServerChoiceOptimizer,
@@ -16,32 +15,34 @@ from edml.helpers.metrics import (
 class SmartNextServerScheduler(NextServerScheduler):
     """
     This scheduler optimizes the server device selection so that the number of rounds with all devices participating is maximized.
-    Therefore, training metrics are needed, so in the first round, the device with the maximal battery capacity is picked.
+    Therefore, training metrics are needed, so in the first round, the device is chosen by the specified fallback scheduler.
     Afterward, the optimal selection schedule is computed. If all devices have been picked according to the schedule,
-    i.e. one device will run out of battery in the next round, again the device with the maximal battery capacity is chosen.
+    i.e. one device will run out of battery in the next round, again the server device is chosen by the fallback.
     """
 
     KEY: str = "smart"
 
-    def __init__(self, max_battery_scheduler: MaxBatteryNextServerScheduler, cfg):
+    def __init__(self, fallback_scheduler: NextServerScheduler):
         super().__init__()
-        self.max_battery_scheduler = max_battery_scheduler
+        self.fallback_scheduler = fallback_scheduler
         self.selection_schedule = None
-        self.cfg = cfg
 
     def _initialize(self, controller: "BaseController"):
+        self.cfg = controller.cfg
         self._update_batteries_cb = controller._get_battery_status
         self._data_model_cb = (
             controller._get_active_devices_dataset_sizes_and_model_flops
         )
+        self.fallback_scheduler.initialize(controller=controller)
 
     def _next_server(
         self,
         active_devices: Sequence[str],
         diagnostic_metric_container: Optional[DiagnosticMetricResultContainer] = None,
+        **kwargs,
     ) -> str:
         if diagnostic_metric_container is None:
-            return self.max_battery_scheduler.next_server(active_devices)
+            return self.fallback_scheduler.next_server(active_devices)
         else:
             if self.selection_schedule is None or len(self.selection_schedule) == 0:
                 self.selection_schedule = self._get_selection_schedule(
@@ -51,7 +52,11 @@ class SmartNextServerScheduler(NextServerScheduler):
             try:
                 return self.selection_schedule.pop(0)
             except IndexError:  # no more devices left in schedule
-                return self.max_battery_scheduler.next_server(active_devices)
+                return self.fallback_scheduler.next_server(
+                    active_devices,
+                    diagnostic_metric_container=diagnostic_metric_container,
+                    kwargs=kwargs,
+                )
 
     def _get_selection_schedule(self, diagnostic_metric_container):
         device_params_list = []
diff --git a/edml/controllers/swarm_controller.py b/edml/controllers/swarm_controller.py
index 595edaf..4463e2c 100644
--- a/edml/controllers/swarm_controller.py
+++ b/edml/controllers/swarm_controller.py
@@ -2,31 +2,15 @@ from typing import Any
 
 from edml.controllers.base_controller import BaseController
 from edml.controllers.scheduler.base import NextServerScheduler
-from edml.controllers.scheduler.max_battery import MaxBatteryNextServerScheduler
-from edml.controllers.scheduler.random import RandomNextServerScheduler
-from edml.controllers.scheduler.sequential import SequentialNextServerScheduler
-from edml.controllers.scheduler.smart import SmartNextServerScheduler
 from edml.helpers.config_helpers import get_device_index_by_id
 
 
 class SwarmController(BaseController):
 
-    def __init__(self, cfg, selection_strategy: str = "sequential"):
+    def __init__(self, cfg, scheduler: NextServerScheduler):
         super().__init__(cfg)
-        self.selection_strategy = selection_strategy
-
-        # Creates the default set of schedulers that swarm learning originally had and initializes them.
-        max_battery_scheduler = MaxBatteryNextServerScheduler()
-        self._next_server_schedulers: [str, NextServerScheduler] = {
-            "max_battery": max_battery_scheduler,
-            "random": RandomNextServerScheduler(),
-            "sequential": SequentialNextServerScheduler(),
-            "smart": SmartNextServerScheduler(
-                max_battery_scheduler=max_battery_scheduler, cfg=cfg
-            ),
-        }
-        for scheduler in self._next_server_schedulers.values():
-            scheduler.initialize(self)
+        scheduler.initialize(self)
+        self._next_server_scheduler = scheduler
 
     def _train(self):
         client_weights = None
@@ -127,7 +111,7 @@ class SwarmController(BaseController):
         """Returns the id of the server device for the given round."""
         if len(self.active_devices) == 0:
             return None
-        return self._next_server_schedulers[self.selection_strategy].next_server(
+        return self._next_server_scheduler.next_server(
             self.active_devices,
             last_server_device_id=last_server_device_id,
             diagnostic_metric_container=diagnostic_metric_container,
diff --git a/edml/tests/controllers/scheduler/smart_test.py b/edml/tests/controllers/scheduler/smart_test.py
index 5522153..bcc0a2e 100644
--- a/edml/tests/controllers/scheduler/smart_test.py
+++ b/edml/tests/controllers/scheduler/smart_test.py
@@ -25,9 +25,9 @@ class SmartServerDeviceSelectionTest(unittest.TestCase):
             "comp_latency_factor": {"d0": 1, "d1": 1.01},
         }
         self.scheduler = SmartNextServerScheduler(
-            max_battery_scheduler=Mock(spec=MaxBatteryNextServerScheduler),
-            cfg=load_sample_config(),
+            fallback_scheduler=Mock(spec=MaxBatteryNextServerScheduler),
         )
+        self.scheduler.cfg = load_sample_config()
         self.scheduler._update_batteries_cb = lambda: {
             "d0": DeviceBatteryStatus.from_tuple((500, 500)),
             "d1": DeviceBatteryStatus.from_tuple((500, 470)),
@@ -44,7 +44,7 @@ class SmartServerDeviceSelectionTest(unittest.TestCase):
     def test_select_server_device_smart_first_round(self):
         # should select according to max battery
         self.scheduler.next_server([""], diagnostic_metric_container=None)
-        self.scheduler.max_battery_scheduler.next_server.assert_called_once()
+        self.scheduler.fallback_scheduler.next_server.assert_called_once()
 
     def test_select_server_device_smart_second_round(self):
         # should select build a schedule and select the first element
diff --git a/edml/tests/controllers/swarm_controller_test.py b/edml/tests/controllers/swarm_controller_test.py
index 56e9a5e..6c025f4 100644
--- a/edml/tests/controllers/swarm_controller_test.py
+++ b/edml/tests/controllers/swarm_controller_test.py
@@ -20,7 +20,9 @@ class SwarmControllerTest(unittest.TestCase):
 
     def setUp(self) -> None:
         self.cfg = load_sample_config()
-        self.swarm_controller = SwarmController(self.cfg)
+        self.swarm_controller = SwarmController(
+            self.cfg, SequentialNextServerScheduler()
+        )
         self.mock = Mock(spec=DeviceRequestDispatcher)
         self.mock.active_devices.return_value = ["d0", "d1"]
         self.swarm_controller.request_dispatcher = self.mock
@@ -79,17 +81,13 @@ class SwarmControllerTest(unittest.TestCase):
 class ServerDeviceSelectionTest(unittest.TestCase):
 
     def setUp(self) -> None:
-        self.swarm_controller = SwarmController(load_sample_config())
+        self.swarm_controller = SwarmController(
+            load_sample_config(), Mock(SequentialNextServerScheduler)
+        )
         self.swarm_controller.request_dispatcher = Mock(spec=DeviceRequestDispatcher)
         self.swarm_controller.devices = ListConfig(
             [{"device_id": "d0"}, {"device_id": "d1"}, {"device_id": "d2"}]
         )  # omitted address etc.
-        self.swarm_controller._next_server_schedulers = {
-            "max_battery": Mock(spec=MaxBatteryNextServerScheduler),
-            "random": Mock(spec=RandomNextServerScheduler),
-            "sequential": Mock(spec=SequentialNextServerScheduler),
-            "smart": Mock(spec=SmartNextServerScheduler),
-        }
 
     def test_select_no_server_device_if_no_active_devices(self):
         self.swarm_controller.request_dispatcher.active_devices.return_value = []
@@ -104,66 +102,4 @@ class ServerDeviceSelectionTest(unittest.TestCase):
     def test_sequential_selection_default(self):
         self.swarm_controller._select_server_device()
 
-        self.swarm_controller._next_server_schedulers[
-            "sequential"
-        ].next_server.assert_called()
-        self.swarm_controller._next_server_schedulers[
-            "max_battery"
-        ].next_server.assert_not_called()
-        self.swarm_controller._next_server_schedulers[
-            "random"
-        ].next_server.assert_not_called()
-        self.swarm_controller._next_server_schedulers[
-            "smart"
-        ].next_server.assert_not_called()
-
-    def test_max_battery_selection(self):
-        self.swarm_controller.selection_strategy = "max_battery"
-        self.swarm_controller._select_server_device()
-
-        self.swarm_controller._next_server_schedulers[
-            "max_battery"
-        ].next_server.assert_called()
-        self.swarm_controller._next_server_schedulers[
-            "sequential"
-        ].next_server.assert_not_called()
-        self.swarm_controller._next_server_schedulers[
-            "random"
-        ].next_server.assert_not_called()
-        self.swarm_controller._next_server_schedulers[
-            "smart"
-        ].next_server.assert_not_called()
-
-    def test_random_selection(self):
-        self.swarm_controller.selection_strategy = "random"
-        self.swarm_controller._select_server_device()
-
-        self.swarm_controller._next_server_schedulers[
-            "random"
-        ].next_server.assert_called()
-        self.swarm_controller._next_server_schedulers[
-            "sequential"
-        ].next_server.assert_not_called()
-        self.swarm_controller._next_server_schedulers[
-            "max_battery"
-        ].next_server.assert_not_called()
-        self.swarm_controller._next_server_schedulers[
-            "smart"
-        ].next_server.assert_not_called()
-
-    def test_smart_selection(self):
-        self.swarm_controller.selection_strategy = "smart"
-        self.swarm_controller._select_server_device()
-
-        self.swarm_controller._next_server_schedulers[
-            "smart"
-        ].next_server.assert_called()
-        self.swarm_controller._next_server_schedulers[
-            "sequential"
-        ].next_server.assert_not_called()
-        self.swarm_controller._next_server_schedulers[
-            "random"
-        ].next_server.assert_not_called()
-        self.swarm_controller._next_server_schedulers[
-            "max_battery"
-        ].next_server.assert_not_called()
+        self.swarm_controller._next_server_scheduler.next_server.assert_called()
-- 
GitLab