Commit 0bef9f58 authored by Sebastian Nickels's avatar Sebastian Nickels

Added flatten to FullyConnected layer

parent 0eb98f6f
Pipeline #173122 passed with stages
in 19 minutes and 44 seconds
......@@ -72,6 +72,7 @@ public class AllPredefinedLayers {
public static final String INPUT_DIM_NAME = "input_dim";
public static final String OUTPUT_DIM_NAME = "output_dim";
public static final String BIDIRECTIONAL_NAME = "bidirectional";
public static final String FLATTEN_PARAMETER_NAME = "flatten";
public static final String BEAMSEARCH_MAX_LENGTH = "max_length";
public static final String BEAMSEARCH_WIDTH_NAME = "width";
......
......@@ -35,12 +35,45 @@ public class FullyConnected extends PredefinedLayerDeclaration {
@Override
public List<ArchTypeSymbol> computeOutputTypes(List<ArchTypeSymbol> inputTypes, LayerSymbol layer, VariableSymbol.Member member) {
return Collections.singletonList(new ArchTypeSymbol.Builder()
.height(1)
.width(1)
.channels(layer.getIntValue(AllPredefinedLayers.UNITS_NAME).get())
.elementType("-oo", "oo")
.build());
boolean flatten = layer.getBooleanValue(AllPredefinedLayers.FLATTEN_PARAMETER_NAME).get();
int units = layer.getIntValue(AllPredefinedLayers.UNITS_NAME).get();
if (flatten) {
return Collections.singletonList(new ArchTypeSymbol.Builder()
.channels(units)
.height(1)
.width(1)
.elementType("-oo", "oo")
.build());
}
else {
ArchTypeSymbol inputType = layer.getInputTypes().get(0);
if (inputType.getWidth() == 1) {
if (inputType.getHeight() == 1) {
return Collections.singletonList(new ArchTypeSymbol.Builder()
.channels(units)
.height(1)
.width(1)
.elementType("-oo", "oo")
.build());
}
return Collections.singletonList(new ArchTypeSymbol.Builder()
.channels(inputType.getChannels())
.height(units)
.width(1)
.elementType("-oo", "oo")
.build());
}
return Collections.singletonList(new ArchTypeSymbol.Builder()
.channels(inputType.getChannels())
.height(inputType.getHeight())
.width(units)
.elementType("-oo", "oo")
.build());
}
}
@Override
......@@ -59,6 +92,11 @@ public class FullyConnected extends PredefinedLayerDeclaration {
.name(AllPredefinedLayers.NOBIAS_NAME)
.constraints(Constraints.BOOLEAN)
.defaultValue(false)
.build(),
new ParameterSymbol.Builder()
.name(AllPredefinedLayers.FLATTEN_PARAMETER_NAME)
.constraints(Constraints.BOOLEAN)
.defaultValue(true)
.build()));
declaration.setParameters(parameters);
return declaration;
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment