diff --git a/edml/core/client.py b/edml/core/client.py index 3526e03c5e2c3c9852d556d385a8a303514082ea..9211cda8617085bc3189b12904d30d24f51090ba 100644 --- a/edml/core/client.py +++ b/edml/core/client.py @@ -9,6 +9,7 @@ from omegaconf import DictConfig from torch import nn from torch.utils.data import DataLoader +from edml.helpers.config_helpers import get_torch_device_id from edml.helpers.decorators import ( check_device_set, simulate_latency_decorator, @@ -65,9 +66,7 @@ class DeviceClient: self._train_data, self._val_data, self._test_data = train_dl, val_dl, test_dl self._batchable_data_loader = None - self._device = torch.device( - "cuda:0" if torch.cuda.is_available() else "cpu" - ) # cuda:0 + self._device = torch.device(get_torch_device_id(cfg)) self._model = model.to(self._device) self._optimizer, self._lr_scheduler = get_optimizer_and_scheduler( cfg, self._model.parameters() diff --git a/edml/core/server.py b/edml/core/server.py index 61cedcde19389a446e85da31299eacd5711fe0cf..f8703de677bd88a36abbe0aef0f6cd32581e1499 100644 --- a/edml/core/server.py +++ b/edml/core/server.py @@ -10,6 +10,7 @@ from colorama import init, Fore from torch import nn from torch.autograd import Variable +from edml.helpers.config_helpers import get_torch_device_id from edml.helpers.decorators import check_device_set, simulate_latency_decorator from edml.helpers.executor import create_executor_with_threads from edml.helpers.flops import estimate_model_flops @@ -37,9 +38,7 @@ class DeviceServer: latency_factor: float = 0.0, ): """Initializes the server with the given model, loss function, configuration and reference to its device.""" - self._device = torch.device( - "cuda:0" if torch.cuda.is_available() else "cpu" - ) # cuda:0 + self._device = torch.device(get_torch_device_id(cfg)) self._model = model.to(self._device) self._optimizer, self._lr_scheduler = get_optimizer_and_scheduler( cfg, self._model.parameters()