fully working resolve method

parent 8a7f93d4
......@@ -42,7 +42,7 @@
<mc.grammars.assembly.version>0.0.6-SNAPSHOT</mc.grammars.assembly.version>
<SIUnit.version>0.0.6-SNAPSHOT</SIUnit.version>
<Common-MontiCar.version>0.0.3-SNAPSHOT</Common-MontiCar.version>
<Math.version>0.0.3-SNAPSHOT-REWORK</Math.version>
<Math.version>0.0.7.1</Math.version>
<!-- .. Libraries .................................................. -->
<guava.version>18.0</guava.version>
<junit.version>4.12</junit.version>
......
......@@ -30,242 +30,296 @@ import java.util.*;
public class PredefinedMethods {
public static final MethodDeclarationSymbol FULLY_CONNECTED = new MethodDeclarationSymbol.Builder()
.name("FullyConnected")
.parameters(
new VariableSymbol.Builder()
.name("units")
.constraints(Constraint.INTEGER, Constraint.POSITIVE)
.build(),
new VariableSymbol.Builder()
.name("no_bias")
.constraints(Constraint.BOOLEAN)
.defaultValue(false)
.build()
)
.shapeFunction((inputShapes, method) -> Collections.singletonList(new ShapeSymbol.Builder()
.height(1)
.width(1)
.channels(method.getIntValue("units").get())
.build()))
.build();
public static final MethodDeclarationSymbol CONVOLUTION = new MethodDeclarationSymbol.Builder()
.name("Convolution")
.parameters(
new VariableSymbol.Builder()
.name("kernel")
.constraints(Constraint.INTEGER_TUPLE, Constraint.POSITIVE)
.build(),
new VariableSymbol.Builder()
.name("channels")
.constraints(Constraint.INTEGER, Constraint.POSITIVE)
.build(),
new VariableSymbol.Builder()
.name("stride")
.constraints(Constraint.INTEGER_TUPLE, Constraint.POSITIVE)
.defaultValue(1, 1)
.build(),
new VariableSymbol.Builder()
.name("no_bias")
.constraints(Constraint.BOOLEAN)
.defaultValue(false)
.build()
)
.shapeFunction((inputShapes, method) ->
strideShapeFunction(inputShapes.get(0),
method,
method.getIntValue("channels").get()))
.build();
public static final String FULLY_CONNECTED_NAME = "FullyConnected";
public static final String CONVOLUTION_NAME = "Convolution";
public static final String SOFTMAX_NAME = "Softmax";
public static final String SIGMOID_NAME = "Sigmoid";
public static final String TANH_NAME = "Tanh";
public static final String RELU_NAME = "Relu";
public static final String DROPOUT_NAME = "Dropout";
public static final String MAX_POOLING_NAME = "MaxPooling";
public static final String AVG_POOLING_NAME = "AveragePooling";
public static final String LRN_NAME = "Lrn";
public static final String BATCHNORM_NAME = "BatchNorm";
public static final String SPLIT_NAME = "Split";
public static final String GET_NAME = "Get";
public static final String ADD_NAME = "Add";
public static final String CONCATENATE_NAME = "Concatenate";
public static final MethodDeclarationSymbol SOFTMAX = new MethodDeclarationSymbol.Builder()
.name("Softmax")
.shapeFunction((inputShapes, method) -> inputShapes)
.build();
public static final List<String> NAME_LIST = Arrays.asList(
FULLY_CONNECTED_NAME,
CONVOLUTION_NAME,
SOFTMAX_NAME,
SIGMOID_NAME,
TANH_NAME,
RELU_NAME,
DROPOUT_NAME,
MAX_POOLING_NAME,
AVG_POOLING_NAME,
LRN_NAME,
BATCHNORM_NAME,
SPLIT_NAME,
GET_NAME,
ADD_NAME,
CONCATENATE_NAME);
public static final MethodDeclarationSymbol SIGMOID = new MethodDeclarationSymbol.Builder()
.name("Sigmoid")
.shapeFunction((inputShapes, method) -> inputShapes)
.build();
public static final MethodDeclarationSymbol TANH = new MethodDeclarationSymbol.Builder()
.name("Tanh")
.shapeFunction((inputShapes, method) -> inputShapes)
.build();
public static List<MethodDeclarationSymbol> createList(){
return Arrays.asList(
createFullyConnected(),
createConvolution(),
createSoftmax(),
createSigmoid(),
createTanh(),
createRelu(),
createDropout(),
createMaxPooling(),
createAveragePooling(),
createLrn(),
createBatchNorm(),
createSplit(),
createGet(),
createAdd(),
createConcatenate());
}
public static final MethodDeclarationSymbol RELU = new MethodDeclarationSymbol.Builder()
.name("Relu")
.shapeFunction((inputShapes, method) -> inputShapes)
.build();
public static final MethodDeclarationSymbol DROPOUT = new MethodDeclarationSymbol.Builder()
.name("Dropout")
.parameters(
new VariableSymbol.Builder()
.name("p")
.constraints(Constraint.NUMBER, Constraint.BETWEEN_ZERO_AND_ONE)
.defaultValue(Rational.valueOf(1, 2))//0.5
.build()
)
.shapeFunction((inputShapes, method) -> inputShapes)
.build();
public static MethodDeclarationSymbol createFullyConnected(){
return new MethodDeclarationSymbol.Builder()
.name(FULLY_CONNECTED_NAME)
.parameters(
new VariableSymbol.Builder()
.name("units")
.constraints(Constraint.INTEGER, Constraint.POSITIVE)
.build(),
new VariableSymbol.Builder()
.name("no_bias")
.constraints(Constraint.BOOLEAN)
.defaultValue(false)
.build()
)
.shapeFunction((inputShapes, method) -> Collections.singletonList(new ShapeSymbol.Builder()
.height(1)
.width(1)
.channels(method.getIntValue("units").get())
.build()))
.build();
}
public static final MethodDeclarationSymbol MAX_POOLING = new MethodDeclarationSymbol.Builder()
.name("MaxPooling")
.parameters(
new VariableSymbol.Builder()
.name("kernel")
.constraints(Constraint.INTEGER_TUPLE, Constraint.POSITIVE)
.build(),
new VariableSymbol.Builder()
.name("stride")
.constraints(Constraint.INTEGER_TUPLE, Constraint.POSITIVE)
.defaultValue(1, 1)
.build(),
new VariableSymbol.Builder()
.name("global")
.constraints(Constraint.BOOLEAN)
.defaultValue(false)
.build()
)
.shapeFunction((inputShapes, method) ->
strideShapeFunction(inputShapes.get(0),
method,
inputShapes.get(0).getChannels().get()))
.build();
public static MethodDeclarationSymbol createConvolution(){
return new MethodDeclarationSymbol.Builder()
.name(CONVOLUTION_NAME)
.parameters(
new VariableSymbol.Builder()
.name("kernel")
.constraints(Constraint.INTEGER_TUPLE, Constraint.POSITIVE)
.build(),
new VariableSymbol.Builder()
.name("channels")
.constraints(Constraint.INTEGER, Constraint.POSITIVE)
.build(),
new VariableSymbol.Builder()
.name("stride")
.constraints(Constraint.INTEGER_TUPLE, Constraint.POSITIVE)
.defaultValue(Arrays.asList(1, 1))
.build(),
new VariableSymbol.Builder()
.name("no_bias")
.constraints(Constraint.BOOLEAN)
.defaultValue(false)
.build()
)
.shapeFunction((inputShapes, method) ->
strideShapeFunction(inputShapes.get(0),
method,
method.getIntValue("channels").get()))
.build();
}
public static final MethodDeclarationSymbol AVERAGE_POOLING = new MethodDeclarationSymbol.Builder()
.name("AveragePooling")
.parameters(
new VariableSymbol.Builder()
.name("kernel")
.constraints(Constraint.INTEGER_TUPLE, Constraint.POSITIVE)
.build(),
new VariableSymbol.Builder()
.name("stride")
.constraints(Constraint.INTEGER_TUPLE, Constraint.POSITIVE)
.defaultValue(1, 1)
.build(),
new VariableSymbol.Builder()
.name("global")
.constraints(Constraint.BOOLEAN)
.defaultValue(false)
.build()
)
.shapeFunction((inputShapes, method) ->
strideShapeFunction(inputShapes.get(0),
method,
inputShapes.get(0).getChannels().get()))
.build();
public static MethodDeclarationSymbol createSoftmax(){
return new MethodDeclarationSymbol.Builder()
.name(SOFTMAX_NAME)
.shapeFunction((inputShapes, method) -> inputShapes)
.build();
}
public static final MethodDeclarationSymbol LRN = new MethodDeclarationSymbol.Builder()
.name("Lrn")
.parameters(
new VariableSymbol.Builder()
.name("nsize")
.constraints(Constraint.INTEGER, Constraint.NON_NEGATIVE)
.build(),
new VariableSymbol.Builder()
.name("knorm")
.constraints(Constraint.NUMBER)
.defaultValue(2)
.build(),
new VariableSymbol.Builder()
.name("alpha")
.constraints(Constraint.NUMBER)
.defaultValue(Rational.valueOf(1, 10000))//0.0001
.build(),
new VariableSymbol.Builder()
.name("beta")
.constraints(Constraint.NUMBER)
.defaultValue(Rational.valueOf(3, 4))//0.75
.build()
)
.shapeFunction((inputShapes, method) -> inputShapes)
.build();
public static MethodDeclarationSymbol createSigmoid(){
return new MethodDeclarationSymbol.Builder()
.name(SIGMOID_NAME)
.shapeFunction((inputShapes, method) -> inputShapes)
.build();
}
public static final MethodDeclarationSymbol BATCHNORM = new MethodDeclarationSymbol.Builder()
.name("BatchNorm")
.parameters(
new VariableSymbol.Builder()
.name("fix_gamma")
.constraints(Constraint.BOOLEAN)
.defaultValue(true)
.build(),
new VariableSymbol.Builder()
.name("axis")
.constraints(Constraint.INTEGER, Constraint.NON_NEGATIVE)
.defaultValue(ShapeSymbol.CHANNEL_INDEX)
.build()
)
.shapeFunction((inputShapes, method) -> inputShapes)
.build();
public static MethodDeclarationSymbol createTanh(){
return new MethodDeclarationSymbol.Builder()
.name(TANH_NAME)
.shapeFunction((inputShapes, method) -> inputShapes)
.build();
}
public static final MethodDeclarationSymbol SPLIT = new MethodDeclarationSymbol.Builder()
.name("Split")
.parameters(
new VariableSymbol.Builder()
.name("index")
.constraints(Constraint.INTEGER, Constraint.NON_NEGATIVE)
.build(),
new VariableSymbol.Builder()
.name("n")
.constraints(Constraint.INTEGER, Constraint.POSITIVE)
.build()
)
.shapeFunction((inputShapes, method) -> splitShapeFunction(inputShapes.get(0), method))
.build();
public static MethodDeclarationSymbol createRelu(){
return new MethodDeclarationSymbol.Builder()
.name(RELU_NAME)
.shapeFunction((inputShapes, method) -> inputShapes)
.build();
}
public static final MethodDeclarationSymbol GET = new MethodDeclarationSymbol.Builder()
.name("Get")
.parameters(
new VariableSymbol.Builder()
.name("index")
.constraints(Constraint.INTEGER, Constraint.NON_NEGATIVE)
.build()
)
.shapeFunction((inputShapes, method) ->
Collections.singletonList(inputShapes.get(method.getIntValue("index").get())))
.build();
public static MethodDeclarationSymbol createDropout(){
return new MethodDeclarationSymbol.Builder()
.name(DROPOUT_NAME)
.parameters(
new VariableSymbol.Builder()
.name("p")
.constraints(Constraint.NUMBER, Constraint.BETWEEN_ZERO_AND_ONE)
.defaultValue(0.5)
.build()
)
.shapeFunction((inputShapes, method) -> inputShapes)
.build();
}
public static final MethodDeclarationSymbol ADD = new MethodDeclarationSymbol.Builder()
.name("Add")
.shapeFunction((inputShapes, method) -> Collections.singletonList(inputShapes.get(0)))
.build();
public static MethodDeclarationSymbol createMaxPooling(){
return new MethodDeclarationSymbol.Builder()
.name(MAX_POOLING_NAME)
.parameters(
new VariableSymbol.Builder()
.name("kernel")
.constraints(Constraint.INTEGER_TUPLE, Constraint.POSITIVE)
.build(),
new VariableSymbol.Builder()
.name("stride")
.constraints(Constraint.INTEGER_TUPLE, Constraint.POSITIVE)
.defaultValue(Arrays.asList(1, 1))
.build(),
new VariableSymbol.Builder()
.name("global")
.constraints(Constraint.BOOLEAN)
.defaultValue(false)
.build()
)
.shapeFunction((inputShapes, method) ->
strideShapeFunction(inputShapes.get(0),
method,
inputShapes.get(0).getChannels().get()))
.build();
}
public static final MethodDeclarationSymbol CONCATENATE = new MethodDeclarationSymbol.Builder()
.name("Concatenate")
.shapeFunction(PredefinedMethods::concatenateShapeFunction)
.build();
public static MethodDeclarationSymbol createAveragePooling(){
return new MethodDeclarationSymbol.Builder()
.name(AVG_POOLING_NAME)
.parameters(
new VariableSymbol.Builder()
.name("kernel")
.constraints(Constraint.INTEGER_TUPLE, Constraint.POSITIVE)
.build(),
new VariableSymbol.Builder()
.name("stride")
.constraints(Constraint.INTEGER_TUPLE, Constraint.POSITIVE)
.defaultValue(Arrays.asList(1, 1))
.build(),
new VariableSymbol.Builder()
.name("global")
.constraints(Constraint.BOOLEAN)
.defaultValue(false)
.build()
)
.shapeFunction((inputShapes, method) ->
strideShapeFunction(inputShapes.get(0),
method,
inputShapes.get(0).getChannels().get()))
.build();
}
public static final List<MethodDeclarationSymbol> LIST = Arrays.asList(
FULLY_CONNECTED,
CONVOLUTION,
SOFTMAX,
SIGMOID,
TANH,
RELU,
DROPOUT,
MAX_POOLING,
AVERAGE_POOLING,
LRN,
BATCHNORM,
SPLIT,
GET,
ADD,
CONCATENATE);
public static MethodDeclarationSymbol createLrn(){
return new MethodDeclarationSymbol.Builder()
.name(LRN_NAME)
.parameters(
new VariableSymbol.Builder()
.name("nsize")
.constraints(Constraint.INTEGER, Constraint.NON_NEGATIVE)
.build(),
new VariableSymbol.Builder()
.name("knorm")
.constraints(Constraint.NUMBER)
.defaultValue(2)
.build(),
new VariableSymbol.Builder()
.name("alpha")
.constraints(Constraint.NUMBER)
.defaultValue(0.0001)
.build(),
new VariableSymbol.Builder()
.name("beta")
.constraints(Constraint.NUMBER)
.defaultValue(0.75)
.build()
)
.shapeFunction((inputShapes, method) -> inputShapes)
.build();
}
public static final Map<String, MethodDeclarationSymbol> MAP = createPredefinedMap();
public static MethodDeclarationSymbol createBatchNorm(){
return new MethodDeclarationSymbol.Builder()
.name(BATCHNORM_NAME)
.parameters(
new VariableSymbol.Builder()
.name("fix_gamma")
.constraints(Constraint.BOOLEAN)
.defaultValue(true)
.build(),
new VariableSymbol.Builder()
.name("axis")
.constraints(Constraint.INTEGER, Constraint.NON_NEGATIVE)
.defaultValue(ShapeSymbol.CHANNEL_INDEX)
.build()
)
.shapeFunction((inputShapes, method) -> inputShapes)
.build();
}
public static MethodDeclarationSymbol createSplit(){
return new MethodDeclarationSymbol.Builder()
.name(SPLIT_NAME)
.parameters(
new VariableSymbol.Builder()
.name("index")
.constraints(Constraint.INTEGER, Constraint.NON_NEGATIVE)
.build(),
new VariableSymbol.Builder()
.name("n")
.constraints(Constraint.INTEGER, Constraint.POSITIVE)
.build()
)
.shapeFunction((inputShapes, method) -> splitShapeFunction(inputShapes.get(0), method))
.build();
}
public static MethodDeclarationSymbol createGet(){
return new MethodDeclarationSymbol.Builder()
.name(GET_NAME)
.parameters(
new VariableSymbol.Builder()
.name("index")
.constraints(Constraint.INTEGER, Constraint.NON_NEGATIVE)
.build()
)
.shapeFunction((inputShapes, method) ->
Collections.singletonList(inputShapes.get(method.getIntValue("index").get())))
.build();
}
public static MethodDeclarationSymbol createAdd(){
return new MethodDeclarationSymbol.Builder()
.name(ADD_NAME)
.shapeFunction((inputShapes, method) -> Collections.singletonList(inputShapes.get(0)))
.build();
}
private static Map<String, MethodDeclarationSymbol> createPredefinedMap() {
Map<String, MethodDeclarationSymbol> map = new HashMap<>();
for (MethodDeclarationSymbol method : LIST) {
map.put(method.getName(), method);
}
return map;
public static MethodDeclarationSymbol createConcatenate(){
return new MethodDeclarationSymbol.Builder()
.name(CONCATENATE_NAME)
.shapeFunction(PredefinedMethods::concatenateShapeFunction)
.build();
}
private static List<ShapeSymbol> strideShapeFunction(ShapeSymbol inputShape, MethodLayerSymbol method, int channels) {
......
......@@ -31,20 +31,23 @@ abstract public class ArchAbstractSequenceExpression extends ArchExpressionSymbo
super();
}
abstract public Optional<List<List<ArchSimpleExpressionSymbol>>> getElements();
abstract public boolean isParallelSequence();
abstract public boolean isSerialSequence();
abstract public Optional<Integer> getParallelLength();
abstract public Optional<Integer> getMaxSerialLength();
@Override
public boolean isSequence(){
return true;
}
@Override
public Optional<Object> getValue() {
if (isResolved()){
List<List<Object>> parallelValues = new ArrayList<>(getParallelLength().get());
for (List<ArchSimpleExpressionSymbol> serialElements : getElements().get()){
List<Object> serialValues = new ArrayList<>(getMaxSerialLength().get());
List<Object> serialValues = new ArrayList<>();
for (ArchSimpleExpressionSymbol element : serialElements){
serialValues.add(element.getValue().get());
}
......@@ -57,4 +60,19 @@ abstract public class ArchAbstractSequenceExpression extends ArchExpressionSymbo
}
}
@Override
public boolean isBoolean(){
return false;
}
@Override
public boolean isNumber(){
return false;
}
@Override
public boolean isTuple(){
return false;
}
}
......@@ -21,7 +21,10 @@
package de.monticore.lang.monticar.cnnarch._symboltable;
import de.monticore.symboltable.CommonSymbol;
import de.monticore.symboltable.MutableScope;
import de.monticore.symboltable.Scope;
import de.monticore.symboltable.Symbol;
import de.se_rwth.commons.logging.Log;
import java.util.*;
......@@ -31,12 +34,16 @@ abstract public class ArchExpressionSymbol extends CommonSymbol {
private Set<String> unresolvableNames = null;
public ArchExpressionSymbol() {
super("", KIND);
}
protected Boolean isResolvable(){
Set<String> set = getUnresolvableNames();
return set != null && set.isEmpty();
}
public Set<String> getUnresolvableNames() {
if (unresolvableNames == null){
checkIfResolvable();
......@@ -44,17 +51,12 @@ abstract public class ArchExpressionSymbol extends CommonSymbol {
return unresolvableNames;
}
public boolean isResolvable(){
return getUnresolvableNames().isEmpty();
protected void setUnresolvableNames(Set<String> unresolvableNames){
this.unresolvableNames = unresolvableNames;
}
public void checkIfResolvable(){
if (isResolved()){
unresolvableNames = new HashSet<>();
}
else {