diff --git a/edml/config/controller/scheduler/smart.yaml b/edml/config/controller/scheduler/smart.yaml new file mode 100644 index 0000000000000000000000000000000000000000..04a05edbc4c75baa121cfecdd1e9a9b400b4b12e --- /dev/null +++ b/edml/config/controller/scheduler/smart.yaml @@ -0,0 +1,4 @@ +name: smart +_target_: edml.controllers.scheduler.smart.SmartNextServerScheduler +fallback_scheduler: + _target_: edml.controllers.scheduler.max_battery.MaxBatteryNextServerScheduler diff --git a/edml/controllers/scheduler/smart.py b/edml/controllers/scheduler/smart.py index 6c0fccc1f27d639fe7b1bfe414a227c498d32126..cf66bd15eda55e1c6387b09fd13c43c67b389557 100644 --- a/edml/controllers/scheduler/smart.py +++ b/edml/controllers/scheduler/smart.py @@ -1,7 +1,6 @@ from typing import Optional, Sequence from edml.controllers.scheduler.base import NextServerScheduler -from edml.controllers.scheduler.max_battery import MaxBatteryNextServerScheduler from edml.controllers.strategy_optimization import ( GlobalParams, ServerChoiceOptimizer, @@ -16,32 +15,34 @@ from edml.helpers.metrics import ( class SmartNextServerScheduler(NextServerScheduler): """ This scheduler optimizes the server device selection so that the number of rounds with all devices participating is maximized. - Therefore, training metrics are needed, so in the first round, the device with the maximal battery capacity is picked. + Therefore, training metrics are needed, so in the first round, the device is chosen by the specified fallback scheduler. Afterward, the optimal selection schedule is computed. If all devices have been picked according to the schedule, - i.e. one device will run out of battery in the next round, again the device with the maximal battery capacity is chosen. + i.e. one device will run out of battery in the next round, again the server device is chosen by the fallback. """ KEY: str = "smart" - def __init__(self, max_battery_scheduler: MaxBatteryNextServerScheduler, cfg): + def __init__(self, fallback_scheduler: NextServerScheduler): super().__init__() - self.max_battery_scheduler = max_battery_scheduler + self.fallback_scheduler = fallback_scheduler self.selection_schedule = None - self.cfg = cfg def _initialize(self, controller: "BaseController"): + self.cfg = controller.cfg self._update_batteries_cb = controller._get_battery_status self._data_model_cb = ( controller._get_active_devices_dataset_sizes_and_model_flops ) + self.fallback_scheduler.initialize(controller=controller) def _next_server( self, active_devices: Sequence[str], diagnostic_metric_container: Optional[DiagnosticMetricResultContainer] = None, + **kwargs, ) -> str: if diagnostic_metric_container is None: - return self.max_battery_scheduler.next_server(active_devices) + return self.fallback_scheduler.next_server(active_devices) else: if self.selection_schedule is None or len(self.selection_schedule) == 0: self.selection_schedule = self._get_selection_schedule( @@ -51,7 +52,11 @@ class SmartNextServerScheduler(NextServerScheduler): try: return self.selection_schedule.pop(0) except IndexError: # no more devices left in schedule - return self.max_battery_scheduler.next_server(active_devices) + return self.fallback_scheduler.next_server( + active_devices, + diagnostic_metric_container=diagnostic_metric_container, + kwargs=kwargs, + ) def _get_selection_schedule(self, diagnostic_metric_container): device_params_list = [] diff --git a/edml/controllers/swarm_controller.py b/edml/controllers/swarm_controller.py index 595edafadb47382f158866a654053eb4ea54e150..4463e2cc8ffd3a89471e8728ba0a117e163e8fea 100644 --- a/edml/controllers/swarm_controller.py +++ b/edml/controllers/swarm_controller.py @@ -2,31 +2,15 @@ from typing import Any from edml.controllers.base_controller import BaseController from edml.controllers.scheduler.base import NextServerScheduler -from edml.controllers.scheduler.max_battery import MaxBatteryNextServerScheduler -from edml.controllers.scheduler.random import RandomNextServerScheduler -from edml.controllers.scheduler.sequential import SequentialNextServerScheduler -from edml.controllers.scheduler.smart import SmartNextServerScheduler from edml.helpers.config_helpers import get_device_index_by_id class SwarmController(BaseController): - def __init__(self, cfg, selection_strategy: str = "sequential"): + def __init__(self, cfg, scheduler: NextServerScheduler): super().__init__(cfg) - self.selection_strategy = selection_strategy - - # Creates the default set of schedulers that swarm learning originally had and initializes them. - max_battery_scheduler = MaxBatteryNextServerScheduler() - self._next_server_schedulers: [str, NextServerScheduler] = { - "max_battery": max_battery_scheduler, - "random": RandomNextServerScheduler(), - "sequential": SequentialNextServerScheduler(), - "smart": SmartNextServerScheduler( - max_battery_scheduler=max_battery_scheduler, cfg=cfg - ), - } - for scheduler in self._next_server_schedulers.values(): - scheduler.initialize(self) + scheduler.initialize(self) + self._next_server_scheduler = scheduler def _train(self): client_weights = None @@ -127,7 +111,7 @@ class SwarmController(BaseController): """Returns the id of the server device for the given round.""" if len(self.active_devices) == 0: return None - return self._next_server_schedulers[self.selection_strategy].next_server( + return self._next_server_scheduler.next_server( self.active_devices, last_server_device_id=last_server_device_id, diagnostic_metric_container=diagnostic_metric_container, diff --git a/edml/tests/controllers/scheduler/smart_test.py b/edml/tests/controllers/scheduler/smart_test.py index 5522153bd629dec681190d9353e97776ea8d16be..bcc0a2e68a5e1a1694dd4399414081787905eb14 100644 --- a/edml/tests/controllers/scheduler/smart_test.py +++ b/edml/tests/controllers/scheduler/smart_test.py @@ -25,9 +25,9 @@ class SmartServerDeviceSelectionTest(unittest.TestCase): "comp_latency_factor": {"d0": 1, "d1": 1.01}, } self.scheduler = SmartNextServerScheduler( - max_battery_scheduler=Mock(spec=MaxBatteryNextServerScheduler), - cfg=load_sample_config(), + fallback_scheduler=Mock(spec=MaxBatteryNextServerScheduler), ) + self.scheduler.cfg = load_sample_config() self.scheduler._update_batteries_cb = lambda: { "d0": DeviceBatteryStatus.from_tuple((500, 500)), "d1": DeviceBatteryStatus.from_tuple((500, 470)), @@ -44,7 +44,7 @@ class SmartServerDeviceSelectionTest(unittest.TestCase): def test_select_server_device_smart_first_round(self): # should select according to max battery self.scheduler.next_server([""], diagnostic_metric_container=None) - self.scheduler.max_battery_scheduler.next_server.assert_called_once() + self.scheduler.fallback_scheduler.next_server.assert_called_once() def test_select_server_device_smart_second_round(self): # should select build a schedule and select the first element diff --git a/edml/tests/controllers/swarm_controller_test.py b/edml/tests/controllers/swarm_controller_test.py index 56e9a5ee46f6c7f4b36cf2fcfcae1030fc0278cf..6c025f4b78a2a648b6055751cf5875b4b7224494 100644 --- a/edml/tests/controllers/swarm_controller_test.py +++ b/edml/tests/controllers/swarm_controller_test.py @@ -20,7 +20,9 @@ class SwarmControllerTest(unittest.TestCase): def setUp(self) -> None: self.cfg = load_sample_config() - self.swarm_controller = SwarmController(self.cfg) + self.swarm_controller = SwarmController( + self.cfg, SequentialNextServerScheduler() + ) self.mock = Mock(spec=DeviceRequestDispatcher) self.mock.active_devices.return_value = ["d0", "d1"] self.swarm_controller.request_dispatcher = self.mock @@ -79,17 +81,13 @@ class SwarmControllerTest(unittest.TestCase): class ServerDeviceSelectionTest(unittest.TestCase): def setUp(self) -> None: - self.swarm_controller = SwarmController(load_sample_config()) + self.swarm_controller = SwarmController( + load_sample_config(), Mock(SequentialNextServerScheduler) + ) self.swarm_controller.request_dispatcher = Mock(spec=DeviceRequestDispatcher) self.swarm_controller.devices = ListConfig( [{"device_id": "d0"}, {"device_id": "d1"}, {"device_id": "d2"}] ) # omitted address etc. - self.swarm_controller._next_server_schedulers = { - "max_battery": Mock(spec=MaxBatteryNextServerScheduler), - "random": Mock(spec=RandomNextServerScheduler), - "sequential": Mock(spec=SequentialNextServerScheduler), - "smart": Mock(spec=SmartNextServerScheduler), - } def test_select_no_server_device_if_no_active_devices(self): self.swarm_controller.request_dispatcher.active_devices.return_value = [] @@ -104,66 +102,4 @@ class ServerDeviceSelectionTest(unittest.TestCase): def test_sequential_selection_default(self): self.swarm_controller._select_server_device() - self.swarm_controller._next_server_schedulers[ - "sequential" - ].next_server.assert_called() - self.swarm_controller._next_server_schedulers[ - "max_battery" - ].next_server.assert_not_called() - self.swarm_controller._next_server_schedulers[ - "random" - ].next_server.assert_not_called() - self.swarm_controller._next_server_schedulers[ - "smart" - ].next_server.assert_not_called() - - def test_max_battery_selection(self): - self.swarm_controller.selection_strategy = "max_battery" - self.swarm_controller._select_server_device() - - self.swarm_controller._next_server_schedulers[ - "max_battery" - ].next_server.assert_called() - self.swarm_controller._next_server_schedulers[ - "sequential" - ].next_server.assert_not_called() - self.swarm_controller._next_server_schedulers[ - "random" - ].next_server.assert_not_called() - self.swarm_controller._next_server_schedulers[ - "smart" - ].next_server.assert_not_called() - - def test_random_selection(self): - self.swarm_controller.selection_strategy = "random" - self.swarm_controller._select_server_device() - - self.swarm_controller._next_server_schedulers[ - "random" - ].next_server.assert_called() - self.swarm_controller._next_server_schedulers[ - "sequential" - ].next_server.assert_not_called() - self.swarm_controller._next_server_schedulers[ - "max_battery" - ].next_server.assert_not_called() - self.swarm_controller._next_server_schedulers[ - "smart" - ].next_server.assert_not_called() - - def test_smart_selection(self): - self.swarm_controller.selection_strategy = "smart" - self.swarm_controller._select_server_device() - - self.swarm_controller._next_server_schedulers[ - "smart" - ].next_server.assert_called() - self.swarm_controller._next_server_schedulers[ - "sequential" - ].next_server.assert_not_called() - self.swarm_controller._next_server_schedulers[ - "random" - ].next_server.assert_not_called() - self.swarm_controller._next_server_schedulers[ - "max_battery" - ].next_server.assert_not_called() + self.swarm_controller._next_server_scheduler.next_server.assert_called()