Skip to content
Snippets Groups Projects
Commit 003a2795 authored by Tim Tobias Bauerle's avatar Tim Tobias Bauerle
Browse files

Made autoencoder as non-trainable layers

parent bde85811
Branches
No related tags found
2 merge requests!18Merge in main,!14Experiment configs
......@@ -6,21 +6,19 @@ class ClientWithAutoencoder(nn.Module):
def __init__(self, model: nn.Module, autoencoder: nn.Module):
super().__init__()
self.model = model
self.autoencoder = autoencoder
self.autoencoder = autoencoder.requires_grad_(False)
def forward(self, x):
x = self.model(x)
with torch.no_grad():
return self.autoencoder(x)
return self.autoencoder(x)
class ServerWithAutoencoder(nn.Module):
def __init__(self, model: nn.Module, autoencoder: nn.Module):
super().__init__()
self.model = model
self.autoencoder = autoencoder
self.autoencoder = autoencoder.requires_grad_(False)
def forward(self, x):
with torch.no_grad():
x = self.autoencoder(x)
x = self.autoencoder(x)
return self.model(x)
import os
import unittest
from copy import deepcopy
import torch
from omegaconf import OmegaConf
from torch.autograd import Variable
from edml.core.start_device import _get_models
from edml.helpers.model_splitting import Part
......@@ -11,20 +13,20 @@ from edml.models.autoencoder import ClientWithAutoencoder, ServerWithAutoencoder
class GetModelsTest(unittest.TestCase):
def setUp(self):
os.chdir("../../../")
os.chdir(os.path.join(os.path.dirname(__file__), "../../../"))
self.cfg = OmegaConf.create({"some_key": "some_value"})
self.cfg.seed = OmegaConf.load(
os.path.join(
os.getcwd(),
"edml/config/seed/default.yaml",
os.path.dirname(__file__),
"../../config/seed/default.yaml",
)
)
def _get_model_from_model_provider_config(self, config_name):
self.cfg.model_provider = OmegaConf.load(
os.path.join(
os.getcwd(),
f"edml/config/model_provider/{config_name}.yaml",
os.path.dirname(__file__),
f"../../config/model_provider/{config_name}.yaml",
)
)
return _get_models(self.cfg)
......@@ -46,6 +48,46 @@ class GetModelsTest(unittest.TestCase):
self.assertEqual(len(client.model.layers), 4)
self.assertEqual(len(server.model.layers), 5)
self.assertEqual(server(client(torch.zeros(1, 3, 32, 32))).shape, (1, 100))
optimizer = torch.optim.Adam(server.parameters())
smashed_data = client(torch.zeros(1, 3, 32, 32))
server_smashed_data = Variable(smashed_data, requires_grad=True)
output_train = server(server_smashed_data)
loss_train = torch.nn.functional.cross_entropy(
output_train, torch.zeros((1, 100))
)
loss_train.backward()
optimizer.step()
smashed_data.backward(server_smashed_data.grad)
optimizer.step()
def test_training_resnet20_with_ae_as_non_trainable_layers(self):
client_encoder, server_decoder = self._get_model_from_model_provider_config(
"resnet20-with-autoencoder"
)
client_params = deepcopy(str(client_encoder.model.state_dict()))
encoder_params = deepcopy(str(client_encoder.autoencoder.state_dict()))
server_params = deepcopy(str(server_decoder.model.state_dict()))
decoder_params = deepcopy(str(server_decoder.autoencoder.state_dict()))
# Training loop
client_optimizer = torch.optim.Adam(client_encoder.parameters())
server_optimizer = torch.optim.Adam(server_decoder.parameters())
smashed_data = client_encoder(torch.zeros(1, 3, 32, 32))
server_smashed_data = Variable(smashed_data, requires_grad=True)
output_train = server_decoder(server_smashed_data)
loss_train = torch.nn.functional.cross_entropy(
output_train, torch.rand((1, 100))
)
loss_train.backward()
server_optimizer.step()
smashed_data.backward(server_smashed_data.grad)
client_optimizer.step()
# check that AE hasn't changed, but client and server have
self.assertEqual(encoder_params, str(client_encoder.autoencoder.state_dict()))
self.assertEqual(decoder_params, str(server_decoder.autoencoder.state_dict()))
self.assertNotEqual(client_params, str(client_encoder.model.state_dict()))
self.assertNotEqual(server_params, str(server_decoder.model.state_dict()))
def test_load_resnet110(self):
client, server = self._get_model_from_model_provider_config("resnet110")
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment