diff --git a/edml/core/client.py b/edml/core/client.py index 4791c219bb99705d134c4ebc6e023910e0485d78..d7ff6723297bc45a9f2b5fd3e3bb002bb3de0111 100644 --- a/edml/core/client.py +++ b/edml/core/client.py @@ -223,7 +223,7 @@ class DeviceClient: gradients = [] for param in self._model.parameters(): if param is not None: - gradients.append(param) + gradients.append(param.grad) else: gradients.append(torch.zeros_like(param)) diff --git a/edml/core/device.py b/edml/core/device.py index ab03a450e3104aa35e3dc49c6a178fdc9e10faa1..2ba6833226b537118f8c02399a0c6c4a4967e47f 100644 --- a/edml/core/device.py +++ b/edml/core/device.py @@ -597,9 +597,10 @@ class RPCDeviceServicer(DeviceServicer): ): gradients = proto_to_tensor(request.gradients.gradients) - metrics = self.device.client.backward_single_batch(gradients=gradients) + metrics, gradients = self.device.client.backward_single_batch(gradients=gradients) return connection_pb2.SingleBatchBackwardResponse( - metrics=metrics_to_proto(metrics) + metrics=metrics_to_proto(metrics), + gradients=Gradients(gradients=tensor_to_proto(gradients)), ) def SetGradientsAndFinalizeTrainingStep( @@ -607,7 +608,7 @@ class RPCDeviceServicer(DeviceServicer): ): gradients = proto_to_tensor(request.gradients.gradients) self.device.client.set_gradient_and_finalize_training(gradients=gradients) - return connection_pb2.Empty() + return Empty() class DeviceRequestDispatcher: @@ -1011,7 +1012,7 @@ class DeviceRequestDispatcher: ) ) return ( - proto_to_metrics(response.metrics), + None, proto_to_tensor(response.gradients.gradients), ) except grpc.RpcError: diff --git a/edml/core/server.py b/edml/core/server.py index 185a616c2bbcf0c41bca54d2e29391cd5ae53338..b78793678f6f77bc84354c93f9d32064eee657c1 100644 --- a/edml/core/server.py +++ b/edml/core/server.py @@ -383,10 +383,15 @@ class DeviceServer: ) -def _calculate_gradient_mean(gradients: List[Variable]) -> Variable: +def _calculate_gradient_mean(gradients: List[Variable], device: str = "cpu") -> Variable: num_devices = len(gradients) weights = [1] * num_devices + # We need to move all tensors to the same device to do calculations. + for i, client_gradients in enumerate(gradients): + for j, grad in enumerate(client_gradients): + gradients[i][j] = grad.to(device) + return [ sum(gradients[i][j] * weights[i] for i in range(num_devices)) for j in range(len(gradients[0]))