From 5d4b3302f2646323936bf000db038dc0d8d6ac89 Mon Sep 17 00:00:00 2001 From: Tim Bauerle <tim.bauerle@rwth-aachen.de> Date: Wed, 21 Aug 2024 16:19:42 +0200 Subject: [PATCH] Fixed autoencoder: skipping AE layers during backprop --- edml/core/client.py | 13 ++- edml/core/server.py | 11 ++- edml/models/autoencoder.py | 19 +++- edml/models/provider/base.py | 38 +++++++- edml/tests/controllers/fed_controller_test.py | 1 + edml/tests/controllers/sample_config.yaml | 1 + edml/tests/core/server_test.py | 10 ++ edml/tests/core/start_device_test.py | 95 +++++++++++-------- edml/tests/models/model_loading_helpers.py | 15 +++ edml/tests/models/model_provider_test.py | 43 +++++++++ 10 files changed, 196 insertions(+), 50 deletions(-) create mode 100644 edml/tests/models/model_loading_helpers.py create mode 100644 edml/tests/models/model_provider_test.py diff --git a/edml/core/client.py b/edml/core/client.py index d8096c7..7dc1617 100644 --- a/edml/core/client.py +++ b/edml/core/client.py @@ -19,6 +19,7 @@ from edml.helpers.flops import estimate_model_flops from edml.helpers.load_optimizer import get_optimizer_and_scheduler from edml.helpers.metrics import DiagnosticMetricResultContainer, DiagnosticMetricResult from edml.helpers.types import StateDict +from edml.models.provider.base import has_autoencoder if TYPE_CHECKING: from edml.core.device import Device @@ -69,7 +70,7 @@ class DeviceClient: self._device = torch.device(get_torch_device_id(cfg)) self._model = model.to(self._device) self._optimizer, self._lr_scheduler = get_optimizer_and_scheduler( - cfg, self._model.parameters() + cfg, self._model.get_optimizer_params() ) # get first sample from train data to estimate model flops sample = self._train_data.dataset.__getitem__(0)[0] @@ -208,7 +209,10 @@ class DeviceClient: self._model_flops * len(batch_data) * 2 ) # 2x for backward pass gradients = gradients.to(self._device) - smashed_data.backward(gradients) + if has_autoencoder(self._model): + self._model.trainable_layers_output.backward(gradients) + else: + smashed_data.backward(gradients) # self._optimizer.step() # We need to store a reference to the smashed_data to make it possible to finalize the training step. @@ -303,7 +307,10 @@ class DeviceClient: self._model_flops * len(batch_data) * 2 ) # 2x for backward pass server_grad = server_grad.to(self._device) - smashed_data.backward(server_grad) + if has_autoencoder(self._model): + self._model.trainable_layers_output.backward(server_grad) + else: + smashed_data.backward(server_grad) self._optimizer.step() client_train_time = ( diff --git a/edml/core/server.py b/edml/core/server.py index 4279f40..d13893c 100644 --- a/edml/core/server.py +++ b/edml/core/server.py @@ -20,6 +20,7 @@ from edml.helpers.metrics import ( DiagnosticMetricResultContainer, ) from edml.helpers.types import StateDict, LossFn +from edml.models.provider.base import has_autoencoder if TYPE_CHECKING: from edml.core.device import Device @@ -40,7 +41,7 @@ class DeviceServer: self._device = torch.device(get_torch_device_id(cfg)) self._model = model.to(self._device) self._optimizer, self._lr_scheduler = get_optimizer_and_scheduler( - cfg, self._model.parameters() + cfg, self._model.get_optimizer_params() ) self._model_flops = 0 # determine later self._metrics = create_metrics( @@ -156,8 +157,12 @@ class DeviceServer: # Capturing training metrics for the current batch. self.node_device.log({"loss": loss_train.item()}) self._metrics.metrics_on_batch(output_train.cpu(), labels.cpu().int()) - - return smashed_data.grad, loss_train.item() + if has_autoencoder(self._model): + return self._model.trainable_layers_input.grad, loss_train.item() + return ( + smashed_data.grad, + loss_train.item(), + ) # hier sollten beim AE die gradients vom server model vor dem decoder zurückgegeben werden def _set_model_flops(self, smashed_data): """Helper to determine the model flops when smashed data are available for the first time.""" diff --git a/edml/models/autoencoder.py b/edml/models/autoencoder.py index ed16f30..7814b7e 100644 --- a/edml/models/autoencoder.py +++ b/edml/models/autoencoder.py @@ -1,5 +1,5 @@ -import torch from torch import nn +from torch.autograd import Variable class ClientWithAutoencoder(nn.Module): @@ -7,10 +7,13 @@ class ClientWithAutoencoder(nn.Module): super().__init__() self.model = model self.autoencoder = autoencoder.requires_grad_(False) + self.trainable_layers_output = ( + None # needed to continue the backprop skipping the AE + ) def forward(self, x): - x = self.model(x) - return self.autoencoder(x) + self.trainable_layers_output = self.model(x) + return self.autoencoder(self.trainable_layers_output) class ServerWithAutoencoder(nn.Module): @@ -18,7 +21,13 @@ class ServerWithAutoencoder(nn.Module): super().__init__() self.model = model self.autoencoder = autoencoder.requires_grad_(False) + self.trainable_layers_input = ( + None # needed to stop backprop before AE and send the gradients to client + ) def forward(self, x): - x = self.autoencoder(x) - return self.model(x) + self.trainable_layers_input = self.autoencoder(x) + self.trainable_layers_input = Variable( + self.trainable_layers_input, requires_grad=True + ) + return self.model(self.trainable_layers_input) diff --git a/edml/models/provider/base.py b/edml/models/provider/base.py index e60736b..2b28c7c 100644 --- a/edml/models/provider/base.py +++ b/edml/models/provider/base.py @@ -1,10 +1,44 @@ +import torch from torch import nn +def has_autoencoder(model: nn.Module): + if hasattr(model, "model") and hasattr(model, "autoencoder"): + return True + return False + + +def get_grads(model: nn.Module): + gradients = [] + for param in model.parameters(): + if param.grad is not None: + gradients.append(param.grad) + else: + gradients.append(torch.zeros_like(param)) + return gradients + + +def add_optimizer_params_function(model: nn.Module): + # exclude AE params + if has_autoencoder(model): + + def get_optimizer_params(model: nn.Module): + return model.model.parameters + + model.get_optimizer_params = get_optimizer_params(model) + else: + + def get_optimizer_params(model: nn.Module): + return model.parameters + + model.get_optimizer_params = get_optimizer_params(model) + return model + + class ModelProvider: def __init__(self, client: nn.Module, server: nn.Module): - self._client = client - self._server = server + self._client = add_optimizer_params_function(client) + self._server = add_optimizer_params_function(server) @property def models(self) -> tuple[nn.Module, nn.Module]: diff --git a/edml/tests/controllers/fed_controller_test.py b/edml/tests/controllers/fed_controller_test.py index a44090e..45c4ef9 100644 --- a/edml/tests/controllers/fed_controller_test.py +++ b/edml/tests/controllers/fed_controller_test.py @@ -149,6 +149,7 @@ class FedTrainRoundThreadingTest(unittest.TestCase): "num_devices": 0, "wandb": {"enabled": False}, "experiment": {"early_stopping": False}, + "simulate_parallelism": False, } ) ) diff --git a/edml/tests/controllers/sample_config.yaml b/edml/tests/controllers/sample_config.yaml index 505bdae..ad8e917 100644 --- a/edml/tests/controllers/sample_config.yaml +++ b/edml/tests/controllers/sample_config.yaml @@ -24,3 +24,4 @@ battery: deduction_per_mbyte_sent: 1 deduction_per_mbyte_received: 1 deduction_per_mflop: 1 +simulate_parallelism: False diff --git a/edml/tests/core/server_test.py b/edml/tests/core/server_test.py index 1331f8a..5ed6d1c 100644 --- a/edml/tests/core/server_test.py +++ b/edml/tests/core/server_test.py @@ -76,6 +76,7 @@ class PSLTest(unittest.TestCase): ] }, "own_device_id": "d0", + "simulate_parallelism": False, } ) # init models with fixed weights for repeatability @@ -90,10 +91,19 @@ class PSLTest(unittest.TestCase): ) server_model = ServerModel() server_model.load_state_dict(server_state_dict) + server_model.get_optimizer_params = ( + server_model.parameters + ) # using a model provider, this is created automatically client_model1 = ClientModel() client_model1.load_state_dict(client_state_dict) + client_model1.get_optimizer_params = ( + server_model.parameters + ) # using a model provider, this is created automatically client_model2 = ClientModel() client_model2.load_state_dict(client_state_dict) + client_model2.get_optimizer_params = ( + server_model.parameters + ) # using a model provider, this is created automatically self.server = DeviceServer( model=server_model, loss_fn=torch.nn.L1Loss(), cfg=cfg ) diff --git a/edml/tests/core/start_device_test.py b/edml/tests/core/start_device_test.py index abb5d9d..dd1f409 100644 --- a/edml/tests/core/start_device_test.py +++ b/edml/tests/core/start_device_test.py @@ -4,11 +4,51 @@ from copy import deepcopy import torch from omegaconf import OmegaConf -from torch.autograd import Variable -from edml.core.start_device import _get_models from edml.helpers.model_splitting import Part from edml.models.autoencoder import ClientWithAutoencoder, ServerWithAutoencoder +from edml.tests.models.model_loading_helpers import ( + _get_model_from_model_provider_config, +) + + +def plot_grad_flow(named_parameters): + from matplotlib import pyplot as plt + import numpy as np + from matplotlib.lines import Line2D + + """Plots the gradients flowing through different layers in the net during training. + Can be used for checking for possible gradient vanishing / exploding problems. + + Usage: Plug this function in Trainer class after loss.backwards() as + "plot_grad_flow(self.model.named_parameters())" to visualize the gradient flow""" + ave_grads = [] + max_grads = [] + layers = [] + for n, p in named_parameters: + if (p.requires_grad) and ("bias" not in n): + layers.append(n) + ave_grads.append(p.grad.abs().mean()) + max_grads.append(p.grad.abs().max()) + plt.bar(np.arange(len(max_grads)), max_grads, alpha=0.1, lw=1, color="c") + plt.bar(np.arange(len(max_grads)), ave_grads, alpha=0.1, lw=1, color="b") + plt.hlines(0, 0, len(ave_grads) + 1, lw=2, color="k") + plt.xticks(range(0, len(ave_grads), 1), layers, rotation="vertical") + plt.xlim(left=0, right=len(ave_grads)) + plt.ylim(bottom=-0.001) # , top=0.02) # zoom in on the lower gradient regions + plt.xlabel("Layers") + plt.ylabel("average gradient") + plt.title("Gradient flow") + plt.grid(True) + plt.legend( + [ + Line2D([0], [0], color="c", lw=4), + Line2D([0], [0], color="b", lw=4), + Line2D([0], [0], color="k", lw=4), + ], + ["max-gradient", "mean-gradient", "zero-gradient"], + ) + plt.show() class GetModelsTest(unittest.TestCase): @@ -22,17 +62,8 @@ class GetModelsTest(unittest.TestCase): ) ) - def _get_model_from_model_provider_config(self, config_name): - self.cfg.model_provider = OmegaConf.load( - os.path.join( - os.path.dirname(__file__), - f"../../config/model_provider/{config_name}.yaml", - ) - ) - return _get_models(self.cfg) - def test_load_resnet20(self): - client, server = self._get_model_from_model_provider_config("resnet20") + client, server = _get_model_from_model_provider_config(self.cfg, "resnet20") self.assertIsInstance(client, Part) self.assertIsInstance(server, Part) self.assertEqual(len(client.layers), 4) @@ -40,29 +71,18 @@ class GetModelsTest(unittest.TestCase): self.assertEqual(server(client(torch.zeros(1, 3, 32, 32))).shape, (1, 100)) def test_load_resnet20_with_ae(self): - client, server = self._get_model_from_model_provider_config( - "resnet20-with-autoencoder" + client, server = _get_model_from_model_provider_config( + self.cfg, "resnet20-with-autoencoder" ) self.assertIsInstance(client, ClientWithAutoencoder) self.assertIsInstance(server, ServerWithAutoencoder) self.assertEqual(len(client.model.layers), 4) self.assertEqual(len(server.model.layers), 5) self.assertEqual(server(client(torch.zeros(1, 3, 32, 32))).shape, (1, 100)) - optimizer = torch.optim.Adam(server.parameters()) - smashed_data = client(torch.zeros(1, 3, 32, 32)) - server_smashed_data = Variable(smashed_data, requires_grad=True) - output_train = server(server_smashed_data) - loss_train = torch.nn.functional.cross_entropy( - output_train, torch.zeros((1, 100)) - ) - loss_train.backward() - optimizer.step() - smashed_data.backward(server_smashed_data.grad) - optimizer.step() def test_training_resnet20_with_ae_as_non_trainable_layers(self): - client_encoder, server_decoder = self._get_model_from_model_provider_config( - "resnet20-with-autoencoder" + client_encoder, server_decoder = _get_model_from_model_provider_config( + self.cfg, "resnet20-with-autoencoder" ) client_params = deepcopy(str(client_encoder.model.state_dict())) encoder_params = deepcopy(str(client_encoder.autoencoder.state_dict())) @@ -70,17 +90,18 @@ class GetModelsTest(unittest.TestCase): decoder_params = deepcopy(str(server_decoder.autoencoder.state_dict())) # Training loop - client_optimizer = torch.optim.Adam(client_encoder.parameters()) - server_optimizer = torch.optim.Adam(server_decoder.parameters()) - smashed_data = client_encoder(torch.zeros(1, 3, 32, 32)) - server_smashed_data = Variable(smashed_data, requires_grad=True) - output_train = server_decoder(server_smashed_data) + client_optimizer = torch.optim.Adam(client_encoder.get_optimizer_params()) + server_optimizer = torch.optim.Adam(server_decoder.get_optimizer_params()) + smashed_data = client_encoder(torch.zeros(7, 3, 32, 32)) + output_train = server_decoder(smashed_data) loss_train = torch.nn.functional.cross_entropy( - output_train, torch.rand((1, 100)) + output_train, torch.rand((7, 100)) ) loss_train.backward() server_optimizer.step() - smashed_data.backward(server_smashed_data.grad) + client_encoder.trainable_layers_output.backward( + server_decoder.trainable_layers_input.grad + ) client_optimizer.step() # check that AE hasn't changed, but client and server have @@ -90,7 +111,7 @@ class GetModelsTest(unittest.TestCase): self.assertNotEqual(server_params, str(server_decoder.model.state_dict())) def test_load_resnet110(self): - client, server = self._get_model_from_model_provider_config("resnet110") + client, server = _get_model_from_model_provider_config(self.cfg, "resnet110") self.assertIsInstance(client, Part) self.assertIsInstance(server, Part) self.assertEqual(len(client.layers), 4) @@ -98,8 +119,8 @@ class GetModelsTest(unittest.TestCase): self.assertEqual(server(client(torch.zeros(1, 3, 32, 32))).shape, (1, 100)) def test_load_resnet110_with_ae(self): - client, server = self._get_model_from_model_provider_config( - "resnet110-with-autoencoder" + client, server = _get_model_from_model_provider_config( + self.cfg, "resnet110-with-autoencoder" ) self.assertIsInstance(client, ClientWithAutoencoder) self.assertIsInstance(server, ServerWithAutoencoder) diff --git a/edml/tests/models/model_loading_helpers.py b/edml/tests/models/model_loading_helpers.py new file mode 100644 index 0000000..18b88a0 --- /dev/null +++ b/edml/tests/models/model_loading_helpers.py @@ -0,0 +1,15 @@ +import os + +from omegaconf import OmegaConf + +from edml.core.start_device import _get_models + + +def _get_model_from_model_provider_config(cfg, config_name): + cfg.model_provider = OmegaConf.load( + os.path.join( + os.path.dirname(__file__), + f"../../config/model_provider/{config_name}.yaml", + ) + ) + return _get_models(cfg) diff --git a/edml/tests/models/model_provider_test.py b/edml/tests/models/model_provider_test.py new file mode 100644 index 0000000..bf872d0 --- /dev/null +++ b/edml/tests/models/model_provider_test.py @@ -0,0 +1,43 @@ +import os +import unittest + +from omegaconf import OmegaConf + +from edml.models.provider.base import has_autoencoder +from edml.tests.models.model_loading_helpers import ( + _get_model_from_model_provider_config, +) + + +class ModelProviderTest(unittest.TestCase): + def setUp(self): + os.chdir(os.path.join(os.path.dirname(__file__), "../../../")) + self.cfg = OmegaConf.create({"some_key": "some_value"}) + + def test_has_autoencoder_for_model_without_autoencoder(self): + client, server = _get_model_from_model_provider_config(self.cfg, "resnet20") + self.assertFalse(has_autoencoder(client)) + self.assertFalse(has_autoencoder(server)) + + def test_has_autoencoder_for_model_with_autoencoder(self): + client, server = _get_model_from_model_provider_config( + self.cfg, "resnet20-with-autoencoder" + ) + self.assertTrue(has_autoencoder(client)) + self.assertTrue(has_autoencoder(server)) + + def test_get_optimizer_params_for_model_without_autoencoder(self): + client, server = _get_model_from_model_provider_config(self.cfg, "resnet20") + self.assertEqual(list(client.get_optimizer_params()), list(client.parameters())) + self.assertEqual(list(server.get_optimizer_params()), list(server.parameters())) + + def test_get_optimizer_params_for_model_with_autoencoder(self): + client, server = _get_model_from_model_provider_config( + self.cfg, "resnet20-with-autoencoder" + ) + self.assertEqual( + list(client.get_optimizer_params()), list(client.model.parameters()) + ) + self.assertEqual( + list(server.get_optimizer_params()), list(server.model.parameters()) + ) -- GitLab