Skip to content
Snippets Groups Projects
Commit c2f2d411 authored by Christian Fuß's avatar Christian Fuß
Browse files

fixed bug leading to multiple equivalent architecture outputs for unrolls

parent f4ecf597
No related branches found
No related tags found
1 merge request!23Added Unroll-related features and layers
Pipeline #178367 failed
......@@ -117,34 +117,35 @@ public:
</#list>
<#list tc.architecture.unrolls as unroll>
<#if unroll.isTrainable()>
class ${tc.fileNameWithoutEnding}_${unroll?index}{
<#list unroll.getBodiesForAllTimesteps() as body>
<#if body.isTrainable()>
class ${tc.fileNameWithoutEnding}_${tc.architecture.streams?size + body?index}{
public:
const std::string json_file = "model/${tc.componentName}/model_${unroll?index}_newest-symbol.json";
const std::string param_file = "model/${tc.componentName}/model_${unroll?index}_newest-0000.params";
const std::string json_file = "model/${tc.componentName}/model_${tc.architecture.streams?size + body?index}_newest-symbol.json";
const std::string param_file = "model/${tc.componentName}/model_${tc.architecture.streams?size + body?index}_newest-0000.params";
const std::vector<std::string> input_keys = {
<#if tc.getUnrollInputNames(unroll)?size == 1>
<#if tc.getStreamInputNames(body)?size == 1>
"data"
<#else>
<#list tc.getUnrollInputNames(unroll) as variable>"data${variable?index}"<#sep>, </#list>
<#list tc.getStreamInputNames(body) as variable>"data${variable?index}"<#sep>, </#list>
</#if>
};
const std::vector<std::vector<mx_uint>> input_shapes = {<#list tc.getUnrollInputDimensions(unroll) as dimensions>{${tc.join(dimensions, ", ")}}<#sep>, </#list>};
const std::vector<std::vector<mx_uint>> input_shapes = {<#list tc.getStreamInputDimensions(body) as dimensions>{${tc.join(dimensions, ", ")}}<#sep>, </#list>};
const bool use_gpu = false;
PredictorHandle handle;
explicit ${tc.fileNameWithoutEnding}_${unroll?index}(){
explicit ${tc.fileNameWithoutEnding}_${tc.architecture.streams?size + body?index}(){
init(json_file, param_file, input_keys, input_shapes, use_gpu);
}
~${tc.fileNameWithoutEnding}_${unroll?index}(){
~${tc.fileNameWithoutEnding}_${tc.architecture.streams?size + body?index}(){
if(handle) MXPredFree(handle);
}
void predict(${tc.join(tc.getUnrollInputNames(unroll), ", ", "const std::vector<float> &in_", "")},
${tc.join(tc.getUnrollOutputNames(unroll), ", ", "std::vector<float> &out_", "")}){
<#list tc.getUnrollInputNames(unroll) as variable>
void predict(${tc.join(tc.getStreamInputNames(body), ", ", "const std::vector<float> &in_", "")},
${tc.join(tc.getStreamOutputNames(body), ", ", "std::vector<float> &out_", "")}){
<#list tc.getStreamInputNames(body) as variable>
MXPredSetInput(handle, input_keys[${variable?index}].c_str(), in_${variable}.data(), static_cast<mx_uint>(in_${variable}.size()));
</#list>
......@@ -155,7 +156,7 @@ public:
mx_uint shape_len;
size_t size;
<#list tc.getUnrollOutputNames(unroll) as variable>
<#list tc.getStreamOutputNames(body) as variable>
output_index = ${variable?index?c};
MXPredGetOutputShape(handle, output_index, &shape, &shape_len);
size = 1;
......@@ -222,6 +223,7 @@ public:
};
</#if>
</#list>
</#list>
#endif // ${tc.fileNameWithoutEnding?upper_case}
......@@ -6,7 +6,8 @@
<#list tc.getLayerVariableMembers("1")?keys as member>
vector<float> ${member}(${tc.join(tc.getLayerVariableMembers("1")[member], " * ")})
</#list>
<#list tc.architecture.outputs as output>
<#list tc.getNoDuplicateArchitectureOutputs() as output>
<#if tc.getName(output)??>
vector<float> ${tc.getName(output)}(${tc.join(output.ioDeclaration.type.dimensions, " * ")});
</#if>
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment