Skip to content
Snippets Groups Projects
Commit ed7b7277 authored by Christian Fuß's avatar Christian Fuß
Browse files

added parts of BeamSEarch for C++. Fixed bug with LSTM states not working as network inputs

parent 2cfa59f0
No related branches found
No related tags found
1 merge request!23Added Unroll-related features and layers
Pipeline #184436 failed
......@@ -84,6 +84,9 @@ public class CNNArch2Gluon extends CNNArchGenerator {
temp = controller.process("CNNSupervisedTrainer", Target.PYTHON);
fileContentMap.put(temp.getKey(), temp.getValue());
temp = controller.process("BeamSearch", Target.CPP);
fileContentMap.put(temp.getKey().replace(".h", ""), temp.getValue());
temp = controller.process("execute", Target.CPP);
fileContentMap.put(temp.getKey().replace(".h", ""), temp.getValue());
......
......@@ -155,8 +155,28 @@ public class CNNArch2GluonTemplateController extends CNNArchTemplateController {
return inputNames;
}
public Collection<List<String>> getStreamInputDimensions(SerialCompositeElementSymbol stream) {
return getStreamInputs(stream).values();
public Collection<List<String>> getStreamInputDimensions(SerialCompositeElementSymbol stream, boolean useStateDim) {
if(useStateDim) {
return getStreamInputs(stream).values();
}else{
Set<String> names = getStreamInputs(stream).keySet();
List<List<String>> dims = new ArrayList<List<String>>(getStreamInputs(stream).values());
List<List<String>> result = new ArrayList<List<String>>();
int index = 0;
for (String name : names) {
if (name.endsWith("_state_")) {
ArrayList dim = new ArrayList<String>();
dim.add("-1");
dim.add(name.replace("_state_", "_output_"));
result.add(dim);
} else {
result.add(dims.get(index));
}
index++;
}
return result;
}
}
public Set<String> getStreamOutputNames(SerialCompositeElementSymbol stream) {
......@@ -188,11 +208,23 @@ public class CNNArch2GluonTemplateController extends CNNArchTemplateController {
}
// Used to initialize all layer variable members which are passed through the networks
public Map<String, List<String>> getLayerVariableMembers(String batchSize) {
Map<String, List<String>> members = new LinkedHashMap<>();
public Map<String, List<List<String>>> getLayerVariableMembers(String batchSize) {
Map<String, List<List<String>>> members = new LinkedHashMap<>();
int index = 0;
for (SerialCompositeElementSymbol stream : getArchitecture().getStreams()) {
members.putAll(getStreamLayerVariableMembers(stream, batchSize, true));
List<List<String>> value = new ArrayList<>();
Map<String, List<String>> member = getStreamLayerVariableMembers(stream, batchSize, true);
for (List<String> entry: member.values()){
value.add(entry);
ArrayList<String> streamIndex = new ArrayList<String>();
streamIndex.add(Integer.toString(index));
value.add(streamIndex);
}
for(String name: member.keySet()){
members.put(name, value);
}
index++;
}
return members;
......@@ -251,7 +283,7 @@ public class CNNArch2GluonTemplateController extends CNNArchTemplateController {
}
private Map<String, List<String>> getStreamLayerVariableMembers(SerialCompositeElementSymbol stream, String batchSize, boolean includeOutput) {
Map<String, List<String>> members = new HashMap<>();
Map<String, List<String>> members = new LinkedHashMap<>();
List<ArchitectureElementSymbol> elements = stream.getSpannedScope().resolveLocally(ArchitectureElementSymbol.KIND);
for (ArchitectureElementSymbol element : elements) {
......
vector<float> applyBeamSearch(vector<float> input, int depth, int width, int maxDepth, double currProb, int netIndex, vector<float> bestOutput)
{
double bestProb = 0.0;
while (depth < maxDepth){
depth ++;
int batchIndex = 0;
for batchEntry in input:
top_k_indices = mx.nd.topk(batchEntry, axis=0, k=width);
top_k_values = mx.nd.topk(batchEntry, ret_typ="value", axis=0, k=width);
for index in range(top_k_indices.size):
/*print mx.nd.array(top_k_indices[index]) */
/*print top_k_values[index] */
if depth == 1:
/*print mx.nd.array(top_k_indices[index]) */
result = applyBeamSearch(self._networks[netIndex](mx.nd.array(top_k_indices[index])), depth, width, maxDepth,
currProb * top_k_values[index], netIndex, self._networks[netIndex](mx.nd.array(top_k_indices[index])));
else:
result = applyBeamSearch(self._networks[netIndex](mx.nd.array(top_k_indices[index])), depth, width, maxDepth,
currProb * top_k_values[index], netIndex, bestOutput);
if depth == maxDepth:
/*print currProb */
if currProb > bestProb:
bestProb = currProb;
bestOutput[batchIndex] = result[batchIndex];
/*print "new bestOutput: ", bestOutput */
batchIndex ++;
}
/*print bestOutput; */
/*cout << bestProb; */
return bestOutput;
}
\ No newline at end of file
......@@ -57,7 +57,8 @@ class ${tc.fileNameWithoutEnding}:
self.networks[${networkInstruction?index}] = Net_${networkInstruction?index}(data_mean=data_mean, data_std=data_std)
self.networks[${networkInstruction?index}].collect_params().initialize(self.weight_initializer, ctx=context)
self.networks[${networkInstruction?index}].hybridize()
self.networks[${networkInstruction?index}](<#list tc.getStreamInputDimensions(networkInstruction.body) as dimensions>mx.nd.zeros((${tc.join(dimensions, ",")},), ctx=context)<#sep>, </#list>)
self.networks[${networkInstruction?index}](<#list tc.getStreamInputDimensions(networkInstruction.body, false) as dimensions>
<#if dimensions[0] == "-1">self.networks[${networkInstruction?index}].${dimensions[1]}.begin_state(batch_size=1, ctx=context)<#else>mx.nd.zeros((${tc.join(dimensions, ",")},), ctx=context)</#if> <#sep>, </#list>)
</#if>
</#list>
......
......@@ -22,7 +22,7 @@ public:
<#list tc.getStreamInputNames(networkInstruction.body) as variable>"data${variable?index}"<#sep>, </#list>
</#if>
};
const std::vector<std::vector<mx_uint>> input_shapes = {<#list tc.getStreamInputDimensions(networkInstruction.body) as dimensions>{${tc.join(dimensions, ", ")}}<#sep>, </#list>};
const std::vector<std::vector<mx_uint>> input_shapes = {<#list tc.getStreamInputDimensions(networkInstruction.body, true) as dimensions>{${tc.join(dimensions, ", ")}}<#sep>, </#list>};
const bool use_gpu = false;
PredictorHandle handle;
......
......@@ -133,7 +133,7 @@ class ${tc.fileNameWithoutEnding}:
train_iter.reset()
for batch_i, batch in enumerate(train_iter):
<#list tc.architectureInputs as input_name>
${input_name} = batch.data[${input_name?index}].as_in_context(mx_context)
${input_name} = batch.data[0].as_in_context(mx_context)
</#list>
<#list tc.architectureOutputs as output_name>
${output_name}label = batch.label[${output_name?index}].as_in_context(mx_context)
......@@ -172,7 +172,7 @@ class ${tc.fileNameWithoutEnding}:
metric = mx.metric.create(eval_metric)
for batch_i, batch in enumerate(train_iter):
<#list tc.architectureInputs as input_name>
${input_name} = batch.data[${input_name?index}].as_in_context(mx_context)
${input_name} = batch.data[0].as_in_context(mx_context)
</#list>
labels = [
......@@ -239,7 +239,7 @@ class ${tc.fileNameWithoutEnding}:
metric = mx.metric.create(eval_metric)
for batch_i, batch in enumerate(test_iter):
<#list tc.architectureInputs as input_name>
${input_name} = batch.data[${input_name?index}].as_in_context(mx_context)
${input_name} = batch.data[0].as_in_context(mx_context)
</#list>
labels = [
......
......@@ -2,7 +2,7 @@
vector<float> ${tc.getName(input)} = CNNTranslator::translate(${input.name}<#if input.arrayAccess.isPresent()>[${input.arrayAccess.get().intValue.get()?c}]</#if>);
</#list>
<#list tc.getLayerVariableMembers("1")?keys as member>
vector<float> ${member}(${tc.join(tc.getLayerVariableMembers("1")[member], " * ")})
vector<float> ${member}(${tc.join(tc.getLayerVariableMembers("1")[member][0], " * ")});
</#list>
<#list tc.architectureOutputSymbols as output>
......@@ -18,7 +18,7 @@
<#if networkInstruction.body.isTrainable()>
_predictor_${networkInstruction?index}_.predict(${tc.join(tc.getStreamInputNames(networkInstruction.body), ", ")}, ${tc.join(tc.getStreamOutputNames(networkInstruction.body), ", ")});
<#else>
${tc.include(networkInstruction.body, "CPP_INLINE")}
<#-- ${tc.include(networkInstruction.body, "CPP_INLINE")}; -->
</#if>
</#if>
</#list>
......
<#list tc.getLayerVariableMembers("batch_size")?keys as member>
${member} = mx.nd.zeros((${tc.join(tc.getLayerVariableMembers("batch_size")[member], ", ")},), ctx=mx_context)
<#if member?ends_with("_state_")>
encoder_state_ = self._networks[${tc.getLayerVariableMembers("batch_size")[member][1][0]}].${member?replace("_state_","_output_")}.begin_state(batch_size=0, ctx=mx_context)
<#else>
${member} = mx.nd.zeros((${tc.join(tc.getLayerVariableMembers("batch_size")[member][0], ", ")},), ctx=mx_context)
</#if>
</#list>
<#list tc.architectureOutputSymbols as output>
${tc.getName(output)} = mx.nd.zeros((batch_size, ${tc.join(output.ioDeclaration.type.dimensions, ", ")},), ctx=mx_context)
......
<#list tc.getLayerVariableMembers("batch_size")?keys as member>
${member} = mx.nd.zeros((${tc.join(tc.getLayerVariableMembers("batch_size")[member], ", ")},), ctx=mx_context)
<#if member?ends_with("_state_")>
encoder_state_ = self._networks[${tc.getLayerVariableMembers("batch_size")[member][1][0]}].${member?replace("_state_","_output_")}.begin_state(batch_size=0, ctx=mx_context)
<#else>
${member} = mx.nd.zeros((${tc.join(tc.getLayerVariableMembers("batch_size")[member][0], ", ")},), ctx=mx_context)
</#if>
</#list>
<#list tc.architectureOutputSymbols as output>
${tc.getName(output)} = mx.nd.zeros((batch_size, ${tc.join(output.ioDeclaration.type.dimensions, ", ")},), ctx=mx_context)
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment