Commit d0b670bd authored by Christian Fuß's avatar Christian Fuß
Browse files

Added BroadcastAdd and Reshape layer. Renamed Multiply layer to...

Added BroadcastAdd and Reshape layer. Renamed Multiply layer to BroadcastMultiply. Added optional axis parameter to Softmax layer
parent cbbb702c
Pipeline #199118 passed with stages
in 18 minutes and 28 seconds
......@@ -57,8 +57,10 @@ public class AllPredefinedLayers {
public static final String REPEAT_NAME = "Repeat";
public static final String REDUCE_SUM_NAME = "ReduceSum";
public static final String EXPAND_DIMS_NAME = "ExpandDims";
public static final String MULTIPLY_NAME = "Multiply";
public static final String BROADCAST_MULTIPLY_NAME = "BroadcastMultiply";
public static final String SWAPAXES_NAME = "SwapAxes";
public static final String BROADCAST_ADD_NAME = "BroadcastAdd";
public static final String RESHAPE_NAME = "Reshape";
//predefined argument names
public static final String KERNEL_NAME = "kernel";
......@@ -88,6 +90,7 @@ public class AllPredefinedLayers {
public static final String REPEATS_NAME = "n";
public static final String AXIS_NAME = "axis";
public static final String AXES_NAME = "axes";
public static final String SHAPE_NAME = "shape";
//possible String values
public static final String PADDING_VALID = "valid";
......@@ -126,8 +129,10 @@ public class AllPredefinedLayers {
Repeat.create(),
ReduceSum.create(),
ExpandDims.create(),
Multiply.create(),
SwapAxes.create());
BroadcastMultiply.create(),
SwapAxes.create(),
BroadcastAdd.create(),
Reshape.create());
}
public static List<UnrollDeclarationSymbol> createUnrollList(){
......
/**
*
* ******************************************************************************
* MontiCAR Modeling Family, www.se-rwth.de
* Copyright (c) 2017, Software Engineering Group at RWTH Aachen,
* All rights reserved.
*
* This project is free software; you can redistribute it and/or
* modify it under the terms of the GNU Lesser General Public
* License as published by the Free Software Foundation; either
* version 3.0 of the License, or (at your option) any later version.
* This library is distributed in the hope that it will be useful,
* but WITHOUT ANY WARRANTY; without even the implied warranty of
* MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the GNU
* Lesser General Public License for more details.
*
* You should have received a copy of the GNU Lesser General Public
* License along with this project. If not, see <http://www.gnu.org/licenses/>.
* *******************************************************************************
*/
package de.monticore.lang.monticar.cnnarch.predefined;
import de.monticore.lang.monticar.cnnarch._symboltable.*;
import de.monticore.lang.monticar.cnnarch.helper.ErrorCodes;
import de.se_rwth.commons.logging.Log;
import org.jscience.mathematics.number.Rational;
import java.util.ArrayList;
import java.util.Collections;
import java.util.List;
public class BroadcastAdd extends PredefinedLayerDeclaration {
private BroadcastAdd() {
super(AllPredefinedLayers.BROADCAST_ADD_NAME);
}
@Override
public List<ArchTypeSymbol> computeOutputTypes(List<ArchTypeSymbol> inputTypes, LayerSymbol layer, VariableSymbol.Member member) {
List<String> range = computeStartAndEndValue(layer.getInputTypes(), Rational::plus, Rational::plus);
return Collections.singletonList(
new ArchTypeSymbol.Builder()
.channels(Math.max(layer.getInputTypes().get(0).getChannels(), layer.getInputTypes().get(1).getChannels()))
.height(Math.max(layer.getInputTypes().get(0).getHeight(), layer.getInputTypes().get(1).getHeight()))
.width(Math.max(layer.getInputTypes().get(0).getWidth(), layer.getInputTypes().get(1).getWidth()))
.elementType(range.get(0), range.get(1))
.build());
}
@Override
public void checkInput(List<ArchTypeSymbol> inputTypes, LayerSymbol layer, VariableSymbol.Member member) {
errorIfInputIsEmpty(inputTypes, layer);
}
public static BroadcastAdd create(){
BroadcastAdd declaration = new BroadcastAdd();
declaration.setParameters(new ArrayList<>());
return declaration;
}
}
\ No newline at end of file
......@@ -30,55 +30,35 @@ import java.util.ArrayList;
import java.util.Collections;
import java.util.List;
public class Multiply extends PredefinedLayerDeclaration {
public class BroadcastMultiply extends PredefinedLayerDeclaration {
private Multiply() {
super(AllPredefinedLayers.MULTIPLY_NAME);
private BroadcastMultiply() {
super(AllPredefinedLayers.BROADCAST_MULTIPLY_NAME);
}
@Override
public List<ArchTypeSymbol> computeOutputTypes(List<ArchTypeSymbol> inputTypes, LayerSymbol layer, VariableSymbol.Member member) {
List<String> range = computeStartAndEndValue(layer.getInputTypes(), Rational::times, Rational::times);
return Collections.singletonList(new ArchTypeSymbol.Builder()
.channels(layer.getInputTypes().get(0).getChannels())
.height(layer.getInputTypes().get(0).getHeight())
.width(layer.getInputTypes().get(0).getWidth())
.elementType(range.get(0), range.get(1))
.build());
return Collections.singletonList(
new ArchTypeSymbol.Builder()
.channels(Math.max(layer.getInputTypes().get(0).getChannels(), layer.getInputTypes().get(1).getChannels()))
.height(Math.max(layer.getInputTypes().get(0).getHeight(), layer.getInputTypes().get(1).getHeight()))
.width(Math.max(layer.getInputTypes().get(0).getWidth(), layer.getInputTypes().get(1).getWidth()))
.elementType(range.get(0), range.get(1))
.build());
}
@Override
public void checkInput(List<ArchTypeSymbol> inputTypes, LayerSymbol layer, VariableSymbol.Member member) {
errorIfInputIsEmpty(inputTypes, layer);
if (inputTypes.size() == 1){
Log.warn("Multiply layer has only one input stream. Layer can be removed." , layer.getSourcePosition());
}
else if (inputTypes.size() > 1){
List<Integer> heightList = new ArrayList<>();
List<Integer> widthList = new ArrayList<>();
List<Integer> channelsList = new ArrayList<>();
for (ArchTypeSymbol shape : inputTypes){
heightList.add(shape.getHeight());
widthList.add(shape.getWidth());
channelsList.add(shape.getChannels());
}
int countEqualHeights = (int)heightList.stream().distinct().count();
int countEqualWidths = (int)widthList.stream().distinct().count();
int countEqualNumberOfChannels = (int)channelsList.stream().distinct().count();
if (countEqualHeights != 1 || countEqualWidths != 1 || countEqualNumberOfChannels != 1){
Log.error("0" + ErrorCodes.INVALID_ELEMENT_INPUT_SHAPE + " Invalid layer input. " +
"Shapes of all input streams must be equal. " +
"Input heights: " + Joiners.COMMA.join(heightList) + ". " +
"Input widths: " + Joiners.COMMA.join(widthList) + ". " +
"Number of input channels: " + Joiners.COMMA.join(channelsList) + ". "
, layer.getSourcePosition());
}
Log.warn("BroadcastMultiply layer has only one input stream. Layer can be removed." , layer.getSourcePosition());
}
}
public static Multiply create(){
Multiply declaration = new Multiply();
public static BroadcastMultiply create(){
BroadcastMultiply declaration = new BroadcastMultiply();
declaration.setParameters(new ArrayList<>());
return declaration;
}
......
/**
*
* ******************************************************************************
* MontiCAR Modeling Family, www.se-rwth.de
* Copyright (c) 2017, Software Engineering Group at RWTH Aachen,
* All rights reserved.
*
* This project is free software; you can redistribute it and/or
* modify it under the terms of the GNU Lesser General Public
* License as published by the Free Software Foundation; either
* version 3.0 of the License, or (at your option) any later version.
* This library is distributed in the hope that it will be useful,
* but WITHOUT ANY WARRANTY; without even the implied warranty of
* MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the GNU
* Lesser General Public License for more details.
*
* You should have received a copy of the GNU Lesser General Public
* License along with this project. If not, see <http://www.gnu.org/licenses/>.
* *******************************************************************************
*/
package de.monticore.lang.monticar.cnnarch.predefined;
import de.monticore.lang.monticar.cnnarch._symboltable.*;
import de.monticore.lang.monticar.cnnarch.helper.ErrorCodes;
import de.se_rwth.commons.logging.Log;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.Collections;
import java.util.List;
public class Reshape extends PredefinedLayerDeclaration {
private Reshape() {
super(AllPredefinedLayers.RESHAPE_NAME);
}
@Override
public List<ArchTypeSymbol> computeOutputTypes(List<ArchTypeSymbol> inputTypes, LayerSymbol layer, VariableSymbol.Member member) {
List<Integer> shape = layer.getIntTupleValue(AllPredefinedLayers.SHAPE_NAME).get();
int channels = -1;
int height = -1;
int width = -1;
// shape.get(0) will be the batch size, which we ignore here
if(shape.size() >=4){
width = shape.get(3);
}else{
width = layer.getInputTypes().get(0).getWidth();
}
if(shape.size() >=3){
height = shape.get(2);
}else{
height = layer.getInputTypes().get(0).getHeight();
}
if(shape.size() >=2){
channels = shape.get(1);
}else{
Log.error("0" + ErrorCodes.ILLEGAL_PARAMETER_VALUE + "\"Shape\" argument needs to contain at least two entries (batchSize and channels)"
, layer.getSourcePosition());
}
int totalSize = layer.getInputTypes().get(0).getChannels() * layer.getInputTypes().get(0).getHeight() * layer.getInputTypes().get(0).getWidth();
int newTotalSize = shape.stream().reduce(1, (x,y) -> x*y);
if(totalSize != newTotalSize && newTotalSize != 0){
Log.error("0" + ErrorCodes.INVALID_ELEMENT_INPUT_SHAPE + "The input of Reshape layer cannot be reshaped to the given shape. "
+ "Source and target shape have a different amount of total values", layer.getSourcePosition());
}
if(newTotalSize != 0) {
return Collections.singletonList(
new ArchTypeSymbol.Builder()
.channels(channels)
.height(height)
.width(width)
.elementType(layer.getInputTypes().get(0).getDomain())
.build());
}else{
if(height == -1){
channels = newTotalSize;
}else if(width == -1){
if(height == 0){
height = newTotalSize/channels;
}else if(channels == 0){
channels = newTotalSize/height;
}
}else{
if(width == 0){
width = newTotalSize/(channels*height);
}else if(height == 0){
height = newTotalSize/(channels*width);
}else if(channels == 0){
channels = newTotalSize/(width*height);
}
}
return Collections.singletonList(
new ArchTypeSymbol.Builder()
.channels(channels)
.height(height)
.width(width)
.elementType(layer.getInputTypes().get(0).getDomain())
.build());
}
}
@Override
public void checkInput(List<ArchTypeSymbol> inputTypes, LayerSymbol layer, VariableSymbol.Member member) {
errorIfInputSizeIsNotOne(inputTypes, layer);
}
public static Reshape create(){
Reshape declaration = new Reshape();
List<ParameterSymbol> parameters = new ArrayList<>(Arrays.asList(
new ParameterSymbol.Builder()
.name(AllPredefinedLayers.SHAPE_NAME)
.constraints(Constraints.INTEGER_TUPLE)
.build()));
declaration.setParameters(parameters);
return declaration;
}
}
\ No newline at end of file
......@@ -20,12 +20,10 @@
*/
package de.monticore.lang.monticar.cnnarch.predefined;
import de.monticore.lang.monticar.cnnarch._symboltable.ArchTypeSymbol;
import de.monticore.lang.monticar.cnnarch._symboltable.LayerSymbol;
import de.monticore.lang.monticar.cnnarch._symboltable.PredefinedLayerDeclaration;
import de.monticore.lang.monticar.cnnarch._symboltable.VariableSymbol;
import de.monticore.lang.monticar.cnnarch._symboltable.*;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.Collections;
import java.util.List;
......@@ -53,7 +51,13 @@ public class Softmax extends PredefinedLayerDeclaration {
public static Softmax create(){
Softmax declaration = new Softmax();
declaration.setParameters(new ArrayList<>());
List<ParameterSymbol> parameters = new ArrayList<>(Arrays.asList(
new ParameterSymbol.Builder()
.name(AllPredefinedLayers.AXIS_NAME)
.constraints(Constraints.INTEGER)
.defaultValue(-1)
.build()));
declaration.setParameters(parameters);
return declaration;
}
}
\ No newline at end of file
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment