From 66b78ca4114d65e75dbe09fe508fab445529663a Mon Sep 17 00:00:00 2001
From: Tim Bauerle <tim.bauerle@rwth-aachen.de>
Date: Thu, 13 Jun 2024 12:34:33 +0200
Subject: [PATCH] Made torch device configurable for each device. Provided
 default implementation so defining the torch device is not mandatory.

---
 edml/helpers/config_helpers.py            | 35 +++++++++++++++++++
 edml/tests/helpers/config_helpers_test.py | 42 +++++++++++++++++++++++
 2 files changed, 77 insertions(+)

diff --git a/edml/helpers/config_helpers.py b/edml/helpers/config_helpers.py
index d848472..63c7b68 100644
--- a/edml/helpers/config_helpers.py
+++ b/edml/helpers/config_helpers.py
@@ -1,7 +1,9 @@
 from inspect import signature
 
+import torch
 from hydra.utils import get_class, instantiate
 from omegaconf import OmegaConf, DictConfig
+from omegaconf.errors import ConfigAttributeError
 
 from edml.controllers.base_controller import BaseController
 
@@ -117,3 +119,36 @@ def instantiate_controller(cfg: DictConfig) -> BaseController:
     # Instantiate the controller.
     controller: BaseController = instantiate(cfg.controller)(cfg=cfg)
     return controller
+
+
+def get_torch_device_id(cfg: DictConfig) -> str:
+    """
+    Returns the configured torch_device for the current device.
+    Resorts to default if no torch_device is configured.
+
+    Args:
+        cfg (DictConfig): The config loaded from YAML files.
+
+    Returns:
+        The id of the configured torch_device for the current device.
+
+    Raises:
+        StopIteration: If the device with the given ID cannot be found.
+        ConfigAttributeError: If no device id is present in the config.
+    """
+    own_device_id = cfg.own_device_id
+    try:
+        return next(
+            device_cfg.torch_device
+            for device_cfg in cfg.topology.devices
+            if device_cfg.device_id == own_device_id
+        )
+    except ConfigAttributeError:
+        return _default_torch_device()
+
+
+def _default_torch_device():
+    """
+    Returns the default torch devices, depending on whether cuda is available.
+    """
+    return "cuda:0" if torch.cuda.is_available() else "cpu"
diff --git a/edml/tests/helpers/config_helpers_test.py b/edml/tests/helpers/config_helpers_test.py
index 3f634fd..41f1836 100644
--- a/edml/tests/helpers/config_helpers_test.py
+++ b/edml/tests/helpers/config_helpers_test.py
@@ -1,4 +1,5 @@
 import unittest
+from unittest.mock import patch
 
 from omegaconf import DictConfig
 
@@ -7,6 +8,7 @@ from edml.helpers.config_helpers import (
     get_device_address_by_id,
     preprocess_config,
     get_device_index_by_id,
+    get_torch_device_id,
 )
 
 
@@ -46,3 +48,43 @@ class ConfigHelpersTest(unittest.TestCase):
         self.cfg.own_device_id = 1
         preprocess_config(self.cfg)
         self.assertEqual("d1", self.cfg.own_device_id)
+
+    def test_get_default_torch_device_if_cuda_available(self):
+        with patch("torch.cuda.is_available", return_value=True):
+            self.assertEqual(get_torch_device_id(self.cfg), "cuda:0")
+
+    def test_get_default_torch_device_if_cuda_not_available(self):
+        with patch("torch.cuda.is_available", return_value=False):
+            self.assertEqual(get_torch_device_id(self.cfg), "cpu")
+
+
+class GetTorchDeviceIdTest(unittest.TestCase):
+
+    def setUp(self) -> None:
+        self.cfg = DictConfig(
+            {
+                "own_device_id": "d0",
+                "topology": {
+                    "devices": [
+                        {
+                            "device_id": "d0",
+                            "address": "localhost:50051",
+                            "torch_device": "my_torch_device1",
+                        },
+                        {
+                            "device_id": "d1",
+                            "address": "localhost:50052",
+                            "torch_device": "my_torch_device2",
+                        },
+                    ]
+                },
+                "num_devices": "${len:${topology.devices}}",
+            }
+        )
+
+    def test_get_torch_device1(self):
+        self.assertEqual(get_torch_device_id(self.cfg), "my_torch_device1")
+
+    def test_get_torch_device2(self):
+        self.cfg.own_device_id = "d1"
+        self.assertEqual(get_torch_device_id(self.cfg), "my_torch_device2")
-- 
GitLab