Commit b0e8514a authored by Christian Fuß's avatar Christian Fuß

fixed a small bug with batch sizes > 2. Added usage of attention images in Beamsearch

parent fd5ad132
Pipeline #213849 failed with stages
in 19 seconds
......@@ -18,7 +18,7 @@
<!-- .. SE-Libraries .................................................. -->
<CNNArch.version>0.3.4-SNAPSHOT</CNNArch.version>
<CNNTrain.version>0.3.8-SNAPSHOT</CNNTrain.version>
<CNNTrain.version>0.3.9-SNAPSHOT</CNNTrain.version>
<CNNArch2X.version>0.0.5-SNAPSHOT</CNNArch2X.version>
<embedded-montiarc-math-opt-generator>0.1.4</embedded-montiarc-math-opt-generator>
<EMADL2PythonWrapper.version>0.0.2-SNAPSHOT</EMADL2PythonWrapper.version>
......
......@@ -30,14 +30,14 @@
k = ${tc.getBeamSearchWidth(networkInstruction)}
<#list tc.getUnrollInputNames(networkInstruction, "1") as inputName>
<#if tc.getNameWithoutIndex(inputName) == tc.outputName>
sequences = [([${inputName}], mx.nd.full((batch_size, 1,), 1.0, ctx=mx_context))]
sequences = [([${inputName}], mx.nd.full((batch_size, 1,), 1.0, ctx=mx_context), [mx.nd.full((batch_size, 64,), 0.0, ctx=mx_context)])]
</#if>
</#list>
for i in range(1, ${tc.getBeamSearchMaxLength(networkInstruction)}):
all_candidates = []
for seq, score in sequences:
for seq, score, attention in sequences:
<#list tc.getUnrollInputNames(networkInstruction, "i") as inputName>
<#if tc.getNameWithoutIndex(inputName) == tc.outputName>
${inputName} = seq[-1]
......@@ -45,7 +45,6 @@
</#list>
<#if tc.isAttentionNetwork()>
${tc.join(tc.getUnrollOutputNames(networkInstruction, "i"), ", ")}, attention_ = self._networks[${networkInstruction?index}](${tc.join(tc.getUnrollInputNames(networkInstruction, "i"), ", ")})
attentionList.append(attention_)
<#else>
${tc.join(tc.getUnrollOutputNames(networkInstruction, "i"), ", ")} = self._networks[${networkInstruction?index}](${tc.join(tc.getUnrollInputNames(networkInstruction, "i"), ", ")})
</#if>
......@@ -61,22 +60,25 @@
j = mx.nd.slice_axis(topk, axis=1, begin=top_index, end=top_index+1)
currentScore = mx.nd.slice_axis(out, axis=1, begin=top_index, end=top_index+1)
newScore = mx.nd.expand_dims(score.squeeze() * currentScore.squeeze(), axis=1)
candidate = (seq + [j], newScore)
<#if tc.isAttentionNetwork()>
candidate = (seq + [j], newScore, attention + [attention_])
<#else>
candidate = (seq + [j], newScore, attention + [])
</#if>
all_candidates.append(candidate)
ordered = []
newSequences = []
for batch_entry in range(batch_size):
ordered.append([])
batchCandidate = [([y[batch_entry] for y in x[0]], x[1][batch_entry]) for x in all_candidates]
batchCandidate = [([y[batch_entry] for y in x[0]], x[1][batch_entry], [y[batch_entry].expand_dims(axis=0) for y in x[2]]) for x in all_candidates]
ordered[batch_entry] = sorted(batchCandidate, key=lambda tup: tup[1].asscalar())
if batch_entry == 0:
newSequences = ordered[batch_entry]
elif batch_entry < (batch_size -1):
newSequences = [([mx.nd.concat(newSequences[x][0][y], ordered[batch_entry][x][0][y], dim=0) for y in range(len(newSequences[x][0]))], mx.nd.concat(newSequences[x][1], ordered[batch_entry][x][1], dim=0)) for x in range(len(newSequences))]
# expand dims only once
else:
newSequences = [([mx.nd.concat(newSequences[x][0][y], ordered[batch_entry][x][0][y], dim=0).expand_dims(axis=1) for y in range(len(newSequences[x][0]))], mx.nd.concat(newSequences[x][1], ordered[batch_entry][x][1], dim=0).expand_dims(axis=1)) for x in range(len(newSequences))]
newSequences = [([mx.nd.concat(newSequences[x][0][y], ordered[batch_entry][x][0][y], dim=0) for y in range(len(newSequences[x][0]))], mx.nd.concat(newSequences[x][1], ordered[batch_entry][x][1], dim=0), [mx.nd.concat(newSequences[x][2][y], ordered[batch_entry][x][2][y], dim=0) for y in range(len(newSequences[x][2]))]) for x in range(len(newSequences))]
newSequences = [([newSequences[x][0][y].expand_dims(axis=1) for y in range(len(newSequences[x][0]))], newSequences[x][1].expand_dims(axis=1), [newSequences[x][2][y] for y in range(len(newSequences[x][2]))]) for x in range(len(newSequences))]
sequences = newSequences[:][:k]
......@@ -85,6 +87,9 @@
<#if tc.getNameWithoutIndex(outputName) == tc.outputName>
${outputName} = sequences[0][0][i]
outputs.append(${outputName})
<#if tc.isAttentionNetwork()>
attentionList.append(sequences[0][2][i])
</#if>
</#if>
</#list>
<#else>
......
Markdown is supported
0%
or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment