Skip to content
Snippets Groups Projects

Made torch device configurable

Merged Tim Tobias Bauerle requested to merge configurable_torch_device into wip
2 files
+ 77
0
Compare changes
  • Side-by-side
  • Inline
Files
2
from inspect import signature
import torch
from hydra.utils import get_class, instantiate
from omegaconf import OmegaConf, DictConfig
from omegaconf.errors import ConfigAttributeError
from edml.controllers.base_controller import BaseController
@@ -117,3 +119,36 @@ def instantiate_controller(cfg: DictConfig) -> BaseController:
# Instantiate the controller.
controller: BaseController = instantiate(cfg.controller)(cfg=cfg)
return controller
def get_torch_device_id(cfg: DictConfig) -> str:
"""
Returns the configured torch_device for the current device.
Resorts to default if no torch_device is configured.
Args:
cfg (DictConfig): The config loaded from YAML files.
Returns:
The id of the configured torch_device for the current device.
Raises:
StopIteration: If the device with the given ID cannot be found.
ConfigAttributeError: If no device id is present in the config.
"""
own_device_id = cfg.own_device_id
try:
return next(
device_cfg.torch_device
for device_cfg in cfg.topology.devices
if device_cfg.device_id == own_device_id
)
except ConfigAttributeError:
return _default_torch_device()
def _default_torch_device():
"""
Returns the default torch devices, depending on whether cuda is available.
"""
return "cuda:0" if torch.cuda.is_available() else "cpu"
Loading