Skip to content
Snippets Groups Projects
Select Git revision
  • Sprint/2021-08
  • master default protected
  • gitkeep
  • dev protected
  • Issue/2449-GuidPidSlugToProjectSettings
  • Issue/2309-docs
  • Issue/2355-topLevelOrg
  • Issue/2328-noFailOnLog
  • Hotfix/2371-fixGitLabinRCV
  • Issue/2287-guestRole
  • Fix/xxxx-activateGitlab
  • Test/xxxx-enablingGitLab
  • Issue/2349-gitlabHttps
  • Issue/2259-updatePids
  • Issue/2101-gitLabResTypeUi
  • Hotfix/2202-fixNaNQuota
  • Issue/2246-quotaResoval
  • Issue/2221-projectDateCreated
  • Hotfix/2224-quotaSizeAnalytics
  • Fix/xxxx-resourceVisibility
  • Issue/2000-gitlabResourcesAPI
  • v4.4.3
  • v4.4.2
  • v4.4.1
  • v4.4.0
  • v4.3.4
  • v4.3.3
  • v4.3.2
  • v4.3.1
  • v4.3.0
  • v4.2.8
  • v4.2.7
  • v4.2.6
  • v4.2.5
  • v4.2.4
  • v4.2.3
  • v4.2.2
  • v4.2.1
  • v4.2.0
  • v4.1.1
  • v4.1.0
41 results

ProjectController.cs

Blame
  • Code owners
    Assign users and groups as approvers for specific file changes. Learn more.
    parallel_split_controller.py 3.67 KiB
    from edml.controllers.adaptive_threshold_mechanism import AdaptiveThresholdFn
    from edml.controllers.adaptive_threshold_mechanism.static import (
        StaticAdaptiveThresholdFn,
    )
    from edml.controllers.base_controller import BaseController
    from edml.controllers.scheduler.base import NextServerScheduler
    from edml.helpers.config_helpers import get_device_index_by_id
    
    
    class ParallelSplitController(BaseController):
        def __init__(
            self,
            cfg,
            scheduler: NextServerScheduler,
            adaptive_threshold_fn: AdaptiveThresholdFn = StaticAdaptiveThresholdFn(0.0),
        ):
            super().__init__(cfg)
            scheduler.initialize(self)
            self._next_server_scheduler = scheduler
            self._adaptive_threshold_fn = adaptive_threshold_fn
    
        def _train(self):
            client_weights = None
            server_weights = None
            server_device_id = self.cfg.topology.devices[0].device_id
    
            for i in range(self.cfg.experiment.max_rounds):
                print(f"=================================Round {i}")
    
                # We fetch the newest device information to check and see what active devices are still available.
                # After that, we can also update the next server device if applicable.
                self._update_devices_battery_status()
                # break if no active devices or only server device left
                if self._devices_empty_or_only_server_left(server_device_id):
                    print("No active client devices left.")
                    break
                if self._next_server_scheduler:
                    server_device_id = self._next_server()
                    print(f"<> training on server: {server_device_id} <>")
    
                # set latest server weights once we did a single round of training.
                if server_weights is not None:
                    print(f">>> Propagating newest server weights to {server_device_id}")
                    self.request_dispatcher.set_weights_on(
                        device_id=server_device_id,
                        state_dict=server_weights,
                        on_client=False,
                    )
    
                # Start parallel training of all client devices.
                adaptive_threshold = self._adaptive_threshold_fn.invoke(i)
                self.logger.log({"adaptive-threshold": adaptive_threshold})
                training_response = self.request_dispatcher.train_parallel_on_server(
                    server_device_id=server_device_id,
                    epochs=1,
                    round_no=i,
                    adaptive_learning_threshold=adaptive_threshold,
                )
    
                self._refresh_active_devices()
                self.logger.log(
                    {"remaining_devices": {"devices": len(self.active_devices), "round": i}}
                )
                self.logger.log(
                    {
                        "server_device": {
                            "device": get_device_index_by_id(self.cfg, server_device_id)
                        },
                        "round": i,
                    }
                )  # log the server device index for convenience
    
                if training_response is False:  # server device unavailable
                    print(f"Training response was false.")
                    break
                else:
                    cw, server_weights, metrics, _ = training_response
    
                    self._aggregate_and_log_metrics(metrics, i)
    
                    early_stop = self.early_stopping(metrics, i)
                    if early_stop:
                        print(f"Early stopping triggered.")
                        break
    
                    self._save_weights(
                        client_weights=cw, server_weights=server_weights, round_no=i
                    )
    
        def _next_server(self) -> str:
            return self._next_server_scheduler.next_server(self.active_devices)