pythonExecuteTest.ftl 6.29 KB
Newer Older
1 2 3 4 5 6 7 8
                    labels = [batch.label[i].as_in_context(mx_context) for i in range(${tc.architectureOutputs?size?c})]

<#list tc.architectureInputs as input_name>
                    ${input_name} = batch.data[${input_name?index}].as_in_context(mx_context)
</#list>

<#if tc.architectureOutputSymbols?size gt 1>
<#assign outputName = tc.getNameWithoutIndex(tc.getName(tc.architectureOutputSymbols[0]))>
9
                    ${outputName} = [mx.nd.zeros((batch_size, ${tc.join(tc.architectureOutputSymbols[0].ioDeclaration.type.dimensions, ", ")},), ctx=mx_context) for i in range(${tc.architectureOutputs?size?c})]
10 11
<#else>
<#list tc.architectureOutputSymbols as output>
12
                    ${tc.getName(output)} = mx.nd.zeros((batch_size, ${tc.join(output.ioDeclaration.type.dimensions, ", ")},), ctx=mx_context)<#sep>,
13 14 15 16
</#list>
</#if>

<#list tc.getLayerVariableMembers()?keys as member>
17
                    ${member} = mx.nd.zeros((batch_size, ${tc.join(tc.cutDimensions(tc.getLayerVariableMembers()[member]), ", ")},), ctx=mx_context)
18 19 20
</#list>

<#list tc.architecture.constants as constant>
21
                    ${tc.getName(constant)} = mx.nd.full((batch_size, 1,), ${constant.intValue?c}, ctx=mx_context)
22 23
</#list>

24 25
                    nd.waitall()

26
                    outputs = []
Sebastian N.'s avatar
Merge  
Sebastian N. committed
27
                    attentionList=[]
28 29 30 31 32
<#list tc.architecture.networkInstructions as networkInstruction>
<#if networkInstruction.isUnroll()>
                    k = ${tc.getBeamSearchWidth(networkInstruction)}
<#list tc.getUnrollInputNames(networkInstruction, "1") as inputName>
<#if tc.getNameWithoutIndex(inputName) == tc.outputName>
33
                    sequences = [([${inputName}], mx.nd.full((batch_size, 1,), 1.0, ctx=mx_context), [mx.nd.full((batch_size, 64,), 0.0, ctx=mx_context)])]
34 35 36 37 38 39
</#if>
</#list>

                    for i in range(1, ${tc.getBeamSearchMaxLength(networkInstruction)}):
                        all_candidates = []

40
                        for seq, score, attention in sequences:
41 42 43 44 45
<#list tc.getUnrollInputNames(networkInstruction, "i") as inputName>
<#if tc.getNameWithoutIndex(inputName) == tc.outputName>
                            ${inputName} = seq[-1]
</#if>
</#list>
Sebastian N.'s avatar
Merge  
Sebastian N. committed
46
<#if tc.isAttentionNetwork()>
Sebastian N.'s avatar
Sebastian N. committed
47
                            ${tc.join(tc.getUnrollOutputNames(networkInstruction, "i"), ", ")}, attention_ = self._networks[${networkInstruction?index}](${tc.join(tc.getUnrollInputNames(networkInstruction, "i"), ", ")})
Sebastian N.'s avatar
Merge  
Sebastian N. committed
48
<#else>
Sebastian N.'s avatar
Sebastian N. committed
49
                            ${tc.join(tc.getUnrollOutputNames(networkInstruction, "i"), ", ")} = self._networks[${networkInstruction?index}](${tc.join(tc.getUnrollInputNames(networkInstruction, "i"), ", ")})
Sebastian N.'s avatar
Merge  
Sebastian N. committed
50
</#if>
51 52 53 54 55 56
<#list tc.getUnrollOutputNames(networkInstruction, "i") as outputName>
<#if tc.getNameWithoutIndex(outputName) == tc.outputName>
                            out = ${outputName}
</#if>
</#list>

57
                            topk = out.topk(k=k)
58

59 60 61 62
                            for top_index in range(len(topk[0])):
                                j = mx.nd.slice_axis(topk, axis=1, begin=top_index, end=top_index+1)
                                currentScore = mx.nd.slice_axis(out, axis=1, begin=top_index, end=top_index+1)
                                newScore = mx.nd.expand_dims(score.squeeze() * currentScore.squeeze(), axis=1)
63 64 65 66 67
<#if tc.isAttentionNetwork()>
                                candidate = (seq + [j],  newScore, attention + [attention_])
<#else>
                                candidate = (seq + [j],  newScore, attention + [])
</#if>
68 69
                                all_candidates.append(candidate)

70 71 72 73
                        ordered = []
                        newSequences = []
                        for batch_entry in range(batch_size):
                            ordered.append([])
74
                            batchCandidate = [([seq[batch_entry] for seq in candidate[0]], candidate[1][batch_entry], [attention[batch_entry].expand_dims(axis=0) for attention in candidate[2]]) for candidate in all_candidates]
75 76 77 78
                            ordered[batch_entry] = sorted(batchCandidate, key=lambda tup: tup[1].asscalar())
                            if batch_entry == 0:
                                newSequences = ordered[batch_entry]
                            else:
79 80 81 82
                                newSequences = [([mx.nd.concat(newSequences[sequenceIndex][0][seqIndex], ordered[batch_entry][sequenceIndex][0][seqIndex], dim=0) for seqIndex in range(len(newSequences[sequenceIndex][0]))],
                                    mx.nd.concat(newSequences[sequenceIndex][1], ordered[batch_entry][sequenceIndex][1], dim=0),
                                    [mx.nd.concat(newSequences[sequenceIndex][2][attentionIndex], ordered[batch_entry][sequenceIndex][2][attentionIndex], dim=0) for attentionIndex in range(len(newSequences[sequenceIndex][2]))])
                                    for sequenceIndex in range(len(newSequences))]
83

84 85 86
                        newSequences = [([newSequences[sequenceIndex][0][seqIndex].expand_dims(axis=1) for seqIndex in range(len(newSequences[sequenceIndex][0]))],
                            newSequences[sequenceIndex][1].expand_dims(axis=1), [newSequences[sequenceIndex][2][attentionIndex] for attentionIndex in range(len(newSequences[sequenceIndex][2]))])
                            for sequenceIndex in range(len(newSequences))]
87 88

                        sequences = newSequences[:][:k]
89 90 91 92 93 94

                    for i in range(1, ${tc.getBeamSearchMaxLength(networkInstruction)}):
<#list tc.getUnrollOutputNames(networkInstruction, "i") as outputName>
<#if tc.getNameWithoutIndex(outputName) == tc.outputName>
                        ${outputName} = sequences[0][0][i]
                        outputs.append(${outputName})
95 96 97
<#if tc.isAttentionNetwork()>
                        attentionList.append(sequences[0][2][i])
</#if>
98 99 100 101 102 103 104 105 106 107 108 109 110 111 112
</#if>
</#list>
<#else>
                    ${tc.join(tc.getStreamOutputNames(networkInstruction.body, true), ", ")} = self._networks[${networkInstruction?index}](${tc.join(tc.getStreamInputNames(networkInstruction.body, true), ", ")})

<#list tc.getStreamOutputNames(networkInstruction.body, true) as outputName>
<#if tc.getNameWithoutIndex(outputName) == tc.outputName>
                    outputs.append(${outputName})
<#if tc.endsWithArgmax(networkInstruction.body)>
                    ${outputName} = mx.nd.argmax(${outputName}, axis=1).expand_dims(1)
</#if>
</#if>
</#list>
</#if>
</#list>