diff --git a/edml/controllers/base_controller.py b/edml/controllers/base_controller.py index 64e763c96197971dfd1f8243405d9107fc7ce577..a2adca85736844b23f40a7c8c5c1a5c4574ef280 100644 --- a/edml/controllers/base_controller.py +++ b/edml/controllers/base_controller.py @@ -6,7 +6,7 @@ import torch from edml.controllers.early_stopping import create_early_stopping_callback from edml.core.device import DeviceRequestDispatcher -from edml.helpers.load_model import get_models +from edml.core.start_device import _get_models from edml.helpers.logging import SimpleLogger, create_logger from edml.helpers.metrics import ModelMetricResultContainer from edml.helpers.types import DeviceBatteryStatus @@ -41,7 +41,7 @@ class BaseController(abc.ABC): ) # if no weights are loaded, initialize the models randomly and set them on all devices - client_model, server_model = get_models(self.cfg) + client_model, server_model = _get_models(self.cfg) self._set_weights_on_all_devices( client_model.state_dict(), on_client=True, wait_for_ready=True )