diff --git a/config/default.yaml b/config/default.yaml
index 6ff71a77b1310a81a84df1bd7bfd040813ee1f9d..dac99d0ff6f6b963c60d9cc8e16b0c82df22d217 100644
--- a/config/default.yaml
+++ b/config/default.yaml
@@ -14,6 +14,9 @@ defaults:
   - wandb: default
   - _self_
 
+# If true, controllers will run devices in parallel. If false, they will run sequentially and their runtime is corrected
+# to account for the parallelism in post-processing.
+simulate_parallelism: False
 own_device_id: "d0"
 num_devices: ${len:${topology.devices}}
 
diff --git a/edml/config/battery/resnet110_cifar100_cost.yaml b/edml/config/battery/resnet110_cifar100_cost.yaml
index d1bcdb98d5d0d98d49eedff71d3cf12b9e1bf9d9..424c8d81b8db9ee8898edb91911d5d9b1f6c26cb 100644
--- a/edml/config/battery/resnet110_cifar100_cost.yaml
+++ b/edml/config/battery/resnet110_cifar100_cost.yaml
@@ -1,4 +1,4 @@
-deduction_per_second: 0.005
-deduction_per_mflop: 0.00000005
-deduction_per_mbyte_received: 0.0002
-deduction_per_mbyte_sent: 0.0002
+deduction_per_second: 0.007
+deduction_per_mflop: 0.00000001
+deduction_per_mbyte_received: 0.00005
+deduction_per_mbyte_sent: 0.00005
diff --git a/edml/config/topology/resnet110_cifar100_batteries.yaml b/edml/config/topology/resnet110_cifar100_batteries.yaml
index 3a81e0087447bd83066bf31840321d1cc2b4413d..61838e0fb17a5e720154b4acb0b10496decf48c9 100644
--- a/edml/config/topology/resnet110_cifar100_batteries.yaml
+++ b/edml/config/topology/resnet110_cifar100_batteries.yaml
@@ -2,26 +2,26 @@ devices: [
   {
     device_id: "d0",
     address: "localhost:50051",
-    battery_capacity: 750,
+    battery_capacity: 300,
   },
   {
     device_id: "d1",
     address: "localhost:50052",
-    battery_capacity: 750
+    battery_capacity: 300
   },
   {
     device_id: "d2",
     address: "localhost:50053",
-    battery_capacity: 600
+    battery_capacity: 300
   },
   {
     device_id: "d3",
     address: "localhost:50054",
-    battery_capacity: 600
+    battery_capacity: 300
   },
   {
     device_id: "d4",
     address: "localhost:50055",
-    battery_capacity: 600
+    battery_capacity: 300
   }
 ]
diff --git a/edml/controllers/fed_controller.py b/edml/controllers/fed_controller.py
index cdf6b6e86d2df4d59f5ea025288fb5e28d3d4efc..250e02c0307e002c2d84c50296a752ec651a8d74 100644
--- a/edml/controllers/fed_controller.py
+++ b/edml/controllers/fed_controller.py
@@ -1,10 +1,13 @@
 import concurrent.futures
+import contextlib
+import functools
 import threading
 from typing import Dict, List
 
 from overrides import override
 
 from edml.controllers.base_controller import BaseController
+from edml.helpers.decorators import Timer
 from edml.helpers.metrics import ModelMetricResultContainer
 
 
@@ -35,6 +38,14 @@ def fed_average(model_weights: list[Dict], weighting_scheme: List[float] = None)
     return None
 
 
+def simulate_parallelism(f):
+    @functools.wraps(f)
+    def inner(self):
+        pass
+
+    return inner
+
+
 class FedController(BaseController):
     """Controller for federated learning."""
 
@@ -43,40 +54,66 @@ class FedController(BaseController):
 
     def _fed_train_round(self, round_no: int = -1):
         """Returns new client and server weights."""
-        client_weights_lock = threading.Lock()
-        server_weights_lock = threading.Lock()
-        samples_count_lock = threading.Lock()
-        metrics_lock = threading.Lock()
         client_weights = []
         server_weights = []
         samples_count = []
         metrics_container = ModelMetricResultContainer()
-        with concurrent.futures.ThreadPoolExecutor(
-            max_workers=max(len(self.active_devices), 1)
-        ) as executor:  # avoid exception when setting 0 workers
-            futures = [
-                executor.submit(
-                    self.request_dispatcher.federated_train_on, device_id, round_no
-                )
-                for device_id in self.active_devices
-            ]
-            for future in concurrent.futures.as_completed(futures):
-                response = future.result()
-                if response is not False:
-                    new_client_weights, new_server_weights, num_samples, metrics, _ = (
-                        response  # skip diagnostic metrics
+
+        if self.cfg.simulate_parallelism:
+            parallel_times = []
+            with Timer() as elapsed_time:
+                for device_id in self.active_devices:
+                    with Timer() as individual_time:
+                        response = self.request_dispatcher.federated_train_on(
+                            device_id, round_no
+                        )
+                        if response is not False:
+                            (
+                                new_client_weights,
+                                new_server_weights,
+                                num_samples,
+                                metrics,
+                                _,
+                            ) = response  # skip diagnostic metrics
+                            client_weights.append(new_client_weights)
+                            server_weights.append(new_server_weights)
+                            samples_count.append(num_samples)
+                            metrics_container.merge(metrics)
+                    parallel_times.append(individual_time.execution_time)
+            self.logger.log(
+                {
+                    "parallel_fed_time": {
+                        "elapsed_time": elapsed_time.execution_time,
+                        "parallel_time": max(parallel_times),
+                    }
+                }
+            )
+        else:
+            with concurrent.futures.ThreadPoolExecutor(
+                max_workers=max(len(self.active_devices), 1)
+            ) as executor:  # avoid exception when setting 0 workers
+                futures = [
+                    executor.submit(
+                        self.request_dispatcher.federated_train_on, device_id, round_no
                     )
-                    with client_weights_lock:
+                    for device_id in self.active_devices
+                ]
+                for future in concurrent.futures.as_completed(futures):
+                    response = future.result()
+                    if response is not False:
+                        (
+                            new_client_weights,
+                            new_server_weights,
+                            num_samples,
+                            metrics,
+                            _,
+                        ) = response  # skip diagnostic metrics
                         client_weights.append(new_client_weights)
-                    with server_weights_lock:
                         server_weights.append(new_server_weights)
-                    with samples_count_lock:
                         samples_count.append(num_samples)
-                    with metrics_lock:
                         metrics_container.merge(metrics)
 
         print(f"samples count {samples_count}")
-
         return (
             fed_average(model_weights=client_weights, weighting_scheme=samples_count),
             fed_average(model_weights=server_weights, weighting_scheme=samples_count),
diff --git a/edml/core/device.py b/edml/core/device.py
index 94bb18ac35241fc86eec778b8ef2d9145a8ed51f..6770e4271e8dc030af1abc7f8ac010a57242f5bb 100644
--- a/edml/core/device.py
+++ b/edml/core/device.py
@@ -225,6 +225,7 @@ class Device(ABC):
 
 class NetworkDevice(Device):
     @update_battery
+    @log_execution_time("logger", "finalize_gradients")
     def set_gradient_and_finalize_training_on_client_only_on(
         self, client_id: str, gradients: Any
     ):
@@ -404,6 +405,7 @@ class NetworkDevice(Device):
 
     @add_time_to_diagnostic_metrics("evaluate_batch")
     @update_battery
+    @log_execution_time("logger", "evaluate_batch_time")
     def evaluate_batch(self, smashed_data, labels):
         result = self.server.evaluate_batch(smashed_data, labels)
         self._log_current_battery_capacity()
diff --git a/edml/core/server.py b/edml/core/server.py
index 4b9ef80fd876d618075a98a14d61e03d83369b66..4279f40301205c361a8648fe907d9084bd159ac5 100644
--- a/edml/core/server.py
+++ b/edml/core/server.py
@@ -10,7 +10,7 @@ from torch import nn
 from torch.autograd import Variable
 
 from edml.helpers.config_helpers import get_torch_device_id
-from edml.helpers.decorators import check_device_set, simulate_latency_decorator
+from edml.helpers.decorators import check_device_set, simulate_latency_decorator, Timer
 from edml.helpers.executor import create_executor_with_threads
 from edml.helpers.flops import estimate_model_flops
 from edml.helpers.load_optimizer import get_optimizer_and_scheduler
@@ -260,101 +260,245 @@ class DeviceServer:
         # We iterate over each batch, initializing all client training at once and processing the results afterward.
         num_batches = self.node_device.client.get_approximated_num_batches()
         print(f":: BATCHES :: {num_batches}")
-        for batch_index in range(num_batches):
-            client_forward_pass_responses = []
-            futures = [
-                executor.submit(client_training_job, client_id, batch_index)
-                for client_id in clients
-            ]
-            for future in concurrent.futures.as_completed(futures):
-                (client_id, result) = future.result()
-                if result is not None and result is not False:
-                    client_forward_pass_responses.append((client_id, result))
-
-            # We want to split up the responses into a list of client IDs and batches again.
-            client_ids = [b[0] for b in client_forward_pass_responses]
-            client_batches = [b[1] for b in client_forward_pass_responses]
-
-            server_batch = _concat_smashed_data(
-                [b[0].to(self._device) for b in client_batches]
-            )
-            server_labels = _concat_smashed_data(
-                [b[1].to(self._device) for b in client_batches]
-            )
-            # Train the part on the server. Then send the gradients to each client, continuing the calculation. We need
-            # to split the gradients back into batch-sized tensors to average them before sending them to the client.
-            server_gradients, server_loss, server_metrics = (
-                self.node_device.train_batch(server_batch, server_labels)
-            )  # DiagnosticMetricResultContainer
-            # We check if the server should activate the adaptive learning threshold. And if true, we make sure to only
-            # do the client propagation once the current loss value is larger than the threshold.
-            print(
-                f"\n{Fore.GREEN}{adaptive_learning_threshold} <-> {server_loss}\n{Fore.RESET}"
-            )
-            if (
-                adaptive_learning_threshold
-                and server_loss < adaptive_learning_threshold
-            ):
+
+        if self._cfg.simulate_parallelism:
+            for batch_index in range(num_batches):
+                client_forward_pass_responses = []
+                parallel_times = []
+                with Timer() as elapsed_time:
+                    for client_id in clients:
+                        with Timer() as individual_time:
+                            (client_id, result) = client_training_job(
+                                client_id, batch_index
+                            )
+                            if result is not None and result is not False:
+                                client_forward_pass_responses.append(
+                                    (client_id, result)
+                                )
+                        parallel_times.append(individual_time.execution_time)
+                self.node_device.log(
+                    {
+                        "parallel_client_train_time": {
+                            "elapsed_time": elapsed_time.execution_time,
+                            "parallel_time": max(parallel_times),
+                        }
+                    }
+                )
+                # We want to split up the responses into a list of client IDs and batches again.
+                client_ids = [b[0] for b in client_forward_pass_responses]
+                client_batches = [b[1] for b in client_forward_pass_responses]
+
+                server_batch = _concat_smashed_data(
+                    [b[0].to(self._device) for b in client_batches]
+                )
+                server_labels = _concat_smashed_data(
+                    [b[1].to(self._device) for b in client_batches]
+                )
+                # Train the part on the server. Then send the gradients to each client, continuing the calculation. We need
+                # to split the gradients back into batch-sized tensors to average them before sending them to the client.
+                server_gradients, server_loss, server_metrics = (
+                    self.node_device.train_batch(server_batch, server_labels)
+                )  # DiagnosticMetricResultContainer
+                # We check if the server should activate the adaptive learning threshold. And if true, we make sure to only
+                # do the client propagation once the current loss value is larger than the threshold.
                 print(
-                    f"\n{Fore.RED}ADAPTIVE TRESHOLD REACHED, NEXT BATCH\n{Fore.RESET}"
+                    f"\n{Fore.GREEN}{adaptive_learning_threshold} <-> {server_loss}\n{Fore.RESET}"
+                )
+                if (
+                    adaptive_learning_threshold
+                    and server_loss < adaptive_learning_threshold
+                ):
+                    print(
+                        f"\n{Fore.RED}ADAPTIVE TRESHOLD REACHED, NEXT BATCH\n{Fore.RESET}"
+                    )
+                    self.node_device.log(
+                        {
+                            "adaptive_learning_threshold_applied": server_gradients.size(
+                                0
+                            )
+                        }
+                    )
+                    continue
+
+                num_client_gradients = len(client_forward_pass_responses)
+                print(
+                    f"::: tensor shape: {server_gradients.shape} -> {server_gradients.size(0)} with metrics: {server_metrics is not None}"
+                )
+                # clone single client gradients so that client_gradients is not a list of views of server_gradients
+                # if we just use torch.chunk, each client will receive the whole server_gradients
+                client_gradients = [
+                    t.clone().detach()
+                    for t in torch.chunk(server_gradients, num_client_gradients)
+                ]
+                client_backpropagation_gradients = []
+                parallel_times = []
+                with Timer() as elapsed_time:
+                    for idx, client_id in enumerate(client_ids):
+                        with Timer() as individual_time:
+                            _, grads = client_backpropagation_job(
+                                client_id, client_gradients[idx]
+                            )
+                            if grads is not None and grads is not False:
+                                client_backpropagation_gradients.append(grads)
+                        parallel_times.append(individual_time.execution_time)
+                self.node_device.log(
+                    {
+                        "parallel_client_backprop_time": {
+                            "elapsed_time": elapsed_time.execution_time,
+                            "parallel_time": max(parallel_times),
+                        }
+                    }
+                )
+                # We want to average the client's backpropagation gradients and send them over again to finalize the
+                # current training step.
+                averaged_gradient = _calculate_gradient_mean(
+                    client_backpropagation_gradients, self._device
                 )
+                parallel_times = []
+                with Timer() as elapsed_time:
+                    for client_id in clients:
+                        with Timer() as individual_time:
+                            client_set_gradient_and_finalize_training_job(
+                                client_id, averaged_gradient
+                            )
+                        parallel_times.append(individual_time.execution_time)
                 self.node_device.log(
-                    {"adaptive_learning_threshold_applied": server_gradients.size(0)}
+                    {
+                        "parallel_client_model_update_time": {
+                            "elapsed_time": elapsed_time.execution_time,
+                            "parallel_time": max(parallel_times),
+                        }
+                    }
+                )
+        else:
+            for batch_index in range(num_batches):
+                client_forward_pass_responses = []
+                futures = [
+                    executor.submit(client_training_job, client_id, batch_index)
+                    for client_id in clients
+                ]
+                for future in concurrent.futures.as_completed(futures):
+                    (client_id, result) = future.result()
+                    if result is not None and result is not False:
+                        client_forward_pass_responses.append((client_id, result))
+
+                # We want to split up the responses into a list of client IDs and batches again.
+                client_ids = [b[0] for b in client_forward_pass_responses]
+                client_batches = [b[1] for b in client_forward_pass_responses]
+
+                server_batch = _concat_smashed_data(
+                    [b[0].to(self._device) for b in client_batches]
                 )
-                continue
+                server_labels = _concat_smashed_data(
+                    [b[1].to(self._device) for b in client_batches]
+                )
+                # Train the part on the server. Then send the gradients to each client, continuing the calculation. We need
+                # to split the gradients back into batch-sized tensors to average them before sending them to the client.
+                server_gradients, server_loss, server_metrics = (
+                    self.node_device.train_batch(server_batch, server_labels)
+                )  # DiagnosticMetricResultContainer
+                # We check if the server should activate the adaptive learning threshold. And if true, we make sure to only
+                # do the client propagation once the current loss value is larger than the threshold.
+                print(
+                    f"\n{Fore.GREEN}{adaptive_learning_threshold} <-> {server_loss}\n{Fore.RESET}"
+                )
+                if (
+                    adaptive_learning_threshold
+                    and server_loss < adaptive_learning_threshold
+                ):
+                    print(
+                        f"\n{Fore.RED}ADAPTIVE TRESHOLD REACHED, NEXT BATCH\n{Fore.RESET}"
+                    )
+                    self.node_device.log(
+                        {
+                            "adaptive_learning_threshold_applied": server_gradients.size(
+                                0
+                            )
+                        }
+                    )
+                    continue
 
-            num_client_gradients = len(client_forward_pass_responses)
-            print(
-                f"::: tensor shape: {server_gradients.shape} -> {server_gradients.size(0)} with metrics: {server_metrics is not None}"
-            )
-            # clone single client gradients so that client_gradients is not a list of views of server_gradients
-            # if we just use torch.chunk, each client will receive the whole server_gradients
-            client_gradients = [
-                t.clone().detach()
-                for t in torch.chunk(server_gradients, num_client_gradients)
-            ]
-            futures = [
-                executor.submit(
-                    client_backpropagation_job, client_id, client_gradients[idx]
+                num_client_gradients = len(client_forward_pass_responses)
+                print(
+                    f"::: tensor shape: {server_gradients.shape} -> {server_gradients.size(0)} with metrics: {server_metrics is not None}"
                 )
-                for (idx, client_id) in enumerate(client_ids)
-            ]
-            client_backpropagation_gradients = []
-            for future in concurrent.futures.as_completed(futures):
-                _, grads = future.result()
-                if grads is not None and grads is not False:
-                    client_backpropagation_gradients.append(grads)
-            # We want to average the client's backpropagation gradients and send them over again to finalize the
-            # current training step.
-            averaged_gradient = _calculate_gradient_mean(
-                client_backpropagation_gradients, self._device
-            )
-            futures = [
-                executor.submit(
-                    client_set_gradient_and_finalize_training_job,
-                    client_id,
-                    averaged_gradient,
+                # clone single client gradients so that client_gradients is not a list of views of server_gradients
+                # if we just use torch.chunk, each client will receive the whole server_gradients
+                client_gradients = [
+                    t.clone().detach()
+                    for t in torch.chunk(server_gradients, num_client_gradients)
+                ]
+                futures = [
+                    executor.submit(
+                        client_backpropagation_job, client_id, client_gradients[idx]
+                    )
+                    for (idx, client_id) in enumerate(client_ids)
+                ]
+                client_backpropagation_gradients = []
+                for future in concurrent.futures.as_completed(futures):
+                    _, grads = future.result()
+                    if grads is not None and grads is not False:
+                        client_backpropagation_gradients.append(grads)
+                # We want to average the client's backpropagation gradients and send them over again to finalize the
+                # current training step.
+                averaged_gradient = _calculate_gradient_mean(
+                    client_backpropagation_gradients, self._device
                 )
-                for client_id in clients
-            ]
-            for future in concurrent.futures.as_completed(futures):
-                future.result()
+                futures = [
+                    executor.submit(
+                        client_set_gradient_and_finalize_training_job,
+                        client_id,
+                        averaged_gradient,
+                    )
+                    for client_id in clients
+                ]
+                for future in concurrent.futures.as_completed(futures):
+                    future.result()
 
         # Now we have to determine the model metrics for each client.
-        for client_id in clients:
-            train_metrics = self.finalize_metrics(str(client_id), "train")
 
-            evaluation_diagnostics_metrics = self.node_device.evaluate_on(
-                device_id=client_id,
-                server_device=self.node_device.device_id,
-                val=True,
+        if self._cfg.simulate_parallelism:
+            parallel_times = []
+            with Timer() as elapsed_time:
+                for client_id in clients:
+                    with Timer() as individual_time:
+                        train_metrics = self.finalize_metrics(str(client_id), "train")
+
+                        evaluation_diagnostics_metrics = self.node_device.evaluate_on(
+                            device_id=client_id,
+                            server_device=self.node_device.device_id,
+                            val=True,
+                        )
+                        # if evaluation_diagnostics_metrics:
+                        #     diagnostic_metrics.merge(evaluation_diagnostics_metrics)
+                        val_metrics = self.finalize_metrics(str(client_id), "val")
+
+                        model_metrics.add_results(train_metrics)
+                        model_metrics.add_results(val_metrics)
+                    parallel_times.append(individual_time.execution_time)
+            self.node_device.log(
+                {
+                    "parallel_client_eval_time": {
+                        "elapsed_time": elapsed_time.execution_time,
+                        "parallel_time": max(parallel_times),
+                    }
+                }
             )
-            # if evaluation_diagnostics_metrics:
-            #     diagnostic_metrics.merge(evaluation_diagnostics_metrics)
-            val_metrics = self.finalize_metrics(str(client_id), "val")
+        else:
+            for client_id in clients:
+                train_metrics = self.finalize_metrics(str(client_id), "train")
+
+                evaluation_diagnostics_metrics = self.node_device.evaluate_on(
+                    device_id=client_id,
+                    server_device=self.node_device.device_id,
+                    val=True,
+                )
+                # if evaluation_diagnostics_metrics:
+                #     diagnostic_metrics.merge(evaluation_diagnostics_metrics)
+                val_metrics = self.finalize_metrics(str(client_id), "val")
 
-            model_metrics.add_results(train_metrics)
-            model_metrics.add_results(val_metrics)
+                model_metrics.add_results(train_metrics)
+                model_metrics.add_results(val_metrics)
 
         optimizer_state = self._optimizer.state_dict()
         # delete references and free GPU memory manually
diff --git a/edml/helpers/decorators.py b/edml/helpers/decorators.py
index 5623648311abd9e63c51281af657649ea950459e..6f4a8da8912f989163ddca9f829b428800db3749 100644
--- a/edml/helpers/decorators.py
+++ b/edml/helpers/decorators.py
@@ -165,6 +165,22 @@ class LatencySimulator:
         time.sleep(self.execution_time * self.latency_factor)
 
 
+class Timer:
+    """
+    Context Manager to measure execution time.
+
+    Notes:
+        Access execution time via Timer.execution_time.
+    """
+
+    def __enter__(self):
+        self.start_time = time.perf_counter()
+        return self
+
+    def __exit__(self, exc_type, exc_val, exc_tb):
+        self.execution_time = time.perf_counter() - self.start_time
+
+
 def add_time_to_diagnostic_metrics(method_name: str):
     """
     A decorator factory that measures the execution time of the wrapped method. It then creates a diagnostic metric
diff --git a/results/result_generation.ipynb b/results/result_generation.ipynb
index 85b37ddb464b85b1e634dcbcde25c0a221baa9fb..5c3ed118508856d66ca7b70f0027b99a2cb8052f 100644
--- a/results/result_generation.ipynb
+++ b/results/result_generation.ipynb
@@ -29,30 +29,15 @@
    "source": [
     "df_base_dir = \"./dataframes\"\n",
     "projects_with_model = [\n",
-    "    (\"greedy_vs_smart_ecg-non-iid_RESULT\", \"tcn\"),\n",
-    "    (\"greedy_vs_smart_cifar100_resnet20_RESULT\", \"resnet20\"),\n",
-    "    (\"greedy_vs_smart_ecg-iid_RESULT\", \"tcn\"),\n",
-    "    (\"greedy_vs_smart_PTBXL_equal_devices_RESULT\", \"tcn\"),\n",
-    "    (\"greedy_vs_smart_PTBXL_unequal_processors_RESULT\", \"tcn\"),\n",
-    "    (\"greedy_vs_smart_MNIST_unequal_processors_RESULT\", \"simple_conv\"),\n",
-    "    (\"greedy_vs_smart_MNIST_unequal_batteries_unequal_partition_RESULT\", \"simple_conv\"),\n",
-    "    (\"greedy_vs_smart_MNIST_equal_devices_RESULT\", \"simple_conv\"),\n",
-    "    (\"greedy_vs_smart_MNIST_unequal_batteries_RESULT\", \"simple_conv\"),\n",
-    "    (\"fed_vs_split_MNIST_limited_batteries_RESULT\", \"simple_conv\"),\n",
-    "    (\"fed_vs_split_MNIST_unlimited_batteries_RESULT\", \"simple_conv\"),\n",
-    "    (\"fed_vs_split_PTBXL_limited_batteries_RESULT\", \"tcn\"),\n",
-    "    (\"fed_vs_split_PTBXL_unlimited_batteries_RESULT\", \"tcn\"),\n",
-    "    (\"fed_vs_split_cifar100_unlimited_batteries_RESULT\", \"resnet20\"),\n",
-    "    (\"fed_vs_split_CIFAR100_limited_batteries_RESULT\", \"resnet20\"),\n",
-    "    (\"fed_vs_split_50_devices_RESULT\", \"resnet110\"),\n",
-    "    (\"greedy_vs_smart_CIFAR100_equal_devices_RESULT\", \"resnet20\"),\n",
-    "    (\"greedy_vs_smart_CIFAR100_unequal_processors_RESULT\", \"resnet20\"),\n",
+    "    (\"5_devices_unlimited_new\", \"resnet110\"),\n",
+    "    (\"50_devices_unlimited_new\", \"resnet110\"),\n",
+    "    (\"controller_comparison\", \"resnet110\")\n",
     "]"
    ],
    "metadata": {
     "collapsed": false
    },
-   "id": "6695251b9af7ea4b"
+   "id": "5b81c8c9ba4b483d"
   },
   {
    "cell_type": "code",
@@ -61,18 +46,33 @@
    "source": [
     "for project_name, _ in projects_with_model:\n",
     "    save_dataframes(project_name=project_name, strategies=[\n",
-    "        \"swarm_seq\",\n",
-    "        \"fed\",\n",
-    "        \"swarm_max\",\n",
-    "        \"swarm_rand\",\n",
-    "        \"swarm_smart\",\n",
-    "        \"split\"\n",
+    "        #\"swarm_seq\",\n",
+    "        #\"fed\",\n",
+    "        #\"swarm_max\",\n",
+    "        #\"swarm_rand\",\n",
+    "        #\"swarm_smart\",\n",
+    "        #\"split\",\n",
+    "        #\"psl_rand_\",\n",
+    "        #\"psl_sequential_\",\n",
+    "        #\"psl_max_batteries_\",\n",
+    "        #\"swarm_rand_\",\n",
+    "        #\"swarm_sequential_\",\n",
+    "        #\"swarm_max_batteries_\",\n",
+    "        \"psl_sequential__\",\n",
+    "        \"fed___\",\n",
+    "        \"split___\",\n",
+    "        \"swarm_sequential__\",\n",
+    "        \"swarm_max_battery__\",\n",
+    "        \"swarm_smart__\",\n",
+    "        \"psl_sequential_static_at_resnet_decoderpth\",\n",
+    "        \"psl_sequential__resnet_decoderpth\",\n",
+    "        \"psl_sequential_static_at_\",\n",
     "    ])"
    ],
    "metadata": {
     "collapsed": false
    },
-   "id": "b07913828b33ffcc"
+   "id": "118f1ed9e7537718"
   },
   {
    "cell_type": "markdown",
@@ -82,7 +82,7 @@
    "metadata": {
     "collapsed": false
    },
-   "id": "d8269abd823cdcc7"
+   "id": "bbc47124f3c80f1c"
   },
   {
    "cell_type": "code",
@@ -92,28 +92,45 @@
     "# Required for total number of FLOPs computation\n",
     "model_flops = {\n",
     "    \"resnet20\": 41498880,\n",
+    "    \"resnet20_ae\": 45758720,\n",
     "    \"resnet110\": 258136320,\n",
+    "    \"resnet110_ae\": 262396160,\n",
     "    \"tcn\": 27240000,\n",
     "    \"simple_conv\": 16621560\n",
-    "}"
+    "}\n",
+    "\n",
+    "client_model_flops = {\n",
+    "    \"resnet20\": 15171584,\n",
+    "    \"resnet20_ae\": 19005440,\n",
+    "    \"resnet110\": 88408064,\n",
+    "    \"resnet110_ae\": 92241920,\n",
+    "}\n",
+    "\n",
+    "server_model_flops = {\n",
+    "    \"resnet20\": 26327296,\n",
+    "    \"resnet20_ae\": 26753280,\n",
+    "    \"resnet110\": 169728256,\n",
+    "    \"resnet110_ae\": 170154240,\n",
+    "}\n",
+    "experiment_batch_size = 64"
    ],
    "metadata": {
     "collapsed": false
    },
-   "id": "828bcb4737b21c6d"
+   "id": "6e4e3bd0198fe7e7"
   },
   {
    "cell_type": "code",
    "execution_count": null,
    "outputs": [],
    "source": [
-    "plots_base_path=\"./plots\"\n",
-    "metrics_base_path=\"./metrics\""
+    "plots_base_path = \"./plots\"\n",
+    "metrics_base_path = \"./metrics\""
    ],
    "metadata": {
     "collapsed": false
    },
-   "id": "b13f9e0e98b7ac5b"
+   "id": "ede70693af668f55"
   },
   {
    "cell_type": "code",
@@ -126,14 +143,16 @@
     "        print(\"  loading data from disk\")\n",
     "        dataframes = load_dataframes(proj_name, df_base_dir)\n",
     "        print(\"  generating metrics\")\n",
-    "        generate_metric_files(dataframes, proj_name, model_flops[model_name])\n",
+    "        generate_metric_files(dataframes, proj_name, model_flops[model_name], client_model_flops[model_name],\n",
+    "                              # TODO distinguish AE\n",
+    "                              base_path=metrics_base_path, batch_size=experiment_batch_size)\n",
     "        print(\"  generating plots\")\n",
     "        generate_plots(dataframes, proj_name)"
    ],
    "metadata": {
     "collapsed": false
    },
-   "id": "c4b1ed2d809c54e2"
+   "id": "1c72379feadc98cb"
   },
   {
    "cell_type": "code",
@@ -145,7 +164,17 @@
    "metadata": {
     "collapsed": false
    },
-   "id": "378bf3365dd9fde2"
+   "id": "7927831aecd5a02"
+  },
+  {
+   "cell_type": "code",
+   "execution_count": null,
+   "outputs": [],
+   "source": [],
+   "metadata": {
+    "collapsed": false
+   },
+   "id": "fc698fc664867532"
   }
  ],
  "metadata": {
diff --git a/results/result_generation.py b/results/result_generation.py
index b95183ebae7218c3c6ae7453b6bc103d69ba163e..caa42642976785d82f0313733b324c9cd63b1166 100644
--- a/results/result_generation.py
+++ b/results/result_generation.py
@@ -5,6 +5,7 @@ import matplotlib.pyplot as plt
 import numpy as np
 import pandas as pd
 import wandb
+import math
 
 # For plotting
 STRATEGY_MAPPING = {
@@ -14,6 +15,16 @@ STRATEGY_MAPPING = {
     "swarm_rand": "Swarm SL (Rand)",
     "swarm_max": "Swarm SL (Greedy)",
     "fed": "Vanilla FL",
+    "psl_sequential__": "PSSL (Seq)",
+    "fed___": "Vanilla FL",
+    "swarm_sequential__": "Swarm SL (Seq)",
+    "swarm_smart__": "Swarm SL (Smart)",
+    "swarm_rand__": "Swarm SL (Rand)",
+    "swarm_max_battery__": "Swarm SL (Greedy)",
+    "split___": "Vanilla SL",
+    "psl_sequential_static_at_resnet_decoderpth": "PSSL (Seq) AE Static",
+    "psl_sequential__resnet_decoderpth": "PSSL (Seq) AE",
+    "psl_sequential_static_at_": "PSSL (Seq) Static",
 }
 
 LABEL_MAPPING = {
@@ -29,6 +40,75 @@ LABEL_MAPPING = {
 }
 
 
+def scale_parallel_time(run_df, scale_factor=1.0):
+    """
+    Scales the time by the provided scale_factor.
+    Args:
+        run_df: The dataframe of the project
+        scale_factor: (float) the factor to shorten time e.g. 2 halves the total time
+    Returns:
+        run_df: the dataframe with scaled timestamps
+    """
+    if scale_factor == 1:
+        return run_df
+    if "_timestamp" in run_df.columns:
+        start_time = run_df["_timestamp"].min()
+        for col in run_df.columns:
+            if col.endswith(".start") or col.endswith(".end") or col == "_timestamp":
+                run_df[col] = (run_df[col] - start_time) / scale_factor + start_time
+            if col.endswith(".duration") or col == "_runtime":
+                run_df[col] = run_df[col] / scale_factor
+    return run_df
+
+
+def get_scale_factors(group):
+    """
+    Determines the scale factor to account for parallelism.
+    For each set of runs (i.e. one run of controller, d0, d1, ...), the time overhead introduced by running a
+    parallel operation sequentially is determined and the resulting factor to scale the runs down as well.
+    If no parallel operations were simulated, no time is deduced and the scale factor will equal 1.
+    Args:
+        group: the group of runs
+    Returns:
+        A list of factors to scale down each set of runs.
+    """
+    columns_to_count = [
+        "parallel_client_train_time",
+        "parallel_client_backprop_time",
+        "parallel_client_model_update_time",
+        "parallel_fed_time",
+    ]
+    scale_factors = []
+    num_runs = len(next(iter(group.values())))
+    max_runtime = [0] * num_runs
+    elapsed_time = [0] * num_runs
+    parallel_time = [0] * num_runs
+    for name, runs in group.items():
+        for i, run_df in enumerate(runs):
+            if "_runtime" in run_df.columns:  # assure that run_df is not empty
+                if run_df["_runtime"].max() > max_runtime[i]:
+                    max_runtime[i] = run_df["_runtime"].max()
+                for col_name in columns_to_count:
+                    if f"{col_name}.parallel_time" in run_df.columns:
+                        elapsed_time[i] += run_df[f"{col_name}.elapsed_time"].sum()
+                        parallel_time[i] += run_df[f"{col_name}.parallel_time"].sum()
+                if "parallel_client_eval_time.parallel_time" in run_df.columns:
+                    if "evaluate_batch_time.duration" in run_df.columns:
+                        elapsed_time[i] += run_df[
+                            "parallel_client_eval_time.elapsed_time"
+                        ].sum()
+                        parallel_time[i] += (
+                            run_df["parallel_client_eval_time.parallel_time"].sum()
+                            - run_df["evaluate_batch_time.duration"].sum()
+                        )  # evaluate batch time measured at server -> sequential either way
+    for i, max_rt in enumerate(max_runtime):
+        if max_rt > 0:
+            scale_factors.append(max_rt / (max_rt - elapsed_time[i] + parallel_time[i]))
+        else:
+            scale_factors.append(1.0)
+    return scale_factors
+
+
 def save_dataframes(project_name, strategies, base_dir="./dataframes"):
     """
     Fetches the dataframes from wandb and saves them to the base_dir.
@@ -81,13 +161,22 @@ def save_dataframes(project_name, strategies, base_dir="./dataframes"):
     history_groups = {}
     for (strategy, job), group in run_groups.items():
         print(f"  {strategy} {job}")
-        history = defaultdict(list)
+        unscaled_runs = defaultdict(list)
         for name, runs in group.items():
             print(f"    {name}")
             for run in runs:
                 history_df = pd.DataFrame(run.scan_history())
-                history[name].append(history_df)
-        history_groups[(strategy, job)] = history
+                unscaled_runs[name].append(history_df)
+        # rescale if parallelism was only simulated
+        if job == "train" and len(unscaled_runs) > 0:
+            scale_factors = get_scale_factors(unscaled_runs)
+            scaled_runs = defaultdict(list)
+            for name, runs in unscaled_runs.items():
+                for i, run in enumerate(runs):
+                    scaled_runs[name].append(scale_parallel_time(run, scale_factors[i]))
+            history_groups[(strategy, job)] = scaled_runs
+        else:
+            history_groups[(strategy, job)] = unscaled_runs
     # save dataframe
     print("saving data")
     for (strategy, job), group in history_groups.items():
@@ -105,7 +194,7 @@ def save_dataframes(project_name, strategies, base_dir="./dataframes"):
 
 def load_dataframes(project_name, base_dir="./dataframes"):
     """
-    Loades saved dataframes from the given project.
+    Loads saved dataframes from the given project.
     Args:
         project_name: (str) the name of the project folder
         base_dir: (str) the base directory to fetch the dataframes from
@@ -130,7 +219,7 @@ def load_dataframes(project_name, base_dir="./dataframes"):
                 project_dir, strategy, job, device_id = path.split(os.sep)
 
                 # Load dataframe from csv
-                df = pd.read_csv(os.path.join(root, file))
+                df = pd.read_csv(os.path.join(root, file), low_memory=False)
 
                 # Add dataframe to dictionary
                 if (strategy, job) not in history_groups:
@@ -141,12 +230,13 @@ def load_dataframes(project_name, base_dir="./dataframes"):
     return history_groups
 
 
-def get_total_flops(groups, total_model_flops):
+def get_total_flops(groups, total_model_flops, client_model_flops, batch_size=64):
     """
     Returns the total number of FLOPs for each group.
     Args:
         groups: The runs of one project, according to the structure of the wandb project
         total_model_flops: (int) the total number of FLOPs of the model
+        client_model_flops: (int) the total number of FLOPs of the client model
     Returns:
         flops_per_group: (dict) the total number of FLOPs for each group
     """
@@ -155,6 +245,7 @@ def get_total_flops(groups, total_model_flops):
         if job == "train":
             flops = 0
             num_runs = 1  # avoid division by 0
+            num_clients = len(group.items()) - 1  # minus controller
             for name, runs in group.items():
                 if (
                     name != "controller"
@@ -170,6 +261,38 @@ def get_total_flops(groups, total_model_flops):
                                 flops += (
                                     run_df[col_name].sum() * total_model_flops
                                 )  # 1x forward
+                            if col_name == "adaptive_learning_threshold_applied":
+                                # deduce client model flops twice as client backprop is avoided
+                                if (
+                                    run_df[col_name].dtype == "object"
+                                ):  # if boolean values were logged
+                                    # assumptions: compute avg number of samples per batch
+                                    avg_samples_per_epoch = sum(
+                                        run_df["train_accuracy.num_samples"].dropna()
+                                    ) / len(
+                                        run_df["train_accuracy.num_samples"].dropna()
+                                    )
+                                    avg_num_batches = (
+                                        math.ceil(
+                                            avg_samples_per_epoch
+                                            / num_clients
+                                            / batch_size
+                                        )
+                                        * num_clients
+                                    )
+                                    avg_samples_per_batch = (
+                                        avg_samples_per_epoch / avg_num_batches
+                                    )
+                                    flops -= (
+                                        len(run_df[col_name].dropna())
+                                        * client_model_flops
+                                        * 2
+                                        * avg_samples_per_batch
+                                    )
+                                else:  # numbers of samples skipped are logged -> sum up
+                                    flops -= (
+                                        run_df[col_name].sum() * client_model_flops * 2
+                                    )
             flops = flops / num_runs
             flops_per_group["strategy"].append(STRATEGY_MAPPING[strategy])
             flops_per_group["flops"].append(round(flops / 1000000000, 3))  # in GFLOPs
@@ -341,6 +464,27 @@ def accuracy_over_epoch(history_groups, phase="train"):
     return results
 
 
+def accuracy_over_time(history_groups, phase="train"):
+    """
+    Returns the accuracy over time for each group. No averaging implemented yet if there are multiple runs per group!
+    Args:
+        history_groups: The runs of one project, according to the structure of the wandb project
+        phase: (str) the phase to get the accuracy for, either "train" or "val"
+    Returns:
+        results: (dict) the accuracy (list(float)) per round (list(int)) for each group
+    """
+    results = {}
+    for (strategy, job), group in history_groups.items():
+        if job == "train":
+            run_df = group["controller"][0]  # no averaging
+            time_acc = run_df[[f"{phase}_accuracy.value", "_runtime"]].dropna()
+            results[(strategy, job)] = (
+                time_acc["_runtime"],
+                time_acc[f"{phase}_accuracy.value"],
+            )
+    return results
+
+
 def plot_accuracies(accuracies_per_round, save_path=None, phase="train"):
     """
     Plots the accuracy over the epoch for each group.
@@ -366,6 +510,28 @@ def plot_accuracies(accuracies_per_round, save_path=None, phase="train"):
         plt.close()
 
 
+def plot_accuracies_over_time(accuracies_per_time, save_path=None, phase="train"):
+    """
+    Plots the accuracy over the time for each group.
+    Args:
+        accuracies_per_time: (dict) the accuracy (list(float)) per time (list(float)) for each group
+        save_path: (str) the path to save the plot to
+    """
+    plt.figure()
+    for (strategy, job), (time, accs) in accuracies_per_time.items():
+        plt.plot(time, accs, label=f"{STRATEGY_MAPPING[strategy]}")
+    plt.xlabel(LABEL_MAPPING["runtime"])
+    plt.ylabel(LABEL_MAPPING[f"{phase} accuracy"])
+    plt.legend()
+    plt.tight_layout()
+    if save_path is None:
+        plt.show()
+    else:
+        plt.savefig(f"{save_path}.pdf", format="pdf")
+        plt.savefig(f"{save_path}.png", format="png")
+        plt.close()
+
+
 def battery_over_time(history_groups, num_intervals=1000):
     """
     Returns the average battery over time for each group.
@@ -1028,20 +1194,6 @@ def generate_plots(history_groups, project_name, base_path="./plots"):
         aggregated=False,
     )
 
-    train_times = get_train_times(history_groups)
-    plot_batteries_over_time_with_activity(
-        batteries_over_time,
-        max_runtimes,
-        train_times,
-        save_path=f"{project_path}/activity_over_time",
-    )
-    plot_batteries_over_epoch_with_activity_at_time_scale(
-        batteries_over_time,
-        max_runtimes,
-        train_times,
-        save_path=f"{project_path}/activity_over_time_with_epoch",
-    )
-
     # batteries over epoch
     batteries_over_epoch = aggregated_battery_over_epoch(
         history_groups, num_intervals=1000
@@ -1058,13 +1210,6 @@ def generate_plots(history_groups, project_name, base_path="./plots"):
         save_path=f"{project_path}/batteries_over_epoch",
         aggregated=False,
     )
-    training_times = get_train_times(history_groups)
-    plot_batteries_over_epoch_with_activity_at_epoch_scale(
-        batteries_over_epoch,
-        training_times=training_times,
-        save_path=f"{project_path}/activity_over_epoch",
-    )
-
     # remaining devices
     remaining_devices = remaining_devices_per_round(history_groups)
     plot_remaining_devices(
@@ -1080,16 +1225,32 @@ def generate_plots(history_groups, project_name, base_path="./plots"):
     val_accs = accuracy_over_epoch(history_groups, "val")
     plot_accuracies(val_accs, save_path=f"{project_path}/val_accuracy", phase="val")
 
+    time_train_accs = accuracy_over_time(history_groups, "train")
+    plot_accuracies_over_time(
+        time_train_accs,
+        save_path=f"{project_path}/train_accuracy_over_time",
+        phase="train",
+    )
+    time_val_accs = accuracy_over_time(history_groups, "val")
+    plot_accuracies_over_time(
+        time_val_accs, save_path=f"{project_path}/val_accuracy_over_time", phase="val"
+    )
+
 
 def generate_metric_files(
-    history_groups, project_name, model_flops, base_path="./metrics"
+    history_groups,
+    project_name,
+    total_model_flops,
+    client_model_flops,
+    base_path="./metrics",
+    batch_size=64,
 ):
     """
     Generates metric file for the given history groups and saves them to the project_name folder.
     Args:
         history_groups: The runs of one project, according to the structure of the wandb project
         project_name: (str) the name of the project
-        model_flops: (int) the total number of FLOPs of the model
+        total_model_flops: (int) the total number of FLOPs of the model
         base_path: (str) the base path to save the metrics to
     """
     project_path = f"{base_path}/{project_name}"
@@ -1103,7 +1264,9 @@ def generate_metric_files(
         get_communication_overhead(history_groups)
     ).set_index("strategy")
     total_flops = pd.DataFrame.from_dict(
-        get_total_flops(history_groups, model_flops)
+        get_total_flops(
+            history_groups, total_model_flops, client_model_flops, batch_size
+        )
     ).set_index("strategy")
     df = pd.concat([test_acc, comm_overhead, total_flops], axis=1)
     df.to_csv(f"{project_path}/metrics.csv")