From 74cd8aa96de84299e0ee30962ec46ace9e36bd39 Mon Sep 17 00:00:00 2001 From: Nicola Gatto Date: Wed, 24 Apr 2019 18:16:47 +0200 Subject: [PATCH] Fix gluon generation test --- .../gluon/CNNCreator_mnist_mnistClassifier_net.py | 5 ++++- .../CNNPredictor_mnist_mnistClassifier_net.h | 15 +++++---------- 2 files changed, 9 insertions(+), 11 deletions(-) diff --git a/src/test/resources/target_code/gluon/CNNCreator_mnist_mnistClassifier_net.py b/src/test/resources/target_code/gluon/CNNCreator_mnist_mnistClassifier_net.py index 8a699b4..b60dc93 100644 --- a/src/test/resources/target_code/gluon/CNNCreator_mnist_mnistClassifier_net.py +++ b/src/test/resources/target_code/gluon/CNNCreator_mnist_mnistClassifier_net.py @@ -6,12 +6,15 @@ from CNNNet_mnist_mnistClassifier_net import Net class CNNCreator_mnist_mnistClassifier_net: _model_dir_ = "model/mnist.LeNetNetwork/" _model_prefix_ = "model" - _input_shapes_ = [(1,28,28)] + _input_shapes_ = [(1,28,28,)] def __init__(self): self.weight_initializer = mx.init.Normal() self.net = None + def get_input_shapes(self): + return self._input_shapes_ + def load(self, context): lastEpoch = 0 param_file = None diff --git a/src/test/resources/target_code/gluon/CNNPredictor_mnist_mnistClassifier_net.h b/src/test/resources/target_code/gluon/CNNPredictor_mnist_mnistClassifier_net.h index 5823ff3..7e61699 100644 --- a/src/test/resources/target_code/gluon/CNNPredictor_mnist_mnistClassifier_net.h +++ b/src/test/resources/target_code/gluon/CNNPredictor_mnist_mnistClassifier_net.h @@ -30,8 +30,7 @@ public: void predict(const std::vector &image, std::vector &predictions){ - MXPredSetInput(handle, "data", image.data(), image.size()); - //MXPredSetInput(handle, "image", image.data(), image.size()); + MXPredSetInput(handle, "data", image.data(), static_cast(image.size())); MXPredForward(handle); @@ -61,8 +60,6 @@ public: int dev_type = use_gpu ? 2 : 1; int dev_id = 0; - handle = 0; - if (json_data.GetLength() == 0 || param_data.GetLength() == 0) { std::exit(-1); @@ -70,10 +67,8 @@ public: const mx_uint num_input_nodes = input_keys.size(); - const char* input_keys_ptr[num_input_nodes]; - for(mx_uint i = 0; i < num_input_nodes; i++){ - input_keys_ptr[i] = input_keys[i].c_str(); - } + const char* input_key[1] = { "data" }; + const char** input_keys_ptr = input_key; mx_uint shape_data_size = 0; mx_uint input_shape_indptr[input_shapes.size() + 1]; @@ -92,8 +87,8 @@ public: } } - MXPredCreate((const char*)json_data.GetBuffer(), - (const char*)param_data.GetBuffer(), + MXPredCreate(static_cast(json_data.GetBuffer()), + static_cast(param_data.GetBuffer()), static_cast(param_data.GetLength()), dev_type, dev_id, -- GitLab