diff --git a/edml/helpers/config_helpers.py b/edml/helpers/config_helpers.py
index d848472f06e460ec636524869368ce9b91a17553..63c7b6873c63553ced6ec637791070cc1b61258c 100644
--- a/edml/helpers/config_helpers.py
+++ b/edml/helpers/config_helpers.py
@@ -1,7 +1,9 @@
 from inspect import signature
 
+import torch
 from hydra.utils import get_class, instantiate
 from omegaconf import OmegaConf, DictConfig
+from omegaconf.errors import ConfigAttributeError
 
 from edml.controllers.base_controller import BaseController
 
@@ -117,3 +119,36 @@ def instantiate_controller(cfg: DictConfig) -> BaseController:
     # Instantiate the controller.
     controller: BaseController = instantiate(cfg.controller)(cfg=cfg)
     return controller
+
+
+def get_torch_device_id(cfg: DictConfig) -> str:
+    """
+    Returns the configured torch_device for the current device.
+    Resorts to default if no torch_device is configured.
+
+    Args:
+        cfg (DictConfig): The config loaded from YAML files.
+
+    Returns:
+        The id of the configured torch_device for the current device.
+
+    Raises:
+        StopIteration: If the device with the given ID cannot be found.
+        ConfigAttributeError: If no device id is present in the config.
+    """
+    own_device_id = cfg.own_device_id
+    try:
+        return next(
+            device_cfg.torch_device
+            for device_cfg in cfg.topology.devices
+            if device_cfg.device_id == own_device_id
+        )
+    except ConfigAttributeError:
+        return _default_torch_device()
+
+
+def _default_torch_device():
+    """
+    Returns the default torch devices, depending on whether cuda is available.
+    """
+    return "cuda:0" if torch.cuda.is_available() else "cpu"
diff --git a/edml/tests/helpers/config_helpers_test.py b/edml/tests/helpers/config_helpers_test.py
index 3f634fd95083d5402b4ab37b1f4ea7eae270bd91..41f18361a36eb62464202fd2b91c5bc6bb2043a5 100644
--- a/edml/tests/helpers/config_helpers_test.py
+++ b/edml/tests/helpers/config_helpers_test.py
@@ -1,4 +1,5 @@
 import unittest
+from unittest.mock import patch
 
 from omegaconf import DictConfig
 
@@ -7,6 +8,7 @@ from edml.helpers.config_helpers import (
     get_device_address_by_id,
     preprocess_config,
     get_device_index_by_id,
+    get_torch_device_id,
 )
 
 
@@ -46,3 +48,43 @@ class ConfigHelpersTest(unittest.TestCase):
         self.cfg.own_device_id = 1
         preprocess_config(self.cfg)
         self.assertEqual("d1", self.cfg.own_device_id)
+
+    def test_get_default_torch_device_if_cuda_available(self):
+        with patch("torch.cuda.is_available", return_value=True):
+            self.assertEqual(get_torch_device_id(self.cfg), "cuda:0")
+
+    def test_get_default_torch_device_if_cuda_not_available(self):
+        with patch("torch.cuda.is_available", return_value=False):
+            self.assertEqual(get_torch_device_id(self.cfg), "cpu")
+
+
+class GetTorchDeviceIdTest(unittest.TestCase):
+
+    def setUp(self) -> None:
+        self.cfg = DictConfig(
+            {
+                "own_device_id": "d0",
+                "topology": {
+                    "devices": [
+                        {
+                            "device_id": "d0",
+                            "address": "localhost:50051",
+                            "torch_device": "my_torch_device1",
+                        },
+                        {
+                            "device_id": "d1",
+                            "address": "localhost:50052",
+                            "torch_device": "my_torch_device2",
+                        },
+                    ]
+                },
+                "num_devices": "${len:${topology.devices}}",
+            }
+        )
+
+    def test_get_torch_device1(self):
+        self.assertEqual(get_torch_device_id(self.cfg), "my_torch_device1")
+
+    def test_get_torch_device2(self):
+        self.cfg.own_device_id = "d1"
+        self.assertEqual(get_torch_device_id(self.cfg), "my_torch_device2")