From bab2754089089c8575eb99f8a8cd24dae3d0e0f6 Mon Sep 17 00:00:00 2001 From: Tim Bauerle <tim.bauerle@rwth-aachen.de> Date: Mon, 1 Jul 2024 11:28:09 +0200 Subject: [PATCH] Assign gradients to correct cuda device --- edml/core/client.py | 4 ++-- edml/core/server.py | 10 +++++++--- edml/tests/controllers/split_controller_test.py | 1 + 3 files changed, 10 insertions(+), 5 deletions(-) diff --git a/edml/core/client.py b/edml/core/client.py index 5e99d87..75fb2bf 100644 --- a/edml/core/client.py +++ b/edml/core/client.py @@ -199,7 +199,7 @@ class DeviceClient: self.node_device.battery.update_flops( self._model_flops * len(batch_data) * 2 ) # 2x for backward pass - smashed_data.to(self._device) + gradients = gradients.to(self._device) smashed_data.backward(gradients) self._optimizer.step() @@ -280,7 +280,7 @@ class DeviceClient: self.node_device.battery.update_flops( self._model_flops * len(batch_data) * 2 ) # 2x for backward pass - smashed_data.to(self._device) + server_grad = server_grad.to(self._device) smashed_data.backward(server_grad) self._optimizer.step() diff --git a/edml/core/server.py b/edml/core/server.py index f8703de..a671ba9 100644 --- a/edml/core/server.py +++ b/edml/core/server.py @@ -205,7 +205,7 @@ class DeviceServer: def evaluate_batch(self, smashed_data, labels): """Evaluates the model on the given batch of data and labels""" with torch.no_grad(): - smashed_data.to(self._device) + smashed_data = smashed_data.to(self._device) self._set_model_flops(smashed_data) self.node_device.battery.update_flops(self._model_flops * len(smashed_data)) pred = self._model(smashed_data) @@ -259,8 +259,12 @@ class DeviceServer: print(f"\n\n\nBATCHES: {len(batches)}\n\n\n") # batches2 = [b for b in batches if b is not None] # print(f"\n\n\nBATCHES FILTERED: {len(batches)}\n\n\n") - server_batch = _concat_smashed_data([b[0] for b in batches]) - server_labels = _concat_smashed_data([b[1] for b in batches]) + server_batch = _concat_smashed_data( + [b[0].to(self._device) for b in batches] + ) + server_labels = _concat_smashed_data( + [b[1].to(self._device) for b in batches] + ) # Train the part on the server. Then send the gradients to each client, continuing the calculation. We need # to split the gradients back into batch-sized tensors to average them before sending them to the client. diff --git a/edml/tests/controllers/split_controller_test.py b/edml/tests/controllers/split_controller_test.py index 807749d..8531a2d 100644 --- a/edml/tests/controllers/split_controller_test.py +++ b/edml/tests/controllers/split_controller_test.py @@ -23,6 +23,7 @@ class SplitControllerTest(unittest.TestCase): {"weights": 42}, {"weights": 43}, ModelMetricResultContainer(), + {"optimizer_state": 44}, DiagnosticMetricResultContainer(), ) -- GitLab