diff --git a/README.md b/README.md
index 1ee07534df56656bda49b1d153093233017cd005..e61b9064479425699e5b1645d81bff6e434d96c4 100644
--- a/README.md
+++ b/README.md
@@ -1,3 +1,43 @@
 # SwarmSplitLearning
 
-We introduce in this repository our own developed fully distributed variant of the well known split learning algorithm
+In this repository, we introduce our own fully distributed variant of the well-known split learning algorithm.
+
+
+## Installation
+
+For a faster installation, install the libmamba-solver first before creating the conda environment:
+
+```bash
+conda install -n base conda-libmamba-solver
+conda config --set solver libmamba
+```
+
+Otherwise, the default conda solver will be used, which may take forever.
+
+Update the file path to your desired directory for the environment by changing the value of `prefix:` at the end of
+the [environment.yml](environment.yml) file.
+Then run:
+
+```bash
+conda env create -f environment.yml
+conda activate [ENV_NAME]
+```
+
+For tracking experiments using Weights and Biases, place your API key without any spaces in
+the [wandb_key.txt](wandb_key.txt) file and make sure that `wandb: True` is set in
+the [wandb config](edml/config/wandb.yaml).
+Otherwise, metrics etc. will be printed to the console.
+
+If you plan to commit to this repository, please install **pre-commit** for a consistent code formatting upon committing.
+Therefore, run the following command in the repo:
+
+```bash
+ pre-commit install
+```
+
+Optionally, for formatting without committing, you may run:
+```bash
+pre-commit run --all-files
+```
+
+
diff --git a/diagrams/Basic_CD.png b/diagrams/Basic_CD.png
deleted file mode 100644
index 7b432dbbdecc7cbfd7292f5276a4ec76027231d7..0000000000000000000000000000000000000000
Binary files a/diagrams/Basic_CD.png and /dev/null differ
diff --git a/diagrams/Basic_SL_Interaction.png b/diagrams/Basic_SL_Interaction.png
deleted file mode 100644
index f1d6a5be49745cc070fe26a52555deffd4801370..0000000000000000000000000000000000000000
Binary files a/diagrams/Basic_SL_Interaction.png and /dev/null differ
diff --git a/diagrams/Device_SL_CD.png b/diagrams/Device_SL_CD.png
deleted file mode 100644
index 98ef8d653e079d98a2a2a557d2d8828082c93243..0000000000000000000000000000000000000000
Binary files a/diagrams/Device_SL_CD.png and /dev/null differ
diff --git a/diagrams/Device_SL_interaction.png b/diagrams/Device_SL_interaction.png
deleted file mode 100644
index 96c0fb6f31948ef5c25f8123e9abd33d4bdf64a1..0000000000000000000000000000000000000000
Binary files a/diagrams/Device_SL_interaction.png and /dev/null differ
diff --git a/edml/controllers/strategy_optimization.py b/edml/controllers/strategy_optimization.py
index 81570cb0407898bec8bb5e4d9c2f2405922aedf5..76572e94d392f9a6922f6bd51d60bde949797086 100644
--- a/edml/controllers/strategy_optimization.py
+++ b/edml/controllers/strategy_optimization.py
@@ -581,3 +581,106 @@ def run_grid_search(
                                     }
                                 )
     return results
+
+
+def run_grid_search_with_variable_devices(
+    num_devices_list,
+    global_params,
+    battery_per_device,
+    total_train_samples,
+    total_val_samples,
+    max_latencies=None,
+    max_split=None,
+    cost_per_sec=None,
+    cost_per_byte_sent=None,
+    cost_per_byte_received=None,
+    cost_per_flop=None,
+):
+    """
+    Runs a grid search for the given device parameters and global parameters.
+    Params:
+        num_devices: list(number of devices)
+        global_params: GlobalParams object
+        battery_per_device: list(battery per device)
+        total_train_samples: total number of train samples
+        total_val_samples: total number of validation samples
+        max_latencies: optional list(max latency per device) Sets all devices except the last one to the max latency if provided
+        max_split: optional list(max split per device) Sets the largest split for last device if provided and distributes the rest equally among the other devices
+        **kwargs: optional parameters for the grid search: should be a list of lists for device params and a list for costs
+        If e.g. cost_per_second is provided, the grid search will be run for all values in the list overriding existing values in the global params object.
+        If no cost_per_second is provided, the grid search will use the value from the global params object.
+    Returns:
+        list of dicts containing the results for each combination of parameters
+    """
+    if max_latencies is None:
+        max_latencies = [1.0]
+    if cost_per_sec is None:
+        cost_per_sec = [global_params.cost_per_sec]
+    if cost_per_byte_sent is None:
+        cost_per_byte_sent = [global_params.cost_per_byte_sent]
+    if cost_per_byte_received is None:
+        cost_per_byte_received = [global_params.cost_per_byte_received]
+    if cost_per_flop is None:
+        cost_per_flop = [global_params.cost_per_flop]
+    results = []
+    for num_devices in num_devices_list:
+        if max_split is None:
+            max_split = [1 / num_devices]
+        for battery in battery_per_device:
+            for latency in max_latencies:
+                for partition in max_split:
+                    device_params_list = [
+                        DeviceParams(
+                            device_id=f"d{i}",
+                            initial_battery=0,
+                            current_battery=battery,
+                            train_samples=(
+                                partition * total_train_samples
+                                if i == num_devices - 1
+                                else total_train_samples // num_devices
+                            ),
+                            validation_samples=(
+                                partition * total_val_samples
+                                if i == num_devices - 1
+                                else total_val_samples // num_devices
+                            ),
+                            comp_latency_factor=latency if i < num_devices - 1 else 1.0,
+                        )
+                        for i in range(num_devices)
+                    ]
+                    for cost_sec in cost_per_sec:
+                        for cost_sent in cost_per_byte_sent:
+                            for cost_received in cost_per_byte_received:
+                                for cost_flop in cost_per_flop:
+                                    global_params.cost_per_sec = cost_sec
+                                    global_params.cost_per_byte_sent = cost_sent
+                                    global_params.cost_per_byte_received = cost_received
+                                    global_params.cost_per_flop = cost_flop
+                                    energy_simulator = EnergySimulator(
+                                        device_params_list, global_params
+                                    )
+                                    num_rounds_smart, _, _ = (
+                                        energy_simulator.simulate_smart_selection()
+                                    )
+                                    num_rounds_greedy, _, _ = (
+                                        energy_simulator.simulate_greedy_selection()
+                                    )
+                                    num_rounds_fl, _ = (
+                                        energy_simulator.simulate_federated_learning()
+                                    )
+                                    results.append(
+                                        {
+                                            "num_devices": num_devices,
+                                            "battery": battery,
+                                            "latency": latency,
+                                            "partition": partition,
+                                            "cost_per_sec": cost_sec,
+                                            "cost_per_byte_sent": cost_sent,
+                                            "cost_per_byte_received": cost_received,
+                                            "cost_per_flop": cost_flop,
+                                            "num_rounds_smart": num_rounds_smart,
+                                            "num_rounds_greedy": num_rounds_greedy,
+                                            "num_rounds_fl": num_rounds_fl,
+                                        }
+                                    )
+    return results
diff --git a/edml/tests/controllers/optimization_test.py b/edml/tests/controllers/optimization_test.py
index ce4e5bdfbd90c78aee6ff5ef39e9cda7001bc063..614c4e74c1a60f4cd942404720459be83ef10a20 100644
--- a/edml/tests/controllers/optimization_test.py
+++ b/edml/tests/controllers/optimization_test.py
@@ -204,3 +204,231 @@ class EnergySimulatorTest(unittest.TestCase):
         num_rounds, remaining_batteries = self.simulator.simulate_federated_learning()
         self.assertEqual(num_rounds, 20)
         self.assertEqual(remaining_batteries, [3800, 2000, 200])
+
+
+class TestWithRealData(unittest.TestCase):
+    """Case studies for estimating the number of rounds of each strategy with given energy constratins."""
+
+    def setUp(self):
+        self.train_samples = [9600, 9600, 9600, 9600, 9600]
+        self.validation_samples = [2400, 2400, 2400, 2400, 2400]
+        self.current_batteries = [3750, 3750, 3750, 3750, 3750]
+        self.comp_latency = [1, 1, 1, 1, 1]
+        self.cost_per_sec = 1
+        self.cost_per_mbyte_sent = 1
+        self.cost_per_mbyte_received = 1
+        self.cost_per_mflop = 1
+
+    def _init_params_and_optimizer(self):
+        self.device_params_list = [
+            DeviceParams(
+                device_id="d0",
+                initial_battery=3750,
+                current_battery=self.current_batteries[0],
+                train_samples=self.train_samples[0],
+                validation_samples=self.validation_samples[0],
+                comp_latency_factor=self.comp_latency[0],
+            ),
+            DeviceParams(
+                device_id="d1",
+                initial_battery=3750,
+                current_battery=self.current_batteries[1],
+                train_samples=self.train_samples[1],
+                validation_samples=self.validation_samples[1],
+                comp_latency_factor=self.comp_latency[1],
+            ),
+            DeviceParams(
+                device_id="d2",
+                initial_battery=3750,
+                current_battery=self.current_batteries[2],
+                train_samples=self.train_samples[2],
+                validation_samples=self.validation_samples[2],
+                comp_latency_factor=self.comp_latency[2],
+            ),
+            DeviceParams(
+                device_id="d3",
+                initial_battery=3750,
+                current_battery=self.current_batteries[3],
+                train_samples=self.train_samples[3],
+                validation_samples=self.validation_samples[3],
+                comp_latency_factor=self.comp_latency[3],
+            ),
+            DeviceParams(
+                device_id="d4",
+                initial_battery=3750,
+                current_battery=self.current_batteries[4],
+                train_samples=self.train_samples[4],
+                validation_samples=self.validation_samples[4],
+                comp_latency_factor=self.comp_latency[4],
+            ),
+        ]
+        self.global_params = GlobalParams(
+            cost_per_sec=self.cost_per_sec,
+            cost_per_byte_sent=self.cost_per_mbyte_sent / 1000000,
+            cost_per_byte_received=self.cost_per_mbyte_received / 1000000,
+            cost_per_flop=self.cost_per_mflop / 1000000,
+            client_model_flops=5405760,
+            server_model_flops=11215800,
+            smashed_data_size=36871,
+            label_size=14,
+            gradient_size=36871,
+            batch_size=64,
+            client_norm_fw_time=0.0001318198063394479,
+            client_norm_bw_time=1.503657614975644e-05,
+            server_norm_fw_time=2.1353501238321005e-05,
+            server_norm_bw_time=3.1509113154913254e-05,
+            client_weights_size=71678,
+            # train global response size 15878758
+            server_weights_size=15807080,
+        )
+        self.optimizer = ServerChoiceOptimizer(
+            self.device_params_list, self.global_params
+        )
+        self.simulator = EnergySimulator(self.device_params_list, self.global_params)
+
+    def test_with_actual_data(self):
+        self.current_batteries = [3719.654, 3717.608, 3711.923, 3708.051, 3704.294]
+        self.cost_per_sec = 1  # 0.1
+        self.cost_per_mbyte_sent = 0.05  # 0.002
+        self.cost_per_mbyte_received = 0.05  # 0.0005
+        self.cost_per_mflop = 0.00025  # 0.0005
+        self._init_params_and_optimizer()
+        solution, status = self.optimizer.optimize()
+        self.assertEqual(
+            solution, {"d0": 4.0, "d1": 4.0, "d2": 4.0, "d3": 3.0, "d4": 0.0}
+        )
+
+    def test_simulate_greedy_selection_with_actual_data(self):
+        self.current_batteries = [3719.654, 3717.608, 3711.923, 3708.051, 3704.294]
+        self.cost_per_sec = 1  # 0.1
+        self.cost_per_mbyte_sent = 0.05  # 0.002
+        self.cost_per_mbyte_received = 0.05  # 0.0005
+        self.cost_per_mflop = 0.00025  # 0.0005
+        self._init_params_and_optimizer()
+        num_rounds, schedule, remaining_batteries = (
+            self.simulator.simulate_greedy_selection()
+        )
+        self.assertEqual(num_rounds, 15)
+        self.assertEqual(
+            schedule,
+            [
+                "d0",
+                "d1",
+                "d2",
+                "d3",
+                "d4",
+                "d0",
+                "d1",
+                "d2",
+                "d3",
+                "d4",
+                "d0",
+                "d1",
+                "d2",
+                "d3",
+                "d4",
+            ],
+        )
+
+    def test_unequal_split(self):
+        self.current_batteries = [5750, 4750, 3750, 2750, 1750]
+        self.train_samples = [4800, 9600, 19200, 7200, 7200]
+        self.validation_samples = [1200, 2400, 4800, 1800, 1800]
+        self.cost_per_sec = 1
+        self.cost_per_mbyte_sent = 0.05
+        self.cost_per_mbyte_received = 0.05
+        self.cost_per_mflop = 0.00025
+        self._init_params_and_optimizer()
+        solution, status = self.optimizer.optimize()
+        self.assertEqual(
+            solution, {"d0": 8.0, "d1": 5.0, "d2": 1.0, "d3": 2.0, "d4": 0.0}
+        )
+
+    def test_unequal_split_and_batteries_with_high_communication_cost(self):
+        self.current_batteries = [3750, 3750, 3750, 3750, 3750]
+        data_partitions = [0.1, 0.1, 0.6, 0.1, 0.1]
+        self.train_samples = [partition * 48000 for partition in data_partitions]
+        self.validation_samples = [partition * 12000 for partition in data_partitions]
+        self.cost_per_sec = 0.1
+        self.cost_per_mbyte_sent = 0.1
+        self.cost_per_mbyte_received = 0.1
+        self.cost_per_mflop = 0.00001
+        self._init_params_and_optimizer()
+        solution, status = self.optimizer.optimize()
+        num_rounds, schedule, remaining_batteries = (
+            self.simulator.simulate_greedy_selection()
+        )
+        print(solution, sum(solution.values()), num_rounds, remaining_batteries)
+        self.assertGreater(sum(solution.values()), num_rounds)
+
+    def test_unequal_battery_unequal_processing_high_time_cost(self):
+        self.current_batteries = [5750, 4750, 3750, 2750, 1750]
+        data_partitions = [0.2, 0.1, 0.1, 0.1, 0.5]
+        self.comp_latency = [5, 5, 5, 5, 1]
+        self.train_samples = [partition * 48000 for partition in data_partitions]
+        self.validation_samples = [partition * 12000 for partition in data_partitions]
+        self.cost_per_sec = 10
+        self.cost_per_mbyte_sent = 0.0  # 1
+        self.cost_per_mbyte_received = 0.0  # 1
+        self.cost_per_mflop = 0.0000  # 5
+        self._init_params_and_optimizer()
+        solution, status = self.optimizer.optimize()
+        num_rounds, schedule, remaining_batteries = (
+            self.simulator.simulate_greedy_selection()
+        )
+        print(solution, sum(solution.values()), num_rounds)
+        self.assertGreater(sum(solution.values()), num_rounds)
+
+    def test_unequal_battery_unequal_processing_high_time_cost2(self):
+        self.current_batteries = [3750, 3750, 3750, 3750, 3750]
+        data_partitions = [0.05, 0.1, 0.1, 0.1, 0.65]
+        self.comp_latency = [10, 10, 10, 10, 1]
+        self.train_samples = [partition * 48000 for partition in data_partitions]
+        self.validation_samples = [partition * 12000 for partition in data_partitions]
+        self.cost_per_sec = 3
+        self.cost_per_mbyte_sent = 0.05
+        self.cost_per_mbyte_received = 0.05
+        self.cost_per_mflop = 0.000005
+        self._init_params_and_optimizer()
+        solution, status = self.optimizer.optimize()
+        num_rounds, schedule, remaining_batteries = (
+            self.simulator.simulate_greedy_selection()
+        )
+        print(solution, sum(solution.values()), num_rounds, remaining_batteries)
+        self.assertGreater(sum(solution.values()), num_rounds)
+
+    def test_unequal_battery_unequal_processing_high_time_cost3(self):
+        self.current_batteries = [3750, 3750, 3750, 3750, 3750]
+        data_partitions = [0.01, 0.01, 0.01, 0.01, 0.96]
+        self.comp_latency = [100, 100, 100, 100, 1]
+        self.train_samples = [partition * 48000 for partition in data_partitions]
+        self.validation_samples = [partition * 12000 for partition in data_partitions]
+        self.cost_per_sec = 3
+        self.cost_per_mbyte_sent = 0.05  # 375
+        self.cost_per_mbyte_received = 0.05
+        self.cost_per_mflop = 0.000005
+        self._init_params_and_optimizer()
+        solution, status = self.optimizer.optimize()
+        num_rounds, schedule, remaining_batteries = (
+            self.simulator.simulate_greedy_selection()
+        )
+        print(solution, sum(solution.values()), num_rounds, remaining_batteries)
+        self.assertGreater(sum(solution.values()), num_rounds)
+
+    def test_unequal_battery_unequal_processing_high_time_cost4(self):
+        self.current_batteries = [3750, 3750, 3750, 3750, 3750]
+        data_partitions = [0.1, 0.1, 0.1, 0.1, 0.6]
+        self.comp_latency = [10, 10, 10, 10, 1]
+        self.train_samples = [partition * 48000 for partition in data_partitions]
+        self.validation_samples = [partition * 12000 for partition in data_partitions]
+        self.cost_per_sec = 4
+        self.cost_per_mbyte_sent = 0.01
+        self.cost_per_mbyte_received = 0.01
+        self.cost_per_mflop = 0.00001
+        self._init_params_and_optimizer()
+        solution, status = self.optimizer.optimize()
+        num_rounds, schedule, remaining_batteries = (
+            self.simulator.simulate_greedy_selection()
+        )
+        print(solution, sum(solution.values()), num_rounds, remaining_batteries)
+        self.assertGreater(sum(solution.values()), num_rounds)
diff --git a/environment.yml b/environment.yml
index b7c4e6c10cd35765f9527be7feb52ee37d21861a..1de0f15678d2ab23fe1767682327f3efa2074173 100644
--- a/environment.yml
+++ b/environment.yml
@@ -15,6 +15,7 @@ dependencies:
   - numpy=1.24.3
   - pandas=1.4.2
   - pip=21.2.4
+  - pre-commit=3.7.0
   - protobuf=3.19.1
   - pydantic
   - pytest=7.4.2
diff --git a/results/result_generation.ipynb b/results/result_generation.ipynb
index 35530262d464a28041811a49a9534674e3e68c4c..85b37ddb464b85b1e634dcbcde25c0a221baa9fb 100644
--- a/results/result_generation.ipynb
+++ b/results/result_generation.ipynb
@@ -9,7 +9,7 @@
    },
    "outputs": [],
    "source": [
-    "from results.result_generation import save_dataframes, load_dataframes, generate_metric_files, generate_plots"
+    "from result_generation import save_dataframes, load_dataframes, generate_metric_files, generate_plots"
    ]
   },
   {
diff --git a/results/result_generation.py b/results/result_generation.py
index 4895b87704aecaf2ad9fa27af19bd56be11e84b5..b95183ebae7218c3c6ae7453b6bc103d69ba163e 100644
--- a/results/result_generation.py
+++ b/results/result_generation.py
@@ -9,10 +9,10 @@ import wandb
 # For plotting
 STRATEGY_MAPPING = {
     "split": "Vanilla SL",
-    "swarm_smart": "SwarmSL (Smart)",
-    "swarm_seq": "SwarmSL (Seq)",
-    "swarm_rand": "SwarmSL (Rand)",
-    "swarm_max": "SwarmSL (Greedy)",
+    "swarm_smart": "Swarm SL (Smart)",
+    "swarm_seq": "Swarm SL (Seq)",
+    "swarm_rand": "Swarm SL (Rand)",
+    "swarm_max": "Swarm SL (Greedy)",
     "fed": "Vanilla FL",
 }
 
@@ -318,18 +318,26 @@ def accuracy_over_epoch(history_groups, phase="train"):
                 round_cols.append(list(run_df[f"{phase}_accuracy.round"].dropna()))
                 value_cols.append(list(run_df[f"{phase}_accuracy.value"].dropna()))
                 # if multiple columns exist (i.e. multiple runs) average in each round and if one run was shorter, use last value
-            max_rounds = max([len(col) for col in round_cols])
-            round = range(0, max_rounds)
-            acc = []
-            for i in round:
+            max_rounds = max([int(col[-1]) for col in round_cols]) + 1
+            mean_acc, round_no = [], []
+            for i in range(max_rounds):
                 single_run_accs = []
-                for j in range(len(value_cols)):
-                    if len(value_cols[j]) > i:
-                        single_run_accs.append(value_cols[j][i])
-                    else:
-                        single_run_accs.append(value_cols[j][-1])
-                acc.append(sum(single_run_accs) / len(single_run_accs))
-            results[(strategy, job)] = (round, acc)
+                for run_idx, value_col in enumerate(
+                    value_cols
+                ):  # round_cols should have same length
+                    if (
+                        round_cols[run_idx].count(i) > 0
+                    ):  # check if values were logged in this round
+                        round_idx = round_cols[run_idx].index(
+                            i
+                        )  # get the index of the round (could be less than the round number due to missing values)
+                        single_run_accs.append(value_col[round_idx])
+                # print(i, single_run_accs)
+                if len(single_run_accs) > 0:
+                    mean_acc.append(sum(single_run_accs) / len(single_run_accs))
+                    round_no.append(i)
+
+            results[(strategy, job)] = (round_no, mean_acc)
     return results
 
 
@@ -515,6 +523,7 @@ def plot_batteries_over_time(
     """
     if aggregated:
         plt.figure()
+        plt.rcParams.update({"font.size": 13})
         for (strategy, job), series in batteries_over_time.items():
             runtime = max_runtimes[(strategy, job)]
             x_values = [
@@ -535,6 +544,7 @@ def plot_batteries_over_time(
     else:
         for (strategy, job), series_dict in batteries_over_time.items():
             plt.figure()
+            plt.rcParams.update({"font.size": 13})
             for device_id, series in series_dict.items():
                 runtime = max_runtimes[(strategy, job)]
                 x_values = [
@@ -564,6 +574,7 @@ def plot_batteries_over_epoch(batteries_over_epoch, save_path=None, aggregated=T
     """
     if aggregated:
         plt.figure()
+        plt.rcParams.update({"font.size": 13})
         num_rounds = [0]
         for (strategy, job), series in batteries_over_epoch.items():
             x_values = series.index
@@ -584,6 +595,7 @@ def plot_batteries_over_epoch(batteries_over_epoch, save_path=None, aggregated=T
     else:
         for (strategy, job), series_dict in batteries_over_epoch.items():
             plt.figure()
+            plt.rcParams.update({"font.size": 13})
             num_rounds = [0]
             for device_id, series in series_dict.items():
                 x_values = series.index
@@ -643,23 +655,40 @@ def get_train_times(run_groups):
                         run_df["client_train_epoch_time.start"] > 0
                     )
                     valid_eval_times = run_df["client_evaluate_time.end"] > 0
-                    train_epoch_times = run_df[valid_train_epoch_times][
+                    train_epoch_start_times = run_df[valid_train_epoch_times][
                         ["client_train_epoch_time.start"]
                     ].reset_index(drop=True)
-                    eval_times = run_df[valid_eval_times][
+                    train_epoch_end_times = run_df[valid_train_epoch_times][
+                        ["client_train_epoch_time.end"]
+                    ].reset_index(drop=True)
+                    eval_start_times = run_df[valid_eval_times][
+                        ["client_evaluate_time.start"]
+                    ].reset_index(drop=True)
+                    eval_end_times = run_df[valid_eval_times][
                         ["client_evaluate_time.end"]
                     ].reset_index(drop=True)
-                    client_train_times[device_id] = (
-                        pd.concat(
-                            [train_epoch_times[: len(eval_times)], eval_times], axis=1
-                        ).rename(
-                            columns={
-                                "client_train_epoch_time.start": "start",
-                                "client_evaluate_time.end": "end",
-                            }
-                        )
-                        - min_start_time
-                    )
+                    client_train_times[device_id] = pd.concat(
+                        [
+                            pd.concat(
+                                [train_epoch_start_times, train_epoch_end_times], axis=1
+                            ).rename(
+                                columns={
+                                    "client_train_epoch_time.start": "start",
+                                    "client_train_epoch_time.end": "end",
+                                }
+                            )
+                            - min_start_time,
+                            pd.concat(
+                                [eval_start_times, eval_end_times], axis=1
+                            ).rename(
+                                columns={
+                                    "client_evaluate_time.start": "start",
+                                    "client_evaluate_time.end": "end",
+                                }
+                            )
+                            - min_start_time,
+                        ]
+                    ).reset_index(drop=True)
 
                 if "fed_train_time.start" in run_df.columns:
                     valid_fed_train_times = run_df["fed_train_time.start"] > 0
@@ -681,6 +710,17 @@ def get_train_times(run_groups):
     return results
 
 
+def train_time_end(server_train_times, client_train_times):
+    max_timestamp = 0
+    for idx, (device_id, df) in enumerate(server_train_times.items()):
+        if df["end"].max() > max_timestamp:
+            max_timestamp = df["end"].max()
+    for idx, (device_id, df) in enumerate(client_train_times.items()):
+        if df["end"].max() > max_timestamp:
+            max_timestamp = df["end"].max()
+    return max_timestamp
+
+
 def plot_batteries_over_time_with_activity(
     batteries_over_time, max_runtimes, training_times, save_path=None
 ):
@@ -689,7 +729,12 @@ def plot_batteries_over_time_with_activity(
     y_offset = 10
     for (strategy, job), series_dict in batteries_over_time.items():
         server_train_times, client_train_times = training_times[(strategy, job)]
+        xlim = (
+            0,
+            train_time_end(server_train_times, client_train_times) * 1.05,
+        )  # set end 5% after last activity timestamp
         plt.figure()
+        plt.rcParams.update({"font.size": 13})
         # battery_plot
         plt.subplot(2, 1, 1)
         for device_id, series in series_dict.items():
@@ -702,7 +747,7 @@ def plot_batteries_over_time_with_activity(
         plt.ylabel(LABEL_MAPPING["device battery"])
         plt.legend()
         plt.tight_layout()
-        # plt.xlim(xlim)
+        plt.xlim(xlim)
         # plt.ylim(3000, 3800)
         # activity plot
         plt.subplot(2, 1, 2)
@@ -729,7 +774,7 @@ def plot_batteries_over_time_with_activity(
         plt.ylabel(LABEL_MAPPING["device"])
         plt.legend(loc="upper left")
         plt.tight_layout()
-        # plt.xlim(xlim)
+        plt.xlim(xlim)
         plt.yticks(
             [(device_space * x + 2 * bar_height + y_offset) for x in range(0, 5)],
             labels=["d0", "d1", "d2", "d3", "d4"][:5],
@@ -766,6 +811,7 @@ def plot_batteries_over_epoch_with_activity_at_time_scale(
             [str(i) for i in range(0, len(start_times), max(1, len(start_times) // 8))],
         )
         plt.figure()
+        plt.rcParams.update({"font.size": 13})
         # battery_plot
         plt.subplot(2, 1, 1)
         for device_id, series in series_dict.items():
@@ -837,6 +883,7 @@ def plot_batteries_over_epoch_with_activity_at_epoch_scale(
 
         # battery plot
         plt.figure()
+        plt.rcParams.update({"font.size": 13})
         plt.subplot(2, 1, 1)
         num_rounds = []
         for device_id, series in series_dict.items():