Completed resolve mechanism and output shape computation.

parent b2d40dcd
...@@ -30,7 +30,7 @@ ...@@ -30,7 +30,7 @@
<groupId>de.monticore.lang.monticar</groupId> <groupId>de.monticore.lang.monticar</groupId>
<artifactId>cnn-arch</artifactId> <artifactId>cnn-arch</artifactId>
<version>0.0.2-SNAPSHOT</version> <version>0.1.0-SNAPSHOT</version>
<!-- == PROJECT DEPENDENCIES ============================================= --> <!-- == PROJECT DEPENDENCIES ============================================= -->
......
...@@ -57,9 +57,11 @@ grammar CNNArch extends de.monticore.lang.math.Math { ...@@ -57,9 +57,11 @@ grammar CNNArch extends de.monticore.lang.math.Math {
ArchSerialSequence = serialValues:(ArchSimpleExpression || "->")+; ArchSerialSequence = serialValues:(ArchSimpleExpression || "->")+;
ArchValueRange implements ArchValueSequence = "[" start:ArchSimpleExpression ArchValueRange implements ArchValueSequence = start:ArchSimpleExpression
(serial:"->" | parallel:"|") (serial:"->" | parallel:"|")
":" end:ArchSimpleExpression "]"; ".."
(serial2:"->" | parallel2:"|")
end:ArchSimpleExpression;
ArchSimpleExpression = (arithmeticExpression:MathArithmeticExpression ArchSimpleExpression = (arithmeticExpression:MathArithmeticExpression
......
...@@ -51,28 +51,60 @@ public enum Constraint { ...@@ -51,28 +51,60 @@ public enum Constraint {
INTEGER_TUPLE { INTEGER_TUPLE {
@Override @Override
public boolean check(ArchSimpleExpressionSymbol exp) { public boolean check(ArchSimpleExpressionSymbol exp) {
boolean res = false; return exp.isIntTuple().get();
if (exp.isTuple()){
//todo
}
return false;
} }
}, },
POSITIVE { POSITIVE {
@Override @Override
public boolean check(ArchSimpleExpressionSymbol exp) { public boolean check(ArchSimpleExpressionSymbol exp) {
if (exp.getDoubleValue().isPresent()){
return exp.getDoubleValue().get() > 0;
}
else if (exp.getDoubleTupleValues().isPresent()){
boolean isPositive = true;
for (double value : exp.getDoubleTupleValues().get()){
if (value <= 0){
isPositive = false;
}
}
return isPositive;
}
return false; return false;
} }
}, },
NON_NEGATIVE { NON_NEGATIVE {
@Override @Override
public boolean check(ArchSimpleExpressionSymbol exp) { public boolean check(ArchSimpleExpressionSymbol exp) {
if (exp.getDoubleValue().isPresent()){
return exp.getDoubleValue().get() >= 0;
}
else if (exp.getDoubleTupleValues().isPresent()){
boolean isPositive = true;
for (double value : exp.getDoubleTupleValues().get()){
if (value < 0){
isPositive = false;
}
}
return isPositive;
}
return false; return false;
} }
}, },
BETWEEN_ZERO_AND_ONE { BETWEEN_ZERO_AND_ONE {
@Override @Override
public boolean check(ArchSimpleExpressionSymbol exp) { public boolean check(ArchSimpleExpressionSymbol exp) {
if (exp.getDoubleValue().isPresent()){
return exp.getDoubleValue().get() >= 0 && exp.getDoubleValue().get() <= 1;
}
else if (exp.getDoubleTupleValues().isPresent()){
boolean isPositive = true;
for (double value : exp.getDoubleTupleValues().get()){
if (value < 0 || value > 1){
isPositive = false;
}
}
return isPositive;
}
return false; return false;
} }
}; };
......
...@@ -41,7 +41,7 @@ public class PredefinedMethods { ...@@ -41,7 +41,7 @@ public class PredefinedMethods {
public static final String AVG_POOLING_NAME = "AveragePooling"; public static final String AVG_POOLING_NAME = "AveragePooling";
public static final String LRN_NAME = "Lrn"; public static final String LRN_NAME = "Lrn";
public static final String BATCHNORM_NAME = "BatchNorm"; public static final String BATCHNORM_NAME = "BatchNorm";
public static final String SPLIT_NAME = "Split"; public static final String SPLIT_NAME = "SplitData";
public static final String GET_NAME = "Get"; public static final String GET_NAME = "Get";
public static final String ADD_NAME = "Add"; public static final String ADD_NAME = "Add";
public static final String CONCATENATE_NAME = "Concatenate"; public static final String CONCATENATE_NAME = "Concatenate";
...@@ -323,22 +323,36 @@ public class PredefinedMethods { ...@@ -323,22 +323,36 @@ public class PredefinedMethods {
} }
private static List<ShapeSymbol> strideShapeFunction(ShapeSymbol inputShape, MethodLayerSymbol method, int channels) { private static List<ShapeSymbol> strideShapeFunction(ShapeSymbol inputShape, MethodLayerSymbol method, int channels) {
int strideHeight = method.getIntTupleValue("stride").get().get(0); Optional<Boolean> optGlobal = method.getBooleanValue("global");
int strideWidth = method.getIntTupleValue("stride").get().get(1); if (optGlobal.isPresent() && optGlobal.get()){
int kernelHeight = method.getIntTupleValue("kernel").get().get(0); return Collections.singletonList(new ShapeSymbol.Builder()
int kernelWidth = method.getIntTupleValue("kernel").get().get(1); .height(1)
int inputHeight = inputShape.getHeight().get(); .width(1)
int inputWidth = inputShape.getWidth().get(); .channels(channels)
.build());
}
else{
int strideHeight = method.getIntTupleValue("stride").get().get(0);
int strideWidth = method.getIntTupleValue("stride").get().get(1);
int kernelHeight = method.getIntTupleValue("kernel").get().get(0);
int kernelWidth = method.getIntTupleValue("kernel").get().get(1);
int inputHeight = inputShape.getHeight().get();
int inputWidth = inputShape.getWidth().get();
//assume padding with border_mode='same' //assume padding with border_mode='same'
int outputWidth = 1 + ((inputWidth - kernelWidth + strideWidth - 1) / strideWidth); int outputWidth = inputWidth / strideWidth;
int outputHeight = 1 + ((inputHeight - kernelHeight + strideHeight - 1) / strideHeight); int outputHeight = inputHeight / strideHeight;
return Collections.singletonList(new ShapeSymbol.Builder() //border_mode=valid
.height(outputHeight) //int outputWidth = 1 + Math.max(0, ((inputWidth - kernelWidth + strideWidth - 1) / strideWidth));
.width(outputWidth) //int outputHeight = 1 + Math.max(0, ((inputHeight - kernelHeight + strideHeight - 1) / strideHeight));
.channels(channels)
.build()); return Collections.singletonList(new ShapeSymbol.Builder()
.height(outputHeight)
.width(outputWidth)
.channels(channels)
.build());
}
} }
private static List<ShapeSymbol> splitShapeFunction(ShapeSymbol inputShape, MethodLayerSymbol method) { private static List<ShapeSymbol> splitShapeFunction(ShapeSymbol inputShape, MethodLayerSymbol method) {
......
...@@ -27,6 +27,7 @@ public class PredefinedVariables { ...@@ -27,6 +27,7 @@ public class PredefinedVariables {
public static final String IF_NAME = "_if"; public static final String IF_NAME = "_if";
public static final String FOR_NAME = "_for"; public static final String FOR_NAME = "_for";
public static final String CARDINALITY_NAME = "_cardinality";
public static final String TRUE_NAME = "true"; public static final String TRUE_NAME = "true";
public static final String FALSE_NAME = "false"; public static final String FALSE_NAME = "false";
...@@ -45,6 +46,13 @@ public class PredefinedVariables { ...@@ -45,6 +46,13 @@ public class PredefinedVariables {
.build(); .build();
} }
public static VariableSymbol createCardinalityParameter(){
return new VariableSymbol.Builder()
.name(CARDINALITY_NAME)
.defaultValue(1)
.build();
}
//necessary because true is currently only a name in MontiMath and it needs to be evaluated at compile time for this language //necessary because true is currently only a name in MontiMath and it needs to be evaluated at compile time for this language
public static VariableSymbol createTrueConstant(){ public static VariableSymbol createTrueConstant(){
return new VariableSymbol.Builder() return new VariableSymbol.Builder()
......
...@@ -22,9 +22,7 @@ package de.monticore.lang.monticar.cnnarch._symboltable; ...@@ -22,9 +22,7 @@ package de.monticore.lang.monticar.cnnarch._symboltable;
import de.monticore.symboltable.CommonSymbol; import de.monticore.symboltable.CommonSymbol;
import de.monticore.symboltable.MutableScope; import de.monticore.symboltable.MutableScope;
import de.monticore.symboltable.Scope;
import de.monticore.symboltable.Symbol; import de.monticore.symboltable.Symbol;
import de.se_rwth.commons.logging.Log;
import java.util.*; import java.util.*;
...@@ -32,7 +30,7 @@ abstract public class ArchExpressionSymbol extends CommonSymbol { ...@@ -32,7 +30,7 @@ abstract public class ArchExpressionSymbol extends CommonSymbol {
public static final ArchExpressionKind KIND = new ArchExpressionKind(); public static final ArchExpressionKind KIND = new ArchExpressionKind();
private Set<String> unresolvableNames = null; private Set<VariableSymbol> unresolvableVariables = null;
public ArchExpressionSymbol() { public ArchExpressionSymbol() {
super("", KIND); super("", KIND);
...@@ -40,25 +38,28 @@ abstract public class ArchExpressionSymbol extends CommonSymbol { ...@@ -40,25 +38,28 @@ abstract public class ArchExpressionSymbol extends CommonSymbol {
protected Boolean isResolvable(){ protected Boolean isResolvable(){
Set<String> set = getUnresolvableNames(); Set<VariableSymbol> set = getUnresolvableVariables();
return set != null && set.isEmpty(); return set != null && set.isEmpty();
} }
public Set<String> getUnresolvableNames() { public Set<VariableSymbol> getUnresolvableVariables() {
if (unresolvableNames == null){ if (unresolvableVariables == null){
checkIfResolvable(); checkIfResolvable(new HashSet<>());
} }
return unresolvableNames; return unresolvableVariables;
} }
protected void setUnresolvableNames(Set<String> unresolvableNames){ protected void setUnresolvableVariables(Set<VariableSymbol> unresolvableVariables){
this.unresolvableNames = unresolvableNames; this.unresolvableVariables = unresolvableVariables;
} }
public void checkIfResolvable(){ public void checkIfResolvable(Set<VariableSymbol> seenVariables){
setUnresolvableNames(computeUnresolvableNames()); Set<VariableSymbol> unresolvableVariables = new HashSet<>();
computeUnresolvableVariables(unresolvableVariables, seenVariables);
setUnresolvableVariables(unresolvableVariables);
} }
/** /**
* Checks whether the value is a boolean. If true getValue() will return a Boolean if present. * Checks whether the value is a boolean. If true getValue() will return a Boolean if present.
* *
...@@ -99,21 +100,21 @@ abstract public class ArchExpressionSymbol extends CommonSymbol { ...@@ -99,21 +100,21 @@ abstract public class ArchExpressionSymbol extends CommonSymbol {
public Optional<Boolean> isIntTuple(){ public Optional<Boolean> isIntTuple(){
if (getValue().isPresent()){ if (getValue().isPresent()){
return Optional.of(getIntTupleValue().isPresent()); return Optional.of(getIntTupleValues().isPresent());
} }
return Optional.empty(); return Optional.empty();
} }
public Optional<Boolean> isNumberTuple(){ public Optional<Boolean> isNumberTuple(){
if (getValue().isPresent()){ if (getValue().isPresent()){
return Optional.of(getDoubleTupleValue().isPresent()); return Optional.of(getDoubleTupleValues().isPresent());
} }
return Optional.empty(); return Optional.empty();
} }
public Optional<Boolean> isBooleanTuple(){ public Optional<Boolean> isBooleanTuple(){
if (getValue().isPresent()){ if (getValue().isPresent()){
return Optional.of(getBooleanTupleValue().isPresent()); return Optional.of(getBooleanTupleValues().isPresent());
} }
return Optional.empty(); return Optional.empty();
} }
...@@ -194,8 +195,8 @@ abstract public class ArchExpressionSymbol extends CommonSymbol { ...@@ -194,8 +195,8 @@ abstract public class ArchExpressionSymbol extends CommonSymbol {
return Optional.empty(); return Optional.empty();
} }
public Optional<List<Integer>> getIntTupleValue(){ public Optional<List<Integer>> getIntTupleValues(){
Optional<List<Object>> optValue = getTupleValue(); Optional<List<Object>> optValue = getTupleValues();
if (optValue.isPresent()){ if (optValue.isPresent()){
List<Integer> list = new ArrayList<>(); List<Integer> list = new ArrayList<>();
for (Object value : optValue.get()) { for (Object value : optValue.get()) {
...@@ -211,8 +212,8 @@ abstract public class ArchExpressionSymbol extends CommonSymbol { ...@@ -211,8 +212,8 @@ abstract public class ArchExpressionSymbol extends CommonSymbol {
return Optional.empty(); return Optional.empty();
} }
public Optional<List<Double>> getDoubleTupleValue() { public Optional<List<Double>> getDoubleTupleValues() {
Optional<List<Object>> optValue = getTupleValue(); Optional<List<Object>> optValue = getTupleValues();
if (optValue.isPresent()){ if (optValue.isPresent()){
List<Double> list = new ArrayList<>(); List<Double> list = new ArrayList<>();
for (Object value : optValue.get()) { for (Object value : optValue.get()) {
...@@ -231,8 +232,8 @@ abstract public class ArchExpressionSymbol extends CommonSymbol { ...@@ -231,8 +232,8 @@ abstract public class ArchExpressionSymbol extends CommonSymbol {
return Optional.empty(); return Optional.empty();
} }
public Optional<List<Boolean>> getBooleanTupleValue() { public Optional<List<Boolean>> getBooleanTupleValues() {
Optional<List<Object>> optValue = getTupleValue(); Optional<List<Object>> optValue = getTupleValues();
if (optValue.isPresent()){ if (optValue.isPresent()){
List<Boolean> list = new ArrayList<>(); List<Boolean> list = new ArrayList<>();
for (Object value : optValue.get()) { for (Object value : optValue.get()) {
...@@ -248,9 +249,10 @@ abstract public class ArchExpressionSymbol extends CommonSymbol { ...@@ -248,9 +249,10 @@ abstract public class ArchExpressionSymbol extends CommonSymbol {
return Optional.empty(); return Optional.empty();
} }
public Optional<List<Object>> getTupleValue(){ public Optional<List<Object>> getTupleValues(){
if (getValue().isPresent()){ if (getValue().isPresent()){
if (isTuple()){ Optional<Object> optValue = getValue();
if (optValue.isPresent() && (optValue.get() instanceof List)){
@SuppressWarnings("unchecked") @SuppressWarnings("unchecked")
List<Object> list = (List<Object>) getValue().get(); List<Object> list = (List<Object>) getValue().get();
return Optional.of(list); return Optional.of(list);
...@@ -300,7 +302,7 @@ abstract public class ArchExpressionSymbol extends CommonSymbol { ...@@ -300,7 +302,7 @@ abstract public class ArchExpressionSymbol extends CommonSymbol {
public void resolveOrError(){ public void resolveOrError(){
resolve(); resolve();
if (!isResolved()){ if (!isResolved()){
throw new IllegalStateException("The following names could not be resolved: " + getUnresolvableNames()); throw new IllegalStateException("The following names could not be resolved: " + getUnresolvableVariables());
} }
} }
...@@ -315,13 +317,15 @@ abstract public class ArchExpressionSymbol extends CommonSymbol { ...@@ -315,13 +317,15 @@ abstract public class ArchExpressionSymbol extends CommonSymbol {
*/ */
abstract public Optional<Object> getValue(); abstract public Optional<Object> getValue();
abstract public void reset();
/** /**
* Replaces all variable names in this values expression if possible. * Replaces all variable names in this values expression if possible.
* The values of the variables depend on the current scope. The replacement is irreversible if successful. * The values of the variables depend on the current scope. The replacement is irreversible if successful.
* *
* @return returns a set of all names which could not be resolved. * @return returns a set of all names which could not be resolved.
*/ */
abstract public Set<String> resolve(); abstract public Set<VariableSymbol> resolve();
/** /**
* @return returns a optional of a list(parallel) of lists(serial) of simple expressions in this sequence. * @return returns a optional of a list(parallel) of lists(serial) of simple expressions in this sequence.
...@@ -330,7 +334,7 @@ abstract public class ArchExpressionSymbol extends CommonSymbol { ...@@ -330,7 +334,7 @@ abstract public class ArchExpressionSymbol extends CommonSymbol {
*/ */
abstract public Optional<List<List<ArchSimpleExpressionSymbol>>> getElements(); abstract public Optional<List<List<ArchSimpleExpressionSymbol>>> getElements();
abstract protected Set<String> computeUnresolvableNames(); abstract protected void computeUnresolvableVariables(Set<VariableSymbol> unresolvableVariables, Set<VariableSymbol> allVariables);
/** /**
* @return returns true if the expression is resolved. * @return returns true if the expression is resolved.
......
...@@ -21,7 +21,6 @@ ...@@ -21,7 +21,6 @@
package de.monticore.lang.monticar.cnnarch._symboltable; package de.monticore.lang.monticar.cnnarch._symboltable;
import de.monticore.symboltable.MutableScope; import de.monticore.symboltable.MutableScope;
import de.monticore.symboltable.Scope;
import java.util.*; import java.util.*;
import java.util.stream.Collectors; import java.util.stream.Collectors;
...@@ -63,6 +62,13 @@ public class ArchRangeExpressionSymbol extends ArchAbstractSequenceExpression { ...@@ -63,6 +62,13 @@ public class ArchRangeExpressionSymbol extends ArchAbstractSequenceExpression {
this.parallel = parallel; this.parallel = parallel;
} }
@Override
public void reset() {
getStartSymbol().reset();
getEndSymbol().reset();
setUnresolvableVariables(null);
}
@Override @Override
public boolean isParallelSequence() { public boolean isParallelSequence() {
return isParallel(); return isParallel();
...@@ -88,16 +94,15 @@ public class ArchRangeExpressionSymbol extends ArchAbstractSequenceExpression { ...@@ -88,16 +94,15 @@ public class ArchRangeExpressionSymbol extends ArchAbstractSequenceExpression {
}*/ }*/
@Override @Override
public Set<String> resolve() { public Set<VariableSymbol> resolve() {
if (!isResolved()){ if (!isResolved()){
checkIfResolvable();
if (isResolvable()){ if (isResolvable()){
getStartSymbol().resolveOrError(); getStartSymbol().resolveOrError();
getEndSymbol().resolveOrError(); getEndSymbol().resolveOrError();
} }
} }
return getUnresolvableNames(); return getUnresolvableVariables();
} }
@Override @Override
...@@ -142,11 +147,11 @@ public class ArchRangeExpressionSymbol extends ArchAbstractSequenceExpression { ...@@ -142,11 +147,11 @@ public class ArchRangeExpressionSymbol extends ArchAbstractSequenceExpression {
} }
@Override @Override
protected Set<String> computeUnresolvableNames() { protected void computeUnresolvableVariables(Set<VariableSymbol> unresolvableVariables, Set<VariableSymbol> allVariables) {
Set<String> unresolvableNames = new HashSet<>(); getStartSymbol().checkIfResolvable(allVariables);
unresolvableNames.addAll(getStartSymbol().computeUnresolvableNames()); unresolvableVariables.addAll(getStartSymbol().getUnresolvableVariables());
unresolvableNames.addAll(getEndSymbol().computeUnresolvableNames()); getEndSymbol().checkIfResolvable(allVariables);
return unresolvableNames; unresolvableVariables.addAll(getEndSymbol().getUnresolvableVariables());
} }
public ArchRangeExpressionSymbol copy(){ public ArchRangeExpressionSymbol copy(){
...@@ -154,7 +159,7 @@ public class ArchRangeExpressionSymbol extends ArchAbstractSequenceExpression { ...@@ -154,7 +159,7 @@ public class ArchRangeExpressionSymbol extends ArchAbstractSequenceExpression {
copy.setParallel(isParallel()); copy.setParallel(isParallel());
copy.setStartSymbol(getStartSymbol().copy()); copy.setStartSymbol(getStartSymbol().copy());
copy.setEndSymbol(getEndSymbol().copy()); copy.setEndSymbol(getEndSymbol().copy());
copy.setUnresolvableNames(getUnresolvableNames()); copy.setUnresolvableVariables(getUnresolvableVariables());
return copy; return copy;
} }
...@@ -165,10 +170,11 @@ public class ArchRangeExpressionSymbol extends ArchAbstractSequenceExpression { ...@@ -165,10 +170,11 @@ public class ArchRangeExpressionSymbol extends ArchAbstractSequenceExpression {
getEndSymbol().putInScope(scope); getEndSymbol().putInScope(scope);
} }
public static ArchRangeExpressionSymbol of(ArchSimpleExpressionSymbol start, ArchSimpleExpressionSymbol end){ public static ArchRangeExpressionSymbol of(ArchSimpleExpressionSymbol start, ArchSimpleExpressionSymbol end, boolean parallel){
ArchRangeExpressionSymbol sym = new ArchRangeExpressionSymbol(); ArchRangeExpressionSymbol sym = new ArchRangeExpressionSymbol();
sym.setStartSymbol(start); sym.setStartSymbol(start);
sym.setEndSymbol(end); sym.setEndSymbol(end);
sym.setParallel(parallel);
return sym; return sym;
} }
} }
...@@ -21,7 +21,6 @@ ...@@ -21,7 +21,6 @@
package de.monticore.lang.monticar.cnnarch._symboltable; package de.monticore.lang.monticar.cnnarch._symboltable;
import de.monticore.symboltable.MutableScope; import de.monticore.symboltable.MutableScope;
import de.monticore.symboltable.Scope;
import java.util.*; import java.util.*;