Commit 4b79d7f6 authored by Evgeny Kusmenko's avatar Evgeny Kusmenko

Merge branch 'update_cnntrainer' into 'master'

Set default optimizer, updated cnntrainer to replace unsupported optimizer by...

See merge request !25
parents 03f422c7 12932b3e
Pipeline #106409 passed with stages
in 6 minutes and 57 seconds
......@@ -8,6 +8,7 @@ import de.monticore.lang.monticar.cnntrain._cocos.CNNTrainCocos;
import de.monticore.lang.monticar.cnntrain._symboltable.CNNTrainCompilationUnitSymbol;
import de.monticore.lang.monticar.cnntrain._symboltable.CNNTrainLanguage;
import de.monticore.lang.monticar.cnntrain._symboltable.ConfigurationSymbol;
import de.monticore.lang.monticar.cnntrain._symboltable.OptimizerSymbol;
import de.monticore.lang.monticar.generator.FileContent;
import de.monticore.lang.monticar.generator.cpp.GeneratorCPP;
import de.monticore.symboltable.GlobalScope;
......@@ -49,7 +50,8 @@ public class CNNTrain2Caffe2 implements CNNTrainGenerator {
ASTOptimizerEntry astOptimizer = (ASTOptimizerEntry) configuration.getOptimizer().getAstNode().get();
astOptimizer.accept(funcChecker);
if (funcChecker.getUnsupportedElemList().contains(funcChecker.unsupportedOptFlag)) {
configuration.setOptimizer(null);
OptimizerSymbol adamOptimizer = new OptimizerSymbol("adam");
configuration.setOptimizer(adamOptimizer); //Set default as adam optimizer
}else {
Iterator it = configuration.getOptimizer().getOptimizerParamMap().keySet().iterator();
while (it.hasNext()) {
......
......@@ -33,9 +33,9 @@ if __name__ == "__main__":
loss='${config.loss}',
</#if>
<#if (config.configuration.optimizer)??>
opt_type='${config.optimizerName}',
opt_type='${config.optimizerName}'<#if config.optimizerParams?has_content>,
<#list config.optimizerParams?keys as param>
<#--To adapt parameter names since parameter names in Caffe2 are different than in CNNTrainLang-->
<#--To adapt parameter names to Caffe2 since they are different than in CNNTrainLang-->
<#assign paramName = param>
<#if param == "learning_rate">
<#assign paramName = "base_learning_rate">
......@@ -50,6 +50,7 @@ if __name__ == "__main__":
</#if>
${paramName}=${config.optimizerParams[param]}<#sep>,
</#list>
</#if>
</#if>
)
......
......@@ -169,6 +169,20 @@ public class GenerationTest extends AbstractSymtabTest{
"CNNTrainer_emptyConfig.py"));
}
@Test
public void testUnsupportedTrainParameters() throws IOException {
Log.getFindings().clear();
Path modelPath = Paths.get("src/test/resources/valid_tests");
CNNTrain2Caffe2 trainGenerator = new CNNTrain2Caffe2();
trainGenerator.generate(modelPath, "UnsupportedConfig");
assertTrue(Log.getFindings().size() == 6);
checkFilesAreEqual(
Paths.get("./target/generated-sources-cnnarch"),
Paths.get("./src/test/resources/target_code"),
Arrays.asList(
"CNNTrainer_unsupportedConfig.py"));
}
@Test
public void testCMakeGeneration() {
......
from caffe2.python import workspace, core, model_helper, brew, optimizer
from caffe2.python.predictor import mobile_exporter
from caffe2.proto import caffe2_pb2
import numpy as np
import logging
import CNNCreator_unsupportedConfig
if __name__ == "__main__":
logging.basicConfig(level=logging.DEBUG)
logger = logging.getLogger()
handler = logging.FileHandler("train.log", "w", encoding=None, delay="true")
logger.addHandler(handler)
unsupportedConfig = CNNCreator_unsupportedConfig.CNNCreator_unsupportedConfig()
unsupportedConfig.train(
num_epoch=5,
batch_size=100,
context='gpu',
opt_type='adam'
)
configuration UnsupportedConfig{
num_epoch : 5
batch_size : 100
context : gpu
load_checkpoint : true
normalize : true
optimizer : nag{
learning_rate_minimum : 0.00001
rescale_grad : 1.1
clip_gradient : 10
}
}
\ No newline at end of file
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment