From bab2754089089c8575eb99f8a8cd24dae3d0e0f6 Mon Sep 17 00:00:00 2001
From: Tim Bauerle <tim.bauerle@rwth-aachen.de>
Date: Mon, 1 Jul 2024 11:28:09 +0200
Subject: [PATCH] Assign gradients to correct cuda device

---
 edml/core/client.py                             |  4 ++--
 edml/core/server.py                             | 10 +++++++---
 edml/tests/controllers/split_controller_test.py |  1 +
 3 files changed, 10 insertions(+), 5 deletions(-)

diff --git a/edml/core/client.py b/edml/core/client.py
index 5e99d87..75fb2bf 100644
--- a/edml/core/client.py
+++ b/edml/core/client.py
@@ -199,7 +199,7 @@ class DeviceClient:
         self.node_device.battery.update_flops(
             self._model_flops * len(batch_data) * 2
         )  # 2x for backward pass
-        smashed_data.to(self._device)
+        gradients = gradients.to(self._device)
         smashed_data.backward(gradients)
         self._optimizer.step()
 
@@ -280,7 +280,7 @@ class DeviceClient:
                 self.node_device.battery.update_flops(
                     self._model_flops * len(batch_data) * 2
                 )  # 2x for backward pass
-                smashed_data.to(self._device)
+                server_grad = server_grad.to(self._device)
                 smashed_data.backward(server_grad)
                 self._optimizer.step()
 
diff --git a/edml/core/server.py b/edml/core/server.py
index f8703de..a671ba9 100644
--- a/edml/core/server.py
+++ b/edml/core/server.py
@@ -205,7 +205,7 @@ class DeviceServer:
     def evaluate_batch(self, smashed_data, labels):
         """Evaluates the model on the given batch of data and labels"""
         with torch.no_grad():
-            smashed_data.to(self._device)
+            smashed_data = smashed_data.to(self._device)
             self._set_model_flops(smashed_data)
             self.node_device.battery.update_flops(self._model_flops * len(smashed_data))
             pred = self._model(smashed_data)
@@ -259,8 +259,12 @@ class DeviceServer:
             print(f"\n\n\nBATCHES: {len(batches)}\n\n\n")
             # batches2 = [b for b in batches if b is not None]
             # print(f"\n\n\nBATCHES FILTERED: {len(batches)}\n\n\n")
-            server_batch = _concat_smashed_data([b[0] for b in batches])
-            server_labels = _concat_smashed_data([b[1] for b in batches])
+            server_batch = _concat_smashed_data(
+                [b[0].to(self._device) for b in batches]
+            )
+            server_labels = _concat_smashed_data(
+                [b[1].to(self._device) for b in batches]
+            )
 
             # Train the part on the server. Then send the gradients to each client, continuing the calculation. We need
             # to split the gradients back into batch-sized tensors to average them before sending them to the client.
diff --git a/edml/tests/controllers/split_controller_test.py b/edml/tests/controllers/split_controller_test.py
index 807749d..8531a2d 100644
--- a/edml/tests/controllers/split_controller_test.py
+++ b/edml/tests/controllers/split_controller_test.py
@@ -23,6 +23,7 @@ class SplitControllerTest(unittest.TestCase):
             {"weights": 42},
             {"weights": 43},
             ModelMetricResultContainer(),
+            {"optimizer_state": 44},
             DiagnosticMetricResultContainer(),
         )
 
-- 
GitLab