Add new tests (adapted from CNNTrainLang) for generation of CNNTrainer.

Fix for missing quotes in eval_metric value.
parent da2dd576
Pipeline #69756 failed with stages
......@@ -29,7 +29,7 @@ if __name__ == "__main__":
normalize = ${config.normalize?string("True","False")},
</#if>
<#if (config.evalMetric)??>
eval_metric = ${config.evalMetric},
eval_metric = '${config.evalMetric}',
</#if>
<#if (config.configuration.optimizer)??>
optimizer = '${config.optimizerName}',
......
......@@ -156,4 +156,107 @@ public class GenerationTest extends AbstractSymtabTest{
Arrays.asList(
"CNNTrainer_main.py"));
}
@Test
public void testFullCfgGeneration() throws IOException, TemplateException {
Log.getFindings().clear();
List<ConfigurationSymbol> configurations = new ArrayList<>();
List<String> instanceName = Arrays.asList("main_net1", "main_net2");
final ModelPath mp = new ModelPath(Paths.get("src/test/resources/valid_tests"));
GlobalScope scope = new GlobalScope(mp, new CNNTrainLanguage());
CNNTrainCompilationUnitSymbol compilationUnit = scope.<CNNTrainCompilationUnitSymbol>
resolve("FullConfig", CNNTrainCompilationUnitSymbol.KIND).get();
CNNTrainCocos.checkAll(compilationUnit);
configurations.add(compilationUnit.getConfiguration());
compilationUnit = scope.<CNNTrainCompilationUnitSymbol>
resolve("FullConfig2", CNNTrainCompilationUnitSymbol.KIND).get();
CNNTrainCocos.checkAll(compilationUnit);
configurations.add(compilationUnit.getConfiguration());
CNNArch2Caffe2 generator = new CNNArch2Caffe2();
Map<String,String> trainerMap = generator.generateTrainer(configurations, instanceName, "mainFull");
for (String fileName : trainerMap.keySet()){
FileWriter writer = new FileWriter(generator.getGenerationTargetPath() + fileName);
writer.write(trainerMap.get(fileName));
writer.close();
}
assertTrue(Log.getFindings().isEmpty());
checkFilesAreEqual(
Paths.get("./target/generated-sources-cnnarch"),
Paths.get("./src/test/resources/target_code"),
Arrays.asList(
"CNNTrainer_mainFull.py"));
}
@Test
public void testSimpleCfgGeneration() throws IOException, TemplateException {
Log.getFindings().clear();
List<ConfigurationSymbol> configurations = new ArrayList<>();
List<String> instanceName = Arrays.asList("main_net1", "main_net2");
final ModelPath mp = new ModelPath(Paths.get("src/test/resources/valid_tests"));
GlobalScope scope = new GlobalScope(mp, new CNNTrainLanguage());
CNNTrainCompilationUnitSymbol compilationUnit = scope.<CNNTrainCompilationUnitSymbol>
resolve("SimpleConfig1", CNNTrainCompilationUnitSymbol.KIND).get();
CNNTrainCocos.checkAll(compilationUnit);
configurations.add(compilationUnit.getConfiguration());
compilationUnit = scope.<CNNTrainCompilationUnitSymbol>
resolve("SimpleConfig2", CNNTrainCompilationUnitSymbol.KIND).get();
CNNTrainCocos.checkAll(compilationUnit);
configurations.add(compilationUnit.getConfiguration());
CNNArch2Caffe2 generator = new CNNArch2Caffe2();
Map<String,String> trainerMap = generator.generateTrainer(configurations, instanceName, "mainSimple");
for (String fileName : trainerMap.keySet()){
FileWriter writer = new FileWriter(generator.getGenerationTargetPath() + fileName);
writer.write(trainerMap.get(fileName));
writer.close();
}
assertTrue(Log.getFindings().isEmpty());
checkFilesAreEqual(
Paths.get("./target/generated-sources-cnnarch"),
Paths.get("./src/test/resources/target_code"),
Arrays.asList(
"CNNTrainer_mainSimple.py"));
}
@Test
public void testEmptyCfgGeneration() throws IOException, TemplateException {
Log.getFindings().clear();
List<ConfigurationSymbol> configurations = new ArrayList<>();
List<String> instanceName = Arrays.asList("main_net1");
final ModelPath mp = new ModelPath(Paths.get("src/test/resources/valid_tests"));
GlobalScope scope = new GlobalScope(mp, new CNNTrainLanguage());
CNNTrainCompilationUnitSymbol compilationUnit = scope.<CNNTrainCompilationUnitSymbol>
resolve("EmptyConfig", CNNTrainCompilationUnitSymbol.KIND).get();
CNNTrainCocos.checkAll(compilationUnit);
configurations.add(compilationUnit.getConfiguration());
CNNArch2Caffe2 generator = new CNNArch2Caffe2();
Map<String,String> trainerMap = generator.generateTrainer(configurations, instanceName, "mainEmpty");
for (String fileName : trainerMap.keySet()){
FileWriter writer = new FileWriter(generator.getGenerationTargetPath() + fileName);
writer.write(trainerMap.get(fileName));
writer.close();
}
assertTrue(Log.getFindings().isEmpty());
checkFilesAreEqual(
Paths.get("./target/generated-sources-cnnarch"),
Paths.get("./src/test/resources/target_code"),
Arrays.asList(
"CNNTrainer_mainEmpty.py"));
}
}
import logging
import mxnet as mx
import CNNCreator_main_net1
if __name__ == "__main__":
logging.basicConfig(level=logging.DEBUG)
logger = logging.getLogger()
handler = logging.FileHandler("train.log","w", encoding=None, delay="true")
logger.addHandler(handler)
main_net1 = CNNCreator_main_net1.CNNCreator_main_net1()
main_net1.train(
)
import logging
import mxnet as mx
import CNNCreator_main_net1
import CNNCreator_main_net2
if __name__ == "__main__":
logging.basicConfig(level=logging.DEBUG)
logger = logging.getLogger()
handler = logging.FileHandler("train.log","w", encoding=None, delay="true")
logger.addHandler(handler)
main_net1 = CNNCreator_main_net1.CNNCreator_main_net1()
main_net1.train(
batch_size = 100,
num_epoch = 5,
load_checkpoint = True,
context = 'gpu',
normalize = True,
eval_metric = 'mse',
optimizer = 'rmsprop',
optimizer_params = {
'weight_decay': 0.01,
'centered': True,
'gamma2': 0.9,
'gamma1': 0.9,
'clip_weights': 10.0,
'learning_rate_decay': 0.9,
'epsilon': 1.0E-6,
'rescale_grad': 1.1,
'clip_gradient': 10.0,
'learning_rate_minimum': 1.0E-5,
'learning_rate_policy': 'step',
'learning_rate': 0.001,
'step_size': 1000 }
)
main_net2 = CNNCreator_main_net2.CNNCreator_main_net2()
main_net2.train(
batch_size = 100,
num_epoch = 10,
load_checkpoint = False,
context = 'gpu',
normalize = False,
eval_metric = 'topKAccuracy',
optimizer = 'adam',
optimizer_params = {
'epsilon': 1.0E-6,
'weight_decay': 0.01,
'rescale_grad': 1.1,
'beta1': 0.9,
'clip_gradient': 10.0,
'beta2': 0.9,
'learning_rate_minimum': 0.001,
'learning_rate_policy': 'exp',
'learning_rate': 0.001,
'learning_rate_decay': 0.9,
'step_size': 1000 }
)
import logging
import mxnet as mx
import CNNCreator_main_net1
import CNNCreator_main_net2
if __name__ == "__main__":
logging.basicConfig(level=logging.DEBUG)
logger = logging.getLogger()
handler = logging.FileHandler("train.log","w", encoding=None, delay="true")
logger.addHandler(handler)
main_net1 = CNNCreator_main_net1.CNNCreator_main_net1()
main_net1.train(
batch_size = 100,
num_epoch = 50,
optimizer = 'adam',
optimizer_params = {
'learning_rate': 0.001 }
)
main_net2 = CNNCreator_main_net2.CNNCreator_main_net2()
main_net2.train(
batch_size = 100,
num_epoch = 5,
optimizer = 'sgd',
optimizer_params = {
'learning_rate': 0.1 }
)
configuration FullConfig{
num_epoch : 5
batch_size : 100
load_checkpoint : true
eval_metric : mse
context : gpu
normalize : true
optimizer : rmsprop{
learning_rate : 0.001
learning_rate_minimum : 0.00001
weight_decay : 0.01
learning_rate_decay : 0.9
learning_rate_policy : step
step_size : 1000
rescale_grad : 1.1
clip_gradient : 10
gamma1 : 0.9
gamma2 : 0.9
epsilon : 0.000001
centered : true
clip_weights : 10
}
}
configuration FullConfig2{
num_epoch : 10
batch_size : 100
load_checkpoint : false
context : gpu
eval_metric : top_k_accuracy
normalize : false
optimizer : adam{
learning_rate : 0.001
learning_rate_minimum : 0.001
weight_decay : 0.01
learning_rate_decay : 0.9
learning_rate_policy : exp
step_size : 1000
rescale_grad : 1.1
clip_gradient : 10
beta1 : 0.9
beta2 : 0.9
epsilon : 0.000001
}
}
configuration SimpleConfig1{
num_epoch : 50
batch_size : 100
optimizer : adam{
learning_rate : 0.001
}
}
configuration SimpleConfig2{
num_epoch:5
batch_size:100
optimizer:sgd{
learning_rate:0.1
}
}
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment