From c2f2d4115fb39a1b7d1be3382893a7bdf955b999 Mon Sep 17 00:00:00 2001 From: Christian Fuss <chrifuss@freenet.de> Date: Fri, 30 Aug 2019 17:44:00 +0200 Subject: [PATCH] fixed bug leading to multiple equivalent architecture outputs for unrolls --- .../templates/gluon/CNNPredictor.ftl | 28 ++++++++++--------- .../resources/templates/gluon/execute.ftl | 3 +- 2 files changed, 17 insertions(+), 14 deletions(-) diff --git a/src/main/resources/templates/gluon/CNNPredictor.ftl b/src/main/resources/templates/gluon/CNNPredictor.ftl index cad5fadd..9df30c91 100644 --- a/src/main/resources/templates/gluon/CNNPredictor.ftl +++ b/src/main/resources/templates/gluon/CNNPredictor.ftl @@ -117,34 +117,35 @@ public: </#list> <#list tc.architecture.unrolls as unroll> -<#if unroll.isTrainable()> -class ${tc.fileNameWithoutEnding}_${unroll?index}{ +<#list unroll.getBodiesForAllTimesteps() as body> +<#if body.isTrainable()> +class ${tc.fileNameWithoutEnding}_${tc.architecture.streams?size + body?index}{ public: - const std::string json_file = "model/${tc.componentName}/model_${unroll?index}_newest-symbol.json"; - const std::string param_file = "model/${tc.componentName}/model_${unroll?index}_newest-0000.params"; + const std::string json_file = "model/${tc.componentName}/model_${tc.architecture.streams?size + body?index}_newest-symbol.json"; + const std::string param_file = "model/${tc.componentName}/model_${tc.architecture.streams?size + body?index}_newest-0000.params"; const std::vector<std::string> input_keys = { -<#if tc.getUnrollInputNames(unroll)?size == 1> +<#if tc.getStreamInputNames(body)?size == 1> "data" <#else> - <#list tc.getUnrollInputNames(unroll) as variable>"data${variable?index}"<#sep>, </#list> + <#list tc.getStreamInputNames(body) as variable>"data${variable?index}"<#sep>, </#list> </#if> }; - const std::vector<std::vector<mx_uint>> input_shapes = {<#list tc.getUnrollInputDimensions(unroll) as dimensions>{${tc.join(dimensions, ", ")}}<#sep>, </#list>}; + const std::vector<std::vector<mx_uint>> input_shapes = {<#list tc.getStreamInputDimensions(body) as dimensions>{${tc.join(dimensions, ", ")}}<#sep>, </#list>}; const bool use_gpu = false; PredictorHandle handle; - explicit ${tc.fileNameWithoutEnding}_${unroll?index}(){ + explicit ${tc.fileNameWithoutEnding}_${tc.architecture.streams?size + body?index}(){ init(json_file, param_file, input_keys, input_shapes, use_gpu); } - ~${tc.fileNameWithoutEnding}_${unroll?index}(){ + ~${tc.fileNameWithoutEnding}_${tc.architecture.streams?size + body?index}(){ if(handle) MXPredFree(handle); } - void predict(${tc.join(tc.getUnrollInputNames(unroll), ", ", "const std::vector<float> &in_", "")}, - ${tc.join(tc.getUnrollOutputNames(unroll), ", ", "std::vector<float> &out_", "")}){ -<#list tc.getUnrollInputNames(unroll) as variable> + void predict(${tc.join(tc.getStreamInputNames(body), ", ", "const std::vector<float> &in_", "")}, + ${tc.join(tc.getStreamOutputNames(body), ", ", "std::vector<float> &out_", "")}){ +<#list tc.getStreamInputNames(body) as variable> MXPredSetInput(handle, input_keys[${variable?index}].c_str(), in_${variable}.data(), static_cast<mx_uint>(in_${variable}.size())); </#list> @@ -155,7 +156,7 @@ public: mx_uint shape_len; size_t size; -<#list tc.getUnrollOutputNames(unroll) as variable> +<#list tc.getStreamOutputNames(body) as variable> output_index = ${variable?index?c}; MXPredGetOutputShape(handle, output_index, &shape, &shape_len); size = 1; @@ -222,6 +223,7 @@ public: }; </#if> </#list> +</#list> #endif // ${tc.fileNameWithoutEnding?upper_case} diff --git a/src/main/resources/templates/gluon/execute.ftl b/src/main/resources/templates/gluon/execute.ftl index cb1418ae..3913e00e 100644 --- a/src/main/resources/templates/gluon/execute.ftl +++ b/src/main/resources/templates/gluon/execute.ftl @@ -6,7 +6,8 @@ <#list tc.getLayerVariableMembers("1")?keys as member> vector<float> ${member}(${tc.join(tc.getLayerVariableMembers("1")[member], " * ")}) </#list> -<#list tc.architecture.outputs as output> + +<#list tc.getNoDuplicateArchitectureOutputs() as output> <#if tc.getName(output)??> vector<float> ${tc.getName(output)}(${tc.join(output.ioDeclaration.type.dimensions, " * ")}); </#if> -- GitLab