Aufgrund einer Störung des s3 Storage, könnten in nächster Zeit folgende GitLab Funktionen nicht zur Verfügung stehen: LFS, Container Registry, Job Artifacs, Uploads (Wiki, Bilder, Projekt-Exporte). Wir bitten um Verständnis. Es wird mit Hochdruck an der Behebung des Problems gearbeitet. Weitere Informationen zur Störung des Object Storage finden Sie hier: https://maintenance.itc.rwth-aachen.de/ticket/status/messages/59-object-storage-pilot

Commit 5f22c7aa authored by Carlos Alfredo Yeverino Rodriguez's avatar Carlos Alfredo Yeverino Rodriguez
Browse files

Added checker for training parameter support. Visitor pattern used.

parent 28de4da4
Pipeline #79178 passed with stages
in 3 minutes and 54 seconds
......@@ -2,6 +2,8 @@ package de.monticore.lang.monticar.cnnarch.caffe2generator;
import de.monticore.io.paths.ModelPath;
import de.monticore.lang.monticar.cnntrain.CNNTrainGenerator;
import de.monticore.lang.monticar.cnntrain._ast.ASTCNNTrainNode;
import de.monticore.lang.monticar.cnntrain._ast.ASTOptimizerEntry;
import de.monticore.lang.monticar.cnntrain._cocos.CNNTrainCocos;
import de.monticore.lang.monticar.cnntrain._symboltable.CNNTrainCompilationUnitSymbol;
import de.monticore.lang.monticar.cnntrain._symboltable.CNNTrainLanguage;
......@@ -19,6 +21,43 @@ public class CNNTrain2Caffe2 implements CNNTrainGenerator {
private String generationTargetPath;
private String instanceName;
private void supportCheck(ConfigurationSymbol configuration){
checkEntryParams(configuration);
checkOptimizerParams(configuration);
}
private void checkEntryParams(ConfigurationSymbol configuration){
TrainParamSupportChecker funcChecker = new TrainParamSupportChecker();
Iterator it = configuration.getEntryMap().keySet().iterator();
while (it.hasNext()) {
String key = it.next().toString();
ASTCNNTrainNode astTrainEntryNode = (ASTCNNTrainNode) configuration.getEntryMap().get(key).getAstNode().get();
astTrainEntryNode.accept(funcChecker);
}
it = configuration.getEntryMap().keySet().iterator();
while (it.hasNext()) {
String key = it.next().toString();
if (funcChecker.getUnsupportedElemList().contains(key)) it.remove();
}
}
private void checkOptimizerParams(ConfigurationSymbol configuration){
TrainParamSupportChecker funcChecker = new TrainParamSupportChecker();
if (configuration.getOptimizer() != null) {
ASTOptimizerEntry astOptimizer = (ASTOptimizerEntry) configuration.getOptimizer().getAstNode().get();
astOptimizer.accept(funcChecker);
if (funcChecker.getUnsupportedElemList().contains(funcChecker.unsupportedOptFlag)) {
configuration.setOptimizer(null);
}else {
Iterator it = configuration.getOptimizer().getOptimizerParamMap().keySet().iterator();
while (it.hasNext()) {
String key = it.next().toString();
if (funcChecker.getUnsupportedElemList().contains(key)) it.remove();
}
}
}
}
public CNNTrain2Caffe2() {
setGenerationTargetPath("./target/generated-sources-cnnarch/");
}
......@@ -54,6 +93,7 @@ public class CNNTrain2Caffe2 implements CNNTrainGenerator {
}
setInstanceName(compilationUnit.get().getFullName());
CNNTrainCocos.checkAll(compilationUnit.get());
supportCheck(compilationUnit.get().getConfiguration());
return compilationUnit.get().getConfiguration();
}
......
package de.monticore.lang.monticar.cnnarch.caffe2generator;
import de.monticore.lang.monticar.cnntrain._ast.*;
import de.monticore.lang.monticar.cnntrain._visitor.CNNTrainVisitor;
import de.se_rwth.commons.logging.Log;
import java.util.ArrayList;
import java.util.List;
public class TrainParamSupportChecker implements CNNTrainVisitor {
private List<String> unsupportedElemList = new ArrayList();
private void printUnsupportedEntryParam(String nodeName){
Log.warn("Unsupported training parameter " + "'" + nodeName + "'" + " for the backend CAFFE2. It will be ignored.");
}
private void printUnsupportedOptimizer(String nodeName){
Log.warn("Unsupported optimizer parameter " + "'" + nodeName + "'" + " for the backend CAFFE2. It will be ignored.");
}
private void printUnsupportedOptimizerParam(String nodeName){
Log.warn("Unsupported training optimizer parameter " + "'" + nodeName + "'" + " for the backend CAFFE2. It will be ignored.");
}
public TrainParamSupportChecker() {
}
public String unsupportedOptFlag = "unsupported_optimizer";
public List getUnsupportedElemList(){
return this.unsupportedElemList;
}
public void visit(ASTNumEpochEntry node){}
public void visit(ASTBatchSizeEntry node){}
public void visit(ASTLoadCheckpointEntry node){
printUnsupportedEntryParam(node.getName());
this.unsupportedElemList.add(node.getName());
}
public void visit(ASTNormalizeEntry node){
printUnsupportedEntryParam(node.getName());
this.unsupportedElemList.add(node.getName());
}
public void visit(ASTTrainContextEntry node){}
public void visit(ASTEvalMetricEntry node){}
public void visit(ASTSGDOptimizer node){}
public void visit(ASTAdamOptimizer node){}
public void visit(ASTRmsPropOptimizer node){}
public void visit(ASTAdaGradOptimizer node){}
public void visit(ASTNesterovOptimizer node){
printUnsupportedOptimizer(node.getName());
this.unsupportedElemList.add(this.unsupportedOptFlag);
}
public void visit(ASTAdaDeltaOptimizer node){
printUnsupportedOptimizer(node.getName());
this.unsupportedElemList.add(this.unsupportedOptFlag);
}
public void visit(ASTLearningRateEntry node){}
public void visit(ASTMinimumLearningRateEntry node){
printUnsupportedOptimizerParam(node.getName());
this.unsupportedElemList.add(node.getName());
}
public void visit(ASTWeightDecayEntry node){}
public void visit(ASTLRDecayEntry node){}
public void visit(ASTLRPolicyEntry node){}
public void visit(ASTRescaleGradEntry node){
printUnsupportedOptimizerParam(node.getName());
this.unsupportedElemList.add(node.getName());
}
public void visit(ASTClipGradEntry node){
printUnsupportedOptimizerParam(node.getName());
this.unsupportedElemList.add(node.getName());
}
public void visit(ASTStepSizeEntry node){}
public void visit(ASTMomentumEntry node){}
public void visit(ASTBeta1Entry node){}
public void visit(ASTBeta2Entry node){}
public void visit(ASTEpsilonEntry node){}
public void visit(ASTGamma1Entry node){}
public void visit(ASTGamma2Entry node){
printUnsupportedOptimizerParam(node.getName());
this.unsupportedElemList.add(node.getName());
}
public void visit(ASTCenteredEntry node){
printUnsupportedOptimizerParam(node.getName());
this.unsupportedElemList.add(node.getName());
}
public void visit(ASTClipWeightsEntry node){
printUnsupportedOptimizerParam(node.getName());
this.unsupportedElemList.add(node.getName());
}
public void visit(ASTRhoEntry node){}
}
......@@ -122,7 +122,7 @@ public class GenerationTest extends AbstractSymtabTest{
CNNTrain2Caffe2 trainGenerator = new CNNTrain2Caffe2();
trainGenerator.generate(Paths.get(sourcePath), "FullConfig");
assertTrue(Log.getFindings().isEmpty());
assertTrue(Log.getFindings().size() == 8);
checkFilesAreEqual(
Paths.get("./target/generated-sources-cnnarch"),
Paths.get("./src/test/resources/target_code"),
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment