CNNPredictor_LeNet.h 3.97 KB
Newer Older
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20
#ifndef CNNPREDICTOR_LENET
#define CNNPREDICTOR_LENET

#include "caffe2/core/common.h"
#include "caffe2/utils/proto_utils.h"
#include "caffe2/core/workspace.h"
#include "caffe2/core/tensor.h"
#include "caffe2/core/init.h"

// Define USE_GPU for GPU computation. Default is CPU computation.
//#define USE_GPU

#ifdef USE_GPU
#include "caffe2/core/context_gpu.h"
#endif

#include <string>
#include <iostream>
#include <map>

21 22
CAFFE2_DEFINE_string(init_net, "./model/LeNet/init_net.pb", "The given path to the init protobuffer.");
CAFFE2_DEFINE_string(predict_net, "./model/LeNet/predict_net.pb", "The given path to the predict protobuffer.");
23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45

using namespace caffe2;

class CNNPredictor_LeNet{
    private:
        TensorCPU input;
        Workspace workSpace;
        NetDef initNet, predictNet;

    public:
        const std::vector<TIndex> input_shapes = {{1,1,28,28}};

        explicit CNNPredictor_LeNet(){
            init(input_shapes);
        }

        ~CNNPredictor_LeNet(){};

        void init(const std::vector<TIndex> &input_shapes){
            int n = 0;
            char **a[1];
            caffe2::GlobalInit(&n, a);

46 47
            if (!std::ifstream(FLAGS_init_net).good()) {
                std::cerr << "\nNetwork loading failure, init_net file '" << FLAGS_init_net << "' does not exist." << std::endl;
48 49 50
                exit(1);
            }

51 52
            if (!std::ifstream(FLAGS_predict_net).good()) {
                std::cerr << "\nNetwork loading failure, predict_net file '" << FLAGS_predict_net << "' does not exist." << std::endl;
53 54 55 56 57 58
                exit(1);
            }

            std::cout << "\nLoading network..." << std::endl;

            // Read protobuf
59 60
            CAFFE_ENFORCE(ReadProtoFromFile(FLAGS_init_net, &initNet));
            CAFFE_ENFORCE(ReadProtoFromFile(FLAGS_predict_net, &predictNet));
61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117

            // Set device type
            #ifdef USE_GPU
            predictNet.mutable_device_option()->set_device_type(CUDA);
            initNet.mutable_device_option()->set_device_type(CUDA);
            std::cout << "== GPU mode selected " << " ==" << std::endl;
            #else
            predictNet.mutable_device_option()->set_device_type(CPU);
            initNet.mutable_device_option()->set_device_type(CPU);

            for(int i = 0; i < predictNet.op_size(); ++i){
                predictNet.mutable_op(i)->mutable_device_option()->set_device_type(CPU);
            }
            for(int i = 0; i < initNet.op_size(); ++i){
                initNet.mutable_op(i)->mutable_device_option()->set_device_type(CPU);
            }
            std::cout << "== CPU mode selected " << " ==" << std::endl;
            #endif

            // Load network
            CAFFE_ENFORCE(workSpace.RunNetOnce(initNet));
            CAFFE_ENFORCE(workSpace.CreateNet(predictNet));
            std::cout << "== Network loaded " << " ==" << std::endl;

            input.Resize(input_shapes);
        }

        void predict(const std::vector<float> &image, std::vector<float> &predictions){
            //Note: ShareExternalPointer requires a float pointer.
            input.ShareExternalPointer((float *) image.data());

            // Get input blob
            #ifdef USE_GPU
            auto dataBlob = workSpace.GetBlob("data")->GetMutable<TensorCUDA>();
            #else
            auto dataBlob = workSpace.GetBlob("data")->GetMutable<TensorCPU>();
            #endif

            // Copy from input data
            dataBlob->CopyFrom(input);

            // Forward
            workSpace.RunNet(predictNet.name());

            // Get output blob
            #ifdef USE_GPU
            auto predictionsBlob = TensorCPU(workSpace.GetBlob("predictions")->Get<TensorCUDA>());
            #else
            auto predictionsBlob = workSpace.GetBlob("predictions")->Get<TensorCPU>();
            #endif
            predictions.assign(predictionsBlob.data<float>(),predictionsBlob.data<float>() + predictionsBlob.size());

            google::protobuf::ShutdownProtobufLibrary();
        }
};

#endif // CNNPREDICTOR_LENET