From c50b171a8a8a08556d30e0a814dee4eaefd92fc0 Mon Sep 17 00:00:00 2001
From: Tim Bauerle <tim.bauerle@rwth-aachen.de>
Date: Thu, 22 Aug 2024 16:07:16 +0200
Subject: [PATCH] Adjusted FLOP estimation to account for different parts of
 the model used in forward and backward pass

---
 edml/core/client.py              | 23 ++++++++++----------
 edml/core/server.py              | 26 ++++++++++++++---------
 edml/helpers/flops.py            | 36 ++++++++++++++++++++++++++------
 edml/tests/helpers/flops_test.py | 16 +++++++++-----
 4 files changed, 68 insertions(+), 33 deletions(-)

diff --git a/edml/core/client.py b/edml/core/client.py
index 1a81dfb..e609efe 100644
--- a/edml/core/client.py
+++ b/edml/core/client.py
@@ -1,6 +1,5 @@
 from __future__ import annotations
 
-import itertools
 import time
 from typing import Optional, Tuple, TYPE_CHECKING, Any
 
@@ -76,7 +75,7 @@ class DeviceClient:
         sample = self._train_data.dataset.__getitem__(0)[0]
         if not isinstance(sample, torch.Tensor):
             sample = torch.tensor(data=sample)
-        self._model_flops = estimate_model_flops(
+        self._model_flops: dict[str, int] = estimate_model_flops(
             self._model, sample.to(self._device).unsqueeze(0)
         )
         self._cfg = cfg
@@ -169,10 +168,10 @@ class DeviceClient:
         # Safety check to ensure that we train same-sized batches only.
         batch_data, batch_labels = next(self._batchable_data_loader)
 
-        # Updates the battery capacity by simulating the required energy consumption for conducting the training step.
-        self.node_device.battery.update_flops(self._model_flops * len(batch_data))
+        # Updates the battery capacity by simulating the required energy consumption for conducting the forward pass.
+        self.node_device.battery.update_flops(self._model_flops["FW"] * len(batch_data))
 
-        # We train the model using the single batch and return the activations and labels. These get send over to the
+        # We train the model using the single batch and return the activations and labels. These get sent over to the
         # server to be then further processed
 
         with LatencySimulator(latency_factor=self.latency_factor):
@@ -205,9 +204,7 @@ class DeviceClient:
 
         start_time_2 = time.time()
 
-        self.node_device.battery.update_flops(
-            self._model_flops * len(batch_data) * 2
-        )  # 2x for backward pass
+        self.node_device.battery.update_flops(self._model_flops["BW"] * len(batch_data))
         gradients = gradients.to(self._device)
         if has_autoencoder(self._model):
             self._model.trainable_layers_output.backward(gradients)
@@ -280,7 +277,9 @@ class DeviceClient:
         self._model.train()
         diagnostic_metric_container = DiagnosticMetricResultContainer()
         for idx, (batch_data, batch_labels) in enumerate(self._train_data):
-            self.node_device.battery.update_flops(self._model_flops * len(batch_data))
+            self.node_device.battery.update_flops(
+                self._model_flops["FW"] * len(batch_data)
+            )
 
             with LatencySimulator(latency_factor=self.latency_factor):
                 batch_data = batch_data.to(self._device)
@@ -304,8 +303,8 @@ class DeviceClient:
                 server_grad, _server_loss, diagnostic_metrics = train_batch_response
                 diagnostic_metric_container.merge(diagnostic_metrics)
                 self.node_device.battery.update_flops(
-                    self._model_flops * len(batch_data) * 2
-                )  # 2x for backward pass
+                    self._model_flops["BW"] * len(batch_data)
+                )
                 server_grad = server_grad.to(self._device)
                 if has_autoencoder(self._model):
                     self._model.trainable_layers_output.backward(server_grad)
@@ -362,7 +361,7 @@ class DeviceClient:
             for b, (batch_data, batch_labels) in enumerate(dataloader):
                 with LatencySimulator(latency_factor=self.latency_factor):
                     self.node_device.battery.update_flops(
-                        self._model_flops * len(batch_data)
+                        self._model_flops["FW"] * len(batch_data)
                     )
                     batch_data = batch_data.to(self._device)
                     batch_labels = batch_labels.to(self._device)
diff --git a/edml/core/server.py b/edml/core/server.py
index d13893c..7cd6e10 100644
--- a/edml/core/server.py
+++ b/edml/core/server.py
@@ -43,7 +43,7 @@ class DeviceServer:
         self._optimizer, self._lr_scheduler = get_optimizer_and_scheduler(
             cfg, self._model.get_optimizer_params()
         )
-        self._model_flops = 0  # determine later
+        self._model_flops: dict[str, int] = {"FW": 0, "BW": 0}  # determine later
         self._metrics = create_metrics(
             cfg.experiment.metrics, cfg.dataset.num_classes, cfg.dataset.average_setting
         )
@@ -140,17 +140,21 @@ class DeviceServer:
         Returns the gradients of the model's parameters."""
         smashed_data, labels = smashed_data.to(self._device), labels.to(self._device)
 
-        self._set_model_flops(smashed_data)
+        self._set_model_flops(smashed_data[0])
 
         self._optimizer.zero_grad()
 
-        self.node_device.battery.update_flops(self._model_flops * len(smashed_data))
+        self.node_device.battery.update_flops(
+            self._model_flops["FW"] * len(smashed_data)
+        )
         smashed_data = Variable(smashed_data, requires_grad=True)
         output_train = self._model(smashed_data)
 
         loss_train = self._loss_fn(output_train, labels)
 
-        self.node_device.battery.update_flops(self._model_flops * len(smashed_data) * 2)
+        self.node_device.battery.update_flops(
+            self._model_flops["BW"] * len(smashed_data)
+        )
         loss_train.backward()
         self._optimizer.step()
 
@@ -164,11 +168,11 @@ class DeviceServer:
             loss_train.item(),
         )  # hier sollten beim AE die gradients vom server model vor dem decoder zurückgegeben werden
 
-    def _set_model_flops(self, smashed_data):
+    def _set_model_flops(self, sample):
         """Helper to determine the model flops when smashed data are available for the first time."""
-        if self._model_flops == 0:
-            self._model_flops = estimate_model_flops(self._model, smashed_data) / len(
-                smashed_data
+        if self._model_flops["FW"] == 0 or self._model_flops["BW"] == 0:
+            self._model_flops = estimate_model_flops(
+                self._model, sample.to(self._device).unsqueeze(0)
             )
 
     @simulate_latency_decorator(latency_factor_attr="latency_factor")
@@ -210,8 +214,10 @@ class DeviceServer:
         """Evaluates the model on the given batch of data and labels"""
         with torch.no_grad():
             smashed_data = smashed_data.to(self._device)
-            self._set_model_flops(smashed_data)
-            self.node_device.battery.update_flops(self._model_flops * len(smashed_data))
+            self._set_model_flops(smashed_data[0])
+            self.node_device.battery.update_flops(
+                self._model_flops["FW"] * len(smashed_data)
+            )
             pred = self._model(smashed_data)
         self._metrics.metrics_on_batch(pred.cpu(), labels.cpu().int())
 
diff --git a/edml/helpers/flops.py b/edml/helpers/flops.py
index 4388c97..b3fc822 100644
--- a/edml/helpers/flops.py
+++ b/edml/helpers/flops.py
@@ -3,19 +3,43 @@ from typing import Union, Tuple
 from fvcore.nn import FlopCountAnalysis
 from torch import Tensor, nn
 
+from edml.models.autoencoder import ClientWithAutoencoder, ServerWithAutoencoder
+from edml.models.provider.base import has_autoencoder
+
 
 def estimate_model_flops(
     model: nn.Module, sample: Union[Tensor, Tuple[Tensor, ...]]
-) -> int:
+) -> dict[str, int]:
     """
-    Estimates the FLOPs of one forward pass of the model using the sample data provided.
-
+    Estimates the FLOPs of one forward pass and one backward pass of the model using the sample data provided.
+    If forward and backward pass affect the same parameters, the number of FLOPs of one backward pass is assumed to be
+    two times the number of FLOPs for the backward pass.
+    In case of non-trainable parameters as an autoencoder in between, the number of backward FLOPs is assumed to be
+    twice the number of the forward FLOPs of the trainable part, while the number of actual forward FLOPs is estimated
+    using the full model.
     Args:
         model (nn.Module): the neural network model to calculate the FLOPs for.
         sample: The data used to calculate the FLOPs.
 
     Returns:
-        int: the number of FLOPs.
-
+        dict(str, int): {"FW": #ForwardFLOPs, "BW": #BackwardFLOPs} the number of estimated forward and backward FLOPs.
     """
-    return FlopCountAnalysis(model, sample).total()
+    fw_flops = FlopCountAnalysis(model, sample).total()
+    if has_autoencoder(model):
+        bw_flops = 0
+        if isinstance(
+            model, ClientWithAutoencoder
+        ):  # run FW pass until Encoder and multiply by 2
+            bw_flops = FlopCountAnalysis(model.model, sample).total() * 2
+        elif isinstance(
+            model, ServerWithAutoencoder
+        ):  # run FW pass from decoder output and multiply by 2
+            bw_flops = (
+                FlopCountAnalysis(model.model, model.autoencoder(sample)).total() * 2
+            )
+        else:
+            raise NotImplementedError()
+    else:
+        bw_flops = 2 * fw_flops
+
+    return {"FW": fw_flops, "BW": bw_flops}
diff --git a/edml/tests/helpers/flops_test.py b/edml/tests/helpers/flops_test.py
index 4b7f0d6..b19b2b2 100644
--- a/edml/tests/helpers/flops_test.py
+++ b/edml/tests/helpers/flops_test.py
@@ -55,9 +55,15 @@ class FlopsTest(unittest.TestCase):
         client_flops = estimate_model_flops(client_model, inputs)
         server_flops = estimate_model_flops(server_model, server_inputs)
 
-        self.assertEqual(client_flops, 30000)
-        self.assertEqual(server_flops, 101000)
-        self.assertEqual(full_flops, client_flops + server_flops)
+        self.assertEqual(client_flops, {"BW": 60000, "FW": 30000})
+        self.assertEqual(server_flops, {"BW": 202000, "FW": 101000})
+        self.assertEqual(
+            full_flops,
+            {
+                "BW": client_flops["BW"] + server_flops["BW"],
+                "FW": client_flops["FW"] + server_flops["FW"],
+            },
+        )
 
     def test_mnist_split_count(self):
         client_model = ClientNet()
@@ -69,5 +75,5 @@ class FlopsTest(unittest.TestCase):
         client_flops = estimate_model_flops(client_model, inputs)
         server_flops = estimate_model_flops(server_model, server_inputs)
 
-        self.assertEqual(client_flops, 5405760)
-        self.assertEqual(server_flops, 11215800)
+        self.assertEqual(client_flops, {"BW": 10811520, "FW": 5405760})
+        self.assertEqual(server_flops, {"BW": 22431600, "FW": 11215800})
-- 
GitLab