diff --git a/edml/core/client.py b/edml/core/client.py index 5e99d875eb155ae165b3dbe750c406f09a4e0e60..75fb2bfbb360eeca936ec4aad9ba69b3fb68f459 100644 --- a/edml/core/client.py +++ b/edml/core/client.py @@ -199,7 +199,7 @@ class DeviceClient: self.node_device.battery.update_flops( self._model_flops * len(batch_data) * 2 ) # 2x for backward pass - smashed_data.to(self._device) + gradients = gradients.to(self._device) smashed_data.backward(gradients) self._optimizer.step() @@ -280,7 +280,7 @@ class DeviceClient: self.node_device.battery.update_flops( self._model_flops * len(batch_data) * 2 ) # 2x for backward pass - smashed_data.to(self._device) + server_grad = server_grad.to(self._device) smashed_data.backward(server_grad) self._optimizer.step() diff --git a/edml/core/server.py b/edml/core/server.py index f8703de677bd88a36abbe0aef0f6cd32581e1499..a671ba98e404c6981b35eb90d3e5bb1c17084304 100644 --- a/edml/core/server.py +++ b/edml/core/server.py @@ -205,7 +205,7 @@ class DeviceServer: def evaluate_batch(self, smashed_data, labels): """Evaluates the model on the given batch of data and labels""" with torch.no_grad(): - smashed_data.to(self._device) + smashed_data = smashed_data.to(self._device) self._set_model_flops(smashed_data) self.node_device.battery.update_flops(self._model_flops * len(smashed_data)) pred = self._model(smashed_data) @@ -259,8 +259,12 @@ class DeviceServer: print(f"\n\n\nBATCHES: {len(batches)}\n\n\n") # batches2 = [b for b in batches if b is not None] # print(f"\n\n\nBATCHES FILTERED: {len(batches)}\n\n\n") - server_batch = _concat_smashed_data([b[0] for b in batches]) - server_labels = _concat_smashed_data([b[1] for b in batches]) + server_batch = _concat_smashed_data( + [b[0].to(self._device) for b in batches] + ) + server_labels = _concat_smashed_data( + [b[1].to(self._device) for b in batches] + ) # Train the part on the server. Then send the gradients to each client, continuing the calculation. We need # to split the gradients back into batch-sized tensors to average them before sending them to the client. diff --git a/edml/tests/controllers/split_controller_test.py b/edml/tests/controllers/split_controller_test.py index 807749d72d338f3c848837c8d9bd8b359966e08b..8531a2d1fe0aa146494d1ffb45317739a1c01dc1 100644 --- a/edml/tests/controllers/split_controller_test.py +++ b/edml/tests/controllers/split_controller_test.py @@ -23,6 +23,7 @@ class SplitControllerTest(unittest.TestCase): {"weights": 42}, {"weights": 43}, ModelMetricResultContainer(), + {"optimizer_state": 44}, DiagnosticMetricResultContainer(), )