Commit a132a0fa authored by Sebastian Nickels's avatar Sebastian Nickels
Browse files

Updated for NetworkInstructionSymbol

parent c0574e47
Pipeline #180400 failed with stages
in 46 seconds
......@@ -2,15 +2,9 @@ import mxnet as mx
import logging
import os
<#list tc.architecture.streams as stream>
<#if stream.isTrainable()>
from CNNNet_${tc.fullArchitectureName} import Net_${stream?index}
</#if>
</#list>
<#list tc.architecture.unrolls as unroll>
<#if unroll.body.isTrainable()>
from CNNNet_${tc.fullArchitectureName} import Net_${tc.architecture.streams?size + unroll?index}
<#list tc.architecture.networkInstructions as networkInstruction>
<#if networkInstruction.body.isTrainable()>
from CNNNet_${tc.fullArchitectureName} import Net_${networkInstruction?index}
</#if>
</#list>
......@@ -58,21 +52,12 @@ class ${tc.fileNameWithoutEnding}:
return earliestLastEpoch
def construct(self, context, data_mean=None, data_std=None):
<#list tc.architecture.streams as stream>
<#if stream.isTrainable()>
self.networks[${stream?index}] = Net_${stream?index}(data_mean=data_mean, data_std=data_std)
self.networks[${stream?index}].collect_params().initialize(self.weight_initializer, ctx=context)
self.networks[${stream?index}].hybridize()
self.networks[${stream?index}](<#list tc.getStreamInputDimensions(stream) as dimensions>mx.nd.zeros((${tc.join(dimensions, ",")},), ctx=context)<#sep>, </#list>)
</#if>
</#list>
<#list tc.architecture.unrolls as unroll>
<#if unroll.body.isTrainable()>
self.networks[${tc.architecture.streams?size + unroll?index}] = Net_${tc.architecture.streams?size + unroll?index}(data_mean=data_mean, data_std=data_std)
self.networks[${tc.architecture.streams?size + unroll?index}].collect_params().initialize(self.weight_initializer, ctx=context)
self.networks[${tc.architecture.streams?size + unroll?index}].hybridize()
self.networks[${tc.architecture.streams?size + unroll?index}](<#list tc.getStreamInputDimensions(unroll.body) as dimensions>mx.nd.zeros((${tc.join(dimensions, ",")},), ctx=context)<#sep>, </#list>)
<#list tc.architecture.networkInstructions as networkInstruction>
<#if networkInstruction.body.isTrainable()>
self.networks[${networkInstruction?index}] = Net_${networkInstruction?index}(data_mean=data_mean, data_std=data_std)
self.networks[${networkInstruction?index}].collect_params().initialize(self.weight_initializer, ctx=context)
self.networks[${networkInstruction?index}].hybridize()
self.networks[${networkInstruction?index}](<#list tc.getStreamInputDimensions(networkInstruction.body) as dimensions>mx.nd.zeros((${tc.join(dimensions, ",")},), ctx=context)<#sep>, </#list>)
</#if>
</#list>
......
......@@ -78,35 +78,18 @@ class NoNormalization(gluon.HybridBlock):
return x
<#list tc.architecture.streams as stream>
<#if stream.isTrainable()>
class Net_${stream?index}(gluon.HybridBlock):
<#list tc.architecture.networkInstructions as networkInstruction>
<#if networkInstruction.body.isTrainable()>
class Net_${networkInstruction?index}(gluon.HybridBlock):
def __init__(self, data_mean=None, data_std=None, **kwargs):
super(Net_${stream?index}, self).__init__(**kwargs)
super(Net_${networkInstruction?index}, self).__init__(**kwargs)
self.last_layers = {}
with self.name_scope():
${tc.include(stream, "ARCHITECTURE_DEFINITION")}
${tc.include(networkInstruction.body, "ARCHITECTURE_DEFINITION")}
def hybrid_forward(self, F, ${tc.join(tc.getStreamInputNames(stream), ", ")}):
${tc.include(stream, "FORWARD_FUNCTION")}
return ${tc.join(tc.getStreamOutputNames(stream), ", ")}
</#if>
</#list>
<#list tc.architecture.unrolls as unroll>
<#if unroll.body.isTrainable()>
class Net_${unroll?index}(gluon.HybridBlock):
def __init__(self, data_mean=None, data_std=None, **kwargs):
super(Net_${unroll?index}, self).__init__(**kwargs)
self.last_layers = {}
with self.name_scope():
${tc.include(unroll.body, "ARCHITECTURE_DEFINITION")}
def hybrid_forward(self, F, ${tc.join(tc.getStreamInputNames(unroll.body), ", ")}):
${tc.include(unroll.body, "FORWARD_FUNCTION")}
return ${tc.join(tc.getStreamOutputNames(unroll.body), ", ")}
def hybrid_forward(self, F, ${tc.join(tc.getStreamInputNames(networkInstruction.body), ", ")}):
${tc.include(networkInstruction.body, "FORWARD_FUNCTION")}
return ${tc.join(tc.getStreamOutputNames(networkInstruction.body), ", ")}
</#if>
</#list>
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment