Commit d8a0221f authored by Abdallah Atouani's avatar Abdallah Atouani

Merge remote-tracking branch 'origin/develop' into feature/path_tagging

# Conflicts:
#	pom.xml
#	src/main/java/de/monticore/lang/monticar/emadl/generator/EMADLGenerator.java
#	src/test/java/de/monticore/lang/monticar/emadl/GenerationTest.java
parents cbeed5e9 d9d9cf5d
......@@ -5,6 +5,6 @@ nppBackup
.classpath
.idea
.git
.vscode
*.iml
train.log
......@@ -24,7 +24,8 @@ stages:
- linux
- deploy
masterJobLinux:
git masterJobLinux:
stage: deploy
image: maven:3-jdk-8
script:
......@@ -34,6 +35,7 @@ masterJobLinux:
only:
- master
integrationMXNetJobLinux:
stage: linux
image: registry.git.rwth-aachen.de/monticore/embeddedmontiarc/generators/emadl2cpp/integrationtests/mxnet:v0.0.3
......@@ -47,17 +49,25 @@ integrationCaffe2JobLinux:
script:
- mvn -Dorg.slf4j.simpleLogger.log.org.apache.maven.cli.transfer.Slf4jMavenTransferListener=warn -B clean install --settings settings.xml -Dtest=IntegrationCaffe2Test
integrationGluonJobLinux:
stage: linux
image: registry.git.rwth-aachen.de/monticore/embeddedmontiarc/generators/emadl2cpp/integrationtests/mxnet:v0.0.3
script:
- mvn -Dorg.slf4j.simpleLogger.log.org.apache.maven.cli.transfer.Slf4jMavenTransferListener=warn -B clean install --settings settings.xml -Dtest=IntegrationGluonTest
integrationPythonWrapperTest:
stage: linux
image: registry.git.rwth-aachen.de/monticore/embeddedmontiarc/generators/emadl2pythonwrapper/tests/mvn-swig:latest
script:
- mvn -Dorg.slf4j.simpleLogger.log.org.apache.maven.cli.transfer.Slf4jMavenTransferListener=warn -B clean install --settings settings.xml -Dtest=IntegrationPythonWrapperTest
masterJobWindows:
stage: windows
script:
- mvn -Dorg.slf4j.simpleLogger.log.org.apache.maven.cli.transfer.Slf4jMavenTransferListener=warn -B clean install --settings settings.xml -Dtest="GenerationTest,SymtabTest"
- mvn -Dorg.slf4j.simpleLogger.log.org.apache.maven.cli.transfer.Slf4jMavenTransferListener=warn -B -U clean install --settings settings.xml -Dtest="GenerationTest,SymtabTest"
tags:
- Windows10
......
{
"configurations": [
{
"type": "java",
"name": "CodeLens (Launch) - EMADLGeneratorCli",
"request": "launch",
"mainClass": "de.monticore.lang.monticar.emadl.generator.EMADLGeneratorCli",
"projectName": "embedded-montiarc-emadl-generator"
}
]
}
\ No newline at end of file
......@@ -8,7 +8,7 @@
<groupId>de.monticore.lang.monticar</groupId>
<artifactId>embedded-montiarc-emadl-generator</artifactId>
<version>0.3.0</version>
<version>0.3.4-SNAPSHOT</version>
<!-- == PROJECT DEPENDENCIES ============================================= -->
......@@ -16,10 +16,11 @@
<!-- .. SE-Libraries .................................................. -->
<emadl.version>0.2.9-SNAPSHOT</emadl.version>
<CNNTrain.version>0.2.6</CNNTrain.version>
<cnnarch-mxnet-generator.version>0.2.14-SNAPSHOT</cnnarch-mxnet-generator.version>
<cnnarch-caffe2-generator.version>0.2.11-SNAPSHOT</cnnarch-caffe2-generator.version>
<cnnarch-gluon-generator.version>0.1.6</cnnarch-gluon-generator.version>
<CNNTrain.version>0.3.6-SNAPSHOT</CNNTrain.version>
<cnnarch-generator.version>0.0.2-SNAPSHOT</cnnarch-generator.version>
<cnnarch-mxnet-generator.version>0.2.16-SNAPSHOT</cnnarch-mxnet-generator.version>
<cnnarch-caffe2-generator.version>0.2.12-SNAPSHOT</cnnarch-caffe2-generator.version>
<cnnarch-gluon-generator.version>0.2.5-SNAPSHOT</cnnarch-gluon-generator.version>
<embedded-montiarc-math-opt-generator>0.1.4</embedded-montiarc-math-opt-generator>
<!-- .. Libraries .................................................. -->
......@@ -73,6 +74,12 @@
<version>${embedded-montiarc-math-opt-generator}</version>
</dependency>
<dependency>
<groupId>de.monticore.lang.monticar</groupId>
<artifactId>cnnarch-generator</artifactId>
<version>${cnnarch-generator.version}</version>
</dependency>
<dependency>
<groupId>de.monticore.lang.monticar</groupId>
<artifactId>cnnarch-mxnet-generator</artifactId>
......
package de.monticore.lang.monticar.emadl.generator;
import de.monticore.lang.monticar.cnnarch.CNNArchGenerator;
import de.monticore.lang.monticar.cnnarch.generator.CNNArchGenerator;
import de.monticore.lang.monticar.cnnarch.generator.CNNTrainGenerator;
import de.monticore.lang.monticar.cnnarch.gluongenerator.CNNArch2Gluon;
import de.monticore.lang.monticar.cnnarch.gluongenerator.CNNTrain2Gluon;
import de.monticore.lang.monticar.cnnarch.mxnetgenerator.CNNArch2MxNet;
import de.monticore.lang.monticar.cnnarch.caffe2generator.CNNArch2Caffe2;
import de.monticore.lang.monticar.cnnarch.mxnetgenerator.CNNTrain2MxNet;
import de.monticore.lang.monticar.cnnarch.caffe2generator.CNNTrain2Caffe2;
import de.monticore.lang.monticar.cnntrain.CNNTrainGenerator;
import de.monticore.lang.monticar.emadl.generator.reinforcementlearning.RewardFunctionCppGenerator;
import java.util.Optional;
......@@ -40,7 +41,7 @@ public enum Backend {
}
@Override
public CNNTrainGenerator getCNNTrainGenerator() {
return new CNNTrain2Gluon();
return new CNNTrain2Gluon(new RewardFunctionCppGenerator());
}
};
......
......@@ -28,10 +28,15 @@ import de.monticore.lang.embeddedmontiarc.embeddedmontiarc._symboltable.cncModel
import de.monticore.lang.embeddedmontiarc.embeddedmontiarc._symboltable.instanceStructure.EMAComponentInstanceSymbol;
import de.monticore.lang.embeddedmontiarcdynamic.embeddedmontiarcdynamic._symboltable.cncModel.EMADynamicComponentSymbol;
import de.monticore.lang.math._symboltable.MathStatementsSymbol;
import de.monticore.lang.monticar.cnnarch.CNNArchGenerator;
import de.monticore.lang.monticar.cnnarch.DataPathConfigParser;
import de.monticore.lang.monticar.cnnarch.generator.CNNArchGenerator;
import de.monticore.lang.monticar.cnnarch.generator.DataPathConfigParser;
import de.monticore.lang.monticar.cnnarch.generator.CNNTrainGenerator;
import de.monticore.lang.monticar.cnnarch._symboltable.ArchitectureSymbol;
import de.monticore.lang.monticar.cnntrain.CNNTrainGenerator;
import de.monticore.lang.monticar.cnnarch._symboltable.SerialCompositeElementSymbol;
import de.monticore.lang.monticar.cnnarch.gluongenerator.CNNTrain2Gluon;
import de.monticore.lang.monticar.cnnarch.gluongenerator.annotations.ArchitectureAdapter;
import de.monticore.lang.monticar.cnntrain._cocos.CNNTrainCoCoChecker;
import de.monticore.lang.monticar.cnntrain._cocos.CNNTrainCocos;
import de.monticore.lang.monticar.cnntrain._symboltable.ConfigurationSymbol;
import de.monticore.lang.monticar.emadl._cocos.EMADLCocos;
import de.monticore.lang.monticar.emadl.tagging.dltag.DataPathSymbol;
......@@ -70,7 +75,7 @@ public class EMADLGenerator {
private String modelsPath;
private Map<String, ArchitectureSymbol> processedArchitecture;
public EMADLGenerator(Backend backend) {
this.backend = backend;
......@@ -112,8 +117,21 @@ public class EMADLGenerator {
}
public void generate(String modelPath, String qualifiedName, String pythonPath, String forced, boolean doCompile) throws IOException, TemplateException {
processedArchitecture = new HashMap<>();
setModelsPath( modelPath );
TaggingResolver symtab = EMADLAbstractSymtab.createSymTabAndTaggingResolver(getModelsPath());
EMAComponentInstanceSymbol instance = resolveComponentInstanceSymbol(qualifiedName, symtab);
generateFiles(symtab, instance, symtab, pythonPath, forced);
if (doCompile) {
compile();
}
processedArchitecture = null;
}
private EMAComponentInstanceSymbol resolveComponentInstanceSymbol(String qualifiedName, TaggingResolver symtab) {
EMAComponentSymbol component = symtab.<EMAComponentSymbol>resolve(qualifiedName, EMAComponentSymbol.KIND).orElse(null);
List<String> splitName = Splitters.DOT.splitToList(qualifiedName);
......@@ -125,13 +143,7 @@ public class EMADLGenerator {
System.exit(1);
}
EMAComponentInstanceSymbol instance = component.getEnclosingScope().<EMAComponentInstanceSymbol>resolve(instanceName, EMAComponentInstanceSymbol.KIND).get();
generateFiles(symtab, instance, symtab, pythonPath, forced);
if (doCompile) {
compile();
}
return component.getEnclosingScope().<EMAComponentInstanceSymbol>resolve(instanceName, EMAComponentInstanceSymbol.KIND).get();
}
public void compile() throws IOException {
......@@ -230,12 +242,14 @@ public class EMADLGenerator {
String b = backend.getBackendString(backend);
String trainingDataHash = "";
String testDataHash = "";
if(b.equals("CAFFE2")){
trainingDataHash = getChecksumForFile(architecture.get().getDataPath() + "/train_lmdb/data.mdb");
testDataHash = getChecksumForFile(architecture.get().getDataPath() + "/test_lmdb/data.mdb");
}else{
trainingDataHash = getChecksumForFile(architecture.get().getDataPath() + "/train.h5");
testDataHash = getChecksumForFile(architecture.get().getDataPath() + "/test.h5");
if (architecture.get().getDataPath() != null) {
if (b.equals("CAFFE2")) {
trainingDataHash = getChecksumForFile(architecture.get().getDataPath() + "/train_lmdb/data.mdb");
testDataHash = getChecksumForFile(architecture.get().getDataPath() + "/test_lmdb/data.mdb");
} else {
trainingDataHash = getChecksumForFile(architecture.get().getDataPath() + "/train.h5");
testDataHash = getChecksumForFile(architecture.get().getDataPath() + "/test.h5");
}
}
String trainingHash = emadlHash + "#" + cnntHash + "#" + trainingDataHash + "#" + testDataHash;
......@@ -314,6 +328,7 @@ public class EMADLGenerator {
public List<FileContent> generateStrings(TaggingResolver taggingResolver, EMAComponentInstanceSymbol componentInstanceSymbol, Scope symtab, Set<EMAComponentInstanceSymbol> allInstances, String forced){
List<FileContent> fileContents = new ArrayList<>();
processedArchitecture = new HashMap<>();
generateComponent(fileContents, allInstances, taggingResolver, componentInstanceSymbol, symtab);
......@@ -340,6 +355,7 @@ public class EMADLGenerator {
fixArmadilloImports(fileContents);
processedArchitecture = null;
return fileContents;
}
......@@ -380,10 +396,14 @@ public class EMADLGenerator {
EMADLCocos.checkAll(componentInstanceSymbol);
if (architecture.isPresent()){
cnnArchGenerator.check(architecture.get());
String dPath = getDataPath(taggingResolver, EMAComponentSymbol, componentInstanceSymbol);
architecture.get().setDataPath(dPath);
architecture.get().setComponentName(EMAComponentSymbol.getFullName());
generateCNN(fileContents, taggingResolver, componentInstanceSymbol, architecture.get());
if (processedArchitecture != null) {
processedArchitecture.put(architecture.get().getComponentName(), architecture.get());
}
}
else if (mathStatements.isPresent()){
generateMathComponent(fileContents, taggingResolver, componentInstanceSymbol, mathStatements.get());
......@@ -415,7 +435,7 @@ public class EMADLGenerator {
String component = emamGen.generateString(taggingResolver, instance, (MathStatementsSymbol) null);
FileContent componentFileContent = new FileContent(
transformComponent(component, "CNNPredictor_" + fullName, executeMethod),
transformComponent(component, "CNNPredictor_" + fullName, executeMethod, architecture),
instance);
for (String fileName : contentMap.keySet()){
......@@ -425,18 +445,26 @@ public class EMADLGenerator {
fileContents.add(new FileContent(readResource("CNNTranslator.h", Charsets.UTF_8), "CNNTranslator.h"));
}
protected String transformComponent(String component, String predictorClassName, String executeMethod){
String networkVariableName = "_cnn_";
protected String transformComponent(String component, String predictorClassName, String executeMethod, ArchitectureSymbol architecture){
//insert includes
component = component.replaceFirst("using namespace",
"#include \"" + predictorClassName + ".h" + "\"\n" +
"#include \"CNNTranslator.h\"\n" +
"using namespace");
//insert network attribute
component = component.replaceFirst("public:",
"public:\n" + predictorClassName + " " + networkVariableName + ";");
//insert network attribute for predictor of each network
String networkAttributes = "public:";
int i = 0;
for (SerialCompositeElementSymbol stream : architecture.getStreams()) {
if (stream.isNetwork()) {
networkAttributes += "\n" + predictorClassName + "_" + i + " _predictor_" + i + "_;";
}
++i;
}
component = component.replaceFirst("public:", networkAttributes);
//insert execute method
component = component.replaceFirst("void execute\\(\\)\\s\\{\\s\\}",
......@@ -504,10 +532,46 @@ public class EMADLGenerator {
String trainConfigFilename = getConfigFilename(mainComponentName, component.getFullName(), component.getName());
//should be removed when CNNTrain supports packages
cnnTrainGenerator.setGenerationTargetPath(getGenerationTargetPath());
if (cnnTrainGenerator instanceof CNNTrain2Gluon) {
((CNNTrain2Gluon) cnnTrainGenerator).setRootProjectModelsDir(getModelsPath());
}
List<String> names = Splitter.on("/").splitToList(trainConfigFilename);
trainConfigFilename = names.get(names.size()-1);
Path modelPath = Paths.get(getModelsPath() + Joiner.on("/").join(names.subList(0,names.size()-1)));
ConfigurationSymbol configuration = cnnTrainGenerator.getConfigurationSymbol(modelPath, trainConfigFilename);
// Annotate train configuration with architecture
final String fullConfigName = String.join(".", names);
ArchitectureSymbol correspondingArchitecture = this.processedArchitecture.get(fullConfigName);
assert correspondingArchitecture != null : "No architecture found for train " + fullConfigName + " configuration!";
configuration.setTrainedArchitecture(
new ArchitectureAdapter(correspondingArchitecture.getName(), correspondingArchitecture));
CNNTrainCocos.checkTrainedArchitectureCoCos(configuration);
// Resolve critic network if critic is present
if (configuration.getCriticName().isPresent()) {
String fullCriticName = configuration.getCriticName().get();
int indexOfFirstNameCharacter = fullCriticName.lastIndexOf('.') + 1;
fullCriticName = fullCriticName.substring(0, indexOfFirstNameCharacter)
+ fullCriticName.substring(indexOfFirstNameCharacter, indexOfFirstNameCharacter + 1).toUpperCase()
+ fullCriticName.substring(indexOfFirstNameCharacter + 1);
TaggingResolver symtab = EMADLAbstractSymtab.createSymTabAndTaggingResolver(getModelsPath());
EMAComponentInstanceSymbol instanceSymbol = resolveComponentInstanceSymbol(fullCriticName, symtab);
EMADLCocos.checkAll(instanceSymbol);
Optional<ArchitectureSymbol> critic = instanceSymbol.getSpannedScope().resolve("", ArchitectureSymbol.KIND);
if (!critic.isPresent()) {
Log.error("During the resolving of critic component: Critic component "
+ fullCriticName + " does not have a CNN implementation but is required to have one");
System.exit(-1);
}
critic.get().setComponentName(fullCriticName);
configuration.setCriticNetwork(new ArchitectureAdapter(fullCriticName, critic.get()));
CNNTrainCocos.checkCriticCocos(configuration);
}
cnnTrainGenerator.setInstanceName(componentInstance.getFullName().replaceAll("\\.", "_"));
Map<String, String> fileContentMap = cnnTrainGenerator.generateStrings(configuration);
for (String fileName : fileContentMap.keySet()){
......
......@@ -52,7 +52,7 @@ public class EMADLGeneratorCli {
.build();
public static final Option OPTION_BACKEND = Option.builder("b")
.longOpt("backend")
.desc("deep-learning-framework backend. Options: MXNET, CAFFE2")
.desc("deep-learning-framework backend. Options: MXNET, CAFFE2, GLUON")
.hasArg(true)
.required(false)
.build();
......
package de.monticore.lang.monticar.emadl.generator.reinforcementlearning;
import de.monticore.lang.embeddedmontiarc.embeddedmontiarc._symboltable.instanceStructure.EMAComponentInstanceSymbol;
import de.monticore.lang.monticar.cnnarch.gluongenerator.reinforcement.RewardFunctionSourceGenerator;
import de.monticore.lang.monticar.emadl.generator.EMADLAbstractSymtab;
import de.monticore.lang.monticar.generator.cpp.GeneratorEMAMOpt2CPP;
import de.monticore.lang.tagging._symboltable.TaggingResolver;
import de.se_rwth.commons.logging.Log;
import java.io.IOException;
import java.util.Optional;
public class RewardFunctionCppGenerator implements RewardFunctionSourceGenerator{
public RewardFunctionCppGenerator() {
}
@Override
public EMAComponentInstanceSymbol resolveSymbol(TaggingResolver taggingResolver, String rootModel) {
Optional<EMAComponentInstanceSymbol> instanceSymbol = taggingResolver
.<EMAComponentInstanceSymbol>resolve(rootModel, EMAComponentInstanceSymbol.KIND);
if (!instanceSymbol.isPresent()) {
Log.error("Generation of reward function is not possible: Cannot resolve component instance "
+ rootModel);
}
return instanceSymbol.get();
}
@Override
public void generate(EMAComponentInstanceSymbol componentInstanceSymbol, TaggingResolver taggingResolver,
String targetPath) {
GeneratorEMAMOpt2CPP generator = new GeneratorEMAMOpt2CPP();
generator.useArmadilloBackend();
generator.setGenerationTargetPath(targetPath);
try {
generator.generate(componentInstanceSymbol, taggingResolver);
} catch (IOException e) {
Log.error("Generation of reward function is not possible: " + e.getMessage());
}
}
@Override
public void generate(String modelPath, String rootModel, String targetPath) {
TaggingResolver taggingResolver = createTaggingResolver(modelPath);
EMAComponentInstanceSymbol instanceSymbol = resolveSymbol(taggingResolver, rootModel);
generate(instanceSymbol, taggingResolver, targetPath);
}
@Override
public TaggingResolver createTaggingResolver(final String modelPath) {
return EMADLAbstractSymtab.createSymTabAndTaggingResolver(modelPath);
}
}
......@@ -23,9 +23,11 @@ package de.monticore.lang.monticar.emadl;
import de.monticore.lang.monticar.emadl.generator.Backend;
import de.monticore.lang.monticar.emadl.generator.EMADLGenerator;
import de.monticore.lang.monticar.emadl.generator.EMADLGeneratorCli;
import de.se_rwth.commons.logging.Finding;
import de.se_rwth.commons.logging.Log;
import freemarker.template.TemplateException;
import org.junit.Before;
import org.junit.Ignore;
import org.junit.Test;
import java.io.IOException;
......@@ -35,14 +37,15 @@ import java.nio.file.Path;
import java.nio.file.Paths;
import java.util.Arrays;
import java.util.List;
import java.util.stream.Collectors;
import static junit.framework.TestCase.assertEquals;
import static junit.framework.TestCase.assertTrue;
import static org.junit.Assert.assertEquals;
import static org.junit.Assert.assertFalse;
import org.junit.rules.ExpectedException;
public class GenerationTest extends AbstractSymtabTest {
@Before
public void setUp() {
// ensure an empty log
......@@ -58,19 +61,25 @@ public class GenerationTest extends AbstractSymtabTest {
EMADLGeneratorCli.main(args);
assertTrue(Log.getFindings().isEmpty());
checkFilesAreEqual(Paths.get("./target/generated-sources-emadl"), Paths.get("./src/test/resources/target_code"),
Arrays.asList("cifar10_cifar10Classifier.cpp", "cifar10_cifar10Classifier.h",
"CNNCreator_cifar10_cifar10Classifier_net.py", "CNNBufferFile.h",
"CNNPredictor_cifar10_cifar10Classifier_net.h", "cifar10_cifar10Classifier_net.h",
"CNNTranslator.h", "cifar10_cifar10Classifier_calculateClass.h",
checkFilesAreEqual(
Paths.get("./target/generated-sources-emadl"),
Paths.get("./src/test/resources/target_code"),
Arrays.asList(
"cifar10_cifar10Classifier.cpp",
"cifar10_cifar10Classifier.h",
"CNNCreator_cifar10_cifar10Classifier_net.py",
"CNNBufferFile.h",
"CNNPredictor_cifar10_cifar10Classifier_net.h",
"cifar10_cifar10Classifier_net.h",
"CNNTranslator.h",
"cifar10_cifar10Classifier_calculateClass.h",
"CNNTrainer_cifar10_cifar10Classifier_net.py"));
}
@Test
public void testSimulatorGeneration() throws IOException, TemplateException {
Log.getFindings().clear();
String[] args = { "-m", "src/test/resources/models/", "-r", "simulator.MainController", "-b", "MXNET", "-f",
"n", "-c", "n" };
String[] args = {"-m", "src/test/resources/models/", "-r", "simulator.MainController", "-b", "MXNET", "-f", "n", "-c", "n"};
EMADLGeneratorCli.main(args);
assertTrue(Log.getFindings().isEmpty());
}
......@@ -78,7 +87,7 @@ public class GenerationTest extends AbstractSymtabTest {
@Test
public void testAddGeneration() throws IOException, TemplateException {
Log.getFindings().clear();
String[] args = { "-m", "src/test/resources/models/", "-r", "Add", "-b", "MXNET", "-f", "n", "-c", "n" };
String[] args = {"-m", "src/test/resources/models/", "-r", "Add", "-b", "MXNET", "-f", "n", "-c", "n"};
EMADLGeneratorCli.main(args);
assertTrue(Log.getFindings().isEmpty());
}
......@@ -86,7 +95,7 @@ public class GenerationTest extends AbstractSymtabTest {
@Test
public void testAlexnetGeneration() throws IOException, TemplateException {
Log.getFindings().clear();
String[] args = { "-m", "src/test/resources/models/", "-r", "Alexnet", "-b", "MXNET", "-f", "n", "-c", "n" };
String[] args = {"-m", "src/test/resources/models/", "-r", "Alexnet", "-b", "MXNET", "-f", "n", "-c", "n"};
EMADLGeneratorCli.main(args);
assertTrue(Log.getFindings().isEmpty());
}
......@@ -94,7 +103,7 @@ public class GenerationTest extends AbstractSymtabTest {
@Test
public void testResNeXtGeneration() throws IOException, TemplateException {
Log.getFindings().clear();
String[] args = { "-m", "src/test/resources/models/", "-r", "ResNeXt50", "-b", "MXNET", "-f", "n", "-c", "n" };
String[] args = {"-m", "src/test/resources/models/", "-r", "ResNeXt50", "-b", "MXNET", "-f", "n", "-c", "n"};
EMADLGeneratorCli.main(args);
assertTrue(Log.getFindings().isEmpty());
}
......@@ -102,8 +111,7 @@ public class GenerationTest extends AbstractSymtabTest {
@Test
public void testThreeInputGeneration() throws IOException, TemplateException {
Log.getFindings().clear();
String[] args = { "-m", "src/test/resources/models/", "-r", "ThreeInputCNN_M14", "-b", "MXNET", "-f", "n", "-c",
"n" };
String[] args = {"-m", "src/test/resources/models/", "-r", "ThreeInputCNN_M14", "-b", "MXNET", "-f", "n", "-c", "n"};
EMADLGeneratorCli.main(args);
assertTrue(Log.getFindings().size() == 1);
}
......@@ -111,8 +119,7 @@ public class GenerationTest extends AbstractSymtabTest {
@Test
public void testMultipleOutputsGeneration() throws IOException, TemplateException {
Log.getFindings().clear();
String[] args = { "-m", "src/test/resources/models/", "-r", "MultipleOutputs", "-b", "MXNET", "-f", "n", "-c",
"n" };
String[] args = {"-m", "src/test/resources/models/", "-r", "MultipleOutputs", "-b", "MXNET", "-f", "n", "-c", "n"};
EMADLGeneratorCli.main(args);
assertTrue(Log.getFindings().size() == 1);
}
......@@ -120,7 +127,7 @@ public class GenerationTest extends AbstractSymtabTest {
@Test
public void testVGGGeneration() throws IOException, TemplateException {
Log.getFindings().clear();
String[] args = { "-m", "src/test/resources/models/", "-r", "VGG16", "-b", "MXNET", "-f", "n", "-c", "n" };
String[] args = {"-m", "src/test/resources/models/", "-r", "VGG16", "-b", "MXNET", "-f", "n", "-c", "n"};
EMADLGeneratorCli.main(args);
assertTrue(Log.getFindings().isEmpty());
}
......@@ -129,59 +136,152 @@ public class GenerationTest extends AbstractSymtabTest {
public void testMultipleInstances() throws IOException, TemplateException {
try {
Log.getFindings().clear();
String[] args = { "-m", "src/test/resources/models/", "-r", "InstanceTest.MainB", "-b", "MXNET", "-f", "n",
"-c", "n" };
String[] args = {"-m", "src/test/resources/models/", "-r", "InstanceTest.MainB", "-b", "MXNET", "-f", "n", "-c", "n"};
EMADLGeneratorCli.main(args);
assertTrue(Log.getFindings().isEmpty());
} catch (Exception e) {
}
catch(Exception e) {
e.printStackTrace();
}
}
@Test
public void testMnistClassifier() throws IOException, TemplateException {
public void testMnistClassifierForCaffe2() throws IOException, TemplateException {
Log.getFindings().clear();
String[] args = { "-m", "src/test/resources/models/", "-r", "mnist.MnistClassifier", "-b", "CAFFE2", "-f", "n",
"-c", "n" };
String[] args = {"-m", "src/test/resources/models/", "-r", "mnist.MnistClassifier", "-b", "CAFFE2", "-f", "n", "-c", "n"};
EMADLGeneratorCli.main(args);
assertTrue(Log.getFindings().isEmpty());
checkFilesAreEqual(Paths.get("./target/generated-sources-emadl"), Paths.get("./src/test/resources/target_code"),
Arrays.asList("mnist_mnistClassifier.cpp", "mnist_mnistClassifier.h",
"CNNCreator_mnist_mnistClassifier_net.py", "CNNPredictor_mnist_mnistClassifier_net.h",
"mnist_mnistClassifier_net.h", "CNNTranslator.h", "mnist_mnistClassifier_calculateClass.h",
checkFilesAreEqual(
Paths.get("./target/generated-sources-emadl"),
Paths.get("./src/test/resources/target_code"),
Arrays.asList(
"mnist_mnistClassifier.cpp",
"mnist_mnistClassifier.h",
"CNNCreator_mnist_mnistClassifier_net.py",
"CNNPredictor_mnist_mnistClassifier_net.h",
"mnist_mnistClassifier_net.h",
"CNNTranslator.h",
"mnist_mnistClassifier_calculateClass.h",
"CNNTrainer_mnist_mnistClassifier_net.py"));
}
@Test
public void testMnistClassifierForGluon() throws IOException, TemplateException {
Log.getFindings().clear();
String[] args = { "-m", "src/test/resources/models/", "-r", "mnist.MnistClassifier", "-b", "GLUON", "-f", "n",
"-c", "n" };
String[] args = {"-m", "src/test/resources/models/", "-r", "mnist.MnistClassifier", "-b", "GLUON", "-f", "n", "-c", "n"};
EMADLGeneratorCli.main(args);
assertTrue(Log.getFindings().isEmpty());
checkFilesAreEqual(Paths.get("./target/generated-sources-emadl"),
checkFilesAreEqual(
Paths.get("./target/generated-sources-emadl"),
Paths.get("./src/test/resources/target_code/gluon"),
Arrays.asList("CNNBufferFile.h", "CNNNet_mnist_mnistClassifier_net.py", "mnist_mnistClassifier.cpp",
"mnist_mnistClassifier.h", "CNNCreator_mnist_mnistClassifier_net.py",
"CNNPredictor_mnist_mnistClassifier_net.h", "CNNDataLoader_mnist_mnistClassifier_net.py",
"supervised_trainer.py", "mnist_mnistClassifier_net.h", "HelperA.h", "CNNTranslator.h",
"mnist_mnistClassifier_calculateClass.h", "CNNTrainer_mnist_mnistClassifier_net.py",
Arrays.asList(
"CNNBufferFile.h",
"CNNNet_mnist_mnistClassifier_net.py",
"mnist_mnistClassifier.cpp",
"mnist_mnistClassifier.h",
"CNNCreator_mnist_mnistClassifier_net.py",
"CNNPredictor_mnist_mnistClassifier_net.h",
"CNNDataLoader_mnist_mnistClassifier_net.py",
"CNNSupervisedTrainer_mnist_mnistClassifier_net.py",
"mnist_mnistClassifier_net.h",
"HelperA.h",
"CNNTranslator.h",
"mnist_mnistClassifier_calculateClass.h",
"CNNTrainer_mnist_mnistClassifier_net.py",
"mnist_mnistClassifier_net.h"));
}
@Test
public void testInvariantForGluon() throws IOException