Skip to content
Snippets Groups Projects
Commit c50b171a authored by Tim Tobias Bauerle's avatar Tim Tobias Bauerle
Browse files

Adjusted FLOP estimation to account for different parts of the model used in...

Adjusted FLOP estimation to account for different parts of the model used in forward and backward pass
parent a7256d58
No related branches found
No related tags found
1 merge request!24Autoencoder, ATM and global optimizer
from __future__ import annotations
import itertools
import time
from typing import Optional, Tuple, TYPE_CHECKING, Any
......@@ -76,7 +75,7 @@ class DeviceClient:
sample = self._train_data.dataset.__getitem__(0)[0]
if not isinstance(sample, torch.Tensor):
sample = torch.tensor(data=sample)
self._model_flops = estimate_model_flops(
self._model_flops: dict[str, int] = estimate_model_flops(
self._model, sample.to(self._device).unsqueeze(0)
)
self._cfg = cfg
......@@ -169,10 +168,10 @@ class DeviceClient:
# Safety check to ensure that we train same-sized batches only.
batch_data, batch_labels = next(self._batchable_data_loader)
# Updates the battery capacity by simulating the required energy consumption for conducting the training step.
self.node_device.battery.update_flops(self._model_flops * len(batch_data))
# Updates the battery capacity by simulating the required energy consumption for conducting the forward pass.
self.node_device.battery.update_flops(self._model_flops["FW"] * len(batch_data))
# We train the model using the single batch and return the activations and labels. These get send over to the
# We train the model using the single batch and return the activations and labels. These get sent over to the
# server to be then further processed
with LatencySimulator(latency_factor=self.latency_factor):
......@@ -205,9 +204,7 @@ class DeviceClient:
start_time_2 = time.time()
self.node_device.battery.update_flops(
self._model_flops * len(batch_data) * 2
) # 2x for backward pass
self.node_device.battery.update_flops(self._model_flops["BW"] * len(batch_data))
gradients = gradients.to(self._device)
if has_autoencoder(self._model):
self._model.trainable_layers_output.backward(gradients)
......@@ -280,7 +277,9 @@ class DeviceClient:
self._model.train()
diagnostic_metric_container = DiagnosticMetricResultContainer()
for idx, (batch_data, batch_labels) in enumerate(self._train_data):
self.node_device.battery.update_flops(self._model_flops * len(batch_data))
self.node_device.battery.update_flops(
self._model_flops["FW"] * len(batch_data)
)
with LatencySimulator(latency_factor=self.latency_factor):
batch_data = batch_data.to(self._device)
......@@ -304,8 +303,8 @@ class DeviceClient:
server_grad, _server_loss, diagnostic_metrics = train_batch_response
diagnostic_metric_container.merge(diagnostic_metrics)
self.node_device.battery.update_flops(
self._model_flops * len(batch_data) * 2
) # 2x for backward pass
self._model_flops["BW"] * len(batch_data)
)
server_grad = server_grad.to(self._device)
if has_autoencoder(self._model):
self._model.trainable_layers_output.backward(server_grad)
......@@ -362,7 +361,7 @@ class DeviceClient:
for b, (batch_data, batch_labels) in enumerate(dataloader):
with LatencySimulator(latency_factor=self.latency_factor):
self.node_device.battery.update_flops(
self._model_flops * len(batch_data)
self._model_flops["FW"] * len(batch_data)
)
batch_data = batch_data.to(self._device)
batch_labels = batch_labels.to(self._device)
......
......@@ -43,7 +43,7 @@ class DeviceServer:
self._optimizer, self._lr_scheduler = get_optimizer_and_scheduler(
cfg, self._model.get_optimizer_params()
)
self._model_flops = 0 # determine later
self._model_flops: dict[str, int] = {"FW": 0, "BW": 0} # determine later
self._metrics = create_metrics(
cfg.experiment.metrics, cfg.dataset.num_classes, cfg.dataset.average_setting
)
......@@ -140,17 +140,21 @@ class DeviceServer:
Returns the gradients of the model's parameters."""
smashed_data, labels = smashed_data.to(self._device), labels.to(self._device)
self._set_model_flops(smashed_data)
self._set_model_flops(smashed_data[0])
self._optimizer.zero_grad()
self.node_device.battery.update_flops(self._model_flops * len(smashed_data))
self.node_device.battery.update_flops(
self._model_flops["FW"] * len(smashed_data)
)
smashed_data = Variable(smashed_data, requires_grad=True)
output_train = self._model(smashed_data)
loss_train = self._loss_fn(output_train, labels)
self.node_device.battery.update_flops(self._model_flops * len(smashed_data) * 2)
self.node_device.battery.update_flops(
self._model_flops["BW"] * len(smashed_data)
)
loss_train.backward()
self._optimizer.step()
......@@ -164,11 +168,11 @@ class DeviceServer:
loss_train.item(),
) # hier sollten beim AE die gradients vom server model vor dem decoder zurückgegeben werden
def _set_model_flops(self, smashed_data):
def _set_model_flops(self, sample):
"""Helper to determine the model flops when smashed data are available for the first time."""
if self._model_flops == 0:
self._model_flops = estimate_model_flops(self._model, smashed_data) / len(
smashed_data
if self._model_flops["FW"] == 0 or self._model_flops["BW"] == 0:
self._model_flops = estimate_model_flops(
self._model, sample.to(self._device).unsqueeze(0)
)
@simulate_latency_decorator(latency_factor_attr="latency_factor")
......@@ -210,8 +214,10 @@ class DeviceServer:
"""Evaluates the model on the given batch of data and labels"""
with torch.no_grad():
smashed_data = smashed_data.to(self._device)
self._set_model_flops(smashed_data)
self.node_device.battery.update_flops(self._model_flops * len(smashed_data))
self._set_model_flops(smashed_data[0])
self.node_device.battery.update_flops(
self._model_flops["FW"] * len(smashed_data)
)
pred = self._model(smashed_data)
self._metrics.metrics_on_batch(pred.cpu(), labels.cpu().int())
......
......@@ -3,19 +3,43 @@ from typing import Union, Tuple
from fvcore.nn import FlopCountAnalysis
from torch import Tensor, nn
from edml.models.autoencoder import ClientWithAutoencoder, ServerWithAutoencoder
from edml.models.provider.base import has_autoencoder
def estimate_model_flops(
model: nn.Module, sample: Union[Tensor, Tuple[Tensor, ...]]
) -> int:
) -> dict[str, int]:
"""
Estimates the FLOPs of one forward pass of the model using the sample data provided.
Estimates the FLOPs of one forward pass and one backward pass of the model using the sample data provided.
If forward and backward pass affect the same parameters, the number of FLOPs of one backward pass is assumed to be
two times the number of FLOPs for the backward pass.
In case of non-trainable parameters as an autoencoder in between, the number of backward FLOPs is assumed to be
twice the number of the forward FLOPs of the trainable part, while the number of actual forward FLOPs is estimated
using the full model.
Args:
model (nn.Module): the neural network model to calculate the FLOPs for.
sample: The data used to calculate the FLOPs.
Returns:
int: the number of FLOPs.
dict(str, int): {"FW": #ForwardFLOPs, "BW": #BackwardFLOPs} the number of estimated forward and backward FLOPs.
"""
return FlopCountAnalysis(model, sample).total()
fw_flops = FlopCountAnalysis(model, sample).total()
if has_autoencoder(model):
bw_flops = 0
if isinstance(
model, ClientWithAutoencoder
): # run FW pass until Encoder and multiply by 2
bw_flops = FlopCountAnalysis(model.model, sample).total() * 2
elif isinstance(
model, ServerWithAutoencoder
): # run FW pass from decoder output and multiply by 2
bw_flops = (
FlopCountAnalysis(model.model, model.autoencoder(sample)).total() * 2
)
else:
raise NotImplementedError()
else:
bw_flops = 2 * fw_flops
return {"FW": fw_flops, "BW": bw_flops}
......@@ -55,9 +55,15 @@ class FlopsTest(unittest.TestCase):
client_flops = estimate_model_flops(client_model, inputs)
server_flops = estimate_model_flops(server_model, server_inputs)
self.assertEqual(client_flops, 30000)
self.assertEqual(server_flops, 101000)
self.assertEqual(full_flops, client_flops + server_flops)
self.assertEqual(client_flops, {"BW": 60000, "FW": 30000})
self.assertEqual(server_flops, {"BW": 202000, "FW": 101000})
self.assertEqual(
full_flops,
{
"BW": client_flops["BW"] + server_flops["BW"],
"FW": client_flops["FW"] + server_flops["FW"],
},
)
def test_mnist_split_count(self):
client_model = ClientNet()
......@@ -69,5 +75,5 @@ class FlopsTest(unittest.TestCase):
client_flops = estimate_model_flops(client_model, inputs)
server_flops = estimate_model_flops(server_model, server_inputs)
self.assertEqual(client_flops, 5405760)
self.assertEqual(server_flops, 11215800)
self.assertEqual(client_flops, {"BW": 10811520, "FW": 5405760})
self.assertEqual(server_flops, {"BW": 22431600, "FW": 11215800})
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment