Skip to content
Snippets Groups Projects
Commit 2c6dd633 authored by Tim Tobias Bauerle's avatar Tim Tobias Bauerle
Browse files

Call LR scheduler before training in case a device misses out a round

parent d0c02fea
Branches
No related tags found
1 merge request!22Fix lr schedule
......@@ -258,6 +258,11 @@ class DeviceClient:
that, this approach does not require to deduce server batch processing time after a "traditional"
measurement.
"""
if self._lr_scheduler is not None:
if round_no != -1:
self._lr_scheduler.step(round_no)
else:
self._lr_scheduler.step()
client_train_start_time = time.time()
server_train_batch_times = (
[]
......@@ -295,12 +300,6 @@ class DeviceClient:
smashed_data.backward(server_grad)
self._optimizer.step()
if self._lr_scheduler is not None:
if round_no != -1:
self._lr_scheduler.step(round_no)
else:
self._lr_scheduler.step()
client_train_time = (
time.time() - client_train_start_time - sum(server_train_batch_times)
)
......
......@@ -90,6 +90,11 @@ class DeviceServer:
if optimizer_state is not None:
self._optimizer.load_state_dict(optimizer_state)
for epoch in range(epochs):
if self._lr_scheduler is not None:
if round_no != -1:
self._lr_scheduler.step(round_no + epoch)
else:
self._lr_scheduler.step()
for device_id in devices:
print(
f"Train epoch {epoch} on client {device_id} with server {self.node_device.device_id}"
......@@ -120,11 +125,6 @@ class DeviceServer:
metrics.add_results(train_metrics)
metrics.add_results(val_metrics)
if self._lr_scheduler is not None:
if round_no != -1:
self._lr_scheduler.step(round_no + epoch)
else:
self._lr_scheduler.step()
return (
client_weights,
self.get_weights(),
......@@ -241,6 +241,12 @@ class DeviceServer:
if optimizer_state is not None:
self._optimizer.load_state_dict(optimizer_state)
if self._lr_scheduler is not None:
if round_no != -1:
self._lr_scheduler.step(round_no + 1) # epoch=1
else:
self._lr_scheduler.step()
num_threads = len(clients)
executor = create_executor_with_threads(num_threads)
......@@ -346,11 +352,6 @@ class DeviceServer:
model_metrics.add_results(val_metrics)
optimizer_state = self._optimizer.state_dict()
if self._lr_scheduler is not None:
if round_no != -1:
self._lr_scheduler.step(round_no + 1) # epoch=1
else:
self._lr_scheduler.step()
# delete references and free GPU memory manually
server_batch = None
server_labels = None
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment