Skip to content
Snippets Groups Projects
Commit 1c942e2a authored by Christian Fuß's avatar Christian Fuß
Browse files

made batch sizes > 1 possible for Beamsearch

parent 8aa46be1
No related branches found
No related tags found
1 merge request!25Adjusted BeamSearch to work with batch sizes > 1
......@@ -214,7 +214,7 @@ class ${tc.fileNameWithoutEnding}:
del optimizer_params['learning_rate_decay']
train_batch_size = batch_size
test_batch_size = ${tc.hasUnrollInstructions()?then('1', 'batch_size')}
test_batch_size = ${tc.hasUnrollInstructions()?then('batch_size', 'batch_size')}
train_iter, train_test_iter, test_iter, data_mean, data_std, train_images, test_images = self._data_loader.load_data(train_batch_size, test_batch_size)
......
......@@ -30,7 +30,7 @@
k = ${tc.getBeamSearchWidth(networkInstruction)}
<#list tc.getUnrollInputNames(networkInstruction, "1") as inputName>
<#if tc.getNameWithoutIndex(inputName) == tc.outputName>
sequences = [([${inputName}], 1.0)]
sequences = [([${inputName}], mx.nd.full((batch_size, 1,), 1.0, ctx=mx_context))]
</#if>
</#list>
......@@ -55,14 +55,27 @@
</#if>
</#list>
topk = out.topk(k=k)[0]
topk = out.topk(k=k)
for j in topk:
candidate = (seq + [mx.nd.full((1, 1,), j, ctx=mx_context)], score * out[0][j].asscalar())
for top_index in range(len(topk[0])):
j = mx.nd.slice_axis(topk, axis=1, begin=top_index, end=top_index+1)
currentScore = mx.nd.slice_axis(out, axis=1, begin=top_index, end=top_index+1)
newScore = mx.nd.expand_dims(score.squeeze() * currentScore.squeeze(), axis=1)
candidate = (seq + [j], newScore)
all_candidates.append(candidate)
ordered = sorted(all_candidates, key=lambda tup: tup[1])
sequences = ordered[:k]
ordered = []
newSequences = []
for batch_entry in range(batch_size):
ordered.append([])
batchCandidate = [([y[batch_entry] for y in x[0]], x[1][batch_entry]) for x in all_candidates]
ordered[batch_entry] = sorted(batchCandidate, key=lambda tup: tup[1].asscalar())
if batch_entry == 0:
newSequences = ordered[batch_entry]
else:
newSequences = [([mx.nd.concat(newSequences[x][0][y], ordered[batch_entry][x][0][y], dim=0).expand_dims(axis=1) for y in range(len(newSequences[x][0]))], mx.nd.concat(newSequences[x][1], ordered[batch_entry][x][1], dim=0).expand_dims(axis=1)) for x in range(len(newSequences))]
sequences = newSequences[:][:k]
for i in range(1, ${tc.getBeamSearchMaxLength(networkInstruction)}):
<#list tc.getUnrollOutputNames(networkInstruction, "i") as outputName>
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment