Aufgrund einer Wartung wird GitLab am 17.08. zwischen 8:30 und 9:00 Uhr kurzzeitig nicht zur Verfügung stehen. / Due to maintenance, GitLab will be temporarily unavailable on 17.08. between 8:30 and 9:00 am.

Commit bbf3a334 authored by Sebastian N.'s avatar Sebastian N.

Renamed VariableSymbol to ParameterSymbol, introduced layer variable...

Renamed VariableSymbol to ParameterSymbol, introduced layer variable declarations and changed IOSymbol to VariableSymbol which now combines IO variables and layer variables
parent 8e6562a8
Pipeline #170396 failed with stages
in 2 minutes and 14 seconds
...@@ -8,14 +8,14 @@ ...@@ -8,14 +8,14 @@
<groupId>de.monticore.lang.monticar</groupId> <groupId>de.monticore.lang.monticar</groupId>
<artifactId>cnnarch-generator</artifactId> <artifactId>cnnarch-generator</artifactId>
<version>0.0.2-SNAPSHOT</version> <version>0.0.3-SNAPSHOT</version>
<!-- == PROJECT DEPENDENCIES ============================================= --> <!-- == PROJECT DEPENDENCIES ============================================= -->
<properties> <properties>
<!-- .. SE-Libraries .................................................. --> <!-- .. SE-Libraries .................................................. -->
<CNNArch.version>0.3.1-SNAPSHOT</CNNArch.version> <CNNArch.version>0.3.2-SNAPSHOT</CNNArch.version>
<CNNTrain.version>0.3.4-SNAPSHOT</CNNTrain.version> <CNNTrain.version>0.3.4-SNAPSHOT</CNNTrain.version>
<embedded-montiarc-math-opt-generator>0.1.4</embedded-montiarc-math-opt-generator> <embedded-montiarc-math-opt-generator>0.1.4</embedded-montiarc-math-opt-generator>
......
...@@ -24,6 +24,8 @@ import de.monticore.lang.monticar.cnnarch._symboltable.ArchTypeSymbol; ...@@ -24,6 +24,8 @@ import de.monticore.lang.monticar.cnnarch._symboltable.ArchTypeSymbol;
import de.monticore.lang.monticar.cnnarch._symboltable.ArchitectureElementSymbol; import de.monticore.lang.monticar.cnnarch._symboltable.ArchitectureElementSymbol;
import de.monticore.lang.monticar.cnnarch._symboltable.ConstantSymbol; import de.monticore.lang.monticar.cnnarch._symboltable.ConstantSymbol;
import de.monticore.lang.monticar.cnnarch._symboltable.LayerSymbol; import de.monticore.lang.monticar.cnnarch._symboltable.LayerSymbol;
import de.monticore.lang.monticar.cnnarch._symboltable.LayerVariableDeclarationSymbol;
import de.monticore.lang.monticar.cnnarch._symboltable.VariableSymbol;
import de.monticore.lang.monticar.cnnarch.predefined.AllPredefinedLayers; import de.monticore.lang.monticar.cnnarch.predefined.AllPredefinedLayers;
import de.se_rwth.commons.logging.Log; import de.se_rwth.commons.logging.Log;
...@@ -67,94 +69,101 @@ public class ArchitectureElementData { ...@@ -67,94 +69,101 @@ public class ArchitectureElementData {
this.templateController = templateController; this.templateController = templateController;
} }
private LayerSymbol getLayerSymbol() {
if (getElement() instanceof VariableSymbol) {
return ((VariableSymbol) getElement()).getLayerVariableDeclaration().getLayer();
}
else {
assert getElement() instanceof LayerSymbol;
return (LayerSymbol) getElement();
}
}
public List<String> getInputs(){ public List<String> getInputs(){
return getTemplateController().getLayerInputs(getElement()); return getTemplateController().getLayerInputs(getElement());
} }
public String getMember() {
assert getElement() instanceof VariableSymbol;
return ((VariableSymbol) getElement()).getMember().toString();
}
public int getConstValue() { public int getConstValue() {
ConstantSymbol constant = (ConstantSymbol) getElement(); assert getElement() instanceof ConstantSymbol;
return constant.getExpression().getIntValue().get();
return ((ConstantSymbol) getElement()).getExpression().getIntValue().get();
} }
public List<Integer> getKernel(){ public List<Integer> getKernel(){
return ((LayerSymbol) getElement()) return getLayerSymbol().getIntTupleValue(AllPredefinedLayers.KERNEL_NAME).get();
.getIntTupleValue(AllPredefinedLayers.KERNEL_NAME).get();
} }
public int getChannels(){ public int getChannels(){
return ((LayerSymbol) getElement()) return getLayerSymbol().getIntValue(AllPredefinedLayers.CHANNELS_NAME).get();
.getIntValue(AllPredefinedLayers.CHANNELS_NAME).get();
} }
public List<Integer> getStride(){ public List<Integer> getStride(){
return ((LayerSymbol) getElement()) return getLayerSymbol().getIntTupleValue(AllPredefinedLayers.STRIDE_NAME).get();
.getIntTupleValue(AllPredefinedLayers.STRIDE_NAME).get();
} }
public int getUnits(){ public int getUnits(){
return ((LayerSymbol) getElement()) return getLayerSymbol().getIntValue(AllPredefinedLayers.UNITS_NAME).get();
.getIntValue(AllPredefinedLayers.UNITS_NAME).get();
} }
public boolean getNoBias(){ public boolean getNoBias(){
return ((LayerSymbol) getElement()) return getLayerSymbol().getBooleanValue(AllPredefinedLayers.NOBIAS_NAME).get();
.getBooleanValue(AllPredefinedLayers.NOBIAS_NAME).get();
} }
public double getP(){ public double getP(){
return ((LayerSymbol) getElement()) return getLayerSymbol().getDoubleValue(AllPredefinedLayers.P_NAME).get();
.getDoubleValue(AllPredefinedLayers.P_NAME).get();
} }
public int getIndex(){ public int getIndex(){
return ((LayerSymbol) getElement()) return getLayerSymbol().getIntValue(AllPredefinedLayers.INDEX_NAME).get();
.getIntValue(AllPredefinedLayers.INDEX_NAME).get();
} }
public int getNumOutputs(){ public int getNumOutputs(){
return ((LayerSymbol) getElement()) return getLayerSymbol().getIntValue(AllPredefinedLayers.NUM_SPLITS_NAME).get();
.getIntValue(AllPredefinedLayers.NUM_SPLITS_NAME).get();
} }
public boolean getFixGamma(){ public boolean getFixGamma(){
return ((LayerSymbol) getElement()) return getLayerSymbol().getBooleanValue(AllPredefinedLayers.FIX_GAMMA_NAME).get();
.getBooleanValue(AllPredefinedLayers.FIX_GAMMA_NAME).get();
} }
public int getNsize(){ public int getNsize(){
return ((LayerSymbol) getElement()) return getLayerSymbol().getIntValue(AllPredefinedLayers.NSIZE_NAME).get();
.getIntValue(AllPredefinedLayers.NSIZE_NAME).get();
} }
public double getKnorm(){ public double getKnorm(){
return ((LayerSymbol) getElement()) return getLayerSymbol().getDoubleValue(AllPredefinedLayers.KNORM_NAME).get();
.getDoubleValue(AllPredefinedLayers.KNORM_NAME).get();
} }
public double getAlpha(){ public double getAlpha(){
return ((LayerSymbol) getElement()) return getLayerSymbol().getDoubleValue(AllPredefinedLayers.ALPHA_NAME).get();
.getDoubleValue(AllPredefinedLayers.ALPHA_NAME).get();
} }
public double getBeta(){ public double getBeta(){
return ((LayerSymbol) getElement()) return getLayerSymbol().getDoubleValue(AllPredefinedLayers.BETA_NAME).get();
.getDoubleValue(AllPredefinedLayers.BETA_NAME).get();
} }
public int getSize(){ public int getSize(){
return ((LayerSymbol) getElement()) return getLayerSymbol().getIntValue(AllPredefinedLayers.SIZE_NAME).get();
.getIntValue(AllPredefinedLayers.SIZE_NAME).get(); }
public int getLayers(){
return getLayerSymbol().getIntValue(AllPredefinedLayers.LAYERS_NAME).get();
} }
@Nullable @Nullable
public String getPoolType(){ public String getPoolType(){
return ((LayerSymbol) getElement()) return getLayerSymbol().getStringValue(AllPredefinedLayers.POOL_TYPE_NAME).get();
.getStringValue(AllPredefinedLayers.POOL_TYPE_NAME).get();
} }
@Nullable @Nullable
public List<Integer> getPadding(){ public List<Integer> getPadding(){
return getPadding((LayerSymbol) getElement()); return getPadding(getLayerSymbol());
} }
@Nullable @Nullable
......
...@@ -4,10 +4,12 @@ import de.monticore.lang.monticar.cnnarch._symboltable.ArchitectureSymbol; ...@@ -4,10 +4,12 @@ import de.monticore.lang.monticar.cnnarch._symboltable.ArchitectureSymbol;
import de.monticore.lang.monticar.cnnarch._symboltable.ArchitectureElementSymbol; import de.monticore.lang.monticar.cnnarch._symboltable.ArchitectureElementSymbol;
import de.monticore.lang.monticar.cnnarch._symboltable.CompositeElementSymbol; import de.monticore.lang.monticar.cnnarch._symboltable.CompositeElementSymbol;
import de.monticore.lang.monticar.cnnarch._symboltable.ConstantSymbol; import de.monticore.lang.monticar.cnnarch._symboltable.ConstantSymbol;
import de.monticore.lang.monticar.cnnarch._symboltable.IODeclarationSymbol;
import de.monticore.lang.monticar.cnnarch._symboltable.SerialCompositeElementSymbol; import de.monticore.lang.monticar.cnnarch._symboltable.SerialCompositeElementSymbol;
import de.monticore.lang.monticar.cnnarch._symboltable.VariableDeclarationSymbol;
import de.se_rwth.commons.logging.Log; import de.se_rwth.commons.logging.Log;
import java.util.List; import java.util.*;
public abstract class ArchitectureSupportChecker { public abstract class ArchitectureSupportChecker {
...@@ -49,8 +51,9 @@ public abstract class ArchitectureSupportChecker { ...@@ -49,8 +51,9 @@ public abstract class ArchitectureSupportChecker {
} }
protected boolean checkMultiDimensionalOutput(ArchitectureSymbol architecture) { protected boolean checkMultiDimensionalOutput(ArchitectureSymbol architecture) {
if (architecture.getOutputs().get(0).getDefinition().getType().getWidth() != 1 || IODeclarationSymbol ioDeclaration = (IODeclarationSymbol) architecture.getOutputs().get(0).getDeclaration();
architecture.getOutputs().get(0).getDefinition().getType().getHeight() != 1) {
if (ioDeclaration.getType().getWidth() != 1 || ioDeclaration.getType().getHeight() != 1) {
Log.error("This cnn architecture has a multi-dimensional output, " + Log.error("This cnn architecture has a multi-dimensional output, " +
"which is currently not supported by the code generator." "which is currently not supported by the code generator."
, architecture.getSourcePosition()); , architecture.getSourcePosition());
...@@ -61,7 +64,7 @@ public abstract class ArchitectureSupportChecker { ...@@ -61,7 +64,7 @@ public abstract class ArchitectureSupportChecker {
return true; return true;
} }
protected boolean hasConstant(ArchitectureElementSymbol element) { private boolean hasConstant(ArchitectureElementSymbol element) {
ArchitectureElementSymbol resolvedElement = element.getResolvedThis().get(); ArchitectureElementSymbol resolvedElement = element.getResolvedThis().get();
if (resolvedElement instanceof CompositeElementSymbol) { if (resolvedElement instanceof CompositeElementSymbol) {
...@@ -94,11 +97,22 @@ public abstract class ArchitectureSupportChecker { ...@@ -94,11 +97,22 @@ public abstract class ArchitectureSupportChecker {
return true; return true;
} }
protected boolean checkLayerVariables(ArchitectureSymbol architecture) {
if (!architecture.getLayerVariableDeclarations().isEmpty()) {
Log.error("This cnn architecture uses layer variables, which are currently not supported by the code generator."
, architecture.getSourcePosition());
return false;
}
return true;
}
public boolean check(ArchitectureSymbol architecture) { public boolean check(ArchitectureSymbol architecture) {
return checkMultipleStreams(architecture) return checkMultipleStreams(architecture)
&& checkMultipleInputs(architecture) && checkMultipleInputs(architecture)
&& checkMultipleOutputs(architecture) && checkMultipleOutputs(architecture)
&& checkMultiDimensionalOutput(architecture) && checkMultiDimensionalOutput(architecture)
&& checkConstants(architecture); && checkConstants(architecture)
&& checkLayerVariables(architecture);
} }
} }
...@@ -154,16 +154,16 @@ public abstract class CNNArchTemplateController { ...@@ -154,16 +154,16 @@ public abstract class CNNArchTemplateController {
public List<String> getArchitectureInputs(){ public List<String> getArchitectureInputs(){
List<String> list = new ArrayList<>(); List<String> list = new ArrayList<>();
for (IOSymbol ioElement : getArchitecture().getInputs()){ for (VariableSymbol element : getArchitecture().getInputs()){
list.add(nameManager.getName(ioElement)); list.add(nameManager.getName(element));
} }
return list; return list;
} }
public List<String> getArchitectureOutputs(){ public List<String> getArchitectureOutputs(){
List<String> list = new ArrayList<>(); List<String> list = new ArrayList<>();
for (IOSymbol ioElement : getArchitecture().getOutputs()){ for (VariableSymbol element : getArchitecture().getOutputs()){
list.add(nameManager.getName(ioElement)); list.add(nameManager.getName(element));
} }
return list; return list;
} }
......
...@@ -100,19 +100,43 @@ public class LayerNameCreator { ...@@ -100,19 +100,43 @@ public class LayerNameCreator {
} }
elementToName.put(architectureElement, name); elementToName.put(architectureElement, name);
nameToElement.put(name, architectureElement);
boolean isLayerVariable = false;
if (architectureElement instanceof VariableSymbol) {
isLayerVariable = ((VariableSymbol) architectureElement).getType() == VariableSymbol.Type.LAYER;
}
// Do not map names of layer variables to their respective element since the names are not unique
// for now the name to element mapping is not used anywhere so it doesn't matter
if (!isLayerVariable) {
nameToElement.put(name, architectureElement);
}
} }
return endStage; return endStage;
} }
protected String createName(ArchitectureElementSymbol architectureElement, int stage, List<Integer> streamIndices){ protected String createName(ArchitectureElementSymbol architectureElement, int stage, List<Integer> streamIndices){
if (architectureElement instanceof IOSymbol){ if (architectureElement instanceof VariableSymbol) {
VariableSymbol element = (VariableSymbol) architectureElement;
String name = createBaseName(architectureElement); String name = createBaseName(architectureElement);
IOSymbol ioElement = (IOSymbol) architectureElement;
if (ioElement.getArrayAccess().isPresent()){ if (element.getType() == VariableSymbol.Type.IO) {
int arrayAccess = ioElement.getArrayAccess().get().getIntValue().get(); if (element.getArrayAccess().isPresent()){
name = name + "_" + arrayAccess + "_"; int arrayAccess = element.getArrayAccess().get().getIntValue().get();
name = name + "_" + arrayAccess + "_";
} else {
name = name + "_";
}
} else if (element.getType() == VariableSymbol.Type.LAYER) {
if (element.getMember() == VariableSymbol.Member.STATE) {
name = name + "_state_";
} else {
name = name + "_output_";
}
} }
return name; return name;
} else { } else {
return createBaseName(architectureElement) + stage + createStreamPostfix(streamIndices) + "_"; return createBaseName(architectureElement) + stage + createStreamPostfix(streamIndices) + "_";
......
...@@ -7,6 +7,7 @@ import de.monticore.lang.monticar.cnnarch._symboltable.CompositeElementSymbol; ...@@ -7,6 +7,7 @@ import de.monticore.lang.monticar.cnnarch._symboltable.CompositeElementSymbol;
import de.monticore.lang.monticar.cnnarch._symboltable.ConstantSymbol; import de.monticore.lang.monticar.cnnarch._symboltable.ConstantSymbol;
import de.monticore.lang.monticar.cnnarch._symboltable.LayerDeclarationSymbol; import de.monticore.lang.monticar.cnnarch._symboltable.LayerDeclarationSymbol;
import de.monticore.lang.monticar.cnnarch._symboltable.LayerSymbol; import de.monticore.lang.monticar.cnnarch._symboltable.LayerSymbol;
import de.monticore.lang.monticar.cnnarch._symboltable.VariableSymbol;
import de.se_rwth.commons.logging.Log; import de.se_rwth.commons.logging.Log;
import java.util.ArrayList; import java.util.ArrayList;
...@@ -31,8 +32,13 @@ public abstract class LayerSupportChecker { ...@@ -31,8 +32,13 @@ public abstract class LayerSupportChecker {
} }
// Support all inputs and outputs // Support all inputs and outputs
if (resolvedElement.isInput() || resolvedElement.isOutput()) { if (resolvedElement instanceof VariableSymbol) {
return true; if (((VariableSymbol) resolvedElement).getType() == VariableSymbol.Type.LAYER) {
return isSupportedLayer(((VariableSymbol) resolvedElement).getLayerVariableDeclaration().getLayer());
}
else if (resolvedElement.isInput() || resolvedElement.isOutput()) {
return true;
}
} }
// Support for constants is checked in ArchitectureSupportChecker // Support for constants is checked in ArchitectureSupportChecker
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment