CNNDataLoader.ftl 5.67 KB
Newer Older
Bernhard Rumpe's avatar
BR-sy  
Bernhard Rumpe committed
1
<#-- (c) https://github.com/MontiCore/monticore -->
Nicola Gatto's avatar
Nicola Gatto committed
2 3 4 5 6
import os
import h5py
import mxnet as mx
import logging
import sys
Julian Dierkes's avatar
Julian Dierkes committed
7 8
import numpy as np
import cv2
9
from mxnet import nd
Nicola Gatto's avatar
Nicola Gatto committed
10

11
class ${tc.fileNameWithoutEnding}:
12 13
    _input_names_ = [<#list tc.architectureInputs as inputName>'${inputName?keep_before_last("_")}'<#sep>, </#list>]
    _output_names_ = [${tc.join(tc.architectureOutputs, ",", "'", "label'")}]
Nicola Gatto's avatar
Nicola Gatto committed
14 15 16 17

    def __init__(self):
        self._data_dir = "${tc.dataPath}/"

18
    def load_data(self, batch_size):
Nicola Gatto's avatar
Nicola Gatto committed
19 20
        train_h5, test_h5 = self.load_h5_files()

21 22 23
        train_data = {}
        data_mean = {}
        data_std = {}
24
        train_images = {}
25 26 27

        for input_name in self._input_names_:
            train_data[input_name] = train_h5[input_name]
28 29
            data_mean[input_name + '_'] = nd.array(train_h5[input_name][:].mean(axis=0))
            data_std[input_name + '_'] = nd.array(train_h5[input_name][:].std(axis=0) + 1e-5)
30

31 32 33
            if 'images' in train_h5:
                train_images = train_h5['images']

34
        train_label = {}
35
        index = 0
36
        for output_name in self._output_names_:
37 38
            train_label[index] = train_h5[output_name]
            index += 1
39 40 41

        train_iter = mx.io.NDArrayIter(data=train_data,
                                       label=train_label,
42
                                       batch_size=batch_size)
Nicola Gatto's avatar
Nicola Gatto committed
43 44

        test_iter = None
45

Nicola Gatto's avatar
Nicola Gatto committed
46
        if test_h5 != None:
47
            test_data = {}
48
            test_images = {}
49 50 51
            for input_name in self._input_names_:
                test_data[input_name] = test_h5[input_name]

52 53 54
                if 'images' in test_h5:
                    test_images = test_h5['images']

55
            test_label = {}
56
            index = 0
57
            for output_name in self._output_names_:
58 59
                test_label[index] = test_h5[output_name]
                index += 1
60 61 62

            test_iter = mx.io.NDArrayIter(data=test_data,
                                          label=test_label,
63
                                          batch_size=batch_size)
64

65
        return train_iter, test_iter, data_mean, data_std, train_images, test_images
Nicola Gatto's avatar
Nicola Gatto committed
66

67
    def load_data_img(self, batch_size, img_size):
Julian Dierkes's avatar
Julian Dierkes committed
68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108
        train_h5, test_h5 = self.load_h5_files()
        width = img_size[0]
        height = img_size[1]

        comb_data = {}
        data_mean = {}
        data_std = {}

        for input_name in self._input_names_:
            train_data = train_h5[input_name][:]
            test_data = test_h5[input_name][:]

            train_shape = train_data.shape
            test_shape = test_data.shape

            comb_data[input_name] = mx.nd.zeros((train_shape[0]+test_shape[0], train_shape[1], width, height))
            for i, img in enumerate(train_data):
                img = img.transpose(1,2,0)
                comb_data[input_name][i] = cv2.resize(img, (width, height)).reshape((train_shape[1],width,height))
            for i, img in enumerate(test_data):
                img = img.transpose(1, 2, 0)
                comb_data[input_name][i+train_shape[0]] = cv2.resize(img, (width, height)).reshape((train_shape[1], width, height))

            data_mean[input_name + '_'] = nd.array(comb_data[input_name][:].mean(axis=0))
            data_std[input_name + '_'] = nd.array(comb_data[input_name][:].asnumpy().std(axis=0) + 1e-5)

        comb_label = {}
        for output_name in self._output_names_:
            train_labels = train_h5[output_name][:]
            test_labels = test_h5[output_name][:]
            comb_label[output_name] = np.append(train_labels, test_labels, axis=0)


        train_iter = mx.io.NDArrayIter(data=comb_data,
                                       label=comb_label,
                                       batch_size=batch_size)

        test_iter = None

        return train_iter, test_iter, data_mean, data_std

Nicola Gatto's avatar
Nicola Gatto committed
109 110 111 112 113
    def load_h5_files(self):
        train_h5 = None
        test_h5 = None
        train_path = self._data_dir + "train.h5"
        test_path = self._data_dir + "test.h5"
114

Nicola Gatto's avatar
Nicola Gatto committed
115 116
        if os.path.isfile(train_path):
            train_h5 = h5py.File(train_path, 'r')
117 118 119 120 121 122 123 124 125 126 127 128 129

            for input_name in self._input_names_:
                if not input_name in train_h5:
                    logging.error("The HDF5 file '" + os.path.abspath(train_path) + "' has to contain the dataset "
                                  + "'" + input_name + "'")
                    sys.exit(1)

            for output_name in self._output_names_:
                if not output_name in train_h5:
                    logging.error("The HDF5 file '" + os.path.abspath(train_path) + "' has to contain the dataset "
                                  + "'" + output_name + "'")
                    sys.exit(1)

Nicola Gatto's avatar
Nicola Gatto committed
130 131
            if os.path.isfile(test_path):
                test_h5 = h5py.File(test_path, 'r')
132 133 134 135 136 137 138 139 140 141 142 143

                for input_name in self._input_names_:
                    if not input_name in test_h5:
                        logging.error("The HDF5 file '" + os.path.abspath(test_path) + "' has to contain the dataset "
                                      + "'" + input_name + "'")
                        sys.exit(1)

                for output_name in self._output_names_:
                    if not output_name in test_h5:
                        logging.error("The HDF5 file '" + os.path.abspath(test_path) + "' has to contain the dataset "
                                      + "'" + output_name + "'")
                        sys.exit(1)
Nicola Gatto's avatar
Nicola Gatto committed
144 145
            else:
                logging.warning("Couldn't load test set. File '" + os.path.abspath(test_path) + "' does not exist.")
146

Nicola Gatto's avatar
Nicola Gatto committed
147 148 149
            return train_h5, test_h5
        else:
            logging.error("Data loading failure. File '" + os.path.abspath(train_path) + "' does not exist.")
Bernhard Rumpe's avatar
BR-sy  
Bernhard Rumpe committed
150
            sys.exit(1)