Commit 82c6d14f authored by Sebastian N.'s avatar Sebastian N.
Browse files

Implemented parameter passing for unroll

parent 9f39abf2
Pipeline #180626 failed with stages
in 45 seconds
......@@ -140,6 +140,20 @@ public class CNNArch2GluonTemplateController extends CNNArchTemplateController {
return getStreamInputs(stream).keySet();
}
// used for unroll
public List<String> getStreamInputNames(SerialCompositeElementSymbol stream, SerialCompositeElementSymbol currentStream) {
List<String> inputNames = new LinkedList<>(getStreamInputNames(stream));
Map<String, String> pairs = getUnrollPairs(stream, currentStream);
for (int i = 0; i != inputNames.size(); ++i) {
if (pairs.containsKey(inputNames.get(i))) {
inputNames.set(i, pairs.get(inputNames.get(i)));
}
}
return inputNames;
}
public Collection<List<String>> getStreamInputDimensions(SerialCompositeElementSymbol stream) {
return getStreamInputs(stream).values();
}
......@@ -158,6 +172,20 @@ public class CNNArch2GluonTemplateController extends CNNArchTemplateController {
return outputNames;
}
// used for unroll
public List<String> getStreamOutputNames(SerialCompositeElementSymbol stream, SerialCompositeElementSymbol currentStream) {
List<String> outputNames = new LinkedList<>(getStreamOutputNames(stream));
Map<String, String> pairs = getUnrollPairs(stream, currentStream);
for (int i = 0; i != outputNames.size(); ++i) {
if (pairs.containsKey(outputNames.get(i))) {
outputNames.set(i, pairs.get(outputNames.get(i)));
}
}
return outputNames;
}
// Used to initialize all layer variable members which are passed through the networks
public Map<String, List<String>> getLayerVariableMembers(String batchSize) {
Map<String, List<String>> members = new LinkedHashMap<>();
......@@ -169,6 +197,34 @@ public class CNNArch2GluonTemplateController extends CNNArchTemplateController {
return members;
}
// Calculate differently named VariableSymbol elements in two streams, currently used for the UnrollInstructionSymbol
// body which is resolved with t = CONST_OFFSET and the current body of the actual timestep t
public Map<String, String> getUnrollPairs(ArchitectureElementSymbol element, ArchitectureElementSymbol current) {
Map<String, String> pairs = new HashMap<>();
if (element instanceof CompositeElementSymbol && current instanceof CompositeElementSymbol) {
List<ArchitectureElementSymbol> elements = ((CompositeElementSymbol) element).getElements();
List<ArchitectureElementSymbol> currentElements = ((CompositeElementSymbol) current).getElements();
if (elements.size() == currentElements.size()) {
for (int i = 0; i != currentElements.size(); ++i) {
String name = getName(elements.get(i));
String currentName = getName(currentElements.get(i));
if (elements.get(i) instanceof VariableSymbol && currentElements.get(i) instanceof VariableSymbol) {
if (name != null && currentName != null && !name.equals(currentName)) {
pairs.put(name, currentName);
}
}
pairs.putAll(getUnrollPairs(elements.get(i), currentElements.get(i)));
}
}
}
return pairs;
}
private Map<String, List<String>> getStreamInputs(SerialCompositeElementSymbol stream) {
Map<String, List<String>> inputs = new LinkedHashMap<>();
......
......@@ -6,9 +6,15 @@
</#list>
<#list tc.architecture.networkInstructions as networkInstruction>
<#if networkInstruction.isUnroll()>
<#list networkInstruction.toUnrollInstruction().resolvedBodies as resolvedBody>
${tc.join(tc.getStreamOutputNames(networkInstruction.body, resolvedBody), ", ")} = self._networks[${networkInstruction?index}](${tc.join(tc.getStreamInputNames(networkInstruction.body, resolvedBody), ", ")})
</#list>
<#else>
<#if networkInstruction.body.isTrainable()>
${tc.join(tc.getStreamOutputNames(networkInstruction.body), ", ")} = self._networks[${networkInstruction?index}](${tc.join(tc.getStreamInputNames(networkInstruction.body), ", ")})
<#else>
${tc.include(networkInstruction.body, "PYTHON_INLINE")}
</#if>
</#if>
</#list>
\ No newline at end of file
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment