Commit cbbb702c authored by Christian Fuß's avatar Christian Fuß
Browse files

added SwapAxes layer. Adjusted FullyConnected layer to work with RNN states.

parent 13b87e01
Pipeline #191051 failed with stages
in 37 seconds
......@@ -63,7 +63,11 @@ abstract public class PredefinedLayerDeclaration extends LayerDeclarationSymbol
}
public boolean isTrainable(VariableSymbol.Member member) {
return true;
if(member == VariableSymbol.Member.STATE || member == VariableSymbol.Member.OUTPUT){
return false;
}else {
return true;
}
}
/**
......
......@@ -58,6 +58,7 @@ public class AllPredefinedLayers {
public static final String REDUCE_SUM_NAME = "ReduceSum";
public static final String EXPAND_DIMS_NAME = "ExpandDims";
public static final String MULTIPLY_NAME = "Multiply";
public static final String SWAPAXES_NAME = "SwapAxes";
//predefined argument names
public static final String KERNEL_NAME = "kernel";
......@@ -86,6 +87,7 @@ public class AllPredefinedLayers {
public static final String WIDTH_NAME = "width";
public static final String REPEATS_NAME = "n";
public static final String AXIS_NAME = "axis";
public static final String AXES_NAME = "axes";
//possible String values
public static final String PADDING_VALID = "valid";
......@@ -124,7 +126,8 @@ public class AllPredefinedLayers {
Repeat.create(),
ReduceSum.create(),
ExpandDims.create(),
Multiply.create());
Multiply.create(),
SwapAxes.create());
}
public static List<UnrollDeclarationSymbol> createUnrollList(){
......
......@@ -38,37 +38,67 @@ public class Concatenate extends PredefinedLayerDeclaration {
@Override
public List<ArchTypeSymbol> computeOutputTypes(List<ArchTypeSymbol> inputTypes, LayerSymbol layer, VariableSymbol.Member member) {
int height = layer.getInputTypes().get(0).getHeight();
int width = layer.getInputTypes().get(0).getWidth();
int height = 0;
int width = 0;
int channels = 0;
for (ArchTypeSymbol inputShape : layer.getInputTypes()) {
channels += inputShape.getChannels();
}
int dim = layer.getIntValue(AllPredefinedLayers.DIM_NAME).get();
List<String> range = computeStartAndEndValue(layer.getInputTypes(), (x,y) -> x.isLessThan(y) ? x : y, (x,y) -> x.isLessThan(y) ? y : x);
return Collections.singletonList(new ArchTypeSymbol.Builder()
.channels(channels)
.height(height)
.width(width)
.elementType(range.get(0), range.get(1))
.build());
if(dim==0){
for (ArchTypeSymbol inputShape : layer.getInputTypes()) {
channels += inputShape.getChannels();
}
return Collections.singletonList(new ArchTypeSymbol.Builder()
.channels(channels)
.height(layer.getInputTypes().get(0).getHeight())
.width(layer.getInputTypes().get(0).getWidth())
.elementType(range.get(0), range.get(1))
.build());
}else if(dim==1){
for (ArchTypeSymbol inputShape : layer.getInputTypes()) {
height += inputShape.getHeight();
}
return Collections.singletonList(new ArchTypeSymbol.Builder()
.channels(layer.getInputTypes().get(0).getChannels())
.height(height)
.width(layer.getInputTypes().get(0).getWidth())
.elementType(range.get(0), range.get(1))
.build());
} else if(dim==2){
for (ArchTypeSymbol inputShape : layer.getInputTypes()) {
width += inputShape.getWidth();
}
return Collections.singletonList(new ArchTypeSymbol.Builder()
.channels(layer.getInputTypes().get(0).getChannels())
.height(layer.getInputTypes().get(0).getHeight())
.width(width)
.elementType(range.get(0), range.get(1))
.build());
}else{
return new ArrayList<>();
}
}
@Override
public void checkInput(List<ArchTypeSymbol> inputTypes, LayerSymbol layer, VariableSymbol.Member member) {
if (!inputTypes.isEmpty()) {
List<Integer> channelList = new ArrayList<>();
List<Integer> heightList = new ArrayList<>();
List<Integer> widthList = new ArrayList<>();
for (ArchTypeSymbol shape : inputTypes){
heightList.add(shape.getHeight());
widthList.add(shape.getWidth());
channelList.add(shape.getChannels());
}
int countEqualcHannels = (int)channelList.stream().distinct().count();
int countEqualHeights = (int)heightList.stream().distinct().count();
int countEqualWidths = (int)widthList.stream().distinct().count();
if (countEqualHeights != 1 || countEqualWidths != 1){
if (countEqualHeights != 1 && countEqualWidths != 1 && countEqualcHannels != 1){
Log.error("0" + ErrorCodes.INVALID_ELEMENT_INPUT_SHAPE + " Invalid layer input. " +
"Concatenation of inputs with different resolutions is not possible. " +
"Input channels: " + Joiners.COMMA.join(channelList) + ". " +
"Input heights: " + Joiners.COMMA.join(heightList) + ". " +
"Input widths: " + Joiners.COMMA.join(widthList) + ". "
, layer.getSourcePosition());
......
......@@ -59,12 +59,21 @@ public class FullyConnected extends PredefinedLayerDeclaration {
.build());
}
return Collections.singletonList(new ArchTypeSymbol.Builder()
.channels(inputType.getChannels())
.height(units)
.width(1)
.elementType("-oo", "oo")
.build());
// if input is an RNN state or output. Can be used to store states in layer variables
if(layer.getInputElement().get() instanceof VariableSymbol && (((VariableSymbol)(layer.getInputElement().get())).getMember() == VariableSymbol.Member.STATE || ((VariableSymbol)((ArchitectureElementSymbol)layer.getInputElement().get())).getMember() == VariableSymbol.Member.OUTPUT)){
return Collections.singletonList(new ArchTypeSymbol.Builder()
.channels(inputType.getChannels())
.height(units)
.elementType("-oo", "oo")
.build());
}else {
return Collections.singletonList(new ArchTypeSymbol.Builder()
.channels(inputType.getChannels())
.height(units)
.width(1)
.elementType("-oo", "oo")
.build());
}
}
return Collections.singletonList(new ArchTypeSymbol.Builder()
......
......@@ -34,6 +34,11 @@ public class Split extends PredefinedLayerDeclaration {
super(AllPredefinedLayers.SPLIT_NAME);
}
@Override
public boolean isTrainable() {
return false;
}
@Override
public List<ArchTypeSymbol> computeOutputTypes(List<ArchTypeSymbol> inputTypes, LayerSymbol layer, VariableSymbol.Member member) {
ArchTypeSymbol inputShape = layer.getInputTypes().get(0);
......
/**
*
* ******************************************************************************
* MontiCAR Modeling Family, www.se-rwth.de
* Copyright (c) 2017, Software Engineering Group at RWTH Aachen,
* All rights reserved.
*
* This project is free software; you can redistribute it and/or
* modify it under the terms of the GNU Lesser General Public
* License as published by the Free Software Foundation; either
* version 3.0 of the License, or (at your option) any later version.
* This library is distributed in the hope that it will be useful,
* but WITHOUT ANY WARRANTY; without even the implied warranty of
* MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the GNU
* Lesser General Public License for more details.
*
* You should have received a copy of the GNU Lesser General Public
* License along with this project. If not, see <http://www.gnu.org/licenses/>.
* *******************************************************************************
*/
package de.monticore.lang.monticar.cnnarch.predefined;
import de.monticore.lang.monticar.cnnarch._symboltable.*;
import de.monticore.lang.monticar.cnnarch.helper.ErrorCodes;
import de.se_rwth.commons.logging.Log;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.Collections;
import java.util.List;
public class SwapAxes extends PredefinedLayerDeclaration {
private SwapAxes() {
super(AllPredefinedLayers.SWAPAXES_NAME);
}
@Override
public List<ArchTypeSymbol> computeOutputTypes(List<ArchTypeSymbol> inputTypes, LayerSymbol layer, VariableSymbol.Member member) {
int firstAxis = layer.getIntTupleValue(AllPredefinedLayers.AXES_NAME).get().get(0);
int secondAxis = layer.getIntTupleValue(AllPredefinedLayers.AXES_NAME).get().get(1);
if((firstAxis == 0 && secondAxis == 1) || (firstAxis == 1 && secondAxis == 0)){
return Collections.singletonList(
new ArchTypeSymbol.Builder()
.channels(layer.getInputTypes().get(0).getHeight())
.height(layer.getInputTypes().get(0).getChannels())
.width(layer.getInputTypes().get(0).getWidth())
.elementType(layer.getInputTypes().get(0).getDomain())
.build());
}else if((firstAxis == 0 && secondAxis == 2) || (firstAxis == 2 && secondAxis == 0)){
return Collections.singletonList(
new ArchTypeSymbol.Builder()
.channels(layer.getInputTypes().get(0).getWidth())
.height(layer.getInputTypes().get(0).getHeight())
.width(layer.getInputTypes().get(0).getChannels())
.elementType(layer.getInputTypes().get(0).getDomain())
.build());
}else if((firstAxis == 1 && secondAxis == 2) || (firstAxis == 2 && secondAxis == 1)){
return Collections.singletonList(
new ArchTypeSymbol.Builder()
.channels(layer.getInputTypes().get(0).getChannels())
.height(layer.getInputTypes().get(0).getWidth())
.width(layer.getInputTypes().get(0).getHeight())
.elementType(layer.getInputTypes().get(0).getDomain())
.build());
}else{
if ((firstAxis < 0 || firstAxis > 2 || secondAxis < 0 || secondAxis > 2)){
Log.error("0" + ErrorCodes.ILLEGAL_PARAMETER_VALUE + " Illegal value for parameter axes. Values must be between 0 and 2"
, layer.getSourcePosition());
}
return new ArrayList<>();
}
}
@Override
public void checkInput(List<ArchTypeSymbol> inputTypes, LayerSymbol layer, VariableSymbol.Member member) {
errorIfInputSizeIsNotOne(inputTypes, layer);
}
public static SwapAxes create(){
SwapAxes declaration = new SwapAxes();
List<ParameterSymbol> parameters = new ArrayList<>(Arrays.asList(
new ParameterSymbol.Builder()
.name(AllPredefinedLayers.AXES_NAME)
.constraints(Constraints.INTEGER_TUPLE)
.build()));
declaration.setParameters(parameters);
return declaration;
}
}
\ No newline at end of file
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment