Commit 09af254b authored by Sebastian Nickels's avatar Sebastian Nickels
Browse files

Added Squeeze layer, updated RNN models, cleaned up some layers

parent cbbb702c
......@@ -184,6 +184,70 @@ public enum Constraints {
return AllPredefinedLayers.POOL_MAX + " or "
+ AllPredefinedLayers.POOL_AVG;
}
},
NULLABLE_AXIS {
@Override
public boolean isValid(ArchSimpleExpressionSymbol exp) {
if (exp.getIntValue().isPresent()){
int intValue = exp.getIntValue().get();
return intValue >= -1 && intValue <= 2; // -1 is null
}
return false;
}
@Override
public String msgString() {
return "an axis between 0 and 2 or -1";
}
},
NULLABLE_AXIS_WITHOUT_2 {
@Override
public boolean isValid(ArchSimpleExpressionSymbol exp) {
if (exp.getIntValue().isPresent()){
int intValue = exp.getIntValue().get();
return intValue >= -1 && intValue <= 2; // -1 is null
}
return false;
}
@Override
public String msgString() {
return "an axis between 0 and 1 or -1";
}
},
AXIS {
@Override
public boolean isValid(ArchSimpleExpressionSymbol exp) {
if (exp.getIntValue().isPresent()){
int intValue = exp.getIntValue().get();
return intValue >= 0 && intValue <= 2;
}
return false;
}
@Override
public String msgString() {
return "an axis between 0 and 2";
}
},
AXIS_WITHOUT_2 {
@Override
public boolean isValid(ArchSimpleExpressionSymbol exp) {
if (exp.getIntValue().isPresent()){
int intValue = exp.getIntValue().get();
return intValue >= 0 && intValue <= 1;
}
return false;
}
@Override
public String msgString() {
return "an axis between 0 and 1";
}
};
protected abstract boolean isValid(ArchSimpleExpressionSymbol exp);
......
......@@ -121,20 +121,6 @@ abstract public class PredefinedLayerDeclaration extends LayerDeclarationSymbol
}
}
protected void errorIfAxisNotFeasible(List<ArchTypeSymbol> inputTypes, LayerSymbol layer){
if ((!layer.getStringValue(AllPredefinedLayers.AXIS_NAME).isPresent() && layer.getIntValue(AllPredefinedLayers.AXIS_NAME).get() > 1)){
Log.error("0" + ErrorCodes.ILLEGAL_PARAMETER_VALUE + " Illegal value for parameter axis. Value must be None, 0 or 1"
, layer.getSourcePosition());
}
}
protected void errorIfDimNotFeasible(List<ArchTypeSymbol> inputTypes, LayerSymbol layer){
if (layer.getIntValue(AllPredefinedLayers.DIM_NAME).get() > 1){
Log.error("0" + ErrorCodes.ILLEGAL_PARAMETER_VALUE + " Illegal value for parameter dim. Value must be 0 or 1"
, layer.getSourcePosition());
}
}
protected void errorIfInputNotFeasibleForDotProduct(List<ArchTypeSymbol> inputTypes, LayerSymbol layer){
if(!(layer.getInputTypes().get(1).getHeight() == layer.getInputTypes().get(0).getWidth())){
Log.error("0" + ErrorCodes.INVALID_ELEMENT_INPUT_SHAPE + " Invalid layer input. Dot Product cannot be applied to input 1 with height " +
......
......@@ -55,6 +55,7 @@ public class AllPredefinedLayers {
public static final String ARG_MAX_NAME = "ArgMax";
public static final String DOT_NAME = "Dot";
public static final String REPEAT_NAME = "Repeat";
public static final String SQUEEZE_NAME = "Squeeze";
public static final String REDUCE_SUM_NAME = "ReduceSum";
public static final String EXPAND_DIMS_NAME = "ExpandDims";
public static final String MULTIPLY_NAME = "Multiply";
......@@ -80,7 +81,6 @@ public class AllPredefinedLayers {
public static final String LAYERS_NAME = "layers";
public static final String INPUT_DIM_NAME = "input_dim";
public static final String OUTPUT_DIM_NAME = "output_dim";
public static final String DIM_NAME = "dim";
public static final String BIDIRECTIONAL_NAME = "bidirectional";
public static final String FLATTEN_PARAMETER_NAME = "flatten";
public static final String MAX_LENGTH_NAME = "max_length";
......@@ -124,6 +124,7 @@ public class AllPredefinedLayers {
ArgMax.create(),
Dot.create(),
Repeat.create(),
Squeeze.create(),
ReduceSum.create(),
ExpandDims.create(),
Multiply.create(),
......
......@@ -29,6 +29,8 @@ import java.util.List;
abstract public class BaseRNN extends PredefinedLayerDeclaration {
protected int numberOfStates = 1;
public BaseRNN(String name) {
super(name);
}
......@@ -47,8 +49,9 @@ abstract public class BaseRNN extends PredefinedLayerDeclaration {
int layers = layer.getIntValue(AllPredefinedLayers.LAYERS_NAME).get();
return Collections.singletonList(new ArchTypeSymbol.Builder()
.channels(bidirectional ? 2 * layers : layers)
.height(units)
.channels(numberOfStates)
.height(bidirectional ? 2 * layers : layers)
.width(units)
.elementType("-oo", "oo")
.build());
}
......@@ -69,9 +72,9 @@ abstract public class BaseRNN extends PredefinedLayerDeclaration {
if (member == VariableSymbol.Member.STATE) {
errorIfInputSizeIsNotOne(inputTypes, layer);
errorIfInputChannelSizeIsInvalid(inputTypes, layer, bidirectional ? 2 * layers : layers);
errorIfInputHeightIsInvalid(inputTypes, layer, units);
errorIfInputWidthIsInvalid(inputTypes, layer, 1);
errorIfInputChannelSizeIsInvalid(inputTypes, layer, numberOfStates);
errorIfInputHeightIsInvalid(inputTypes, layer, bidirectional ? 2 * layers : layers);
errorIfInputWidthIsInvalid(inputTypes, layer, units);
}
else {
errorIfInputSizeIsNotOne(inputTypes, layer);
......
......@@ -38,74 +38,61 @@ public class Concatenate extends PredefinedLayerDeclaration {
@Override
public List<ArchTypeSymbol> computeOutputTypes(List<ArchTypeSymbol> inputTypes, LayerSymbol layer, VariableSymbol.Member member) {
int height = 0;
int width = 0;
int channels = 0;
int channels = layer.getInputTypes().get(0).getChannels();
int height = layer.getInputTypes().get(0).getHeight();
int width = layer.getInputTypes().get(0).getWidth();
int dim = layer.getIntValue(AllPredefinedLayers.DIM_NAME).get();
int axis = layer.getIntValue(AllPredefinedLayers.AXIS_NAME).get();
List<String> range = computeStartAndEndValue(layer.getInputTypes(), (x,y) -> x.isLessThan(y) ? x : y, (x,y) -> x.isLessThan(y) ? y : x);
if(dim==0){
for (ArchTypeSymbol inputShape : layer.getInputTypes()) {
List<ArchTypeSymbol> types = layer.getInputTypes();
types.remove(0);
if (axis == 0) {
for (ArchTypeSymbol inputShape : types) {
channels += inputShape.getChannels();
}
return Collections.singletonList(new ArchTypeSymbol.Builder()
.channels(channels)
.height(layer.getInputTypes().get(0).getHeight())
.width(layer.getInputTypes().get(0).getWidth())
.elementType(range.get(0), range.get(1))
.build());
}else if(dim==1){
for (ArchTypeSymbol inputShape : layer.getInputTypes()) {
} else if (axis == 1) {
for (ArchTypeSymbol inputShape : types) {
height += inputShape.getHeight();
}
return Collections.singletonList(new ArchTypeSymbol.Builder()
.channels(layer.getInputTypes().get(0).getChannels())
.height(height)
.width(layer.getInputTypes().get(0).getWidth())
.elementType(range.get(0), range.get(1))
.build());
} else if(dim==2){
for (ArchTypeSymbol inputShape : layer.getInputTypes()) {
} else {
for (ArchTypeSymbol inputShape : types) {
width += inputShape.getWidth();
}
return Collections.singletonList(new ArchTypeSymbol.Builder()
.channels(layer.getInputTypes().get(0).getChannels())
.height(layer.getInputTypes().get(0).getHeight())
.width(width)
.elementType(range.get(0), range.get(1))
.build());
}else{
return new ArrayList<>();
}
return Collections.singletonList(new ArchTypeSymbol.Builder()
.channels(channels)
.height(height)
.width(width)
.elementType(range.get(0), range.get(1))
.build());
}
@Override
public void checkInput(List<ArchTypeSymbol> inputTypes, LayerSymbol layer, VariableSymbol.Member member) {
if (!inputTypes.isEmpty()) {
List<Integer> channelList = new ArrayList<>();
List<Integer> heightList = new ArrayList<>();
List<Integer> widthList = new ArrayList<>();
for (ArchTypeSymbol shape : inputTypes){
heightList.add(shape.getHeight());
widthList.add(shape.getWidth());
channelList.add(shape.getChannels());
}
int countEqualcHannels = (int)channelList.stream().distinct().count();
int countEqualHeights = (int)heightList.stream().distinct().count();
int countEqualWidths = (int)widthList.stream().distinct().count();
if (countEqualHeights != 1 && countEqualWidths != 1 && countEqualcHannels != 1){
Log.error("0" + ErrorCodes.INVALID_ELEMENT_INPUT_SHAPE + " Invalid layer input. " +
"Concatenation of inputs with different resolutions is not possible. " +
"Input channels: " + Joiners.COMMA.join(channelList) + ". " +
"Input heights: " + Joiners.COMMA.join(heightList) + ". " +
"Input widths: " + Joiners.COMMA.join(widthList) + ". "
, layer.getSourcePosition());
}
errorIfInputIsEmpty(inputTypes, layer);
List<Integer> channelList = new ArrayList<>();
List<Integer> heightList = new ArrayList<>();
List<Integer> widthList = new ArrayList<>();
for (ArchTypeSymbol shape : inputTypes){
heightList.add(shape.getHeight());
widthList.add(shape.getWidth());
channelList.add(shape.getChannels());
}
else {
errorIfInputIsEmpty(inputTypes, layer);
int countEqualChannels = (int)channelList.stream().distinct().count();
int countEqualHeights = (int)heightList.stream().distinct().count();
int countEqualWidths = (int)widthList.stream().distinct().count();
if (countEqualHeights != 1 && countEqualWidths != 1 && countEqualChannels != 1){
Log.error("0" + ErrorCodes.INVALID_ELEMENT_INPUT_SHAPE + " Invalid layer input. " +
"Concatenation of inputs with different resolutions is not possible. " +
"Input channels: " + Joiners.COMMA.join(channelList) + ". " +
"Input heights: " + Joiners.COMMA.join(heightList) + ". " +
"Input widths: " + Joiners.COMMA.join(widthList) + ". "
, layer.getSourcePosition());
}
}
......@@ -113,9 +100,9 @@ public class Concatenate extends PredefinedLayerDeclaration {
Concatenate declaration = new Concatenate();
List<ParameterSymbol> parameters = new ArrayList<>(Arrays.asList(
new ParameterSymbol.Builder()
.name(AllPredefinedLayers.DIM_NAME)
.constraints(Constraints.INTEGER, Constraints.NON_NEGATIVE)
.defaultValue(1)
.name(AllPredefinedLayers.AXIS_NAME)
.constraints(Constraints.AXIS)
.defaultValue(0)
.build()));
declaration.setParameters(parameters);
return declaration;
......
......@@ -37,17 +37,17 @@ public class ExpandDims extends PredefinedLayerDeclaration {
@Override
public List<ArchTypeSymbol> computeOutputTypes(List<ArchTypeSymbol> inputTypes, LayerSymbol layer, VariableSymbol.Member member) {
int axis = layer.getIntValue(AllPredefinedLayers.AXIS_NAME).get();
int dim = layer.getIntValue(AllPredefinedLayers.DIM_NAME).get();
int channels = layer.getInputTypes().get(0).getChannels();
int height = layer.getInputTypes().get(0).getHeight();
int width = layer.getInputTypes().get(0).getWidth();
if (dim == 0) {
if (axis == 0) {
width = height;
height = channels;
channels = 1;
}else if (dim == 1) {
} else if (axis == 1) {
width = height;
height = 1;
}
......@@ -56,22 +56,22 @@ public class ExpandDims extends PredefinedLayerDeclaration {
.channels(channels)
.height(height)
.width(width)
.elementType("-oo", "oo")
.elementType(layer.getInputTypes().get(0).getDomain())
.build());
}
@Override
public void checkInput(List<ArchTypeSymbol> inputTypes, LayerSymbol layer, VariableSymbol.Member member) {
errorIfInputIsEmpty(inputTypes, layer);
errorIfDimNotFeasible(inputTypes, layer);
errorIfInputWidthIsInvalid(inputTypes, layer, 1);
}
public static ExpandDims create(){
ExpandDims declaration = new ExpandDims();
List<ParameterSymbol> parameters = new ArrayList<>(Arrays.asList(
new ParameterSymbol.Builder()
.name(AllPredefinedLayers.DIM_NAME)
.constraints(Constraints.INTEGER, Constraints.NON_NEGATIVE)
.name(AllPredefinedLayers.AXIS_NAME)
.constraints(Constraints.AXIS_WITHOUT_2)
.build()));
declaration.setParameters(parameters);
return declaration;
......
......@@ -24,6 +24,8 @@ public class LSTM extends BaseRNN {
private LSTM() {
super(AllPredefinedLayers.LSTM_NAME);
numberOfStates = 2;
}
public static LSTM create() {
......
......@@ -23,6 +23,7 @@ package de.monticore.lang.monticar.cnnarch.predefined;
import de.monticore.lang.monticar.cnnarch._symboltable.*;
import de.monticore.lang.monticar.cnnarch.helper.ErrorCodes;
import de.se_rwth.commons.logging.Log;
import org.fusesource.jansi.internal.Kernel32;
import java.util.ArrayList;
import java.util.Arrays;
......@@ -37,44 +38,33 @@ public class ReduceSum extends PredefinedLayerDeclaration {
@Override
public List<ArchTypeSymbol> computeOutputTypes(List<ArchTypeSymbol> inputTypes, LayerSymbol layer, VariableSymbol.Member member) {
int axis;
boolean axisIsNone = layer.getStringValue(AllPredefinedLayers.AXIS_NAME).isPresent();
int axis = layer.getIntValue(AllPredefinedLayers.AXIS_NAME).get();
int channels = layer.getInputTypes().get(0).getChannels();
int height = layer.getInputTypes().get(0).getHeight();
int width = layer.getInputTypes().get(0).getWidth();
if (axisIsNone) {
return Collections.singletonList(
new ArchTypeSymbol.Builder()
.channels(1)
.height(1)
.width(1)
.elementType("-oo", "oo")
.build());
if (axis == 0) {
height = 1;
} else if (axis == 1) {
width = 1;
} else {
axis = layer.getIntValue(AllPredefinedLayers.AXIS_NAME).get();
if (axis == 0) {
height = 1;
}else{
width = 1;
}
channels = 1;
height = 1;
width = 1;
}
return Collections.singletonList(new ArchTypeSymbol.Builder()
return Collections.singletonList(new ArchTypeSymbol.Builder()
.channels(channels)
.height(height)
.width(width)
.elementType("-oo", "oo")
.build());
}
}
@Override
public void checkInput(List<ArchTypeSymbol> inputTypes, LayerSymbol layer, VariableSymbol.Member member) {
errorIfInputIsEmpty(inputTypes, layer);
errorIfAxisNotFeasible(inputTypes, layer);
}
public static ReduceSum create(){
......@@ -82,6 +72,8 @@ public class ReduceSum extends PredefinedLayerDeclaration {
List<ParameterSymbol> parameters = new ArrayList<>(Arrays.asList(
new ParameterSymbol.Builder()
.name(AllPredefinedLayers.AXIS_NAME)
.constraints(Constraints.NULLABLE_AXIS_WITHOUT_2)
.defaultValue(-1)
.build()));
declaration.setParameters(parameters);
return declaration;
......
......@@ -37,45 +37,44 @@ public class Repeat extends PredefinedLayerDeclaration {
@Override
public List<ArchTypeSymbol> computeOutputTypes(List<ArchTypeSymbol> inputTypes, LayerSymbol layer, VariableSymbol.Member member) {
int repeats = layer.getIntValue(AllPredefinedLayers.REPEATS_NAME).get();
int axis;
boolean axisIsNone = layer.getStringValue(AllPredefinedLayers.AXIS_NAME).isPresent();
int axis = layer.getIntValue(AllPredefinedLayers.AXIS_NAME).get();
int channels = layer.getInputTypes().get(0).getChannels();
int height = layer.getInputTypes().get(0).getHeight();
int width = layer.getInputTypes().get(0).getWidth();
if(axisIsNone){
return Collections.singletonList(
new ArchTypeSymbol.Builder()
.channels(channels *= repeats)
.height(1)
.width(1)
.elementType("-oo", "oo")
.build());
}else {
axis = layer.getIntValue(AllPredefinedLayers.AXIS_NAME).get();
if(axis == 0){
height *= repeats;
}else{
width *= repeats;
}
return Collections.singletonList(
new ArchTypeSymbol.Builder()
.channels(channels)
.height(height)
.width(width)
.elementType("-oo", "oo")
.build());
if (axis == 0) {
channels *= repeats;
} else if (axis == 1) {
height *= repeats;
} else if (axis == 2) {
width *= repeats;
} else {
// when no axis is given, expand dimension and repeat in new, first dimension
width = height;
height = channels;
channels = repeats;
}
return Collections.singletonList(
new ArchTypeSymbol.Builder()
.channels(channels)
.height(height)
.width(width)
.elementType(layer.getInputTypes().get(0).getDomain())
.build());
}
@Override
public void checkInput(List<ArchTypeSymbol> inputTypes, LayerSymbol layer, VariableSymbol.Member member) {
errorIfInputSizeIsNotOne(inputTypes, layer);
errorIfAxisNotFeasible(inputTypes, layer);
int axis = layer.getIntValue(AllPredefinedLayers.AXIS_NAME).get();
if (axis == -1) {
errorIfInputWidthIsInvalid(inputTypes, layer, 1);
}
}
public static Repeat create(){
......@@ -85,10 +84,10 @@ public class Repeat extends PredefinedLayerDeclaration {
.name(AllPredefinedLayers.REPEATS_NAME)
.constraints(Constraints.INTEGER, Constraints.POSITIVE)
.build(),
// no constraints in order to allow 'None' value
new ParameterSymbol.Builder()
.name(AllPredefinedLayers.AXIS_NAME)
.defaultValue(1)
.constraints(Constraints.NULLABLE_AXIS)
.defaultValue(-1)
.build()));
declaration.setParameters(parameters);
return declaration;
......
/**
*
* ******************************************************************************
* MontiCAR Modeling Family, www.se-rwth.de
* Copyright (c) 2017, Software Engineering Group at RWTH Aachen,
* All rights reserved.
*
* This project is free software; you can redistribute it and/or
* modify it under the terms of the GNU Lesser General Public
* License as published by the Free Software Foundation; either
* version 3.0 of the License, or (at your option) any later version.
* This library is distributed in the hope that it will be useful,
* but WITHOUT ANY WARRANTY; without even the implied warranty of
* MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the GNU
* Lesser General Public License for more details.
*
* You should have received a copy of the GNU Lesser General Public
* License along with this project. If not, see <http://www.gnu.org/licenses/>.
* *******************************************************************************
*/
package de.monticore.lang.monticar.cnnarch.predefined;
import de.monticore.lang.monticar.cnnarch._symboltable.*;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.Collections;
import java.util.List;
public class Squeeze extends PredefinedLayerDeclaration {
private Squeeze() {
super(AllPredefinedLayers.SQUEEZE_NAME);
}
@Override
public List<ArchTypeSymbol> computeOutputTypes(List<ArchTypeSymbol> inputTypes, LayerSymbol layer, VariableSymbol.Member member) {
List<Integer> dimensions = layer.getInputTypes().get(0).getDimensions();
int axis = layer.getIntValue(AllPredefinedLayers.AXIS_NAME).get();
if (axis == -1) {
dimensions.remove(new Integer(1));
} else {
dimensions.remove(axis);
}
while (dimensions.size() < 3) {
dimensions.add(1);
}
int channels = dimensions.get(0);
int height = dimensions.get(1);
int width = dimensions.get(2);
return Collections.singletonList(
new ArchTypeSymbol.Builder()
.channels(channels)
.height(height)
.width(width)
.elementType(layer.getInputTypes().get(0).getDomain())
.build());
}
@Override
public void checkInput(List<ArchTypeSymbol> inputTypes, LayerSymbol