From 94d3b4a86704e7ab1651f607ba173a304c49e97a Mon Sep 17 00:00:00 2001
From: Sebastian Nickels <sn1c@protonmail.ch>
Date: Fri, 16 Aug 2019 20:49:41 +0200
Subject: [PATCH] Fixed a bug which caused that RNNs could not be used without
 variable

---
 src/main/resources/templates/gluon/elements/GRU.ftl  | 8 ++++++--
 src/main/resources/templates/gluon/elements/LSTM.ftl | 8 ++++++--
 src/main/resources/templates/gluon/elements/RNN.ftl  | 8 ++++++--
 3 files changed, 18 insertions(+), 6 deletions(-)

diff --git a/src/main/resources/templates/gluon/elements/GRU.ftl b/src/main/resources/templates/gluon/elements/GRU.ftl
index 37bbfc66..10179926 100644
--- a/src/main/resources/templates/gluon/elements/GRU.ftl
+++ b/src/main/resources/templates/gluon/elements/GRU.ftl
@@ -1,10 +1,14 @@
 <#if element.member == "NONE">
 <#assign input = element.inputs[0]>
 <#if mode == "ARCHITECTURE_DEFINITION">
-            self.rnn_${element.element.name} = gluon.rnn.GRU(hidden_size=${element.units?c}, num_layers=${element.layers?c}, layout='NTC')
+            self.${element.name} = gluon.rnn.GRU(hidden_size=${element.units?c}, num_layers=${element.layers?c}, layout='NTC')
             <#include "OutputShape.ftl">
 <#elseif mode == "FORWARD_FUNCTION">
-        ${element.name}, ${element.element.name}_state_ = self.rnn_${element.element.name}(${input}, ${element.element.name}_state_)
+<#if element.isVariable()>
+        ${element.name}, ${element.element.name}_state_ = self.${element.name}(${input}, ${element.element.name}_state_)
+<#else>
+        ${element.name} = self.${element.name}(${input})
+</#if>
 </#if>
 <#elseif element.member == "STATE">
 <#if element.inputs?size gte 1>
diff --git a/src/main/resources/templates/gluon/elements/LSTM.ftl b/src/main/resources/templates/gluon/elements/LSTM.ftl
index 34cc7411..c78f4551 100644
--- a/src/main/resources/templates/gluon/elements/LSTM.ftl
+++ b/src/main/resources/templates/gluon/elements/LSTM.ftl
@@ -1,10 +1,14 @@
 <#if element.member == "NONE">
 <#assign input = element.inputs[0]>
 <#if mode == "ARCHITECTURE_DEFINITION">
-            self.rnn_${element.element.name} = gluon.rnn.LSTM(hidden_size=${element.units?c}, num_layers=${element.layers?c}, layout='NTC')
+            self.${element.name} = gluon.rnn.LSTM(hidden_size=${element.units?c}, num_layers=${element.layers?c}, layout='NTC')
             <#include "OutputShape.ftl">
 <#elseif mode == "FORWARD_FUNCTION">
-        ${element.name}, ${element.element.name}_state_ = self.rnn_${element.element.name}(${input}, ${element.element.name}_state_)
+<#if element.isVariable()>
+        ${element.name}, ${element.element.name}_state_ = self.${element.name}(${input}, ${element.element.name}_state_)
+<#else>
+        ${element.name} = self.${element.name}(${input})
+</#if>
 </#if>
 <#elseif element.member == "STATE">
 <#if element.inputs?size gte 1>
diff --git a/src/main/resources/templates/gluon/elements/RNN.ftl b/src/main/resources/templates/gluon/elements/RNN.ftl
index 2fc1904f..1fbd75bc 100644
--- a/src/main/resources/templates/gluon/elements/RNN.ftl
+++ b/src/main/resources/templates/gluon/elements/RNN.ftl
@@ -1,10 +1,14 @@
 <#if element.member == "NONE">
 <#assign input = element.inputs[0]>
 <#if mode == "ARCHITECTURE_DEFINITION">
-            self.rnn_${element.element.name} = gluon.rnn.RNN(hidden_size=${element.units?c}, num_layers=${element.layers?c}, activation='tanh', layout='NTC')
+            self.${element.name} = gluon.rnn.RNN(hidden_size=${element.units?c}, num_layers=${element.layers?c}, activation='tanh', layout='NTC')
             <#include "OutputShape.ftl">
 <#elseif mode == "FORWARD_FUNCTION">
-        ${element.name}, ${element.element.name}_state_ = self.rnn_${element.element.name}(${input}, ${element.element.name}_state_)
+<#if element.isVariable()>
+        ${element.name}, ${element.element.name}_state_ = self.${element.name}(${input}, ${element.element.name}_state_)
+<#else>
+        ${element.name} = self.${element.name}(${input})
+</#if>
 </#if>
 <#elseif element.member == "STATE">
 <#if element.inputs?size gte 1>
-- 
GitLab