Commit 646a35d3 authored by Thomas Michael Timmermanns's avatar Thomas Michael Timmermanns
Browse files

Refactoring. Added Tests.

parent 7fb61c34
......@@ -3,42 +3,38 @@
[![Build Status](https://circleci.com/gh/EmbeddedMontiArc/EmbeddedMontiArcDL/tree/master.svg?style=shield&circle-token=:circle-token)](https://circleci.com/gh/EmbeddedMontiArc/EmbeddedMontiArcDL/tree/master)
# EmbeddedMontiArcDL
**work in progress**
##Examples
```
package mnist;
Embeds [CNNArch](https://github.com/EmbeddedMontiArc/CNNArchLang), CNNTrain and MontiMath into EmbeddedMontiArc.
component SimpleCNN{
## Examples
In the following, we list common CNN architectures that are modeled inside an EMA component.
```
component LeNet{
ports in Z(0:255)^{1,28,28} data,
out Q(0:1)^{10} predictions;
implementation CNN {
data ->
Convolution(kernel=(5,5), channels=20) ->
Tanh() ->
Pooling(pool_type="max", kernel=(2,2), stride=(2,2)) ->
Convolution(kernel=(5,5), channels=20) ->
Convolution(kernel=(5,5), channels=50) ->
Tanh() ->
Pooling(pool_type="max", kernel=(2,2), stride=(2,2)) ->
FullyConnected(units=1000) ->
FullyConnected(units=500) ->
Tanh() ->
Dropout() ->
FullyConnected(units=10) ->
Softmax() ->
predictions
}
}
```
```
component VGG16{
component VGG16(Z(2:oo) classes){
ports in Z(0:255)^{3, 224, 224} image,
out Q(0:1)^{1000} predictions;
out Q(0:1)^{classes} predictions;
implementation CNN {
def conv(filter, channels){
Convolution(kernel=(filter,filter), channels=channels) ->
Relu()
......@@ -48,7 +44,6 @@ component VGG16{
Relu() ->
Dropout(p=0.5)
}
image ->
conv(filter=3, channels=64, ->=2) ->
Pooling(pool_type="max", kernel=(2,2), stride=(2,2)) ->
......@@ -62,17 +57,16 @@ component VGG16{
Pooling(pool_type="max", kernel=(2,2), stride=(2,2)) ->
fc() ->
fc() ->
FullyConnected(units=1000) ->
FullyConnected(units=classes) ->
Softmax() ->
predictions
}
}
```
```
component ResNet34{
component ResNet34(Z(2:oo) classes){
ports in Z(0:255)^{3, 224, 224} image,
out Q(0:1)^{1000} predictions;
out Q(0:1)^{classes} predictions;
implementation CNN {
def conv(filter, channels, stride=1, act=true){
......@@ -106,20 +100,18 @@ component ResNet34{
resLayer(channels=512, stride=2) ->
resLayer(channels=512, ->=2) ->
GlobalPooling(pool_type="avg") ->
FullyConnected(units=1000) ->
FullyConnected(units=classes) ->
Softmax() ->
predictions
}
}
```
```
component Alexnet{
component Alexnet(Z(2:oo) classes){
ports in Z(0:255)^{3, 224, 224} image,
out Q(0:1)^{10} predictions;
out Q(0:1)^{classes} predictions;
implementation CNN {
def split1(i){
[i] ->
Convolution(kernel=(5,5), channels=128) ->
......@@ -140,7 +132,6 @@ component Alexnet{
Relu() ->
Dropout()
}
image ->
Convolution(kernel=(11,11), channels=96, stride=(4,4), padding="no_loss") ->
Lrn(nsize=5, alpha=0.0001, beta=0.75) ->
......@@ -155,21 +146,18 @@ component Alexnet{
split2(i=[0|1]) ->
Concatenate() ->
fc(->=2) ->
FullyConnected(units=10) ->
FullyConnected(units=classes) ->
Softmax() ->
predictions
}
}
```
```
component ResNeXt50{
component ResNeXt50(Z(2:oo) classes){
ports in Z(0:255)^{3, 224, 224} image,
out Q(0:1)^{1000} predictions;
out Q(0:1)^{classes} predictions;
implementation CNN {
def conv(filter, channels, stride=1, act=true){
Convolution(kernel=filter, channels=channels, stride=(stride,stride)) ->
BatchNorm() ->
......@@ -209,7 +197,7 @@ component ResNeXt50{
resLayer(innerChannels=32, outChannels=2048, stride=2) ->
resLayer(innerChannels=32, outChannels=2048, -> = 2) ->
GlobalPooling(pool_type="avg") ->
FullyConnected(units=1000) ->
FullyConnected(units=classes) ->
Softmax() ->
predictions
}
......
File added
......@@ -61,7 +61,7 @@ public class PortArraySymbol2IODeclarationSymbol extends IODeclarationSymbol
if (shape.size() >= 3){
type.setWidthIndex(2);
}
type.setElementType(getElementType(ps));
type.setDomain(getElementType(ps));
type.setDimensionSymbols(shape);
setType(type);
......
......@@ -45,7 +45,7 @@ import java.nio.file.Paths;
import java.util.Arrays;
public class AbstractSymtab {
protected static TaggingResolver createSymTabAndTaggingResolver(String... modelPath) {
public static TaggingResolver createSymTabAndTaggingResolver(String... modelPath) {
Scope scope = createSymTab(modelPath);
TaggingResolver tagging = new TaggingResolver(scope, Arrays.asList(modelPath));
TagMinMaxTagSchema.registerTagTypes(tagging);
......
......@@ -52,18 +52,16 @@ import java.nio.file.Files;
import java.nio.file.Path;
import java.nio.file.Paths;
import java.util.*;
import java.util.regex.Matcher;
import java.util.regex.Pattern;
public class Generator {
public class EMADLGenerator {
public static final String CNN_HELPER = "CNNTranslator";
public static final String CNN_TRAINER = "CNNTrainer";
private GeneratorCPP emamGen;
public Generator() {
public EMADLGenerator() {
emamGen = new GeneratorCPP();
emamGen.useArmadilloBackend();
emamGen.setGenerationTargetPath("./target/generated-sources-emadl/");
......@@ -91,11 +89,38 @@ public class Generator {
return emamGen;
}
public void generate(String modelPath, String qualifiedName) throws IOException, TemplateException {
setModelsPath( modelPath );
TaggingResolver symtab = AbstractSymtab.createSymTabAndTaggingResolver(getModelsPath());
ComponentSymbol component = symtab.<ComponentSymbol>resolve(qualifiedName, ComponentSymbol.KIND).orElse(null);
List<String> splitName = Splitters.DOT.splitToList(qualifiedName);
String componentName = splitName.get(splitName.size() - 1);
String instanceName = componentName.substring(0, 1).toLowerCase() + componentName.substring(1);
if (component == null){
Log.error("Component with name '" + componentName + "' does not exist.");
System.exit(1);
}
ExpandedComponentInstanceSymbol instance = component.getEnclosingScope().<ExpandedComponentInstanceSymbol>resolve(instanceName, ExpandedComponentInstanceSymbol.KIND).get();
generateFiles(symtab, instance, symtab);
}
public void generateFiles(TaggingResolver taggingResolver, ExpandedComponentInstanceSymbol componentSymbol, Scope symtab) throws IOException {
List<FileContent> fileContents = generateStrings(taggingResolver, componentSymbol, symtab);
for (FileContent fileContent : fileContents) {
emamGen.generateFile(fileContent);
}
}
public List<FileContent> generateStrings(TaggingResolver taggingResolver, ExpandedComponentInstanceSymbol componentInstanceSymbol, Scope symtab){
List<FileContent> fileContents = new ArrayList<>();
Set<ExpandedComponentInstanceSymbol> allInstances = new HashSet<>();
generateStrings(fileContents, allInstances, taggingResolver, componentInstanceSymbol, symtab);
generateComponent(fileContents, allInstances, taggingResolver, componentInstanceSymbol, symtab);
fileContents.add(generateCNNTrainer(allInstances, componentInstanceSymbol.getComponentType().getFullName().replaceAll("\\.", "_")));
fileContents.add(ArmadilloHelper.getArmadilloHelperFileContent());
......@@ -113,11 +138,11 @@ public class Generator {
return fileContents;
}
protected void generateStrings(List<FileContent> fileContents,
Set<ExpandedComponentInstanceSymbol> allInstances,
TaggingResolver taggingResolver,
ExpandedComponentInstanceSymbol componentInstanceSymbol,
Scope symtab){
protected void generateComponent(List<FileContent> fileContents,
Set<ExpandedComponentInstanceSymbol> allInstances,
TaggingResolver taggingResolver,
ExpandedComponentInstanceSymbol componentInstanceSymbol,
Scope symtab){
allInstances.add(componentInstanceSymbol);
ASTComponent astComponent = (ASTComponent) componentInstanceSymbol.getComponentType().getReferencedSymbol().getAstNode().get();
......@@ -218,7 +243,7 @@ public class Generator {
Log.info(generateComponentInstance + "", "Bool:");
}
if (generateComponentInstance) {
generateStrings(fileContents, allInstances, taggingResolver, instanceSymbol, symtab);
generateComponent(fileContents, allInstances, taggingResolver, instanceSymbol, symtab);
}
}
}
......@@ -252,9 +277,9 @@ public class Generator {
private String getTrainingParamsForComponent(String mainComponentName, ComponentSymbol component, ExpandedComponentInstanceSymbol instance) {
String configFilename;
String mainComponentConfigFilename = mainComponentName + "Config";
String componentConfigFilename = component.getFullName().replaceAll("\\.", "/") + "Config";
String instanceConfigFilename = component.getFullName().replaceAll("\\.", "/") + "_" + instance.getName() + "Config";
String mainComponentConfigFilename = mainComponentName.replaceAll("\\.", "/");
String componentConfigFilename = component.getFullName().replaceAll("\\.", "/");
String instanceConfigFilename = component.getFullName().replaceAll("\\.", "/") + "_" + instance.getName();
if (Files.exists(Paths.get( getModelsPath() + instanceConfigFilename + ".cnnt"))) {
configFilename = instanceConfigFilename;
}
......@@ -285,33 +310,6 @@ public class Generator {
return fileContents.getValue();
}
public void generateFiles(TaggingResolver taggingResolver, ExpandedComponentInstanceSymbol componentSymbol, Scope symtab) throws IOException {
List<FileContent> fileContents = generateStrings(taggingResolver, componentSymbol, symtab);
for (FileContent fileContent : fileContents) {
emamGen.generateFile(fileContent);
}
}
public void generate(String modelPath, String qualifiedName) throws IOException, TemplateException {
setModelsPath( modelPath );
TaggingResolver symtab = AbstractSymtab.createSymTabAndTaggingResolver(getModelsPath());
ComponentSymbol component = symtab.<ComponentSymbol>resolve(qualifiedName, ComponentSymbol.KIND).orElse(null);
List<String> splitName = Splitters.DOT.splitToList(qualifiedName);
String componentName = splitName.get(splitName.size() - 1);
String instanceName = componentName.substring(0, 1).toLowerCase() + componentName.substring(1);
if (component == null){
Log.error("Component with name '" + componentName + "' does not exist.");
System.exit(1);
}
ExpandedComponentInstanceSymbol instance = component.getEnclosingScope().<ExpandedComponentInstanceSymbol>resolve(instanceName, ExpandedComponentInstanceSymbol.KIND).get();
generateFiles(symtab, instance, symtab);
}
protected String processTemplate(Map<String, Object> ftlContext, String templateNameWithoutEnding){
StringWriter writer = new StringWriter();
String templateName = templateNameWithoutEnding + ".ftl";
......
......@@ -27,11 +27,11 @@ import org.apache.commons.cli.*;
import java.io.IOException;
public class GeneratorCli{
public class EMADLGeneratorCli {
public static final Option OPTION_MODELS_PATH = Option.builder("m")
.longOpt("models-dir")
.desc("full path to directory with EMAM models e.g. C:\\Users\\vpupkin\\proj\\MyAwesomeAutopilot\\src\\main\\emam")
.desc("full path to directory with EMADL models e.g. C:\\Users\\vpupkin\\proj\\MyAwesomeAutopilot\\src\\main\\emam")
.hasArg(true)
.required(true)
.build();
......@@ -50,7 +50,7 @@ public class GeneratorCli{
.required(false)
.build();
private GeneratorCli() {
private EMADLGeneratorCli() {
}
public static void main(String[] args) {
......@@ -85,7 +85,7 @@ public class GeneratorCli{
private static void runGenerator(CommandLine cliArgs) {
String rootModelName = cliArgs.getOptionValue(OPTION_ROOT_MODEL.getOpt());
String outputPath = cliArgs.getOptionValue(OPTION_OUTPUT_PATH.getOpt());
Generator generator = new Generator();
EMADLGenerator generator = new EMADLGenerator();
if (outputPath != null){
generator.setGenerationTargetPath(outputPath);
}
......
/**
*
* ******************************************************************************
* MontiCAR Modeling Family, www.se-rwth.de
* Copyright (c) 2017, Software Engineering Group at RWTH Aachen,
* All rights reserved.
*
* This project is free software; you can redistribute it and/or
* modify it under the terms of the GNU Lesser General Public
* License as published by the Free Software Foundation; either
* version 3.0 of the License, or (at your option) any later version.
* This library is distributed in the hope that it will be useful,
* but WITHOUT ANY WARRANTY; without even the implied warranty of
* MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the GNU
* Lesser General Public License for more details.
*
* You should have received a copy of the GNU Lesser General Public
* License along with this project. If not, see <http://www.gnu.org/licenses/>.
* *******************************************************************************
*/
package de.monticore.lang.monticar.emadl.helper;
/*
contains toString functions which translate AST nodes into a String for generation
*/
public interface ASTPrinter {
/*static String toString(ASTNumber number) {
if (number.getUnitNumber().isPresent()) {
return number.getUnitNumber().get().toString();
} else if (number.getFloatPointUnitNumber().isPresent()) {
return number.getFloatPointUnitNumber().get().getTFloatPointUnitNumber();
} else if (number.getComplexNumber().isPresent()) {
return number.getComplexNumber().get().toString();
} else if (number.getHexUnitNumber().isPresent()) {
return number.getHexUnitNumber().get().getTHexUnitNumber();
} else {
return null;
}
}*/
/*static String toString(ASTTuple tuple) {
List<String> stringList = new LinkedList<>();
for (ASTNumber number : tuple.getValues()) {
stringList.add(toString(number));
}
String res = Joiners.COMMA.join(stringList);
res = "(" + res + ")";
return res;
}*/
/*static String toString(ASTArgumentRhs rhs) {
if (rhs == null){
return null;
}
if (rhs.getType().isPresent()) {
return rhs.getType().get().getName().get();
}
else if(rhs.getNumber().isPresent()) {
return toString(rhs.getNumber().get());
}
else if(rhs.getTuple().isPresent()) {
return toString(rhs.getTuple().get());
}
else if (rhs.getBooleanVal().isPresent()){
return rhs.getBooleanVal().get().name().toLowerCase();
}
else {
//should never be reached
return null;
}
}
static String toString(ASTParameterRhs rhs) {
if (rhs == null){
return null;
}
if (rhs.getStringVal().isPresent()) {
return rhs.getStringVal().get();
}
else if(rhs.getNumber().isPresent()) {
return toString(rhs.getNumber().get());
}
else if (rhs.getBooleanVal().isPresent()){
return rhs.getBooleanVal().get();
}
else {
return rhs.getRef().get();
}
}*/
}
......@@ -20,20 +20,27 @@
*/
package de.monticore.lang.monticar.emadl;
import de.monticore.lang.monticar.emadl.generator.Generator;
import de.monticore.lang.embeddedmontiarc.embeddedmontiarc._symboltable.ComponentSymbol;
import de.monticore.lang.embeddedmontiarc.embeddedmontiarc._symboltable.ExpandedComponentInstanceSymbol;
import de.monticore.lang.monticar.emadl.generator.AbstractSymtab;
import de.monticore.lang.monticar.emadl.generator.EMADLGenerator;
import de.monticore.lang.monticar.generator.FileContent;
import de.monticore.lang.tagging._symboltable.TaggingResolver;
import de.se_rwth.commons.Splitters;
import de.se_rwth.commons.logging.Log;
import freemarker.template.TemplateException;
import org.junit.Before;
import org.junit.Ignore;
import org.junit.Test;
import java.io.File;
import java.io.IOException;
import java.nio.file.Path;
import java.nio.file.Paths;
import java.util.List;
import java.util.Scanner;
import static de.monticore.lang.monticar.emadl.ParserTest.ENABLE_FAIL_QUICK;
import static junit.framework.TestCase.assertEquals;
import static junit.framework.TestCase.assertTrue;
import static de.se_rwth.commons.logging.Log.getFindings;
public class GenerationTest {
......@@ -45,21 +52,86 @@ public class GenerationTest {
}
private void generate(String qualifiedName) throws IOException, TemplateException{
Generator gen = new Generator();
gen.generate("src/test/resources/", qualifiedName);
EMADLGenerator gen = new EMADLGenerator();
gen.generate("src/test/resources/models/", qualifiedName);
}
@Test
public void testMnistGeneration() throws IOException, TemplateException {
generate("mnist.MnistClassifier");
assertTrue(Log.getFindings().isEmpty());
private List<FileContent> generateStrings(String modelsDirPath, String qualifiedName) throws IOException, TemplateException {
EMADLGenerator gen = new EMADLGenerator();
gen.setModelsPath( modelsDirPath );
TaggingResolver symtab = AbstractSymtab.createSymTabAndTaggingResolver(gen.getModelsPath());
ComponentSymbol component = symtab.<ComponentSymbol>resolve(qualifiedName, ComponentSymbol.KIND).orElse(null);
List<String> splitName = Splitters.DOT.splitToList(qualifiedName);
String componentName = splitName.get(splitName.size() - 1);
String instanceName = componentName.substring(0, 1).toLowerCase() + componentName.substring(1);
if (component == null){
Log.error("Component with name '" + componentName + "' does not exist.");
System.exit(1);
}
ExpandedComponentInstanceSymbol instance = component.getEnclosingScope().<ExpandedComponentInstanceSymbol>resolve(instanceName, ExpandedComponentInstanceSymbol.KIND).get();
return gen.generateStrings(symtab, instance, symtab);
}
private String readFileFromResources(String relativePath) throws IOException{
ClassLoader classLoader = getClass().getClassLoader();
File file = new File(classLoader.getResource(relativePath).getFile());
Scanner scanner = new Scanner(file);
scanner.useDelimiter("\\Z");
String content = scanner.next() + "\n";
scanner.close();
return content;
}
@Test
@Ignore
public void testCifar10Generation() throws IOException, TemplateException {
generate("cifar10.Cifar10Classifier");
//generate("cifar10.Cifar10Classifier");
Log.getFindings().clear();
List<FileContent> fileContents = generateStrings(
"src/test/resources/models/",
"cifar10.Cifar10Classifier");
assertTrue(Log.getFindings().isEmpty());
for (FileContent fileContent : fileContents){
switch (fileContent.getFileName()){
case "cifar10_cifar10Classifier.h":
assertEquals(fileContent.getFileContent(),
readFileFromResources("target_code/cifar10_cifar10Classifier.h"));
break;
case "CNNCreator_cifar10_cifar10Classifier_net.py":
assertEquals(fileContent.getFileContent(),
readFileFromResources("target_code/CNNCreator_cifar10_cifar10Classifier_net.py"));
break;
case "CNNBufferFile.h":
assertEquals(fileContent.getFileContent(),
readFileFromResources("target_code/CNNBufferFile.h"));
break;
case "CNNPredictor_cifar10_cifar10Classifier_net.h":
assertEquals(fileContent.getFileContent(),
readFileFromResources("target_code/CNNPredictor_cifar10_cifar10Classifier_net.h"));
break;
case "cifar10_cifar10Classifier_net.h":
assertEquals(fileContent.getFileContent(),
readFileFromResources("target_code/cifar10_cifar10Classifier_net.h"));
break;
case "CNNTranslator.h":
assertEquals(fileContent.getFileContent(),
readFileFromResources("target_code/CNNTranslator.h"));
break;
case "cifar10_cifar10Classifier_calculateClass.h":
assertEquals(fileContent.getFileContent(),
readFileFromResources("target_code/cifar10_cifar10Classifier_calculateClass.h"));
break;
case "CNNTrainer_cifar10_Cifar10Classifier.py":
assertEquals(fileContent.getFileContent(),
readFileFromResources("target_code/CNNTrainer_cifar10_Cifar10Classifier.py"));
break;
}
}
}
@Test
......
......@@ -22,8 +22,8 @@ package de.monticore.lang.monticar.emadl;
import de.monticore.lang.embeddedmontiarc.embeddedmontiarc._symboltable.ExpandedComponentInstanceSymbol;
import de.monticore.lang.monticar.cnnarch._symboltable.ArchitectureSymbol;
import de.monticore.lang.monticar.cnnarch._symboltable.CompositeLayerSymbol;
import de.monticore.lang.monticar.cnnarch._symboltable.MethodLayerSymbol;
import de.monticore.lang.monticar.cnnarch._symboltable.CompositeElementSymbol;
import de.monticore.lang.monticar.cnnarch._symboltable.LayerSymbol;
import de.monticore.symboltable.Scope;
import de.se_rwth.commons.logging.Log;
import org.junit.Before;
......@@ -47,7 +47,7 @@ public class InstanceTest extends AbstractSymtabTest {
@Test
public void testInstances(){
Scope symtab = createSymTab("src/test/resources/");
Scope symtab = createSymTab("src/test/resources/models/");
ExpandedComponentInstanceSymbol mainInstance = symtab.<ExpandedComponentInstanceSymbol>
resolve("InstanceTest.mainB", ExpandedComponentInstanceSymbol.KIND).get();
ExpandedComponentInstanceSymbol net1 = mainInstance.getSpannedScope().<ExpandedComponentInstanceSymbol>
......@@ -64,8 +64,8 @@ public class InstanceTest extends AbstractSymtabTest {
arch1.resolve();
arch2.resolve();
int convChannels1 = ((MethodLayerSymbol)((CompositeLayerSymbol)arch1.getBody()).getLayers().get(1)).getArgument("channels").get().getRhs().getIntValue().get();
int convChannels2 = ((MethodLayerSymbol)((CompositeLayerSymbol)arch2.getBody()).getLayers().get(1)).getArgument("channels").get().getRhs().getIntValue().get();
int convChannels1 = ((LayerSymbol)((CompositeElementSymbol)arch1.getBody()).getElements().get(1)).getArgument("channels").get().getRhs().getIntValue().get();
int convChannels2 = ((LayerSymbol)((CompositeElementSymbol)arch2.getBody()).getElements().get(1)).getArgument("channels").get().getRhs().getIntValue().get();
assertEquals(20, convChannels1);
assertEquals(40, convChannels2);
......
......@@ -41,11 +41,7 @@ import static junit.framework.TestCase.assertTrue;
public class ParserTest {