From c2f2d4115fb39a1b7d1be3382893a7bdf955b999 Mon Sep 17 00:00:00 2001
From: Christian Fuss <chrifuss@freenet.de>
Date: Fri, 30 Aug 2019 17:44:00 +0200
Subject: [PATCH] fixed bug leading to multiple equivalent architecture outputs
 for unrolls

---
 .../templates/gluon/CNNPredictor.ftl          | 28 ++++++++++---------
 .../resources/templates/gluon/execute.ftl     |  3 +-
 2 files changed, 17 insertions(+), 14 deletions(-)

diff --git a/src/main/resources/templates/gluon/CNNPredictor.ftl b/src/main/resources/templates/gluon/CNNPredictor.ftl
index cad5fadd..9df30c91 100644
--- a/src/main/resources/templates/gluon/CNNPredictor.ftl
+++ b/src/main/resources/templates/gluon/CNNPredictor.ftl
@@ -117,34 +117,35 @@ public:
 </#list>
 
 <#list tc.architecture.unrolls as unroll>
-<#if unroll.isTrainable()>
-class ${tc.fileNameWithoutEnding}_${unroll?index}{
+<#list unroll.getBodiesForAllTimesteps() as body>
+<#if body.isTrainable()>
+class ${tc.fileNameWithoutEnding}_${tc.architecture.streams?size + body?index}{
 public:
-    const std::string json_file = "model/${tc.componentName}/model_${unroll?index}_newest-symbol.json";
-    const std::string param_file = "model/${tc.componentName}/model_${unroll?index}_newest-0000.params";
+    const std::string json_file = "model/${tc.componentName}/model_${tc.architecture.streams?size + body?index}_newest-symbol.json";
+    const std::string param_file = "model/${tc.componentName}/model_${tc.architecture.streams?size + body?index}_newest-0000.params";
     const std::vector<std::string> input_keys = {
-<#if tc.getUnrollInputNames(unroll)?size == 1>
+<#if tc.getStreamInputNames(body)?size == 1>
         "data"
 <#else>
-        <#list tc.getUnrollInputNames(unroll) as variable>"data${variable?index}"<#sep>, </#list>
+        <#list tc.getStreamInputNames(body) as variable>"data${variable?index}"<#sep>, </#list>
 </#if>
     };
-    const std::vector<std::vector<mx_uint>> input_shapes = {<#list tc.getUnrollInputDimensions(unroll) as dimensions>{${tc.join(dimensions, ", ")}}<#sep>, </#list>};
+    const std::vector<std::vector<mx_uint>> input_shapes = {<#list tc.getStreamInputDimensions(body) as dimensions>{${tc.join(dimensions, ", ")}}<#sep>, </#list>};
     const bool use_gpu = false;
 
     PredictorHandle handle;
 
-    explicit ${tc.fileNameWithoutEnding}_${unroll?index}(){
+    explicit ${tc.fileNameWithoutEnding}_${tc.architecture.streams?size + body?index}(){
         init(json_file, param_file, input_keys, input_shapes, use_gpu);
     }
 
-    ~${tc.fileNameWithoutEnding}_${unroll?index}(){
+    ~${tc.fileNameWithoutEnding}_${tc.architecture.streams?size + body?index}(){
         if(handle) MXPredFree(handle);
     }
 
-    void predict(${tc.join(tc.getUnrollInputNames(unroll), ", ", "const std::vector<float> &in_", "")},
-                 ${tc.join(tc.getUnrollOutputNames(unroll), ", ", "std::vector<float> &out_", "")}){
-<#list tc.getUnrollInputNames(unroll) as variable>
+    void predict(${tc.join(tc.getStreamInputNames(body), ", ", "const std::vector<float> &in_", "")},
+                 ${tc.join(tc.getStreamOutputNames(body), ", ", "std::vector<float> &out_", "")}){
+<#list tc.getStreamInputNames(body) as variable>
         MXPredSetInput(handle, input_keys[${variable?index}].c_str(), in_${variable}.data(), static_cast<mx_uint>(in_${variable}.size()));
 </#list>
 
@@ -155,7 +156,7 @@ public:
         mx_uint shape_len;
         size_t size;
 
-<#list tc.getUnrollOutputNames(unroll) as variable>
+<#list tc.getStreamOutputNames(body) as variable>
         output_index = ${variable?index?c};
         MXPredGetOutputShape(handle, output_index, &shape, &shape_len);
         size = 1;
@@ -222,6 +223,7 @@ public:
 };
 </#if>
 </#list>
+</#list>
 
 
 #endif // ${tc.fileNameWithoutEnding?upper_case}
diff --git a/src/main/resources/templates/gluon/execute.ftl b/src/main/resources/templates/gluon/execute.ftl
index cb1418ae..3913e00e 100644
--- a/src/main/resources/templates/gluon/execute.ftl
+++ b/src/main/resources/templates/gluon/execute.ftl
@@ -6,7 +6,8 @@
 <#list tc.getLayerVariableMembers("1")?keys as member>
     vector<float> ${member}(${tc.join(tc.getLayerVariableMembers("1")[member], " * ")})
 </#list>
-<#list tc.architecture.outputs as output>
+
+<#list tc.getNoDuplicateArchitectureOutputs() as output>
 <#if tc.getName(output)??>
     vector<float> ${tc.getName(output)}(${tc.join(output.ioDeclaration.type.dimensions, " * ")});
 </#if>
-- 
GitLab