Select Git revision
SubProjectController.cs
Code owners
Assign users and groups as approvers for specific file changes. Learn more.
parallel_split_controller.py 3.67 KiB
from edml.controllers.adaptive_threshold_mechanism import AdaptiveThresholdFn
from edml.controllers.adaptive_threshold_mechanism.static import (
StaticAdaptiveThresholdFn,
)
from edml.controllers.base_controller import BaseController
from edml.controllers.scheduler.base import NextServerScheduler
from edml.helpers.config_helpers import get_device_index_by_id
class ParallelSplitController(BaseController):
def __init__(
self,
cfg,
scheduler: NextServerScheduler,
adaptive_threshold_fn: AdaptiveThresholdFn = StaticAdaptiveThresholdFn(0.0),
):
super().__init__(cfg)
scheduler.initialize(self)
self._next_server_scheduler = scheduler
self._adaptive_threshold_fn = adaptive_threshold_fn
def _train(self):
client_weights = None
server_weights = None
server_device_id = self.cfg.topology.devices[0].device_id
for i in range(self.cfg.experiment.max_rounds):
print(f"=================================Round {i}")
# We fetch the newest device information to check and see what active devices are still available.
# After that, we can also update the next server device if applicable.
self._update_devices_battery_status()
# break if no active devices or only server device left
if self._devices_empty_or_only_server_left(server_device_id):
print("No active client devices left.")
break
if self._next_server_scheduler:
server_device_id = self._next_server()
print(f"<> training on server: {server_device_id} <>")
# set latest server weights once we did a single round of training.
if server_weights is not None:
print(f">>> Propagating newest server weights to {server_device_id}")
self.request_dispatcher.set_weights_on(
device_id=server_device_id,
state_dict=server_weights,
on_client=False,
)
# Start parallel training of all client devices.
adaptive_threshold = self._adaptive_threshold_fn.invoke(i)
self.logger.log({"adaptive-threshold": adaptive_threshold})
training_response = self.request_dispatcher.train_parallel_on_server(
server_device_id=server_device_id,
epochs=1,
round_no=i,
adaptive_learning_threshold=adaptive_threshold,
)
self._refresh_active_devices()
self.logger.log(
{"remaining_devices": {"devices": len(self.active_devices), "round": i}}
)
self.logger.log(
{
"server_device": {
"device": get_device_index_by_id(self.cfg, server_device_id)
},
"round": i,
}
) # log the server device index for convenience
if training_response is False: # server device unavailable
print(f"Training response was false.")
break
else:
cw, server_weights, metrics, _ = training_response
self._aggregate_and_log_metrics(metrics, i)
early_stop = self.early_stopping(metrics, i)
if early_stop:
print(f"Early stopping triggered.")
break
self._save_weights(
client_weights=cw, server_weights=server_weights, round_no=i
)
def _next_server(self) -> str:
return self._next_server_scheduler.next_server(self.active_devices)