diff --git a/edml/controllers/scheduler/smart.py b/edml/controllers/scheduler/smart.py
index 48940331247f647a855d72bf3e8a54b44f13ba78..eddf1e98f9b18b0589ff0f61782e54a98ccdd130 100644
--- a/edml/controllers/scheduler/smart.py
+++ b/edml/controllers/scheduler/smart.py
@@ -90,8 +90,10 @@ class SmartNextServerScheduler(NextServerScheduler):
             device_params_list.append(device_params)
         global_params = GlobalParams()
         global_params.fill_values_from_config(self.cfg)
-        global_params.client_model_flops = model_flops["client"]
-        global_params.server_model_flops = model_flops["server"]
+        global_params.client_fw_flops = model_flops["client"]["FW"]
+        global_params.server_fw_flops = model_flops["server"]["FW"]
+        global_params.client_bw_flops = model_flops["client"]["BW"]
+        global_params.server_bw_flops = model_flops["server"]["BW"]
         global_params.client_norm_fw_time = optimization_metrics["client_norm_fw_time"]
         global_params.client_norm_bw_time = optimization_metrics["client_norm_bw_time"]
         global_params.server_norm_fw_time = optimization_metrics["server_norm_fw_time"]
diff --git a/edml/controllers/strategy_optimization.py b/edml/controllers/strategy_optimization.py
index cd7bae8ac0db1c8eb4bcf818713ed8b2f1ed0e47..7c11f73652cf1aecd792a26e603a099c488b9943 100644
--- a/edml/controllers/strategy_optimization.py
+++ b/edml/controllers/strategy_optimization.py
@@ -29,8 +29,10 @@ class GlobalParams:
         cost_per_byte_sent=None,
         cost_per_byte_received=None,
         cost_per_flop=None,
-        client_model_flops=None,
-        server_model_flops=None,
+        client_fw_flops=None,  # mandatory forward FLOPs
+        server_fw_flops=None,  # mandatory forward FLOPs
+        client_bw_flops=None,  # optional backward FLOPs if bw FLOPs != 2 * fw FLOPs (e.g. for AE)
+        server_bw_flops=None,  # optional backward FLOPs if bw FLOPs != 2 * fw FLOPs (e.g. for AE)
         smashed_data_size=None,
         label_size=None,
         gradient_size=None,
@@ -49,8 +51,10 @@ class GlobalParams:
         self.cost_per_byte_sent = cost_per_byte_sent
         self.cost_per_byte_received = cost_per_byte_received
         self.cost_per_flop = cost_per_flop
-        self.client_model_flops = client_model_flops
-        self.server_model_flops = server_model_flops
+        self.client_fw_flops = client_fw_flops
+        self.server_fw_flops = server_fw_flops
+        self.client_bw_flops = client_bw_flops  # optional backward FLOPs if bw FLOPs != 2 * fw FLOPs (e.g. for AE)
+        self.server_bw_flops = server_bw_flops  # optional backward FLOPs if bw FLOPs != 2 * fw FLOPs (e.g. for AE)
         self.smashed_data_size = smashed_data_size
         self.batch_size = batch_size
         self.client_weights_size = client_weights_size
@@ -79,8 +83,8 @@ class GlobalParams:
             and self.cost_per_byte_sent is not None
             and self.cost_per_byte_received is not None
             and self.cost_per_flop is not None
-            and self.client_model_flops is not None
-            and self.server_model_flops is not None
+            and self.client_fw_flops is not None
+            and self.server_fw_flops is not None
             and self.optimizer_state_size is not None
             and self.smashed_data_size is not None
             and self.label_size is not None
@@ -181,22 +185,24 @@ class ServerChoiceOptimizer:
     def num_flops_per_round_on_device(self, device_id, server_device_id):
         device = self._get_device_params(device_id)
         total_flops = 0
+        # set bw FLOPs to 2 * fw FLOPs if not provided
+        if self.global_params.client_bw_flops and self.global_params.server_bw_flops:
+            client_bw_flops = self.global_params.client_bw_flops
+            server_bw_flops = self.global_params.server_bw_flops
+        else:
+            client_bw_flops = (
+                2 * self.global_params.client_fw_flops
+            )  # by default bw FLOPs = 2*fw FLOPs
+            server_bw_flops = 2 * self.global_params.server_fw_flops
+        client_fw_flops = self.global_params.client_fw_flops
+        server_fw_flops = self.global_params.server_fw_flops
         if device_id == server_device_id:
             total_flops += (
-                self.global_params.server_model_flops
-                * self._total_train_dataset_size()
-                * 3
-            )  # fw + 2bw
-            total_flops += (
-                self.global_params.server_model_flops
-                * self._total_validation_dataset_size()
-            )  # fw
-        total_flops += (
-            self.global_params.client_model_flops * device.train_samples * 3
-        )  # fw + 2bw
-        total_flops += (
-            self.global_params.client_model_flops * device.validation_samples
-        )  # fw
+                server_fw_flops + server_bw_flops
+            ) * self._total_train_dataset_size()
+            total_flops += server_fw_flops * self._total_validation_dataset_size()
+        total_flops += (client_fw_flops + client_bw_flops) * device.train_samples
+        total_flops += client_fw_flops * device.validation_samples
         return total_flops
 
     def num_bytes_sent_per_round_on_device(self, device_id, server_device_id):
@@ -476,10 +482,10 @@ class EnergySimulator:
     def _fl_flops_on_device(self, device_id):
         device = self._get_device_params(device_id)
         total_flops = 0
-        total_flops += self.global_params.client_model_flops * device.train_samples * 3
-        total_flops += self.global_params.client_model_flops * device.validation_samples
-        total_flops += self.global_params.server_model_flops * device.train_samples * 3
-        total_flops += self.global_params.server_model_flops * device.validation_samples
+        total_flops += self.global_params.client_fw_flops * device.train_samples * 3
+        total_flops += self.global_params.client_fw_flops * device.validation_samples
+        total_flops += self.global_params.server_fw_flops * device.train_samples * 3
+        total_flops += self.global_params.server_fw_flops * device.validation_samples
         return total_flops
 
     def _fl_data_sent_per_device(self):
diff --git a/edml/controllers/swarm_controller.py b/edml/controllers/swarm_controller.py
index 4463e2cc8ffd3a89471e8728ba0a117e163e8fea..1b61f0aac70ae67a5ff0a7df98077dcc4bba3ec8 100644
--- a/edml/controllers/swarm_controller.py
+++ b/edml/controllers/swarm_controller.py
@@ -121,17 +121,32 @@ class SwarmController(BaseController):
         """Returns the dataset sizes and model flops of active devices only."""
         dataset_sizes = {}
         model_flops = {}
-        client_flop_list = []
-        server_flop_list = []
+        client_fw_flop_list = []
+        server_fw_flop_list = []
+        client_bw_flop_list = []
+        server_bw_flop_list = []
         for device_id in self.active_devices:
-            train_samples, val_samples, client_flops, server_flops = (
-                self.request_dispatcher.get_dataset_model_info_on(device_id)
-            )
+            (
+                train_samples,
+                val_samples,
+                client_fw_flops,
+                server_fw_flops,
+                client_bw_flops,
+                server_bw_flops,
+            ) = self.request_dispatcher.get_dataset_model_info_on(device_id)
             dataset_sizes[device_id] = (train_samples, val_samples)
-            client_flop_list.append(client_flops)
-            server_flop_list.append(server_flops)
+            client_fw_flop_list.append(client_fw_flops)
+            server_fw_flop_list.append(server_fw_flops)
+            client_bw_flop_list.append(client_bw_flops)
+            server_bw_flop_list.append(server_bw_flops)
         # avoid that flops are taken from a device that wasn't used for training and thus has no flops
         # apart from that, FLOPs should be the same everywhere
-        model_flops["client"] = max(client_flop_list)
-        model_flops["server"] = max(server_flop_list)
+        model_flops["client"] = {
+            "FW": max(client_fw_flop_list),
+            "BW": max(client_bw_flop_list),
+        }
+        model_flops["server"] = {
+            "FW": max(server_fw_flop_list),
+            "BW": max(server_bw_flop_list),
+        }
         return dataset_sizes, model_flops
diff --git a/edml/core/client.py b/edml/core/client.py
index e609efe2909b448dcbf8fe1a44bf3ab68a8415a5..aae9b9a917a6be55c6c2e273021d0b8289458bcc 100644
--- a/edml/core/client.py
+++ b/edml/core/client.py
@@ -93,6 +93,7 @@ class DeviceClient:
             If a latency factor is specified, this function sleeps for said amount before returning.
         """
         self.node_device = node_device
+        self.node_device.logger.log({"client_model_flops": self._model_flops})
 
     @simulate_latency_decorator(latency_factor_attr="latency_factor")
     def set_weights(self, state_dict: StateDict):
diff --git a/edml/core/device.py b/edml/core/device.py
index 6770e4271e8dc030af1abc7f8ac010a57242f5bb..961ba4b729f2f854cd2e71970dc27875aef2ffbd 100644
--- a/edml/core/device.py
+++ b/edml/core/device.py
@@ -557,8 +557,10 @@ class RPCDeviceServicer(DeviceServicer):
         return connection_pb2.DatasetModelInfoResponse(
             train_samples=len(self.device.client._train_data.dataset),
             validation_samples=len(self.device.client._val_data.dataset),
-            client_model_flops=int(self.device.client._model_flops),
-            server_model_flops=int(self.device.server._model_flops),
+            client_fw_flops=int(self.device.client._model_flops["FW"]),
+            server_fw_flops=int(self.device.server._model_flops["FW"]),
+            client_bw_flops=int(self.device.client._model_flops["BW"]),
+            server_bw_flops=int(self.device.server._model_flops["BW"]),
         )
 
     def TrainGlobalParallelSplitLearning(self, request, context):
@@ -995,7 +997,7 @@ class DeviceRequestDispatcher:
 
     def get_dataset_model_info_on(
         self, device_id: str
-    ) -> Union[Tuple[int, int, float, float], bool]:
+    ) -> Union[Tuple[int, int, int, int, int, int], bool]:
         try:
             response: DatasetModelInfoResponse = self._get_connection(
                 device_id
@@ -1003,8 +1005,10 @@ class DeviceRequestDispatcher:
             return (
                 response.train_samples,
                 response.validation_samples,
-                response.client_model_flops,
-                response.server_model_flops,
+                response.client_fw_flops,
+                response.server_fw_flops,
+                response.client_bw_flops,
+                response.server_bw_flops,
             )
         except grpc.RpcError:
             self._handle_rpc_error(device_id)
diff --git a/edml/core/server.py b/edml/core/server.py
index 7cd6e1013dc8d0be8b48bed01153aafddffc9068..22b6a599a6d7a067a24f2543bc21470f54ca2bf6 100644
--- a/edml/core/server.py
+++ b/edml/core/server.py
@@ -174,6 +174,7 @@ class DeviceServer:
             self._model_flops = estimate_model_flops(
                 self._model, sample.to(self._device).unsqueeze(0)
             )
+            self.node_device.logger.log({"server_model_flops": self._model_flops})
 
     @simulate_latency_decorator(latency_factor_attr="latency_factor")
     def finalize_metrics(self, device_id: str, phase: str):
diff --git a/edml/generated/connection_pb2.py b/edml/generated/connection_pb2.py
index f3c72442f44b99d587925a030002e1af92924ac8..3c25ef7670c9e08c02ca4df71de66afa9e55193b 100644
--- a/edml/generated/connection_pb2.py
+++ b/edml/generated/connection_pb2.py
@@ -14,7 +14,7 @@ _sym_db = _symbol_database.Default()
 import datastructures_pb2 as datastructures__pb2
 
 
-DESCRIPTOR = _descriptor_pool.Default().AddSerializedFile(b'\n\x10\x63onnection.proto\x1a\x14\x64\x61tastructures.proto\"4\n\x13SetGradientsRequest\x12\x1d\n\tgradients\x18\x01 \x01(\x0b\x32\n.Gradients\"5\n\x14UpdateWeightsRequest\x12\x1d\n\tgradients\x18\x01 \x01(\x0b\x32\n.Gradients\";\n\x1aSingleBatchBackwardRequest\x12\x1d\n\tgradients\x18\x01 \x01(\x0b\x32\n.Gradients\"j\n\x1bSingleBatchBackwardResponse\x12\x19\n\x07metrics\x18\x01 \x01(\x0b\x32\x08.Metrics\x12\"\n\tgradients\x18\x02 \x01(\x0b\x32\n.GradientsH\x00\x88\x01\x01\x42\x0c\n\n_gradients\"C\n\x1aSingleBatchTrainingRequest\x12\x13\n\x0b\x62\x61tch_index\x18\x01 \x01(\x05\x12\x10\n\x08round_no\x18\x02 \x01(\x05\"\x80\x01\n\x1bSingleBatchTrainingResponse\x12\'\n\x0csmashed_data\x18\x01 \x01(\x0b\x32\x0c.ActivationsH\x00\x88\x01\x01\x12\x1c\n\x06labels\x18\x02 \x01(\x0b\x32\x07.LabelsH\x01\x88\x01\x01\x42\x0f\n\r_smashed_dataB\t\n\x07_labels\"\xd5\x01\n\'TrainGlobalParallelSplitLearningRequest\x12\x15\n\x08round_no\x18\x01 \x01(\x05H\x00\x88\x01\x01\x12(\n\x1b\x61\x64\x61ptive_learning_threshold\x18\x02 \x01(\x01H\x01\x88\x01\x01\x12(\n\x0foptimizer_state\x18\x03 \x01(\x0b\x32\n.StateDictH\x02\x88\x01\x01\x42\x0b\n\t_round_noB\x1e\n\x1c_adaptive_learning_thresholdB\x12\n\x10_optimizer_state\"\x89\x02\n(TrainGlobalParallelSplitLearningResponse\x12 \n\x0e\x63lient_weights\x18\x01 \x01(\x0b\x32\x08.Weights\x12 \n\x0eserver_weights\x18\x02 \x01(\x0b\x32\x08.Weights\x12\x19\n\x07metrics\x18\x03 \x01(\x0b\x32\x08.Metrics\x12(\n\x0foptimizer_state\x18\x04 \x01(\x0b\x32\n.StateDictH\x00\x88\x01\x01\x12)\n\x12\x64iagnostic_metrics\x18\x05 \x01(\x0b\x32\x08.MetricsH\x01\x88\x01\x01\x42\x12\n\x10_optimizer_stateB\x15\n\x13_diagnostic_metrics\"\x86\x01\n\x12TrainGlobalRequest\x12\x0e\n\x06\x65pochs\x18\x01 \x01(\x05\x12\x15\n\x08round_no\x18\x02 \x01(\x05H\x00\x88\x01\x01\x12(\n\x0foptimizer_state\x18\x03 \x01(\x0b\x32\n.StateDictH\x01\x88\x01\x01\x42\x0b\n\t_round_noB\x12\n\x10_optimizer_state\"\xf4\x01\n\x13TrainGlobalResponse\x12 \n\x0e\x63lient_weights\x18\x01 \x01(\x0b\x32\x08.Weights\x12 \n\x0eserver_weights\x18\x02 \x01(\x0b\x32\x08.Weights\x12\x19\n\x07metrics\x18\x03 \x01(\x0b\x32\x08.Metrics\x12(\n\x0foptimizer_state\x18\x04 \x01(\x0b\x32\n.StateDictH\x00\x88\x01\x01\x12)\n\x12\x64iagnostic_metrics\x18\x05 \x01(\x0b\x32\x08.MetricsH\x01\x88\x01\x01\x42\x12\n\x10_optimizer_stateB\x15\n\x13_diagnostic_metrics\"A\n\x11SetWeightsRequest\x12\x19\n\x07weights\x18\x01 \x01(\x0b\x32\x08.Weights\x12\x11\n\ton_client\x18\x02 \x01(\x08\"V\n\x12SetWeightsResponse\x12)\n\x12\x64iagnostic_metrics\x18\x01 \x01(\x0b\x32\x08.MetricsH\x00\x88\x01\x01\x42\x15\n\x13_diagnostic_metrics\"T\n\x11TrainEpochRequest\x12\x1b\n\x06server\x18\x01 \x01(\x0b\x32\x0b.DeviceInfo\x12\x15\n\x08round_no\x18\x02 \x01(\x05H\x00\x88\x01\x01\x42\x0b\n\t_round_no\"q\n\x12TrainEpochResponse\x12\x19\n\x07weights\x18\x01 \x01(\x0b\x32\x08.Weights\x12)\n\x12\x64iagnostic_metrics\x18\x02 \x01(\x0b\x32\x08.MetricsH\x00\x88\x01\x01\x42\x15\n\x13_diagnostic_metrics\"P\n\x11TrainBatchRequest\x12\"\n\x0csmashed_data\x18\x01 \x01(\x0b\x32\x0c.Activations\x12\x17\n\x06labels\x18\x02 \x01(\x0b\x32\x07.Labels\"\x91\x01\n\x12TrainBatchResponse\x12\x1d\n\tgradients\x18\x01 \x01(\x0b\x32\n.Gradients\x12)\n\x12\x64iagnostic_metrics\x18\x02 \x01(\x0b\x32\x08.MetricsH\x00\x88\x01\x01\x12\x11\n\x04loss\x18\x03 \x01(\x01H\x01\x88\x01\x01\x42\x15\n\x13_diagnostic_metricsB\x07\n\x05_loss\":\n\x11\x45valGlobalRequest\x12\x12\n\nvalidation\x18\x01 \x01(\x08\x12\x11\n\tfederated\x18\x02 \x01(\x08\"q\n\x12\x45valGlobalResponse\x12\x19\n\x07metrics\x18\x01 \x01(\x0b\x32\x08.Metrics\x12)\n\x12\x64iagnostic_metrics\x18\x02 \x01(\x0b\x32\x08.MetricsH\x00\x88\x01\x01\x42\x15\n\x13_diagnostic_metrics\">\n\x0b\x45valRequest\x12\x1b\n\x06server\x18\x01 \x01(\x0b\x32\x0b.DeviceInfo\x12\x12\n\nvalidation\x18\x02 \x01(\x08\"P\n\x0c\x45valResponse\x12)\n\x12\x64iagnostic_metrics\x18\x01 \x01(\x0b\x32\x08.MetricsH\x00\x88\x01\x01\x42\x15\n\x13_diagnostic_metrics\"O\n\x10\x45valBatchRequest\x12\"\n\x0csmashed_data\x18\x01 \x01(\x0b\x32\x0c.Activations\x12\x17\n\x06labels\x18\x02 \x01(\x0b\x32\x07.Labels\"p\n\x11\x45valBatchResponse\x12\x19\n\x07metrics\x18\x01 \x01(\x0b\x32\x08.Metrics\x12)\n\x12\x64iagnostic_metrics\x18\x02 \x01(\x0b\x32\x08.MetricsH\x00\x88\x01\x01\x42\x15\n\x13_diagnostic_metrics\";\n\x15\x46ullModelTrainRequest\x12\x15\n\x08round_no\x18\x01 \x01(\x05H\x00\x88\x01\x01\x42\x0b\n\t_round_no\"\xce\x01\n\x16\x46ullModelTrainResponse\x12 \n\x0e\x63lient_weights\x18\x01 \x01(\x0b\x32\x08.Weights\x12 \n\x0eserver_weights\x18\x02 \x01(\x0b\x32\x08.Weights\x12\x13\n\x0bnum_samples\x18\x03 \x01(\x05\x12\x19\n\x07metrics\x18\x04 \x01(\x0b\x32\x08.Metrics\x12)\n\x12\x64iagnostic_metrics\x18\x05 \x01(\x0b\x32\x08.MetricsH\x00\x88\x01\x01\x42\x15\n\x13_diagnostic_metrics\"\x18\n\x16StartExperimentRequest\"[\n\x17StartExperimentResponse\x12)\n\x12\x64iagnostic_metrics\x18\x01 \x01(\x0b\x32\x08.MetricsH\x00\x88\x01\x01\x42\x15\n\x13_diagnostic_metrics\"\x16\n\x14\x45ndExperimentRequest\"Y\n\x15\x45ndExperimentResponse\x12)\n\x12\x64iagnostic_metrics\x18\x01 \x01(\x0b\x32\x08.MetricsH\x00\x88\x01\x01\x42\x15\n\x13_diagnostic_metrics\"\x16\n\x14\x42\x61tteryStatusRequest\"y\n\x15\x42\x61tteryStatusResponse\x12\x1e\n\x06status\x18\x01 \x01(\x0b\x32\x0e.BatteryStatus\x12)\n\x12\x64iagnostic_metrics\x18\x02 \x01(\x0b\x32\x08.MetricsH\x00\x88\x01\x01\x42\x15\n\x13_diagnostic_metrics\"\x19\n\x17\x44\x61tasetModelInfoRequest\"\xc7\x01\n\x18\x44\x61tasetModelInfoResponse\x12\x15\n\rtrain_samples\x18\x01 \x01(\x05\x12\x1a\n\x12validation_samples\x18\x02 \x01(\x05\x12\x1a\n\x12\x63lient_model_flops\x18\x03 \x01(\x05\x12\x1a\n\x12server_model_flops\x18\x04 \x01(\x05\x12)\n\x12\x64iagnostic_metrics\x18\x05 \x01(\x0b\x32\x08.MetricsH\x00\x88\x01\x01\x42\x15\n\x13_diagnostic_metrics2\xf8\x08\n\x06\x44\x65vice\x12:\n\x0bTrainGlobal\x12\x13.TrainGlobalRequest\x1a\x14.TrainGlobalResponse\"\x00\x12\x37\n\nSetWeights\x12\x12.SetWeightsRequest\x1a\x13.SetWeightsResponse\"\x00\x12\x37\n\nTrainEpoch\x12\x12.TrainEpochRequest\x1a\x13.TrainEpochResponse\"\x00\x12\x37\n\nTrainBatch\x12\x12.TrainBatchRequest\x1a\x13.TrainBatchResponse\"\x00\x12;\n\x0e\x45valuateGlobal\x12\x12.EvalGlobalRequest\x1a\x13.EvalGlobalResponse\"\x00\x12)\n\x08\x45valuate\x12\x0c.EvalRequest\x1a\r.EvalResponse\"\x00\x12\x38\n\rEvaluateBatch\x12\x11.EvalBatchRequest\x1a\x12.EvalBatchResponse\"\x00\x12\x46\n\x11\x46ullModelTraining\x12\x16.FullModelTrainRequest\x1a\x17.FullModelTrainResponse\"\x00\x12\x46\n\x0fStartExperiment\x12\x17.StartExperimentRequest\x1a\x18.StartExperimentResponse\"\x00\x12@\n\rEndExperiment\x12\x15.EndExperimentRequest\x1a\x16.EndExperimentResponse\"\x00\x12\x43\n\x10GetBatteryStatus\x12\x15.BatteryStatusRequest\x1a\x16.BatteryStatusResponse\"\x00\x12L\n\x13GetDatasetModelInfo\x12\x18.DatasetModelInfoRequest\x1a\x19.DatasetModelInfoResponse\"\x00\x12y\n TrainGlobalParallelSplitLearning\x12(.TrainGlobalParallelSplitLearningRequest\x1a).TrainGlobalParallelSplitLearningResponse\"\x00\x12W\n\x18TrainSingleBatchOnClient\x12\x1b.SingleBatchTrainingRequest\x1a\x1c.SingleBatchTrainingResponse\"\x00\x12\x65\n&BackwardPropagationSingleBatchOnClient\x12\x1b.SingleBatchBackwardRequest\x1a\x1c.SingleBatchBackwardResponse\"\x00\x12\x45\n#SetGradientsAndFinalizeTrainingStep\x12\x14.SetGradientsRequest\x1a\x06.Empty\"\x00\x62\x06proto3')
+DESCRIPTOR = _descriptor_pool.Default().AddSerializedFile(b'\n\x10\x63onnection.proto\x1a\x14\x64\x61tastructures.proto\"4\n\x13SetGradientsRequest\x12\x1d\n\tgradients\x18\x01 \x01(\x0b\x32\n.Gradients\"5\n\x14UpdateWeightsRequest\x12\x1d\n\tgradients\x18\x01 \x01(\x0b\x32\n.Gradients\";\n\x1aSingleBatchBackwardRequest\x12\x1d\n\tgradients\x18\x01 \x01(\x0b\x32\n.Gradients\"j\n\x1bSingleBatchBackwardResponse\x12\x19\n\x07metrics\x18\x01 \x01(\x0b\x32\x08.Metrics\x12\"\n\tgradients\x18\x02 \x01(\x0b\x32\n.GradientsH\x00\x88\x01\x01\x42\x0c\n\n_gradients\"C\n\x1aSingleBatchTrainingRequest\x12\x13\n\x0b\x62\x61tch_index\x18\x01 \x01(\x05\x12\x10\n\x08round_no\x18\x02 \x01(\x05\"\x80\x01\n\x1bSingleBatchTrainingResponse\x12\'\n\x0csmashed_data\x18\x01 \x01(\x0b\x32\x0c.ActivationsH\x00\x88\x01\x01\x12\x1c\n\x06labels\x18\x02 \x01(\x0b\x32\x07.LabelsH\x01\x88\x01\x01\x42\x0f\n\r_smashed_dataB\t\n\x07_labels\"\xd5\x01\n\'TrainGlobalParallelSplitLearningRequest\x12\x15\n\x08round_no\x18\x01 \x01(\x05H\x00\x88\x01\x01\x12(\n\x1b\x61\x64\x61ptive_learning_threshold\x18\x02 \x01(\x01H\x01\x88\x01\x01\x12(\n\x0foptimizer_state\x18\x03 \x01(\x0b\x32\n.StateDictH\x02\x88\x01\x01\x42\x0b\n\t_round_noB\x1e\n\x1c_adaptive_learning_thresholdB\x12\n\x10_optimizer_state\"\x89\x02\n(TrainGlobalParallelSplitLearningResponse\x12 \n\x0e\x63lient_weights\x18\x01 \x01(\x0b\x32\x08.Weights\x12 \n\x0eserver_weights\x18\x02 \x01(\x0b\x32\x08.Weights\x12\x19\n\x07metrics\x18\x03 \x01(\x0b\x32\x08.Metrics\x12(\n\x0foptimizer_state\x18\x04 \x01(\x0b\x32\n.StateDictH\x00\x88\x01\x01\x12)\n\x12\x64iagnostic_metrics\x18\x05 \x01(\x0b\x32\x08.MetricsH\x01\x88\x01\x01\x42\x12\n\x10_optimizer_stateB\x15\n\x13_diagnostic_metrics\"\x86\x01\n\x12TrainGlobalRequest\x12\x0e\n\x06\x65pochs\x18\x01 \x01(\x05\x12\x15\n\x08round_no\x18\x02 \x01(\x05H\x00\x88\x01\x01\x12(\n\x0foptimizer_state\x18\x03 \x01(\x0b\x32\n.StateDictH\x01\x88\x01\x01\x42\x0b\n\t_round_noB\x12\n\x10_optimizer_state\"\xf4\x01\n\x13TrainGlobalResponse\x12 \n\x0e\x63lient_weights\x18\x01 \x01(\x0b\x32\x08.Weights\x12 \n\x0eserver_weights\x18\x02 \x01(\x0b\x32\x08.Weights\x12\x19\n\x07metrics\x18\x03 \x01(\x0b\x32\x08.Metrics\x12(\n\x0foptimizer_state\x18\x04 \x01(\x0b\x32\n.StateDictH\x00\x88\x01\x01\x12)\n\x12\x64iagnostic_metrics\x18\x05 \x01(\x0b\x32\x08.MetricsH\x01\x88\x01\x01\x42\x12\n\x10_optimizer_stateB\x15\n\x13_diagnostic_metrics\"A\n\x11SetWeightsRequest\x12\x19\n\x07weights\x18\x01 \x01(\x0b\x32\x08.Weights\x12\x11\n\ton_client\x18\x02 \x01(\x08\"V\n\x12SetWeightsResponse\x12)\n\x12\x64iagnostic_metrics\x18\x01 \x01(\x0b\x32\x08.MetricsH\x00\x88\x01\x01\x42\x15\n\x13_diagnostic_metrics\"T\n\x11TrainEpochRequest\x12\x1b\n\x06server\x18\x01 \x01(\x0b\x32\x0b.DeviceInfo\x12\x15\n\x08round_no\x18\x02 \x01(\x05H\x00\x88\x01\x01\x42\x0b\n\t_round_no\"q\n\x12TrainEpochResponse\x12\x19\n\x07weights\x18\x01 \x01(\x0b\x32\x08.Weights\x12)\n\x12\x64iagnostic_metrics\x18\x02 \x01(\x0b\x32\x08.MetricsH\x00\x88\x01\x01\x42\x15\n\x13_diagnostic_metrics\"P\n\x11TrainBatchRequest\x12\"\n\x0csmashed_data\x18\x01 \x01(\x0b\x32\x0c.Activations\x12\x17\n\x06labels\x18\x02 \x01(\x0b\x32\x07.Labels\"\x91\x01\n\x12TrainBatchResponse\x12\x1d\n\tgradients\x18\x01 \x01(\x0b\x32\n.Gradients\x12)\n\x12\x64iagnostic_metrics\x18\x02 \x01(\x0b\x32\x08.MetricsH\x00\x88\x01\x01\x12\x11\n\x04loss\x18\x03 \x01(\x01H\x01\x88\x01\x01\x42\x15\n\x13_diagnostic_metricsB\x07\n\x05_loss\":\n\x11\x45valGlobalRequest\x12\x12\n\nvalidation\x18\x01 \x01(\x08\x12\x11\n\tfederated\x18\x02 \x01(\x08\"q\n\x12\x45valGlobalResponse\x12\x19\n\x07metrics\x18\x01 \x01(\x0b\x32\x08.Metrics\x12)\n\x12\x64iagnostic_metrics\x18\x02 \x01(\x0b\x32\x08.MetricsH\x00\x88\x01\x01\x42\x15\n\x13_diagnostic_metrics\">\n\x0b\x45valRequest\x12\x1b\n\x06server\x18\x01 \x01(\x0b\x32\x0b.DeviceInfo\x12\x12\n\nvalidation\x18\x02 \x01(\x08\"P\n\x0c\x45valResponse\x12)\n\x12\x64iagnostic_metrics\x18\x01 \x01(\x0b\x32\x08.MetricsH\x00\x88\x01\x01\x42\x15\n\x13_diagnostic_metrics\"O\n\x10\x45valBatchRequest\x12\"\n\x0csmashed_data\x18\x01 \x01(\x0b\x32\x0c.Activations\x12\x17\n\x06labels\x18\x02 \x01(\x0b\x32\x07.Labels\"p\n\x11\x45valBatchResponse\x12\x19\n\x07metrics\x18\x01 \x01(\x0b\x32\x08.Metrics\x12)\n\x12\x64iagnostic_metrics\x18\x02 \x01(\x0b\x32\x08.MetricsH\x00\x88\x01\x01\x42\x15\n\x13_diagnostic_metrics\";\n\x15\x46ullModelTrainRequest\x12\x15\n\x08round_no\x18\x01 \x01(\x05H\x00\x88\x01\x01\x42\x0b\n\t_round_no\"\xce\x01\n\x16\x46ullModelTrainResponse\x12 \n\x0e\x63lient_weights\x18\x01 \x01(\x0b\x32\x08.Weights\x12 \n\x0eserver_weights\x18\x02 \x01(\x0b\x32\x08.Weights\x12\x13\n\x0bnum_samples\x18\x03 \x01(\x05\x12\x19\n\x07metrics\x18\x04 \x01(\x0b\x32\x08.Metrics\x12)\n\x12\x64iagnostic_metrics\x18\x05 \x01(\x0b\x32\x08.MetricsH\x00\x88\x01\x01\x42\x15\n\x13_diagnostic_metrics\"\x18\n\x16StartExperimentRequest\"[\n\x17StartExperimentResponse\x12)\n\x12\x64iagnostic_metrics\x18\x01 \x01(\x0b\x32\x08.MetricsH\x00\x88\x01\x01\x42\x15\n\x13_diagnostic_metrics\"\x16\n\x14\x45ndExperimentRequest\"Y\n\x15\x45ndExperimentResponse\x12)\n\x12\x64iagnostic_metrics\x18\x01 \x01(\x0b\x32\x08.MetricsH\x00\x88\x01\x01\x42\x15\n\x13_diagnostic_metrics\"\x16\n\x14\x42\x61tteryStatusRequest\"y\n\x15\x42\x61tteryStatusResponse\x12\x1e\n\x06status\x18\x01 \x01(\x0b\x32\x0e.BatteryStatus\x12)\n\x12\x64iagnostic_metrics\x18\x02 \x01(\x0b\x32\x08.MetricsH\x00\x88\x01\x01\x42\x15\n\x13_diagnostic_metrics\"\x19\n\x17\x44\x61tasetModelInfoRequest\"\xa5\x02\n\x18\x44\x61tasetModelInfoResponse\x12\x15\n\rtrain_samples\x18\x01 \x01(\x05\x12\x1a\n\x12validation_samples\x18\x02 \x01(\x05\x12\x17\n\x0f\x63lient_fw_flops\x18\x03 \x01(\x05\x12\x17\n\x0fserver_fw_flops\x18\x04 \x01(\x05\x12\x1c\n\x0f\x63lient_bw_flops\x18\x05 \x01(\x05H\x00\x88\x01\x01\x12\x1c\n\x0fserver_bw_flops\x18\x06 \x01(\x05H\x01\x88\x01\x01\x12)\n\x12\x64iagnostic_metrics\x18\x07 \x01(\x0b\x32\x08.MetricsH\x02\x88\x01\x01\x42\x12\n\x10_client_bw_flopsB\x12\n\x10_server_bw_flopsB\x15\n\x13_diagnostic_metrics2\xf8\x08\n\x06\x44\x65vice\x12:\n\x0bTrainGlobal\x12\x13.TrainGlobalRequest\x1a\x14.TrainGlobalResponse\"\x00\x12\x37\n\nSetWeights\x12\x12.SetWeightsRequest\x1a\x13.SetWeightsResponse\"\x00\x12\x37\n\nTrainEpoch\x12\x12.TrainEpochRequest\x1a\x13.TrainEpochResponse\"\x00\x12\x37\n\nTrainBatch\x12\x12.TrainBatchRequest\x1a\x13.TrainBatchResponse\"\x00\x12;\n\x0e\x45valuateGlobal\x12\x12.EvalGlobalRequest\x1a\x13.EvalGlobalResponse\"\x00\x12)\n\x08\x45valuate\x12\x0c.EvalRequest\x1a\r.EvalResponse\"\x00\x12\x38\n\rEvaluateBatch\x12\x11.EvalBatchRequest\x1a\x12.EvalBatchResponse\"\x00\x12\x46\n\x11\x46ullModelTraining\x12\x16.FullModelTrainRequest\x1a\x17.FullModelTrainResponse\"\x00\x12\x46\n\x0fStartExperiment\x12\x17.StartExperimentRequest\x1a\x18.StartExperimentResponse\"\x00\x12@\n\rEndExperiment\x12\x15.EndExperimentRequest\x1a\x16.EndExperimentResponse\"\x00\x12\x43\n\x10GetBatteryStatus\x12\x15.BatteryStatusRequest\x1a\x16.BatteryStatusResponse\"\x00\x12L\n\x13GetDatasetModelInfo\x12\x18.DatasetModelInfoRequest\x1a\x19.DatasetModelInfoResponse\"\x00\x12y\n TrainGlobalParallelSplitLearning\x12(.TrainGlobalParallelSplitLearningRequest\x1a).TrainGlobalParallelSplitLearningResponse\"\x00\x12W\n\x18TrainSingleBatchOnClient\x12\x1b.SingleBatchTrainingRequest\x1a\x1c.SingleBatchTrainingResponse\"\x00\x12\x65\n&BackwardPropagationSingleBatchOnClient\x12\x1b.SingleBatchBackwardRequest\x1a\x1c.SingleBatchBackwardResponse\"\x00\x12\x45\n#SetGradientsAndFinalizeTrainingStep\x12\x14.SetGradientsRequest\x1a\x06.Empty\"\x00\x62\x06proto3')
 
 _globals = globals()
 _builder.BuildMessageAndEnumDescriptors(DESCRIPTOR, _globals)
@@ -85,7 +85,7 @@ if _descriptor._USE_C_DESCRIPTORS == False:
   _globals['_DATASETMODELINFOREQUEST']._serialized_start=3141
   _globals['_DATASETMODELINFOREQUEST']._serialized_end=3166
   _globals['_DATASETMODELINFORESPONSE']._serialized_start=3169
-  _globals['_DATASETMODELINFORESPONSE']._serialized_end=3368
-  _globals['_DEVICE']._serialized_start=3371
-  _globals['_DEVICE']._serialized_end=4515
+  _globals['_DATASETMODELINFORESPONSE']._serialized_end=3462
+  _globals['_DEVICE']._serialized_start=3465
+  _globals['_DEVICE']._serialized_end=4609
 # @@protoc_insertion_point(module_scope)
diff --git a/edml/generated/connection_pb2.pyi b/edml/generated/connection_pb2.pyi
index 89353343aa1c7e39072bd0ea03c891bd2df7b4df..11b713d56cdbe69e9f9b68a52bd9aa51693f49b3 100644
--- a/edml/generated/connection_pb2.pyi
+++ b/edml/generated/connection_pb2.pyi
@@ -246,15 +246,19 @@ class DatasetModelInfoRequest(_message.Message):
     def __init__(self) -> None: ...
 
 class DatasetModelInfoResponse(_message.Message):
-    __slots__ = ["train_samples", "validation_samples", "client_model_flops", "server_model_flops", "diagnostic_metrics"]
+    __slots__ = ["train_samples", "validation_samples", "client_fw_flops", "server_fw_flops", "client_bw_flops", "server_bw_flops", "diagnostic_metrics"]
     TRAIN_SAMPLES_FIELD_NUMBER: _ClassVar[int]
     VALIDATION_SAMPLES_FIELD_NUMBER: _ClassVar[int]
-    CLIENT_MODEL_FLOPS_FIELD_NUMBER: _ClassVar[int]
-    SERVER_MODEL_FLOPS_FIELD_NUMBER: _ClassVar[int]
+    CLIENT_FW_FLOPS_FIELD_NUMBER: _ClassVar[int]
+    SERVER_FW_FLOPS_FIELD_NUMBER: _ClassVar[int]
+    CLIENT_BW_FLOPS_FIELD_NUMBER: _ClassVar[int]
+    SERVER_BW_FLOPS_FIELD_NUMBER: _ClassVar[int]
     DIAGNOSTIC_METRICS_FIELD_NUMBER: _ClassVar[int]
     train_samples: int
     validation_samples: int
-    client_model_flops: int
-    server_model_flops: int
+    client_fw_flops: int
+    server_fw_flops: int
+    client_bw_flops: int
+    server_bw_flops: int
     diagnostic_metrics: _datastructures_pb2.Metrics
-    def __init__(self, train_samples: _Optional[int] = ..., validation_samples: _Optional[int] = ..., client_model_flops: _Optional[int] = ..., server_model_flops: _Optional[int] = ..., diagnostic_metrics: _Optional[_Union[_datastructures_pb2.Metrics, _Mapping]] = ...) -> None: ...
+    def __init__(self, train_samples: _Optional[int] = ..., validation_samples: _Optional[int] = ..., client_fw_flops: _Optional[int] = ..., server_fw_flops: _Optional[int] = ..., client_bw_flops: _Optional[int] = ..., server_bw_flops: _Optional[int] = ..., diagnostic_metrics: _Optional[_Union[_datastructures_pb2.Metrics, _Mapping]] = ...) -> None: ...
diff --git a/edml/helpers/flops.py b/edml/helpers/flops.py
index b3fc822fe9aaedf359f69ce5cf9dd624e4a66c37..038f27fd31391a8feac144da37db988c1aa93c69 100644
--- a/edml/helpers/flops.py
+++ b/edml/helpers/flops.py
@@ -19,7 +19,7 @@ def estimate_model_flops(
     using the full model.
     Args:
         model (nn.Module): the neural network model to calculate the FLOPs for.
-        sample: The data used to calculate the FLOPs.
+        sample: The data used to calculate the FLOPs. The first dimension should be the batch dimension.
 
     Returns:
         dict(str, int): {"FW": #ForwardFLOPs, "BW": #BackwardFLOPs} the number of estimated forward and backward FLOPs.
diff --git a/edml/proto/connection.proto b/edml/proto/connection.proto
index 5755518d01796bdff68957655d4d339dcf02ab1d..6fecb49b933ebdfed642b247322b2a099fb5901e 100644
--- a/edml/proto/connection.proto
+++ b/edml/proto/connection.proto
@@ -173,7 +173,9 @@ message DatasetModelInfoRequest {}
 message DatasetModelInfoResponse {
   int32 train_samples = 1;
   int32 validation_samples = 2;
-  int32 client_model_flops = 3;
-  int32 server_model_flops = 4;
-  optional Metrics diagnostic_metrics = 5;
+  int32 client_fw_flops = 3;
+  int32 server_fw_flops = 4;
+  optional int32 client_bw_flops = 5;
+  optional int32 server_bw_flops = 6;
+  optional Metrics diagnostic_metrics = 7;
 }
diff --git a/edml/tests/controllers/optimization_test.py b/edml/tests/controllers/optimization_test.py
index 2e8cc390c7a18e0d2b5c8e43822930973f32f953..ab89b184b992891dbe7793e31c889ce4ea912212 100644
--- a/edml/tests/controllers/optimization_test.py
+++ b/edml/tests/controllers/optimization_test.py
@@ -42,8 +42,8 @@ class StrategyOptimizationTest(unittest.TestCase):
             cost_per_byte_sent=1,
             cost_per_byte_received=1,
             cost_per_flop=1,
-            client_model_flops=10,
-            server_model_flops=20,
+            client_fw_flops=10,
+            server_fw_flops=20,
             smashed_data_size=50,
             label_size=5,
             gradient_size=50,
@@ -168,8 +168,8 @@ class EnergySimulatorTest(unittest.TestCase):
             cost_per_byte_sent=1,
             cost_per_byte_received=1,
             cost_per_flop=1,
-            client_model_flops=10,
-            server_model_flops=20,
+            client_fw_flops=10,
+            server_fw_flops=20,
             smashed_data_size=50,
             label_size=5,
             gradient_size=50,
@@ -284,8 +284,8 @@ class TestWithRealData(unittest.TestCase):
             cost_per_byte_sent=self.cost_per_mbyte_sent / 1000000,
             cost_per_byte_received=self.cost_per_mbyte_received / 1000000,
             cost_per_flop=self.cost_per_mflop / 1000000,
-            client_model_flops=5405760,
-            server_model_flops=11215800,
+            client_fw_flops=5405760,
+            server_fw_flops=11215800,
             smashed_data_size=36871,
             label_size=14,
             gradient_size=36871,
@@ -450,3 +450,85 @@ class TestWithRealData(unittest.TestCase):
         )
         print(solution, sum(solution.values()), num_rounds, remaining_batteries)
         self.assertGreater(sum(solution.values()), num_rounds)
+
+
+class TestBug(unittest.TestCase):
+    """Case studies for estimating the number of rounds of each strategy with given energy constratins."""
+
+    def setUp(self):
+        self.global_params = GlobalParams(
+            cost_per_sec=0.02,
+            cost_per_byte_sent=2e-10,
+            cost_per_byte_received=2e-10,
+            cost_per_flop=4.9999999999999995e-14,
+            client_fw_flops=88408064,
+            server_fw_flops=169728256,
+            smashed_data_size=65542.875,
+            batch_size=64,
+            client_weights_size=416469,
+            server_weights_size=6769181,
+            optimizer_state_size=6664371,
+            train_global_time=204.80956506729126,
+            last_server_device_id="d0",
+            label_size=14.359375,
+            gradient_size=65542.875,
+            client_norm_fw_time=0.0002889558474222819,
+            client_norm_bw_time=0.00033905270364549425,
+            server_norm_fw_time=0.0035084471996721703,
+            server_norm_bw_time=0.005231082973148723,
+        )
+        self.device_params_list = [
+            DeviceParams(
+                device_id="d0",
+                initial_battery=400.0,
+                current_battery=393.09099611222206,
+                train_samples=4500,
+                validation_samples=500,
+                comp_latency_factor=3.322075197090219,
+            ),
+            DeviceParams(
+                device_id="d1",
+                initial_battery=400.0,
+                current_battery=398.2543529099223,
+                train_samples=4500,
+                validation_samples=500,
+                comp_latency_factor=3.5023548154181494,
+            ),
+            DeviceParams(
+                device_id="d2",
+                initial_battery=300.0,
+                current_battery=297.6957505401916,
+                train_samples=4500,
+                validation_samples=500,
+                comp_latency_factor=3.2899405054194144,
+            ),
+            DeviceParams(
+                device_id="d3",
+                initial_battery=200.0,
+                current_battery=197.12318101733192,
+                train_samples=4500,
+                validation_samples=500,
+                comp_latency_factor=3.4622951339692114,
+            ),
+            DeviceParams(
+                device_id="d4",
+                initial_battery=200.0,
+                current_battery=194.35766488685508,
+                train_samples=27000,
+                validation_samples=3000,
+                comp_latency_factor=1.0,
+            ),
+        ]
+
+        self.optimizer = ServerChoiceOptimizer(
+            self.device_params_list, self.global_params
+        )
+        self.simulator = EnergySimulator(self.device_params_list, self.global_params)
+
+    def test_bug(self):
+        solution, status = self.optimizer.optimize()
+        num_rounds, schedule, remaining_batteries = (
+            self.simulator.simulate_smart_selection()
+        )
+        print(solution, sum(solution.values()), num_rounds, remaining_batteries)
+        self.assertEqual(sum(solution.values()), num_rounds)
diff --git a/edml/tests/controllers/scheduler/smart_test.py b/edml/tests/controllers/scheduler/smart_test.py
index 6eb7e758073a82ff965eaf414fc37946164d2541..a711f7d78deb4795859c09edb67b2b10431ae087 100644
--- a/edml/tests/controllers/scheduler/smart_test.py
+++ b/edml/tests/controllers/scheduler/smart_test.py
@@ -39,7 +39,10 @@ class SmartServerDeviceSelectionTest(unittest.TestCase):
                 "d0": (2, 1),
                 "d1": (4, 1),
             },
-            {"client": 1000000, "server": 1000000},
+            {
+                "client": {"FW": 1000000, "BW": 2000000},
+                "server": {"FW": 1000000, "BW": 2000000},
+            },
         ]
 
     def test_select_server_device_smart_first_round(self):
diff --git a/edml/tests/core/device_test.py b/edml/tests/core/device_test.py
index 3827df9c727254bb13482e22bef6899bb2ae2c47..a0d2be2eaf7c4f2f1a8e27b3f1182e842658c87d 100644
--- a/edml/tests/core/device_test.py
+++ b/edml/tests/core/device_test.py
@@ -315,8 +315,8 @@ class RPCDeviceServicerTest(unittest.TestCase):
     def test_dataset_model_info(self):
         self.mock_device.client._train_data.dataset = [1]
         self.mock_device.client._val_data.dataset = [2]
-        self.mock_device.client._model_flops = 3
-        self.mock_device.server._model_flops = 4
+        self.mock_device.client._model_flops = {"FW": 3, "BW": 6}
+        self.mock_device.server._model_flops = {"FW": 4, "BW": 8}
         request = connection_pb2.DatasetModelInfoRequest()
 
         response, metadata, code, details = self.make_call(
@@ -326,8 +326,10 @@ class RPCDeviceServicerTest(unittest.TestCase):
         self.assertEqual(code, StatusCode.OK)
         self.assertEqual(response.train_samples, 1)
         self.assertEqual(response.validation_samples, 1)
-        self.assertEqual(response.client_model_flops, 3)
-        self.assertEqual(response.server_model_flops, 4)
+        self.assertEqual(response.client_fw_flops, 3)
+        self.assertEqual(response.server_fw_flops, 4)
+        self.assertEqual(response.client_bw_flops, 6)
+        self.assertEqual(response.server_bw_flops, 8)
 
 
 class RPCDeviceServicerBatteryEmptyTest(unittest.TestCase):
@@ -809,14 +811,16 @@ class RequestDispatcherTest(unittest.TestCase):
             connection_pb2.DatasetModelInfoResponse(
                 train_samples=42,
                 validation_samples=21,
-                client_model_flops=42,
-                server_model_flops=21,
+                client_fw_flops=1,
+                server_fw_flops=2,
+                client_bw_flops=3,
+                server_bw_flops=4,
             )
         )
 
         response = self.dispatcher.get_dataset_model_info_on("1")
 
-        self.assertEqual(response, (42, 21, 42, 21))
+        self.assertEqual(response, (42, 21, 1, 2, 3, 4))
         self.mock_stub.GetDatasetModelInfo.assert_called_once_with(
             connection_pb2.DatasetModelInfoRequest()
         )
diff --git a/edml/tests/core/server_test.py b/edml/tests/core/server_test.py
index 5ed6d1c1cad0b4070e8b213f1f6520b87919bc1f..c855d82bdb82b26356b91c04960c4c5f8dafd164 100644
--- a/edml/tests/core/server_test.py
+++ b/edml/tests/core/server_test.py
@@ -13,6 +13,7 @@ from edml.core.battery import Battery
 from edml.core.client import DeviceClient
 from edml.core.device import Device
 from edml.core.server import DeviceServer
+from edml.helpers.logging import SimpleLogger
 
 
 class ClientModel(nn.Module):
@@ -167,6 +168,7 @@ class PSLTest(unittest.TestCase):
 
         node_device = Mock(Device)
         node_device.battery = Mock(Battery)
+        node_device.logger = Mock(SimpleLogger)
         node_device.train_batch_on_client_only_on.side_effect = get_client_side_effect(
             "train_single_batch"
         )
diff --git a/edml/tests/helpers/flops_test.py b/edml/tests/helpers/flops_test.py
index b19b2b2e590e1f0482ae5c3bc2a7b45b6325ccda..cb966d4864c1b11b16f75e1f1ed94a504db98a1f 100644
--- a/edml/tests/helpers/flops_test.py
+++ b/edml/tests/helpers/flops_test.py
@@ -1,10 +1,15 @@
+import os
 import unittest
 
 import torch
 import torch.nn as nn
+from omegaconf import OmegaConf
 
 from edml.helpers.flops import estimate_model_flops
 from edml.models.mnist_models import ClientNet, ServerNet
+from edml.tests.models.model_loading_helpers import (
+    _get_model_from_model_provider_config,
+)
 
 
 class FullTestModel(nn.Module):
@@ -77,3 +82,32 @@ class FlopsTest(unittest.TestCase):
 
         self.assertEqual(client_flops, {"BW": 10811520, "FW": 5405760})
         self.assertEqual(server_flops, {"BW": 22431600, "FW": 11215800})
+
+    def test_autoencoder_does_not_affect_backward_flops(self):
+        os.chdir(os.path.join(os.path.dirname(__file__), "../../../"))
+        client, server = _get_model_from_model_provider_config(
+            OmegaConf.create({}), "resnet20"
+        )
+        client_with_encoder, server_with_decoder = (
+            _get_model_from_model_provider_config(
+                OmegaConf.create({}), "resnet20-with-autoencoder"
+            )
+        )
+        input = torch.randn(1, 3, 32, 32)
+
+        ae_smashed_data = client_with_encoder(input)
+        smashed_data = client(input)
+
+        client_flops = estimate_model_flops(client, input)
+        server_flops = estimate_model_flops(server, smashed_data)
+
+        client_with_encoder_flops = estimate_model_flops(client_with_encoder, input)
+        server_with_decoder_flops = estimate_model_flops(
+            server_with_decoder, ae_smashed_data
+        )
+
+        self.assertEqual(client_flops["BW"], client_with_encoder_flops["BW"])
+        self.assertEqual(server_flops["BW"], server_with_decoder_flops["BW"])
+
+        self.assertGreater(client_with_encoder_flops["FW"], client_flops["FW"])
+        self.assertGreater(server_with_decoder_flops["FW"], server_flops["FW"])
diff --git a/results/result_generation.py b/results/result_generation.py
index f7db5c83721b4f51fb6cd08b7f14e6a277a4387b..f22bca3d8558982250e4fa8d3376ee41b70d7539 100644
--- a/results/result_generation.py
+++ b/results/result_generation.py
@@ -232,7 +232,13 @@ def load_dataframes(project_name, base_dir="./dataframes"):
     return history_groups
 
 
-def get_total_flops(groups, total_model_flops, client_model_flops, batch_size=64):
+def get_total_flops(
+    groups,
+    total_model_flops,
+    client_model_flops,
+    strategy_autoencoder_mapping,
+    batch_size=64,
+):
     """
     Returns the total number of FLOPs for each group.
     Args:
@@ -244,6 +250,13 @@ def get_total_flops(groups, total_model_flops, client_model_flops, batch_size=64
     """
     flops_per_group = {"strategy": [], "flops": []}
     for (strategy, job), group in groups.items():
+        # determine model FLOPs depending on whether an autoencoder was used
+        if strategy_autoencoder_mapping[strategy]:  # AE
+            total_model_fw_flops = total_model_flops["ae"]
+        else:  # no AE
+            total_model_fw_flops = total_model_flops["plain"]
+        total_model_bw_flops = 2 * total_model_flops["plain"]
+        client_model_bw_flops = 2 * client_model_flops["plain"]
         if job == "train":
             flops = 0
             num_runs = 1  # avoid division by 0
@@ -256,12 +269,12 @@ def get_total_flops(groups, total_model_flops, client_model_flops, batch_size=64
                     for run_df in runs:
                         for col_name in run_df.columns:
                             if col_name == "train_accuracy.num_samples":
-                                flops += (
-                                    run_df[col_name].sum() * total_model_flops * 3
+                                flops += run_df[col_name].sum() * (
+                                    total_model_fw_flops + total_model_bw_flops
                                 )  # 1x forward + 2x backward
                             if col_name == "val_accuracy.num_samples":
                                 flops += (
-                                    run_df[col_name].sum() * total_model_flops
+                                    run_df[col_name].sum() * total_model_fw_flops
                                 )  # 1x forward
                             if col_name == "adaptive_learning_threshold_applied":
                                 # deduce client model flops twice as client backprop is avoided
@@ -287,13 +300,12 @@ def get_total_flops(groups, total_model_flops, client_model_flops, batch_size=64
                                     )
                                     flops -= (
                                         len(run_df[col_name].dropna())
-                                        * client_model_flops
-                                        * 2
+                                        * client_model_bw_flops
                                         * avg_samples_per_batch
                                     )
                                 else:  # numbers of samples skipped are logged -> sum up
                                     flops -= (
-                                        run_df[col_name].sum() * client_model_flops * 2
+                                        run_df[col_name].sum() * client_model_bw_flops
                                     )
             flops = flops / num_runs
             flops_per_group["strategy"].append(STRATEGY_MAPPING[strategy])
@@ -1244,6 +1256,7 @@ def generate_metric_files(
     project_name,
     total_model_flops,
     client_model_flops,
+    strategy_autoencoder_mapping,
     base_path="./metrics",
     batch_size=64,
 ):
@@ -1267,7 +1280,11 @@ def generate_metric_files(
     ).set_index("strategy")
     total_flops = pd.DataFrame.from_dict(
         get_total_flops(
-            history_groups, total_model_flops, client_model_flops, batch_size
+            history_groups,
+            total_model_flops,
+            client_model_flops,
+            strategy_autoencoder_mapping,
+            batch_size,
         )
     ).set_index("strategy")
     df = pd.concat([test_acc, comm_overhead, total_flops], axis=1)