From a132a0faf6305aa711b1beb5f894dbef197d8695 Mon Sep 17 00:00:00 2001
From: Sebastian Nickels <sn1c@protonmail.ch>
Date: Fri, 6 Sep 2019 03:15:28 +0200
Subject: [PATCH] Updated for NetworkInstructionSymbol

---
 .../resources/templates/gluon/CNNCreator.ftl  | 33 +++++--------------
 src/main/resources/templates/gluon/CNNNet.ftl | 33 +++++--------------
 2 files changed, 17 insertions(+), 49 deletions(-)

diff --git a/src/main/resources/templates/gluon/CNNCreator.ftl b/src/main/resources/templates/gluon/CNNCreator.ftl
index 55bab2ca..2414a4be 100644
--- a/src/main/resources/templates/gluon/CNNCreator.ftl
+++ b/src/main/resources/templates/gluon/CNNCreator.ftl
@@ -2,15 +2,9 @@ import mxnet as mx
 import logging
 import os
 
-<#list tc.architecture.streams as stream>
-<#if stream.isTrainable()>
-from CNNNet_${tc.fullArchitectureName} import Net_${stream?index}
-</#if>
-</#list>
-
-<#list tc.architecture.unrolls as unroll>
-<#if unroll.body.isTrainable()>
-from CNNNet_${tc.fullArchitectureName} import Net_${tc.architecture.streams?size + unroll?index}
+<#list tc.architecture.networkInstructions as networkInstruction>
+<#if networkInstruction.body.isTrainable()>
+from CNNNet_${tc.fullArchitectureName} import Net_${networkInstruction?index}
 </#if>
 </#list>
 
@@ -58,21 +52,12 @@ class ${tc.fileNameWithoutEnding}:
         return earliestLastEpoch
 
     def construct(self, context, data_mean=None, data_std=None):
-<#list tc.architecture.streams as stream>
-<#if stream.isTrainable()>
-        self.networks[${stream?index}] = Net_${stream?index}(data_mean=data_mean, data_std=data_std)
-        self.networks[${stream?index}].collect_params().initialize(self.weight_initializer, ctx=context)
-        self.networks[${stream?index}].hybridize()
-        self.networks[${stream?index}](<#list tc.getStreamInputDimensions(stream) as dimensions>mx.nd.zeros((${tc.join(dimensions, ",")},), ctx=context)<#sep>, </#list>)
-</#if>
-</#list>
-
-<#list tc.architecture.unrolls as unroll>
-<#if unroll.body.isTrainable()>
-        self.networks[${tc.architecture.streams?size + unroll?index}] = Net_${tc.architecture.streams?size + unroll?index}(data_mean=data_mean, data_std=data_std)
-        self.networks[${tc.architecture.streams?size + unroll?index}].collect_params().initialize(self.weight_initializer, ctx=context)
-        self.networks[${tc.architecture.streams?size + unroll?index}].hybridize()
-        self.networks[${tc.architecture.streams?size + unroll?index}](<#list tc.getStreamInputDimensions(unroll.body) as dimensions>mx.nd.zeros((${tc.join(dimensions, ",")},), ctx=context)<#sep>, </#list>)
+<#list tc.architecture.networkInstructions as networkInstruction>
+<#if networkInstruction.body.isTrainable()>
+        self.networks[${networkInstruction?index}] = Net_${networkInstruction?index}(data_mean=data_mean, data_std=data_std)
+        self.networks[${networkInstruction?index}].collect_params().initialize(self.weight_initializer, ctx=context)
+        self.networks[${networkInstruction?index}].hybridize()
+        self.networks[${networkInstruction?index}](<#list tc.getStreamInputDimensions(networkInstruction.body) as dimensions>mx.nd.zeros((${tc.join(dimensions, ",")},), ctx=context)<#sep>, </#list>)
 </#if>
 </#list>
 
diff --git a/src/main/resources/templates/gluon/CNNNet.ftl b/src/main/resources/templates/gluon/CNNNet.ftl
index c62f1342..cd42d7d1 100644
--- a/src/main/resources/templates/gluon/CNNNet.ftl
+++ b/src/main/resources/templates/gluon/CNNNet.ftl
@@ -78,35 +78,18 @@ class NoNormalization(gluon.HybridBlock):
         return x
 
 
-<#list tc.architecture.streams as stream>
-<#if stream.isTrainable()>
-class Net_${stream?index}(gluon.HybridBlock):
+<#list tc.architecture.networkInstructions as networkInstruction>
+<#if networkInstruction.body.isTrainable()>
+class Net_${networkInstruction?index}(gluon.HybridBlock):
     def __init__(self, data_mean=None, data_std=None, **kwargs):
-        super(Net_${stream?index}, self).__init__(**kwargs)
+        super(Net_${networkInstruction?index}, self).__init__(**kwargs)
         self.last_layers = {}
         with self.name_scope():
-${tc.include(stream, "ARCHITECTURE_DEFINITION")}
+${tc.include(networkInstruction.body, "ARCHITECTURE_DEFINITION")}
 
-    def hybrid_forward(self, F, ${tc.join(tc.getStreamInputNames(stream), ", ")}):
-${tc.include(stream, "FORWARD_FUNCTION")}
-        return ${tc.join(tc.getStreamOutputNames(stream), ", ")}
-
-</#if>
-</#list>
-
-
-<#list tc.architecture.unrolls as unroll>
-<#if unroll.body.isTrainable()>
-class Net_${unroll?index}(gluon.HybridBlock):
-    def __init__(self, data_mean=None, data_std=None, **kwargs):
-        super(Net_${unroll?index}, self).__init__(**kwargs)
-        self.last_layers = {}
-        with self.name_scope():
-${tc.include(unroll.body, "ARCHITECTURE_DEFINITION")}
-
-    def hybrid_forward(self, F, ${tc.join(tc.getStreamInputNames(unroll.body), ", ")}):
-${tc.include(unroll.body, "FORWARD_FUNCTION")}
-        return ${tc.join(tc.getStreamOutputNames(unroll.body), ", ")}
+    def hybrid_forward(self, F, ${tc.join(tc.getStreamInputNames(networkInstruction.body), ", ")}):
+${tc.include(networkInstruction.body, "FORWARD_FUNCTION")}
+        return ${tc.join(tc.getStreamOutputNames(networkInstruction.body), ", ")}
 
 </#if>
 </#list>
-- 
GitLab