From 6096077b4ceda1167f0e94375e5a39381afaff58 Mon Sep 17 00:00:00 2001 From: Thomas Michael Timmermanns Date: Wed, 3 Jan 2018 00:56:37 +0100 Subject: [PATCH] Completed resolve mechanism and output shape computation. --- pom.xml | 2 +- .../de/monticore/lang/monticar/CNNArch.mc4 | 6 +- .../lang/monticar/cnnarch/Constraint.java | 42 ++++- .../monticar/cnnarch/PredefinedMethods.java | 44 +++-- .../monticar/cnnarch/PredefinedVariables.java | 8 + .../_symboltable/ArchExpressionSymbol.java | 56 +++--- .../ArchRangeExpressionSymbol.java | 28 +-- .../ArchSequenceExpressionSymbol.java | 29 +-- .../ArchSimpleExpressionSymbol.java | 122 ++++++++----- .../_symboltable/ArchitectureSymbol.java | 16 ++ .../cnnarch/_symboltable/ArgumentSymbol.java | 30 +++- .../CNNArchSymbolTableCreator.java | 29 +-- .../_symboltable/CompositeLayerSymbol.java | 35 ++-- .../cnnarch/_symboltable/DimensionKind.java | 39 ---- .../cnnarch/_symboltable/DimensionSymbol.java | 71 -------- .../_symboltable/IODeclarationSymbol.java | 2 +- .../cnnarch/_symboltable/IOLayerSymbol.java | 58 +++--- .../cnnarch/_symboltable/LayerSymbol.java | 37 ++-- .../_symboltable/MethodDeclarationSymbol.java | 56 +++--- .../_symboltable/MethodLayerSymbol.java | 170 ++++++++++-------- .../cnnarch/_symboltable/ShapeSymbol.java | 156 +++++++++------- .../cnnarch/_symboltable/VariableSymbol.java | 5 +- .../lang/monticar/cnnarch/SymtabTest.java | 39 +++- .../monticar/cnnarch/cocos/AllCoCoTest.java | 4 + src/test/resources/architectures/Alexnet.cnna | 4 +- .../resources/architectures/Alexnet_alt.cnna | 8 +- .../resources/architectures/ResNeXt50.cnna | 4 +- .../architectures/ThreeInputCNN_M14.cnna | 8 +- .../resources/valid_tests/Fixed_Alexnet.cnna | 10 +- .../valid_tests/Fixed_ResNeXt50.cnna | 22 +-- .../valid_tests/Fixed_ThreeInputCNN_M14.cnna | 27 +++ 31 files changed, 666 insertions(+), 501 deletions(-) delete mode 100644 src/main/java/de/monticore/lang/monticar/cnnarch/_symboltable/DimensionKind.java delete mode 100644 src/main/java/de/monticore/lang/monticar/cnnarch/_symboltable/DimensionSymbol.java create mode 100644 src/test/resources/valid_tests/Fixed_ThreeInputCNN_M14.cnna diff --git a/pom.xml b/pom.xml index a3b022a..73446e5 100644 --- a/pom.xml +++ b/pom.xml @@ -30,7 +30,7 @@ de.monticore.lang.monticar cnn-arch - 0.0.2-SNAPSHOT + 0.1.0-SNAPSHOT diff --git a/src/main/grammars/de/monticore/lang/monticar/CNNArch.mc4 b/src/main/grammars/de/monticore/lang/monticar/CNNArch.mc4 index bb28851..8731064 100644 --- a/src/main/grammars/de/monticore/lang/monticar/CNNArch.mc4 +++ b/src/main/grammars/de/monticore/lang/monticar/CNNArch.mc4 @@ -57,9 +57,11 @@ grammar CNNArch extends de.monticore.lang.math.Math { ArchSerialSequence = serialValues:(ArchSimpleExpression || "->")+; - ArchValueRange implements ArchValueSequence = "[" start:ArchSimpleExpression + ArchValueRange implements ArchValueSequence = start:ArchSimpleExpression (serial:"->" | parallel:"|") - ":" end:ArchSimpleExpression "]"; + ".." + (serial2:"->" | parallel2:"|") + end:ArchSimpleExpression; ArchSimpleExpression = (arithmeticExpression:MathArithmeticExpression diff --git a/src/main/java/de/monticore/lang/monticar/cnnarch/Constraint.java b/src/main/java/de/monticore/lang/monticar/cnnarch/Constraint.java index 9415e29..f2f0841 100644 --- a/src/main/java/de/monticore/lang/monticar/cnnarch/Constraint.java +++ b/src/main/java/de/monticore/lang/monticar/cnnarch/Constraint.java @@ -51,28 +51,60 @@ public enum Constraint { INTEGER_TUPLE { @Override public boolean check(ArchSimpleExpressionSymbol exp) { - boolean res = false; - if (exp.isTuple()){ - //todo - } - return false; + return exp.isIntTuple().get(); } }, POSITIVE { @Override public boolean check(ArchSimpleExpressionSymbol exp) { + if (exp.getDoubleValue().isPresent()){ + return exp.getDoubleValue().get() > 0; + } + else if (exp.getDoubleTupleValues().isPresent()){ + boolean isPositive = true; + for (double value : exp.getDoubleTupleValues().get()){ + if (value <= 0){ + isPositive = false; + } + } + return isPositive; + } return false; } }, NON_NEGATIVE { @Override public boolean check(ArchSimpleExpressionSymbol exp) { + if (exp.getDoubleValue().isPresent()){ + return exp.getDoubleValue().get() >= 0; + } + else if (exp.getDoubleTupleValues().isPresent()){ + boolean isPositive = true; + for (double value : exp.getDoubleTupleValues().get()){ + if (value < 0){ + isPositive = false; + } + } + return isPositive; + } return false; } }, BETWEEN_ZERO_AND_ONE { @Override public boolean check(ArchSimpleExpressionSymbol exp) { + if (exp.getDoubleValue().isPresent()){ + return exp.getDoubleValue().get() >= 0 && exp.getDoubleValue().get() <= 1; + } + else if (exp.getDoubleTupleValues().isPresent()){ + boolean isPositive = true; + for (double value : exp.getDoubleTupleValues().get()){ + if (value < 0 || value > 1){ + isPositive = false; + } + } + return isPositive; + } return false; } }; diff --git a/src/main/java/de/monticore/lang/monticar/cnnarch/PredefinedMethods.java b/src/main/java/de/monticore/lang/monticar/cnnarch/PredefinedMethods.java index a2524df..dab8b7e 100644 --- a/src/main/java/de/monticore/lang/monticar/cnnarch/PredefinedMethods.java +++ b/src/main/java/de/monticore/lang/monticar/cnnarch/PredefinedMethods.java @@ -41,7 +41,7 @@ public class PredefinedMethods { public static final String AVG_POOLING_NAME = "AveragePooling"; public static final String LRN_NAME = "Lrn"; public static final String BATCHNORM_NAME = "BatchNorm"; - public static final String SPLIT_NAME = "Split"; + public static final String SPLIT_NAME = "SplitData"; public static final String GET_NAME = "Get"; public static final String ADD_NAME = "Add"; public static final String CONCATENATE_NAME = "Concatenate"; @@ -323,22 +323,36 @@ public class PredefinedMethods { } private static List strideShapeFunction(ShapeSymbol inputShape, MethodLayerSymbol method, int channels) { - int strideHeight = method.getIntTupleValue("stride").get().get(0); - int strideWidth = method.getIntTupleValue("stride").get().get(1); - int kernelHeight = method.getIntTupleValue("kernel").get().get(0); - int kernelWidth = method.getIntTupleValue("kernel").get().get(1); - int inputHeight = inputShape.getHeight().get(); - int inputWidth = inputShape.getWidth().get(); + Optional optGlobal = method.getBooleanValue("global"); + if (optGlobal.isPresent() && optGlobal.get()){ + return Collections.singletonList(new ShapeSymbol.Builder() + .height(1) + .width(1) + .channels(channels) + .build()); + } + else{ + int strideHeight = method.getIntTupleValue("stride").get().get(0); + int strideWidth = method.getIntTupleValue("stride").get().get(1); + int kernelHeight = method.getIntTupleValue("kernel").get().get(0); + int kernelWidth = method.getIntTupleValue("kernel").get().get(1); + int inputHeight = inputShape.getHeight().get(); + int inputWidth = inputShape.getWidth().get(); - //assume padding with border_mode='same' - int outputWidth = 1 + ((inputWidth - kernelWidth + strideWidth - 1) / strideWidth); - int outputHeight = 1 + ((inputHeight - kernelHeight + strideHeight - 1) / strideHeight); + //assume padding with border_mode='same' + int outputWidth = inputWidth / strideWidth; + int outputHeight = inputHeight / strideHeight; - return Collections.singletonList(new ShapeSymbol.Builder() - .height(outputHeight) - .width(outputWidth) - .channels(channels) - .build()); + //border_mode=valid + //int outputWidth = 1 + Math.max(0, ((inputWidth - kernelWidth + strideWidth - 1) / strideWidth)); + //int outputHeight = 1 + Math.max(0, ((inputHeight - kernelHeight + strideHeight - 1) / strideHeight)); + + return Collections.singletonList(new ShapeSymbol.Builder() + .height(outputHeight) + .width(outputWidth) + .channels(channels) + .build()); + } } private static List splitShapeFunction(ShapeSymbol inputShape, MethodLayerSymbol method) { diff --git a/src/main/java/de/monticore/lang/monticar/cnnarch/PredefinedVariables.java b/src/main/java/de/monticore/lang/monticar/cnnarch/PredefinedVariables.java index 24eeb91..b50ab09 100644 --- a/src/main/java/de/monticore/lang/monticar/cnnarch/PredefinedVariables.java +++ b/src/main/java/de/monticore/lang/monticar/cnnarch/PredefinedVariables.java @@ -27,6 +27,7 @@ public class PredefinedVariables { public static final String IF_NAME = "_if"; public static final String FOR_NAME = "_for"; + public static final String CARDINALITY_NAME = "_cardinality"; public static final String TRUE_NAME = "true"; public static final String FALSE_NAME = "false"; @@ -45,6 +46,13 @@ public class PredefinedVariables { .build(); } + public static VariableSymbol createCardinalityParameter(){ + return new VariableSymbol.Builder() + .name(CARDINALITY_NAME) + .defaultValue(1) + .build(); + } + //necessary because true is currently only a name in MontiMath and it needs to be evaluated at compile time for this language public static VariableSymbol createTrueConstant(){ return new VariableSymbol.Builder() diff --git a/src/main/java/de/monticore/lang/monticar/cnnarch/_symboltable/ArchExpressionSymbol.java b/src/main/java/de/monticore/lang/monticar/cnnarch/_symboltable/ArchExpressionSymbol.java index 0f2c93f..b86c8c5 100644 --- a/src/main/java/de/monticore/lang/monticar/cnnarch/_symboltable/ArchExpressionSymbol.java +++ b/src/main/java/de/monticore/lang/monticar/cnnarch/_symboltable/ArchExpressionSymbol.java @@ -22,9 +22,7 @@ package de.monticore.lang.monticar.cnnarch._symboltable; import de.monticore.symboltable.CommonSymbol; import de.monticore.symboltable.MutableScope; -import de.monticore.symboltable.Scope; import de.monticore.symboltable.Symbol; -import de.se_rwth.commons.logging.Log; import java.util.*; @@ -32,7 +30,7 @@ abstract public class ArchExpressionSymbol extends CommonSymbol { public static final ArchExpressionKind KIND = new ArchExpressionKind(); - private Set unresolvableNames = null; + private Set unresolvableVariables = null; public ArchExpressionSymbol() { super("", KIND); @@ -40,25 +38,28 @@ abstract public class ArchExpressionSymbol extends CommonSymbol { protected Boolean isResolvable(){ - Set set = getUnresolvableNames(); + Set set = getUnresolvableVariables(); return set != null && set.isEmpty(); } - public Set getUnresolvableNames() { - if (unresolvableNames == null){ - checkIfResolvable(); + public Set getUnresolvableVariables() { + if (unresolvableVariables == null){ + checkIfResolvable(new HashSet<>()); } - return unresolvableNames; + return unresolvableVariables; } - protected void setUnresolvableNames(Set unresolvableNames){ - this.unresolvableNames = unresolvableNames; + protected void setUnresolvableVariables(Set unresolvableVariables){ + this.unresolvableVariables = unresolvableVariables; } - public void checkIfResolvable(){ - setUnresolvableNames(computeUnresolvableNames()); + public void checkIfResolvable(Set seenVariables){ + Set unresolvableVariables = new HashSet<>(); + computeUnresolvableVariables(unresolvableVariables, seenVariables); + setUnresolvableVariables(unresolvableVariables); } + /** * Checks whether the value is a boolean. If true getValue() will return a Boolean if present. * @@ -99,21 +100,21 @@ abstract public class ArchExpressionSymbol extends CommonSymbol { public Optional isIntTuple(){ if (getValue().isPresent()){ - return Optional.of(getIntTupleValue().isPresent()); + return Optional.of(getIntTupleValues().isPresent()); } return Optional.empty(); } public Optional isNumberTuple(){ if (getValue().isPresent()){ - return Optional.of(getDoubleTupleValue().isPresent()); + return Optional.of(getDoubleTupleValues().isPresent()); } return Optional.empty(); } public Optional isBooleanTuple(){ if (getValue().isPresent()){ - return Optional.of(getBooleanTupleValue().isPresent()); + return Optional.of(getBooleanTupleValues().isPresent()); } return Optional.empty(); } @@ -194,8 +195,8 @@ abstract public class ArchExpressionSymbol extends CommonSymbol { return Optional.empty(); } - public Optional> getIntTupleValue(){ - Optional> optValue = getTupleValue(); + public Optional> getIntTupleValues(){ + Optional> optValue = getTupleValues(); if (optValue.isPresent()){ List list = new ArrayList<>(); for (Object value : optValue.get()) { @@ -211,8 +212,8 @@ abstract public class ArchExpressionSymbol extends CommonSymbol { return Optional.empty(); } - public Optional> getDoubleTupleValue() { - Optional> optValue = getTupleValue(); + public Optional> getDoubleTupleValues() { + Optional> optValue = getTupleValues(); if (optValue.isPresent()){ List list = new ArrayList<>(); for (Object value : optValue.get()) { @@ -231,8 +232,8 @@ abstract public class ArchExpressionSymbol extends CommonSymbol { return Optional.empty(); } - public Optional> getBooleanTupleValue() { - Optional> optValue = getTupleValue(); + public Optional> getBooleanTupleValues() { + Optional> optValue = getTupleValues(); if (optValue.isPresent()){ List list = new ArrayList<>(); for (Object value : optValue.get()) { @@ -248,9 +249,10 @@ abstract public class ArchExpressionSymbol extends CommonSymbol { return Optional.empty(); } - public Optional> getTupleValue(){ + public Optional> getTupleValues(){ if (getValue().isPresent()){ - if (isTuple()){ + Optional optValue = getValue(); + if (optValue.isPresent() && (optValue.get() instanceof List)){ @SuppressWarnings("unchecked") List list = (List) getValue().get(); return Optional.of(list); @@ -300,7 +302,7 @@ abstract public class ArchExpressionSymbol extends CommonSymbol { public void resolveOrError(){ resolve(); if (!isResolved()){ - throw new IllegalStateException("The following names could not be resolved: " + getUnresolvableNames()); + throw new IllegalStateException("The following names could not be resolved: " + getUnresolvableVariables()); } } @@ -315,13 +317,15 @@ abstract public class ArchExpressionSymbol extends CommonSymbol { */ abstract public Optional getValue(); + abstract public void reset(); + /** * Replaces all variable names in this values expression if possible. * The values of the variables depend on the current scope. The replacement is irreversible if successful. * * @return returns a set of all names which could not be resolved. */ - abstract public Set resolve(); + abstract public Set resolve(); /** * @return returns a optional of a list(parallel) of lists(serial) of simple expressions in this sequence. @@ -330,7 +334,7 @@ abstract public class ArchExpressionSymbol extends CommonSymbol { */ abstract public Optional>> getElements(); - abstract protected Set computeUnresolvableNames(); + abstract protected void computeUnresolvableVariables(Set unresolvableVariables, Set allVariables); /** * @return returns true if the expression is resolved. diff --git a/src/main/java/de/monticore/lang/monticar/cnnarch/_symboltable/ArchRangeExpressionSymbol.java b/src/main/java/de/monticore/lang/monticar/cnnarch/_symboltable/ArchRangeExpressionSymbol.java index 5de3449..f4201fe 100644 --- a/src/main/java/de/monticore/lang/monticar/cnnarch/_symboltable/ArchRangeExpressionSymbol.java +++ b/src/main/java/de/monticore/lang/monticar/cnnarch/_symboltable/ArchRangeExpressionSymbol.java @@ -21,7 +21,6 @@ package de.monticore.lang.monticar.cnnarch._symboltable; import de.monticore.symboltable.MutableScope; -import de.monticore.symboltable.Scope; import java.util.*; import java.util.stream.Collectors; @@ -63,6 +62,13 @@ public class ArchRangeExpressionSymbol extends ArchAbstractSequenceExpression { this.parallel = parallel; } + @Override + public void reset() { + getStartSymbol().reset(); + getEndSymbol().reset(); + setUnresolvableVariables(null); + } + @Override public boolean isParallelSequence() { return isParallel(); @@ -88,16 +94,15 @@ public class ArchRangeExpressionSymbol extends ArchAbstractSequenceExpression { }*/ @Override - public Set resolve() { + public Set resolve() { if (!isResolved()){ - checkIfResolvable(); if (isResolvable()){ getStartSymbol().resolveOrError(); getEndSymbol().resolveOrError(); } } - return getUnresolvableNames(); + return getUnresolvableVariables(); } @Override @@ -142,11 +147,11 @@ public class ArchRangeExpressionSymbol extends ArchAbstractSequenceExpression { } @Override - protected Set computeUnresolvableNames() { - Set unresolvableNames = new HashSet<>(); - unresolvableNames.addAll(getStartSymbol().computeUnresolvableNames()); - unresolvableNames.addAll(getEndSymbol().computeUnresolvableNames()); - return unresolvableNames; + protected void computeUnresolvableVariables(Set unresolvableVariables, Set allVariables) { + getStartSymbol().checkIfResolvable(allVariables); + unresolvableVariables.addAll(getStartSymbol().getUnresolvableVariables()); + getEndSymbol().checkIfResolvable(allVariables); + unresolvableVariables.addAll(getEndSymbol().getUnresolvableVariables()); } public ArchRangeExpressionSymbol copy(){ @@ -154,7 +159,7 @@ public class ArchRangeExpressionSymbol extends ArchAbstractSequenceExpression { copy.setParallel(isParallel()); copy.setStartSymbol(getStartSymbol().copy()); copy.setEndSymbol(getEndSymbol().copy()); - copy.setUnresolvableNames(getUnresolvableNames()); + copy.setUnresolvableVariables(getUnresolvableVariables()); return copy; } @@ -165,10 +170,11 @@ public class ArchRangeExpressionSymbol extends ArchAbstractSequenceExpression { getEndSymbol().putInScope(scope); } - public static ArchRangeExpressionSymbol of(ArchSimpleExpressionSymbol start, ArchSimpleExpressionSymbol end){ + public static ArchRangeExpressionSymbol of(ArchSimpleExpressionSymbol start, ArchSimpleExpressionSymbol end, boolean parallel){ ArchRangeExpressionSymbol sym = new ArchRangeExpressionSymbol(); sym.setStartSymbol(start); sym.setEndSymbol(end); + sym.setParallel(parallel); return sym; } } diff --git a/src/main/java/de/monticore/lang/monticar/cnnarch/_symboltable/ArchSequenceExpressionSymbol.java b/src/main/java/de/monticore/lang/monticar/cnnarch/_symboltable/ArchSequenceExpressionSymbol.java index 37b69b7..661fc5e 100644 --- a/src/main/java/de/monticore/lang/monticar/cnnarch/_symboltable/ArchSequenceExpressionSymbol.java +++ b/src/main/java/de/monticore/lang/monticar/cnnarch/_symboltable/ArchSequenceExpressionSymbol.java @@ -21,7 +21,6 @@ package de.monticore.lang.monticar.cnnarch._symboltable; import de.monticore.symboltable.MutableScope; -import de.monticore.symboltable.Scope; import java.util.*; @@ -47,6 +46,16 @@ public class ArchSequenceExpressionSymbol extends ArchAbstractSequenceExpression this.elements = elements; } + @Override + public void reset() { + for (List serialElements : _getElements()){ + for (ArchSimpleExpressionSymbol element : serialElements){ + element.reset(); + } + } + setUnresolvableVariables(null); + } + @Override public boolean isSerialSequence(){ boolean isSerial = !isParallelSequence(); @@ -64,10 +73,9 @@ public class ArchSequenceExpressionSymbol extends ArchAbstractSequenceExpression } @Override - public Set resolve() { - if (!isResolved()){ - checkIfResolvable(); - if (isResolvable()){ + public Set resolve() { + if (!isResolved()) { + if (isResolvable()) { for (List serialList : _getElements()) { for (ArchSimpleExpressionSymbol element : serialList) { @@ -76,7 +84,7 @@ public class ArchSequenceExpressionSymbol extends ArchAbstractSequenceExpression } } } - return getUnresolvableNames(); + return getUnresolvableVariables(); } @Override @@ -93,14 +101,13 @@ public class ArchSequenceExpressionSymbol extends ArchAbstractSequenceExpression } @Override - protected Set computeUnresolvableNames() { - Set unresolvableNames = new HashSet<>(); + protected void computeUnresolvableVariables(Set unresolvableVariables, Set allVariables) { for (List serialElements : _getElements()){ for (ArchSimpleExpressionSymbol element : serialElements){ - unresolvableNames.addAll(element.computeUnresolvableNames()); + element.checkIfResolvable(allVariables); + unresolvableVariables.addAll(element.getUnresolvableVariables()); } } - return unresolvableNames; } public ArchSequenceExpressionSymbol copy(){ @@ -114,7 +121,7 @@ public class ArchSequenceExpressionSymbol extends ArchAbstractSequenceExpression elementsCopy.add(serialListCopy); } copy.setElements(getElements().get()); - copy.setUnresolvableNames(getUnresolvableNames()); + copy.setUnresolvableVariables(getUnresolvableVariables()); return copy; } diff --git a/src/main/java/de/monticore/lang/monticar/cnnarch/_symboltable/ArchSimpleExpressionSymbol.java b/src/main/java/de/monticore/lang/monticar/cnnarch/_symboltable/ArchSimpleExpressionSymbol.java index 0f93985..4195149 100644 --- a/src/main/java/de/monticore/lang/monticar/cnnarch/_symboltable/ArchSimpleExpressionSymbol.java +++ b/src/main/java/de/monticore/lang/monticar/cnnarch/_symboltable/ArchSimpleExpressionSymbol.java @@ -53,6 +53,14 @@ public class ArchSimpleExpressionSymbol extends ArchExpressionSymbol implements this.value = value; } + @Override + public void reset(){ + if (getMathExpression().isPresent()){ + setValue(null); + setUnresolvableVariables(null); + } + } + @Override public boolean isSimpleValue() { return true; @@ -60,7 +68,7 @@ public class ArchSimpleExpressionSymbol extends ArchExpressionSymbol implements @Override public boolean isBoolean() { - if (getMathExpression().isPresent()){ + if (getMathExpression().isPresent() && !(getMathExpression().get() instanceof MathNameExpressionSymbol)){ return getMathExpression().get() instanceof MathCompareExpressionSymbol; } else { @@ -70,7 +78,7 @@ public class ArchSimpleExpressionSymbol extends ArchExpressionSymbol implements @Override public boolean isNumber() { - if (getMathExpression().isPresent()){ + if (getMathExpression().isPresent() && !(getMathExpression().get() instanceof MathNameExpressionSymbol)){ return getMathExpression().get() instanceof MathArithmeticExpressionSymbol; } else { @@ -80,37 +88,31 @@ public class ArchSimpleExpressionSymbol extends ArchExpressionSymbol implements @Override public boolean isTuple() { - if (getMathExpression().isPresent()){ + if (getMathExpression().isPresent() && !(getMathExpression().get() instanceof MathNameExpressionSymbol)){ return getMathExpression().get() instanceof TupleExpressionSymbol; } else { - return getValue().get() instanceof List; + return getTupleValues().isPresent(); } } - @Override - protected Set computeUnresolvableNames() { - Set unresolvableNames = new HashSet<>(); - Set allNames = new HashSet<>(); - computeUnresolvableNames(unresolvableNames, allNames); - return unresolvableNames; - } - - protected void computeUnresolvableNames(Set unresolvableNames, Set allNames) { + protected void computeUnresolvableVariables(Set unresolvableVariables, Set allVariables) { if (getMathExpression().isPresent()) { for (MathExpressionSymbol exp : ExpressionHelper.createSubExpressionList(getMathExpression().get())) { if (exp instanceof MathNameExpressionSymbol) { String name = ((MathNameExpressionSymbol) exp).getNameToAccess(); - if (!allNames.contains(name)) { - allNames.add(name); - Optional variable = getEnclosingScope().resolve(name, VariableSymbol.KIND); - if (variable.isPresent() && !variable.get().getExpression().isResolved()) { - if (variable.get().hasValue()) { - variable.get().getExpression().computeUnresolvableNames(unresolvableNames, allNames); - } else { - unresolvableNames.add(name); + Optional variable = getEnclosingScope().resolve(name, VariableSymbol.KIND); + //todo: implement coco to check isPresent() + if (!allVariables.contains(variable.get())) { + allVariables.add(variable.get()); + if (variable.get().hasValue()) { + if (!variable.get().getExpression().isResolved()) { + variable.get().getExpression().computeUnresolvableVariables(unresolvableVariables, allVariables); } } + else { + unresolvableVariables.add(variable.get()); + } } } } @@ -118,49 +120,68 @@ public class ArchSimpleExpressionSymbol extends ArchExpressionSymbol implements } @Override - public Set resolve() { - checkIfResolvable(); - if (getMathExpression().isPresent() && isResolvable()) { - Object value; - if (isTuple()){ - TupleExpressionSymbol tuple = (TupleExpressionSymbol) getMathExpression().get(); - List tupleValues = new ArrayList<>(tuple.getExpressions().size()); - for (MathExpressionSymbol exp : tuple.getExpressions()){ - tupleValues.add(computeValue()); - } - value = tupleValues; - } - else { - value = computeValue(); + public Set resolve() { + if (!isResolved()) { + if (getMathExpression().isPresent() && isResolvable()) { + Object value = computeValue(); + setValue(value); } - setValue(value); } - return getUnresolvableNames(); + return getUnresolvableVariables(); } private Object computeValue(){ - Map replacementMap = new HashMap<>(); - for (MathExpressionSymbol exp : ExpressionHelper.createSubExpressionList(getMathExpression().get())) { - if (exp instanceof MathNameExpressionSymbol) { - String name = ((MathNameExpressionSymbol) exp).getNameToAccess(); - VariableSymbol variable = (VariableSymbol) getEnclosingScope().resolve(name, VariableSymbol.KIND).get(); - if (!variable.getExpression().isResolved()) { - variable.getExpression().resolveOrError(); + if (getMathExpression().get() instanceof MathNameExpressionSymbol){ + return computeValue((MathNameExpressionSymbol) getMathExpression().get()); + } + else if (getMathExpression().get() instanceof TupleExpressionSymbol){ + Map replacementMap = new HashMap<>(); + List valueList = new ArrayList<>(); + TupleExpressionSymbol tuple = (TupleExpressionSymbol) getMathExpression().get(); + for (MathExpressionSymbol mathExp : tuple.getExpressions()){ + if (mathExp instanceof MathNameExpressionSymbol){ + valueList.add(computeValue((MathNameExpressionSymbol) mathExp)); } + else { + ArchSimpleExpressionSymbol temp = ArchSimpleExpressionSymbol.of(mathExp); + temp.setEnclosingScope(getEnclosingScope().getAsMutableScope()); + temp.resolveOrError(); + valueList.add(temp.getValue().get()); + getEnclosingScope().getAsMutableScope().remove(temp); + } + } + return valueList; + } + else { + Map replacementMap = new HashMap<>(); + for (MathExpressionSymbol exp : ExpressionHelper.createSubExpressionList(getMathExpression().get())) { + if (exp instanceof MathNameExpressionSymbol) { + String name = ((MathNameExpressionSymbol) exp).getNameToAccess(); + VariableSymbol variable = (VariableSymbol) getEnclosingScope().resolve(name, VariableSymbol.KIND).get(); + variable.getExpression().resolveOrError(); - replacementMap.put(name, variable.getExpression().getTextualRepresentation()); + replacementMap.put(name, variable.getExpression().getTextualRepresentation()); + } } + + String resolvedString = ExpressionHelper.replace(getTextualRepresentation(), replacementMap); + return Calculator.getInstance().calculate(resolvedString); } + } + + private Object computeValue(MathNameExpressionSymbol mathExpression){ + String name = ((MathNameExpressionSymbol) mathExpression).getNameToAccess(); + VariableSymbol variable = (VariableSymbol) getEnclosingScope().resolve(name, VariableSymbol.KIND).get(); + variable.getExpression().resolveOrError(); - String resolvedString = ExpressionHelper.replace(getTextualRepresentation(), replacementMap); - return Calculator.getInstance().calculate(resolvedString); + return variable.getExpression().getValue().get(); } @Override public String getTextualRepresentation() { if (isResolved()){ if (isTuple()){ - return ExpressionHelper.createTupleTextualRepresentation(getTupleValue().get(), Object::toString); + return ExpressionHelper.createTupleTextualRepresentation(getTupleValues().get(), Object::toString); } else { return getValue().get().toString(); @@ -184,8 +205,11 @@ public class ArchSimpleExpressionSymbol extends ArchExpressionSymbol implements public ArchSimpleExpressionSymbol copy(){ ArchSimpleExpressionSymbol copy = new ArchSimpleExpressionSymbol(); //copy.setMathExpression(mathExpression); + if (!getValue().isPresent()){ + throw new IllegalStateException(); + } copy.setValue(value); - copy.setUnresolvableNames(getUnresolvableNames()); + copy.setUnresolvableVariables(getUnresolvableVariables()); return copy; } diff --git a/src/main/java/de/monticore/lang/monticar/cnnarch/_symboltable/ArchitectureSymbol.java b/src/main/java/de/monticore/lang/monticar/cnnarch/_symboltable/ArchitectureSymbol.java index 7c86518..72bdb16 100644 --- a/src/main/java/de/monticore/lang/monticar/cnnarch/_symboltable/ArchitectureSymbol.java +++ b/src/main/java/de/monticore/lang/monticar/cnnarch/_symboltable/ArchitectureSymbol.java @@ -23,7 +23,11 @@ package de.monticore.lang.monticar.cnnarch._symboltable; +import de.monticore.symboltable.SymbolKind; + +import java.util.HashSet; import java.util.List; +import java.util.Set; public class ArchitectureSymbol extends ArchitectureSymbolTOP { @@ -64,5 +68,17 @@ public class ArchitectureSymbol extends ArchitectureSymbolTOP { this.outputs = outputs; } + public Set resolve(){ + getBody().checkIfResolvable(); + Set set = getBody().resolve(); + return set; + } + public boolean isResolved(){ + return getBody().isResolved(); + } + + public Set getUnresolvableVariables(){ + return getBody().getUnresolvableVariables(); + } } diff --git a/src/main/java/de/monticore/lang/monticar/cnnarch/_symboltable/ArgumentSymbol.java b/src/main/java/de/monticore/lang/monticar/cnnarch/_symboltable/ArgumentSymbol.java index 67b12bd..6e828dc 100644 --- a/src/main/java/de/monticore/lang/monticar/cnnarch/_symboltable/ArgumentSymbol.java +++ b/src/main/java/de/monticore/lang/monticar/cnnarch/_symboltable/ArgumentSymbol.java @@ -20,6 +20,7 @@ */ package de.monticore.lang.monticar.cnnarch._symboltable; +import de.monticore.lang.monticar.cnnarch.PredefinedVariables; import de.monticore.symboltable.CommonSymbol; import de.monticore.symboltable.MutableScope; import de.monticore.symboltable.Symbol; @@ -57,13 +58,31 @@ public class ArgumentSymbol extends CommonSymbol { } protected void setRhs(ArchExpressionSymbol rhs) { - this.rhs = rhs; + if (getName().equals(PredefinedVariables.FOR_NAME) + && rhs instanceof ArchSimpleExpressionSymbol + && (!rhs.getValue().isPresent() || !rhs.getValue().get().equals(1))){ + this.rhs = ArchRangeExpressionSymbol.of( + ArchSimpleExpressionSymbol.of(1), + (ArchSimpleExpressionSymbol) rhs, + false); + } + else if (getName().equals(PredefinedVariables.CARDINALITY_NAME) + && rhs instanceof ArchSimpleExpressionSymbol + && (!rhs.getValue().isPresent() || !rhs.getValue().get().equals(1))) { + this.rhs = ArchRangeExpressionSymbol.of( + ArchSimpleExpressionSymbol.of(1), + (ArchSimpleExpressionSymbol) rhs, + true); + } + else { + this.rhs = rhs; + } } //do not call if value is a sequence public void set(){ if (getRhs().isSimpleValue()){ - getParameter().setExpression((ArchSimpleExpressionSymbol) getRhs()); + getParameter().setExpression((ArchSimpleExpressionSymbol) getRhs().copy()); } else { throw new IllegalStateException("The value of the parameter is set to a sequence. This should never happen."); @@ -80,7 +99,12 @@ public class ArgumentSymbol extends CommonSymbol { for (List serialElementList : elements){ List serialArgumentList = new ArrayList<>(serialElementList.size()); for (ArchSimpleExpressionSymbol element : serialElementList){ - ArgumentSymbol argument = new Builder().parameter(getParameter()).value(element).build(); + ArchSimpleExpressionSymbol value = element; + if (getName().equals(PredefinedVariables.FOR_NAME) || getName().equals(PredefinedVariables.CARDINALITY_NAME)){ + value = ArchSimpleExpressionSymbol.of(1); + } + + ArgumentSymbol argument = new Builder().parameter(getParameter()).value(value).build(); serialArgumentList.add(argument); } arguments.add(serialArgumentList); diff --git a/src/main/java/de/monticore/lang/monticar/cnnarch/_symboltable/CNNArchSymbolTableCreator.java b/src/main/java/de/monticore/lang/monticar/cnnarch/_symboltable/CNNArchSymbolTableCreator.java index 7202a7b..9e856b1 100644 --- a/src/main/java/de/monticore/lang/monticar/cnnarch/_symboltable/CNNArchSymbolTableCreator.java +++ b/src/main/java/de/monticore/lang/monticar/cnnarch/_symboltable/CNNArchSymbolTableCreator.java @@ -172,36 +172,38 @@ public class CNNArchSymbolTableCreator extends de.monticore.symboltable.CommonSy //todo } + @Override + public void visit(ASTShape ast) { + ShapeSymbol sym = new ShapeSymbol(); + addToScopeAndLinkWithNode(sym, ast); + } + @Override public void endVisit(ASTShape node) { - ShapeSymbol sym; + ShapeSymbol sym = (ShapeSymbol) node.getSymbol().get(); if (node.getDimensions().size() == 1){ - sym = new ShapeSymbol.Builder() - .channels((DimensionSymbol) node.getDimensions().get(0).getSymbol().get()) - .build(); + sym.setChannels((ArchSimpleExpressionSymbol) node.getDimensions().get(0).getSymbol().get()); } else if (node.getDimensions().size() == 3){ - sym = new ShapeSymbol.Builder() - .height((DimensionSymbol) node.getDimensions().get(ShapeSymbol.HEIGHT_INDEX - 1).getSymbol().get()) - .width((DimensionSymbol) node.getDimensions().get(ShapeSymbol.WIDTH_INDEX - 1).getSymbol().get()) - .channels((DimensionSymbol) node.getDimensions().get(ShapeSymbol.CHANNEL_INDEX - 1).getSymbol().get()) - .build(); + sym.setHeight((ArchSimpleExpressionSymbol) node.getDimensions().get(ShapeSymbol.HEIGHT_INDEX - 1).getSymbol().get()); + sym.setWidth((ArchSimpleExpressionSymbol) node.getDimensions().get(ShapeSymbol.WIDTH_INDEX - 1).getSymbol().get()); + sym.setChannels((ArchSimpleExpressionSymbol) node.getDimensions().get(ShapeSymbol.CHANNEL_INDEX - 1).getSymbol().get()); } else { //todo - throw new IllegalStateException(); + throw new IllegalStateException("todo: incorrect shape"); } addToScopeAndLinkWithNode(sym, node); } @Override public void endVisit(ASTDimension node) { - DimensionSymbol sym; + ArchSimpleExpressionSymbol sym; if (node.getIntLiteral().isPresent()){ - sym = DimensionSymbol.of(node.getIntLiteral().get().getNumber().get().getDividend().intValue()); + sym = ArchSimpleExpressionSymbol.of(node.getIntLiteral().get().getNumber().get().getDividend().intValue()); } else { - sym = DimensionSymbol.of((VariableSymbol) node.getIOVariable().get().getSymbol().get()); + sym = ArchSimpleExpressionSymbol.of((VariableSymbol) node.getIOVariable().get().getSymbol().get()); } addToScopeAndLinkWithNode(sym, node); } @@ -212,6 +214,7 @@ public class CNNArchSymbolTableCreator extends de.monticore.symboltable.CommonSy .name(node.getName()) .type(VariableType.IOVariable) .build(); + addToScope(ArchSimpleExpressionSymbol.of(ioVariable)); addToScopeAndLinkWithNode(ioVariable, node); } diff --git a/src/main/java/de/monticore/lang/monticar/cnnarch/_symboltable/CompositeLayerSymbol.java b/src/main/java/de/monticore/lang/monticar/cnnarch/_symboltable/CompositeLayerSymbol.java index 8c0061e..df8cf74 100644 --- a/src/main/java/de/monticore/lang/monticar/cnnarch/_symboltable/CompositeLayerSymbol.java +++ b/src/main/java/de/monticore/lang/monticar/cnnarch/_symboltable/CompositeLayerSymbol.java @@ -20,7 +20,6 @@ */ package de.monticore.lang.monticar.cnnarch._symboltable; -import de.monticore.symboltable.MutableScope; import de.monticore.symboltable.Symbol; import java.util.*; @@ -83,15 +82,24 @@ public class CompositeLayerSymbol extends LayerSymbol { } @Override - public Set resolve() { - checkIfResolvable(); - if (isResolvable()){ - List resolvedLayers = new ArrayList<>(); - for (LayerSymbol layer : getLayers()){ - layer.resolve(); + public void reset() { + for (LayerSymbol layer : getLayers()){ + layer.reset(); + } + setUnresolvableVariables(null); + } + + @Override + public Set resolve() { + if (!isResolved()) { + if (isResolvable()) { + List resolvedLayers = new ArrayList<>(); + for (LayerSymbol layer : getLayers()) { + layer.resolve(); + } } } - return getUnresolvableNames(); + return getUnresolvableVariables(); } @Override @@ -113,13 +121,11 @@ public class CompositeLayerSymbol extends LayerSymbol { } @Override - protected Set computeUnresolvableNames() { - Set unresolvableSet = new HashSet<>(); + protected void computeUnresolvableVariables(Set unresolvableVariables, Set allVariables) { for (LayerSymbol layer : getLayers()){ - layer.checkIfResolvable(); - unresolvableSet.addAll(layer.getUnresolvableNames()); + layer.checkIfResolvable(allVariables); + unresolvableVariables.addAll(layer.getUnresolvableVariables()); } - return unresolvableSet; } @Override @@ -139,6 +145,9 @@ public class CompositeLayerSymbol extends LayerSymbol { return outputShapes; } else { + for (LayerSymbol layer : getLayers()){ + layer.getOutputShapes(); + } return getLayers().get(getLayers().size() - 1).getOutputShapes(); } } diff --git a/src/main/java/de/monticore/lang/monticar/cnnarch/_symboltable/DimensionKind.java b/src/main/java/de/monticore/lang/monticar/cnnarch/_symboltable/DimensionKind.java deleted file mode 100644 index 4126dfb..0000000 --- a/src/main/java/de/monticore/lang/monticar/cnnarch/_symboltable/DimensionKind.java +++ /dev/null @@ -1,39 +0,0 @@ -/** - * - * ****************************************************************************** - * MontiCAR Modeling Family, www.se-rwth.de - * Copyright (c) 2017, Software Engineering Group at RWTH Aachen, - * All rights reserved. - * - * This project is free software; you can redistribute it and/or - * modify it under the terms of the GNU Lesser General Public - * License as published by the Free Software Foundation; either - * version 3.0 of the License, or (at your option) any later version. - * This library is distributed in the hope that it will be useful, - * but WITHOUT ANY WARRANTY; without even the implied warranty of - * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the GNU - * Lesser General Public License for more details. - * - * You should have received a copy of the GNU Lesser General Public - * License along with this project. If not, see . - * ******************************************************************************* - */ -package de.monticore.lang.monticar.cnnarch._symboltable; - -import de.monticore.symboltable.SymbolKind; - -public class DimensionKind implements SymbolKind { - - private static final String NAME = "de.monticore.lang.monticar.cnnarch._symboltable.DimensionKind"; - - @Override - public String getName() { - return NAME; - } - - @Override - public boolean isKindOf(SymbolKind kind) { - return NAME.equals(kind.getName()) || SymbolKind.super.isKindOf(kind); - } - -} diff --git a/src/main/java/de/monticore/lang/monticar/cnnarch/_symboltable/DimensionSymbol.java b/src/main/java/de/monticore/lang/monticar/cnnarch/_symboltable/DimensionSymbol.java deleted file mode 100644 index d2d778d..0000000 --- a/src/main/java/de/monticore/lang/monticar/cnnarch/_symboltable/DimensionSymbol.java +++ /dev/null @@ -1,71 +0,0 @@ -/** - * - * ****************************************************************************** - * MontiCAR Modeling Family, www.se-rwth.de - * Copyright (c) 2017, Software Engineering Group at RWTH Aachen, - * All rights reserved. - * - * This project is free software; you can redistribute it and/or - * modify it under the terms of the GNU Lesser General Public - * License as published by the Free Software Foundation; either - * version 3.0 of the License, or (at your option) any later version. - * This library is distributed in the hope that it will be useful, - * but WITHOUT ANY WARRANTY; without even the implied warranty of - * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the GNU - * Lesser General Public License for more details. - * - * You should have received a copy of the GNU Lesser General Public - * License along with this project. If not, see . - * ******************************************************************************* - */ -package de.monticore.lang.monticar.cnnarch._symboltable; - -import de.monticore.symboltable.CommonSymbol; - -import java.util.Optional; - -public class DimensionSymbol extends CommonSymbol { - - public static final DimensionKind KIND = new DimensionKind(); - - private ArchSimpleExpressionSymbol valueExpression; - private VariableSymbol ioVariable; - - protected DimensionSymbol() { - super("", KIND); - } - - public ArchSimpleExpressionSymbol getValueExpression() { - return valueExpression; - } - - public void setValueExpression(ArchSimpleExpressionSymbol valueExpression) { - this.valueExpression = valueExpression; - } - - public Optional getIoVariable() { - return Optional.ofNullable(ioVariable); - } - - public void setIoVariable(VariableSymbol ioVariable) { - this.ioVariable = ioVariable; - } - - public Optional getValue(){ - Optional optObj = getValueExpression().getValue(); - return optObj.map(o -> (Integer) o); - } - - public static DimensionSymbol of(int value){ - DimensionSymbol sym = new DimensionSymbol(); - sym.setValueExpression(ArchSimpleExpressionSymbol.of(value)); - return sym; - } - - public static DimensionSymbol of(VariableSymbol variable){ - DimensionSymbol sym = new DimensionSymbol(); - sym.setValueExpression(ArchSimpleExpressionSymbol.of(variable)); - sym.setIoVariable(variable); - return sym; - } -} diff --git a/src/main/java/de/monticore/lang/monticar/cnnarch/_symboltable/IODeclarationSymbol.java b/src/main/java/de/monticore/lang/monticar/cnnarch/_symboltable/IODeclarationSymbol.java index 91e7672..a8aef9e 100644 --- a/src/main/java/de/monticore/lang/monticar/cnnarch/_symboltable/IODeclarationSymbol.java +++ b/src/main/java/de/monticore/lang/monticar/cnnarch/_symboltable/IODeclarationSymbol.java @@ -33,7 +33,7 @@ public class IODeclarationSymbol extends CommonSymbol { private ASTElementType type; private ShapeSymbol shape; private boolean input; //true->input, false->output - private int arrayLength = 0; + private int arrayLength = 1; protected IODeclarationSymbol(String name) { diff --git a/src/main/java/de/monticore/lang/monticar/cnnarch/_symboltable/IOLayerSymbol.java b/src/main/java/de/monticore/lang/monticar/cnnarch/_symboltable/IOLayerSymbol.java index 5bda953..8516cf8 100644 --- a/src/main/java/de/monticore/lang/monticar/cnnarch/_symboltable/IOLayerSymbol.java +++ b/src/main/java/de/monticore/lang/monticar/cnnarch/_symboltable/IOLayerSymbol.java @@ -21,7 +21,6 @@ package de.monticore.lang.monticar.cnnarch._symboltable; import de.monticore.lang.monticar.cnnarch.ErrorMessages; -import de.monticore.symboltable.MutableScope; import de.monticore.symboltable.Symbol; import de.se_rwth.commons.logging.Log; @@ -66,13 +65,32 @@ public class IOLayerSymbol extends LayerSymbol { } @Override - public Set resolve() { - checkIfResolvable(); - if (isResolvable()){ - resolveExpressions(); - getDefinition().getShape().resolve(); + public boolean isInput(){ + return getDefinition().isInput(); + } + + @Override + public boolean isOutput(){ + return getDefinition().isOutput(); + } + + @Override + public void reset() { + setUnresolvableVariables(null); + if (getArrayAccess().isPresent()){ + getArrayAccess().get().reset(); + } + } + + @Override + public Set resolve() { + if (!isResolved()) { + if (isResolvable()) { + resolveExpressions(); + getDefinition().getShape().resolve(); + } } - return getUnresolvableNames(); + return getUnresolvableVariables(); } @Override @@ -84,20 +102,20 @@ public class IOLayerSymbol extends LayerSymbol { } } //todo getShape().isResolved - if (!getDefinition().getShape().computeUnresolvableNames().isEmpty()){ + if (!getDefinition().getShape().isResolved()){ isResolved = false; } return isResolved; } @Override - protected Set computeUnresolvableNames() { - HashSet unresolvableNames = new HashSet<>(); + protected void computeUnresolvableVariables(Set unresolvableVariables, Set allVariables) { if (getArrayAccess().isPresent()){ - unresolvableNames.addAll(getArrayAccess().get().computeUnresolvableNames()); + getArrayAccess().get().checkIfResolvable(allVariables); + unresolvableVariables.addAll(getArrayAccess().get().getUnresolvableVariables()); } - unresolvableNames.addAll(getDefinition().getShape().computeUnresolvableNames()); - return unresolvableNames; + getDefinition().getShape().checkIfResolvable(allVariables); + unresolvableVariables.addAll(getDefinition().getShape().getUnresolvableVariables()); } @Override @@ -145,7 +163,7 @@ public class IOLayerSymbol extends LayerSymbol { arrayAccessCopy = getArrayAccess().get().copy(); } IOLayerSymbol copy = new Builder() - .name(getName()) + .definition(getDefinition()) .arrayAccess(arrayAccessCopy) .build(); return copy; @@ -160,7 +178,7 @@ public class IOLayerSymbol extends LayerSymbol { public static class Builder{ private ArchSimpleExpressionSymbol arrayAccess = null; - private String name; + private IODeclarationSymbol definition; public Builder arrayAccess(ArchSimpleExpressionSymbol arrayAccess){ this.arrayAccess = arrayAccess; @@ -172,16 +190,16 @@ public class IOLayerSymbol extends LayerSymbol { return this; } - public Builder name(String name){ - this.name = name; + public Builder definition(IODeclarationSymbol definition){ + this.definition = definition; return this; } public IOLayerSymbol build(){ - if (name == null || name.equals("")){ - throw new IllegalStateException("Missing or empty name for IOLayerSymbol"); + if (definition == null){ + throw new IllegalStateException("Missing or definition for IOLayerSymbol"); } - IOLayerSymbol sym = new IOLayerSymbol(name); + IOLayerSymbol sym = new IOLayerSymbol(definition.getName()); sym.setArrayAccess(arrayAccess); return sym; } diff --git a/src/main/java/de/monticore/lang/monticar/cnnarch/_symboltable/LayerSymbol.java b/src/main/java/de/monticore/lang/monticar/cnnarch/_symboltable/LayerSymbol.java index a011f51..f6df94a 100644 --- a/src/main/java/de/monticore/lang/monticar/cnnarch/_symboltable/LayerSymbol.java +++ b/src/main/java/de/monticore/lang/monticar/cnnarch/_symboltable/LayerSymbol.java @@ -21,8 +21,7 @@ package de.monticore.lang.monticar.cnnarch._symboltable; import de.monticore.symboltable.CommonScopeSpanningSymbol; -import de.monticore.symboltable.MutableScope; -import de.monticore.symboltable.Scope; +import de.se_rwth.commons.Joiners; import java.util.HashSet; import java.util.List; @@ -35,7 +34,7 @@ public abstract class LayerSymbol extends CommonScopeSpanningSymbol { private LayerSymbol inputLayer; private List outputShapes = null; - private Set unresolvableNames = null; + private Set unresolvableVariables = null; protected LayerSymbol(String name) { super(name, KIND); @@ -86,37 +85,43 @@ public abstract class LayerSymbol extends CommonScopeSpanningSymbol { return false; } - public Set getUnresolvableNames() { - if (unresolvableNames == null){ + public Set getUnresolvableVariables() { + if (unresolvableVariables == null){ checkIfResolvable(); } - return unresolvableNames; + return unresolvableVariables; } - protected void setUnresolvableNames(Set unresolvableNames) { - this.unresolvableNames = unresolvableNames; + protected void setUnresolvableVariables(Set unresolvableVariables) { + this.unresolvableVariables = unresolvableVariables; } public boolean isResolvable(){ - return getUnresolvableNames().isEmpty(); + return getUnresolvableVariables().isEmpty(); } public void checkIfResolvable(){ - setUnresolvableNames(computeUnresolvableNames()); + checkIfResolvable(new HashSet<>()); + } + + protected void checkIfResolvable(Set occurringVariables){ + Set unresolvableVariables = new HashSet<>(); + computeUnresolvableVariables(unresolvableVariables, occurringVariables); + setUnresolvableVariables(unresolvableVariables); } public void resolveOrError(){ - Set names = resolve(); + Set names = resolve(); if (!isResolved()){ - throw new IllegalStateException("The following names could not be resolved: " + getUnresolvableNames()); + throw new IllegalStateException("The following names could not be resolved: " + Joiners.COMMA.join(getUnresolvableVariables())); } } - abstract public Set resolve(); + abstract public Set resolve(); abstract protected List computeOutputShapes(); - abstract protected Set computeUnresolvableNames(); + abstract protected void computeUnresolvableVariables(Set unresolvableVariables, Set allVariables); abstract public boolean isResolved(); @@ -140,12 +145,12 @@ public abstract class LayerSymbol extends CommonScopeSpanningSymbol { } } - /*abstract protected void putInScope(MutableScope scope);*/ - //deepCopy for LayerSymbols, ArgumentSymbol and ArchExpressionSymbols but does not copy math expressions or scope and ast information. abstract public LayerSymbol copy(); abstract protected void putInScope(LayerScope scope); abstract protected void resolveExpressions(); + + abstract public void reset(); } diff --git a/src/main/java/de/monticore/lang/monticar/cnnarch/_symboltable/MethodDeclarationSymbol.java b/src/main/java/de/monticore/lang/monticar/cnnarch/_symboltable/MethodDeclarationSymbol.java index ef7595a..cc0af09 100644 --- a/src/main/java/de/monticore/lang/monticar/cnnarch/_symboltable/MethodDeclarationSymbol.java +++ b/src/main/java/de/monticore/lang/monticar/cnnarch/_symboltable/MethodDeclarationSymbol.java @@ -25,8 +25,6 @@ package de.monticore.lang.monticar.cnnarch._symboltable; import de.monticore.lang.monticar.cnnarch.PredefinedVariables; import de.monticore.symboltable.CommonScopeSpanningSymbol; -import de.monticore.symboltable.MutableScope; -import de.monticore.symboltable.Symbol; import java.util.*; import java.util.function.BiFunction; @@ -69,6 +67,11 @@ public class MethodDeclarationSymbol extends CommonScopeSpanningSymbol { this.parameters.add(forParam); forParam.putInScope(getSpannedScope()); } + if (!getParameter(PredefinedVariables.CARDINALITY_NAME).isPresent()){ + VariableSymbol forParam = PredefinedVariables.createCardinalityParameter(); + this.parameters.add(forParam); + forParam.putInScope(getSpannedScope()); + } } public CompositeLayerSymbol getBody() { @@ -103,30 +106,41 @@ public class MethodDeclarationSymbol extends CommonScopeSpanningSymbol { public LayerSymbol call(MethodLayerSymbol layer) { + checkForSequence(layer.getArguments()); + + if (isPredefined()){ + return layer; + } + else { + set(layer.getArguments()); + getBody().resolveOrError(); + CompositeLayerSymbol copy = getBody().copy(); + reset(); + return copy; + } + } + + private void reset(){ + for (VariableSymbol param : getParameters()){ + param.reset(); + } + getBody().reset(); + } + + private void set(List arguments){ + for (ArgumentSymbol arg : arguments){ + arg.set(); + } + } + + private void checkForSequence(List arguments){ boolean valid = true; - for (ArgumentSymbol arg : layer.getArguments()){ + for (ArgumentSymbol arg : arguments){ if (arg.getRhs() instanceof ArchAbstractSequenceExpression){ valid = false; } } - - if (valid){ - if (isPredefined()){ - return layer; - } - else { - for (VariableSymbol param : getParameters()){ - param.reset(); - } - for (ArgumentSymbol arg : layer.getArguments()){ - arg.set(); - } - - getBody().resolveOrError(); - return getBody().copy(); - } - } - else { + if (!valid){ throw new IllegalArgumentException("Arguments with sequence expressions have to be resolved first before calling the method."); } } diff --git a/src/main/java/de/monticore/lang/monticar/cnnarch/_symboltable/MethodLayerSymbol.java b/src/main/java/de/monticore/lang/monticar/cnnarch/_symboltable/MethodLayerSymbol.java index 06cf5ff..56f4a6b 100644 --- a/src/main/java/de/monticore/lang/monticar/cnnarch/_symboltable/MethodLayerSymbol.java +++ b/src/main/java/de/monticore/lang/monticar/cnnarch/_symboltable/MethodLayerSymbol.java @@ -24,10 +24,13 @@ package de.monticore.lang.monticar.cnnarch._symboltable; import de.monticore.lang.monticar.cnnarch.ErrorMessages; import de.monticore.lang.monticar.cnnarch.PredefinedVariables; import de.monticore.symboltable.Symbol; +import de.se_rwth.commons.Joiners; import de.se_rwth.commons.logging.Log; import java.util.*; import java.util.function.BiFunction; +import java.util.function.Function; +import java.util.function.Supplier; public class MethodLayerSymbol extends LayerSymbol { @@ -60,9 +63,9 @@ public class MethodLayerSymbol extends LayerSymbol { } private void setMethod(MethodDeclarationSymbol method) { - if (method.isPredefined()){ + /*if (method.isPredefined()){ setResolvedThis(this); - } + }*/ this.method = method; } @@ -75,11 +78,13 @@ public class MethodLayerSymbol extends LayerSymbol { } public ArchExpressionSymbol getIfExpression(){ - return getMethod().getParameter(PredefinedVariables.IF_NAME).get().getExpression(); - } - - public ArchExpressionSymbol getForExpression(){ - return getMethod().getParameter(PredefinedVariables.FOR_NAME).get().getExpression(); + Optional argument = getArgument(PredefinedVariables.IF_NAME); + if (argument.isPresent()){ + return argument.get().getRhs(); + } + else { + return ArchSimpleExpressionSymbol.of(true); + } } public Optional getResolvedThis() { @@ -89,21 +94,47 @@ public class MethodLayerSymbol extends LayerSymbol { protected void setResolvedThis(LayerSymbol resolvedThis) { if (resolvedThis != null && resolvedThis != this){ resolvedThis.putInScope(getSpannedScope()); + if (getInputLayer().isPresent()){ + resolvedThis.setInputLayer(getInputLayer().get()); + } } this.resolvedThis = resolvedThis; } + @Override + public void setInputLayer(LayerSymbol inputLayer) { + super.setInputLayer(inputLayer); + if (getResolvedThis().isPresent()){ + getResolvedThis().get().setInputLayer(inputLayer); + } + } + @Override protected void putInScope(LayerScope scope){ Collection symbolsInScope = scope.getLocalSymbols().get(getName()); if (symbolsInScope == null || !symbolsInScope.contains(this)){ scope.add(this); + /*if (getResolvedThis().isPresent()){ + getResolvedThis().get().putInScope(getSpannedScope()); + }*/ for (ArgumentSymbol argument : getArguments()){ argument.putInScope(getSpannedScope()); } } } + @Override + public void reset() { + /*if (getResolvedThis().isPresent() && getResolvedThis().get() != this && getResolvedThis().get().getMaxSerialLength().get() != 0){ + getSpannedScope().remove(getResolvedThis().get()); + }*/ + setResolvedThis(null); + setUnresolvableVariables(null); + for (ArgumentSymbol arg : getArguments()){ + arg.getRhs().reset(); + } + } + @Override public boolean isMethod(){ return true; @@ -114,55 +145,34 @@ public class MethodLayerSymbol extends LayerSymbol { } @Override - public Set resolve() { + public Set resolve() { //todo checkForRecursion() - checkIfResolvable(); - if (isResolvable()){ - resolveExpressions(); - int parallelLength = getParallelLength().get(); - int maxSerialLength = getMaxSerialLength().get(); - - if (!isActive() || maxSerialLength == 0){ - //set resolvedThis to empty composite. This practically removes this method call. - setResolvedThis(new CompositeLayerSymbol.Builder().build()); - } - else if (parallelLength == 1 && maxSerialLength == 1){ - //resolve the method call - LayerSymbol resolvedMethod = getMethod().call(this); - setResolvedThis(resolvedMethod); - } - else { - //split the method if it contains an argument sequence - LayerSymbol splitComposite = resolveSequences(parallelLength, getSerialLengths().get()); - setResolvedThis(splitComposite); - splitComposite.resolveOrError(); + if (!isResolved()) { + if (isResolvable()) { + resolveExpressions(); + int parallelLength = getParallelLength().get(); + int maxSerialLength = getMaxSerialLength().get(); + + if (!isActive() || maxSerialLength == 0) { + //set resolvedThis to empty composite. This practically removes this method call. + setResolvedThis(new CompositeLayerSymbol.Builder().build()); + } + else if (parallelLength == 1 && maxSerialLength == 1) { + //resolve the method call + LayerSymbol resolvedMethod = getMethod().call(this); + setResolvedThis(resolvedMethod); + } + else { + //split the method if it contains an argument sequence + LayerSymbol splitComposite = resolveSequences(parallelLength, getSerialLengths().get()); + setResolvedThis(splitComposite); + splitComposite.resolveOrError(); + } } } - return getUnresolvableNames(); + return getUnresolvableVariables(); } - /*protected void computeResolvedThis(){ - int parallelLength = getParallelLength().get(); - int maxSerialLength = getMaxSerialLength().get(); - - if (!isActive() || maxSerialLength == 0){ - //set resolvedThis to empty composite. This practically removes this method call. - setResolvedThis(new CompositeLayerSymbol.Builder().build()); - } - else if (parallelLength == 1 && maxSerialLength == 1){ - //resolve the method call - LayerSymbol resolvedMethod = getMethod().call(this); - setResolvedThis(resolvedMethod); - resolvedMethod.computeResolvedThis(); - } - else { - //split the method if it contains an argument sequence - LayerSymbol splitComposite = resolveSequences(parallelLength, getSerialLengths().get()); - setResolvedThis(splitComposite); - splitComposite.resolveOrError(); - } - }*/ - private boolean isActive(){ if (getIfExpression().isSimpleValue() && !getIfExpression().getBooleanValue().get()){ return false; @@ -225,12 +235,11 @@ public class MethodLayerSymbol extends LayerSymbol { } @Override - protected Set computeUnresolvableNames() { - Set unresolvableNames = new HashSet<>(); + protected void computeUnresolvableVariables(Set unresolvableVariables, Set allVariables) { for (ArgumentSymbol argument : getArguments()){ - unresolvableNames.addAll(argument.getRhs().computeUnresolvableNames()); + argument.getRhs().checkIfResolvable(allVariables); + unresolvableVariables.addAll(argument.getRhs().getUnresolvableVariables()); } - return unresolvableNames; } @Override @@ -240,13 +249,13 @@ public class MethodLayerSymbol extends LayerSymbol { return shapeFunction.apply(getInputLayer().get().getOutputShapes(), this); } else { - Set unresolvableNames = resolve(); - if (unresolvableNames.isEmpty()){ + Set unresolvableVariables = resolve(); + if (unresolvableVariables.isEmpty()){ return getResolvedThis().get().computeOutputShapes(); } else { throw new IllegalStateException("The output shape can only be computed if this and all previous layer are resolvable. " + - "The following names cannot be resolved: " + String.join(", ", unresolvableNames)); + "The following names cannot be resolved: " + Joiners.COMMA.join(unresolvableVariables)); } } @@ -262,32 +271,29 @@ public class MethodLayerSymbol extends LayerSymbol { } public Optional getIntValue(String argumentName){ - Optional arg = getArgument(argumentName); - if (arg.isPresent()){ - Optional val = arg.get().getValue(); - if (val.isPresent() && val.get() instanceof Integer){ - return val.map(o -> (Integer) o); - } - } - return Optional.empty(); + return getTValue(argumentName, ArchExpressionSymbol::getIntValue); } public Optional> getIntTupleValue(String argumentName){ + return getTValue(argumentName, ArchExpressionSymbol::getIntTupleValues); + } + + public Optional getBooleanValue(String argumentName){ + return getTValue(argumentName, ArchExpressionSymbol::getBooleanValue); + } + + public Optional getValue(String argumentName){ + return getTValue(argumentName, ArchExpressionSymbol::getValue); + } + + private Optional getTValue(String argumentName, Function> getMethod){ Optional arg = getArgument(argumentName); + Optional param = getMethod().getParameter(argumentName); if (arg.isPresent()){ - Optional val = arg.get().getValue(); - if (val.isPresent() && val.get() instanceof List){ - List list = new ArrayList<>(); - for (Object obj : (List) val.get()){ - if (obj instanceof Integer){ - list.add((Integer) obj); - } - else{ - return Optional.empty(); - } - } - return Optional.of(list); - } + return getMethod.apply(arg.get().getRhs()); + } + else if (param.isPresent() && param.get().getDefaultExpression().isPresent()){ + return getMethod.apply(param.get().getDefaultExpression().get()); } return Optional.empty(); } @@ -335,6 +341,9 @@ public class MethodLayerSymbol extends LayerSymbol { return Optional.empty(); } } + if (getArguments().isEmpty()){ + max = 1; + } return Optional.of(max); } @@ -394,6 +403,9 @@ public class MethodLayerSymbol extends LayerSymbol { return Optional.empty(); } } + if (getArguments().isEmpty()){ + argumentLengths.add(Collections.singletonList(1)); + } return Optional.of(argumentLengths); } diff --git a/src/main/java/de/monticore/lang/monticar/cnnarch/_symboltable/ShapeSymbol.java b/src/main/java/de/monticore/lang/monticar/cnnarch/_symboltable/ShapeSymbol.java index bdbd6cb..af043cd 100644 --- a/src/main/java/de/monticore/lang/monticar/cnnarch/_symboltable/ShapeSymbol.java +++ b/src/main/java/de/monticore/lang/monticar/cnnarch/_symboltable/ShapeSymbol.java @@ -20,7 +20,10 @@ */ package de.monticore.lang.monticar.cnnarch._symboltable; +import de.monticore.lang.math.math._symboltable.expression.MathNameExpressionSymbol; import de.monticore.symboltable.CommonSymbol; +import de.monticore.symboltable.MutableScope; +import de.monticore.symboltable.Symbol; import java.util.*; @@ -29,92 +32,108 @@ public class ShapeSymbol extends CommonSymbol { public static final ShapeKind KIND = new ShapeKind(); + public static final int BATCH_SIZE_INDEX = 0; public static final int HEIGHT_INDEX = 1; public static final int WIDTH_INDEX = 2; public static final int CHANNEL_INDEX = 3; - private List dimensions = Arrays.asList(DimensionSymbol.of(-1),DimensionSymbol.of(-1), DimensionSymbol.of(-1), DimensionSymbol.of(-1)); + private List dimensions = + Arrays.asList(ArchSimpleExpressionSymbol.of(-1), + ArchSimpleExpressionSymbol.of(-1), + ArchSimpleExpressionSymbol.of(-1), + ArchSimpleExpressionSymbol.of(-1)); public ShapeSymbol() { super("", KIND); } - public DimensionSymbol getHeightSymbol() { + public ArchSimpleExpressionSymbol getBatchSizeSymbol() { + return dimensions.get(BATCH_SIZE_INDEX); + } + + public void setBatchSize(int batchSize) { + dimensions.get(BATCH_SIZE_INDEX).reset(); + dimensions.get(BATCH_SIZE_INDEX).setValue(batchSize); + dimensions.get(BATCH_SIZE_INDEX).setMathExpression(null); + } + + public void setBatchSize(ArchSimpleExpressionSymbol batchSize) { + getDimensionSymbols().set(BATCH_SIZE_INDEX, batchSize); + } + + public ArchSimpleExpressionSymbol getHeightSymbol() { return dimensions.get(HEIGHT_INDEX); } - public void setHeightSymbol(DimensionSymbol heightSymbol) { - dimensions.set(HEIGHT_INDEX, heightSymbol); + public void setHeight(int height) { + dimensions.get(HEIGHT_INDEX).reset(); + dimensions.get(HEIGHT_INDEX).setValue(height); + dimensions.get(HEIGHT_INDEX).setMathExpression(null); + } + + public void setHeight(ArchSimpleExpressionSymbol height) { + getDimensionSymbols().set(HEIGHT_INDEX, height); } - public DimensionSymbol getWidthSymbol() { + public ArchSimpleExpressionSymbol getWidthSymbol() { return dimensions.get(WIDTH_INDEX); } - public void setWidthSymbol(DimensionSymbol widthSymbol) { - dimensions.set(WIDTH_INDEX, widthSymbol); + public void setWidth(int width) { + dimensions.get(WIDTH_INDEX).reset(); + dimensions.get(WIDTH_INDEX).setValue(width); + dimensions.get(WIDTH_INDEX).setMathExpression(null); + } + + public void setWidth(ArchSimpleExpressionSymbol width) { + getDimensionSymbols().set(WIDTH_INDEX, width); } - public DimensionSymbol getChannelsSymbol() { + public ArchSimpleExpressionSymbol getChannelsSymbol() { return dimensions.get(CHANNEL_INDEX); } - public void setChannelsSymbol(DimensionSymbol channelsSymbol) { - dimensions.set(CHANNEL_INDEX, channelsSymbol); + public void setChannels(int channels) { + dimensions.get(CHANNEL_INDEX).reset(); + dimensions.get(CHANNEL_INDEX).setValue(channels); + dimensions.get(CHANNEL_INDEX).setMathExpression(null); + } + + public void setChannels(ArchSimpleExpressionSymbol channels) { + getDimensionSymbols().set(CHANNEL_INDEX, channels); } public Optional getWidth(){ - return getWidthSymbol().getValue(); + return getWidthSymbol().getIntValue(); } public Optional getHeight(){ - return getHeightSymbol().getValue(); + return getHeightSymbol().getIntValue(); } public Optional getChannels(){ - return getChannelsSymbol().getValue(); + return getChannelsSymbol().getIntValue(); } - public List getDimensionSymbols() { + public List getDimensionSymbols() { return dimensions; } - - public List getIOVariables(){ - List vars = new ArrayList<>(4); - for (DimensionSymbol dim : getDimensionSymbols()){ - if (dim.getIoVariable().isPresent()){ - vars.add(dim.getIoVariable().get()); - } - } - return vars; - } - - public Set computeUnresolvableNames(){ - Set unresolvableNames = new HashSet<>(); - for (VariableSymbol variable : getIOVariables()){ - if (!variable.hasValue()){ - unresolvableNames.add(variable.getName()); - } - } - return unresolvableNames; - } - - public Set resolve() { + public Set resolve() { if (!isResolved()){ if (isResolvable()){ - for (DimensionSymbol dimension : getDimensionSymbols()){ - dimension.getValueExpression().resolveOrError(); + for (ArchSimpleExpressionSymbol dimension : getDimensionSymbols()){ + dimension.resolveOrError(); } } } - return getUnresolvableNames(); + return getUnresolvableVariables(); } public boolean isResolvable(){ boolean isResolvable = true; - for (DimensionSymbol dimension : getDimensionSymbols()){ - if (!dimension.getValueExpression().isResolvable()){ + for (ArchSimpleExpressionSymbol dimension : getDimensionSymbols()){ + if (!dimension.isResolvable()){ isResolvable = false; } } @@ -123,62 +142,61 @@ public class ShapeSymbol extends CommonSymbol { public boolean isResolved(){ boolean isResolved = true; - for (DimensionSymbol dimension : getDimensionSymbols()){ - if (!dimension.getValueExpression().isResolved()){ + for (ArchSimpleExpressionSymbol dimension : getDimensionSymbols()){ + if (!dimension.isResolved()){ isResolved = false; } } return isResolved; } - public Set getUnresolvableNames(){ - Set unresolvableNames = new HashSet<>(); - for (DimensionSymbol dimension : getDimensionSymbols()){ - unresolvableNames.addAll(dimension.getValueExpression().getUnresolvableNames()); + public Set getUnresolvableVariables(){ + Set unresolvableVariables = new HashSet<>(); + for (ArchSimpleExpressionSymbol dimension : getDimensionSymbols()){ + unresolvableVariables.addAll(dimension.getUnresolvableVariables()); } - return unresolvableNames; + return unresolvableVariables; } - public static class Builder{ - private DimensionSymbol height = DimensionSymbol.of(1); - private DimensionSymbol width = DimensionSymbol.of(1); - private DimensionSymbol channels = DimensionSymbol.of(1); + public void checkIfResolvable(Set seenVariables) { + for (ArchSimpleExpressionSymbol dimension : getDimensionSymbols()){ + dimension.checkIfResolvable(seenVariables); + } + } - public Builder height(int height){ - this.height = DimensionSymbol.of(height); - return this; + @Override + public void setEnclosingScope(MutableScope scope) { + super.setEnclosingScope(scope); + for (ArchSimpleExpressionSymbol dimension : getDimensionSymbols()){ + dimension.putInScope(scope); } + } - public Builder height(DimensionSymbol height){ + public static class Builder{ + private int height; + private int width; + private int channels; + + public Builder height(int height){ this.height = height; return this; } public Builder width(int width){ - this.width = DimensionSymbol.of(width); - return this; - } - - public Builder width(DimensionSymbol width){ this.width = width; return this; } public Builder channels(int channels){ - this.channels = DimensionSymbol.of(channels); - return this; - } - - public Builder channels(DimensionSymbol channels){ this.channels = channels; return this; } public ShapeSymbol build(){ ShapeSymbol sym = new ShapeSymbol(); - sym.setHeightSymbol(height); - sym.setChannelsSymbol(channels); - sym.setWidthSymbol(width); + sym.setHeight(height); + sym.setChannels(channels); + sym.setWidth(width); return sym; } } diff --git a/src/main/java/de/monticore/lang/monticar/cnnarch/_symboltable/VariableSymbol.java b/src/main/java/de/monticore/lang/monticar/cnnarch/_symboltable/VariableSymbol.java index ed9963f..8275f7f 100644 --- a/src/main/java/de/monticore/lang/monticar/cnnarch/_symboltable/VariableSymbol.java +++ b/src/main/java/de/monticore/lang/monticar/cnnarch/_symboltable/VariableSymbol.java @@ -21,6 +21,7 @@ package de.monticore.lang.monticar.cnnarch._symboltable; import de.monticore.lang.monticar.cnnarch.Constraint; +import de.monticore.lang.monticar.cnnarch.PredefinedVariables; import de.monticore.symboltable.CommonSymbol; import de.monticore.symboltable.MutableScope; import de.monticore.symboltable.Symbol; @@ -90,8 +91,8 @@ public class VariableSymbol extends CommonSymbol { return getCurrentExpression().isPresent() || getDefaultExpression().isPresent(); } - public void setExpression(ArchSimpleExpressionSymbol value){ - currentExpression = value; + protected void setExpression(ArchSimpleExpressionSymbol expression){ + currentExpression = expression; } public ArchSimpleExpressionSymbol getExpression(){ diff --git a/src/test/java/de/monticore/lang/monticar/cnnarch/SymtabTest.java b/src/test/java/de/monticore/lang/monticar/cnnarch/SymtabTest.java index e2167cc..653bcc4 100644 --- a/src/test/java/de/monticore/lang/monticar/cnnarch/SymtabTest.java +++ b/src/test/java/de/monticore/lang/monticar/cnnarch/SymtabTest.java @@ -41,28 +41,55 @@ public class SymtabTest extends AbstractSymtabTest { @Test public void testAlexnet(){ + Scope symTab = createSymTab("src/test/resources/architectures"); + ArchitectureSymbol a = symTab.resolve( + "Alexnet", + ArchitectureSymbol.KIND).orElse(null); + assertNotNull(a); + a.resolve(); + } + + @Test + public void testThreeInput(){ + Scope symTab = createSymTab("src/test/resources/architectures"); + ArchitectureSymbol a = symTab.resolve( + "ThreeInputCNN_M14", + ArchitectureSymbol.KIND).orElse(null); + assertNotNull(a); + a.resolve(); + } + + @Test + public void testFixedAlexnet(){ Scope symTab = createSymTab("src/test/resources/valid_tests"); ArchitectureSymbol a = symTab.resolve( "Fixed_Alexnet", ArchitectureSymbol.KIND).orElse(null); assertNotNull(a); - a.getBody().resolve(); + a.resolve(); List asd = a.getBody().getOutputShapes(); - boolean f = true; } @Test - public void testRes(){ + public void testFixedResNeXt(){ Scope symTab = createSymTab("src/test/resources/valid_tests"); ArchitectureSymbol a = symTab.resolve( "Fixed_ResNeXt50", ArchitectureSymbol.KIND).orElse(null); assertNotNull(a); - a.getBody().resolve(); + a.resolve(); List asd = a.getBody().getOutputShapes(); - boolean f = true; } - + @Test + public void testFixedThreeInput(){ + Scope symTab = createSymTab("src/test/resources/valid_tests"); + ArchitectureSymbol a = symTab.resolve( + "Fixed_ThreeInputCNN_M14", + ArchitectureSymbol.KIND).orElse(null); + assertNotNull(a); + a.resolve(); + List asd = a.getBody().getOutputShapes(); + } } diff --git a/src/test/java/de/monticore/lang/monticar/cnnarch/cocos/AllCoCoTest.java b/src/test/java/de/monticore/lang/monticar/cnnarch/cocos/AllCoCoTest.java index d0eed44..efe8678 100644 --- a/src/test/java/de/monticore/lang/monticar/cnnarch/cocos/AllCoCoTest.java +++ b/src/test/java/de/monticore/lang/monticar/cnnarch/cocos/AllCoCoTest.java @@ -40,6 +40,10 @@ public class AllCoCoTest extends AbstractCoCoTest { checkValid("architectures", "Alexnet_alt"); checkValid("architectures", "ResNeXt50"); checkValid("architectures", "ThreeInputCNN_M14"); + + checkValid("valid_tests", "Fixed_Alexnet"); + checkValid("valid_tests", "Fixed_ResNeXt50"); + checkValid("valid_tests", "Fixed_ThreeInputCNN_M14"); /*checkValid("architectures", "Alexnet"); checkValid("architectures", "Resnet34"); checkValid("architectures", "ResNeXt50"); diff --git a/src/test/resources/architectures/Alexnet.cnna b/src/test/resources/architectures/Alexnet.cnna index b51089a..e8154b6 100644 --- a/src/test/resources/architectures/Alexnet.cnna +++ b/src/test/resources/architectures/Alexnet.cnna @@ -13,12 +13,12 @@ architecture Alexnet{ Dropout() } def split1(i, groups){ - Split(index=i, n=groups) -> + SplitData(index=i, n=groups) -> conv(filter=(5,5), channels=128) -> Lrn(nsize=5, alpha=0.0001, beta=0.75) } def split2(i, groups){ - Split(index=i, n=groups) -> + SplitData(index=i, n=groups) -> conv(filter=(3,3), channels=192, hasPool=false) -> conv(filter=(3,3), channels=128) } diff --git a/src/test/resources/architectures/Alexnet_alt.cnna b/src/test/resources/architectures/Alexnet_alt.cnna index 169e124..21f2732 100644 --- a/src/test/resources/architectures/Alexnet_alt.cnna +++ b/src/test/resources/architectures/Alexnet_alt.cnna @@ -18,11 +18,11 @@ architecture Alexnet_alt{ Lrn(nsize=5, alpha=0.0001, beta=0.75) -> ( - Split(index=0, n=2) -> + SplitData(index=0, n=2) -> conv(filter=(5,5), channels=128) -> Lrn(nsize=5, alpha=0.0001, beta=0.75) | - Split(index=1, n=2) -> + SplitData(index=1, n=2) -> conv(filter=(5,5), channels=128) -> Lrn(nsize=5, alpha=0.0001, beta=0.75) ) -> @@ -30,11 +30,11 @@ architecture Alexnet_alt{ conv(filter=(3,3), channels=384 ,hasPool=false) -> ( - Split(index=0, n=2) -> + SplitData(index=0, n=2) -> conv(filter=(3,3), channels=192, hasPool=false) -> conv(filter=(3,3), channels=128) | - Split(index=1, n=2) -> + SplitData(index=1, n=2) -> conv(filter=(3,3), channels=192, hasPool=false) -> conv(filter=(3,3), channels=128) ) -> diff --git a/src/test/resources/architectures/ResNeXt50.cnna b/src/test/resources/architectures/ResNeXt50.cnna index 1ca4435..bcdf872 100644 --- a/src/test/resources/architectures/ResNeXt50.cnna +++ b/src/test/resources/architectures/ResNeXt50.cnna @@ -21,7 +21,7 @@ architecture ResNeXt50{ resGroup(innerChannels=innerChannels, outChannels=outChannels, stride=stride, - _for=[1|:cardinality]) -> + _cardinality=cardinality) -> Add() | skip(outChannels=outChannels, stride=stride, _if=(stride!=1)) @@ -31,7 +31,7 @@ architecture ResNeXt50{ } def resStructure(innerChannels, outChannels, resLayers){ resLayer(innerChannels=innerChannels, outChannels=outChannels, stride=2) -> - resLayer(innerChannels=innerChannels, outChannels=outChannels, _for=[2->:resLayers]) + resLayer(innerChannels=innerChannels, outChannels=outChannels, _for=resLayers - 1) } image -> diff --git a/src/test/resources/architectures/ThreeInputCNN_M14.cnna b/src/test/resources/architectures/ThreeInputCNN_M14.cnna index fbddd8a..f926978 100644 --- a/src/test/resources/architectures/ThreeInputCNN_M14.cnna +++ b/src/test/resources/architectures/ThreeInputCNN_M14.cnna @@ -7,14 +7,14 @@ architecture ThreeInputCNN_M14{ Relu() } - def inputGroup(group){ - [group] -> - conv(filter=(3,3), channels=32, _for=[1->:3]) -> + def inputGroup(index){ + [index] -> + conv(filter=(3,3), channels=32, _for=3) -> MaxPooling(kernel=(2,2), stride=2) } image -> - inputGroup(group=[0|:2]) -> + inputGroup(index=0|..|2) -> Concatenate() -> conv(filter=(3,3), channels=64) -> MaxPooling(kernel=(2,2), stride=2) -> diff --git a/src/test/resources/valid_tests/Fixed_Alexnet.cnna b/src/test/resources/valid_tests/Fixed_Alexnet.cnna index 735f5c6..bb58c21 100644 --- a/src/test/resources/valid_tests/Fixed_Alexnet.cnna +++ b/src/test/resources/valid_tests/Fixed_Alexnet.cnna @@ -2,10 +2,10 @@ architecture Fixed_Alexnet{ def input Z(0:255)^{256,256,3} image def output Q(0:1)^{10} predictions - def conv(filter, channels, hasPool=true, convStride=1){ + def conv(filter, channels, hasPool=true, convStride=(1,1)){ Convolution(kernel=filter, channels=channels, stride=convStride) -> Relu() -> - MaxPooling(kernel=(3,3), stride=2, _if=hasPool) + MaxPooling(kernel=(3,3), stride=(2,2), _if=hasPool) } def fc(){ FullyConnected(units=4096) -> @@ -13,18 +13,18 @@ architecture Fixed_Alexnet{ Dropout() } def split1(i, groups){ - Split(index=i, n=groups) -> + SplitData(index=i, n=groups) -> conv(filter=(5,5), channels=128) -> Lrn(nsize=5, alpha=0.0001, beta=0.75) } def split2(i, groups){ - Split(index=i, n=groups) -> + SplitData(index=i, n=groups) -> conv(filter=(3,3), channels=192, hasPool=false) -> conv(filter=(3,3), channels=128) } image -> - conv(filter=(11,11), channels=96, convStride=4) -> + conv(filter=(11,11), channels=96, convStride=(4,4)) -> Lrn(nsize=5, alpha=0.0001, beta=0.75) -> split1(i=[0|1], groups=2) -> diff --git a/src/test/resources/valid_tests/Fixed_ResNeXt50.cnna b/src/test/resources/valid_tests/Fixed_ResNeXt50.cnna index 270cf1e..302385d 100644 --- a/src/test/resources/valid_tests/Fixed_ResNeXt50.cnna +++ b/src/test/resources/valid_tests/Fixed_ResNeXt50.cnna @@ -1,9 +1,9 @@ architecture Fixed_ResNeXt50{ - def input Z(0:255)^{256,256,3} image - def output Q(0:1)^{10} predictions + def input Z(0:255)^{224,224,3} image + def output Q(0:1)^{1000} predictions def conv(filter, channels, stride=1, act=true){ - Convolution(kernel=filter, channels=channels, stride=stride) -> + Convolution(kernel=filter, channels=channels, stride=(stride, stride)) -> BatchNorm() -> Relu(_if=act) } @@ -13,7 +13,7 @@ architecture Fixed_ResNeXt50{ conv(filter=(1,1), channels=outChannels, act=false) } def skip(outChannels, stride){ - Convolution(kernel=(1,1), channels=outChannels, stride=stride) -> + Convolution(kernel=(1,1), channels=outChannels, stride=(stride,stride)) -> BatchNorm() } def resLayer(innerChannels, outChannels, stride=1, cardinality=32){ @@ -21,7 +21,7 @@ architecture Fixed_ResNeXt50{ resGroup(innerChannels=innerChannels, outChannels=outChannels, stride=stride, - _for=[1|:cardinality]) -> + _cardinality=cardinality) -> Add() | skip(outChannels=outChannels, stride=stride, _if=(stride!=1)) @@ -29,17 +29,17 @@ architecture Fixed_ResNeXt50{ Add() -> Relu() } - def resStructure(innerChannels, outChannels, resLayers){ - resLayer(innerChannels=innerChannels, outChannels=outChannels, stride=2) -> - resLayer(innerChannels=innerChannels, outChannels=outChannels, _for=[2->:resLayers]) + def resStructure(innerChannels, outChannels, resLayers, firstLayerStride){ + resLayer(innerChannels=innerChannels, outChannels=outChannels, stride=firstLayerStride) -> + resLayer(innerChannels=innerChannels, outChannels=outChannels, _for= resLayers - 1) } image -> conv(filter=(7,7), channels=64, stride=2) -> - MaxPooling(kernel=(3,3), stride=2) -> - resStructure(innerChannels=[4->8->16->32], outChannels=[256->512->1024->2048], resLayers=[3->4->6->3]) -> + MaxPooling(kernel=(3,3), stride=(2,2)) -> + resStructure(innerChannels=[4->8->16->32], outChannels=[256->512->1024->2048], resLayers=[3->4->6->3], firstLayerStride=[1->2->2->2]) -> AveragePooling(global=true) -> - FullyConnected(units=10) -> + FullyConnected(units=1000) -> Softmax() -> predictions } \ No newline at end of file diff --git a/src/test/resources/valid_tests/Fixed_ThreeInputCNN_M14.cnna b/src/test/resources/valid_tests/Fixed_ThreeInputCNN_M14.cnna new file mode 100644 index 0000000..9865cbf --- /dev/null +++ b/src/test/resources/valid_tests/Fixed_ThreeInputCNN_M14.cnna @@ -0,0 +1,27 @@ +architecture Fixed_ThreeInputCNN_M14{ + def input Z(0:255)^{256,256,3} image[3] + def output Q(0:1)^{3} predictions + + def conv(filter, channels){ + Convolution(kernel=filter, channels=channels) -> + Relu() + } + + def inputGroup(index){ + [index] -> + conv(filter=(3,3), channels=32, _for=3) -> + MaxPooling(kernel=(2,2), stride=(2,2)) + } + + image -> + inputGroup(index=0|..|2) -> + Concatenate() -> + conv(filter=(3,3), channels=64) -> + MaxPooling(kernel=(2,2), stride=(2,2)) -> + + FullyConnected(units=32) -> + Relu() -> + FullyConnected(units=3) -> + Softmax() -> + predictions +} \ No newline at end of file -- GitLab