Skip to content
Snippets Groups Projects
Commit 9793d10d authored by Sven Michael Lechner's avatar Sven Michael Lechner
Browse files

fix(psl): gradient accumulation and sending

parent 026fe2c5
No related branches found
No related tags found
2 merge requests!18Merge in main,!16Fix PSL
......@@ -205,7 +205,6 @@ class DeviceClient:
) # 2x for backward pass
gradients = gradients.to(self._device)
smashed_data.backward(gradients)
print(smashed_data.grad)
# self._optimizer.step()
# We need to store a reference to the smashed_data to make it possible to finalize the training step.
......@@ -221,7 +220,14 @@ class DeviceClient:
)
metrics_container = DiagnosticMetricResultContainer([metric])
return metrics_container, smashed_data.grad
gradients = []
for param in self._model.parameters():
if param is not None:
gradients.append(param)
else:
gradients.append(torch.zeros_like(param))
return metrics_container, gradients
def get_approximated_num_batches(self) -> int:
return len(self._train_data)
......@@ -372,8 +378,9 @@ class DeviceClient:
)
return diagnostic_metric_results
def set_gradient_and_finalize_training(self, gradients: Any):
smashed_data = self._psl_cache["smashed_data"]
smashed_data.grad = gradients
def set_gradient_and_finalize_training(self, gradients: Any):
for param, grad in zip(self._model.parameters(), gradients):
param.grad = grad.to(self._device)
self._optimizer.step()
self._psl_cache = None
......@@ -15,6 +15,7 @@ from edml.core.client import DeviceClient
from edml.core.server import DeviceServer
from edml.generated import connection_pb2
from edml.generated.connection_pb2 import (
SetGradientsRequest,
SetWeightsRequest,
TrainBatchRequest,
TrainGlobalResponse,
......@@ -600,6 +601,13 @@ class RPCDeviceServicer(DeviceServicer):
return connection_pb2.SingleBatchBackwardResponse(
metrics=metrics_to_proto(metrics)
)
def SetGradientsAndFinalizeTrainingStep(
self, request: SetGradientsRequest, context
):
gradients = proto_to_tensor(request.gradients.gradients)
self.device.client.set_gradient_and_finalize_training(gradients=gradients)
return connection_pb2.Empty()
class DeviceRequestDispatcher:
......@@ -1003,8 +1011,8 @@ class DeviceRequestDispatcher:
)
)
return (
response.metrics,
response.gradients,
proto_to_metrics(response.metrics),
proto_to_tensor(response.gradients.gradients),
)
except grpc.RpcError:
self._handle_rpc_error(device_id)
......
......@@ -384,8 +384,13 @@ class DeviceServer:
def _calculate_gradient_mean(gradients: List[Variable]) -> Variable:
"""Calculates the mean of a list of gradients."""
return torch.mean(torch.stack(gradients), dim=0)
num_devices = len(gradients)
weights = [1] * num_devices
return [
sum(gradients[i][j] * weights[i] for i in range(num_devices))
for j in range(len(gradients[0]))
]
def _concat_smashed_data(data: List[Any]) -> Any:
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment