Skip to content
Snippets Groups Projects
Commit bde85811 authored by Tim Tobias Bauerle's avatar Tim Tobias Bauerle
Browse files

Fixed loading serialized model

parent b23927ad
No related branches found
No related tags found
2 merge requests!18Merge in main,!14Experiment configs
...@@ -6,3 +6,7 @@ class SerializedModel(nn.Module): ...@@ -6,3 +6,7 @@ class SerializedModel(nn.Module):
def __init__(self, model: nn.Module, path: str): def __init__(self, model: nn.Module, path: str):
super().__init__() super().__init__()
model.load_state_dict(torch.load(path)) model.load_state_dict(torch.load(path))
self.model = model
def forward(self, x):
return self.model(x)
import os
import unittest
import torch
from omegaconf import OmegaConf
from edml.core.start_device import _get_models
from edml.helpers.model_splitting import Part
from edml.models.autoencoder import ClientWithAutoencoder, ServerWithAutoencoder
class GetModelsTest(unittest.TestCase):
def setUp(self):
os.chdir("../../../")
self.cfg = OmegaConf.create({"some_key": "some_value"})
self.cfg.seed = OmegaConf.load(
os.path.join(
os.getcwd(),
"edml/config/seed/default.yaml",
)
)
def _get_model_from_model_provider_config(self, config_name):
self.cfg.model_provider = OmegaConf.load(
os.path.join(
os.getcwd(),
f"edml/config/model_provider/{config_name}.yaml",
)
)
return _get_models(self.cfg)
def test_load_resnet20(self):
client, server = self._get_model_from_model_provider_config("resnet20")
self.assertIsInstance(client, Part)
self.assertIsInstance(server, Part)
self.assertEqual(len(client.layers), 4)
self.assertEqual(len(server.layers), 5)
self.assertEqual(server(client(torch.zeros(1, 3, 32, 32))).shape, (1, 100))
def test_load_resnet20_with_ae(self):
client, server = self._get_model_from_model_provider_config(
"resnet20-with-autoencoder"
)
self.assertIsInstance(client, ClientWithAutoencoder)
self.assertIsInstance(server, ServerWithAutoencoder)
self.assertEqual(len(client.model.layers), 4)
self.assertEqual(len(server.model.layers), 5)
self.assertEqual(server(client(torch.zeros(1, 3, 32, 32))).shape, (1, 100))
def test_load_resnet110(self):
client, server = self._get_model_from_model_provider_config("resnet110")
self.assertIsInstance(client, Part)
self.assertIsInstance(server, Part)
self.assertEqual(len(client.layers), 4)
self.assertEqual(len(server.layers), 5)
self.assertEqual(server(client(torch.zeros(1, 3, 32, 32))).shape, (1, 100))
def test_load_resnet110_with_ae(self):
client, server = self._get_model_from_model_provider_config(
"resnet110-with-autoencoder"
)
self.assertIsInstance(client, ClientWithAutoencoder)
self.assertIsInstance(server, ServerWithAutoencoder)
self.assertEqual(len(client.model.layers), 4)
self.assertEqual(len(server.model.layers), 5)
self.assertEqual(server(client(torch.zeros(1, 3, 32, 32))).shape, (1, 100))
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment