Commit 65062aef authored by Christian Fuß's avatar Christian Fuß

fixed a problem with inferring size for OneHotLayer, when no 'size' parameter is given

parent 844877c0
Pipeline #155402 failed with stages
in 3 minutes and 22 seconds
......@@ -162,13 +162,8 @@ public class ArchitectureElementData {
}
public int getSize(){
if(getElement().isOutput()) {
return ((LayerSymbol) getElement())
.getIntValue(AllPredefinedLayers.ONE_HOT_SIZE_NAME).get();
}else{
return ((LayerSymbol) getElement())
.getIntValue(AllPredefinedLayers.ONE_HOT_SIZE_NAME).get();
}
return ((LayerSymbol) getElement())
.getIntValue(AllPredefinedLayers.ONE_HOT_SIZE_NAME).get();
}
@Nullable
......
<#assign size = element.size>
${element.name} = mx.symbol.one_hot(data=${element.inputs[0]},
indices=mx.symbol.argmax(data=${element.inputs[0]}, axis=1), depth=${size})
<#include "OutputShape.ftl">
\ No newline at end of file
<#assign size = element.size>
${element.name} = mx.symbol.one_hot(data=${element.inputs[0]},
indices=mx.symbol.argmax(data=${element.inputs[0]}, axis=1), depth=${element.element.outputTypes[0].dimensions[0]})
<#include "OutputShape.ftl">
\ No newline at end of file
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment