Skip to content
Snippets Groups Projects
Commit 7dcf3631 authored by Tim Tobias Bauerle's avatar Tim Tobias Bauerle
Browse files

Merge branch 'wip' into 'feat/adaptive-threshold-mechanism'

# Conflicts:
#   edml/helpers/config_helpers.py
#   edml/tests/helpers/config_helpers_test.py
parents 042a75eb 5fede3b6
No related branches found
No related tags found
2 merge requests!18Merge in main,!11chore: experiment files for adaptive threshold mechanism
Showing with 469 additions and 322 deletions
......@@ -34,13 +34,16 @@ class NextServerScheduler(ABC):
def _initialize(self, controller: "BaseController"):
"""Custom hook for implementations to initialize themselves."""
def next_server(self, active_devices: Sequence[str]) -> Optional[str]:
def next_server(self, active_devices: Sequence[str], **kwargs) -> Optional[str]:
"""
Returns the next active server device. This method only calls the `_next_server` method after verifying the
list of active devices is not empty.
For using multiple schedulers interchangeably, pass all required args for all schedulers as kwargs.
Args:
active_devices: The list of currently active device IDs. This list should not be empty.
last_server_device_id: Optional kwarg, the id of the last server device.
diagnostic_metric_container: Optional kwarg, a DiagnosticMetricResultContainer.
Returns:
The next server device ID. In the case that no device is available, `None` is returned.
......@@ -50,9 +53,15 @@ class NextServerScheduler(ABC):
"""
if len(active_devices) == 0:
raise ValueError("list cannot be empty")
return self._next_server(active_devices)
return self._next_server(active_devices, **kwargs)
@abstractmethod
def _next_server(self, active_devices: Sequence[str]) -> Optional[str]:
"""Custom hook to return the next active device."""
def _next_server(self, active_devices: Sequence[str], **kwargs) -> Optional[str]:
"""
Custom hook to return the next active device. Concrete implementations that need more arguments,
should define these as additional keyword args, so the base signature does not have to change.
However, when relying on polymorphism, the next_server template should receive all keyword args for the
different schedulers. Also, the concrete implementations should define **kwargs in the signature,
so it does not have to change when new kwargs are added.
"""
raise NotImplementedError()
......@@ -24,7 +24,7 @@ class MaxBatteryNextServerScheduler(NextServerScheduler):
def _initialize(self, controller: BaseController):
self._update_batteries_cb = controller._get_battery_status
def _next_server(self, _: Sequence[str]) -> Optional[str]:
def _next_server(self, _: Sequence[str], **kwargs) -> Optional[str]:
battery_status = self._get_active_devices_battery_levels()
if len(battery_status) == 0:
return None
......
......@@ -12,5 +12,5 @@ class RandomNextServerScheduler(NextServerScheduler):
KEY: str = "random"
def _next_server(self, active_devices: Sequence[str]) -> str:
def _next_server(self, active_devices: Sequence[str], **kwargs) -> str:
return choice(active_devices)
......@@ -9,17 +9,22 @@ class SequentialNextServerScheduler(NextServerScheduler):
devices.
Args:
first_server_device_id (optional): The device ID of the first server to use. If not provided, the first server
will be the first active device in the list.
last_server_device_id (optional): The device ID of the last server device used. If not provided, the first
server device will be the first active device in the list.
"""
KEY: str = "sequential"
def __init__(self, first_server_device_id: Optional[str] = None):
def __init__(self, last_server_device_id: Optional[str] = None):
super().__init__()
self._last_server_device_id = first_server_device_id
def _next_server(self, active_devices: Sequence[str]) -> str:
self._last_server_device_id = last_server_device_id
def _next_server(
self,
active_devices: Sequence[str],
last_server_device_id: Optional[str] = None,
**kwargs
) -> str:
# Special case if we do not have an initial first server. We simply return the first server ID in the list.
if self._last_server_device_id is None:
next_server = active_devices[0]
......
from typing import Optional, Sequence
from edml.controllers.scheduler.base import NextServerScheduler
from edml.controllers.scheduler.max_battery import MaxBatteryNextServerScheduler
from edml.controllers.strategy_optimization import (
GlobalParams,
ServerChoiceOptimizer,
DeviceParams,
)
from edml.helpers.metrics import (
DiagnosticMetricResultContainer,
compute_metrics_for_optimization,
)
class SmartNextServerScheduler(NextServerScheduler):
"""
This scheduler optimizes the server device selection so that the number of rounds with all devices participating is maximized.
Therefore, training metrics are needed, so in the first round, the device with the maximal battery capacity is picked.
Afterward, the optimal selection schedule is computed. If all devices have been picked according to the schedule,
i.e. one device will run out of battery in the next round, again the device with the maximal battery capacity is chosen.
"""
KEY: str = "smart"
def __init__(self, max_battery_scheduler: MaxBatteryNextServerScheduler, cfg):
super().__init__()
self.max_battery_scheduler = max_battery_scheduler
self.selection_schedule = None
self.cfg = cfg
def _initialize(self, controller: "BaseController"):
self._update_batteries_cb = controller._get_battery_status
self._data_model_cb = (
controller._get_active_devices_dataset_sizes_and_model_flops
)
def _next_server(
self,
active_devices: Sequence[str],
diagnostic_metric_container: Optional[DiagnosticMetricResultContainer] = None,
) -> str:
if diagnostic_metric_container is None:
return self.max_battery_scheduler.next_server(active_devices)
else:
if self.selection_schedule is None or len(self.selection_schedule) == 0:
self.selection_schedule = self._get_selection_schedule(
diagnostic_metric_container
)
print(f"server device schedule: {self.selection_schedule}")
try:
return self.selection_schedule.pop(0)
except IndexError: # no more devices left in schedule
return self.max_battery_scheduler.next_server(active_devices)
def _get_selection_schedule(self, diagnostic_metric_container):
device_params_list = []
device_battery_levels = self._update_batteries_cb()
# get num samples and flops per device
dataset_sizes, model_flops = self._data_model_cb()
try:
optimization_metrics = compute_metrics_for_optimization(
diagnostic_metric_container,
dataset_sizes,
self.cfg.experiment.batch_size,
)
except (
KeyError
): # if some metrics are not available e.g. because a device ran out of battery
return [] # return empty schedule
for device_id, battery_level in device_battery_levels.items():
device_params = DeviceParams(
device_id=device_id,
initial_battery=battery_level.initial_capacity,
current_battery=battery_level.current_capacity,
train_samples=dataset_sizes[device_id][0],
validation_samples=dataset_sizes[device_id][1],
comp_latency_factor=optimization_metrics["comp_latency_factor"][
device_id
],
)
device_params_list.append(device_params)
global_params = GlobalParams()
global_params.fill_values_from_config(self.cfg)
global_params.client_model_flops = model_flops["client"]
global_params.server_model_flops = model_flops["server"]
global_params.client_norm_fw_time = optimization_metrics["client_norm_fw_time"]
global_params.client_norm_bw_time = optimization_metrics["client_norm_bw_time"]
global_params.server_norm_fw_time = optimization_metrics["server_norm_fw_time"]
global_params.server_norm_bw_time = optimization_metrics["server_norm_bw_time"]
global_params.gradient_size = optimization_metrics["gradient_size"]
global_params.label_size = optimization_metrics["label_size"]
global_params.smashed_data_size = optimization_metrics["smashed_data_size"]
global_params.client_weights_size = optimization_metrics["client_weight_size"]
global_params.server_weights_size = optimization_metrics["server_weight_size"]
print(f"global params: {vars(global_params)}")
print(f"device params: {[vars(device) for device in device_params_list]}")
server_choice_optimizer = ServerChoiceOptimizer(
device_params_list, global_params
)
solution, status = server_choice_optimizer.optimize()
schedule = []
for device_id in solution.keys():
schedule += [device_id] * int(solution[device_id])
return schedule
import random
from edml.controllers.base_controller import BaseController
from edml.controllers.scheduler.base import NextServerScheduler
from edml.controllers.scheduler.max_battery import MaxBatteryNextServerScheduler
from edml.controllers.scheduler.random import RandomNextServerScheduler
from edml.controllers.scheduler.sequential import SequentialNextServerScheduler
from edml.controllers.strategy_optimization import (
DeviceParams,
GlobalParams,
ServerChoiceOptimizer,
)
from edml.controllers.scheduler.smart import SmartNextServerScheduler
from edml.helpers.config_helpers import get_device_index_by_id
from edml.helpers.metrics import compute_metrics_for_optimization
class SwarmController(BaseController):
......@@ -19,13 +12,16 @@ class SwarmController(BaseController):
def __init__(self, cfg, selection_strategy: str = "sequential"):
super().__init__(cfg)
self.selection_strategy = selection_strategy
self.selection_schedule = None
# Creates the default set of schedulers that swarm learning originally had and initializes them.
max_battery_scheduler = MaxBatteryNextServerScheduler()
self._next_server_schedulers: [str, NextServerScheduler] = {
"max_battery": MaxBatteryNextServerScheduler(),
"max_battery": max_battery_scheduler,
"random": RandomNextServerScheduler(),
"sequential": SequentialNextServerScheduler(),
"smart": SmartNextServerScheduler(
max_battery_scheduler=max_battery_scheduler, cfg=cfg
),
}
for scheduler in self._next_server_schedulers.values():
scheduler.initialize(self)
......@@ -111,41 +107,11 @@ class SwarmController(BaseController):
"""Returns the id of the server device for the given round."""
if len(self.active_devices) == 0:
return None
if self.selection_strategy == "random":
return self._next_server_schedulers[self.selection_strategy].next_server(
self.active_devices
)
elif self.selection_strategy == "max_battery":
return self._next_server_schedulers[self.selection_strategy].next_server(
self.active_devices
)
elif self.selection_strategy == "smart":
return self._select_server_device_smart(
last_server_device_id, diagnostic_metric_container
)
else: # sequential selection as default
return self._next_server_schedulers[self.selection_strategy].next_server(
self.active_devices
)
def _select_server_device_random(self):
return self.active_devices[random.randint(0, len(self.active_devices) - 1)]
def _select_server_device_max_battery(self):
"""Selects the device with the highest battery level. Returns None if there are no active devices."""
battery_status = self._get_active_devices_battery_levels()
if len(battery_status) == 0:
return None
# return the device_id with the highest battery level
return max(battery_status, key=lambda device_id: battery_status[device_id])
def _get_active_devices_battery_levels(self):
"""Returns the current battery levels of active devices only."""
return {
key: battery_info[1]
for key, battery_info in self._get_battery_status().items()
if battery_info is not None
}
return self._next_server_schedulers[self.selection_strategy].next_server(
self.active_devices,
last_server_device_id=last_server_device_id,
diagnostic_metric_container=diagnostic_metric_container,
)
def _get_active_devices_dataset_sizes_and_model_flops(self):
"""Returns the dataset sizes and model flops of active devices only."""
......@@ -165,86 +131,3 @@ class SwarmController(BaseController):
model_flops["client"] = max(client_flop_list)
model_flops["server"] = max(server_flop_list)
return dataset_sizes, model_flops
def _select_server_device_sequentially(self, last_server_device_id):
if last_server_device_id is None:
return self.active_devices[0]
device_list = [device.device_id for device in self.devices]
last_server_index = device_list.index(last_server_device_id)
next_server_ordering = (
device_list[last_server_index + 1 :] + device_list[: last_server_index + 1]
)
return next(
server for server in next_server_ordering if server in self.active_devices
)
def _select_server_device_smart(
self, last_server_device_id, diagnostic_metric_container
):
if (
last_server_device_id is None or diagnostic_metric_container is None
): # choose device with most battery in first round
return self._select_server_device_max_battery()
else:
if self.selection_schedule is None or len(self.selection_schedule) == 0:
self.selection_schedule = self._get_selection_schedule(
diagnostic_metric_container
)
self.logger.log(f"server device schedule: {self.selection_schedule}")
try:
return self.selection_schedule.pop(0)
except IndexError: # no more devices left in schedule
return self._select_server_device_max_battery()
def _get_selection_schedule(self, diagnostic_metric_container):
device_params_list = []
device_battery_levels = self._get_battery_status()
# get num samples and flops per device
dataset_sizes, model_flops = (
self._get_active_devices_dataset_sizes_and_model_flops()
)
try:
optimization_metrics = compute_metrics_for_optimization(
diagnostic_metric_container,
dataset_sizes,
self.cfg.experiment.batch_size,
)
except (
KeyError
): # if some metrics are not available e.g. because a device ran out of battery
return [] # return empty schedule
for device_id, battery_level in device_battery_levels.items():
device_params = DeviceParams(
device_id=device_id,
initial_battery=battery_level[0],
current_battery=battery_level[1],
train_samples=dataset_sizes[device_id][0],
validation_samples=dataset_sizes[device_id][1],
comp_latency_factor=optimization_metrics["comp_latency_factor"][
device_id
],
)
device_params_list.append(device_params)
global_params = GlobalParams()
global_params.fill_values_from_config(self.cfg)
global_params.client_model_flops = model_flops["client"]
global_params.server_model_flops = model_flops["server"]
global_params.client_norm_fw_time = optimization_metrics["client_norm_fw_time"]
global_params.client_norm_bw_time = optimization_metrics["client_norm_bw_time"]
global_params.server_norm_fw_time = optimization_metrics["server_norm_fw_time"]
global_params.server_norm_bw_time = optimization_metrics["server_norm_bw_time"]
global_params.gradient_size = optimization_metrics["gradient_size"]
global_params.label_size = optimization_metrics["label_size"]
global_params.smashed_data_size = optimization_metrics["smashed_data_size"]
global_params.client_weights_size = optimization_metrics["client_weight_size"]
global_params.server_weights_size = optimization_metrics["server_weight_size"]
print(f"global params: {vars(global_params)}")
print(f"device params: {[vars(device) for device in device_params_list]}")
server_choice_optimizer = ServerChoiceOptimizer(
device_params_list, global_params
)
solution, status = server_choice_optimizer.optimize()
schedule = []
for device_id in solution.keys():
schedule += [device_id] * int(solution[device_id])
return schedule
from copy import deepcopy
from inspect import signature
import torch
from hydra.utils import get_class, instantiate
from omegaconf import OmegaConf, DictConfig, ListConfig
from omegaconf.errors import ConfigAttributeError
from edml.controllers.base_controller import BaseController
......@@ -206,3 +208,36 @@ def instantiate_controller(cfg: DictConfig) -> BaseController:
# Instantiate the controller.
controller: BaseController = instantiate(cfg.controller)(cfg=original_cfg)
return controller
def get_torch_device_id(cfg: DictConfig) -> str:
"""
Returns the configured torch_device for the current device.
Resorts to default if no torch_device is configured.
Args:
cfg (DictConfig): The config loaded from YAML files.
Returns:
The id of the configured torch_device for the current device.
Raises:
StopIteration: If the device with the given ID cannot be found.
ConfigAttributeError: If no device id is present in the config.
"""
own_device_id = cfg.own_device_id
try:
return next(
device_cfg.torch_device
for device_cfg in cfg.topology.devices
if device_cfg.device_id == own_device_id
)
except ConfigAttributeError:
return _default_torch_device()
def _default_torch_device():
"""
Returns the default torch devices, depending on whether cuda is available.
"""
return "cuda:0" if torch.cuda.is_available() else "cpu"
import unittest
from edml.controllers.scheduler.max_battery import MaxBatteryNextServerScheduler
from edml.helpers.types import DeviceBatteryStatus
class MaxBatteryServerDeviceSelectionTest(unittest.TestCase):
def setUp(self) -> None:
self.active_devices = ["d0", "d1", "d2"]
self.scheduler = MaxBatteryNextServerScheduler()
def test_select_first_device_for_equal_batteries(self):
self.scheduler._update_batteries_cb = lambda: {
"d0": DeviceBatteryStatus.from_tuple((100, 50)),
"d1": DeviceBatteryStatus.from_tuple((100, 50)),
"d2": DeviceBatteryStatus.from_tuple((100, 50)),
}
server_device = self.scheduler.next_server(self.active_devices)
self.assertEqual(server_device, "d0")
def test_select_device_with_max_battery(self):
self.scheduler._update_batteries_cb = lambda: {
"d0": None,
"d1": DeviceBatteryStatus.from_tuple((100, 50)),
"d2": DeviceBatteryStatus.from_tuple((100, 100)),
}
server_device = self.scheduler.next_server(self.active_devices)
self.assertEqual(server_device, "d2")
def test_no_server_for_no_active_device(self):
self.scheduler._update_batteries_cb = lambda: {
"d0": None,
"d1": None,
"d2": None,
}
server_device = self.scheduler.next_server(self.active_devices)
self.assertEqual(server_device, None)
import unittest
from edml.controllers.scheduler.sequential import SequentialNextServerScheduler
class SequentialServerDeviceSelectionTest(unittest.TestCase):
def setUp(self) -> None:
self.active_devices = ["d0", "d1", "d2"]
self.scheduler = SequentialNextServerScheduler()
self.scheduler.devices = self.active_devices
def test_select_server_device_for_only_active_devices(self):
server_device = self.scheduler.next_server(self.active_devices)
self.assertEqual(server_device, "d0")
print("=1="),
server_device = self.scheduler.next_server(
self.active_devices, last_server_device_id=server_device
)
self.assertEqual(server_device, "d1")
print("=2=")
server_device = self.scheduler.next_server(
self.active_devices, last_server_device_id=server_device
)
self.assertEqual(server_device, "d2")
print("=3=")
server_device = self.scheduler.next_server(
self.active_devices, last_server_device_id=server_device
)
self.assertEqual(server_device, "d0")
print("=4=")
def test_select_server_device_with_last_server_device_inactive(self):
self.active_devices = ["d1", "d2"]
server_device = self.scheduler.next_server(
self.active_devices, last_server_device_id="d0"
)
self.assertEqual(server_device, "d1")
server_device = self.scheduler.next_server(
self.active_devices, last_server_device_id=server_device
)
self.assertEqual(server_device, "d2")
server_device = self.scheduler.next_server(
self.active_devices, last_server_device_id=server_device
)
self.assertEqual(server_device, "d1")
def test_select_same_server_device_if_all_other_devices_inactive(self):
self.active_devices = ["d1"]
server_device = self.scheduler.next_server(self.active_devices)
self.assertEqual(server_device, "d1")
server_device = self.scheduler.next_server(
self.active_devices, last_server_device_id=server_device
)
self.assertEqual(server_device, "d1")
import unittest
from unittest.mock import Mock, patch
from edml.controllers.scheduler.max_battery import MaxBatteryNextServerScheduler
from edml.controllers.scheduler.smart import SmartNextServerScheduler
from edml.helpers.metrics import DiagnosticMetricResultContainer
from edml.helpers.types import DeviceBatteryStatus
from edml.tests.controllers.test_helper import load_sample_config
class SmartServerDeviceSelectionTest(unittest.TestCase):
def setUp(self):
self.diagnostic_metric_container = Mock(spec=DiagnosticMetricResultContainer)
self.metrics = {
"gradient_size": 100000,
"label_size": 2000,
"smashed_data_size": 100000,
"client_weight_size": 300000,
"server_weight_size": 300000,
"client_norm_fw_time": 3,
"client_norm_bw_time": 3,
"server_norm_fw_time": 3,
"server_norm_bw_time": 3,
"comp_latency_factor": {"d0": 1, "d1": 1.01},
}
self.scheduler = SmartNextServerScheduler(
max_battery_scheduler=Mock(spec=MaxBatteryNextServerScheduler),
cfg=load_sample_config(),
)
self.scheduler._update_batteries_cb = lambda: {
"d0": DeviceBatteryStatus.from_tuple((500, 500)),
"d1": DeviceBatteryStatus.from_tuple((500, 470)),
}
self.scheduler._data_model_cb = lambda: [
{
"d0": (2, 1),
"d1": (4, 1),
},
{"client": 1000000, "server": 1000000},
]
def test_select_server_device_smart_first_round(self):
# should select according to max battery
self.scheduler.next_server([""], diagnostic_metric_container=None)
self.scheduler.max_battery_scheduler.next_server.assert_called_once()
def test_select_server_device_smart_second_round(self):
# should select build a schedule and select the first element
with patch(
"edml.controllers.scheduler.smart.compute_metrics_for_optimization",
return_value=self.metrics,
):
server_device = self.scheduler.next_server(
["d0", "d1"],
diagnostic_metric_container=self.diagnostic_metric_container,
)
self.assertEqual(server_device, "d0")
def test_get_selection_schedule(self):
with patch(
"edml.controllers.scheduler.smart.compute_metrics_for_optimization",
return_value=self.metrics,
):
schedule = self.scheduler._get_selection_schedule(
self.diagnostic_metric_container
)
self.assertEqual(schedule, ["d0", "d1", "d1", "d1"])
import unittest
from unittest.mock import Mock, call, patch
from unittest.mock import Mock, call
from omegaconf import ListConfig
from edml.controllers.scheduler.base import NextServerScheduler
from edml.controllers.scheduler.max_battery import MaxBatteryNextServerScheduler
from edml.controllers.scheduler.random import RandomNextServerScheduler
from edml.controllers.scheduler.sequential import SequentialNextServerScheduler
from edml.controllers.scheduler.smart import SmartNextServerScheduler
from edml.controllers.swarm_controller import SwarmController
from edml.core.device import DeviceRequestDispatcher
from edml.helpers.metrics import (
ModelMetricResultContainer,
DiagnosticMetricResultContainer,
)
from edml.helpers.types import DeviceBatteryStatus
from edml.tests.controllers.test_helper import load_sample_config
......@@ -68,136 +69,23 @@ class SwarmControllerTest(unittest.TestCase):
self.mock.train_global_on.assert_called_once_with("d1", epochs=1, round_no=0)
class SequentialServerDeviceSelectionTest(unittest.TestCase):
def setUp(self) -> None:
self.swarm_controller = SwarmController(load_sample_config())
self.mock = Mock(spec=DeviceRequestDispatcher)
self.swarm_controller.request_dispatcher = self.mock
self.swarm_controller.devices = ListConfig(
[{"device_id": "d0"}, {"device_id": "d1"}, {"device_id": "d2"}]
)
# We need to configure the scheduler in a way that allows us to set the `last_server_device_id` during each test
# step.
self.dispatcher = SequentialNextServerScheduler()
self.dispatcher.initialize(self.swarm_controller)
print(f"scheduler devices: {self.dispatcher.devices}")
self.swarm_controller._next_server_schedulers: [str, NextServerScheduler] = {
"sequential": self.dispatcher,
}
def test_select_server_device_for_only_active_devices(self):
self.mock.active_devices.return_value = ["d0", "d1", "d2"]
self.swarm_controller._refresh_active_devices()
print(f"new active devices: {self.swarm_controller.active_devices}")
server_device = self.swarm_controller._select_server_device(None)
self.assertEqual(server_device, "d0")
print("=1=")
server_device = self.swarm_controller._select_server_device("d0")
self.assertEqual(server_device, "d1")
print("=2=")
server_device = self.swarm_controller._select_server_device("d1")
self.assertEqual(server_device, "d2")
print("=3=")
server_device = self.swarm_controller._select_server_device("d2")
self.assertEqual(server_device, "d0")
print("=4=")
def test_select_server_device_with_last_server_device_inactive(self):
self.mock.active_devices.return_value = ["d1", "d2"]
self.swarm_controller._refresh_active_devices()
server_device = self.swarm_controller._select_server_device_sequentially("d0")
self.assertEqual(server_device, "d1")
server_device = self.swarm_controller._select_server_device_sequentially("d1")
self.assertEqual(server_device, "d2")
server_device = self.swarm_controller._select_server_device_sequentially("d2")
self.assertEqual(server_device, "d1")
def test_select_same_server_device_if_all_other_devices_inactive(self):
self.mock.active_devices.return_value = ["d1"]
self.swarm_controller._refresh_active_devices()
server_device = self.swarm_controller._select_server_device_sequentially(None)
self.assertEqual(server_device, "d1")
server_device = self.swarm_controller._select_server_device_sequentially("d1")
self.assertEqual(server_device, "d1")
def _get_battery_level_side_effect(battery_levels):
"""Initializer for the side effect of the get_battery_status_on method.
Takes a dict of battery levels for each device id."""
def side_effect(*args, **kwargs):
return battery_levels[args[0]]
return side_effect
class MaxBatteryServerDeviceSelectionTest(unittest.TestCase):
def setUp(self) -> None:
self.swarm_controller = SwarmController(
load_sample_config(), selection_strategy="max_battery"
)
self.mock = Mock(spec=DeviceRequestDispatcher)
self.swarm_controller.request_dispatcher = self.mock
self.mock.active_devices.return_value = ["d0", "d1", "d2"]
self.swarm_controller._refresh_active_devices()
def test_select_first_device_for_equal_batteries(self):
self.mock.get_battery_status_on.side_effect = _get_battery_level_side_effect(
{
"d0": DeviceBatteryStatus.from_tuple((100, 50)),
"d1": DeviceBatteryStatus.from_tuple((100, 50)),
"d2": DeviceBatteryStatus.from_tuple((100, 50)),
}
)
server_device = self.swarm_controller._select_server_device()
self.assertEqual(server_device, "d0")
def test_select_device_with_max_battery(self):
self.mock.get_battery_status_on.side_effect = _get_battery_level_side_effect(
{
"d0": None,
"d1": DeviceBatteryStatus.from_tuple((100, 50)),
"d2": DeviceBatteryStatus.from_tuple((100, 100)),
}
)
server_device = self.swarm_controller._select_server_device()
self.assertEqual(server_device, "d2")
def test_no_server_for_no_active_device(self):
self.mock.get_battery_status_on.side_effect = _get_battery_level_side_effect(
{"d0": None, "d1": None, "d2": None}
)
server_device = self.swarm_controller._select_server_device()
self.assertEqual(server_device, None)
class ServerDeviceSelectionTest(unittest.TestCase):
def setUp(self) -> None:
self.swarm_controller = SwarmController(load_sample_config())
self.mock = Mock(spec=DeviceRequestDispatcher)
self.swarm_controller.request_dispatcher = self.mock
self.swarm_controller.request_dispatcher = Mock(spec=DeviceRequestDispatcher)
self.swarm_controller.devices = ListConfig(
[{"device_id": "d0"}, {"device_id": "d1"}, {"device_id": "d2"}]
) # omitted address etc.
self.swarm_controller._next_server_schedulers = {
"max_battery": Mock(spec=MaxBatteryNextServerScheduler),
"random": Mock(spec=RandomNextServerScheduler),
"sequential": Mock(spec=SequentialNextServerScheduler),
"smart": Mock(spec=SmartNextServerScheduler),
}
def test_select_no_server_device_if_no_active_devices(self):
self.mock.active_devices.return_value = []
self.swarm_controller.request_dispatcher.active_devices.return_value = []
self.swarm_controller._refresh_active_devices()
server_device = self.swarm_controller._select_server_device(None)
......@@ -206,61 +94,69 @@ class ServerDeviceSelectionTest(unittest.TestCase):
server_device = self.swarm_controller._select_server_device("d1")
self.assertEqual(server_device, None)
class SmartServerDeviceSelectionTest(unittest.TestCase):
def setUp(self):
self.swarm_controller = SwarmController(load_sample_config())
self.mock = Mock(spec=DeviceRequestDispatcher)
self.swarm_controller.request_dispatcher = self.mock
self.swarm_controller.devices = ListConfig(
[{"device_id": "d0"}, {"device_id": "d1"}]
)
self.diagnostic_metric_container = Mock(spec=DiagnosticMetricResultContainer)
self.metrics = {
"gradient_size": 100000,
"label_size": 2000,
"smashed_data_size": 100000,
"client_weight_size": 300000,
"server_weight_size": 300000,
"client_norm_fw_time": 3,
"client_norm_bw_time": 3,
"server_norm_fw_time": 3,
"server_norm_bw_time": 3,
"comp_latency_factor": {"d0": 1, "d1": 1.01},
}
self.mock.get_battery_status_on.side_effect = _get_battery_level_side_effect(
{"d0": (500, 500), "d1": (500, 470)}
)
self.mock.get_dataset_model_info_on.side_effect = [
(2, 1, 1000000, 1000000),
(4, 1, 1000000, 1000000),
]
def test_select_server_device_smart_first_round(self):
# should select according to max battery
server_device = self.swarm_controller._select_server_device_smart(
None, self.diagnostic_metric_container
)
self.assertEqual(server_device, "d0")
def test_select_server_device_smart_second_round(self):
# should select build a schedule and select the first element
with patch(
"edml.controllers.swarm_controller.compute_metrics_for_optimization",
return_value=self.metrics,
):
server_device = self.swarm_controller._select_server_device_smart(
"d1", self.diagnostic_metric_container
)
self.assertEqual(server_device, "d0")
def test_get_selection_schedule(self):
with patch(
"edml.controllers.swarm_controller.compute_metrics_for_optimization",
return_value=self.metrics,
):
schedule = self.swarm_controller._get_selection_schedule(
self.diagnostic_metric_container
)
self.assertEqual(schedule, ["d0", "d1", "d1", "d1"])
def test_sequential_selection_default(self):
self.swarm_controller._select_server_device()
self.swarm_controller._next_server_schedulers[
"sequential"
].next_server.assert_called()
self.swarm_controller._next_server_schedulers[
"max_battery"
].next_server.assert_not_called()
self.swarm_controller._next_server_schedulers[
"random"
].next_server.assert_not_called()
self.swarm_controller._next_server_schedulers[
"smart"
].next_server.assert_not_called()
def test_max_battery_selection(self):
self.swarm_controller.selection_strategy = "max_battery"
self.swarm_controller._select_server_device()
self.swarm_controller._next_server_schedulers[
"max_battery"
].next_server.assert_called()
self.swarm_controller._next_server_schedulers[
"sequential"
].next_server.assert_not_called()
self.swarm_controller._next_server_schedulers[
"random"
].next_server.assert_not_called()
self.swarm_controller._next_server_schedulers[
"smart"
].next_server.assert_not_called()
def test_random_selection(self):
self.swarm_controller.selection_strategy = "random"
self.swarm_controller._select_server_device()
self.swarm_controller._next_server_schedulers[
"random"
].next_server.assert_called()
self.swarm_controller._next_server_schedulers[
"sequential"
].next_server.assert_not_called()
self.swarm_controller._next_server_schedulers[
"max_battery"
].next_server.assert_not_called()
self.swarm_controller._next_server_schedulers[
"smart"
].next_server.assert_not_called()
def test_smart_selection(self):
self.swarm_controller.selection_strategy = "smart"
self.swarm_controller._select_server_device()
self.swarm_controller._next_server_schedulers[
"smart"
].next_server.assert_called()
self.swarm_controller._next_server_schedulers[
"sequential"
].next_server.assert_not_called()
self.swarm_controller._next_server_schedulers[
"random"
].next_server.assert_not_called()
self.swarm_controller._next_server_schedulers[
"max_battery"
].next_server.assert_not_called()
......@@ -12,6 +12,7 @@ from edml.helpers.config_helpers import (
get_device_index_by_id,
instantiate_controller,
__drop_irrelevant_keys__,
get_torch_device_id,
)
......@@ -64,6 +65,14 @@ class ConfigHelpersTest(unittest.TestCase):
preprocess_config(self.cfg)
self.assertEqual("d1", self.cfg.own_device_id)
def test_get_default_torch_device_if_cuda_available(self):
with patch("torch.cuda.is_available", return_value=True):
self.assertEqual(get_torch_device_id(self.cfg), "cuda:0")
def test_get_default_torch_device_if_cuda_not_available(self):
with patch("torch.cuda.is_available", return_value=False):
self.assertEqual(get_torch_device_id(self.cfg), "cpu")
def test_preprocess_config_group_name(self):
preprocess_config(self.cfg)
self.assertEqual(self.cfg.group, "swarm_max_battery_static")
......@@ -102,3 +111,35 @@ class ControllerInstantiationTest(unittest.TestCase):
self.assertListEqual(
list(reduced_cfg.keys()), ["_target_", "_partial_", "scheduler"]
)
class GetTorchDeviceIdTest(unittest.TestCase):
def setUp(self) -> None:
self.cfg = DictConfig(
{
"own_device_id": "d0",
"topology": {
"devices": [
{
"device_id": "d0",
"address": "localhost:50051",
"torch_device": "my_torch_device1",
},
{
"device_id": "d1",
"address": "localhost:50052",
"torch_device": "my_torch_device2",
},
]
},
"num_devices": "${len:${topology.devices}}",
}
)
def test_get_torch_device1(self):
self.assertEqual(get_torch_device_id(self.cfg), "my_torch_device1")
def test_get_torch_device2(self):
self.cfg.own_device_id = "d1"
self.assertEqual(get_torch_device_id(self.cfg), "my_torch_device2")
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment