Commit dc316085 authored by Sebastian Nickels's avatar Sebastian Nickels

Fixed loss for beam search

parent cd626bc9
Pipeline #226778 passed with stages
in 5 minutes and 12 seconds
......@@ -31,14 +31,14 @@
k = ${tc.getBeamSearchWidth(networkInstruction)}
<#list tc.getUnrollInputNames(networkInstruction, "1") as inputName>
<#if tc.getNameWithoutIndex(inputName) == tc.outputName>
sequences = [([${inputName}], mx.nd.full((batch_size, 1,), 1.0, ctx=mx_context), [mx.nd.full((batch_size, 64,), 0.0, ctx=mx_context)])]
sequences = [([${inputName}], mx.nd.full((batch_size, 1,), 1.0, ctx=mx_context), [mx.nd.full((batch_size, 64,), 0.0, ctx=mx_context)], [mx.nd.full((batch_size, 64,), 0.0, ctx=mx_context)])]
</#if>
</#list>
for i in range(1, ${tc.getBeamSearchMaxLength(networkInstruction)}):
all_candidates = []
for seq, score, attention in sequences:
for seq, score, seqLossList, attention in sequences:
<#list tc.getUnrollInputNames(networkInstruction, "i") as inputName>
<#if tc.getNameWithoutIndex(inputName) == tc.outputName>
${inputName} = seq[-1]
......@@ -52,6 +52,7 @@
<#list tc.getUnrollOutputNames(networkInstruction, "i") as outputName>
<#if tc.getNameWithoutIndex(outputName) == tc.outputName>
out = ${outputName}
newLossList = seqLossList + [loss_function(${outputName}, labels[${tc.getIndex(outputName, true)}])]
</#if>
</#list>
......@@ -62,9 +63,9 @@
currentScore = mx.nd.slice_axis(out, axis=1, begin=top_index, end=top_index+1)
newScore = mx.nd.expand_dims(score.squeeze() * currentScore.squeeze(), axis=1)
<#if tc.isAttentionNetwork()>
candidate = (seq + [j], newScore, attention + [attention_])
candidate = (seq + [j], newScore, newLossList, attention + [attention_])
<#else>
candidate = (seq + [j], newScore, attention + [])
candidate = (seq + [j], newScore, newLossList, attention + [])
</#if>
all_candidates.append(candidate)
......@@ -72,18 +73,21 @@
newSequences = []
for batch_entry in range(batch_size):
ordered.append([])
batchCandidate = [([seq[batch_entry] for seq in candidate[0]], candidate[1][batch_entry], [attention[batch_entry].expand_dims(axis=0) for attention in candidate[2]]) for candidate in all_candidates]
batchCandidate = [([seq[batch_entry] for seq in candidate[0]], candidate[1][batch_entry], [seq[batch_entry] for seq in candidate[2]], [attention[batch_entry].expand_dims(axis=0) for attention in candidate[3]]) for candidate in all_candidates]
ordered[batch_entry] = sorted(batchCandidate, key=lambda tup: tup[1].asscalar())
if batch_entry == 0:
newSequences = ordered[batch_entry]
else:
newSequences = [([mx.nd.concat(newSequences[sequenceIndex][0][seqIndex], ordered[batch_entry][sequenceIndex][0][seqIndex], dim=0) for seqIndex in range(len(newSequences[sequenceIndex][0]))],
mx.nd.concat(newSequences[sequenceIndex][1], ordered[batch_entry][sequenceIndex][1], dim=0),
[mx.nd.concat(newSequences[sequenceIndex][2][attentionIndex], ordered[batch_entry][sequenceIndex][2][attentionIndex], dim=0) for attentionIndex in range(len(newSequences[sequenceIndex][2]))])
[mx.nd.concat(newSequences[sequenceIndex][2][lossIndex], ordered[batch_entry][sequenceIndex][2][lossIndex], dim=0) for lossIndex in range(len(newSequences[sequenceIndex][2]))],
[mx.nd.concat(newSequences[sequenceIndex][3][attentionIndex], ordered[batch_entry][sequenceIndex][3][attentionIndex], dim=0) for attentionIndex in range(len(newSequences[sequenceIndex][3]))])
for sequenceIndex in range(len(newSequences))]
newSequences = [([newSequences[sequenceIndex][0][seqIndex].expand_dims(axis=1) for seqIndex in range(len(newSequences[sequenceIndex][0]))],
newSequences[sequenceIndex][1].expand_dims(axis=1), [newSequences[sequenceIndex][2][attentionIndex] for attentionIndex in range(len(newSequences[sequenceIndex][2]))])
newSequences[sequenceIndex][1].expand_dims(axis=1),
newSequences[sequenceIndex][2],
[newSequences[sequenceIndex][3][attentionIndex] for attentionIndex in range(len(newSequences[sequenceIndex][3]))])
for sequenceIndex in range(len(newSequences))]
sequences = newSequences[:][:k]
......@@ -93,9 +97,9 @@
<#if tc.getNameWithoutIndex(outputName) == tc.outputName>
${outputName} = sequences[0][0][i]
outputs.append(${outputName})
lossList.append(loss_function(${outputName}, labels[${tc.getIndex(outputName, true)}]))
lossList.append(sequences[0][2][i])
<#if tc.isAttentionNetwork()>
attentionList.append(sequences[0][2][i])
attentionList.append(sequences[0][3][i])
</#if>
</#if>
</#list>
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment