CNNDataLoader_cartpole_master_dqn.py 3.44 KB
Newer Older
Nicola Gatto's avatar
Nicola Gatto committed
1 2 3 4 5
import os
import h5py
import mxnet as mx
import logging
import sys
Nicola Gatto's avatar
Nicola Gatto committed
6
from mxnet import nd
Nicola Gatto's avatar
Nicola Gatto committed
7

Nicola Gatto's avatar
Nicola Gatto committed
8
class CNNDataLoader_cartpole_master_dqn:
Nicola Gatto's avatar
Nicola Gatto committed
9 10 11 12 13 14 15 16 17
    _input_names_ = ['state']
    _output_names_ = ['qvalues_label']

    def __init__(self):
        self._data_dir = "data/"

    def load_data(self, batch_size):
        train_h5, test_h5 = self.load_h5_files()

Nicola Gatto's avatar
Nicola Gatto committed
18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33
        train_data = {}
        data_mean = {}
        data_std = {}

        for input_name in self._input_names_:
            train_data[input_name] = train_h5[input_name]
            data_mean[input_name] = nd.array(train_h5[input_name][:].mean(axis=0))
            data_std[input_name] = nd.array(train_h5[input_name][:].std(axis=0) + 1e-5)

        train_label = {}
        for output_name in self._output_names_:
            train_label[output_name] = train_h5[output_name]

        train_iter = mx.io.NDArrayIter(data=train_data,
                                       label=train_label,
                                       batch_size=batch_size)
Nicola Gatto's avatar
Nicola Gatto committed
34 35

        test_iter = None
Nicola Gatto's avatar
Nicola Gatto committed
36

Nicola Gatto's avatar
Nicola Gatto committed
37
        if test_h5 != None:
Nicola Gatto's avatar
Nicola Gatto committed
38 39 40 41 42 43 44 45 46 47 48 49
            test_data = {}
            for input_name in self._input_names_:
                test_data[input_name] = test_h5[input_name]

            test_label = {}
            for output_name in self._output_names_:
                test_label[output_name] = test_h5[output_name]

            test_iter = mx.io.NDArrayIter(data=test_data,
                                          label=test_label,
                                          batch_size=batch_size)

Nicola Gatto's avatar
Nicola Gatto committed
50 51 52 53 54 55 56
        return train_iter, test_iter, data_mean, data_std

    def load_h5_files(self):
        train_h5 = None
        test_h5 = None
        train_path = self._data_dir + "train.h5"
        test_path = self._data_dir + "test.h5"
Nicola Gatto's avatar
Nicola Gatto committed
57

Nicola Gatto's avatar
Nicola Gatto committed
58 59
        if os.path.isfile(train_path):
            train_h5 = h5py.File(train_path, 'r')
Nicola Gatto's avatar
Nicola Gatto committed
60 61 62 63 64 65 66 67 68 69 70 71 72

            for input_name in self._input_names_:
                if not input_name in train_h5:
                    logging.error("The HDF5 file '" + os.path.abspath(train_path) + "' has to contain the dataset "
                                  + "'" + input_name + "'")
                    sys.exit(1)

            for output_name in self._output_names_:
                if not output_name in train_h5:
                    logging.error("The HDF5 file '" + os.path.abspath(train_path) + "' has to contain the dataset "
                                  + "'" + output_name + "'")
                    sys.exit(1)

Nicola Gatto's avatar
Nicola Gatto committed
73 74
            if os.path.isfile(test_path):
                test_h5 = h5py.File(test_path, 'r')
Nicola Gatto's avatar
Nicola Gatto committed
75 76 77 78 79 80 81 82 83 84 85 86

                for input_name in self._input_names_:
                    if not input_name in test_h5:
                        logging.error("The HDF5 file '" + os.path.abspath(test_path) + "' has to contain the dataset "
                                      + "'" + input_name + "'")
                        sys.exit(1)

                for output_name in self._output_names_:
                    if not output_name in test_h5:
                        logging.error("The HDF5 file '" + os.path.abspath(test_path) + "' has to contain the dataset "
                                      + "'" + output_name + "'")
                        sys.exit(1)
Nicola Gatto's avatar
Nicola Gatto committed
87 88
            else:
                logging.warning("Couldn't load test set. File '" + os.path.abspath(test_path) + "' does not exist.")
Nicola Gatto's avatar
Nicola Gatto committed
89

Nicola Gatto's avatar
Nicola Gatto committed
90 91 92 93
            return train_h5, test_h5
        else:
            logging.error("Data loading failure. File '" + os.path.abspath(train_path) + "' does not exist.")
            sys.exit(1)