Skip to content
Snippets Groups Projects
Commit 29ada85f authored by Tim Tobias Bauerle's avatar Tim Tobias Bauerle
Browse files

Set cuda device again after clearing cache

parent 6abb3d39
Branches
No related tags found
2 merge requests!18Merge in main,!14Experiment configs
......@@ -136,6 +136,7 @@ class DeviceClient:
@check_device_set()
def train_single_batch(self, batch_index: int) -> Optional[SLTrainBatchResult]:
torch.cuda.set_device(self._device)
# We have to re-initialize the data loader in the case that we do another epoch.
if batch_index == 0:
self._batchable_data_loader = iter(self._train_data)
......@@ -187,6 +188,7 @@ class DeviceClient:
@check_device_set()
def backward_single_batch(self, gradients) -> DiagnosticMetricResultContainer:
torch.cuda.set_device(self._device)
batch_data, smashed_data, start_time, end_time = (
self._psl_cache["batch_data"],
self._psl_cache["smashed_data"],
......
......@@ -343,6 +343,7 @@ class DeviceServer:
concatenated_client_gradients = None
mean_tensor = None
torch.cuda.empty_cache()
torch.cuda.set_device(self._device)
return (
self.node_device.client.get_weights(),
self.get_weights(),
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment