Skip to content
Snippets Groups Projects
Commit 5fede3b6 authored by Sven Michael Lechner's avatar Sven Michael Lechner
Browse files

Merge branch 'configurable_torch_device' into 'wip'

Made torch device configurable

See merge request !10
parents 2546b936 66b78ca4
Branches
No related tags found
2 merge requests!18Merge in main,!10Made torch device configurable
from inspect import signature
import torch
from hydra.utils import get_class, instantiate
from omegaconf import OmegaConf, DictConfig
from omegaconf.errors import ConfigAttributeError
from edml.controllers.base_controller import BaseController
......@@ -117,3 +119,36 @@ def instantiate_controller(cfg: DictConfig) -> BaseController:
# Instantiate the controller.
controller: BaseController = instantiate(cfg.controller)(cfg=cfg)
return controller
def get_torch_device_id(cfg: DictConfig) -> str:
"""
Returns the configured torch_device for the current device.
Resorts to default if no torch_device is configured.
Args:
cfg (DictConfig): The config loaded from YAML files.
Returns:
The id of the configured torch_device for the current device.
Raises:
StopIteration: If the device with the given ID cannot be found.
ConfigAttributeError: If no device id is present in the config.
"""
own_device_id = cfg.own_device_id
try:
return next(
device_cfg.torch_device
for device_cfg in cfg.topology.devices
if device_cfg.device_id == own_device_id
)
except ConfigAttributeError:
return _default_torch_device()
def _default_torch_device():
"""
Returns the default torch devices, depending on whether cuda is available.
"""
return "cuda:0" if torch.cuda.is_available() else "cpu"
import unittest
from unittest.mock import patch
from omegaconf import DictConfig
......@@ -7,6 +8,7 @@ from edml.helpers.config_helpers import (
get_device_address_by_id,
preprocess_config,
get_device_index_by_id,
get_torch_device_id,
)
......@@ -46,3 +48,43 @@ class ConfigHelpersTest(unittest.TestCase):
self.cfg.own_device_id = 1
preprocess_config(self.cfg)
self.assertEqual("d1", self.cfg.own_device_id)
def test_get_default_torch_device_if_cuda_available(self):
with patch("torch.cuda.is_available", return_value=True):
self.assertEqual(get_torch_device_id(self.cfg), "cuda:0")
def test_get_default_torch_device_if_cuda_not_available(self):
with patch("torch.cuda.is_available", return_value=False):
self.assertEqual(get_torch_device_id(self.cfg), "cpu")
class GetTorchDeviceIdTest(unittest.TestCase):
def setUp(self) -> None:
self.cfg = DictConfig(
{
"own_device_id": "d0",
"topology": {
"devices": [
{
"device_id": "d0",
"address": "localhost:50051",
"torch_device": "my_torch_device1",
},
{
"device_id": "d1",
"address": "localhost:50052",
"torch_device": "my_torch_device2",
},
]
},
"num_devices": "${len:${topology.devices}}",
}
)
def test_get_torch_device1(self):
self.assertEqual(get_torch_device_id(self.cfg), "my_torch_device1")
def test_get_torch_device2(self):
self.cfg.own_device_id = "d1"
self.assertEqual(get_torch_device_id(self.cfg), "my_torch_device2")
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment