CNNDataLoader.ftl 3.66 KB
Newer Older
Nicola Gatto's avatar
Nicola Gatto committed
1
2
3
4
5
import os
import h5py
import mxnet as mx
import logging
import sys
6
from mxnet import nd
Nicola Gatto's avatar
Nicola Gatto committed
7

8
class ${tc.fileNameWithoutEnding}:
9
10
    _input_names_ = [<#list tc.architectureInputs as inputName>'${inputName?keep_before_last("_")}'<#sep>, </#list>]
    _output_names_ = [${tc.join(tc.architectureOutputs, ",", "'", "label'")}]
Nicola Gatto's avatar
Nicola Gatto committed
11
12
13
14
15
16
17

    def __init__(self):
        self._data_dir = "${tc.dataPath}/"

    def load_data(self, batch_size):
        train_h5, test_h5 = self.load_h5_files()

18
19
20
21
22
23
        train_data = {}
        data_mean = {}
        data_std = {}

        for input_name in self._input_names_:
            train_data[input_name] = train_h5[input_name]
24
25
            data_mean[input_name + '_'] = nd.array(train_h5[input_name][:].mean(axis=0))
            data_std[input_name + '_'] = nd.array(train_h5[input_name][:].std(axis=0) + 1e-5)
26
27

        train_label = {}
28
        index = 0
29
        for output_name in self._output_names_:
30
31
            train_label[index] = train_h5[output_name]
            index += 1
32
33
34
35

        train_iter = mx.io.NDArrayIter(data=train_data,
                                       label=train_label,
                                       batch_size=batch_size)
Nicola Gatto's avatar
Nicola Gatto committed
36
37

        test_iter = None
38

Nicola Gatto's avatar
Nicola Gatto committed
39
        if test_h5 != None:
40
41
42
43
44
            test_data = {}
            for input_name in self._input_names_:
                test_data[input_name] = test_h5[input_name]

            test_label = {}
45
            index = 0
46
            for output_name in self._output_names_:
47
48
                test_label[index] = test_h5[output_name]
                index += 1
49
50
51
52
53

            test_iter = mx.io.NDArrayIter(data=test_data,
                                          label=test_label,
                                          batch_size=batch_size)

Nicola Gatto's avatar
Nicola Gatto committed
54
55
56
57
58
59
60
        return train_iter, test_iter, data_mean, data_std

    def load_h5_files(self):
        train_h5 = None
        test_h5 = None
        train_path = self._data_dir + "train.h5"
        test_path = self._data_dir + "test.h5"
61

Nicola Gatto's avatar
Nicola Gatto committed
62
63
        if os.path.isfile(train_path):
            train_h5 = h5py.File(train_path, 'r')
64
65
66
67
68
69
70
71
72
73
74
75
76

            for input_name in self._input_names_:
                if not input_name in train_h5:
                    logging.error("The HDF5 file '" + os.path.abspath(train_path) + "' has to contain the dataset "
                                  + "'" + input_name + "'")
                    sys.exit(1)

            for output_name in self._output_names_:
                if not output_name in train_h5:
                    logging.error("The HDF5 file '" + os.path.abspath(train_path) + "' has to contain the dataset "
                                  + "'" + output_name + "'")
                    sys.exit(1)

Nicola Gatto's avatar
Nicola Gatto committed
77
78
            if os.path.isfile(test_path):
                test_h5 = h5py.File(test_path, 'r')
79
80
81
82
83
84
85
86
87
88
89
90

                for input_name in self._input_names_:
                    if not input_name in test_h5:
                        logging.error("The HDF5 file '" + os.path.abspath(test_path) + "' has to contain the dataset "
                                      + "'" + input_name + "'")
                        sys.exit(1)

                for output_name in self._output_names_:
                    if not output_name in test_h5:
                        logging.error("The HDF5 file '" + os.path.abspath(test_path) + "' has to contain the dataset "
                                      + "'" + output_name + "'")
                        sys.exit(1)
Nicola Gatto's avatar
Nicola Gatto committed
91
92
            else:
                logging.warning("Couldn't load test set. File '" + os.path.abspath(test_path) + "' does not exist.")
93

Nicola Gatto's avatar
Nicola Gatto committed
94
95
96
97
            return train_h5, test_h5
        else:
            logging.error("Data loading failure. File '" + os.path.abspath(train_path) + "' does not exist.")
            sys.exit(1)