From 457ca231ffa4abe04c7565041cef5926e8797afd Mon Sep 17 00:00:00 2001
From: Tim Bauerle <tim.bauerle@rwth-aachen.de>
Date: Sat, 29 Jun 2024 15:45:04 +0200
Subject: [PATCH] Fixed cuda device assignment

---
 edml/core/client.py | 5 ++---
 edml/core/server.py | 5 ++---
 2 files changed, 4 insertions(+), 6 deletions(-)

diff --git a/edml/core/client.py b/edml/core/client.py
index 3526e03..9211cda 100644
--- a/edml/core/client.py
+++ b/edml/core/client.py
@@ -9,6 +9,7 @@ from omegaconf import DictConfig
 from torch import nn
 from torch.utils.data import DataLoader
 
+from edml.helpers.config_helpers import get_torch_device_id
 from edml.helpers.decorators import (
     check_device_set,
     simulate_latency_decorator,
@@ -65,9 +66,7 @@ class DeviceClient:
 
         self._train_data, self._val_data, self._test_data = train_dl, val_dl, test_dl
         self._batchable_data_loader = None
-        self._device = torch.device(
-            "cuda:0" if torch.cuda.is_available() else "cpu"
-        )  # cuda:0
+        self._device = torch.device(get_torch_device_id(cfg))
         self._model = model.to(self._device)
         self._optimizer, self._lr_scheduler = get_optimizer_and_scheduler(
             cfg, self._model.parameters()
diff --git a/edml/core/server.py b/edml/core/server.py
index 61cedcd..f8703de 100644
--- a/edml/core/server.py
+++ b/edml/core/server.py
@@ -10,6 +10,7 @@ from colorama import init, Fore
 from torch import nn
 from torch.autograd import Variable
 
+from edml.helpers.config_helpers import get_torch_device_id
 from edml.helpers.decorators import check_device_set, simulate_latency_decorator
 from edml.helpers.executor import create_executor_with_threads
 from edml.helpers.flops import estimate_model_flops
@@ -37,9 +38,7 @@ class DeviceServer:
         latency_factor: float = 0.0,
     ):
         """Initializes the server with the given model, loss function, configuration and reference to its device."""
-        self._device = torch.device(
-            "cuda:0" if torch.cuda.is_available() else "cpu"
-        )  # cuda:0
+        self._device = torch.device(get_torch_device_id(cfg))
         self._model = model.to(self._device)
         self._optimizer, self._lr_scheduler = get_optimizer_and_scheduler(
             cfg, self._model.parameters()
-- 
GitLab