Commit 9b0850da authored by Thomas Michael Timmermanns's avatar Thomas Michael Timmermanns
Browse files

Added output shape functions for methods to calculate the dimensions after each layer.

parent d8be1af8
# CNNArch
[![Build Status](https://travis-ci.org/EmbeddedMontiArc/CNNArchLang.svg?branch=master)](https://travis-ci.org/EmbeddedMontiArc/CNNArchLang)
[![Build Status](https://circleci.com/gh/EmbeddedMontiArc/CNNArchLang/tree/master.svg?style=shield&circle-token=:circle-token)](https://circleci.com/gh/EmbeddedMontiArc/CNNArchLang/tree/master)
[![Coverage Status](https://coveralls.io/repos/github/EmbeddedMontiArc/CNNArchLang/badge.svg?branch=master)](https://coveralls.io/github/EmbeddedMontiArc/CNNArchLang?branch=master)
[![Build Status](https://travis-ci.org/EmbeddedMontiArc/CNNArchLang.svg?branch=timmermanns)](https://travis-ci.org/EmbeddedMontiArc/CNNArchLang)
[![Build Status](https://circleci.com/gh/EmbeddedMontiArc/CNNArchLang/tree/master.svg?style=shield&circle-token=:circle-token)](https://circleci.com/gh/EmbeddedMontiArc/CNNArchLang/tree/timmermanns)
[![Coverage Status](https://coveralls.io/repos/github/EmbeddedMontiArc/CNNArchLang/badge.svg?branch=timmermanns)](https://coveralls.io/github/EmbeddedMontiArc/CNNArchLang?branch=timmermanns)
......@@ -57,7 +57,7 @@
<!-- interactiveMode
| This will determine whether maven prompts you when it needs input. If set to false,
| maven will use a sensible default value, perhaps based on some other setting, for
| maven will use a sensible default rhs, perhaps based on some other setting, for
| the parameter in question.
|
| Default: true
......@@ -205,9 +205,9 @@
|
| As noted above, profiles can be activated in a variety of ways. One way - the activeProfiles
| section of this document (settings.xml) - will be discussed later. Another way essentially
| relies on the detection of a system property, either matching a particular value for the property,
| relies on the detection of a system property, either matching a particular rhs for the property,
| or merely testing its existence. Profiles can also be activated by JDK version prefix, where a
| value of '1.4' might activate a profile when the build is executed on a JDK version of '1.4.2_07'.
| rhs of '1.4' might activate a profile when the build is executed on a JDK version of '1.4.2_07'.
| Finally, the list of active profiles can be specified directly from the command line.
|
| NOTE: For profiles defined in the settings.xml, you are restricted to specifying only artifact
......@@ -245,7 +245,7 @@
-->
<!--
| Here is another profile, activated by the system property 'target-env' with a value of 'dev',
| Here is another profile, activated by the system property 'target-env' with a rhs of 'dev',
| which provides a specific path to the Tomcat instance. To use this, your plugin configuration
| might hypothetically look like:
|
......@@ -261,14 +261,14 @@
| ...
|
| NOTE: If you just wanted to inject this configuration whenever someone set 'target-env' to
| anything, you could just leave off the <value/> inside the activation-property.
| anything, you could just leave off the <rhs/> inside the activation-property.
|
<profile>
<id>env-dev</id>
<activation>
<property>
<name>target-env</name>
<value>dev</value>
<rhs>dev</rhs>
</property>
</activation>
<properties>
......
......@@ -21,16 +21,19 @@
package de.monticore.lang.monticar.cnnarch;
import de.monticore.lang.monticar.cnnarch._symboltable.MethodDeclarationSymbol;
import de.monticore.lang.monticar.cnnarch._symboltable.MethodLayerSymbol;
import de.monticore.lang.monticar.cnnarch._symboltable.ShapeSymbol;
import de.monticore.lang.monticar.cnnarch._symboltable.VariableSymbol;
import org.jscience.mathematics.number.Rational;
import java.util.Arrays;
import java.util.Collections;
import java.util.List;
public class PredefinedMethods {
public static MethodDeclarationSymbol createFullyConnected(){
return new MethodDeclarationSymbol.Builder()
return new MethodDeclarationSymbol.Builder()
.name("FullyConnected")
.parameters(
new VariableSymbol.Builder()
......@@ -41,7 +44,11 @@ public class PredefinedMethods {
.defaultValue(false)
.build()
)
.predefined(true)
.shapeFunction((inputShapes, method) -> Collections.singletonList(new ShapeSymbol.Builder()
.height(1)
.width(1)
.channels(method.getIntValue("units").get())
.build()))
.build();
}
......@@ -64,35 +71,53 @@ public class PredefinedMethods {
.defaultValue(false)
.build()
)
.predefined(true)
.shapeFunction((inputShapes, method) ->
strideShapeFunction(inputShapes.get(0),
method,
method.getIntValue("channels").get()))
.build();
}
private static List<ShapeSymbol> strideShapeFunction(ShapeSymbol inputShape, MethodLayerSymbol method, int channels){
int strideHeight = method.getIntTupleValue("stride").get().get(0);
int strideWidth = method.getIntTupleValue("stride").get().get(1);
int kernelHeight = method.getIntTupleValue("kernel").get().get(0);
int kernelWidth = method.getIntTupleValue("kernel").get().get(1);
int inputHeight = inputShape.getHeight().get();
int inputWidth = inputShape.getWidth().get();
//assume padding with border_mode='same'
int outputWidth = 1 + ((inputWidth - kernelWidth + strideWidth - 1) / strideWidth);
int outputHeight = 1 + ((inputHeight - kernelHeight + strideHeight - 1) / strideHeight);
return Collections.singletonList(new ShapeSymbol.Builder()
.height(outputHeight)
.width(outputWidth)
.channels(channels)
.build());
}
public static MethodDeclarationSymbol createSoftmax(){
return new MethodDeclarationSymbol.Builder()
.name("Softmax")
.predefined(true)
.build();
}
public static MethodDeclarationSymbol createSigmoid(){
return new MethodDeclarationSymbol.Builder()
.name("Sigmoid")
.predefined(true)
.build();
}
public static MethodDeclarationSymbol createTanh(){
return new MethodDeclarationSymbol.Builder()
.name("Tanh")
.predefined(true)
.build();
}
public static MethodDeclarationSymbol createRelu(){
return new MethodDeclarationSymbol.Builder()
.name("Relu")
.predefined(true)
.build();
}
......@@ -105,7 +130,6 @@ public class PredefinedMethods {
.defaultValue(Rational.valueOf(1,2))//0.5
.build()
)
.predefined(true)
.build();
}
......@@ -125,7 +149,10 @@ public class PredefinedMethods {
.defaultValue(false)
.build()
)
.predefined(true)
.shapeFunction((inputShapes, method) ->
strideShapeFunction(inputShapes.get(0),
method,
inputShapes.get(0).getChannels().get()))
.build();
}
......@@ -145,7 +172,10 @@ public class PredefinedMethods {
.defaultValue(false)
.build()
)
.predefined(true)
.shapeFunction((inputShapes, method) ->
strideShapeFunction(inputShapes.get(0),
method,
inputShapes.get(0).getChannels().get()))
.build();
}
......@@ -169,7 +199,6 @@ public class PredefinedMethods {
.defaultValue(Rational.valueOf(3,4))//0.75
.build()
)
.predefined(true)
.build();
}
......@@ -179,7 +208,6 @@ public class PredefinedMethods {
.parameters(
//todo
)
.predefined(true)
.build();
}
......@@ -194,10 +222,34 @@ public class PredefinedMethods {
.name("n")
.build()
)
.predefined(true)
.shapeFunction((inputShapes, method) -> splitShapeFunction(inputShapes.get(0), method))
.build();
}
private static List<ShapeSymbol> splitShapeFunction(ShapeSymbol inputShape, MethodLayerSymbol method){
int numberOfSplits = method.getIntValue("n").get();
int groupIndex = method.getIntValue("index").get();
int inputChannels = inputShape.getChannels().get();
int outputChannels = inputChannels / numberOfSplits;
int outputChannelsLast = inputChannels - numberOfSplits*outputChannels;
if (groupIndex == numberOfSplits - 1){
return Collections.singletonList(new ShapeSymbol.Builder()
.height(1)
.width(1)
.channels(outputChannelsLast)
.build());
}
else {
return Collections.singletonList(new ShapeSymbol.Builder()
.height(1)
.width(1)
.channels(outputChannels)
.build());
}
}
public static MethodDeclarationSymbol createGet(){
return new MethodDeclarationSymbol.Builder()
.name("Get")
......@@ -206,30 +258,37 @@ public class PredefinedMethods {
.name("index")
.build()
)
.predefined(true)
.shapeFunction((inputShapes, method) ->
Collections.singletonList(inputShapes.get(method.getIntValue("index").get())))
.build();
}
public static MethodDeclarationSymbol createAdd(){
return new MethodDeclarationSymbol.Builder()
.name("Add")
.parameters(
)
.predefined(true)
.shapeFunction((inputShapes, method) -> Collections.singletonList(inputShapes.get(0)))
.build();
}
public static MethodDeclarationSymbol createConcatenate(){
return new MethodDeclarationSymbol.Builder()
.name("Concatenate")
.parameters(
)
.predefined(true)
.shapeFunction(PredefinedMethods::concatenateShapeFunction)
.build();
}
private static List<ShapeSymbol> concatenateShapeFunction(List<ShapeSymbol> inputShapes, MethodLayerSymbol method){
int channels = 0;
for (ShapeSymbol inputShape : inputShapes){
channels += inputShape.getChannels().get();
}
return Collections.singletonList(new ShapeSymbol.Builder()
.height(inputShapes.get(0).getHeight().get())
.width(inputShapes.get(0).getWidth().get())
.channels(channels)
.build());
}
public static List<MethodDeclarationSymbol> createList(){
return Arrays.asList(
......
......@@ -55,7 +55,7 @@ abstract public class ArchExpressionSymbol extends CommonSymbol {
}
/**
* Checks whether the value is a boolean. If true getValue() will return a Boolean if present.
* Checks whether the value is a boolean. If true getRhs() will return a Boolean if present.
*
* @return returns true iff the value of the resolved expression will be a boolean.
*/
......@@ -65,7 +65,7 @@ abstract public class ArchExpressionSymbol extends CommonSymbol {
/**
* Checks whether the value is a number.
* Note that the return of getValue() can be either a Double or an Integer if present.
* Note that the return of getRhs() can be either a Double or an Integer if present.
*
* @return returns true iff the value of the resolved expression will be a number.
*/
......@@ -75,7 +75,7 @@ abstract public class ArchExpressionSymbol extends CommonSymbol {
/**
* Checks whether the value is a Tuple.
* If true getValue() will return (if present) a List of Objects.
* If true getRhs() will return (if present) a List of Objects.
* These Objects can either be Integer, Double or Boolean.
*
* @return returns true iff the value of the expression will be a tuple.
......@@ -86,7 +86,7 @@ abstract public class ArchExpressionSymbol extends CommonSymbol {
/**
* Checks whether the value is an integer. This can only be checked if the expression is resolvable.
* If true getValue() will return an Integer.
* If true getRhs() will return an Integer.
*
* @return returns Optional.of(true) iff the value of the expression is an integer.
* The Optional is present if the expression can be resolved.
......@@ -97,7 +97,7 @@ abstract public class ArchExpressionSymbol extends CommonSymbol {
/**
* Checks whether the value is a parallel Sequence.
* If true, getValue() will return (if present) a List of Lists of Objects.
* If true, getRhs() will return (if present) a List of Lists of Objects.
* These Objects can either be Integer, Double or Boolean.
* If isSerialSequence() returns false, the second List will always have a size smaller than 2.
* Sequences of size 1 or 0 cannot be parallel sequences.
......@@ -110,7 +110,7 @@ abstract public class ArchExpressionSymbol extends CommonSymbol {
/**
* Checks whether the value is a serial Sequence.
* If true, getValue() will return (if present) a List(parallel) of Lists(serial) of Objects.
* If true, getRhs() will return (if present) a List(parallel) of Lists(serial) of Objects.
* If isParallelSequence() is false, the first list will be of size 1.
* These Objects can either be Integer, Double or Boolean.
* Sequences of size 1 or 0 are counted as serial sequences.
......
......@@ -21,9 +21,9 @@
package de.monticore.lang.monticar.cnnarch._symboltable;
import de.monticore.symboltable.CommonSymbol;
import de.monticore.symboltable.SymbolKind;
import de.se_rwth.commons.logging.Log;
import java.util.List;
import java.util.Optional;
public class ArgumentSymbol extends CommonSymbol {
......@@ -31,7 +31,7 @@ public class ArgumentSymbol extends CommonSymbol {
public static final ArgumentKind KIND = new ArgumentKind();
private VariableSymbol parameter;
private ArchExpressionSymbol value;
private ArchExpressionSymbol rhs;
protected ArgumentSymbol(String name) {
super(name, KIND);
......@@ -54,12 +54,26 @@ public class ArgumentSymbol extends CommonSymbol {
this.parameter = parameter;
}
public ArchExpressionSymbol getValue() {
return value;
public ArchExpressionSymbol getRhs() {
return rhs;
}
protected void setValue(ArchExpressionSymbol value) {
this.value = value;
public Optional<Object> getValue(){
return getRhs().getValue();
}
protected void setRhs(ArchExpressionSymbol rhs) {
this.rhs = rhs;
}
public List<List<ArgumentSymbol>> split(){
//todo
return null;
}
public List<List<ArgumentSymbol>> expandedSplit(int parallelLength, int serialLength){
//todo
return null;
}
......@@ -90,7 +104,7 @@ public class ArgumentSymbol extends CommonSymbol {
}
ArgumentSymbol sym = new ArgumentSymbol(name);
sym.setParameter(parameter);
sym.setValue(value);
sym.setRhs(value);
return sym;
}
}
......
......@@ -21,7 +21,6 @@
package de.monticore.lang.monticar.cnnarch._symboltable;
import de.monticore.ast.ASTNode;
import de.monticore.lang.math.math._ast.ASTMathExpression;
import de.monticore.lang.math.math._symboltable.MathSymbolTableCreator;
import de.monticore.lang.math.math._symboltable.expression.MathExpressionSymbol;
......@@ -368,7 +367,7 @@ public class CNNArchSymbolTableCreator extends de.monticore.symboltable.CommonSy
MethodLayerSymbol methodLayer = (MethodLayerSymbol) currentScope().get().getSpanningSymbol().get();
ArchExpressionSymbol value = (ArchExpressionSymbol) node.getRhs().getSymbol().get();
ArgumentSymbol argument = (ArgumentSymbol) node.getSymbol().get();
argument.setValue(value);
argument.setRhs(value);
VariableSymbol parameter = (VariableSymbol) methodLayer.getMethod().getSpannedScope()
.resolveLocally(argument.getName(), VariableSymbol.KIND).get();
......
......@@ -55,30 +55,41 @@ public class CompositeLayerSymbol extends LayerSymbol {
@Override
public Set<String> resolve() {
Set<String> unresolvableSet = new HashSet<>();
if (!isFullyResolved()) {
for (LayerSymbol layer : layers) {
unresolvableSet.addAll(layer.resolve());
}
}
//todo
return unresolvableSet;
}
@Override
protected void checkIfResolved() {
boolean isResolved = true;
for (LayerSymbol layer : layers){
layer.checkIfResolved();
if (!layer.isFullyResolved()){
isResolved = false;
//todo
}
@Override
protected List<ShapeSymbol> computeOutputShape() {
if (isParallel()){
List<ShapeSymbol> outputShapes = new ArrayList<>(getLayers().size());
for (LayerSymbol layer : getLayers()){
//todo: assure that last layer in each parallel group has only one outputShape
outputShapes.add(layer.getOutputShapes().get(0));
}
return outputShapes;
}
else {
return getLayers().get(getLayers().size() - 1).getOutputShapes();
}
setFullyResolved(isResolved);
}
@Override
public boolean isResolvable() {
//todo
return false;
}
public static class Builder{
private boolean parallel = false;
private List<LayerSymbol> layers = new ArrayList<>();
private LayerSymbol inputLayer;
public Builder parallel(boolean parallel){
this.parallel = parallel;
......@@ -95,10 +106,16 @@ public class CompositeLayerSymbol extends LayerSymbol {
return this;
}
public Builder inputLayer(LayerSymbol inputLayer){
this.inputLayer = inputLayer;
return this;
}
public CompositeLayerSymbol build(){
CompositeLayerSymbol sym = new CompositeLayerSymbol();
sym.setParallel(parallel);
sym.setLayers(layers);
sym.setInputLayer(inputLayer);
return sym;
}
}
......
/**
*
* ******************************************************************************
* MontiCAR Modeling Family, www.se-rwth.de
* Copyright (c) 2017, Software Engineering Group at RWTH Aachen,
* All rights reserved.
*
* This project is free software; you can redistribute it and/or
* modify it under the terms of the GNU Lesser General Public
* License as published by the Free Software Foundation; either
* version 3.0 of the License, or (at your option) any later version.
* This library is distributed in the hope that it will be useful,
* but WITHOUT ANY WARRANTY; without even the implied warranty of
* MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the GNU
* Lesser General Public License for more details.
*
* You should have received a copy of the GNU Lesser General Public
* License along with this project. If not, see <http://www.gnu.org/licenses/>.
* *******************************************************************************
*/
package de.monticore.lang.monticar.cnnarch._symboltable;
public class DimensionKind {
import de.monticore.symboltable.SymbolKind;
public class DimensionKind implements SymbolKind {
private static final String NAME = "de.monticore.lang.monticar.cnnarch._symboltable.DimensionKind";
@Override
public String getName() {
return NAME;
}
@Override
public boolean isKindOf(SymbolKind kind) {
return NAME.equals(kind.getName()) || SymbolKind.super.isKindOf(kind);
}
}
/**
*
* ******************************************************************************
* MontiCAR Modeling Family, www.se-rwth.de
* Copyright (c) 2017, Software Engineering Group at RWTH Aachen,
* All rights reserved.
*
* This project is free software; you can redistribute it and/or
* modify it under the terms of the GNU Lesser General Public
* License as published by the Free Software Foundation; either
* version 3.0 of the License, or (at your option) any later version.
* This library is distributed in the hope that it will be useful,
* but WITHOUT ANY WARRANTY; without even the implied warranty of
* MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the GNU
* Lesser General Public License for more details.
*
* You should have received a copy of the GNU Lesser General Public
* License along with this project. If not, see <http://www.gnu.org/licenses/>.
* *******************************************************************************
*/
package de.monticore.lang.monticar.cnnarch._symboltable;
public class DimensionSymbol {
import de.monticore.symboltable.CommonSymbol;
import java.util.Optional;
public class DimensionSymbol extends CommonSymbol {
public static final DimensionKind KIND = new DimensionKind();
private ArchSimpleExpressionSymbol valueExpression;
private VariableSymbol ioVariable;
public DimensionSymbol() {
super("", KIND);
}
public ArchSimpleExpressionSymbol getValueExpression() {
return valueExpression;
}
public void setValueExpression(ArchSimpleExpressionSymbol valueExpression) {
this.valueExpression = valueExpression;
}
public Optional<VariableSymbol> getIoVariable() {
return Optional.ofNullable(ioVariable);
}
public void setIoVariable(VariableSymbol ioVariable) {
this.ioVariable = ioVariable;
}
public Optional<Integer> getValue(){
Optional<Object> optObj = getValueExpression().getValue();
return optObj.map(o -> (Integer) o);
}
public static DimensionSymbol of(int value){
DimensionSymbol sym = new DimensionSymbol();
sym.setValueExpression(ArchSimpleExpressionSymbol.of(value));
return sym;
}
}