diff --git a/edml/models/autoencoder.py b/edml/models/autoencoder.py
index d02d9cf0f09332bba205a3e3ce664cbe144fb5a9..ed16f3000d694c53856e901a2634776fcf488433 100644
--- a/edml/models/autoencoder.py
+++ b/edml/models/autoencoder.py
@@ -6,21 +6,19 @@ class ClientWithAutoencoder(nn.Module):
     def __init__(self, model: nn.Module, autoencoder: nn.Module):
         super().__init__()
         self.model = model
-        self.autoencoder = autoencoder
+        self.autoencoder = autoencoder.requires_grad_(False)
 
     def forward(self, x):
         x = self.model(x)
-        with torch.no_grad():
-            return self.autoencoder(x)
+        return self.autoencoder(x)
 
 
 class ServerWithAutoencoder(nn.Module):
     def __init__(self, model: nn.Module, autoencoder: nn.Module):
         super().__init__()
         self.model = model
-        self.autoencoder = autoencoder
+        self.autoencoder = autoencoder.requires_grad_(False)
 
     def forward(self, x):
-        with torch.no_grad():
-            x = self.autoencoder(x)
+        x = self.autoencoder(x)
         return self.model(x)
diff --git a/edml/tests/core/start_device_test.py b/edml/tests/core/start_device_test.py
index e4ae787b9850fa612648fd67d3d60ae129ed7b01..abb5d9d61d68c791657d168f9e266abec366a74c 100644
--- a/edml/tests/core/start_device_test.py
+++ b/edml/tests/core/start_device_test.py
@@ -1,8 +1,10 @@
 import os
 import unittest
+from copy import deepcopy
 
 import torch
 from omegaconf import OmegaConf
+from torch.autograd import Variable
 
 from edml.core.start_device import _get_models
 from edml.helpers.model_splitting import Part
@@ -11,20 +13,20 @@ from edml.models.autoencoder import ClientWithAutoencoder, ServerWithAutoencoder
 
 class GetModelsTest(unittest.TestCase):
     def setUp(self):
-        os.chdir("../../../")
+        os.chdir(os.path.join(os.path.dirname(__file__), "../../../"))
         self.cfg = OmegaConf.create({"some_key": "some_value"})
         self.cfg.seed = OmegaConf.load(
             os.path.join(
-                os.getcwd(),
-                "edml/config/seed/default.yaml",
+                os.path.dirname(__file__),
+                "../../config/seed/default.yaml",
             )
         )
 
     def _get_model_from_model_provider_config(self, config_name):
         self.cfg.model_provider = OmegaConf.load(
             os.path.join(
-                os.getcwd(),
-                f"edml/config/model_provider/{config_name}.yaml",
+                os.path.dirname(__file__),
+                f"../../config/model_provider/{config_name}.yaml",
             )
         )
         return _get_models(self.cfg)
@@ -46,6 +48,46 @@ class GetModelsTest(unittest.TestCase):
         self.assertEqual(len(client.model.layers), 4)
         self.assertEqual(len(server.model.layers), 5)
         self.assertEqual(server(client(torch.zeros(1, 3, 32, 32))).shape, (1, 100))
+        optimizer = torch.optim.Adam(server.parameters())
+        smashed_data = client(torch.zeros(1, 3, 32, 32))
+        server_smashed_data = Variable(smashed_data, requires_grad=True)
+        output_train = server(server_smashed_data)
+        loss_train = torch.nn.functional.cross_entropy(
+            output_train, torch.zeros((1, 100))
+        )
+        loss_train.backward()
+        optimizer.step()
+        smashed_data.backward(server_smashed_data.grad)
+        optimizer.step()
+
+    def test_training_resnet20_with_ae_as_non_trainable_layers(self):
+        client_encoder, server_decoder = self._get_model_from_model_provider_config(
+            "resnet20-with-autoencoder"
+        )
+        client_params = deepcopy(str(client_encoder.model.state_dict()))
+        encoder_params = deepcopy(str(client_encoder.autoencoder.state_dict()))
+        server_params = deepcopy(str(server_decoder.model.state_dict()))
+        decoder_params = deepcopy(str(server_decoder.autoencoder.state_dict()))
+
+        # Training loop
+        client_optimizer = torch.optim.Adam(client_encoder.parameters())
+        server_optimizer = torch.optim.Adam(server_decoder.parameters())
+        smashed_data = client_encoder(torch.zeros(1, 3, 32, 32))
+        server_smashed_data = Variable(smashed_data, requires_grad=True)
+        output_train = server_decoder(server_smashed_data)
+        loss_train = torch.nn.functional.cross_entropy(
+            output_train, torch.rand((1, 100))
+        )
+        loss_train.backward()
+        server_optimizer.step()
+        smashed_data.backward(server_smashed_data.grad)
+        client_optimizer.step()
+
+        # check that AE hasn't changed, but client and server have
+        self.assertEqual(encoder_params, str(client_encoder.autoencoder.state_dict()))
+        self.assertEqual(decoder_params, str(server_decoder.autoencoder.state_dict()))
+        self.assertNotEqual(client_params, str(client_encoder.model.state_dict()))
+        self.assertNotEqual(server_params, str(server_decoder.model.state_dict()))
 
     def test_load_resnet110(self):
         client, server = self._get_model_from_model_provider_config("resnet110")