From 94d3b4a86704e7ab1651f607ba173a304c49e97a Mon Sep 17 00:00:00 2001 From: Sebastian Nickels <sn1c@protonmail.ch> Date: Fri, 16 Aug 2019 20:49:41 +0200 Subject: [PATCH] Fixed a bug which caused that RNNs could not be used without variable --- src/main/resources/templates/gluon/elements/GRU.ftl | 8 ++++++-- src/main/resources/templates/gluon/elements/LSTM.ftl | 8 ++++++-- src/main/resources/templates/gluon/elements/RNN.ftl | 8 ++++++-- 3 files changed, 18 insertions(+), 6 deletions(-) diff --git a/src/main/resources/templates/gluon/elements/GRU.ftl b/src/main/resources/templates/gluon/elements/GRU.ftl index 37bbfc66..10179926 100644 --- a/src/main/resources/templates/gluon/elements/GRU.ftl +++ b/src/main/resources/templates/gluon/elements/GRU.ftl @@ -1,10 +1,14 @@ <#if element.member == "NONE"> <#assign input = element.inputs[0]> <#if mode == "ARCHITECTURE_DEFINITION"> - self.rnn_${element.element.name} = gluon.rnn.GRU(hidden_size=${element.units?c}, num_layers=${element.layers?c}, layout='NTC') + self.${element.name} = gluon.rnn.GRU(hidden_size=${element.units?c}, num_layers=${element.layers?c}, layout='NTC') <#include "OutputShape.ftl"> <#elseif mode == "FORWARD_FUNCTION"> - ${element.name}, ${element.element.name}_state_ = self.rnn_${element.element.name}(${input}, ${element.element.name}_state_) +<#if element.isVariable()> + ${element.name}, ${element.element.name}_state_ = self.${element.name}(${input}, ${element.element.name}_state_) +<#else> + ${element.name} = self.${element.name}(${input}) +</#if> </#if> <#elseif element.member == "STATE"> <#if element.inputs?size gte 1> diff --git a/src/main/resources/templates/gluon/elements/LSTM.ftl b/src/main/resources/templates/gluon/elements/LSTM.ftl index 34cc7411..c78f4551 100644 --- a/src/main/resources/templates/gluon/elements/LSTM.ftl +++ b/src/main/resources/templates/gluon/elements/LSTM.ftl @@ -1,10 +1,14 @@ <#if element.member == "NONE"> <#assign input = element.inputs[0]> <#if mode == "ARCHITECTURE_DEFINITION"> - self.rnn_${element.element.name} = gluon.rnn.LSTM(hidden_size=${element.units?c}, num_layers=${element.layers?c}, layout='NTC') + self.${element.name} = gluon.rnn.LSTM(hidden_size=${element.units?c}, num_layers=${element.layers?c}, layout='NTC') <#include "OutputShape.ftl"> <#elseif mode == "FORWARD_FUNCTION"> - ${element.name}, ${element.element.name}_state_ = self.rnn_${element.element.name}(${input}, ${element.element.name}_state_) +<#if element.isVariable()> + ${element.name}, ${element.element.name}_state_ = self.${element.name}(${input}, ${element.element.name}_state_) +<#else> + ${element.name} = self.${element.name}(${input}) +</#if> </#if> <#elseif element.member == "STATE"> <#if element.inputs?size gte 1> diff --git a/src/main/resources/templates/gluon/elements/RNN.ftl b/src/main/resources/templates/gluon/elements/RNN.ftl index 2fc1904f..1fbd75bc 100644 --- a/src/main/resources/templates/gluon/elements/RNN.ftl +++ b/src/main/resources/templates/gluon/elements/RNN.ftl @@ -1,10 +1,14 @@ <#if element.member == "NONE"> <#assign input = element.inputs[0]> <#if mode == "ARCHITECTURE_DEFINITION"> - self.rnn_${element.element.name} = gluon.rnn.RNN(hidden_size=${element.units?c}, num_layers=${element.layers?c}, activation='tanh', layout='NTC') + self.${element.name} = gluon.rnn.RNN(hidden_size=${element.units?c}, num_layers=${element.layers?c}, activation='tanh', layout='NTC') <#include "OutputShape.ftl"> <#elseif mode == "FORWARD_FUNCTION"> - ${element.name}, ${element.element.name}_state_ = self.rnn_${element.element.name}(${input}, ${element.element.name}_state_) +<#if element.isVariable()> + ${element.name}, ${element.element.name}_state_ = self.${element.name}(${input}, ${element.element.name}_state_) +<#else> + ${element.name} = self.${element.name}(${input}) +</#if> </#if> <#elseif element.member == "STATE"> <#if element.inputs?size gte 1> -- GitLab