Select Git revision
train_SECRET.py
Code owners
Assign users and groups as approvers for specific file changes. Learn more.
config_helpers.py 5.06 KiB
from inspect import signature
import torch
from hydra.utils import get_class, instantiate
from omegaconf import OmegaConf, DictConfig
from omegaconf.errors import ConfigAttributeError
from edml.controllers.base_controller import BaseController
def get_device_address_by_id(device_id: str, cfg: DictConfig) -> str:
"""
Returns the binding address of the device with the given id.
Args:
device_id (str): The device id.
cfg (DictConfig): The config loaded from YAML files.
Returns:
The device's binding address.
Raises:
StopIteration: If the device with the given ID cannot be found.
"""
return next(
device.address
for device in cfg.topology.devices
if device.device_id == device_id
)
def get_device_id_by_index(cfg: DictConfig, index: int) -> str:
"""
Returns the device info of the device with the given index in the network topology.
Args:
cfg (DictConfig): The config loaded from YAML files.
index (int): The index of the device inside the configuration file.
Returns:
The device's ID.
Raises:
StopIteration: If the device with the given ID cannot be found.
"""
return cfg.topology.devices[index].device_id
def get_device_index_by_id(cfg: DictConfig, device_id: str) -> int:
"""
Returns the index of the device with the given id in the network topology.
Args:
cfg (DictConfig): The config loaded from YAML files.
device_id (str): The device's ID.
Returns:
The index of the device inside the configuration file.
Raises:
StopIteration: If the device with the given ID cannot be found.
"""
return next(
i
for i, device in enumerate(cfg.topology.devices)
if device.device_id == device_id
)
def preprocess_config(cfg: DictConfig):
"""
Configures `OmegaConf` and registers custom resolvers. Additionally, normalizes the configuration file for command
line usage:
- If `own_device_id` is an integer, the value is treated as an index into the list of available devices; it is
treated as the i-th device inside the configured topology. This functions then looks up the device_id by index
and sets `own_device_id`.
"""
OmegaConf.register_new_resolver("len", lambda x: len(x), replace=True)
OmegaConf.resolve(cfg)
# In case someone specified an integer instead of a proper device_id (str), we look up the proper device by indexing
# the list of all available devices using said integer.
if isinstance(cfg.own_device_id, int):
cfg.own_device_id = get_device_id_by_index(cfg, cfg.own_device_id)
def instantiate_controller(cfg: DictConfig) -> BaseController:
"""
Instantiates a controller based on the configuration. This method filters out extra parameters defined through hydra
but not required by the controller's init method. This allows for hydra's multirun feature to work even if
controllers have different parameters (like next server schedulers).
Args:
cfg: The controller configuration.
Returns:
An instance of `BaseController`.
"""
# Filter out any arguments not present in the controller constructor. This is a hack required to make multirun work.
# We want to be able to use different scheduling strategies combined with different controllers. But hydra's
# `instantiate` method is strict and fails if it receives any extra arguments.
controller_class = get_class(cfg.controller._target_)
controller_signature = signature(controller_class.__init__)
controller_args = controller_signature.parameters.keys()
# These are special hydra keywords that we do not want to filter out.
special_keys = ["_target_", "_recursive_", "_partial_"]
cfg.controller = {
k: v
for k, v in cfg.controller.items()
if k in controller_args or k in special_keys
}
# Update the device ID and set it to controller.
cfg.own_device_id = "controller"
# Instantiate the controller.
controller: BaseController = instantiate(cfg.controller)(cfg=cfg)
return controller
def get_torch_device_id(cfg: DictConfig) -> str:
"""
Returns the configured torch_device for the current device.
Resorts to default if no torch_device is configured.
Args:
cfg (DictConfig): The config loaded from YAML files.
Returns:
The id of the configured torch_device for the current device.
Raises:
StopIteration: If the device with the given ID cannot be found.
ConfigAttributeError: If no device id is present in the config.
"""
own_device_id = cfg.own_device_id
try:
return next(
device_cfg.torch_device
for device_cfg in cfg.topology.devices
if device_cfg.device_id == own_device_id
)
except ConfigAttributeError:
return _default_torch_device()
def _default_torch_device():
"""
Returns the default torch devices, depending on whether cuda is available.
"""
return "cuda:0" if torch.cuda.is_available() else "cpu"