Skip to content
Snippets Groups Projects
Commit 82c6d14f authored by Sebastian Nickels's avatar Sebastian Nickels
Browse files

Implemented parameter passing for unroll

parent 9f39abf2
No related branches found
No related tags found
1 merge request!23Added Unroll-related features and layers
Pipeline #180626 failed
......@@ -140,6 +140,20 @@ public class CNNArch2GluonTemplateController extends CNNArchTemplateController {
return getStreamInputs(stream).keySet();
}
// used for unroll
public List<String> getStreamInputNames(SerialCompositeElementSymbol stream, SerialCompositeElementSymbol currentStream) {
List<String> inputNames = new LinkedList<>(getStreamInputNames(stream));
Map<String, String> pairs = getUnrollPairs(stream, currentStream);
for (int i = 0; i != inputNames.size(); ++i) {
if (pairs.containsKey(inputNames.get(i))) {
inputNames.set(i, pairs.get(inputNames.get(i)));
}
}
return inputNames;
}
public Collection<List<String>> getStreamInputDimensions(SerialCompositeElementSymbol stream) {
return getStreamInputs(stream).values();
}
......@@ -158,6 +172,20 @@ public class CNNArch2GluonTemplateController extends CNNArchTemplateController {
return outputNames;
}
// used for unroll
public List<String> getStreamOutputNames(SerialCompositeElementSymbol stream, SerialCompositeElementSymbol currentStream) {
List<String> outputNames = new LinkedList<>(getStreamOutputNames(stream));
Map<String, String> pairs = getUnrollPairs(stream, currentStream);
for (int i = 0; i != outputNames.size(); ++i) {
if (pairs.containsKey(outputNames.get(i))) {
outputNames.set(i, pairs.get(outputNames.get(i)));
}
}
return outputNames;
}
// Used to initialize all layer variable members which are passed through the networks
public Map<String, List<String>> getLayerVariableMembers(String batchSize) {
Map<String, List<String>> members = new LinkedHashMap<>();
......@@ -169,6 +197,34 @@ public class CNNArch2GluonTemplateController extends CNNArchTemplateController {
return members;
}
// Calculate differently named VariableSymbol elements in two streams, currently used for the UnrollInstructionSymbol
// body which is resolved with t = CONST_OFFSET and the current body of the actual timestep t
public Map<String, String> getUnrollPairs(ArchitectureElementSymbol element, ArchitectureElementSymbol current) {
Map<String, String> pairs = new HashMap<>();
if (element instanceof CompositeElementSymbol && current instanceof CompositeElementSymbol) {
List<ArchitectureElementSymbol> elements = ((CompositeElementSymbol) element).getElements();
List<ArchitectureElementSymbol> currentElements = ((CompositeElementSymbol) current).getElements();
if (elements.size() == currentElements.size()) {
for (int i = 0; i != currentElements.size(); ++i) {
String name = getName(elements.get(i));
String currentName = getName(currentElements.get(i));
if (elements.get(i) instanceof VariableSymbol && currentElements.get(i) instanceof VariableSymbol) {
if (name != null && currentName != null && !name.equals(currentName)) {
pairs.put(name, currentName);
}
}
pairs.putAll(getUnrollPairs(elements.get(i), currentElements.get(i)));
}
}
}
return pairs;
}
private Map<String, List<String>> getStreamInputs(SerialCompositeElementSymbol stream) {
Map<String, List<String>> inputs = new LinkedHashMap<>();
......
......@@ -6,9 +6,15 @@
</#list>
<#list tc.architecture.networkInstructions as networkInstruction>
<#if networkInstruction.isUnroll()>
<#list networkInstruction.toUnrollInstruction().resolvedBodies as resolvedBody>
${tc.join(tc.getStreamOutputNames(networkInstruction.body, resolvedBody), ", ")} = self._networks[${networkInstruction?index}](${tc.join(tc.getStreamInputNames(networkInstruction.body, resolvedBody), ", ")})
</#list>
<#else>
<#if networkInstruction.body.isTrainable()>
${tc.join(tc.getStreamOutputNames(networkInstruction.body), ", ")} = self._networks[${networkInstruction?index}](${tc.join(tc.getStreamInputNames(networkInstruction.body), ", ")})
<#else>
${tc.include(networkInstruction.body, "PYTHON_INLINE")}
</#if>
</#if>
</#list>
\ No newline at end of file
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment