From a221923ca101817dceee0a9a4be6821a296bb1da Mon Sep 17 00:00:00 2001
From: Sebastian Nickels <sn1c@protonmail.ch>
Date: Fri, 16 Aug 2019 01:50:17 +0200
Subject: [PATCH] Added LSTM and GRU layers

---
 pom.xml                                       |  4 ++--
 .../CNNArch2GluonLayerSupportChecker.java     |  2 ++
 .../templates/gluon/elements/GRU.ftl          | 20 +++++++++++++++++++
 .../templates/gluon/elements/LSTM.ftl         | 20 +++++++++++++++++++
 .../templates/gluon/elements/RNN.ftl          |  2 +-
 5 files changed, 45 insertions(+), 3 deletions(-)
 create mode 100644 src/main/resources/templates/gluon/elements/GRU.ftl
 create mode 100644 src/main/resources/templates/gluon/elements/LSTM.ftl

diff --git a/pom.xml b/pom.xml
index e814f9df..7436e70a 100644
--- a/pom.xml
+++ b/pom.xml
@@ -8,14 +8,14 @@
 
   <groupId>de.monticore.lang.monticar</groupId>
   <artifactId>cnnarch-gluon-generator</artifactId>
-  <version>0.2.7-SNAPSHOT</version>
+  <version>0.2.8-SNAPSHOT</version>
 
   <!-- == PROJECT DEPENDENCIES ============================================= -->
 
   <properties>
 
     <!-- .. SE-Libraries .................................................. -->
-      <CNNArch.version>0.3.2-SNAPSHOT</CNNArch.version>
+      <CNNArch.version>0.3.3-SNAPSHOT</CNNArch.version>
       <CNNTrain.version>0.3.6-SNAPSHOT</CNNTrain.version>
       <CNNArch2X.version>0.0.3-SNAPSHOT</CNNArch2X.version>
       <embedded-montiarc-math-opt-generator>0.1.4</embedded-montiarc-math-opt-generator>
diff --git a/src/main/java/de/monticore/lang/monticar/cnnarch/gluongenerator/CNNArch2GluonLayerSupportChecker.java b/src/main/java/de/monticore/lang/monticar/cnnarch/gluongenerator/CNNArch2GluonLayerSupportChecker.java
index 0f73871c..e657b8a1 100644
--- a/src/main/java/de/monticore/lang/monticar/cnnarch/gluongenerator/CNNArch2GluonLayerSupportChecker.java
+++ b/src/main/java/de/monticore/lang/monticar/cnnarch/gluongenerator/CNNArch2GluonLayerSupportChecker.java
@@ -24,6 +24,8 @@ public class CNNArch2GluonLayerSupportChecker extends LayerSupportChecker {
         supportedLayerList.add(AllPredefinedLayers.FLATTEN_NAME);
         supportedLayerList.add(AllPredefinedLayers.ONE_HOT_NAME);
         supportedLayerList.add(AllPredefinedLayers.RNN_NAME);
+        supportedLayerList.add(AllPredefinedLayers.LSTM_NAME);
+        supportedLayerList.add(AllPredefinedLayers.GRU_NAME);
     }
 
 }
diff --git a/src/main/resources/templates/gluon/elements/GRU.ftl b/src/main/resources/templates/gluon/elements/GRU.ftl
new file mode 100644
index 00000000..37bbfc66
--- /dev/null
+++ b/src/main/resources/templates/gluon/elements/GRU.ftl
@@ -0,0 +1,20 @@
+<#if element.member == "NONE">
+<#assign input = element.inputs[0]>
+<#if mode == "ARCHITECTURE_DEFINITION">
+            self.rnn_${element.element.name} = gluon.rnn.GRU(hidden_size=${element.units?c}, num_layers=${element.layers?c}, layout='NTC')
+            <#include "OutputShape.ftl">
+<#elseif mode == "FORWARD_FUNCTION">
+        ${element.name}, ${element.element.name}_state_ = self.rnn_${element.element.name}(${input}, ${element.element.name}_state_)
+</#if>
+<#elseif element.member == "STATE">
+<#if element.inputs?size gte 1>
+<#assign input = element.inputs[0]>
+<#if mode == "FORWARD_FUNCTION">
+        ${element.name} = ${input}
+<#elseif mode == "PYTHON_INLINE">
+                    ${element.name} = ${input}
+<#elseif mode == "CPP_INLINE">
+    ${element.name} = ${input}
+</#if>
+</#if>
+</#if>
\ No newline at end of file
diff --git a/src/main/resources/templates/gluon/elements/LSTM.ftl b/src/main/resources/templates/gluon/elements/LSTM.ftl
new file mode 100644
index 00000000..34cc7411
--- /dev/null
+++ b/src/main/resources/templates/gluon/elements/LSTM.ftl
@@ -0,0 +1,20 @@
+<#if element.member == "NONE">
+<#assign input = element.inputs[0]>
+<#if mode == "ARCHITECTURE_DEFINITION">
+            self.rnn_${element.element.name} = gluon.rnn.LSTM(hidden_size=${element.units?c}, num_layers=${element.layers?c}, layout='NTC')
+            <#include "OutputShape.ftl">
+<#elseif mode == "FORWARD_FUNCTION">
+        ${element.name}, ${element.element.name}_state_ = self.rnn_${element.element.name}(${input}, ${element.element.name}_state_)
+</#if>
+<#elseif element.member == "STATE">
+<#if element.inputs?size gte 1>
+<#assign input = element.inputs[0]>
+<#if mode == "FORWARD_FUNCTION">
+        ${element.name} = ${input}
+<#elseif mode == "PYTHON_INLINE">
+                    ${element.name} = ${input}
+<#elseif mode == "CPP_INLINE">
+    ${element.name} = ${input}
+</#if>
+</#if>
+</#if>
\ No newline at end of file
diff --git a/src/main/resources/templates/gluon/elements/RNN.ftl b/src/main/resources/templates/gluon/elements/RNN.ftl
index 1125ff11..2fc1904f 100644
--- a/src/main/resources/templates/gluon/elements/RNN.ftl
+++ b/src/main/resources/templates/gluon/elements/RNN.ftl
@@ -1,7 +1,7 @@
 <#if element.member == "NONE">
 <#assign input = element.inputs[0]>
 <#if mode == "ARCHITECTURE_DEFINITION">
-            self.rnn_${element.element.name} = gluon.rnn.RNN(hidden_size=${element.units?c}, num_layers=${element.layers?c}, layout='NTC')
+            self.rnn_${element.element.name} = gluon.rnn.RNN(hidden_size=${element.units?c}, num_layers=${element.layers?c}, activation='tanh', layout='NTC')
             <#include "OutputShape.ftl">
 <#elseif mode == "FORWARD_FUNCTION">
         ${element.name}, ${element.element.name}_state_ = self.rnn_${element.element.name}(${input}, ${element.element.name}_state_)
-- 
GitLab