Commit 9e3f0875 authored by Sebastian N.'s avatar Sebastian N.
Browse files

Merge branch 'rnn' of...

Merge branch 'rnn' of git.rwth-aachen.de:monticore/EmbeddedMontiArc/generators/CNNArch2Gluon into rnn
parents c1c4c3ff f80449ff
......@@ -137,13 +137,25 @@ public class CNNArch2GluonTemplateController extends CNNArchTemplateController {
include(architectureElement, getWriter(), netDefinitionMode);
}
public Set<String> getStreamInputNames(SerialCompositeElementSymbol stream) {
return getStreamInputs(stream).keySet();
public Set<String> getStreamInputNames(SerialCompositeElementSymbol stream, boolean addStateIndex) {
if(addStateIndex) {
Set<String> names = getStreamInputs(stream, addStateIndex).keySet();
Set<String> newNames = new LinkedHashSet<>();
for (String name : names) {
// if LSTM state, transform name into list of hidden state and cell state
if (name.endsWith("_state_")) {
name = "[" + name + "[0], " + name + "[1]]";
}
newNames.add(name);
}
return newNames;
}
return getStreamInputs(stream, addStateIndex).keySet();
}
// used for unroll
public List<String> getStreamInputNames(SerialCompositeElementSymbol stream, SerialCompositeElementSymbol currentStream) {
List<String> inputNames = new LinkedList<>(getStreamInputNames(stream));
public List<String> getStreamInputNames(SerialCompositeElementSymbol stream, SerialCompositeElementSymbol currentStream, boolean addStateIndex) {
List<String> inputNames = new LinkedList<>(getStreamInputNames(stream, addStateIndex));
Map<String, String> pairs = getUnrollPairs(stream, currentStream);
for (int i = 0; i != inputNames.size(); ++i) {
......@@ -157,19 +169,19 @@ public class CNNArch2GluonTemplateController extends CNNArchTemplateController {
public Collection<List<String>> getStreamInputDimensions(SerialCompositeElementSymbol stream, boolean useStateDim) {
if(useStateDim) {
return getStreamInputs(stream).values();
return getStreamInputs(stream, false).values();
}else{
Set<String> names = getStreamInputs(stream).keySet();
List<List<String>> dims = new ArrayList<List<String>>(getStreamInputs(stream).values());
Set<String> names = getStreamInputs(stream, true).keySet();
List<List<String>> dims = new ArrayList<List<String>>(getStreamInputs(stream, false).values());
List<List<String>> result = new ArrayList<List<String>>();
int index = 0;
for (String name : names) {
if (name.endsWith("_state_")) {
if (name.endsWith("_state_") || name.endsWith("_state_[0]")) {
ArrayList dim = new ArrayList<String>();
dim.add("-1");
dim.add(name.replace("_state_", "_output_"));
dim.add(name.replace("_state_", "_output_.begin_state(batch_size=1, ctx=context)"));
result.add(dim);
} else {
}else{
result.add(dims.get(index));
}
index++;
......@@ -188,7 +200,7 @@ public class CNNArch2GluonTemplateController extends CNNArchTemplateController {
}
}
outputNames.addAll(getStreamLayerVariableMembers(stream, "1", true, false).keySet());
outputNames.addAll(getStreamLayerVariableMembers(stream, "1", true, false, false).keySet());
return outputNames;
}
......@@ -214,7 +226,7 @@ public class CNNArch2GluonTemplateController extends CNNArchTemplateController {
int index = 0;
for (SerialCompositeElementSymbol stream : getArchitecture().getStreams()) {
List<List<String>> value = new ArrayList<>();
Map<String, List<String>> member = getStreamLayerVariableMembers(stream, batchSize, true, includeStates);
Map<String, List<String>> member = getStreamLayerVariableMembers(stream, batchSize, true, includeStates, false);
for (List<String> entry: member.values()){
value.add(entry);
ArrayList<String> streamIndex = new ArrayList<String>();
......@@ -260,7 +272,7 @@ public class CNNArch2GluonTemplateController extends CNNArchTemplateController {
return pairs;
}
private Map<String, List<String>> getStreamInputs(SerialCompositeElementSymbol stream) {
private Map<String, List<String>> getStreamInputs(SerialCompositeElementSymbol stream, boolean addStateIndex) {
Map<String, List<String>> inputs = new LinkedHashMap<>();
for (ArchitectureElementSymbol element : stream.getFirstAtomicElements()) {
......@@ -279,12 +291,12 @@ public class CNNArch2GluonTemplateController extends CNNArchTemplateController {
}
}
inputs.putAll(getStreamLayerVariableMembers(stream, "1", false, false));
inputs.putAll(getStreamLayerVariableMembers(stream, "1", false, false, addStateIndex));
return inputs;
}
private Map<String, List<String>> getStreamLayerVariableMembers(SerialCompositeElementSymbol stream, String batchSize, boolean includeOutput, boolean includeStates) {
private Map<String, List<String>> getStreamLayerVariableMembers(SerialCompositeElementSymbol stream, String batchSize, boolean includeOutput, boolean includeStates, boolean addStateIndex) {
Map<String, List<String>> members = new LinkedHashMap<>();
List<ArchitectureElementSymbol> elements = stream.getSpannedScope().resolveLocally(ArchitectureElementSymbol.KIND);
......@@ -300,7 +312,12 @@ public class CNNArch2GluonTemplateController extends CNNArchTemplateController {
(PredefinedLayerDeclaration) layerVariableDeclaration.getLayer().getDeclaration();
if (predefinedLayerDeclaration.isValidMember(VariableSymbol.Member.STATE)) {
String name = variable.getName() + "_state_";
String name;
if(addStateIndex && predefinedLayerDeclaration.getName().equals(AllPredefinedLayers.GRU_NAME)){
name = variable.getName() + "_state_[0]";
}else{
name = variable.getName() + "_state_";
}
List<Integer> intDimensions = predefinedLayerDeclaration.computeOutputTypes(
layerVariableDeclaration.getLayer().getInputTypes(),
......
......@@ -58,7 +58,7 @@ class ${tc.fileNameWithoutEnding}:
self.networks[${networkInstruction?index}].collect_params().initialize(self.weight_initializer, ctx=context)
self.networks[${networkInstruction?index}].hybridize()
self.networks[${networkInstruction?index}](<#list tc.getStreamInputDimensions(networkInstruction.body, false) as dimensions>
<#if dimensions[0] == "-1">self.networks[${networkInstruction?index}].${dimensions[1]}.begin_state(batch_size=1, ctx=context)[0]<#else>mx.nd.zeros((${tc.join(dimensions, ",")},), ctx=context)</#if> <#sep>, </#list>)
<#if dimensions[0] == "-1">self.networks[${networkInstruction?index}].${dimensions[1]}<#else>mx.nd.zeros((${tc.join(dimensions, ",")},), ctx=context)</#if> <#sep>, </#list>)
</#if>
</#list>
......
......@@ -47,6 +47,7 @@ class ${tc.fileNameWithoutEnding}:
test_label[index] = test_h5[output_name]
index += 1
test_iter = mx.io.NDArrayIter(data=test_data,
label=test_label,
batch_size=batch_size)
......
......@@ -131,7 +131,7 @@ class Net_${networkInstruction?index}(gluon.HybridBlock):
with self.name_scope():
${tc.include(networkInstruction.body, "ARCHITECTURE_DEFINITION")}
def hybrid_forward(self, F, ${tc.join(tc.getStreamInputNames(networkInstruction.body), ", ")}):
def hybrid_forward(self, F, ${tc.join(tc.getStreamInputNames(networkInstruction.body, false), ", ")}):
${tc.include(networkInstruction.body, "FORWARD_FUNCTION")}
return ${tc.join(tc.getStreamOutputNames(networkInstruction.body), ", ")}
......
......@@ -16,10 +16,10 @@ public:
const std::string json_file = "model/${tc.componentName}/model_${networkInstruction?index}_newest-symbol.json";
const std::string param_file = "model/${tc.componentName}/model_${networkInstruction?index}_newest-0000.params";
const std::vector<std::string> input_keys = {
<#if tc.getStreamInputNames(networkInstruction.body)?size == 1>
<#if tc.getStreamInputNames(networkInstruction.body, false)?size == 1>
"data"
<#else>
<#list tc.getStreamInputNames(networkInstruction.body) as variable>"data${variable?index}"<#sep>, </#list>
<#list tc.getStreamInputNames(networkInstruction.body, false) as variable>"data${variable?index}"<#sep>, </#list>
</#if>
};
const std::vector<std::vector<mx_uint>> input_shapes = {<#list tc.getStreamInputDimensions(networkInstruction.body, true) as dimensions>{${tc.join(dimensions, ", ")}}<#sep>, </#list>};
......@@ -35,9 +35,9 @@ public:
if(handle) MXPredFree(handle);
}
void predict(${tc.join(tc.getStreamInputNames(networkInstruction.body), ", ", "const std::vector<float> &in_", "")},
void predict(${tc.join(tc.getStreamInputNames(networkInstruction.body, false), ", ", "const std::vector<float> &in_", "")},
${tc.join(tc.getStreamOutputNames(networkInstruction.body), ", ", "std::vector<float> &out_", "")}){
<#list tc.getStreamInputNames(networkInstruction.body) as variable>
<#list tc.getStreamInputNames(networkInstruction.body, false) as variable>
MXPredSetInput(handle, input_keys[${variable?index}].c_str(), in_${variable}.data(), static_cast<mx_uint>(in_${variable}.size()));
</#list>
......
......@@ -279,7 +279,7 @@ class ${tc.fileNameWithoutEnding}:
train_iter.reset()
for batch_i, batch in enumerate(train_iter):
<#list tc.architectureInputs as input_name>
${input_name} = batch.data[${input_name?index}].as_in_context(mx_context)
${input_name} = batch.data[0].as_in_context(mx_context)
</#list>
<#list tc.architectureOutputs as output_name>
${output_name}label = batch.label[${output_name?index}].as_in_context(mx_context)
......@@ -318,7 +318,7 @@ class ${tc.fileNameWithoutEnding}:
metric = mx.metric.create(eval_metric)
for batch_i, batch in enumerate(train_iter):
<#list tc.architectureInputs as input_name>
${input_name} = batch.data[${input_name?index}].as_in_context(mx_context)
${input_name} = batch.data[0].as_in_context(mx_context)
</#list>
labels = [
......@@ -348,7 +348,7 @@ class ${tc.fileNameWithoutEnding}:
metric = mx.metric.create(eval_metric)
for batch_i, batch in enumerate(test_iter):
<#list tc.architectureInputs as input_name>
${input_name} = batch.data[${input_name?index}].as_in_context(mx_context)
${input_name} = batch.data[0].as_in_context(mx_context)
</#list>
labels = [
......
......@@ -6,11 +6,7 @@
<#elseif mode == "FORWARD_FUNCTION">
${element.name} = self.${element.name}(${input})
<#elseif mode == "PYTHON_INLINE">
<#if input?ends_with("_state_")>
${element.name} = mx.nd.split(data=${input}[0], axis=0, num_outputs=${num_outputs})
<#else>
${element.name} = mx.nd.split(data=${input}, axis=0, num_outputs=${num_outputs})
</#if>
<#elseif mode == "CPP_INLINE">
${element.name} = ${input}
</#if>
\ No newline at end of file
......@@ -12,11 +12,11 @@
<#list tc.architecture.networkInstructions as networkInstruction>
<#if networkInstruction.isUnroll()>
<#list networkInstruction.toUnrollInstruction().resolvedBodies as resolvedBody>
_predictor_${networkInstruction?index}_.predict(${tc.join(tc.getStreamInputNames(networkInstruction.body, resolvedBody), ", ")}, ${tc.join(tc.getStreamOutputNames(networkInstruction.body, resolvedBody), ", ")});
_predictor_${networkInstruction?index}_.predict(${tc.join(tc.getStreamInputNames(networkInstruction.body, resolvedBody, false), ", ")}, ${tc.join(tc.getStreamOutputNames(networkInstruction.body, resolvedBody), ", ")});
</#list>
<#else>
<#if networkInstruction.body.isTrainable()>
_predictor_${networkInstruction?index}_.predict(${tc.join(tc.getStreamInputNames(networkInstruction.body), ", ")}, ${tc.join(tc.getStreamOutputNames(networkInstruction.body), ", ")});
_predictor_${networkInstruction?index}_.predict(${tc.join(tc.getStreamInputNames(networkInstruction.body, false), ", ")}, ${tc.join(tc.getStreamOutputNames(networkInstruction.body), ", ")});
<#else>
<#-- ${tc.include(networkInstruction.body, "CPP_INLINE")}; -->
</#if>
......
......@@ -14,12 +14,12 @@
<#if networkInstruction.isUnroll()>
<#list networkInstruction.toUnrollInstruction().resolvedBodies as resolvedBody>
<#if networkInstruction.name == "BeamSearch">
input = ${tc.join(tc.getStreamInputNames(networkInstruction.body, resolvedBody), ", ")}
input = ${tc.join(tc.getStreamInputNames(networkInstruction.body, resolvedBody, true), ", ")}
<#assign length = tc.getBeamSearchLength(networkInstruction.toUnrollInstruction())>
<#assign width = tc.getBeamSearchWidth(networkInstruction.toUnrollInstruction())>
${tc.getStreamOutputNames(networkInstruction.body, resolvedBody)[0]} = applyBeamSearch(input, 0, ${length}, ${width}, 1.0, ${networkInstruction?index}, input)
<#else>
${tc.join(tc.getStreamOutputNames(networkInstruction.body, resolvedBody), ", ")} = self._networks[${networkInstruction?index}](${tc.join(tc.getStreamInputNames(networkInstruction.body, resolvedBody), ", ")?replace("_state_","_state_[0]")})
${tc.join(tc.getStreamOutputNames(networkInstruction.body, resolvedBody), ", ")} = self._networks[${networkInstruction?index}](${tc.join(tc.getStreamInputNames(networkInstruction.body, resolvedBody, true), ", ")?replace("_state_","_state_")})
<#if !(tc.getStreamOutputNames(networkInstruction.body, resolvedBody)[0]?ends_with("_output_"))>
outputs.append(${tc.getStreamOutputNames(networkInstruction.body, resolvedBody)[0]})
</#if>
......@@ -32,7 +32,7 @@
</#list>
<#else>
<#if networkInstruction.body.isTrainable()>
${tc.join(tc.getStreamOutputNames(networkInstruction.body), ", ")} = self._networks[${networkInstruction?index}](${tc.join(tc.getStreamInputNames(networkInstruction.body), ", ")?replace("_state_","_state_[0]")})
${tc.join(tc.getStreamOutputNames(networkInstruction.body), ", ")} = self._networks[${networkInstruction?index}](${tc.join(tc.getStreamInputNames(networkInstruction.body, true), ", ")?replace("_state_","_state_")})
<#if !(tc.getStreamOutputNames(networkInstruction.body)[0]?ends_with("_output_"))>
outputs.append(${tc.getStreamOutputNames(networkInstruction.body)[0]})
</#if>
......
......@@ -13,7 +13,7 @@
<#list tc.architecture.networkInstructions as networkInstruction>
<#if networkInstruction.isUnroll()>
<#list networkInstruction.toUnrollInstruction().resolvedBodies as resolvedBody>
${tc.join(tc.getStreamOutputNames(networkInstruction.body, resolvedBody), ", ")} = self._networks[${networkInstruction?index}](${tc.join(tc.getStreamInputNames(networkInstruction.body, resolvedBody), ", ")?replace("_state_","_state_[0]")})
${tc.join(tc.getStreamOutputNames(networkInstruction.body, resolvedBody), ", ")} = self._networks[${networkInstruction?index}](${tc.join(tc.getStreamInputNames(networkInstruction.body, resolvedBody, true), ", ")?replace("_state_","_state_")})
lossList.append(loss_function(${tc.getStreamOutputNames(networkInstruction.body, resolvedBody)[0]}, ${tc.getStreamOutputNames(networkInstruction.body, resolvedBody)[0]}label))
<#list resolvedBody.elements as element>
<#if element.name == "ArgMax">
......@@ -23,7 +23,7 @@
</#list>
<#else>
<#if networkInstruction.body.isTrainable()>
${tc.join(tc.getStreamOutputNames(networkInstruction.body), ", ")} = self._networks[${networkInstruction?index}](${tc.join(tc.getStreamInputNames(networkInstruction.body), ", ")?replace("_state_","_state_[0]")})
${tc.join(tc.getStreamOutputNames(networkInstruction.body), ", ")} = self._networks[${networkInstruction?index}](${tc.join(tc.getStreamInputNames(networkInstruction.body, true), ", ")?replace("_state_","_state_")})
<#if !(tc.getStreamOutputNames(networkInstruction.body)[0]?ends_with("_output_"))>
lossList.append(loss_function(${tc.getStreamOutputNames(networkInstruction.body)[0]}, ${tc.getStreamOutputNames(networkInstruction.body)[0]}label))
</#if>
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment