Completed resolve mechanism and output shape computation.

parent b2d40dcd
......@@ -30,7 +30,7 @@
<groupId>de.monticore.lang.monticar</groupId>
<artifactId>cnn-arch</artifactId>
<version>0.0.2-SNAPSHOT</version>
<version>0.1.0-SNAPSHOT</version>
<!-- == PROJECT DEPENDENCIES ============================================= -->
......
......@@ -57,9 +57,11 @@ grammar CNNArch extends de.monticore.lang.math.Math {
ArchSerialSequence = serialValues:(ArchSimpleExpression || "->")+;
ArchValueRange implements ArchValueSequence = "[" start:ArchSimpleExpression
ArchValueRange implements ArchValueSequence = start:ArchSimpleExpression
(serial:"->" | parallel:"|")
":" end:ArchSimpleExpression "]";
".."
(serial2:"->" | parallel2:"|")
end:ArchSimpleExpression;
ArchSimpleExpression = (arithmeticExpression:MathArithmeticExpression
......
......@@ -51,28 +51,60 @@ public enum Constraint {
INTEGER_TUPLE {
@Override
public boolean check(ArchSimpleExpressionSymbol exp) {
boolean res = false;
if (exp.isTuple()){
//todo
}
return false;
return exp.isIntTuple().get();
}
},
POSITIVE {
@Override
public boolean check(ArchSimpleExpressionSymbol exp) {
if (exp.getDoubleValue().isPresent()){
return exp.getDoubleValue().get() > 0;
}
else if (exp.getDoubleTupleValues().isPresent()){
boolean isPositive = true;
for (double value : exp.getDoubleTupleValues().get()){
if (value <= 0){
isPositive = false;
}
}
return isPositive;
}
return false;
}
},
NON_NEGATIVE {
@Override
public boolean check(ArchSimpleExpressionSymbol exp) {
if (exp.getDoubleValue().isPresent()){
return exp.getDoubleValue().get() >= 0;
}
else if (exp.getDoubleTupleValues().isPresent()){
boolean isPositive = true;
for (double value : exp.getDoubleTupleValues().get()){
if (value < 0){
isPositive = false;
}
}
return isPositive;
}
return false;
}
},
BETWEEN_ZERO_AND_ONE {
@Override
public boolean check(ArchSimpleExpressionSymbol exp) {
if (exp.getDoubleValue().isPresent()){
return exp.getDoubleValue().get() >= 0 && exp.getDoubleValue().get() <= 1;
}
else if (exp.getDoubleTupleValues().isPresent()){
boolean isPositive = true;
for (double value : exp.getDoubleTupleValues().get()){
if (value < 0 || value > 1){
isPositive = false;
}
}
return isPositive;
}
return false;
}
};
......
......@@ -41,7 +41,7 @@ public class PredefinedMethods {
public static final String AVG_POOLING_NAME = "AveragePooling";
public static final String LRN_NAME = "Lrn";
public static final String BATCHNORM_NAME = "BatchNorm";
public static final String SPLIT_NAME = "Split";
public static final String SPLIT_NAME = "SplitData";
public static final String GET_NAME = "Get";
public static final String ADD_NAME = "Add";
public static final String CONCATENATE_NAME = "Concatenate";
......@@ -323,6 +323,15 @@ public class PredefinedMethods {
}
private static List<ShapeSymbol> strideShapeFunction(ShapeSymbol inputShape, MethodLayerSymbol method, int channels) {
Optional<Boolean> optGlobal = method.getBooleanValue("global");
if (optGlobal.isPresent() && optGlobal.get()){
return Collections.singletonList(new ShapeSymbol.Builder()
.height(1)
.width(1)
.channels(channels)
.build());
}
else{
int strideHeight = method.getIntTupleValue("stride").get().get(0);
int strideWidth = method.getIntTupleValue("stride").get().get(1);
int kernelHeight = method.getIntTupleValue("kernel").get().get(0);
......@@ -331,8 +340,12 @@ public class PredefinedMethods {
int inputWidth = inputShape.getWidth().get();
//assume padding with border_mode='same'
int outputWidth = 1 + ((inputWidth - kernelWidth + strideWidth - 1) / strideWidth);
int outputHeight = 1 + ((inputHeight - kernelHeight + strideHeight - 1) / strideHeight);
int outputWidth = inputWidth / strideWidth;
int outputHeight = inputHeight / strideHeight;
//border_mode=valid
//int outputWidth = 1 + Math.max(0, ((inputWidth - kernelWidth + strideWidth - 1) / strideWidth));
//int outputHeight = 1 + Math.max(0, ((inputHeight - kernelHeight + strideHeight - 1) / strideHeight));
return Collections.singletonList(new ShapeSymbol.Builder()
.height(outputHeight)
......@@ -340,6 +353,7 @@ public class PredefinedMethods {
.channels(channels)
.build());
}
}
private static List<ShapeSymbol> splitShapeFunction(ShapeSymbol inputShape, MethodLayerSymbol method) {
int numberOfSplits = method.getIntValue("n").get();
......
......@@ -27,6 +27,7 @@ public class PredefinedVariables {
public static final String IF_NAME = "_if";
public static final String FOR_NAME = "_for";
public static final String CARDINALITY_NAME = "_cardinality";
public static final String TRUE_NAME = "true";
public static final String FALSE_NAME = "false";
......@@ -45,6 +46,13 @@ public class PredefinedVariables {
.build();
}
public static VariableSymbol createCardinalityParameter(){
return new VariableSymbol.Builder()
.name(CARDINALITY_NAME)
.defaultValue(1)
.build();
}
//necessary because true is currently only a name in MontiMath and it needs to be evaluated at compile time for this language
public static VariableSymbol createTrueConstant(){
return new VariableSymbol.Builder()
......
......@@ -22,9 +22,7 @@ package de.monticore.lang.monticar.cnnarch._symboltable;
import de.monticore.symboltable.CommonSymbol;
import de.monticore.symboltable.MutableScope;
import de.monticore.symboltable.Scope;
import de.monticore.symboltable.Symbol;
import de.se_rwth.commons.logging.Log;
import java.util.*;
......@@ -32,7 +30,7 @@ abstract public class ArchExpressionSymbol extends CommonSymbol {
public static final ArchExpressionKind KIND = new ArchExpressionKind();
private Set<String> unresolvableNames = null;
private Set<VariableSymbol> unresolvableVariables = null;
public ArchExpressionSymbol() {
super("", KIND);
......@@ -40,25 +38,28 @@ abstract public class ArchExpressionSymbol extends CommonSymbol {
protected Boolean isResolvable(){
Set<String> set = getUnresolvableNames();
Set<VariableSymbol> set = getUnresolvableVariables();
return set != null && set.isEmpty();
}
public Set<String> getUnresolvableNames() {
if (unresolvableNames == null){
checkIfResolvable();
public Set<VariableSymbol> getUnresolvableVariables() {
if (unresolvableVariables == null){
checkIfResolvable(new HashSet<>());
}
return unresolvableNames;
return unresolvableVariables;
}
protected void setUnresolvableNames(Set<String> unresolvableNames){
this.unresolvableNames = unresolvableNames;
protected void setUnresolvableVariables(Set<VariableSymbol> unresolvableVariables){
this.unresolvableVariables = unresolvableVariables;
}
public void checkIfResolvable(){
setUnresolvableNames(computeUnresolvableNames());
public void checkIfResolvable(Set<VariableSymbol> seenVariables){
Set<VariableSymbol> unresolvableVariables = new HashSet<>();
computeUnresolvableVariables(unresolvableVariables, seenVariables);
setUnresolvableVariables(unresolvableVariables);
}
/**
* Checks whether the value is a boolean. If true getValue() will return a Boolean if present.
*
......@@ -99,21 +100,21 @@ abstract public class ArchExpressionSymbol extends CommonSymbol {
public Optional<Boolean> isIntTuple(){
if (getValue().isPresent()){
return Optional.of(getIntTupleValue().isPresent());
return Optional.of(getIntTupleValues().isPresent());
}
return Optional.empty();
}
public Optional<Boolean> isNumberTuple(){
if (getValue().isPresent()){
return Optional.of(getDoubleTupleValue().isPresent());
return Optional.of(getDoubleTupleValues().isPresent());
}
return Optional.empty();
}
public Optional<Boolean> isBooleanTuple(){
if (getValue().isPresent()){
return Optional.of(getBooleanTupleValue().isPresent());
return Optional.of(getBooleanTupleValues().isPresent());
}
return Optional.empty();
}
......@@ -194,8 +195,8 @@ abstract public class ArchExpressionSymbol extends CommonSymbol {
return Optional.empty();
}
public Optional<List<Integer>> getIntTupleValue(){
Optional<List<Object>> optValue = getTupleValue();
public Optional<List<Integer>> getIntTupleValues(){
Optional<List<Object>> optValue = getTupleValues();
if (optValue.isPresent()){
List<Integer> list = new ArrayList<>();
for (Object value : optValue.get()) {
......@@ -211,8 +212,8 @@ abstract public class ArchExpressionSymbol extends CommonSymbol {
return Optional.empty();
}
public Optional<List<Double>> getDoubleTupleValue() {
Optional<List<Object>> optValue = getTupleValue();
public Optional<List<Double>> getDoubleTupleValues() {
Optional<List<Object>> optValue = getTupleValues();
if (optValue.isPresent()){
List<Double> list = new ArrayList<>();
for (Object value : optValue.get()) {
......@@ -231,8 +232,8 @@ abstract public class ArchExpressionSymbol extends CommonSymbol {
return Optional.empty();
}
public Optional<List<Boolean>> getBooleanTupleValue() {
Optional<List<Object>> optValue = getTupleValue();
public Optional<List<Boolean>> getBooleanTupleValues() {
Optional<List<Object>> optValue = getTupleValues();
if (optValue.isPresent()){
List<Boolean> list = new ArrayList<>();
for (Object value : optValue.get()) {
......@@ -248,9 +249,10 @@ abstract public class ArchExpressionSymbol extends CommonSymbol {
return Optional.empty();
}
public Optional<List<Object>> getTupleValue(){
public Optional<List<Object>> getTupleValues(){
if (getValue().isPresent()){
if (isTuple()){
Optional<Object> optValue = getValue();
if (optValue.isPresent() && (optValue.get() instanceof List)){
@SuppressWarnings("unchecked")
List<Object> list = (List<Object>) getValue().get();
return Optional.of(list);
......@@ -300,7 +302,7 @@ abstract public class ArchExpressionSymbol extends CommonSymbol {
public void resolveOrError(){
resolve();
if (!isResolved()){
throw new IllegalStateException("The following names could not be resolved: " + getUnresolvableNames());
throw new IllegalStateException("The following names could not be resolved: " + getUnresolvableVariables());
}
}
......@@ -315,13 +317,15 @@ abstract public class ArchExpressionSymbol extends CommonSymbol {
*/
abstract public Optional<Object> getValue();
abstract public void reset();
/**
* Replaces all variable names in this values expression if possible.
* The values of the variables depend on the current scope. The replacement is irreversible if successful.
*
* @return returns a set of all names which could not be resolved.
*/
abstract public Set<String> resolve();
abstract public Set<VariableSymbol> resolve();
/**
* @return returns a optional of a list(parallel) of lists(serial) of simple expressions in this sequence.
......@@ -330,7 +334,7 @@ abstract public class ArchExpressionSymbol extends CommonSymbol {
*/
abstract public Optional<List<List<ArchSimpleExpressionSymbol>>> getElements();
abstract protected Set<String> computeUnresolvableNames();
abstract protected void computeUnresolvableVariables(Set<VariableSymbol> unresolvableVariables, Set<VariableSymbol> allVariables);
/**
* @return returns true if the expression is resolved.
......
......@@ -21,7 +21,6 @@
package de.monticore.lang.monticar.cnnarch._symboltable;
import de.monticore.symboltable.MutableScope;
import de.monticore.symboltable.Scope;
import java.util.*;
import java.util.stream.Collectors;
......@@ -63,6 +62,13 @@ public class ArchRangeExpressionSymbol extends ArchAbstractSequenceExpression {
this.parallel = parallel;
}
@Override
public void reset() {
getStartSymbol().reset();
getEndSymbol().reset();
setUnresolvableVariables(null);
}
@Override
public boolean isParallelSequence() {
return isParallel();
......@@ -88,16 +94,15 @@ public class ArchRangeExpressionSymbol extends ArchAbstractSequenceExpression {
}*/
@Override
public Set<String> resolve() {
public Set<VariableSymbol> resolve() {
if (!isResolved()){
checkIfResolvable();
if (isResolvable()){
getStartSymbol().resolveOrError();
getEndSymbol().resolveOrError();
}
}
return getUnresolvableNames();
return getUnresolvableVariables();
}
@Override
......@@ -142,11 +147,11 @@ public class ArchRangeExpressionSymbol extends ArchAbstractSequenceExpression {
}
@Override
protected Set<String> computeUnresolvableNames() {
Set<String> unresolvableNames = new HashSet<>();
unresolvableNames.addAll(getStartSymbol().computeUnresolvableNames());
unresolvableNames.addAll(getEndSymbol().computeUnresolvableNames());
return unresolvableNames;
protected void computeUnresolvableVariables(Set<VariableSymbol> unresolvableVariables, Set<VariableSymbol> allVariables) {
getStartSymbol().checkIfResolvable(allVariables);
unresolvableVariables.addAll(getStartSymbol().getUnresolvableVariables());
getEndSymbol().checkIfResolvable(allVariables);
unresolvableVariables.addAll(getEndSymbol().getUnresolvableVariables());
}
public ArchRangeExpressionSymbol copy(){
......@@ -154,7 +159,7 @@ public class ArchRangeExpressionSymbol extends ArchAbstractSequenceExpression {
copy.setParallel(isParallel());
copy.setStartSymbol(getStartSymbol().copy());
copy.setEndSymbol(getEndSymbol().copy());
copy.setUnresolvableNames(getUnresolvableNames());
copy.setUnresolvableVariables(getUnresolvableVariables());
return copy;
}
......@@ -165,10 +170,11 @@ public class ArchRangeExpressionSymbol extends ArchAbstractSequenceExpression {
getEndSymbol().putInScope(scope);
}
public static ArchRangeExpressionSymbol of(ArchSimpleExpressionSymbol start, ArchSimpleExpressionSymbol end){
public static ArchRangeExpressionSymbol of(ArchSimpleExpressionSymbol start, ArchSimpleExpressionSymbol end, boolean parallel){
ArchRangeExpressionSymbol sym = new ArchRangeExpressionSymbol();
sym.setStartSymbol(start);
sym.setEndSymbol(end);
sym.setParallel(parallel);
return sym;
}
}
......@@ -21,7 +21,6 @@
package de.monticore.lang.monticar.cnnarch._symboltable;
import de.monticore.symboltable.MutableScope;
import de.monticore.symboltable.Scope;
import java.util.*;
......@@ -47,6 +46,16 @@ public class ArchSequenceExpressionSymbol extends ArchAbstractSequenceExpression
this.elements = elements;
}
@Override
public void reset() {
for (List<ArchSimpleExpressionSymbol> serialElements : _getElements()){
for (ArchSimpleExpressionSymbol element : serialElements){
element.reset();
}
}
setUnresolvableVariables(null);
}
@Override
public boolean isSerialSequence(){
boolean isSerial = !isParallelSequence();
......@@ -64,10 +73,9 @@ public class ArchSequenceExpressionSymbol extends ArchAbstractSequenceExpression
}
@Override
public Set<String> resolve() {
if (!isResolved()){
checkIfResolvable();
if (isResolvable()){
public Set<VariableSymbol> resolve() {
if (!isResolved()) {
if (isResolvable()) {
for (List<ArchSimpleExpressionSymbol> serialList : _getElements()) {
for (ArchSimpleExpressionSymbol element : serialList) {
......@@ -76,7 +84,7 @@ public class ArchSequenceExpressionSymbol extends ArchAbstractSequenceExpression
}
}
}
return getUnresolvableNames();
return getUnresolvableVariables();
}
@Override
......@@ -93,14 +101,13 @@ public class ArchSequenceExpressionSymbol extends ArchAbstractSequenceExpression
}
@Override
protected Set<String> computeUnresolvableNames() {
Set<String> unresolvableNames = new HashSet<>();
protected void computeUnresolvableVariables(Set<VariableSymbol> unresolvableVariables, Set<VariableSymbol> allVariables) {
for (List<ArchSimpleExpressionSymbol> serialElements : _getElements()){
for (ArchSimpleExpressionSymbol element : serialElements){
unresolvableNames.addAll(element.computeUnresolvableNames());
element.checkIfResolvable(allVariables);
unresolvableVariables.addAll(element.getUnresolvableVariables());
}
}
return unresolvableNames;
}
public ArchSequenceExpressionSymbol copy(){
......@@ -114,7 +121,7 @@ public class ArchSequenceExpressionSymbol extends ArchAbstractSequenceExpression
elementsCopy.add(serialListCopy);
}
copy.setElements(getElements().get());
copy.setUnresolvableNames(getUnresolvableNames());
copy.setUnresolvableVariables(getUnresolvableVariables());
return copy;
}
......
......@@ -53,6 +53,14 @@ public class ArchSimpleExpressionSymbol extends ArchExpressionSymbol implements
this.value = value;
}
@Override
public void reset(){
if (getMathExpression().isPresent()){
setValue(null);
setUnresolvableVariables(null);
}
}
@Override
public boolean isSimpleValue() {
return true;
......@@ -60,7 +68,7 @@ public class ArchSimpleExpressionSymbol extends ArchExpressionSymbol implements
@Override
public boolean isBoolean() {
if (getMathExpression().isPresent()){
if (getMathExpression().isPresent() && !(getMathExpression().get() instanceof MathNameExpressionSymbol)){
return getMathExpression().get() instanceof MathCompareExpressionSymbol;
}
else {
......@@ -70,7 +78,7 @@ public class ArchSimpleExpressionSymbol extends ArchExpressionSymbol implements
@Override
public boolean isNumber() {
if (getMathExpression().isPresent()){
if (getMathExpression().isPresent() && !(getMathExpression().get() instanceof MathNameExpressionSymbol)){
return getMathExpression().get() instanceof MathArithmeticExpressionSymbol;
}
else {
......@@ -80,37 +88,31 @@ public class ArchSimpleExpressionSymbol extends ArchExpressionSymbol implements
@Override
public boolean isTuple() {
if (getMathExpression().isPresent()){
if (getMathExpression().isPresent() && !(getMathExpression().get() instanceof MathNameExpressionSymbol)){
return getMathExpression().get() instanceof TupleExpressionSymbol;
}
else {
return getValue().get() instanceof List;
}
return getTupleValues().isPresent();
}
@Override
protected Set<String> computeUnresolvableNames() {
Set<String> unresolvableNames = new HashSet<>();
Set<String> allNames = new HashSet<>();
computeUnresolvableNames(unresolvableNames, allNames);
return unresolvableNames;
}
protected void computeUnresolvableNames(Set<String> unresolvableNames, Set<String> allNames) {
protected void computeUnresolvableVariables(Set<VariableSymbol> unresolvableVariables, Set<VariableSymbol> allVariables) {
if (getMathExpression().isPresent()) {
for (MathExpressionSymbol exp : ExpressionHelper.createSubExpressionList(getMathExpression().get())) {
if (exp instanceof MathNameExpressionSymbol) {
String name = ((MathNameExpressionSymbol) exp).getNameToAccess();
if (!allNames.contains(name)) {
allNames.add(name);
Optional<VariableSymbol> variable = getEnclosingScope().resolve(name, VariableSymbol.KIND);
if (variable.isPresent() && !variable.get().getExpression().isResolved()) {
//todo: implement coco to check isPresent()
if (!allVariables.contains(variable.get())) {
allVariables.add(variable.get());
if (variable.get().hasValue()) {
variable.get().getExpression().computeUnresolvableNames(unresolvableNames, allNames);
} else {
unresolvableNames.add(name);
if (!variable.get().getExpression().isResolved()) {
variable.get().getExpression().computeUnresolvableVariables(unresolvableVariables, allVariables);
}
}
else {
unresolvableVariables.add(variable.get());
}
}
}
}
......@@ -118,35 +120,45 @@ public class ArchSimpleExpressionSymbol extends ArchExpressionSymbol implements
}
@Override
public Set<String> resolve() {
checkIfResolvable();
public Set<VariableSymbol> resolve() {
if (!isResolved()) {
if (getMathExpression().isPresent() && isResolvable()) {
Object value;
if (isTuple()){
TupleExpressionSymbol tuple = (TupleExpressionSymbol) getMathExpression().get();
List<Object> tupleValues = new ArrayList<>(tuple.getExpressions().size());
for (MathExpressionSymbol exp : tuple.getExpressions()){
tupleValues.add(computeValue());
Object value = computeValue();
setValue(value);
}
value = tupleValues;
}
return getUnresolvableVariables();
}
private Object computeValue(){
if (getMathExpression().get() instanceof MathNameExpressionSymbol){
return computeValue((MathNameExpressionSymbol) getMathExpression().get());
}
else if (getMathExpression().get() instanceof TupleExpressionSymbol){
Map<String, String> replacementMap = new HashMap<>();
List<Object> valueList = new ArrayList<>();
TupleExpressionSymbol tuple = (TupleExpressionSymbol) getMathExpression().get();
for (MathExpressionSymbol mathExp : tuple.getExpressions()){