From d351e5ad86d13c685f4b783afc433a34c34f2951 Mon Sep 17 00:00:00 2001
From: Tim Bauerle <tim.bauerle@rwth-aachen.de>
Date: Sat, 29 Jun 2024 19:24:37 +0200
Subject: [PATCH] Assign smashed data to correct cuda device

---
 edml/core/client.py | 2 ++
 1 file changed, 2 insertions(+)

diff --git a/edml/core/client.py b/edml/core/client.py
index 9211cda..5e99d87 100644
--- a/edml/core/client.py
+++ b/edml/core/client.py
@@ -199,6 +199,7 @@ class DeviceClient:
         self.node_device.battery.update_flops(
             self._model_flops * len(batch_data) * 2
         )  # 2x for backward pass
+        smashed_data.to(self._device)
         smashed_data.backward(gradients)
         self._optimizer.step()
 
@@ -279,6 +280,7 @@ class DeviceClient:
                 self.node_device.battery.update_flops(
                     self._model_flops * len(batch_data) * 2
                 )  # 2x for backward pass
+                smashed_data.to(self._device)
                 smashed_data.backward(server_grad)
                 self._optimizer.step()
 
-- 
GitLab