Commit f259b302 authored by Sebastian Nickels's avatar Sebastian Nickels
Browse files

Added support for ArgMax in CNNPredictor

parent 6ccdf121
...@@ -53,7 +53,7 @@ public: ...@@ -53,7 +53,7 @@ public:
MXPredGetOutputShape(handle, output_index, &shape, &shape_len); MXPredGetOutputShape(handle, output_index, &shape, &shape_len);
size = 1; size = 1;
for (mx_uint i = 0; i < shape_len; ++i) size *= shape[i]; for (mx_uint i = 0; i < shape_len; ++i) size *= shape[i];
assert(size == out_${variable}.size()); assert(out_${variable}.size() == 1 || size == out_${variable}.size());
MXPredGetOutput(handle, output_index, &(out_${variable}[0]), out_${variable}.size()); MXPredGetOutput(handle, output_index, &(out_${variable}[0]), out_${variable}.size());
</#list> </#list>
......
...@@ -48,6 +48,10 @@ ...@@ -48,6 +48,10 @@
_predictor_${networkInstruction?index}_.predict(${tc.join(tc.getUnrollInputNames(networkInstruction, "i"), ", ")}, ${tc.join(tc.getUnrollOutputNames(networkInstruction, "i"), ", ")}); _predictor_${networkInstruction?index}_.predict(${tc.join(tc.getUnrollInputNames(networkInstruction, "i"), ", ")}, ${tc.join(tc.getUnrollOutputNames(networkInstruction, "i"), ", ")});
<#list tc.getUnrollOutputNames(networkInstruction, "i") as outputName> <#list tc.getUnrollOutputNames(networkInstruction, "i") as outputName>
<#if tc.getNameWithoutIndex(outputName) == tc.outputName> <#if tc.getNameWithoutIndex(outputName) == tc.outputName>
<#if tc.endsWithArgmax(networkInstruction.body)>
std::vector<float>::iterator maxElement = std::max_element(${outputName}.begin(), ${outputName}.end());
${outputName} = std::vector<float>{static_cast<float>(std::distance(${outputName}.begin(), maxElement))};
</#if>
vector<float> out = ${outputName}; vector<float> out = ${outputName};
</#if> </#if>
</#list> </#list>
...@@ -85,6 +89,14 @@ ...@@ -85,6 +89,14 @@
} }
<#else> <#else>
_predictor_${networkInstruction?index}_.predict(${tc.join(tc.getStreamInputNames(networkInstruction.body, true), ", ")}, ${tc.join(tc.getStreamOutputNames(networkInstruction.body, true), ", ")}); _predictor_${networkInstruction?index}_.predict(${tc.join(tc.getStreamInputNames(networkInstruction.body, true), ", ")}, ${tc.join(tc.getStreamOutputNames(networkInstruction.body, true), ", ")});
<#list tc.getStreamOutputNames(networkInstruction.body, true) as outputName>
<#if tc.getNameWithoutIndex(outputName) == tc.outputName>
<#if tc.endsWithArgmax(networkInstruction.body)>
std::vector<float>::iterator maxElement = std::max_element(${outputName}.begin(), ${outputName}.end());
${outputName} = std::vector<float>{static_cast<float>(std::distance(${outputName}.begin(), maxElement))};
</#if>
</#if>
</#list>
</#if> </#if>
</#list> </#list>
......
...@@ -44,7 +44,7 @@ public: ...@@ -44,7 +44,7 @@ public:
MXPredGetOutputShape(handle, output_index, &shape, &shape_len); MXPredGetOutputShape(handle, output_index, &shape, &shape_len);
size = 1; size = 1;
for (mx_uint i = 0; i < shape_len; ++i) size *= shape[i]; for (mx_uint i = 0; i < shape_len; ++i) size *= shape[i];
assert(size == out_predictions_.size()); assert(out_predictions_.size() == 1 || size == out_predictions_.size());
MXPredGetOutput(handle, output_index, &(out_predictions_[0]), out_predictions_.size()); MXPredGetOutput(handle, output_index, &(out_predictions_[0]), out_predictions_.size());
} }
......
...@@ -44,7 +44,7 @@ public: ...@@ -44,7 +44,7 @@ public:
MXPredGetOutputShape(handle, output_index, &shape, &shape_len); MXPredGetOutputShape(handle, output_index, &shape, &shape_len);
size = 1; size = 1;
for (mx_uint i = 0; i < shape_len; ++i) size *= shape[i]; for (mx_uint i = 0; i < shape_len; ++i) size *= shape[i];
assert(size == out_softmax_.size()); assert(out_softmax_.size() == 1 || size == out_softmax_.size());
MXPredGetOutput(handle, output_index, &(out_softmax_[0]), out_softmax_.size()); MXPredGetOutput(handle, output_index, &(out_softmax_[0]), out_softmax_.size());
} }
......
...@@ -44,7 +44,7 @@ public: ...@@ -44,7 +44,7 @@ public:
MXPredGetOutputShape(handle, output_index, &shape, &shape_len); MXPredGetOutputShape(handle, output_index, &shape, &shape_len);
size = 1; size = 1;
for (mx_uint i = 0; i < shape_len; ++i) size *= shape[i]; for (mx_uint i = 0; i < shape_len; ++i) size *= shape[i];
assert(size == out_predictions_.size()); assert(out_predictions_.size() == 1 || size == out_predictions_.size());
MXPredGetOutput(handle, output_index, &(out_predictions_[0]), out_predictions_.size()); MXPredGetOutput(handle, output_index, &(out_predictions_[0]), out_predictions_.size());
} }
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment