pythonExecuteTrain.ftl 3.08 KB
Newer Older
1 2 3 4 5 6 7 8
                    labels = [batch.label[i].as_in_context(mx_context) for i in range(${tc.architectureOutputs?size?c})]

<#list tc.architectureInputs as input_name>
                    ${input_name} = batch.data[${input_name?index}].as_in_context(mx_context)
</#list>

<#if tc.architectureOutputSymbols?size gt 1>
<#assign outputName = tc.getNameWithoutIndex(tc.getName(tc.architectureOutputSymbols[0]))>
9
                    ${outputName} = [mx.nd.zeros((batch_size, ${tc.join(tc.architectureOutputSymbols[0].ioDeclaration.type.dimensions, ", ")},), ctx=mx_context) for i in range(${tc.architectureOutputs?size?c})]
10 11
<#else>
<#list tc.architectureOutputSymbols as output>
12
                    ${tc.getName(output)} = mx.nd.zeros((batch_size, ${tc.join(output.ioDeclaration.type.dimensions, ", ")},), ctx=mx_context)<#sep>,
13 14 15 16
</#list>
</#if>

<#list tc.getLayerVariableMembers()?keys as member>
17
                    ${member} = mx.nd.zeros((batch_size, ${tc.join(tc.cutDimensions(tc.getLayerVariableMembers()[member]), ", ")},), ctx=mx_context)
18 19 20
</#list>

<#list tc.architecture.constants as constant>
21
                    ${tc.getName(constant)} = mx.nd.full((batch_size, 1,), ${constant.intValue?c}, ctx=mx_context)
22 23
</#list>

24 25
                    nd.waitall()

26 27 28 29 30
                    lossList = []

<#list tc.architecture.networkInstructions as networkInstruction>
<#if networkInstruction.isUnroll()>
                    for i in range(1, ${tc.getBeamSearchMaxLength(networkInstruction)}):
Sebastian N.'s avatar
Merge  
Sebastian N. committed
31 32 33
<#if tc.isAttentionNetwork()>
                        ${tc.join(tc.getUnrollOutputNames(networkInstruction, "i"), ", ")}, _ = self._networks[${networkInstruction?index}](${tc.join(tc.getUnrollInputNames(networkInstruction, "i"), ", ")})
<#else>
34
                        ${tc.join(tc.getUnrollOutputNames(networkInstruction, "i"), ", ")} = self._networks[${networkInstruction?index}](${tc.join(tc.getUnrollInputNames(networkInstruction, "i"), ", ")})
Sebastian N.'s avatar
Merge  
Sebastian N. committed
35
</#if>
36 37 38 39 40 41 42

<#list tc.getUnrollOutputNames(networkInstruction, "i") as outputName>
<#if tc.getNameWithoutIndex(outputName) == tc.outputName>
                        lossList.append(loss_function(${outputName}, labels[${tc.getIndex(outputName, true)}]))
<#if tc.endsWithArgmax(networkInstruction.body)>
                        ${outputName} = mx.nd.argmax(${outputName}, axis=1).expand_dims(1)
</#if>
43 44
                        if use_teacher_forcing == "True":
                            ${outputName} = mx.nd.expand_dims(labels[${tc.getIndex(outputName, true)}], axis=1)
45 46 47 48 49 50 51 52 53 54 55 56 57 58 59
</#if>
</#list>
<#else>
                    ${tc.join(tc.getStreamOutputNames(networkInstruction.body, true), ", ")} = self._networks[${networkInstruction?index}](${tc.join(tc.getStreamInputNames(networkInstruction.body, true), ", ")})

<#list tc.getStreamOutputNames(networkInstruction.body, true) as outputName>
<#if tc.getNameWithoutIndex(outputName) == tc.outputName>
                    lossList.append(loss_function(${outputName}, labels[${tc.getIndex(outputName, true)}]))
<#if tc.endsWithArgmax(networkInstruction.body)>
                    ${outputName} = mx.nd.argmax(${outputName}, axis=1).expand_dims(1)
</#if>
</#if>
</#list>
</#if>
</#list>