Commit 25b29329 authored by Christian Fuß's avatar Christian Fuß
Browse files

added BeamSearch funtionality, not yet fully working

parent 9f516b40
Pipeline #183247 failed with stages
......@@ -25,6 +25,7 @@ import de.monticore.lang.monticar.cnnarch.generator.CNNArchTemplateController;
import de.monticore.lang.monticar.cnnarch._symboltable.*;
import de.monticore.lang.monticar.cnnarch.generator.TemplateConfiguration;
import de.monticore.lang.monticar.cnnarch.predefined.AllPredefinedLayers;
import java.io.Writer;
import java.util.*;
......@@ -315,4 +316,8 @@ public class CNNArch2GluonTemplateController extends CNNArchTemplateController {
return members;
}
public int getBeamSearchWidth(UnrollInstructionSymbol unroll){
return unroll.getIntValue(AllPredefinedLayers.WIDTH_NAME).get();
}
}
......@@ -182,6 +182,31 @@ class ${tc.fileNameWithoutEnding}:
]
#TODO still needs testing, currently one path will always end up with p ~ 1.0
def applyBeamSearch(input, depth, max_width, currProb, netIndex, bestOutput):
bestProb = 0.0
while depth < max_width:
depth += 1
for beam in input:
top_k = mx.nd.topk(beam, axis=0, k=2)
top_k_values = mx.nd.topk(beam, ret_typ='value', axis=0, k=2)
for index in range(top_k.size):
#print mx.nd.array(top_k[index])
#print mx.nd.array(top_k_values[index])
if depth == 1:
result = applyBeamSearch(self._networks[netIndex](mx.nd.array(top_k[index])), depth, max_width, currProb * top_k_values[index], netIndex, self._networks[netIndex](mx.nd.array(top_k[index])))
else:
result = applyBeamSearch(self._networks[netIndex](mx.nd.array(top_k[index])), depth, max_width, currProb * top_k_values[index], netIndex, bestOutput)
if depth == max_width:
#print currProb
if currProb > bestProb:
bestProb = currProb
bestOutput = result
#print bestOutput
#print bestProb
return bestOutput
if True: <#-- Fix indentation -->
<#include "pythonExecute.ftl">
......@@ -197,7 +222,6 @@ class ${tc.fileNameWithoutEnding}:
else:
predictions.append(output_name)
<#include "elements/BeamSearch.ftl">
metric.update(preds=predictions, labels=labels)
train_metric_score = metric.get()[1]
......
......@@ -9,9 +9,13 @@
<#if networkInstruction.isUnroll()>
<#list networkInstruction.toUnrollInstruction().resolvedBodies as resolvedBody>
${tc.join(tc.getStreamOutputNames(networkInstruction.body, resolvedBody), ", ")} = self._networks[${networkInstruction?index}](${tc.join(tc.getStreamInputNames(networkInstruction.body, resolvedBody), ", ")})
<#if networkInstruction.name == "BeamSearch">
input = ${tc.join(tc.getStreamOutputNames(networkInstruction.body, resolvedBody), ", ")}
${tc.getStreamOutputNames(networkInstruction.body, resolvedBody)[0]} = applyBeamSearch(input, 0, ${tc.getBeamSearchWidth(networkInstruction.toUnrollInstruction())}, 1.0, ${networkInstruction?index}, input)
</#if>
<#list resolvedBody.elements as element>
<#if element.name == "ArgMax">
${tc.join(tc.getStreamOutputNames(networkInstruction.body, resolvedBody), ", ")} = mx.nd.argmax(${tc.join(tc.getStreamOutputNames(networkInstruction.body, resolvedBody), ", ")}, axis=1)
${tc.getStreamOutputNames(networkInstruction.body, resolvedBody)[0]} = mx.nd.argmax(${tc.getStreamOutputNames(networkInstruction.body, resolvedBody)[0]}, axis=1)
</#if>
</#list>
</#list>
......@@ -20,7 +24,7 @@
${tc.join(tc.getStreamOutputNames(networkInstruction.body), ", ")} = self._networks[${networkInstruction?index}](${tc.join(tc.getStreamInputNames(networkInstruction.body), ", ")})
<#list networkInstruction.body.elements as element>
<#if element.name == "ArgMax">
${tc.join(tc.getStreamOutputNames(networkInstruction.body), ", ")} = mx.nd.argmax(${tc.join(tc.getStreamOutputNames(networkInstruction.body), ", ")}, axis=1)
${tc.getStreamOutputNames(networkInstruction.body)[0]} = mx.nd.argmax(${tc.getStreamOutputNames(networkInstruction.body)[0]}, axis=1)
</#if>
</#list>
<#else>
......
......@@ -13,17 +13,19 @@
lossList.append(loss_function(${tc.getStreamOutputNames(networkInstruction.body, resolvedBody)[0]}, ${tc.getStreamOutputNames(networkInstruction.body, resolvedBody)[0]}label))
<#list resolvedBody.elements as element>
<#if element.name == "ArgMax">
${tc.join(tc.getStreamOutputNames(networkInstruction.body, resolvedBody), ", ")} = mx.nd.argmax(${tc.join(tc.getStreamOutputNames(networkInstruction.body, resolvedBody), ", ")}, axis=1)
${tc.getStreamOutputNames(networkInstruction.body, resolvedBody)[0]} = mx.nd.argmax(${tc.getStreamOutputNames(networkInstruction.body, resolvedBody)[0]}, axis=1)
</#if>
</#list>
</#list>
<#else>
<#if networkInstruction.body.isTrainable()>
${tc.join(tc.getStreamOutputNames(networkInstruction.body), ", ")} = self._networks[${networkInstruction?index}](${tc.join(tc.getStreamInputNames(networkInstruction.body), ", ")})
lossList.append(loss_function(${tc.getStreamOutputNames(networkInstruction.body, resolvedBody)[0]}, ${tc.getStreamOutputNames(networkInstruction.body, resolvedBody)[0]}label))
<#if tc.getStreamOutputNames(networkInstruction.body)[0] != "encoder_output_">
lossList.append(loss_function(${tc.getStreamOutputNames(networkInstruction.body)[0]}, ${tc.getStreamOutputNames(networkInstruction.body)[0]}label))
</#if>
<#list networkInstruction.body.elements as element>
<#if element.name == "ArgMax">
${tc.join(tc.getStreamOutputNames(networkInstruction.body), ", ")} = mx.nd.argmax(${tc.join(tc.getStreamOutputNames(networkInstruction.body), ", ")}, axis=1)
${tc.getStreamOutputNames(networkInstruction.body)[0]} = mx.nd.argmax(${tc.getStreamOutputNames(networkInstruction.body)[0]}, axis=1)
</#if>
</#list>
<#else>
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment