Skip to content
Snippets Groups Projects
Commit 5d4b3302 authored by Tim Tobias Bauerle's avatar Tim Tobias Bauerle
Browse files

Fixed autoencoder: skipping AE layers during backprop

parent 4127b0ad
No related branches found
No related tags found
1 merge request!24Autoencoder, ATM and global optimizer
......@@ -19,6 +19,7 @@ from edml.helpers.flops import estimate_model_flops
from edml.helpers.load_optimizer import get_optimizer_and_scheduler
from edml.helpers.metrics import DiagnosticMetricResultContainer, DiagnosticMetricResult
from edml.helpers.types import StateDict
from edml.models.provider.base import has_autoencoder
if TYPE_CHECKING:
from edml.core.device import Device
......@@ -69,7 +70,7 @@ class DeviceClient:
self._device = torch.device(get_torch_device_id(cfg))
self._model = model.to(self._device)
self._optimizer, self._lr_scheduler = get_optimizer_and_scheduler(
cfg, self._model.parameters()
cfg, self._model.get_optimizer_params()
)
# get first sample from train data to estimate model flops
sample = self._train_data.dataset.__getitem__(0)[0]
......@@ -208,7 +209,10 @@ class DeviceClient:
self._model_flops * len(batch_data) * 2
) # 2x for backward pass
gradients = gradients.to(self._device)
smashed_data.backward(gradients)
if has_autoencoder(self._model):
self._model.trainable_layers_output.backward(gradients)
else:
smashed_data.backward(gradients)
# self._optimizer.step()
# We need to store a reference to the smashed_data to make it possible to finalize the training step.
......@@ -303,7 +307,10 @@ class DeviceClient:
self._model_flops * len(batch_data) * 2
) # 2x for backward pass
server_grad = server_grad.to(self._device)
smashed_data.backward(server_grad)
if has_autoencoder(self._model):
self._model.trainable_layers_output.backward(server_grad)
else:
smashed_data.backward(server_grad)
self._optimizer.step()
client_train_time = (
......
......@@ -20,6 +20,7 @@ from edml.helpers.metrics import (
DiagnosticMetricResultContainer,
)
from edml.helpers.types import StateDict, LossFn
from edml.models.provider.base import has_autoencoder
if TYPE_CHECKING:
from edml.core.device import Device
......@@ -40,7 +41,7 @@ class DeviceServer:
self._device = torch.device(get_torch_device_id(cfg))
self._model = model.to(self._device)
self._optimizer, self._lr_scheduler = get_optimizer_and_scheduler(
cfg, self._model.parameters()
cfg, self._model.get_optimizer_params()
)
self._model_flops = 0 # determine later
self._metrics = create_metrics(
......@@ -156,8 +157,12 @@ class DeviceServer:
# Capturing training metrics for the current batch.
self.node_device.log({"loss": loss_train.item()})
self._metrics.metrics_on_batch(output_train.cpu(), labels.cpu().int())
return smashed_data.grad, loss_train.item()
if has_autoencoder(self._model):
return self._model.trainable_layers_input.grad, loss_train.item()
return (
smashed_data.grad,
loss_train.item(),
) # hier sollten beim AE die gradients vom server model vor dem decoder zurückgegeben werden
def _set_model_flops(self, smashed_data):
"""Helper to determine the model flops when smashed data are available for the first time."""
......
import torch
from torch import nn
from torch.autograd import Variable
class ClientWithAutoencoder(nn.Module):
......@@ -7,10 +7,13 @@ class ClientWithAutoencoder(nn.Module):
super().__init__()
self.model = model
self.autoencoder = autoencoder.requires_grad_(False)
self.trainable_layers_output = (
None # needed to continue the backprop skipping the AE
)
def forward(self, x):
x = self.model(x)
return self.autoencoder(x)
self.trainable_layers_output = self.model(x)
return self.autoencoder(self.trainable_layers_output)
class ServerWithAutoencoder(nn.Module):
......@@ -18,7 +21,13 @@ class ServerWithAutoencoder(nn.Module):
super().__init__()
self.model = model
self.autoencoder = autoencoder.requires_grad_(False)
self.trainable_layers_input = (
None # needed to stop backprop before AE and send the gradients to client
)
def forward(self, x):
x = self.autoencoder(x)
return self.model(x)
self.trainable_layers_input = self.autoencoder(x)
self.trainable_layers_input = Variable(
self.trainable_layers_input, requires_grad=True
)
return self.model(self.trainable_layers_input)
import torch
from torch import nn
def has_autoencoder(model: nn.Module):
if hasattr(model, "model") and hasattr(model, "autoencoder"):
return True
return False
def get_grads(model: nn.Module):
gradients = []
for param in model.parameters():
if param.grad is not None:
gradients.append(param.grad)
else:
gradients.append(torch.zeros_like(param))
return gradients
def add_optimizer_params_function(model: nn.Module):
# exclude AE params
if has_autoencoder(model):
def get_optimizer_params(model: nn.Module):
return model.model.parameters
model.get_optimizer_params = get_optimizer_params(model)
else:
def get_optimizer_params(model: nn.Module):
return model.parameters
model.get_optimizer_params = get_optimizer_params(model)
return model
class ModelProvider:
def __init__(self, client: nn.Module, server: nn.Module):
self._client = client
self._server = server
self._client = add_optimizer_params_function(client)
self._server = add_optimizer_params_function(server)
@property
def models(self) -> tuple[nn.Module, nn.Module]:
......
......@@ -149,6 +149,7 @@ class FedTrainRoundThreadingTest(unittest.TestCase):
"num_devices": 0,
"wandb": {"enabled": False},
"experiment": {"early_stopping": False},
"simulate_parallelism": False,
}
)
)
......
......@@ -24,3 +24,4 @@ battery:
deduction_per_mbyte_sent: 1
deduction_per_mbyte_received: 1
deduction_per_mflop: 1
simulate_parallelism: False
......@@ -76,6 +76,7 @@ class PSLTest(unittest.TestCase):
]
},
"own_device_id": "d0",
"simulate_parallelism": False,
}
)
# init models with fixed weights for repeatability
......@@ -90,10 +91,19 @@ class PSLTest(unittest.TestCase):
)
server_model = ServerModel()
server_model.load_state_dict(server_state_dict)
server_model.get_optimizer_params = (
server_model.parameters
) # using a model provider, this is created automatically
client_model1 = ClientModel()
client_model1.load_state_dict(client_state_dict)
client_model1.get_optimizer_params = (
server_model.parameters
) # using a model provider, this is created automatically
client_model2 = ClientModel()
client_model2.load_state_dict(client_state_dict)
client_model2.get_optimizer_params = (
server_model.parameters
) # using a model provider, this is created automatically
self.server = DeviceServer(
model=server_model, loss_fn=torch.nn.L1Loss(), cfg=cfg
)
......
......@@ -4,11 +4,51 @@ from copy import deepcopy
import torch
from omegaconf import OmegaConf
from torch.autograd import Variable
from edml.core.start_device import _get_models
from edml.helpers.model_splitting import Part
from edml.models.autoencoder import ClientWithAutoencoder, ServerWithAutoencoder
from edml.tests.models.model_loading_helpers import (
_get_model_from_model_provider_config,
)
def plot_grad_flow(named_parameters):
from matplotlib import pyplot as plt
import numpy as np
from matplotlib.lines import Line2D
"""Plots the gradients flowing through different layers in the net during training.
Can be used for checking for possible gradient vanishing / exploding problems.
Usage: Plug this function in Trainer class after loss.backwards() as
"plot_grad_flow(self.model.named_parameters())" to visualize the gradient flow"""
ave_grads = []
max_grads = []
layers = []
for n, p in named_parameters:
if (p.requires_grad) and ("bias" not in n):
layers.append(n)
ave_grads.append(p.grad.abs().mean())
max_grads.append(p.grad.abs().max())
plt.bar(np.arange(len(max_grads)), max_grads, alpha=0.1, lw=1, color="c")
plt.bar(np.arange(len(max_grads)), ave_grads, alpha=0.1, lw=1, color="b")
plt.hlines(0, 0, len(ave_grads) + 1, lw=2, color="k")
plt.xticks(range(0, len(ave_grads), 1), layers, rotation="vertical")
plt.xlim(left=0, right=len(ave_grads))
plt.ylim(bottom=-0.001) # , top=0.02) # zoom in on the lower gradient regions
plt.xlabel("Layers")
plt.ylabel("average gradient")
plt.title("Gradient flow")
plt.grid(True)
plt.legend(
[
Line2D([0], [0], color="c", lw=4),
Line2D([0], [0], color="b", lw=4),
Line2D([0], [0], color="k", lw=4),
],
["max-gradient", "mean-gradient", "zero-gradient"],
)
plt.show()
class GetModelsTest(unittest.TestCase):
......@@ -22,17 +62,8 @@ class GetModelsTest(unittest.TestCase):
)
)
def _get_model_from_model_provider_config(self, config_name):
self.cfg.model_provider = OmegaConf.load(
os.path.join(
os.path.dirname(__file__),
f"../../config/model_provider/{config_name}.yaml",
)
)
return _get_models(self.cfg)
def test_load_resnet20(self):
client, server = self._get_model_from_model_provider_config("resnet20")
client, server = _get_model_from_model_provider_config(self.cfg, "resnet20")
self.assertIsInstance(client, Part)
self.assertIsInstance(server, Part)
self.assertEqual(len(client.layers), 4)
......@@ -40,29 +71,18 @@ class GetModelsTest(unittest.TestCase):
self.assertEqual(server(client(torch.zeros(1, 3, 32, 32))).shape, (1, 100))
def test_load_resnet20_with_ae(self):
client, server = self._get_model_from_model_provider_config(
"resnet20-with-autoencoder"
client, server = _get_model_from_model_provider_config(
self.cfg, "resnet20-with-autoencoder"
)
self.assertIsInstance(client, ClientWithAutoencoder)
self.assertIsInstance(server, ServerWithAutoencoder)
self.assertEqual(len(client.model.layers), 4)
self.assertEqual(len(server.model.layers), 5)
self.assertEqual(server(client(torch.zeros(1, 3, 32, 32))).shape, (1, 100))
optimizer = torch.optim.Adam(server.parameters())
smashed_data = client(torch.zeros(1, 3, 32, 32))
server_smashed_data = Variable(smashed_data, requires_grad=True)
output_train = server(server_smashed_data)
loss_train = torch.nn.functional.cross_entropy(
output_train, torch.zeros((1, 100))
)
loss_train.backward()
optimizer.step()
smashed_data.backward(server_smashed_data.grad)
optimizer.step()
def test_training_resnet20_with_ae_as_non_trainable_layers(self):
client_encoder, server_decoder = self._get_model_from_model_provider_config(
"resnet20-with-autoencoder"
client_encoder, server_decoder = _get_model_from_model_provider_config(
self.cfg, "resnet20-with-autoencoder"
)
client_params = deepcopy(str(client_encoder.model.state_dict()))
encoder_params = deepcopy(str(client_encoder.autoencoder.state_dict()))
......@@ -70,17 +90,18 @@ class GetModelsTest(unittest.TestCase):
decoder_params = deepcopy(str(server_decoder.autoencoder.state_dict()))
# Training loop
client_optimizer = torch.optim.Adam(client_encoder.parameters())
server_optimizer = torch.optim.Adam(server_decoder.parameters())
smashed_data = client_encoder(torch.zeros(1, 3, 32, 32))
server_smashed_data = Variable(smashed_data, requires_grad=True)
output_train = server_decoder(server_smashed_data)
client_optimizer = torch.optim.Adam(client_encoder.get_optimizer_params())
server_optimizer = torch.optim.Adam(server_decoder.get_optimizer_params())
smashed_data = client_encoder(torch.zeros(7, 3, 32, 32))
output_train = server_decoder(smashed_data)
loss_train = torch.nn.functional.cross_entropy(
output_train, torch.rand((1, 100))
output_train, torch.rand((7, 100))
)
loss_train.backward()
server_optimizer.step()
smashed_data.backward(server_smashed_data.grad)
client_encoder.trainable_layers_output.backward(
server_decoder.trainable_layers_input.grad
)
client_optimizer.step()
# check that AE hasn't changed, but client and server have
......@@ -90,7 +111,7 @@ class GetModelsTest(unittest.TestCase):
self.assertNotEqual(server_params, str(server_decoder.model.state_dict()))
def test_load_resnet110(self):
client, server = self._get_model_from_model_provider_config("resnet110")
client, server = _get_model_from_model_provider_config(self.cfg, "resnet110")
self.assertIsInstance(client, Part)
self.assertIsInstance(server, Part)
self.assertEqual(len(client.layers), 4)
......@@ -98,8 +119,8 @@ class GetModelsTest(unittest.TestCase):
self.assertEqual(server(client(torch.zeros(1, 3, 32, 32))).shape, (1, 100))
def test_load_resnet110_with_ae(self):
client, server = self._get_model_from_model_provider_config(
"resnet110-with-autoencoder"
client, server = _get_model_from_model_provider_config(
self.cfg, "resnet110-with-autoencoder"
)
self.assertIsInstance(client, ClientWithAutoencoder)
self.assertIsInstance(server, ServerWithAutoencoder)
......
import os
from omegaconf import OmegaConf
from edml.core.start_device import _get_models
def _get_model_from_model_provider_config(cfg, config_name):
cfg.model_provider = OmegaConf.load(
os.path.join(
os.path.dirname(__file__),
f"../../config/model_provider/{config_name}.yaml",
)
)
return _get_models(cfg)
import os
import unittest
from omegaconf import OmegaConf
from edml.models.provider.base import has_autoencoder
from edml.tests.models.model_loading_helpers import (
_get_model_from_model_provider_config,
)
class ModelProviderTest(unittest.TestCase):
def setUp(self):
os.chdir(os.path.join(os.path.dirname(__file__), "../../../"))
self.cfg = OmegaConf.create({"some_key": "some_value"})
def test_has_autoencoder_for_model_without_autoencoder(self):
client, server = _get_model_from_model_provider_config(self.cfg, "resnet20")
self.assertFalse(has_autoencoder(client))
self.assertFalse(has_autoencoder(server))
def test_has_autoencoder_for_model_with_autoencoder(self):
client, server = _get_model_from_model_provider_config(
self.cfg, "resnet20-with-autoencoder"
)
self.assertTrue(has_autoencoder(client))
self.assertTrue(has_autoencoder(server))
def test_get_optimizer_params_for_model_without_autoencoder(self):
client, server = _get_model_from_model_provider_config(self.cfg, "resnet20")
self.assertEqual(list(client.get_optimizer_params()), list(client.parameters()))
self.assertEqual(list(server.get_optimizer_params()), list(server.parameters()))
def test_get_optimizer_params_for_model_with_autoencoder(self):
client, server = _get_model_from_model_provider_config(
self.cfg, "resnet20-with-autoencoder"
)
self.assertEqual(
list(client.get_optimizer_params()), list(client.model.parameters())
)
self.assertEqual(
list(server.get_optimizer_params()), list(server.model.parameters())
)
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment