diff --git a/main.py b/main.py index e8b3cdcb3d0eab1c667bf9a01c4b90e1a000291f..d8fabc2bc41e4226e56ff5cffbd5ebe2642aaeb8 100644 --- a/main.py +++ b/main.py @@ -5,6 +5,7 @@ import numpy.random import torch from omegaconf import DictConfig +from edml.controllers.test_controller import TestController from edml.core.start_device import launch_device from edml.helpers.config_helpers import preprocess_config, instantiate_controller from multiprocessing import Process @@ -37,6 +38,8 @@ def main(cfg): p.terminate() p.join() + _run_test_evaluation(cfg) + def _start_controller(cfg: DictConfig): controller = instantiate_controller(cfg) @@ -54,6 +57,22 @@ def _start_device(cfg: DictConfig, device_id: str): launch_device(cfg) +def _run_test_evaluation(cfg): + cfg.experiment.job = "test" + cfg.experiment.partition = "False" + cfg.experiment.latency = None + cfg.num_devices = 1 + device_id = cfg.topology.devices[0].device_id + p = Process(target=_start_device, args=(cfg, device_id), name=device_id) + p.start() + + controller = TestController(cfg) + controller.train() + + p.terminate() + p.join() + + def _make_deterministic(seed_cfg: DictConfig): seed = seed_cfg.value