Commit d4163365 authored by Thomas Michael Timmermanns's avatar Thomas Michael Timmermanns Committed by Thomas Michael Timmermanns

Changed ShapeSymbol to ArchTypeSymbol which combines Shape and Type.

IOLayer arrays do not exist anymore at symbol level. They all are now split up at symbol table creation.
Added output type check and test.
parent e9136aec
......@@ -5,8 +5,8 @@ grammar CNNArch extends de.monticore.lang.math.Math {
CNNArchCompilationUnit = Architecture;
symbol scope Architecture = "architecture"
name:Name& "("
(ArchitectureParameter || ",")* ")" "{"
name:Name&
( "(" (ArchitectureParameter || ",")* ")" )? "{"
declarations:ArchDeclaration*
body:ArchBody "}";
......
......@@ -20,23 +20,23 @@
*/
package de.monticore.lang.monticar.cnnarch._cocos;
import de.monticore.lang.monticar.cnnarch._ast.ASTArchType;
import de.monticore.lang.monticar.cnnarch._ast.ASTDimensionArgument;
import de.monticore.lang.monticar.cnnarch._ast.ASTShape;
import de.monticore.lang.monticar.cnnarch._symboltable.ArchSimpleExpressionSymbol;
import de.monticore.lang.monticar.cnnarch._symboltable.ShapeSymbol;
import de.monticore.lang.monticar.cnnarch._symboltable.ArchTypeSymbol;
import de.monticore.lang.monticar.cnnarch.helper.ErrorCodes;
import de.se_rwth.commons.logging.Log;
import java.util.Optional;
public class CheckIOShape implements CNNArchASTShapeCoCo {
public class CheckIOShape implements CNNArchASTArchTypeCoCo {
@Override
public void check(ASTShape node) {
public void check(ASTArchType node) {
boolean hasHeight = false;
boolean hasWidth = false;
boolean hasChannels = false;
for (ASTDimensionArgument dimensionArg : node.getDimensions()){
for (ASTDimensionArgument dimensionArg : node.getShape().getDimensions()){
if (dimensionArg.getWidth().isPresent()){
if (hasWidth){
repetitionError(dimensionArg);
......@@ -58,7 +58,7 @@ public class CheckIOShape implements CNNArchASTShapeCoCo {
}
ShapeSymbol shape = (ShapeSymbol) node.getSymbol().get();
ArchTypeSymbol shape = (ArchTypeSymbol) node.getSymbol().get();
for (ArchSimpleExpressionSymbol dimension : shape.getDimensionSymbols()){
Optional<Integer> value = dimension.getIntValue();
if (!value.isPresent() || value.get() <= 0){
......
......@@ -27,6 +27,6 @@ public class CheckLayerInputs implements CNNArchASTArchitectureCoCo {
@Override
public void check(ASTArchitecture node) {
ArchitectureSymbol architecture = (ArchitectureSymbol) node.getSymbol().get();
architecture.getBody().checkInputAndOutput();
architecture.getBody().checkInput();
}
}
......@@ -26,7 +26,7 @@ import de.monticore.lang.monticar.cnnarch._ast.ASTArchSimpleExpression;
import de.monticore.lang.monticar.cnnarch._symboltable.ArchSimpleExpressionSymbol;
import de.monticore.lang.monticar.cnnarch._symboltable.VariableSymbol;
import de.monticore.lang.monticar.cnnarch.helper.ErrorCodes;
import de.monticore.lang.monticar.cnnarch.helper.ExpressionHelper;
import de.monticore.lang.monticar.cnnarch.helper.Utils;
import de.se_rwth.commons.logging.Log;
import java.util.Collection;
......@@ -39,7 +39,7 @@ public class CheckNameExpression implements CNNArchASTArchSimpleExpressionCoCo {
if (expression.getMathExpression().isPresent()){
MathExpressionSymbol mathExpression = expression.getMathExpression().get();
for (MathExpressionSymbol subMathExp : ExpressionHelper.createSubExpressionList(mathExpression)){
for (MathExpressionSymbol subMathExp : Utils.createSubExpressionList(mathExpression)){
if (subMathExp instanceof MathNameExpressionSymbol){
String name = ((MathNameExpressionSymbol) subMathExp).getNameToAccess();
Collection<VariableSymbol> variableCollection = node.getEnclosingScope().get().resolveMany(name, VariableSymbol.KIND);
......
......@@ -24,25 +24,22 @@ import de.monticore.lang.monticar.cnnarch._ast.ASTIOLayer;
import de.monticore.lang.monticar.cnnarch._symboltable.CompositeLayerSymbol;
import de.monticore.lang.monticar.cnnarch._symboltable.IODeclarationSymbol;
import de.monticore.lang.monticar.cnnarch._symboltable.IOLayerSymbol;
import de.monticore.lang.monticar.cnnarch._symboltable.LayerSymbol;
import de.monticore.lang.monticar.cnnarch.helper.ErrorCodes;
import de.monticore.symboltable.Symbol;
import de.se_rwth.commons.logging.Log;
import java.util.Collection;
import java.util.Collections;
public class CheckUnknownIO implements CNNArchASTIOLayerCoCo {
@Override
public void check(ASTIOLayer node) {
Symbol symbol = node.getSymbol().get();
IODeclarationSymbol ioDeclaration = null;
if (symbol instanceof IOLayerSymbol){
ioDeclaration = ((IOLayerSymbol) symbol).getDefinition();
}
else if (symbol instanceof CompositeLayerSymbol){
IOLayerSymbol layer = (IOLayerSymbol) ((CompositeLayerSymbol) symbol).getLayers().get(0);
ioDeclaration = layer.getDefinition();
}
Collection<IODeclarationSymbol> ioDeclarations = node.getEnclosingScope().get().<IODeclarationSymbol>resolveMany(node.getName(), IODeclarationSymbol.KIND);
if (ioDeclaration == null){
if (ioDeclarations.isEmpty()){
Log.error("0" + ErrorCodes.UNKNOWN_IO + " Unknown input or output name. " +
"The input or output '" + node.getName() + "' does not exist"
, node.get_SourcePositionStart());
......
......@@ -22,7 +22,7 @@ package de.monticore.lang.monticar.cnnarch._symboltable;
import de.monticore.lang.math.math._symboltable.expression.*;
import de.monticore.lang.monticar.cnnarch.helper.Calculator;
import de.monticore.lang.monticar.cnnarch.helper.ExpressionHelper;
import de.monticore.lang.monticar.cnnarch.helper.Utils;
import java.util.*;
......@@ -100,7 +100,7 @@ public class ArchSimpleExpressionSymbol extends ArchExpressionSymbol {
protected void computeUnresolvableVariables(Set<VariableSymbol> unresolvableVariables, Set<VariableSymbol> allVariables) {
if (getMathExpression().isPresent()) {
for (MathExpressionSymbol exp : ExpressionHelper.createSubExpressionList(getMathExpression().get())) {
for (MathExpressionSymbol exp : Utils.createSubExpressionList(getMathExpression().get())) {
if (exp instanceof MathNameExpressionSymbol) {
String name = ((MathNameExpressionSymbol) exp).getNameToAccess();
Optional<VariableSymbol> variable = getEnclosingScope().resolve(name, VariableSymbol.KIND);
......@@ -164,7 +164,7 @@ public class ArchSimpleExpressionSymbol extends ArchExpressionSymbol {
}
else {
Map<String, String> replacementMap = new HashMap<>();
for (MathExpressionSymbol exp : ExpressionHelper.createSubExpressionList(getMathExpression().get())) {
for (MathExpressionSymbol exp : Utils.createSubExpressionList(getMathExpression().get())) {
if (exp instanceof MathNameExpressionSymbol) {
String name = ((MathNameExpressionSymbol) exp).getNameToAccess();
VariableSymbol variable = (VariableSymbol) getEnclosingScope().resolve(name, VariableSymbol.KIND).get();
......@@ -174,7 +174,7 @@ public class ArchSimpleExpressionSymbol extends ArchExpressionSymbol {
}
}
String resolvedString = ExpressionHelper.replace(getTextualRepresentation(), replacementMap);
String resolvedString = Utils.replace(getTextualRepresentation(), replacementMap);
return Calculator.getInstance().calculate(resolvedString);
}
}
......@@ -191,7 +191,7 @@ public class ArchSimpleExpressionSymbol extends ArchExpressionSymbol {
public String getTextualRepresentation() {
if (isResolved()){
if (isTuple()){
return ExpressionHelper.createTupleTextualRepresentation(getTupleValues().get(), Object::toString);
return Utils.createTupleTextualRepresentation(getTupleValues().get(), Object::toString);
}
else {
return getValue().get().toString();
......
......@@ -22,9 +22,9 @@ package de.monticore.lang.monticar.cnnarch._symboltable;
import de.monticore.symboltable.SymbolKind;
public class ShapeKind implements SymbolKind {
public class ArchTypeKind implements SymbolKind {
private static final String NAME = "de.monticore.lang.monticar.cnnarch._symboltable.ShapeKind";
private static final String NAME = "de.monticore.lang.monticar.cnnarch._symboltable.ArchTypeKind";
@Override
public String getName() {
......
......@@ -20,24 +20,40 @@
*/
package de.monticore.lang.monticar.cnnarch._symboltable;
import de.monticore.lang.monticar.types2._ast.ASTElementType;
import de.monticore.symboltable.CommonSymbol;
import de.monticore.symboltable.MutableScope;
import java.util.*;
public class ShapeSymbol extends CommonSymbol {
public class ArchTypeSymbol extends CommonSymbol {
public static final ShapeKind KIND = new ShapeKind();
public static final ArchTypeKind KIND = new ArchTypeKind();
protected static final String DEFAULT_ELEMENT_TYPE = "Q(-oo:oo)";
private ASTElementType elementType;
private int channelIndex = -1;
private int heightIndex = -1;
private int widthIndex = -1;
private List<ArchSimpleExpressionSymbol> dimensions = new ArrayList<>();
public ShapeSymbol() {
public ArchTypeSymbol() {
super("", KIND);
ASTElementType elementType = new ASTElementType();
elementType.setTElementType(DEFAULT_ELEMENT_TYPE);
setElementType(elementType);
}
public ASTElementType getElementType() {
return elementType;
}
public void setElementType(ASTElementType elementType) {
this.elementType = elementType;
}
public int getHeightIndex() {
......@@ -178,6 +194,7 @@ public class ShapeSymbol extends CommonSymbol {
private int height = 1;
private int width = 1;
private int channels = 1;
private ASTElementType elementType = null;
public Builder height(int height){
this.height = height;
......@@ -191,13 +208,28 @@ public class ShapeSymbol extends CommonSymbol {
this.channels = channels;
return this;
}
public Builder elementType(ASTElementType elementType){
this.elementType = elementType;
return this;
}
public Builder elementType(String start, String end){
elementType = new ASTElementType();
elementType.setTElementType("Q(" + start + ":" + end +")");
return this;
}
public ShapeSymbol build(){
ShapeSymbol sym = new ShapeSymbol();
public ArchTypeSymbol build(){
ArchTypeSymbol sym = new ArchTypeSymbol();
sym.setChannelIndex(0);
sym.setHeightIndex(1);
sym.setWidthIndex(2);
sym.setDimensions(Arrays.asList(channels, height, width));
if (elementType == null){
elementType = new ASTElementType();
elementType.setTElementType(DEFAULT_ELEMENT_TYPE);
}
sym.setElementType(elementType);
return sym;
}
}
......
......@@ -98,6 +98,8 @@ public class ArchitectureSymbol extends ArchitectureSymbolTOP {
}
}
//todo: deep copy method for instances
public List<LayerSymbol> getFirstLayers(){
if (!getBody().isResolved()){
resolve();
......
......@@ -61,10 +61,6 @@ public class ArgumentSymbol extends CommonSymbol {
return rhs;
}
public Optional<Object> getValue(){
return getRhs().getValue();
}
protected void setRhs(ArchExpressionSymbol rhs) {
if (getName().equals(AllPredefinedVariables.FOR_NAME)
&& rhs instanceof ArchSimpleExpressionSymbol
......@@ -192,6 +188,7 @@ public class ArgumentSymbol extends CommonSymbol {
public static class Builder{
private String name;
private VariableSymbol parameter;
private ArchExpressionSymbol value;
......@@ -200,19 +197,31 @@ public class ArgumentSymbol extends CommonSymbol {
return this;
}
public Builder parameter(String name) {
this.name = name;
return this;
}
public Builder value(ArchExpressionSymbol value) {
this.value = value;
return this;
}
public ArgumentSymbol build(){
if (parameter == null){
if (parameter == null && name == null){
throw new IllegalStateException("Missing parameter for ArgumentSymbol");
}
ArgumentSymbol sym = new ArgumentSymbol(parameter.getName());
sym.setParameter(parameter);
sym.setRhs(value);
return sym;
if (parameter == null){
ArgumentSymbol sym = new ArgumentSymbol(name);
sym.setRhs(value);
return sym;
}
else {
ArgumentSymbol sym = new ArgumentSymbol(parameter.getName());
sym.setParameter(parameter);
sym.setRhs(value);
return sym;
}
}
}
}
......@@ -30,7 +30,6 @@ import de.monticore.lang.monticar.cnnarch._visitor.CNNArchVisitor;
import de.monticore.lang.monticar.cnnarch._visitor.CommonCNNArchDelegatorVisitor;
import de.monticore.lang.monticar.cnnarch.predefined.AllPredefinedMethods;
import de.monticore.lang.monticar.cnnarch.predefined.AllPredefinedVariables;
import de.monticore.lang.monticar.types2._ast.ASTType;
import de.monticore.symboltable.*;
import de.se_rwth.commons.logging.Log;
......@@ -173,9 +172,8 @@ public class CNNArchSymbolTableCreator extends de.monticore.symboltable.CommonSy
if (ast.getArrayDeclaration().isPresent()){
iODeclaration.setArrayLength(ast.getArrayDeclaration().get().getIntLiteral().getNumber().get().getDividend().intValue());
}
iODeclaration.setShape((ShapeSymbol) ast.getType().getShape().getSymbol().get());
iODeclaration.setType((ArchTypeSymbol) ast.getType().getSymbol().get());
iODeclaration.setInput(ast.getIn().isPresent());
iODeclaration.setType(ast.getType().getElementType());
if (iODeclaration.isInput()){
inputs.add(iODeclaration);
}
......@@ -185,23 +183,19 @@ public class CNNArchSymbolTableCreator extends de.monticore.symboltable.CommonSy
}
@Override
public void endVisit(ASTType node) {
//todo
}
@Override
public void visit(ASTShape ast) {
ShapeSymbol sym = new ShapeSymbol();
public void visit(ASTArchType ast) {
ArchTypeSymbol sym = new ArchTypeSymbol();
addToScopeAndLinkWithNode(sym, ast);
}
@Override
public void endVisit(ASTShape node) {
ShapeSymbol sym = (ShapeSymbol) node.getSymbol().get();
public void endVisit(ASTArchType node) {
ArchTypeSymbol sym = (ArchTypeSymbol) node.getSymbol().get();
List<ASTDimensionArgument> astDimensions = node.getShape().getDimensions();
List<ArchSimpleExpressionSymbol> dimensionList = new ArrayList<>(3);
for (int i = 0; i < node.getDimensions().size(); i++){
ASTDimensionArgument dimensionArg = node.getDimensions().get(i);
for (int i = 0; i < astDimensions.size(); i++){
ASTDimensionArgument dimensionArg = astDimensions.get(i);
if (dimensionArg.getHeight().isPresent()){
sym.setHeightIndex(i);
ArchSimpleExpressionSymbol exp = (ArchSimpleExpressionSymbol) dimensionArg.getHeight().get().getSymbol().get();
......@@ -219,6 +213,7 @@ public class CNNArchSymbolTableCreator extends de.monticore.symboltable.CommonSy
}
}
sym.setDimensionSymbols(dimensionList);
sym.setElementType(node.getElementType());
addToScopeAndLinkWithNode(sym, node);
}
......@@ -405,26 +400,21 @@ public class CNNArchSymbolTableCreator extends de.monticore.symboltable.CommonSy
isInput = ioDeclaration.isInput();
}
if (!node.getIndex().isPresent() && arrayLength > 1 && isInput){
List<LayerSymbol> ioLayers = new ArrayList<>(arrayLength);
IOLayerSymbol ioLayer;
for (int i = 0; i < arrayLength; i++){
ioLayer = new IOLayerSymbol(node.getName());
ioLayer.setArrayAccess(i);
ioLayers.add(ioLayer);
}
if (!node.getIndex().isPresent() && arrayLength > 1){
//transform io array into parallel composite
List<LayerSymbol> parallelLayers = createSerialIOLayerPart(node, arrayLength, isInput);
CompositeLayerSymbol composite = new CompositeLayerSymbol.Builder()
.parallel(true)
.layers(ioLayers)
.layers(parallelLayers)
.build();
addToScopeAndLinkWithNode(composite, node);
for (LayerSymbol layer : ioLayers){
addToScope(layer);
for (LayerSymbol layer : parallelLayers){
layer.putInScope(composite.getSpannedScope());
layer.setAstNode(node);
}
}
else {
IOLayerSymbol ioLayer = new IOLayerSymbol(node.getName());
......@@ -432,6 +422,40 @@ public class CNNArchSymbolTableCreator extends de.monticore.symboltable.CommonSy
}
}
private List<LayerSymbol> createSerialIOLayerPart(ASTIOLayer node, int arrayLength, boolean isInput){
List<LayerSymbol> parallelLayers = new ArrayList<>(arrayLength);
if (isInput){
for (int i = 0; i < arrayLength; i++){
IOLayerSymbol ioLayer = new IOLayerSymbol(node.getName());
ioLayer.setArrayAccess(i);
parallelLayers.add(ioLayer);
}
}
else {
for (int i = 0; i < arrayLength; i++){
CompositeLayerSymbol serialComposite = new CompositeLayerSymbol();
serialComposite.setParallel(false);
IOLayerSymbol ioLayer = new IOLayerSymbol(node.getName());
ioLayer.setArrayAccess(i);
ioLayer.setAstNode(node);
MethodLayerSymbol getLayer = new MethodLayerSymbol(AllPredefinedMethods.GET_NAME);
getLayer.setArguments(Collections.singletonList(
new ArgumentSymbol.Builder()
.parameter(AllPredefinedMethods.INDEX_NAME)
.value(ArchSimpleExpressionSymbol.of(i))
.build()));
getLayer.setAstNode(node);
serialComposite.setLayers(Arrays.asList(getLayer, ioLayer));
parallelLayers.add(serialComposite);
}
}
return parallelLayers;
}
@Override
public void endVisit(ASTIOLayer node) {
if (node.getIndex().isPresent()){
......
......@@ -177,44 +177,49 @@ public class CompositeLayerSymbol extends LayerSymbol {
}
@Override
public List<ShapeSymbol> computeOutputShapes() {
public List<ArchTypeSymbol> computeOutputTypes() {
if (getLayers().isEmpty()){
return getInputLayer().get().getOutputShapes();
if (getInputLayer().isPresent()){
return getInputLayer().get().getOutputTypes();
}
else {
return Collections.emptyList();
}
}
else {
if (isParallel()){
List<ShapeSymbol> outputShapes = new ArrayList<>(getLayers().size());
List<ArchTypeSymbol> outputShapes = new ArrayList<>(getLayers().size());
for (LayerSymbol layer : getLayers()){
if (layer.getOutputShapes().size() != 0){
outputShapes.add(layer.getOutputShapes().get(0));
if (layer.getOutputTypes().size() != 0){
outputShapes.add(layer.getOutputTypes().get(0));
}
}
return outputShapes;
}
else {
for (LayerSymbol layer : getLayers()){
layer.getOutputShapes();
layer.getOutputTypes();
}
return getLayers().get(getLayers().size() - 1).getOutputShapes();
return getLayers().get(getLayers().size() - 1).getOutputTypes();
}
}
}
@Override
public void checkInputAndOutput() {
public void checkInput() {
if (!getLayers().isEmpty()){
for (LayerSymbol layer : getLayers()){
layer.checkInput();
}
if (isParallel()){
for (LayerSymbol layer : getLayers()){
if (layer.getOutputShapes().size() > 1){
Log.error("0" + ErrorCodes.MISSING_MERGE + " Missing merge layer (Add(), Concatenate, [i]). " +
if (layer.getOutputTypes().size() > 1){
Log.error("0" + ErrorCodes.MISSING_MERGE + " Missing merge layer (Add(), Concatenate() or [i]). " +
"Each stream at the end of a parallel layer can only have one output stream. "
, getSourcePosition());
}
}
}
for (LayerSymbol layer : getLayers()){
layer.checkInputAndOutput();
}
}
}
......
......@@ -23,7 +23,6 @@
package de.monticore.lang.monticar.cnnarch._symboltable;
import de.monticore.lang.monticar.types2._ast.ASTElementType;
import de.monticore.symboltable.CommonSymbol;
import java.util.HashSet;
......@@ -33,8 +32,7 @@ public class IODeclarationSymbol extends CommonSymbol {
public static final IODeclarationKind KIND = new IODeclarationKind();
private ASTElementType type;
private ShapeSymbol shape;
private ArchTypeSymbol type;
private boolean input; //true->input, false->output
private int arrayLength = 1;
private Set<IOLayerSymbol> connectedLayers = new HashSet<>();
......@@ -44,22 +42,14 @@ public class IODeclarationSymbol extends CommonSymbol {
super(name, KIND);
}
public ASTElementType getType() {
public ArchTypeSymbol getType() {
return type;
}
protected void setType(ASTElementType type) {
protected void setType(ArchTypeSymbol type) {
this.type = type;
}
public ShapeSymbol getShape() {
return shape;
}
protected void setShape(ShapeSymbol shape) {
this.shape = shape;
}
public Set<IOLayerSymbol> getConnectedLayers() {
return connectedLayers;
}
......@@ -87,22 +77,16 @@ public class IODeclarationSymbol extends CommonSymbol {
public static class Builder{
private ASTElementType type;
private ShapeSymbol shape;
private ArchTypeSymbol type;
private boolean input; //true->input, false->output
private int arrayLength = 0;
private String name;
public Builder type(ASTElementType type){
public Builder type(ArchTypeSymbol type){
this.type = type;
return this;
}
public Builder shape(ShapeSymbol shape){
this.shape = shape;
return this;
}
public Builder input(boolean input){