Commit 4e6e6673 authored by Svetlana Pavlitskaya's avatar Svetlana Pavlitskaya Committed by Evgeny Kusmenko
Browse files

Adapt different backends, moved to a newer EMADL version

parent 9c3425c2
...@@ -28,20 +28,23 @@ masterJobLinux: ...@@ -28,20 +28,23 @@ masterJobLinux:
image: maven:3-jdk-8 image: maven:3-jdk-8
script: script:
- mvn -Dorg.slf4j.simpleLogger.log.org.apache.maven.cli.transfer.Slf4jMavenTransferListener=warn -B clean deploy --settings settings.xml - mvn -Dorg.slf4j.simpleLogger.log.org.apache.maven.cli.transfer.Slf4jMavenTransferListener=warn -B clean deploy --settings settings.xml
- cat target/site/jacoco/index.html
- mvn package sonar:sonar -s settings.xml
only: only:
- master - master
masterJobWindows: #masterJobWindows:
stage: windows # stage: windows
script: # script:
- mvn -Dorg.slf4j.simpleLogger.log.org.apache.maven.cli.transfer.Slf4jMavenTransferListener=warn -B clean install --settings settings.xml # - mvn -Dorg.slf4j.simpleLogger.log.org.apache.maven.cli.transfer.Slf4jMavenTransferListener=warn -B clean install --settings settings.xml
tags: # tags:
- Windows10 # - Windows10
BranchJobLinux: BranchJobLinux:
stage: linux stage: linux
image: maven:3-jdk-8 image: maven:3-jdk-8
script: script:
- mvn -Dorg.slf4j.simpleLogger.log.org.apache.maven.cli.transfer.Slf4jMavenTransferListener=warn -B clean install --settings settings.xml - mvn -Dorg.slf4j.simpleLogger.log.org.apache.maven.cli.transfer.Slf4jMavenTransferListener=warn -B clean install --settings settings.xml
- cat target/site/jacoco/index.html
except: except:
- master - master
![pipeline](https://git.rwth-aachen.de/monticore/EmbeddedMontiArc/generators/EMADL2CPP/badges/master/build.svg)
![coverage](https://git.rwth-aachen.de/monticore/EmbeddedMontiArc/generators/EMADL2CPP/badges/master/coverage.svg)
# EMADL2CPP # EMADL2CPP
Generates CPP/Python code for EmbeddedMontiArcDL. Generates CPP/Python code for EmbeddedMontiArcDL.
See example project [EMADL-Demo](https://git.rwth-aachen.de/thomas.timmermanns/EMADL-Demo) for more information on how the generated code can be used. See example project [EMADL-Demo](https://git.rwth-aachen.de/thomas.timmermanns/EMADL-Demo) for more information on how the generated code can be used.
\ No newline at end of file
...@@ -8,17 +8,18 @@ ...@@ -8,17 +8,18 @@
<groupId>de.monticore.lang.monticar</groupId> <groupId>de.monticore.lang.monticar</groupId>
<artifactId>embedded-montiarc-emadl-generator</artifactId> <artifactId>embedded-montiarc-emadl-generator</artifactId>
<version>0.2.1-SNAPSHOT</version> <version>0.2.2-SNAPSHOT</version>
<!-- == PROJECT DEPENDENCIES ============================================= --> <!-- == PROJECT DEPENDENCIES ============================================= -->
<properties> <properties>
<!-- .. SE-Libraries .................................................. --> <!-- .. SE-Libraries .................................................. -->
<emadl.version>0.2.1-SNAPSHOT</emadl.version> <emadl.version>0.2.2-SNAPSHOT</emadl.version>
<CNNTrain.version>0.2.1-SNAPSHOT</CNNTrain.version> <CNNTrain.version>0.2.4-SNAPSHOT</CNNTrain.version>
<cnnarch-generator.version>0.2.1-SNAPSHOT</cnnarch-generator.version> <cnnarch-mxnet-generator.version>0.2.3-SNAPSHOT</cnnarch-mxnet-generator.version>
<embedded-montiarc-math-generator>0.0.10</embedded-montiarc-math-generator> <cnnarch-caffe2-generator.version>0.2.2-SNAPSHOT</cnnarch-caffe2-generator.version>
<embedded-montiarc-math-generator>0.0.25-SNAPSHOT</embedded-montiarc-math-generator>
<!-- .. Libraries .................................................. --> <!-- .. Libraries .................................................. -->
<guava.version>18.0</guava.version> <guava.version>18.0</guava.version>
...@@ -31,6 +32,7 @@ ...@@ -31,6 +32,7 @@
<compiler.plugin>3.3</compiler.plugin> <compiler.plugin>3.3</compiler.plugin>
<source.plugin>2.4</source.plugin> <source.plugin>2.4</source.plugin>
<shade.plugin>2.4.3</shade.plugin> <shade.plugin>2.4.3</shade.plugin>
<jacoco.plugin>0.8.1</jacoco.plugin>
<!-- Classifiers --> <!-- Classifiers -->
<grammars.classifier>grammars</grammars.classifier> <grammars.classifier>grammars</grammars.classifier>
...@@ -68,7 +70,13 @@ ...@@ -68,7 +70,13 @@
<dependency> <dependency>
<groupId>de.monticore.lang.monticar</groupId> <groupId>de.monticore.lang.monticar</groupId>
<artifactId>cnnarch-mxnet-generator</artifactId> <artifactId>cnnarch-mxnet-generator</artifactId>
<version>${cnnarch-generator.version}</version> <version>${cnnarch-mxnet-generator.version}</version>
</dependency>
<dependency>
<groupId>de.monticore.lang.monticar</groupId>
<artifactId>cnnarch-caffe2-generator</artifactId>
<version>${cnnarch-caffe2-generator.version}</version>
</dependency> </dependency>
<dependency> <dependency>
...@@ -132,6 +140,27 @@ ...@@ -132,6 +140,27 @@
<version>2.8.1</version> <version>2.8.1</version>
</plugin> </plugin>
<!-- Test coverage -->
<plugin>
<groupId>org.jacoco</groupId>
<artifactId>jacoco-maven-plugin</artifactId>
<version>${jacoco.plugin}</version>
<executions>
<execution>
<id>pre-unit-test</id>
<goals>
<goal>prepare-agent</goal>
</goals>
</execution>
<execution>
<id>post-unit-test</id>
<phase>test</phase>
<goals>
<goal>report</goal>
</goals>
</execution>
</executions>
</plugin>
<!-- Other Configuration --> <!-- Other Configuration -->
<plugin> <plugin>
<artifactId>maven-compiler-plugin</artifactId> <artifactId>maven-compiler-plugin</artifactId>
......
...@@ -57,6 +57,24 @@ ...@@ -57,6 +57,24 @@
</mirrors> </mirrors>
<profiles> <profiles>
<profile>
<id>sonar</id>
<activation>
<activeByDefault>true</activeByDefault>
</activation>
<properties>
<!-- Optional URL to server. Default value is http://localhost:9000 -->
<sonar.host.url>
https://metric.se.rwth-aachen.de
</sonar.host.url>
<sonar.login>
jenkins
</sonar.login>
<sonar.password>
${env.sonar}
</sonar.password>
</properties>
</profile>
<profile> <profile>
<id>se-nexus</id> <id>se-nexus</id>
......
package de.monticore.lang.monticar.emadl.generator;
import de.monticore.lang.monticar.cnnarch.CNNArchGenerator;
import de.monticore.lang.monticar.cnnarch.mxnetgenerator.CNNArch2MxNet;
import de.monticore.lang.monticar.cnnarch.caffe2generator.CNNArch2Caffe2;
import java.util.Optional;
public enum Backend {
MXNET{
@Override
public CNNArchGenerator getGenerator() {
return new CNNArch2MxNet();
}
},
CAFFE2{
@Override
public CNNArchGenerator getGenerator() {
return new CNNArch2Caffe2();
}
};
public abstract CNNArchGenerator getGenerator();
public static Optional<Backend> getBackendFromString(String backend){
switch (backend){
case "MXNET":
return Optional.of(MXNET);
case "CAFFE2":
return Optional.of(CAFFE2);
default:
return Optional.empty();
}
}
}
...@@ -44,7 +44,10 @@ import de.monticore.symboltable.Scope; ...@@ -44,7 +44,10 @@ import de.monticore.symboltable.Scope;
import java.nio.file.Paths; import java.nio.file.Paths;
import java.util.Arrays; import java.util.Arrays;
public class AbstractSymtab { public class EMADLAbstractSymtab {
public EMADLAbstractSymtab() {
}
public static TaggingResolver createSymTabAndTaggingResolver(String... modelPath) { public static TaggingResolver createSymTabAndTaggingResolver(String... modelPath) {
Scope scope = createSymTab(modelPath); Scope scope = createSymTab(modelPath);
TaggingResolver tagging = new TaggingResolver(scope, Arrays.asList(modelPath)); TaggingResolver tagging = new TaggingResolver(scope, Arrays.asList(modelPath));
......
...@@ -22,14 +22,18 @@ package de.monticore.lang.monticar.emadl.generator; ...@@ -22,14 +22,18 @@ package de.monticore.lang.monticar.emadl.generator;
import com.google.common.base.Joiner; import com.google.common.base.Joiner;
import com.google.common.base.Splitter; import com.google.common.base.Splitter;
import com.google.common.base.Charsets;
import com.google.common.io.Resources;
import de.monticore.io.paths.ModelPath; import de.monticore.io.paths.ModelPath;
import de.monticore.lang.embeddedmontiarc.embeddedmontiarc._symboltable.ComponentSymbol; import de.monticore.lang.embeddedmontiarc.embeddedmontiarc._symboltable.ComponentSymbol;
import de.monticore.lang.embeddedmontiarc.embeddedmontiarc._symboltable.ExpandedComponentInstanceSymbol; import de.monticore.lang.embeddedmontiarc.embeddedmontiarc._symboltable.ExpandedComponentInstanceSymbol;
import de.monticore.lang.math.math._symboltable.MathStatementsSymbol; import de.monticore.lang.math._symboltable.MathStatementsSymbol;
import de.monticore.lang.monticar.cnnarch.CNNArchGenerator;
import de.monticore.lang.monticar.cnnarch._symboltable.ArchitectureSymbol; import de.monticore.lang.monticar.cnnarch._symboltable.ArchitectureSymbol;
import de.monticore.lang.monticar.cnnarch.generator.CNNArchGenerator; import de.monticore.lang.monticar.cnntrain._cocos.CNNTrainCocos;
import de.monticore.lang.monticar.cnntrain._symboltable.CNNTrainCompilationUnitSymbol;
import de.monticore.lang.monticar.cnntrain._symboltable.CNNTrainLanguage; import de.monticore.lang.monticar.cnntrain._symboltable.CNNTrainLanguage;
import de.monticore.lang.monticar.cnntrain.generator.CNNTrainGenerator; import de.monticore.lang.monticar.cnntrain._symboltable.ConfigurationSymbol;
import de.monticore.lang.monticar.emadl._cocos.EMADLCocos; import de.monticore.lang.monticar.emadl._cocos.EMADLCocos;
import de.monticore.lang.monticar.generator.FileContent; import de.monticore.lang.monticar.generator.FileContent;
import de.monticore.lang.monticar.generator.cpp.ArmadilloHelper; import de.monticore.lang.monticar.generator.cpp.ArmadilloHelper;
...@@ -42,11 +46,10 @@ import de.monticore.symboltable.GlobalScope; ...@@ -42,11 +46,10 @@ import de.monticore.symboltable.GlobalScope;
import de.monticore.symboltable.Scope; import de.monticore.symboltable.Scope;
import de.se_rwth.commons.Splitters; import de.se_rwth.commons.Splitters;
import de.se_rwth.commons.logging.Log; import de.se_rwth.commons.logging.Log;
import freemarker.template.Template;
import freemarker.template.TemplateException; import freemarker.template.TemplateException;
import java.io.IOException; import java.io.*;
import java.io.StringWriter; import java.nio.charset.Charset;
import java.nio.file.Files; import java.nio.file.Files;
import java.nio.file.Path; import java.nio.file.Path;
import java.nio.file.Paths; import java.nio.file.Paths;
...@@ -55,19 +58,19 @@ import java.util.*; ...@@ -55,19 +58,19 @@ import java.util.*;
public class EMADLGenerator { public class EMADLGenerator {
public static final String CNN_HELPER = "CNNTranslator";
public static final String CNN_TRAINER = "CNNTrainer";
private GeneratorCPP emamGen; private GeneratorCPP emamGen;
private CNNArchGenerator cnnArchGenerator;
private String modelsPath;
public EMADLGenerator() {
public EMADLGenerator(Backend backend) {
emamGen = new GeneratorCPP(); emamGen = new GeneratorCPP();
emamGen.useArmadilloBackend(); emamGen.useArmadilloBackend();
emamGen.setGenerationTargetPath("./target/generated-sources-emadl/"); emamGen.setGenerationTargetPath("./target/generated-sources-emadl/");
cnnArchGenerator = backend.getGenerator();
} }
private String modelsPath;
public String getModelsPath() { public String getModelsPath() {
return modelsPath; return modelsPath;
} }
...@@ -100,7 +103,7 @@ public class EMADLGenerator { ...@@ -100,7 +103,7 @@ public class EMADLGenerator {
public void generate(String modelPath, String qualifiedName) throws IOException, TemplateException { public void generate(String modelPath, String qualifiedName) throws IOException, TemplateException {
setModelsPath( modelPath ); setModelsPath( modelPath );
TaggingResolver symtab = AbstractSymtab.createSymTabAndTaggingResolver(getModelsPath()); TaggingResolver symtab = EMADLAbstractSymtab.createSymTabAndTaggingResolver(getModelsPath());
ComponentSymbol component = symtab.<ComponentSymbol>resolve(qualifiedName, ComponentSymbol.KIND).orElse(null); ComponentSymbol component = symtab.<ComponentSymbol>resolve(qualifiedName, ComponentSymbol.KIND).orElse(null);
List<String> splitName = Splitters.DOT.splitToList(qualifiedName); List<String> splitName = Splitters.DOT.splitToList(qualifiedName);
...@@ -131,7 +134,7 @@ public class EMADLGenerator { ...@@ -131,7 +134,7 @@ public class EMADLGenerator {
generateComponent(fileContents, allInstances, taggingResolver, componentInstanceSymbol, symtab); generateComponent(fileContents, allInstances, taggingResolver, componentInstanceSymbol, symtab);
fileContents.add(generateCNNTrainer(allInstances, componentInstanceSymbol.getComponentType().getFullName().replaceAll("\\.", "_"))); fileContents.addAll(generateCNNTrainer(allInstances, componentInstanceSymbol.getComponentType().getFullName().replaceAll("\\.", "_")));
fileContents.add(ArmadilloHelper.getArmadilloHelperFileContent()); fileContents.add(ArmadilloHelper.getArmadilloHelperFileContent());
TypesGeneratorCPP tg = new TypesGeneratorCPP(); TypesGeneratorCPP tg = new TypesGeneratorCPP();
fileContents.addAll(tg.generateTypes(TypeConverter.getTypeSymbols())); fileContents.addAll(tg.generateTypes(TypeConverter.getTypeSymbols()));
...@@ -185,7 +188,6 @@ public class EMADLGenerator { ...@@ -185,7 +188,6 @@ public class EMADLGenerator {
} }
public void generateCNN(List<FileContent> fileContents, TaggingResolver taggingResolver, ExpandedComponentInstanceSymbol instance, ArchitectureSymbol architecture){ public void generateCNN(List<FileContent> fileContents, TaggingResolver taggingResolver, ExpandedComponentInstanceSymbol instance, ArchitectureSymbol architecture){
CNNArchGenerator cnnArchGenerator = new CNNArchGenerator();
Map<String,String> contentMap = cnnArchGenerator.generateStrings(architecture); Map<String,String> contentMap = cnnArchGenerator.generateStrings(architecture);
String fullName = instance.getFullName().replaceAll("\\.", "_"); String fullName = instance.getFullName().replaceAll("\\.", "_");
...@@ -206,7 +208,7 @@ public class EMADLGenerator { ...@@ -206,7 +208,7 @@ public class EMADLGenerator {
fileContents.add(new FileContent(contentMap.get(fileName), fileName)); fileContents.add(new FileContent(contentMap.get(fileName), fileName));
} }
fileContents.add(componentFileContent); fileContents.add(componentFileContent);
fileContents.add(new FileContent(processTemplate(new HashMap<>(), CNN_HELPER), CNN_HELPER + ".h")); fileContents.add(new FileContent(readResource("CNNTranslator.h", Charsets.UTF_8), "CNNTranslator.h"));
} }
protected String transformComponent(String component, String predictorClassName, String executeMethod){ protected String transformComponent(String component, String predictorClassName, String executeMethod){
...@@ -215,7 +217,7 @@ public class EMADLGenerator { ...@@ -215,7 +217,7 @@ public class EMADLGenerator {
//insert includes //insert includes
component = component.replaceFirst("using namespace", component = component.replaceFirst("using namespace",
"#include \"" + predictorClassName + ".h" + "\"\n" + "#include \"" + predictorClassName + ".h" + "\"\n" +
"#include \"" + CNN_HELPER + ".h" + "\"\n" + "#include \"CNNTranslator.h\"\n" +
"using namespace"); "using namespace");
//insert network attribute //insert network attribute
...@@ -252,34 +254,28 @@ public class EMADLGenerator { ...@@ -252,34 +254,28 @@ public class EMADLGenerator {
} }
} }
public FileContent generateCNNTrainer(Set<ExpandedComponentInstanceSymbol> allInstances, String mainComponentName){ public List<FileContent> generateCNNTrainer(Set<ExpandedComponentInstanceSymbol> allInstances, String mainComponentName) {
List<ExpandedComponentInstanceSymbol> cnnInstances = new ArrayList<>(); List<String> cnnInstanceNames = new ArrayList<>();
List<String> trainParams = new ArrayList<>(); List<ConfigurationSymbol> configurations = new ArrayList<>();
Set<String> componentNames = new HashSet<>(); for (ExpandedComponentInstanceSymbol componentInstance : allInstances) {
for (ExpandedComponentInstanceSymbol componentInstance : allInstances){
ComponentSymbol component = componentInstance.getComponentType().getReferencedSymbol(); ComponentSymbol component = componentInstance.getComponentType().getReferencedSymbol();
Optional<ArchitectureSymbol> architecture = component.getSpannedScope().resolve("", ArchitectureSymbol.KIND); Optional<ArchitectureSymbol> architecture = component.getSpannedScope().resolve("", ArchitectureSymbol.KIND);
if (architecture.isPresent()){ if (architecture.isPresent()) {
ConfigurationSymbol configuration = getTrainingConfiguration(mainComponentName, component, componentInstance);
String fileContent = getTrainingParamsForComponent(mainComponentName, component, componentInstance); configurations.add(configuration);
if (!fileContent.isEmpty()) { cnnInstanceNames.add(componentInstance.getFullName().replaceAll("\\.", "_"));
trainParams.add(fileContent);
}
cnnInstances.add(componentInstance);
componentNames.add(component.getFullName());
} }
} }
Map<String, Object> ftlContext = new HashMap<>(); List<FileContent> fileContents = new ArrayList<>();
ftlContext.put("instances", cnnInstances); Map<String, String> fileContentMap = cnnArchGenerator.generateTrainer(configurations, cnnInstanceNames, mainComponentName);
ftlContext.put("componentNames", componentNames); for (String fileName : fileContentMap.keySet()){
ftlContext.put("trainParams", trainParams); fileContents.add(new FileContent(fileContentMap.get(fileName), fileName));
return new FileContent(processTemplate(ftlContext, CNN_TRAINER), CNN_TRAINER + "_" + mainComponentName + ".py"); }
return fileContents;
} }
private String getTrainingParamsForComponent(String mainComponentName, ComponentSymbol component, ExpandedComponentInstanceSymbol instance) { public ConfigurationSymbol getTrainingConfiguration(String mainComponentName, ComponentSymbol component, ExpandedComponentInstanceSymbol instance) {
String configFilename; String configFilename;
String mainComponentConfigFilename = mainComponentName.replaceAll("\\.", "/"); String mainComponentConfigFilename = mainComponentName.replaceAll("\\.", "/");
String componentConfigFilename = component.getFullName().replaceAll("\\.", "/"); String componentConfigFilename = component.getFullName().replaceAll("\\.", "/");
...@@ -300,7 +296,7 @@ public class EMADLGenerator { ...@@ -300,7 +296,7 @@ public class EMADLGenerator {
+ getModelsPath() + componentConfigFilename + ".cnnt', '" + getModelsPath() + componentConfigFilename + ".cnnt', '"
+ getModelsPath() + mainComponentConfigFilename + ".cnnt'." + + getModelsPath() + mainComponentConfigFilename + ".cnnt'." +
" These files denote respectively the configuration for the single instance, the component or the whole system."); " These files denote respectively the configuration for the single instance, the component or the whole system.");
return ""; return null;
} }
//should be removed when CNNTrain supports packages //should be removed when CNNTrain supports packages
...@@ -309,29 +305,32 @@ public class EMADLGenerator { ...@@ -309,29 +305,32 @@ public class EMADLGenerator {
Path modelPath = Paths.get(getModelsPath() + Joiner.on("/").join(names.subList(0,names.size()-1))); Path modelPath = Paths.get(getModelsPath() + Joiner.on("/").join(names.subList(0,names.size()-1)));
// //
CNNTrainGenerator cnnTrainGenerator = new CNNTrainGenerator(); //CNNTrainGenerator cnnTrainGenerator = new CNNTrainGenerator(); //No need of cnnTrainGenerator since cnnArchGenerator can also generateTrainer()
final ModelPath mp = new ModelPath(modelPath); final ModelPath mp = new ModelPath(modelPath);
GlobalScope trainScope = new GlobalScope(mp, new CNNTrainLanguage()); GlobalScope trainScope = new GlobalScope(mp, new CNNTrainLanguage());
Map.Entry<String, String> fileContents = cnnTrainGenerator.generateFileContent( trainScope, configFilename ); Optional<CNNTrainCompilationUnitSymbol> compilationUnit = trainScope.resolve(configFilename, CNNTrainCompilationUnitSymbol.KIND);
return fileContents.getValue(); if (!compilationUnit.isPresent()){
Log.error("CNNTrainCompilationUnitSymbol is empty. Could not resolve configuration " + configFilename);
System.exit(1);
}
CNNTrainCocos.checkAll(compilationUnit.get());
ConfigurationSymbol configuration = compilationUnit.get().getConfiguration();
return configuration;
} }
protected String processTemplate(Map<String, Object> ftlContext, String templateNameWithoutEnding){ public String readResource(final String fileName, Charset charset) {
StringWriter writer = new StringWriter(); try {
String templateName = templateNameWithoutEnding + ".ftl"; return Resources.toString(Resources.getResource(fileName), charset);
try{
Template template = TemplateConfiguration.get().getTemplate(templateName); } catch (IllegalArgumentException e) {
template.process(ftlContext, writer); System.err.println("Resource file " + fileName + " not found");
}
catch (IOException e) {
Log.error("Freemarker could not find template " + templateName + " :\n" + e.getMessage());
System.exit(1); System.exit(1);
} return null;
catch (TemplateException e){ } catch (IOException e) {
Log.error("An exception occured in template " + templateName + " :\n" + e.getMessage()); System.err.println("IO Error occurred");
System.exit(1); System.exit(1);
return null;
} }
return writer.toString();
} }
} }
...@@ -26,6 +26,7 @@ import freemarker.template.TemplateException; ...@@ -26,6 +26,7 @@ import freemarker.template.TemplateException;
import org.apache.commons.cli.*; import org.apache.commons.cli.*;
import java.io.IOException; import java.io.IOException;
import java.util.Optional;
public class EMADLGeneratorCli { public class EMADLGeneratorCli {
...@@ -49,6 +50,12 @@ public class EMADLGeneratorCli { ...@@ -49,6 +50,12 @@ public class EMADLGeneratorCli {
.hasArg(true) .hasArg(true)
.required(false) .required(false)
.build(); .build();
public static final Option OPTION_BACKEND = Option.builder("b")
.longOpt("backend")
.desc("deep-learning-framework backend. Options: MXNET, CAFFE2")
.hasArg(true)
.required(false)
.build();
private EMADLGeneratorCli() { private EMADLGeneratorCli() {
} }
...@@ -67,6 +74,7 @@ public class EMADLGeneratorCli { ...@@ -67,6 +74,7 @@ public class EMADLGeneratorCli {
options.addOption(OPTION_MODELS_PATH); options.addOption(OPTION_MODELS_PATH);
options.addOption(OPTION_ROOT_MODEL); options.addOption(OPTION_ROOT_MODEL);
options.addOption(OPTION_OUTPUT_PATH); options.addOption(OPTION_OUTPUT_PATH);
options.addOption(OPTION_BACKEND);
return options; return options;
} }
...@@ -85,7 +93,21 @@ public class EMADLGeneratorCli { ...@@ -85,7 +93,21 @@ public class EMADLGeneratorCli {
private static void runGenerator(CommandLine cliArgs) { private static void runGenerator(CommandLine cliArgs) {
String rootModelName = cliArgs.getOptionValue(OPTION_ROOT_MODEL.getOpt()); String rootModelName = cliArgs.getOptionValue(OPTION_ROOT_MODEL.getOpt());
String outputPath = cliArgs.getOptionValue(OPTION_OUTPUT_PATH.getOpt()); String outputPath = cliArgs.getOptionValue(OPTION_OUTPUT_PATH.getOpt());
EMADLGenerator generator = new EMADLGenerator(); String backendString = cliArgs.getOptionValue(OPTION_BACKEND.getOpt());
final String DEFAULT_BACKEND = "MXNET";
if (backendString == null) {
Log.warn("backend not specified. backend set to default value " + DEFAULT_BACKEND);
backendString = DEFAULT_BACKEND;
}
Optional<Backend> backend = Backend.getBackendFromString(backendString);
if (!backend.isPresent()){
Log.warn("specified backend " + backendString + " not supported. backend set to default value " + DEFAULT_BACKEND);
backend = Backend.getBackendFromString(DEFAULT_BACKEND);
}
EMADLGenerator generator = new EMADLGenerator(backend.get());
if (outputPath != null){ if (outputPath != null){
generator.setGenerationTargetPath(outputPath); generator.setGenerationTargetPath(outputPath);
} }
......