From a5e9d403640e65ac7b49b30443d105bb81f69291 Mon Sep 17 00:00:00 2001 From: Sebastian Nickels <sn1c@protonmail.ch> Date: Sat, 17 Aug 2019 21:51:49 +0200 Subject: [PATCH] Outputs now can be used as inputs --- .../CNNArch2GluonArchitectureSupportChecker.java | 5 +++++ .../gluongenerator/CNNArch2GluonTemplateController.java | 2 +- src/main/resources/templates/gluon/elements/Output.ftl | 2 ++ src/main/resources/templates/gluon/pythonExecute.ftl | 3 +++ .../resources/target_code/CNNSupervisedTrainer_Alexnet.py | 3 +++ .../CNNSupervisedTrainer_CifarClassifierNetwork.py | 3 +++ src/test/resources/target_code/CNNSupervisedTrainer_VGG16.py | 3 +++ 7 files changed, 20 insertions(+), 1 deletion(-) diff --git a/src/main/java/de/monticore/lang/monticar/cnnarch/gluongenerator/CNNArch2GluonArchitectureSupportChecker.java b/src/main/java/de/monticore/lang/monticar/cnnarch/gluongenerator/CNNArch2GluonArchitectureSupportChecker.java index 484b61ca..901be31d 100644 --- a/src/main/java/de/monticore/lang/monticar/cnnarch/gluongenerator/CNNArch2GluonArchitectureSupportChecker.java +++ b/src/main/java/de/monticore/lang/monticar/cnnarch/gluongenerator/CNNArch2GluonArchitectureSupportChecker.java @@ -37,4 +37,9 @@ public class CNNArch2GluonArchitectureSupportChecker extends ArchitectureSupport return true; } + @Override + protected boolean checkOutputAsInput(ArchitectureSymbol architecture) { + return true; + } + } diff --git a/src/main/java/de/monticore/lang/monticar/cnnarch/gluongenerator/CNNArch2GluonTemplateController.java b/src/main/java/de/monticore/lang/monticar/cnnarch/gluongenerator/CNNArch2GluonTemplateController.java index fa28ee43..864d4da7 100644 --- a/src/main/java/de/monticore/lang/monticar/cnnarch/gluongenerator/CNNArch2GluonTemplateController.java +++ b/src/main/java/de/monticore/lang/monticar/cnnarch/gluongenerator/CNNArch2GluonTemplateController.java @@ -173,7 +173,7 @@ public class CNNArch2GluonTemplateController extends CNNArchTemplateController { Map<String, List<String>> inputs = new LinkedHashMap<>(); for (ArchitectureElementSymbol element : stream.getFirstAtomicElements()) { - if (element.isInput()) { + if (element.isInput() || element.isOutput()) { List<Integer> intDimensions = element.getOutputTypes().get(0).getDimensions(); List<String> dimensions = new ArrayList<>(); diff --git a/src/main/resources/templates/gluon/elements/Output.ftl b/src/main/resources/templates/gluon/elements/Output.ftl index 4fd9f028..dc06d3b6 100644 --- a/src/main/resources/templates/gluon/elements/Output.ftl +++ b/src/main/resources/templates/gluon/elements/Output.ftl @@ -1,3 +1,4 @@ +<#if element.inputs?size gte 1> <#assign input = element.inputs[0]> <#if mode == "FORWARD_FUNCTION"> ${element.name} = ${input} @@ -6,3 +7,4 @@ <#elseif mode == "CPP_INLINE"> ${element.name} = ${input}; </#if> +</#if> \ No newline at end of file diff --git a/src/main/resources/templates/gluon/pythonExecute.ftl b/src/main/resources/templates/gluon/pythonExecute.ftl index 73d8f4dd..c0093eeb 100644 --- a/src/main/resources/templates/gluon/pythonExecute.ftl +++ b/src/main/resources/templates/gluon/pythonExecute.ftl @@ -1,6 +1,9 @@ <#list tc.getLayerVariableMembers("batch_size")?keys as member> ${member} = mx.nd.zeroes((${tc.join(tc.getLayerVariableMembers("batch_size")[member], ", ")},), ctx=mx_context) </#list> +<#list tc.architecture.outputs as output> + ${tc.getName(output)} = mx.nd.zeroes(((${tc.join(output.ioDeclaration.type.dimensions, ", ")},), ctx=mx_context) +</#list> <#list tc.architecture.streams as stream> <#if stream.isTrainable()> diff --git a/src/test/resources/target_code/CNNSupervisedTrainer_Alexnet.py b/src/test/resources/target_code/CNNSupervisedTrainer_Alexnet.py index 538d579d..0d22884f 100644 --- a/src/test/resources/target_code/CNNSupervisedTrainer_Alexnet.py +++ b/src/test/resources/target_code/CNNSupervisedTrainer_Alexnet.py @@ -136,6 +136,7 @@ class CNNSupervisedTrainer_Alexnet: predictions_label = batch.label[0].as_in_context(mx_context) with autograd.record(): + predictions_ = mx.nd.zeroes(((10,), ctx=mx_context) predictions_ = self._networks[0](data_) @@ -172,6 +173,7 @@ class CNNSupervisedTrainer_Alexnet: ] if True: + predictions_ = mx.nd.zeroes(((10,), ctx=mx_context) predictions_ = self._networks[0](data_) @@ -192,6 +194,7 @@ class CNNSupervisedTrainer_Alexnet: ] if True: + predictions_ = mx.nd.zeroes(((10,), ctx=mx_context) predictions_ = self._networks[0](data_) diff --git a/src/test/resources/target_code/CNNSupervisedTrainer_CifarClassifierNetwork.py b/src/test/resources/target_code/CNNSupervisedTrainer_CifarClassifierNetwork.py index 363b6830..2a8aee06 100644 --- a/src/test/resources/target_code/CNNSupervisedTrainer_CifarClassifierNetwork.py +++ b/src/test/resources/target_code/CNNSupervisedTrainer_CifarClassifierNetwork.py @@ -136,6 +136,7 @@ class CNNSupervisedTrainer_CifarClassifierNetwork: softmax_label = batch.label[0].as_in_context(mx_context) with autograd.record(): + softmax_ = mx.nd.zeroes(((10,), ctx=mx_context) softmax_ = self._networks[0](data_) @@ -172,6 +173,7 @@ class CNNSupervisedTrainer_CifarClassifierNetwork: ] if True: + softmax_ = mx.nd.zeroes(((10,), ctx=mx_context) softmax_ = self._networks[0](data_) @@ -192,6 +194,7 @@ class CNNSupervisedTrainer_CifarClassifierNetwork: ] if True: + softmax_ = mx.nd.zeroes(((10,), ctx=mx_context) softmax_ = self._networks[0](data_) diff --git a/src/test/resources/target_code/CNNSupervisedTrainer_VGG16.py b/src/test/resources/target_code/CNNSupervisedTrainer_VGG16.py index 56606ef4..7651ff55 100644 --- a/src/test/resources/target_code/CNNSupervisedTrainer_VGG16.py +++ b/src/test/resources/target_code/CNNSupervisedTrainer_VGG16.py @@ -136,6 +136,7 @@ class CNNSupervisedTrainer_VGG16: predictions_label = batch.label[0].as_in_context(mx_context) with autograd.record(): + predictions_ = mx.nd.zeroes(((1000,), ctx=mx_context) predictions_ = self._networks[0](data_) @@ -172,6 +173,7 @@ class CNNSupervisedTrainer_VGG16: ] if True: + predictions_ = mx.nd.zeroes(((1000,), ctx=mx_context) predictions_ = self._networks[0](data_) @@ -192,6 +194,7 @@ class CNNSupervisedTrainer_VGG16: ] if True: + predictions_ = mx.nd.zeroes(((1000,), ctx=mx_context) predictions_ = self._networks[0](data_) -- GitLab