Commit 49ac443c authored by Christian Fuß's avatar Christian Fuß
Browse files

small bugfix

parent ed7b7277
......@@ -58,7 +58,7 @@ class ${tc.fileNameWithoutEnding}:
self.networks[${networkInstruction?index}].collect_params().initialize(self.weight_initializer, ctx=context)
self.networks[${networkInstruction?index}].hybridize()
self.networks[${networkInstruction?index}](<#list tc.getStreamInputDimensions(networkInstruction.body, false) as dimensions>
<#if dimensions[0] == "-1">self.networks[${networkInstruction?index}].${dimensions[1]}.begin_state(batch_size=1, ctx=context)<#else>mx.nd.zeros((${tc.join(dimensions, ",")},), ctx=context)</#if> <#sep>, </#list>)
<#if dimensions[0] == "-1">self.networks[${networkInstruction?index}].${dimensions[1]}.begin_state(batch_size=1, ctx=context) <#sep>,<#else>mx.nd.zeros((${tc.join(dimensions, ",")},), ctx=context)</#if> <#sep>, </#list>)
</#if>
</#list>
......
......@@ -133,7 +133,7 @@ class ${tc.fileNameWithoutEnding}:
train_iter.reset()
for batch_i, batch in enumerate(train_iter):
<#list tc.architectureInputs as input_name>
${input_name} = batch.data[0].as_in_context(mx_context)
${input_name} = batch.data[${input_name?index}].as_in_context(mx_context)
</#list>
<#list tc.architectureOutputs as output_name>
${output_name}label = batch.label[${output_name?index}].as_in_context(mx_context)
......@@ -172,7 +172,7 @@ class ${tc.fileNameWithoutEnding}:
metric = mx.metric.create(eval_metric)
for batch_i, batch in enumerate(train_iter):
<#list tc.architectureInputs as input_name>
${input_name} = batch.data[0].as_in_context(mx_context)
${input_name} = batch.data[${input_name?index}].as_in_context(mx_context)
</#list>
labels = [
......@@ -239,7 +239,7 @@ class ${tc.fileNameWithoutEnding}:
metric = mx.metric.create(eval_metric)
for batch_i, batch in enumerate(test_iter):
<#list tc.architectureInputs as input_name>
${input_name} = batch.data[0].as_in_context(mx_context)
${input_name} = batch.data[${input_name?index}].as_in_context(mx_context)
</#list>
labels = [
......
......@@ -31,7 +31,7 @@
<#if networkInstruction.body.isTrainable()>
${tc.join(tc.getStreamOutputNames(networkInstruction.body), ", ")} = self._networks[${networkInstruction?index}](${tc.join(tc.getStreamInputNames(networkInstruction.body), ", ")})
<#list networkInstruction.body.elements as element>
<#if element.name == "ArgMax" && tc.architecture.networkInstructions[instructionCounter+1].getName() != "BeamSearch">
<#if element.name == "ArgMax" && (tc.architecture.networkInstructions?size <= instructionCounter+1 || tc.architecture.networkInstructions[instructionCounter+1].getName() != "BeamSearch")>
${tc.getStreamOutputNames(networkInstruction.body)[0]} = mx.nd.argmax(${tc.getStreamOutputNames(networkInstruction.body)[0]}, axis=1)
</#if>
</#list>
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment