diff --git a/edml/models/provider/path.py b/edml/models/provider/path.py
index c53e893ae0048550887f73b32d8b4297440cd8e8..6245270a9cc43cfea81eec58d03c4ed84fb746c7 100644
--- a/edml/models/provider/path.py
+++ b/edml/models/provider/path.py
@@ -6,3 +6,7 @@ class SerializedModel(nn.Module):
     def __init__(self, model: nn.Module, path: str):
         super().__init__()
         model.load_state_dict(torch.load(path))
+        self.model = model
+
+    def forward(self, x):
+        return self.model(x)
diff --git a/edml/tests/core/start_device_test.py b/edml/tests/core/start_device_test.py
new file mode 100644
index 0000000000000000000000000000000000000000..e4ae787b9850fa612648fd67d3d60ae129ed7b01
--- /dev/null
+++ b/edml/tests/core/start_device_test.py
@@ -0,0 +1,66 @@
+import os
+import unittest
+
+import torch
+from omegaconf import OmegaConf
+
+from edml.core.start_device import _get_models
+from edml.helpers.model_splitting import Part
+from edml.models.autoencoder import ClientWithAutoencoder, ServerWithAutoencoder
+
+
+class GetModelsTest(unittest.TestCase):
+    def setUp(self):
+        os.chdir("../../../")
+        self.cfg = OmegaConf.create({"some_key": "some_value"})
+        self.cfg.seed = OmegaConf.load(
+            os.path.join(
+                os.getcwd(),
+                "edml/config/seed/default.yaml",
+            )
+        )
+
+    def _get_model_from_model_provider_config(self, config_name):
+        self.cfg.model_provider = OmegaConf.load(
+            os.path.join(
+                os.getcwd(),
+                f"edml/config/model_provider/{config_name}.yaml",
+            )
+        )
+        return _get_models(self.cfg)
+
+    def test_load_resnet20(self):
+        client, server = self._get_model_from_model_provider_config("resnet20")
+        self.assertIsInstance(client, Part)
+        self.assertIsInstance(server, Part)
+        self.assertEqual(len(client.layers), 4)
+        self.assertEqual(len(server.layers), 5)
+        self.assertEqual(server(client(torch.zeros(1, 3, 32, 32))).shape, (1, 100))
+
+    def test_load_resnet20_with_ae(self):
+        client, server = self._get_model_from_model_provider_config(
+            "resnet20-with-autoencoder"
+        )
+        self.assertIsInstance(client, ClientWithAutoencoder)
+        self.assertIsInstance(server, ServerWithAutoencoder)
+        self.assertEqual(len(client.model.layers), 4)
+        self.assertEqual(len(server.model.layers), 5)
+        self.assertEqual(server(client(torch.zeros(1, 3, 32, 32))).shape, (1, 100))
+
+    def test_load_resnet110(self):
+        client, server = self._get_model_from_model_provider_config("resnet110")
+        self.assertIsInstance(client, Part)
+        self.assertIsInstance(server, Part)
+        self.assertEqual(len(client.layers), 4)
+        self.assertEqual(len(server.layers), 5)
+        self.assertEqual(server(client(torch.zeros(1, 3, 32, 32))).shape, (1, 100))
+
+    def test_load_resnet110_with_ae(self):
+        client, server = self._get_model_from_model_provider_config(
+            "resnet110-with-autoencoder"
+        )
+        self.assertIsInstance(client, ClientWithAutoencoder)
+        self.assertIsInstance(server, ServerWithAutoencoder)
+        self.assertEqual(len(client.model.layers), 4)
+        self.assertEqual(len(server.model.layers), 5)
+        self.assertEqual(server(client(torch.zeros(1, 3, 32, 32))).shape, (1, 100))