Commit 2cfd0f0d authored by Thomas Michael Timmermanns's avatar Thomas Michael Timmermanns
Browse files

Implemented most classes of the SMI.

Still missing ShapeSymbol, DimensionSymbol, output shape calculation and method resolve method.
parent e6700112
......@@ -13,11 +13,11 @@ grammar CNNArch extends de.monticore.lang.math.Math {
interface ArchitectureElement;
interface Variable;
symbol IODeclaration implements ArchDeclaration = "def"
(in:"input" | out:"output")
type:ArchType
Name&
(ArrayDeclaration)?;
IODeclaration implements ArchDeclaration = "def"
(in:"input" | out:"output")
type:ArchType
Name&
(ArrayDeclaration)?;
ArchType implements Type = (ElementType ("^" "{" (Dimension || ",")+ "}")?)?;
......@@ -25,14 +25,14 @@ grammar CNNArch extends de.monticore.lang.math.Math {
IOVariable implements Variable = Name&;
Constant implements Variable = "def" Name& "=" rhs:ArchExpression;
Constant implements Variable = Name& "=" rhs:ArchSimpleExpression;
symbol scope MethodDeclaration implements ArchDeclaration = "def"
Name& "("
parameters:(Parameter || ",")* ")" "{"
body:ArchBody "}";
MethodDeclaration implements ArchDeclaration = "def"
Name& "("
parameters:(Parameter || ",")* ")" "{"
body:ArchBody "}";
Parameter implements Variable = Name& ("=" default:ArchExpression)?;
Parameter implements Variable = Name& ("=" default:ArchSimpleExpression)?;
scope ArchBody = elements:(ArchitectureElement || "->")*;
......@@ -42,21 +42,22 @@ grammar CNNArch extends de.monticore.lang.math.Math {
Argument = Name "=" rhs:ArchExpression;
ParallelLayer implements ArchitectureElement = "(" groups:(ArchBody || "|")+ ")";
ParallelLayer implements ArchitectureElement = "(" groups:ArchBody "|" groups:(ArchBody || "|")+ ")";
ArrayAccessLayer implements ArchitectureElement = "[" index:ArchSimpleExpression "]";
ArchExpression = (expression:ArchSimpleExpression | sequence:ArchValueSequence);
ArchValueSequence = "[" parallelValues:(ArchSerialSequence || "|")* "]";
interface ArchValueSequence;
ArchParallelSequence implements ArchValueSequence = "[" parallelValues:(ArchSerialSequence || "|")* "]";
ArchSerialSequence = serialValues:(ArchSimpleExpression || "->")+;
ArchValueRange extends ArchValueSequence = "[" start:ArchSimpleExpression
(serial:"->" | parallel:"|")
(":" step:ArchSimpleExpression)?
":" end:ArchSimpleExpression "]";
ArchValueRange implements ArchValueSequence = "[" start:ArchSimpleExpression
(serial:"->" | parallel:"|")
":" end:ArchSimpleExpression "]";
ArchSimpleExpression = (arithmeticExpression:MathArithmeticExpression
......
/**
*
* ******************************************************************************
* MontiCAR Modeling Family, www.se-rwth.de
* Copyright (c) 2017, Software Engineering Group at RWTH Aachen,
* All rights reserved.
*
* This project is free software; you can redistribute it and/or
* modify it under the terms of the GNU Lesser General Public
* License as published by the Free Software Foundation; either
* version 3.0 of the License, or (at your option) any later version.
* This library is distributed in the hope that it will be useful,
* but WITHOUT ANY WARRANTY; without even the implied warranty of
* MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the GNU
* Lesser General Public License for more details.
*
* You should have received a copy of the GNU Lesser General Public
* License along with this project. If not, see <http://www.gnu.org/licenses/>.
* *******************************************************************************
*/
package de.monticore.lang.monticar.cnnarch;
import de.monticore.lang.monticar.cnnarch._symboltable.MethodDeclarationSymbol;
import de.monticore.lang.monticar.cnnarch._symboltable.VariableSymbol;
import org.jscience.mathematics.number.Rational;
import java.util.Arrays;
import java.util.List;
public class PredefinedMethods {
public static MethodDeclarationSymbol createFullyConnected(){
return new MethodDeclarationSymbol.Builder()
.name("FullyConnected")
.parameters(
new VariableSymbol.Builder()
.name("units")
.build(),
new VariableSymbol.Builder()
.name("no_bias")
.defaultValue(false)
.build()
)
.predefined(true)
.build();
}
public static MethodDeclarationSymbol createConvolution(){
return new MethodDeclarationSymbol.Builder()
.name("Convolution")
.parameters(
new VariableSymbol.Builder()
.name("kernel")
.build(),
new VariableSymbol.Builder()
.name("channels")
.build(),
new VariableSymbol.Builder()
.name("stride")
.defaultValue(1,1)
.build(),
new VariableSymbol.Builder()
.name("no_bias")
.defaultValue(false)
.build()
)
.predefined(true)
.build();
}
public static MethodDeclarationSymbol createSoftmax(){
return new MethodDeclarationSymbol.Builder()
.name("Softmax")
.predefined(true)
.build();
}
public static MethodDeclarationSymbol createSigmoid(){
return new MethodDeclarationSymbol.Builder()
.name("Sigmoid")
.predefined(true)
.build();
}
public static MethodDeclarationSymbol createTanh(){
return new MethodDeclarationSymbol.Builder()
.name("Tanh")
.predefined(true)
.build();
}
public static MethodDeclarationSymbol createRelu(){
return new MethodDeclarationSymbol.Builder()
.name("Relu")
.predefined(true)
.build();
}
public static MethodDeclarationSymbol createDropout(){
return new MethodDeclarationSymbol.Builder()
.name("Dropout")
.parameters(
new VariableSymbol.Builder()
.name("p")
.defaultValue(Rational.valueOf(1,2))//0.5
.build()
)
.predefined(true)
.build();
}
public static MethodDeclarationSymbol createMaxPooling(){
return new MethodDeclarationSymbol.Builder()
.name("MaxPooling")
.parameters(
new VariableSymbol.Builder()
.name("kernel")
.build(),
new VariableSymbol.Builder()
.name("stride")
.defaultValue(1,1)
.build(),
new VariableSymbol.Builder()
.name("global")
.defaultValue(false)
.build()
)
.predefined(true)
.build();
}
public static MethodDeclarationSymbol createAveragePooling(){
return new MethodDeclarationSymbol.Builder()
.name("AveragePooling")
.parameters(
new VariableSymbol.Builder()
.name("kernel")
.build(),
new VariableSymbol.Builder()
.name("stride")
.defaultValue(1,1)
.build(),
new VariableSymbol.Builder()
.name("global")
.defaultValue(false)
.build()
)
.predefined(true)
.build();
}
public static MethodDeclarationSymbol createLrn(){
return new MethodDeclarationSymbol.Builder()
.name("Lrn")
.parameters(
new VariableSymbol.Builder()
.name("nsize")
.build(),
new VariableSymbol.Builder()
.name("knorm")
.defaultValue(2)
.build(),
new VariableSymbol.Builder()
.name("alpha")
.defaultValue(Rational.valueOf(1,10000))//0.0001
.build(),
new VariableSymbol.Builder()
.name("beta")
.defaultValue(Rational.valueOf(3,4))//0.75
.build()
)
.predefined(true)
.build();
}
public static MethodDeclarationSymbol createBatchNorm(){
return new MethodDeclarationSymbol.Builder()
.name("BatchNorm")
.parameters(
//todo
)
.predefined(true)
.build();
}
public static MethodDeclarationSymbol createSplit(){
return new MethodDeclarationSymbol.Builder()
.name("Split")
.parameters(
new VariableSymbol.Builder()
.name("index")
.build(),
new VariableSymbol.Builder()
.name("n")
.build()
)
.predefined(true)
.build();
}
public static MethodDeclarationSymbol createGet(){
return new MethodDeclarationSymbol.Builder()
.name("Get")
.parameters(
new VariableSymbol.Builder()
.name("index")
.build()
)
.predefined(true)
.build();
}
public static MethodDeclarationSymbol createAdd(){
return new MethodDeclarationSymbol.Builder()
.name("Add")
.parameters(
)
.predefined(true)
.build();
}
public static MethodDeclarationSymbol createConcatenate(){
return new MethodDeclarationSymbol.Builder()
.name("Concatenate")
.parameters(
)
.predefined(true)
.build();
}
public static List<MethodDeclarationSymbol> createList(){
return Arrays.asList(
createFullyConnected(),
createConvolution(),
createSoftmax(),
createSigmoid(),
createTanh(),
createRelu(),
createDropout(),
createMaxPooling(),
createAveragePooling(),
createLrn(),
createBatchNorm(),
createSplit(),
createGet(),
createAdd(),
createConcatenate());
}
}
......@@ -20,12 +20,13 @@
*/
package de.monticore.lang.monticar.cnnarch._symboltable;
import java.util.List;
import java.util.Optional;
abstract public class ArchAbstractSequenceValue extends ArchValueSymbol {
abstract public class ArchAbstractSequenceExpression extends ArchExpressionSymbol {
public ArchAbstractSequenceValue() {
public ArchAbstractSequenceExpression() {
super();
}
......
/**
*
* ******************************************************************************
* MontiCAR Modeling Family, www.se-rwth.de
* Copyright (c) 2017, Software Engineering Group at RWTH Aachen,
* All rights reserved.
*
* This project is free software; you can redistribute it and/or
* modify it under the terms of the GNU Lesser General Public
* License as published by the Free Software Foundation; either
* version 3.0 of the License, or (at your option) any later version.
* This library is distributed in the hope that it will be useful,
* but WITHOUT ANY WARRANTY; without even the implied warranty of
* MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the GNU
* Lesser General Public License for more details.
*
* You should have received a copy of the GNU Lesser General Public
* License along with this project. If not, see <http://www.gnu.org/licenses/>.
* *******************************************************************************
*/
package de.monticore.lang.monticar.cnnarch._symboltable;
import de.monticore.symboltable.SymbolKind;
public class ArchExpressionKind implements SymbolKind {
private static final String NAME = "de.monticore.lang.monticar.cnnarch._symboltable.ArchExpressionKind";
@Override
public String getName() {
return NAME;
}
@Override
public boolean isKindOf(SymbolKind kind) {
return NAME.equals(kind.getName()) || SymbolKind.super.isKindOf(kind);
}
}
\ No newline at end of file
......@@ -21,33 +21,36 @@
package de.monticore.lang.monticar.cnnarch._symboltable;
import de.monticore.symboltable.CommonSymbol;
import de.monticore.symboltable.SymbolKind;
import java.util.Optional;
import java.util.Set;
abstract public class ArchValueSymbol extends CommonSymbol {
abstract public class ArchExpressionSymbol extends CommonSymbol {
public static final ArchValueKind KIND = new ArchValueKind();
public static final ArchExpressionKind KIND = new ArchExpressionKind();
private boolean fullyResolved = false;
public ArchValueSymbol() {
public ArchExpressionSymbol() {
super("", KIND);
}
/**
* Getter for the fullyResolved attribute.
* If it is still false for the return of resolve()
* then the value contains a dimension variable for input and output which has to be set.
* then the value contains a dimension variable for input
* and output which has to be set to succesfully resolve thr expression.
*
* @return returns true if the value contains no variables.
* @return returns true iff the expression is resolved.
*/
public boolean isFullyResolved() {
return fullyResolved;
}
public void setFullyResolved(boolean fullyResolved) {
//todo: change to isResolvable()
protected void setFullyResolved(boolean fullyResolved) {
this.fullyResolved = fullyResolved;
}
......@@ -97,6 +100,7 @@ abstract public class ArchValueSymbol extends CommonSymbol {
* If true, getValue() will return (if present) a List of Lists of Objects.
* These Objects can either be Integer, Double or Boolean.
* If isSerialSequence() returns false, the second List will always have a size smaller than 2.
* Sequences of size 1 or 0 cannot be parallel sequences.
*
* @return returns true iff the value contains a parallel sequence.
*/
......@@ -106,8 +110,8 @@ abstract public class ArchValueSymbol extends CommonSymbol {
/**
* Checks whether the value is a serial Sequence.
* If true, getValue() will either return (if present) a List of Objects
* or a List(parallel) of Lists(serial) of Objects if isParallelSequence() is also true.
* If true, getValue() will return (if present) a List(parallel) of Lists(serial) of Objects.
* If isParallelSequence() is false, the first list will be of size 1.
* These Objects can either be Integer, Double or Boolean.
* Sequences of size 1 or 0 are counted as serial sequences.
* Therefore, this returns always true if isParallelSequence() returns false.
......@@ -120,7 +124,7 @@ abstract public class ArchValueSymbol extends CommonSymbol {
/**
*
* @return returns true if this object is instance of ArchRangeValueSymbol
* @return returns true if this object is instance of ArchRangeExpressionSymbol
*/
public boolean isRange(){
return false;
......@@ -128,7 +132,7 @@ abstract public class ArchValueSymbol extends CommonSymbol {
/**
*
* @return returns true if this object is instance of ArchSimpleValueSymbol
* @return returns true if this object is instance of ArchSimpleExpressionSymbol
*/
public boolean isSimpleValue(){
return false;
......@@ -139,7 +143,6 @@ abstract public class ArchValueSymbol extends CommonSymbol {
/**
* This method returns the result of the expression.
* This can be a primitive object (Integer, Double or Boolean)
* or List of primitive objects
* or a list of lists of primitive objects. (See other methods for more information)
*
* @return returns the value as Object or Optional.empty if the expression cannot be completely resolved yet.
......@@ -148,14 +151,13 @@ abstract public class ArchValueSymbol extends CommonSymbol {
abstract public Optional<Object> getValue();
/**
* Creates a copy of this symbol where all Variables are replaced by expressions without variables.
* Replaces all variable names in this values expression.
* If the expression contains an IOVariable which has not yet been set
* then the expression is resolved as much as possible and the attribute fullyResolved of the return object remains false.
* then the expression is resolved as much as possible and the attribute fullyResolved of this object remains false.
*
* @return returns a copy of this object where the expression is resolved as much as possible
* or itself if attribute fullyResolved is true.
* @return returns a set of all names which could not be resolved.
*/
abstract public ArchValueSymbol resolve();
abstract public Set<String> resolve();
//abstract public ArchValueSymbol resolveCopy()
abstract protected void checkIfResolved();
}
......@@ -20,37 +20,32 @@
*/
package de.monticore.lang.monticar.cnnarch._symboltable;
import java.util.Optional;
import java.util.*;
public class ArchRangeValueSymbol extends ArchAbstractSequenceValue {
public class ArchRangeExpressionSymbol extends ArchAbstractSequenceExpression {
private ArchSimpleValueSymbol startSymbol;
private ArchSimpleValueSymbol endSymbol;
private ArchSimpleExpressionSymbol startSymbol;
private ArchSimpleExpressionSymbol endSymbol;
private boolean parallel;
public ArchRangeValueSymbol() {
super();
}
public ArchRangeValueSymbol(ArchSimpleValueSymbol startSymbol, ArchSimpleValueSymbol endSymbol, boolean parallel) {
this.startSymbol = startSymbol;
this.endSymbol = endSymbol;
this.parallel = parallel;
public ArchRangeExpressionSymbol() {
super();
}
public ArchSimpleValueSymbol getStartSymbol() {
public ArchSimpleExpressionSymbol getStartSymbol() {
return startSymbol;
}
public void setStartSymbol(ArchSimpleValueSymbol startSymbol) {
public void setStartSymbol(ArchSimpleExpressionSymbol startSymbol) {
this.startSymbol = startSymbol;
}
public ArchSimpleValueSymbol getEndSymbol() {
public ArchSimpleExpressionSymbol getEndSymbol() {
return endSymbol;
}
public void setEndSymbol(ArchSimpleValueSymbol endSymbol) {
public void setEndSymbol(ArchSimpleExpressionSymbol endSymbol) {
this.endSymbol = endSymbol;
}
......@@ -74,12 +69,11 @@ public class ArchRangeValueSymbol extends ArchAbstractSequenceValue {
private Optional<Integer> getLength(){
Optional<Integer> optLength = Optional.empty();
ArchRangeValueSymbol resolvedSymbol = resolve();
if (resolvedSymbol.isFullyResolved()){
Object startValue = resolvedSymbol.getEndSymbol().getValue().get();
Object endValue = resolvedSymbol.getEndSymbol().getValue().get();
if (startValue instanceof Integer && endValue instanceof Integer){
int start = (Integer)startValue;
if (isFullyResolved()) {
Object startValue = getEndSymbol().getValue().get();
Object endValue = getEndSymbol().getValue().get();
if (startValue instanceof Integer && endValue instanceof Integer) {
int start = (Integer) startValue;
int end = (Integer) endValue;
optLength = Optional.of(Math.abs(end - start) + 1);
}
......@@ -109,11 +103,60 @@ public class ArchRangeValueSymbol extends ArchAbstractSequenceValue {
@Override
public Optional<Object> getValue() {
return null;
if (isFullyResolved()){
//todo check in CoCo: startSymbol.isInt() && endSymbol.isInt()
int startInt = (Integer) startSymbol.getValue().get();
int endInt = (Integer) endSymbol.getValue().get();
int step = 1;
if (endInt < startInt){
step = -1;
}
List<List<Integer>> valueLists = new ArrayList<>();
if (isParallel()){
for (int i = startInt; i <= endInt; i = i + step){
List<Integer> values = new ArrayList<>(1);
values.add(i);
valueLists.add(values);
}