Skip to content
Snippets Groups Projects
Commit ae55fd73 authored by Christian Fuß's avatar Christian Fuß
Browse files

fixed problem with expected size of argmax outputs being too high in C++ code

parent 87aff72b
No related branches found
No related tags found
1 merge request!23Added Unroll-related features and layers
Pipeline #206452 failed
......@@ -386,6 +386,10 @@ public class CNNArch2GluonTemplateController extends CNNArchTemplateController {
return AllAttentionModels.getAttentionModels().contains(getComponentName());
}
public boolean isArchitectureOutput(String element){
return getArchitectureOutputs().contains(element.replace("1000000", "0"));
}
public int getBeamSearchMaxLength(UnrollInstructionSymbol unroll){
return unroll.getIntValue(AllPredefinedLayers.MAX_LENGTH_NAME).get();
}
......
......@@ -52,7 +52,10 @@ public:
output_index = ${variable?index?c};
MXPredGetOutputShape(handle, output_index, &shape, &shape_len);
size = 1;
<#if !(tc.isArchitectureOutput(variable) && tc.endsWithArgmax(networkInstruction.body))>
for (mx_uint i = 0; i < shape_len; ++i) size *= shape[i];
</#if>
assert(size == out_${variable}.size());
MXPredGetOutput(handle, output_index, &(out_${variable}[0]), out_${variable}.size());
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment