Commit 3a697803 authored by Evgeny Kusmenko's avatar Evgeny Kusmenko

Merge branch 'release-candidate' into 'master'

Integrates new Version of CNNArch2Gluon

See merge request !24
parents 16023357 8a07e768
Pipeline #170788 passed with stages
in 10 minutes and 6 seconds
...@@ -5,6 +5,6 @@ nppBackup ...@@ -5,6 +5,6 @@ nppBackup
.classpath .classpath
.idea .idea
.git .git
.vscode
*.iml *.iml
train.log
...@@ -24,7 +24,8 @@ stages: ...@@ -24,7 +24,8 @@ stages:
- linux - linux
- deploy - deploy
masterJobLinux:
git masterJobLinux:
stage: deploy stage: deploy
image: maven:3-jdk-8 image: maven:3-jdk-8
script: script:
...@@ -34,6 +35,7 @@ masterJobLinux: ...@@ -34,6 +35,7 @@ masterJobLinux:
only: only:
- master - master
integrationMXNetJobLinux: integrationMXNetJobLinux:
stage: linux stage: linux
image: registry.git.rwth-aachen.de/monticore/embeddedmontiarc/generators/emadl2cpp/integrationtests/mxnet:v0.0.3 image: registry.git.rwth-aachen.de/monticore/embeddedmontiarc/generators/emadl2cpp/integrationtests/mxnet:v0.0.3
...@@ -47,17 +49,25 @@ integrationCaffe2JobLinux: ...@@ -47,17 +49,25 @@ integrationCaffe2JobLinux:
script: script:
- mvn -Dorg.slf4j.simpleLogger.log.org.apache.maven.cli.transfer.Slf4jMavenTransferListener=warn -B clean install --settings settings.xml -Dtest=IntegrationCaffe2Test - mvn -Dorg.slf4j.simpleLogger.log.org.apache.maven.cli.transfer.Slf4jMavenTransferListener=warn -B clean install --settings settings.xml -Dtest=IntegrationCaffe2Test
integrationGluonJobLinux: integrationGluonJobLinux:
stage: linux stage: linux
image: registry.git.rwth-aachen.de/monticore/embeddedmontiarc/generators/emadl2cpp/integrationtests/mxnet:v0.0.3 image: registry.git.rwth-aachen.de/monticore/embeddedmontiarc/generators/emadl2cpp/integrationtests/mxnet:v0.0.3
script: script:
- mvn -Dorg.slf4j.simpleLogger.log.org.apache.maven.cli.transfer.Slf4jMavenTransferListener=warn -B clean install --settings settings.xml -Dtest=IntegrationGluonTest - mvn -Dorg.slf4j.simpleLogger.log.org.apache.maven.cli.transfer.Slf4jMavenTransferListener=warn -B clean install --settings settings.xml -Dtest=IntegrationGluonTest
integrationPythonWrapperTest:
stage: linux
image: registry.git.rwth-aachen.de/monticore/embeddedmontiarc/generators/emadl2pythonwrapper/tests/mvn-swig:latest
script:
- mvn -Dorg.slf4j.simpleLogger.log.org.apache.maven.cli.transfer.Slf4jMavenTransferListener=warn -B clean install --settings settings.xml -Dtest=IntegrationPythonWrapperTest
masterJobWindows: masterJobWindows:
stage: windows stage: windows
script: script:
- mvn -Dorg.slf4j.simpleLogger.log.org.apache.maven.cli.transfer.Slf4jMavenTransferListener=warn -B clean install --settings settings.xml -Dtest="GenerationTest,SymtabTest" - mvn -Dorg.slf4j.simpleLogger.log.org.apache.maven.cli.transfer.Slf4jMavenTransferListener=warn -B -U clean install --settings settings.xml -Dtest="GenerationTest,SymtabTest"
tags: tags:
- Windows10 - Windows10
......
{
"configurations": [
{
"type": "java",
"name": "CodeLens (Launch) - EMADLGeneratorCli",
"request": "launch",
"mainClass": "de.monticore.lang.monticar.emadl.generator.EMADLGeneratorCli",
"projectName": "embedded-montiarc-emadl-generator"
}
]
}
\ No newline at end of file
...@@ -8,18 +8,19 @@ ...@@ -8,18 +8,19 @@
<groupId>de.monticore.lang.monticar</groupId> <groupId>de.monticore.lang.monticar</groupId>
<artifactId>embedded-montiarc-emadl-generator</artifactId> <artifactId>embedded-montiarc-emadl-generator</artifactId>
<version>0.3.0</version> <version>0.3.5-SNAPSHOT</version>
<!-- == PROJECT DEPENDENCIES ============================================= --> <!-- == PROJECT DEPENDENCIES ============================================= -->
<properties> <properties>
<!-- .. SE-Libraries .................................................. --> <!-- .. SE-Libraries .................................................. -->
<emadl.version>0.2.6</emadl.version> <emadl.version>0.2.8-SNAPSHOT</emadl.version>
<CNNTrain.version>0.2.6</CNNTrain.version> <CNNTrain.version>0.3.6-SNAPSHOT</CNNTrain.version>
<cnnarch-mxnet-generator.version>0.2.14-SNAPSHOT</cnnarch-mxnet-generator.version> <cnnarch-generator.version>0.0.2-SNAPSHOT</cnnarch-generator.version>
<cnnarch-caffe2-generator.version>0.2.11-SNAPSHOT</cnnarch-caffe2-generator.version> <cnnarch-mxnet-generator.version>0.2.16-SNAPSHOT</cnnarch-mxnet-generator.version>
<cnnarch-gluon-generator.version>0.1.6</cnnarch-gluon-generator.version> <cnnarch-caffe2-generator.version>0.2.12-SNAPSHOT</cnnarch-caffe2-generator.version>
<cnnarch-gluon-generator.version>0.2.6-SNAPSHOT</cnnarch-gluon-generator.version>
<embedded-montiarc-math-opt-generator>0.1.4</embedded-montiarc-math-opt-generator> <embedded-montiarc-math-opt-generator>0.1.4</embedded-montiarc-math-opt-generator>
<!-- .. Libraries .................................................. --> <!-- .. Libraries .................................................. -->
...@@ -68,6 +69,12 @@ ...@@ -68,6 +69,12 @@
<version>${embedded-montiarc-math-opt-generator}</version> <version>${embedded-montiarc-math-opt-generator}</version>
</dependency> </dependency>
<dependency>
<groupId>de.monticore.lang.monticar</groupId>
<artifactId>cnnarch-generator</artifactId>
<version>${cnnarch-generator.version}</version>
</dependency>
<dependency> <dependency>
<groupId>de.monticore.lang.monticar</groupId> <groupId>de.monticore.lang.monticar</groupId>
<artifactId>cnnarch-mxnet-generator</artifactId> <artifactId>cnnarch-mxnet-generator</artifactId>
......
package de.monticore.lang.monticar.emadl.generator; package de.monticore.lang.monticar.emadl.generator;
import de.monticore.lang.monticar.cnnarch.CNNArchGenerator; import de.monticore.lang.monticar.cnnarch.generator.CNNArchGenerator;
import de.monticore.lang.monticar.cnnarch.generator.CNNTrainGenerator;
import de.monticore.lang.monticar.cnnarch.gluongenerator.CNNArch2Gluon; import de.monticore.lang.monticar.cnnarch.gluongenerator.CNNArch2Gluon;
import de.monticore.lang.monticar.cnnarch.gluongenerator.CNNTrain2Gluon; import de.monticore.lang.monticar.cnnarch.gluongenerator.CNNTrain2Gluon;
import de.monticore.lang.monticar.cnnarch.mxnetgenerator.CNNArch2MxNet; import de.monticore.lang.monticar.cnnarch.mxnetgenerator.CNNArch2MxNet;
import de.monticore.lang.monticar.cnnarch.caffe2generator.CNNArch2Caffe2; import de.monticore.lang.monticar.cnnarch.caffe2generator.CNNArch2Caffe2;
import de.monticore.lang.monticar.cnnarch.mxnetgenerator.CNNTrain2MxNet; import de.monticore.lang.monticar.cnnarch.mxnetgenerator.CNNTrain2MxNet;
import de.monticore.lang.monticar.cnnarch.caffe2generator.CNNTrain2Caffe2; import de.monticore.lang.monticar.cnnarch.caffe2generator.CNNTrain2Caffe2;
import de.monticore.lang.monticar.cnntrain.CNNTrainGenerator; import de.monticore.lang.monticar.emadl.generator.reinforcementlearning.RewardFunctionCppGenerator;
import java.util.Optional; import java.util.Optional;
...@@ -40,7 +41,7 @@ public enum Backend { ...@@ -40,7 +41,7 @@ public enum Backend {
} }
@Override @Override
public CNNTrainGenerator getCNNTrainGenerator() { public CNNTrainGenerator getCNNTrainGenerator() {
return new CNNTrain2Gluon(); return new CNNTrain2Gluon(new RewardFunctionCppGenerator());
} }
}; };
......
...@@ -27,10 +27,15 @@ import com.google.common.io.Resources; ...@@ -27,10 +27,15 @@ import com.google.common.io.Resources;
import de.monticore.lang.embeddedmontiarc.embeddedmontiarc._symboltable.cncModel.EMAComponentSymbol; import de.monticore.lang.embeddedmontiarc.embeddedmontiarc._symboltable.cncModel.EMAComponentSymbol;
import de.monticore.lang.embeddedmontiarc.embeddedmontiarc._symboltable.instanceStructure.EMAComponentInstanceSymbol; import de.monticore.lang.embeddedmontiarc.embeddedmontiarc._symboltable.instanceStructure.EMAComponentInstanceSymbol;
import de.monticore.lang.math._symboltable.MathStatementsSymbol; import de.monticore.lang.math._symboltable.MathStatementsSymbol;
import de.monticore.lang.monticar.cnnarch.CNNArchGenerator; import de.monticore.lang.monticar.cnnarch.generator.CNNArchGenerator;
import de.monticore.lang.monticar.cnnarch.DataPathConfigParser; import de.monticore.lang.monticar.cnnarch.generator.DataPathConfigParser;
import de.monticore.lang.monticar.cnnarch.generator.CNNTrainGenerator;
import de.monticore.lang.monticar.cnnarch._symboltable.ArchitectureSymbol; import de.monticore.lang.monticar.cnnarch._symboltable.ArchitectureSymbol;
import de.monticore.lang.monticar.cnntrain.CNNTrainGenerator; import de.monticore.lang.monticar.cnnarch._symboltable.SerialCompositeElementSymbol;
import de.monticore.lang.monticar.cnnarch.gluongenerator.CNNTrain2Gluon;
import de.monticore.lang.monticar.cnnarch.gluongenerator.annotations.ArchitectureAdapter;
import de.monticore.lang.monticar.cnntrain._cocos.CNNTrainCoCoChecker;
import de.monticore.lang.monticar.cnntrain._cocos.CNNTrainCocos;
import de.monticore.lang.monticar.cnntrain._symboltable.ConfigurationSymbol; import de.monticore.lang.monticar.cnntrain._symboltable.ConfigurationSymbol;
import de.monticore.lang.monticar.emadl._cocos.EMADLCocos; import de.monticore.lang.monticar.emadl._cocos.EMADLCocos;
import de.monticore.lang.monticar.generator.FileContent; import de.monticore.lang.monticar.generator.FileContent;
...@@ -66,8 +71,8 @@ public class EMADLGenerator { ...@@ -66,8 +71,8 @@ public class EMADLGenerator {
private Backend backend; private Backend backend;
private String modelsPath; private String modelsPath;
private Map<String, ArchitectureSymbol> processedArchitecture;
public EMADLGenerator(Backend backend) { public EMADLGenerator(Backend backend) {
this.backend = backend; this.backend = backend;
...@@ -109,8 +114,21 @@ public class EMADLGenerator { ...@@ -109,8 +114,21 @@ public class EMADLGenerator {
} }
public void generate(String modelPath, String qualifiedName, String pythonPath, String forced, boolean doCompile) throws IOException, TemplateException { public void generate(String modelPath, String qualifiedName, String pythonPath, String forced, boolean doCompile) throws IOException, TemplateException {
processedArchitecture = new HashMap<>();
setModelsPath( modelPath ); setModelsPath( modelPath );
TaggingResolver symtab = EMADLAbstractSymtab.createSymTabAndTaggingResolver(getModelsPath()); TaggingResolver symtab = EMADLAbstractSymtab.createSymTabAndTaggingResolver(getModelsPath());
EMAComponentInstanceSymbol instance = resolveComponentInstanceSymbol(qualifiedName, symtab);
generateFiles(symtab, instance, symtab, pythonPath, forced);
if (doCompile) {
compile();
}
processedArchitecture = null;
}
private EMAComponentInstanceSymbol resolveComponentInstanceSymbol(String qualifiedName, TaggingResolver symtab) {
EMAComponentSymbol component = symtab.<EMAComponentSymbol>resolve(qualifiedName, EMAComponentSymbol.KIND).orElse(null); EMAComponentSymbol component = symtab.<EMAComponentSymbol>resolve(qualifiedName, EMAComponentSymbol.KIND).orElse(null);
List<String> splitName = Splitters.DOT.splitToList(qualifiedName); List<String> splitName = Splitters.DOT.splitToList(qualifiedName);
...@@ -122,14 +140,7 @@ public class EMADLGenerator { ...@@ -122,14 +140,7 @@ public class EMADLGenerator {
System.exit(1); System.exit(1);
} }
EMAComponentInstanceSymbol instance = component.getEnclosingScope().<EMAComponentInstanceSymbol>resolve(instanceName, EMAComponentInstanceSymbol.KIND).get(); return component.getEnclosingScope().<EMAComponentInstanceSymbol>resolve(instanceName, EMAComponentInstanceSymbol.KIND).get();
generateFiles(symtab, instance, symtab, pythonPath, forced);
if (doCompile) {
compile();
}
} }
public void compile() throws IOException { public void compile() throws IOException {
...@@ -228,12 +239,14 @@ public class EMADLGenerator { ...@@ -228,12 +239,14 @@ public class EMADLGenerator {
String b = backend.getBackendString(backend); String b = backend.getBackendString(backend);
String trainingDataHash = ""; String trainingDataHash = "";
String testDataHash = ""; String testDataHash = "";
if(b.equals("CAFFE2")){ if (architecture.get().getDataPath() != null) {
trainingDataHash = getChecksumForFile(architecture.get().getDataPath() + "/train_lmdb/data.mdb"); if (b.equals("CAFFE2")) {
testDataHash = getChecksumForFile(architecture.get().getDataPath() + "/test_lmdb/data.mdb"); trainingDataHash = getChecksumForFile(architecture.get().getDataPath() + "/train_lmdb/data.mdb");
}else{ testDataHash = getChecksumForFile(architecture.get().getDataPath() + "/test_lmdb/data.mdb");
trainingDataHash = getChecksumForFile(architecture.get().getDataPath() + "/train.h5"); } else {
testDataHash = getChecksumForFile(architecture.get().getDataPath() + "/test.h5"); trainingDataHash = getChecksumForFile(architecture.get().getDataPath() + "/train.h5");
testDataHash = getChecksumForFile(architecture.get().getDataPath() + "/test.h5");
}
} }
String trainingHash = emadlHash + "#" + cnntHash + "#" + trainingDataHash + "#" + testDataHash; String trainingHash = emadlHash + "#" + cnntHash + "#" + trainingDataHash + "#" + testDataHash;
...@@ -312,6 +325,7 @@ public class EMADLGenerator { ...@@ -312,6 +325,7 @@ public class EMADLGenerator {
public List<FileContent> generateStrings(TaggingResolver taggingResolver, EMAComponentInstanceSymbol componentInstanceSymbol, Scope symtab, Set<EMAComponentInstanceSymbol> allInstances, String forced){ public List<FileContent> generateStrings(TaggingResolver taggingResolver, EMAComponentInstanceSymbol componentInstanceSymbol, Scope symtab, Set<EMAComponentInstanceSymbol> allInstances, String forced){
List<FileContent> fileContents = new ArrayList<>(); List<FileContent> fileContents = new ArrayList<>();
processedArchitecture = new HashMap<>();
generateComponent(fileContents, allInstances, taggingResolver, componentInstanceSymbol, symtab); generateComponent(fileContents, allInstances, taggingResolver, componentInstanceSymbol, symtab);
...@@ -338,6 +352,7 @@ public class EMADLGenerator { ...@@ -338,6 +352,7 @@ public class EMADLGenerator {
fixArmadilloImports(fileContents); fixArmadilloImports(fileContents);
processedArchitecture = null;
return fileContents; return fileContents;
} }
...@@ -360,13 +375,26 @@ public class EMADLGenerator { ...@@ -360,13 +375,26 @@ public class EMADLGenerator {
EMADLCocos.checkAll(componentInstanceSymbol); EMADLCocos.checkAll(componentInstanceSymbol);
if (architecture.isPresent()){ if (architecture.isPresent()){
DataPathConfigParser newParserConfig = new DataPathConfigParser(getModelsPath() + "data_paths.txt"); cnnArchGenerator.check(architecture.get());
String dPath = newParserConfig.getDataPath(EMAComponentSymbol.getFullName());
String dPath = null;
Path dataPathDefinition = Paths.get(getModelsPath(), "data_paths.txt");
if (dataPathDefinition.toFile().exists()) {
DataPathConfigParser newParserConfig = new DataPathConfigParser(getModelsPath() + "data_paths.txt");
dPath = newParserConfig.getDataPath(EMAComponentSymbol.getFullName());
} else {
Log.warn("No data path definition found in " + dataPathDefinition + " found: "
+ "Set data path to default ./data path");
dPath = "data";
}
/*String dPath = DataPathConfigParser.getDataPath(getModelsPath() + "data_paths.txt", componentSymbol.getFullName());*/ /*String dPath = DataPathConfigParser.getDataPath(getModelsPath() + "data_paths.txt", componentSymbol.getFullName());*/
architecture.get().setDataPath(dPath); architecture.get().setDataPath(dPath);
architecture.get().setComponentName(EMAComponentSymbol.getFullName()); architecture.get().setComponentName(EMAComponentSymbol.getFullName());
generateCNN(fileContents, taggingResolver, componentInstanceSymbol, architecture.get()); generateCNN(fileContents, taggingResolver, componentInstanceSymbol, architecture.get());
if (processedArchitecture != null) {
processedArchitecture.put(architecture.get().getComponentName(), architecture.get());
}
} }
else if (mathStatements.isPresent()){ else if (mathStatements.isPresent()){
generateMathComponent(fileContents, taggingResolver, componentInstanceSymbol, mathStatements.get()); generateMathComponent(fileContents, taggingResolver, componentInstanceSymbol, mathStatements.get());
...@@ -398,7 +426,7 @@ public class EMADLGenerator { ...@@ -398,7 +426,7 @@ public class EMADLGenerator {
String component = emamGen.generateString(taggingResolver, instance, (MathStatementsSymbol) null); String component = emamGen.generateString(taggingResolver, instance, (MathStatementsSymbol) null);
FileContent componentFileContent = new FileContent( FileContent componentFileContent = new FileContent(
transformComponent(component, "CNNPredictor_" + fullName, executeMethod), transformComponent(component, "CNNPredictor_" + fullName, executeMethod, architecture),
instance); instance);
for (String fileName : contentMap.keySet()){ for (String fileName : contentMap.keySet()){
...@@ -408,18 +436,26 @@ public class EMADLGenerator { ...@@ -408,18 +436,26 @@ public class EMADLGenerator {
fileContents.add(new FileContent(readResource("CNNTranslator.h", Charsets.UTF_8), "CNNTranslator.h")); fileContents.add(new FileContent(readResource("CNNTranslator.h", Charsets.UTF_8), "CNNTranslator.h"));
} }
protected String transformComponent(String component, String predictorClassName, String executeMethod){ protected String transformComponent(String component, String predictorClassName, String executeMethod, ArchitectureSymbol architecture){
String networkVariableName = "_cnn_";
//insert includes //insert includes
component = component.replaceFirst("using namespace", component = component.replaceFirst("using namespace",
"#include \"" + predictorClassName + ".h" + "\"\n" + "#include \"" + predictorClassName + ".h" + "\"\n" +
"#include \"CNNTranslator.h\"\n" + "#include \"CNNTranslator.h\"\n" +
"using namespace"); "using namespace");
//insert network attribute //insert network attribute for predictor of each network
component = component.replaceFirst("public:", String networkAttributes = "public:";
"public:\n" + predictorClassName + " " + networkVariableName + ";");
int i = 0;
for (SerialCompositeElementSymbol stream : architecture.getStreams()) {
if (stream.isNetwork()) {
networkAttributes += "\n" + predictorClassName + "_" + i + " _predictor_" + i + "_;";
}
++i;
}
component = component.replaceFirst("public:", networkAttributes);
//insert execute method //insert execute method
component = component.replaceFirst("void execute\\(\\)\\s\\{\\s\\}", component = component.replaceFirst("void execute\\(\\)\\s\\{\\s\\}",
...@@ -487,10 +523,46 @@ public class EMADLGenerator { ...@@ -487,10 +523,46 @@ public class EMADLGenerator {
String trainConfigFilename = getConfigFilename(mainComponentName, component.getFullName(), component.getName()); String trainConfigFilename = getConfigFilename(mainComponentName, component.getFullName(), component.getName());
//should be removed when CNNTrain supports packages //should be removed when CNNTrain supports packages
cnnTrainGenerator.setGenerationTargetPath(getGenerationTargetPath());
if (cnnTrainGenerator instanceof CNNTrain2Gluon) {
((CNNTrain2Gluon) cnnTrainGenerator).setRootProjectModelsDir(getModelsPath());
}
List<String> names = Splitter.on("/").splitToList(trainConfigFilename); List<String> names = Splitter.on("/").splitToList(trainConfigFilename);
trainConfigFilename = names.get(names.size()-1); trainConfigFilename = names.get(names.size()-1);
Path modelPath = Paths.get(getModelsPath() + Joiner.on("/").join(names.subList(0,names.size()-1))); Path modelPath = Paths.get(getModelsPath() + Joiner.on("/").join(names.subList(0,names.size()-1)));
ConfigurationSymbol configuration = cnnTrainGenerator.getConfigurationSymbol(modelPath, trainConfigFilename); ConfigurationSymbol configuration = cnnTrainGenerator.getConfigurationSymbol(modelPath, trainConfigFilename);
// Annotate train configuration with architecture
final String fullConfigName = String.join(".", names);
ArchitectureSymbol correspondingArchitecture = this.processedArchitecture.get(fullConfigName);
assert correspondingArchitecture != null : "No architecture found for train " + fullConfigName + " configuration!";
configuration.setTrainedArchitecture(
new ArchitectureAdapter(correspondingArchitecture.getName(), correspondingArchitecture));
CNNTrainCocos.checkTrainedArchitectureCoCos(configuration);
// Resolve critic network if critic is present
if (configuration.getCriticName().isPresent()) {
String fullCriticName = configuration.getCriticName().get();
int indexOfFirstNameCharacter = fullCriticName.lastIndexOf('.') + 1;
fullCriticName = fullCriticName.substring(0, indexOfFirstNameCharacter)
+ fullCriticName.substring(indexOfFirstNameCharacter, indexOfFirstNameCharacter + 1).toUpperCase()
+ fullCriticName.substring(indexOfFirstNameCharacter + 1);
TaggingResolver symtab = EMADLAbstractSymtab.createSymTabAndTaggingResolver(getModelsPath());
EMAComponentInstanceSymbol instanceSymbol = resolveComponentInstanceSymbol(fullCriticName, symtab);
EMADLCocos.checkAll(instanceSymbol);
Optional<ArchitectureSymbol> critic = instanceSymbol.getSpannedScope().resolve("", ArchitectureSymbol.KIND);
if (!critic.isPresent()) {
Log.error("During the resolving of critic component: Critic component "
+ fullCriticName + " does not have a CNN implementation but is required to have one");
System.exit(-1);
}
critic.get().setComponentName(fullCriticName);
configuration.setCriticNetwork(new ArchitectureAdapter(fullCriticName, critic.get()));
CNNTrainCocos.checkCriticCocos(configuration);
}
cnnTrainGenerator.setInstanceName(componentInstance.getFullName().replaceAll("\\.", "_")); cnnTrainGenerator.setInstanceName(componentInstance.getFullName().replaceAll("\\.", "_"));
Map<String, String> fileContentMap = cnnTrainGenerator.generateStrings(configuration); Map<String, String> fileContentMap = cnnTrainGenerator.generateStrings(configuration);
for (String fileName : fileContentMap.keySet()){ for (String fileName : fileContentMap.keySet()){
......
...@@ -52,7 +52,7 @@ public class EMADLGeneratorCli { ...@@ -52,7 +52,7 @@ public class EMADLGeneratorCli {
.build(); .build();
public static final Option OPTION_BACKEND = Option.builder("b") public static final Option OPTION_BACKEND = Option.builder("b")
.longOpt("backend") .longOpt("backend")
.desc("deep-learning-framework backend. Options: MXNET, CAFFE2") .desc("deep-learning-framework backend. Options: MXNET, CAFFE2, GLUON")
.hasArg(true) .hasArg(true)
.required(false) .required(false)
.build(); .build();
......
package de.monticore.lang.monticar.emadl.generator.reinforcementlearning;
import de.monticore.lang.embeddedmontiarc.embeddedmontiarc._symboltable.instanceStructure.EMAComponentInstanceSymbol;
import de.monticore.lang.monticar.cnnarch.gluongenerator.reinforcement.RewardFunctionSourceGenerator;
import de.monticore.lang.monticar.emadl.generator.EMADLAbstractSymtab;
import de.monticore.lang.monticar.generator.cpp.GeneratorEMAMOpt2CPP;
import de.monticore.lang.tagging._symboltable.TaggingResolver;
import de.se_rwth.commons.logging.Log;
import java.io.IOException;
import java.util.Optional;
public class RewardFunctionCppGenerator implements RewardFunctionSourceGenerator{
public RewardFunctionCppGenerator() {
}
@Override
public EMAComponentInstanceSymbol resolveSymbol(TaggingResolver taggingResolver, String rootModel) {
Optional<EMAComponentInstanceSymbol> instanceSymbol = taggingResolver
.<EMAComponentInstanceSymbol>resolve(rootModel, EMAComponentInstanceSymbol.KIND);
if (!instanceSymbol.isPresent()) {
Log.error("Generation of reward function is not possible: Cannot resolve component instance "
+ rootModel);
}
return instanceSymbol.get();
}
@Override
public void generate(EMAComponentInstanceSymbol componentInstanceSymbol, TaggingResolver taggingResolver,
String targetPath) {
GeneratorEMAMOpt2CPP generator = new GeneratorEMAMOpt2CPP();
generator.useArmadilloBackend();
generator.setGenerationTargetPath(targetPath);
try {
generator.generate(componentInstanceSymbol, taggingResolver);
} catch (IOException e) {
Log.error("Generation of reward function is not possible: " + e.getMessage());
}
}
@Override
public void generate(String modelPath, String rootModel, String targetPath) {
TaggingResolver taggingResolver = createTaggingResolver(modelPath);
EMAComponentInstanceSymbol instanceSymbol = resolveSymbol(taggingResolver, rootModel);
generate(instanceSymbol, taggingResolver, targetPath);
}
@Override
public TaggingResolver createTaggingResolver(final String modelPath) {
return EMADLAbstractSymtab.createSymTabAndTaggingResolver(modelPath);
}
}
...@@ -23,9 +23,11 @@ package de.monticore.lang.monticar.emadl; ...@@ -23,9 +23,11 @@ package de.monticore.lang.monticar.emadl;
import de.monticore.lang.monticar.emadl.generator.Backend; import de.monticore.lang.monticar.emadl.generator.Backend;
import de.monticore.lang.monticar.emadl.generator.EMADLGenerator; import de.monticore.lang.monticar.emadl.generator.EMADLGenerator;
import de.monticore.lang.monticar.emadl.generator.EMADLGeneratorCli; import de.monticore.lang.monticar.emadl.generator.EMADLGeneratorCli;
import de.se_rwth.commons.logging.Finding;
import de.se_rwth.commons.logging.Log; import de.se_rwth.commons.logging.Log;
import freemarker.template.TemplateException; import freemarker.template.TemplateException;
import org.junit.Before; import org.junit.Before;
import org.junit.Ignore;
import org.junit.Test; import org.junit.Test;
import java.io.IOException; import java.io.IOException;
...@@ -35,12 +37,13 @@ import java.nio.file.Path; ...@@ -35,12 +37,13 @@ import java.nio.file.Path;
import java.nio.file.Paths; import java.nio.file.Paths;
import java.util.Arrays; import java.util.Arrays;
import java.util.List; import