CNNPredictor_torcs_agent_torcsAgent_dqn.h 3.42 KB
Newer Older
1 2 3 4 5 6 7 8 9 10 11
#ifndef CNNPREDICTOR_TORCS_AGENT_TORCSAGENT_DQN
#define CNNPREDICTOR_TORCS_AGENT_TORCSAGENT_DQN

#include <mxnet/c_predict_api.h>

#include <cassert>
#include <string>
#include <vector>

#include <CNNBufferFile.h>

Nicola Gatto's avatar
Nicola Gatto committed
12
class CNNPredictor_torcs_agent_torcsAgent_dqn_0{
13
public:
Nicola Gatto's avatar
Nicola Gatto committed
14 15 16 17 18
    const std::string json_file = "model/torcs.agent.dqn.TorcsDQN/model_0_newest-symbol.json";
    const std::string param_file = "model/torcs.agent.dqn.TorcsDQN/model_0_newest-0000.params";
    const std::vector<std::string> input_keys = {
        "data"
    };
Sebastian N.'s avatar
Sebastian N. committed
19
    const std::vector<std::vector<mx_uint>> input_shapes = {{1, 5}};
20 21 22 23
    const bool use_gpu = false;

    PredictorHandle handle;

Nicola Gatto's avatar
Nicola Gatto committed
24
    explicit CNNPredictor_torcs_agent_torcsAgent_dqn_0(){
25 26 27
        init(json_file, param_file, input_keys, input_shapes, use_gpu);
    }

Nicola Gatto's avatar
Nicola Gatto committed
28
    ~CNNPredictor_torcs_agent_torcsAgent_dqn_0(){
29 30 31
        if(handle) MXPredFree(handle);
    }

Sebastian N.'s avatar
Sebastian N. committed
32 33 34
    void predict(const std::vector<float> &in_state_,
                 std::vector<float> &out_qvalues_){
        MXPredSetInput(handle, input_keys[0].c_str(), in_state_.data(), static_cast<mx_uint>(in_state_.size()));
35 36 37 38 39 40 41 42 43 44 45 46

        MXPredForward(handle);

        mx_uint output_index;
        mx_uint *shape = 0;
        mx_uint shape_len;
        size_t size;

        output_index = 0;
        MXPredGetOutputShape(handle, output_index, &shape, &shape_len);
        size = 1;
        for (mx_uint i = 0; i < shape_len; ++i) size *= shape[i];
Sebastian N.'s avatar
Sebastian N. committed
47
        assert(size == out_qvalues_.size());
Sebastian N.'s avatar
Sebastian N. committed
48
        MXPredGetOutput(handle, output_index, &(out_qvalues_[0]), out_qvalues_.size());
49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70

    }

    void init(const std::string &json_file,
              const std::string &param_file,
              const std::vector<std::string> &input_keys,
              const std::vector<std::vector<mx_uint>> &input_shapes,
              const bool &use_gpu){

        BufferFile json_data(json_file);
        BufferFile param_data(param_file);

        int dev_type = use_gpu ? 2 : 1;
        int dev_id = 0;

        if (json_data.GetLength() == 0 ||
            param_data.GetLength() == 0) {
            std::exit(-1);
        }

        const mx_uint num_input_nodes = input_keys.size();

Nicola Gatto's avatar
Nicola Gatto committed
71 72 73 74
        const char* input_keys_ptr[num_input_nodes];
        for(mx_uint i = 0; i < num_input_nodes; i++){
            input_keys_ptr[i] = input_keys[i].c_str();
        }
75 76 77 78 79 80

        mx_uint shape_data_size = 0;
        mx_uint input_shape_indptr[input_shapes.size() + 1];
        input_shape_indptr[0] = 0;
        for(mx_uint i = 0; i < input_shapes.size(); i++){
            shape_data_size += input_shapes[i].size();
Nicola Gatto's avatar
Nicola Gatto committed
81
            input_shape_indptr[i+1] = shape_data_size;
82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107
        }

        mx_uint input_shape_data[shape_data_size];
        mx_uint index = 0;
        for(mx_uint i = 0; i < input_shapes.size(); i++){
            for(mx_uint j = 0; j < input_shapes[i].size(); j++){
                input_shape_data[index] = input_shapes[i][j];
                index++;
            }
        }

        MXPredCreate(static_cast<const char*>(json_data.GetBuffer()),
                     static_cast<const char*>(param_data.GetBuffer()),
                     static_cast<size_t>(param_data.GetLength()),
                     dev_type,
                     dev_id,
                     num_input_nodes,
                     input_keys_ptr,
                     input_shape_indptr,
                     input_shape_data,
                     &handle);
        assert(handle);
    }
};

#endif // CNNPREDICTOR_TORCS_AGENT_TORCSAGENT_DQN