CNNPredictor_VGG16.h 3.97 KB
Newer Older
1
2
3
#ifndef CNNPREDICTOR_VGG16
#define CNNPREDICTOR_VGG16

4
5
6
7
8
9
#include "caffe2/core/common.h"
#include "caffe2/utils/proto_utils.h"
#include "caffe2/core/workspace.h"
#include "caffe2/core/tensor.h"
#include "caffe2/core/init.h"

10
// Define USE_GPU for GPU computation. Default is CPU computation.
11
12
13
14
15
//#define USE_GPU

#ifdef USE_GPU
#include "caffe2/core/context_gpu.h"
#endif
16
17

#include <string>
18
19
20
21
22
#include <iostream>
#include <map>

CAFFE2_DEFINE_string(init_net, "./model/VGG16/init_net.pb", "The given path to the init protobuffer.");
CAFFE2_DEFINE_string(predict_net, "./model/VGG16/predict_net.pb", "The given path to the predict protobuffer.");
23

24
using namespace caffe2;
25
26

class CNNPredictor_VGG16{
27
28
29
30
    private:
        TensorCPU input;
        Workspace workSpace;
        NetDef initNet, predictNet;
31

32
33
    public:
        const std::vector<TIndex> input_shapes = {{1,3,224,224}};
34

35
36
        explicit CNNPredictor_VGG16(){
            init(input_shapes);
37
38
        }

39
        ~CNNPredictor_VGG16(){};
40
41
42
43
44
45
46

        void init(const std::vector<TIndex> &input_shapes){
            int n = 0;
            char **a[1];
            caffe2::GlobalInit(&n, a);

            if (!std::ifstream(FLAGS_init_net).good()) {
47
48
                std::cerr << "\nNetwork loading failure, init_net file '" << FLAGS_init_net << "' does not exist." << std::endl;
                exit(1);
49
50
51
            }

            if (!std::ifstream(FLAGS_predict_net).good()) {
52
53
                std::cerr << "\nNetwork loading failure, predict_net file '" << FLAGS_predict_net << "' does not exist." << std::endl;
                exit(1);
54
55
            }

56
            std::cout << "\nLoading network..." << std::endl;
57

58
59
60
61
62
            // Read protobuf
            CAFFE_ENFORCE(ReadProtoFromFile(FLAGS_init_net, &initNet));
            CAFFE_ENFORCE(ReadProtoFromFile(FLAGS_predict_net, &predictNet));

            // Set device type
63
            #ifdef USE_GPU
64
65
66
            predictNet.mutable_device_option()->set_device_type(CUDA);
            initNet.mutable_device_option()->set_device_type(CUDA);
            std::cout << "== GPU mode selected " << " ==" << std::endl;
67
            #else
68
69
70
71
72
73
74
75
            predictNet.mutable_device_option()->set_device_type(CPU);
            initNet.mutable_device_option()->set_device_type(CPU);

            for(int i = 0; i < predictNet.op_size(); ++i){
                predictNet.mutable_op(i)->mutable_device_option()->set_device_type(CPU);
            }
            for(int i = 0; i < initNet.op_size(); ++i){
                initNet.mutable_op(i)->mutable_device_option()->set_device_type(CPU);
76
            }
77
            std::cout << "== CPU mode selected " << " ==" << std::endl;
78
            #endif
79
80
81
82
83
84
85

            // Load network
            CAFFE_ENFORCE(workSpace.RunNetOnce(initNet));
            CAFFE_ENFORCE(workSpace.CreateNet(predictNet));
            std::cout << "== Network loaded " << " ==" << std::endl;

            input.Resize(input_shapes);
86
87
        }

88
89
90
91
92
        void predict(const std::vector<float> &data, std::vector<float> &predictions){
            //Note: ShareExternalPointer requires a float pointer.
            input.ShareExternalPointer((float *) data.data());

            // Get input blob
93
            #ifdef USE_GPU
94
            auto dataBlob = workSpace.GetBlob("data")->GetMutable<TensorCUDA>();
95
            #else
96
            auto dataBlob = workSpace.GetBlob("data")->GetMutable<TensorCPU>();
97
            #endif
98
99
100
101
102
103
104
105

            // Copy from input data
            dataBlob->CopyFrom(input);

            // Forward
            workSpace.RunNet(predictNet.name());

            // Get output blob
106
            #ifdef USE_GPU
107
            auto predictionsBlob = TensorCPU(workSpace.GetBlob("predictions")->Get<TensorCUDA>());
108
            #else
109
            auto predictionsBlob = workSpace.GetBlob("predictions")->Get<TensorCPU>();
110
            #endif
111
112
113
114
            predictions.assign(predictionsBlob.data<float>(),predictionsBlob.data<float>() + predictionsBlob.size());

            google::protobuf::ShutdownProtobufLibrary();
        }
115
116
117
};

#endif // CNNPREDICTOR_VGG16