CNNDataLoader_Alexnet.py 8.48 KB
Newer Older
Nicola Gatto's avatar
Nicola Gatto committed
1
2
3
4
5
import os
import h5py
import mxnet as mx
import logging
import sys
6
7
import numpy as np
import importlib
8
from mxnet import nd
Nicola Gatto's avatar
Nicola Gatto committed
9

10
class CNNDataLoader_Alexnet:
Nicola Gatto's avatar
Nicola Gatto committed
11
    _input_names_ = ['data']
12
    _output_names_ = ['predictions_label']
Nicola Gatto's avatar
Nicola Gatto committed
13
14
15
16

    def __init__(self):
        self._data_dir = "data/Alexnet/"

17
    def load_data(self, batch_size, shuffle=False):
Nicola Gatto's avatar
Nicola Gatto committed
18
19
        train_h5, test_h5 = self.load_h5_files()

20
21
22
        train_data = {}
        data_mean = {}
        data_std = {}
23
        train_images = {}
24
25
26

        for input_name in self._input_names_:
            train_data[input_name] = train_h5[input_name]
27
28
            data_mean[input_name + '_'] = nd.array(train_h5[input_name][:].mean(axis=0))
            data_std[input_name + '_'] = nd.array(train_h5[input_name][:].std(axis=0) + 1e-5)
29

30
31
32
            if 'images' in train_h5:
                train_images = train_h5['images']

33
        train_label = {}
Christian Fuß's avatar
Christian Fuß committed
34
        index = 0
35
        for output_name in self._output_names_:
Christian Fuß's avatar
Christian Fuß committed
36
37
            train_label[index] = train_h5[output_name]
            index += 1
38
39
40

        train_iter = mx.io.NDArrayIter(data=train_data,
                                       label=train_label,
41
42
                                       batch_size=batch_size,
                                       shuffle=shuffle)
Nicola Gatto's avatar
Nicola Gatto committed
43
44

        test_iter = None
45

Nicola Gatto's avatar
Nicola Gatto committed
46
        if test_h5 != None:
47
            test_data = {}
48
            test_images = {}
49
50
51
            for input_name in self._input_names_:
                test_data[input_name] = test_h5[input_name]

52
53
54
                if 'images' in test_h5:
                    test_images = test_h5['images']

55
            test_label = {}
Christian Fuß's avatar
Christian Fuß committed
56
            index = 0
57
            for output_name in self._output_names_:
Christian Fuß's avatar
Christian Fuß committed
58
59
                test_label[index] = test_h5[output_name]
                index += 1
60
61
62

            test_iter = mx.io.NDArrayIter(data=test_data,
                                          label=test_label,
63
                                          batch_size=batch_size)
64

65
        return train_iter, test_iter, data_mean, data_std, train_images, test_images
Nicola Gatto's avatar
Nicola Gatto committed
66

67
    def load_preprocessed_data(self, batch_size, preproc_lib, shuffle=False):
68
69
70
71
72
73
74
75
76
77
78
79
        train_h5, test_h5 = self.load_h5_files()

        wrapper = importlib.import_module(preproc_lib)
        instance = getattr(wrapper, preproc_lib)()
        instance.init()
        lib_head, _sep, tail = preproc_lib.rpartition('_')
        inp = getattr(wrapper, lib_head + "_input")()

        train_data = {}
        train_label = {}
        data_mean = {}
        data_std = {}
Julian Dierkes's avatar
Julian Dierkes committed
80
        train_images = {}
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108

        shape_output = self.preprocess_data(instance, inp, 0, train_h5)
        train_len = len(train_h5[self._input_names_[0]])

        for input_name in self._input_names_:
            if type(getattr(shape_output, input_name + "_out")) == np.ndarray:
                cur_shape = (train_len,) + getattr(shape_output, input_name + "_out").shape
            else:
                cur_shape = (train_len, 1)
            train_data[input_name] = mx.nd.zeros(cur_shape)
        for output_name in self._output_names_:
            if type(getattr(shape_output, output_name + "_out")) == nd.array:
                cur_shape = (train_len,) + getattr(shape_output, output_name + "_out").shape
            else:
                cur_shape = (train_len, 1)
            train_label[output_name] = mx.nd.zeros(cur_shape)

        for i in range(train_len):
            output = self.preprocess_data(instance, inp, i, train_h5)
            for input_name in self._input_names_:
                train_data[input_name][i] = getattr(output, input_name + "_out")
            for output_name in self._output_names_:
                train_label[output_name][i] = getattr(shape_output, output_name + "_out")

        for input_name in self._input_names_:
            data_mean[input_name + '_'] = nd.array(train_data[input_name][:].mean(axis=0))
            data_std[input_name + '_'] = nd.array(train_data[input_name][:].asnumpy().std(axis=0) + 1e-5)

109
110
111
        if 'images' in train_h5:
            train_images = train_h5['images']

112
113
        train_iter = mx.io.NDArrayIter(data=train_data,
                                       label=train_label,
114
115
                                       batch_size=batch_size,
                                       shuffle=shuffle)
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142

        test_data = {}
        test_label = {}

        shape_output = self.preprocess_data(instance, inp, 0, test_h5)
        test_len = len(test_h5[self._input_names_[0]])

        for input_name in self._input_names_:
            if type(getattr(shape_output, input_name + "_out")) == np.ndarray:
                cur_shape = (test_len,) + getattr(shape_output, input_name + "_out").shape
            else:
                cur_shape = (test_len, 1)
            test_data[input_name] = mx.nd.zeros(cur_shape)
        for output_name in self._output_names_:
            if type(getattr(shape_output, output_name + "_out")) == nd.array:
                cur_shape = (test_len,) + getattr(shape_output, output_name + "_out").shape
            else:
                cur_shape = (test_len, 1)
            test_label[output_name] = mx.nd.zeros(cur_shape)

        for i in range(test_len):
            output = self.preprocess_data(instance, inp, i, test_h5)
            for input_name in self._input_names_:
                test_data[input_name][i] = getattr(output, input_name + "_out")
            for output_name in self._output_names_:
                test_label[output_name][i] = getattr(shape_output, output_name + "_out")

Julian Dierkes's avatar
Julian Dierkes committed
143
        test_images = {}
144
145
146
        if 'images' in test_h5:
            test_images = test_h5['images']

147
148
149
150
        test_iter = mx.io.NDArrayIter(data=test_data,
                                       label=test_label,
                                       batch_size=batch_size)

151
        return train_iter, test_iter, data_mean, data_std, train_images, test_images
152
153
154

    def preprocess_data(self, instance_wrapper, input_wrapper, index, data_h5):
        for input_name in self._input_names_:
Julian Dierkes's avatar
Julian Dierkes committed
155
            data = data_h5[input_name][index]
156
157
158
159
160
161
162
            attr = getattr(input_wrapper, input_name)
            if (type(data)) == np.ndarray:
                data = np.asfortranarray(data).astype(attr.dtype)
            else:
                data = type(attr)(data)
            setattr(input_wrapper, input_name, data)
        for output_name in self._output_names_:
Julian Dierkes's avatar
Julian Dierkes committed
163
            data = data_h5[output_name][index]
164
165
166
167
168
169
170
171
            attr = getattr(input_wrapper, output_name)
            if (type(data)) == np.ndarray:
                data = np.asfortranarray(data).astype(attr.dtype)
            else:
                data = type(attr)(data)
            setattr(input_wrapper, output_name, data)
        return instance_wrapper.execute(input_wrapper)

Nicola Gatto's avatar
Nicola Gatto committed
172
173
174
175
176
    def load_h5_files(self):
        train_h5 = None
        test_h5 = None
        train_path = self._data_dir + "train.h5"
        test_path = self._data_dir + "test.h5"
177

Nicola Gatto's avatar
Nicola Gatto committed
178
179
        if os.path.isfile(train_path):
            train_h5 = h5py.File(train_path, 'r')
180
181
182
183
184
185
186
187
188
189
190
191
192

            for input_name in self._input_names_:
                if not input_name in train_h5:
                    logging.error("The HDF5 file '" + os.path.abspath(train_path) + "' has to contain the dataset "
                                  + "'" + input_name + "'")
                    sys.exit(1)

            for output_name in self._output_names_:
                if not output_name in train_h5:
                    logging.error("The HDF5 file '" + os.path.abspath(train_path) + "' has to contain the dataset "
                                  + "'" + output_name + "'")
                    sys.exit(1)

Nicola Gatto's avatar
Nicola Gatto committed
193
194
            if os.path.isfile(test_path):
                test_h5 = h5py.File(test_path, 'r')
195
196
197
198
199
200
201
202
203
204
205
206

                for input_name in self._input_names_:
                    if not input_name in test_h5:
                        logging.error("The HDF5 file '" + os.path.abspath(test_path) + "' has to contain the dataset "
                                      + "'" + input_name + "'")
                        sys.exit(1)

                for output_name in self._output_names_:
                    if not output_name in test_h5:
                        logging.error("The HDF5 file '" + os.path.abspath(test_path) + "' has to contain the dataset "
                                      + "'" + output_name + "'")
                        sys.exit(1)
Nicola Gatto's avatar
Nicola Gatto committed
207
208
            else:
                logging.warning("Couldn't load test set. File '" + os.path.abspath(test_path) + "' does not exist.")
209

Nicola Gatto's avatar
Nicola Gatto committed
210
211
212
            return train_h5, test_h5
        else:
            logging.error("Data loading failure. File '" + os.path.abspath(train_path) + "' does not exist.")
Bernhard Rumpe's avatar
BR-sy    
Bernhard Rumpe committed
213
            sys.exit(1)