Set default optimizer, updated cnntrainer to replace unsupported optimizer by...

Set default optimizer, updated cnntrainer to replace unsupported optimizer by default, added JUnit test for unsupported train parameters
parent 03f422c7
Pipeline #106405 passed with stages
in 3 minutes and 51 seconds
......@@ -8,6 +8,7 @@ import de.monticore.lang.monticar.cnntrain._cocos.CNNTrainCocos;
import de.monticore.lang.monticar.cnntrain._symboltable.CNNTrainCompilationUnitSymbol;
import de.monticore.lang.monticar.cnntrain._symboltable.CNNTrainLanguage;
import de.monticore.lang.monticar.cnntrain._symboltable.ConfigurationSymbol;
import de.monticore.lang.monticar.cnntrain._symboltable.OptimizerSymbol;
import de.monticore.lang.monticar.generator.FileContent;
import de.monticore.lang.monticar.generator.cpp.GeneratorCPP;
import de.monticore.symboltable.GlobalScope;
......@@ -49,7 +50,8 @@ public class CNNTrain2Caffe2 implements CNNTrainGenerator {
ASTOptimizerEntry astOptimizer = (ASTOptimizerEntry) configuration.getOptimizer().getAstNode().get();
astOptimizer.accept(funcChecker);
if (funcChecker.getUnsupportedElemList().contains(funcChecker.unsupportedOptFlag)) {
configuration.setOptimizer(null);
OptimizerSymbol adamOptimizer = new OptimizerSymbol("adam");
configuration.setOptimizer(adamOptimizer); //Set default as adam optimizer
}else {
Iterator it = configuration.getOptimizer().getOptimizerParamMap().keySet().iterator();
while (it.hasNext()) {
......
......@@ -33,9 +33,9 @@ if __name__ == "__main__":
loss='${config.loss}',
</#if>
<#if (config.configuration.optimizer)??>
opt_type='${config.optimizerName}',
opt_type='${config.optimizerName}'<#if config.optimizerParams?has_content>,
<#list config.optimizerParams?keys as param>
<#--To adapt parameter names since parameter names in Caffe2 are different than in CNNTrainLang-->
<#--To adapt parameter names to Caffe2 since they are different than in CNNTrainLang-->
<#assign paramName = param>
<#if param == "learning_rate">
<#assign paramName = "base_learning_rate">
......@@ -50,6 +50,7 @@ if __name__ == "__main__":
</#if>
${paramName}=${config.optimizerParams[param]}<#sep>,
</#list>
</#if>
</#if>
)
......
......@@ -169,6 +169,20 @@ public class GenerationTest extends AbstractSymtabTest{
"CNNTrainer_emptyConfig.py"));
}
@Test
public void testUnsupportedTrainParameters() throws IOException {
Log.getFindings().clear();
Path modelPath = Paths.get("src/test/resources/valid_tests");
CNNTrain2Caffe2 trainGenerator = new CNNTrain2Caffe2();
trainGenerator.generate(modelPath, "UnsupportedConfig");
assertTrue(Log.getFindings().size() == 6);
checkFilesAreEqual(
Paths.get("./target/generated-sources-cnnarch"),
Paths.get("./src/test/resources/target_code"),
Arrays.asList(
"CNNTrainer_unsupportedConfig.py"));
}
@Test
public void testCMakeGeneration() {
......
from caffe2.python import workspace, core, model_helper, brew, optimizer
from caffe2.python.predictor import mobile_exporter
from caffe2.proto import caffe2_pb2
import numpy as np
import logging
import CNNCreator_unsupportedConfig
if __name__ == "__main__":
logging.basicConfig(level=logging.DEBUG)
logger = logging.getLogger()
handler = logging.FileHandler("train.log", "w", encoding=None, delay="true")
logger.addHandler(handler)
unsupportedConfig = CNNCreator_unsupportedConfig.CNNCreator_unsupportedConfig()
unsupportedConfig.train(
num_epoch=5,
batch_size=100,
context='gpu',
opt_type='adam'
)
configuration UnsupportedConfig{
num_epoch : 5
batch_size : 100
context : gpu
load_checkpoint : true
normalize : true
optimizer : nag{
learning_rate_minimum : 0.00001
rescale_grad : 1.1
clip_gradient : 10
}
}
\ No newline at end of file
Markdown is supported
0%
or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment