diff --git a/edml/models/provider/path.py b/edml/models/provider/path.py index c53e893ae0048550887f73b32d8b4297440cd8e8..6245270a9cc43cfea81eec58d03c4ed84fb746c7 100644 --- a/edml/models/provider/path.py +++ b/edml/models/provider/path.py @@ -6,3 +6,7 @@ class SerializedModel(nn.Module): def __init__(self, model: nn.Module, path: str): super().__init__() model.load_state_dict(torch.load(path)) + self.model = model + + def forward(self, x): + return self.model(x) diff --git a/edml/tests/core/start_device_test.py b/edml/tests/core/start_device_test.py new file mode 100644 index 0000000000000000000000000000000000000000..e4ae787b9850fa612648fd67d3d60ae129ed7b01 --- /dev/null +++ b/edml/tests/core/start_device_test.py @@ -0,0 +1,66 @@ +import os +import unittest + +import torch +from omegaconf import OmegaConf + +from edml.core.start_device import _get_models +from edml.helpers.model_splitting import Part +from edml.models.autoencoder import ClientWithAutoencoder, ServerWithAutoencoder + + +class GetModelsTest(unittest.TestCase): + def setUp(self): + os.chdir("../../../") + self.cfg = OmegaConf.create({"some_key": "some_value"}) + self.cfg.seed = OmegaConf.load( + os.path.join( + os.getcwd(), + "edml/config/seed/default.yaml", + ) + ) + + def _get_model_from_model_provider_config(self, config_name): + self.cfg.model_provider = OmegaConf.load( + os.path.join( + os.getcwd(), + f"edml/config/model_provider/{config_name}.yaml", + ) + ) + return _get_models(self.cfg) + + def test_load_resnet20(self): + client, server = self._get_model_from_model_provider_config("resnet20") + self.assertIsInstance(client, Part) + self.assertIsInstance(server, Part) + self.assertEqual(len(client.layers), 4) + self.assertEqual(len(server.layers), 5) + self.assertEqual(server(client(torch.zeros(1, 3, 32, 32))).shape, (1, 100)) + + def test_load_resnet20_with_ae(self): + client, server = self._get_model_from_model_provider_config( + "resnet20-with-autoencoder" + ) + self.assertIsInstance(client, ClientWithAutoencoder) + self.assertIsInstance(server, ServerWithAutoencoder) + self.assertEqual(len(client.model.layers), 4) + self.assertEqual(len(server.model.layers), 5) + self.assertEqual(server(client(torch.zeros(1, 3, 32, 32))).shape, (1, 100)) + + def test_load_resnet110(self): + client, server = self._get_model_from_model_provider_config("resnet110") + self.assertIsInstance(client, Part) + self.assertIsInstance(server, Part) + self.assertEqual(len(client.layers), 4) + self.assertEqual(len(server.layers), 5) + self.assertEqual(server(client(torch.zeros(1, 3, 32, 32))).shape, (1, 100)) + + def test_load_resnet110_with_ae(self): + client, server = self._get_model_from_model_provider_config( + "resnet110-with-autoencoder" + ) + self.assertIsInstance(client, ClientWithAutoencoder) + self.assertIsInstance(server, ServerWithAutoencoder) + self.assertEqual(len(client.model.layers), 4) + self.assertEqual(len(server.model.layers), 5) + self.assertEqual(server(client(torch.zeros(1, 3, 32, 32))).shape, (1, 100))