diff --git a/src/main/resources/templates/gluon/CNNSupervisedTrainer.ftl b/src/main/resources/templates/gluon/CNNSupervisedTrainer.ftl
index c2bb59b93b6ae17ae610646738435efac408810e..579ac61b8596f5e15f3851329cd5bff219964a72 100644
--- a/src/main/resources/templates/gluon/CNNSupervisedTrainer.ftl
+++ b/src/main/resources/templates/gluon/CNNSupervisedTrainer.ftl
@@ -185,13 +185,19 @@ class ${tc.fileNameWithoutEnding}:
                 if True: <#-- Fix indentation -->
 <#include "pythonExecute.ftl">
 
-                predictions = [
-<#list tc.architectureOutputs as output_name>
-                    mx.nd.argmax(${output_name}, axis=1)<#sep>,
-</#list>
-]
+                out_names=[]
+                <#list tc.architectureOutputs as output_name>
+                out_names.append(${output_name})
+                </#list>
+                predictions = []
+                for output_name in out_names:
+                    if mx.nd.shape_array(output_name).size > 1:
+                        predictions.append(mx.nd.argmax(output_name, axis=1))
+                    #ArgMax already applied
+                    else:
+                        predictions.append(output_name)
 
-                <#include "elements/BeamSearchStart.ftl">
+                <#include "elements/BeamSearch.ftl">
 
                 metric.update(preds=predictions, labels=labels)
             train_metric_score = metric.get()[1]
@@ -213,12 +219,17 @@ class ${tc.fileNameWithoutEnding}:
                 if True: <#-- Fix indentation -->
 <#include "pythonExecute.ftl">
 
-                predictions = [
-<#list tc.architectureOutputs as output_name>
-                    mx.nd.argmax(${output_name}, axis=1)<#sep>,
-</#list>
-
-                ]
+                out_names=[]
+                <#list tc.architectureOutputs as output_name>
+                out_names.append(${output_name})
+                </#list>
+                predictions = []
+                for output_name in out_names:
+                    if mx.nd.shape_array(output_name).size > 1:
+                        predictions.append(mx.nd.argmax(output_name, axis=1))
+                    #ArgMax already applied
+                    else:
+                        predictions.append(output_name)
 
                 metric.update(preds=predictions, labels=labels)
             test_metric_score = metric.get()[1]
diff --git a/src/main/resources/templates/gluon/elements/ArgMax.ftl b/src/main/resources/templates/gluon/elements/ArgMax.ftl
index 44c6e7ab11b36e7b242a0e9db689594a6963f1e1..b686873590e9f4400b5a1551f8373070afc6a375 100644
--- a/src/main/resources/templates/gluon/elements/ArgMax.ftl
+++ b/src/main/resources/templates/gluon/elements/ArgMax.ftl
@@ -1,4 +1,5 @@
 <#assign input = element.inputs[0]>
 <#if mode == "FORWARD_FUNCTION">
-        ${element.name} = F.ndarray.argmax(${input}, keepdims=True)
+        <#-- only passtrough method, argmax logic is applied in pythonExecute.ftl and CNNSupervisedTrainer.ftl -->
+        ${element.name} = ${input}
 </#if>
\ No newline at end of file
diff --git a/src/main/resources/templates/gluon/pythonExecute.ftl b/src/main/resources/templates/gluon/pythonExecute.ftl
index 40f86d6570e11c7ecd83473f939de7d002600615..ba627a11a167c0c9712f92cad9b8f535410be176 100644
--- a/src/main/resources/templates/gluon/pythonExecute.ftl
+++ b/src/main/resources/templates/gluon/pythonExecute.ftl
@@ -9,10 +9,20 @@
 <#if networkInstruction.isUnroll()>
 <#list networkInstruction.toUnrollInstruction().resolvedBodies as resolvedBody>
                     ${tc.join(tc.getStreamOutputNames(networkInstruction.body, resolvedBody), ", ")} = self._networks[${networkInstruction?index}](${tc.join(tc.getStreamInputNames(networkInstruction.body, resolvedBody), ", ")})
+                    <#list resolvedBody.elements as element>
+                    <#if element.name == "ArgMax">
+                    ${tc.join(tc.getStreamOutputNames(networkInstruction.body, resolvedBody), ", ")} = mx.nd.argmax(${tc.join(tc.getStreamOutputNames(networkInstruction.body, resolvedBody), ", ")}, axis=1)
+                    </#if>
+                    </#list>
 </#list>
 <#else>
 <#if networkInstruction.body.isTrainable()>
                     ${tc.join(tc.getStreamOutputNames(networkInstruction.body), ", ")} = self._networks[${networkInstruction?index}](${tc.join(tc.getStreamInputNames(networkInstruction.body), ", ")})
+                    <#list networkInstruction.body.elements as element>
+                    <#if element.name == "ArgMax">
+                    ${tc.join(tc.getStreamOutputNames(networkInstruction.body), ", ")} = mx.nd.argmax(${tc.join(tc.getStreamOutputNames(networkInstruction.body), ", ")}, axis=1)
+                    </#if>
+                    </#list>
 <#else>
 ${tc.include(networkInstruction.body, "PYTHON_INLINE")}
 </#if>
diff --git a/src/test/resources/target_code/CNNSupervisedTrainer_Alexnet.py b/src/test/resources/target_code/CNNSupervisedTrainer_Alexnet.py
index 795f867ddf101c9be48848cacc112de73b0255f9..e3682957ecb5aef2d050f5fbbbdae2efa1055ade 100644
--- a/src/test/resources/target_code/CNNSupervisedTrainer_Alexnet.py
+++ b/src/test/resources/target_code/CNNSupervisedTrainer_Alexnet.py
@@ -140,7 +140,6 @@ class CNNSupervisedTrainer_Alexnet:
 
                     predictions_ = self._networks[0](data_)
 
-
                     loss = \
                         loss_function(predictions_, predictions_label)
 
@@ -178,9 +177,15 @@ class CNNSupervisedTrainer_Alexnet:
 
                     predictions_ = self._networks[0](data_)
 
-
-                predictions = [
-                    mx.nd.argmax(predictions_, axis=1)]
+                out_names=[]
+                out_names.append(predictions_)
+                predictions = []
+                for output_name in out_names:
+                    if mx.nd.shape_array(output_name).size > 1:
+                        predictions.append(mx.nd.argmax(output_name, axis=1))
+                    #ArgMax already applied
+                    else:
+                        predictions.append(output_name)
 
 
                 metric.update(preds=predictions, labels=labels)
@@ -200,10 +205,15 @@ class CNNSupervisedTrainer_Alexnet:
 
                     predictions_ = self._networks[0](data_)
 
-
-                predictions = [
-                    mx.nd.argmax(predictions_, axis=1)
-                ]
+                out_names=[]
+                out_names.append(predictions_)
+                predictions = []
+                for output_name in out_names:
+                    if mx.nd.shape_array(output_name).size > 1:
+                        predictions.append(mx.nd.argmax(output_name, axis=1))
+                    #ArgMax already applied
+                    else:
+                        predictions.append(output_name)
 
                 metric.update(preds=predictions, labels=labels)
             test_metric_score = metric.get()[1]
diff --git a/src/test/resources/target_code/CNNSupervisedTrainer_CifarClassifierNetwork.py b/src/test/resources/target_code/CNNSupervisedTrainer_CifarClassifierNetwork.py
index a34bcc250f55fd3aab285c66c5db7a55deb47ee9..7cf9b09a37e28b4cf334f7083098fa2f6c12a038 100644
--- a/src/test/resources/target_code/CNNSupervisedTrainer_CifarClassifierNetwork.py
+++ b/src/test/resources/target_code/CNNSupervisedTrainer_CifarClassifierNetwork.py
@@ -140,7 +140,6 @@ class CNNSupervisedTrainer_CifarClassifierNetwork:
 
                     softmax_ = self._networks[0](data_)
 
-
                     loss = \
                         loss_function(softmax_, softmax_label)
 
@@ -178,9 +177,15 @@ class CNNSupervisedTrainer_CifarClassifierNetwork:
 
                     softmax_ = self._networks[0](data_)
 
-
-                predictions = [
-                    mx.nd.argmax(softmax_, axis=1)]
+                out_names=[]
+                out_names.append(softmax_)
+                predictions = []
+                for output_name in out_names:
+                    if mx.nd.shape_array(output_name).size > 1:
+                        predictions.append(mx.nd.argmax(output_name, axis=1))
+                    #ArgMax already applied
+                    else:
+                        predictions.append(output_name)
 
 
                 metric.update(preds=predictions, labels=labels)
@@ -200,10 +205,15 @@ class CNNSupervisedTrainer_CifarClassifierNetwork:
 
                     softmax_ = self._networks[0](data_)
 
-
-                predictions = [
-                    mx.nd.argmax(softmax_, axis=1)
-                ]
+                out_names=[]
+                out_names.append(softmax_)
+                predictions = []
+                for output_name in out_names:
+                    if mx.nd.shape_array(output_name).size > 1:
+                        predictions.append(mx.nd.argmax(output_name, axis=1))
+                    #ArgMax already applied
+                    else:
+                        predictions.append(output_name)
 
                 metric.update(preds=predictions, labels=labels)
             test_metric_score = metric.get()[1]
diff --git a/src/test/resources/target_code/CNNSupervisedTrainer_VGG16.py b/src/test/resources/target_code/CNNSupervisedTrainer_VGG16.py
index 290db22b6432beabc785a22cf99c3a33b906916f..49090160070411ab904280b9a69c8392beab4d0b 100644
--- a/src/test/resources/target_code/CNNSupervisedTrainer_VGG16.py
+++ b/src/test/resources/target_code/CNNSupervisedTrainer_VGG16.py
@@ -140,7 +140,6 @@ class CNNSupervisedTrainer_VGG16:
 
                     predictions_ = self._networks[0](data_)
 
-
                     loss = \
                         loss_function(predictions_, predictions_label)
 
@@ -178,9 +177,15 @@ class CNNSupervisedTrainer_VGG16:
 
                     predictions_ = self._networks[0](data_)
 
-
-                predictions = [
-                    mx.nd.argmax(predictions_, axis=1)]
+                out_names=[]
+                out_names.append(predictions_)
+                predictions = []
+                for output_name in out_names:
+                    if mx.nd.shape_array(output_name).size > 1:
+                        predictions.append(mx.nd.argmax(output_name, axis=1))
+                    #ArgMax already applied
+                    else:
+                        predictions.append(output_name)
 
 
                 metric.update(preds=predictions, labels=labels)
@@ -200,10 +205,15 @@ class CNNSupervisedTrainer_VGG16:
 
                     predictions_ = self._networks[0](data_)
 
-
-                predictions = [
-                    mx.nd.argmax(predictions_, axis=1)
-                ]
+                out_names=[]
+                out_names.append(predictions_)
+                predictions = []
+                for output_name in out_names:
+                    if mx.nd.shape_array(output_name).size > 1:
+                        predictions.append(mx.nd.argmax(output_name, axis=1))
+                    #ArgMax already applied
+                    else:
+                        predictions.append(output_name)
 
                 metric.update(preds=predictions, labels=labels)
             test_metric_score = metric.get()[1]