From a221923ca101817dceee0a9a4be6821a296bb1da Mon Sep 17 00:00:00 2001 From: Sebastian Nickels <sn1c@protonmail.ch> Date: Fri, 16 Aug 2019 01:50:17 +0200 Subject: [PATCH] Added LSTM and GRU layers --- pom.xml | 4 ++-- .../CNNArch2GluonLayerSupportChecker.java | 2 ++ .../templates/gluon/elements/GRU.ftl | 20 +++++++++++++++++++ .../templates/gluon/elements/LSTM.ftl | 20 +++++++++++++++++++ .../templates/gluon/elements/RNN.ftl | 2 +- 5 files changed, 45 insertions(+), 3 deletions(-) create mode 100644 src/main/resources/templates/gluon/elements/GRU.ftl create mode 100644 src/main/resources/templates/gluon/elements/LSTM.ftl diff --git a/pom.xml b/pom.xml index e814f9df..7436e70a 100644 --- a/pom.xml +++ b/pom.xml @@ -8,14 +8,14 @@ <groupId>de.monticore.lang.monticar</groupId> <artifactId>cnnarch-gluon-generator</artifactId> - <version>0.2.7-SNAPSHOT</version> + <version>0.2.8-SNAPSHOT</version> <!-- == PROJECT DEPENDENCIES ============================================= --> <properties> <!-- .. SE-Libraries .................................................. --> - <CNNArch.version>0.3.2-SNAPSHOT</CNNArch.version> + <CNNArch.version>0.3.3-SNAPSHOT</CNNArch.version> <CNNTrain.version>0.3.6-SNAPSHOT</CNNTrain.version> <CNNArch2X.version>0.0.3-SNAPSHOT</CNNArch2X.version> <embedded-montiarc-math-opt-generator>0.1.4</embedded-montiarc-math-opt-generator> diff --git a/src/main/java/de/monticore/lang/monticar/cnnarch/gluongenerator/CNNArch2GluonLayerSupportChecker.java b/src/main/java/de/monticore/lang/monticar/cnnarch/gluongenerator/CNNArch2GluonLayerSupportChecker.java index 0f73871c..e657b8a1 100644 --- a/src/main/java/de/monticore/lang/monticar/cnnarch/gluongenerator/CNNArch2GluonLayerSupportChecker.java +++ b/src/main/java/de/monticore/lang/monticar/cnnarch/gluongenerator/CNNArch2GluonLayerSupportChecker.java @@ -24,6 +24,8 @@ public class CNNArch2GluonLayerSupportChecker extends LayerSupportChecker { supportedLayerList.add(AllPredefinedLayers.FLATTEN_NAME); supportedLayerList.add(AllPredefinedLayers.ONE_HOT_NAME); supportedLayerList.add(AllPredefinedLayers.RNN_NAME); + supportedLayerList.add(AllPredefinedLayers.LSTM_NAME); + supportedLayerList.add(AllPredefinedLayers.GRU_NAME); } } diff --git a/src/main/resources/templates/gluon/elements/GRU.ftl b/src/main/resources/templates/gluon/elements/GRU.ftl new file mode 100644 index 00000000..37bbfc66 --- /dev/null +++ b/src/main/resources/templates/gluon/elements/GRU.ftl @@ -0,0 +1,20 @@ +<#if element.member == "NONE"> +<#assign input = element.inputs[0]> +<#if mode == "ARCHITECTURE_DEFINITION"> + self.rnn_${element.element.name} = gluon.rnn.GRU(hidden_size=${element.units?c}, num_layers=${element.layers?c}, layout='NTC') + <#include "OutputShape.ftl"> +<#elseif mode == "FORWARD_FUNCTION"> + ${element.name}, ${element.element.name}_state_ = self.rnn_${element.element.name}(${input}, ${element.element.name}_state_) +</#if> +<#elseif element.member == "STATE"> +<#if element.inputs?size gte 1> +<#assign input = element.inputs[0]> +<#if mode == "FORWARD_FUNCTION"> + ${element.name} = ${input} +<#elseif mode == "PYTHON_INLINE"> + ${element.name} = ${input} +<#elseif mode == "CPP_INLINE"> + ${element.name} = ${input} +</#if> +</#if> +</#if> \ No newline at end of file diff --git a/src/main/resources/templates/gluon/elements/LSTM.ftl b/src/main/resources/templates/gluon/elements/LSTM.ftl new file mode 100644 index 00000000..34cc7411 --- /dev/null +++ b/src/main/resources/templates/gluon/elements/LSTM.ftl @@ -0,0 +1,20 @@ +<#if element.member == "NONE"> +<#assign input = element.inputs[0]> +<#if mode == "ARCHITECTURE_DEFINITION"> + self.rnn_${element.element.name} = gluon.rnn.LSTM(hidden_size=${element.units?c}, num_layers=${element.layers?c}, layout='NTC') + <#include "OutputShape.ftl"> +<#elseif mode == "FORWARD_FUNCTION"> + ${element.name}, ${element.element.name}_state_ = self.rnn_${element.element.name}(${input}, ${element.element.name}_state_) +</#if> +<#elseif element.member == "STATE"> +<#if element.inputs?size gte 1> +<#assign input = element.inputs[0]> +<#if mode == "FORWARD_FUNCTION"> + ${element.name} = ${input} +<#elseif mode == "PYTHON_INLINE"> + ${element.name} = ${input} +<#elseif mode == "CPP_INLINE"> + ${element.name} = ${input} +</#if> +</#if> +</#if> \ No newline at end of file diff --git a/src/main/resources/templates/gluon/elements/RNN.ftl b/src/main/resources/templates/gluon/elements/RNN.ftl index 1125ff11..2fc1904f 100644 --- a/src/main/resources/templates/gluon/elements/RNN.ftl +++ b/src/main/resources/templates/gluon/elements/RNN.ftl @@ -1,7 +1,7 @@ <#if element.member == "NONE"> <#assign input = element.inputs[0]> <#if mode == "ARCHITECTURE_DEFINITION"> - self.rnn_${element.element.name} = gluon.rnn.RNN(hidden_size=${element.units?c}, num_layers=${element.layers?c}, layout='NTC') + self.rnn_${element.element.name} = gluon.rnn.RNN(hidden_size=${element.units?c}, num_layers=${element.layers?c}, activation='tanh', layout='NTC') <#include "OutputShape.ftl"> <#elseif mode == "FORWARD_FUNCTION"> ${element.name}, ${element.element.name}_state_ = self.rnn_${element.element.name}(${input}, ${element.element.name}_state_) -- GitLab