diff --git a/results/result_generation.py b/results/result_generation.py index 26ea0d79ef13e614af15daf7753597e559edb544..2b0cb3e61e344add382c03abf0838844a3d2ce80 100644 --- a/results/result_generation.py +++ b/results/result_generation.py @@ -429,7 +429,7 @@ def plot_remaining_devices(devices_per_round, save_path=None): ) plt.xlabel(LABEL_MAPPING["round"]) plt.ylabel(LABEL_MAPPING["num_devices"]) - plt.legend() + plt.legend(loc="upper center", bbox_to_anchor=(0.5, -0.15), ncol=2) plt.tight_layout() if save_path is None: plt.show() @@ -564,7 +564,7 @@ def plot_accuracies(accuracies_per_round, save_path=None, phase="train"): plt.xticks(range(0, max(num_rounds) + 1, max(1, (max(num_rounds) + 1) // 8))) plt.xlabel(LABEL_MAPPING["round"]) plt.ylabel(LABEL_MAPPING[f"{phase} accuracy"]) - plt.legend() + plt.legend(loc="upper center", bbox_to_anchor=(0.5, -0.15), ncol=2) plt.tight_layout() if save_path is None: plt.show() @@ -586,7 +586,7 @@ def plot_accuracies_over_time(accuracies_per_time, save_path=None, phase="train" plt.plot(df, label=f"{STRATEGY_MAPPING[strategy]}") plt.xlabel(LABEL_MAPPING["runtime"]) plt.ylabel(LABEL_MAPPING[f"{phase} accuracy"]) - plt.legend() + plt.legend(loc="upper center", bbox_to_anchor=(0.5, -0.15), ncol=2) plt.tight_layout() if save_path is None: plt.show() @@ -763,7 +763,7 @@ def plot_batteries_over_time( plt.plot(x_values, y_values, label=f"{STRATEGY_MAPPING[strategy]}") plt.xlabel(LABEL_MAPPING["runtime"]) plt.ylabel(LABEL_MAPPING["total network battery"]) - plt.legend() + plt.legend(loc="upper center", bbox_to_anchor=(0.5, -0.15), ncol=2) plt.tight_layout() if save_path is None: plt.show() @@ -784,7 +784,7 @@ def plot_batteries_over_time( plt.plot(x_values, y_values, label=f"{device_id}") plt.xlabel(LABEL_MAPPING["runtime"]) plt.ylabel(LABEL_MAPPING["device battery"]) - plt.legend() + plt.legend(loc="upper center", bbox_to_anchor=(0.5, -0.15), ncol=2) plt.tight_layout() if save_path is None: plt.show() @@ -814,7 +814,7 @@ def plot_batteries_over_epoch(batteries_over_epoch, save_path=None, aggregated=T plt.xticks(range(0, max(num_rounds) + 1, max(1, (max(num_rounds) + 1) // 8))) plt.ylabel(LABEL_MAPPING["total network battery"]) plt.xlabel(LABEL_MAPPING["round"]) - plt.legend() + plt.legend(loc="upper center", bbox_to_anchor=(0.5, -0.15), ncol=2) plt.tight_layout() if save_path is None: plt.show() @@ -837,7 +837,7 @@ def plot_batteries_over_epoch(batteries_over_epoch, save_path=None, aggregated=T ) plt.ylabel(LABEL_MAPPING["device battery"]) plt.xlabel(LABEL_MAPPING["round"]) - plt.legend() + plt.legend(loc="upper center", bbox_to_anchor=(0.5, -0.15), ncol=2) plt.tight_layout() if save_path is None: plt.show() @@ -975,7 +975,7 @@ def plot_batteries_over_time_with_activity( y_values = series.values plt.plot(x_values, y_values, label=f"{device_id}") plt.ylabel(LABEL_MAPPING["device battery"]) - plt.legend() + plt.legend(loc="upper center", bbox_to_anchor=(0.5, -0.15), ncol=2) plt.tight_layout() plt.xlim(xlim) # plt.ylim(3000, 3800) @@ -1052,7 +1052,7 @@ def plot_batteries_over_epoch_with_activity_at_time_scale( y_values = series.values plt.plot(x_values, y_values, label=f"{device_id}") plt.ylabel(LABEL_MAPPING["device battery"]) - plt.legend() + plt.legend(loc="upper center", bbox_to_anchor=(0.5, -0.15), ncol=2) plt.tight_layout() plt.xlim(xlim) plt.xticks(xticks[0], labels=xticks[1]) @@ -1123,7 +1123,7 @@ def plot_batteries_over_epoch_with_activity_at_epoch_scale( num_rounds.append(len(series)) plt.xticks(range(0, max(num_rounds) + 1, max(1, (max(num_rounds) + 1) // 8))) plt.ylabel(LABEL_MAPPING["device battery"]) - plt.legend() + plt.legend(loc="upper center", bbox_to_anchor=(0.5, -0.15), ncol=2) plt.tight_layout() # device plot