Unverified Commit 99bdbead authored by Thomas Michael Timmermanns's avatar Thomas Michael Timmermanns Committed by GitHub

Changed pooling and shape syntax. (#17)

* Syntax change.
Changed MaxPooling and AveragePooling to Pooling with type argument.
Added method GlobalPooling and removed 'global' argument of Pooling.
Changed special argument 'If' to '?'.

* Changed IO Shapes.
The architecture can now handle other data formats (until now: NHWC only).
parent 5f2575c7
This diff is collapsed.
......@@ -23,7 +23,11 @@ grammar CNNArch extends de.monticore.lang.math.Math {
ArchType implements Type = ElementType "^" Shape;
Shape = "{" dimensions:(ArchSimpleExpression || ",")+ "}";
Shape = "{" dimensions:(DimensionArgument || ",")* "}";
DimensionArgument = (name:"H" ":" height:ArchSimpleExpression
| name:"W" ":" width:ArchSimpleExpression
| name:"C" ":" channels:ArchSimpleExpression);
ArchitectureParameter implements Variable = Name& ("=" default:ArchSimpleExpression)?;
......@@ -46,7 +50,8 @@ grammar CNNArch extends de.monticore.lang.math.Math {
ArchParameterArgument implements ArchArgument = Name "=" rhs:ArchExpression;
ArchSpecialArgument implements ArchArgument = (serial:"->" | parallel:"|") "=" rhs:ArchExpression;
ArchSpecialArgument implements ArchArgument = (serial:"->" | parallel:"|" | conditional:"?") "="
rhs:ArchExpression;
ast ArchSpecialArgument = method public String getName(){return "";};
ParallelLayer implements ArchitectureElement = "(" groups:ArchBody "|" groups:(ArchBody || "|")+ ")";
......
......@@ -27,8 +27,8 @@ public class ASTArchSpecialArgument extends ASTArchSpecialArgumentTOP {
public ASTArchSpecialArgument() {
}
public ASTArchSpecialArgument(ASTArchExpression rhs, String serial, String parallel) {
super(rhs, serial, parallel);
public ASTArchSpecialArgument(ASTArchExpression rhs, String serial, String parallel, String conditional) {
super(rhs, serial, parallel, conditional);
}
@Override
......@@ -39,7 +39,12 @@ public class ASTArchSpecialArgument extends ASTArchSpecialArgumentTOP {
else if (getSerial().isPresent()) {
return AllPredefinedVariables.FOR_NAME;
}
return null;
else if (getConditional().isPresent()){
return AllPredefinedVariables.IF_NAME;
}
else {
throw new IllegalStateException();
}
}
}
......@@ -26,7 +26,6 @@ public class CNNArchPostResolveCocos {
return new CNNArchCoCoChecker()
.addCoCo(new CheckLayerInputs())
.addCoCo(new CheckIOAccessAndIOMissing())
.addCoCo(new CheckIOType())
.addCoCo(new CheckIOShape());
}
......
......@@ -31,7 +31,7 @@ public class CheckArgument implements CNNArchASTArchArgumentCoCo {
public void check(ASTArchArgument node) {
ArgumentSymbol argument = (ArgumentSymbol) node.getSymbol().get();
if (argument.getParameter() == null){
Log.error("0"+ ErrorCodes.UNKNOWN_ARGUMENT_CODE + " Unknown Argument. " +
Log.error("0"+ ErrorCodes.UNKNOWN_ARGUMENT + " Unknown Argument. " +
"Parameter with name '" + node.getName() + "' does not exist."
, node.get_SourcePositionStart());
}
......
......@@ -34,7 +34,7 @@ public class CheckIOName implements CNNArchASTIODeclarationCoCo {
@Override
public void check(ASTIODeclaration node) {
if (ioNames.contains(node.getName())){
Log.error("0" + ErrorCodes.DUPLICATED_NAME_CODE + " Duplicated IO name. " +
Log.error("0" + ErrorCodes.DUPLICATED_NAME + " Duplicated IO name. " +
"The name '" + node.getName() + "' is already used."
, node.get_SourcePositionStart());
}
......
......@@ -20,36 +20,59 @@
*/
package de.monticore.lang.monticar.cnnarch._cocos;
import de.monticore.lang.monticar.cnnarch._ast.ASTIODeclaration;
import de.monticore.lang.monticar.cnnarch._ast.ASTDimensionArgument;
import de.monticore.lang.monticar.cnnarch._ast.ASTShape;
import de.monticore.lang.monticar.cnnarch._symboltable.ArchSimpleExpressionSymbol;
import de.monticore.lang.monticar.cnnarch._symboltable.IODeclarationSymbol;
import de.monticore.lang.monticar.cnnarch._symboltable.ShapeSymbol;
import de.monticore.lang.monticar.cnnarch.helper.ErrorCodes;
import de.se_rwth.commons.logging.Log;
import java.util.Optional;
public class CheckIOShape implements CNNArchASTIODeclarationCoCo {
public class CheckIOShape implements CNNArchASTShapeCoCo {
@Override
public void check(ASTIODeclaration node) {
int shapeSize = node.getType().getShape().getDimensions().size();
if (shapeSize != 1 && shapeSize != 3){
Log.error("0" + ErrorCodes.INVALID_IO_SHAPE + " Invalid shape. " +
"IO Shape has to be either {height, width, channels} or {channels}."
, node.getType().getShape().get_SourcePositionStart());
}
else {
IODeclarationSymbol ioDeclaration = (IODeclarationSymbol) node.getSymbol().get();
for (ArchSimpleExpressionSymbol dimension : ioDeclaration.getShape().getDimensionSymbols()){
Optional<Integer> value = dimension.getIntValue();
if (!value.isPresent() || value.get() <= 0){
Log.error("0" + ErrorCodes.INVALID_IO_SHAPE + " Invalid shape. " +
"The dimension can only be defined by a positive integer."
, dimension.getSourcePosition());
public void check(ASTShape node) {
boolean hasHeight = false;
boolean hasWidth = false;
boolean hasChannels = false;
for (ASTDimensionArgument dimensionArg : node.getDimensions()){
if (dimensionArg.getWidth().isPresent()){
if (hasWidth){
repetitionError(dimensionArg);
}
hasWidth = true;
}
else if (dimensionArg.getHeight().isPresent()){
if (hasHeight){
repetitionError(dimensionArg);
}
hasHeight = true;
}
else {
if (hasChannels){
repetitionError(dimensionArg);
}
hasChannels = true;
}
}
ShapeSymbol shape = (ShapeSymbol) node.getSymbol().get();
for (ArchSimpleExpressionSymbol dimension : shape.getDimensionSymbols()){
Optional<Integer> value = dimension.getIntValue();
if (!value.isPresent() || value.get() <= 0){
Log.error("0" + ErrorCodes.INVALID_IO_SHAPE + " Invalid shape. " +
"The dimensions can only be defined by a positive integer."
, dimension.getSourcePosition());
}
}
}
private void repetitionError(ASTDimensionArgument node){
Log.error("0" + ErrorCodes.INVALID_IO_SHAPE + " Invalid shape. " +
"The dimension '" + node.getName().get() + "' was defined multiple times. "
, node.get_SourcePositionStart());
}
}
/**
*
* ******************************************************************************
* MontiCAR Modeling Family, www.se-rwth.de
* Copyright (c) 2017, Software Engineering Group at RWTH Aachen,
* All rights reserved.
*
* This project is free software; you can redistribute it and/or
* modify it under the terms of the GNU Lesser General Public
* License as published by the Free Software Foundation; either
* version 3.0 of the License, or (at your option) any later version.
* This library is distributed in the hope that it will be useful,
* but WITHOUT ANY WARRANTY; without even the implied warranty of
* MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the GNU
* Lesser General Public License for more details.
*
* You should have received a copy of the GNU Lesser General Public
* License along with this project. If not, see <http://www.gnu.org/licenses/>.
* *******************************************************************************
*/
package de.monticore.lang.monticar.cnnarch._cocos;
import de.monticore.lang.monticar.cnnarch._ast.ASTArchitecture;
public class CheckIOType implements CNNArchASTArchitectureCoCo {
@Override
public void check(ASTArchitecture node) {
//todo:
}
}
......@@ -26,7 +26,6 @@ import de.monticore.lang.monticar.cnnarch._symboltable.MethodDeclarationSymbol;
import de.monticore.lang.monticar.cnnarch._symboltable.MethodLayerSymbol;
import de.monticore.lang.monticar.cnnarch._symboltable.VariableSymbol;
import de.monticore.lang.monticar.cnnarch.helper.ErrorCodes;
import de.monticore.lang.monticar.cnnarch.predefined.AllPredefinedMethods;
import de.se_rwth.commons.logging.Log;
import java.util.HashSet;
......@@ -40,7 +39,7 @@ public class CheckMethodLayer implements CNNArchASTMethodLayerCoCo{
for (ASTArchArgument argument : node.getArguments()){
String name = argument.getName();
if (nameSet.contains(name)){
Log.error("0" + ErrorCodes.DUPLICATED_ARG_CODE + " Duplicated name: " + name +
Log.error("0" + ErrorCodes.DUPLICATED_ARG + " Duplicated name: " + name +
". Multiple values assigned to the same argument."
, argument.get_SourcePositionStart());
}
......@@ -51,7 +50,7 @@ public class CheckMethodLayer implements CNNArchASTMethodLayerCoCo{
MethodDeclarationSymbol method = ((MethodLayerSymbol) node.getSymbol().get()).getMethod();
if (method == null){
Log.error("0" + ErrorCodes.UNKNOWN_METHOD_CODE + " Unknown method error. " +
Log.error("0" + ErrorCodes.UNKNOWN_METHOD + " Unknown method error. " +
"Method with name '" + node.getName() + "' does not exist"
, node.get_SourcePositionStart());
}
......@@ -64,13 +63,10 @@ public class CheckMethodLayer implements CNNArchASTMethodLayerCoCo{
}
for (ASTArchArgument argument : node.getArguments()){
requiredArguments.remove(argument.getName());
if (argument.getName().equals(AllPredefinedMethods.GLOBAL_NAME)){
requiredArguments.remove(AllPredefinedMethods.KERNEL_NAME);
}
}
for (String missingArgumentName : requiredArguments){
Log.error("0"+ErrorCodes.MISSING_ARGUMENT_CODE + " Missing argument. " +
Log.error("0"+ErrorCodes.MISSING_ARGUMENT + " Missing argument. " +
"The argument '" + missingArgumentName + "' is required."
, node.get_SourcePositionStart());
}
......
......@@ -21,10 +21,6 @@
package de.monticore.lang.monticar.cnnarch._cocos;
import de.monticore.lang.monticar.cnnarch._ast.ASTMethodDeclaration;
import de.monticore.lang.monticar.cnnarch._symboltable.CompositeLayerSymbol;
import de.monticore.lang.monticar.cnnarch._symboltable.LayerSymbol;
import de.monticore.lang.monticar.cnnarch._symboltable.MethodDeclarationSymbol;
import de.monticore.lang.monticar.cnnarch._symboltable.MethodLayerSymbol;
import de.monticore.lang.monticar.cnnarch.helper.ErrorCodes;
import de.se_rwth.commons.logging.Log;
......@@ -39,13 +35,13 @@ public class CheckMethodName implements CNNArchASTMethodDeclarationCoCo {
public void check(ASTMethodDeclaration node) {
String name = node.getName();
if (name.isEmpty() || !Character.isLowerCase(name.codePointAt(0))){
Log.error("0" + ErrorCodes.ILLEGAL_NAME_CODE + " Illegal name: " + name +
Log.error("0" + ErrorCodes.ILLEGAL_NAME + " Illegal name: " + name +
". All new variable and method names have to start with a lowercase letter. "
, node.get_SourcePositionStart());
}
if (methodNames.contains(name)){
Log.error("0" + ErrorCodes.DUPLICATED_NAME_CODE + " Duplicated method name. " +
Log.error("0" + ErrorCodes.DUPLICATED_NAME + " Duplicated method name. " +
"The name '" + name + "' is already used."
, node.get_SourcePositionStart());
}
......
......@@ -55,7 +55,7 @@ public class CheckMethodRecursion implements CNNArchASTMethodDeclarationCoCo {
if (method != null && !method.isPredefined() && !seenMethods.contains(method)) {
seenMethods.add(method);
if (startingMethod == method) {
Log.error("0" + ErrorCodes.RECURSION_ERROR_CODE + " Recursion is not allowed. " +
Log.error("0" + ErrorCodes.RECURSION_ERROR + " Recursion is not allowed. " +
"The method '" + startingMethod.getName() + "' creates a recursive cycle."
, startingMethod.getSourcePosition());
done = true;
......
......@@ -43,7 +43,7 @@ public class CheckUnknownIO implements CNNArchASTIOLayerCoCo {
}
if (ioDeclaration == null){
Log.error("0" + ErrorCodes.UNKNOWN_IO_CODE + " Unknown input or output name. " +
Log.error("0" + ErrorCodes.UNKNOWN_IO + " Unknown input or output name. " +
"The input or output '" + node.getName() + "' does not exist"
, node.get_SourcePositionStart());
}
......
......@@ -45,17 +45,17 @@ public class CheckVariableName implements CNNArchASTVariableCoCo {
private void checkForIllegalNames(ASTVariable node){
String name = node.getName();
if (name.isEmpty() || !Character.isLowerCase(name.codePointAt(0))){
Log.error("0" + ErrorCodes.ILLEGAL_NAME_CODE + " Illegal name: " + name +
Log.error("0" + ErrorCodes.ILLEGAL_NAME + " Illegal name: " + name +
". All new variable and method names have to start with a lowercase letter. "
, node.get_SourcePositionStart());
}
else if (name.equals(AllPredefinedVariables.TRUE_NAME) || name.equals(AllPredefinedVariables.FALSE_NAME)){
Log.error("0" + ErrorCodes.ILLEGAL_NAME_CODE + " Illegal name: " + name +
Log.error("0" + ErrorCodes.ILLEGAL_NAME + " Illegal name: " + name +
". No variable can be named 'true' or 'false'"
, node.get_SourcePositionStart());
}
else if (name.equals(AllPredefinedVariables.IF_NAME.toLowerCase())){
Log.error("0" + ErrorCodes.ILLEGAL_NAME_CODE + " Illegal name: " + name +
Log.error("0" + ErrorCodes.ILLEGAL_NAME + " Illegal name: " + name +
". No variable can be named 'if'"
, node.get_SourcePositionStart());
}
......@@ -80,7 +80,7 @@ public class CheckVariableName implements CNNArchASTVariableCoCo {
}
private void duplicationError(ASTVariable node){
Log.error("0" + ErrorCodes.DUPLICATED_NAME_CODE + " Duplicated variable name. " +
Log.error("0" + ErrorCodes.DUPLICATED_NAME + " Duplicated variable name. " +
"The name '" + node.getName() + "' is already used."
, node.get_SourcePositionStart());
}
......
......@@ -120,7 +120,7 @@ public class ArchitectureSymbol extends ArchitectureSymbolTOP {
public void checkParameters(){
for (VariableSymbol parameter : getParameters()){
if (!parameter.hasExpression()){
Log.error("0" + ErrorCodes.MISSING_VAR_VALUE_CODE + " Missing architecture argument. " +
Log.error("0" + ErrorCodes.MISSING_VAR_VALUE + " Missing architecture argument. " +
"The parameter '" + parameter.getName() + "' has no value.");
}
}
......
......@@ -198,21 +198,28 @@ public class CNNArchSymbolTableCreator extends de.monticore.symboltable.CommonSy
@Override
public void endVisit(ASTShape node) {
ShapeSymbol sym = (ShapeSymbol) node.getSymbol().get();
if (node.getDimensions().size() == 1){
ArchSimpleExpressionSymbol channels = (ArchSimpleExpressionSymbol) node.getDimensions().get(0).getSymbol().get();
sym.setChannels(channels);
}
else if (node.getDimensions().size() == 3){
ArchSimpleExpressionSymbol height = (ArchSimpleExpressionSymbol) node.getDimensions().get(ShapeSymbol.HEIGHT_INDEX - 1).getSymbol().get();
ArchSimpleExpressionSymbol width = (ArchSimpleExpressionSymbol) node.getDimensions().get(ShapeSymbol.WIDTH_INDEX - 1).getSymbol().get();
ArchSimpleExpressionSymbol channels = (ArchSimpleExpressionSymbol) node.getDimensions().get(ShapeSymbol.CHANNEL_INDEX - 1).getSymbol().get();
sym.setHeight(height);
sym.setWidth(width);
sym.setChannels(channels);
}
else {
//do nothing; will be checked in coco
List<ArchSimpleExpressionSymbol> dimensionList = new ArrayList<>(3);
for (int i = 0; i < node.getDimensions().size(); i++){
ASTDimensionArgument dimensionArg = node.getDimensions().get(i);
if (dimensionArg.getHeight().isPresent()){
sym.setHeightIndex(i);
ArchSimpleExpressionSymbol exp = (ArchSimpleExpressionSymbol) dimensionArg.getHeight().get().getSymbol().get();
dimensionList.add(exp);
}
else if (dimensionArg.getWidth().isPresent()){
sym.setWidthIndex(i);
ArchSimpleExpressionSymbol exp = (ArchSimpleExpressionSymbol) dimensionArg.getWidth().get().getSymbol().get();
dimensionList.add(exp);
}
else {
sym.setChannelIndex(i);
ArchSimpleExpressionSymbol exp = (ArchSimpleExpressionSymbol) dimensionArg.getChannels().get().getSymbol().get();
dimensionList.add(exp);
}
}
sym.setDimensionSymbols(dimensionList);
addToScopeAndLinkWithNode(sym, node);
}
......
......@@ -28,7 +28,7 @@ import de.se_rwth.commons.logging.Log;
import java.util.List;
import java.util.Optional;
import static de.monticore.lang.monticar.cnnarch.helper.ErrorCodes.ILLEGAL_ASSIGNMENT_CODE;
import static de.monticore.lang.monticar.cnnarch.helper.ErrorCodes.ILLEGAL_ASSIGNMENT;
public enum Constraints {
NUMBER {
......@@ -163,7 +163,28 @@ public enum Constraints {
@Override
protected String msgString() {
return AllPredefinedMethods.PADDING_VALID + " or " + AllPredefinedMethods.PADDING_SAME;
return AllPredefinedMethods.PADDING_VALID + ", "
+ AllPredefinedMethods.PADDING_SAME + " or "
+ AllPredefinedMethods.PADDING_NO_LOSS;
}
},
POOL_TYPE {
@Override
public boolean isValid(ArchSimpleExpressionSymbol exp) {
Optional<String> optString= exp.getStringValue();
if (optString.isPresent()){
if (optString.get().equals(AllPredefinedMethods.POOL_MAX)
|| optString.get().equals(AllPredefinedMethods.POOL_AVG)){
return true;
}
}
return false;
}
@Override
protected String msgString() {
return AllPredefinedMethods.POOL_MAX + " or "
+ AllPredefinedMethods.POOL_AVG;
}
};
......@@ -201,7 +222,7 @@ public enum Constraints {
for (List<ArchSimpleExpressionSymbol> expList : exp.getElements().get()) {
for (ArchSimpleExpressionSymbol singleExp : expList) {
if (!isValid(singleExp)) {
Log.error("0" + ILLEGAL_ASSIGNMENT_CODE + " Illegal assignment of '" + printName(name) + "'. " +
Log.error("0" + ILLEGAL_ASSIGNMENT + " Illegal assignment of '" + printName(name) + "'. " +
"Expression must be " + msgString() + "."
, sourcePosition);
return false;
......
......@@ -354,7 +354,7 @@ public class MethodLayerSymbol extends LayerSymbol {
length = argLength;
}
else if (length != argLength) {
Log.error("0" + ErrorCodes.ILLEGAL_SEQUENCE_LENGTH_CODE + " Illegal sequence length. " +
Log.error("0" + ErrorCodes.ILLEGAL_SEQUENCE_LENGTH + " Illegal sequence length. " +
"Length is " + argLength + " but it should be " + length + " or not a sequence. " +
"All parallel sequences in the same method layer must be of the same size. "
, argument.getSourcePosition());
......@@ -419,7 +419,7 @@ public class MethodLayerSymbol extends LayerSymbol {
}
}
else if (argLength != 1 && argLength != serialLength){
Log.error("0" + ErrorCodes.ILLEGAL_SEQUENCE_LENGTH_CODE + " Illegal sequence length. " +
Log.error("0" + ErrorCodes.ILLEGAL_SEQUENCE_LENGTH + " Illegal sequence length. " +
"Length of sequence dimension "+ serialIndex +" is " + argLength + " but it should be " + serialLength + " or not a sequence. " +
"All serial sequences of the same paralle dimension in the same method layer must be of the same size. "
, getSourcePosition());
......
......@@ -67,29 +67,23 @@ abstract public class PredefinedMethodDeclaration extends MethodDeclarationSymbo
//check input for convolution and pooling
protected static void errorIfInputSmallerThanKernel(List<ShapeSymbol> inputShapes, MethodLayerSymbol layer){
if (!inputShapes.isEmpty()) {
Optional<Boolean> optGlobal = layer.getBooleanValue(AllPredefinedMethods.GLOBAL_NAME);
if (optGlobal.isPresent() && optGlobal.get()){
//do nothing
}
else{
int inputHeight = inputShapes.get(0).getHeight().get();
int inputWidth = inputShapes.get(0).getWidth().get();
int kernelHeight = layer.getIntTupleValue(AllPredefinedMethods.KERNEL_NAME).get().get(0);
int kernelWidth = layer.getIntTupleValue(AllPredefinedMethods.KERNEL_NAME).get().get(1);
if (kernelHeight > inputHeight || kernelWidth > inputWidth){
if (layer.getStringValue(AllPredefinedMethods.PADDING_NAME).equals(AllPredefinedMethods.PADDING_VALID)){
Log.error("0" + ErrorCodes.INVALID_LAYER_INPUT + " Invalid layer input. " +
"The input resolution is smaller than the kernel and the padding mode is 'valid'." +
"This would result in an output resolution of 0x0."
, layer.getSourcePosition());
}
else {
Log.warn("The input resolution is smaller than the kernel. " +
"This results in an output resolution of 1x1. " +
"If this warning appears multiple times, consider changing your architecture"
, layer.getSourcePosition());
}
int inputHeight = inputShapes.get(0).getHeight().get();
int inputWidth = inputShapes.get(0).getWidth().get();
int kernelHeight = layer.getIntTupleValue(AllPredefinedMethods.KERNEL_NAME).get().get(0);
int kernelWidth = layer.getIntTupleValue(AllPredefinedMethods.KERNEL_NAME).get().get(1);
if (kernelHeight > inputHeight || kernelWidth > inputWidth){
if (layer.getStringValue(AllPredefinedMethods.PADDING_NAME).equals(AllPredefinedMethods.PADDING_VALID)){
Log.error("0" + ErrorCodes.INVALID_LAYER_INPUT + " Invalid layer input. " +
"The input resolution is smaller than the kernel and the padding mode is 'valid'." +
"This would result in an output resolution of 0x0."
, layer.getSourcePosition());
}
else {
Log.warn("The input resolution is smaller than the kernel. " +
"This results in an output resolution of 1x1. " +
"If this warning appears multiple times, consider changing your architecture"
, layer.getSourcePosition());
}
}
}
......@@ -97,36 +91,21 @@ abstract public class PredefinedMethodDeclaration extends MethodDeclarationSymbo
//output shape function for convolution and pooling
protected static List<ShapeSymbol> computeConvAndPoolOutputShape(ShapeSymbol inputShape, MethodLayerSymbol method, int channels) {
Optional<Boolean> optGlobal = method.getBooleanValue(AllPredefinedMethods.GLOBAL_NAME);
if (optGlobal.isPresent() && optGlobal.get()){
//argument global is true which means the pooling is applied to the whole input and is flattened. kernel, stride and border_mode is ignored.
return computeOutputShapeForGlobalPooling(channels);
String borderModeSetting = method.getStringValue(AllPredefinedMethods.PADDING_NAME).get();
if (borderModeSetting.equals(AllPredefinedMethods.PADDING_SAME)){
return computeOutputShapeWithSamePadding(inputShape, method, channels);
}
else if (borderModeSetting.equals(AllPredefinedMethods.PADDING_VALID)){
return computeOutputShapeWithValidPadding(inputShape, method, channels);
}
else if (borderModeSetting.equals(AllPredefinedMethods.PADDING_NO_LOSS)){
return computeOutputShapeWithNoLossPadding(inputShape, method, channels);
}
else{
String borderModeSetting = method.getStringValue(AllPredefinedMethods.PADDING_NAME).get();
if (borderModeSetting.equals(AllPredefinedMethods.PADDING_SAME)){
return computeOutputShapeWithSamePadding(inputShape, method, channels);
}
else if (borderModeSetting.equals(AllPredefinedMethods.PADDING_VALID)){
return computeOutputShapeWithValidPadding(inputShape, method, channels);
}
else if (borderModeSetting.equals(AllPredefinedMethods.PADDING_NO_LOSS)){
return computeOutputShapeWithNoLossPadding(inputShape, method, channels);
}
else{
throw new IllegalStateException("border_mode is " + borderModeSetting + ". This should never happen.");
}
throw new IllegalStateException("border_mode is " + borderModeSetting + ". This should never happen.");
}
}
private static List<ShapeSymbol> computeOutputShapeForGlobalPooling(int channels){
return Collections.singletonList(new ShapeSymbol.Builder()
.height(1)
.width(1)
.channels(channels)
.build());
}
//padding with border_mode=valid, no padding
private static List<ShapeSymbol> computeOutputShapeWithValidPadding(ShapeSymbol inputShape, MethodLayerSymbol method, int channels){
int strideHeight = method.getIntTupleValue(AllPredefinedMethods.STRIDE_NAME).get().get(0);
......
......@@ -30,75 +30,59 @@ public class ShapeSymbol extends CommonSymbol {
public static final ShapeKind KIND = new ShapeKind();
public static final int BATCH_SIZE_INDEX = 0;
public static final int HEIGHT_INDEX = 1;
public static final int WIDTH_INDEX = 2;
public static final int CHANNEL_INDEX = 3;
private int channelIndex = -1;
private int heightIndex = -1;
private int widthIndex = -1;
private List<ArchSimpleExpressionSymbol> dimensions =
Arrays.asList(ArchSimpleExpressionSymbol.of(1),
ArchSimpleExpressionSymbol.of(1),
ArchSimpleExpressionSymbol.of(1),
ArchSimpleExpressionSymbol.of(1));
private List<ArchSimpleExpressionSymbol> dimensions = new ArrayList<>();
public ShapeSymbol() {
super("", KIND);
}
public ArchSimpleExpressionSymbol getBatchSizeSymbol() {
return dimensions.get(BATCH_SIZE_INDEX);
public int getHeightIndex() {
return heightIndex;
}
public void setBatchSize(int batchSize) {
dimensions.get(BATCH_SIZE_INDEX).reset();
dimensions.get(BATCH_SIZE_INDEX).setValue(batchSize);
dimensions.get(BATCH_SIZE_INDEX).setMathExpression(null);
protected void setHeightIndex(int heightIndex) {