Commit f80449ff authored by Christian Fuß's avatar Christian Fuß
Browse files

solved an issue with LSTMs not getting both hidden and cell state

parent 665375a8
Pipeline #193155 failed with stages
in 18 seconds
...@@ -137,13 +137,25 @@ public class CNNArch2GluonTemplateController extends CNNArchTemplateController { ...@@ -137,13 +137,25 @@ public class CNNArch2GluonTemplateController extends CNNArchTemplateController {
include(architectureElement, getWriter(), netDefinitionMode); include(architectureElement, getWriter(), netDefinitionMode);
} }
public Set<String> getStreamInputNames(SerialCompositeElementSymbol stream) { public Set<String> getStreamInputNames(SerialCompositeElementSymbol stream, boolean addStateIndex) {
return getStreamInputs(stream).keySet(); if(addStateIndex) {
Set<String> names = getStreamInputs(stream, addStateIndex).keySet();
Set<String> newNames = new LinkedHashSet<>();
for (String name : names) {
// if LSTM state, transform name into list of hidden state and cell state
if (name.endsWith("_state_")) {
name = "[" + name + "[0], " + name + "[1]]";
}
newNames.add(name);
}
return newNames;
}
return getStreamInputs(stream, addStateIndex).keySet();
} }
// used for unroll // used for unroll
public List<String> getStreamInputNames(SerialCompositeElementSymbol stream, SerialCompositeElementSymbol currentStream) { public List<String> getStreamInputNames(SerialCompositeElementSymbol stream, SerialCompositeElementSymbol currentStream, boolean addStateIndex) {
List<String> inputNames = new LinkedList<>(getStreamInputNames(stream)); List<String> inputNames = new LinkedList<>(getStreamInputNames(stream, addStateIndex));
Map<String, String> pairs = getUnrollPairs(stream, currentStream); Map<String, String> pairs = getUnrollPairs(stream, currentStream);
for (int i = 0; i != inputNames.size(); ++i) { for (int i = 0; i != inputNames.size(); ++i) {
...@@ -157,19 +169,19 @@ public class CNNArch2GluonTemplateController extends CNNArchTemplateController { ...@@ -157,19 +169,19 @@ public class CNNArch2GluonTemplateController extends CNNArchTemplateController {
public Collection<List<String>> getStreamInputDimensions(SerialCompositeElementSymbol stream, boolean useStateDim) { public Collection<List<String>> getStreamInputDimensions(SerialCompositeElementSymbol stream, boolean useStateDim) {
if(useStateDim) { if(useStateDim) {
return getStreamInputs(stream).values(); return getStreamInputs(stream, false).values();
}else{ }else{
Set<String> names = getStreamInputs(stream).keySet(); Set<String> names = getStreamInputs(stream, true).keySet();
List<List<String>> dims = new ArrayList<List<String>>(getStreamInputs(stream).values()); List<List<String>> dims = new ArrayList<List<String>>(getStreamInputs(stream, false).values());
List<List<String>> result = new ArrayList<List<String>>(); List<List<String>> result = new ArrayList<List<String>>();
int index = 0; int index = 0;
for (String name : names) { for (String name : names) {
if (name.endsWith("_state_")) { if (name.endsWith("_state_") || name.endsWith("_state_[0]")) {
ArrayList dim = new ArrayList<String>(); ArrayList dim = new ArrayList<String>();
dim.add("-1"); dim.add("-1");
dim.add(name.replace("_state_", "_output_")); dim.add(name.replace("_state_", "_output_.begin_state(batch_size=1, ctx=context)"));
result.add(dim); result.add(dim);
} else { }else{
result.add(dims.get(index)); result.add(dims.get(index));
} }
index++; index++;
...@@ -188,7 +200,7 @@ public class CNNArch2GluonTemplateController extends CNNArchTemplateController { ...@@ -188,7 +200,7 @@ public class CNNArch2GluonTemplateController extends CNNArchTemplateController {
} }
} }
outputNames.addAll(getStreamLayerVariableMembers(stream, "1", true, false).keySet()); outputNames.addAll(getStreamLayerVariableMembers(stream, "1", true, false, false).keySet());
return outputNames; return outputNames;
} }
...@@ -214,7 +226,7 @@ public class CNNArch2GluonTemplateController extends CNNArchTemplateController { ...@@ -214,7 +226,7 @@ public class CNNArch2GluonTemplateController extends CNNArchTemplateController {
int index = 0; int index = 0;
for (SerialCompositeElementSymbol stream : getArchitecture().getStreams()) { for (SerialCompositeElementSymbol stream : getArchitecture().getStreams()) {
List<List<String>> value = new ArrayList<>(); List<List<String>> value = new ArrayList<>();
Map<String, List<String>> member = getStreamLayerVariableMembers(stream, batchSize, true, includeStates); Map<String, List<String>> member = getStreamLayerVariableMembers(stream, batchSize, true, includeStates, false);
for (List<String> entry: member.values()){ for (List<String> entry: member.values()){
value.add(entry); value.add(entry);
ArrayList<String> streamIndex = new ArrayList<String>(); ArrayList<String> streamIndex = new ArrayList<String>();
...@@ -260,7 +272,7 @@ public class CNNArch2GluonTemplateController extends CNNArchTemplateController { ...@@ -260,7 +272,7 @@ public class CNNArch2GluonTemplateController extends CNNArchTemplateController {
return pairs; return pairs;
} }
private Map<String, List<String>> getStreamInputs(SerialCompositeElementSymbol stream) { private Map<String, List<String>> getStreamInputs(SerialCompositeElementSymbol stream, boolean addStateIndex) {
Map<String, List<String>> inputs = new LinkedHashMap<>(); Map<String, List<String>> inputs = new LinkedHashMap<>();
for (ArchitectureElementSymbol element : stream.getFirstAtomicElements()) { for (ArchitectureElementSymbol element : stream.getFirstAtomicElements()) {
...@@ -279,12 +291,12 @@ public class CNNArch2GluonTemplateController extends CNNArchTemplateController { ...@@ -279,12 +291,12 @@ public class CNNArch2GluonTemplateController extends CNNArchTemplateController {
} }
} }
inputs.putAll(getStreamLayerVariableMembers(stream, "1", false, false)); inputs.putAll(getStreamLayerVariableMembers(stream, "1", false, false, addStateIndex));
return inputs; return inputs;
} }
private Map<String, List<String>> getStreamLayerVariableMembers(SerialCompositeElementSymbol stream, String batchSize, boolean includeOutput, boolean includeStates) { private Map<String, List<String>> getStreamLayerVariableMembers(SerialCompositeElementSymbol stream, String batchSize, boolean includeOutput, boolean includeStates, boolean addStateIndex) {
Map<String, List<String>> members = new LinkedHashMap<>(); Map<String, List<String>> members = new LinkedHashMap<>();
List<ArchitectureElementSymbol> elements = stream.getSpannedScope().resolveLocally(ArchitectureElementSymbol.KIND); List<ArchitectureElementSymbol> elements = stream.getSpannedScope().resolveLocally(ArchitectureElementSymbol.KIND);
...@@ -300,7 +312,12 @@ public class CNNArch2GluonTemplateController extends CNNArchTemplateController { ...@@ -300,7 +312,12 @@ public class CNNArch2GluonTemplateController extends CNNArchTemplateController {
(PredefinedLayerDeclaration) layerVariableDeclaration.getLayer().getDeclaration(); (PredefinedLayerDeclaration) layerVariableDeclaration.getLayer().getDeclaration();
if (predefinedLayerDeclaration.isValidMember(VariableSymbol.Member.STATE)) { if (predefinedLayerDeclaration.isValidMember(VariableSymbol.Member.STATE)) {
String name = variable.getName() + "_state_"; String name;
if(addStateIndex && predefinedLayerDeclaration.getName().equals(AllPredefinedLayers.GRU_NAME)){
name = variable.getName() + "_state_[0]";
}else{
name = variable.getName() + "_state_";
}
List<Integer> intDimensions = predefinedLayerDeclaration.computeOutputTypes( List<Integer> intDimensions = predefinedLayerDeclaration.computeOutputTypes(
layerVariableDeclaration.getLayer().getInputTypes(), layerVariableDeclaration.getLayer().getInputTypes(),
......
...@@ -58,7 +58,7 @@ class ${tc.fileNameWithoutEnding}: ...@@ -58,7 +58,7 @@ class ${tc.fileNameWithoutEnding}:
self.networks[${networkInstruction?index}].collect_params().initialize(self.weight_initializer, ctx=context) self.networks[${networkInstruction?index}].collect_params().initialize(self.weight_initializer, ctx=context)
self.networks[${networkInstruction?index}].hybridize() self.networks[${networkInstruction?index}].hybridize()
self.networks[${networkInstruction?index}](<#list tc.getStreamInputDimensions(networkInstruction.body, false) as dimensions> self.networks[${networkInstruction?index}](<#list tc.getStreamInputDimensions(networkInstruction.body, false) as dimensions>
<#if dimensions[0] == "-1">self.networks[${networkInstruction?index}].${dimensions[1]}.begin_state(batch_size=1, ctx=context)[0]<#else>mx.nd.zeros((${tc.join(dimensions, ",")},), ctx=context)</#if> <#sep>, </#list>) <#if dimensions[0] == "-1">self.networks[${networkInstruction?index}].${dimensions[1]}<#else>mx.nd.zeros((${tc.join(dimensions, ",")},), ctx=context)</#if> <#sep>, </#list>)
</#if> </#if>
</#list> </#list>
......
...@@ -47,6 +47,7 @@ class ${tc.fileNameWithoutEnding}: ...@@ -47,6 +47,7 @@ class ${tc.fileNameWithoutEnding}:
test_label[index] = test_h5[output_name] test_label[index] = test_h5[output_name]
index += 1 index += 1
test_iter = mx.io.NDArrayIter(data=test_data, test_iter = mx.io.NDArrayIter(data=test_data,
label=test_label, label=test_label,
batch_size=batch_size) batch_size=batch_size)
......
...@@ -131,7 +131,7 @@ class Net_${networkInstruction?index}(gluon.HybridBlock): ...@@ -131,7 +131,7 @@ class Net_${networkInstruction?index}(gluon.HybridBlock):
with self.name_scope(): with self.name_scope():
${tc.include(networkInstruction.body, "ARCHITECTURE_DEFINITION")} ${tc.include(networkInstruction.body, "ARCHITECTURE_DEFINITION")}
def hybrid_forward(self, F, ${tc.join(tc.getStreamInputNames(networkInstruction.body), ", ")}): def hybrid_forward(self, F, ${tc.join(tc.getStreamInputNames(networkInstruction.body, false), ", ")}):
${tc.include(networkInstruction.body, "FORWARD_FUNCTION")} ${tc.include(networkInstruction.body, "FORWARD_FUNCTION")}
return ${tc.join(tc.getStreamOutputNames(networkInstruction.body), ", ")} return ${tc.join(tc.getStreamOutputNames(networkInstruction.body), ", ")}
......
...@@ -16,10 +16,10 @@ public: ...@@ -16,10 +16,10 @@ public:
const std::string json_file = "model/${tc.componentName}/model_${networkInstruction?index}_newest-symbol.json"; const std::string json_file = "model/${tc.componentName}/model_${networkInstruction?index}_newest-symbol.json";
const std::string param_file = "model/${tc.componentName}/model_${networkInstruction?index}_newest-0000.params"; const std::string param_file = "model/${tc.componentName}/model_${networkInstruction?index}_newest-0000.params";
const std::vector<std::string> input_keys = { const std::vector<std::string> input_keys = {
<#if tc.getStreamInputNames(networkInstruction.body)?size == 1> <#if tc.getStreamInputNames(networkInstruction.body, false)?size == 1>
"data" "data"
<#else> <#else>
<#list tc.getStreamInputNames(networkInstruction.body) as variable>"data${variable?index}"<#sep>, </#list> <#list tc.getStreamInputNames(networkInstruction.body, false) as variable>"data${variable?index}"<#sep>, </#list>
</#if> </#if>
}; };
const std::vector<std::vector<mx_uint>> input_shapes = {<#list tc.getStreamInputDimensions(networkInstruction.body, true) as dimensions>{${tc.join(dimensions, ", ")}}<#sep>, </#list>}; const std::vector<std::vector<mx_uint>> input_shapes = {<#list tc.getStreamInputDimensions(networkInstruction.body, true) as dimensions>{${tc.join(dimensions, ", ")}}<#sep>, </#list>};
...@@ -35,9 +35,9 @@ public: ...@@ -35,9 +35,9 @@ public:
if(handle) MXPredFree(handle); if(handle) MXPredFree(handle);
} }
void predict(${tc.join(tc.getStreamInputNames(networkInstruction.body), ", ", "const std::vector<float> &in_", "")}, void predict(${tc.join(tc.getStreamInputNames(networkInstruction.body, false), ", ", "const std::vector<float> &in_", "")},
${tc.join(tc.getStreamOutputNames(networkInstruction.body), ", ", "std::vector<float> &out_", "")}){ ${tc.join(tc.getStreamOutputNames(networkInstruction.body), ", ", "std::vector<float> &out_", "")}){
<#list tc.getStreamInputNames(networkInstruction.body) as variable> <#list tc.getStreamInputNames(networkInstruction.body, false) as variable>
MXPredSetInput(handle, input_keys[${variable?index}].c_str(), in_${variable}.data(), static_cast<mx_uint>(in_${variable}.size())); MXPredSetInput(handle, input_keys[${variable?index}].c_str(), in_${variable}.data(), static_cast<mx_uint>(in_${variable}.size()));
</#list> </#list>
......
...@@ -168,7 +168,7 @@ class ${tc.fileNameWithoutEnding}: ...@@ -168,7 +168,7 @@ class ${tc.fileNameWithoutEnding}:
train_iter.reset() train_iter.reset()
for batch_i, batch in enumerate(train_iter): for batch_i, batch in enumerate(train_iter):
<#list tc.architectureInputs as input_name> <#list tc.architectureInputs as input_name>
${input_name} = batch.data[${input_name?index}].as_in_context(mx_context) ${input_name} = batch.data[0].as_in_context(mx_context)
</#list> </#list>
<#list tc.architectureOutputs as output_name> <#list tc.architectureOutputs as output_name>
${output_name}label = batch.label[${output_name?index}].as_in_context(mx_context) ${output_name}label = batch.label[${output_name?index}].as_in_context(mx_context)
...@@ -207,7 +207,7 @@ class ${tc.fileNameWithoutEnding}: ...@@ -207,7 +207,7 @@ class ${tc.fileNameWithoutEnding}:
metric = mx.metric.create(eval_metric) metric = mx.metric.create(eval_metric)
for batch_i, batch in enumerate(train_iter): for batch_i, batch in enumerate(train_iter):
<#list tc.architectureInputs as input_name> <#list tc.architectureInputs as input_name>
${input_name} = batch.data[${input_name?index}].as_in_context(mx_context) ${input_name} = batch.data[0].as_in_context(mx_context)
</#list> </#list>
labels = [ labels = [
...@@ -266,7 +266,7 @@ class ${tc.fileNameWithoutEnding}: ...@@ -266,7 +266,7 @@ class ${tc.fileNameWithoutEnding}:
metric = mx.metric.create(eval_metric) metric = mx.metric.create(eval_metric)
for batch_i, batch in enumerate(test_iter): for batch_i, batch in enumerate(test_iter):
<#list tc.architectureInputs as input_name> <#list tc.architectureInputs as input_name>
${input_name} = batch.data[${input_name?index}].as_in_context(mx_context) ${input_name} = batch.data[0].as_in_context(mx_context)
</#list> </#list>
labels = [ labels = [
......
...@@ -6,11 +6,7 @@ ...@@ -6,11 +6,7 @@
<#elseif mode == "FORWARD_FUNCTION"> <#elseif mode == "FORWARD_FUNCTION">
${element.name} = self.${element.name}(${input}) ${element.name} = self.${element.name}(${input})
<#elseif mode == "PYTHON_INLINE"> <#elseif mode == "PYTHON_INLINE">
<#if input?ends_with("_state_")>
${element.name} = mx.nd.split(data=${input}[0], axis=0, num_outputs=${num_outputs})
<#else>
${element.name} = mx.nd.split(data=${input}, axis=0, num_outputs=${num_outputs}) ${element.name} = mx.nd.split(data=${input}, axis=0, num_outputs=${num_outputs})
</#if>
<#elseif mode == "CPP_INLINE"> <#elseif mode == "CPP_INLINE">
${element.name} = ${input} ${element.name} = ${input}
</#if> </#if>
\ No newline at end of file
...@@ -12,11 +12,11 @@ ...@@ -12,11 +12,11 @@
<#list tc.architecture.networkInstructions as networkInstruction> <#list tc.architecture.networkInstructions as networkInstruction>
<#if networkInstruction.isUnroll()> <#if networkInstruction.isUnroll()>
<#list networkInstruction.toUnrollInstruction().resolvedBodies as resolvedBody> <#list networkInstruction.toUnrollInstruction().resolvedBodies as resolvedBody>
_predictor_${networkInstruction?index}_.predict(${tc.join(tc.getStreamInputNames(networkInstruction.body, resolvedBody), ", ")}, ${tc.join(tc.getStreamOutputNames(networkInstruction.body, resolvedBody), ", ")}); _predictor_${networkInstruction?index}_.predict(${tc.join(tc.getStreamInputNames(networkInstruction.body, resolvedBody, false), ", ")}, ${tc.join(tc.getStreamOutputNames(networkInstruction.body, resolvedBody), ", ")});
</#list> </#list>
<#else> <#else>
<#if networkInstruction.body.isTrainable()> <#if networkInstruction.body.isTrainable()>
_predictor_${networkInstruction?index}_.predict(${tc.join(tc.getStreamInputNames(networkInstruction.body), ", ")}, ${tc.join(tc.getStreamOutputNames(networkInstruction.body), ", ")}); _predictor_${networkInstruction?index}_.predict(${tc.join(tc.getStreamInputNames(networkInstruction.body, false), ", ")}, ${tc.join(tc.getStreamOutputNames(networkInstruction.body), ", ")});
<#else> <#else>
<#-- ${tc.include(networkInstruction.body, "CPP_INLINE")}; --> <#-- ${tc.include(networkInstruction.body, "CPP_INLINE")}; -->
</#if> </#if>
......
...@@ -14,12 +14,12 @@ ...@@ -14,12 +14,12 @@
<#if networkInstruction.isUnroll()> <#if networkInstruction.isUnroll()>
<#list networkInstruction.toUnrollInstruction().resolvedBodies as resolvedBody> <#list networkInstruction.toUnrollInstruction().resolvedBodies as resolvedBody>
<#if networkInstruction.name == "BeamSearch"> <#if networkInstruction.name == "BeamSearch">
input = ${tc.join(tc.getStreamInputNames(networkInstruction.body, resolvedBody), ", ")} input = ${tc.join(tc.getStreamInputNames(networkInstruction.body, resolvedBody, true), ", ")}
<#assign length = tc.getBeamSearchLength(networkInstruction.toUnrollInstruction())> <#assign length = tc.getBeamSearchLength(networkInstruction.toUnrollInstruction())>
<#assign width = tc.getBeamSearchWidth(networkInstruction.toUnrollInstruction())> <#assign width = tc.getBeamSearchWidth(networkInstruction.toUnrollInstruction())>
${tc.getStreamOutputNames(networkInstruction.body, resolvedBody)[0]} = applyBeamSearch(input, 0, ${length}, ${width}, 1.0, ${networkInstruction?index}, input) ${tc.getStreamOutputNames(networkInstruction.body, resolvedBody)[0]} = applyBeamSearch(input, 0, ${length}, ${width}, 1.0, ${networkInstruction?index}, input)
<#else> <#else>
${tc.join(tc.getStreamOutputNames(networkInstruction.body, resolvedBody), ", ")} = self._networks[${networkInstruction?index}](${tc.join(tc.getStreamInputNames(networkInstruction.body, resolvedBody), ", ")?replace("_state_","_state_[0]")}) ${tc.join(tc.getStreamOutputNames(networkInstruction.body, resolvedBody), ", ")} = self._networks[${networkInstruction?index}](${tc.join(tc.getStreamInputNames(networkInstruction.body, resolvedBody, true), ", ")?replace("_state_","_state_")})
<#if !(tc.getStreamOutputNames(networkInstruction.body, resolvedBody)[0]?ends_with("_output_"))> <#if !(tc.getStreamOutputNames(networkInstruction.body, resolvedBody)[0]?ends_with("_output_"))>
outputs.append(${tc.getStreamOutputNames(networkInstruction.body, resolvedBody)[0]}) outputs.append(${tc.getStreamOutputNames(networkInstruction.body, resolvedBody)[0]})
</#if> </#if>
...@@ -32,7 +32,7 @@ ...@@ -32,7 +32,7 @@
</#list> </#list>
<#else> <#else>
<#if networkInstruction.body.isTrainable()> <#if networkInstruction.body.isTrainable()>
${tc.join(tc.getStreamOutputNames(networkInstruction.body), ", ")} = self._networks[${networkInstruction?index}](${tc.join(tc.getStreamInputNames(networkInstruction.body), ", ")?replace("_state_","_state_[0]")}) ${tc.join(tc.getStreamOutputNames(networkInstruction.body), ", ")} = self._networks[${networkInstruction?index}](${tc.join(tc.getStreamInputNames(networkInstruction.body, true), ", ")?replace("_state_","_state_")})
<#if !(tc.getStreamOutputNames(networkInstruction.body)[0]?ends_with("_output_"))> <#if !(tc.getStreamOutputNames(networkInstruction.body)[0]?ends_with("_output_"))>
outputs.append(${tc.getStreamOutputNames(networkInstruction.body)[0]}) outputs.append(${tc.getStreamOutputNames(networkInstruction.body)[0]})
</#if> </#if>
......
...@@ -13,7 +13,7 @@ ...@@ -13,7 +13,7 @@
<#list tc.architecture.networkInstructions as networkInstruction> <#list tc.architecture.networkInstructions as networkInstruction>
<#if networkInstruction.isUnroll()> <#if networkInstruction.isUnroll()>
<#list networkInstruction.toUnrollInstruction().resolvedBodies as resolvedBody> <#list networkInstruction.toUnrollInstruction().resolvedBodies as resolvedBody>
${tc.join(tc.getStreamOutputNames(networkInstruction.body, resolvedBody), ", ")} = self._networks[${networkInstruction?index}](${tc.join(tc.getStreamInputNames(networkInstruction.body, resolvedBody), ", ")?replace("_state_","_state_[0]")}) ${tc.join(tc.getStreamOutputNames(networkInstruction.body, resolvedBody), ", ")} = self._networks[${networkInstruction?index}](${tc.join(tc.getStreamInputNames(networkInstruction.body, resolvedBody, true), ", ")?replace("_state_","_state_")})
lossList.append(loss_function(${tc.getStreamOutputNames(networkInstruction.body, resolvedBody)[0]}, ${tc.getStreamOutputNames(networkInstruction.body, resolvedBody)[0]}label)) lossList.append(loss_function(${tc.getStreamOutputNames(networkInstruction.body, resolvedBody)[0]}, ${tc.getStreamOutputNames(networkInstruction.body, resolvedBody)[0]}label))
<#list resolvedBody.elements as element> <#list resolvedBody.elements as element>
<#if element.name == "ArgMax"> <#if element.name == "ArgMax">
...@@ -23,7 +23,7 @@ ...@@ -23,7 +23,7 @@
</#list> </#list>
<#else> <#else>
<#if networkInstruction.body.isTrainable()> <#if networkInstruction.body.isTrainable()>
${tc.join(tc.getStreamOutputNames(networkInstruction.body), ", ")} = self._networks[${networkInstruction?index}](${tc.join(tc.getStreamInputNames(networkInstruction.body), ", ")?replace("_state_","_state_[0]")}) ${tc.join(tc.getStreamOutputNames(networkInstruction.body), ", ")} = self._networks[${networkInstruction?index}](${tc.join(tc.getStreamInputNames(networkInstruction.body, true), ", ")?replace("_state_","_state_")})
<#if !(tc.getStreamOutputNames(networkInstruction.body)[0]?ends_with("_output_"))> <#if !(tc.getStreamOutputNames(networkInstruction.body)[0]?ends_with("_output_"))>
lossList.append(loss_function(${tc.getStreamOutputNames(networkInstruction.body)[0]}, ${tc.getStreamOutputNames(networkInstruction.body)[0]}label)) lossList.append(loss_function(${tc.getStreamOutputNames(networkInstruction.body)[0]}, ${tc.getStreamOutputNames(networkInstruction.body)[0]}label))
</#if> </#if>
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment