fully working resolve method

parent 8a7f93d4
......@@ -42,7 +42,7 @@
<mc.grammars.assembly.version>0.0.6-SNAPSHOT</mc.grammars.assembly.version>
<SIUnit.version>0.0.6-SNAPSHOT</SIUnit.version>
<Common-MontiCar.version>0.0.3-SNAPSHOT</Common-MontiCar.version>
<Math.version>0.0.3-SNAPSHOT-REWORK</Math.version>
<Math.version>0.0.7.1</Math.version>
<!-- .. Libraries .................................................. -->
<guava.version>18.0</guava.version>
<junit.version>4.12</junit.version>
......
......@@ -30,8 +30,63 @@ import java.util.*;
public class PredefinedMethods {
public static final MethodDeclarationSymbol FULLY_CONNECTED = new MethodDeclarationSymbol.Builder()
.name("FullyConnected")
public static final String FULLY_CONNECTED_NAME = "FullyConnected";
public static final String CONVOLUTION_NAME = "Convolution";
public static final String SOFTMAX_NAME = "Softmax";
public static final String SIGMOID_NAME = "Sigmoid";
public static final String TANH_NAME = "Tanh";
public static final String RELU_NAME = "Relu";
public static final String DROPOUT_NAME = "Dropout";
public static final String MAX_POOLING_NAME = "MaxPooling";
public static final String AVG_POOLING_NAME = "AveragePooling";
public static final String LRN_NAME = "Lrn";
public static final String BATCHNORM_NAME = "BatchNorm";
public static final String SPLIT_NAME = "Split";
public static final String GET_NAME = "Get";
public static final String ADD_NAME = "Add";
public static final String CONCATENATE_NAME = "Concatenate";
public static final List<String> NAME_LIST = Arrays.asList(
FULLY_CONNECTED_NAME,
CONVOLUTION_NAME,
SOFTMAX_NAME,
SIGMOID_NAME,
TANH_NAME,
RELU_NAME,
DROPOUT_NAME,
MAX_POOLING_NAME,
AVG_POOLING_NAME,
LRN_NAME,
BATCHNORM_NAME,
SPLIT_NAME,
GET_NAME,
ADD_NAME,
CONCATENATE_NAME);
public static List<MethodDeclarationSymbol> createList(){
return Arrays.asList(
createFullyConnected(),
createConvolution(),
createSoftmax(),
createSigmoid(),
createTanh(),
createRelu(),
createDropout(),
createMaxPooling(),
createAveragePooling(),
createLrn(),
createBatchNorm(),
createSplit(),
createGet(),
createAdd(),
createConcatenate());
}
public static MethodDeclarationSymbol createFullyConnected(){
return new MethodDeclarationSymbol.Builder()
.name(FULLY_CONNECTED_NAME)
.parameters(
new VariableSymbol.Builder()
.name("units")
......@@ -49,9 +104,11 @@ public class PredefinedMethods {
.channels(method.getIntValue("units").get())
.build()))
.build();
}
public static final MethodDeclarationSymbol CONVOLUTION = new MethodDeclarationSymbol.Builder()
.name("Convolution")
public static MethodDeclarationSymbol createConvolution(){
return new MethodDeclarationSymbol.Builder()
.name(CONVOLUTION_NAME)
.parameters(
new VariableSymbol.Builder()
.name("kernel")
......@@ -64,7 +121,7 @@ public class PredefinedMethods {
new VariableSymbol.Builder()
.name("stride")
.constraints(Constraint.INTEGER_TUPLE, Constraint.POSITIVE)
.defaultValue(1, 1)
.defaultValue(Arrays.asList(1, 1))
.build(),
new VariableSymbol.Builder()
.name("no_bias")
......@@ -77,41 +134,53 @@ public class PredefinedMethods {
method,
method.getIntValue("channels").get()))
.build();
}
public static final MethodDeclarationSymbol SOFTMAX = new MethodDeclarationSymbol.Builder()
.name("Softmax")
public static MethodDeclarationSymbol createSoftmax(){
return new MethodDeclarationSymbol.Builder()
.name(SOFTMAX_NAME)
.shapeFunction((inputShapes, method) -> inputShapes)
.build();
}
public static final MethodDeclarationSymbol SIGMOID = new MethodDeclarationSymbol.Builder()
.name("Sigmoid")
public static MethodDeclarationSymbol createSigmoid(){
return new MethodDeclarationSymbol.Builder()
.name(SIGMOID_NAME)
.shapeFunction((inputShapes, method) -> inputShapes)
.build();
}
public static final MethodDeclarationSymbol TANH = new MethodDeclarationSymbol.Builder()
.name("Tanh")
public static MethodDeclarationSymbol createTanh(){
return new MethodDeclarationSymbol.Builder()
.name(TANH_NAME)
.shapeFunction((inputShapes, method) -> inputShapes)
.build();
}
public static final MethodDeclarationSymbol RELU = new MethodDeclarationSymbol.Builder()
.name("Relu")
public static MethodDeclarationSymbol createRelu(){
return new MethodDeclarationSymbol.Builder()
.name(RELU_NAME)
.shapeFunction((inputShapes, method) -> inputShapes)
.build();
}
public static final MethodDeclarationSymbol DROPOUT = new MethodDeclarationSymbol.Builder()
.name("Dropout")
public static MethodDeclarationSymbol createDropout(){
return new MethodDeclarationSymbol.Builder()
.name(DROPOUT_NAME)
.parameters(
new VariableSymbol.Builder()
.name("p")
.constraints(Constraint.NUMBER, Constraint.BETWEEN_ZERO_AND_ONE)
.defaultValue(Rational.valueOf(1, 2))//0.5
.defaultValue(0.5)
.build()
)
.shapeFunction((inputShapes, method) -> inputShapes)
.build();
}
public static final MethodDeclarationSymbol MAX_POOLING = new MethodDeclarationSymbol.Builder()
.name("MaxPooling")
public static MethodDeclarationSymbol createMaxPooling(){
return new MethodDeclarationSymbol.Builder()
.name(MAX_POOLING_NAME)
.parameters(
new VariableSymbol.Builder()
.name("kernel")
......@@ -120,7 +189,7 @@ public class PredefinedMethods {
new VariableSymbol.Builder()
.name("stride")
.constraints(Constraint.INTEGER_TUPLE, Constraint.POSITIVE)
.defaultValue(1, 1)
.defaultValue(Arrays.asList(1, 1))
.build(),
new VariableSymbol.Builder()
.name("global")
......@@ -133,9 +202,11 @@ public class PredefinedMethods {
method,
inputShapes.get(0).getChannels().get()))
.build();
}
public static final MethodDeclarationSymbol AVERAGE_POOLING = new MethodDeclarationSymbol.Builder()
.name("AveragePooling")
public static MethodDeclarationSymbol createAveragePooling(){
return new MethodDeclarationSymbol.Builder()
.name(AVG_POOLING_NAME)
.parameters(
new VariableSymbol.Builder()
.name("kernel")
......@@ -144,7 +215,7 @@ public class PredefinedMethods {
new VariableSymbol.Builder()
.name("stride")
.constraints(Constraint.INTEGER_TUPLE, Constraint.POSITIVE)
.defaultValue(1, 1)
.defaultValue(Arrays.asList(1, 1))
.build(),
new VariableSymbol.Builder()
.name("global")
......@@ -157,9 +228,11 @@ public class PredefinedMethods {
method,
inputShapes.get(0).getChannels().get()))
.build();
}
public static final MethodDeclarationSymbol LRN = new MethodDeclarationSymbol.Builder()
.name("Lrn")
public static MethodDeclarationSymbol createLrn(){
return new MethodDeclarationSymbol.Builder()
.name(LRN_NAME)
.parameters(
new VariableSymbol.Builder()
.name("nsize")
......@@ -173,19 +246,21 @@ public class PredefinedMethods {
new VariableSymbol.Builder()
.name("alpha")
.constraints(Constraint.NUMBER)
.defaultValue(Rational.valueOf(1, 10000))//0.0001
.defaultValue(0.0001)
.build(),
new VariableSymbol.Builder()
.name("beta")
.constraints(Constraint.NUMBER)
.defaultValue(Rational.valueOf(3, 4))//0.75
.defaultValue(0.75)
.build()
)
.shapeFunction((inputShapes, method) -> inputShapes)
.build();
}
public static final MethodDeclarationSymbol BATCHNORM = new MethodDeclarationSymbol.Builder()
.name("BatchNorm")
public static MethodDeclarationSymbol createBatchNorm(){
return new MethodDeclarationSymbol.Builder()
.name(BATCHNORM_NAME)
.parameters(
new VariableSymbol.Builder()
.name("fix_gamma")
......@@ -200,9 +275,11 @@ public class PredefinedMethods {
)
.shapeFunction((inputShapes, method) -> inputShapes)
.build();
}
public static final MethodDeclarationSymbol SPLIT = new MethodDeclarationSymbol.Builder()
.name("Split")
public static MethodDeclarationSymbol createSplit(){
return new MethodDeclarationSymbol.Builder()
.name(SPLIT_NAME)
.parameters(
new VariableSymbol.Builder()
.name("index")
......@@ -215,9 +292,11 @@ public class PredefinedMethods {
)
.shapeFunction((inputShapes, method) -> splitShapeFunction(inputShapes.get(0), method))
.build();
}
public static final MethodDeclarationSymbol GET = new MethodDeclarationSymbol.Builder()
.name("Get")
public static MethodDeclarationSymbol createGet(){
return new MethodDeclarationSymbol.Builder()
.name(GET_NAME)
.parameters(
new VariableSymbol.Builder()
.name("index")
......@@ -227,45 +306,20 @@ public class PredefinedMethods {
.shapeFunction((inputShapes, method) ->
Collections.singletonList(inputShapes.get(method.getIntValue("index").get())))
.build();
}
public static final MethodDeclarationSymbol ADD = new MethodDeclarationSymbol.Builder()
.name("Add")
public static MethodDeclarationSymbol createAdd(){
return new MethodDeclarationSymbol.Builder()
.name(ADD_NAME)
.shapeFunction((inputShapes, method) -> Collections.singletonList(inputShapes.get(0)))
.build();
}
public static final MethodDeclarationSymbol CONCATENATE = new MethodDeclarationSymbol.Builder()
.name("Concatenate")
public static MethodDeclarationSymbol createConcatenate(){
return new MethodDeclarationSymbol.Builder()
.name(CONCATENATE_NAME)
.shapeFunction(PredefinedMethods::concatenateShapeFunction)
.build();
public static final List<MethodDeclarationSymbol> LIST = Arrays.asList(
FULLY_CONNECTED,
CONVOLUTION,
SOFTMAX,
SIGMOID,
TANH,
RELU,
DROPOUT,
MAX_POOLING,
AVERAGE_POOLING,
LRN,
BATCHNORM,
SPLIT,
GET,
ADD,
CONCATENATE);
public static final Map<String, MethodDeclarationSymbol> MAP = createPredefinedMap();
private static Map<String, MethodDeclarationSymbol> createPredefinedMap() {
Map<String, MethodDeclarationSymbol> map = new HashMap<>();
for (MethodDeclarationSymbol method : LIST) {
map.put(method.getName(), method);
}
return map;
}
private static List<ShapeSymbol> strideShapeFunction(ShapeSymbol inputShape, MethodLayerSymbol method, int channels) {
......
......@@ -31,20 +31,23 @@ abstract public class ArchAbstractSequenceExpression extends ArchExpressionSymbo
super();
}
abstract public Optional<List<List<ArchSimpleExpressionSymbol>>> getElements();
abstract public boolean isParallelSequence();
abstract public boolean isSerialSequence();
abstract public Optional<Integer> getParallelLength();
abstract public Optional<Integer> getMaxSerialLength();
@Override
public boolean isSequence(){
return true;
}
@Override
public Optional<Object> getValue() {
if (isResolved()){
List<List<Object>> parallelValues = new ArrayList<>(getParallelLength().get());
for (List<ArchSimpleExpressionSymbol> serialElements : getElements().get()){
List<Object> serialValues = new ArrayList<>(getMaxSerialLength().get());
List<Object> serialValues = new ArrayList<>();
for (ArchSimpleExpressionSymbol element : serialElements){
serialValues.add(element.getValue().get());
}
......@@ -57,4 +60,19 @@ abstract public class ArchAbstractSequenceExpression extends ArchExpressionSymbo
}
}
@Override
public boolean isBoolean(){
return false;
}
@Override
public boolean isNumber(){
return false;
}
@Override
public boolean isTuple(){
return false;
}
}
......@@ -21,7 +21,10 @@
package de.monticore.lang.monticar.cnnarch._symboltable;
import de.monticore.symboltable.CommonSymbol;
import de.monticore.symboltable.MutableScope;
import de.monticore.symboltable.Scope;
import de.monticore.symboltable.Symbol;
import de.se_rwth.commons.logging.Log;
import java.util.*;
......@@ -31,12 +34,16 @@ abstract public class ArchExpressionSymbol extends CommonSymbol {
private Set<String> unresolvableNames = null;
public ArchExpressionSymbol() {
super("", KIND);
}
protected Boolean isResolvable(){
Set<String> set = getUnresolvableNames();
return set != null && set.isEmpty();
}
public Set<String> getUnresolvableNames() {
if (unresolvableNames == null){
checkIfResolvable();
......@@ -44,17 +51,12 @@ abstract public class ArchExpressionSymbol extends CommonSymbol {
return unresolvableNames;
}
public boolean isResolvable(){
return getUnresolvableNames().isEmpty();
protected void setUnresolvableNames(Set<String> unresolvableNames){
this.unresolvableNames = unresolvableNames;
}
public void checkIfResolvable(){
if (isResolved()){
unresolvableNames = new HashSet<>();
}
else {
unresolvableNames = computeUnresolvableNames();
}
setUnresolvableNames(computeUnresolvableNames());
}
/**
......@@ -62,9 +64,7 @@ abstract public class ArchExpressionSymbol extends CommonSymbol {
*
* @return returns true iff the value of the resolved expression will be a boolean.
*/
public boolean isBoolean(){
return false;
}
abstract public boolean isBoolean();
/**
* Checks whether the value is a number.
......@@ -72,9 +72,7 @@ abstract public class ArchExpressionSymbol extends CommonSymbol {
*
* @return returns true iff the value of the resolved expression will be a number.
*/
public boolean isNumber(){
return false;
}
abstract public boolean isNumber();
/**
* Checks whether the value is a Tuple.
......@@ -83,9 +81,7 @@ abstract public class ArchExpressionSymbol extends CommonSymbol {
*
* @return returns true iff the value of the expression will be a tuple.
*/
public boolean isTuple(){
return false;
}
abstract public boolean isTuple();
/**
* Checks whether the value is an integer. This can only be checked if the expression is resolved.
......@@ -124,7 +120,7 @@ abstract public class ArchExpressionSymbol extends CommonSymbol {
/**
* Checks whether the value is a parallel Sequence.
* If true, getRhs() will return (if present) a List of Lists of Objects.
* If true, getValue() will return (if present) a List of Lists of Objects.
* These Objects can either be Integer, Double or Boolean.
* If isSerialSequence() returns false, the second List will always have a size smaller than 2.
* Sequences of size 1 or 0 cannot be parallel sequences.
......@@ -137,7 +133,7 @@ abstract public class ArchExpressionSymbol extends CommonSymbol {
/**
* Checks whether the value is a serial Sequence.
* If true, getRhs() will return (if present) a List(parallel) of Lists(serial) of Objects.
* If true, getValue() will return (if present) a List(parallel) of Lists(serial) of Objects.
* If isParallelSequence() is false, the first list will be of size 1.
* These Objects can either be Integer, Double or Boolean.
* Sequences of size 1 or 0 are counted as serial sequences.
......@@ -157,6 +153,10 @@ abstract public class ArchExpressionSymbol extends CommonSymbol {
return false;
}
public boolean isSequence(){
return false;
}
/**
*
* @return returns true if this object is instance of ArchSimpleExpressionSymbol
......@@ -259,11 +259,46 @@ abstract public class ArchExpressionSymbol extends CommonSymbol {
return Optional.empty();
}
public Optional<Integer> getParallelLength(){
Optional<List<List<ArchSimpleExpressionSymbol>>> elements = getElements();
return elements.map(e -> e.isEmpty() ? 1 : e.size());
}
public Optional<List<Integer>> getSerialLengths(){
Optional<List<List<ArchSimpleExpressionSymbol>>> elements = getElements();
if (elements.isPresent()){
List<Integer> serialLengths = new ArrayList<>();
for (List<ArchSimpleExpressionSymbol> serialList : getElements().get()){
serialLengths.add(serialList.size());
}
return Optional.of(serialLengths);
}
else {
return Optional.empty();
}
}
public Optional<Integer> getMaxSerialLength(){
int max = 0;
Optional<List<Integer>> optLens = getSerialLengths();
if (optLens.isPresent()){
for (int len : optLens.get()){
if (len > max){
max = len;
}
}
}
else {
return Optional.empty();
}
return Optional.of(max);
}
/**
* Same as resolve() but throws an error if it was not successful.
*/
public void resolveOrError(Scope resolvingScope){
resolve(resolvingScope);
public void resolveOrError(){
resolve();
if (!isResolved()){
throw new IllegalStateException("The following names could not be resolved: " + getUnresolvableNames());
}
......@@ -286,12 +321,29 @@ abstract public class ArchExpressionSymbol extends CommonSymbol {
*
* @return returns a set of all names which could not be resolved.
*/
abstract public Set<String> resolve(Scope resolvingScope);
abstract public Set<String> resolve();
/**
* @return returns a optional of a list(parallel) of lists(serial) of simple expressions in this sequence.
* These lists will only contain one element if this is not a sequence.
* If the optional is not present that means this expression is a range which is not resolved.
*/
abstract public Optional<List<List<ArchSimpleExpressionSymbol>>> getElements();
abstract protected Set<String> computeUnresolvableNames();
/**
* @return returns true if the expression is resolved.
*/
abstract public boolean isResolved();
abstract public ArchExpressionSymbol copy();
protected void putInScope(MutableScope scope){
Collection<Symbol> symbolsInScope = scope.getLocalSymbols().get(getName());
if (symbolsInScope == null || !symbolsInScope.contains(this)) {
scope.add(this);
}
}
}
......@@ -20,9 +20,12 @@
*/
package de.monticore.lang.monticar.cnnarch._symboltable;
import de.monticore.symboltable.MutableScope;
import de.monticore.symboltable.Scope;
import java.util.*;
import java.util.stream.Collectors;
import java.util.stream.IntStream;
public class ArchRangeExpressionSymbol extends ArchAbstractSequenceExpression {
......@@ -32,7 +35,7 @@ public class ArchRangeExpressionSymbol extends ArchAbstractSequenceExpression {
private List<List<ArchSimpleExpressionSymbol>> elements = null;
public ArchRangeExpressionSymbol() {
protected ArchRangeExpressionSymbol() {
super();
}
......@@ -40,7 +43,7 @@ public class ArchRangeExpressionSymbol extends ArchAbstractSequenceExpression {
return startSymbol;
}
public void setStartSymbol(ArchSimpleExpressionSymbol startSymbol) {
protected void setStartSymbol(ArchSimpleExpressionSymbol startSymbol) {
this.startSymbol = startSymbol;
}
......@@ -48,11 +51,11 @@ public class ArchRangeExpressionSymbol extends ArchAbstractSequenceExpression {
return endSymbol;
}
public void setEndSymbol(ArchSimpleExpressionSymbol endSymbol) {
protected void setEndSymbol(ArchSimpleExpressionSymbol endSymbol) {
this.endSymbol = endSymbol;
}
protected boolean isParallel() {
public boolean isParallel() {
return parallel;
}
......@@ -70,7 +73,7 @@ public class ArchRangeExpressionSymbol extends ArchAbstractSequenceExpression {
return !isParallel();
}
private Optional<Integer> getLength(){
/*private Optional<Integer> getLength(){
Optional<Integer> optLength = Optional.empty();
if (isResolved()){
Object startValue = getEndSymbol().getValue().get();
......@@ -82,36 +85,16 @@ public class ArchRangeExpressionSymbol extends ArchAbstractSequenceExpression {
}
}
return optLength;
}
@Override
public Optional<Integer> getParallelLength() {
if (isParallelSequence()) {
return getLength();
}
else {
return Optional.of(1);
}
}
@Override
public Optional<Integer> getMaxSerialLength() {
if (isSerialSequence()) {
return getLength();
}
else {
return Optional.of(1);
}
}
}*/
@Override
public Set<String> resolve(Scope resolvingScope) {
public Set<String> resolve() {