diff --git a/edml/models/provider/path.py b/edml/models/provider/path.py index 6245270a9cc43cfea81eec58d03c4ed84fb746c7..2fe5a8aa41fe410b13b09f41ceece73896cbedf4 100644 --- a/edml/models/provider/path.py +++ b/edml/models/provider/path.py @@ -5,7 +5,7 @@ from torch import nn class SerializedModel(nn.Module): def __init__(self, model: nn.Module, path: str): super().__init__() - model.load_state_dict(torch.load(path)) + model.load_state_dict(torch.load(path, map_location="cpu")) self.model = model def forward(self, x):