diff --git a/edml/models/autoencoder.py b/edml/models/autoencoder.py index d02d9cf0f09332bba205a3e3ce664cbe144fb5a9..ed16f3000d694c53856e901a2634776fcf488433 100644 --- a/edml/models/autoencoder.py +++ b/edml/models/autoencoder.py @@ -6,21 +6,19 @@ class ClientWithAutoencoder(nn.Module): def __init__(self, model: nn.Module, autoencoder: nn.Module): super().__init__() self.model = model - self.autoencoder = autoencoder + self.autoencoder = autoencoder.requires_grad_(False) def forward(self, x): x = self.model(x) - with torch.no_grad(): - return self.autoencoder(x) + return self.autoencoder(x) class ServerWithAutoencoder(nn.Module): def __init__(self, model: nn.Module, autoencoder: nn.Module): super().__init__() self.model = model - self.autoencoder = autoencoder + self.autoencoder = autoencoder.requires_grad_(False) def forward(self, x): - with torch.no_grad(): - x = self.autoencoder(x) + x = self.autoencoder(x) return self.model(x) diff --git a/edml/tests/core/start_device_test.py b/edml/tests/core/start_device_test.py index e4ae787b9850fa612648fd67d3d60ae129ed7b01..abb5d9d61d68c791657d168f9e266abec366a74c 100644 --- a/edml/tests/core/start_device_test.py +++ b/edml/tests/core/start_device_test.py @@ -1,8 +1,10 @@ import os import unittest +from copy import deepcopy import torch from omegaconf import OmegaConf +from torch.autograd import Variable from edml.core.start_device import _get_models from edml.helpers.model_splitting import Part @@ -11,20 +13,20 @@ from edml.models.autoencoder import ClientWithAutoencoder, ServerWithAutoencoder class GetModelsTest(unittest.TestCase): def setUp(self): - os.chdir("../../../") + os.chdir(os.path.join(os.path.dirname(__file__), "../../../")) self.cfg = OmegaConf.create({"some_key": "some_value"}) self.cfg.seed = OmegaConf.load( os.path.join( - os.getcwd(), - "edml/config/seed/default.yaml", + os.path.dirname(__file__), + "../../config/seed/default.yaml", ) ) def _get_model_from_model_provider_config(self, config_name): self.cfg.model_provider = OmegaConf.load( os.path.join( - os.getcwd(), - f"edml/config/model_provider/{config_name}.yaml", + os.path.dirname(__file__), + f"../../config/model_provider/{config_name}.yaml", ) ) return _get_models(self.cfg) @@ -46,6 +48,46 @@ class GetModelsTest(unittest.TestCase): self.assertEqual(len(client.model.layers), 4) self.assertEqual(len(server.model.layers), 5) self.assertEqual(server(client(torch.zeros(1, 3, 32, 32))).shape, (1, 100)) + optimizer = torch.optim.Adam(server.parameters()) + smashed_data = client(torch.zeros(1, 3, 32, 32)) + server_smashed_data = Variable(smashed_data, requires_grad=True) + output_train = server(server_smashed_data) + loss_train = torch.nn.functional.cross_entropy( + output_train, torch.zeros((1, 100)) + ) + loss_train.backward() + optimizer.step() + smashed_data.backward(server_smashed_data.grad) + optimizer.step() + + def test_training_resnet20_with_ae_as_non_trainable_layers(self): + client_encoder, server_decoder = self._get_model_from_model_provider_config( + "resnet20-with-autoencoder" + ) + client_params = deepcopy(str(client_encoder.model.state_dict())) + encoder_params = deepcopy(str(client_encoder.autoencoder.state_dict())) + server_params = deepcopy(str(server_decoder.model.state_dict())) + decoder_params = deepcopy(str(server_decoder.autoencoder.state_dict())) + + # Training loop + client_optimizer = torch.optim.Adam(client_encoder.parameters()) + server_optimizer = torch.optim.Adam(server_decoder.parameters()) + smashed_data = client_encoder(torch.zeros(1, 3, 32, 32)) + server_smashed_data = Variable(smashed_data, requires_grad=True) + output_train = server_decoder(server_smashed_data) + loss_train = torch.nn.functional.cross_entropy( + output_train, torch.rand((1, 100)) + ) + loss_train.backward() + server_optimizer.step() + smashed_data.backward(server_smashed_data.grad) + client_optimizer.step() + + # check that AE hasn't changed, but client and server have + self.assertEqual(encoder_params, str(client_encoder.autoencoder.state_dict())) + self.assertEqual(decoder_params, str(server_decoder.autoencoder.state_dict())) + self.assertNotEqual(client_params, str(client_encoder.model.state_dict())) + self.assertNotEqual(server_params, str(server_decoder.model.state_dict())) def test_load_resnet110(self): client, server = self._get_model_from_model_provider_config("resnet110")