Skip to content
Snippets Groups Projects
Commit bab27540 authored by Tim Tobias Bauerle's avatar Tim Tobias Bauerle
Browse files

Assign gradients to correct cuda device

parent d351e5ad
Branches
No related tags found
2 merge requests!18Merge in main,!14Experiment configs
......@@ -199,7 +199,7 @@ class DeviceClient:
self.node_device.battery.update_flops(
self._model_flops * len(batch_data) * 2
) # 2x for backward pass
smashed_data.to(self._device)
gradients = gradients.to(self._device)
smashed_data.backward(gradients)
self._optimizer.step()
......@@ -280,7 +280,7 @@ class DeviceClient:
self.node_device.battery.update_flops(
self._model_flops * len(batch_data) * 2
) # 2x for backward pass
smashed_data.to(self._device)
server_grad = server_grad.to(self._device)
smashed_data.backward(server_grad)
self._optimizer.step()
......
......@@ -205,7 +205,7 @@ class DeviceServer:
def evaluate_batch(self, smashed_data, labels):
"""Evaluates the model on the given batch of data and labels"""
with torch.no_grad():
smashed_data.to(self._device)
smashed_data = smashed_data.to(self._device)
self._set_model_flops(smashed_data)
self.node_device.battery.update_flops(self._model_flops * len(smashed_data))
pred = self._model(smashed_data)
......@@ -259,8 +259,12 @@ class DeviceServer:
print(f"\n\n\nBATCHES: {len(batches)}\n\n\n")
# batches2 = [b for b in batches if b is not None]
# print(f"\n\n\nBATCHES FILTERED: {len(batches)}\n\n\n")
server_batch = _concat_smashed_data([b[0] for b in batches])
server_labels = _concat_smashed_data([b[1] for b in batches])
server_batch = _concat_smashed_data(
[b[0].to(self._device) for b in batches]
)
server_labels = _concat_smashed_data(
[b[1].to(self._device) for b in batches]
)
# Train the part on the server. Then send the gradients to each client, continuing the calculation. We need
# to split the gradients back into batch-sized tensors to average them before sending them to the client.
......
......@@ -23,6 +23,7 @@ class SplitControllerTest(unittest.TestCase):
{"weights": 42},
{"weights": 43},
ModelMetricResultContainer(),
{"optimizer_state": 44},
DiagnosticMetricResultContainer(),
)
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment