diff --git a/src/main/resources/templates/gluon/elements/GRU.ftl b/src/main/resources/templates/gluon/elements/GRU.ftl index 37bbfc665c7373a4ac79b419c4291f7f6a76c0a6..101799266dfe0914907e48e3b1bae8cc577db078 100644 --- a/src/main/resources/templates/gluon/elements/GRU.ftl +++ b/src/main/resources/templates/gluon/elements/GRU.ftl @@ -1,10 +1,14 @@ <#if element.member == "NONE"> <#assign input = element.inputs[0]> <#if mode == "ARCHITECTURE_DEFINITION"> - self.rnn_${element.element.name} = gluon.rnn.GRU(hidden_size=${element.units?c}, num_layers=${element.layers?c}, layout='NTC') + self.${element.name} = gluon.rnn.GRU(hidden_size=${element.units?c}, num_layers=${element.layers?c}, layout='NTC') <#include "OutputShape.ftl"> <#elseif mode == "FORWARD_FUNCTION"> - ${element.name}, ${element.element.name}_state_ = self.rnn_${element.element.name}(${input}, ${element.element.name}_state_) +<#if element.isVariable()> + ${element.name}, ${element.element.name}_state_ = self.${element.name}(${input}, ${element.element.name}_state_) +<#else> + ${element.name} = self.${element.name}(${input}) +</#if> </#if> <#elseif element.member == "STATE"> <#if element.inputs?size gte 1> diff --git a/src/main/resources/templates/gluon/elements/LSTM.ftl b/src/main/resources/templates/gluon/elements/LSTM.ftl index 34cc7411a83cbe8d960dfa2873deadaed73dfa6e..c78f4551fa95a7e0d4f5fd1f588324baf460406b 100644 --- a/src/main/resources/templates/gluon/elements/LSTM.ftl +++ b/src/main/resources/templates/gluon/elements/LSTM.ftl @@ -1,10 +1,14 @@ <#if element.member == "NONE"> <#assign input = element.inputs[0]> <#if mode == "ARCHITECTURE_DEFINITION"> - self.rnn_${element.element.name} = gluon.rnn.LSTM(hidden_size=${element.units?c}, num_layers=${element.layers?c}, layout='NTC') + self.${element.name} = gluon.rnn.LSTM(hidden_size=${element.units?c}, num_layers=${element.layers?c}, layout='NTC') <#include "OutputShape.ftl"> <#elseif mode == "FORWARD_FUNCTION"> - ${element.name}, ${element.element.name}_state_ = self.rnn_${element.element.name}(${input}, ${element.element.name}_state_) +<#if element.isVariable()> + ${element.name}, ${element.element.name}_state_ = self.${element.name}(${input}, ${element.element.name}_state_) +<#else> + ${element.name} = self.${element.name}(${input}) +</#if> </#if> <#elseif element.member == "STATE"> <#if element.inputs?size gte 1> diff --git a/src/main/resources/templates/gluon/elements/RNN.ftl b/src/main/resources/templates/gluon/elements/RNN.ftl index 2fc1904f2e53650322174adad182ea5c9dd5877b..1fbd75bcdfe13d0793cff72ce272537f17a1774d 100644 --- a/src/main/resources/templates/gluon/elements/RNN.ftl +++ b/src/main/resources/templates/gluon/elements/RNN.ftl @@ -1,10 +1,14 @@ <#if element.member == "NONE"> <#assign input = element.inputs[0]> <#if mode == "ARCHITECTURE_DEFINITION"> - self.rnn_${element.element.name} = gluon.rnn.RNN(hidden_size=${element.units?c}, num_layers=${element.layers?c}, activation='tanh', layout='NTC') + self.${element.name} = gluon.rnn.RNN(hidden_size=${element.units?c}, num_layers=${element.layers?c}, activation='tanh', layout='NTC') <#include "OutputShape.ftl"> <#elseif mode == "FORWARD_FUNCTION"> - ${element.name}, ${element.element.name}_state_ = self.rnn_${element.element.name}(${input}, ${element.element.name}_state_) +<#if element.isVariable()> + ${element.name}, ${element.element.name}_state_ = self.${element.name}(${input}, ${element.element.name}_state_) +<#else> + ${element.name} = self.${element.name}(${input}) +</#if> </#if> <#elseif element.member == "STATE"> <#if element.inputs?size gte 1>