From 29ada85fcf3b4aa33664c2bb436cd6b4d4198ca0 Mon Sep 17 00:00:00 2001
From: Tim Bauerle <tim.bauerle@rwth-aachen.de>
Date: Wed, 3 Jul 2024 12:13:59 +0200
Subject: [PATCH] Set cuda device again after clearing cache

---
 edml/core/client.py | 2 ++
 edml/core/server.py | 1 +
 2 files changed, 3 insertions(+)

diff --git a/edml/core/client.py b/edml/core/client.py
index 75fb2bf..05404ed 100644
--- a/edml/core/client.py
+++ b/edml/core/client.py
@@ -136,6 +136,7 @@ class DeviceClient:
 
     @check_device_set()
     def train_single_batch(self, batch_index: int) -> Optional[SLTrainBatchResult]:
+        torch.cuda.set_device(self._device)
         # We have to re-initialize the data loader in the case that we do another epoch.
         if batch_index == 0:
             self._batchable_data_loader = iter(self._train_data)
@@ -187,6 +188,7 @@ class DeviceClient:
 
     @check_device_set()
     def backward_single_batch(self, gradients) -> DiagnosticMetricResultContainer:
+        torch.cuda.set_device(self._device)
         batch_data, smashed_data, start_time, end_time = (
             self._psl_cache["batch_data"],
             self._psl_cache["smashed_data"],
diff --git a/edml/core/server.py b/edml/core/server.py
index 3c9833e..885b22f 100644
--- a/edml/core/server.py
+++ b/edml/core/server.py
@@ -343,6 +343,7 @@ class DeviceServer:
         concatenated_client_gradients = None
         mean_tensor = None
         torch.cuda.empty_cache()
+        torch.cuda.set_device(self._device)
         return (
             self.node_device.client.get_weights(),
             self.get_weights(),
-- 
GitLab