Skip to content
Snippets Groups Projects
Commit 46e071a4 authored by Tim Tobias Bauerle's avatar Tim Tobias Bauerle
Browse files

Merge branch 'smart-scheduler-integration' into 'wip'

Smart scheduler integration

See merge request !13
parents b8bf39b7 d0fc72e8
Branches
No related tags found
2 merge requests!18Merge in main,!13Smart scheduler integration
name: smart
_target_: edml.controllers.scheduler.smart.SmartNextServerScheduler
fallback_scheduler:
_target_: edml.controllers.scheduler.max_battery.MaxBatteryNextServerScheduler
from typing import Optional, Sequence
from edml.controllers.scheduler.base import NextServerScheduler
from edml.controllers.scheduler.max_battery import MaxBatteryNextServerScheduler
from edml.controllers.strategy_optimization import (
GlobalParams,
ServerChoiceOptimizer,
......@@ -16,32 +15,34 @@ from edml.helpers.metrics import (
class SmartNextServerScheduler(NextServerScheduler):
"""
This scheduler optimizes the server device selection so that the number of rounds with all devices participating is maximized.
Therefore, training metrics are needed, so in the first round, the device with the maximal battery capacity is picked.
Therefore, training metrics are needed, so in the first round, the device is chosen by the specified fallback scheduler.
Afterward, the optimal selection schedule is computed. If all devices have been picked according to the schedule,
i.e. one device will run out of battery in the next round, again the device with the maximal battery capacity is chosen.
i.e. one device will run out of battery in the next round, again the server device is chosen by the fallback.
"""
KEY: str = "smart"
def __init__(self, max_battery_scheduler: MaxBatteryNextServerScheduler, cfg):
def __init__(self, fallback_scheduler: NextServerScheduler):
super().__init__()
self.max_battery_scheduler = max_battery_scheduler
self.fallback_scheduler = fallback_scheduler
self.selection_schedule = None
self.cfg = cfg
def _initialize(self, controller: "BaseController"):
self.cfg = controller.cfg
self._update_batteries_cb = controller._get_battery_status
self._data_model_cb = (
controller._get_active_devices_dataset_sizes_and_model_flops
)
self.fallback_scheduler.initialize(controller=controller)
def _next_server(
self,
active_devices: Sequence[str],
diagnostic_metric_container: Optional[DiagnosticMetricResultContainer] = None,
**kwargs,
) -> str:
if diagnostic_metric_container is None:
return self.max_battery_scheduler.next_server(active_devices)
return self.fallback_scheduler.next_server(active_devices)
else:
if self.selection_schedule is None or len(self.selection_schedule) == 0:
self.selection_schedule = self._get_selection_schedule(
......@@ -51,7 +52,11 @@ class SmartNextServerScheduler(NextServerScheduler):
try:
return self.selection_schedule.pop(0)
except IndexError: # no more devices left in schedule
return self.max_battery_scheduler.next_server(active_devices)
return self.fallback_scheduler.next_server(
active_devices,
diagnostic_metric_container=diagnostic_metric_container,
kwargs=kwargs,
)
def _get_selection_schedule(self, diagnostic_metric_container):
device_params_list = []
......
......@@ -2,31 +2,15 @@ from typing import Any
from edml.controllers.base_controller import BaseController
from edml.controllers.scheduler.base import NextServerScheduler
from edml.controllers.scheduler.max_battery import MaxBatteryNextServerScheduler
from edml.controllers.scheduler.random import RandomNextServerScheduler
from edml.controllers.scheduler.sequential import SequentialNextServerScheduler
from edml.controllers.scheduler.smart import SmartNextServerScheduler
from edml.helpers.config_helpers import get_device_index_by_id
class SwarmController(BaseController):
def __init__(self, cfg, selection_strategy: str = "sequential"):
def __init__(self, cfg, scheduler: NextServerScheduler):
super().__init__(cfg)
self.selection_strategy = selection_strategy
# Creates the default set of schedulers that swarm learning originally had and initializes them.
max_battery_scheduler = MaxBatteryNextServerScheduler()
self._next_server_schedulers: [str, NextServerScheduler] = {
"max_battery": max_battery_scheduler,
"random": RandomNextServerScheduler(),
"sequential": SequentialNextServerScheduler(),
"smart": SmartNextServerScheduler(
max_battery_scheduler=max_battery_scheduler, cfg=cfg
),
}
for scheduler in self._next_server_schedulers.values():
scheduler.initialize(self)
scheduler.initialize(self)
self._next_server_scheduler = scheduler
def _train(self):
client_weights = None
......@@ -127,7 +111,7 @@ class SwarmController(BaseController):
"""Returns the id of the server device for the given round."""
if len(self.active_devices) == 0:
return None
return self._next_server_schedulers[self.selection_strategy].next_server(
return self._next_server_scheduler.next_server(
self.active_devices,
last_server_device_id=last_server_device_id,
diagnostic_metric_container=diagnostic_metric_container,
......
......@@ -25,9 +25,9 @@ class SmartServerDeviceSelectionTest(unittest.TestCase):
"comp_latency_factor": {"d0": 1, "d1": 1.01},
}
self.scheduler = SmartNextServerScheduler(
max_battery_scheduler=Mock(spec=MaxBatteryNextServerScheduler),
cfg=load_sample_config(),
fallback_scheduler=Mock(spec=MaxBatteryNextServerScheduler),
)
self.scheduler.cfg = load_sample_config()
self.scheduler._update_batteries_cb = lambda: {
"d0": DeviceBatteryStatus.from_tuple((500, 500)),
"d1": DeviceBatteryStatus.from_tuple((500, 470)),
......@@ -44,7 +44,7 @@ class SmartServerDeviceSelectionTest(unittest.TestCase):
def test_select_server_device_smart_first_round(self):
# should select according to max battery
self.scheduler.next_server([""], diagnostic_metric_container=None)
self.scheduler.max_battery_scheduler.next_server.assert_called_once()
self.scheduler.fallback_scheduler.next_server.assert_called_once()
def test_select_server_device_smart_second_round(self):
# should select build a schedule and select the first element
......
......@@ -20,7 +20,9 @@ class SwarmControllerTest(unittest.TestCase):
def setUp(self) -> None:
self.cfg = load_sample_config()
self.swarm_controller = SwarmController(self.cfg)
self.swarm_controller = SwarmController(
self.cfg, SequentialNextServerScheduler()
)
self.mock = Mock(spec=DeviceRequestDispatcher)
self.mock.active_devices.return_value = ["d0", "d1"]
self.swarm_controller.request_dispatcher = self.mock
......@@ -79,17 +81,13 @@ class SwarmControllerTest(unittest.TestCase):
class ServerDeviceSelectionTest(unittest.TestCase):
def setUp(self) -> None:
self.swarm_controller = SwarmController(load_sample_config())
self.swarm_controller = SwarmController(
load_sample_config(), Mock(SequentialNextServerScheduler)
)
self.swarm_controller.request_dispatcher = Mock(spec=DeviceRequestDispatcher)
self.swarm_controller.devices = ListConfig(
[{"device_id": "d0"}, {"device_id": "d1"}, {"device_id": "d2"}]
) # omitted address etc.
self.swarm_controller._next_server_schedulers = {
"max_battery": Mock(spec=MaxBatteryNextServerScheduler),
"random": Mock(spec=RandomNextServerScheduler),
"sequential": Mock(spec=SequentialNextServerScheduler),
"smart": Mock(spec=SmartNextServerScheduler),
}
def test_select_no_server_device_if_no_active_devices(self):
self.swarm_controller.request_dispatcher.active_devices.return_value = []
......@@ -104,66 +102,4 @@ class ServerDeviceSelectionTest(unittest.TestCase):
def test_sequential_selection_default(self):
self.swarm_controller._select_server_device()
self.swarm_controller._next_server_schedulers[
"sequential"
].next_server.assert_called()
self.swarm_controller._next_server_schedulers[
"max_battery"
].next_server.assert_not_called()
self.swarm_controller._next_server_schedulers[
"random"
].next_server.assert_not_called()
self.swarm_controller._next_server_schedulers[
"smart"
].next_server.assert_not_called()
def test_max_battery_selection(self):
self.swarm_controller.selection_strategy = "max_battery"
self.swarm_controller._select_server_device()
self.swarm_controller._next_server_schedulers[
"max_battery"
].next_server.assert_called()
self.swarm_controller._next_server_schedulers[
"sequential"
].next_server.assert_not_called()
self.swarm_controller._next_server_schedulers[
"random"
].next_server.assert_not_called()
self.swarm_controller._next_server_schedulers[
"smart"
].next_server.assert_not_called()
def test_random_selection(self):
self.swarm_controller.selection_strategy = "random"
self.swarm_controller._select_server_device()
self.swarm_controller._next_server_schedulers[
"random"
].next_server.assert_called()
self.swarm_controller._next_server_schedulers[
"sequential"
].next_server.assert_not_called()
self.swarm_controller._next_server_schedulers[
"max_battery"
].next_server.assert_not_called()
self.swarm_controller._next_server_schedulers[
"smart"
].next_server.assert_not_called()
def test_smart_selection(self):
self.swarm_controller.selection_strategy = "smart"
self.swarm_controller._select_server_device()
self.swarm_controller._next_server_schedulers[
"smart"
].next_server.assert_called()
self.swarm_controller._next_server_schedulers[
"sequential"
].next_server.assert_not_called()
self.swarm_controller._next_server_schedulers[
"random"
].next_server.assert_not_called()
self.swarm_controller._next_server_schedulers[
"max_battery"
].next_server.assert_not_called()
self.swarm_controller._next_server_scheduler.next_server.assert_called()
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment