CNNDataLoader_Alexnet.py 5.49 KB
Newer Older
Nicola Gatto's avatar
Nicola Gatto committed
1 2 3 4 5
import os
import h5py
import mxnet as mx
import logging
import sys
Sebastian N.'s avatar
Sebastian N. committed
6 7
import numpy as np
import cv2
8
from mxnet import nd
Nicola Gatto's avatar
Nicola Gatto committed
9

10
class CNNDataLoader_Alexnet:
Nicola Gatto's avatar
Nicola Gatto committed
11
    _input_names_ = ['data']
12
    _output_names_ = ['predictions_label']
Nicola Gatto's avatar
Nicola Gatto committed
13 14 15 16

    def __init__(self):
        self._data_dir = "data/Alexnet/"

17
    def load_data(self, batch_size):
Nicola Gatto's avatar
Nicola Gatto committed
18 19
        train_h5, test_h5 = self.load_h5_files()

20 21 22
        train_data = {}
        data_mean = {}
        data_std = {}
23
        train_images = {}
24 25 26

        for input_name in self._input_names_:
            train_data[input_name] = train_h5[input_name]
27 28
            data_mean[input_name + '_'] = nd.array(train_h5[input_name][:].mean(axis=0))
            data_std[input_name + '_'] = nd.array(train_h5[input_name][:].std(axis=0) + 1e-5)
29

30 31 32
            if 'images' in train_h5:
                train_images = train_h5['images']

33
        train_label = {}
Christian Fuß's avatar
Christian Fuß committed
34
        index = 0
35
        for output_name in self._output_names_:
Christian Fuß's avatar
Christian Fuß committed
36 37
            train_label[index] = train_h5[output_name]
            index += 1
38 39 40

        train_iter = mx.io.NDArrayIter(data=train_data,
                                       label=train_label,
41
                                       batch_size=batch_size)
Nicola Gatto's avatar
Nicola Gatto committed
42 43

        test_iter = None
44

Nicola Gatto's avatar
Nicola Gatto committed
45
        if test_h5 != None:
46
            test_data = {}
47
            test_images = {}
48 49 50
            for input_name in self._input_names_:
                test_data[input_name] = test_h5[input_name]

51 52 53
                if 'images' in test_h5:
                    test_images = test_h5['images']

54
            test_label = {}
Christian Fuß's avatar
Christian Fuß committed
55
            index = 0
56
            for output_name in self._output_names_:
Christian Fuß's avatar
Christian Fuß committed
57 58
                test_label[index] = test_h5[output_name]
                index += 1
59 60 61

            test_iter = mx.io.NDArrayIter(data=test_data,
                                          label=test_label,
62
                                          batch_size=batch_size)
63

64
        return train_iter, test_iter, data_mean, data_std, train_images, test_images
Nicola Gatto's avatar
Nicola Gatto committed
65

66
    def load_data_img(self, batch_size, img_size):
Sebastian N.'s avatar
Sebastian N. committed
67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107
        train_h5, test_h5 = self.load_h5_files()
        width = img_size[0]
        height = img_size[1]

        comb_data = {}
        data_mean = {}
        data_std = {}

        for input_name in self._input_names_:
            train_data = train_h5[input_name][:]
            test_data = test_h5[input_name][:]

            train_shape = train_data.shape
            test_shape = test_data.shape

            comb_data[input_name] = mx.nd.zeros((train_shape[0]+test_shape[0], train_shape[1], width, height))
            for i, img in enumerate(train_data):
                img = img.transpose(1,2,0)
                comb_data[input_name][i] = cv2.resize(img, (width, height)).reshape((train_shape[1],width,height))
            for i, img in enumerate(test_data):
                img = img.transpose(1, 2, 0)
                comb_data[input_name][i+train_shape[0]] = cv2.resize(img, (width, height)).reshape((train_shape[1], width, height))

            data_mean[input_name + '_'] = nd.array(comb_data[input_name][:].mean(axis=0))
            data_std[input_name + '_'] = nd.array(comb_data[input_name][:].asnumpy().std(axis=0) + 1e-5)

        comb_label = {}
        for output_name in self._output_names_:
            train_labels = train_h5[output_name][:]
            test_labels = test_h5[output_name][:]
            comb_label[output_name] = np.append(train_labels, test_labels, axis=0)


        train_iter = mx.io.NDArrayIter(data=comb_data,
                                       label=comb_label,
                                       batch_size=batch_size)

        test_iter = None

        return train_iter, test_iter, data_mean, data_std

Nicola Gatto's avatar
Nicola Gatto committed
108 109 110 111 112
    def load_h5_files(self):
        train_h5 = None
        test_h5 = None
        train_path = self._data_dir + "train.h5"
        test_path = self._data_dir + "test.h5"
113

Nicola Gatto's avatar
Nicola Gatto committed
114 115
        if os.path.isfile(train_path):
            train_h5 = h5py.File(train_path, 'r')
116 117 118 119 120 121 122 123 124 125 126 127 128

            for input_name in self._input_names_:
                if not input_name in train_h5:
                    logging.error("The HDF5 file '" + os.path.abspath(train_path) + "' has to contain the dataset "
                                  + "'" + input_name + "'")
                    sys.exit(1)

            for output_name in self._output_names_:
                if not output_name in train_h5:
                    logging.error("The HDF5 file '" + os.path.abspath(train_path) + "' has to contain the dataset "
                                  + "'" + output_name + "'")
                    sys.exit(1)

Nicola Gatto's avatar
Nicola Gatto committed
129 130
            if os.path.isfile(test_path):
                test_h5 = h5py.File(test_path, 'r')
131 132 133 134 135 136 137 138 139 140 141 142

                for input_name in self._input_names_:
                    if not input_name in test_h5:
                        logging.error("The HDF5 file '" + os.path.abspath(test_path) + "' has to contain the dataset "
                                      + "'" + input_name + "'")
                        sys.exit(1)

                for output_name in self._output_names_:
                    if not output_name in test_h5:
                        logging.error("The HDF5 file '" + os.path.abspath(test_path) + "' has to contain the dataset "
                                      + "'" + output_name + "'")
                        sys.exit(1)
Nicola Gatto's avatar
Nicola Gatto committed
143 144
            else:
                logging.warning("Couldn't load test set. File '" + os.path.abspath(test_path) + "' does not exist.")
145

Nicola Gatto's avatar
Nicola Gatto committed
146 147 148
            return train_h5, test_h5
        else:
            logging.error("Data loading failure. File '" + os.path.abspath(train_path) + "' does not exist.")
Bernhard Rumpe's avatar
BR-sy  
Bernhard Rumpe committed
149
            sys.exit(1)