Aufgrund einer Wartung wird GitLab am 19.10. zwischen 8:00 und 9:00 Uhr kurzzeitig nicht zur Verfügung stehen. / Due to maintenance, GitLab will be temporarily unavailable on 19.10. between 8:00 and 9:00 am.

CNNPredictor.ftl 8.02 KB
Newer Older
1
2
3
4
5
6
7
8
9
10
11
#ifndef ${tc.fileNameWithoutEnding?upper_case}
#define ${tc.fileNameWithoutEnding?upper_case}

#include <mxnet/c_predict_api.h>

#include <cassert>
#include <string>
#include <vector>

#include <CNNBufferFile.h>

12
<#list tc.architecture.streams as stream>
13
<#if stream.isTrainable()>
14
class ${tc.fileNameWithoutEnding}_${stream?index}{
15
public:
16
17
    const std::string json_file = "model/${tc.componentName}/model_${stream?index}_newest-symbol.json";
    const std::string param_file = "model/${tc.componentName}/model_${stream?index}_newest-0000.params";
18
    const std::vector<std::string> input_keys = {
19
<#if tc.getStreamInputNames(stream)?size == 1>
20
21
        "data"
<#else>
22
        <#list tc.getStreamInputNames(stream) as variable>"data${variable?index}"<#sep>, </#list>
23
24
</#if>
    };
25
    const std::vector<std::vector<mx_uint>> input_shapes = {<#list tc.getStreamInputDimensions(stream) as dimensions>{${tc.join(dimensions, ", ")}}<#sep>, </#list>};
26
27
28
29
    const bool use_gpu = false;

    PredictorHandle handle;

30
    explicit ${tc.fileNameWithoutEnding}_${stream?index}(){
31
32
33
        init(json_file, param_file, input_keys, input_shapes, use_gpu);
    }

34
    ~${tc.fileNameWithoutEnding}_${stream?index}(){
35
36
37
        if(handle) MXPredFree(handle);
    }

38
39
40
41
    void predict(${tc.join(tc.getStreamInputNames(stream), ", ", "const std::vector<float> &in_", "")},
                 ${tc.join(tc.getStreamOutputNames(stream), ", ", "std::vector<float> &out_", "")}){
<#list tc.getStreamInputNames(stream) as variable>
        MXPredSetInput(handle, input_keys[${variable?index}].c_str(), in_${variable}.data(), static_cast<mx_uint>(in_${variable}.size()));
42
43
44
45
46
47
48
49
50
</#list>

        MXPredForward(handle);

        mx_uint output_index;
        mx_uint *shape = 0;
        mx_uint shape_len;
        size_t size;

51
52
<#list tc.getStreamOutputNames(stream) as variable>
        output_index = ${variable?index?c};
53
54
55
        MXPredGetOutputShape(handle, output_index, &shape, &shape_len);
        size = 1;
        for (mx_uint i = 0; i < shape_len; ++i) size *= shape[i];
56
57
        assert(size == out_${variable}.size());
        MXPredGetOutput(handle, ${variable?index?c}, &(out_${variable}[0]), out_${variable}.size());
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90

</#list>
    }

    void init(const std::string &json_file,
              const std::string &param_file,
              const std::vector<std::string> &input_keys,
              const std::vector<std::vector<mx_uint>> &input_shapes,
              const bool &use_gpu){

        BufferFile json_data(json_file);
        BufferFile param_data(param_file);

        int dev_type = use_gpu ? 2 : 1;
        int dev_id = 0;

        if (json_data.GetLength() == 0 ||
            param_data.GetLength() == 0) {
            std::exit(-1);
        }

        const mx_uint num_input_nodes = input_keys.size();

        const char* input_keys_ptr[num_input_nodes];
        for(mx_uint i = 0; i < num_input_nodes; i++){
            input_keys_ptr[i] = input_keys[i].c_str();
        }

        mx_uint shape_data_size = 0;
        mx_uint input_shape_indptr[input_shapes.size() + 1];
        input_shape_indptr[0] = 0;
        for(mx_uint i = 0; i < input_shapes.size(); i++){
            shape_data_size += input_shapes[i].size();
91
            input_shape_indptr[i+1] = shape_data_size;
92
93
94
95
96
97
98
99
100
101
102
        }

        mx_uint input_shape_data[shape_data_size];
        mx_uint index = 0;
        for(mx_uint i = 0; i < input_shapes.size(); i++){
            for(mx_uint j = 0; j < input_shapes[i].size(); j++){
                input_shape_data[index] = input_shapes[i][j];
                index++;
            }
        }

103
104
        MXPredCreate(static_cast<const char*>(json_data.GetBuffer()),
                     static_cast<const char*>(param_data.GetBuffer()),
105
106
107
108
109
110
111
112
113
114
115
                     static_cast<size_t>(param_data.GetLength()),
                     dev_type,
                     dev_id,
                     num_input_nodes,
                     input_keys_ptr,
                     input_shape_indptr,
                     input_shape_data,
                     &handle);
        assert(handle);
    }
};
116
117
</#if>
</#list>
118

119
<#list tc.architecture.unrolls as unroll>
120
121
122
<#list unroll.getBodiesForAllTimesteps() as body>
<#if body.isTrainable()>
class ${tc.fileNameWithoutEnding}_${tc.architecture.streams?size + body?index}{
123
public:
124
125
    const std::string json_file = "model/${tc.componentName}/model_${tc.architecture.streams?size + body?index}_newest-symbol.json";
    const std::string param_file = "model/${tc.componentName}/model_${tc.architecture.streams?size + body?index}_newest-0000.params";
126
    const std::vector<std::string> input_keys = {
127
<#if tc.getStreamInputNames(body)?size == 1>
128
129
        "data"
<#else>
130
        <#list tc.getStreamInputNames(body) as variable>"data${variable?index}"<#sep>, </#list>
131
132
</#if>
    };
133
    const std::vector<std::vector<mx_uint>> input_shapes = {<#list tc.getStreamInputDimensions(body) as dimensions>{${tc.join(dimensions, ", ")}}<#sep>, </#list>};
134
135
136
137
    const bool use_gpu = false;

    PredictorHandle handle;

138
    explicit ${tc.fileNameWithoutEnding}_${tc.architecture.streams?size + body?index}(){
139
140
141
        init(json_file, param_file, input_keys, input_shapes, use_gpu);
    }

142
    ~${tc.fileNameWithoutEnding}_${tc.architecture.streams?size + body?index}(){
143
144
145
        if(handle) MXPredFree(handle);
    }

146
147
148
    void predict(${tc.join(tc.getStreamInputNames(body), ", ", "const std::vector<float> &in_", "")},
                 ${tc.join(tc.getStreamOutputNames(body), ", ", "std::vector<float> &out_", "")}){
<#list tc.getStreamInputNames(body) as variable>
149
150
151
152
153
154
155
156
157
158
        MXPredSetInput(handle, input_keys[${variable?index}].c_str(), in_${variable}.data(), static_cast<mx_uint>(in_${variable}.size()));
</#list>

        MXPredForward(handle);

        mx_uint output_index;
        mx_uint *shape = 0;
        mx_uint shape_len;
        size_t size;

159
<#list tc.getStreamOutputNames(body) as variable>
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
        output_index = ${variable?index?c};
        MXPredGetOutputShape(handle, output_index, &shape, &shape_len);
        size = 1;
        for (mx_uint i = 0; i < shape_len; ++i) size *= shape[i];
        assert(size == out_${variable}.size());
        MXPredGetOutput(handle, ${variable?index?c}, &(out_${variable}[0]), out_${variable}.size());

</#list>
    }

    void init(const std::string &json_file,
              const std::string &param_file,
              const std::vector<std::string> &input_keys,
              const std::vector<std::vector<mx_uint>> &input_shapes,
              const bool &use_gpu){

        BufferFile json_data(json_file);
        BufferFile param_data(param_file);

        int dev_type = use_gpu ? 2 : 1;
        int dev_id = 0;

        if (json_data.GetLength() == 0 ||
            param_data.GetLength() == 0) {
            std::exit(-1);
        }

        const mx_uint num_input_nodes = input_keys.size();

        const char* input_keys_ptr[num_input_nodes];
        for(mx_uint i = 0; i < num_input_nodes; i++){
            input_keys_ptr[i] = input_keys[i].c_str();
        }

        mx_uint shape_data_size = 0;
        mx_uint input_shape_indptr[input_shapes.size() + 1];
        input_shape_indptr[0] = 0;
        for(mx_uint i = 0; i < input_shapes.size(); i++){
            shape_data_size += input_shapes[i].size();
            input_shape_indptr[i+1] = shape_data_size;
        }

        mx_uint input_shape_data[shape_data_size];
        mx_uint index = 0;
        for(mx_uint i = 0; i < input_shapes.size(); i++){
            for(mx_uint j = 0; j < input_shapes[i].size(); j++){
                input_shape_data[index] = input_shapes[i][j];
                index++;
            }
        }

        MXPredCreate(static_cast<const char*>(json_data.GetBuffer()),
                     static_cast<const char*>(param_data.GetBuffer()),
                     static_cast<size_t>(param_data.GetLength()),
                     dev_type,
                     dev_id,
                     num_input_nodes,
                     input_keys_ptr,
                     input_shape_indptr,
                     input_shape_data,
                     &handle);
        assert(handle);
    }
};
</#if>
</#list>
226
</#list>
227
228


229
#endif // ${tc.fileNameWithoutEnding?upper_case}