From b23927ad72287dea9c5579b5a1599fc0ddac8dd7 Mon Sep 17 00:00:00 2001
From: Tim Bauerle <tim.bauerle@rwth-aachen.de>
Date: Mon, 1 Jul 2024 12:24:06 +0200
Subject: [PATCH] Added test run after training

---
 main.py | 19 +++++++++++++++++++
 1 file changed, 19 insertions(+)

diff --git a/main.py b/main.py
index e8b3cdc..d8fabc2 100644
--- a/main.py
+++ b/main.py
@@ -5,6 +5,7 @@ import numpy.random
 import torch
 from omegaconf import DictConfig
 
+from edml.controllers.test_controller import TestController
 from edml.core.start_device import launch_device
 from edml.helpers.config_helpers import preprocess_config, instantiate_controller
 from multiprocessing import Process
@@ -37,6 +38,8 @@ def main(cfg):
         p.terminate()
         p.join()
 
+    _run_test_evaluation(cfg)
+
 
 def _start_controller(cfg: DictConfig):
     controller = instantiate_controller(cfg)
@@ -54,6 +57,22 @@ def _start_device(cfg: DictConfig, device_id: str):
     launch_device(cfg)
 
 
+def _run_test_evaluation(cfg):
+    cfg.experiment.job = "test"
+    cfg.experiment.partition = "False"
+    cfg.experiment.latency = None
+    cfg.num_devices = 1
+    device_id = cfg.topology.devices[0].device_id
+    p = Process(target=_start_device, args=(cfg, device_id), name=device_id)
+    p.start()
+
+    controller = TestController(cfg)
+    controller.train()
+
+    p.terminate()
+    p.join()
+
+
 def _make_deterministic(seed_cfg: DictConfig):
     seed = seed_cfg.value
 
-- 
GitLab