Commit d1b7e652 authored by Christian Fuß's avatar Christian Fuß

changed size parameter for OneHot layer to automatically use architecture...

changed size parameter for OneHot layer to automatically use architecture output size, if not specified elsewise.
parent 51c6e628
......@@ -39,10 +39,15 @@ public class OneHot extends PredefinedLayerDeclaration {
@Override
public List<ArchTypeSymbol> computeOutputTypes(List<ArchTypeSymbol> inputTypes, LayerSymbol layer) {
channels = layer.getIntValue(AllPredefinedLayers.SIZE_NAME).get();
if(layer.getOutputElement().get() instanceof IOSymbol && layer.getOutputElement().get().isOutput()) {
channels = ((IOSymbol) layer.getOutputElement().get()).getDefinition().getType().getChannels();
}else{
channels = layer.getIntValue(AllPredefinedLayers.SIZE_NAME).get();
}
return Collections.singletonList(new ArchTypeSymbol.Builder()
.channels(layer.getIntValue(AllPredefinedLayers.SIZE_NAME).get())
.channels(channels)
.height(1)
.width(1)
.elementType("0", "1")
......
......@@ -55,5 +55,6 @@ architecture Alexnet(img_height=224, img_width=224, img_channels=3, classes=10){
Concatenate() ->
FullyConnected(units=10) ->
Softmax() ->
OneHot() ->
predictions;
}
\ No newline at end of file
Markdown is supported
0%
or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment