Commit 3c80ac8b authored by Sebastian Nickels's avatar Sebastian Nickels
Browse files

Implemented multiple inputs

parent ec3f56ed
......@@ -6,7 +6,7 @@ from CNNNet_${tc.fullArchitectureName} import Net
class ${tc.fileNameWithoutEnding}:
_model_dir_ = "model/${tc.componentName}/"
_model_prefix_ = "model"
_input_shapes_ = [<#list tc.architecture.inputs as input>(${tc.join(input.definition.type.dimensions, ",")},)</#list>]
_input_shapes_ = [<#list tc.architecture.inputs as input>(${tc.join(input.definition.type.dimensions, ",")},)<#sep>, </#list>]
def __init__(self):
self.weight_initializer = mx.init.Normal()
......
......@@ -13,9 +13,14 @@ class ${tc.fileNameWithoutEnding}{
public:
const std::string json_file = "model/${tc.componentName}/model_newest-symbol.json";
const std::string param_file = "model/${tc.componentName}/model_newest-0000.params";
//const std::vector<std::string> input_keys = {"data"};
const std::vector<std::string> input_keys = {${tc.join(tc.architectureInputs, ",", "\"", "\"")}};
const std::vector<std::vector<mx_uint>> input_shapes = {<#list tc.architecture.inputs as input>{1,${tc.join(input.definition.type.dimensions, ",")}}<#if input?has_next>,</#if></#list>};
const std::vector<std::string> input_keys = {
<#if (tc.architectureInputs?size == 1)>
"data"
<#else>
<#list tc.architectureInputs as inputName>"data${inputName?index}"<#sep>, </#list>
</#if>
};
const std::vector<std::vector<mx_uint>> input_shapes = {<#list tc.architecture.inputs as input>{1, ${tc.join(input.definition.type.dimensions, ", ")}}<#sep>, </#list>};
const bool use_gpu = false;
PredictorHandle handle;
......@@ -31,7 +36,11 @@ public:
void predict(${tc.join(tc.architectureInputs, ", ", "const std::vector<float> &", "")},
${tc.join(tc.architectureOutputs, ", ", "std::vector<float> &", "")}){
<#list tc.architectureInputs as inputName>
<#if (tc.architectureInputs?size == 1)>
MXPredSetInput(handle, "data", ${inputName}.data(), static_cast<mx_uint>(${inputName}.size()));
<#else>
MXPredSetInput(handle, "data${inputName?index}", ${inputName}.data(), static_cast<mx_uint>(${inputName}.size()));
</#if>
</#list>
MXPredForward(handle);
......@@ -71,22 +80,17 @@ public:
const mx_uint num_input_nodes = input_keys.size();
<#if (tc.architectureInputs?size >= 2)>
const char* input_keys_ptr[num_input_nodes];
for(mx_uint i = 0; i < num_input_nodes; i++){
input_keys_ptr[i] = input_keys[i].c_str();
}
<#else>
const char* input_key[1] = { "data" };
const char** input_keys_ptr = input_key;
</#if>
mx_uint shape_data_size = 0;
mx_uint input_shape_indptr[input_shapes.size() + 1];
input_shape_indptr[0] = 0;
for(mx_uint i = 0; i < input_shapes.size(); i++){
input_shape_indptr[i+1] = input_shapes[i].size();
shape_data_size += input_shapes[i].size();
input_shape_indptr[i+1] = shape_data_size;
}
mx_uint input_shape_data[shape_data_size];
......
......@@ -13,9 +13,10 @@ class CNNPredictor_Alexnet{
public:
const std::string json_file = "model/Alexnet/model_newest-symbol.json";
const std::string param_file = "model/Alexnet/model_newest-0000.params";
//const std::vector<std::string> input_keys = {"data"};
const std::vector<std::string> input_keys = {"data"};
const std::vector<std::vector<mx_uint>> input_shapes = {{1,3,224,224}};
const std::vector<std::string> input_keys = {
"data"
};
const std::vector<std::vector<mx_uint>> input_shapes = {{1, 3, 224, 224}};
const bool use_gpu = false;
PredictorHandle handle;
......@@ -67,15 +68,17 @@ public:
const mx_uint num_input_nodes = input_keys.size();
const char* input_key[1] = { "data" };
const char** input_keys_ptr = input_key;
const char* input_keys_ptr[num_input_nodes];
for(mx_uint i = 0; i < num_input_nodes; i++){
input_keys_ptr[i] = input_keys[i].c_str();
}
mx_uint shape_data_size = 0;
mx_uint input_shape_indptr[input_shapes.size() + 1];
input_shape_indptr[0] = 0;
for(mx_uint i = 0; i < input_shapes.size(); i++){
input_shape_indptr[i+1] = input_shapes[i].size();
shape_data_size += input_shapes[i].size();
input_shape_indptr[i+1] = shape_data_size;
}
mx_uint input_shape_data[shape_data_size];
......@@ -101,4 +104,4 @@ public:
}
};
#endif // CNNPREDICTOR_ALEXNET
\ No newline at end of file
#endif // CNNPREDICTOR_ALEXNET
......@@ -13,9 +13,10 @@ class CNNPredictor_CifarClassifierNetwork{
public:
const std::string json_file = "model/CifarClassifierNetwork/model_newest-symbol.json";
const std::string param_file = "model/CifarClassifierNetwork/model_newest-0000.params";
//const std::vector<std::string> input_keys = {"data"};
const std::vector<std::string> input_keys = {"data"};
const std::vector<std::vector<mx_uint>> input_shapes = {{1,3,32,32}};
const std::vector<std::string> input_keys = {
"data"
};
const std::vector<std::vector<mx_uint>> input_shapes = {{1, 3, 32, 32}};
const bool use_gpu = false;
PredictorHandle handle;
......@@ -67,15 +68,17 @@ public:
const mx_uint num_input_nodes = input_keys.size();
const char* input_key[1] = { "data" };
const char** input_keys_ptr = input_key;
const char* input_keys_ptr[num_input_nodes];
for(mx_uint i = 0; i < num_input_nodes; i++){
input_keys_ptr[i] = input_keys[i].c_str();
}
mx_uint shape_data_size = 0;
mx_uint input_shape_indptr[input_shapes.size() + 1];
input_shape_indptr[0] = 0;
for(mx_uint i = 0; i < input_shapes.size(); i++){
input_shape_indptr[i+1] = input_shapes[i].size();
shape_data_size += input_shapes[i].size();
input_shape_indptr[i+1] = shape_data_size;
}
mx_uint input_shape_data[shape_data_size];
......@@ -101,4 +104,4 @@ public:
}
};
#endif // CNNPREDICTOR_CIFARCLASSIFIERNETWORK
\ No newline at end of file
#endif // CNNPREDICTOR_CIFARCLASSIFIERNETWORK
......@@ -13,9 +13,10 @@ class CNNPredictor_VGG16{
public:
const std::string json_file = "model/VGG16/model_newest-symbol.json";
const std::string param_file = "model/VGG16/model_newest-0000.params";
//const std::vector<std::string> input_keys = {"data"};
const std::vector<std::string> input_keys = {"data"};
const std::vector<std::vector<mx_uint>> input_shapes = {{1,3,224,224}};
const std::vector<std::string> input_keys = {
"data"
};
const std::vector<std::vector<mx_uint>> input_shapes = {{1, 3, 224, 224}};
const bool use_gpu = false;
PredictorHandle handle;
......@@ -67,15 +68,17 @@ public:
const mx_uint num_input_nodes = input_keys.size();
const char* input_key[1] = { "data" };
const char** input_keys_ptr = input_key;
const char* input_keys_ptr[num_input_nodes];
for(mx_uint i = 0; i < num_input_nodes; i++){
input_keys_ptr[i] = input_keys[i].c_str();
}
mx_uint shape_data_size = 0;
mx_uint input_shape_indptr[input_shapes.size() + 1];
input_shape_indptr[0] = 0;
for(mx_uint i = 0; i < input_shapes.size(); i++){
input_shape_indptr[i+1] = input_shapes[i].size();
shape_data_size += input_shapes[i].size();
input_shape_indptr[i+1] = shape_data_size;
}
mx_uint input_shape_data[shape_data_size];
......@@ -101,4 +104,4 @@ public:
}
};
#endif // CNNPREDICTOR_VGG16
\ No newline at end of file
#endif // CNNPREDICTOR_VGG16
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment