Commit dc316085 authored by Sebastian Nickels's avatar Sebastian Nickels
Browse files

Fixed loss for beam search

parent cd626bc9
Pipeline #226778 passed with stages
in 5 minutes and 12 seconds
...@@ -31,14 +31,14 @@ ...@@ -31,14 +31,14 @@
k = ${tc.getBeamSearchWidth(networkInstruction)} k = ${tc.getBeamSearchWidth(networkInstruction)}
<#list tc.getUnrollInputNames(networkInstruction, "1") as inputName> <#list tc.getUnrollInputNames(networkInstruction, "1") as inputName>
<#if tc.getNameWithoutIndex(inputName) == tc.outputName> <#if tc.getNameWithoutIndex(inputName) == tc.outputName>
sequences = [([${inputName}], mx.nd.full((batch_size, 1,), 1.0, ctx=mx_context), [mx.nd.full((batch_size, 64,), 0.0, ctx=mx_context)])] sequences = [([${inputName}], mx.nd.full((batch_size, 1,), 1.0, ctx=mx_context), [mx.nd.full((batch_size, 64,), 0.0, ctx=mx_context)], [mx.nd.full((batch_size, 64,), 0.0, ctx=mx_context)])]
</#if> </#if>
</#list> </#list>
for i in range(1, ${tc.getBeamSearchMaxLength(networkInstruction)}): for i in range(1, ${tc.getBeamSearchMaxLength(networkInstruction)}):
all_candidates = [] all_candidates = []
for seq, score, attention in sequences: for seq, score, seqLossList, attention in sequences:
<#list tc.getUnrollInputNames(networkInstruction, "i") as inputName> <#list tc.getUnrollInputNames(networkInstruction, "i") as inputName>
<#if tc.getNameWithoutIndex(inputName) == tc.outputName> <#if tc.getNameWithoutIndex(inputName) == tc.outputName>
${inputName} = seq[-1] ${inputName} = seq[-1]
...@@ -52,6 +52,7 @@ ...@@ -52,6 +52,7 @@
<#list tc.getUnrollOutputNames(networkInstruction, "i") as outputName> <#list tc.getUnrollOutputNames(networkInstruction, "i") as outputName>
<#if tc.getNameWithoutIndex(outputName) == tc.outputName> <#if tc.getNameWithoutIndex(outputName) == tc.outputName>
out = ${outputName} out = ${outputName}
newLossList = seqLossList + [loss_function(${outputName}, labels[${tc.getIndex(outputName, true)}])]
</#if> </#if>
</#list> </#list>
...@@ -62,9 +63,9 @@ ...@@ -62,9 +63,9 @@
currentScore = mx.nd.slice_axis(out, axis=1, begin=top_index, end=top_index+1) currentScore = mx.nd.slice_axis(out, axis=1, begin=top_index, end=top_index+1)
newScore = mx.nd.expand_dims(score.squeeze() * currentScore.squeeze(), axis=1) newScore = mx.nd.expand_dims(score.squeeze() * currentScore.squeeze(), axis=1)
<#if tc.isAttentionNetwork()> <#if tc.isAttentionNetwork()>
candidate = (seq + [j], newScore, attention + [attention_]) candidate = (seq + [j], newScore, newLossList, attention + [attention_])
<#else> <#else>
candidate = (seq + [j], newScore, attention + []) candidate = (seq + [j], newScore, newLossList, attention + [])
</#if> </#if>
all_candidates.append(candidate) all_candidates.append(candidate)
...@@ -72,18 +73,21 @@ ...@@ -72,18 +73,21 @@
newSequences = [] newSequences = []
for batch_entry in range(batch_size): for batch_entry in range(batch_size):
ordered.append([]) ordered.append([])
batchCandidate = [([seq[batch_entry] for seq in candidate[0]], candidate[1][batch_entry], [attention[batch_entry].expand_dims(axis=0) for attention in candidate[2]]) for candidate in all_candidates] batchCandidate = [([seq[batch_entry] for seq in candidate[0]], candidate[1][batch_entry], [seq[batch_entry] for seq in candidate[2]], [attention[batch_entry].expand_dims(axis=0) for attention in candidate[3]]) for candidate in all_candidates]
ordered[batch_entry] = sorted(batchCandidate, key=lambda tup: tup[1].asscalar()) ordered[batch_entry] = sorted(batchCandidate, key=lambda tup: tup[1].asscalar())
if batch_entry == 0: if batch_entry == 0:
newSequences = ordered[batch_entry] newSequences = ordered[batch_entry]
else: else:
newSequences = [([mx.nd.concat(newSequences[sequenceIndex][0][seqIndex], ordered[batch_entry][sequenceIndex][0][seqIndex], dim=0) for seqIndex in range(len(newSequences[sequenceIndex][0]))], newSequences = [([mx.nd.concat(newSequences[sequenceIndex][0][seqIndex], ordered[batch_entry][sequenceIndex][0][seqIndex], dim=0) for seqIndex in range(len(newSequences[sequenceIndex][0]))],
mx.nd.concat(newSequences[sequenceIndex][1], ordered[batch_entry][sequenceIndex][1], dim=0), mx.nd.concat(newSequences[sequenceIndex][1], ordered[batch_entry][sequenceIndex][1], dim=0),
[mx.nd.concat(newSequences[sequenceIndex][2][attentionIndex], ordered[batch_entry][sequenceIndex][2][attentionIndex], dim=0) for attentionIndex in range(len(newSequences[sequenceIndex][2]))]) [mx.nd.concat(newSequences[sequenceIndex][2][lossIndex], ordered[batch_entry][sequenceIndex][2][lossIndex], dim=0) for lossIndex in range(len(newSequences[sequenceIndex][2]))],
[mx.nd.concat(newSequences[sequenceIndex][3][attentionIndex], ordered[batch_entry][sequenceIndex][3][attentionIndex], dim=0) for attentionIndex in range(len(newSequences[sequenceIndex][3]))])
for sequenceIndex in range(len(newSequences))] for sequenceIndex in range(len(newSequences))]
newSequences = [([newSequences[sequenceIndex][0][seqIndex].expand_dims(axis=1) for seqIndex in range(len(newSequences[sequenceIndex][0]))], newSequences = [([newSequences[sequenceIndex][0][seqIndex].expand_dims(axis=1) for seqIndex in range(len(newSequences[sequenceIndex][0]))],
newSequences[sequenceIndex][1].expand_dims(axis=1), [newSequences[sequenceIndex][2][attentionIndex] for attentionIndex in range(len(newSequences[sequenceIndex][2]))]) newSequences[sequenceIndex][1].expand_dims(axis=1),
newSequences[sequenceIndex][2],
[newSequences[sequenceIndex][3][attentionIndex] for attentionIndex in range(len(newSequences[sequenceIndex][3]))])
for sequenceIndex in range(len(newSequences))] for sequenceIndex in range(len(newSequences))]
sequences = newSequences[:][:k] sequences = newSequences[:][:k]
...@@ -93,9 +97,9 @@ ...@@ -93,9 +97,9 @@
<#if tc.getNameWithoutIndex(outputName) == tc.outputName> <#if tc.getNameWithoutIndex(outputName) == tc.outputName>
${outputName} = sequences[0][0][i] ${outputName} = sequences[0][0][i]
outputs.append(${outputName}) outputs.append(${outputName})
lossList.append(loss_function(${outputName}, labels[${tc.getIndex(outputName, true)}])) lossList.append(sequences[0][2][i])
<#if tc.isAttentionNetwork()> <#if tc.isAttentionNetwork()>
attentionList.append(sequences[0][2][i]) attentionList.append(sequences[0][3][i])
</#if> </#if>
</#if> </#if>
</#list> </#list>
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment