diff --git a/edml/helpers/config_helpers.py b/edml/helpers/config_helpers.py index d848472f06e460ec636524869368ce9b91a17553..63c7b6873c63553ced6ec637791070cc1b61258c 100644 --- a/edml/helpers/config_helpers.py +++ b/edml/helpers/config_helpers.py @@ -1,7 +1,9 @@ from inspect import signature +import torch from hydra.utils import get_class, instantiate from omegaconf import OmegaConf, DictConfig +from omegaconf.errors import ConfigAttributeError from edml.controllers.base_controller import BaseController @@ -117,3 +119,36 @@ def instantiate_controller(cfg: DictConfig) -> BaseController: # Instantiate the controller. controller: BaseController = instantiate(cfg.controller)(cfg=cfg) return controller + + +def get_torch_device_id(cfg: DictConfig) -> str: + """ + Returns the configured torch_device for the current device. + Resorts to default if no torch_device is configured. + + Args: + cfg (DictConfig): The config loaded from YAML files. + + Returns: + The id of the configured torch_device for the current device. + + Raises: + StopIteration: If the device with the given ID cannot be found. + ConfigAttributeError: If no device id is present in the config. + """ + own_device_id = cfg.own_device_id + try: + return next( + device_cfg.torch_device + for device_cfg in cfg.topology.devices + if device_cfg.device_id == own_device_id + ) + except ConfigAttributeError: + return _default_torch_device() + + +def _default_torch_device(): + """ + Returns the default torch devices, depending on whether cuda is available. + """ + return "cuda:0" if torch.cuda.is_available() else "cpu" diff --git a/edml/tests/helpers/config_helpers_test.py b/edml/tests/helpers/config_helpers_test.py index 3f634fd95083d5402b4ab37b1f4ea7eae270bd91..41f18361a36eb62464202fd2b91c5bc6bb2043a5 100644 --- a/edml/tests/helpers/config_helpers_test.py +++ b/edml/tests/helpers/config_helpers_test.py @@ -1,4 +1,5 @@ import unittest +from unittest.mock import patch from omegaconf import DictConfig @@ -7,6 +8,7 @@ from edml.helpers.config_helpers import ( get_device_address_by_id, preprocess_config, get_device_index_by_id, + get_torch_device_id, ) @@ -46,3 +48,43 @@ class ConfigHelpersTest(unittest.TestCase): self.cfg.own_device_id = 1 preprocess_config(self.cfg) self.assertEqual("d1", self.cfg.own_device_id) + + def test_get_default_torch_device_if_cuda_available(self): + with patch("torch.cuda.is_available", return_value=True): + self.assertEqual(get_torch_device_id(self.cfg), "cuda:0") + + def test_get_default_torch_device_if_cuda_not_available(self): + with patch("torch.cuda.is_available", return_value=False): + self.assertEqual(get_torch_device_id(self.cfg), "cpu") + + +class GetTorchDeviceIdTest(unittest.TestCase): + + def setUp(self) -> None: + self.cfg = DictConfig( + { + "own_device_id": "d0", + "topology": { + "devices": [ + { + "device_id": "d0", + "address": "localhost:50051", + "torch_device": "my_torch_device1", + }, + { + "device_id": "d1", + "address": "localhost:50052", + "torch_device": "my_torch_device2", + }, + ] + }, + "num_devices": "${len:${topology.devices}}", + } + ) + + def test_get_torch_device1(self): + self.assertEqual(get_torch_device_id(self.cfg), "my_torch_device1") + + def test_get_torch_device2(self): + self.cfg.own_device_id = "d1" + self.assertEqual(get_torch_device_id(self.cfg), "my_torch_device2")