Commit 19649197 authored by Thomas Michael Timmermanns's avatar Thomas Michael Timmermanns
Browse files

Started implementation of resolve method and ioLayer and changed Predefined methods.

parent 9b0850da
......@@ -51,6 +51,8 @@ public class ErrorMessages {
"This is because the number of outputs is variable for all architectures. " +
"Example: output{ fullyConnected() activation.softmax() } -> out";
public static final String UNKNOWN_NAME_CODE = "x32585";
public static final String UNKNOWN_NAME_MSG = "0" + UNKNOWN_NAME_CODE + " Unknown method error. ";
}
......@@ -26,59 +26,220 @@ import de.monticore.lang.monticar.cnnarch._symboltable.ShapeSymbol;
import de.monticore.lang.monticar.cnnarch._symboltable.VariableSymbol;
import org.jscience.mathematics.number.Rational;
import java.util.Arrays;
import java.util.Collections;
import java.util.List;
import java.util.*;
public class PredefinedMethods {
public static MethodDeclarationSymbol createFullyConnected(){
return new MethodDeclarationSymbol.Builder()
.name("FullyConnected")
.parameters(
new VariableSymbol.Builder()
.name("units")
.build(),
new VariableSymbol.Builder()
.name("no_bias")
.defaultValue(false)
.build()
)
.shapeFunction((inputShapes, method) -> Collections.singletonList(new ShapeSymbol.Builder()
.height(1)
.width(1)
.channels(method.getIntValue("units").get())
.build()))
.build();
}
public static MethodDeclarationSymbol FULLY_CONNECTED = new MethodDeclarationSymbol.Builder()
.name("FullyConnected")
.parameters(
new VariableSymbol.Builder()
.name("units")
.build(),
new VariableSymbol.Builder()
.name("no_bias")
.defaultValue(false)
.build()
)
.shapeFunction((inputShapes, method) -> Collections.singletonList(new ShapeSymbol.Builder()
.height(1)
.width(1)
.channels(method.getIntValue("units").get())
.build()))
.build();
public static MethodDeclarationSymbol CONVOLUTION = new MethodDeclarationSymbol.Builder()
.name("Convolution")
.parameters(
new VariableSymbol.Builder()
.name("kernel")
.build(),
new VariableSymbol.Builder()
.name("channels")
.build(),
new VariableSymbol.Builder()
.name("stride")
.defaultValue(1, 1)
.build(),
new VariableSymbol.Builder()
.name("no_bias")
.defaultValue(false)
.build()
)
.shapeFunction((inputShapes, method) ->
strideShapeFunction(inputShapes.get(0),
method,
method.getIntValue("channels").get()))
.build();
public static MethodDeclarationSymbol SOFTMAX = new MethodDeclarationSymbol.Builder()
.name("Softmax")
.shapeFunction((inputShapes, method) -> inputShapes)
.build();
public static MethodDeclarationSymbol SIGMOID = new MethodDeclarationSymbol.Builder()
.name("Sigmoid")
.shapeFunction((inputShapes, method) -> inputShapes)
.build();
public static MethodDeclarationSymbol TANH = new MethodDeclarationSymbol.Builder()
.name("Tanh")
.shapeFunction((inputShapes, method) -> inputShapes)
.build();
public static MethodDeclarationSymbol RELU = new MethodDeclarationSymbol.Builder()
.name("Relu")
.shapeFunction((inputShapes, method) -> inputShapes)
.build();
public static MethodDeclarationSymbol DROPOUT = new MethodDeclarationSymbol.Builder()
.name("Dropout")
.parameters(
new VariableSymbol.Builder()
.name("p")
.defaultValue(Rational.valueOf(1, 2))//0.5
.build()
)
.shapeFunction((inputShapes, method) -> inputShapes)
.build();
public static MethodDeclarationSymbol MAX_POOLING = new MethodDeclarationSymbol.Builder()
.name("MaxPooling")
.parameters(
new VariableSymbol.Builder()
.name("kernel")
.build(),
new VariableSymbol.Builder()
.name("stride")
.defaultValue(1, 1)
.build(),
new VariableSymbol.Builder()
.name("global")
.defaultValue(false)
.build()
)
.shapeFunction((inputShapes, method) ->
strideShapeFunction(inputShapes.get(0),
method,
inputShapes.get(0).getChannels().get()))
.build();
public static MethodDeclarationSymbol AVERAGE_POOLING = new MethodDeclarationSymbol.Builder()
.name("AveragePooling")
.parameters(
new VariableSymbol.Builder()
.name("kernel")
.build(),
new VariableSymbol.Builder()
.name("stride")
.defaultValue(1, 1)
.build(),
new VariableSymbol.Builder()
.name("global")
.defaultValue(false)
.build()
)
.shapeFunction((inputShapes, method) ->
strideShapeFunction(inputShapes.get(0),
method,
inputShapes.get(0).getChannels().get()))
.build();
public static MethodDeclarationSymbol LRN = new MethodDeclarationSymbol.Builder()
.name("Lrn")
.parameters(
new VariableSymbol.Builder()
.name("nsize")
.build(),
new VariableSymbol.Builder()
.name("knorm")
.defaultValue(2)
.build(),
new VariableSymbol.Builder()
.name("alpha")
.defaultValue(Rational.valueOf(1, 10000))//0.0001
.build(),
new VariableSymbol.Builder()
.name("beta")
.defaultValue(Rational.valueOf(3, 4))//0.75
.build()
)
.shapeFunction((inputShapes, method) -> inputShapes)
.build();
public static MethodDeclarationSymbol BATCHNORM = new MethodDeclarationSymbol.Builder()
.name("BatchNorm")
.parameters(
//todo
)
.shapeFunction((inputShapes, method) -> inputShapes)
.build();
public static MethodDeclarationSymbol SPLIT = new MethodDeclarationSymbol.Builder()
.name("Split")
.parameters(
new VariableSymbol.Builder()
.name("index")
.build(),
new VariableSymbol.Builder()
.name("n")
.build()
)
.shapeFunction((inputShapes, method) -> splitShapeFunction(inputShapes.get(0), method))
.build();
public static MethodDeclarationSymbol createConvolution(){
return new MethodDeclarationSymbol.Builder()
.name("Convolution")
.parameters(
new VariableSymbol.Builder()
.name("kernel")
.build(),
new VariableSymbol.Builder()
.name("channels")
.build(),
new VariableSymbol.Builder()
.name("stride")
.defaultValue(1,1)
.build(),
new VariableSymbol.Builder()
.name("no_bias")
.defaultValue(false)
.build()
)
.shapeFunction((inputShapes, method) ->
strideShapeFunction(inputShapes.get(0),
method,
method.getIntValue("channels").get()))
.build();
public static MethodDeclarationSymbol GET = new MethodDeclarationSymbol.Builder()
.name("Get")
.parameters(
new VariableSymbol.Builder()
.name("index")
.build()
)
.shapeFunction((inputShapes, method) ->
Collections.singletonList(inputShapes.get(method.getIntValue("index").get())))
.build();
public static MethodDeclarationSymbol ADD = new MethodDeclarationSymbol.Builder()
.name("Add")
.shapeFunction((inputShapes, method) -> Collections.singletonList(inputShapes.get(0)))
.build();
public static MethodDeclarationSymbol CONCATENATE = new MethodDeclarationSymbol.Builder()
.name("Concatenate")
.shapeFunction(PredefinedMethods::concatenateShapeFunction)
.build();
public static List<MethodDeclarationSymbol> LIST = Arrays.asList(
FULLY_CONNECTED,
CONVOLUTION,
SOFTMAX,
SIGMOID,
TANH,
RELU,
DROPOUT,
MAX_POOLING,
AVERAGE_POOLING,
LRN,
BATCHNORM,
SPLIT,
GET,
ADD,
CONCATENATE);
public static Map<String, MethodDeclarationSymbol> MAP = createPredefinedMap();
private static Map<String, MethodDeclarationSymbol> createPredefinedMap() {
Map<String, MethodDeclarationSymbol> map = new HashMap<>();
for (MethodDeclarationSymbol method : LIST) {
map.put(method.getName(), method);
}
return map;
}
private static List<ShapeSymbol> strideShapeFunction(ShapeSymbol inputShape, MethodLayerSymbol method, int channels){
private static List<ShapeSymbol> strideShapeFunction(ShapeSymbol inputShape, MethodLayerSymbol method, int channels) {
int strideHeight = method.getIntTupleValue("stride").get().get(0);
int strideWidth = method.getIntTupleValue("stride").get().get(1);
int kernelHeight = method.getIntTupleValue("kernel").get().get(0);
......@@ -97,151 +258,21 @@ public class PredefinedMethods {
.build());
}
public static MethodDeclarationSymbol createSoftmax(){
return new MethodDeclarationSymbol.Builder()
.name("Softmax")
.build();
}
public static MethodDeclarationSymbol createSigmoid(){
return new MethodDeclarationSymbol.Builder()
.name("Sigmoid")
.build();
}
public static MethodDeclarationSymbol createTanh(){
return new MethodDeclarationSymbol.Builder()
.name("Tanh")
.build();
}
public static MethodDeclarationSymbol createRelu(){
return new MethodDeclarationSymbol.Builder()
.name("Relu")
.build();
}
public static MethodDeclarationSymbol createDropout(){
return new MethodDeclarationSymbol.Builder()
.name("Dropout")
.parameters(
new VariableSymbol.Builder()
.name("p")
.defaultValue(Rational.valueOf(1,2))//0.5
.build()
)
.build();
}
public static MethodDeclarationSymbol createMaxPooling(){
return new MethodDeclarationSymbol.Builder()
.name("MaxPooling")
.parameters(
new VariableSymbol.Builder()
.name("kernel")
.build(),
new VariableSymbol.Builder()
.name("stride")
.defaultValue(1,1)
.build(),
new VariableSymbol.Builder()
.name("global")
.defaultValue(false)
.build()
)
.shapeFunction((inputShapes, method) ->
strideShapeFunction(inputShapes.get(0),
method,
inputShapes.get(0).getChannels().get()))
.build();
}
public static MethodDeclarationSymbol createAveragePooling(){
return new MethodDeclarationSymbol.Builder()
.name("AveragePooling")
.parameters(
new VariableSymbol.Builder()
.name("kernel")
.build(),
new VariableSymbol.Builder()
.name("stride")
.defaultValue(1,1)
.build(),
new VariableSymbol.Builder()
.name("global")
.defaultValue(false)
.build()
)
.shapeFunction((inputShapes, method) ->
strideShapeFunction(inputShapes.get(0),
method,
inputShapes.get(0).getChannels().get()))
.build();
}
public static MethodDeclarationSymbol createLrn(){
return new MethodDeclarationSymbol.Builder()
.name("Lrn")
.parameters(
new VariableSymbol.Builder()
.name("nsize")
.build(),
new VariableSymbol.Builder()
.name("knorm")
.defaultValue(2)
.build(),
new VariableSymbol.Builder()
.name("alpha")
.defaultValue(Rational.valueOf(1,10000))//0.0001
.build(),
new VariableSymbol.Builder()
.name("beta")
.defaultValue(Rational.valueOf(3,4))//0.75
.build()
)
.build();
}
public static MethodDeclarationSymbol createBatchNorm(){
return new MethodDeclarationSymbol.Builder()
.name("BatchNorm")
.parameters(
//todo
)
.build();
}
public static MethodDeclarationSymbol createSplit(){
return new MethodDeclarationSymbol.Builder()
.name("Split")
.parameters(
new VariableSymbol.Builder()
.name("index")
.build(),
new VariableSymbol.Builder()
.name("n")
.build()
)
.shapeFunction((inputShapes, method) -> splitShapeFunction(inputShapes.get(0), method))
.build();
}
private static List<ShapeSymbol> splitShapeFunction(ShapeSymbol inputShape, MethodLayerSymbol method){
private static List<ShapeSymbol> splitShapeFunction(ShapeSymbol inputShape, MethodLayerSymbol method) {
int numberOfSplits = method.getIntValue("n").get();
int groupIndex = method.getIntValue("index").get();
int inputChannels = inputShape.getChannels().get();
int outputChannels = inputChannels / numberOfSplits;
int outputChannelsLast = inputChannels - numberOfSplits*outputChannels;
int outputChannelsLast = inputChannels - numberOfSplits * outputChannels;
if (groupIndex == numberOfSplits - 1){
if (groupIndex == numberOfSplits - 1) {
return Collections.singletonList(new ShapeSymbol.Builder()
.height(1)
.width(1)
.channels(outputChannelsLast)
.build());
}
else {
} else {
return Collections.singletonList(new ShapeSymbol.Builder()
.height(1)
.width(1)
......@@ -250,36 +281,9 @@ public class PredefinedMethods {
}
}
public static MethodDeclarationSymbol createGet(){
return new MethodDeclarationSymbol.Builder()
.name("Get")
.parameters(
new VariableSymbol.Builder()
.name("index")
.build()
)
.shapeFunction((inputShapes, method) ->
Collections.singletonList(inputShapes.get(method.getIntValue("index").get())))
.build();
}
public static MethodDeclarationSymbol createAdd(){
return new MethodDeclarationSymbol.Builder()
.name("Add")
.shapeFunction((inputShapes, method) -> Collections.singletonList(inputShapes.get(0)))
.build();
}
public static MethodDeclarationSymbol createConcatenate(){
return new MethodDeclarationSymbol.Builder()
.name("Concatenate")
.shapeFunction(PredefinedMethods::concatenateShapeFunction)
.build();
}
private static List<ShapeSymbol> concatenateShapeFunction(List<ShapeSymbol> inputShapes, MethodLayerSymbol method){
private static List<ShapeSymbol> concatenateShapeFunction(List<ShapeSymbol> inputShapes, MethodLayerSymbol method) {
int channels = 0;
for (ShapeSymbol inputShape : inputShapes){
for (ShapeSymbol inputShape : inputShapes) {
channels += inputShape.getChannels().get();
}
return Collections.singletonList(new ShapeSymbol.Builder()
......@@ -288,24 +292,4 @@ public class PredefinedMethods {
.channels(channels)
.build());
}
public static List<MethodDeclarationSymbol> createList(){
return Arrays.asList(
createFullyConnected(),
createConvolution(),
createSoftmax(),
createSigmoid(),
createTanh(),
createRelu(),
createDropout(),
createMaxPooling(),
createAveragePooling(),
createLrn(),
createBatchNorm(),
createSplit(),
createGet(),
createAdd(),
createConcatenate());
}
}
......@@ -97,8 +97,17 @@ public class ArchSimpleExpressionSymbol extends ArchExpressionSymbol implements
@Override
public Optional<Object> getValue() {
if (isFullyResolved()){
Object value = Calculator.getInstance().calculate(getExpression());
return Optional.of(value);
if (isTuple()){
List<Object> tupleValues = new ArrayList<>();
for (MathExpressionSymbol element : ((TupleExpressionSymbol) getExpression()).getExpressions()){
tupleValues.add(Calculator.getInstance().calculate(element));
}
return Optional.of(tupleValues);
}
else {
Object value = Calculator.getInstance().calculate(getExpression());
return Optional.of(value);
}
}
else {
return Optional.empty();
......@@ -186,6 +195,11 @@ public class ArchSimpleExpressionSymbol extends ArchExpressionSymbol implements
return new ArchSimpleExpressionSymbol(tupleExpression);
}
public static ArchSimpleExpressionSymbol of(VariableSymbol variable){
MathExpressionSymbol exp = new MathNameExpressionSymbol(variable.getName());
return new ArchSimpleExpressionSymbol(exp);
}
/*
only used to create 'true' and 'false' expression.
This is necessary because true and false are at the moment just names in the MontiMath SMI.
......
......@@ -38,15 +38,6 @@ public class ArgumentSymbol extends CommonSymbol {
}
public VariableSymbol getParameter() {
if (parameter == null){
Optional<VariableSymbol> optParam = getEnclosingScope().resolve(getName(), VariableSymbol.KIND);
if (optParam.isPresent()){
parameter = optParam.get();
}
else {
Log.error("Parameter with name " + getName() + " could not be resolved", getSourcePosition());
}
}
return parameter;
}
......@@ -80,9 +71,7 @@ public class ArgumentSymbol extends CommonSymbol {
public static class Builder{
private VariableSymbol parameter;
private ArchExpressionSymbol value;
private String name;
//will be assigned automatically by name if not set
public Builder parameter(VariableSymbol parameter) {
this.parameter = parameter;
return this;
......@@ -93,16 +82,11 @@ public class ArgumentSymbol extends CommonSymbol {
return this;
}
public Builder name(String name) {
this.name = name;
return this;
}
public ArgumentSymbol build(){
if (name == null || name.equals("")){
throw new IllegalStateException("Missing name for ArgumentSymbol");