CNNPredictor.ftl 4.21 KB
Newer Older
1
2
3
4
5
6
7
8
9
10
11
#ifndef ${tc.fileNameWithoutEnding?upper_case}
#define ${tc.fileNameWithoutEnding?upper_case}

#include <mxnet/c_predict_api.h>

#include <cassert>
#include <string>
#include <vector>

#include <CNNBufferFile.h>

12
13
14
<#list tc.architecture.streams as stream>
<#if stream.isNetwork()>
class ${tc.fileNameWithoutEnding}_${stream?index}{
15
public:
16
17
    const std::string json_file = "model/${tc.componentName}/model_${stream?index}_newest-symbol.json";
    const std::string param_file = "model/${tc.componentName}/model_${stream?index}_newest-0000.params";
18
    const std::vector<std::string> input_keys = {
19
<#if (tc.getStreamInputNames(stream)?size == 1)>
20
21
        "data"
<#else>
22
        <#list tc.getStreamInputNames(stream) as inputName>"data${inputName?index}"<#sep>, </#list>
23
24
</#if>
    };
25
    const std::vector<std::vector<mx_uint>> input_shapes = {<#list stream.getFirstAtomicElements() as input>{1, ${tc.join(input.definition.type.dimensions, ", ")}}<#sep>, </#list>};
26
27
28
29
    const bool use_gpu = false;

    PredictorHandle handle;

30
    explicit ${tc.fileNameWithoutEnding}_${stream?index}(){
31
32
33
        init(json_file, param_file, input_keys, input_shapes, use_gpu);
    }

34
    ~${tc.fileNameWithoutEnding}_${stream?index}(){
35
36
37
        if(handle) MXPredFree(handle);
    }

38
39
40
41
    void predict(${tc.join(tc.getStreamInputNames(stream), ", ", "const std::vector<float> &", "")},
                 ${tc.join(tc.getStreamOutputNames(stream), ", ", "std::vector<float> &", "")}){
<#list tc.getStreamInputNames(stream) as inputName>
<#if (tc.getStreamInputNames(stream)?size == 1)>
42
        MXPredSetInput(handle, "data", ${inputName}.data(), static_cast<mx_uint>(${inputName}.size()));
43
44
45
<#else>
        MXPredSetInput(handle, "data${inputName?index}", ${inputName}.data(), static_cast<mx_uint>(${inputName}.size()));
</#if>
46
47
48
49
50
51
52
53
54
</#list>

        MXPredForward(handle);

        mx_uint output_index;
        mx_uint *shape = 0;
        mx_uint shape_len;
        size_t size;

55
<#list tc.getStreamOutputNames(stream) as outputName>
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
        output_index = ${outputName?index?c};
        MXPredGetOutputShape(handle, output_index, &shape, &shape_len);
        size = 1;
        for (mx_uint i = 0; i < shape_len; ++i) size *= shape[i];
        assert(size == ${outputName}.size());
        MXPredGetOutput(handle, ${outputName?index?c}, &(${outputName}[0]), ${outputName}.size());

</#list>
    }

    void init(const std::string &json_file,
              const std::string &param_file,
              const std::vector<std::string> &input_keys,
              const std::vector<std::vector<mx_uint>> &input_shapes,
              const bool &use_gpu){

        BufferFile json_data(json_file);
        BufferFile param_data(param_file);

        int dev_type = use_gpu ? 2 : 1;
        int dev_id = 0;

        if (json_data.GetLength() == 0 ||
            param_data.GetLength() == 0) {
            std::exit(-1);
        }

        const mx_uint num_input_nodes = input_keys.size();

        const char* input_keys_ptr[num_input_nodes];
        for(mx_uint i = 0; i < num_input_nodes; i++){
            input_keys_ptr[i] = input_keys[i].c_str();
        }

        mx_uint shape_data_size = 0;
        mx_uint input_shape_indptr[input_shapes.size() + 1];
        input_shape_indptr[0] = 0;
        for(mx_uint i = 0; i < input_shapes.size(); i++){
            shape_data_size += input_shapes[i].size();
95
            input_shape_indptr[i+1] = shape_data_size;
96
97
98
99
100
101
102
103
104
105
106
        }

        mx_uint input_shape_data[shape_data_size];
        mx_uint index = 0;
        for(mx_uint i = 0; i < input_shapes.size(); i++){
            for(mx_uint j = 0; j < input_shapes[i].size(); j++){
                input_shape_data[index] = input_shapes[i][j];
                index++;
            }
        }

107
108
        MXPredCreate(static_cast<const char*>(json_data.GetBuffer()),
                     static_cast<const char*>(param_data.GetBuffer()),
109
110
111
112
113
114
115
116
117
118
119
                     static_cast<size_t>(param_data.GetLength()),
                     dev_type,
                     dev_id,
                     num_input_nodes,
                     input_keys_ptr,
                     input_shape_indptr,
                     input_shape_data,
                     &handle);
        assert(handle);
    }
};
120
121
</#if>
</#list>
122
123

#endif // ${tc.fileNameWithoutEnding?upper_case}