diff --git a/edml/controllers/strategy_optimization.py b/edml/controllers/strategy_optimization.py index eddbf5742a526285cf6a3793b208900beda9f3ec..cd7bae8ac0db1c8eb4bcf818713ed8b2f1ed0e47 100644 --- a/edml/controllers/strategy_optimization.py +++ b/edml/controllers/strategy_optimization.py @@ -123,13 +123,15 @@ class ServerChoiceOptimizer: self.global_params.train_global_time is not None and self.global_params.last_server_device_id is not None ): - return ( + latency = ( self.global_params.train_global_time - self._round_runtime_with_server_no_latency( self.global_params.last_server_device_id ) ) - return 0 # latency not known + if latency > 0: + return latency + return 0 # latency not known or runtime was overestimated previously def _round_runtime_with_server_no_latency(self, server_device_id): """ @@ -356,20 +358,7 @@ class EnergySimulator: device_params_list, global_params ) - def simulate_greedy_selection(self): - """ - Simulates the greedy server choice algorithm. - Returns: - num_rounds: number of rounds until the first device runs out of battery - server_selection_schedule: list of server device ids for each round - device_batteries: list of battery levels for each device after the last successful round - """ - - def __get_device_with_max_battery__(device_battery_list): - return max( - range(len(device_battery_list)), key=device_battery_list.__getitem__ - ) - + def _simulate_selection(self, selection_callback=None): def __all_devices_alive__(device_battery_list): return all(battery > 0 for battery in device_battery_list) @@ -394,7 +383,9 @@ class EnergySimulator: server_selection_schedule = [] num_rounds = 0 while all_devices_alive: - server_device_idx = __get_device_with_max_battery__(device_batteries) + server_device_idx = selection_callback( + device_battery_list=device_batteries, num_rounds=num_rounds + ) new_batteries = device_batteries.copy() for idx, device in enumerate(self.device_params_list): new_batteries[idx] = new_batteries[idx] - energy[idx][server_device_idx] @@ -409,6 +400,36 @@ class EnergySimulator: break return num_rounds, server_selection_schedule, device_batteries + def simulate_greedy_selection(self): + """ + Simulates the greedy server choice algorithm. + Returns: + num_rounds: number of rounds until the first device runs out of battery + server_selection_schedule: list of server device ids for each round + device_batteries: list of battery levels for each device after the last successful round + """ + + def __get_device_with_max_battery__(device_battery_list, **kwargs): + return max( + range(len(device_battery_list)), key=device_battery_list.__getitem__ + ) + + return self._simulate_selection(__get_device_with_max_battery__) + + def simulate_sequential_selection(self): + """ + Simulates the sequential server choice algorithm. + Returns: + num_rounds: number of rounds until the first device runs out of battery + server_selection_schedule: list of server device ids for each round + device_batteries: list of battery levels for each device after the last successful round + """ + + def __sequential_selection__(device_battery_list, num_rounds): + return num_rounds % len(device_battery_list) + + return self._simulate_selection(__sequential_selection__) + def simulate_smart_selection(self): """ Simulates the smart server choice algorithm. @@ -574,47 +595,52 @@ def run_grid_search( for partition in partitions: for cost_sec in cost_per_sec: for cost_sent in cost_per_byte_sent: - for cost_received in cost_per_byte_received: - for cost_flop in cost_per_flop: - global_params.cost_per_sec = cost_sec - global_params.cost_per_byte_sent = cost_sent - global_params.cost_per_byte_received = cost_received - global_params.cost_per_flop = cost_flop - for idx, device in enumerate(device_params_list): - device.current_battery = battery[idx] - device.comp_latency_factor = latency[idx] - device.train_samples = ( - partition[idx] * total_train_samples - ) - device.validation_samples = ( - partition[idx] * total_val_samples - ) - energy_simulator = EnergySimulator( - device_params_list, global_params - ) - num_rounds_smart, _, _ = ( - energy_simulator.simulate_smart_selection() - ) - num_rounds_greedy, _, _ = ( - energy_simulator.simulate_greedy_selection() - ) - num_rounds_fl, _ = ( - energy_simulator.simulate_federated_learning() + # for cost_received in cost_per_byte_received: + cost_received = cost_sent + for cost_flop in cost_per_flop: + global_params.cost_per_sec = cost_sec + global_params.cost_per_byte_sent = cost_sent + global_params.cost_per_byte_received = cost_received + global_params.cost_per_flop = cost_flop + for idx, device in enumerate(device_params_list): + device.current_battery = battery[idx] + device.comp_latency_factor = latency[idx] + device.train_samples = ( + partition[idx] * total_train_samples ) - results.append( - { - "battery": battery, - "latency": latency, - "partition": partition, - "cost_per_sec": cost_sec, - "cost_per_byte_sent": cost_sent, - "cost_per_byte_received": cost_received, - "cost_per_flop": cost_flop, - "num_rounds_smart": num_rounds_smart, - "num_rounds_greedy": num_rounds_greedy, - "num_rounds_fl": num_rounds_fl, - } + device.validation_samples = ( + partition[idx] * total_val_samples ) + energy_simulator = EnergySimulator( + device_params_list, global_params + ) + num_rounds_smart, _, _ = ( + energy_simulator.simulate_smart_selection() + ) + num_rounds_greedy, _, _ = ( + energy_simulator.simulate_greedy_selection() + ) + num_rounds_seq, _, _ = ( + energy_simulator.simulate_sequential_selection() + ) + num_rounds_fl, _ = ( + energy_simulator.simulate_federated_learning() + ) + results.append( + { + "battery": battery, + "latency": latency, + "partition": partition, + "cost_per_sec": cost_sec, + "cost_per_byte_sent": cost_sent, + "cost_per_byte_received": cost_received, + "cost_per_flop": cost_flop, + "num_rounds_smart": num_rounds_smart, + "num_rounds_seq": num_rounds_seq, + "num_rounds_greedy": num_rounds_greedy, + "num_rounds_fl": num_rounds_fl, + } + ) return results