Skip to content
Snippets Groups Projects
Commit c0574e47 authored by Sebastian Nickels's avatar Sebastian Nickels
Browse files

Temporarily removed getBodiesForTimesteps call

parent 604bf5b8
No related branches found
No related tags found
1 merge request!23Added Unroll-related features and layers
Pipeline #180358 failed
......@@ -9,12 +9,10 @@ from CNNNet_${tc.fullArchitectureName} import Net_${stream?index}
</#list>
<#list tc.architecture.unrolls as unroll>
<#list unroll.getBodiesForAllTimesteps() as body>
<#if body.isTrainable()>
from CNNNet_${tc.fullArchitectureName} import Net_${tc.architecture.streams?size + body?index}
<#if unroll.body.isTrainable()>
from CNNNet_${tc.fullArchitectureName} import Net_${tc.architecture.streams?size + unroll?index}
</#if>
</#list>
</#list>
class ${tc.fileNameWithoutEnding}:
_model_dir_ = "model/${tc.componentName}/"
......@@ -70,14 +68,12 @@ class ${tc.fileNameWithoutEnding}:
</#list>
<#list tc.architecture.unrolls as unroll>
<#list unroll.getBodiesForAllTimesteps() as body>
<#if body.isTrainable()>
self.networks[${tc.architecture.streams?size + body?index}] = Net_${tc.architecture.streams?size + body?index}(data_mean=data_mean, data_std=data_std)
self.networks[${tc.architecture.streams?size + body?index}].collect_params().initialize(self.weight_initializer, ctx=context)
self.networks[${tc.architecture.streams?size + body?index}].hybridize()
self.networks[${tc.architecture.streams?size + body?index}](<#list tc.getStreamInputDimensions(body) as dimensions>mx.nd.zeros((${tc.join(dimensions, ",")},), ctx=context)<#sep>, </#list>)
<#if unroll.body.isTrainable()>
self.networks[${tc.architecture.streams?size + unroll?index}] = Net_${tc.architecture.streams?size + unroll?index}(data_mean=data_mean, data_std=data_std)
self.networks[${tc.architecture.streams?size + unroll?index}].collect_params().initialize(self.weight_initializer, ctx=context)
self.networks[${tc.architecture.streams?size + unroll?index}].hybridize()
self.networks[${tc.architecture.streams?size + unroll?index}](<#list tc.getStreamInputDimensions(unroll.body) as dimensions>mx.nd.zeros((${tc.join(dimensions, ",")},), ctx=context)<#sep>, </#list>)
</#if>
</#list>
</#list>
if not os.path.exists(self._model_dir_):
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment