diff --git a/edml/core/client.py b/edml/core/client.py
index 75fb2bfbb360eeca936ec4aad9ba69b3fb68f459..05404ed95f94c91b7392a2139f42f19678172fe4 100644
--- a/edml/core/client.py
+++ b/edml/core/client.py
@@ -136,6 +136,7 @@ class DeviceClient:
 
     @check_device_set()
     def train_single_batch(self, batch_index: int) -> Optional[SLTrainBatchResult]:
+        torch.cuda.set_device(self._device)
         # We have to re-initialize the data loader in the case that we do another epoch.
         if batch_index == 0:
             self._batchable_data_loader = iter(self._train_data)
@@ -187,6 +188,7 @@ class DeviceClient:
 
     @check_device_set()
     def backward_single_batch(self, gradients) -> DiagnosticMetricResultContainer:
+        torch.cuda.set_device(self._device)
         batch_data, smashed_data, start_time, end_time = (
             self._psl_cache["batch_data"],
             self._psl_cache["smashed_data"],
diff --git a/edml/core/server.py b/edml/core/server.py
index 3c9833e5a965a7afd1d3061719c0c50ffdf536ce..885b22fb0d2aafe93ed2da750f43428d9350dbda 100644
--- a/edml/core/server.py
+++ b/edml/core/server.py
@@ -343,6 +343,7 @@ class DeviceServer:
         concatenated_client_gradients = None
         mean_tensor = None
         torch.cuda.empty_cache()
+        torch.cuda.set_device(self._device)
         return (
             self.node_device.client.get_weights(),
             self.get_weights(),