diff --git a/edml/core/client.py b/edml/core/client.py index b3c3e7491e718cf98c9ccde598e00fd4d2054996..ace69aca8f663b895b5dd8a84393e1bf8c563591 100644 --- a/edml/core/client.py +++ b/edml/core/client.py @@ -258,6 +258,11 @@ class DeviceClient: that, this approach does not require to deduce server batch processing time after a "traditional" measurement. """ + if self._lr_scheduler is not None: + if round_no != -1: + self._lr_scheduler.step(round_no) + else: + self._lr_scheduler.step() client_train_start_time = time.time() server_train_batch_times = ( [] @@ -295,12 +300,6 @@ class DeviceClient: smashed_data.backward(server_grad) self._optimizer.step() - if self._lr_scheduler is not None: - if round_no != -1: - self._lr_scheduler.step(round_no) - else: - self._lr_scheduler.step() - client_train_time = ( time.time() - client_train_start_time - sum(server_train_batch_times) ) diff --git a/edml/core/server.py b/edml/core/server.py index 5e2f8235de476c5a998f9c8c040b462ac9d3b198..c0e60fb928ddbdaa2022ff08c5fb6192aa36cf10 100644 --- a/edml/core/server.py +++ b/edml/core/server.py @@ -90,6 +90,11 @@ class DeviceServer: if optimizer_state is not None: self._optimizer.load_state_dict(optimizer_state) for epoch in range(epochs): + if self._lr_scheduler is not None: + if round_no != -1: + self._lr_scheduler.step(round_no + epoch) + else: + self._lr_scheduler.step() for device_id in devices: print( f"Train epoch {epoch} on client {device_id} with server {self.node_device.device_id}" @@ -120,11 +125,6 @@ class DeviceServer: metrics.add_results(train_metrics) metrics.add_results(val_metrics) - if self._lr_scheduler is not None: - if round_no != -1: - self._lr_scheduler.step(round_no + epoch) - else: - self._lr_scheduler.step() return ( client_weights, self.get_weights(), @@ -241,6 +241,12 @@ class DeviceServer: if optimizer_state is not None: self._optimizer.load_state_dict(optimizer_state) + if self._lr_scheduler is not None: + if round_no != -1: + self._lr_scheduler.step(round_no + 1) # epoch=1 + else: + self._lr_scheduler.step() + num_threads = len(clients) executor = create_executor_with_threads(num_threads) @@ -346,11 +352,6 @@ class DeviceServer: model_metrics.add_results(val_metrics) optimizer_state = self._optimizer.state_dict() - if self._lr_scheduler is not None: - if round_no != -1: - self._lr_scheduler.step(round_no + 1) # epoch=1 - else: - self._lr_scheduler.step() # delete references and free GPU memory manually server_batch = None server_labels = None