pythonExecute.ftl 3.47 KB
Newer Older
1
2
<#list tc.getLayerVariableMembers()?keys as member>
                    ${member} = mx.nd.zeros((batch_size, ${tc.join(tc.cutDimensions(tc.getLayerVariableMembers()[member]), ", ")},), ctx=mx_context)
3
</#list>
4
5
<#list tc.architectureOutputSymbols as output>
                    ${tc.getName(output)} = mx.nd.zeros((batch_size, ${tc.join(output.ioDeclaration.type.dimensions, ", ")},), ctx=mx_context)
6
</#list>
Sebastian N.'s avatar
Sebastian N. committed
7
8
9
<#list tc.architecture.constants as constant>
                    ${tc.getName(constant)} = mx.nd.full((batch_size, 1,), ${constant.intValue?c}, ctx=mx_context)
</#list>
10

11
<#assign instructionCounter = 0>
Sebastian N.'s avatar
Sebastian N. committed
12
<#list tc.architecture.networkInstructions as networkInstruction>
13
14
<#if networkInstruction.isUnroll()>
<#list networkInstruction.toUnrollInstruction().resolvedBodies as resolvedBody>
15
                    <#if networkInstruction.name == "BeamSearch">
16
                    input = ${tc.join(tc.getStreamInputNames(networkInstruction.body, resolvedBody), ", ")}
Christian Fuß's avatar
Christian Fuß committed
17
                    <#assign length = tc.getBeamSearchLength(networkInstruction.toUnrollInstruction())>
18
                    <#assign width = tc.getBeamSearchWidth(networkInstruction.toUnrollInstruction())>
Christian Fuß's avatar
Christian Fuß committed
19
                    ${tc.getStreamOutputNames(networkInstruction.body, resolvedBody)[0]} = applyBeamSearch(input, 0, ${length}, ${width}, 1.0, ${networkInstruction?index}, input)
20
                    <#else>
21
                    ${tc.join(tc.getStreamOutputNames(networkInstruction.body, resolvedBody), ", ")} = self._networks[${networkInstruction?index}](${tc.join(tc.getStreamInputNames(networkInstruction.body, resolvedBody), ", ")})
22
                    <#if !(tc.getStreamOutputNames(networkInstruction.body, resolvedBody)[0]?ends_with("_output_"))>
23
24
                    outputs.append(${tc.getStreamOutputNames(networkInstruction.body, resolvedBody)[0]})
                    </#if>
25
26
                    <#list resolvedBody.elements as element>
                    <#if element.name == "ArgMax">
27
                    ${tc.getStreamOutputNames(networkInstruction.body, resolvedBody)[0]} = mx.nd.argmax(${tc.getStreamOutputNames(networkInstruction.body, resolvedBody)[0]}, axis=1).expand_dims(1)
28
29
                    </#if>
                    </#list>
30
                    </#if>
31
32
</#list>
<#else>
Sebastian N.'s avatar
Sebastian N. committed
33
<#if networkInstruction.body.isTrainable()>
34
                    ${tc.join(tc.getStreamOutputNames(networkInstruction.body), ", ")} = self._networks[${networkInstruction?index}](${tc.join(tc.getStreamInputNames(networkInstruction.body), ", ")})
35
36
37
                    <#if !(tc.getStreamOutputNames(networkInstruction.body)[0]?ends_with("_output_"))>
                    outputs.append(${tc.getStreamOutputNames(networkInstruction.body)[0]})
                    </#if>
38
                    <#list networkInstruction.body.elements as element>
Christian Fuß's avatar
Christian Fuß committed
39
                    <#if element.name == "ArgMax" && (tc.architecture.networkInstructions?size <= instructionCounter+1 || tc.architecture.networkInstructions[instructionCounter+1].getName() != "BeamSearch")>
40
                    ${tc.getStreamOutputNames(networkInstruction.body)[0]} = mx.nd.argmax(${tc.getStreamOutputNames(networkInstruction.body)[0]}, axis=1).expand_dims(1)
41
42
                    </#if>
                    </#list>
43
<#else>
Sebastian N.'s avatar
Sebastian N. committed
44
${tc.include(networkInstruction.body, "PYTHON_INLINE")}
45
46
47
<#if !(tc.getStreamOutputNames(networkInstruction.body)[0]?ends_with("_state_"))>
                    outputs.append(${tc.getStreamOutputNames(networkInstruction.body)[0]})
</#if>
48
</#if>
49
</#if>
50
<#assign instructionCounter = instructionCounter + 1>
51
</#list>