CNNDataLoader_VGG16.py 8.43 KB
Newer Older
Nicola Gatto's avatar
Nicola Gatto committed
1 2 3 4 5
import os
import h5py
import mxnet as mx
import logging
import sys
6 7 8
import numpy as np
import cv2
import importlib
9
from mxnet import nd
Nicola Gatto's avatar
Nicola Gatto committed
10

11
class CNNDataLoader_VGG16:
Nicola Gatto's avatar
Nicola Gatto committed
12
    _input_names_ = ['data']
13
    _output_names_ = ['predictions_label']
Nicola Gatto's avatar
Nicola Gatto committed
14 15 16 17

    def __init__(self):
        self._data_dir = "data/VGG16/"

18
    def load_data(self, batch_size, shuffle=False):
Nicola Gatto's avatar
Nicola Gatto committed
19 20
        train_h5, test_h5 = self.load_h5_files()

21 22 23
        train_data = {}
        data_mean = {}
        data_std = {}
24
        train_images = {}
25 26 27

        for input_name in self._input_names_:
            train_data[input_name] = train_h5[input_name]
28 29
            data_mean[input_name + '_'] = nd.array(train_h5[input_name][:].mean(axis=0))
            data_std[input_name + '_'] = nd.array(train_h5[input_name][:].std(axis=0) + 1e-5)
30

31 32 33
            if 'images' in train_h5:
                train_images = train_h5['images']

34
        train_label = {}
Christian Fuß's avatar
Christian Fuß committed
35
        index = 0
36
        for output_name in self._output_names_:
Christian Fuß's avatar
Christian Fuß committed
37 38
            train_label[index] = train_h5[output_name]
            index += 1
39 40 41

        train_iter = mx.io.NDArrayIter(data=train_data,
                                       label=train_label,
42 43
                                       batch_size=batch_size,
                                       shuffle=shuffle)
Nicola Gatto's avatar
Nicola Gatto committed
44 45

        test_iter = None
46

Nicola Gatto's avatar
Nicola Gatto committed
47
        if test_h5 != None:
48
            test_data = {}
49
            test_images = {}
50 51 52
            for input_name in self._input_names_:
                test_data[input_name] = test_h5[input_name]

53 54 55
                if 'images' in test_h5:
                    test_images = test_h5['images']

56
            test_label = {}
Christian Fuß's avatar
Christian Fuß committed
57
            index = 0
58
            for output_name in self._output_names_:
Christian Fuß's avatar
Christian Fuß committed
59 60
                test_label[index] = test_h5[output_name]
                index += 1
61 62 63

            test_iter = mx.io.NDArrayIter(data=test_data,
                                          label=test_label,
64
                                          batch_size=batch_size)
65

66
        return train_iter, test_iter, data_mean, data_std, train_images, test_images
Nicola Gatto's avatar
Nicola Gatto committed
67

68
    def load_preprocessed_data(self, batch_size, preproc_lib, shuffle=False):
69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108
        train_h5, test_h5 = self.load_h5_files()

        wrapper = importlib.import_module(preproc_lib)
        instance = getattr(wrapper, preproc_lib)()
        instance.init()
        lib_head, _sep, tail = preproc_lib.rpartition('_')
        inp = getattr(wrapper, lib_head + "_input")()

        train_data = {}
        train_label = {}
        data_mean = {}
        data_std = {}

        shape_output = self.preprocess_data(instance, inp, 0, train_h5)
        train_len = len(train_h5[self._input_names_[0]])

        for input_name in self._input_names_:
            if type(getattr(shape_output, input_name + "_out")) == np.ndarray:
                cur_shape = (train_len,) + getattr(shape_output, input_name + "_out").shape
            else:
                cur_shape = (train_len, 1)
            train_data[input_name] = mx.nd.zeros(cur_shape)
        for output_name in self._output_names_:
            if type(getattr(shape_output, output_name + "_out")) == nd.array:
                cur_shape = (train_len,) + getattr(shape_output, output_name + "_out").shape
            else:
                cur_shape = (train_len, 1)
            train_label[output_name] = mx.nd.zeros(cur_shape)

        for i in range(train_len):
            output = self.preprocess_data(instance, inp, i, train_h5)
            for input_name in self._input_names_:
                train_data[input_name][i] = getattr(output, input_name + "_out")
            for output_name in self._output_names_:
                train_label[output_name][i] = getattr(shape_output, output_name + "_out")

        for input_name in self._input_names_:
            data_mean[input_name + '_'] = nd.array(train_data[input_name][:].mean(axis=0))
            data_std[input_name + '_'] = nd.array(train_data[input_name][:].asnumpy().std(axis=0) + 1e-5)

109 110 111
        if 'images' in train_h5:
            train_images = train_h5['images']

112 113
        train_iter = mx.io.NDArrayIter(data=train_data,
                                       label=train_label,
114 115
                                       batch_size=batch_size,
                                       shuffle=shuffle)
116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142

        test_data = {}
        test_label = {}

        shape_output = self.preprocess_data(instance, inp, 0, test_h5)
        test_len = len(test_h5[self._input_names_[0]])

        for input_name in self._input_names_:
            if type(getattr(shape_output, input_name + "_out")) == np.ndarray:
                cur_shape = (test_len,) + getattr(shape_output, input_name + "_out").shape
            else:
                cur_shape = (test_len, 1)
            test_data[input_name] = mx.nd.zeros(cur_shape)
        for output_name in self._output_names_:
            if type(getattr(shape_output, output_name + "_out")) == nd.array:
                cur_shape = (test_len,) + getattr(shape_output, output_name + "_out").shape
            else:
                cur_shape = (test_len, 1)
            test_label[output_name] = mx.nd.zeros(cur_shape)

        for i in range(test_len):
            output = self.preprocess_data(instance, inp, i, test_h5)
            for input_name in self._input_names_:
                test_data[input_name][i] = getattr(output, input_name + "_out")
            for output_name in self._output_names_:
                test_label[output_name][i] = getattr(shape_output, output_name + "_out")

143 144 145
        if 'images' in test_h5:
            test_images = test_h5['images']

146 147 148 149
        test_iter = mx.io.NDArrayIter(data=test_data,
                                       label=test_label,
                                       batch_size=batch_size)

150
        return train_iter, test_iter, data_mean, data_std, train_images, test_images
151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170

    def preprocess_data(self, instance_wrapper, input_wrapper, index, data_h5):
        for input_name in self._input_names_:
            data = data_h5[input_name][0]
            attr = getattr(input_wrapper, input_name)
            if (type(data)) == np.ndarray:
                data = np.asfortranarray(data).astype(attr.dtype)
            else:
                data = type(attr)(data)
            setattr(input_wrapper, input_name, data)
        for output_name in self._output_names_:
            data = data_h5[output_name][0]
            attr = getattr(input_wrapper, output_name)
            if (type(data)) == np.ndarray:
                data = np.asfortranarray(data).astype(attr.dtype)
            else:
                data = type(attr)(data)
            setattr(input_wrapper, output_name, data)
        return instance_wrapper.execute(input_wrapper)

Nicola Gatto's avatar
Nicola Gatto committed
171 172 173 174 175
    def load_h5_files(self):
        train_h5 = None
        test_h5 = None
        train_path = self._data_dir + "train.h5"
        test_path = self._data_dir + "test.h5"
176

Nicola Gatto's avatar
Nicola Gatto committed
177 178
        if os.path.isfile(train_path):
            train_h5 = h5py.File(train_path, 'r')
179 180 181 182 183 184 185 186 187 188 189 190 191

            for input_name in self._input_names_:
                if not input_name in train_h5:
                    logging.error("The HDF5 file '" + os.path.abspath(train_path) + "' has to contain the dataset "
                                  + "'" + input_name + "'")
                    sys.exit(1)

            for output_name in self._output_names_:
                if not output_name in train_h5:
                    logging.error("The HDF5 file '" + os.path.abspath(train_path) + "' has to contain the dataset "
                                  + "'" + output_name + "'")
                    sys.exit(1)

Nicola Gatto's avatar
Nicola Gatto committed
192 193
            if os.path.isfile(test_path):
                test_h5 = h5py.File(test_path, 'r')
194 195 196 197 198 199 200 201 202 203 204 205

                for input_name in self._input_names_:
                    if not input_name in test_h5:
                        logging.error("The HDF5 file '" + os.path.abspath(test_path) + "' has to contain the dataset "
                                      + "'" + input_name + "'")
                        sys.exit(1)

                for output_name in self._output_names_:
                    if not output_name in test_h5:
                        logging.error("The HDF5 file '" + os.path.abspath(test_path) + "' has to contain the dataset "
                                      + "'" + output_name + "'")
                        sys.exit(1)
Nicola Gatto's avatar
Nicola Gatto committed
206 207
            else:
                logging.warning("Couldn't load test set. File '" + os.path.abspath(test_path) + "' does not exist.")
208

Nicola Gatto's avatar
Nicola Gatto committed
209 210 211
            return train_h5, test_h5
        else:
            logging.error("Data loading failure. File '" + os.path.abspath(train_path) + "' does not exist.")
Bernhard Rumpe's avatar
BR-sy  
Bernhard Rumpe committed
212
            sys.exit(1)