diff --git a/results/result_generation.py b/results/result_generation.py
index 26ea0d79ef13e614af15daf7753597e559edb544..2b0cb3e61e344add382c03abf0838844a3d2ce80 100644
--- a/results/result_generation.py
+++ b/results/result_generation.py
@@ -429,7 +429,7 @@ def plot_remaining_devices(devices_per_round, save_path=None):
     )
     plt.xlabel(LABEL_MAPPING["round"])
     plt.ylabel(LABEL_MAPPING["num_devices"])
-    plt.legend()
+    plt.legend(loc="upper center", bbox_to_anchor=(0.5, -0.15), ncol=2)
     plt.tight_layout()
     if save_path is None:
         plt.show()
@@ -564,7 +564,7 @@ def plot_accuracies(accuracies_per_round, save_path=None, phase="train"):
     plt.xticks(range(0, max(num_rounds) + 1, max(1, (max(num_rounds) + 1) // 8)))
     plt.xlabel(LABEL_MAPPING["round"])
     plt.ylabel(LABEL_MAPPING[f"{phase} accuracy"])
-    plt.legend()
+    plt.legend(loc="upper center", bbox_to_anchor=(0.5, -0.15), ncol=2)
     plt.tight_layout()
     if save_path is None:
         plt.show()
@@ -586,7 +586,7 @@ def plot_accuracies_over_time(accuracies_per_time, save_path=None, phase="train"
         plt.plot(df, label=f"{STRATEGY_MAPPING[strategy]}")
     plt.xlabel(LABEL_MAPPING["runtime"])
     plt.ylabel(LABEL_MAPPING[f"{phase} accuracy"])
-    plt.legend()
+    plt.legend(loc="upper center", bbox_to_anchor=(0.5, -0.15), ncol=2)
     plt.tight_layout()
     if save_path is None:
         plt.show()
@@ -763,7 +763,7 @@ def plot_batteries_over_time(
             plt.plot(x_values, y_values, label=f"{STRATEGY_MAPPING[strategy]}")
         plt.xlabel(LABEL_MAPPING["runtime"])
         plt.ylabel(LABEL_MAPPING["total network battery"])
-        plt.legend()
+        plt.legend(loc="upper center", bbox_to_anchor=(0.5, -0.15), ncol=2)
         plt.tight_layout()
         if save_path is None:
             plt.show()
@@ -784,7 +784,7 @@ def plot_batteries_over_time(
                 plt.plot(x_values, y_values, label=f"{device_id}")
             plt.xlabel(LABEL_MAPPING["runtime"])
             plt.ylabel(LABEL_MAPPING["device battery"])
-            plt.legend()
+            plt.legend(loc="upper center", bbox_to_anchor=(0.5, -0.15), ncol=2)
             plt.tight_layout()
             if save_path is None:
                 plt.show()
@@ -814,7 +814,7 @@ def plot_batteries_over_epoch(batteries_over_epoch, save_path=None, aggregated=T
         plt.xticks(range(0, max(num_rounds) + 1, max(1, (max(num_rounds) + 1) // 8)))
         plt.ylabel(LABEL_MAPPING["total network battery"])
         plt.xlabel(LABEL_MAPPING["round"])
-        plt.legend()
+        plt.legend(loc="upper center", bbox_to_anchor=(0.5, -0.15), ncol=2)
         plt.tight_layout()
         if save_path is None:
             plt.show()
@@ -837,7 +837,7 @@ def plot_batteries_over_epoch(batteries_over_epoch, save_path=None, aggregated=T
             )
             plt.ylabel(LABEL_MAPPING["device battery"])
             plt.xlabel(LABEL_MAPPING["round"])
-            plt.legend()
+            plt.legend(loc="upper center", bbox_to_anchor=(0.5, -0.15), ncol=2)
             plt.tight_layout()
             if save_path is None:
                 plt.show()
@@ -975,7 +975,7 @@ def plot_batteries_over_time_with_activity(
             y_values = series.values
             plt.plot(x_values, y_values, label=f"{device_id}")
         plt.ylabel(LABEL_MAPPING["device battery"])
-        plt.legend()
+        plt.legend(loc="upper center", bbox_to_anchor=(0.5, -0.15), ncol=2)
         plt.tight_layout()
         plt.xlim(xlim)
         # plt.ylim(3000, 3800)
@@ -1052,7 +1052,7 @@ def plot_batteries_over_epoch_with_activity_at_time_scale(
             y_values = series.values
             plt.plot(x_values, y_values, label=f"{device_id}")
         plt.ylabel(LABEL_MAPPING["device battery"])
-        plt.legend()
+        plt.legend(loc="upper center", bbox_to_anchor=(0.5, -0.15), ncol=2)
         plt.tight_layout()
         plt.xlim(xlim)
         plt.xticks(xticks[0], labels=xticks[1])
@@ -1123,7 +1123,7 @@ def plot_batteries_over_epoch_with_activity_at_epoch_scale(
         num_rounds.append(len(series))
         plt.xticks(range(0, max(num_rounds) + 1, max(1, (max(num_rounds) + 1) // 8)))
         plt.ylabel(LABEL_MAPPING["device battery"])
-        plt.legend()
+        plt.legend(loc="upper center", bbox_to_anchor=(0.5, -0.15), ncol=2)
         plt.tight_layout()
 
         # device plot