diff --git a/edml/core/client.py b/edml/core/client.py index 75fb2bfbb360eeca936ec4aad9ba69b3fb68f459..05404ed95f94c91b7392a2139f42f19678172fe4 100644 --- a/edml/core/client.py +++ b/edml/core/client.py @@ -136,6 +136,7 @@ class DeviceClient: @check_device_set() def train_single_batch(self, batch_index: int) -> Optional[SLTrainBatchResult]: + torch.cuda.set_device(self._device) # We have to re-initialize the data loader in the case that we do another epoch. if batch_index == 0: self._batchable_data_loader = iter(self._train_data) @@ -187,6 +188,7 @@ class DeviceClient: @check_device_set() def backward_single_batch(self, gradients) -> DiagnosticMetricResultContainer: + torch.cuda.set_device(self._device) batch_data, smashed_data, start_time, end_time = ( self._psl_cache["batch_data"], self._psl_cache["smashed_data"], diff --git a/edml/core/server.py b/edml/core/server.py index 3c9833e5a965a7afd1d3061719c0c50ffdf536ce..885b22fb0d2aafe93ed2da750f43428d9350dbda 100644 --- a/edml/core/server.py +++ b/edml/core/server.py @@ -343,6 +343,7 @@ class DeviceServer: concatenated_client_gradients = None mean_tensor = None torch.cuda.empty_cache() + torch.cuda.set_device(self._device) return ( self.node_device.client.get_weights(), self.get_weights(),