diff --git a/config/default.yaml b/config/default.yaml
index dac99d0ff6f6b963c60d9cc8e16b0c82df22d217..cf51f4118495971a2b319f335d6844cc6cb1351f 100644
--- a/config/default.yaml
+++ b/config/default.yaml
@@ -14,8 +14,9 @@ defaults:
   - wandb: default
   - _self_
 
-# If true, controllers will run devices in parallel. If false, they will run sequentially and their runtime is corrected
+# If False, controllers will run devices in parallel. If True, they will run sequentially and their runtime is corrected
 # to account for the parallelism in post-processing.
+# Important Note: In a limited energy setting, the runtime will not be accounted for correctly (i.e. wall time) if parallelism is only simulated
 simulate_parallelism: False
 own_device_id: "d0"
 num_devices: ${len:${topology.devices}}
diff --git a/edml/config/controller/swarm.yaml b/edml/config/controller/swarm.yaml
index f31a60a20c1e4cff7f8eec4bd9b1ff4958753ab5..39bb393eef5bf2b73a6aa5c755c9f41ad765809e 100644
--- a/edml/config/controller/swarm.yaml
+++ b/edml/config/controller/swarm.yaml
@@ -3,3 +3,4 @@ _target_: edml.controllers.swarm_controller.SwarmController
 _partial_: true
 defaults:
   - scheduler: sequential
+  - adaptive_threshold_fn: !!null
diff --git a/edml/controllers/parallel_split_controller.py b/edml/controllers/parallel_split_controller.py
index 82bf519a06dc456d79ba25368914c376236275c9..05e5e49685086fcc47b7af390a64c825118abbe0 100644
--- a/edml/controllers/parallel_split_controller.py
+++ b/edml/controllers/parallel_split_controller.py
@@ -49,13 +49,13 @@ class ParallelSplitController(BaseController):
                 )
 
             # Start parallel training of all client devices.
-            adaptive_threshold = self._adaptive_threshold_fn.invoke(i)
-            self.logger.log({"adaptive-threshold": adaptive_threshold})
+            adaptive_threshold_value = self._adaptive_threshold_fn.invoke(i)
+            self.logger.log({"adaptive-threshold": adaptive_threshold_value})
             training_response = self.request_dispatcher.train_parallel_on_server(
                 server_device_id=server_device_id,
                 epochs=1,
                 round_no=i,
-                adaptive_learning_threshold=adaptive_threshold,
+                adaptive_threshold_value=adaptive_threshold_value,
                 optimizer_state=optimizer_state,
             )
 
diff --git a/edml/controllers/scheduler/smart.py b/edml/controllers/scheduler/smart.py
index 48940331247f647a855d72bf3e8a54b44f13ba78..eddf1e98f9b18b0589ff0f61782e54a98ccdd130 100644
--- a/edml/controllers/scheduler/smart.py
+++ b/edml/controllers/scheduler/smart.py
@@ -90,8 +90,10 @@ class SmartNextServerScheduler(NextServerScheduler):
             device_params_list.append(device_params)
         global_params = GlobalParams()
         global_params.fill_values_from_config(self.cfg)
-        global_params.client_model_flops = model_flops["client"]
-        global_params.server_model_flops = model_flops["server"]
+        global_params.client_fw_flops = model_flops["client"]["FW"]
+        global_params.server_fw_flops = model_flops["server"]["FW"]
+        global_params.client_bw_flops = model_flops["client"]["BW"]
+        global_params.server_bw_flops = model_flops["server"]["BW"]
         global_params.client_norm_fw_time = optimization_metrics["client_norm_fw_time"]
         global_params.client_norm_bw_time = optimization_metrics["client_norm_bw_time"]
         global_params.server_norm_fw_time = optimization_metrics["server_norm_fw_time"]
diff --git a/edml/controllers/strategy_optimization.py b/edml/controllers/strategy_optimization.py
index cd7bae8ac0db1c8eb4bcf818713ed8b2f1ed0e47..7c11f73652cf1aecd792a26e603a099c488b9943 100644
--- a/edml/controllers/strategy_optimization.py
+++ b/edml/controllers/strategy_optimization.py
@@ -29,8 +29,10 @@ class GlobalParams:
         cost_per_byte_sent=None,
         cost_per_byte_received=None,
         cost_per_flop=None,
-        client_model_flops=None,
-        server_model_flops=None,
+        client_fw_flops=None,  # mandatory forward FLOPs
+        server_fw_flops=None,  # mandatory forward FLOPs
+        client_bw_flops=None,  # optional backward FLOPs if bw FLOPs != 2 * fw FLOPs (e.g. for AE)
+        server_bw_flops=None,  # optional backward FLOPs if bw FLOPs != 2 * fw FLOPs (e.g. for AE)
         smashed_data_size=None,
         label_size=None,
         gradient_size=None,
@@ -49,8 +51,10 @@ class GlobalParams:
         self.cost_per_byte_sent = cost_per_byte_sent
         self.cost_per_byte_received = cost_per_byte_received
         self.cost_per_flop = cost_per_flop
-        self.client_model_flops = client_model_flops
-        self.server_model_flops = server_model_flops
+        self.client_fw_flops = client_fw_flops
+        self.server_fw_flops = server_fw_flops
+        self.client_bw_flops = client_bw_flops  # optional backward FLOPs if bw FLOPs != 2 * fw FLOPs (e.g. for AE)
+        self.server_bw_flops = server_bw_flops  # optional backward FLOPs if bw FLOPs != 2 * fw FLOPs (e.g. for AE)
         self.smashed_data_size = smashed_data_size
         self.batch_size = batch_size
         self.client_weights_size = client_weights_size
@@ -79,8 +83,8 @@ class GlobalParams:
             and self.cost_per_byte_sent is not None
             and self.cost_per_byte_received is not None
             and self.cost_per_flop is not None
-            and self.client_model_flops is not None
-            and self.server_model_flops is not None
+            and self.client_fw_flops is not None
+            and self.server_fw_flops is not None
             and self.optimizer_state_size is not None
             and self.smashed_data_size is not None
             and self.label_size is not None
@@ -181,22 +185,24 @@ class ServerChoiceOptimizer:
     def num_flops_per_round_on_device(self, device_id, server_device_id):
         device = self._get_device_params(device_id)
         total_flops = 0
+        # set bw FLOPs to 2 * fw FLOPs if not provided
+        if self.global_params.client_bw_flops and self.global_params.server_bw_flops:
+            client_bw_flops = self.global_params.client_bw_flops
+            server_bw_flops = self.global_params.server_bw_flops
+        else:
+            client_bw_flops = (
+                2 * self.global_params.client_fw_flops
+            )  # by default bw FLOPs = 2*fw FLOPs
+            server_bw_flops = 2 * self.global_params.server_fw_flops
+        client_fw_flops = self.global_params.client_fw_flops
+        server_fw_flops = self.global_params.server_fw_flops
         if device_id == server_device_id:
             total_flops += (
-                self.global_params.server_model_flops
-                * self._total_train_dataset_size()
-                * 3
-            )  # fw + 2bw
-            total_flops += (
-                self.global_params.server_model_flops
-                * self._total_validation_dataset_size()
-            )  # fw
-        total_flops += (
-            self.global_params.client_model_flops * device.train_samples * 3
-        )  # fw + 2bw
-        total_flops += (
-            self.global_params.client_model_flops * device.validation_samples
-        )  # fw
+                server_fw_flops + server_bw_flops
+            ) * self._total_train_dataset_size()
+            total_flops += server_fw_flops * self._total_validation_dataset_size()
+        total_flops += (client_fw_flops + client_bw_flops) * device.train_samples
+        total_flops += client_fw_flops * device.validation_samples
         return total_flops
 
     def num_bytes_sent_per_round_on_device(self, device_id, server_device_id):
@@ -476,10 +482,10 @@ class EnergySimulator:
     def _fl_flops_on_device(self, device_id):
         device = self._get_device_params(device_id)
         total_flops = 0
-        total_flops += self.global_params.client_model_flops * device.train_samples * 3
-        total_flops += self.global_params.client_model_flops * device.validation_samples
-        total_flops += self.global_params.server_model_flops * device.train_samples * 3
-        total_flops += self.global_params.server_model_flops * device.validation_samples
+        total_flops += self.global_params.client_fw_flops * device.train_samples * 3
+        total_flops += self.global_params.client_fw_flops * device.validation_samples
+        total_flops += self.global_params.server_fw_flops * device.train_samples * 3
+        total_flops += self.global_params.server_fw_flops * device.validation_samples
         return total_flops
 
     def _fl_data_sent_per_device(self):
diff --git a/edml/controllers/swarm_controller.py b/edml/controllers/swarm_controller.py
index 4463e2cc8ffd3a89471e8728ba0a117e163e8fea..716b73ec20546a83c7b63e34dd9287a113b4a8e4 100644
--- a/edml/controllers/swarm_controller.py
+++ b/edml/controllers/swarm_controller.py
@@ -1,5 +1,9 @@
 from typing import Any
 
+from edml.controllers.adaptive_threshold_mechanism import AdaptiveThresholdFn
+from edml.controllers.adaptive_threshold_mechanism.static import (
+    StaticAdaptiveThresholdFn,
+)
 from edml.controllers.base_controller import BaseController
 from edml.controllers.scheduler.base import NextServerScheduler
 from edml.helpers.config_helpers import get_device_index_by_id
@@ -7,10 +11,16 @@ from edml.helpers.config_helpers import get_device_index_by_id
 
 class SwarmController(BaseController):
 
-    def __init__(self, cfg, scheduler: NextServerScheduler):
+    def __init__(
+        self,
+        cfg,
+        scheduler: NextServerScheduler,
+        adaptive_threshold_fn: AdaptiveThresholdFn = StaticAdaptiveThresholdFn(0.0),
+    ):
         super().__init__(cfg)
         scheduler.initialize(self)
         self._next_server_scheduler = scheduler
+        self._adaptive_threshold_fn = adaptive_threshold_fn
 
     def _train(self):
         client_weights = None
@@ -87,10 +97,13 @@ class SwarmController(BaseController):
             device_id=server_device_id, state_dict=server_weights, on_client=False
         )
 
+        adaptive_threshold = self._adaptive_threshold_fn.invoke(round_no)
+        self.logger.log({"adaptive-threshold": adaptive_threshold})
         training_response = self.request_dispatcher.train_global_on(
             server_device_id,
             epochs=1,
             round_no=round_no,
+            adaptive_threshold_value=adaptive_threshold,
             optimizer_state=optimizer_state,
         )
 
@@ -121,17 +134,32 @@ class SwarmController(BaseController):
         """Returns the dataset sizes and model flops of active devices only."""
         dataset_sizes = {}
         model_flops = {}
-        client_flop_list = []
-        server_flop_list = []
+        client_fw_flop_list = []
+        server_fw_flop_list = []
+        client_bw_flop_list = []
+        server_bw_flop_list = []
         for device_id in self.active_devices:
-            train_samples, val_samples, client_flops, server_flops = (
-                self.request_dispatcher.get_dataset_model_info_on(device_id)
-            )
+            (
+                train_samples,
+                val_samples,
+                client_fw_flops,
+                server_fw_flops,
+                client_bw_flops,
+                server_bw_flops,
+            ) = self.request_dispatcher.get_dataset_model_info_on(device_id)
             dataset_sizes[device_id] = (train_samples, val_samples)
-            client_flop_list.append(client_flops)
-            server_flop_list.append(server_flops)
+            client_fw_flop_list.append(client_fw_flops)
+            server_fw_flop_list.append(server_fw_flops)
+            client_bw_flop_list.append(client_bw_flops)
+            server_bw_flop_list.append(server_bw_flops)
         # avoid that flops are taken from a device that wasn't used for training and thus has no flops
         # apart from that, FLOPs should be the same everywhere
-        model_flops["client"] = max(client_flop_list)
-        model_flops["server"] = max(server_flop_list)
+        model_flops["client"] = {
+            "FW": max(client_fw_flop_list),
+            "BW": max(client_bw_flop_list),
+        }
+        model_flops["server"] = {
+            "FW": max(server_fw_flop_list),
+            "BW": max(server_bw_flop_list),
+        }
         return dataset_sizes, model_flops
diff --git a/edml/core/client.py b/edml/core/client.py
index d8096c7ca5bc479a2aec458f2daf1cc442db4986..03437ce4a4ade8472a889586fb3f82981a6278c8 100644
--- a/edml/core/client.py
+++ b/edml/core/client.py
@@ -1,6 +1,5 @@
 from __future__ import annotations
 
-import itertools
 import time
 from typing import Optional, Tuple, TYPE_CHECKING, Any
 
@@ -19,6 +18,7 @@ from edml.helpers.flops import estimate_model_flops
 from edml.helpers.load_optimizer import get_optimizer_and_scheduler
 from edml.helpers.metrics import DiagnosticMetricResultContainer, DiagnosticMetricResult
 from edml.helpers.types import StateDict
+from edml.models.provider.base import has_autoencoder
 
 if TYPE_CHECKING:
     from edml.core.device import Device
@@ -69,13 +69,13 @@ class DeviceClient:
         self._device = torch.device(get_torch_device_id(cfg))
         self._model = model.to(self._device)
         self._optimizer, self._lr_scheduler = get_optimizer_and_scheduler(
-            cfg, self._model.parameters()
+            cfg, self._model.get_optimizer_params()
         )
         # get first sample from train data to estimate model flops
         sample = self._train_data.dataset.__getitem__(0)[0]
         if not isinstance(sample, torch.Tensor):
             sample = torch.tensor(data=sample)
-        self._model_flops = estimate_model_flops(
+        self._model_flops: dict[str, int] = estimate_model_flops(
             self._model, sample.to(self._device).unsqueeze(0)
         )
         self._cfg = cfg
@@ -93,6 +93,7 @@ class DeviceClient:
             If a latency factor is specified, this function sleeps for said amount before returning.
         """
         self.node_device = node_device
+        self.node_device.logger.log({"client_model_flops": self._model_flops})
 
     @simulate_latency_decorator(latency_factor_attr="latency_factor")
     def set_weights(self, state_dict: StateDict):
@@ -168,10 +169,10 @@ class DeviceClient:
         # Safety check to ensure that we train same-sized batches only.
         batch_data, batch_labels = next(self._batchable_data_loader)
 
-        # Updates the battery capacity by simulating the required energy consumption for conducting the training step.
-        self.node_device.battery.update_flops(self._model_flops * len(batch_data))
+        # Updates the battery capacity by simulating the required energy consumption for conducting the forward pass.
+        self.node_device.battery.update_flops(self._model_flops["FW"] * len(batch_data))
 
-        # We train the model using the single batch and return the activations and labels. These get send over to the
+        # We train the model using the single batch and return the activations and labels. These get sent over to the
         # server to be then further processed
 
         with LatencySimulator(latency_factor=self.latency_factor):
@@ -204,11 +205,12 @@ class DeviceClient:
 
         start_time_2 = time.time()
 
-        self.node_device.battery.update_flops(
-            self._model_flops * len(batch_data) * 2
-        )  # 2x for backward pass
+        self.node_device.battery.update_flops(self._model_flops["BW"] * len(batch_data))
         gradients = gradients.to(self._device)
-        smashed_data.backward(gradients)
+        if has_autoencoder(self._model):
+            self._model.trainable_layers_output.backward(gradients)
+        else:
+            smashed_data.backward(gradients)
         # self._optimizer.step()
 
         # We need to store a reference to the smashed_data to make it possible to finalize the training step.
@@ -225,7 +227,7 @@ class DeviceClient:
         metrics_container = DiagnosticMetricResultContainer([metric])
 
         gradients = []
-        for param in self._model.parameters():
+        for param in self._model.get_optimizer_params():
             if param.grad is not None:
                 gradients.append(param.grad)
             else:
@@ -276,7 +278,9 @@ class DeviceClient:
         self._model.train()
         diagnostic_metric_container = DiagnosticMetricResultContainer()
         for idx, (batch_data, batch_labels) in enumerate(self._train_data):
-            self.node_device.battery.update_flops(self._model_flops * len(batch_data))
+            self.node_device.battery.update_flops(
+                self._model_flops["FW"] * len(batch_data)
+            )
 
             with LatencySimulator(latency_factor=self.latency_factor):
                 batch_data = batch_data.to(self._device)
@@ -299,12 +303,16 @@ class DeviceClient:
                     break
                 server_grad, _server_loss, diagnostic_metrics = train_batch_response
                 diagnostic_metric_container.merge(diagnostic_metrics)
-                self.node_device.battery.update_flops(
-                    self._model_flops * len(batch_data) * 2
-                )  # 2x for backward pass
-                server_grad = server_grad.to(self._device)
-                smashed_data.backward(server_grad)
-                self._optimizer.step()
+                if server_grad is not None:  # otherwise threshold was applied
+                    self.node_device.battery.update_flops(
+                        self._model_flops["BW"] * len(batch_data)
+                    )
+                    server_grad = server_grad.to(self._device)
+                    if has_autoencoder(self._model):
+                        self._model.trainable_layers_output.backward(server_grad)
+                    else:
+                        smashed_data.backward(server_grad)
+                    self._optimizer.step()
 
         client_train_time = (
             time.time() - client_train_start_time - sum(server_train_batch_times)
@@ -355,7 +363,7 @@ class DeviceClient:
             for b, (batch_data, batch_labels) in enumerate(dataloader):
                 with LatencySimulator(latency_factor=self.latency_factor):
                     self.node_device.battery.update_flops(
-                        self._model_flops * len(batch_data)
+                        self._model_flops["FW"] * len(batch_data)
                     )
                     batch_data = batch_data.to(self._device)
                     batch_labels = batch_labels.to(self._device)
@@ -382,7 +390,7 @@ class DeviceClient:
         return diagnostic_metric_results
 
     def set_gradient_and_finalize_training(self, gradients: Any):
-        for param, grad in zip(self._model.parameters(), gradients):
+        for param, grad in zip(self._model.get_optimizer_params(), gradients):
             param.grad = grad.to(self._device)
 
         self._optimizer.step()
diff --git a/edml/core/device.py b/edml/core/device.py
index 6770e4271e8dc030af1abc7f8ac010a57242f5bb..aea1ae95de2735dc4c984a13d49dadd632d66432 100644
--- a/edml/core/device.py
+++ b/edml/core/device.py
@@ -242,13 +242,13 @@ class NetworkDevice(Device):
         self,
         clients: list[str],
         round_no: int,
-        adaptive_learning_threshold: Optional[float] = None,
+        adaptive_threshold_value: Optional[float] = None,
         optimizer_state: dict[str, Any] = None,
     ):
         return self.server.train_parallel_split_learning(
             clients=clients,
             round_no=round_no,
-            adaptive_learning_threshold=adaptive_learning_threshold,
+            adaptive_threshold_value=adaptive_threshold_value,
             optimizer_state=optimizer_state,
         )
 
@@ -303,7 +303,11 @@ class NetworkDevice(Device):
     @update_battery
     @log_execution_time("logger", "train_global_time")
     def train_global(
-        self, epochs: int, round_no: int = -1, optimizer_state: dict[str, Any] = None
+        self,
+        epochs: int,
+        round_no: int = -1,
+        adaptive_threshold_value: Optional[float] = None,
+        optimizer_state: dict[str, Any] = None,
     ) -> Tuple[
         Any, Any, ModelMetricResultContainer, Any, DiagnosticMetricResultContainer
     ]:
@@ -311,6 +315,7 @@ class NetworkDevice(Device):
             devices=self.__get_device_ids__(),
             epochs=epochs,
             round_no=round_no,
+            adaptive_threshold_value=adaptive_threshold_value,
             optimizer_state=optimizer_state,
         )
 
@@ -448,7 +453,12 @@ class RPCDeviceServicer(DeviceServicer):
     def TrainGlobal(self, request, context):
         print(f"Called TrainGlobal on device {self.device.device_id}")
         client_weights, server_weights, metrics, optimizer_state, diagnostic_metrics = (
-            self.device.train_global(request.epochs, request.round_no)
+            self.device.train_global(
+                request.epochs,
+                request.round_no,
+                request.adaptive_threshold_value,
+                proto_to_state_dict(request.optimizer_state),
+            )
         )
         response = connection_pb2.TrainGlobalResponse(
             client_weights=Weights(weights=state_dict_to_proto(client_weights)),
@@ -557,21 +567,25 @@ class RPCDeviceServicer(DeviceServicer):
         return connection_pb2.DatasetModelInfoResponse(
             train_samples=len(self.device.client._train_data.dataset),
             validation_samples=len(self.device.client._val_data.dataset),
-            client_model_flops=int(self.device.client._model_flops),
-            server_model_flops=int(self.device.server._model_flops),
+            client_fw_flops=int(self.device.client._model_flops["FW"]),
+            server_fw_flops=int(self.device.server._model_flops["FW"]),
+            client_bw_flops=int(self.device.client._model_flops["BW"]),
+            server_bw_flops=int(self.device.server._model_flops["BW"]),
         )
 
     def TrainGlobalParallelSplitLearning(self, request, context):
         print(f"Starting parallel split learning")
         clients = self.device.__get_device_ids__()
         round_no = request.round_no
-        adaptive_learning_threshold = request.adaptive_learning_threshold
+        adaptive_threshold_value = request.adaptive_threshold_value
+        optimizer_state = proto_to_state_dict(request.optimizer_state)
 
         cw, sw, model_metrics, optimizer_state, diagnostic_metrics = (
             self.device.train_parallel_split_learning(
                 clients=clients,
                 round_no=round_no,
-                adaptive_learning_threshold=adaptive_learning_threshold,
+                adaptive_threshold_value=adaptive_threshold_value,
+                optimizer_state=optimizer_state,
             )
         )
         response = connection_pb2.TrainGlobalParallelSplitLearningResponse(
@@ -651,10 +665,10 @@ class DeviceRequestDispatcher:
         server_device_id: str,
         epochs: int,
         round_no: int,
-        adaptive_learning_threshold: Optional[float] = None,
+        adaptive_threshold_value: Optional[float] = None,
         optimizer_state: dict[str, Any] = None,
     ):
-        print(f"><><><> {adaptive_learning_threshold}")
+        print(f"><><><> {adaptive_threshold_value}")
 
         try:
             response: TrainGlobalParallelSplitLearningResponse = self._get_connection(
@@ -662,7 +676,7 @@ class DeviceRequestDispatcher:
             ).TrainGlobalParallelSplitLearning(
                 connection_pb2.TrainGlobalParallelSplitLearningRequest(
                     round_no=round_no,
-                    adaptive_learning_threshold=adaptive_learning_threshold,
+                    adaptive_threshold_value=adaptive_threshold_value,
                     optimizer_state=state_dict_to_proto(optimizer_state),
                 )
             )
@@ -759,6 +773,7 @@ class DeviceRequestDispatcher:
         device_id: str,
         epochs: int,
         round_no: int = -1,
+        adaptive_threshold_value: Optional[float] = None,
         optimizer_state: dict[str, Any] = None,
     ) -> Union[
         Tuple[
@@ -775,6 +790,7 @@ class DeviceRequestDispatcher:
                 connection_pb2.TrainGlobalRequest(
                     epochs=epochs,
                     round_no=round_no,
+                    adaptive_threshold_value=adaptive_threshold_value,
                     optimizer_state=state_dict_to_proto(optimizer_state),
                 )
             )
@@ -995,7 +1011,7 @@ class DeviceRequestDispatcher:
 
     def get_dataset_model_info_on(
         self, device_id: str
-    ) -> Union[Tuple[int, int, float, float], bool]:
+    ) -> Union[Tuple[int, int, int, int, int, int], bool]:
         try:
             response: DatasetModelInfoResponse = self._get_connection(
                 device_id
@@ -1003,8 +1019,10 @@ class DeviceRequestDispatcher:
             return (
                 response.train_samples,
                 response.validation_samples,
-                response.client_model_flops,
-                response.server_model_flops,
+                response.client_fw_flops,
+                response.server_fw_flops,
+                response.client_bw_flops,
+                response.server_bw_flops,
             )
         except grpc.RpcError:
             self._handle_rpc_error(device_id)
diff --git a/edml/core/server.py b/edml/core/server.py
index 4279f40301205c361a8648fe907d9084bd159ac5..f93f7fb3025095a93d72f08b26151db751032fc0 100644
--- a/edml/core/server.py
+++ b/edml/core/server.py
@@ -5,7 +5,6 @@ from typing import List, Optional, Tuple, Any, TYPE_CHECKING
 
 import torch
 from omegaconf import DictConfig
-from colorama import Fore
 from torch import nn
 from torch.autograd import Variable
 
@@ -20,6 +19,7 @@ from edml.helpers.metrics import (
     DiagnosticMetricResultContainer,
 )
 from edml.helpers.types import StateDict, LossFn
+from edml.models.provider.base import has_autoencoder
 
 if TYPE_CHECKING:
     from edml.core.device import Device
@@ -40,9 +40,9 @@ class DeviceServer:
         self._device = torch.device(get_torch_device_id(cfg))
         self._model = model.to(self._device)
         self._optimizer, self._lr_scheduler = get_optimizer_and_scheduler(
-            cfg, self._model.parameters()
+            cfg, self._model.get_optimizer_params()
         )
-        self._model_flops = 0  # determine later
+        self._model_flops: dict[str, int] = {"FW": 0, "BW": 0}  # determine later
         self._metrics = create_metrics(
             cfg.experiment.metrics, cfg.dataset.num_classes, cfg.dataset.average_setting
         )
@@ -50,6 +50,7 @@ class DeviceServer:
         self._cfg = cfg
         self.node_device: Optional[Device] = None
         self.latency_factor = latency_factor
+        self.adaptive_threshold_value = None
 
     def set_device(self, node_device: Device):
         """Sets the device reference for the server."""
@@ -72,6 +73,7 @@ class DeviceServer:
         devices: List[str],
         epochs: int = 1,
         round_no: int = -1,
+        adaptive_threshold_value: Optional[float] = None,
         optimizer_state: dict[str, Any] = None,
     ) -> Tuple[
         Any, Any, ModelMetricResultContainer, Any, DiagnosticMetricResultContainer
@@ -82,6 +84,7 @@ class DeviceServer:
             devices: The devices to train on
             epochs: Optionally, the number of epochs to train.
             round_no: Optionally, the current global epoch number if a learning rate scheduler is used.
+            adaptive_threshold_value: Optionally, the loss threshold to not send the gradients to the client
             optimizer_state: Optionally, the optimizer_state to proceed from
         """
         client_weights = None
@@ -89,6 +92,8 @@ class DeviceServer:
         diagnostic_metric_container = DiagnosticMetricResultContainer()
         if optimizer_state is not None:
             self._optimizer.load_state_dict(optimizer_state)
+        if adaptive_threshold_value is not None:
+            self.adaptive_threshold_value = adaptive_threshold_value
         for epoch in range(epochs):
             if self._lr_scheduler is not None:
                 if round_no != -1:
@@ -134,22 +139,26 @@ class DeviceServer:
         )
 
     @simulate_latency_decorator(latency_factor_attr="latency_factor")
-    def train_batch(self, smashed_data, labels) -> Tuple[Variable, float]:
+    def train_batch(self, smashed_data, labels) -> Tuple[Optional[Variable], float]:
         """Train the model on the given batch of data and labels.
         Returns the gradients of the model's parameters."""
         smashed_data, labels = smashed_data.to(self._device), labels.to(self._device)
 
-        self._set_model_flops(smashed_data)
+        self._set_model_flops(smashed_data[0])
 
         self._optimizer.zero_grad()
 
-        self.node_device.battery.update_flops(self._model_flops * len(smashed_data))
+        self.node_device.battery.update_flops(
+            self._model_flops["FW"] * len(smashed_data)
+        )
         smashed_data = Variable(smashed_data, requires_grad=True)
         output_train = self._model(smashed_data)
 
         loss_train = self._loss_fn(output_train, labels)
 
-        self.node_device.battery.update_flops(self._model_flops * len(smashed_data) * 2)
+        self.node_device.battery.update_flops(
+            self._model_flops["BW"] * len(smashed_data)
+        )
         loss_train.backward()
         self._optimizer.step()
 
@@ -157,14 +166,27 @@ class DeviceServer:
         self.node_device.log({"loss": loss_train.item()})
         self._metrics.metrics_on_batch(output_train.cpu(), labels.cpu().int())
 
-        return smashed_data.grad, loss_train.item()
+        if has_autoencoder(self._model):
+            gradients = self._model.trainable_layers_input.grad
+        else:
+            gradients = smashed_data.grad
+        if (
+            self.adaptive_threshold_value
+            and loss_train.item() < self.adaptive_threshold_value
+        ):
+            self.node_device.log(
+                {"adaptive_learning_threshold_applied": gradients.size(0)}
+            )
+            return None, loss_train.item()
+        return gradients, loss_train.item()
 
-    def _set_model_flops(self, smashed_data):
+    def _set_model_flops(self, sample):
         """Helper to determine the model flops when smashed data are available for the first time."""
-        if self._model_flops == 0:
-            self._model_flops = estimate_model_flops(self._model, smashed_data) / len(
-                smashed_data
+        if self._model_flops["FW"] == 0 or self._model_flops["BW"] == 0:
+            self._model_flops = estimate_model_flops(
+                self._model, sample.to(self._device).unsqueeze(0)
             )
+            self.node_device.logger.log({"server_model_flops": self._model_flops})
 
     @simulate_latency_decorator(latency_factor_attr="latency_factor")
     def finalize_metrics(self, device_id: str, phase: str):
@@ -205,8 +227,10 @@ class DeviceServer:
         """Evaluates the model on the given batch of data and labels"""
         with torch.no_grad():
             smashed_data = smashed_data.to(self._device)
-            self._set_model_flops(smashed_data)
-            self.node_device.battery.update_flops(self._model_flops * len(smashed_data))
+            self._set_model_flops(smashed_data[0])
+            self.node_device.battery.update_flops(
+                self._model_flops["FW"] * len(smashed_data)
+            )
             pred = self._model(smashed_data)
         self._metrics.metrics_on_batch(pred.cpu(), labels.cpu().int())
 
@@ -215,7 +239,7 @@ class DeviceServer:
         self,
         clients: List[str],
         round_no: int,
-        adaptive_learning_threshold: Optional[float] = None,
+        adaptive_threshold_value: Optional[float] = None,
         optimizer_state: dict[str, Any] = None,
     ):
         def client_training_job(client_id: str, batch_index: int):
@@ -249,7 +273,8 @@ class DeviceServer:
                 self._lr_scheduler.step(round_no + 1)  # epoch=1
             else:
                 self._lr_scheduler.step()
-
+        if adaptive_threshold_value is not None:
+            self.adaptive_threshold_value = adaptive_threshold_value
         num_threads = len(clients)
         executor = create_executor_with_threads(num_threads)
 
@@ -299,25 +324,9 @@ class DeviceServer:
                 server_gradients, server_loss, server_metrics = (
                     self.node_device.train_batch(server_batch, server_labels)
                 )  # DiagnosticMetricResultContainer
-                # We check if the server should activate the adaptive learning threshold. And if true, we make sure to only
-                # do the client propagation once the current loss value is larger than the threshold.
-                print(
-                    f"\n{Fore.GREEN}{adaptive_learning_threshold} <-> {server_loss}\n{Fore.RESET}"
-                )
                 if (
-                    adaptive_learning_threshold
-                    and server_loss < adaptive_learning_threshold
-                ):
-                    print(
-                        f"\n{Fore.RED}ADAPTIVE TRESHOLD REACHED, NEXT BATCH\n{Fore.RESET}"
-                    )
-                    self.node_device.log(
-                        {
-                            "adaptive_learning_threshold_applied": server_gradients.size(
-                                0
-                            )
-                        }
-                    )
+                    server_gradients is None
+                ):  # loss threshold was reached, skip client backprop
                     continue
 
                 num_client_gradients = len(client_forward_pass_responses)
@@ -397,25 +406,9 @@ class DeviceServer:
                 server_gradients, server_loss, server_metrics = (
                     self.node_device.train_batch(server_batch, server_labels)
                 )  # DiagnosticMetricResultContainer
-                # We check if the server should activate the adaptive learning threshold. And if true, we make sure to only
-                # do the client propagation once the current loss value is larger than the threshold.
-                print(
-                    f"\n{Fore.GREEN}{adaptive_learning_threshold} <-> {server_loss}\n{Fore.RESET}"
-                )
                 if (
-                    adaptive_learning_threshold
-                    and server_loss < adaptive_learning_threshold
-                ):
-                    print(
-                        f"\n{Fore.RED}ADAPTIVE TRESHOLD REACHED, NEXT BATCH\n{Fore.RESET}"
-                    )
-                    self.node_device.log(
-                        {
-                            "adaptive_learning_threshold_applied": server_gradients.size(
-                                0
-                            )
-                        }
-                    )
+                    server_gradients is None
+                ):  # loss threshold was reached, skip client backprop
                     continue
 
                 num_client_gradients = len(client_forward_pass_responses)
diff --git a/edml/generated/connection_pb2.py b/edml/generated/connection_pb2.py
index f3c72442f44b99d587925a030002e1af92924ac8..3c9a0d7b48b4739017066ae0dfedaa636dd64a62 100644
--- a/edml/generated/connection_pb2.py
+++ b/edml/generated/connection_pb2.py
@@ -14,7 +14,7 @@ _sym_db = _symbol_database.Default()
 import datastructures_pb2 as datastructures__pb2
 
 
-DESCRIPTOR = _descriptor_pool.Default().AddSerializedFile(b'\n\x10\x63onnection.proto\x1a\x14\x64\x61tastructures.proto\"4\n\x13SetGradientsRequest\x12\x1d\n\tgradients\x18\x01 \x01(\x0b\x32\n.Gradients\"5\n\x14UpdateWeightsRequest\x12\x1d\n\tgradients\x18\x01 \x01(\x0b\x32\n.Gradients\";\n\x1aSingleBatchBackwardRequest\x12\x1d\n\tgradients\x18\x01 \x01(\x0b\x32\n.Gradients\"j\n\x1bSingleBatchBackwardResponse\x12\x19\n\x07metrics\x18\x01 \x01(\x0b\x32\x08.Metrics\x12\"\n\tgradients\x18\x02 \x01(\x0b\x32\n.GradientsH\x00\x88\x01\x01\x42\x0c\n\n_gradients\"C\n\x1aSingleBatchTrainingRequest\x12\x13\n\x0b\x62\x61tch_index\x18\x01 \x01(\x05\x12\x10\n\x08round_no\x18\x02 \x01(\x05\"\x80\x01\n\x1bSingleBatchTrainingResponse\x12\'\n\x0csmashed_data\x18\x01 \x01(\x0b\x32\x0c.ActivationsH\x00\x88\x01\x01\x12\x1c\n\x06labels\x18\x02 \x01(\x0b\x32\x07.LabelsH\x01\x88\x01\x01\x42\x0f\n\r_smashed_dataB\t\n\x07_labels\"\xd5\x01\n\'TrainGlobalParallelSplitLearningRequest\x12\x15\n\x08round_no\x18\x01 \x01(\x05H\x00\x88\x01\x01\x12(\n\x1b\x61\x64\x61ptive_learning_threshold\x18\x02 \x01(\x01H\x01\x88\x01\x01\x12(\n\x0foptimizer_state\x18\x03 \x01(\x0b\x32\n.StateDictH\x02\x88\x01\x01\x42\x0b\n\t_round_noB\x1e\n\x1c_adaptive_learning_thresholdB\x12\n\x10_optimizer_state\"\x89\x02\n(TrainGlobalParallelSplitLearningResponse\x12 \n\x0e\x63lient_weights\x18\x01 \x01(\x0b\x32\x08.Weights\x12 \n\x0eserver_weights\x18\x02 \x01(\x0b\x32\x08.Weights\x12\x19\n\x07metrics\x18\x03 \x01(\x0b\x32\x08.Metrics\x12(\n\x0foptimizer_state\x18\x04 \x01(\x0b\x32\n.StateDictH\x00\x88\x01\x01\x12)\n\x12\x64iagnostic_metrics\x18\x05 \x01(\x0b\x32\x08.MetricsH\x01\x88\x01\x01\x42\x12\n\x10_optimizer_stateB\x15\n\x13_diagnostic_metrics\"\x86\x01\n\x12TrainGlobalRequest\x12\x0e\n\x06\x65pochs\x18\x01 \x01(\x05\x12\x15\n\x08round_no\x18\x02 \x01(\x05H\x00\x88\x01\x01\x12(\n\x0foptimizer_state\x18\x03 \x01(\x0b\x32\n.StateDictH\x01\x88\x01\x01\x42\x0b\n\t_round_noB\x12\n\x10_optimizer_state\"\xf4\x01\n\x13TrainGlobalResponse\x12 \n\x0e\x63lient_weights\x18\x01 \x01(\x0b\x32\x08.Weights\x12 \n\x0eserver_weights\x18\x02 \x01(\x0b\x32\x08.Weights\x12\x19\n\x07metrics\x18\x03 \x01(\x0b\x32\x08.Metrics\x12(\n\x0foptimizer_state\x18\x04 \x01(\x0b\x32\n.StateDictH\x00\x88\x01\x01\x12)\n\x12\x64iagnostic_metrics\x18\x05 \x01(\x0b\x32\x08.MetricsH\x01\x88\x01\x01\x42\x12\n\x10_optimizer_stateB\x15\n\x13_diagnostic_metrics\"A\n\x11SetWeightsRequest\x12\x19\n\x07weights\x18\x01 \x01(\x0b\x32\x08.Weights\x12\x11\n\ton_client\x18\x02 \x01(\x08\"V\n\x12SetWeightsResponse\x12)\n\x12\x64iagnostic_metrics\x18\x01 \x01(\x0b\x32\x08.MetricsH\x00\x88\x01\x01\x42\x15\n\x13_diagnostic_metrics\"T\n\x11TrainEpochRequest\x12\x1b\n\x06server\x18\x01 \x01(\x0b\x32\x0b.DeviceInfo\x12\x15\n\x08round_no\x18\x02 \x01(\x05H\x00\x88\x01\x01\x42\x0b\n\t_round_no\"q\n\x12TrainEpochResponse\x12\x19\n\x07weights\x18\x01 \x01(\x0b\x32\x08.Weights\x12)\n\x12\x64iagnostic_metrics\x18\x02 \x01(\x0b\x32\x08.MetricsH\x00\x88\x01\x01\x42\x15\n\x13_diagnostic_metrics\"P\n\x11TrainBatchRequest\x12\"\n\x0csmashed_data\x18\x01 \x01(\x0b\x32\x0c.Activations\x12\x17\n\x06labels\x18\x02 \x01(\x0b\x32\x07.Labels\"\x91\x01\n\x12TrainBatchResponse\x12\x1d\n\tgradients\x18\x01 \x01(\x0b\x32\n.Gradients\x12)\n\x12\x64iagnostic_metrics\x18\x02 \x01(\x0b\x32\x08.MetricsH\x00\x88\x01\x01\x12\x11\n\x04loss\x18\x03 \x01(\x01H\x01\x88\x01\x01\x42\x15\n\x13_diagnostic_metricsB\x07\n\x05_loss\":\n\x11\x45valGlobalRequest\x12\x12\n\nvalidation\x18\x01 \x01(\x08\x12\x11\n\tfederated\x18\x02 \x01(\x08\"q\n\x12\x45valGlobalResponse\x12\x19\n\x07metrics\x18\x01 \x01(\x0b\x32\x08.Metrics\x12)\n\x12\x64iagnostic_metrics\x18\x02 \x01(\x0b\x32\x08.MetricsH\x00\x88\x01\x01\x42\x15\n\x13_diagnostic_metrics\">\n\x0b\x45valRequest\x12\x1b\n\x06server\x18\x01 \x01(\x0b\x32\x0b.DeviceInfo\x12\x12\n\nvalidation\x18\x02 \x01(\x08\"P\n\x0c\x45valResponse\x12)\n\x12\x64iagnostic_metrics\x18\x01 \x01(\x0b\x32\x08.MetricsH\x00\x88\x01\x01\x42\x15\n\x13_diagnostic_metrics\"O\n\x10\x45valBatchRequest\x12\"\n\x0csmashed_data\x18\x01 \x01(\x0b\x32\x0c.Activations\x12\x17\n\x06labels\x18\x02 \x01(\x0b\x32\x07.Labels\"p\n\x11\x45valBatchResponse\x12\x19\n\x07metrics\x18\x01 \x01(\x0b\x32\x08.Metrics\x12)\n\x12\x64iagnostic_metrics\x18\x02 \x01(\x0b\x32\x08.MetricsH\x00\x88\x01\x01\x42\x15\n\x13_diagnostic_metrics\";\n\x15\x46ullModelTrainRequest\x12\x15\n\x08round_no\x18\x01 \x01(\x05H\x00\x88\x01\x01\x42\x0b\n\t_round_no\"\xce\x01\n\x16\x46ullModelTrainResponse\x12 \n\x0e\x63lient_weights\x18\x01 \x01(\x0b\x32\x08.Weights\x12 \n\x0eserver_weights\x18\x02 \x01(\x0b\x32\x08.Weights\x12\x13\n\x0bnum_samples\x18\x03 \x01(\x05\x12\x19\n\x07metrics\x18\x04 \x01(\x0b\x32\x08.Metrics\x12)\n\x12\x64iagnostic_metrics\x18\x05 \x01(\x0b\x32\x08.MetricsH\x00\x88\x01\x01\x42\x15\n\x13_diagnostic_metrics\"\x18\n\x16StartExperimentRequest\"[\n\x17StartExperimentResponse\x12)\n\x12\x64iagnostic_metrics\x18\x01 \x01(\x0b\x32\x08.MetricsH\x00\x88\x01\x01\x42\x15\n\x13_diagnostic_metrics\"\x16\n\x14\x45ndExperimentRequest\"Y\n\x15\x45ndExperimentResponse\x12)\n\x12\x64iagnostic_metrics\x18\x01 \x01(\x0b\x32\x08.MetricsH\x00\x88\x01\x01\x42\x15\n\x13_diagnostic_metrics\"\x16\n\x14\x42\x61tteryStatusRequest\"y\n\x15\x42\x61tteryStatusResponse\x12\x1e\n\x06status\x18\x01 \x01(\x0b\x32\x0e.BatteryStatus\x12)\n\x12\x64iagnostic_metrics\x18\x02 \x01(\x0b\x32\x08.MetricsH\x00\x88\x01\x01\x42\x15\n\x13_diagnostic_metrics\"\x19\n\x17\x44\x61tasetModelInfoRequest\"\xc7\x01\n\x18\x44\x61tasetModelInfoResponse\x12\x15\n\rtrain_samples\x18\x01 \x01(\x05\x12\x1a\n\x12validation_samples\x18\x02 \x01(\x05\x12\x1a\n\x12\x63lient_model_flops\x18\x03 \x01(\x05\x12\x1a\n\x12server_model_flops\x18\x04 \x01(\x05\x12)\n\x12\x64iagnostic_metrics\x18\x05 \x01(\x0b\x32\x08.MetricsH\x00\x88\x01\x01\x42\x15\n\x13_diagnostic_metrics2\xf8\x08\n\x06\x44\x65vice\x12:\n\x0bTrainGlobal\x12\x13.TrainGlobalRequest\x1a\x14.TrainGlobalResponse\"\x00\x12\x37\n\nSetWeights\x12\x12.SetWeightsRequest\x1a\x13.SetWeightsResponse\"\x00\x12\x37\n\nTrainEpoch\x12\x12.TrainEpochRequest\x1a\x13.TrainEpochResponse\"\x00\x12\x37\n\nTrainBatch\x12\x12.TrainBatchRequest\x1a\x13.TrainBatchResponse\"\x00\x12;\n\x0e\x45valuateGlobal\x12\x12.EvalGlobalRequest\x1a\x13.EvalGlobalResponse\"\x00\x12)\n\x08\x45valuate\x12\x0c.EvalRequest\x1a\r.EvalResponse\"\x00\x12\x38\n\rEvaluateBatch\x12\x11.EvalBatchRequest\x1a\x12.EvalBatchResponse\"\x00\x12\x46\n\x11\x46ullModelTraining\x12\x16.FullModelTrainRequest\x1a\x17.FullModelTrainResponse\"\x00\x12\x46\n\x0fStartExperiment\x12\x17.StartExperimentRequest\x1a\x18.StartExperimentResponse\"\x00\x12@\n\rEndExperiment\x12\x15.EndExperimentRequest\x1a\x16.EndExperimentResponse\"\x00\x12\x43\n\x10GetBatteryStatus\x12\x15.BatteryStatusRequest\x1a\x16.BatteryStatusResponse\"\x00\x12L\n\x13GetDatasetModelInfo\x12\x18.DatasetModelInfoRequest\x1a\x19.DatasetModelInfoResponse\"\x00\x12y\n TrainGlobalParallelSplitLearning\x12(.TrainGlobalParallelSplitLearningRequest\x1a).TrainGlobalParallelSplitLearningResponse\"\x00\x12W\n\x18TrainSingleBatchOnClient\x12\x1b.SingleBatchTrainingRequest\x1a\x1c.SingleBatchTrainingResponse\"\x00\x12\x65\n&BackwardPropagationSingleBatchOnClient\x12\x1b.SingleBatchBackwardRequest\x1a\x1c.SingleBatchBackwardResponse\"\x00\x12\x45\n#SetGradientsAndFinalizeTrainingStep\x12\x14.SetGradientsRequest\x1a\x06.Empty\"\x00\x62\x06proto3')
+DESCRIPTOR = _descriptor_pool.Default().AddSerializedFile(b'\n\x10\x63onnection.proto\x1a\x14\x64\x61tastructures.proto\"4\n\x13SetGradientsRequest\x12\x1d\n\tgradients\x18\x01 \x01(\x0b\x32\n.Gradients\"5\n\x14UpdateWeightsRequest\x12\x1d\n\tgradients\x18\x01 \x01(\x0b\x32\n.Gradients\";\n\x1aSingleBatchBackwardRequest\x12\x1d\n\tgradients\x18\x01 \x01(\x0b\x32\n.Gradients\"j\n\x1bSingleBatchBackwardResponse\x12\x19\n\x07metrics\x18\x01 \x01(\x0b\x32\x08.Metrics\x12\"\n\tgradients\x18\x02 \x01(\x0b\x32\n.GradientsH\x00\x88\x01\x01\x42\x0c\n\n_gradients\"C\n\x1aSingleBatchTrainingRequest\x12\x13\n\x0b\x62\x61tch_index\x18\x01 \x01(\x05\x12\x10\n\x08round_no\x18\x02 \x01(\x05\"\x80\x01\n\x1bSingleBatchTrainingResponse\x12\'\n\x0csmashed_data\x18\x01 \x01(\x0b\x32\x0c.ActivationsH\x00\x88\x01\x01\x12\x1c\n\x06labels\x18\x02 \x01(\x0b\x32\x07.LabelsH\x01\x88\x01\x01\x42\x0f\n\r_smashed_dataB\t\n\x07_labels\"\xcf\x01\n\'TrainGlobalParallelSplitLearningRequest\x12\x15\n\x08round_no\x18\x01 \x01(\x05H\x00\x88\x01\x01\x12%\n\x18\x61\x64\x61ptive_threshold_value\x18\x02 \x01(\x01H\x01\x88\x01\x01\x12(\n\x0foptimizer_state\x18\x03 \x01(\x0b\x32\n.StateDictH\x02\x88\x01\x01\x42\x0b\n\t_round_noB\x1b\n\x19_adaptive_threshold_valueB\x12\n\x10_optimizer_state\"\x89\x02\n(TrainGlobalParallelSplitLearningResponse\x12 \n\x0e\x63lient_weights\x18\x01 \x01(\x0b\x32\x08.Weights\x12 \n\x0eserver_weights\x18\x02 \x01(\x0b\x32\x08.Weights\x12\x19\n\x07metrics\x18\x03 \x01(\x0b\x32\x08.Metrics\x12(\n\x0foptimizer_state\x18\x04 \x01(\x0b\x32\n.StateDictH\x00\x88\x01\x01\x12)\n\x12\x64iagnostic_metrics\x18\x05 \x01(\x0b\x32\x08.MetricsH\x01\x88\x01\x01\x42\x12\n\x10_optimizer_stateB\x15\n\x13_diagnostic_metrics\"\xca\x01\n\x12TrainGlobalRequest\x12\x0e\n\x06\x65pochs\x18\x01 \x01(\x05\x12\x15\n\x08round_no\x18\x02 \x01(\x05H\x00\x88\x01\x01\x12%\n\x18\x61\x64\x61ptive_threshold_value\x18\x03 \x01(\x01H\x01\x88\x01\x01\x12(\n\x0foptimizer_state\x18\x04 \x01(\x0b\x32\n.StateDictH\x02\x88\x01\x01\x42\x0b\n\t_round_noB\x1b\n\x19_adaptive_threshold_valueB\x12\n\x10_optimizer_state\"\xf4\x01\n\x13TrainGlobalResponse\x12 \n\x0e\x63lient_weights\x18\x01 \x01(\x0b\x32\x08.Weights\x12 \n\x0eserver_weights\x18\x02 \x01(\x0b\x32\x08.Weights\x12\x19\n\x07metrics\x18\x03 \x01(\x0b\x32\x08.Metrics\x12(\n\x0foptimizer_state\x18\x04 \x01(\x0b\x32\n.StateDictH\x00\x88\x01\x01\x12)\n\x12\x64iagnostic_metrics\x18\x05 \x01(\x0b\x32\x08.MetricsH\x01\x88\x01\x01\x42\x12\n\x10_optimizer_stateB\x15\n\x13_diagnostic_metrics\"A\n\x11SetWeightsRequest\x12\x19\n\x07weights\x18\x01 \x01(\x0b\x32\x08.Weights\x12\x11\n\ton_client\x18\x02 \x01(\x08\"V\n\x12SetWeightsResponse\x12)\n\x12\x64iagnostic_metrics\x18\x01 \x01(\x0b\x32\x08.MetricsH\x00\x88\x01\x01\x42\x15\n\x13_diagnostic_metrics\"T\n\x11TrainEpochRequest\x12\x1b\n\x06server\x18\x01 \x01(\x0b\x32\x0b.DeviceInfo\x12\x15\n\x08round_no\x18\x02 \x01(\x05H\x00\x88\x01\x01\x42\x0b\n\t_round_no\"q\n\x12TrainEpochResponse\x12\x19\n\x07weights\x18\x01 \x01(\x0b\x32\x08.Weights\x12)\n\x12\x64iagnostic_metrics\x18\x02 \x01(\x0b\x32\x08.MetricsH\x00\x88\x01\x01\x42\x15\n\x13_diagnostic_metrics\"P\n\x11TrainBatchRequest\x12\"\n\x0csmashed_data\x18\x01 \x01(\x0b\x32\x0c.Activations\x12\x17\n\x06labels\x18\x02 \x01(\x0b\x32\x07.Labels\"\x91\x01\n\x12TrainBatchResponse\x12\x1d\n\tgradients\x18\x01 \x01(\x0b\x32\n.Gradients\x12)\n\x12\x64iagnostic_metrics\x18\x02 \x01(\x0b\x32\x08.MetricsH\x00\x88\x01\x01\x12\x11\n\x04loss\x18\x03 \x01(\x01H\x01\x88\x01\x01\x42\x15\n\x13_diagnostic_metricsB\x07\n\x05_loss\":\n\x11\x45valGlobalRequest\x12\x12\n\nvalidation\x18\x01 \x01(\x08\x12\x11\n\tfederated\x18\x02 \x01(\x08\"q\n\x12\x45valGlobalResponse\x12\x19\n\x07metrics\x18\x01 \x01(\x0b\x32\x08.Metrics\x12)\n\x12\x64iagnostic_metrics\x18\x02 \x01(\x0b\x32\x08.MetricsH\x00\x88\x01\x01\x42\x15\n\x13_diagnostic_metrics\">\n\x0b\x45valRequest\x12\x1b\n\x06server\x18\x01 \x01(\x0b\x32\x0b.DeviceInfo\x12\x12\n\nvalidation\x18\x02 \x01(\x08\"P\n\x0c\x45valResponse\x12)\n\x12\x64iagnostic_metrics\x18\x01 \x01(\x0b\x32\x08.MetricsH\x00\x88\x01\x01\x42\x15\n\x13_diagnostic_metrics\"O\n\x10\x45valBatchRequest\x12\"\n\x0csmashed_data\x18\x01 \x01(\x0b\x32\x0c.Activations\x12\x17\n\x06labels\x18\x02 \x01(\x0b\x32\x07.Labels\"p\n\x11\x45valBatchResponse\x12\x19\n\x07metrics\x18\x01 \x01(\x0b\x32\x08.Metrics\x12)\n\x12\x64iagnostic_metrics\x18\x02 \x01(\x0b\x32\x08.MetricsH\x00\x88\x01\x01\x42\x15\n\x13_diagnostic_metrics\";\n\x15\x46ullModelTrainRequest\x12\x15\n\x08round_no\x18\x01 \x01(\x05H\x00\x88\x01\x01\x42\x0b\n\t_round_no\"\xce\x01\n\x16\x46ullModelTrainResponse\x12 \n\x0e\x63lient_weights\x18\x01 \x01(\x0b\x32\x08.Weights\x12 \n\x0eserver_weights\x18\x02 \x01(\x0b\x32\x08.Weights\x12\x13\n\x0bnum_samples\x18\x03 \x01(\x05\x12\x19\n\x07metrics\x18\x04 \x01(\x0b\x32\x08.Metrics\x12)\n\x12\x64iagnostic_metrics\x18\x05 \x01(\x0b\x32\x08.MetricsH\x00\x88\x01\x01\x42\x15\n\x13_diagnostic_metrics\"\x18\n\x16StartExperimentRequest\"[\n\x17StartExperimentResponse\x12)\n\x12\x64iagnostic_metrics\x18\x01 \x01(\x0b\x32\x08.MetricsH\x00\x88\x01\x01\x42\x15\n\x13_diagnostic_metrics\"\x16\n\x14\x45ndExperimentRequest\"Y\n\x15\x45ndExperimentResponse\x12)\n\x12\x64iagnostic_metrics\x18\x01 \x01(\x0b\x32\x08.MetricsH\x00\x88\x01\x01\x42\x15\n\x13_diagnostic_metrics\"\x16\n\x14\x42\x61tteryStatusRequest\"y\n\x15\x42\x61tteryStatusResponse\x12\x1e\n\x06status\x18\x01 \x01(\x0b\x32\x0e.BatteryStatus\x12)\n\x12\x64iagnostic_metrics\x18\x02 \x01(\x0b\x32\x08.MetricsH\x00\x88\x01\x01\x42\x15\n\x13_diagnostic_metrics\"\x19\n\x17\x44\x61tasetModelInfoRequest\"\xa5\x02\n\x18\x44\x61tasetModelInfoResponse\x12\x15\n\rtrain_samples\x18\x01 \x01(\x05\x12\x1a\n\x12validation_samples\x18\x02 \x01(\x05\x12\x17\n\x0f\x63lient_fw_flops\x18\x03 \x01(\x05\x12\x17\n\x0fserver_fw_flops\x18\x04 \x01(\x05\x12\x1c\n\x0f\x63lient_bw_flops\x18\x05 \x01(\x05H\x00\x88\x01\x01\x12\x1c\n\x0fserver_bw_flops\x18\x06 \x01(\x05H\x01\x88\x01\x01\x12)\n\x12\x64iagnostic_metrics\x18\x07 \x01(\x0b\x32\x08.MetricsH\x02\x88\x01\x01\x42\x12\n\x10_client_bw_flopsB\x12\n\x10_server_bw_flopsB\x15\n\x13_diagnostic_metrics2\xf8\x08\n\x06\x44\x65vice\x12:\n\x0bTrainGlobal\x12\x13.TrainGlobalRequest\x1a\x14.TrainGlobalResponse\"\x00\x12\x37\n\nSetWeights\x12\x12.SetWeightsRequest\x1a\x13.SetWeightsResponse\"\x00\x12\x37\n\nTrainEpoch\x12\x12.TrainEpochRequest\x1a\x13.TrainEpochResponse\"\x00\x12\x37\n\nTrainBatch\x12\x12.TrainBatchRequest\x1a\x13.TrainBatchResponse\"\x00\x12;\n\x0e\x45valuateGlobal\x12\x12.EvalGlobalRequest\x1a\x13.EvalGlobalResponse\"\x00\x12)\n\x08\x45valuate\x12\x0c.EvalRequest\x1a\r.EvalResponse\"\x00\x12\x38\n\rEvaluateBatch\x12\x11.EvalBatchRequest\x1a\x12.EvalBatchResponse\"\x00\x12\x46\n\x11\x46ullModelTraining\x12\x16.FullModelTrainRequest\x1a\x17.FullModelTrainResponse\"\x00\x12\x46\n\x0fStartExperiment\x12\x17.StartExperimentRequest\x1a\x18.StartExperimentResponse\"\x00\x12@\n\rEndExperiment\x12\x15.EndExperimentRequest\x1a\x16.EndExperimentResponse\"\x00\x12\x43\n\x10GetBatteryStatus\x12\x15.BatteryStatusRequest\x1a\x16.BatteryStatusResponse\"\x00\x12L\n\x13GetDatasetModelInfo\x12\x18.DatasetModelInfoRequest\x1a\x19.DatasetModelInfoResponse\"\x00\x12y\n TrainGlobalParallelSplitLearning\x12(.TrainGlobalParallelSplitLearningRequest\x1a).TrainGlobalParallelSplitLearningResponse\"\x00\x12W\n\x18TrainSingleBatchOnClient\x12\x1b.SingleBatchTrainingRequest\x1a\x1c.SingleBatchTrainingResponse\"\x00\x12\x65\n&BackwardPropagationSingleBatchOnClient\x12\x1b.SingleBatchBackwardRequest\x1a\x1c.SingleBatchBackwardResponse\"\x00\x12\x45\n#SetGradientsAndFinalizeTrainingStep\x12\x14.SetGradientsRequest\x1a\x06.Empty\"\x00\x62\x06proto3')
 
 _globals = globals()
 _builder.BuildMessageAndEnumDescriptors(DESCRIPTOR, _globals)
@@ -35,57 +35,57 @@ if _descriptor._USE_C_DESCRIPTORS == False:
   _globals['_SINGLEBATCHTRAININGRESPONSE']._serialized_start=390
   _globals['_SINGLEBATCHTRAININGRESPONSE']._serialized_end=518
   _globals['_TRAINGLOBALPARALLELSPLITLEARNINGREQUEST']._serialized_start=521
-  _globals['_TRAINGLOBALPARALLELSPLITLEARNINGREQUEST']._serialized_end=734
-  _globals['_TRAINGLOBALPARALLELSPLITLEARNINGRESPONSE']._serialized_start=737
-  _globals['_TRAINGLOBALPARALLELSPLITLEARNINGRESPONSE']._serialized_end=1002
-  _globals['_TRAINGLOBALREQUEST']._serialized_start=1005
-  _globals['_TRAINGLOBALREQUEST']._serialized_end=1139
-  _globals['_TRAINGLOBALRESPONSE']._serialized_start=1142
-  _globals['_TRAINGLOBALRESPONSE']._serialized_end=1386
-  _globals['_SETWEIGHTSREQUEST']._serialized_start=1388
-  _globals['_SETWEIGHTSREQUEST']._serialized_end=1453
-  _globals['_SETWEIGHTSRESPONSE']._serialized_start=1455
-  _globals['_SETWEIGHTSRESPONSE']._serialized_end=1541
-  _globals['_TRAINEPOCHREQUEST']._serialized_start=1543
-  _globals['_TRAINEPOCHREQUEST']._serialized_end=1627
-  _globals['_TRAINEPOCHRESPONSE']._serialized_start=1629
-  _globals['_TRAINEPOCHRESPONSE']._serialized_end=1742
-  _globals['_TRAINBATCHREQUEST']._serialized_start=1744
-  _globals['_TRAINBATCHREQUEST']._serialized_end=1824
-  _globals['_TRAINBATCHRESPONSE']._serialized_start=1827
-  _globals['_TRAINBATCHRESPONSE']._serialized_end=1972
-  _globals['_EVALGLOBALREQUEST']._serialized_start=1974
-  _globals['_EVALGLOBALREQUEST']._serialized_end=2032
-  _globals['_EVALGLOBALRESPONSE']._serialized_start=2034
-  _globals['_EVALGLOBALRESPONSE']._serialized_end=2147
-  _globals['_EVALREQUEST']._serialized_start=2149
-  _globals['_EVALREQUEST']._serialized_end=2211
-  _globals['_EVALRESPONSE']._serialized_start=2213
-  _globals['_EVALRESPONSE']._serialized_end=2293
-  _globals['_EVALBATCHREQUEST']._serialized_start=2295
-  _globals['_EVALBATCHREQUEST']._serialized_end=2374
-  _globals['_EVALBATCHRESPONSE']._serialized_start=2376
-  _globals['_EVALBATCHRESPONSE']._serialized_end=2488
-  _globals['_FULLMODELTRAINREQUEST']._serialized_start=2490
-  _globals['_FULLMODELTRAINREQUEST']._serialized_end=2549
-  _globals['_FULLMODELTRAINRESPONSE']._serialized_start=2552
-  _globals['_FULLMODELTRAINRESPONSE']._serialized_end=2758
-  _globals['_STARTEXPERIMENTREQUEST']._serialized_start=2760
-  _globals['_STARTEXPERIMENTREQUEST']._serialized_end=2784
-  _globals['_STARTEXPERIMENTRESPONSE']._serialized_start=2786
-  _globals['_STARTEXPERIMENTRESPONSE']._serialized_end=2877
-  _globals['_ENDEXPERIMENTREQUEST']._serialized_start=2879
-  _globals['_ENDEXPERIMENTREQUEST']._serialized_end=2901
-  _globals['_ENDEXPERIMENTRESPONSE']._serialized_start=2903
-  _globals['_ENDEXPERIMENTRESPONSE']._serialized_end=2992
-  _globals['_BATTERYSTATUSREQUEST']._serialized_start=2994
-  _globals['_BATTERYSTATUSREQUEST']._serialized_end=3016
-  _globals['_BATTERYSTATUSRESPONSE']._serialized_start=3018
-  _globals['_BATTERYSTATUSRESPONSE']._serialized_end=3139
-  _globals['_DATASETMODELINFOREQUEST']._serialized_start=3141
-  _globals['_DATASETMODELINFOREQUEST']._serialized_end=3166
-  _globals['_DATASETMODELINFORESPONSE']._serialized_start=3169
-  _globals['_DATASETMODELINFORESPONSE']._serialized_end=3368
-  _globals['_DEVICE']._serialized_start=3371
-  _globals['_DEVICE']._serialized_end=4515
+  _globals['_TRAINGLOBALPARALLELSPLITLEARNINGREQUEST']._serialized_end=728
+  _globals['_TRAINGLOBALPARALLELSPLITLEARNINGRESPONSE']._serialized_start=731
+  _globals['_TRAINGLOBALPARALLELSPLITLEARNINGRESPONSE']._serialized_end=996
+  _globals['_TRAINGLOBALREQUEST']._serialized_start=999
+  _globals['_TRAINGLOBALREQUEST']._serialized_end=1201
+  _globals['_TRAINGLOBALRESPONSE']._serialized_start=1204
+  _globals['_TRAINGLOBALRESPONSE']._serialized_end=1448
+  _globals['_SETWEIGHTSREQUEST']._serialized_start=1450
+  _globals['_SETWEIGHTSREQUEST']._serialized_end=1515
+  _globals['_SETWEIGHTSRESPONSE']._serialized_start=1517
+  _globals['_SETWEIGHTSRESPONSE']._serialized_end=1603
+  _globals['_TRAINEPOCHREQUEST']._serialized_start=1605
+  _globals['_TRAINEPOCHREQUEST']._serialized_end=1689
+  _globals['_TRAINEPOCHRESPONSE']._serialized_start=1691
+  _globals['_TRAINEPOCHRESPONSE']._serialized_end=1804
+  _globals['_TRAINBATCHREQUEST']._serialized_start=1806
+  _globals['_TRAINBATCHREQUEST']._serialized_end=1886
+  _globals['_TRAINBATCHRESPONSE']._serialized_start=1889
+  _globals['_TRAINBATCHRESPONSE']._serialized_end=2034
+  _globals['_EVALGLOBALREQUEST']._serialized_start=2036
+  _globals['_EVALGLOBALREQUEST']._serialized_end=2094
+  _globals['_EVALGLOBALRESPONSE']._serialized_start=2096
+  _globals['_EVALGLOBALRESPONSE']._serialized_end=2209
+  _globals['_EVALREQUEST']._serialized_start=2211
+  _globals['_EVALREQUEST']._serialized_end=2273
+  _globals['_EVALRESPONSE']._serialized_start=2275
+  _globals['_EVALRESPONSE']._serialized_end=2355
+  _globals['_EVALBATCHREQUEST']._serialized_start=2357
+  _globals['_EVALBATCHREQUEST']._serialized_end=2436
+  _globals['_EVALBATCHRESPONSE']._serialized_start=2438
+  _globals['_EVALBATCHRESPONSE']._serialized_end=2550
+  _globals['_FULLMODELTRAINREQUEST']._serialized_start=2552
+  _globals['_FULLMODELTRAINREQUEST']._serialized_end=2611
+  _globals['_FULLMODELTRAINRESPONSE']._serialized_start=2614
+  _globals['_FULLMODELTRAINRESPONSE']._serialized_end=2820
+  _globals['_STARTEXPERIMENTREQUEST']._serialized_start=2822
+  _globals['_STARTEXPERIMENTREQUEST']._serialized_end=2846
+  _globals['_STARTEXPERIMENTRESPONSE']._serialized_start=2848
+  _globals['_STARTEXPERIMENTRESPONSE']._serialized_end=2939
+  _globals['_ENDEXPERIMENTREQUEST']._serialized_start=2941
+  _globals['_ENDEXPERIMENTREQUEST']._serialized_end=2963
+  _globals['_ENDEXPERIMENTRESPONSE']._serialized_start=2965
+  _globals['_ENDEXPERIMENTRESPONSE']._serialized_end=3054
+  _globals['_BATTERYSTATUSREQUEST']._serialized_start=3056
+  _globals['_BATTERYSTATUSREQUEST']._serialized_end=3078
+  _globals['_BATTERYSTATUSRESPONSE']._serialized_start=3080
+  _globals['_BATTERYSTATUSRESPONSE']._serialized_end=3201
+  _globals['_DATASETMODELINFOREQUEST']._serialized_start=3203
+  _globals['_DATASETMODELINFOREQUEST']._serialized_end=3228
+  _globals['_DATASETMODELINFORESPONSE']._serialized_start=3231
+  _globals['_DATASETMODELINFORESPONSE']._serialized_end=3524
+  _globals['_DEVICE']._serialized_start=3527
+  _globals['_DEVICE']._serialized_end=4671
 # @@protoc_insertion_point(module_scope)
diff --git a/edml/generated/connection_pb2.pyi b/edml/generated/connection_pb2.pyi
index 89353343aa1c7e39072bd0ea03c891bd2df7b4df..bc0c09189004803b0d556d0c7eaeeecaf82922e0 100644
--- a/edml/generated/connection_pb2.pyi
+++ b/edml/generated/connection_pb2.pyi
@@ -48,14 +48,14 @@ class SingleBatchTrainingResponse(_message.Message):
     def __init__(self, smashed_data: _Optional[_Union[_datastructures_pb2.Activations, _Mapping]] = ..., labels: _Optional[_Union[_datastructures_pb2.Labels, _Mapping]] = ...) -> None: ...
 
 class TrainGlobalParallelSplitLearningRequest(_message.Message):
-    __slots__ = ["round_no", "adaptive_learning_threshold", "optimizer_state"]
+    __slots__ = ["round_no", "adaptive_threshold_value", "optimizer_state"]
     ROUND_NO_FIELD_NUMBER: _ClassVar[int]
-    ADAPTIVE_LEARNING_THRESHOLD_FIELD_NUMBER: _ClassVar[int]
+    ADAPTIVE_THRESHOLD_VALUE_FIELD_NUMBER: _ClassVar[int]
     OPTIMIZER_STATE_FIELD_NUMBER: _ClassVar[int]
     round_no: int
-    adaptive_learning_threshold: float
+    adaptive_threshold_value: float
     optimizer_state: _datastructures_pb2.StateDict
-    def __init__(self, round_no: _Optional[int] = ..., adaptive_learning_threshold: _Optional[float] = ..., optimizer_state: _Optional[_Union[_datastructures_pb2.StateDict, _Mapping]] = ...) -> None: ...
+    def __init__(self, round_no: _Optional[int] = ..., adaptive_threshold_value: _Optional[float] = ..., optimizer_state: _Optional[_Union[_datastructures_pb2.StateDict, _Mapping]] = ...) -> None: ...
 
 class TrainGlobalParallelSplitLearningResponse(_message.Message):
     __slots__ = ["client_weights", "server_weights", "metrics", "optimizer_state", "diagnostic_metrics"]
@@ -72,14 +72,16 @@ class TrainGlobalParallelSplitLearningResponse(_message.Message):
     def __init__(self, client_weights: _Optional[_Union[_datastructures_pb2.Weights, _Mapping]] = ..., server_weights: _Optional[_Union[_datastructures_pb2.Weights, _Mapping]] = ..., metrics: _Optional[_Union[_datastructures_pb2.Metrics, _Mapping]] = ..., optimizer_state: _Optional[_Union[_datastructures_pb2.StateDict, _Mapping]] = ..., diagnostic_metrics: _Optional[_Union[_datastructures_pb2.Metrics, _Mapping]] = ...) -> None: ...
 
 class TrainGlobalRequest(_message.Message):
-    __slots__ = ["epochs", "round_no", "optimizer_state"]
+    __slots__ = ["epochs", "round_no", "adaptive_threshold_value", "optimizer_state"]
     EPOCHS_FIELD_NUMBER: _ClassVar[int]
     ROUND_NO_FIELD_NUMBER: _ClassVar[int]
+    ADAPTIVE_THRESHOLD_VALUE_FIELD_NUMBER: _ClassVar[int]
     OPTIMIZER_STATE_FIELD_NUMBER: _ClassVar[int]
     epochs: int
     round_no: int
+    adaptive_threshold_value: float
     optimizer_state: _datastructures_pb2.StateDict
-    def __init__(self, epochs: _Optional[int] = ..., round_no: _Optional[int] = ..., optimizer_state: _Optional[_Union[_datastructures_pb2.StateDict, _Mapping]] = ...) -> None: ...
+    def __init__(self, epochs: _Optional[int] = ..., round_no: _Optional[int] = ..., adaptive_threshold_value: _Optional[float] = ..., optimizer_state: _Optional[_Union[_datastructures_pb2.StateDict, _Mapping]] = ...) -> None: ...
 
 class TrainGlobalResponse(_message.Message):
     __slots__ = ["client_weights", "server_weights", "metrics", "optimizer_state", "diagnostic_metrics"]
@@ -246,15 +248,19 @@ class DatasetModelInfoRequest(_message.Message):
     def __init__(self) -> None: ...
 
 class DatasetModelInfoResponse(_message.Message):
-    __slots__ = ["train_samples", "validation_samples", "client_model_flops", "server_model_flops", "diagnostic_metrics"]
+    __slots__ = ["train_samples", "validation_samples", "client_fw_flops", "server_fw_flops", "client_bw_flops", "server_bw_flops", "diagnostic_metrics"]
     TRAIN_SAMPLES_FIELD_NUMBER: _ClassVar[int]
     VALIDATION_SAMPLES_FIELD_NUMBER: _ClassVar[int]
-    CLIENT_MODEL_FLOPS_FIELD_NUMBER: _ClassVar[int]
-    SERVER_MODEL_FLOPS_FIELD_NUMBER: _ClassVar[int]
+    CLIENT_FW_FLOPS_FIELD_NUMBER: _ClassVar[int]
+    SERVER_FW_FLOPS_FIELD_NUMBER: _ClassVar[int]
+    CLIENT_BW_FLOPS_FIELD_NUMBER: _ClassVar[int]
+    SERVER_BW_FLOPS_FIELD_NUMBER: _ClassVar[int]
     DIAGNOSTIC_METRICS_FIELD_NUMBER: _ClassVar[int]
     train_samples: int
     validation_samples: int
-    client_model_flops: int
-    server_model_flops: int
+    client_fw_flops: int
+    server_fw_flops: int
+    client_bw_flops: int
+    server_bw_flops: int
     diagnostic_metrics: _datastructures_pb2.Metrics
-    def __init__(self, train_samples: _Optional[int] = ..., validation_samples: _Optional[int] = ..., client_model_flops: _Optional[int] = ..., server_model_flops: _Optional[int] = ..., diagnostic_metrics: _Optional[_Union[_datastructures_pb2.Metrics, _Mapping]] = ...) -> None: ...
+    def __init__(self, train_samples: _Optional[int] = ..., validation_samples: _Optional[int] = ..., client_fw_flops: _Optional[int] = ..., server_fw_flops: _Optional[int] = ..., client_bw_flops: _Optional[int] = ..., server_bw_flops: _Optional[int] = ..., diagnostic_metrics: _Optional[_Union[_datastructures_pb2.Metrics, _Mapping]] = ...) -> None: ...
diff --git a/edml/helpers/flops.py b/edml/helpers/flops.py
index 4388c9754e4c42e33a958c309f8a73a2c21f4265..038f27fd31391a8feac144da37db988c1aa93c69 100644
--- a/edml/helpers/flops.py
+++ b/edml/helpers/flops.py
@@ -3,19 +3,43 @@ from typing import Union, Tuple
 from fvcore.nn import FlopCountAnalysis
 from torch import Tensor, nn
 
+from edml.models.autoencoder import ClientWithAutoencoder, ServerWithAutoencoder
+from edml.models.provider.base import has_autoencoder
+
 
 def estimate_model_flops(
     model: nn.Module, sample: Union[Tensor, Tuple[Tensor, ...]]
-) -> int:
+) -> dict[str, int]:
     """
-    Estimates the FLOPs of one forward pass of the model using the sample data provided.
-
+    Estimates the FLOPs of one forward pass and one backward pass of the model using the sample data provided.
+    If forward and backward pass affect the same parameters, the number of FLOPs of one backward pass is assumed to be
+    two times the number of FLOPs for the backward pass.
+    In case of non-trainable parameters as an autoencoder in between, the number of backward FLOPs is assumed to be
+    twice the number of the forward FLOPs of the trainable part, while the number of actual forward FLOPs is estimated
+    using the full model.
     Args:
         model (nn.Module): the neural network model to calculate the FLOPs for.
-        sample: The data used to calculate the FLOPs.
+        sample: The data used to calculate the FLOPs. The first dimension should be the batch dimension.
 
     Returns:
-        int: the number of FLOPs.
-
+        dict(str, int): {"FW": #ForwardFLOPs, "BW": #BackwardFLOPs} the number of estimated forward and backward FLOPs.
     """
-    return FlopCountAnalysis(model, sample).total()
+    fw_flops = FlopCountAnalysis(model, sample).total()
+    if has_autoencoder(model):
+        bw_flops = 0
+        if isinstance(
+            model, ClientWithAutoencoder
+        ):  # run FW pass until Encoder and multiply by 2
+            bw_flops = FlopCountAnalysis(model.model, sample).total() * 2
+        elif isinstance(
+            model, ServerWithAutoencoder
+        ):  # run FW pass from decoder output and multiply by 2
+            bw_flops = (
+                FlopCountAnalysis(model.model, model.autoencoder(sample)).total() * 2
+            )
+        else:
+            raise NotImplementedError()
+    else:
+        bw_flops = 2 * fw_flops
+
+    return {"FW": fw_flops, "BW": bw_flops}
diff --git a/edml/models/autoencoder.py b/edml/models/autoencoder.py
index ed16f3000d694c53856e901a2634776fcf488433..7814b7e465e6ca8bb218fd4e4798760917999a9e 100644
--- a/edml/models/autoencoder.py
+++ b/edml/models/autoencoder.py
@@ -1,5 +1,5 @@
-import torch
 from torch import nn
+from torch.autograd import Variable
 
 
 class ClientWithAutoencoder(nn.Module):
@@ -7,10 +7,13 @@ class ClientWithAutoencoder(nn.Module):
         super().__init__()
         self.model = model
         self.autoencoder = autoencoder.requires_grad_(False)
+        self.trainable_layers_output = (
+            None  # needed to continue the backprop skipping the AE
+        )
 
     def forward(self, x):
-        x = self.model(x)
-        return self.autoencoder(x)
+        self.trainable_layers_output = self.model(x)
+        return self.autoencoder(self.trainable_layers_output)
 
 
 class ServerWithAutoencoder(nn.Module):
@@ -18,7 +21,13 @@ class ServerWithAutoencoder(nn.Module):
         super().__init__()
         self.model = model
         self.autoencoder = autoencoder.requires_grad_(False)
+        self.trainable_layers_input = (
+            None  # needed to stop backprop before AE and send the gradients to client
+        )
 
     def forward(self, x):
-        x = self.autoencoder(x)
-        return self.model(x)
+        self.trainable_layers_input = self.autoencoder(x)
+        self.trainable_layers_input = Variable(
+            self.trainable_layers_input, requires_grad=True
+        )
+        return self.model(self.trainable_layers_input)
diff --git a/edml/models/provider/base.py b/edml/models/provider/base.py
index e60736bae46fb8f3a2239c9ff016ddb32ce0578e..9fa81abd5a7c58e880f163d2d88eb070c7271e47 100644
--- a/edml/models/provider/base.py
+++ b/edml/models/provider/base.py
@@ -1,10 +1,33 @@
 from torch import nn
 
 
+def has_autoencoder(model: nn.Module):
+    if hasattr(model, "model") and hasattr(model, "autoencoder"):
+        return True
+    return False
+
+
+def add_optimizer_params_function(model: nn.Module):
+    # exclude AE params
+    if has_autoencoder(model):
+
+        def get_optimizer_params(model: nn.Module):
+            return model.model.parameters
+
+        model.get_optimizer_params = get_optimizer_params(model)
+    else:
+
+        def get_optimizer_params(model: nn.Module):
+            return model.parameters
+
+        model.get_optimizer_params = get_optimizer_params(model)
+    return model
+
+
 class ModelProvider:
     def __init__(self, client: nn.Module, server: nn.Module):
-        self._client = client
-        self._server = server
+        self._client = add_optimizer_params_function(client)
+        self._server = add_optimizer_params_function(server)
 
     @property
     def models(self) -> tuple[nn.Module, nn.Module]:
diff --git a/edml/proto/connection.proto b/edml/proto/connection.proto
index 5755518d01796bdff68957655d4d339dcf02ab1d..f4c7952a62b9db7225760a9b7596a03e4b9f09f4 100644
--- a/edml/proto/connection.proto
+++ b/edml/proto/connection.proto
@@ -51,7 +51,7 @@ message SingleBatchTrainingResponse {
 
 message TrainGlobalParallelSplitLearningRequest {
   optional int32 round_no = 1;
-  optional double adaptive_learning_threshold = 2;
+  optional double adaptive_threshold_value = 2;
   optional StateDict optimizer_state = 3;
 }
 
@@ -66,7 +66,8 @@ message TrainGlobalParallelSplitLearningResponse {
 message TrainGlobalRequest {
   int32 epochs = 1;
   optional int32 round_no = 2;
-  optional StateDict optimizer_state = 3;
+  optional double adaptive_threshold_value = 3;
+  optional StateDict optimizer_state = 4;
 
 }
 
@@ -173,7 +174,9 @@ message DatasetModelInfoRequest {}
 message DatasetModelInfoResponse {
   int32 train_samples = 1;
   int32 validation_samples = 2;
-  int32 client_model_flops = 3;
-  int32 server_model_flops = 4;
-  optional Metrics diagnostic_metrics = 5;
+  int32 client_fw_flops = 3;
+  int32 server_fw_flops = 4;
+  optional int32 client_bw_flops = 5;
+  optional int32 server_bw_flops = 6;
+  optional Metrics diagnostic_metrics = 7;
 }
diff --git a/edml/tests/controllers/fed_controller_test.py b/edml/tests/controllers/fed_controller_test.py
index a44090eefc7b1fb3b791cc749c516835e6d8cc67..45c4ef900c74a02035674ad7cf96be12bd95a5af 100644
--- a/edml/tests/controllers/fed_controller_test.py
+++ b/edml/tests/controllers/fed_controller_test.py
@@ -149,6 +149,7 @@ class FedTrainRoundThreadingTest(unittest.TestCase):
                     "num_devices": 0,
                     "wandb": {"enabled": False},
                     "experiment": {"early_stopping": False},
+                    "simulate_parallelism": False,
                 }
             )
         )
diff --git a/edml/tests/controllers/optimization_test.py b/edml/tests/controllers/optimization_test.py
index 2e8cc390c7a18e0d2b5c8e43822930973f32f953..ab89b184b992891dbe7793e31c889ce4ea912212 100644
--- a/edml/tests/controllers/optimization_test.py
+++ b/edml/tests/controllers/optimization_test.py
@@ -42,8 +42,8 @@ class StrategyOptimizationTest(unittest.TestCase):
             cost_per_byte_sent=1,
             cost_per_byte_received=1,
             cost_per_flop=1,
-            client_model_flops=10,
-            server_model_flops=20,
+            client_fw_flops=10,
+            server_fw_flops=20,
             smashed_data_size=50,
             label_size=5,
             gradient_size=50,
@@ -168,8 +168,8 @@ class EnergySimulatorTest(unittest.TestCase):
             cost_per_byte_sent=1,
             cost_per_byte_received=1,
             cost_per_flop=1,
-            client_model_flops=10,
-            server_model_flops=20,
+            client_fw_flops=10,
+            server_fw_flops=20,
             smashed_data_size=50,
             label_size=5,
             gradient_size=50,
@@ -284,8 +284,8 @@ class TestWithRealData(unittest.TestCase):
             cost_per_byte_sent=self.cost_per_mbyte_sent / 1000000,
             cost_per_byte_received=self.cost_per_mbyte_received / 1000000,
             cost_per_flop=self.cost_per_mflop / 1000000,
-            client_model_flops=5405760,
-            server_model_flops=11215800,
+            client_fw_flops=5405760,
+            server_fw_flops=11215800,
             smashed_data_size=36871,
             label_size=14,
             gradient_size=36871,
@@ -450,3 +450,85 @@ class TestWithRealData(unittest.TestCase):
         )
         print(solution, sum(solution.values()), num_rounds, remaining_batteries)
         self.assertGreater(sum(solution.values()), num_rounds)
+
+
+class TestBug(unittest.TestCase):
+    """Case studies for estimating the number of rounds of each strategy with given energy constratins."""
+
+    def setUp(self):
+        self.global_params = GlobalParams(
+            cost_per_sec=0.02,
+            cost_per_byte_sent=2e-10,
+            cost_per_byte_received=2e-10,
+            cost_per_flop=4.9999999999999995e-14,
+            client_fw_flops=88408064,
+            server_fw_flops=169728256,
+            smashed_data_size=65542.875,
+            batch_size=64,
+            client_weights_size=416469,
+            server_weights_size=6769181,
+            optimizer_state_size=6664371,
+            train_global_time=204.80956506729126,
+            last_server_device_id="d0",
+            label_size=14.359375,
+            gradient_size=65542.875,
+            client_norm_fw_time=0.0002889558474222819,
+            client_norm_bw_time=0.00033905270364549425,
+            server_norm_fw_time=0.0035084471996721703,
+            server_norm_bw_time=0.005231082973148723,
+        )
+        self.device_params_list = [
+            DeviceParams(
+                device_id="d0",
+                initial_battery=400.0,
+                current_battery=393.09099611222206,
+                train_samples=4500,
+                validation_samples=500,
+                comp_latency_factor=3.322075197090219,
+            ),
+            DeviceParams(
+                device_id="d1",
+                initial_battery=400.0,
+                current_battery=398.2543529099223,
+                train_samples=4500,
+                validation_samples=500,
+                comp_latency_factor=3.5023548154181494,
+            ),
+            DeviceParams(
+                device_id="d2",
+                initial_battery=300.0,
+                current_battery=297.6957505401916,
+                train_samples=4500,
+                validation_samples=500,
+                comp_latency_factor=3.2899405054194144,
+            ),
+            DeviceParams(
+                device_id="d3",
+                initial_battery=200.0,
+                current_battery=197.12318101733192,
+                train_samples=4500,
+                validation_samples=500,
+                comp_latency_factor=3.4622951339692114,
+            ),
+            DeviceParams(
+                device_id="d4",
+                initial_battery=200.0,
+                current_battery=194.35766488685508,
+                train_samples=27000,
+                validation_samples=3000,
+                comp_latency_factor=1.0,
+            ),
+        ]
+
+        self.optimizer = ServerChoiceOptimizer(
+            self.device_params_list, self.global_params
+        )
+        self.simulator = EnergySimulator(self.device_params_list, self.global_params)
+
+    def test_bug(self):
+        solution, status = self.optimizer.optimize()
+        num_rounds, schedule, remaining_batteries = (
+            self.simulator.simulate_smart_selection()
+        )
+        print(solution, sum(solution.values()), num_rounds, remaining_batteries)
+        self.assertEqual(sum(solution.values()), num_rounds)
diff --git a/edml/tests/controllers/sample_config.yaml b/edml/tests/controllers/sample_config.yaml
index 505bdae73da5d1fd3901b53d94ae2316cb6ae560..ad8e9175cb16b04de23351fa6794a91b2e41ed92 100644
--- a/edml/tests/controllers/sample_config.yaml
+++ b/edml/tests/controllers/sample_config.yaml
@@ -24,3 +24,4 @@ battery:
   deduction_per_mbyte_sent: 1
   deduction_per_mbyte_received: 1
   deduction_per_mflop: 1
+simulate_parallelism: False
diff --git a/edml/tests/controllers/scheduler/smart_test.py b/edml/tests/controllers/scheduler/smart_test.py
index 6eb7e758073a82ff965eaf414fc37946164d2541..a711f7d78deb4795859c09edb67b2b10431ae087 100644
--- a/edml/tests/controllers/scheduler/smart_test.py
+++ b/edml/tests/controllers/scheduler/smart_test.py
@@ -39,7 +39,10 @@ class SmartServerDeviceSelectionTest(unittest.TestCase):
                 "d0": (2, 1),
                 "d1": (4, 1),
             },
-            {"client": 1000000, "server": 1000000},
+            {
+                "client": {"FW": 1000000, "BW": 2000000},
+                "server": {"FW": 1000000, "BW": 2000000},
+            },
         ]
 
     def test_select_server_device_smart_first_round(self):
diff --git a/edml/tests/controllers/swarm_controller_test.py b/edml/tests/controllers/swarm_controller_test.py
index 6c025f4b78a2a648b6055751cf5875b4b7224494..89ad8aa1fe830bc397ff2831cd037df49ba3b4b7 100644
--- a/edml/tests/controllers/swarm_controller_test.py
+++ b/edml/tests/controllers/swarm_controller_test.py
@@ -37,7 +37,9 @@ class SwarmControllerTest(unittest.TestCase):
         )
 
         client_weights, server_weights, metrics, optimizer_state, diagnostic_metrics = (
-            self.swarm_controller._swarm_train_round(None, None, "d1", 0)
+            self.swarm_controller._swarm_train_round(
+                None, None, "d1", 0, optimizer_state={"optimizer_state": 43}
+            )
         )
 
         self.assertEqual(client_weights, {"weights": 42})
@@ -52,7 +54,11 @@ class SwarmControllerTest(unittest.TestCase):
             ]
         )
         self.mock.train_global_on.assert_called_once_with(
-            "d1", epochs=1, round_no=0, optimizer_state=None
+            "d1",
+            epochs=1,
+            round_no=0,
+            adaptive_threshold_value=0.0,
+            optimizer_state={"optimizer_state": 43},
         )
 
     def test_split_train_round_with_inactive_server_device(self):
@@ -74,7 +80,11 @@ class SwarmControllerTest(unittest.TestCase):
             ]
         )
         self.mock.train_global_on.assert_called_once_with(
-            "d1", epochs=1, round_no=0, optimizer_state=None
+            "d1",
+            epochs=1,
+            round_no=0,
+            adaptive_threshold_value=0.0,
+            optimizer_state=None,
         )
 
 
diff --git a/edml/tests/core/device_test.py b/edml/tests/core/device_test.py
index 3827df9c727254bb13482e22bef6899bb2ae2c47..85b0e1f7f7204efa6a66b3c95f89a5c392098855 100644
--- a/edml/tests/core/device_test.py
+++ b/edml/tests/core/device_test.py
@@ -131,7 +131,12 @@ class RPCDeviceServicerTest(unittest.TestCase):
             {"optimizer_state": 44},
             self.diagnostic_metrics,
         )
-        request = connection_pb2.TrainGlobalRequest(epochs=42)
+        request = connection_pb2.TrainGlobalRequest(
+            epochs=42,
+            round_no=1,
+            adaptive_threshold_value=3,
+            optimizer_state=state_dict_to_proto({"optimizer_state": 42}),
+        )
 
         response, metadata, code, details = self.make_call("TrainGlobal", request)
 
@@ -147,7 +152,9 @@ class RPCDeviceServicerTest(unittest.TestCase):
             proto_to_state_dict(response.optimizer_state), {"optimizer_state": 44}
         )
         self.assertEqual(code, StatusCode.OK)
-        self.mock_device.train_global.assert_called_once_with(42, 0)
+        self.mock_device.train_global.assert_called_once_with(
+            42, 1, 3, {"optimizer_state": 42}
+        )
         self.assertEqual(
             proto_to_metrics(response.diagnostic_metrics), self.diagnostic_metrics
         )
@@ -315,8 +322,8 @@ class RPCDeviceServicerTest(unittest.TestCase):
     def test_dataset_model_info(self):
         self.mock_device.client._train_data.dataset = [1]
         self.mock_device.client._val_data.dataset = [2]
-        self.mock_device.client._model_flops = 3
-        self.mock_device.server._model_flops = 4
+        self.mock_device.client._model_flops = {"FW": 3, "BW": 6}
+        self.mock_device.server._model_flops = {"FW": 4, "BW": 8}
         request = connection_pb2.DatasetModelInfoRequest()
 
         response, metadata, code, details = self.make_call(
@@ -326,8 +333,10 @@ class RPCDeviceServicerTest(unittest.TestCase):
         self.assertEqual(code, StatusCode.OK)
         self.assertEqual(response.train_samples, 1)
         self.assertEqual(response.validation_samples, 1)
-        self.assertEqual(response.client_model_flops, 3)
-        self.assertEqual(response.server_model_flops, 4)
+        self.assertEqual(response.client_fw_flops, 3)
+        self.assertEqual(response.server_fw_flops, 4)
+        self.assertEqual(response.client_bw_flops, 6)
+        self.assertEqual(response.server_bw_flops, 8)
 
 
 class RPCDeviceServicerBatteryEmptyTest(unittest.TestCase):
@@ -355,7 +364,9 @@ class RPCDeviceServicerBatteryEmptyTest(unittest.TestCase):
         self.mock_device.train_global.side_effect = BatteryEmptyException(
             "Battery empty"
         )
-        request = connection_pb2.TrainGlobalRequest()
+        request = connection_pb2.TrainGlobalRequest(
+            optimizer_state=state_dict_to_proto(None)
+        )
         self._test_device_out_of_battery_lets_rpc_fail(request, "TrainGlobal")
 
     def test_stop_at_set_weights(self):
@@ -502,7 +513,7 @@ class RequestDispatcherTest(unittest.TestCase):
         )
 
         client_weights, server_weights, metrics, optimizer_state, diagnostic_metrics = (
-            self.dispatcher.train_global_on("1", 42, 43)
+            self.dispatcher.train_global_on("1", 42, 43, 3, {"optimizer_state": 44})
         )
 
         self.assertEqual(client_weights, self.weights)
@@ -513,19 +524,31 @@ class RequestDispatcherTest(unittest.TestCase):
         self._assert_field_size_added_to_diagnostic_metrics(diagnostic_metrics)
         self.mock_stub.TrainGlobal.assert_called_once_with(
             connection_pb2.TrainGlobalRequest(
-                epochs=42, round_no=43, optimizer_state=state_dict_to_proto(None)
+                epochs=42,
+                round_no=43,
+                adaptive_threshold_value=3,
+                optimizer_state=state_dict_to_proto({"optimizer_state": 44}),
             )
         )
 
     def test_train_global_on_with_error(self):
         self.mock_stub.TrainGlobal.side_effect = grpc.RpcError()
 
-        response = self.dispatcher.train_global_on("1", 42, round_no=43)
+        response = self.dispatcher.train_global_on(
+            "1",
+            42,
+            round_no=43,
+            adaptive_threshold_value=3,
+            optimizer_state={"optimizer_state": 44},
+        )
 
         self.assertEqual(response, False)
         self.mock_stub.TrainGlobal.assert_called_once_with(
             connection_pb2.TrainGlobalRequest(
-                epochs=42, round_no=43, optimizer_state=state_dict_to_proto(None)
+                epochs=42,
+                round_no=43,
+                adaptive_threshold_value=3,
+                optimizer_state=state_dict_to_proto({"optimizer_state": 44}),
             )
         )
 
@@ -809,14 +832,16 @@ class RequestDispatcherTest(unittest.TestCase):
             connection_pb2.DatasetModelInfoResponse(
                 train_samples=42,
                 validation_samples=21,
-                client_model_flops=42,
-                server_model_flops=21,
+                client_fw_flops=1,
+                server_fw_flops=2,
+                client_bw_flops=3,
+                server_bw_flops=4,
             )
         )
 
         response = self.dispatcher.get_dataset_model_info_on("1")
 
-        self.assertEqual(response, (42, 21, 42, 21))
+        self.assertEqual(response, (42, 21, 1, 2, 3, 4))
         self.mock_stub.GetDatasetModelInfo.assert_called_once_with(
             connection_pb2.DatasetModelInfoRequest()
         )
diff --git a/edml/tests/core/server_test.py b/edml/tests/core/server_test.py
index 1331f8aa7bad195eb02044bcf25d381fe2ee2a13..c855d82bdb82b26356b91c04960c4c5f8dafd164 100644
--- a/edml/tests/core/server_test.py
+++ b/edml/tests/core/server_test.py
@@ -13,6 +13,7 @@ from edml.core.battery import Battery
 from edml.core.client import DeviceClient
 from edml.core.device import Device
 from edml.core.server import DeviceServer
+from edml.helpers.logging import SimpleLogger
 
 
 class ClientModel(nn.Module):
@@ -76,6 +77,7 @@ class PSLTest(unittest.TestCase):
                     ]
                 },
                 "own_device_id": "d0",
+                "simulate_parallelism": False,
             }
         )
         # init models with fixed weights for repeatability
@@ -90,10 +92,19 @@ class PSLTest(unittest.TestCase):
         )
         server_model = ServerModel()
         server_model.load_state_dict(server_state_dict)
+        server_model.get_optimizer_params = (
+            server_model.parameters
+        )  # using a model provider, this is created automatically
         client_model1 = ClientModel()
         client_model1.load_state_dict(client_state_dict)
+        client_model1.get_optimizer_params = (
+            server_model.parameters
+        )  # using a model provider, this is created automatically
         client_model2 = ClientModel()
         client_model2.load_state_dict(client_state_dict)
+        client_model2.get_optimizer_params = (
+            server_model.parameters
+        )  # using a model provider, this is created automatically
         self.server = DeviceServer(
             model=server_model, loss_fn=torch.nn.L1Loss(), cfg=cfg
         )
@@ -157,6 +168,7 @@ class PSLTest(unittest.TestCase):
 
         node_device = Mock(Device)
         node_device.battery = Mock(Battery)
+        node_device.logger = Mock(SimpleLogger)
         node_device.train_batch_on_client_only_on.side_effect = get_client_side_effect(
             "train_single_batch"
         )
diff --git a/edml/tests/core/start_device_test.py b/edml/tests/core/start_device_test.py
index abb5d9d61d68c791657d168f9e266abec366a74c..dd1f40955f492fda856fbb05194b3d22aa1e165f 100644
--- a/edml/tests/core/start_device_test.py
+++ b/edml/tests/core/start_device_test.py
@@ -4,11 +4,51 @@ from copy import deepcopy
 
 import torch
 from omegaconf import OmegaConf
-from torch.autograd import Variable
 
-from edml.core.start_device import _get_models
 from edml.helpers.model_splitting import Part
 from edml.models.autoencoder import ClientWithAutoencoder, ServerWithAutoencoder
+from edml.tests.models.model_loading_helpers import (
+    _get_model_from_model_provider_config,
+)
+
+
+def plot_grad_flow(named_parameters):
+    from matplotlib import pyplot as plt
+    import numpy as np
+    from matplotlib.lines import Line2D
+
+    """Plots the gradients flowing through different layers in the net during training.
+    Can be used for checking for possible gradient vanishing / exploding problems.
+
+    Usage: Plug this function in Trainer class after loss.backwards() as
+    "plot_grad_flow(self.model.named_parameters())" to visualize the gradient flow"""
+    ave_grads = []
+    max_grads = []
+    layers = []
+    for n, p in named_parameters:
+        if (p.requires_grad) and ("bias" not in n):
+            layers.append(n)
+            ave_grads.append(p.grad.abs().mean())
+            max_grads.append(p.grad.abs().max())
+    plt.bar(np.arange(len(max_grads)), max_grads, alpha=0.1, lw=1, color="c")
+    plt.bar(np.arange(len(max_grads)), ave_grads, alpha=0.1, lw=1, color="b")
+    plt.hlines(0, 0, len(ave_grads) + 1, lw=2, color="k")
+    plt.xticks(range(0, len(ave_grads), 1), layers, rotation="vertical")
+    plt.xlim(left=0, right=len(ave_grads))
+    plt.ylim(bottom=-0.001)  # , top=0.02)  # zoom in on the lower gradient regions
+    plt.xlabel("Layers")
+    plt.ylabel("average gradient")
+    plt.title("Gradient flow")
+    plt.grid(True)
+    plt.legend(
+        [
+            Line2D([0], [0], color="c", lw=4),
+            Line2D([0], [0], color="b", lw=4),
+            Line2D([0], [0], color="k", lw=4),
+        ],
+        ["max-gradient", "mean-gradient", "zero-gradient"],
+    )
+    plt.show()
 
 
 class GetModelsTest(unittest.TestCase):
@@ -22,17 +62,8 @@ class GetModelsTest(unittest.TestCase):
             )
         )
 
-    def _get_model_from_model_provider_config(self, config_name):
-        self.cfg.model_provider = OmegaConf.load(
-            os.path.join(
-                os.path.dirname(__file__),
-                f"../../config/model_provider/{config_name}.yaml",
-            )
-        )
-        return _get_models(self.cfg)
-
     def test_load_resnet20(self):
-        client, server = self._get_model_from_model_provider_config("resnet20")
+        client, server = _get_model_from_model_provider_config(self.cfg, "resnet20")
         self.assertIsInstance(client, Part)
         self.assertIsInstance(server, Part)
         self.assertEqual(len(client.layers), 4)
@@ -40,29 +71,18 @@ class GetModelsTest(unittest.TestCase):
         self.assertEqual(server(client(torch.zeros(1, 3, 32, 32))).shape, (1, 100))
 
     def test_load_resnet20_with_ae(self):
-        client, server = self._get_model_from_model_provider_config(
-            "resnet20-with-autoencoder"
+        client, server = _get_model_from_model_provider_config(
+            self.cfg, "resnet20-with-autoencoder"
         )
         self.assertIsInstance(client, ClientWithAutoencoder)
         self.assertIsInstance(server, ServerWithAutoencoder)
         self.assertEqual(len(client.model.layers), 4)
         self.assertEqual(len(server.model.layers), 5)
         self.assertEqual(server(client(torch.zeros(1, 3, 32, 32))).shape, (1, 100))
-        optimizer = torch.optim.Adam(server.parameters())
-        smashed_data = client(torch.zeros(1, 3, 32, 32))
-        server_smashed_data = Variable(smashed_data, requires_grad=True)
-        output_train = server(server_smashed_data)
-        loss_train = torch.nn.functional.cross_entropy(
-            output_train, torch.zeros((1, 100))
-        )
-        loss_train.backward()
-        optimizer.step()
-        smashed_data.backward(server_smashed_data.grad)
-        optimizer.step()
 
     def test_training_resnet20_with_ae_as_non_trainable_layers(self):
-        client_encoder, server_decoder = self._get_model_from_model_provider_config(
-            "resnet20-with-autoencoder"
+        client_encoder, server_decoder = _get_model_from_model_provider_config(
+            self.cfg, "resnet20-with-autoencoder"
         )
         client_params = deepcopy(str(client_encoder.model.state_dict()))
         encoder_params = deepcopy(str(client_encoder.autoencoder.state_dict()))
@@ -70,17 +90,18 @@ class GetModelsTest(unittest.TestCase):
         decoder_params = deepcopy(str(server_decoder.autoencoder.state_dict()))
 
         # Training loop
-        client_optimizer = torch.optim.Adam(client_encoder.parameters())
-        server_optimizer = torch.optim.Adam(server_decoder.parameters())
-        smashed_data = client_encoder(torch.zeros(1, 3, 32, 32))
-        server_smashed_data = Variable(smashed_data, requires_grad=True)
-        output_train = server_decoder(server_smashed_data)
+        client_optimizer = torch.optim.Adam(client_encoder.get_optimizer_params())
+        server_optimizer = torch.optim.Adam(server_decoder.get_optimizer_params())
+        smashed_data = client_encoder(torch.zeros(7, 3, 32, 32))
+        output_train = server_decoder(smashed_data)
         loss_train = torch.nn.functional.cross_entropy(
-            output_train, torch.rand((1, 100))
+            output_train, torch.rand((7, 100))
         )
         loss_train.backward()
         server_optimizer.step()
-        smashed_data.backward(server_smashed_data.grad)
+        client_encoder.trainable_layers_output.backward(
+            server_decoder.trainable_layers_input.grad
+        )
         client_optimizer.step()
 
         # check that AE hasn't changed, but client and server have
@@ -90,7 +111,7 @@ class GetModelsTest(unittest.TestCase):
         self.assertNotEqual(server_params, str(server_decoder.model.state_dict()))
 
     def test_load_resnet110(self):
-        client, server = self._get_model_from_model_provider_config("resnet110")
+        client, server = _get_model_from_model_provider_config(self.cfg, "resnet110")
         self.assertIsInstance(client, Part)
         self.assertIsInstance(server, Part)
         self.assertEqual(len(client.layers), 4)
@@ -98,8 +119,8 @@ class GetModelsTest(unittest.TestCase):
         self.assertEqual(server(client(torch.zeros(1, 3, 32, 32))).shape, (1, 100))
 
     def test_load_resnet110_with_ae(self):
-        client, server = self._get_model_from_model_provider_config(
-            "resnet110-with-autoencoder"
+        client, server = _get_model_from_model_provider_config(
+            self.cfg, "resnet110-with-autoencoder"
         )
         self.assertIsInstance(client, ClientWithAutoencoder)
         self.assertIsInstance(server, ServerWithAutoencoder)
diff --git a/edml/tests/helpers/flops_test.py b/edml/tests/helpers/flops_test.py
index 4b7f0d66bbfe6c74ba37e517969bb596842277b4..cb966d4864c1b11b16f75e1f1ed94a504db98a1f 100644
--- a/edml/tests/helpers/flops_test.py
+++ b/edml/tests/helpers/flops_test.py
@@ -1,10 +1,15 @@
+import os
 import unittest
 
 import torch
 import torch.nn as nn
+from omegaconf import OmegaConf
 
 from edml.helpers.flops import estimate_model_flops
 from edml.models.mnist_models import ClientNet, ServerNet
+from edml.tests.models.model_loading_helpers import (
+    _get_model_from_model_provider_config,
+)
 
 
 class FullTestModel(nn.Module):
@@ -55,9 +60,15 @@ class FlopsTest(unittest.TestCase):
         client_flops = estimate_model_flops(client_model, inputs)
         server_flops = estimate_model_flops(server_model, server_inputs)
 
-        self.assertEqual(client_flops, 30000)
-        self.assertEqual(server_flops, 101000)
-        self.assertEqual(full_flops, client_flops + server_flops)
+        self.assertEqual(client_flops, {"BW": 60000, "FW": 30000})
+        self.assertEqual(server_flops, {"BW": 202000, "FW": 101000})
+        self.assertEqual(
+            full_flops,
+            {
+                "BW": client_flops["BW"] + server_flops["BW"],
+                "FW": client_flops["FW"] + server_flops["FW"],
+            },
+        )
 
     def test_mnist_split_count(self):
         client_model = ClientNet()
@@ -69,5 +80,34 @@ class FlopsTest(unittest.TestCase):
         client_flops = estimate_model_flops(client_model, inputs)
         server_flops = estimate_model_flops(server_model, server_inputs)
 
-        self.assertEqual(client_flops, 5405760)
-        self.assertEqual(server_flops, 11215800)
+        self.assertEqual(client_flops, {"BW": 10811520, "FW": 5405760})
+        self.assertEqual(server_flops, {"BW": 22431600, "FW": 11215800})
+
+    def test_autoencoder_does_not_affect_backward_flops(self):
+        os.chdir(os.path.join(os.path.dirname(__file__), "../../../"))
+        client, server = _get_model_from_model_provider_config(
+            OmegaConf.create({}), "resnet20"
+        )
+        client_with_encoder, server_with_decoder = (
+            _get_model_from_model_provider_config(
+                OmegaConf.create({}), "resnet20-with-autoencoder"
+            )
+        )
+        input = torch.randn(1, 3, 32, 32)
+
+        ae_smashed_data = client_with_encoder(input)
+        smashed_data = client(input)
+
+        client_flops = estimate_model_flops(client, input)
+        server_flops = estimate_model_flops(server, smashed_data)
+
+        client_with_encoder_flops = estimate_model_flops(client_with_encoder, input)
+        server_with_decoder_flops = estimate_model_flops(
+            server_with_decoder, ae_smashed_data
+        )
+
+        self.assertEqual(client_flops["BW"], client_with_encoder_flops["BW"])
+        self.assertEqual(server_flops["BW"], server_with_decoder_flops["BW"])
+
+        self.assertGreater(client_with_encoder_flops["FW"], client_flops["FW"])
+        self.assertGreater(server_with_decoder_flops["FW"], server_flops["FW"])
diff --git a/edml/tests/models/model_loading_helpers.py b/edml/tests/models/model_loading_helpers.py
new file mode 100644
index 0000000000000000000000000000000000000000..18b88a0618c7c3ac490400ee500a019ec19578ae
--- /dev/null
+++ b/edml/tests/models/model_loading_helpers.py
@@ -0,0 +1,15 @@
+import os
+
+from omegaconf import OmegaConf
+
+from edml.core.start_device import _get_models
+
+
+def _get_model_from_model_provider_config(cfg, config_name):
+    cfg.model_provider = OmegaConf.load(
+        os.path.join(
+            os.path.dirname(__file__),
+            f"../../config/model_provider/{config_name}.yaml",
+        )
+    )
+    return _get_models(cfg)
diff --git a/edml/tests/models/model_provider_test.py b/edml/tests/models/model_provider_test.py
new file mode 100644
index 0000000000000000000000000000000000000000..bf872d0c07d0e40188208978e258fd1298c74eaa
--- /dev/null
+++ b/edml/tests/models/model_provider_test.py
@@ -0,0 +1,43 @@
+import os
+import unittest
+
+from omegaconf import OmegaConf
+
+from edml.models.provider.base import has_autoencoder
+from edml.tests.models.model_loading_helpers import (
+    _get_model_from_model_provider_config,
+)
+
+
+class ModelProviderTest(unittest.TestCase):
+    def setUp(self):
+        os.chdir(os.path.join(os.path.dirname(__file__), "../../../"))
+        self.cfg = OmegaConf.create({"some_key": "some_value"})
+
+    def test_has_autoencoder_for_model_without_autoencoder(self):
+        client, server = _get_model_from_model_provider_config(self.cfg, "resnet20")
+        self.assertFalse(has_autoencoder(client))
+        self.assertFalse(has_autoencoder(server))
+
+    def test_has_autoencoder_for_model_with_autoencoder(self):
+        client, server = _get_model_from_model_provider_config(
+            self.cfg, "resnet20-with-autoencoder"
+        )
+        self.assertTrue(has_autoencoder(client))
+        self.assertTrue(has_autoencoder(server))
+
+    def test_get_optimizer_params_for_model_without_autoencoder(self):
+        client, server = _get_model_from_model_provider_config(self.cfg, "resnet20")
+        self.assertEqual(list(client.get_optimizer_params()), list(client.parameters()))
+        self.assertEqual(list(server.get_optimizer_params()), list(server.parameters()))
+
+    def test_get_optimizer_params_for_model_with_autoencoder(self):
+        client, server = _get_model_from_model_provider_config(
+            self.cfg, "resnet20-with-autoencoder"
+        )
+        self.assertEqual(
+            list(client.get_optimizer_params()), list(client.model.parameters())
+        )
+        self.assertEqual(
+            list(server.get_optimizer_params()), list(server.model.parameters())
+        )
diff --git a/results/result_generation.ipynb b/results/result_generation.ipynb
index 5c3ed118508856d66ca7b70f0027b99a2cb8052f..fe4e16d604eb3d3ffcc2f1bc3f8c731019b6b797 100644
--- a/results/result_generation.ipynb
+++ b/results/result_generation.ipynb
@@ -13,31 +13,56 @@
    ]
   },
   {
-   "cell_type": "markdown",
+   "cell_type": "code",
+   "execution_count": null,
+   "outputs": [],
    "source": [
-    "# Downloading the dataframes (takes time)"
+    "df_base_dir = \"./dataframes\"\n",
+    "projects_with_model = [\n",
+    "    (\"5_devices_unlimited_new\", \"resnet110\"),\n",
+    "    (\"50_devices_unlimited_new\", \"resnet110\"),\n",
+    "    (\"controller_comparison\", \"resnet110\"),\n",
+    "    (\"controller_comparison_het_bat\", \"resnet110\"),\n",
+    "    (\"controller_comparison_homogeneous\", \"resnet110\")\n",
+    "]"
    ],
    "metadata": {
     "collapsed": false
    },
-   "id": "3e72d0fe76a55e56"
+   "id": "5b81c8c9ba4b483d"
   },
   {
    "cell_type": "code",
    "execution_count": null,
    "outputs": [],
    "source": [
-    "df_base_dir = \"./dataframes\"\n",
-    "projects_with_model = [\n",
-    "    (\"5_devices_unlimited_new\", \"resnet110\"),\n",
-    "    (\"50_devices_unlimited_new\", \"resnet110\"),\n",
-    "    (\"controller_comparison\", \"resnet110\")\n",
-    "]"
+    "strategy_autoencoder = {\n",
+    "    \"psl_sequential__\": False,\n",
+    "    \"fed___\": False,\n",
+    "    \"split___\": False,\n",
+    "    \"swarm_sequential__\": False,\n",
+    "    \"swarm_max_battery__\": False,\n",
+    "    \"swarm_smart__\": False,\n",
+    "    \"psl_sequential_static_at_resnet_decoderpth\": True,\n",
+    "    \"psl_sequential__resnet_decoderpth\": True,\n",
+    "    \"psl_sequential_static_at_\": False,\n",
+    "}\n",
+    "strategies = list(strategy_autoencoder.keys())"
    ],
    "metadata": {
     "collapsed": false
    },
-   "id": "5b81c8c9ba4b483d"
+   "id": "26e70757d4650fc1"
+  },
+  {
+   "cell_type": "markdown",
+   "source": [
+    "# Downloading the dataframes (takes time)"
+   ],
+   "metadata": {
+    "collapsed": false
+   },
+   "id": "3e72d0fe76a55e56"
   },
   {
    "cell_type": "code",
@@ -45,29 +70,7 @@
    "outputs": [],
    "source": [
     "for project_name, _ in projects_with_model:\n",
-    "    save_dataframes(project_name=project_name, strategies=[\n",
-    "        #\"swarm_seq\",\n",
-    "        #\"fed\",\n",
-    "        #\"swarm_max\",\n",
-    "        #\"swarm_rand\",\n",
-    "        #\"swarm_smart\",\n",
-    "        #\"split\",\n",
-    "        #\"psl_rand_\",\n",
-    "        #\"psl_sequential_\",\n",
-    "        #\"psl_max_batteries_\",\n",
-    "        #\"swarm_rand_\",\n",
-    "        #\"swarm_sequential_\",\n",
-    "        #\"swarm_max_batteries_\",\n",
-    "        \"psl_sequential__\",\n",
-    "        \"fed___\",\n",
-    "        \"split___\",\n",
-    "        \"swarm_sequential__\",\n",
-    "        \"swarm_max_battery__\",\n",
-    "        \"swarm_smart__\",\n",
-    "        \"psl_sequential_static_at_resnet_decoderpth\",\n",
-    "        \"psl_sequential__resnet_decoderpth\",\n",
-    "        \"psl_sequential_static_at_\",\n",
-    "    ])"
+    "    save_dataframes(project_name=project_name, strategies=strategies)"
    ],
    "metadata": {
     "collapsed": false
@@ -90,27 +93,23 @@
    "outputs": [],
    "source": [
     "# Required for total number of FLOPs computation\n",
+    "# plain = forward FLOPs without AE = 1/2 backward FLOPs with or without AE (number of backward FLOPs equal with and without AE as AE is skipped during BP)\n",
+    "# ae = forward FLOPs with AE\n",
     "model_flops = {\n",
-    "    \"resnet20\": 41498880,\n",
-    "    \"resnet20_ae\": 45758720,\n",
-    "    \"resnet110\": 258136320,\n",
-    "    \"resnet110_ae\": 262396160,\n",
-    "    \"tcn\": 27240000,\n",
-    "    \"simple_conv\": 16621560\n",
+    "    \"resnet20\": {\"plain\": 41498880, \"ae\": 45758720},\n",
+    "    \"resnet110\": {\"plain\": 258136320, \"ae\": 262396160},\n",
+    "    \"tcn\": {\"plain\": 27240000, \"ae\": 27240000},  # no AE implemented yet\n",
+    "    \"simple_conv\": {\"plain\": 16621560, \"ae\": 16621560},  # no AE implemented yet\n",
     "}\n",
     "\n",
     "client_model_flops = {\n",
-    "    \"resnet20\": 15171584,\n",
-    "    \"resnet20_ae\": 19005440,\n",
-    "    \"resnet110\": 88408064,\n",
-    "    \"resnet110_ae\": 92241920,\n",
+    "    \"resnet20\": {\"plain\": 15171584, \"ae\": 19005440},\n",
+    "    \"resnet110\": {\"plain\": 88408064, \"ae\": 92241920},\n",
     "}\n",
     "\n",
     "server_model_flops = {\n",
-    "    \"resnet20\": 26327296,\n",
-    "    \"resnet20_ae\": 26753280,\n",
-    "    \"resnet110\": 169728256,\n",
-    "    \"resnet110_ae\": 170154240,\n",
+    "    \"resnet20\": {\"plain\": 26327296, \"ae\": 26753280},\n",
+    "    \"resnet110\": {\"plain\": 169728256, \"ae\": 170154240},\n",
     "}\n",
     "experiment_batch_size = 64"
    ],
@@ -144,7 +143,7 @@
     "        dataframes = load_dataframes(proj_name, df_base_dir)\n",
     "        print(\"  generating metrics\")\n",
     "        generate_metric_files(dataframes, proj_name, model_flops[model_name], client_model_flops[model_name],\n",
-    "                              # TODO distinguish AE\n",
+    "                              strategy_autoencoder_mapping=strategy_autoencoder,\n",
     "                              base_path=metrics_base_path, batch_size=experiment_batch_size)\n",
     "        print(\"  generating plots\")\n",
     "        generate_plots(dataframes, proj_name)"
diff --git a/results/result_generation.py b/results/result_generation.py
index caa42642976785d82f0313733b324c9cd63b1166..f22bca3d8558982250e4fa8d3376ee41b70d7539 100644
--- a/results/result_generation.py
+++ b/results/result_generation.py
@@ -39,6 +39,8 @@ LABEL_MAPPING = {
     "device": "Device",
 }
 
+DPI = 300
+
 
 def scale_parallel_time(run_df, scale_factor=1.0):
     """
@@ -230,7 +232,13 @@ def load_dataframes(project_name, base_dir="./dataframes"):
     return history_groups
 
 
-def get_total_flops(groups, total_model_flops, client_model_flops, batch_size=64):
+def get_total_flops(
+    groups,
+    total_model_flops,
+    client_model_flops,
+    strategy_autoencoder_mapping,
+    batch_size=64,
+):
     """
     Returns the total number of FLOPs for each group.
     Args:
@@ -242,6 +250,13 @@ def get_total_flops(groups, total_model_flops, client_model_flops, batch_size=64
     """
     flops_per_group = {"strategy": [], "flops": []}
     for (strategy, job), group in groups.items():
+        # determine model FLOPs depending on whether an autoencoder was used
+        if strategy_autoencoder_mapping[strategy]:  # AE
+            total_model_fw_flops = total_model_flops["ae"]
+        else:  # no AE
+            total_model_fw_flops = total_model_flops["plain"]
+        total_model_bw_flops = 2 * total_model_flops["plain"]
+        client_model_bw_flops = 2 * client_model_flops["plain"]
         if job == "train":
             flops = 0
             num_runs = 1  # avoid division by 0
@@ -254,12 +269,12 @@ def get_total_flops(groups, total_model_flops, client_model_flops, batch_size=64
                     for run_df in runs:
                         for col_name in run_df.columns:
                             if col_name == "train_accuracy.num_samples":
-                                flops += (
-                                    run_df[col_name].sum() * total_model_flops * 3
+                                flops += run_df[col_name].sum() * (
+                                    total_model_fw_flops + total_model_bw_flops
                                 )  # 1x forward + 2x backward
                             if col_name == "val_accuracy.num_samples":
                                 flops += (
-                                    run_df[col_name].sum() * total_model_flops
+                                    run_df[col_name].sum() * total_model_fw_flops
                                 )  # 1x forward
                             if col_name == "adaptive_learning_threshold_applied":
                                 # deduce client model flops twice as client backprop is avoided
@@ -285,13 +300,12 @@ def get_total_flops(groups, total_model_flops, client_model_flops, batch_size=64
                                     )
                                     flops -= (
                                         len(run_df[col_name].dropna())
-                                        * client_model_flops
-                                        * 2
+                                        * client_model_bw_flops
                                         * avg_samples_per_batch
                                     )
                                 else:  # numbers of samples skipped are logged -> sum up
                                     flops -= (
-                                        run_df[col_name].sum() * client_model_flops * 2
+                                        run_df[col_name].sum() * client_model_bw_flops
                                     )
             flops = flops / num_runs
             flops_per_group["strategy"].append(STRATEGY_MAPPING[strategy])
@@ -401,7 +415,7 @@ def plot_remaining_devices(devices_per_round, save_path=None):
         save_path: (str) the path to save the plot to
     """
 
-    plt.figure()
+    plt.figure(dpi=DPI)
     num_rounds = [0]
     max_devices = []
     for (strategy, job), (rounds, num_devices) in devices_per_round.items():
@@ -492,7 +506,7 @@ def plot_accuracies(accuracies_per_round, save_path=None, phase="train"):
         accuracies_per_round: (dict) the accuracy (list(float)) per round (list(int)) for each group
         save_path: (str) the path to save the plot to
     """
-    plt.figure()
+    plt.figure(dpi=DPI)
     num_rounds = [0]
     for (strategy, job), (rounds, accs) in accuracies_per_round.items():
         plt.plot(rounds, accs, label=f"{STRATEGY_MAPPING[strategy]}")
@@ -517,7 +531,7 @@ def plot_accuracies_over_time(accuracies_per_time, save_path=None, phase="train"
         accuracies_per_time: (dict) the accuracy (list(float)) per time (list(float)) for each group
         save_path: (str) the path to save the plot to
     """
-    plt.figure()
+    plt.figure(dpi=DPI)
     for (strategy, job), (time, accs) in accuracies_per_time.items():
         plt.plot(time, accs, label=f"{STRATEGY_MAPPING[strategy]}")
     plt.xlabel(LABEL_MAPPING["runtime"])
@@ -688,7 +702,7 @@ def plot_batteries_over_time(
         aggregated: (bool) whether the battery is aggregated or not
     """
     if aggregated:
-        plt.figure()
+        plt.figure(dpi=DPI)
         plt.rcParams.update({"font.size": 13})
         for (strategy, job), series in batteries_over_time.items():
             runtime = max_runtimes[(strategy, job)]
@@ -709,7 +723,7 @@ def plot_batteries_over_time(
             plt.close()
     else:
         for (strategy, job), series_dict in batteries_over_time.items():
-            plt.figure()
+            plt.figure(dpi=DPI)
             plt.rcParams.update({"font.size": 13})
             for device_id, series in series_dict.items():
                 runtime = max_runtimes[(strategy, job)]
@@ -739,7 +753,7 @@ def plot_batteries_over_epoch(batteries_over_epoch, save_path=None, aggregated=T
         aggregated: (bool) whether the battery is aggregated or not
     """
     if aggregated:
-        plt.figure()
+        plt.figure(dpi=DPI)
         plt.rcParams.update({"font.size": 13})
         num_rounds = [0]
         for (strategy, job), series in batteries_over_epoch.items():
@@ -760,7 +774,7 @@ def plot_batteries_over_epoch(batteries_over_epoch, save_path=None, aggregated=T
             plt.close()
     else:
         for (strategy, job), series_dict in batteries_over_epoch.items():
-            plt.figure()
+            plt.figure(dpi=DPI)
             plt.rcParams.update({"font.size": 13})
             num_rounds = [0]
             for device_id, series in series_dict.items():
@@ -899,7 +913,7 @@ def plot_batteries_over_time_with_activity(
             0,
             train_time_end(server_train_times, client_train_times) * 1.05,
         )  # set end 5% after last activity timestamp
-        plt.figure()
+        plt.figure(dpi=DPI)
         plt.rcParams.update({"font.size": 13})
         # battery_plot
         plt.subplot(2, 1, 1)
@@ -976,7 +990,7 @@ def plot_batteries_over_epoch_with_activity_at_time_scale(
             ],
             [str(i) for i in range(0, len(start_times), max(1, len(start_times) // 8))],
         )
-        plt.figure()
+        plt.figure(dpi=DPI)
         plt.rcParams.update({"font.size": 13})
         # battery_plot
         plt.subplot(2, 1, 1)
@@ -1048,7 +1062,7 @@ def plot_batteries_over_epoch_with_activity_at_epoch_scale(
         ).sort_values("start")
 
         # battery plot
-        plt.figure()
+        plt.figure(dpi=DPI)
         plt.rcParams.update({"font.size": 13})
         plt.subplot(2, 1, 1)
         num_rounds = []
@@ -1242,6 +1256,7 @@ def generate_metric_files(
     project_name,
     total_model_flops,
     client_model_flops,
+    strategy_autoencoder_mapping,
     base_path="./metrics",
     batch_size=64,
 ):
@@ -1265,7 +1280,11 @@ def generate_metric_files(
     ).set_index("strategy")
     total_flops = pd.DataFrame.from_dict(
         get_total_flops(
-            history_groups, total_model_flops, client_model_flops, batch_size
+            history_groups,
+            total_model_flops,
+            client_model_flops,
+            strategy_autoencoder_mapping,
+            batch_size,
         )
     ).set_index("strategy")
     df = pd.concat([test_acc, comm_overhead, total_flops], axis=1)