From 29ada85fcf3b4aa33664c2bb436cd6b4d4198ca0 Mon Sep 17 00:00:00 2001 From: Tim Bauerle <tim.bauerle@rwth-aachen.de> Date: Wed, 3 Jul 2024 12:13:59 +0200 Subject: [PATCH] Set cuda device again after clearing cache --- edml/core/client.py | 2 ++ edml/core/server.py | 1 + 2 files changed, 3 insertions(+) diff --git a/edml/core/client.py b/edml/core/client.py index 75fb2bf..05404ed 100644 --- a/edml/core/client.py +++ b/edml/core/client.py @@ -136,6 +136,7 @@ class DeviceClient: @check_device_set() def train_single_batch(self, batch_index: int) -> Optional[SLTrainBatchResult]: + torch.cuda.set_device(self._device) # We have to re-initialize the data loader in the case that we do another epoch. if batch_index == 0: self._batchable_data_loader = iter(self._train_data) @@ -187,6 +188,7 @@ class DeviceClient: @check_device_set() def backward_single_batch(self, gradients) -> DiagnosticMetricResultContainer: + torch.cuda.set_device(self._device) batch_data, smashed_data, start_time, end_time = ( self._psl_cache["batch_data"], self._psl_cache["smashed_data"], diff --git a/edml/core/server.py b/edml/core/server.py index 3c9833e..885b22f 100644 --- a/edml/core/server.py +++ b/edml/core/server.py @@ -343,6 +343,7 @@ class DeviceServer: concatenated_client_gradients = None mean_tensor = None torch.cuda.empty_cache() + torch.cuda.set_device(self._device) return ( self.node_device.client.get_weights(), self.get_weights(), -- GitLab