Commit 5f326fc2 authored by Sebastian Nickels's avatar Sebastian Nickels

Merge rnn into develop

parents 20d0b24c 9f5fdbf4
Pipeline #150600 canceled with stages
......@@ -8,15 +8,15 @@
<groupId>de.monticore.lang.monticar</groupId>
<artifactId>cnnarch-mxnet-generator</artifactId>
<version>0.2.15-SNAPSHOT</version>
<version>0.2.16-SNAPSHOT</version>
<!-- == PROJECT DEPENDENCIES ============================================= -->
<properties>
<!-- .. SE-Libraries .................................................. -->
<CNNArch.version>0.3.0-SNAPSHOT</CNNArch.version>
<CNNTrain.version>0.2.6</CNNTrain.version>
<CNNArch.version>0.3.1-SNAPSHOT</CNNArch.version>
<CNNTrain.version>0.3.2-SNAPSHOT</CNNTrain.version>
<embedded-montiarc-math-opt-generator>0.1.4</embedded-montiarc-math-opt-generator>
<!-- .. Libraries .................................................. -->
......@@ -102,6 +102,12 @@
<scope>test</scope>
</dependency>
<dependency>
<groupId>com.github.stefanbirkner</groupId>
<artifactId>system-rules</artifactId>
<version>1.3.0</version>
</dependency>
<dependency>
<groupId>ch.qos.logback</groupId>
<artifactId>logback-classic</artifactId>
......
......@@ -90,6 +90,9 @@ public class ArchitectureElementData {
return getTemplateController().isSoftmaxOutput(getElement());
}
public boolean isOneHotOutput(){
return getTemplateController().isOneHotOutput(getElement());
}
......@@ -158,6 +161,11 @@ public class ArchitectureElementData {
.getDoubleValue(AllPredefinedLayers.BETA_NAME).get();
}
public int getSize(){
return ((LayerSymbol) getElement())
.getIntValue(AllPredefinedLayers.SIZE_NAME).get();
}
@Nullable
public String getPoolType(){
return ((LayerSymbol) getElement())
......
package de.monticore.lang.monticar.cnnarch.mxnetgenerator;
import de.monticore.lang.monticar.cnnarch._symboltable.ArchitectureSymbol;
import de.se_rwth.commons.logging.Log;
public class ArchitectureSupportChecker {
public ArchitectureSupportChecker() {}
// Overload functions returning always true to enable the features
protected boolean checkMultipleStreams(ArchitectureSymbol architecture) {
if (architecture.getStreams().size() != 1) {
Log.error("This cnn architecture has multiple instructions, " +
"which is currently not supported by the code generator. "
, architecture.getSourcePosition());
return false;
}
return true;
}
protected boolean checkMultipleInputs(ArchitectureSymbol architecture) {
if (architecture.getInputs().size() > 1) {
Log.error("This cnn architecture has multiple inputs, " +
"which is currently not supported by the code generator. "
, architecture.getSourcePosition());
return false;
}
return true;
}
protected boolean checkMultipleOutputs(ArchitectureSymbol architecture) {
if (architecture.getOutputs().size() > 1) {
Log.error("This cnn architecture has multiple outputs, " +
"which is currently not supported by the code generator. "
, architecture.getSourcePosition());
return false;
}
return true;
}
protected boolean checkMultiDimensionalOutput(ArchitectureSymbol architecture) {
if (architecture.getOutputs().get(0).getDefinition().getType().getWidth() != 1 ||
architecture.getOutputs().get(0).getDefinition().getType().getHeight() != 1) {
Log.error("This cnn architecture has a multi-dimensional output, " +
"which is currently not supported by the code generator."
, architecture.getSourcePosition());
return false;
}
return true;
}
public boolean check(ArchitectureSymbol architecture) {
return checkMultipleStreams(architecture)
&& checkMultipleInputs(architecture)
&& checkMultipleOutputs(architecture)
&& checkMultiDimensionalOutput(architecture);
}
}
......@@ -23,7 +23,6 @@ package de.monticore.lang.monticar.cnnarch.mxnetgenerator;
import de.monticore.lang.monticar.cnnarch.CNNArchGenerator;
import de.monticore.lang.monticar.cnnarch._symboltable.ArchitectureSymbol;
import de.monticore.lang.monticar.cnnarch.DataPathConfigParser;
import de.monticore.lang.monticar.cnnarch.mxnetgenerator.checker.AllowAllLayerSupportChecker;
import de.monticore.lang.monticar.generator.FileContent;
import de.monticore.lang.monticar.generator.cmake.CMakeConfig;
import de.monticore.lang.monticar.generator.cmake.CMakeFindModule;
......@@ -36,12 +35,14 @@ import java.util.HashMap;
import java.util.Map;
public class CNNArch2MxNet extends CNNArchGenerator {
public CNNArch2MxNet() {
setGenerationTargetPath("./target/generated-sources-cnnarch/");
}
public void generate(Scope scope, String rootModelName){
CNNArchSymbolCompiler symbolCompiler = new CNNArchSymbolCompiler(new AllowAllLayerSupportChecker());
CNNArchSymbolCompiler symbolCompiler = new CNNArchSymbolCompiler(new CNNArch2MxNetArchitectureSupportChecker(),
new CNNArch2MxNetLayerSupportChecker());
ArchitectureSymbol architectureSymbol = symbolCompiler.compileArchitectureSymbol(scope, rootModelName);
try{
......@@ -58,11 +59,8 @@ public class CNNArch2MxNet extends CNNArchGenerator {
//check cocos with CNNArchCocos.checkAll(architecture) before calling this method.
public Map<String, String> generateStrings(ArchitectureSymbol architecture){
TemplateConfiguration templateConfiguration = new MxNetTemplateConfiguration();
Map<String, String> fileContentMap = new HashMap<>();
CNNArch2MxNetTemplateController archTc
= new CNNArch2MxNetTemplateController(architecture, templateConfiguration);
CNNArch2MxNetTemplateController archTc = new CNNArch2MxNetTemplateController(architecture);
Map.Entry<String, String> temp;
temp = archTc.process("CNNPredictor", Target.CPP);
......@@ -77,8 +75,6 @@ public class CNNArch2MxNet extends CNNArchGenerator {
temp = archTc.process("CNNBufferFile", Target.CPP);
fileContentMap.put("CNNBufferFile.h", temp.getValue());
checkValidGeneration(architecture);
return fileContentMap;
}
......
package de.monticore.lang.monticar.cnnarch.mxnetgenerator;
public class CNNArch2MxNetArchitectureSupportChecker extends ArchitectureSupportChecker {
public CNNArch2MxNetArchitectureSupportChecker() {}
}
package de.monticore.lang.monticar.cnnarch.mxnetgenerator;
import de.monticore.lang.monticar.cnnarch.predefined.AllPredefinedLayers;
public class CNNArch2MxNetLayerSupportChecker extends LayerSupportChecker {
public CNNArch2MxNetLayerSupportChecker() {
supportedLayerList.add(AllPredefinedLayers.FULLY_CONNECTED_NAME);
supportedLayerList.add(AllPredefinedLayers.CONVOLUTION_NAME);
supportedLayerList.add(AllPredefinedLayers.SOFTMAX_NAME);
supportedLayerList.add(AllPredefinedLayers.SIGMOID_NAME);
supportedLayerList.add(AllPredefinedLayers.TANH_NAME);
supportedLayerList.add(AllPredefinedLayers.RELU_NAME);
supportedLayerList.add(AllPredefinedLayers.DROPOUT_NAME);
supportedLayerList.add(AllPredefinedLayers.POOLING_NAME);
supportedLayerList.add(AllPredefinedLayers.GLOBAL_POOLING_NAME);
supportedLayerList.add(AllPredefinedLayers.LRN_NAME);
supportedLayerList.add(AllPredefinedLayers.BATCHNORM_NAME);
supportedLayerList.add(AllPredefinedLayers.SPLIT_NAME);
supportedLayerList.add(AllPredefinedLayers.GET_NAME);
supportedLayerList.add(AllPredefinedLayers.ADD_NAME);
supportedLayerList.add(AllPredefinedLayers.CONCATENATE_NAME);
supportedLayerList.add(AllPredefinedLayers.FLATTEN_NAME);
supportedLayerList.add(AllPredefinedLayers.ONE_HOT_NAME);
}
}
......@@ -9,9 +9,8 @@ import java.io.Writer;
*/
public class CNNArch2MxNetTemplateController extends CNNArchTemplateController {
public CNNArch2MxNetTemplateController(ArchitectureSymbol architecture,
TemplateConfiguration templateConfiguration) {
super(architecture, templateConfiguration);
public CNNArch2MxNetTemplateController(ArchitectureSymbol architecture) {
super(architecture, new MxNetTemplateConfiguration());
}
public void include(IOSymbol ioElement, Writer writer){
......@@ -37,7 +36,7 @@ public class CNNArch2MxNetTemplateController extends CNNArchTemplateController {
if (layer.isAtomic()){
ArchitectureElementSymbol nextElement = layer.getOutputElement().get();
if (!isSoftmaxOutput(nextElement) && !isLogisticRegressionOutput(nextElement)){
if (!isSoftmaxOutput(nextElement) && !isLogisticRegressionOutput(nextElement) && !isOneHotOutput(nextElement)){
String templateName = layer.getDeclaration().getName();
include(TEMPLATE_ELEMENTS_DIR_PATH, templateName, writer);
}
......
......@@ -3,7 +3,6 @@ package de.monticore.lang.monticar.cnnarch.mxnetgenerator;
import de.monticore.io.paths.ModelPath;
import de.monticore.lang.monticar.cnnarch._cocos.CNNArchCocos;
import de.monticore.lang.monticar.cnnarch._symboltable.*;
import de.monticore.lang.monticar.cnnarch.mxnetgenerator.checker.LayerSupportChecker;
import de.monticore.symboltable.GlobalScope;
import de.monticore.symboltable.Scope;
import de.se_rwth.commons.logging.Log;
......@@ -13,10 +12,13 @@ import java.util.List;
import java.util.Optional;
public class CNNArchSymbolCompiler {
private final LayerSupportChecker layerChecker;
private final ArchitectureSupportChecker architectureSupportChecker;
private final LayerSupportChecker layerSupportChecker;
public CNNArchSymbolCompiler(final LayerSupportChecker layerChecker) {
this.layerChecker = layerChecker;
public CNNArchSymbolCompiler(final ArchitectureSupportChecker architectureSupportChecker,
final LayerSupportChecker layerSupportChecker) {
this.architectureSupportChecker = architectureSupportChecker;
this.layerSupportChecker = layerSupportChecker;
}
public ArchitectureSymbol compileArchitectureSymbolFromModelsDir(
......@@ -30,47 +32,21 @@ public class CNNArchSymbolCompiler {
Optional<CNNArchCompilationUnitSymbol> compilationUnit = scope.resolve(rootModelName, CNNArchCompilationUnitSymbol.KIND);
if (!compilationUnit.isPresent()){
failWithMessage("Could not resolve architecture " + rootModelName);
}
CNNArchCocos.checkAll(compilationUnit.get());
if (!supportCheck(compilationUnit.get().getArchitecture())){
ArchitectureSymbol architecture = compilationUnit.get().getArchitecture();
if (!architectureSupportChecker.check(architecture) || !layerSupportChecker.check(architecture)) {
failWithMessage("Architecture not supported by generator");
}
return compilationUnit.get().getArchitecture();
return architecture;
}
private void failWithMessage(final String message) {
Log.error(message);
System.exit(1);
}
private boolean supportCheck(ArchitectureSymbol architecture){
for (ArchitectureElementSymbol element : ((CompositeElementSymbol)architecture.getBody()).getElements()){
if(!isSupportedLayer(element, layerChecker)) {
return false;
}
}
return true;
}
private boolean isSupportedLayer(ArchitectureElementSymbol element, LayerSupportChecker layerChecker){
List<ArchitectureElementSymbol> constructLayerElemList;
if (element.getResolvedThis().get() instanceof CompositeElementSymbol) {
constructLayerElemList = ((CompositeElementSymbol)element.getResolvedThis().get()).getElements();
for (ArchitectureElementSymbol constructedLayerElement : constructLayerElemList) {
if (!isSupportedLayer(constructedLayerElement, layerChecker)) {
return false;
}
}
}
if (!layerChecker.isSupported(element.toString())) {
Log.error("Unsupported layer " + "'" + element.getName() + "'" + " for the backend.");
return false;
} else {
return true;
}
}
}
......@@ -23,6 +23,7 @@ package de.monticore.lang.monticar.cnnarch.mxnetgenerator;
import de.monticore.lang.monticar.cnnarch._symboltable.*;
import de.monticore.lang.monticar.cnnarch.predefined.Sigmoid;
import de.monticore.lang.monticar.cnnarch.predefined.Softmax;
import de.monticore.lang.monticar.cnnarch.predefined.OneHot;
import java.io.StringWriter;
import java.io.Writer;
......@@ -139,7 +140,7 @@ public abstract class CNNArchTemplateController {
public List<String> getLayerInputs(ArchitectureElementSymbol layer){
List<String> inputNames = new ArrayList<>();
if (isSoftmaxOutput(layer) || isLogisticRegressionOutput(layer)){
if (isSoftmaxOutput(layer) || isLogisticRegressionOutput(layer) || isOneHotOutput(layer)){
inputNames = getLayerInputs(layer.getInputElement().get());
} else {
for (ArchitectureElementSymbol input : layer.getPrevious()) {
......@@ -228,9 +229,13 @@ public abstract class CNNArchTemplateController {
public boolean isLinearRegressionOutput(ArchitectureElementSymbol architectureElement){
return architectureElement.isOutput()
&& !isLogisticRegressionOutput(architectureElement)
&& !isSoftmaxOutput(architectureElement);
&& !isSoftmaxOutput(architectureElement)
&& !isOneHotOutput(architectureElement);
}
public boolean isOneHotOutput(ArchitectureElementSymbol architectureElement){
return isTOutput(OneHot.class, architectureElement);
}
public boolean isSoftmaxOutput(ArchitectureElementSymbol architectureElement){
return isTOutput(Softmax.class, architectureElement);
......
......@@ -33,7 +33,10 @@ public class LayerNameCreator {
private Map<String, ArchitectureElementSymbol> nameToElement = new HashMap<>();
public LayerNameCreator(ArchitectureSymbol architecture) {
name(architecture.getBody(), 1, new ArrayList<>());
int stage = 1;
for (CompositeElementSymbol stream : architecture.getStreams()) {
stage = name(stream, stage, new ArrayList<>());
}
}
public ArchitectureElementSymbol getArchitectureElement(String name){
......
package de.monticore.lang.monticar.cnnarch.mxnetgenerator;
import de.monticore.lang.monticar.cnnarch.predefined.AllPredefinedLayers;
import de.monticore.lang.monticar.cnnarch._symboltable.ArchitectureElementSymbol;
import de.monticore.lang.monticar.cnnarch._symboltable.ArchitectureSymbol;
import de.monticore.lang.monticar.cnnarch._symboltable.CompositeElementSymbol;
import de.monticore.lang.monticar.cnnarch._symboltable.LayerDeclarationSymbol;
import de.monticore.lang.monticar.cnnarch._symboltable.LayerSymbol;
import de.se_rwth.commons.logging.Log;
import java.util.ArrayList;
import java.util.List;
public abstract class LayerSupportChecker {
protected List<String> supportedLayerList = new ArrayList<>();
private boolean isSupportedLayer(ArchitectureElementSymbol element){
ArchitectureElementSymbol resolvedElement = element.getResolvedThis().get();
List<ArchitectureElementSymbol> constructLayerElemList;
if (resolvedElement instanceof CompositeElementSymbol) {
constructLayerElemList = ((CompositeElementSymbol) resolvedElement).getElements();
for (ArchitectureElementSymbol constructedLayerElement : constructLayerElemList) {
if (!isSupportedLayer(constructedLayerElement)) {
return false;
}
}
return true;
}
// Support all inputs and outputs
if (resolvedElement.isInput() || resolvedElement.isOutput()) {
return true;
}
// Support all layer declarations
if (resolvedElement instanceof LayerSymbol) {
if (!((LayerSymbol) resolvedElement).getDeclaration().isPredefined()) {
return true;
}
}
if (!supportedLayerList.contains(element.toString())) {
Log.error("Unsupported layer " + "'" + element.getName() + "'" + " for the current backend.");
return false;
} else {
return true;
}
}
public boolean check(ArchitectureSymbol architecture) {
for (CompositeElementSymbol stream : architecture.getStreams()) {
for (ArchitectureElementSymbol element : stream.getElements()) {
if (!isSupportedLayer(element)) {
return false;
}
}
}
return true;
}
}
......@@ -9,7 +9,6 @@ public class MxNetTemplateConfiguration extends TemplateConfiguration {
private static Configuration configuration;
public MxNetTemplateConfiguration() {
super();
if (configuration == null) {
configuration = super.createConfiguration();
}
......
package de.monticore.lang.monticar.cnnarch.mxnetgenerator.checker;
import de.monticore.lang.monticar.cnnarch.mxnetgenerator.checker.LayerSupportChecker;
import java.util.ArrayList;
import java.util.List;
public class AllowAllLayerSupportChecker implements LayerSupportChecker {
private List<String> unsupportedLayerList = new ArrayList<>();
public AllowAllLayerSupportChecker() {
//Set the unsupported layers for the backend
//this.unsupportedLayerList.add(PREDEFINED_LAYER_NAME);
}
@Override
public boolean isSupported(String element) {
return !this.unsupportedLayerList.contains(element);
}
}
package de.monticore.lang.monticar.cnnarch.mxnetgenerator.checker;
public interface LayerSupportChecker {
boolean isSupported(String element);
}
\ No newline at end of file
......@@ -172,7 +172,7 @@ class ${tc.fileNameWithoutEnding}:
def construct(self, context, data_mean=None, data_std=None):
${tc.include(tc.architecture.body)}
${tc.include(tc.architecture.streams[0])}
self.module = mx.mod.Module(symbol=mx.symbol.Group([${tc.join(tc.architectureOutputs, ",")}]),
data_names=self._input_names_,
label_names=self._output_names_,
......
${element.name} = mx.symbol.one_hot(data=${element.inputs[0]},
indices=mx.symbol.argmax(data=${element.inputs[0]}, axis=1), depth=${element.size}))
<#include "OutputShape.ftl">
\ No newline at end of file
......@@ -8,4 +8,7 @@
<#elseif element.linearRegressionOutput>
${element.name} = mx.symbol.LinearRegressionOutput(data=${element.inputs[0]},
name="${element.name}")
<#elseif element.oneHotOutput>
${element.name} = mx.symbol.SoftmaxOutput(data=${element.inputs[0]},
name="${element.name}")
</#if>
\ No newline at end of file
......@@ -23,6 +23,7 @@ package de.monticore.lang.monticar.cnnarch.mxnetgenerator;
import de.se_rwth.commons.logging.Log;
import freemarker.template.TemplateException;
import org.junit.Before;
import org.junit.Rule;
import org.junit.Test;
import java.io.IOException;
......@@ -30,9 +31,13 @@ import java.nio.file.Path;
import java.nio.file.Paths;
import java.util.*;
import org.junit.contrib.java.lang.system.Assertion;
import org.junit.contrib.java.lang.system.ExpectedSystemExit;
import static junit.framework.TestCase.assertTrue;
public class GenerationTest extends AbstractSymtabTest{
@Rule
public final ExpectedSystemExit exit = ExpectedSystemExit.none();
@Before
public void setUp() {
......@@ -90,13 +95,17 @@ public class GenerationTest extends AbstractSymtabTest{
"execute_VGG16"));
}
@Test
public void testThreeInputCNNGeneration() throws IOException, TemplateException {
Log.getFindings().clear();
String[] args = {"-m", "src/test/resources/architectures", "-r", "ThreeInputCNN_M14"};
exit.expectSystemExit();
exit.checkAssertionAfterwards(new Assertion() {
public void checkAssertion() {
assertTrue(Log.getFindings().size() == 2);
}
});
CNNArch2MxNetCli.main(args);
assertTrue(Log.getFindings().size() == 1);
}
@Test
......@@ -107,12 +116,30 @@ public class GenerationTest extends AbstractSymtabTest{
assertTrue(Log.getFindings().isEmpty());
}
@Test
public void testMultipleStreams() throws IOException, TemplateException {
Log.getFindings().clear();
String[] args = {"-m", "src/test/resources/invalid_tests", "-r", "MultipleStreams"};
exit.expectSystemExit();
exit.checkAssertionAfterwards(new Assertion() {
public void checkAssertion() {
assertTrue(Log.getFindings().size() == 2);
}
});
CNNArch2MxNetCli.main(args);
}
@Test
public void testMultipleOutputs() throws IOException, TemplateException {
Log.getFindings().clear();
String[] args = {"-m", "src/test/resources/valid_tests", "-r", "MultipleOutputs"};
String[] args = {"-m", "src/test/resources/invalid_tests", "-r", "MultipleOutputs"};
exit.expectSystemExit();
exit.checkAssertionAfterwards(new Assertion() {
public void checkAssertion() {
assertTrue(Log.getFindings().size() == 2);
}
});
CNNArch2MxNetCli.main(args);
assertTrue(Log.getFindings().size() == 3);
}
@Test
......
......@@ -55,7 +55,7 @@ public class SymtabTest extends AbstractSymtabTest {
CNNArchCompilationUnitSymbol.KIND).orElse(null);
assertNotNull(a);
a.resolve();
a.getArchitecture().getBody().getOutputTypes();
a.getArchitecture().getStreams().get(0).getOutputTypes();
}
@Ignore
......@@ -67,7 +67,7 @@ public class SymtabTest extends AbstractSymtabTest {
CNNArchCompilationUnitSymbol.KIND).orElse(null);
assertNotNull(a);
a.resolve();
a.getArchitecture().getBody().getOutputTypes();
a.getArchitecture().getStreams().get(0).getOutputTypes();
}
@Ignore
......@@ -79,7 +79,7 @@ public class SymtabTest extends AbstractSymtabTest {
CNNArchCompilationUnitSymbol.KIND).orElse(null);
assertNotNull(a);
a.resolve();
a.getArchitecture().getBody().getOutputTypes();
a.getArchitecture().getStreams().get(0).getOutputTypes();
}
}
......@@ -39,5 +39,5 @@ architecture Alexnet(img_height=224, img_width=224, img_channels=3, classes=10){
fc(->=2) ->
FullyConnected(units=10) ->
Softmax() ->
predictions
predictions;
}
\ No newline at end of file
......@@ -40,5 +40,5 @@ architecture ResNeXt50(img_height=224, img_width=224, img_channels=3, classes=10
GlobalPooling(pool_type="avg") ->
FullyConnected(units=classes) ->
Softmax() ->
predictions
predictions;
}
\ No newline at end of file
......@@ -33,5 +33,5 @@ architecture ResNet152(img_height=224, img_width=224, img_channels=3, classes=10
GlobalPooling(pool_type="avg") ->
FullyConnected(units=classes) ->
Softmax() ->
predictions
predictions;
}
\ No newline at end of file
......@@ -31,5 +31,5 @@ architecture ResNet34(img_height=224, img_width=224, img_channels=3, classes=100
GlobalPooling(pool_type="avg") ->
FullyConnected(units=classes) ->
Softmax() ->
predictions
predictions;
}
......@@ -25,5 +25,5 @@ architecture SequentialAlexnet(img_height=224, img_width=224, img_channels=3, cl
fc() ->
FullyConnected(units=classes) ->
Softmax() ->
predictions
predictions;
}
......@@ -28,5 +28,5 @@ architecture ThreeInputCNN_M14(img_height=200, img_width=300, img_channels=3, cl
Relu() ->
FullyConnected(units=classes) ->
Softmax() ->
predictions
predictions;
}
\ No newline at end of file
......@@ -27,5 +27,5 @@ architecture VGG16(img_height=224, img_width=224, img_channels=3, classes=1000){
fc() ->
FullyConnected(units=classes) ->
Softmax() ->
predictions
predictions;
}
\ No newline at end of file
architecture ArgumentConstraintTest1(img_height=224, img_width=224, img_channels=3, classes=1000){
def input Z(0:255)^{img_channels, img_height, img_width} image
def output Q(0:1)^{classes} predictions
def conv(kernel, channels, stride=1, act=true){
Convolution(kernel=(kernel,kernel), channels=channels, stride=(stride,stride)) ->
BatchNorm() ->
Relu(?=act)
}
def skip(channels, stride){
Convolution(kernel=(1,1), channels=75, stride=(stride,stride)) ->
BatchNorm()
}
def resLayer(channels, stride=1){
(
conv(kernel=3, channels=channels, stride=stride) ->