labels = [batch.label[i].as_in_context(mx_context) for i in range(${tc.architectureOutputs?size?c})] <#list tc.architectureInputs as input_name> ${input_name} = batch.data[${input_name?index}].as_in_context(mx_context) <#if tc.architectureOutputSymbols?size gt 1> <#assign outputName = tc.getNameWithoutIndex(tc.getName(tc.architectureOutputSymbols[0]))> ${outputName} = [mx.nd.zeros((batch_size, ${tc.join(tc.architectureOutputSymbols[0].ioDeclaration.type.dimensions, ", ")},), ctx=mx_context) for i in range(${tc.architectureOutputs?size?c})] <#else> <#list tc.architectureOutputSymbols as output> ${tc.getName(output)} = mx.nd.zeros((batch_size, ${tc.join(output.ioDeclaration.type.dimensions, ", ")},), ctx=mx_context)<#sep>, <#list tc.getLayerVariableMembers()?keys as member> ${member} = mx.nd.zeros((batch_size, ${tc.join(tc.cutDimensions(tc.getLayerVariableMembers()[member]), ", ")},), ctx=mx_context) <#list tc.architecture.constants as constant> ${tc.getName(constant)} = mx.nd.full((batch_size, 1,), ${constant.intValue?c}, ctx=mx_context) nd.waitall() outputs = [] attentionList=[] <#list tc.architecture.networkInstructions as networkInstruction> <#if networkInstruction.isUnroll()> k = ${tc.getBeamSearchWidth(networkInstruction)} <#list tc.getUnrollInputNames(networkInstruction, "1") as inputName> <#if tc.getNameWithoutIndex(inputName) == tc.outputName> sequences = [([${inputName}], mx.nd.full((batch_size, 1,), 1.0, ctx=mx_context), [mx.nd.full((batch_size, 64,), 0.0, ctx=mx_context)])] for i in range(1, ${tc.getBeamSearchMaxLength(networkInstruction)}): all_candidates = [] for seq, score, attention in sequences: <#list tc.getUnrollInputNames(networkInstruction, "i") as inputName> <#if tc.getNameWithoutIndex(inputName) == tc.outputName> ${inputName} = seq[-1] <#if tc.isAttentionNetwork()> ${tc.join(tc.getUnrollOutputNames(networkInstruction, "i"), ", ")}, attention_ = self._networks[${networkInstruction?index}](${tc.join(tc.getUnrollInputNames(networkInstruction, "i"), ", ")}) <#else> ${tc.join(tc.getUnrollOutputNames(networkInstruction, "i"), ", ")} = self._networks[${networkInstruction?index}](${tc.join(tc.getUnrollInputNames(networkInstruction, "i"), ", ")}) <#list tc.getUnrollOutputNames(networkInstruction, "i") as outputName> <#if tc.getNameWithoutIndex(outputName) == tc.outputName> out = ${outputName} topk = out.topk(k=k) for top_index in range(len(topk[0])): j = mx.nd.slice_axis(topk, axis=1, begin=top_index, end=top_index+1) currentScore = mx.nd.slice_axis(out, axis=1, begin=top_index, end=top_index+1) newScore = mx.nd.expand_dims(score.squeeze() * currentScore.squeeze(), axis=1) <#if tc.isAttentionNetwork()> candidate = (seq + [j], newScore, attention + [attention_]) <#else> candidate = (seq + [j], newScore, attention + []) all_candidates.append(candidate) ordered = [] newSequences = [] for batch_entry in range(batch_size): ordered.append([]) batchCandidate = [([seq[batch_entry] for seq in candidate[0]], candidate[1][batch_entry], [attention[batch_entry].expand_dims(axis=0) for attention in candidate[2]]) for candidate in all_candidates] ordered[batch_entry] = sorted(batchCandidate, key=lambda tup: tup[1].asscalar()) if batch_entry == 0: newSequences = ordered[batch_entry] else: newSequences = [([mx.nd.concat(newSequences[sequenceIndex][0][seqIndex], ordered[batch_entry][sequenceIndex][0][seqIndex], dim=0) for seqIndex in range(len(newSequences[sequenceIndex][0]))], mx.nd.concat(newSequences[sequenceIndex][1], ordered[batch_entry][sequenceIndex][1], dim=0), [mx.nd.concat(newSequences[sequenceIndex][2][attentionIndex], ordered[batch_entry][sequenceIndex][2][attentionIndex], dim=0) for attentionIndex in range(len(newSequences[sequenceIndex][2]))]) for sequenceIndex in range(len(newSequences))] newSequences = [([newSequences[sequenceIndex][0][seqIndex].expand_dims(axis=1) for seqIndex in range(len(newSequences[sequenceIndex][0]))], newSequences[sequenceIndex][1].expand_dims(axis=1), [newSequences[sequenceIndex][2][attentionIndex] for attentionIndex in range(len(newSequences[sequenceIndex][2]))]) for sequenceIndex in range(len(newSequences))] sequences = newSequences[:][:k] for i in range(1, ${tc.getBeamSearchMaxLength(networkInstruction)}): <#list tc.getUnrollOutputNames(networkInstruction, "i") as outputName> <#if tc.getNameWithoutIndex(outputName) == tc.outputName> ${outputName} = sequences[0][0][i] outputs.append(${outputName}) <#if tc.isAttentionNetwork()> attentionList.append(sequences[0][2][i]) <#else> ${tc.join(tc.getStreamOutputNames(networkInstruction.body, true), ", ")} = self._networks[${networkInstruction?index}](${tc.join(tc.getStreamInputNames(networkInstruction.body, true), ", ")}) <#list tc.getStreamOutputNames(networkInstruction.body, true) as outputName> <#if tc.getNameWithoutIndex(outputName) == tc.outputName> outputs.append(${outputName}) <#if tc.endsWithArgmax(networkInstruction.body)> ${outputName} = mx.nd.argmax(${outputName}, axis=1).expand_dims(1)