Commit bbf3a334 authored by Sebastian N.'s avatar Sebastian N.

Renamed VariableSymbol to ParameterSymbol, introduced layer variable...

Renamed VariableSymbol to ParameterSymbol, introduced layer variable declarations and changed IOSymbol to VariableSymbol which now combines IO variables and layer variables
parent 8e6562a8
Pipeline #170396 failed with stages
in 2 minutes and 14 seconds
......@@ -8,14 +8,14 @@
<groupId>de.monticore.lang.monticar</groupId>
<artifactId>cnnarch-generator</artifactId>
<version>0.0.2-SNAPSHOT</version>
<version>0.0.3-SNAPSHOT</version>
<!-- == PROJECT DEPENDENCIES ============================================= -->
<properties>
<!-- .. SE-Libraries .................................................. -->
<CNNArch.version>0.3.1-SNAPSHOT</CNNArch.version>
<CNNArch.version>0.3.2-SNAPSHOT</CNNArch.version>
<CNNTrain.version>0.3.4-SNAPSHOT</CNNTrain.version>
<embedded-montiarc-math-opt-generator>0.1.4</embedded-montiarc-math-opt-generator>
......
......@@ -24,6 +24,8 @@ import de.monticore.lang.monticar.cnnarch._symboltable.ArchTypeSymbol;
import de.monticore.lang.monticar.cnnarch._symboltable.ArchitectureElementSymbol;
import de.monticore.lang.monticar.cnnarch._symboltable.ConstantSymbol;
import de.monticore.lang.monticar.cnnarch._symboltable.LayerSymbol;
import de.monticore.lang.monticar.cnnarch._symboltable.LayerVariableDeclarationSymbol;
import de.monticore.lang.monticar.cnnarch._symboltable.VariableSymbol;
import de.monticore.lang.monticar.cnnarch.predefined.AllPredefinedLayers;
import de.se_rwth.commons.logging.Log;
......@@ -67,94 +69,101 @@ public class ArchitectureElementData {
this.templateController = templateController;
}
private LayerSymbol getLayerSymbol() {
if (getElement() instanceof VariableSymbol) {
return ((VariableSymbol) getElement()).getLayerVariableDeclaration().getLayer();
}
else {
assert getElement() instanceof LayerSymbol;
return (LayerSymbol) getElement();
}
}
public List<String> getInputs(){
return getTemplateController().getLayerInputs(getElement());
}
public String getMember() {
assert getElement() instanceof VariableSymbol;
return ((VariableSymbol) getElement()).getMember().toString();
}
public int getConstValue() {
ConstantSymbol constant = (ConstantSymbol) getElement();
return constant.getExpression().getIntValue().get();
assert getElement() instanceof ConstantSymbol;
return ((ConstantSymbol) getElement()).getExpression().getIntValue().get();
}
public List<Integer> getKernel(){
return ((LayerSymbol) getElement())
.getIntTupleValue(AllPredefinedLayers.KERNEL_NAME).get();
return getLayerSymbol().getIntTupleValue(AllPredefinedLayers.KERNEL_NAME).get();
}
public int getChannels(){
return ((LayerSymbol) getElement())
.getIntValue(AllPredefinedLayers.CHANNELS_NAME).get();
return getLayerSymbol().getIntValue(AllPredefinedLayers.CHANNELS_NAME).get();
}
public List<Integer> getStride(){
return ((LayerSymbol) getElement())
.getIntTupleValue(AllPredefinedLayers.STRIDE_NAME).get();
return getLayerSymbol().getIntTupleValue(AllPredefinedLayers.STRIDE_NAME).get();
}
public int getUnits(){
return ((LayerSymbol) getElement())
.getIntValue(AllPredefinedLayers.UNITS_NAME).get();
return getLayerSymbol().getIntValue(AllPredefinedLayers.UNITS_NAME).get();
}
public boolean getNoBias(){
return ((LayerSymbol) getElement())
.getBooleanValue(AllPredefinedLayers.NOBIAS_NAME).get();
return getLayerSymbol().getBooleanValue(AllPredefinedLayers.NOBIAS_NAME).get();
}
public double getP(){
return ((LayerSymbol) getElement())
.getDoubleValue(AllPredefinedLayers.P_NAME).get();
return getLayerSymbol().getDoubleValue(AllPredefinedLayers.P_NAME).get();
}
public int getIndex(){
return ((LayerSymbol) getElement())
.getIntValue(AllPredefinedLayers.INDEX_NAME).get();
return getLayerSymbol().getIntValue(AllPredefinedLayers.INDEX_NAME).get();
}
public int getNumOutputs(){
return ((LayerSymbol) getElement())
.getIntValue(AllPredefinedLayers.NUM_SPLITS_NAME).get();
return getLayerSymbol().getIntValue(AllPredefinedLayers.NUM_SPLITS_NAME).get();
}
public boolean getFixGamma(){
return ((LayerSymbol) getElement())
.getBooleanValue(AllPredefinedLayers.FIX_GAMMA_NAME).get();
return getLayerSymbol().getBooleanValue(AllPredefinedLayers.FIX_GAMMA_NAME).get();
}
public int getNsize(){
return ((LayerSymbol) getElement())
.getIntValue(AllPredefinedLayers.NSIZE_NAME).get();
return getLayerSymbol().getIntValue(AllPredefinedLayers.NSIZE_NAME).get();
}
public double getKnorm(){
return ((LayerSymbol) getElement())
.getDoubleValue(AllPredefinedLayers.KNORM_NAME).get();
return getLayerSymbol().getDoubleValue(AllPredefinedLayers.KNORM_NAME).get();
}
public double getAlpha(){
return ((LayerSymbol) getElement())
.getDoubleValue(AllPredefinedLayers.ALPHA_NAME).get();
return getLayerSymbol().getDoubleValue(AllPredefinedLayers.ALPHA_NAME).get();
}
public double getBeta(){
return ((LayerSymbol) getElement())
.getDoubleValue(AllPredefinedLayers.BETA_NAME).get();
return getLayerSymbol().getDoubleValue(AllPredefinedLayers.BETA_NAME).get();
}
public int getSize(){
return ((LayerSymbol) getElement())
.getIntValue(AllPredefinedLayers.SIZE_NAME).get();
return getLayerSymbol().getIntValue(AllPredefinedLayers.SIZE_NAME).get();
}
public int getLayers(){
return getLayerSymbol().getIntValue(AllPredefinedLayers.LAYERS_NAME).get();
}
@Nullable
public String getPoolType(){
return ((LayerSymbol) getElement())
.getStringValue(AllPredefinedLayers.POOL_TYPE_NAME).get();
return getLayerSymbol().getStringValue(AllPredefinedLayers.POOL_TYPE_NAME).get();
}
@Nullable
public List<Integer> getPadding(){
return getPadding((LayerSymbol) getElement());
return getPadding(getLayerSymbol());
}
@Nullable
......
......@@ -4,10 +4,12 @@ import de.monticore.lang.monticar.cnnarch._symboltable.ArchitectureSymbol;
import de.monticore.lang.monticar.cnnarch._symboltable.ArchitectureElementSymbol;
import de.monticore.lang.monticar.cnnarch._symboltable.CompositeElementSymbol;
import de.monticore.lang.monticar.cnnarch._symboltable.ConstantSymbol;
import de.monticore.lang.monticar.cnnarch._symboltable.IODeclarationSymbol;
import de.monticore.lang.monticar.cnnarch._symboltable.SerialCompositeElementSymbol;
import de.monticore.lang.monticar.cnnarch._symboltable.VariableDeclarationSymbol;
import de.se_rwth.commons.logging.Log;
import java.util.List;
import java.util.*;
public abstract class ArchitectureSupportChecker {
......@@ -49,8 +51,9 @@ public abstract class ArchitectureSupportChecker {
}
protected boolean checkMultiDimensionalOutput(ArchitectureSymbol architecture) {
if (architecture.getOutputs().get(0).getDefinition().getType().getWidth() != 1 ||
architecture.getOutputs().get(0).getDefinition().getType().getHeight() != 1) {
IODeclarationSymbol ioDeclaration = (IODeclarationSymbol) architecture.getOutputs().get(0).getDeclaration();
if (ioDeclaration.getType().getWidth() != 1 || ioDeclaration.getType().getHeight() != 1) {
Log.error("This cnn architecture has a multi-dimensional output, " +
"which is currently not supported by the code generator."
, architecture.getSourcePosition());
......@@ -61,7 +64,7 @@ public abstract class ArchitectureSupportChecker {
return true;
}
protected boolean hasConstant(ArchitectureElementSymbol element) {
private boolean hasConstant(ArchitectureElementSymbol element) {
ArchitectureElementSymbol resolvedElement = element.getResolvedThis().get();
if (resolvedElement instanceof CompositeElementSymbol) {
......@@ -94,11 +97,22 @@ public abstract class ArchitectureSupportChecker {
return true;
}
protected boolean checkLayerVariables(ArchitectureSymbol architecture) {
if (!architecture.getLayerVariableDeclarations().isEmpty()) {
Log.error("This cnn architecture uses layer variables, which are currently not supported by the code generator."
, architecture.getSourcePosition());
return false;
}
return true;
}
public boolean check(ArchitectureSymbol architecture) {
return checkMultipleStreams(architecture)
&& checkMultipleInputs(architecture)
&& checkMultipleOutputs(architecture)
&& checkMultiDimensionalOutput(architecture)
&& checkConstants(architecture);
&& checkConstants(architecture)
&& checkLayerVariables(architecture);
}
}
......@@ -154,16 +154,16 @@ public abstract class CNNArchTemplateController {
public List<String> getArchitectureInputs(){
List<String> list = new ArrayList<>();
for (IOSymbol ioElement : getArchitecture().getInputs()){
list.add(nameManager.getName(ioElement));
for (VariableSymbol element : getArchitecture().getInputs()){
list.add(nameManager.getName(element));
}
return list;
}
public List<String> getArchitectureOutputs(){
List<String> list = new ArrayList<>();
for (IOSymbol ioElement : getArchitecture().getOutputs()){
list.add(nameManager.getName(ioElement));
for (VariableSymbol element : getArchitecture().getOutputs()){
list.add(nameManager.getName(element));
}
return list;
}
......
......@@ -100,19 +100,43 @@ public class LayerNameCreator {
}
elementToName.put(architectureElement, name);
nameToElement.put(name, architectureElement);
boolean isLayerVariable = false;
if (architectureElement instanceof VariableSymbol) {
isLayerVariable = ((VariableSymbol) architectureElement).getType() == VariableSymbol.Type.LAYER;
}
// Do not map names of layer variables to their respective element since the names are not unique
// for now the name to element mapping is not used anywhere so it doesn't matter
if (!isLayerVariable) {
nameToElement.put(name, architectureElement);
}
}
return endStage;
}
protected String createName(ArchitectureElementSymbol architectureElement, int stage, List<Integer> streamIndices){
if (architectureElement instanceof IOSymbol){
if (architectureElement instanceof VariableSymbol) {
VariableSymbol element = (VariableSymbol) architectureElement;
String name = createBaseName(architectureElement);
IOSymbol ioElement = (IOSymbol) architectureElement;
if (ioElement.getArrayAccess().isPresent()){
int arrayAccess = ioElement.getArrayAccess().get().getIntValue().get();
name = name + "_" + arrayAccess + "_";
if (element.getType() == VariableSymbol.Type.IO) {
if (element.getArrayAccess().isPresent()){
int arrayAccess = element.getArrayAccess().get().getIntValue().get();
name = name + "_" + arrayAccess + "_";
} else {
name = name + "_";
}
} else if (element.getType() == VariableSymbol.Type.LAYER) {
if (element.getMember() == VariableSymbol.Member.STATE) {
name = name + "_state_";
} else {
name = name + "_output_";
}
}
return name;
} else {
return createBaseName(architectureElement) + stage + createStreamPostfix(streamIndices) + "_";
......
......@@ -7,6 +7,7 @@ import de.monticore.lang.monticar.cnnarch._symboltable.CompositeElementSymbol;
import de.monticore.lang.monticar.cnnarch._symboltable.ConstantSymbol;
import de.monticore.lang.monticar.cnnarch._symboltable.LayerDeclarationSymbol;
import de.monticore.lang.monticar.cnnarch._symboltable.LayerSymbol;
import de.monticore.lang.monticar.cnnarch._symboltable.VariableSymbol;
import de.se_rwth.commons.logging.Log;
import java.util.ArrayList;
......@@ -31,8 +32,13 @@ public abstract class LayerSupportChecker {
}
// Support all inputs and outputs
if (resolvedElement.isInput() || resolvedElement.isOutput()) {
return true;
if (resolvedElement instanceof VariableSymbol) {
if (((VariableSymbol) resolvedElement).getType() == VariableSymbol.Type.LAYER) {
return isSupportedLayer(((VariableSymbol) resolvedElement).getLayerVariableDeclaration().getLayer());
}
else if (resolvedElement.isInput() || resolvedElement.isOutput()) {
return true;
}
}
// Support for constants is checked in ArchitectureSupportChecker
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment