From a5e9d403640e65ac7b49b30443d105bb81f69291 Mon Sep 17 00:00:00 2001
From: Sebastian Nickels <sn1c@protonmail.ch>
Date: Sat, 17 Aug 2019 21:51:49 +0200
Subject: [PATCH] Outputs now can be used as inputs

---
 .../CNNArch2GluonArchitectureSupportChecker.java             | 5 +++++
 .../gluongenerator/CNNArch2GluonTemplateController.java      | 2 +-
 src/main/resources/templates/gluon/elements/Output.ftl       | 2 ++
 src/main/resources/templates/gluon/pythonExecute.ftl         | 3 +++
 .../resources/target_code/CNNSupervisedTrainer_Alexnet.py    | 3 +++
 .../CNNSupervisedTrainer_CifarClassifierNetwork.py           | 3 +++
 src/test/resources/target_code/CNNSupervisedTrainer_VGG16.py | 3 +++
 7 files changed, 20 insertions(+), 1 deletion(-)

diff --git a/src/main/java/de/monticore/lang/monticar/cnnarch/gluongenerator/CNNArch2GluonArchitectureSupportChecker.java b/src/main/java/de/monticore/lang/monticar/cnnarch/gluongenerator/CNNArch2GluonArchitectureSupportChecker.java
index 484b61ca..901be31d 100644
--- a/src/main/java/de/monticore/lang/monticar/cnnarch/gluongenerator/CNNArch2GluonArchitectureSupportChecker.java
+++ b/src/main/java/de/monticore/lang/monticar/cnnarch/gluongenerator/CNNArch2GluonArchitectureSupportChecker.java
@@ -37,4 +37,9 @@ public class CNNArch2GluonArchitectureSupportChecker extends ArchitectureSupport
         return true;
     }
 
+    @Override
+    protected boolean checkOutputAsInput(ArchitectureSymbol architecture) {
+        return true;
+    }
+
 }
diff --git a/src/main/java/de/monticore/lang/monticar/cnnarch/gluongenerator/CNNArch2GluonTemplateController.java b/src/main/java/de/monticore/lang/monticar/cnnarch/gluongenerator/CNNArch2GluonTemplateController.java
index fa28ee43..864d4da7 100644
--- a/src/main/java/de/monticore/lang/monticar/cnnarch/gluongenerator/CNNArch2GluonTemplateController.java
+++ b/src/main/java/de/monticore/lang/monticar/cnnarch/gluongenerator/CNNArch2GluonTemplateController.java
@@ -173,7 +173,7 @@ public class CNNArch2GluonTemplateController extends CNNArchTemplateController {
         Map<String, List<String>> inputs = new LinkedHashMap<>();
 
         for (ArchitectureElementSymbol element : stream.getFirstAtomicElements()) {
-            if (element.isInput()) {
+            if (element.isInput() || element.isOutput()) {
                 List<Integer> intDimensions = element.getOutputTypes().get(0).getDimensions();
 
                 List<String> dimensions = new ArrayList<>();
diff --git a/src/main/resources/templates/gluon/elements/Output.ftl b/src/main/resources/templates/gluon/elements/Output.ftl
index 4fd9f028..dc06d3b6 100644
--- a/src/main/resources/templates/gluon/elements/Output.ftl
+++ b/src/main/resources/templates/gluon/elements/Output.ftl
@@ -1,3 +1,4 @@
+<#if element.inputs?size gte 1>
 <#assign input = element.inputs[0]>
 <#if mode == "FORWARD_FUNCTION">
         ${element.name} = ${input}
@@ -6,3 +7,4 @@
 <#elseif mode == "CPP_INLINE">
     ${element.name} = ${input};
 </#if>
+</#if>
\ No newline at end of file
diff --git a/src/main/resources/templates/gluon/pythonExecute.ftl b/src/main/resources/templates/gluon/pythonExecute.ftl
index 73d8f4dd..c0093eeb 100644
--- a/src/main/resources/templates/gluon/pythonExecute.ftl
+++ b/src/main/resources/templates/gluon/pythonExecute.ftl
@@ -1,6 +1,9 @@
 <#list tc.getLayerVariableMembers("batch_size")?keys as member>
                     ${member} = mx.nd.zeroes((${tc.join(tc.getLayerVariableMembers("batch_size")[member], ", ")},), ctx=mx_context)
 </#list>
+<#list tc.architecture.outputs as output>
+                    ${tc.getName(output)} = mx.nd.zeroes(((${tc.join(output.ioDeclaration.type.dimensions, ", ")},), ctx=mx_context)
+</#list>
 
 <#list tc.architecture.streams as stream>
 <#if stream.isTrainable()>
diff --git a/src/test/resources/target_code/CNNSupervisedTrainer_Alexnet.py b/src/test/resources/target_code/CNNSupervisedTrainer_Alexnet.py
index 538d579d..0d22884f 100644
--- a/src/test/resources/target_code/CNNSupervisedTrainer_Alexnet.py
+++ b/src/test/resources/target_code/CNNSupervisedTrainer_Alexnet.py
@@ -136,6 +136,7 @@ class CNNSupervisedTrainer_Alexnet:
                 predictions_label = batch.label[0].as_in_context(mx_context)
 
                 with autograd.record():
+                    predictions_ = mx.nd.zeroes(((10,), ctx=mx_context)
 
                     predictions_ = self._networks[0](data_)
 
@@ -172,6 +173,7 @@ class CNNSupervisedTrainer_Alexnet:
                 ]
 
                 if True:
+                    predictions_ = mx.nd.zeroes(((10,), ctx=mx_context)
 
                     predictions_ = self._networks[0](data_)
 
@@ -192,6 +194,7 @@ class CNNSupervisedTrainer_Alexnet:
                 ]
 
                 if True:
+                    predictions_ = mx.nd.zeroes(((10,), ctx=mx_context)
 
                     predictions_ = self._networks[0](data_)
 
diff --git a/src/test/resources/target_code/CNNSupervisedTrainer_CifarClassifierNetwork.py b/src/test/resources/target_code/CNNSupervisedTrainer_CifarClassifierNetwork.py
index 363b6830..2a8aee06 100644
--- a/src/test/resources/target_code/CNNSupervisedTrainer_CifarClassifierNetwork.py
+++ b/src/test/resources/target_code/CNNSupervisedTrainer_CifarClassifierNetwork.py
@@ -136,6 +136,7 @@ class CNNSupervisedTrainer_CifarClassifierNetwork:
                 softmax_label = batch.label[0].as_in_context(mx_context)
 
                 with autograd.record():
+                    softmax_ = mx.nd.zeroes(((10,), ctx=mx_context)
 
                     softmax_ = self._networks[0](data_)
 
@@ -172,6 +173,7 @@ class CNNSupervisedTrainer_CifarClassifierNetwork:
                 ]
 
                 if True:
+                    softmax_ = mx.nd.zeroes(((10,), ctx=mx_context)
 
                     softmax_ = self._networks[0](data_)
 
@@ -192,6 +194,7 @@ class CNNSupervisedTrainer_CifarClassifierNetwork:
                 ]
 
                 if True:
+                    softmax_ = mx.nd.zeroes(((10,), ctx=mx_context)
 
                     softmax_ = self._networks[0](data_)
 
diff --git a/src/test/resources/target_code/CNNSupervisedTrainer_VGG16.py b/src/test/resources/target_code/CNNSupervisedTrainer_VGG16.py
index 56606ef4..7651ff55 100644
--- a/src/test/resources/target_code/CNNSupervisedTrainer_VGG16.py
+++ b/src/test/resources/target_code/CNNSupervisedTrainer_VGG16.py
@@ -136,6 +136,7 @@ class CNNSupervisedTrainer_VGG16:
                 predictions_label = batch.label[0].as_in_context(mx_context)
 
                 with autograd.record():
+                    predictions_ = mx.nd.zeroes(((1000,), ctx=mx_context)
 
                     predictions_ = self._networks[0](data_)
 
@@ -172,6 +173,7 @@ class CNNSupervisedTrainer_VGG16:
                 ]
 
                 if True:
+                    predictions_ = mx.nd.zeroes(((1000,), ctx=mx_context)
 
                     predictions_ = self._networks[0](data_)
 
@@ -192,6 +194,7 @@ class CNNSupervisedTrainer_VGG16:
                 ]
 
                 if True:
+                    predictions_ = mx.nd.zeroes(((1000,), ctx=mx_context)
 
                     predictions_ = self._networks[0](data_)
 
-- 
GitLab