diff --git a/src/main/resources/templates/gluon/CNNSupervisedTrainer.ftl b/src/main/resources/templates/gluon/CNNSupervisedTrainer.ftl index c2bb59b93b6ae17ae610646738435efac408810e..579ac61b8596f5e15f3851329cd5bff219964a72 100644 --- a/src/main/resources/templates/gluon/CNNSupervisedTrainer.ftl +++ b/src/main/resources/templates/gluon/CNNSupervisedTrainer.ftl @@ -185,13 +185,19 @@ class ${tc.fileNameWithoutEnding}: if True: <#-- Fix indentation --> <#include "pythonExecute.ftl"> - predictions = [ -<#list tc.architectureOutputs as output_name> - mx.nd.argmax(${output_name}, axis=1)<#sep>, -</#list> -] + out_names=[] + <#list tc.architectureOutputs as output_name> + out_names.append(${output_name}) + </#list> + predictions = [] + for output_name in out_names: + if mx.nd.shape_array(output_name).size > 1: + predictions.append(mx.nd.argmax(output_name, axis=1)) + #ArgMax already applied + else: + predictions.append(output_name) - <#include "elements/BeamSearchStart.ftl"> + <#include "elements/BeamSearch.ftl"> metric.update(preds=predictions, labels=labels) train_metric_score = metric.get()[1] @@ -213,12 +219,17 @@ class ${tc.fileNameWithoutEnding}: if True: <#-- Fix indentation --> <#include "pythonExecute.ftl"> - predictions = [ -<#list tc.architectureOutputs as output_name> - mx.nd.argmax(${output_name}, axis=1)<#sep>, -</#list> - - ] + out_names=[] + <#list tc.architectureOutputs as output_name> + out_names.append(${output_name}) + </#list> + predictions = [] + for output_name in out_names: + if mx.nd.shape_array(output_name).size > 1: + predictions.append(mx.nd.argmax(output_name, axis=1)) + #ArgMax already applied + else: + predictions.append(output_name) metric.update(preds=predictions, labels=labels) test_metric_score = metric.get()[1] diff --git a/src/main/resources/templates/gluon/elements/ArgMax.ftl b/src/main/resources/templates/gluon/elements/ArgMax.ftl index 44c6e7ab11b36e7b242a0e9db689594a6963f1e1..b686873590e9f4400b5a1551f8373070afc6a375 100644 --- a/src/main/resources/templates/gluon/elements/ArgMax.ftl +++ b/src/main/resources/templates/gluon/elements/ArgMax.ftl @@ -1,4 +1,5 @@ <#assign input = element.inputs[0]> <#if mode == "FORWARD_FUNCTION"> - ${element.name} = F.ndarray.argmax(${input}, keepdims=True) + <#-- only passtrough method, argmax logic is applied in pythonExecute.ftl and CNNSupervisedTrainer.ftl --> + ${element.name} = ${input} </#if> \ No newline at end of file diff --git a/src/main/resources/templates/gluon/pythonExecute.ftl b/src/main/resources/templates/gluon/pythonExecute.ftl index 40f86d6570e11c7ecd83473f939de7d002600615..ba627a11a167c0c9712f92cad9b8f535410be176 100644 --- a/src/main/resources/templates/gluon/pythonExecute.ftl +++ b/src/main/resources/templates/gluon/pythonExecute.ftl @@ -9,10 +9,20 @@ <#if networkInstruction.isUnroll()> <#list networkInstruction.toUnrollInstruction().resolvedBodies as resolvedBody> ${tc.join(tc.getStreamOutputNames(networkInstruction.body, resolvedBody), ", ")} = self._networks[${networkInstruction?index}](${tc.join(tc.getStreamInputNames(networkInstruction.body, resolvedBody), ", ")}) + <#list resolvedBody.elements as element> + <#if element.name == "ArgMax"> + ${tc.join(tc.getStreamOutputNames(networkInstruction.body, resolvedBody), ", ")} = mx.nd.argmax(${tc.join(tc.getStreamOutputNames(networkInstruction.body, resolvedBody), ", ")}, axis=1) + </#if> + </#list> </#list> <#else> <#if networkInstruction.body.isTrainable()> ${tc.join(tc.getStreamOutputNames(networkInstruction.body), ", ")} = self._networks[${networkInstruction?index}](${tc.join(tc.getStreamInputNames(networkInstruction.body), ", ")}) + <#list networkInstruction.body.elements as element> + <#if element.name == "ArgMax"> + ${tc.join(tc.getStreamOutputNames(networkInstruction.body), ", ")} = mx.nd.argmax(${tc.join(tc.getStreamOutputNames(networkInstruction.body), ", ")}, axis=1) + </#if> + </#list> <#else> ${tc.include(networkInstruction.body, "PYTHON_INLINE")} </#if> diff --git a/src/test/resources/target_code/CNNSupervisedTrainer_Alexnet.py b/src/test/resources/target_code/CNNSupervisedTrainer_Alexnet.py index 795f867ddf101c9be48848cacc112de73b0255f9..e3682957ecb5aef2d050f5fbbbdae2efa1055ade 100644 --- a/src/test/resources/target_code/CNNSupervisedTrainer_Alexnet.py +++ b/src/test/resources/target_code/CNNSupervisedTrainer_Alexnet.py @@ -140,7 +140,6 @@ class CNNSupervisedTrainer_Alexnet: predictions_ = self._networks[0](data_) - loss = \ loss_function(predictions_, predictions_label) @@ -178,9 +177,15 @@ class CNNSupervisedTrainer_Alexnet: predictions_ = self._networks[0](data_) - - predictions = [ - mx.nd.argmax(predictions_, axis=1)] + out_names=[] + out_names.append(predictions_) + predictions = [] + for output_name in out_names: + if mx.nd.shape_array(output_name).size > 1: + predictions.append(mx.nd.argmax(output_name, axis=1)) + #ArgMax already applied + else: + predictions.append(output_name) metric.update(preds=predictions, labels=labels) @@ -200,10 +205,15 @@ class CNNSupervisedTrainer_Alexnet: predictions_ = self._networks[0](data_) - - predictions = [ - mx.nd.argmax(predictions_, axis=1) - ] + out_names=[] + out_names.append(predictions_) + predictions = [] + for output_name in out_names: + if mx.nd.shape_array(output_name).size > 1: + predictions.append(mx.nd.argmax(output_name, axis=1)) + #ArgMax already applied + else: + predictions.append(output_name) metric.update(preds=predictions, labels=labels) test_metric_score = metric.get()[1] diff --git a/src/test/resources/target_code/CNNSupervisedTrainer_CifarClassifierNetwork.py b/src/test/resources/target_code/CNNSupervisedTrainer_CifarClassifierNetwork.py index a34bcc250f55fd3aab285c66c5db7a55deb47ee9..7cf9b09a37e28b4cf334f7083098fa2f6c12a038 100644 --- a/src/test/resources/target_code/CNNSupervisedTrainer_CifarClassifierNetwork.py +++ b/src/test/resources/target_code/CNNSupervisedTrainer_CifarClassifierNetwork.py @@ -140,7 +140,6 @@ class CNNSupervisedTrainer_CifarClassifierNetwork: softmax_ = self._networks[0](data_) - loss = \ loss_function(softmax_, softmax_label) @@ -178,9 +177,15 @@ class CNNSupervisedTrainer_CifarClassifierNetwork: softmax_ = self._networks[0](data_) - - predictions = [ - mx.nd.argmax(softmax_, axis=1)] + out_names=[] + out_names.append(softmax_) + predictions = [] + for output_name in out_names: + if mx.nd.shape_array(output_name).size > 1: + predictions.append(mx.nd.argmax(output_name, axis=1)) + #ArgMax already applied + else: + predictions.append(output_name) metric.update(preds=predictions, labels=labels) @@ -200,10 +205,15 @@ class CNNSupervisedTrainer_CifarClassifierNetwork: softmax_ = self._networks[0](data_) - - predictions = [ - mx.nd.argmax(softmax_, axis=1) - ] + out_names=[] + out_names.append(softmax_) + predictions = [] + for output_name in out_names: + if mx.nd.shape_array(output_name).size > 1: + predictions.append(mx.nd.argmax(output_name, axis=1)) + #ArgMax already applied + else: + predictions.append(output_name) metric.update(preds=predictions, labels=labels) test_metric_score = metric.get()[1] diff --git a/src/test/resources/target_code/CNNSupervisedTrainer_VGG16.py b/src/test/resources/target_code/CNNSupervisedTrainer_VGG16.py index 290db22b6432beabc785a22cf99c3a33b906916f..49090160070411ab904280b9a69c8392beab4d0b 100644 --- a/src/test/resources/target_code/CNNSupervisedTrainer_VGG16.py +++ b/src/test/resources/target_code/CNNSupervisedTrainer_VGG16.py @@ -140,7 +140,6 @@ class CNNSupervisedTrainer_VGG16: predictions_ = self._networks[0](data_) - loss = \ loss_function(predictions_, predictions_label) @@ -178,9 +177,15 @@ class CNNSupervisedTrainer_VGG16: predictions_ = self._networks[0](data_) - - predictions = [ - mx.nd.argmax(predictions_, axis=1)] + out_names=[] + out_names.append(predictions_) + predictions = [] + for output_name in out_names: + if mx.nd.shape_array(output_name).size > 1: + predictions.append(mx.nd.argmax(output_name, axis=1)) + #ArgMax already applied + else: + predictions.append(output_name) metric.update(preds=predictions, labels=labels) @@ -200,10 +205,15 @@ class CNNSupervisedTrainer_VGG16: predictions_ = self._networks[0](data_) - - predictions = [ - mx.nd.argmax(predictions_, axis=1) - ] + out_names=[] + out_names.append(predictions_) + predictions = [] + for output_name in out_names: + if mx.nd.shape_array(output_name).size > 1: + predictions.append(mx.nd.argmax(output_name, axis=1)) + #ArgMax already applied + else: + predictions.append(output_name) metric.update(preds=predictions, labels=labels) test_metric_score = metric.get()[1]