Commit a5585368 authored by Evgeny Kusmenko's avatar Evgeny Kusmenko

Merge branch 'develop' into 'master'

Implemented layer variables

See merge request !3
parents 6f26fc6a acf49043
Pipeline #171086 passed with stages
in 4 minutes and 51 seconds
......@@ -8,14 +8,14 @@
<groupId>de.monticore.lang.monticar</groupId>
<artifactId>cnnarch-generator</artifactId>
<version>0.0.2-SNAPSHOT</version>
<version>0.0.3-SNAPSHOT</version>
<!-- == PROJECT DEPENDENCIES ============================================= -->
<properties>
<!-- .. SE-Libraries .................................................. -->
<CNNArch.version>0.3.1-SNAPSHOT</CNNArch.version>
<CNNArch.version>0.3.2-SNAPSHOT</CNNArch.version>
<CNNTrain.version>0.3.4-SNAPSHOT</CNNTrain.version>
<embedded-montiarc-math-opt-generator>0.1.4</embedded-montiarc-math-opt-generator>
......
......@@ -24,6 +24,8 @@ import de.monticore.lang.monticar.cnnarch._symboltable.ArchTypeSymbol;
import de.monticore.lang.monticar.cnnarch._symboltable.ArchitectureElementSymbol;
import de.monticore.lang.monticar.cnnarch._symboltable.ConstantSymbol;
import de.monticore.lang.monticar.cnnarch._symboltable.LayerSymbol;
import de.monticore.lang.monticar.cnnarch._symboltable.LayerVariableDeclarationSymbol;
import de.monticore.lang.monticar.cnnarch._symboltable.VariableSymbol;
import de.monticore.lang.monticar.cnnarch.predefined.AllPredefinedLayers;
import de.se_rwth.commons.logging.Log;
......@@ -67,94 +69,101 @@ public class ArchitectureElementData {
this.templateController = templateController;
}
private LayerSymbol getLayerSymbol() {
if (getElement() instanceof VariableSymbol) {
return ((VariableSymbol) getElement()).getLayerVariableDeclaration().getLayer();
}
else {
assert getElement() instanceof LayerSymbol;
return (LayerSymbol) getElement();
}
}
public List<String> getInputs(){
return getTemplateController().getLayerInputs(getElement());
}
public String getMember() {
assert getElement() instanceof VariableSymbol;
return ((VariableSymbol) getElement()).getMember().toString();
}
public int getConstValue() {
ConstantSymbol constant = (ConstantSymbol) getElement();
return constant.getExpression().getIntValue().get();
assert getElement() instanceof ConstantSymbol;
return ((ConstantSymbol) getElement()).getExpression().getIntValue().get();
}
public List<Integer> getKernel(){
return ((LayerSymbol) getElement())
.getIntTupleValue(AllPredefinedLayers.KERNEL_NAME).get();
return getLayerSymbol().getIntTupleValue(AllPredefinedLayers.KERNEL_NAME).get();
}
public int getChannels(){
return ((LayerSymbol) getElement())
.getIntValue(AllPredefinedLayers.CHANNELS_NAME).get();
return getLayerSymbol().getIntValue(AllPredefinedLayers.CHANNELS_NAME).get();
}
public List<Integer> getStride(){
return ((LayerSymbol) getElement())
.getIntTupleValue(AllPredefinedLayers.STRIDE_NAME).get();
return getLayerSymbol().getIntTupleValue(AllPredefinedLayers.STRIDE_NAME).get();
}
public int getUnits(){
return ((LayerSymbol) getElement())
.getIntValue(AllPredefinedLayers.UNITS_NAME).get();
return getLayerSymbol().getIntValue(AllPredefinedLayers.UNITS_NAME).get();
}
public boolean getNoBias(){
return ((LayerSymbol) getElement())
.getBooleanValue(AllPredefinedLayers.NOBIAS_NAME).get();
return getLayerSymbol().getBooleanValue(AllPredefinedLayers.NOBIAS_NAME).get();
}
public double getP(){
return ((LayerSymbol) getElement())
.getDoubleValue(AllPredefinedLayers.P_NAME).get();
return getLayerSymbol().getDoubleValue(AllPredefinedLayers.P_NAME).get();
}
public int getIndex(){
return ((LayerSymbol) getElement())
.getIntValue(AllPredefinedLayers.INDEX_NAME).get();
return getLayerSymbol().getIntValue(AllPredefinedLayers.INDEX_NAME).get();
}
public int getNumOutputs(){
return ((LayerSymbol) getElement())
.getIntValue(AllPredefinedLayers.NUM_SPLITS_NAME).get();
return getLayerSymbol().getIntValue(AllPredefinedLayers.NUM_SPLITS_NAME).get();
}
public boolean getFixGamma(){
return ((LayerSymbol) getElement())
.getBooleanValue(AllPredefinedLayers.FIX_GAMMA_NAME).get();
return getLayerSymbol().getBooleanValue(AllPredefinedLayers.FIX_GAMMA_NAME).get();
}
public int getNsize(){
return ((LayerSymbol) getElement())
.getIntValue(AllPredefinedLayers.NSIZE_NAME).get();
return getLayerSymbol().getIntValue(AllPredefinedLayers.NSIZE_NAME).get();
}
public double getKnorm(){
return ((LayerSymbol) getElement())
.getDoubleValue(AllPredefinedLayers.KNORM_NAME).get();
return getLayerSymbol().getDoubleValue(AllPredefinedLayers.KNORM_NAME).get();
}
public double getAlpha(){
return ((LayerSymbol) getElement())
.getDoubleValue(AllPredefinedLayers.ALPHA_NAME).get();
return getLayerSymbol().getDoubleValue(AllPredefinedLayers.ALPHA_NAME).get();
}
public double getBeta(){
return ((LayerSymbol) getElement())
.getDoubleValue(AllPredefinedLayers.BETA_NAME).get();
return getLayerSymbol().getDoubleValue(AllPredefinedLayers.BETA_NAME).get();
}
public int getSize(){
return ((LayerSymbol) getElement())
.getIntValue(AllPredefinedLayers.SIZE_NAME).get();
return getLayerSymbol().getIntValue(AllPredefinedLayers.SIZE_NAME).get();
}
public int getLayers(){
return getLayerSymbol().getIntValue(AllPredefinedLayers.LAYERS_NAME).get();
}
@Nullable
public String getPoolType(){
return ((LayerSymbol) getElement())
.getStringValue(AllPredefinedLayers.POOL_TYPE_NAME).get();
return getLayerSymbol().getStringValue(AllPredefinedLayers.POOL_TYPE_NAME).get();
}
@Nullable
public List<Integer> getPadding(){
return getPadding((LayerSymbol) getElement());
return getPadding(getLayerSymbol());
}
@Nullable
......
......@@ -4,10 +4,12 @@ import de.monticore.lang.monticar.cnnarch._symboltable.ArchitectureSymbol;
import de.monticore.lang.monticar.cnnarch._symboltable.ArchitectureElementSymbol;
import de.monticore.lang.monticar.cnnarch._symboltable.CompositeElementSymbol;
import de.monticore.lang.monticar.cnnarch._symboltable.ConstantSymbol;
import de.monticore.lang.monticar.cnnarch._symboltable.IODeclarationSymbol;
import de.monticore.lang.monticar.cnnarch._symboltable.SerialCompositeElementSymbol;
import de.monticore.lang.monticar.cnnarch._symboltable.VariableDeclarationSymbol;
import de.se_rwth.commons.logging.Log;
import java.util.List;
import java.util.*;
public abstract class ArchitectureSupportChecker {
......@@ -49,8 +51,9 @@ public abstract class ArchitectureSupportChecker {
}
protected boolean checkMultiDimensionalOutput(ArchitectureSymbol architecture) {
if (architecture.getOutputs().get(0).getDefinition().getType().getWidth() != 1 ||
architecture.getOutputs().get(0).getDefinition().getType().getHeight() != 1) {
IODeclarationSymbol ioDeclaration = (IODeclarationSymbol) architecture.getOutputs().get(0).getDeclaration();
if (ioDeclaration.getType().getWidth() != 1 || ioDeclaration.getType().getHeight() != 1) {
Log.error("This cnn architecture has a multi-dimensional output, " +
"which is currently not supported by the code generator."
, architecture.getSourcePosition());
......@@ -61,7 +64,7 @@ public abstract class ArchitectureSupportChecker {
return true;
}
protected boolean hasConstant(ArchitectureElementSymbol element) {
private boolean hasConstant(ArchitectureElementSymbol element) {
ArchitectureElementSymbol resolvedElement = element.getResolvedThis().get();
if (resolvedElement instanceof CompositeElementSymbol) {
......@@ -94,11 +97,22 @@ public abstract class ArchitectureSupportChecker {
return true;
}
protected boolean checkLayerVariables(ArchitectureSymbol architecture) {
if (!architecture.getLayerVariableDeclarations().isEmpty()) {
Log.error("This cnn architecture uses layer variables, which are currently not supported by the code generator."
, architecture.getSourcePosition());
return false;
}
return true;
}
public boolean check(ArchitectureSymbol architecture) {
return checkMultipleStreams(architecture)
&& checkMultipleInputs(architecture)
&& checkMultipleOutputs(architecture)
&& checkMultiDimensionalOutput(architecture)
&& checkConstants(architecture);
&& checkConstants(architecture)
&& checkLayerVariables(architecture);
}
}
......@@ -126,18 +126,5 @@ public abstract class CNNArchGenerator {
}
}
public Map<String, String> generateCMakeContent(String rootModelName) {
// model name should start with a lower case letter. If it is a component, replace dot . by _
rootModelName = rootModelName.replace('.', '_').replace('[', '_').replace(']', '_');
rootModelName = rootModelName.substring(0, 1).toLowerCase() + rootModelName.substring(1);
CMakeConfig cMakeConfig = new CMakeConfig(rootModelName);
cMakeConfig.addModuleDependency(new CMakeFindModule("Armadillo", true));
cMakeConfig.addCMakeCommand("set(LIBS ${LIBS} mxnet)");
Map<String,String> fileContentMap = new HashMap<>();
for (FileContent fileContent : cMakeConfig.generateCMakeFiles()){
fileContentMap.put(fileContent.getFileName(), fileContent.getFileContent());
}
return fileContentMap;
}}
public abstract Map<String, String> generateCMakeContent(String rootModelName);
}
......@@ -154,16 +154,16 @@ public abstract class CNNArchTemplateController {
public List<String> getArchitectureInputs(){
List<String> list = new ArrayList<>();
for (IOSymbol ioElement : getArchitecture().getInputs()){
list.add(nameManager.getName(ioElement));
for (VariableSymbol element : getArchitecture().getInputs()){
list.add(nameManager.getName(element));
}
return list;
}
public List<String> getArchitectureOutputs(){
List<String> list = new ArrayList<>();
for (IOSymbol ioElement : getArchitecture().getOutputs()){
list.add(nameManager.getName(ioElement));
for (VariableSymbol element : getArchitecture().getOutputs()){
list.add(nameManager.getName(element));
}
return list;
}
......
......@@ -100,19 +100,43 @@ public class LayerNameCreator {
}
elementToName.put(architectureElement, name);
nameToElement.put(name, architectureElement);
boolean isLayerVariable = false;
if (architectureElement instanceof VariableSymbol) {
isLayerVariable = ((VariableSymbol) architectureElement).getType() == VariableSymbol.Type.LAYER;
}
// Do not map names of layer variables to their respective element since the names are not unique
// for now the name to element mapping is not used anywhere so it doesn't matter
if (!isLayerVariable) {
nameToElement.put(name, architectureElement);
}
}
return endStage;
}
protected String createName(ArchitectureElementSymbol architectureElement, int stage, List<Integer> streamIndices){
if (architectureElement instanceof IOSymbol){
if (architectureElement instanceof VariableSymbol) {
VariableSymbol element = (VariableSymbol) architectureElement;
String name = createBaseName(architectureElement);
IOSymbol ioElement = (IOSymbol) architectureElement;
if (ioElement.getArrayAccess().isPresent()){
int arrayAccess = ioElement.getArrayAccess().get().getIntValue().get();
name = name + "_" + arrayAccess + "_";
if (element.getType() == VariableSymbol.Type.IO) {
if (element.getArrayAccess().isPresent()){
int arrayAccess = element.getArrayAccess().get().getIntValue().get();
name = name + "_" + arrayAccess + "_";
} else {
name = name + "_";
}
} else if (element.getType() == VariableSymbol.Type.LAYER) {
if (element.getMember() == VariableSymbol.Member.STATE) {
name = name + "_state_";
} else {
name = name + "_output_";
}
}
return name;
} else {
return createBaseName(architectureElement) + stage + createStreamPostfix(streamIndices) + "_";
......
......@@ -7,6 +7,7 @@ import de.monticore.lang.monticar.cnnarch._symboltable.CompositeElementSymbol;
import de.monticore.lang.monticar.cnnarch._symboltable.ConstantSymbol;
import de.monticore.lang.monticar.cnnarch._symboltable.LayerDeclarationSymbol;
import de.monticore.lang.monticar.cnnarch._symboltable.LayerSymbol;
import de.monticore.lang.monticar.cnnarch._symboltable.VariableSymbol;
import de.se_rwth.commons.logging.Log;
import java.util.ArrayList;
......@@ -31,8 +32,13 @@ public abstract class LayerSupportChecker {
}
// Support all inputs and outputs
if (resolvedElement.isInput() || resolvedElement.isOutput()) {
return true;
if (resolvedElement instanceof VariableSymbol) {
if (((VariableSymbol) resolvedElement).getType() == VariableSymbol.Type.LAYER) {
return isSupportedLayer(((VariableSymbol) resolvedElement).getLayerVariableDeclaration().getLayer());
}
else if (resolvedElement.isInput() || resolvedElement.isOutput()) {
return true;
}
}
// Support for constants is checked in ArchitectureSupportChecker
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment