Commit 0e823454 authored by Thomas Michael Timmermanns's avatar Thomas Michael Timmermanns
Browse files

Implemented resolve method for LayerSymbol classes and started implementation...

Implemented resolve method for LayerSymbol classes and started implementation of CoCo for parameter type check.
parent 19649197
......@@ -19,7 +19,9 @@ grammar CNNArch extends de.monticore.lang.math.Math {
Name&
(ArrayDeclaration)?;
ArchType implements Type = (ElementType ("^" "{" (Dimension || ",")+ "}")?)?;
ArchType implements Type = ElementType "^" Shape;
Shape = "{" (Dimension || ",")+ "}";
Dimension = IOVariable | intLiteral:UnitNumberResolution;
......
/**
*
* ******************************************************************************
* MontiCAR Modeling Family, www.se-rwth.de
* Copyright (c) 2017, Software Engineering Group at RWTH Aachen,
* All rights reserved.
*
* This project is free software; you can redistribute it and/or
* modify it under the terms of the GNU Lesser General Public
* License as published by the Free Software Foundation; either
* version 3.0 of the License, or (at your option) any later version.
* This library is distributed in the hope that it will be useful,
* but WITHOUT ANY WARRANTY; without even the implied warranty of
* MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the GNU
* Lesser General Public License for more details.
*
* You should have received a copy of the GNU Lesser General Public
* License along with this project. If not, see <http://www.gnu.org/licenses/>.
* *******************************************************************************
*/
package de.monticore.lang.monticar.cnnarch;
import de.monticore.lang.monticar.cnnarch._symboltable.ArchSimpleExpressionSymbol;
import de.monticore.lang.monticar.cnnarch._symboltable.TupleExpressionSymbol;
public enum Constraint {
NUMBER {
@Override
public boolean check(ArchSimpleExpressionSymbol exp) {
return exp.isNumber();
}
},
INTEGER {
@Override
public boolean check(ArchSimpleExpressionSymbol exp) {
return exp.isInt().get();
}
},
BOOLEAN {
@Override
public boolean check(ArchSimpleExpressionSymbol exp) {
return exp.isBoolean();
}
},
TUPLE {
@Override
public boolean check(ArchSimpleExpressionSymbol exp) {
return exp.isTuple();
}
},
INTEGER_TUPLE {
@Override
public boolean check(ArchSimpleExpressionSymbol exp) {
boolean res = false;
if (exp.isTuple()){
//todo
}
return false;
}
},
POSITIVE {
@Override
public boolean check(ArchSimpleExpressionSymbol exp) {
return false;
}
},
NON_NEGATIVE {
@Override
public boolean check(ArchSimpleExpressionSymbol exp) {
return false;
}
},
BETWEEN_ZERO_AND_ONE {
@Override
public boolean check(ArchSimpleExpressionSymbol exp) {
return false;
}
};
abstract public boolean check(ArchSimpleExpressionSymbol exp);
public boolean check(Constraint constraint, ArchSimpleExpressionSymbol exp){
return constraint.check(exp);
}
}
......@@ -54,5 +54,9 @@ public class ErrorMessages {
public static final String UNKNOWN_NAME_CODE = "x32585";
public static final String UNKNOWN_NAME_MSG = "0" + UNKNOWN_NAME_CODE + " Unknown method error. ";
public static final String ILLEGAL_SEQUENCE_LENGTH_CODE = "x24772";
public static final String ILLEGAL_SEQUENCE_LENGTH_MSG = "0" + ILLEGAL_SEQUENCE_LENGTH_CODE + " Illegal sequence length. ";
}
......@@ -30,14 +30,16 @@ import java.util.*;
public class PredefinedMethods {
public static MethodDeclarationSymbol FULLY_CONNECTED = new MethodDeclarationSymbol.Builder()
public static final MethodDeclarationSymbol FULLY_CONNECTED = new MethodDeclarationSymbol.Builder()
.name("FullyConnected")
.parameters(
new VariableSymbol.Builder()
.name("units")
.constraints(Constraint.INTEGER, Constraint.POSITIVE)
.build(),
new VariableSymbol.Builder()
.name("no_bias")
.constraints(Constraint.BOOLEAN)
.defaultValue(false)
.build()
)
......@@ -48,21 +50,25 @@ public class PredefinedMethods {
.build()))
.build();
public static MethodDeclarationSymbol CONVOLUTION = new MethodDeclarationSymbol.Builder()
public static final MethodDeclarationSymbol CONVOLUTION = new MethodDeclarationSymbol.Builder()
.name("Convolution")
.parameters(
new VariableSymbol.Builder()
.name("kernel")
.constraints(Constraint.INTEGER_TUPLE, Constraint.POSITIVE)
.build(),
new VariableSymbol.Builder()
.name("channels")
.constraints(Constraint.INTEGER, Constraint.POSITIVE)
.build(),
new VariableSymbol.Builder()
.name("stride")
.constraints(Constraint.INTEGER_TUPLE, Constraint.POSITIVE)
.defaultValue(1, 1)
.build(),
new VariableSymbol.Builder()
.name("no_bias")
.constraints(Constraint.BOOLEAN)
.defaultValue(false)
.build()
)
......@@ -72,49 +78,53 @@ public class PredefinedMethods {
method.getIntValue("channels").get()))
.build();
public static MethodDeclarationSymbol SOFTMAX = new MethodDeclarationSymbol.Builder()
public static final MethodDeclarationSymbol SOFTMAX = new MethodDeclarationSymbol.Builder()
.name("Softmax")
.shapeFunction((inputShapes, method) -> inputShapes)
.build();
public static MethodDeclarationSymbol SIGMOID = new MethodDeclarationSymbol.Builder()
public static final MethodDeclarationSymbol SIGMOID = new MethodDeclarationSymbol.Builder()
.name("Sigmoid")
.shapeFunction((inputShapes, method) -> inputShapes)
.build();
public static MethodDeclarationSymbol TANH = new MethodDeclarationSymbol.Builder()
public static final MethodDeclarationSymbol TANH = new MethodDeclarationSymbol.Builder()
.name("Tanh")
.shapeFunction((inputShapes, method) -> inputShapes)
.build();
public static MethodDeclarationSymbol RELU = new MethodDeclarationSymbol.Builder()
public static final MethodDeclarationSymbol RELU = new MethodDeclarationSymbol.Builder()
.name("Relu")
.shapeFunction((inputShapes, method) -> inputShapes)
.build();
public static MethodDeclarationSymbol DROPOUT = new MethodDeclarationSymbol.Builder()
public static final MethodDeclarationSymbol DROPOUT = new MethodDeclarationSymbol.Builder()
.name("Dropout")
.parameters(
new VariableSymbol.Builder()
.name("p")
.constraints(Constraint.NUMBER, Constraint.BETWEEN_ZERO_AND_ONE)
.defaultValue(Rational.valueOf(1, 2))//0.5
.build()
)
.shapeFunction((inputShapes, method) -> inputShapes)
.build();
public static MethodDeclarationSymbol MAX_POOLING = new MethodDeclarationSymbol.Builder()
public static final MethodDeclarationSymbol MAX_POOLING = new MethodDeclarationSymbol.Builder()
.name("MaxPooling")
.parameters(
new VariableSymbol.Builder()
.name("kernel")
.constraints(Constraint.INTEGER_TUPLE, Constraint.POSITIVE)
.build(),
new VariableSymbol.Builder()
.name("stride")
.constraints(Constraint.INTEGER_TUPLE, Constraint.POSITIVE)
.defaultValue(1, 1)
.build(),
new VariableSymbol.Builder()
.name("global")
.constraints(Constraint.BOOLEAN)
.defaultValue(false)
.build()
)
......@@ -124,18 +134,21 @@ public class PredefinedMethods {
inputShapes.get(0).getChannels().get()))
.build();
public static MethodDeclarationSymbol AVERAGE_POOLING = new MethodDeclarationSymbol.Builder()
public static final MethodDeclarationSymbol AVERAGE_POOLING = new MethodDeclarationSymbol.Builder()
.name("AveragePooling")
.parameters(
new VariableSymbol.Builder()
.name("kernel")
.constraints(Constraint.INTEGER_TUPLE, Constraint.POSITIVE)
.build(),
new VariableSymbol.Builder()
.name("stride")
.constraints(Constraint.INTEGER_TUPLE, Constraint.POSITIVE)
.defaultValue(1, 1)
.build(),
new VariableSymbol.Builder()
.name("global")
.constraints(Constraint.BOOLEAN)
.defaultValue(false)
.build()
)
......@@ -145,71 +158,87 @@ public class PredefinedMethods {
inputShapes.get(0).getChannels().get()))
.build();
public static MethodDeclarationSymbol LRN = new MethodDeclarationSymbol.Builder()
public static final MethodDeclarationSymbol LRN = new MethodDeclarationSymbol.Builder()
.name("Lrn")
.parameters(
new VariableSymbol.Builder()
.name("nsize")
.constraints(Constraint.INTEGER, Constraint.NON_NEGATIVE)
.build(),
new VariableSymbol.Builder()
.name("knorm")
.constraints(Constraint.NUMBER)
.defaultValue(2)
.build(),
new VariableSymbol.Builder()
.name("alpha")
.constraints(Constraint.NUMBER)
.defaultValue(Rational.valueOf(1, 10000))//0.0001
.build(),
new VariableSymbol.Builder()
.name("beta")
.constraints(Constraint.NUMBER)
.defaultValue(Rational.valueOf(3, 4))//0.75
.build()
)
.shapeFunction((inputShapes, method) -> inputShapes)
.build();
public static MethodDeclarationSymbol BATCHNORM = new MethodDeclarationSymbol.Builder()
public static final MethodDeclarationSymbol BATCHNORM = new MethodDeclarationSymbol.Builder()
.name("BatchNorm")
.parameters(
//todo
new VariableSymbol.Builder()
.name("fix_gamma")
.constraints(Constraint.BOOLEAN)
.defaultValue(true)
.build(),
new VariableSymbol.Builder()
.name("axis")
.constraints(Constraint.INTEGER, Constraint.NON_NEGATIVE)
.defaultValue(ShapeSymbol.CHANNEL_INDEX)
.build()
)
.shapeFunction((inputShapes, method) -> inputShapes)
.build();
public static MethodDeclarationSymbol SPLIT = new MethodDeclarationSymbol.Builder()
public static final MethodDeclarationSymbol SPLIT = new MethodDeclarationSymbol.Builder()
.name("Split")
.parameters(
new VariableSymbol.Builder()
.name("index")
.constraints(Constraint.INTEGER, Constraint.NON_NEGATIVE)
.build(),
new VariableSymbol.Builder()
.name("n")
.constraints(Constraint.INTEGER, Constraint.POSITIVE)
.build()
)
.shapeFunction((inputShapes, method) -> splitShapeFunction(inputShapes.get(0), method))
.build();
public static MethodDeclarationSymbol GET = new MethodDeclarationSymbol.Builder()
public static final MethodDeclarationSymbol GET = new MethodDeclarationSymbol.Builder()
.name("Get")
.parameters(
new VariableSymbol.Builder()
.name("index")
.constraints(Constraint.INTEGER, Constraint.NON_NEGATIVE)
.build()
)
.shapeFunction((inputShapes, method) ->
Collections.singletonList(inputShapes.get(method.getIntValue("index").get())))
.build();
public static MethodDeclarationSymbol ADD = new MethodDeclarationSymbol.Builder()
public static final MethodDeclarationSymbol ADD = new MethodDeclarationSymbol.Builder()
.name("Add")
.shapeFunction((inputShapes, method) -> Collections.singletonList(inputShapes.get(0)))
.build();
public static MethodDeclarationSymbol CONCATENATE = new MethodDeclarationSymbol.Builder()
public static final MethodDeclarationSymbol CONCATENATE = new MethodDeclarationSymbol.Builder()
.name("Concatenate")
.shapeFunction(PredefinedMethods::concatenateShapeFunction)
.build();
public static List<MethodDeclarationSymbol> LIST = Arrays.asList(
public static final List<MethodDeclarationSymbol> LIST = Arrays.asList(
FULLY_CONNECTED,
CONVOLUTION,
SOFTMAX,
......@@ -226,7 +255,7 @@ public class PredefinedMethods {
ADD,
CONCATENATE);
public static Map<String, MethodDeclarationSymbol> MAP = createPredefinedMap();
public static final Map<String, MethodDeclarationSymbol> MAP = createPredefinedMap();
......
/**
*
* ******************************************************************************
* MontiCAR Modeling Family, www.se-rwth.de
* Copyright (c) 2017, Software Engineering Group at RWTH Aachen,
* All rights reserved.
*
* This project is free software; you can redistribute it and/or
* modify it under the terms of the GNU Lesser General Public
* License as published by the Free Software Foundation; either
* version 3.0 of the License, or (at your option) any later version.
* This library is distributed in the hope that it will be useful,
* but WITHOUT ANY WARRANTY; without even the implied warranty of
* MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the GNU
* Lesser General Public License for more details.
*
* You should have received a copy of the GNU Lesser General Public
* License along with this project. If not, see <http://www.gnu.org/licenses/>.
* *******************************************************************************
*/
package de.monticore.lang.monticar.cnnarch;
import de.monticore.lang.monticar.cnnarch._symboltable.VariableSymbol;
public class PredefinedVariables {
public static final String IF_NAME = "_if";
public static final String FOR_NAME = "_for";
public static VariableSymbol createIfParameter(){
return new VariableSymbol.Builder()
.name(IF_NAME)
.constraints(Constraint.BOOLEAN)
.defaultValue(true)
.build();
}
public static VariableSymbol createForParameter(){
return new VariableSymbol.Builder()
.name(FOR_NAME)
.defaultValue(1)
.build();
}
//todo true and false
}
......@@ -24,6 +24,7 @@ package de.monticore.lang.monticar.cnnarch._cocos;
public class CNNArchCocos {
public static CNNArchCoCoChecker createChecker() {
return new CNNArchCoCoChecker();
return new CNNArchCoCoChecker()
.addCoCo(new CheckVariableConstraints());
}
}
\ No newline at end of file
/**
*
* ******************************************************************************
* MontiCAR Modeling Family, www.se-rwth.de
* Copyright (c) 2017, Software Engineering Group at RWTH Aachen,
* All rights reserved.
*
* This project is free software; you can redistribute it and/or
* modify it under the terms of the GNU Lesser General Public
* License as published by the Free Software Foundation; either
* version 3.0 of the License, or (at your option) any later version.
* This library is distributed in the hope that it will be useful,
* but WITHOUT ANY WARRANTY; without even the implied warranty of
* MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the GNU
* Lesser General Public License for more details.
*
* You should have received a copy of the GNU Lesser General Public
* License along with this project. If not, see <http://www.gnu.org/licenses/>.
* *******************************************************************************
*/
package de.monticore.lang.monticar.cnnarch._cocos;
import de.monticore.lang.monticar.cnnarch.Constraint;
import de.monticore.lang.monticar.cnnarch._ast.ASTVariable;
import de.monticore.lang.monticar.cnnarch._symboltable.VariableSymbol;
public class CheckVariableConstraints implements CNNArchASTVariableCoCo {
@Override
public void check(ASTVariable node) {
if (node == null || !node.getSymbol().isPresent()){
throw new IllegalArgumentException();
}
VariableSymbol variable = (VariableSymbol) node.getSymbol().get();
for (Constraint constraint : variable.getConstraints()){
constraint.check(variable.getValueSymbol());
}
}
}
......@@ -30,13 +30,14 @@ abstract public class ArchAbstractSequenceExpression extends ArchExpressionSymbo
super();
}
abstract public boolean isParallelSequence();
abstract public boolean isSerialSequence();
//todo no Optional
abstract public Optional<Integer> getParallelLength();
//todo no Optional
abstract public Optional<Integer> getSerialLength();
}
......@@ -22,6 +22,7 @@ package de.monticore.lang.monticar.cnnarch._symboltable;
import de.monticore.symboltable.CommonSymbol;
import java.util.List;
import java.util.Optional;
import java.util.Set;
......@@ -159,5 +160,51 @@ abstract public class ArchExpressionSymbol extends CommonSymbol {
*/
abstract public Set<String> resolve();
//todo remove
abstract protected void checkIfResolved();
abstract public boolean isResolved();
abstract public List<List<ArchSimpleExpressionSymbol>> getElements();
public void resolveOrError(){
resolve();
if (isResolved()){
throw new IllegalStateException("The following names could not be resolved: " + getUnresolvableNames());
}
}
public boolean isResolvable(){
//todo
return true;
}
public Set<String> getUnresolvableNames() {
//todo
return null;
}
public void checkIfResolvable(){
//todo: unresolvableNames = computeUnresolvableNames();
}
protected Set<String> computeUnresolvableNames(){
//todo
return null;
}
public boolean isIntTuple(){
return false;
}
public boolean isNumberTuple(){
return false;
}
public boolean isBooleanTuple(){
return false;
}
}
......@@ -157,6 +157,15 @@ public class ArchRangeExpressionSymbol extends ArchAbstractSequenceExpression {
setFullyResolved(startSymbol.isFullyResolved() && endSymbol.isFullyResolved());
}
@Override
public boolean isResolved() {
//todo
return false;
}
@Override
public List<List<ArchSimpleExpressionSymbol>> getElements() {
//todo
return null;
}
}
......@@ -133,4 +133,9 @@ public class ArchSequenceExpressionSymbol extends ArchAbstractSequenceExpression
setFullyResolved(isResolved);
}
@Override
public boolean isResolved() {
//todo
return false;
}
}
......@@ -89,13 +89,94 @@ public class ArchSimpleExpressionSymbol extends ArchExpressionSymbol implements
return isInt;
}
@Override
public boolean isIntTuple(){
//todo
return false;
}
@Override
public boolean isNumberTuple(){
//todo
return false;
}
@Override
public boolean isBooleanTuple(){
//todo
return false;
}
@Override
public boolean isSimpleValue() {
return true;
}