Commit 74cd8aa9 authored by Nicola Gatto's avatar Nicola Gatto

Fix gluon generation test

parent 8e8d2bc3
...@@ -6,12 +6,15 @@ from CNNNet_mnist_mnistClassifier_net import Net ...@@ -6,12 +6,15 @@ from CNNNet_mnist_mnistClassifier_net import Net
class CNNCreator_mnist_mnistClassifier_net: class CNNCreator_mnist_mnistClassifier_net:
_model_dir_ = "model/mnist.LeNetNetwork/" _model_dir_ = "model/mnist.LeNetNetwork/"
_model_prefix_ = "model" _model_prefix_ = "model"
_input_shapes_ = [(1,28,28)] _input_shapes_ = [(1,28,28,)]
def __init__(self): def __init__(self):
self.weight_initializer = mx.init.Normal() self.weight_initializer = mx.init.Normal()
self.net = None self.net = None
def get_input_shapes(self):
return self._input_shapes_
def load(self, context): def load(self, context):
lastEpoch = 0 lastEpoch = 0
param_file = None param_file = None
......
...@@ -30,8 +30,7 @@ public: ...@@ -30,8 +30,7 @@ public:
void predict(const std::vector<float> &image, void predict(const std::vector<float> &image,
std::vector<float> &predictions){ std::vector<float> &predictions){
MXPredSetInput(handle, "data", image.data(), image.size()); MXPredSetInput(handle, "data", image.data(), static_cast<mx_uint>(image.size()));
//MXPredSetInput(handle, "image", image.data(), image.size());
MXPredForward(handle); MXPredForward(handle);
...@@ -61,8 +60,6 @@ public: ...@@ -61,8 +60,6 @@ public:
int dev_type = use_gpu ? 2 : 1; int dev_type = use_gpu ? 2 : 1;
int dev_id = 0; int dev_id = 0;
handle = 0;
if (json_data.GetLength() == 0 || if (json_data.GetLength() == 0 ||
param_data.GetLength() == 0) { param_data.GetLength() == 0) {
std::exit(-1); std::exit(-1);
...@@ -70,10 +67,8 @@ public: ...@@ -70,10 +67,8 @@ public:
const mx_uint num_input_nodes = input_keys.size(); const mx_uint num_input_nodes = input_keys.size();
const char* input_keys_ptr[num_input_nodes]; const char* input_key[1] = { "data" };
for(mx_uint i = 0; i < num_input_nodes; i++){ const char** input_keys_ptr = input_key;
input_keys_ptr[i] = input_keys[i].c_str();
}
mx_uint shape_data_size = 0; mx_uint shape_data_size = 0;
mx_uint input_shape_indptr[input_shapes.size() + 1]; mx_uint input_shape_indptr[input_shapes.size() + 1];
...@@ -92,8 +87,8 @@ public: ...@@ -92,8 +87,8 @@ public:
} }
} }
MXPredCreate((const char*)json_data.GetBuffer(), MXPredCreate(static_cast<const char*>(json_data.GetBuffer()),
(const char*)param_data.GetBuffer(), static_cast<const char*>(param_data.GetBuffer()),
static_cast<size_t>(param_data.GetLength()), static_cast<size_t>(param_data.GetLength()),
dev_type, dev_type,
dev_id, dev_id,
......
Markdown is supported
0%
or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment