diff --git a/edml/core/client.py b/edml/core/client.py
index b3c3e7491e718cf98c9ccde598e00fd4d2054996..ace69aca8f663b895b5dd8a84393e1bf8c563591 100644
--- a/edml/core/client.py
+++ b/edml/core/client.py
@@ -258,6 +258,11 @@ class DeviceClient:
             that, this approach does not require to deduce server batch processing time after a "traditional"
             measurement.
         """
+        if self._lr_scheduler is not None:
+            if round_no != -1:
+                self._lr_scheduler.step(round_no)
+            else:
+                self._lr_scheduler.step()
         client_train_start_time = time.time()
         server_train_batch_times = (
             []
@@ -295,12 +300,6 @@ class DeviceClient:
                 smashed_data.backward(server_grad)
                 self._optimizer.step()
 
-        if self._lr_scheduler is not None:
-            if round_no != -1:
-                self._lr_scheduler.step(round_no)
-            else:
-                self._lr_scheduler.step()
-
         client_train_time = (
             time.time() - client_train_start_time - sum(server_train_batch_times)
         )
diff --git a/edml/core/server.py b/edml/core/server.py
index 5e2f8235de476c5a998f9c8c040b462ac9d3b198..c0e60fb928ddbdaa2022ff08c5fb6192aa36cf10 100644
--- a/edml/core/server.py
+++ b/edml/core/server.py
@@ -90,6 +90,11 @@ class DeviceServer:
         if optimizer_state is not None:
             self._optimizer.load_state_dict(optimizer_state)
         for epoch in range(epochs):
+            if self._lr_scheduler is not None:
+                if round_no != -1:
+                    self._lr_scheduler.step(round_no + epoch)
+                else:
+                    self._lr_scheduler.step()
             for device_id in devices:
                 print(
                     f"Train epoch {epoch} on client {device_id} with server {self.node_device.device_id}"
@@ -120,11 +125,6 @@ class DeviceServer:
 
                     metrics.add_results(train_metrics)
                     metrics.add_results(val_metrics)
-            if self._lr_scheduler is not None:
-                if round_no != -1:
-                    self._lr_scheduler.step(round_no + epoch)
-                else:
-                    self._lr_scheduler.step()
         return (
             client_weights,
             self.get_weights(),
@@ -241,6 +241,12 @@ class DeviceServer:
         if optimizer_state is not None:
             self._optimizer.load_state_dict(optimizer_state)
 
+        if self._lr_scheduler is not None:
+            if round_no != -1:
+                self._lr_scheduler.step(round_no + 1)  # epoch=1
+            else:
+                self._lr_scheduler.step()
+
         num_threads = len(clients)
         executor = create_executor_with_threads(num_threads)
 
@@ -346,11 +352,6 @@ class DeviceServer:
             model_metrics.add_results(val_metrics)
 
         optimizer_state = self._optimizer.state_dict()
-        if self._lr_scheduler is not None:
-            if round_no != -1:
-                self._lr_scheduler.step(round_no + 1)  # epoch=1
-            else:
-                self._lr_scheduler.step()
         # delete references and free GPU memory manually
         server_batch = None
         server_labels = None