diff --git a/edml/core/client.py b/edml/core/client.py index 9802a84b8227421d0dd7c5a1acf17d7968d81d1f..4791c219bb99705d134c4ebc6e023910e0485d78 100644 --- a/edml/core/client.py +++ b/edml/core/client.py @@ -205,7 +205,6 @@ class DeviceClient: ) # 2x for backward pass gradients = gradients.to(self._device) smashed_data.backward(gradients) - print(smashed_data.grad) # self._optimizer.step() # We need to store a reference to the smashed_data to make it possible to finalize the training step. @@ -221,7 +220,14 @@ class DeviceClient: ) metrics_container = DiagnosticMetricResultContainer([metric]) - return metrics_container, smashed_data.grad + gradients = [] + for param in self._model.parameters(): + if param is not None: + gradients.append(param) + else: + gradients.append(torch.zeros_like(param)) + + return metrics_container, gradients def get_approximated_num_batches(self) -> int: return len(self._train_data) @@ -372,8 +378,9 @@ class DeviceClient: ) return diagnostic_metric_results - def set_gradient_and_finalize_training(self, gradients: Any): - smashed_data = self._psl_cache["smashed_data"] - smashed_data.grad = gradients + def set_gradient_and_finalize_training(self, gradients: Any): + for param, grad in zip(self._model.parameters(), gradients): + param.grad = grad.to(self._device) + self._optimizer.step() self._psl_cache = None diff --git a/edml/core/device.py b/edml/core/device.py index 694775c1825a7cb51f120a4c743a2498057d66a1..ab03a450e3104aa35e3dc49c6a178fdc9e10faa1 100644 --- a/edml/core/device.py +++ b/edml/core/device.py @@ -15,6 +15,7 @@ from edml.core.client import DeviceClient from edml.core.server import DeviceServer from edml.generated import connection_pb2 from edml.generated.connection_pb2 import ( + SetGradientsRequest, SetWeightsRequest, TrainBatchRequest, TrainGlobalResponse, @@ -600,6 +601,13 @@ class RPCDeviceServicer(DeviceServicer): return connection_pb2.SingleBatchBackwardResponse( metrics=metrics_to_proto(metrics) ) + + def SetGradientsAndFinalizeTrainingStep( + self, request: SetGradientsRequest, context + ): + gradients = proto_to_tensor(request.gradients.gradients) + self.device.client.set_gradient_and_finalize_training(gradients=gradients) + return connection_pb2.Empty() class DeviceRequestDispatcher: @@ -1003,8 +1011,8 @@ class DeviceRequestDispatcher: ) ) return ( - response.metrics, - response.gradients, + proto_to_metrics(response.metrics), + proto_to_tensor(response.gradients.gradients), ) except grpc.RpcError: self._handle_rpc_error(device_id) diff --git a/edml/core/server.py b/edml/core/server.py index bc6ac5b2ede8ecbe86e40129b27d686cc1a3a33e..185a616c2bbcf0c41bca54d2e29391cd5ae53338 100644 --- a/edml/core/server.py +++ b/edml/core/server.py @@ -384,8 +384,13 @@ class DeviceServer: def _calculate_gradient_mean(gradients: List[Variable]) -> Variable: - """Calculates the mean of a list of gradients.""" - return torch.mean(torch.stack(gradients), dim=0) + num_devices = len(gradients) + weights = [1] * num_devices + + return [ + sum(gradients[i][j] * weights[i] for i in range(num_devices)) + for j in range(len(gradients[0])) + ] def _concat_smashed_data(data: List[Any]) -> Any: