diff --git a/src/main/resources/templates/caffe2/CNNTrainer.ftl b/src/main/resources/templates/caffe2/CNNTrainer.ftl index 2150a47418c99954045ad5743b17d4ec0860a19d..52fef5fb50ad45bc287b68263305ecfe6a7a83d2 100644 --- a/src/main/resources/templates/caffe2/CNNTrainer.ftl +++ b/src/main/resources/templates/caffe2/CNNTrainer.ftl @@ -29,7 +29,7 @@ if __name__ == "__main__": normalize = ${config.normalize?string("True","False")}, <#if (config.evalMetric)??> - eval_metric = ${config.evalMetric}, + eval_metric = '${config.evalMetric}', <#if (config.configuration.optimizer)??> optimizer = '${config.optimizerName}', diff --git a/src/test/java/de/monticore/lang/monticar/cnnarch/caffe2generator/GenerationTest.java b/src/test/java/de/monticore/lang/monticar/cnnarch/caffe2generator/GenerationTest.java index 2b84f722df1bc36baf170ebe8209c9bcc860b558..a55af965d7c0b3fefcbbbcc98f7d7d1c00f6c9c1 100644 --- a/src/test/java/de/monticore/lang/monticar/cnnarch/caffe2generator/GenerationTest.java +++ b/src/test/java/de/monticore/lang/monticar/cnnarch/caffe2generator/GenerationTest.java @@ -156,4 +156,107 @@ public class GenerationTest extends AbstractSymtabTest{ Arrays.asList( "CNNTrainer_main.py")); } + + @Test + public void testFullCfgGeneration() throws IOException, TemplateException { + Log.getFindings().clear(); + List configurations = new ArrayList<>(); + List instanceName = Arrays.asList("main_net1", "main_net2"); + + final ModelPath mp = new ModelPath(Paths.get("src/test/resources/valid_tests")); + GlobalScope scope = new GlobalScope(mp, new CNNTrainLanguage()); + + CNNTrainCompilationUnitSymbol compilationUnit = scope. + resolve("FullConfig", CNNTrainCompilationUnitSymbol.KIND).get(); + CNNTrainCocos.checkAll(compilationUnit); + configurations.add(compilationUnit.getConfiguration()); + + compilationUnit = scope. + resolve("FullConfig2", CNNTrainCompilationUnitSymbol.KIND).get(); + CNNTrainCocos.checkAll(compilationUnit); + configurations.add(compilationUnit.getConfiguration()); + + CNNArch2Caffe2 generator = new CNNArch2Caffe2(); + Map trainerMap = generator.generateTrainer(configurations, instanceName, "mainFull"); + + for (String fileName : trainerMap.keySet()){ + FileWriter writer = new FileWriter(generator.getGenerationTargetPath() + fileName); + writer.write(trainerMap.get(fileName)); + writer.close(); + } + + assertTrue(Log.getFindings().isEmpty()); + checkFilesAreEqual( + Paths.get("./target/generated-sources-cnnarch"), + Paths.get("./src/test/resources/target_code"), + Arrays.asList( + "CNNTrainer_mainFull.py")); + } + + @Test + public void testSimpleCfgGeneration() throws IOException, TemplateException { + Log.getFindings().clear(); + List configurations = new ArrayList<>(); + List instanceName = Arrays.asList("main_net1", "main_net2"); + + final ModelPath mp = new ModelPath(Paths.get("src/test/resources/valid_tests")); + GlobalScope scope = new GlobalScope(mp, new CNNTrainLanguage()); + + CNNTrainCompilationUnitSymbol compilationUnit = scope. + resolve("SimpleConfig1", CNNTrainCompilationUnitSymbol.KIND).get(); + CNNTrainCocos.checkAll(compilationUnit); + configurations.add(compilationUnit.getConfiguration()); + + compilationUnit = scope. + resolve("SimpleConfig2", CNNTrainCompilationUnitSymbol.KIND).get(); + CNNTrainCocos.checkAll(compilationUnit); + configurations.add(compilationUnit.getConfiguration()); + + CNNArch2Caffe2 generator = new CNNArch2Caffe2(); + Map trainerMap = generator.generateTrainer(configurations, instanceName, "mainSimple"); + + for (String fileName : trainerMap.keySet()){ + FileWriter writer = new FileWriter(generator.getGenerationTargetPath() + fileName); + writer.write(trainerMap.get(fileName)); + writer.close(); + } + + assertTrue(Log.getFindings().isEmpty()); + checkFilesAreEqual( + Paths.get("./target/generated-sources-cnnarch"), + Paths.get("./src/test/resources/target_code"), + Arrays.asList( + "CNNTrainer_mainSimple.py")); + } + + @Test + public void testEmptyCfgGeneration() throws IOException, TemplateException { + Log.getFindings().clear(); + List configurations = new ArrayList<>(); + List instanceName = Arrays.asList("main_net1"); + + final ModelPath mp = new ModelPath(Paths.get("src/test/resources/valid_tests")); + GlobalScope scope = new GlobalScope(mp, new CNNTrainLanguage()); + + CNNTrainCompilationUnitSymbol compilationUnit = scope. + resolve("EmptyConfig", CNNTrainCompilationUnitSymbol.KIND).get(); + CNNTrainCocos.checkAll(compilationUnit); + configurations.add(compilationUnit.getConfiguration()); + + CNNArch2Caffe2 generator = new CNNArch2Caffe2(); + Map trainerMap = generator.generateTrainer(configurations, instanceName, "mainEmpty"); + + for (String fileName : trainerMap.keySet()){ + FileWriter writer = new FileWriter(generator.getGenerationTargetPath() + fileName); + writer.write(trainerMap.get(fileName)); + writer.close(); + } + + assertTrue(Log.getFindings().isEmpty()); + checkFilesAreEqual( + Paths.get("./target/generated-sources-cnnarch"), + Paths.get("./src/test/resources/target_code"), + Arrays.asList( + "CNNTrainer_mainEmpty.py")); + } } diff --git a/src/test/resources/target_code/CNNTrainer_mainEmpty.py b/src/test/resources/target_code/CNNTrainer_mainEmpty.py new file mode 100644 index 0000000000000000000000000000000000000000..0cb6fb3fc16bdd2ce31c5fffd070294b9de54eea --- /dev/null +++ b/src/test/resources/target_code/CNNTrainer_mainEmpty.py @@ -0,0 +1,13 @@ +import logging +import mxnet as mx +import CNNCreator_main_net1 + +if __name__ == "__main__": + logging.basicConfig(level=logging.DEBUG) + logger = logging.getLogger() + handler = logging.FileHandler("train.log","w", encoding=None, delay="true") + logger.addHandler(handler) + + main_net1 = CNNCreator_main_net1.CNNCreator_main_net1() + main_net1.train( + ) diff --git a/src/test/resources/target_code/CNNTrainer_mainFull.py b/src/test/resources/target_code/CNNTrainer_mainFull.py new file mode 100644 index 0000000000000000000000000000000000000000..7c12a09ef2235ef5f96eafa8d7fe54794699fdf4 --- /dev/null +++ b/src/test/resources/target_code/CNNTrainer_mainFull.py @@ -0,0 +1,57 @@ +import logging +import mxnet as mx +import CNNCreator_main_net1 +import CNNCreator_main_net2 + +if __name__ == "__main__": + logging.basicConfig(level=logging.DEBUG) + logger = logging.getLogger() + handler = logging.FileHandler("train.log","w", encoding=None, delay="true") + logger.addHandler(handler) + + main_net1 = CNNCreator_main_net1.CNNCreator_main_net1() + main_net1.train( + batch_size = 100, + num_epoch = 5, + load_checkpoint = True, + context = 'gpu', + normalize = True, + eval_metric = 'mse', + optimizer = 'rmsprop', + optimizer_params = { + 'weight_decay': 0.01, + 'centered': True, + 'gamma2': 0.9, + 'gamma1': 0.9, + 'clip_weights': 10.0, + 'learning_rate_decay': 0.9, + 'epsilon': 1.0E-6, + 'rescale_grad': 1.1, + 'clip_gradient': 10.0, + 'learning_rate_minimum': 1.0E-5, + 'learning_rate_policy': 'step', + 'learning_rate': 0.001, + 'step_size': 1000 } + ) + main_net2 = CNNCreator_main_net2.CNNCreator_main_net2() + main_net2.train( + batch_size = 100, + num_epoch = 10, + load_checkpoint = False, + context = 'gpu', + normalize = False, + eval_metric = 'topKAccuracy', + optimizer = 'adam', + optimizer_params = { + 'epsilon': 1.0E-6, + 'weight_decay': 0.01, + 'rescale_grad': 1.1, + 'beta1': 0.9, + 'clip_gradient': 10.0, + 'beta2': 0.9, + 'learning_rate_minimum': 0.001, + 'learning_rate_policy': 'exp', + 'learning_rate': 0.001, + 'learning_rate_decay': 0.9, + 'step_size': 1000 } + ) diff --git a/src/test/resources/target_code/CNNTrainer_mainSimple.py b/src/test/resources/target_code/CNNTrainer_mainSimple.py new file mode 100644 index 0000000000000000000000000000000000000000..d29da2ea8dd6dc43823852cb15aeb92ce668ea02 --- /dev/null +++ b/src/test/resources/target_code/CNNTrainer_mainSimple.py @@ -0,0 +1,27 @@ +import logging +import mxnet as mx +import CNNCreator_main_net1 +import CNNCreator_main_net2 + +if __name__ == "__main__": + logging.basicConfig(level=logging.DEBUG) + logger = logging.getLogger() + handler = logging.FileHandler("train.log","w", encoding=None, delay="true") + logger.addHandler(handler) + + main_net1 = CNNCreator_main_net1.CNNCreator_main_net1() + main_net1.train( + batch_size = 100, + num_epoch = 50, + optimizer = 'adam', + optimizer_params = { + 'learning_rate': 0.001 } + ) + main_net2 = CNNCreator_main_net2.CNNCreator_main_net2() + main_net2.train( + batch_size = 100, + num_epoch = 5, + optimizer = 'sgd', + optimizer_params = { + 'learning_rate': 0.1 } + ) diff --git a/src/test/resources/valid_tests/EmptyConfig.cnnt b/src/test/resources/valid_tests/EmptyConfig.cnnt new file mode 100644 index 0000000000000000000000000000000000000000..357c4235f324198fe9e12d5b3bc434fdea27c19e --- /dev/null +++ b/src/test/resources/valid_tests/EmptyConfig.cnnt @@ -0,0 +1,2 @@ +configuration EmptyConfig{ +} diff --git a/src/test/resources/valid_tests/FullConfig.cnnt b/src/test/resources/valid_tests/FullConfig.cnnt new file mode 100644 index 0000000000000000000000000000000000000000..df3313b7263ab850c4398c74d2e9e114f184f629 --- /dev/null +++ b/src/test/resources/valid_tests/FullConfig.cnnt @@ -0,0 +1,23 @@ +configuration FullConfig{ + num_epoch : 5 + batch_size : 100 + load_checkpoint : true + eval_metric : mse + context : gpu + normalize : true + optimizer : rmsprop{ + learning_rate : 0.001 + learning_rate_minimum : 0.00001 + weight_decay : 0.01 + learning_rate_decay : 0.9 + learning_rate_policy : step + step_size : 1000 + rescale_grad : 1.1 + clip_gradient : 10 + gamma1 : 0.9 + gamma2 : 0.9 + epsilon : 0.000001 + centered : true + clip_weights : 10 + } +} diff --git a/src/test/resources/valid_tests/FullConfig2.cnnt b/src/test/resources/valid_tests/FullConfig2.cnnt new file mode 100644 index 0000000000000000000000000000000000000000..3e585e4a87f0b9832d34d38b87a74a866f0baf02 --- /dev/null +++ b/src/test/resources/valid_tests/FullConfig2.cnnt @@ -0,0 +1,21 @@ +configuration FullConfig2{ + num_epoch : 10 + batch_size : 100 + load_checkpoint : false + context : gpu + eval_metric : top_k_accuracy + normalize : false + optimizer : adam{ + learning_rate : 0.001 + learning_rate_minimum : 0.001 + weight_decay : 0.01 + learning_rate_decay : 0.9 + learning_rate_policy : exp + step_size : 1000 + rescale_grad : 1.1 + clip_gradient : 10 + beta1 : 0.9 + beta2 : 0.9 + epsilon : 0.000001 + } +} diff --git a/src/test/resources/valid_tests/SimpleConfig1.cnnt b/src/test/resources/valid_tests/SimpleConfig1.cnnt new file mode 100644 index 0000000000000000000000000000000000000000..d7463b64ae7429d81436dcb49619521905c5778a --- /dev/null +++ b/src/test/resources/valid_tests/SimpleConfig1.cnnt @@ -0,0 +1,7 @@ +configuration SimpleConfig1{ + num_epoch : 50 + batch_size : 100 + optimizer : adam{ + learning_rate : 0.001 + } +} diff --git a/src/test/resources/valid_tests/SimpleConfig2.cnnt b/src/test/resources/valid_tests/SimpleConfig2.cnnt new file mode 100644 index 0000000000000000000000000000000000000000..c8f3693901e1d7c0140be6cc7941ed53b3e0c7ab --- /dev/null +++ b/src/test/resources/valid_tests/SimpleConfig2.cnnt @@ -0,0 +1,7 @@ +configuration SimpleConfig2{ + num_epoch:5 + batch_size:100 + optimizer:sgd{ + learning_rate:0.1 + } +}