From da2dd5766b37668d2d3c03b49b05d24befa68422 Mon Sep 17 00:00:00 2001 From: Carlos Yeverino Date: Wed, 22 Aug 2018 18:35:29 +0200 Subject: [PATCH] Add num_epoch and eval_metric training parameters to the CNNTrainer.ftl template and correct target code for corresponding test --- src/main/resources/templates/caffe2/CNNTrainer.ftl | 6 ++++++ src/test/resources/target_code/CNNTrainer_main.py | 2 ++ 2 files changed, 8 insertions(+) diff --git a/src/main/resources/templates/caffe2/CNNTrainer.ftl b/src/main/resources/templates/caffe2/CNNTrainer.ftl index 8d190c2..2150a47 100644 --- a/src/main/resources/templates/caffe2/CNNTrainer.ftl +++ b/src/main/resources/templates/caffe2/CNNTrainer.ftl @@ -16,6 +16,9 @@ if __name__ == "__main__": <#if (config.batchSize)??> batch_size = ${config.batchSize}, +<#if (config.numEpoch)??> + num_epoch = ${config.numEpoch}, + <#if (config.loadCheckpoint)??> load_checkpoint = ${config.loadCheckpoint?string("True","False")}, @@ -25,6 +28,9 @@ if __name__ == "__main__": <#if (config.normalize)??> normalize = ${config.normalize?string("True","False")}, +<#if (config.evalMetric)??> + eval_metric = ${config.evalMetric}, + <#if (config.configuration.optimizer)??> optimizer = '${config.optimizerName}', optimizer_params = { diff --git a/src/test/resources/target_code/CNNTrainer_main.py b/src/test/resources/target_code/CNNTrainer_main.py index 1d96b72..d96ed2e 100644 --- a/src/test/resources/target_code/CNNTrainer_main.py +++ b/src/test/resources/target_code/CNNTrainer_main.py @@ -12,6 +12,7 @@ if __name__ == "__main__": main_net1 = CNNCreator_main_net1.CNNCreator_main_net1() main_net1.train( batch_size = 64, + num_epoch = 10, load_checkpoint = False, context = 'gpu', normalize = True, @@ -25,6 +26,7 @@ if __name__ == "__main__": main_net2 = CNNCreator_main_net2.CNNCreator_main_net2() main_net2.train( batch_size = 32, + num_epoch = 10, load_checkpoint = False, context = 'gpu', normalize = True, -- GitLab