diff --git a/README.md b/README.md index 1ee07534df56656bda49b1d153093233017cd005..e61b9064479425699e5b1645d81bff6e434d96c4 100644 --- a/README.md +++ b/README.md @@ -1,3 +1,43 @@ # SwarmSplitLearning -We introduce in this repository our own developed fully distributed variant of the well known split learning algorithm +In this repository, we introduce our own fully distributed variant of the well-known split learning algorithm. + + +## Installation + +For a faster installation, install the libmamba-solver first before creating the conda environment: + +```bash +conda install -n base conda-libmamba-solver +conda config --set solver libmamba +``` + +Otherwise, the default conda solver will be used, which may take forever. + +Update the file path to your desired directory for the environment by changing the value of `prefix:` at the end of +the [environment.yml](environment.yml) file. +Then run: + +```bash +conda env create -f environment.yml +conda activate [ENV_NAME] +``` + +For tracking experiments using Weights and Biases, place your API key without any spaces in +the [wandb_key.txt](wandb_key.txt) file and make sure that `wandb: True` is set in +the [wandb config](edml/config/wandb.yaml). +Otherwise, metrics etc. will be printed to the console. + +If you plan to commit to this repository, please install **pre-commit** for a consistent code formatting upon committing. +Therefore, run the following command in the repo: + +```bash + pre-commit install +``` + +Optionally, for formatting without committing, you may run: +```bash +pre-commit run --all-files +``` + + diff --git a/diagrams/Basic_CD.png b/diagrams/Basic_CD.png deleted file mode 100644 index 7b432dbbdecc7cbfd7292f5276a4ec76027231d7..0000000000000000000000000000000000000000 Binary files a/diagrams/Basic_CD.png and /dev/null differ diff --git a/diagrams/Basic_SL_Interaction.png b/diagrams/Basic_SL_Interaction.png deleted file mode 100644 index f1d6a5be49745cc070fe26a52555deffd4801370..0000000000000000000000000000000000000000 Binary files a/diagrams/Basic_SL_Interaction.png and /dev/null differ diff --git a/diagrams/Device_SL_CD.png b/diagrams/Device_SL_CD.png deleted file mode 100644 index 98ef8d653e079d98a2a2a557d2d8828082c93243..0000000000000000000000000000000000000000 Binary files a/diagrams/Device_SL_CD.png and /dev/null differ diff --git a/diagrams/Device_SL_interaction.png b/diagrams/Device_SL_interaction.png deleted file mode 100644 index 96c0fb6f31948ef5c25f8123e9abd33d4bdf64a1..0000000000000000000000000000000000000000 Binary files a/diagrams/Device_SL_interaction.png and /dev/null differ diff --git a/edml/controllers/strategy_optimization.py b/edml/controllers/strategy_optimization.py index 81570cb0407898bec8bb5e4d9c2f2405922aedf5..76572e94d392f9a6922f6bd51d60bde949797086 100644 --- a/edml/controllers/strategy_optimization.py +++ b/edml/controllers/strategy_optimization.py @@ -581,3 +581,106 @@ def run_grid_search( } ) return results + + +def run_grid_search_with_variable_devices( + num_devices_list, + global_params, + battery_per_device, + total_train_samples, + total_val_samples, + max_latencies=None, + max_split=None, + cost_per_sec=None, + cost_per_byte_sent=None, + cost_per_byte_received=None, + cost_per_flop=None, +): + """ + Runs a grid search for the given device parameters and global parameters. + Params: + num_devices: list(number of devices) + global_params: GlobalParams object + battery_per_device: list(battery per device) + total_train_samples: total number of train samples + total_val_samples: total number of validation samples + max_latencies: optional list(max latency per device) Sets all devices except the last one to the max latency if provided + max_split: optional list(max split per device) Sets the largest split for last device if provided and distributes the rest equally among the other devices + **kwargs: optional parameters for the grid search: should be a list of lists for device params and a list for costs + If e.g. cost_per_second is provided, the grid search will be run for all values in the list overriding existing values in the global params object. + If no cost_per_second is provided, the grid search will use the value from the global params object. + Returns: + list of dicts containing the results for each combination of parameters + """ + if max_latencies is None: + max_latencies = [1.0] + if cost_per_sec is None: + cost_per_sec = [global_params.cost_per_sec] + if cost_per_byte_sent is None: + cost_per_byte_sent = [global_params.cost_per_byte_sent] + if cost_per_byte_received is None: + cost_per_byte_received = [global_params.cost_per_byte_received] + if cost_per_flop is None: + cost_per_flop = [global_params.cost_per_flop] + results = [] + for num_devices in num_devices_list: + if max_split is None: + max_split = [1 / num_devices] + for battery in battery_per_device: + for latency in max_latencies: + for partition in max_split: + device_params_list = [ + DeviceParams( + device_id=f"d{i}", + initial_battery=0, + current_battery=battery, + train_samples=( + partition * total_train_samples + if i == num_devices - 1 + else total_train_samples // num_devices + ), + validation_samples=( + partition * total_val_samples + if i == num_devices - 1 + else total_val_samples // num_devices + ), + comp_latency_factor=latency if i < num_devices - 1 else 1.0, + ) + for i in range(num_devices) + ] + for cost_sec in cost_per_sec: + for cost_sent in cost_per_byte_sent: + for cost_received in cost_per_byte_received: + for cost_flop in cost_per_flop: + global_params.cost_per_sec = cost_sec + global_params.cost_per_byte_sent = cost_sent + global_params.cost_per_byte_received = cost_received + global_params.cost_per_flop = cost_flop + energy_simulator = EnergySimulator( + device_params_list, global_params + ) + num_rounds_smart, _, _ = ( + energy_simulator.simulate_smart_selection() + ) + num_rounds_greedy, _, _ = ( + energy_simulator.simulate_greedy_selection() + ) + num_rounds_fl, _ = ( + energy_simulator.simulate_federated_learning() + ) + results.append( + { + "num_devices": num_devices, + "battery": battery, + "latency": latency, + "partition": partition, + "cost_per_sec": cost_sec, + "cost_per_byte_sent": cost_sent, + "cost_per_byte_received": cost_received, + "cost_per_flop": cost_flop, + "num_rounds_smart": num_rounds_smart, + "num_rounds_greedy": num_rounds_greedy, + "num_rounds_fl": num_rounds_fl, + } + ) + return results diff --git a/edml/tests/controllers/optimization_test.py b/edml/tests/controllers/optimization_test.py index ce4e5bdfbd90c78aee6ff5ef39e9cda7001bc063..614c4e74c1a60f4cd942404720459be83ef10a20 100644 --- a/edml/tests/controllers/optimization_test.py +++ b/edml/tests/controllers/optimization_test.py @@ -204,3 +204,231 @@ class EnergySimulatorTest(unittest.TestCase): num_rounds, remaining_batteries = self.simulator.simulate_federated_learning() self.assertEqual(num_rounds, 20) self.assertEqual(remaining_batteries, [3800, 2000, 200]) + + +class TestWithRealData(unittest.TestCase): + """Case studies for estimating the number of rounds of each strategy with given energy constratins.""" + + def setUp(self): + self.train_samples = [9600, 9600, 9600, 9600, 9600] + self.validation_samples = [2400, 2400, 2400, 2400, 2400] + self.current_batteries = [3750, 3750, 3750, 3750, 3750] + self.comp_latency = [1, 1, 1, 1, 1] + self.cost_per_sec = 1 + self.cost_per_mbyte_sent = 1 + self.cost_per_mbyte_received = 1 + self.cost_per_mflop = 1 + + def _init_params_and_optimizer(self): + self.device_params_list = [ + DeviceParams( + device_id="d0", + initial_battery=3750, + current_battery=self.current_batteries[0], + train_samples=self.train_samples[0], + validation_samples=self.validation_samples[0], + comp_latency_factor=self.comp_latency[0], + ), + DeviceParams( + device_id="d1", + initial_battery=3750, + current_battery=self.current_batteries[1], + train_samples=self.train_samples[1], + validation_samples=self.validation_samples[1], + comp_latency_factor=self.comp_latency[1], + ), + DeviceParams( + device_id="d2", + initial_battery=3750, + current_battery=self.current_batteries[2], + train_samples=self.train_samples[2], + validation_samples=self.validation_samples[2], + comp_latency_factor=self.comp_latency[2], + ), + DeviceParams( + device_id="d3", + initial_battery=3750, + current_battery=self.current_batteries[3], + train_samples=self.train_samples[3], + validation_samples=self.validation_samples[3], + comp_latency_factor=self.comp_latency[3], + ), + DeviceParams( + device_id="d4", + initial_battery=3750, + current_battery=self.current_batteries[4], + train_samples=self.train_samples[4], + validation_samples=self.validation_samples[4], + comp_latency_factor=self.comp_latency[4], + ), + ] + self.global_params = GlobalParams( + cost_per_sec=self.cost_per_sec, + cost_per_byte_sent=self.cost_per_mbyte_sent / 1000000, + cost_per_byte_received=self.cost_per_mbyte_received / 1000000, + cost_per_flop=self.cost_per_mflop / 1000000, + client_model_flops=5405760, + server_model_flops=11215800, + smashed_data_size=36871, + label_size=14, + gradient_size=36871, + batch_size=64, + client_norm_fw_time=0.0001318198063394479, + client_norm_bw_time=1.503657614975644e-05, + server_norm_fw_time=2.1353501238321005e-05, + server_norm_bw_time=3.1509113154913254e-05, + client_weights_size=71678, + # train global response size 15878758 + server_weights_size=15807080, + ) + self.optimizer = ServerChoiceOptimizer( + self.device_params_list, self.global_params + ) + self.simulator = EnergySimulator(self.device_params_list, self.global_params) + + def test_with_actual_data(self): + self.current_batteries = [3719.654, 3717.608, 3711.923, 3708.051, 3704.294] + self.cost_per_sec = 1 # 0.1 + self.cost_per_mbyte_sent = 0.05 # 0.002 + self.cost_per_mbyte_received = 0.05 # 0.0005 + self.cost_per_mflop = 0.00025 # 0.0005 + self._init_params_and_optimizer() + solution, status = self.optimizer.optimize() + self.assertEqual( + solution, {"d0": 4.0, "d1": 4.0, "d2": 4.0, "d3": 3.0, "d4": 0.0} + ) + + def test_simulate_greedy_selection_with_actual_data(self): + self.current_batteries = [3719.654, 3717.608, 3711.923, 3708.051, 3704.294] + self.cost_per_sec = 1 # 0.1 + self.cost_per_mbyte_sent = 0.05 # 0.002 + self.cost_per_mbyte_received = 0.05 # 0.0005 + self.cost_per_mflop = 0.00025 # 0.0005 + self._init_params_and_optimizer() + num_rounds, schedule, remaining_batteries = ( + self.simulator.simulate_greedy_selection() + ) + self.assertEqual(num_rounds, 15) + self.assertEqual( + schedule, + [ + "d0", + "d1", + "d2", + "d3", + "d4", + "d0", + "d1", + "d2", + "d3", + "d4", + "d0", + "d1", + "d2", + "d3", + "d4", + ], + ) + + def test_unequal_split(self): + self.current_batteries = [5750, 4750, 3750, 2750, 1750] + self.train_samples = [4800, 9600, 19200, 7200, 7200] + self.validation_samples = [1200, 2400, 4800, 1800, 1800] + self.cost_per_sec = 1 + self.cost_per_mbyte_sent = 0.05 + self.cost_per_mbyte_received = 0.05 + self.cost_per_mflop = 0.00025 + self._init_params_and_optimizer() + solution, status = self.optimizer.optimize() + self.assertEqual( + solution, {"d0": 8.0, "d1": 5.0, "d2": 1.0, "d3": 2.0, "d4": 0.0} + ) + + def test_unequal_split_and_batteries_with_high_communication_cost(self): + self.current_batteries = [3750, 3750, 3750, 3750, 3750] + data_partitions = [0.1, 0.1, 0.6, 0.1, 0.1] + self.train_samples = [partition * 48000 for partition in data_partitions] + self.validation_samples = [partition * 12000 for partition in data_partitions] + self.cost_per_sec = 0.1 + self.cost_per_mbyte_sent = 0.1 + self.cost_per_mbyte_received = 0.1 + self.cost_per_mflop = 0.00001 + self._init_params_and_optimizer() + solution, status = self.optimizer.optimize() + num_rounds, schedule, remaining_batteries = ( + self.simulator.simulate_greedy_selection() + ) + print(solution, sum(solution.values()), num_rounds, remaining_batteries) + self.assertGreater(sum(solution.values()), num_rounds) + + def test_unequal_battery_unequal_processing_high_time_cost(self): + self.current_batteries = [5750, 4750, 3750, 2750, 1750] + data_partitions = [0.2, 0.1, 0.1, 0.1, 0.5] + self.comp_latency = [5, 5, 5, 5, 1] + self.train_samples = [partition * 48000 for partition in data_partitions] + self.validation_samples = [partition * 12000 for partition in data_partitions] + self.cost_per_sec = 10 + self.cost_per_mbyte_sent = 0.0 # 1 + self.cost_per_mbyte_received = 0.0 # 1 + self.cost_per_mflop = 0.0000 # 5 + self._init_params_and_optimizer() + solution, status = self.optimizer.optimize() + num_rounds, schedule, remaining_batteries = ( + self.simulator.simulate_greedy_selection() + ) + print(solution, sum(solution.values()), num_rounds) + self.assertGreater(sum(solution.values()), num_rounds) + + def test_unequal_battery_unequal_processing_high_time_cost2(self): + self.current_batteries = [3750, 3750, 3750, 3750, 3750] + data_partitions = [0.05, 0.1, 0.1, 0.1, 0.65] + self.comp_latency = [10, 10, 10, 10, 1] + self.train_samples = [partition * 48000 for partition in data_partitions] + self.validation_samples = [partition * 12000 for partition in data_partitions] + self.cost_per_sec = 3 + self.cost_per_mbyte_sent = 0.05 + self.cost_per_mbyte_received = 0.05 + self.cost_per_mflop = 0.000005 + self._init_params_and_optimizer() + solution, status = self.optimizer.optimize() + num_rounds, schedule, remaining_batteries = ( + self.simulator.simulate_greedy_selection() + ) + print(solution, sum(solution.values()), num_rounds, remaining_batteries) + self.assertGreater(sum(solution.values()), num_rounds) + + def test_unequal_battery_unequal_processing_high_time_cost3(self): + self.current_batteries = [3750, 3750, 3750, 3750, 3750] + data_partitions = [0.01, 0.01, 0.01, 0.01, 0.96] + self.comp_latency = [100, 100, 100, 100, 1] + self.train_samples = [partition * 48000 for partition in data_partitions] + self.validation_samples = [partition * 12000 for partition in data_partitions] + self.cost_per_sec = 3 + self.cost_per_mbyte_sent = 0.05 # 375 + self.cost_per_mbyte_received = 0.05 + self.cost_per_mflop = 0.000005 + self._init_params_and_optimizer() + solution, status = self.optimizer.optimize() + num_rounds, schedule, remaining_batteries = ( + self.simulator.simulate_greedy_selection() + ) + print(solution, sum(solution.values()), num_rounds, remaining_batteries) + self.assertGreater(sum(solution.values()), num_rounds) + + def test_unequal_battery_unequal_processing_high_time_cost4(self): + self.current_batteries = [3750, 3750, 3750, 3750, 3750] + data_partitions = [0.1, 0.1, 0.1, 0.1, 0.6] + self.comp_latency = [10, 10, 10, 10, 1] + self.train_samples = [partition * 48000 for partition in data_partitions] + self.validation_samples = [partition * 12000 for partition in data_partitions] + self.cost_per_sec = 4 + self.cost_per_mbyte_sent = 0.01 + self.cost_per_mbyte_received = 0.01 + self.cost_per_mflop = 0.00001 + self._init_params_and_optimizer() + solution, status = self.optimizer.optimize() + num_rounds, schedule, remaining_batteries = ( + self.simulator.simulate_greedy_selection() + ) + print(solution, sum(solution.values()), num_rounds, remaining_batteries) + self.assertGreater(sum(solution.values()), num_rounds) diff --git a/environment.yml b/environment.yml index b7c4e6c10cd35765f9527be7feb52ee37d21861a..1de0f15678d2ab23fe1767682327f3efa2074173 100644 --- a/environment.yml +++ b/environment.yml @@ -15,6 +15,7 @@ dependencies: - numpy=1.24.3 - pandas=1.4.2 - pip=21.2.4 + - pre-commit=3.7.0 - protobuf=3.19.1 - pydantic - pytest=7.4.2 diff --git a/results/result_generation.ipynb b/results/result_generation.ipynb index 35530262d464a28041811a49a9534674e3e68c4c..85b37ddb464b85b1e634dcbcde25c0a221baa9fb 100644 --- a/results/result_generation.ipynb +++ b/results/result_generation.ipynb @@ -9,7 +9,7 @@ }, "outputs": [], "source": [ - "from results.result_generation import save_dataframes, load_dataframes, generate_metric_files, generate_plots" + "from result_generation import save_dataframes, load_dataframes, generate_metric_files, generate_plots" ] }, { diff --git a/results/result_generation.py b/results/result_generation.py index 4895b87704aecaf2ad9fa27af19bd56be11e84b5..b95183ebae7218c3c6ae7453b6bc103d69ba163e 100644 --- a/results/result_generation.py +++ b/results/result_generation.py @@ -9,10 +9,10 @@ import wandb # For plotting STRATEGY_MAPPING = { "split": "Vanilla SL", - "swarm_smart": "SwarmSL (Smart)", - "swarm_seq": "SwarmSL (Seq)", - "swarm_rand": "SwarmSL (Rand)", - "swarm_max": "SwarmSL (Greedy)", + "swarm_smart": "Swarm SL (Smart)", + "swarm_seq": "Swarm SL (Seq)", + "swarm_rand": "Swarm SL (Rand)", + "swarm_max": "Swarm SL (Greedy)", "fed": "Vanilla FL", } @@ -318,18 +318,26 @@ def accuracy_over_epoch(history_groups, phase="train"): round_cols.append(list(run_df[f"{phase}_accuracy.round"].dropna())) value_cols.append(list(run_df[f"{phase}_accuracy.value"].dropna())) # if multiple columns exist (i.e. multiple runs) average in each round and if one run was shorter, use last value - max_rounds = max([len(col) for col in round_cols]) - round = range(0, max_rounds) - acc = [] - for i in round: + max_rounds = max([int(col[-1]) for col in round_cols]) + 1 + mean_acc, round_no = [], [] + for i in range(max_rounds): single_run_accs = [] - for j in range(len(value_cols)): - if len(value_cols[j]) > i: - single_run_accs.append(value_cols[j][i]) - else: - single_run_accs.append(value_cols[j][-1]) - acc.append(sum(single_run_accs) / len(single_run_accs)) - results[(strategy, job)] = (round, acc) + for run_idx, value_col in enumerate( + value_cols + ): # round_cols should have same length + if ( + round_cols[run_idx].count(i) > 0 + ): # check if values were logged in this round + round_idx = round_cols[run_idx].index( + i + ) # get the index of the round (could be less than the round number due to missing values) + single_run_accs.append(value_col[round_idx]) + # print(i, single_run_accs) + if len(single_run_accs) > 0: + mean_acc.append(sum(single_run_accs) / len(single_run_accs)) + round_no.append(i) + + results[(strategy, job)] = (round_no, mean_acc) return results @@ -515,6 +523,7 @@ def plot_batteries_over_time( """ if aggregated: plt.figure() + plt.rcParams.update({"font.size": 13}) for (strategy, job), series in batteries_over_time.items(): runtime = max_runtimes[(strategy, job)] x_values = [ @@ -535,6 +544,7 @@ def plot_batteries_over_time( else: for (strategy, job), series_dict in batteries_over_time.items(): plt.figure() + plt.rcParams.update({"font.size": 13}) for device_id, series in series_dict.items(): runtime = max_runtimes[(strategy, job)] x_values = [ @@ -564,6 +574,7 @@ def plot_batteries_over_epoch(batteries_over_epoch, save_path=None, aggregated=T """ if aggregated: plt.figure() + plt.rcParams.update({"font.size": 13}) num_rounds = [0] for (strategy, job), series in batteries_over_epoch.items(): x_values = series.index @@ -584,6 +595,7 @@ def plot_batteries_over_epoch(batteries_over_epoch, save_path=None, aggregated=T else: for (strategy, job), series_dict in batteries_over_epoch.items(): plt.figure() + plt.rcParams.update({"font.size": 13}) num_rounds = [0] for device_id, series in series_dict.items(): x_values = series.index @@ -643,23 +655,40 @@ def get_train_times(run_groups): run_df["client_train_epoch_time.start"] > 0 ) valid_eval_times = run_df["client_evaluate_time.end"] > 0 - train_epoch_times = run_df[valid_train_epoch_times][ + train_epoch_start_times = run_df[valid_train_epoch_times][ ["client_train_epoch_time.start"] ].reset_index(drop=True) - eval_times = run_df[valid_eval_times][ + train_epoch_end_times = run_df[valid_train_epoch_times][ + ["client_train_epoch_time.end"] + ].reset_index(drop=True) + eval_start_times = run_df[valid_eval_times][ + ["client_evaluate_time.start"] + ].reset_index(drop=True) + eval_end_times = run_df[valid_eval_times][ ["client_evaluate_time.end"] ].reset_index(drop=True) - client_train_times[device_id] = ( - pd.concat( - [train_epoch_times[: len(eval_times)], eval_times], axis=1 - ).rename( - columns={ - "client_train_epoch_time.start": "start", - "client_evaluate_time.end": "end", - } - ) - - min_start_time - ) + client_train_times[device_id] = pd.concat( + [ + pd.concat( + [train_epoch_start_times, train_epoch_end_times], axis=1 + ).rename( + columns={ + "client_train_epoch_time.start": "start", + "client_train_epoch_time.end": "end", + } + ) + - min_start_time, + pd.concat( + [eval_start_times, eval_end_times], axis=1 + ).rename( + columns={ + "client_evaluate_time.start": "start", + "client_evaluate_time.end": "end", + } + ) + - min_start_time, + ] + ).reset_index(drop=True) if "fed_train_time.start" in run_df.columns: valid_fed_train_times = run_df["fed_train_time.start"] > 0 @@ -681,6 +710,17 @@ def get_train_times(run_groups): return results +def train_time_end(server_train_times, client_train_times): + max_timestamp = 0 + for idx, (device_id, df) in enumerate(server_train_times.items()): + if df["end"].max() > max_timestamp: + max_timestamp = df["end"].max() + for idx, (device_id, df) in enumerate(client_train_times.items()): + if df["end"].max() > max_timestamp: + max_timestamp = df["end"].max() + return max_timestamp + + def plot_batteries_over_time_with_activity( batteries_over_time, max_runtimes, training_times, save_path=None ): @@ -689,7 +729,12 @@ def plot_batteries_over_time_with_activity( y_offset = 10 for (strategy, job), series_dict in batteries_over_time.items(): server_train_times, client_train_times = training_times[(strategy, job)] + xlim = ( + 0, + train_time_end(server_train_times, client_train_times) * 1.05, + ) # set end 5% after last activity timestamp plt.figure() + plt.rcParams.update({"font.size": 13}) # battery_plot plt.subplot(2, 1, 1) for device_id, series in series_dict.items(): @@ -702,7 +747,7 @@ def plot_batteries_over_time_with_activity( plt.ylabel(LABEL_MAPPING["device battery"]) plt.legend() plt.tight_layout() - # plt.xlim(xlim) + plt.xlim(xlim) # plt.ylim(3000, 3800) # activity plot plt.subplot(2, 1, 2) @@ -729,7 +774,7 @@ def plot_batteries_over_time_with_activity( plt.ylabel(LABEL_MAPPING["device"]) plt.legend(loc="upper left") plt.tight_layout() - # plt.xlim(xlim) + plt.xlim(xlim) plt.yticks( [(device_space * x + 2 * bar_height + y_offset) for x in range(0, 5)], labels=["d0", "d1", "d2", "d3", "d4"][:5], @@ -766,6 +811,7 @@ def plot_batteries_over_epoch_with_activity_at_time_scale( [str(i) for i in range(0, len(start_times), max(1, len(start_times) // 8))], ) plt.figure() + plt.rcParams.update({"font.size": 13}) # battery_plot plt.subplot(2, 1, 1) for device_id, series in series_dict.items(): @@ -837,6 +883,7 @@ def plot_batteries_over_epoch_with_activity_at_epoch_scale( # battery plot plt.figure() + plt.rcParams.update({"font.size": 13}) plt.subplot(2, 1, 1) num_rounds = [] for device_id, series in series_dict.items():