Commit d9c07007 authored by Sebastian Nickels's avatar Sebastian Nickels
Browse files

Cleaned up code, renamed BeamSearchStart to BeamSearch

parent 6793aa59
Pipeline #180213 failed with stages
in 18 minutes and 31 seconds
......@@ -359,6 +359,7 @@
<version>2.19.1</version>
<configuration>
<argLine>-Xmx1024m -Xms1024m -XX:MaxPermSize=512m -Djdk.net.URLClassPath.disableClassPathURLCheck=true</argLine>
<trimStackTrace>false</trimStackTrace>
</configuration>
</plugin>
<plugin>
......
......@@ -62,7 +62,7 @@ grammar CNNArch extends de.monticore.CommonExpressions, de.monticore.lang.Math,
Stream = elements:(ArchitectureElement || "->")+;
Unroll = "timed" "<" timeParameter:LayerParameter ">"
Unroll = "timed" "<" timeParameter:TimeParameter ">"
Name "(" arguments:(ArchArgument || ",")* ")"
"{" body:Stream "}";
......@@ -90,6 +90,8 @@ grammar CNNArch extends de.monticore.CommonExpressions, de.monticore.lang.Math,
LayerParameter implements ArchParameter = Name ("=" default:ArchSimpleExpression)? ;
TimeParameter implements ArchParameter = Name ("=" default:ArchSimpleExpression)? ;
interface ArchArgument;
ArchParameterArgument implements ArchArgument = Name "=" rhs:ArchExpression ;
......
......@@ -269,6 +269,26 @@ public class CNNArchSymbolTableCreator extends de.monticore.symboltable.CommonSy
}
}
@Override
public void visit(ASTTimeParameter ast) {
ParameterSymbol variable = new ParameterSymbol(ast.getName());
variable.setType(ParameterType.TIME_PARAMETER);
addToScopeAndLinkWithNode(variable, ast);
}
@Override
public void endVisit(ASTTimeParameter ast) {
ParameterSymbol variable = (ParameterSymbol) ast.getSymbolOpt().get();
if (ast.isPresentDefault()){
variable.setDefaultExpression((ArchSimpleExpressionSymbol) ast.getDefault().getSymbolOpt().get());
}
else {
ArchSimpleExpressionSymbol expression = new ArchSimpleExpressionSymbol();
expression.setValue(1);
variable.setDefaultExpression(expression);
}
}
@Override
public void endVisit(ASTArchSimpleExpression ast) {
ArchSimpleExpressionSymbol sym = new ArchSimpleExpressionSymbol();
......@@ -348,10 +368,7 @@ public class CNNArchSymbolTableCreator extends de.monticore.symboltable.CommonSy
}
layer.setArguments(arguments);
ArchSimpleExpressionSymbol t_value = new ArchSimpleExpressionSymbol();
t_value.setValue(1);
((ParameterSymbol)ast.getTimeParameter().getSymbolOpt().get()).setDefaultExpression(t_value);
layer.setTimeParameter((ParameterSymbol)ast.getTimeParameter().getSymbolOpt().get());
layer.setTimeParameter((ParameterSymbol) ast.getTimeParameter().getSymbolOpt().get());
removeCurrentScope();
}
......
......@@ -38,7 +38,6 @@ public class ParameterSymbol extends CommonSymbol {
private ArchSimpleExpressionSymbol currentExpression = null; //Optional
private Set<Constraints> constraints = new HashSet<>();
protected ParameterSymbol(String name) {
super(name, KIND);
}
......@@ -94,6 +93,9 @@ public class ParameterSymbol extends CommonSymbol {
return type == ParameterType.LAYER_PARAMETER;
}
public boolean isTimeParameter(){
return type == ParameterType.TIME_PARAMETER;
}
public boolean hasExpression(){
return getCurrentExpression().isPresent() || getDefaultExpression().isPresent();
......
......@@ -39,6 +39,12 @@ public enum ParameterType {
return "constant";
}
},
TIME_PARAMETER {
@Override
public String toString(){
return "time parameter";
}
},
UNKNOWN {
//describes a parameter which does not exist. Only used to avoid exceptions while checking Cocos.
@Override
......
/**
*
* ******************************************************************************
* MontiCAR Modeling Family, www.se-rwth.de
* Copyright (c) 2017, Software Engineering Group at RWTH Aachen,
* All rights reserved.
*
* This project is free software; you can redistribute it and/or
* modify it under the terms of the GNU Lesser General Public
* License as published by the Free Software Foundation; either
* version 3.0 of the License, or (at your option) any later version.
* This library is distributed in the hope that it will be useful,
* but WITHOUT ANY WARRANTY; without even the implied warranty of
* MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the GNU
* Lesser General Public License for more details.
*
* You should have received a copy of the GNU Lesser General Public
* License along with this project. If not, see <http://www.gnu.org/licenses/>.
* *******************************************************************************
*/
package de.monticore.lang.monticar.cnnarch._symboltable;
import com.google.gson.internal.Streams;
import de.monticore.lang.monticar.cnnarch.helper.ErrorCodes;
import de.monticore.lang.monticar.cnnarch.predefined.AllPredefinedLayers;
import de.monticore.lang.monticar.ranges._ast.ASTRange;
import de.se_rwth.commons.logging.Log;
import org.jscience.mathematics.number.Rational;
import java.util.Arrays;
import java.util.Collections;
import java.util.List;
import java.util.Optional;
import java.util.function.BinaryOperator;
import java.util.stream.Stream;
abstract public class PredefinedUnrollDeclaration extends UnrollDeclarationSymbol {
public PredefinedUnrollDeclaration(String name) {
super(name);
}
@Override
protected void setParameters(List<ParameterSymbol> parameters) {
super.setParameters(parameters);
for (ParameterSymbol param : parameters){
param.putInScope(getSpannedScope());
}
}
protected void setLayers(List<LayerSymbol> layers) {
super.setLayers(layers);
for (LayerSymbol layer : layers){
layer.putInScope(getSpannedScope());
}
}
@Override
public boolean isPredefined() {
return true;
}
abstract public List<ArchTypeSymbol> computeOutputTypes(List<ArchTypeSymbol> inputTypes, UnrollSymbol layer, VariableSymbol.Member member);
abstract public void checkInput(List<ArchTypeSymbol> inputTypes, UnrollSymbol layer, VariableSymbol.Member member);
@Override
public PredefinedUnrollDeclaration deepCopy() {
throw new IllegalStateException("Copy method should not be called for predefined layer declarations.");
}
//the following methods are only here to avoid duplication. They are used by multiple subclasses.
//check if inputTypes is of size 1
protected void errorIfInputSizeIsNotOne(List<ArchTypeSymbol> inputTypes, UnrollSymbol layer){
if (inputTypes.size() != 1){
Log.error("0" + ErrorCodes.INVALID_ELEMENT_INPUT_SHAPE + " Invalid layer input. " +
getName() + " layer can only handle one input stream. " +
"Current number of input streams " + inputTypes.size() + "."
, layer.getSourcePosition());
}
}
protected void errorIfInputIsEmpty(List<ArchTypeSymbol> inputTypes, LayerSymbol layer){
if (inputTypes.size() == 0){
Log.error("0" + ErrorCodes.INVALID_ELEMENT_INPUT_SHAPE + " Invalid layer input. Number of input streams is 0"
, layer.getSourcePosition());
}
}
//check input for convolution and pooling
protected static void errorIfInputSmallerThanKernel(List<ArchTypeSymbol> inputTypes, LayerSymbol layer){
if (!inputTypes.isEmpty()) {
int inputHeight = inputTypes.get(0).getHeight();
int inputWidth = inputTypes.get(0).getWidth();
int kernelHeight = layer.getIntTupleValue(AllPredefinedLayers.KERNEL_NAME).get().get(0);
int kernelWidth = layer.getIntTupleValue(AllPredefinedLayers.KERNEL_NAME).get().get(1);
if (kernelHeight > inputHeight || kernelWidth > inputWidth){
if (layer.getStringValue(AllPredefinedLayers.PADDING_NAME).equals(AllPredefinedLayers.PADDING_VALID)){
Log.error("0" + ErrorCodes.INVALID_ELEMENT_INPUT_SHAPE + " Invalid layer input. " +
"The input resolution is smaller than the kernel and the padding mode is 'valid'." +
"This would result in an output resolution of 0x0."
, layer.getSourcePosition());
}
else {
Log.warn("The input resolution is smaller than the kernel. " +
"This results in an output resolution of 1x1. " +
"If this warning appears multiple times, consider changing your architecture"
, layer.getSourcePosition());
}
}
}
}
//output type function for convolution and pooling
protected static List<ArchTypeSymbol> computeConvAndPoolOutputShape(ArchTypeSymbol inputType, LayerSymbol method, int channels) {
String borderModeSetting = method.getStringValue(AllPredefinedLayers.PADDING_NAME).get();
if (borderModeSetting.equals(AllPredefinedLayers.PADDING_SAME)){
return computeOutputShapeWithSamePadding(inputType, method, channels);
}
else if (borderModeSetting.equals(AllPredefinedLayers.PADDING_VALID)){
return computeOutputShapeWithValidPadding(inputType, method, channels);
}
else if (borderModeSetting.equals(AllPredefinedLayers.PADDING_NO_LOSS)){
return computeOutputShapeWithNoLossPadding(inputType, method, channels);
}
else{
throw new IllegalStateException("border_mode is " + borderModeSetting + ". This should never happen.");
}
}
//padding with border_mode=valid, no padding
private static List<ArchTypeSymbol> computeOutputShapeWithValidPadding(ArchTypeSymbol inputType, LayerSymbol method, int channels){
int strideHeight = method.getIntTupleValue(AllPredefinedLayers.STRIDE_NAME).get().get(0);
int strideWidth = method.getIntTupleValue(AllPredefinedLayers.STRIDE_NAME).get().get(1);
int kernelHeight = method.getIntTupleValue(AllPredefinedLayers.KERNEL_NAME).get().get(0);
int kernelWidth = method.getIntTupleValue(AllPredefinedLayers.KERNEL_NAME).get().get(1);
int inputHeight = inputType.getHeight();
int inputWidth = inputType.getWidth();
int outputWidth;
int outputHeight;
if (inputWidth < kernelWidth || inputHeight < kernelHeight){
outputWidth = 0;
outputHeight = 0;
}
else {
outputWidth = 1 + (inputWidth - kernelWidth) / strideWidth;
outputHeight = 1 + (inputHeight - kernelHeight) / strideHeight;
}
return Collections.singletonList(new ArchTypeSymbol.Builder()
.height(outputHeight)
.width(outputWidth)
.channels(channels)
.elementType("-oo", "oo")
.build());
}
//padding until no data gets discarded, same as valid with a stride of 1
private static List<ArchTypeSymbol> computeOutputShapeWithNoLossPadding(ArchTypeSymbol inputType, LayerSymbol method, int channels){
int strideHeight = method.getIntTupleValue(AllPredefinedLayers.STRIDE_NAME).get().get(0);
int strideWidth = method.getIntTupleValue(AllPredefinedLayers.STRIDE_NAME).get().get(1);
int kernelHeight = method.getIntTupleValue(AllPredefinedLayers.KERNEL_NAME).get().get(0);
int kernelWidth = method.getIntTupleValue(AllPredefinedLayers.KERNEL_NAME).get().get(1);
int inputHeight = inputType.getHeight();
int inputWidth = inputType.getWidth();
int outputWidth = 1 + Math.max(0, ((inputWidth - kernelWidth + strideWidth - 1) / strideWidth));
int outputHeight = 1 + Math.max(0, ((inputHeight - kernelHeight + strideHeight - 1) / strideHeight));
return Collections.singletonList(new ArchTypeSymbol.Builder()
.height(outputHeight)
.width(outputWidth)
.channels(channels)
.elementType("-oo", "oo")
.build());
}
//padding with border_mode='same'
private static List<ArchTypeSymbol> computeOutputShapeWithSamePadding(ArchTypeSymbol inputType, LayerSymbol method, int channels){
int strideHeight = method.getIntTupleValue(AllPredefinedLayers.STRIDE_NAME).get().get(0);
int strideWidth = method.getIntTupleValue(AllPredefinedLayers.STRIDE_NAME).get().get(1);
int inputHeight = inputType.getHeight();
int inputWidth = inputType.getWidth();
int outputWidth = (inputWidth + strideWidth - 1) / strideWidth;
int outputHeight = (inputHeight + strideWidth - 1) / strideHeight;
return Collections.singletonList(new ArchTypeSymbol.Builder()
.height(outputHeight)
.width(outputWidth)
.channels(channels)
.elementType("-oo", "oo")
.build());
}
protected List<String> computeStartAndEndValue(List<ArchTypeSymbol> inputTypes, BinaryOperator<Rational> startValAccumulator, BinaryOperator<Rational> endValAccumulator){
Stream.Builder<Rational> startValues = Stream.builder();
Stream.Builder<Rational> endValues = Stream.builder();
String start = null;
String end = null;
for (ArchTypeSymbol inputType : inputTypes){
Optional<ASTRange> range = inputType.getDomain().getRangeOpt();
if (range.isPresent()) {
if (range.get().hasNoLowerLimit()){
start = "-oo";
}
else {
startValues.add(range.get().getStartValue());
}
if (range.get().hasNoUpperLimit()){
end = "oo";
}
else {
endValues.add(range.get().getEndValue());
}
}
}
if (start == null){
start = "" + startValues.build().reduce(startValAccumulator).get().doubleValue();
}
if (end == null){
end = "" + endValues.build().reduce(endValAccumulator).get().doubleValue();
}
return Arrays.asList(start, end);
}
}
......@@ -35,9 +35,6 @@ public class UnrollDeclarationSymbol extends CommonScopeSpanningSymbol {
public static final UnrollDeclarationKind KIND = new UnrollDeclarationKind();
private List<ParameterSymbol> parameters;
private List<LayerSymbol> layers;
private SerialCompositeElementSymbol body;
protected UnrollDeclarationSymbol(String name) {
super(name, KIND);
......@@ -53,14 +50,6 @@ public class UnrollDeclarationSymbol extends CommonScopeSpanningSymbol {
return (UnrollDeclarationScope) super.getSpannedScope();
}
protected void setLayers(List<LayerSymbol> layers) {
this.layers = layers;
}
public List<LayerSymbol> getLayers() {
return layers;
}
public List<ParameterSymbol> getParameters() {
return parameters;
}
......@@ -82,23 +71,9 @@ public class UnrollDeclarationSymbol extends CommonScopeSpanningSymbol {
this.parameters.add(forParam);
forParam.putInScope(getSpannedScope());
}
}
public SerialCompositeElementSymbol getBody() {
return body;
}
protected void setBody(SerialCompositeElementSymbol body) {
this.body = body;
}
public boolean isPredefined() {
//Override by PredefinedUnrollDeclaration
return false;
}
public boolean isTrainable() {
return body.isTrainable();
for (ParameterSymbol param : parameters){
param.putInScope(getSpannedScope());
}
}
public Optional<ParameterSymbol> getParameter(String name) {
......@@ -141,66 +116,4 @@ public class UnrollDeclarationSymbol extends CommonScopeSpanningSymbol {
throw new IllegalArgumentException("Arguments with sequence expressions have to be resolved first before calling the layer method.");
}
}
public UnrollDeclarationSymbol deepCopy() {
UnrollDeclarationSymbol copy = new UnrollDeclarationSymbol(getName());
if (getAstNode().isPresent()){
copy.setAstNode(getAstNode().get());
}
System.err.println("parameters to copy: " + getParameters());
List<ParameterSymbol> parameterCopies = new ArrayList<>(getParameters().size());
for (ParameterSymbol parameter : getParameters()){
ParameterSymbol parameterCopy = parameter.deepCopy();
parameterCopies.add(parameterCopy);
parameterCopy.putInScope(copy.getSpannedScope());
}
copy.setParameters(parameterCopies);
copy.setBody(getBody().preResolveDeepCopy());
copy.getBody().putInScope(copy.getSpannedScope());
return copy;
}
/*public static class Builder{
private List<VariableSymbol> parameters = new ArrayList<>();
private CompositeElementSymbol body;
private String name = "";
public Builder parameters(List<VariableSymbol> parameters) {
this.parameters = parameters;
return this;
}
public Builder parameters(VariableSymbol... parameters) {
this.parameters = new ArrayList<>(Arrays.asList(parameters));
return this;
}
public Builder body(CompositeElementSymbol body) {
this.body = body;
return this;
}
public Builder name(String name) {
this.name = name;
return this;
}
public UnrollDeclarationSymbol build(){
if (name == null || name.equals("")){
throw new IllegalStateException("Missing or empty name for UnrollDeclarationSymbol");
}
UnrollDeclarationSymbol sym = new UnrollDeclarationSymbol(name);
sym.setBody(body);
if (body != null){
body.putInScope(sym.getSpannedScope());
}
for (VariableSymbol param : parameters){
param.putInScope(sym.getSpannedScope());
}
sym.setParameters(parameters);
return sym;
}
}*/
}
......@@ -93,6 +93,7 @@ public class UnrollSymbol extends ResolvableSymbol {
protected void setTimeParameter(ParameterSymbol timeParameter){
this.timeParameter = timeParameter;
this.timeParameter.putInScope(getSpannedScope());
}
protected void putInScope(Scope scope){
......@@ -114,45 +115,25 @@ public class UnrollSymbol extends ResolvableSymbol {
getDeclaration();
resolveExpressions();
getBody().resolveOrError();
for (int timestep = this.getIntValue(AllPredefinedLayers.T_NAME).get(); timestep < this.getIntValue(AllPredefinedLayers.MAX_LENGTH_NAME).get(); timestep++) {
SerialCompositeElementSymbol newBody = new SerialCompositeElementSymbol();
List<ArchitectureElementSymbol> newBodyList = new ArrayList<>();
SerialCompositeElementSymbol body = getBody().preResolveDeepCopy();
body.putInScope(getBody().getSpannedScope());
for (ArchitectureElementSymbol element : body.getElements()) {
if (element.getEnclosingScope() == null) {
element.setEnclosingScope(getEnclosingScope().getAsMutableScope());
}
((ParameterSymbol)((ArrayList<Symbol>)getSpannedScope().getLocalSymbols().get(AllPredefinedLayers.T_NAME)).get(0)).getExpression().setValue(timestep);
int startValue = getTimeParameter().getDefaultExpression().get().getIntValue().get();
int endValue = getIntValue(AllPredefinedLayers.MAX_LENGTH_NAME).get();
try {
this.resolveExpressions();
for (ParameterSymbol p:declaration.getParameters()) {
if (p.getEnclosingScope() == null) {
p.putInScope(getSpannedScope());
}
}
getTimeParameter().getExpression().setValue(1); // TODO: Change constant
getBody().resolveOrError();
element.resolve();
}
catch (ArchResolveException e) {
e.printStackTrace();
}
for (int timestep = startValue; timestep < endValue; timestep++) {
SerialCompositeElementSymbol currentBody = getBody().preResolveDeepCopy();
currentBody.putInScope(getBody().getSpannedScope());
newBodyList.add(element);
}
getTimeParameter().getExpression().setValue(timestep);
getTimeParameter().putInScope(currentBody.getEnclosingScope());
newBody.putInScope(this.getBody().getSpannedScope());
newBody.setElements(newBodyList);
currentBody.resolveOrError();
bodies.add(newBody);
bodies.add(currentBody);
}
UnrollSymbol resolvedUnroll = getDeclaration().call(this);
setResolvedThis(resolvedUnroll);
}
......
......@@ -46,7 +46,7 @@ public class AllPredefinedLayers {
public static final String CONCATENATE_NAME = "Concatenate";
public static final String FLATTEN_NAME = "Flatten";
public static final String ONE_HOT_NAME = "OneHot";
public static final String BEAMSEARCH_NAME = "BeamSearchStart";
public static final String BEAMSEARCH_NAME = "BeamSearch";
public static final String GREEDYSEARCH_NAME = "GreedySearch";
public static final String RNN_NAME = "RNN";
public static final String LSTM_NAME = "LSTM";
......@@ -78,7 +78,6 @@ public class AllPredefinedLayers {
public static final String FLATTEN_PARAMETER_NAME = "flatten";
public static final String MAX_LENGTH_NAME = "max_length";
public static final String WIDTH_NAME = "width";
public static final String T_NAME = "t";
//possible String values
public static final String PADDING_VALID = "valid";
......@@ -118,7 +117,7 @@ public class AllPredefinedLayers {
public static List<UnrollDeclarationSymbol> createUnrollList(){
return Arrays.asList(
GreedySearch.create(),
BeamSearchStart.create());
BeamSearch.create());
}
}
......@@ -26,34 +26,19 @@ import java.util.ArrayList;
import java.util.Arrays;
import java.util.List;
public class BeamSearchStart extends PredefinedUnrollDeclaration {
public class BeamSearch extends UnrollDeclarationSymbol {
private BeamSearchStart() {
private BeamSearch() {
super(AllPredefinedLayers.BEAMSEARCH_NAME);
}
@Override
public List<ArchTypeSymbol> computeOutputTypes(List<ArchTypeSymbol> inputTypes, UnrollSymbol layer, VariableSymbol.Member member) {
return new ArrayList<ArchTypeSymbol>();
}