From bde858114b25c58a5eadd8943d3794905b1f35d6 Mon Sep 17 00:00:00 2001 From: Tim Bauerle <tim.bauerle@rwth-aachen.de> Date: Mon, 1 Jul 2024 17:12:53 +0200 Subject: [PATCH] Fixed loading serialized model --- edml/models/provider/path.py | 4 ++ edml/tests/core/start_device_test.py | 66 ++++++++++++++++++++++++++++ 2 files changed, 70 insertions(+) create mode 100644 edml/tests/core/start_device_test.py diff --git a/edml/models/provider/path.py b/edml/models/provider/path.py index c53e893..6245270 100644 --- a/edml/models/provider/path.py +++ b/edml/models/provider/path.py @@ -6,3 +6,7 @@ class SerializedModel(nn.Module): def __init__(self, model: nn.Module, path: str): super().__init__() model.load_state_dict(torch.load(path)) + self.model = model + + def forward(self, x): + return self.model(x) diff --git a/edml/tests/core/start_device_test.py b/edml/tests/core/start_device_test.py new file mode 100644 index 0000000..e4ae787 --- /dev/null +++ b/edml/tests/core/start_device_test.py @@ -0,0 +1,66 @@ +import os +import unittest + +import torch +from omegaconf import OmegaConf + +from edml.core.start_device import _get_models +from edml.helpers.model_splitting import Part +from edml.models.autoencoder import ClientWithAutoencoder, ServerWithAutoencoder + + +class GetModelsTest(unittest.TestCase): + def setUp(self): + os.chdir("../../../") + self.cfg = OmegaConf.create({"some_key": "some_value"}) + self.cfg.seed = OmegaConf.load( + os.path.join( + os.getcwd(), + "edml/config/seed/default.yaml", + ) + ) + + def _get_model_from_model_provider_config(self, config_name): + self.cfg.model_provider = OmegaConf.load( + os.path.join( + os.getcwd(), + f"edml/config/model_provider/{config_name}.yaml", + ) + ) + return _get_models(self.cfg) + + def test_load_resnet20(self): + client, server = self._get_model_from_model_provider_config("resnet20") + self.assertIsInstance(client, Part) + self.assertIsInstance(server, Part) + self.assertEqual(len(client.layers), 4) + self.assertEqual(len(server.layers), 5) + self.assertEqual(server(client(torch.zeros(1, 3, 32, 32))).shape, (1, 100)) + + def test_load_resnet20_with_ae(self): + client, server = self._get_model_from_model_provider_config( + "resnet20-with-autoencoder" + ) + self.assertIsInstance(client, ClientWithAutoencoder) + self.assertIsInstance(server, ServerWithAutoencoder) + self.assertEqual(len(client.model.layers), 4) + self.assertEqual(len(server.model.layers), 5) + self.assertEqual(server(client(torch.zeros(1, 3, 32, 32))).shape, (1, 100)) + + def test_load_resnet110(self): + client, server = self._get_model_from_model_provider_config("resnet110") + self.assertIsInstance(client, Part) + self.assertIsInstance(server, Part) + self.assertEqual(len(client.layers), 4) + self.assertEqual(len(server.layers), 5) + self.assertEqual(server(client(torch.zeros(1, 3, 32, 32))).shape, (1, 100)) + + def test_load_resnet110_with_ae(self): + client, server = self._get_model_from_model_provider_config( + "resnet110-with-autoencoder" + ) + self.assertIsInstance(client, ClientWithAutoencoder) + self.assertIsInstance(server, ServerWithAutoencoder) + self.assertEqual(len(client.model.layers), 4) + self.assertEqual(len(server.model.layers), 5) + self.assertEqual(server(client(torch.zeros(1, 3, 32, 32))).shape, (1, 100)) -- GitLab