Commit 4b79d7f6 authored by Evgeny Kusmenko's avatar Evgeny Kusmenko

Merge branch 'update_cnntrainer' into 'master'

Set default optimizer, updated cnntrainer to replace unsupported optimizer by...

See merge request !25
parents 03f422c7 12932b3e
Pipeline #106409 passed with stages
in 6 minutes and 57 seconds
...@@ -8,6 +8,7 @@ import de.monticore.lang.monticar.cnntrain._cocos.CNNTrainCocos; ...@@ -8,6 +8,7 @@ import de.monticore.lang.monticar.cnntrain._cocos.CNNTrainCocos;
import de.monticore.lang.monticar.cnntrain._symboltable.CNNTrainCompilationUnitSymbol; import de.monticore.lang.monticar.cnntrain._symboltable.CNNTrainCompilationUnitSymbol;
import de.monticore.lang.monticar.cnntrain._symboltable.CNNTrainLanguage; import de.monticore.lang.monticar.cnntrain._symboltable.CNNTrainLanguage;
import de.monticore.lang.monticar.cnntrain._symboltable.ConfigurationSymbol; import de.monticore.lang.monticar.cnntrain._symboltable.ConfigurationSymbol;
import de.monticore.lang.monticar.cnntrain._symboltable.OptimizerSymbol;
import de.monticore.lang.monticar.generator.FileContent; import de.monticore.lang.monticar.generator.FileContent;
import de.monticore.lang.monticar.generator.cpp.GeneratorCPP; import de.monticore.lang.monticar.generator.cpp.GeneratorCPP;
import de.monticore.symboltable.GlobalScope; import de.monticore.symboltable.GlobalScope;
...@@ -49,7 +50,8 @@ public class CNNTrain2Caffe2 implements CNNTrainGenerator { ...@@ -49,7 +50,8 @@ public class CNNTrain2Caffe2 implements CNNTrainGenerator {
ASTOptimizerEntry astOptimizer = (ASTOptimizerEntry) configuration.getOptimizer().getAstNode().get(); ASTOptimizerEntry astOptimizer = (ASTOptimizerEntry) configuration.getOptimizer().getAstNode().get();
astOptimizer.accept(funcChecker); astOptimizer.accept(funcChecker);
if (funcChecker.getUnsupportedElemList().contains(funcChecker.unsupportedOptFlag)) { if (funcChecker.getUnsupportedElemList().contains(funcChecker.unsupportedOptFlag)) {
configuration.setOptimizer(null); OptimizerSymbol adamOptimizer = new OptimizerSymbol("adam");
configuration.setOptimizer(adamOptimizer); //Set default as adam optimizer
}else { }else {
Iterator it = configuration.getOptimizer().getOptimizerParamMap().keySet().iterator(); Iterator it = configuration.getOptimizer().getOptimizerParamMap().keySet().iterator();
while (it.hasNext()) { while (it.hasNext()) {
......
...@@ -33,9 +33,9 @@ if __name__ == "__main__": ...@@ -33,9 +33,9 @@ if __name__ == "__main__":
loss='${config.loss}', loss='${config.loss}',
</#if> </#if>
<#if (config.configuration.optimizer)??> <#if (config.configuration.optimizer)??>
opt_type='${config.optimizerName}', opt_type='${config.optimizerName}'<#if config.optimizerParams?has_content>,
<#list config.optimizerParams?keys as param> <#list config.optimizerParams?keys as param>
<#--To adapt parameter names since parameter names in Caffe2 are different than in CNNTrainLang--> <#--To adapt parameter names to Caffe2 since they are different than in CNNTrainLang-->
<#assign paramName = param> <#assign paramName = param>
<#if param == "learning_rate"> <#if param == "learning_rate">
<#assign paramName = "base_learning_rate"> <#assign paramName = "base_learning_rate">
...@@ -50,6 +50,7 @@ if __name__ == "__main__": ...@@ -50,6 +50,7 @@ if __name__ == "__main__":
</#if> </#if>
${paramName}=${config.optimizerParams[param]}<#sep>, ${paramName}=${config.optimizerParams[param]}<#sep>,
</#list> </#list>
</#if>
</#if> </#if>
) )
......
...@@ -169,6 +169,20 @@ public class GenerationTest extends AbstractSymtabTest{ ...@@ -169,6 +169,20 @@ public class GenerationTest extends AbstractSymtabTest{
"CNNTrainer_emptyConfig.py")); "CNNTrainer_emptyConfig.py"));
} }
@Test
public void testUnsupportedTrainParameters() throws IOException {
Log.getFindings().clear();
Path modelPath = Paths.get("src/test/resources/valid_tests");
CNNTrain2Caffe2 trainGenerator = new CNNTrain2Caffe2();
trainGenerator.generate(modelPath, "UnsupportedConfig");
assertTrue(Log.getFindings().size() == 6);
checkFilesAreEqual(
Paths.get("./target/generated-sources-cnnarch"),
Paths.get("./src/test/resources/target_code"),
Arrays.asList(
"CNNTrainer_unsupportedConfig.py"));
}
@Test @Test
public void testCMakeGeneration() { public void testCMakeGeneration() {
......
from caffe2.python import workspace, core, model_helper, brew, optimizer
from caffe2.python.predictor import mobile_exporter
from caffe2.proto import caffe2_pb2
import numpy as np
import logging
import CNNCreator_unsupportedConfig
if __name__ == "__main__":
logging.basicConfig(level=logging.DEBUG)
logger = logging.getLogger()
handler = logging.FileHandler("train.log", "w", encoding=None, delay="true")
logger.addHandler(handler)
unsupportedConfig = CNNCreator_unsupportedConfig.CNNCreator_unsupportedConfig()
unsupportedConfig.train(
num_epoch=5,
batch_size=100,
context='gpu',
opt_type='adam'
)
configuration UnsupportedConfig{
num_epoch : 5
batch_size : 100
context : gpu
load_checkpoint : true
normalize : true
optimizer : nag{
learning_rate_minimum : 0.00001
rescale_grad : 1.1
clip_gradient : 10
}
}
\ No newline at end of file
Markdown is supported
0%
or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment