CNNPredictor.ftl 4.71 KB
Newer Older
1
2
3
#ifndef ${tc.fileNameWithoutEnding?upper_case}
#define ${tc.fileNameWithoutEnding?upper_case}

4
5
6
7
8
9
10
11
12
13
14
15
#include "caffe2/core/common.h"
#include "caffe2/utils/proto_utils.h"
#include "caffe2/core/workspace.h"
#include "caffe2/core/tensor.h"
#include "caffe2/core/init.h"

// Enable define USE_GPU if you want to use gpu
//#define USE_GPU

#ifdef USE_GPU
#include "caffe2/core/context_gpu.h"
#endif
16
17

#include <string>
18
19
20
21
22
#include <iostream>
#include <map>

CAFFE2_DEFINE_string(init_net, "./model/${tc.fullArchitectureName}/init_net.pb", "The given path to the init protobuffer.");
CAFFE2_DEFINE_string(predict_net, "./model/${tc.fullArchitectureName}/predict_net.pb", "The given path to the predict protobuffer.");
23

24
using namespace caffe2;
25
26

class ${tc.fileNameWithoutEnding}{
27
28
29
30
    private:
        TensorCPU input;
        Workspace workSpace;
        NetDef initNet, predictNet;
31

32
33
34
    public:
        const std::vector<TIndex> input_shapes = {<#list tc.architecture.inputs as input>{1,${tc.join(input.definition.type.dimensions, ",")}}<#if input?has_next>,</#if></#list>};
        const bool use_gpu = false;
35

36
37
38
        explicit ${tc.fileNameWithoutEnding}(){
            init(input_shapes);
        }
39

40
        //~${tc.fileNameWithoutEnding}(){};
41

42
43
44
45
        void init(const std::vector<TIndex> &input_shapes){
            int n = 0;
            char **a[1];
            caffe2::GlobalInit(&n, a);
46

47
48
49
50
            if (!std::ifstream(FLAGS_init_net).good()) {
                std::cerr << "Network loading failure, init_net file '" << FLAGS_init_net << "' does not exist." << std::endl;
                return;
            }
51

52
53
54
55
            if (!std::ifstream(FLAGS_predict_net).good()) {
                std::cerr << "Network loading failure, predict_net file '" << FLAGS_predict_net << "' does not exist." << std::endl;
                return;
            }
56

57
58
            std::cout << "****************************************************************" << std::endl;
            std::cout << "Loading network..." << std::endl;
59

60
61
62
            // Read protobuf
            CAFFE_ENFORCE(ReadProtoFromFile(FLAGS_init_net, &initNet));
            CAFFE_ENFORCE(ReadProtoFromFile(FLAGS_predict_net, &predictNet));
63

64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
            // Set device type
#ifdef USE_GPU
            predictNet.mutable_device_option()->set_device_type(CUDA);
            initNet.mutable_device_option()->set_device_type(CUDA);
            std::cout << "== GPU mode selected " << " ==" << std::endl;
#else
            predictNet.mutable_device_option()->set_device_type(CPU);
            initNet.mutable_device_option()->set_device_type(CPU);

            for(int i = 0; i < predictNet.op_size(); ++i){
                predictNet.mutable_op(i)->mutable_device_option()->set_device_type(CPU);
            }
            for(int i = 0; i < initNet.op_size(); ++i){
                initNet.mutable_op(i)->mutable_device_option()->set_device_type(CPU);
            }
            std::cout << "== CPU mode selected " << " ==" << std::endl;
#endif
81

82
83
84
85
            // Load network
            CAFFE_ENFORCE(workSpace.RunNetOnce(initNet));
            CAFFE_ENFORCE(workSpace.CreateNet(predictNet));
            std::cout << "== Network loaded " << " ==" << std::endl;
86

87
            input.Resize(input_shapes);
88
89
        }

90
91
92
        void predict(${tc.join(tc.architectureInputs, ", ", "const std::vector<float> &", "")}, ${tc.join(tc.architectureOutputs, ", ", "std::vector<float> &", "")}){
            //Note: ShareExternalPointer requires a float pointer.
            input.ShareExternalPointer((float *) ${tc.join(tc.architectureInputs, ",", "","")}.data());
93

94
95
96
97
98
99
100
101
102
103
104
105
106
            // Get input blob
<#--<#list tc.architectureInputs as inputName>-->
#ifdef USE_GPU
            <#--auto ${inputName + "Blob"} = workSpace.GetBlob("${inputName}")->GetMutable<TensorCUDA>();-->
            auto dataBlob = workSpace.GetBlob("data")->GetMutable<TensorCUDA>();
#else
            <#--auto ${inputName + "Blob"} = workSpace.GetBlob("${inputName}")->GetMutable<TensorCPU>();-->
            auto dataBlob = workSpace.GetBlob("data")->GetMutable<TensorCPU>();
#endif

<#--</#list>-->
            // Copy from input data
            dataBlob->CopyFrom(input);
107

108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
            // Forward
            workSpace.RunNet(predictNet.name());

            // Get output blob
<#list tc.architectureOutputs as outputName>
#ifdef USE_GPU
            auto ${outputName + "Blob"} = TensorCPU(workSpace.GetBlob("${outputName}")->Get<TensorCUDA>());
#else
            auto ${outputName + "Blob"} = workSpace.GetBlob("${outputName}")->Get<TensorCPU>();
#endif
            ${outputName}.assign(${outputName + "Blob"}.data<float>(),${outputName + "Blob"}.data<float>() + ${outputName + "Blob"}.size());

</#list>
            google::protobuf::ShutdownProtobufLibrary();
        }
123
124
125
};

#endif // ${tc.fileNameWithoutEnding?upper_case}