diff --git a/src/main/resources/templates/caffe2/CNNTrainer.ftl b/src/main/resources/templates/caffe2/CNNTrainer.ftl index 8d190c26e437624a490ec69e18e6994f0afd9f4d..2150a47418c99954045ad5743b17d4ec0860a19d 100644 --- a/src/main/resources/templates/caffe2/CNNTrainer.ftl +++ b/src/main/resources/templates/caffe2/CNNTrainer.ftl @@ -16,6 +16,9 @@ if __name__ == "__main__": <#if (config.batchSize)??> batch_size = ${config.batchSize}, +<#if (config.numEpoch)??> + num_epoch = ${config.numEpoch}, + <#if (config.loadCheckpoint)??> load_checkpoint = ${config.loadCheckpoint?string("True","False")}, @@ -25,6 +28,9 @@ if __name__ == "__main__": <#if (config.normalize)??> normalize = ${config.normalize?string("True","False")}, +<#if (config.evalMetric)??> + eval_metric = ${config.evalMetric}, + <#if (config.configuration.optimizer)??> optimizer = '${config.optimizerName}', optimizer_params = { diff --git a/src/test/resources/target_code/CNNTrainer_main.py b/src/test/resources/target_code/CNNTrainer_main.py index 1d96b72a780f2ead330e0f086b47f18930a81b80..d96ed2e03d43ebdca219190d65f3a8a5fd1dd59c 100644 --- a/src/test/resources/target_code/CNNTrainer_main.py +++ b/src/test/resources/target_code/CNNTrainer_main.py @@ -12,6 +12,7 @@ if __name__ == "__main__": main_net1 = CNNCreator_main_net1.CNNCreator_main_net1() main_net1.train( batch_size = 64, + num_epoch = 10, load_checkpoint = False, context = 'gpu', normalize = True, @@ -25,6 +26,7 @@ if __name__ == "__main__": main_net2 = CNNCreator_main_net2.CNNCreator_main_net2() main_net2.train( batch_size = 32, + num_epoch = 10, load_checkpoint = False, context = 'gpu', normalize = True,