From d351e5ad86d13c685f4b783afc433a34c34f2951 Mon Sep 17 00:00:00 2001 From: Tim Bauerle <tim.bauerle@rwth-aachen.de> Date: Sat, 29 Jun 2024 19:24:37 +0200 Subject: [PATCH] Assign smashed data to correct cuda device --- edml/core/client.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/edml/core/client.py b/edml/core/client.py index 9211cda..5e99d87 100644 --- a/edml/core/client.py +++ b/edml/core/client.py @@ -199,6 +199,7 @@ class DeviceClient: self.node_device.battery.update_flops( self._model_flops * len(batch_data) * 2 ) # 2x for backward pass + smashed_data.to(self._device) smashed_data.backward(gradients) self._optimizer.step() @@ -279,6 +280,7 @@ class DeviceClient: self.node_device.battery.update_flops( self._model_flops * len(batch_data) * 2 ) # 2x for backward pass + smashed_data.to(self._device) smashed_data.backward(server_grad) self._optimizer.step() -- GitLab