diff --git a/edml/controllers/scheduler/smart.py b/edml/controllers/scheduler/smart.py
index cf66bd15eda55e1c6387b09fd13c43c67b389557..d2de92b02e96f8d39c5250504eb123e54b31c8ec 100644
--- a/edml/controllers/scheduler/smart.py
+++ b/edml/controllers/scheduler/smart.py
@@ -98,6 +98,9 @@ class SmartNextServerScheduler(NextServerScheduler):
         global_params.smashed_data_size = optimization_metrics["smashed_data_size"]
         global_params.client_weights_size = optimization_metrics["client_weight_size"]
         global_params.server_weights_size = optimization_metrics["server_weight_size"]
+        global_params.optimizer_state_size = optimization_metrics[
+            "optimizer_state_size"
+        ]
         print(f"global params: {vars(global_params)}")
         print(f"device params: {[vars(device) for device in device_params_list]}")
         server_choice_optimizer = ServerChoiceOptimizer(
diff --git a/edml/controllers/strategy_optimization.py b/edml/controllers/strategy_optimization.py
index 76572e94d392f9a6922f6bd51d60bde949797086..022f2049f5695b25abe445fa242fbc564efa2ac1 100644
--- a/edml/controllers/strategy_optimization.py
+++ b/edml/controllers/strategy_optimization.py
@@ -37,6 +37,7 @@ class GlobalParams:
         batch_size=None,
         client_weights_size=None,
         server_weights_size=None,
+        optimizer_state_size=None,
         client_norm_fw_time=None,
         client_norm_bw_time=None,
         server_norm_fw_time=None,
@@ -52,6 +53,7 @@ class GlobalParams:
         self.batch_size = batch_size
         self.client_weights_size = client_weights_size
         self.server_weights_size = server_weights_size
+        self.optimizer_state_size = optimizer_state_size
         # metrics per sample
         self.label_size = label_size
         self.gradient_size = gradient_size
@@ -75,6 +77,7 @@ class GlobalParams:
             and self.cost_per_flop is not None
             and self.client_model_flops is not None
             and self.server_model_flops is not None
+            and self.optimizer_state_size is not None
             and self.smashed_data_size is not None
             and self.label_size is not None
             and self.gradient_size is not None
@@ -172,9 +175,9 @@ class ServerChoiceOptimizer:
             total_bytes += self.global_params.client_weights_size * (
                 self._num_devices() - 1
             )  # setting client weights before training
-            total_bytes += (
-                self.global_params.server_weights_size
-            )  # return server weights in the end
+            # return server weights and optimizer state in the end
+            total_bytes += self.global_params.server_weights_size
+            total_bytes += self.global_params.optimizer_state_size
             # exclude server device's own gradients
             total_bytes -= self.global_params.gradient_size * device.train_samples
         else:
@@ -202,9 +205,9 @@ class ServerChoiceOptimizer:
             total_bytes += self.global_params.client_weights_size * (
                 self._num_devices() - 1
             )  # clients return their weights to server
-            total_bytes += (
-                self.global_params.server_weights_size
-            )  # server weights set once in the beginning
+            # server weights and optimizer state set once in the beginning
+            total_bytes += self.global_params.server_weights_size
+            total_bytes += self.global_params.optimizer_state_size
             # exclude server device's own data and labels
             total_bytes -= self.global_params.label_size * (
                 device.train_samples + device.validation_samples
diff --git a/edml/helpers/metrics.py b/edml/helpers/metrics.py
index 2f7ef4021a087a7bea8133bf519f4d0f87abb843..7dc57e7b1757da656fe3ba31f11fdbda2d469621 100644
--- a/edml/helpers/metrics.py
+++ b/edml/helpers/metrics.py
@@ -431,6 +431,9 @@ def compute_metrics_for_optimization(
     result["server_weight_size"] = max(
         [metric.value for metric in raw_metrics[("size", "server_weights")]]
     )  # all values should be equal anyway
+    result["optimizer_state_size"] = max(
+        [metric.value for metric in raw_metrics[("size", "optimizer_state")]]
+    )  # all values should be equal anyway
 
     # time
     avg_server_model_train_time_per_sample = (
diff --git a/edml/tests/controllers/optimization_test.py b/edml/tests/controllers/optimization_test.py
index 614c4e74c1a60f4cd942404720459be83ef10a20..d488cd017429774b8aa950735e99a81e9ee1e818 100644
--- a/edml/tests/controllers/optimization_test.py
+++ b/edml/tests/controllers/optimization_test.py
@@ -54,6 +54,7 @@ class StrategyOptimizationTest(unittest.TestCase):
             server_norm_bw_time=4,
             client_weights_size=10,
             server_weights_size=10,
+            optimizer_state_size=10,
         )
         self.optimizer = ServerChoiceOptimizer(
             self.device_params_list, self.global_params
@@ -88,7 +89,7 @@ class StrategyOptimizationTest(unittest.TestCase):
 
     def test_num_bytes_sent_per_round_on_device(self):
         self.assertEqual(
-            self.optimizer.num_bytes_sent_per_round_on_device("d0", "d0"), 380
+            self.optimizer.num_bytes_sent_per_round_on_device("d0", "d0"), 390
         )
         self.assertEqual(
             self.optimizer.num_bytes_sent_per_round_on_device("d1", "d0"), 230
@@ -99,7 +100,7 @@ class StrategyOptimizationTest(unittest.TestCase):
 
     def test_num_bytes_received_per_round_on_device(self):
         self.assertEqual(
-            self.optimizer.num_bytes_received_per_round_on_device("d0", "d0"), 525
+            self.optimizer.num_bytes_received_per_round_on_device("d0", "d0"), 535
         )
         self.assertEqual(
             self.optimizer.num_bytes_received_per_round_on_device("d1", "d0"), 160
@@ -109,7 +110,7 @@ class StrategyOptimizationTest(unittest.TestCase):
         )
 
     def test_energy_per_round_on_device(self):
-        self.assertEqual(self.optimizer.energy_per_round_on_device("d0", "d0"), 1668.5)
+        self.assertEqual(self.optimizer.energy_per_round_on_device("d0", "d0"), 1688.5)
         self.assertEqual(self.optimizer.energy_per_round_on_device("d1", "d0"), 583.5)
         self.assertEqual(self.optimizer.energy_per_round_on_device("d2", "d0"), 718.5)
 
@@ -164,6 +165,7 @@ class EnergySimulatorTest(unittest.TestCase):
             server_norm_bw_time=4,
             client_weights_size=10,
             server_weights_size=10,
+            optimizer_state_size=10,
         )
         self.simulator = EnergySimulator(self.device_params_list, self.global_params)
 
@@ -175,7 +177,7 @@ class EnergySimulatorTest(unittest.TestCase):
         self.assertEqual(
             schedule, ["d0", "d1", "d2", "d0", "d1", "d2", "d0", "d1", "d0", "d2"]
         )
-        self.assertEqual(remaining_batteries, [545.0, 1045.0, 325.0])
+        self.assertEqual(remaining_batteries, [465.0, 985.0, 265.0])
 
     def test_simulate_smart_selection(self):
         num_rounds, solution, remaining_batteries = (
@@ -183,7 +185,7 @@ class EnergySimulatorTest(unittest.TestCase):
         )
         self.assertEqual(num_rounds, 10)
         self.assertEqual(solution, {"d0": 4.0, "d1": 3.0, "d2": 3.0})
-        self.assertEqual(remaining_batteries, [545.0, 1045.0, 325.0])
+        self.assertEqual(remaining_batteries, [465.0, 985.0, 265.0])
 
     def test_fl_round_time(self):
         self.assertEqual(self.simulator._fl_round_time(), 60.0)
@@ -280,6 +282,7 @@ class TestWithRealData(unittest.TestCase):
             client_weights_size=71678,
             # train global response size 15878758
             server_weights_size=15807080,
+            optimizer_state_size=0,  # was not recorded then
         )
         self.optimizer = ServerChoiceOptimizer(
             self.device_params_list, self.global_params
diff --git a/edml/tests/controllers/scheduler/smart_test.py b/edml/tests/controllers/scheduler/smart_test.py
index bcc0a2e68a5e1a1694dd4399414081787905eb14..4e0178b6226c2e5dabfc055dff5484e7948ea7ef 100644
--- a/edml/tests/controllers/scheduler/smart_test.py
+++ b/edml/tests/controllers/scheduler/smart_test.py
@@ -18,6 +18,7 @@ class SmartServerDeviceSelectionTest(unittest.TestCase):
             "smashed_data_size": 100000,
             "client_weight_size": 300000,
             "server_weight_size": 300000,
+            "optimizer_state_size": 300000,
             "client_norm_fw_time": 3,
             "client_norm_bw_time": 3,
             "server_norm_fw_time": 3,
diff --git a/edml/tests/helpers/metrics_test.py b/edml/tests/helpers/metrics_test.py
index 0bd02a27ab679d9d34defa252ffab08ba240165f..5810d4e0a4ad18313b97765adf3a787e3c2a3b00 100644
--- a/edml/tests/helpers/metrics_test.py
+++ b/edml/tests/helpers/metrics_test.py
@@ -285,6 +285,9 @@ class MetricsForOptimizationTest(unittest.TestCase):
             DiagnosticMetricResult(
                 device_id="d1", method="server_weights", name="size", value=500
             ),
+            DiagnosticMetricResult(
+                device_id="d1", method="optimizer_state", name="size", value=500
+            ),
         ]
         diagnostic_metric_result_container = DiagnosticMetricResultContainer(results)
         batch_size = 1
@@ -300,6 +303,7 @@ class MetricsForOptimizationTest(unittest.TestCase):
         self.assertEqual(normalized_metrics["gradient_size"], 50)
         self.assertEqual(normalized_metrics["client_weight_size"], 500)
         self.assertEqual(normalized_metrics["server_weight_size"], 500)
+        self.assertEqual(normalized_metrics["optimizer_state_size"], 500)
         self.assertAlmostEqual(normalized_metrics["client_norm_fw_time"], 0.1)
         self.assertAlmostEqual(normalized_metrics["client_norm_bw_time"], 0.2)
         self.assertAlmostEqual(normalized_metrics["server_norm_fw_time"], 0.1)