Aufgrund einer Wartung wird GitLab am 26.10. zwischen 8:00 und 9:00 Uhr kurzzeitig nicht zur Verfügung stehen. / Due to maintenance, GitLab will be temporarily unavailable on 26.10. between 8:00 and 9:00 am.

CNNDataLoader_Alexnet.py 3.44 KB
Newer Older
Nicola Gatto's avatar
Nicola Gatto committed
1
2
3
4
5
import os
import h5py
import mxnet as mx
import logging
import sys
6
from mxnet import nd
Nicola Gatto's avatar
Nicola Gatto committed
7

8
class CNNDataLoader_Alexnet:
Nicola Gatto's avatar
Nicola Gatto committed
9
    _input_names_ = ['data']
10
    _output_names_ = ['predictions']
Nicola Gatto's avatar
Nicola Gatto committed
11
12
13
14
15
16
17

    def __init__(self):
        self._data_dir = "data/Alexnet/"

    def load_data(self, batch_size):
        train_h5, test_h5 = self.load_h5_files()

18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
        train_data = {}
        data_mean = {}
        data_std = {}

        for input_name in self._input_names_:
            train_data[input_name] = train_h5[input_name]
            data_mean[input_name] = nd.array(train_h5[input_name][:].mean(axis=0))
            data_std[input_name] = nd.array(train_h5[input_name][:].std(axis=0) + 1e-5)

        train_label = {}
        for output_name in self._output_names_:
            train_label[output_name] = train_h5[output_name]

        train_iter = mx.io.NDArrayIter(data=train_data,
                                       label=train_label,
                                       batch_size=batch_size)
Nicola Gatto's avatar
Nicola Gatto committed
34
35

        test_iter = None
36

Nicola Gatto's avatar
Nicola Gatto committed
37
        if test_h5 != None:
38
39
40
41
42
43
44
45
46
47
48
49
            test_data = {}
            for input_name in self._input_names_:
                test_data[input_name] = test_h5[input_name]

            test_label = {}
            for output_name in self._output_names_:
                test_label[output_name] = test_h5[output_name]

            test_iter = mx.io.NDArrayIter(data=test_data,
                                          label=test_label,
                                          batch_size=batch_size)

Nicola Gatto's avatar
Nicola Gatto committed
50
51
52
53
54
55
56
        return train_iter, test_iter, data_mean, data_std

    def load_h5_files(self):
        train_h5 = None
        test_h5 = None
        train_path = self._data_dir + "train.h5"
        test_path = self._data_dir + "test.h5"
57

Nicola Gatto's avatar
Nicola Gatto committed
58
59
        if os.path.isfile(train_path):
            train_h5 = h5py.File(train_path, 'r')
60
61
62
63
64
65
66
67
68
69
70
71
72

            for input_name in self._input_names_:
                if not input_name in train_h5:
                    logging.error("The HDF5 file '" + os.path.abspath(train_path) + "' has to contain the dataset "
                                  + "'" + input_name + "'")
                    sys.exit(1)

            for output_name in self._output_names_:
                if not output_name in train_h5:
                    logging.error("The HDF5 file '" + os.path.abspath(train_path) + "' has to contain the dataset "
                                  + "'" + output_name + "'")
                    sys.exit(1)

Nicola Gatto's avatar
Nicola Gatto committed
73
74
            if os.path.isfile(test_path):
                test_h5 = h5py.File(test_path, 'r')
75
76
77
78
79
80
81
82
83
84
85
86

                for input_name in self._input_names_:
                    if not input_name in test_h5:
                        logging.error("The HDF5 file '" + os.path.abspath(test_path) + "' has to contain the dataset "
                                      + "'" + input_name + "'")
                        sys.exit(1)

                for output_name in self._output_names_:
                    if not output_name in test_h5:
                        logging.error("The HDF5 file '" + os.path.abspath(test_path) + "' has to contain the dataset "
                                      + "'" + output_name + "'")
                        sys.exit(1)
Nicola Gatto's avatar
Nicola Gatto committed
87
88
            else:
                logging.warning("Couldn't load test set. File '" + os.path.abspath(test_path) + "' does not exist.")
89

Nicola Gatto's avatar
Nicola Gatto committed
90
91
92
93
            return train_h5, test_h5
        else:
            logging.error("Data loading failure. File '" + os.path.abspath(train_path) + "' does not exist.")
            sys.exit(1)