From d0fc72e8f40e34ee020728dc5e0798b606e1a093 Mon Sep 17 00:00:00 2001 From: Tim Tobias Bauerle <tim.bauerle@rwth-aachen.de> Date: Wed, 3 Jul 2024 13:29:41 +0200 Subject: [PATCH] Smart scheduler integration --- edml/config/controller/scheduler/smart.yaml | 4 + edml/controllers/scheduler/smart.py | 21 +++-- edml/controllers/swarm_controller.py | 24 +----- .../tests/controllers/scheduler/smart_test.py | 6 +- .../controllers/swarm_controller_test.py | 78 ++----------------- 5 files changed, 31 insertions(+), 102 deletions(-) create mode 100644 edml/config/controller/scheduler/smart.yaml diff --git a/edml/config/controller/scheduler/smart.yaml b/edml/config/controller/scheduler/smart.yaml new file mode 100644 index 0000000..04a05ed --- /dev/null +++ b/edml/config/controller/scheduler/smart.yaml @@ -0,0 +1,4 @@ +name: smart +_target_: edml.controllers.scheduler.smart.SmartNextServerScheduler +fallback_scheduler: + _target_: edml.controllers.scheduler.max_battery.MaxBatteryNextServerScheduler diff --git a/edml/controllers/scheduler/smart.py b/edml/controllers/scheduler/smart.py index 6c0fccc..cf66bd1 100644 --- a/edml/controllers/scheduler/smart.py +++ b/edml/controllers/scheduler/smart.py @@ -1,7 +1,6 @@ from typing import Optional, Sequence from edml.controllers.scheduler.base import NextServerScheduler -from edml.controllers.scheduler.max_battery import MaxBatteryNextServerScheduler from edml.controllers.strategy_optimization import ( GlobalParams, ServerChoiceOptimizer, @@ -16,32 +15,34 @@ from edml.helpers.metrics import ( class SmartNextServerScheduler(NextServerScheduler): """ This scheduler optimizes the server device selection so that the number of rounds with all devices participating is maximized. - Therefore, training metrics are needed, so in the first round, the device with the maximal battery capacity is picked. + Therefore, training metrics are needed, so in the first round, the device is chosen by the specified fallback scheduler. Afterward, the optimal selection schedule is computed. If all devices have been picked according to the schedule, - i.e. one device will run out of battery in the next round, again the device with the maximal battery capacity is chosen. + i.e. one device will run out of battery in the next round, again the server device is chosen by the fallback. """ KEY: str = "smart" - def __init__(self, max_battery_scheduler: MaxBatteryNextServerScheduler, cfg): + def __init__(self, fallback_scheduler: NextServerScheduler): super().__init__() - self.max_battery_scheduler = max_battery_scheduler + self.fallback_scheduler = fallback_scheduler self.selection_schedule = None - self.cfg = cfg def _initialize(self, controller: "BaseController"): + self.cfg = controller.cfg self._update_batteries_cb = controller._get_battery_status self._data_model_cb = ( controller._get_active_devices_dataset_sizes_and_model_flops ) + self.fallback_scheduler.initialize(controller=controller) def _next_server( self, active_devices: Sequence[str], diagnostic_metric_container: Optional[DiagnosticMetricResultContainer] = None, + **kwargs, ) -> str: if diagnostic_metric_container is None: - return self.max_battery_scheduler.next_server(active_devices) + return self.fallback_scheduler.next_server(active_devices) else: if self.selection_schedule is None or len(self.selection_schedule) == 0: self.selection_schedule = self._get_selection_schedule( @@ -51,7 +52,11 @@ class SmartNextServerScheduler(NextServerScheduler): try: return self.selection_schedule.pop(0) except IndexError: # no more devices left in schedule - return self.max_battery_scheduler.next_server(active_devices) + return self.fallback_scheduler.next_server( + active_devices, + diagnostic_metric_container=diagnostic_metric_container, + kwargs=kwargs, + ) def _get_selection_schedule(self, diagnostic_metric_container): device_params_list = [] diff --git a/edml/controllers/swarm_controller.py b/edml/controllers/swarm_controller.py index 595edaf..4463e2c 100644 --- a/edml/controllers/swarm_controller.py +++ b/edml/controllers/swarm_controller.py @@ -2,31 +2,15 @@ from typing import Any from edml.controllers.base_controller import BaseController from edml.controllers.scheduler.base import NextServerScheduler -from edml.controllers.scheduler.max_battery import MaxBatteryNextServerScheduler -from edml.controllers.scheduler.random import RandomNextServerScheduler -from edml.controllers.scheduler.sequential import SequentialNextServerScheduler -from edml.controllers.scheduler.smart import SmartNextServerScheduler from edml.helpers.config_helpers import get_device_index_by_id class SwarmController(BaseController): - def __init__(self, cfg, selection_strategy: str = "sequential"): + def __init__(self, cfg, scheduler: NextServerScheduler): super().__init__(cfg) - self.selection_strategy = selection_strategy - - # Creates the default set of schedulers that swarm learning originally had and initializes them. - max_battery_scheduler = MaxBatteryNextServerScheduler() - self._next_server_schedulers: [str, NextServerScheduler] = { - "max_battery": max_battery_scheduler, - "random": RandomNextServerScheduler(), - "sequential": SequentialNextServerScheduler(), - "smart": SmartNextServerScheduler( - max_battery_scheduler=max_battery_scheduler, cfg=cfg - ), - } - for scheduler in self._next_server_schedulers.values(): - scheduler.initialize(self) + scheduler.initialize(self) + self._next_server_scheduler = scheduler def _train(self): client_weights = None @@ -127,7 +111,7 @@ class SwarmController(BaseController): """Returns the id of the server device for the given round.""" if len(self.active_devices) == 0: return None - return self._next_server_schedulers[self.selection_strategy].next_server( + return self._next_server_scheduler.next_server( self.active_devices, last_server_device_id=last_server_device_id, diagnostic_metric_container=diagnostic_metric_container, diff --git a/edml/tests/controllers/scheduler/smart_test.py b/edml/tests/controllers/scheduler/smart_test.py index 5522153..bcc0a2e 100644 --- a/edml/tests/controllers/scheduler/smart_test.py +++ b/edml/tests/controllers/scheduler/smart_test.py @@ -25,9 +25,9 @@ class SmartServerDeviceSelectionTest(unittest.TestCase): "comp_latency_factor": {"d0": 1, "d1": 1.01}, } self.scheduler = SmartNextServerScheduler( - max_battery_scheduler=Mock(spec=MaxBatteryNextServerScheduler), - cfg=load_sample_config(), + fallback_scheduler=Mock(spec=MaxBatteryNextServerScheduler), ) + self.scheduler.cfg = load_sample_config() self.scheduler._update_batteries_cb = lambda: { "d0": DeviceBatteryStatus.from_tuple((500, 500)), "d1": DeviceBatteryStatus.from_tuple((500, 470)), @@ -44,7 +44,7 @@ class SmartServerDeviceSelectionTest(unittest.TestCase): def test_select_server_device_smart_first_round(self): # should select according to max battery self.scheduler.next_server([""], diagnostic_metric_container=None) - self.scheduler.max_battery_scheduler.next_server.assert_called_once() + self.scheduler.fallback_scheduler.next_server.assert_called_once() def test_select_server_device_smart_second_round(self): # should select build a schedule and select the first element diff --git a/edml/tests/controllers/swarm_controller_test.py b/edml/tests/controllers/swarm_controller_test.py index 56e9a5e..6c025f4 100644 --- a/edml/tests/controllers/swarm_controller_test.py +++ b/edml/tests/controllers/swarm_controller_test.py @@ -20,7 +20,9 @@ class SwarmControllerTest(unittest.TestCase): def setUp(self) -> None: self.cfg = load_sample_config() - self.swarm_controller = SwarmController(self.cfg) + self.swarm_controller = SwarmController( + self.cfg, SequentialNextServerScheduler() + ) self.mock = Mock(spec=DeviceRequestDispatcher) self.mock.active_devices.return_value = ["d0", "d1"] self.swarm_controller.request_dispatcher = self.mock @@ -79,17 +81,13 @@ class SwarmControllerTest(unittest.TestCase): class ServerDeviceSelectionTest(unittest.TestCase): def setUp(self) -> None: - self.swarm_controller = SwarmController(load_sample_config()) + self.swarm_controller = SwarmController( + load_sample_config(), Mock(SequentialNextServerScheduler) + ) self.swarm_controller.request_dispatcher = Mock(spec=DeviceRequestDispatcher) self.swarm_controller.devices = ListConfig( [{"device_id": "d0"}, {"device_id": "d1"}, {"device_id": "d2"}] ) # omitted address etc. - self.swarm_controller._next_server_schedulers = { - "max_battery": Mock(spec=MaxBatteryNextServerScheduler), - "random": Mock(spec=RandomNextServerScheduler), - "sequential": Mock(spec=SequentialNextServerScheduler), - "smart": Mock(spec=SmartNextServerScheduler), - } def test_select_no_server_device_if_no_active_devices(self): self.swarm_controller.request_dispatcher.active_devices.return_value = [] @@ -104,66 +102,4 @@ class ServerDeviceSelectionTest(unittest.TestCase): def test_sequential_selection_default(self): self.swarm_controller._select_server_device() - self.swarm_controller._next_server_schedulers[ - "sequential" - ].next_server.assert_called() - self.swarm_controller._next_server_schedulers[ - "max_battery" - ].next_server.assert_not_called() - self.swarm_controller._next_server_schedulers[ - "random" - ].next_server.assert_not_called() - self.swarm_controller._next_server_schedulers[ - "smart" - ].next_server.assert_not_called() - - def test_max_battery_selection(self): - self.swarm_controller.selection_strategy = "max_battery" - self.swarm_controller._select_server_device() - - self.swarm_controller._next_server_schedulers[ - "max_battery" - ].next_server.assert_called() - self.swarm_controller._next_server_schedulers[ - "sequential" - ].next_server.assert_not_called() - self.swarm_controller._next_server_schedulers[ - "random" - ].next_server.assert_not_called() - self.swarm_controller._next_server_schedulers[ - "smart" - ].next_server.assert_not_called() - - def test_random_selection(self): - self.swarm_controller.selection_strategy = "random" - self.swarm_controller._select_server_device() - - self.swarm_controller._next_server_schedulers[ - "random" - ].next_server.assert_called() - self.swarm_controller._next_server_schedulers[ - "sequential" - ].next_server.assert_not_called() - self.swarm_controller._next_server_schedulers[ - "max_battery" - ].next_server.assert_not_called() - self.swarm_controller._next_server_schedulers[ - "smart" - ].next_server.assert_not_called() - - def test_smart_selection(self): - self.swarm_controller.selection_strategy = "smart" - self.swarm_controller._select_server_device() - - self.swarm_controller._next_server_schedulers[ - "smart" - ].next_server.assert_called() - self.swarm_controller._next_server_schedulers[ - "sequential" - ].next_server.assert_not_called() - self.swarm_controller._next_server_schedulers[ - "random" - ].next_server.assert_not_called() - self.swarm_controller._next_server_schedulers[ - "max_battery" - ].next_server.assert_not_called() + self.swarm_controller._next_server_scheduler.next_server.assert_called() -- GitLab