From c50b171a8a8a08556d30e0a814dee4eaefd92fc0 Mon Sep 17 00:00:00 2001 From: Tim Bauerle <tim.bauerle@rwth-aachen.de> Date: Thu, 22 Aug 2024 16:07:16 +0200 Subject: [PATCH] Adjusted FLOP estimation to account for different parts of the model used in forward and backward pass --- edml/core/client.py | 23 ++++++++++---------- edml/core/server.py | 26 ++++++++++++++--------- edml/helpers/flops.py | 36 ++++++++++++++++++++++++++------ edml/tests/helpers/flops_test.py | 16 +++++++++----- 4 files changed, 68 insertions(+), 33 deletions(-) diff --git a/edml/core/client.py b/edml/core/client.py index 1a81dfb..e609efe 100644 --- a/edml/core/client.py +++ b/edml/core/client.py @@ -1,6 +1,5 @@ from __future__ import annotations -import itertools import time from typing import Optional, Tuple, TYPE_CHECKING, Any @@ -76,7 +75,7 @@ class DeviceClient: sample = self._train_data.dataset.__getitem__(0)[0] if not isinstance(sample, torch.Tensor): sample = torch.tensor(data=sample) - self._model_flops = estimate_model_flops( + self._model_flops: dict[str, int] = estimate_model_flops( self._model, sample.to(self._device).unsqueeze(0) ) self._cfg = cfg @@ -169,10 +168,10 @@ class DeviceClient: # Safety check to ensure that we train same-sized batches only. batch_data, batch_labels = next(self._batchable_data_loader) - # Updates the battery capacity by simulating the required energy consumption for conducting the training step. - self.node_device.battery.update_flops(self._model_flops * len(batch_data)) + # Updates the battery capacity by simulating the required energy consumption for conducting the forward pass. + self.node_device.battery.update_flops(self._model_flops["FW"] * len(batch_data)) - # We train the model using the single batch and return the activations and labels. These get send over to the + # We train the model using the single batch and return the activations and labels. These get sent over to the # server to be then further processed with LatencySimulator(latency_factor=self.latency_factor): @@ -205,9 +204,7 @@ class DeviceClient: start_time_2 = time.time() - self.node_device.battery.update_flops( - self._model_flops * len(batch_data) * 2 - ) # 2x for backward pass + self.node_device.battery.update_flops(self._model_flops["BW"] * len(batch_data)) gradients = gradients.to(self._device) if has_autoencoder(self._model): self._model.trainable_layers_output.backward(gradients) @@ -280,7 +277,9 @@ class DeviceClient: self._model.train() diagnostic_metric_container = DiagnosticMetricResultContainer() for idx, (batch_data, batch_labels) in enumerate(self._train_data): - self.node_device.battery.update_flops(self._model_flops * len(batch_data)) + self.node_device.battery.update_flops( + self._model_flops["FW"] * len(batch_data) + ) with LatencySimulator(latency_factor=self.latency_factor): batch_data = batch_data.to(self._device) @@ -304,8 +303,8 @@ class DeviceClient: server_grad, _server_loss, diagnostic_metrics = train_batch_response diagnostic_metric_container.merge(diagnostic_metrics) self.node_device.battery.update_flops( - self._model_flops * len(batch_data) * 2 - ) # 2x for backward pass + self._model_flops["BW"] * len(batch_data) + ) server_grad = server_grad.to(self._device) if has_autoencoder(self._model): self._model.trainable_layers_output.backward(server_grad) @@ -362,7 +361,7 @@ class DeviceClient: for b, (batch_data, batch_labels) in enumerate(dataloader): with LatencySimulator(latency_factor=self.latency_factor): self.node_device.battery.update_flops( - self._model_flops * len(batch_data) + self._model_flops["FW"] * len(batch_data) ) batch_data = batch_data.to(self._device) batch_labels = batch_labels.to(self._device) diff --git a/edml/core/server.py b/edml/core/server.py index d13893c..7cd6e10 100644 --- a/edml/core/server.py +++ b/edml/core/server.py @@ -43,7 +43,7 @@ class DeviceServer: self._optimizer, self._lr_scheduler = get_optimizer_and_scheduler( cfg, self._model.get_optimizer_params() ) - self._model_flops = 0 # determine later + self._model_flops: dict[str, int] = {"FW": 0, "BW": 0} # determine later self._metrics = create_metrics( cfg.experiment.metrics, cfg.dataset.num_classes, cfg.dataset.average_setting ) @@ -140,17 +140,21 @@ class DeviceServer: Returns the gradients of the model's parameters.""" smashed_data, labels = smashed_data.to(self._device), labels.to(self._device) - self._set_model_flops(smashed_data) + self._set_model_flops(smashed_data[0]) self._optimizer.zero_grad() - self.node_device.battery.update_flops(self._model_flops * len(smashed_data)) + self.node_device.battery.update_flops( + self._model_flops["FW"] * len(smashed_data) + ) smashed_data = Variable(smashed_data, requires_grad=True) output_train = self._model(smashed_data) loss_train = self._loss_fn(output_train, labels) - self.node_device.battery.update_flops(self._model_flops * len(smashed_data) * 2) + self.node_device.battery.update_flops( + self._model_flops["BW"] * len(smashed_data) + ) loss_train.backward() self._optimizer.step() @@ -164,11 +168,11 @@ class DeviceServer: loss_train.item(), ) # hier sollten beim AE die gradients vom server model vor dem decoder zurückgegeben werden - def _set_model_flops(self, smashed_data): + def _set_model_flops(self, sample): """Helper to determine the model flops when smashed data are available for the first time.""" - if self._model_flops == 0: - self._model_flops = estimate_model_flops(self._model, smashed_data) / len( - smashed_data + if self._model_flops["FW"] == 0 or self._model_flops["BW"] == 0: + self._model_flops = estimate_model_flops( + self._model, sample.to(self._device).unsqueeze(0) ) @simulate_latency_decorator(latency_factor_attr="latency_factor") @@ -210,8 +214,10 @@ class DeviceServer: """Evaluates the model on the given batch of data and labels""" with torch.no_grad(): smashed_data = smashed_data.to(self._device) - self._set_model_flops(smashed_data) - self.node_device.battery.update_flops(self._model_flops * len(smashed_data)) + self._set_model_flops(smashed_data[0]) + self.node_device.battery.update_flops( + self._model_flops["FW"] * len(smashed_data) + ) pred = self._model(smashed_data) self._metrics.metrics_on_batch(pred.cpu(), labels.cpu().int()) diff --git a/edml/helpers/flops.py b/edml/helpers/flops.py index 4388c97..b3fc822 100644 --- a/edml/helpers/flops.py +++ b/edml/helpers/flops.py @@ -3,19 +3,43 @@ from typing import Union, Tuple from fvcore.nn import FlopCountAnalysis from torch import Tensor, nn +from edml.models.autoencoder import ClientWithAutoencoder, ServerWithAutoencoder +from edml.models.provider.base import has_autoencoder + def estimate_model_flops( model: nn.Module, sample: Union[Tensor, Tuple[Tensor, ...]] -) -> int: +) -> dict[str, int]: """ - Estimates the FLOPs of one forward pass of the model using the sample data provided. - + Estimates the FLOPs of one forward pass and one backward pass of the model using the sample data provided. + If forward and backward pass affect the same parameters, the number of FLOPs of one backward pass is assumed to be + two times the number of FLOPs for the backward pass. + In case of non-trainable parameters as an autoencoder in between, the number of backward FLOPs is assumed to be + twice the number of the forward FLOPs of the trainable part, while the number of actual forward FLOPs is estimated + using the full model. Args: model (nn.Module): the neural network model to calculate the FLOPs for. sample: The data used to calculate the FLOPs. Returns: - int: the number of FLOPs. - + dict(str, int): {"FW": #ForwardFLOPs, "BW": #BackwardFLOPs} the number of estimated forward and backward FLOPs. """ - return FlopCountAnalysis(model, sample).total() + fw_flops = FlopCountAnalysis(model, sample).total() + if has_autoencoder(model): + bw_flops = 0 + if isinstance( + model, ClientWithAutoencoder + ): # run FW pass until Encoder and multiply by 2 + bw_flops = FlopCountAnalysis(model.model, sample).total() * 2 + elif isinstance( + model, ServerWithAutoencoder + ): # run FW pass from decoder output and multiply by 2 + bw_flops = ( + FlopCountAnalysis(model.model, model.autoencoder(sample)).total() * 2 + ) + else: + raise NotImplementedError() + else: + bw_flops = 2 * fw_flops + + return {"FW": fw_flops, "BW": bw_flops} diff --git a/edml/tests/helpers/flops_test.py b/edml/tests/helpers/flops_test.py index 4b7f0d6..b19b2b2 100644 --- a/edml/tests/helpers/flops_test.py +++ b/edml/tests/helpers/flops_test.py @@ -55,9 +55,15 @@ class FlopsTest(unittest.TestCase): client_flops = estimate_model_flops(client_model, inputs) server_flops = estimate_model_flops(server_model, server_inputs) - self.assertEqual(client_flops, 30000) - self.assertEqual(server_flops, 101000) - self.assertEqual(full_flops, client_flops + server_flops) + self.assertEqual(client_flops, {"BW": 60000, "FW": 30000}) + self.assertEqual(server_flops, {"BW": 202000, "FW": 101000}) + self.assertEqual( + full_flops, + { + "BW": client_flops["BW"] + server_flops["BW"], + "FW": client_flops["FW"] + server_flops["FW"], + }, + ) def test_mnist_split_count(self): client_model = ClientNet() @@ -69,5 +75,5 @@ class FlopsTest(unittest.TestCase): client_flops = estimate_model_flops(client_model, inputs) server_flops = estimate_model_flops(server_model, server_inputs) - self.assertEqual(client_flops, 5405760) - self.assertEqual(server_flops, 11215800) + self.assertEqual(client_flops, {"BW": 10811520, "FW": 5405760}) + self.assertEqual(server_flops, {"BW": 22431600, "FW": 11215800}) -- GitLab