diff --git a/edml/controllers/scheduler/smart.py b/edml/controllers/scheduler/smart.py index cf66bd15eda55e1c6387b09fd13c43c67b389557..d2de92b02e96f8d39c5250504eb123e54b31c8ec 100644 --- a/edml/controllers/scheduler/smart.py +++ b/edml/controllers/scheduler/smart.py @@ -98,6 +98,9 @@ class SmartNextServerScheduler(NextServerScheduler): global_params.smashed_data_size = optimization_metrics["smashed_data_size"] global_params.client_weights_size = optimization_metrics["client_weight_size"] global_params.server_weights_size = optimization_metrics["server_weight_size"] + global_params.optimizer_state_size = optimization_metrics[ + "optimizer_state_size" + ] print(f"global params: {vars(global_params)}") print(f"device params: {[vars(device) for device in device_params_list]}") server_choice_optimizer = ServerChoiceOptimizer( diff --git a/edml/controllers/strategy_optimization.py b/edml/controllers/strategy_optimization.py index 76572e94d392f9a6922f6bd51d60bde949797086..022f2049f5695b25abe445fa242fbc564efa2ac1 100644 --- a/edml/controllers/strategy_optimization.py +++ b/edml/controllers/strategy_optimization.py @@ -37,6 +37,7 @@ class GlobalParams: batch_size=None, client_weights_size=None, server_weights_size=None, + optimizer_state_size=None, client_norm_fw_time=None, client_norm_bw_time=None, server_norm_fw_time=None, @@ -52,6 +53,7 @@ class GlobalParams: self.batch_size = batch_size self.client_weights_size = client_weights_size self.server_weights_size = server_weights_size + self.optimizer_state_size = optimizer_state_size # metrics per sample self.label_size = label_size self.gradient_size = gradient_size @@ -75,6 +77,7 @@ class GlobalParams: and self.cost_per_flop is not None and self.client_model_flops is not None and self.server_model_flops is not None + and self.optimizer_state_size is not None and self.smashed_data_size is not None and self.label_size is not None and self.gradient_size is not None @@ -172,9 +175,9 @@ class ServerChoiceOptimizer: total_bytes += self.global_params.client_weights_size * ( self._num_devices() - 1 ) # setting client weights before training - total_bytes += ( - self.global_params.server_weights_size - ) # return server weights in the end + # return server weights and optimizer state in the end + total_bytes += self.global_params.server_weights_size + total_bytes += self.global_params.optimizer_state_size # exclude server device's own gradients total_bytes -= self.global_params.gradient_size * device.train_samples else: @@ -202,9 +205,9 @@ class ServerChoiceOptimizer: total_bytes += self.global_params.client_weights_size * ( self._num_devices() - 1 ) # clients return their weights to server - total_bytes += ( - self.global_params.server_weights_size - ) # server weights set once in the beginning + # server weights and optimizer state set once in the beginning + total_bytes += self.global_params.server_weights_size + total_bytes += self.global_params.optimizer_state_size # exclude server device's own data and labels total_bytes -= self.global_params.label_size * ( device.train_samples + device.validation_samples diff --git a/edml/helpers/metrics.py b/edml/helpers/metrics.py index 2f7ef4021a087a7bea8133bf519f4d0f87abb843..7dc57e7b1757da656fe3ba31f11fdbda2d469621 100644 --- a/edml/helpers/metrics.py +++ b/edml/helpers/metrics.py @@ -431,6 +431,9 @@ def compute_metrics_for_optimization( result["server_weight_size"] = max( [metric.value for metric in raw_metrics[("size", "server_weights")]] ) # all values should be equal anyway + result["optimizer_state_size"] = max( + [metric.value for metric in raw_metrics[("size", "optimizer_state")]] + ) # all values should be equal anyway # time avg_server_model_train_time_per_sample = ( diff --git a/edml/tests/controllers/optimization_test.py b/edml/tests/controllers/optimization_test.py index 614c4e74c1a60f4cd942404720459be83ef10a20..d488cd017429774b8aa950735e99a81e9ee1e818 100644 --- a/edml/tests/controllers/optimization_test.py +++ b/edml/tests/controllers/optimization_test.py @@ -54,6 +54,7 @@ class StrategyOptimizationTest(unittest.TestCase): server_norm_bw_time=4, client_weights_size=10, server_weights_size=10, + optimizer_state_size=10, ) self.optimizer = ServerChoiceOptimizer( self.device_params_list, self.global_params @@ -88,7 +89,7 @@ class StrategyOptimizationTest(unittest.TestCase): def test_num_bytes_sent_per_round_on_device(self): self.assertEqual( - self.optimizer.num_bytes_sent_per_round_on_device("d0", "d0"), 380 + self.optimizer.num_bytes_sent_per_round_on_device("d0", "d0"), 390 ) self.assertEqual( self.optimizer.num_bytes_sent_per_round_on_device("d1", "d0"), 230 @@ -99,7 +100,7 @@ class StrategyOptimizationTest(unittest.TestCase): def test_num_bytes_received_per_round_on_device(self): self.assertEqual( - self.optimizer.num_bytes_received_per_round_on_device("d0", "d0"), 525 + self.optimizer.num_bytes_received_per_round_on_device("d0", "d0"), 535 ) self.assertEqual( self.optimizer.num_bytes_received_per_round_on_device("d1", "d0"), 160 @@ -109,7 +110,7 @@ class StrategyOptimizationTest(unittest.TestCase): ) def test_energy_per_round_on_device(self): - self.assertEqual(self.optimizer.energy_per_round_on_device("d0", "d0"), 1668.5) + self.assertEqual(self.optimizer.energy_per_round_on_device("d0", "d0"), 1688.5) self.assertEqual(self.optimizer.energy_per_round_on_device("d1", "d0"), 583.5) self.assertEqual(self.optimizer.energy_per_round_on_device("d2", "d0"), 718.5) @@ -164,6 +165,7 @@ class EnergySimulatorTest(unittest.TestCase): server_norm_bw_time=4, client_weights_size=10, server_weights_size=10, + optimizer_state_size=10, ) self.simulator = EnergySimulator(self.device_params_list, self.global_params) @@ -175,7 +177,7 @@ class EnergySimulatorTest(unittest.TestCase): self.assertEqual( schedule, ["d0", "d1", "d2", "d0", "d1", "d2", "d0", "d1", "d0", "d2"] ) - self.assertEqual(remaining_batteries, [545.0, 1045.0, 325.0]) + self.assertEqual(remaining_batteries, [465.0, 985.0, 265.0]) def test_simulate_smart_selection(self): num_rounds, solution, remaining_batteries = ( @@ -183,7 +185,7 @@ class EnergySimulatorTest(unittest.TestCase): ) self.assertEqual(num_rounds, 10) self.assertEqual(solution, {"d0": 4.0, "d1": 3.0, "d2": 3.0}) - self.assertEqual(remaining_batteries, [545.0, 1045.0, 325.0]) + self.assertEqual(remaining_batteries, [465.0, 985.0, 265.0]) def test_fl_round_time(self): self.assertEqual(self.simulator._fl_round_time(), 60.0) @@ -280,6 +282,7 @@ class TestWithRealData(unittest.TestCase): client_weights_size=71678, # train global response size 15878758 server_weights_size=15807080, + optimizer_state_size=0, # was not recorded then ) self.optimizer = ServerChoiceOptimizer( self.device_params_list, self.global_params diff --git a/edml/tests/controllers/scheduler/smart_test.py b/edml/tests/controllers/scheduler/smart_test.py index bcc0a2e68a5e1a1694dd4399414081787905eb14..4e0178b6226c2e5dabfc055dff5484e7948ea7ef 100644 --- a/edml/tests/controllers/scheduler/smart_test.py +++ b/edml/tests/controllers/scheduler/smart_test.py @@ -18,6 +18,7 @@ class SmartServerDeviceSelectionTest(unittest.TestCase): "smashed_data_size": 100000, "client_weight_size": 300000, "server_weight_size": 300000, + "optimizer_state_size": 300000, "client_norm_fw_time": 3, "client_norm_bw_time": 3, "server_norm_fw_time": 3, diff --git a/edml/tests/helpers/metrics_test.py b/edml/tests/helpers/metrics_test.py index 0bd02a27ab679d9d34defa252ffab08ba240165f..5810d4e0a4ad18313b97765adf3a787e3c2a3b00 100644 --- a/edml/tests/helpers/metrics_test.py +++ b/edml/tests/helpers/metrics_test.py @@ -285,6 +285,9 @@ class MetricsForOptimizationTest(unittest.TestCase): DiagnosticMetricResult( device_id="d1", method="server_weights", name="size", value=500 ), + DiagnosticMetricResult( + device_id="d1", method="optimizer_state", name="size", value=500 + ), ] diagnostic_metric_result_container = DiagnosticMetricResultContainer(results) batch_size = 1 @@ -300,6 +303,7 @@ class MetricsForOptimizationTest(unittest.TestCase): self.assertEqual(normalized_metrics["gradient_size"], 50) self.assertEqual(normalized_metrics["client_weight_size"], 500) self.assertEqual(normalized_metrics["server_weight_size"], 500) + self.assertEqual(normalized_metrics["optimizer_state_size"], 500) self.assertAlmostEqual(normalized_metrics["client_norm_fw_time"], 0.1) self.assertAlmostEqual(normalized_metrics["client_norm_bw_time"], 0.2) self.assertAlmostEqual(normalized_metrics["server_norm_fw_time"], 0.1)