Commit e785fbc2 authored by Evgeny Kusmenko's avatar Evgeny Kusmenko

Merge branch 'oneclick_nn_training' into 'master'

Oneclick nn training

Closes #7

See merge request !21
parents abbd6de0 59972bb9
Pipeline #110151 failed with stages
in 7 minutes and 38 seconds
......@@ -22,29 +22,44 @@
stages:
- windows
- linux
- deploy
masterJobLinux:
stage: linux
stage: deploy
image: maven:3-jdk-8
script:
- mvn -Dorg.slf4j.simpleLogger.log.org.apache.maven.cli.transfer.Slf4jMavenTransferListener=warn -B clean deploy --settings settings.xml
- mvn -Dorg.slf4j.simpleLogger.log.org.apache.maven.cli.transfer.Slf4jMavenTransferListener=warn -B clean deploy --settings settings.xml -DskipTests
- cat target/site/jacoco/index.html
- mvn package sonar:sonar -s settings.xml
only:
- master
integrationMXNetJobLinux:
stage: linux
image: registry.git.rwth-aachen.de/monticore/embeddedmontiarc/generators/emadl2cpp/integrationtests/mxnet:v0.0.3
script:
- mvn -Dorg.slf4j.simpleLogger.log.org.apache.maven.cli.transfer.Slf4jMavenTransferListener=warn -B clean install --settings settings.xml -Dtest=IntegrationMXNetTest
integrationCaffe2JobLinux:
stage: linux
image: registry.git.rwth-aachen.de/monticore/embeddedmontiarc/generators/emadl2cpp/integrationtests/caffe2:v0.0.3
script:
- mvn -Dorg.slf4j.simpleLogger.log.org.apache.maven.cli.transfer.Slf4jMavenTransferListener=warn -B clean install --settings settings.xml -Dtest=IntegrationCaffe2Test
masterJobWindows:
stage: windows
script:
- mvn -Dorg.slf4j.simpleLogger.log.org.apache.maven.cli.transfer.Slf4jMavenTransferListener=warn -B clean install --settings settings.xml
- mvn -Dorg.slf4j.simpleLogger.log.org.apache.maven.cli.transfer.Slf4jMavenTransferListener=warn -B clean install --settings settings.xml -Dtest="GenerationTest,SymtabTest"
tags:
- Windows10
BranchJobLinux:
UnitTestJobLinux:
stage: linux
image: maven:3-jdk-8
script:
- mvn -Dorg.slf4j.simpleLogger.log.org.apache.maven.cli.transfer.Slf4jMavenTransferListener=warn -B clean install --settings settings.xml
- mvn -Dorg.slf4j.simpleLogger.log.org.apache.maven.cli.transfer.Slf4jMavenTransferListener=warn -B clean install --settings settings.xml -Dtest="GenerationTest,SymtabTest*"
- cat target/site/jacoco/index.html
except:
- master
{
"configurations": [
{
"type": "java",
"name": "CodeLens (Launch) - EMADLGeneratorCli",
"request": "launch",
"mainClass": "de.monticore.lang.monticar.emadl.generator.EMADLGeneratorCli",
"projectName": "embedded-montiarc-emadl-generator"
}
]
}
\ No newline at end of file
......@@ -2,4 +2,84 @@
![coverage](https://git.rwth-aachen.de/monticore/EmbeddedMontiArc/generators/EMADL2CPP/badges/master/coverage.svg)
# EMADL2CPP
Generates CPP/Python code for EmbeddedMontiArcDL.
See example project [EMADL-Demo](https://git.rwth-aachen.de/thomas.timmermanns/EMADL-Demo) for more information on how the generated code can be used.
\ No newline at end of file
See example project [EMADL-Demo](https://git.rwth-aachen.de/thomas.timmermanns/EMADL-Demo) for more information on how the generated code can be used.
[ How to develop and train a CNN component using EMADL2CPP](#nn)
<a name="nn"></a>
# Development and training of a CNN component using EMADL2CPP
## Prerequisites
* Linux. Ubuntu Linux 16.04 and 18.04 were used during testing.
* Deep learning backend:
* MXNet
* training - generated is Python code. Required is Python 2.7 or higher, Python packages `h5py`, `mxnet` (for training on CPU) or e.g. `mxnet-cu75` for CUDA 7.5 (for training on GPU with CUDA, concrete package should be selected according to CUDA version). Follow [official instructions on MXNet site](https://mxnet.incubator.apache.org/install/index.html?platform=Linux&language=Python&processor=CPU)
* prediction - generated code is C++. Install MXNet using [official instructions on MXNet site](https://mxnet.incubator.apache.org) for C++.
* Caffe2
* training - generated is Python code. Follow [ official instructions on Caffe2 site ](https://caffe2.ai/docs/getting-started.html?platform=ubuntu&configuration=prebuilt)
* Gluon
### HowTo
1. Define a EMADL component containing architecture of a neural network and save it in a `.emadl` file. For more information on architecture language please refer to [CNNArchLang project](https://git.rwth-aachen.de/monticore/EmbeddedMontiArc/languages/CNNArchLang). An example of NN architecture:
```
component VGG16{
ports in Z(0:255)^{3, 224, 224} image,
out Q(0:1)^{1000} predictions;
implementation CNN {
def conv(filter, channels){
Convolution(kernel=(filter,filter), channels=channels) ->
Relu()
}
def fc(){
FullyConnected(units=4096) ->
Relu() ->
Dropout(p=0.5)
}
image ->
conv(filter=3, channels=64, ->=2) ->
Pooling(pool_type="max", kernel=(2,2), stride=(2,2)) ->
conv(filter=3, channels=128, ->=2) ->
Pooling(pool_type="max", kernel=(2,2), stride=(2,2)) ->
conv(filter=3, channels=256, ->=3) ->
Pooling(pool_type="max", kernel=(2,2), stride=(2,2)) ->
conv(filter=3, channels=512, ->=3) ->
Pooling(pool_type="max", kernel=(2,2), stride=(2,2)) ->
conv(filter=3, channels=512, ->=3) ->
Pooling(pool_type="max", kernel=(2,2), stride=(2,2)) ->
fc() ->
fc() ->
FullyConnected(units=1000) ->
Softmax() ->
predictions
}
}
```
2. Define a training configuration for this network and store it in a `.cnnt file`, the name of the file should be the same as that of the corresponding architecture (e.g. `VGG16.emadl` and `VGG16.cnnt`). For more information on architecture language please refer to [CNNTrainLang project](https://git.rwth-aachen.de/monticore/EmbeddedMontiArc/languages/CNNTrainLang). An example of a training configuration:
```
configuration VGG16{
num_epoch:10
batch_size:64
normalize:true
load_checkpoint:false
optimizer:adam{
learning_rate:0.01
learning_rate_decay:0.8
step_size:1000
}
}
```
3. Generate C code which uses neural networks that were trained using the specified deep learning backend. The generator receives the following command line parameters:
* `-m` path to directory with EMADL models
* `-r` name of the root model
* `-o` output path
* `-b` backend
* `-p` path to python (Not mandatory; Default is `/usr/bin/python`)
* `-f` forced training (Not mandatory; values can be `y` for a forced training and `n` for a skip (a forced no-training)). By default, the hash value (from the training and test data, the structure of the model (.emadl) and the training parameters (.cnnt) of the model) will be compared. The model is retrained only if the hash changes. This can be used to distribute trained models, by distributing the corresponding `.training_hash` file as well, which will prevent a retraining.
Assuming both the architecture definition `VGG16.emadl` and the corresponding training configuration `VGG16.cnnt` are located in a folder `models` and the target code should be generated in a `target` folder using the `MXNet` backend, an example of a command is then:
```java -jar embedded-montiarc-emadl-generator-0.2.10-jar-with-dependencies.jar -m models -r VGG16 -o target -b MXNET```
......@@ -8,7 +8,7 @@
<groupId>de.monticore.lang.monticar</groupId>
<artifactId>embedded-montiarc-emadl-generator</artifactId>
<version>0.2.8</version>
<version>0.2.10</version>
<!-- == PROJECT DEPENDENCIES ============================================= -->
......@@ -17,8 +17,8 @@
<!-- .. SE-Libraries .................................................. -->
<emadl.version>0.2.6</emadl.version>
<CNNTrain.version>0.2.6</CNNTrain.version>
<cnnarch-mxnet-generator.version>0.2.12</cnnarch-mxnet-generator.version>
<cnnarch-caffe2-generator.version>0.2.9</cnnarch-caffe2-generator.version>
<cnnarch-mxnet-generator.version>0.2.14-SNAPSHOT</cnnarch-mxnet-generator.version>
<cnnarch-caffe2-generator.version>0.2.10-SNAPSHOT</cnnarch-caffe2-generator.version>
<embedded-montiarc-math-opt-generator>0.1.4</embedded-montiarc-math-opt-generator>
<!-- .. Libraries .................................................. -->
......
......@@ -39,7 +39,7 @@
<id>se-nexus</id>
<username>cibuild</username>
<password>${env.cibuild}</password>
</server>
</server>
<server>
<id>github</id>
......@@ -53,7 +53,7 @@
<id>se-nexus</id>
<mirrorOf>external:*</mirrorOf>
<url>https://nexus.se.rwth-aachen.de/content/groups/public</url>
</mirror>
</mirror>
</mirrors>
<profiles>
......@@ -99,5 +99,5 @@
<activeProfiles>
<activeProfile>se-nexus</activeProfile>
</activeProfiles>
</settings>
\ No newline at end of file
</activeProfiles>
</settings>
......@@ -47,4 +47,13 @@ public enum Backend {
return Optional.empty();
}
}
public static String getBackendString(Backend backend){
switch (backend){
case CAFFE2:
return "CAFFE2";
default:
return "MXNET";
}
}
}
......@@ -28,6 +28,7 @@ import de.monticore.lang.embeddedmontiarc.embeddedmontiarc._symboltable.cncModel
import de.monticore.lang.embeddedmontiarc.embeddedmontiarc._symboltable.instanceStructure.EMAComponentInstanceSymbol;
import de.monticore.lang.math._symboltable.MathStatementsSymbol;
import de.monticore.lang.monticar.cnnarch.CNNArchGenerator;
import de.monticore.lang.monticar.cnnarch.DataPathConfigParser;
import de.monticore.lang.monticar.cnnarch._symboltable.ArchitectureSymbol;
import de.monticore.lang.monticar.cnntrain.CNNTrainGenerator;
import de.monticore.lang.monticar.cnntrain._symboltable.ConfigurationSymbol;
......@@ -44,24 +45,32 @@ import de.se_rwth.commons.Splitters;
import de.se_rwth.commons.logging.Log;
import freemarker.template.TemplateException;
import java.io.IOException;
import java.io.*;
import java.nio.charset.Charset;
import java.nio.file.Files;
import java.nio.file.Path;
import java.nio.file.Paths;
import java.util.*;
import java.security.MessageDigest;
import java.security.NoSuchAlgorithmException;
import java.security.DigestInputStream;
import javax.xml.bind.DatatypeConverter;
public class EMADLGenerator {
private GeneratorEMAMOpt2CPP emamGen;
private CNNArchGenerator cnnArchGenerator;
private CNNTrainGenerator cnnTrainGenerator;
private Backend backend;
private String modelsPath;
public EMADLGenerator(Backend backend) {
this.backend = backend;
emamGen = new GeneratorEMAMOpt2CPP();
emamGen.useArmadilloBackend();
emamGen.setGenerationTargetPath("./target/generated-sources-emadl/");
......@@ -99,7 +108,7 @@ public class EMADLGenerator {
return emamGen;
}
public void generate(String modelPath, String qualifiedName) throws IOException, TemplateException {
public void generate(String modelPath, String qualifiedName, String pythonPath, String forced) throws IOException, TemplateException {
setModelsPath( modelPath );
TaggingResolver symtab = EMADLAbstractSymtab.createSymTabAndTaggingResolver(getModelsPath());
EMAComponentSymbol component = symtab.<EMAComponentSymbol>resolve(qualifiedName, EMAComponentSymbol.KIND).orElse(null);
......@@ -115,20 +124,194 @@ public class EMADLGenerator {
EMAComponentInstanceSymbol instance = component.getEnclosingScope().<EMAComponentInstanceSymbol>resolve(instanceName, EMAComponentInstanceSymbol.KIND).get();
generateFiles(symtab, instance, symtab);
generateFiles(symtab, instance, symtab, pythonPath, forced);
try{
executeCommands();
}catch(Exception e){
System.out.println(e);
}
}
public void executeCommands() throws IOException {
File tempScript = createTempScript();
try {
ProcessBuilder pb = new ProcessBuilder("bash", tempScript.toString());
pb.inheritIO();
Process process = pb.start();
process.waitFor();
}catch(Exception e){
System.out.println(e);
} finally {
tempScript.delete();
}
}
public File createTempScript() throws IOException{
File tempScript = File.createTempFile("script", null);
try{
Writer streamWriter = new OutputStreamWriter(new FileOutputStream(
tempScript));
PrintWriter printWriter = new PrintWriter(streamWriter);
printWriter.println("#!/bin/bash");
printWriter.println("cd " + getGenerationTargetPath());
printWriter.println("mkdir --parents build");
printWriter.println("cd build");
printWriter.println("cmake ..");
printWriter.println("make");
printWriter.close();
}catch(Exception e){
System.out.println(e);
}
return tempScript;
}
public void generateFiles(TaggingResolver taggingResolver, EMAComponentInstanceSymbol EMAComponentSymbol, Scope symtab) throws IOException {
List<FileContent> fileContents = generateStrings(taggingResolver, EMAComponentSymbol, symtab);
public String getChecksumForFile(String filePath) throws IOException {
Path wiki_path = Paths.get(filePath);
MessageDigest md5;
try {
md5 = MessageDigest.getInstance("MD5");
md5.update(Files.readAllBytes(wiki_path));
byte[] digest = md5.digest();
return DatatypeConverter.printHexBinary(digest).toUpperCase();
} catch (NoSuchAlgorithmException e) {
e.printStackTrace();
return "No_Such_Algorithm_Exception";
}
}
public void generateFiles(TaggingResolver taggingResolver, EMAComponentInstanceSymbol EMAComponentSymbol, Scope symtab, String pythonPath, String forced) throws IOException {
Set<EMAComponentInstanceSymbol> allInstances = new HashSet<>();
List<FileContent> fileContents = generateStrings(taggingResolver, EMAComponentSymbol, symtab, allInstances, forced);
for (FileContent fileContent : fileContents) {
emamGen.generateFile(fileContent);
}
// train
Map<String, String> fileContentMap = new HashMap<>();
for(FileContent f : fileContents) {
fileContentMap.put(f.getFileName(), f.getFileContent());
}
List<FileContent> fileContentsTrainingHashes = new ArrayList<>();
List<String> newHashes = new ArrayList<>();
for (EMAComponentInstanceSymbol componentInstance : allInstances) {
Optional<ArchitectureSymbol> architecture = componentInstance.getSpannedScope().resolve("", ArchitectureSymbol.KIND);
if(!architecture.isPresent()) {
continue;
}
if(forced.equals("n")) {
continue;
}
String configFilename = getConfigFilename(componentInstance.getComponentType().getFullName(), componentInstance.getFullName(), componentInstance.getName());
String emadlPath = getModelsPath() + configFilename + ".emadl";
String cnntPath = getModelsPath() + configFilename + ".cnnt";
String emadlHash = getChecksumForFile(emadlPath);
String cnntHash = getChecksumForFile(cnntPath);
String componentConfigFilename = componentInstance.getComponentType().getReferencedSymbol().getFullName().replaceAll("\\.", "/");
String b = backend.getBackendString(backend);
String trainingDataHash = "";
String testDataHash = "";
if(b.equals("CAFFE2")){
trainingDataHash = getChecksumForFile(architecture.get().getDataPath() + "/train_lmdb/data.mdb");
testDataHash = getChecksumForFile(architecture.get().getDataPath() + "/test_lmdb/data.mdb");
}else{
trainingDataHash = getChecksumForFile(architecture.get().getDataPath() + "/train.h5");
testDataHash = getChecksumForFile(architecture.get().getDataPath() + "/test.h5");
}
String trainingHash = emadlHash + "#" + cnntHash + "#" + trainingDataHash + "#" + testDataHash;
boolean alreadyTrained = newHashes.contains(trainingHash) || isAlreadyTrained(trainingHash, componentInstance);
if(alreadyTrained && !forced.equals("y")) {
Log.warn("Training of model " + componentInstance.getFullName() + " skipped");
}
else {
String parsedFullName = componentInstance.getFullName().substring(0, 1).toLowerCase() + componentInstance.getFullName().substring(1).replaceAll("\\.", "_");
String trainerScriptName = "CNNTrainer_" + parsedFullName + ".py";
String trainingPath = getGenerationTargetPath() + trainerScriptName;
if(Files.exists(Paths.get(trainingPath))){
ProcessBuilder pb = new ProcessBuilder(Arrays.asList(pythonPath, trainingPath)).inheritIO();
Process p = pb.start();
int exitCode = 0;
try {
exitCode = p.waitFor();
}
catch(InterruptedException e) {
//throw new Exception("Error: Training aborted" + e.toString());
System.out.println("Error: Training aborted" + e.toString());
continue;
}
if(exitCode != 0) {
//throw new Exception("Error: Training error");
System.out.println("Error: Training failed" + Integer.toString(exitCode));
continue;
}
fileContentsTrainingHashes.add(new FileContent(trainingHash, componentConfigFilename + ".training_hash"));
newHashes.add(trainingHash);
}
else{
System.out.println("Trainingfile " + trainingPath + " not found.");
}
}
}
for (FileContent fileContent : fileContentsTrainingHashes) {
emamGen.generateFile(fileContent);
}
}
private static String convertByteArrayToHexString(byte[] arrayBytes) {
StringBuffer stringBuffer = new StringBuffer();
for (int i = 0; i < arrayBytes.length; i++) {
stringBuffer.append(Integer.toString((arrayBytes[i] & 0xff) + 0x100, 16)
.substring(1));
}
return stringBuffer.toString();
}
private boolean isAlreadyTrained(String trainingHash, EMAComponentInstanceSymbol componentInstance) {
try {
EMAComponentSymbol component = componentInstance.getComponentType().getReferencedSymbol();
String componentConfigFilename = component.getFullName().replaceAll("\\.", "/");
String checkFilePathString = getGenerationTargetPath() + componentConfigFilename + ".training_hash";
Path checkFilePath = Paths.get( checkFilePathString);
if(Files.exists(checkFilePath)) {
List<String> hashes = Files.readAllLines(checkFilePath);
for(String hash : hashes) {
if(hash.equals(trainingHash)) {
return true;
}
}
}
return false;
}
catch(Exception e) {
return false;
}
}
public List<FileContent> generateStrings(TaggingResolver taggingResolver, EMAComponentInstanceSymbol componentInstanceSymbol, Scope symtab){
public List<FileContent> generateStrings(TaggingResolver taggingResolver, EMAComponentInstanceSymbol componentInstanceSymbol, Scope symtab, Set<EMAComponentInstanceSymbol> allInstances, String forced){
List<FileContent> fileContents = new ArrayList<>();
Set<EMAComponentInstanceSymbol> allInstances = new HashSet<>();
generateComponent(fileContents, allInstances, taggingResolver, componentInstanceSymbol, symtab);
......@@ -177,6 +360,12 @@ public class EMADLGenerator {
EMADLCocos.checkAll(componentInstanceSymbol);
if (architecture.isPresent()){
DataPathConfigParser newParserConfig = new DataPathConfigParser(getModelsPath() + "data_paths.txt");
String dPath = newParserConfig.getDataPath(EMAComponentSymbol.getFullName());
/*String dPath = DataPathConfigParser.getDataPath(getModelsPath() + "data_paths.txt", componentSymbol.getFullName());*/
architecture.get().setDataPath(dPath);
architecture.get().setComponentName(EMAComponentSymbol.getFullName());
generateCNN(fileContents, taggingResolver, componentInstanceSymbol, architecture.get());
}
else if (mathStatements.isPresent()){
......@@ -262,6 +451,32 @@ public class EMADLGenerator {
}
}
private String getConfigFilename(String mainComponentName, String componentFullName, String componentName) {
String trainConfigFilename;
String mainComponentConfigFilename = mainComponentName.replaceAll("\\.", "/");
String componentConfigFilename = componentFullName.replaceAll("\\.", "/");
String instanceConfigFilename = componentFullName.replaceAll("\\.", "/") + "_" + componentName;
if (Files.exists(Paths.get( getModelsPath() + instanceConfigFilename + ".cnnt"))) {
trainConfigFilename = instanceConfigFilename;
}
else if (Files.exists(Paths.get( getModelsPath() + componentConfigFilename + ".cnnt"))){
trainConfigFilename = componentConfigFilename;
}
else if (Files.exists(Paths.get( getModelsPath() + mainComponentConfigFilename + ".cnnt"))){
trainConfigFilename = mainComponentConfigFilename;
}
else{
Log.error("Missing configuration file. " +
"Could not find a file with any of the following names (only one needed): '"
+ getModelsPath() + instanceConfigFilename + ".cnnt', '"
+ getModelsPath() + componentConfigFilename + ".cnnt', '"
+ getModelsPath() + mainComponentConfigFilename + ".cnnt'." +
" These files denote respectively the configuration for the single instance, the component or the whole system.");
return null;
}
return trainConfigFilename;
}
public List<FileContent> generateCNNTrainer(Set<EMAComponentInstanceSymbol> allInstances, String mainComponentName) {
List<FileContent> fileContents = new ArrayList<>();
for (EMAComponentInstanceSymbol componentInstance : allInstances) {
......@@ -269,28 +484,7 @@ public class EMADLGenerator {
Optional<ArchitectureSymbol> architecture = component.getSpannedScope().resolve("", ArchitectureSymbol.KIND);
if (architecture.isPresent()) {
String trainConfigFilename;
String mainComponentConfigFilename = mainComponentName.replaceAll("\\.", "/");
String componentConfigFilename = component.getFullName().replaceAll("\\.", "/");
String instanceConfigFilename = component.getFullName().replaceAll("\\.", "/") + "_" + component.getName();
if (Files.exists(Paths.get( getModelsPath() + instanceConfigFilename + ".cnnt"))) {
trainConfigFilename = instanceConfigFilename;
}
else if (Files.exists(Paths.get( getModelsPath() + componentConfigFilename + ".cnnt"))){
trainConfigFilename = componentConfigFilename;
}
else if (Files.exists(Paths.get( getModelsPath() + mainComponentConfigFilename + ".cnnt"))){
trainConfigFilename = mainComponentConfigFilename;
}
else{
Log.error("Missing configuration file. " +
"Could not find a file with any of the following names (only one needed): '"
+ getModelsPath() + instanceConfigFilename + ".cnnt', '"
+ getModelsPath() + componentConfigFilename + ".cnnt', '"
+ getModelsPath() + mainComponentConfigFilename + ".cnnt'." +
" These files denote respectively the configuration for the single instance, the component or the whole system.");
return null;
}
String trainConfigFilename = getConfigFilename(mainComponentName, component.getFullName(), component.getName());
//should be removed when CNNTrain supports packages
List<String> names = Splitter.on("/").splitToList(trainConfigFilename);
......
......@@ -20,11 +20,11 @@
*/
package de.monticore.lang.monticar.emadl.generator;
import de.se_rwth.commons.logging.Log;
import freemarker.template.TemplateException;
import org.apache.commons.cli.*;
import java.io.File;
import java.io.IOException;
import java.util.Optional;
......@@ -57,6 +57,21 @@ public class EMADLGeneratorCli {
.required(false)
.build();
public static final Option OPTION_TRAINING_PYTHON_PATH = Option.builder("p")
.longOpt("python")
.desc("path to python. Default is /usr/bin/python")
.hasArg(true)
.required(false)
.build();
public static final Option OPTION_RESTRAINED_TRAINING = Option.builder("f")
.longOpt("forced")
.desc("no training or a forced training. Options: y (a forced training), n (no training)")
.hasArg(true)
.required(false)
.build();
private EMADLGeneratorCli() {
}
......@@ -75,6 +90,8 @@ public class EMADLGeneratorCli {
options.addOption(OPTION_ROOT_MODEL);
options.addOption(OPTION_OUTPUT_PATH);
options.addOption(OPTION_BACKEND);
options.addOption(OPTION_RESTRAINED_TRAINING);
options.addOption(OPTION_TRAINING_PYTHON_PATH);
return options;
}
......@@ -94,7 +111,10 @@ public class EMADLGeneratorCli {
String rootModelName = cliArgs.getOptionValue(OPTION_ROOT_MODEL.getOpt());
String outputPath = cliArgs.getOptionValue(OPTION_OUTPUT_PATH.getOpt());
String backendString = cliArgs.getOptionValue(OPTION_BACKEND.getOpt());
String forced = cliArgs.getOptionValue(OPTION_RESTRAINED_TRAINING.getOpt());
String pythonPath = cliArgs.getOptionValue(OPTION_TRAINING_PYTHON_PATH.getOpt());
final String DEFAULT_BACKEND = "MXNET";
final String DEFAULT_FORCED = "UNSET";
if (backendString == null) {
Log.warn("backend not specified. backend set to default value " + DEFAULT_BACKEND);
......@@ -106,13 +126,25 @@ public class EMADLGeneratorCli {
Log.warn("specified backend " + backendString + " not supported. backend set to default value " + DEFAULT_BACKEND);
backend = Backend.getBackendFromString(DEFAULT_BACKEND);
}
if (pythonPath == null) {
pythonPath = "/usr/bin/python";
}
if (forced == null) {
forced = DEFAULT_FORCED;
}
else if (!forced.equals("y") && !forced.equals("n")) {
Log.error("specified setting ("+forced+") for forcing/preventing training not supported. set to default value " + DEFAULT_FORCED);
forced = DEFAULT_FORCED;
}
EMADLGenerator generator = new EMADLGenerator(backend.get());
if (outputPath != null){
generator.setGenerationTargetPath(outputPath);
}
try{
generator.generate(cliArgs.getOptionValue(OPTION_MODELS_PATH.getOpt()), rootModelName);
generator.generate(cliArgs.getOptionValue(OPTION_MODELS_PATH.getOpt()), rootModelName, pythonPath, forced);
}
catch (IOException e){
Log.error("io error during generation", e);
......
......@@ -20,6 +20,8 @@
*/
package de.monticore.lang.monticar.emadl;
import de.monticore.lang.monticar.emadl.generator.Backend;
import de.monticore.lang.monticar.emadl.generator.EMADLGenerator;
import de.monticore.lang.monticar.emadl.generator.EMADLGeneratorCli;
import de.se_rwth.commons.logging.Log;
import freemarker.template.TemplateException;
......@@ -27,10 +29,15 @@ import org.junit.Before;
import org.junit.Test;
import java.io.IOException;
import java.nio.charset.Charset;
import java.nio.file.Files;