Skip to content
Snippets Groups Projects
Select Git revision
  • v1.3.1
  • master default protected
  • gitkeep
  • dev protected
  • Issue/2353-dropShapeFix
  • Issue/2583-treeBug
  • Hotfix/2562-organizations
  • Issue/2464-invalidateMeta
  • Issue/2484-filterExtracted
  • Issue/2309-docs
  • Issue/2462-removeTraces
  • Hotfix/2459-EncodingPath
  • Hotfix/2452-linkedDeletion
  • Issue/2328-noFailOnLog
  • Issue/1792-newMetadataStructure
  • v2.5.2-Hotfix2365
  • Hotfix/2365-targetClassWorks
  • Issue/2269-niceKpiParser
  • Issue/2295-singleOrganizationFix
  • Issue/1953-owlImports
  • Hotfix/2087-efNet6
  • v2.9.0
  • v2.8.2
  • v2.8.1
  • v2.8.0
  • v2.7.2
  • v2.7.1
  • v2.7.0
  • v2.6.2
  • v2.6.1
  • v2.6.0
  • v2.5.3
  • v2.5.2
  • v2.5.1
  • v2.5.0
  • v2.4.1
  • v2.4.0
  • v2.3.0
  • v2.2.0
  • v2.1.0
  • v2.0.0
41 results

build.cake

Blame
  • Code owners
    Assign users and groups as approvers for specific file changes. Learn more.
    config_helpers.py 5.06 KiB
    from inspect import signature
    
    import torch
    from hydra.utils import get_class, instantiate
    from omegaconf import OmegaConf, DictConfig
    from omegaconf.errors import ConfigAttributeError
    
    from edml.controllers.base_controller import BaseController
    
    
    def get_device_address_by_id(device_id: str, cfg: DictConfig) -> str:
        """
        Returns the binding address of the device with the given id.
    
        Args:
            device_id (str): The device id.
            cfg (DictConfig): The config loaded from YAML files.
    
        Returns:
            The device's binding address.
    
        Raises:
            StopIteration: If the device with the given ID cannot be found.
        """
        return next(
            device.address
            for device in cfg.topology.devices
            if device.device_id == device_id
        )
    
    
    def get_device_id_by_index(cfg: DictConfig, index: int) -> str:
        """
        Returns the device info of the device with the given index in the network topology.
    
        Args:
            cfg (DictConfig): The config loaded from YAML files.
            index (int): The index of the device inside the configuration file.
    
        Returns:
            The device's ID.
    
        Raises:
            StopIteration: If the device with the given ID cannot be found.
        """
        return cfg.topology.devices[index].device_id
    
    
    def get_device_index_by_id(cfg: DictConfig, device_id: str) -> int:
        """
        Returns the index of the device with the given id in the network topology.
    
        Args:
            cfg (DictConfig): The config loaded from YAML files.
            device_id (str): The device's ID.
    
        Returns:
            The index of the device inside the configuration file.
    
        Raises:
            StopIteration: If the device with the given ID cannot be found.
        """
        return next(
            i
            for i, device in enumerate(cfg.topology.devices)
            if device.device_id == device_id
        )
    
    
    def preprocess_config(cfg: DictConfig):
        """
        Configures `OmegaConf` and registers custom resolvers. Additionally, normalizes the configuration file for command
        line usage:
    
            - If `own_device_id` is an integer, the value is treated as an index into the list of available devices; it is
              treated as the i-th device inside the configured topology. This functions then looks up the device_id by index
              and sets `own_device_id`.
        """
        OmegaConf.register_new_resolver("len", lambda x: len(x), replace=True)
        OmegaConf.resolve(cfg)
    
        # In case someone specified an integer instead of a proper device_id (str), we look up the proper device by indexing
        # the list of all available devices using said integer.
        if isinstance(cfg.own_device_id, int):
            cfg.own_device_id = get_device_id_by_index(cfg, cfg.own_device_id)
    
    
    def instantiate_controller(cfg: DictConfig) -> BaseController:
        """
        Instantiates a controller based on the configuration. This method filters out extra parameters defined through hydra
        but not required by the controller's init method. This allows for hydra's multirun feature to work even if
        controllers have different parameters (like next server schedulers).
    
        Args:
            cfg: The controller configuration.
    
        Returns:
            An instance of `BaseController`.
        """
    
        # Filter out any arguments not present in the controller constructor. This is a hack required to make multirun work.
        # We want to be able to use different scheduling strategies combined with different controllers. But hydra's
        # `instantiate` method is strict and fails if it receives any extra arguments.
        controller_class = get_class(cfg.controller._target_)
        controller_signature = signature(controller_class.__init__)
        controller_args = controller_signature.parameters.keys()
    
        # These are special hydra keywords that we do not want to filter out.
        special_keys = ["_target_", "_recursive_", "_partial_"]
        cfg.controller = {
            k: v
            for k, v in cfg.controller.items()
            if k in controller_args or k in special_keys
        }
    
        # Update the device ID and set it to controller.
        cfg.own_device_id = "controller"
    
        # Instantiate the controller.
        controller: BaseController = instantiate(cfg.controller)(cfg=cfg)
        return controller
    
    
    def get_torch_device_id(cfg: DictConfig) -> str:
        """
        Returns the configured torch_device for the current device.
        Resorts to default if no torch_device is configured.
    
        Args:
            cfg (DictConfig): The config loaded from YAML files.
    
        Returns:
            The id of the configured torch_device for the current device.
    
        Raises:
            StopIteration: If the device with the given ID cannot be found.
            ConfigAttributeError: If no device id is present in the config.
        """
        own_device_id = cfg.own_device_id
        try:
            return next(
                device_cfg.torch_device
                for device_cfg in cfg.topology.devices
                if device_cfg.device_id == own_device_id
            )
        except ConfigAttributeError:
            return _default_torch_device()
    
    
    def _default_torch_device():
        """
        Returns the default torch devices, depending on whether cuda is available.
        """
        return "cuda:0" if torch.cuda.is_available() else "cpu"