From 01c16d6c76baded5f8eaea5d66d7f54bcb20af1d Mon Sep 17 00:00:00 2001
From: Christian Fuss <chrifuss@freenet.de>
Date: Mon, 9 Sep 2019 15:59:46 +0200
Subject: [PATCH] added functionality for ArgMax layer

---
 .../templates/gluon/CNNSupervisedTrainer.ftl  | 35 ++++++++++++-------
 .../templates/gluon/elements/ArgMax.ftl       |  3 +-
 .../templates/gluon/pythonExecute.ftl         | 10 ++++++
 .../CNNSupervisedTrainer_Alexnet.py           | 26 +++++++++-----
 ...upervisedTrainer_CifarClassifierNetwork.py | 26 +++++++++-----
 .../target_code/CNNSupervisedTrainer_VGG16.py | 26 +++++++++-----
 6 files changed, 89 insertions(+), 37 deletions(-)

diff --git a/src/main/resources/templates/gluon/CNNSupervisedTrainer.ftl b/src/main/resources/templates/gluon/CNNSupervisedTrainer.ftl
index c2bb59b9..579ac61b 100644
--- a/src/main/resources/templates/gluon/CNNSupervisedTrainer.ftl
+++ b/src/main/resources/templates/gluon/CNNSupervisedTrainer.ftl
@@ -185,13 +185,19 @@ class ${tc.fileNameWithoutEnding}:
                 if True: <#-- Fix indentation -->
 <#include "pythonExecute.ftl">
 
-                predictions = [
-<#list tc.architectureOutputs as output_name>
-                    mx.nd.argmax(${output_name}, axis=1)<#sep>,
-</#list>
-]
+                out_names=[]
+                <#list tc.architectureOutputs as output_name>
+                out_names.append(${output_name})
+                </#list>
+                predictions = []
+                for output_name in out_names:
+                    if mx.nd.shape_array(output_name).size > 1:
+                        predictions.append(mx.nd.argmax(output_name, axis=1))
+                    #ArgMax already applied
+                    else:
+                        predictions.append(output_name)
 
-                <#include "elements/BeamSearchStart.ftl">
+                <#include "elements/BeamSearch.ftl">
 
                 metric.update(preds=predictions, labels=labels)
             train_metric_score = metric.get()[1]
@@ -213,12 +219,17 @@ class ${tc.fileNameWithoutEnding}:
                 if True: <#-- Fix indentation -->
 <#include "pythonExecute.ftl">
 
-                predictions = [
-<#list tc.architectureOutputs as output_name>
-                    mx.nd.argmax(${output_name}, axis=1)<#sep>,
-</#list>
-
-                ]
+                out_names=[]
+                <#list tc.architectureOutputs as output_name>
+                out_names.append(${output_name})
+                </#list>
+                predictions = []
+                for output_name in out_names:
+                    if mx.nd.shape_array(output_name).size > 1:
+                        predictions.append(mx.nd.argmax(output_name, axis=1))
+                    #ArgMax already applied
+                    else:
+                        predictions.append(output_name)
 
                 metric.update(preds=predictions, labels=labels)
             test_metric_score = metric.get()[1]
diff --git a/src/main/resources/templates/gluon/elements/ArgMax.ftl b/src/main/resources/templates/gluon/elements/ArgMax.ftl
index 44c6e7ab..b6868735 100644
--- a/src/main/resources/templates/gluon/elements/ArgMax.ftl
+++ b/src/main/resources/templates/gluon/elements/ArgMax.ftl
@@ -1,4 +1,5 @@
 <#assign input = element.inputs[0]>
 <#if mode == "FORWARD_FUNCTION">
-        ${element.name} = F.ndarray.argmax(${input}, keepdims=True)
+        <#-- only passtrough method, argmax logic is applied in pythonExecute.ftl and CNNSupervisedTrainer.ftl -->
+        ${element.name} = ${input}
 </#if>
\ No newline at end of file
diff --git a/src/main/resources/templates/gluon/pythonExecute.ftl b/src/main/resources/templates/gluon/pythonExecute.ftl
index 40f86d65..ba627a11 100644
--- a/src/main/resources/templates/gluon/pythonExecute.ftl
+++ b/src/main/resources/templates/gluon/pythonExecute.ftl
@@ -9,10 +9,20 @@
 <#if networkInstruction.isUnroll()>
 <#list networkInstruction.toUnrollInstruction().resolvedBodies as resolvedBody>
                     ${tc.join(tc.getStreamOutputNames(networkInstruction.body, resolvedBody), ", ")} = self._networks[${networkInstruction?index}](${tc.join(tc.getStreamInputNames(networkInstruction.body, resolvedBody), ", ")})
+                    <#list resolvedBody.elements as element>
+                    <#if element.name == "ArgMax">
+                    ${tc.join(tc.getStreamOutputNames(networkInstruction.body, resolvedBody), ", ")} = mx.nd.argmax(${tc.join(tc.getStreamOutputNames(networkInstruction.body, resolvedBody), ", ")}, axis=1)
+                    </#if>
+                    </#list>
 </#list>
 <#else>
 <#if networkInstruction.body.isTrainable()>
                     ${tc.join(tc.getStreamOutputNames(networkInstruction.body), ", ")} = self._networks[${networkInstruction?index}](${tc.join(tc.getStreamInputNames(networkInstruction.body), ", ")})
+                    <#list networkInstruction.body.elements as element>
+                    <#if element.name == "ArgMax">
+                    ${tc.join(tc.getStreamOutputNames(networkInstruction.body), ", ")} = mx.nd.argmax(${tc.join(tc.getStreamOutputNames(networkInstruction.body), ", ")}, axis=1)
+                    </#if>
+                    </#list>
 <#else>
 ${tc.include(networkInstruction.body, "PYTHON_INLINE")}
 </#if>
diff --git a/src/test/resources/target_code/CNNSupervisedTrainer_Alexnet.py b/src/test/resources/target_code/CNNSupervisedTrainer_Alexnet.py
index 795f867d..e3682957 100644
--- a/src/test/resources/target_code/CNNSupervisedTrainer_Alexnet.py
+++ b/src/test/resources/target_code/CNNSupervisedTrainer_Alexnet.py
@@ -140,7 +140,6 @@ class CNNSupervisedTrainer_Alexnet:
 
                     predictions_ = self._networks[0](data_)
 
-
                     loss = \
                         loss_function(predictions_, predictions_label)
 
@@ -178,9 +177,15 @@ class CNNSupervisedTrainer_Alexnet:
 
                     predictions_ = self._networks[0](data_)
 
-
-                predictions = [
-                    mx.nd.argmax(predictions_, axis=1)]
+                out_names=[]
+                out_names.append(predictions_)
+                predictions = []
+                for output_name in out_names:
+                    if mx.nd.shape_array(output_name).size > 1:
+                        predictions.append(mx.nd.argmax(output_name, axis=1))
+                    #ArgMax already applied
+                    else:
+                        predictions.append(output_name)
 
 
                 metric.update(preds=predictions, labels=labels)
@@ -200,10 +205,15 @@ class CNNSupervisedTrainer_Alexnet:
 
                     predictions_ = self._networks[0](data_)
 
-
-                predictions = [
-                    mx.nd.argmax(predictions_, axis=1)
-                ]
+                out_names=[]
+                out_names.append(predictions_)
+                predictions = []
+                for output_name in out_names:
+                    if mx.nd.shape_array(output_name).size > 1:
+                        predictions.append(mx.nd.argmax(output_name, axis=1))
+                    #ArgMax already applied
+                    else:
+                        predictions.append(output_name)
 
                 metric.update(preds=predictions, labels=labels)
             test_metric_score = metric.get()[1]
diff --git a/src/test/resources/target_code/CNNSupervisedTrainer_CifarClassifierNetwork.py b/src/test/resources/target_code/CNNSupervisedTrainer_CifarClassifierNetwork.py
index a34bcc25..7cf9b09a 100644
--- a/src/test/resources/target_code/CNNSupervisedTrainer_CifarClassifierNetwork.py
+++ b/src/test/resources/target_code/CNNSupervisedTrainer_CifarClassifierNetwork.py
@@ -140,7 +140,6 @@ class CNNSupervisedTrainer_CifarClassifierNetwork:
 
                     softmax_ = self._networks[0](data_)
 
-
                     loss = \
                         loss_function(softmax_, softmax_label)
 
@@ -178,9 +177,15 @@ class CNNSupervisedTrainer_CifarClassifierNetwork:
 
                     softmax_ = self._networks[0](data_)
 
-
-                predictions = [
-                    mx.nd.argmax(softmax_, axis=1)]
+                out_names=[]
+                out_names.append(softmax_)
+                predictions = []
+                for output_name in out_names:
+                    if mx.nd.shape_array(output_name).size > 1:
+                        predictions.append(mx.nd.argmax(output_name, axis=1))
+                    #ArgMax already applied
+                    else:
+                        predictions.append(output_name)
 
 
                 metric.update(preds=predictions, labels=labels)
@@ -200,10 +205,15 @@ class CNNSupervisedTrainer_CifarClassifierNetwork:
 
                     softmax_ = self._networks[0](data_)
 
-
-                predictions = [
-                    mx.nd.argmax(softmax_, axis=1)
-                ]
+                out_names=[]
+                out_names.append(softmax_)
+                predictions = []
+                for output_name in out_names:
+                    if mx.nd.shape_array(output_name).size > 1:
+                        predictions.append(mx.nd.argmax(output_name, axis=1))
+                    #ArgMax already applied
+                    else:
+                        predictions.append(output_name)
 
                 metric.update(preds=predictions, labels=labels)
             test_metric_score = metric.get()[1]
diff --git a/src/test/resources/target_code/CNNSupervisedTrainer_VGG16.py b/src/test/resources/target_code/CNNSupervisedTrainer_VGG16.py
index 290db22b..49090160 100644
--- a/src/test/resources/target_code/CNNSupervisedTrainer_VGG16.py
+++ b/src/test/resources/target_code/CNNSupervisedTrainer_VGG16.py
@@ -140,7 +140,6 @@ class CNNSupervisedTrainer_VGG16:
 
                     predictions_ = self._networks[0](data_)
 
-
                     loss = \
                         loss_function(predictions_, predictions_label)
 
@@ -178,9 +177,15 @@ class CNNSupervisedTrainer_VGG16:
 
                     predictions_ = self._networks[0](data_)
 
-
-                predictions = [
-                    mx.nd.argmax(predictions_, axis=1)]
+                out_names=[]
+                out_names.append(predictions_)
+                predictions = []
+                for output_name in out_names:
+                    if mx.nd.shape_array(output_name).size > 1:
+                        predictions.append(mx.nd.argmax(output_name, axis=1))
+                    #ArgMax already applied
+                    else:
+                        predictions.append(output_name)
 
 
                 metric.update(preds=predictions, labels=labels)
@@ -200,10 +205,15 @@ class CNNSupervisedTrainer_VGG16:
 
                     predictions_ = self._networks[0](data_)
 
-
-                predictions = [
-                    mx.nd.argmax(predictions_, axis=1)
-                ]
+                out_names=[]
+                out_names.append(predictions_)
+                predictions = []
+                for output_name in out_names:
+                    if mx.nd.shape_array(output_name).size > 1:
+                        predictions.append(mx.nd.argmax(output_name, axis=1))
+                    #ArgMax already applied
+                    else:
+                        predictions.append(output_name)
 
                 metric.update(preds=predictions, labels=labels)
             test_metric_score = metric.get()[1]
-- 
GitLab