Skip to content
Snippets Groups Projects
Commit bab27540 authored by Tim Tobias Bauerle's avatar Tim Tobias Bauerle
Browse files

Assign gradients to correct cuda device

parent d351e5ad
Branches
No related tags found
2 merge requests!18Merge in main,!14Experiment configs
...@@ -199,7 +199,7 @@ class DeviceClient: ...@@ -199,7 +199,7 @@ class DeviceClient:
self.node_device.battery.update_flops( self.node_device.battery.update_flops(
self._model_flops * len(batch_data) * 2 self._model_flops * len(batch_data) * 2
) # 2x for backward pass ) # 2x for backward pass
smashed_data.to(self._device) gradients = gradients.to(self._device)
smashed_data.backward(gradients) smashed_data.backward(gradients)
self._optimizer.step() self._optimizer.step()
...@@ -280,7 +280,7 @@ class DeviceClient: ...@@ -280,7 +280,7 @@ class DeviceClient:
self.node_device.battery.update_flops( self.node_device.battery.update_flops(
self._model_flops * len(batch_data) * 2 self._model_flops * len(batch_data) * 2
) # 2x for backward pass ) # 2x for backward pass
smashed_data.to(self._device) server_grad = server_grad.to(self._device)
smashed_data.backward(server_grad) smashed_data.backward(server_grad)
self._optimizer.step() self._optimizer.step()
......
...@@ -205,7 +205,7 @@ class DeviceServer: ...@@ -205,7 +205,7 @@ class DeviceServer:
def evaluate_batch(self, smashed_data, labels): def evaluate_batch(self, smashed_data, labels):
"""Evaluates the model on the given batch of data and labels""" """Evaluates the model on the given batch of data and labels"""
with torch.no_grad(): with torch.no_grad():
smashed_data.to(self._device) smashed_data = smashed_data.to(self._device)
self._set_model_flops(smashed_data) self._set_model_flops(smashed_data)
self.node_device.battery.update_flops(self._model_flops * len(smashed_data)) self.node_device.battery.update_flops(self._model_flops * len(smashed_data))
pred = self._model(smashed_data) pred = self._model(smashed_data)
...@@ -259,8 +259,12 @@ class DeviceServer: ...@@ -259,8 +259,12 @@ class DeviceServer:
print(f"\n\n\nBATCHES: {len(batches)}\n\n\n") print(f"\n\n\nBATCHES: {len(batches)}\n\n\n")
# batches2 = [b for b in batches if b is not None] # batches2 = [b for b in batches if b is not None]
# print(f"\n\n\nBATCHES FILTERED: {len(batches)}\n\n\n") # print(f"\n\n\nBATCHES FILTERED: {len(batches)}\n\n\n")
server_batch = _concat_smashed_data([b[0] for b in batches]) server_batch = _concat_smashed_data(
server_labels = _concat_smashed_data([b[1] for b in batches]) [b[0].to(self._device) for b in batches]
)
server_labels = _concat_smashed_data(
[b[1].to(self._device) for b in batches]
)
# Train the part on the server. Then send the gradients to each client, continuing the calculation. We need # Train the part on the server. Then send the gradients to each client, continuing the calculation. We need
# to split the gradients back into batch-sized tensors to average them before sending them to the client. # to split the gradients back into batch-sized tensors to average them before sending them to the client.
......
...@@ -23,6 +23,7 @@ class SplitControllerTest(unittest.TestCase): ...@@ -23,6 +23,7 @@ class SplitControllerTest(unittest.TestCase):
{"weights": 42}, {"weights": 42},
{"weights": 43}, {"weights": 43},
ModelMetricResultContainer(), ModelMetricResultContainer(),
{"optimizer_state": 44},
DiagnosticMetricResultContainer(), DiagnosticMetricResultContainer(),
) )
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment