Commit 720a8fe8 authored by Sebastian N.'s avatar Sebastian N.
Browse files

Fixed BeamSearch (now compiles)

parent 808a6be5
Pipeline #203352 failed with stages
in 18 seconds
......@@ -28,16 +28,16 @@
int k = ${tc.getBeamSearchWidth(networkInstruction)};
<#list tc.getUnrollInputNames(networkInstruction, "1") as inputName>
<#if tc.getNameWithoutIndex(inputName) == tc.outputName>
vector<pair<vector<vector<float>>, double>> sequences{make_pair(vector<float>{${inputName}}, 1.0)};
vector<pair<vector<vector<float>>, double>> sequences{make_pair(vector<vector<float>>{${inputName}}, 1.0)};
</#if>
</#list>
for (size_t i = 1; i < ${tc.getBeamSearchMaxLength(networkInstruction)}; ++i) {
vector<pair<vector<vector<float>>, double>> allCandidates;
for (const pair<vector<vector<float>>& pair : sequences) {
vector<vector<float>> seq = pair.first;
double score = pair.second;
for (const pair<vector<vector<float>>, double>& p : sequences) {
vector<vector<float>> seq = p.first;
double score = p.second;
<#list tc.getUnrollInputNames(networkInstruction, "i") as inputName>
<#if tc.getNameWithoutIndex(inputName) == tc.outputName>
......@@ -47,7 +47,7 @@
_predictor_${networkInstruction?index}_.predict(${tc.join(tc.getUnrollInputNames(networkInstruction, "i"), ", ")}, ${tc.join(tc.getUnrollOutputNames(networkInstruction, "i"), ", ")});
<#list tc.getUnrollOutputNames(networkInstruction, "i") as outputName>
<#if tc.getNameWithoutIndex(outputName) == tc.outputName>
vector<float>& out = ${outputName};
vector<float> out = ${outputName};
</#if>
</#list>
......@@ -58,12 +58,12 @@
sort(topk.begin(), topk.end(), [] (const pair<int, float>& p1, const pair<int, float>& p2) {
return p1.second > p2.second;
};
});
topk = vector<pair<int, float>>(topk.begin(), topk.begin() + std::min<int>(k, topk.size()));
for (const pair<int, float>& pair : topk) {
vector<vector<float>> currentSeq = seq;
currentSeq.push_back(vector<float>{pair.first});
currentSeq.push_back(vector<float>{(float) pair.first});
allCandidates.emplace_back(currentSeq, score * pair.second);
}
}
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment