Skip to content
Snippets Groups Projects
Commit 1ded4b88 authored by Sebastian Nickels's avatar Sebastian Nickels
Browse files

Fixed bugs

parent 4be3befa
No related branches found
No related tags found
1 merge request!23Added Unroll-related features and layers
Pipeline #203307 failed
......@@ -183,8 +183,12 @@ public class CNNArch2GluonTemplateController extends CNNArchTemplateController {
if (element.isOutput()) {
String name = getName(element);
if (asArray) {
name = getNameAsArray(name);
if (asArray && element instanceof VariableSymbol) {
VariableSymbol variable = (VariableSymbol) element;
if (variable.getType() == VariableSymbol.Type.IO) {
name = getNameAsArray(name);
}
}
outputNames.add(name);
......@@ -287,8 +291,12 @@ public class CNNArch2GluonTemplateController extends CNNArchTemplateController {
String name = getName(element);
if (element.isOutput() && outputAsArray) {
name = getNameAsArray(name);
if (outputAsArray && element.isOutput() && element instanceof VariableSymbol) {
VariableSymbol variable = (VariableSymbol) element;
if (variable.getType() == VariableSymbol.Type.IO) {
name = getNameAsArray(name);
}
}
inputs.put(name, dimensions);
......
......@@ -42,10 +42,10 @@
</#if>
</#list>
<#if tc.isAttentionNetwork()>
${tc.join(tc.getUnrollOutputNames(networkInstruction, "i"), ", ")}, attention_ = self._networks[${networkInstruction?index}](${tc.join(tc.getUnrollInputNames(networkInstruction, "i"), ", ")})
attentionList.append(attention_)
${tc.join(tc.getUnrollOutputNames(networkInstruction, "i"), ", ")}, attention_ = self._networks[${networkInstruction?index}](${tc.join(tc.getUnrollInputNames(networkInstruction, "i"), ", ")})
attentionList.append(attention_)
<#else>
${tc.join(tc.getUnrollOutputNames(networkInstruction, "i"), ", ")} = self._networks[${networkInstruction?index}](${tc.join(tc.getUnrollInputNames(networkInstruction, "i"), ", ")})
${tc.join(tc.getUnrollOutputNames(networkInstruction, "i"), ", ")} = self._networks[${networkInstruction?index}](${tc.join(tc.getUnrollInputNames(networkInstruction, "i"), ", ")})
</#if>
<#list tc.getUnrollOutputNames(networkInstruction, "i") as outputName>
<#if tc.getNameWithoutIndex(outputName) == tc.outputName>
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment