Commit 128601f9 authored by Christian Fuß's avatar Christian Fuß
Browse files

adjusted Reshape layer to abstract batch size

parent 63534968
Pipeline #211278 passed with stages
in 16 minutes and 46 seconds
......@@ -44,68 +44,56 @@ public class Reshape extends PredefinedLayerDeclaration {
int height = -1;
int width = -1;
// shape.get(0) will be the batch size, which we ignore here
if(shape.size() >=4){
width = shape.get(3);
}else{
width = layer.getInputTypes().get(0).getWidth();
if (shape.size() >= 3) {
width = shape.get(2);
}
if(shape.size() >=3){
height = shape.get(2);
}else{
height = layer.getInputTypes().get(0).getHeight();
if (shape.size() >= 2) {
height = shape.get(1);
}
if(shape.size() >=2){
channels = shape.get(1);
}else{
Log.error("0" + ErrorCodes.ILLEGAL_PARAMETER_VALUE + "\"Shape\" argument needs to contain at least two entries (batchSize and channels)"
if (shape.size() >= 1) {
channels = shape.get(0);
} else {
Log.error("0" + ErrorCodes.ILLEGAL_PARAMETER_VALUE + "\"Shape\" argument needs to contain at least one entry"
, layer.getSourcePosition());
}
int totalSize = layer.getInputTypes().get(0).getChannels() * layer.getInputTypes().get(0).getHeight() * layer.getInputTypes().get(0).getWidth();
int newTotalSize = shape.stream().reduce(1, (x,y) -> x*y);
int newTotalSize = shape.stream().reduce(1, (x, y) -> x * y);
if(totalSize != newTotalSize && newTotalSize != 0){
if (totalSize != newTotalSize && newTotalSize != 0) {
Log.error("0" + ErrorCodes.INVALID_ELEMENT_INPUT_SHAPE + "The input of Reshape layer cannot be reshaped to the given shape. "
+ "Source and target shape have a different amount of total values", layer.getSourcePosition());
}
if(newTotalSize != 0) {
return Collections.singletonList(
new ArchTypeSymbol.Builder()
.channels(channels)
.height(height)
.width(width)
.elementType(layer.getInputTypes().get(0).getDomain())
.build());
}else{
if(height == -1){
channels = newTotalSize;
}else if(width == -1){
if(height == 0){
height = newTotalSize/channels;
}else if(channels == 0){
channels = newTotalSize/height;
}
}else{
if(width == 0){
width = newTotalSize/(channels*height);
}else if(height == 0){
height = newTotalSize/(channels*width);
}else if(channels == 0){
channels = newTotalSize/(width*height);
}
if (newTotalSize != 0) {
if (width != -1) {
return Collections.singletonList(
new ArchTypeSymbol.Builder()
.channels(channels)
.height(height)
.width(width)
.elementType(layer.getInputTypes().get(0).getDomain())
.build());
} else if (height != -1) {
return Collections.singletonList(
new ArchTypeSymbol.Builder()
.channels(channels)
.height(height)
.elementType(layer.getInputTypes().get(0).getDomain())
.build());
} else{
return Collections.singletonList(
new ArchTypeSymbol.Builder()
.channels(channels)
.elementType(layer.getInputTypes().get(0).getDomain())
.build());
}
return Collections.singletonList(
new ArchTypeSymbol.Builder()
.channels(channels)
.height(height)
.width(width)
.elementType(layer.getInputTypes().get(0).getDomain())
.build());
}
return Collections.emptyList();
}
@Override
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment