From a132a0faf6305aa711b1beb5f894dbef197d8695 Mon Sep 17 00:00:00 2001 From: Sebastian Nickels <sn1c@protonmail.ch> Date: Fri, 6 Sep 2019 03:15:28 +0200 Subject: [PATCH] Updated for NetworkInstructionSymbol --- .../resources/templates/gluon/CNNCreator.ftl | 33 +++++-------------- src/main/resources/templates/gluon/CNNNet.ftl | 33 +++++-------------- 2 files changed, 17 insertions(+), 49 deletions(-) diff --git a/src/main/resources/templates/gluon/CNNCreator.ftl b/src/main/resources/templates/gluon/CNNCreator.ftl index 55bab2ca..2414a4be 100644 --- a/src/main/resources/templates/gluon/CNNCreator.ftl +++ b/src/main/resources/templates/gluon/CNNCreator.ftl @@ -2,15 +2,9 @@ import mxnet as mx import logging import os -<#list tc.architecture.streams as stream> -<#if stream.isTrainable()> -from CNNNet_${tc.fullArchitectureName} import Net_${stream?index} -</#if> -</#list> - -<#list tc.architecture.unrolls as unroll> -<#if unroll.body.isTrainable()> -from CNNNet_${tc.fullArchitectureName} import Net_${tc.architecture.streams?size + unroll?index} +<#list tc.architecture.networkInstructions as networkInstruction> +<#if networkInstruction.body.isTrainable()> +from CNNNet_${tc.fullArchitectureName} import Net_${networkInstruction?index} </#if> </#list> @@ -58,21 +52,12 @@ class ${tc.fileNameWithoutEnding}: return earliestLastEpoch def construct(self, context, data_mean=None, data_std=None): -<#list tc.architecture.streams as stream> -<#if stream.isTrainable()> - self.networks[${stream?index}] = Net_${stream?index}(data_mean=data_mean, data_std=data_std) - self.networks[${stream?index}].collect_params().initialize(self.weight_initializer, ctx=context) - self.networks[${stream?index}].hybridize() - self.networks[${stream?index}](<#list tc.getStreamInputDimensions(stream) as dimensions>mx.nd.zeros((${tc.join(dimensions, ",")},), ctx=context)<#sep>, </#list>) -</#if> -</#list> - -<#list tc.architecture.unrolls as unroll> -<#if unroll.body.isTrainable()> - self.networks[${tc.architecture.streams?size + unroll?index}] = Net_${tc.architecture.streams?size + unroll?index}(data_mean=data_mean, data_std=data_std) - self.networks[${tc.architecture.streams?size + unroll?index}].collect_params().initialize(self.weight_initializer, ctx=context) - self.networks[${tc.architecture.streams?size + unroll?index}].hybridize() - self.networks[${tc.architecture.streams?size + unroll?index}](<#list tc.getStreamInputDimensions(unroll.body) as dimensions>mx.nd.zeros((${tc.join(dimensions, ",")},), ctx=context)<#sep>, </#list>) +<#list tc.architecture.networkInstructions as networkInstruction> +<#if networkInstruction.body.isTrainable()> + self.networks[${networkInstruction?index}] = Net_${networkInstruction?index}(data_mean=data_mean, data_std=data_std) + self.networks[${networkInstruction?index}].collect_params().initialize(self.weight_initializer, ctx=context) + self.networks[${networkInstruction?index}].hybridize() + self.networks[${networkInstruction?index}](<#list tc.getStreamInputDimensions(networkInstruction.body) as dimensions>mx.nd.zeros((${tc.join(dimensions, ",")},), ctx=context)<#sep>, </#list>) </#if> </#list> diff --git a/src/main/resources/templates/gluon/CNNNet.ftl b/src/main/resources/templates/gluon/CNNNet.ftl index c62f1342..cd42d7d1 100644 --- a/src/main/resources/templates/gluon/CNNNet.ftl +++ b/src/main/resources/templates/gluon/CNNNet.ftl @@ -78,35 +78,18 @@ class NoNormalization(gluon.HybridBlock): return x -<#list tc.architecture.streams as stream> -<#if stream.isTrainable()> -class Net_${stream?index}(gluon.HybridBlock): +<#list tc.architecture.networkInstructions as networkInstruction> +<#if networkInstruction.body.isTrainable()> +class Net_${networkInstruction?index}(gluon.HybridBlock): def __init__(self, data_mean=None, data_std=None, **kwargs): - super(Net_${stream?index}, self).__init__(**kwargs) + super(Net_${networkInstruction?index}, self).__init__(**kwargs) self.last_layers = {} with self.name_scope(): -${tc.include(stream, "ARCHITECTURE_DEFINITION")} +${tc.include(networkInstruction.body, "ARCHITECTURE_DEFINITION")} - def hybrid_forward(self, F, ${tc.join(tc.getStreamInputNames(stream), ", ")}): -${tc.include(stream, "FORWARD_FUNCTION")} - return ${tc.join(tc.getStreamOutputNames(stream), ", ")} - -</#if> -</#list> - - -<#list tc.architecture.unrolls as unroll> -<#if unroll.body.isTrainable()> -class Net_${unroll?index}(gluon.HybridBlock): - def __init__(self, data_mean=None, data_std=None, **kwargs): - super(Net_${unroll?index}, self).__init__(**kwargs) - self.last_layers = {} - with self.name_scope(): -${tc.include(unroll.body, "ARCHITECTURE_DEFINITION")} - - def hybrid_forward(self, F, ${tc.join(tc.getStreamInputNames(unroll.body), ", ")}): -${tc.include(unroll.body, "FORWARD_FUNCTION")} - return ${tc.join(tc.getStreamOutputNames(unroll.body), ", ")} + def hybrid_forward(self, F, ${tc.join(tc.getStreamInputNames(networkInstruction.body), ", ")}): +${tc.include(networkInstruction.body, "FORWARD_FUNCTION")} + return ${tc.join(tc.getStreamOutputNames(networkInstruction.body), ", ")} </#if> </#list> -- GitLab