Skip to content
Snippets Groups Projects
Commit bde85811 authored by Tim Tobias Bauerle's avatar Tim Tobias Bauerle
Browse files

Fixed loading serialized model

parent b23927ad
Branches
No related tags found
2 merge requests!18Merge in main,!14Experiment configs
......@@ -6,3 +6,7 @@ class SerializedModel(nn.Module):
def __init__(self, model: nn.Module, path: str):
super().__init__()
model.load_state_dict(torch.load(path))
self.model = model
def forward(self, x):
return self.model(x)
import os
import unittest
import torch
from omegaconf import OmegaConf
from edml.core.start_device import _get_models
from edml.helpers.model_splitting import Part
from edml.models.autoencoder import ClientWithAutoencoder, ServerWithAutoencoder
class GetModelsTest(unittest.TestCase):
def setUp(self):
os.chdir("../../../")
self.cfg = OmegaConf.create({"some_key": "some_value"})
self.cfg.seed = OmegaConf.load(
os.path.join(
os.getcwd(),
"edml/config/seed/default.yaml",
)
)
def _get_model_from_model_provider_config(self, config_name):
self.cfg.model_provider = OmegaConf.load(
os.path.join(
os.getcwd(),
f"edml/config/model_provider/{config_name}.yaml",
)
)
return _get_models(self.cfg)
def test_load_resnet20(self):
client, server = self._get_model_from_model_provider_config("resnet20")
self.assertIsInstance(client, Part)
self.assertIsInstance(server, Part)
self.assertEqual(len(client.layers), 4)
self.assertEqual(len(server.layers), 5)
self.assertEqual(server(client(torch.zeros(1, 3, 32, 32))).shape, (1, 100))
def test_load_resnet20_with_ae(self):
client, server = self._get_model_from_model_provider_config(
"resnet20-with-autoencoder"
)
self.assertIsInstance(client, ClientWithAutoencoder)
self.assertIsInstance(server, ServerWithAutoencoder)
self.assertEqual(len(client.model.layers), 4)
self.assertEqual(len(server.model.layers), 5)
self.assertEqual(server(client(torch.zeros(1, 3, 32, 32))).shape, (1, 100))
def test_load_resnet110(self):
client, server = self._get_model_from_model_provider_config("resnet110")
self.assertIsInstance(client, Part)
self.assertIsInstance(server, Part)
self.assertEqual(len(client.layers), 4)
self.assertEqual(len(server.layers), 5)
self.assertEqual(server(client(torch.zeros(1, 3, 32, 32))).shape, (1, 100))
def test_load_resnet110_with_ae(self):
client, server = self._get_model_from_model_provider_config(
"resnet110-with-autoencoder"
)
self.assertIsInstance(client, ClientWithAutoencoder)
self.assertIsInstance(server, ServerWithAutoencoder)
self.assertEqual(len(client.model.layers), 4)
self.assertEqual(len(server.model.layers), 5)
self.assertEqual(server(client(torch.zeros(1, 3, 32, 32))).shape, (1, 100))
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment