Commit 3b3b4884 authored by Sebastian Nickels's avatar Sebastian Nickels

Fixed TrainParamSupportChecker

parent e00a3860
Pipeline #157653 passed with stages
in 3 minutes and 13 seconds
......@@ -16,7 +16,7 @@
<!-- .. SE-Libraries .................................................. -->
<CNNArch.version>0.3.1-SNAPSHOT</CNNArch.version>
<CNNTrain.version>0.3.0-SNAPSHOT</CNNTrain.version>
<CNNTrain.version>0.3.2-SNAPSHOT</CNNTrain.version>
<CNNArch2X.version>0.0.1-SNAPSHOT</CNNArch2X.version>
<embedded-montiarc-math-opt-generator>0.1.4</embedded-montiarc-math-opt-generator>
......
/**
*
* ******************************************************************************
* MontiCAR Modeling Family, www.se-rwth.de
* Copyright (c) 2017, Software Engineering Group at RWTH Aachen,
* All rights reserved.
*
* This project is free software; you can redistribute it and/or
* modify it under the terms of the GNU Lesser General Public
* License as published by the Free Software Foundation; either
* version 3.0 of the License, or (at your option) any later version.
* This library is distributed in the hope that it will be useful,
* but WITHOUT ANY WARRANTY; without even the implied warranty of
* MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the GNU
* Lesser General Public License for more details.
*
* You should have received a copy of the GNU Lesser General Public
* License along with this project. If not, see <http://www.gnu.org/licenses/>.
* *******************************************************************************
*/
package de.monticore.lang.monticar.cnnarch.caffe2generator;
import de.monticore.lang.monticar.cnnarch.generator.TrainParamSupportChecker;
import de.monticore.lang.monticar.cnntrain._ast.*;
import de.monticore.lang.monticar.cnntrain._visitor.CNNTrainVisitor;
import de.se_rwth.commons.logging.Log;
import java.util.ArrayList;
import java.util.List;
public class TrainParamSupportChecker implements CNNTrainVisitor {
private List<String> unsupportedElemList = new ArrayList();
private void printUnsupportedEntryParam(String nodeName){
Log.warn("Unsupported training parameter " + "'" + nodeName + "'" + " for the backend CAFFE2. It will be ignored.");
}
private void printUnsupportedOptimizer(String nodeName){
Log.warn("Unsupported optimizer parameter " + "'" + nodeName + "'" + " for the backend CAFFE2. It will be ignored.");
}
private void printUnsupportedOptimizerParam(String nodeName){
Log.warn("Unsupported training optimizer parameter " + "'" + nodeName + "'" + " for the backend CAFFE2. It will be ignored.");
}
public TrainParamSupportChecker() {
}
public static final String unsupportedOptFlag = "unsupported_optimizer";
public List getUnsupportedElemList(){
return this.unsupportedElemList;
}
//Empty visit method denotes that the corresponding training parameter is supported.
//To set a training parameter as unsupported, add the corresponding node to the unsupportedElemList
public void visit(ASTNumEpochEntry node){}
public void visit(ASTBatchSizeEntry node){}
public class CNNArch2Caffe2TrainParamSupportChecker extends TrainParamSupportChecker {
public void visit(ASTLoadCheckpointEntry node){
printUnsupportedEntryParam(node.getName());
......@@ -47,18 +35,6 @@ public class TrainParamSupportChecker implements CNNTrainVisitor {
this.unsupportedElemList.add(node.getName());
}
public void visit(ASTTrainContextEntry node){}
public void visit(ASTEvalMetricEntry node){}
public void visit(ASTSGDOptimizer node){}
public void visit(ASTAdamOptimizer node){}
public void visit(ASTRmsPropOptimizer node){}
public void visit(ASTAdaGradOptimizer node){}
public void visit(ASTNesterovOptimizer node){
printUnsupportedOptimizer(node.getName());
this.unsupportedElemList.add(this.unsupportedOptFlag);
......@@ -69,19 +45,11 @@ public class TrainParamSupportChecker implements CNNTrainVisitor {
this.unsupportedElemList.add(this.unsupportedOptFlag);
}
public void visit(ASTLearningRateEntry node){}
public void visit(ASTMinimumLearningRateEntry node){
printUnsupportedOptimizerParam(node.getName());
this.unsupportedElemList.add(node.getName());
}
public void visit(ASTWeightDecayEntry node){}
public void visit(ASTLRDecayEntry node){}
public void visit(ASTLRPolicyEntry node){}
public void visit(ASTRescaleGradEntry node){
printUnsupportedOptimizerParam(node.getName());
this.unsupportedElemList.add(node.getName());
......@@ -92,18 +60,6 @@ public class TrainParamSupportChecker implements CNNTrainVisitor {
this.unsupportedElemList.add(node.getName());
}
public void visit(ASTStepSizeEntry node){}
public void visit(ASTMomentumEntry node){}
public void visit(ASTBeta1Entry node){}
public void visit(ASTBeta2Entry node){}
public void visit(ASTEpsilonEntry node){}
public void visit(ASTGamma1Entry node){}
public void visit(ASTGamma2Entry node){
printUnsupportedOptimizerParam(node.getName());
this.unsupportedElemList.add(node.getName());
......@@ -119,6 +75,4 @@ public class TrainParamSupportChecker implements CNNTrainVisitor {
this.unsupportedElemList.add(node.getName());
}
public void visit(ASTRhoEntry node){}
}
}
\ No newline at end of file
package de.monticore.lang.monticar.cnnarch.caffe2generator;
import de.monticore.io.paths.ModelPath;
import de.monticore.lang.monticar.cnntrain.CNNTrainGenerator;
import de.monticore.lang.monticar.cnnarch.generator.CNNTrainGenerator;
import de.monticore.lang.monticar.cnntrain._ast.ASTCNNTrainNode;
import de.monticore.lang.monticar.cnntrain._ast.ASTOptimizerEntry;
import de.monticore.lang.monticar.cnntrain._cocos.CNNTrainCocos;
......@@ -18,96 +18,13 @@ import java.io.IOException;
import java.nio.file.Path;
import java.util.*;
public class CNNTrain2Caffe2 implements CNNTrainGenerator {
private String generationTargetPath;
private String instanceName;
private void supportCheck(ConfigurationSymbol configuration){
checkEntryParams(configuration);
checkOptimizerParams(configuration);
}
private void checkEntryParams(ConfigurationSymbol configuration){
TrainParamSupportChecker funcChecker = new TrainParamSupportChecker();
Iterator it = configuration.getEntryMap().keySet().iterator();
while (it.hasNext()) {
String key = it.next().toString();
ASTCNNTrainNode astTrainEntryNode = (ASTCNNTrainNode) configuration.getEntryMap().get(key).getAstNode().get();
astTrainEntryNode.accept(funcChecker);
}
it = configuration.getEntryMap().keySet().iterator();
while (it.hasNext()) {
String key = it.next().toString();
if (funcChecker.getUnsupportedElemList().contains(key)) {
it.remove();
}
}
}
private void checkOptimizerParams(ConfigurationSymbol configuration){
TrainParamSupportChecker funcChecker = new TrainParamSupportChecker();
if (configuration.getOptimizer() != null) {
ASTOptimizerEntry astOptimizer = (ASTOptimizerEntry) configuration.getOptimizer().getAstNode().get();
astOptimizer.accept(funcChecker);
if (funcChecker.getUnsupportedElemList().contains(funcChecker.unsupportedOptFlag)) {
OptimizerSymbol adamOptimizer = new OptimizerSymbol("adam");
configuration.setOptimizer(adamOptimizer); //Set default as adam optimizer
}else {
Iterator it = configuration.getOptimizer().getOptimizerParamMap().keySet().iterator();
while (it.hasNext()) {
String key = it.next().toString();
if (funcChecker.getUnsupportedElemList().contains(key)) {
it.remove();
}
}
}
}
}
private static void quitGeneration(){
Log.error("Code generation is aborted");
System.exit(1);
}
public class CNNTrain2Caffe2 extends CNNTrainGenerator {
public CNNTrain2Caffe2() {
setGenerationTargetPath("./target/generated-sources-cnnarch/");
}
public String getInstanceName() {
String parsedInstanceName = this.instanceName.replace('.', '_').replace('[', '_').replace(']', '_');
parsedInstanceName = parsedInstanceName.substring(0, 1).toLowerCase() + parsedInstanceName.substring(1);
return parsedInstanceName;
}
public void setInstanceName(String instanceName) {
this.instanceName = instanceName;
}
public String getGenerationTargetPath() {
if (generationTargetPath.charAt(generationTargetPath.length() - 1) != '/') {
this.generationTargetPath = generationTargetPath + "/";
}
return generationTargetPath;
}
public void setGenerationTargetPath(String generationTargetPath) {
this.generationTargetPath = generationTargetPath;
}
public ConfigurationSymbol getConfigurationSymbol(Path modelsDirPath, String rootModelName) {
final ModelPath mp = new ModelPath(modelsDirPath);
GlobalScope scope = new GlobalScope(mp, new CNNTrainLanguage());
Optional<CNNTrainCompilationUnitSymbol> compilationUnit = scope.resolve(rootModelName, CNNTrainCompilationUnitSymbol.KIND);
if (!compilationUnit.isPresent()) {
Log.error("could not resolve training configuration " + rootModelName);
quitGeneration();
}
setInstanceName(compilationUnit.get().getFullName());
CNNTrainCocos.checkAll(compilationUnit.get());
supportCheck(compilationUnit.get().getConfiguration());
return compilationUnit.get().getConfiguration();
trainParamSupportChecker = new CNNArch2Caffe2TrainParamSupportChecker();
}
@Override
public void generate(Path modelsDirPath, String rootModelName) {
ConfigurationSymbol configuration = getConfigurationSymbol(modelsDirPath, rootModelName);
Map<String, String> fileContents = generateStrings(configuration);
......@@ -122,6 +39,7 @@ public class CNNTrain2Caffe2 implements CNNTrainGenerator {
}
}
@Override
public Map<String, String> generateStrings(ConfigurationSymbol configuration) {
ConfigurationData configData = new ConfigurationData(configuration, getInstanceName());
List<ConfigurationData> configDataList = new ArrayList<>();
......@@ -131,5 +49,4 @@ public class CNNTrain2Caffe2 implements CNNTrainGenerator {
String templateContent = TemplateConfiguration.processTemplate(ftlContext, "CNNTrainer.ftl");
return Collections.singletonMap("CNNTrainer_" + getInstanceName() + ".py", templateContent);
}
}
......@@ -17,5 +17,5 @@ if __name__ == "__main__":
num_epoch=5,
batch_size=100,
context='gpu',
opt_type='adam'
)
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment