custom_functions.py 1.09 KB
Newer Older
1
import mxnet as mx
2 3
import numpy as np

4 5 6
DATA_NAME = 'data'
PREDICTIONS_LABEL = 'predictions_label'

7 8
TRAIN_MEAN = [99.39394537, 110.60877108, 117.86127587]
TRAIN_STD = [42.04910545, 49.47874084, 62.61726178]
9

10
def load_data_rec(data_dir, batch_size):
11 12
    width = 280
    height = 210
13

14
    data_mean = np.load("./mean_image.npy")
15
    train_iter = mx.image.ImageIter(
16
        path_imgrec=data_dir + "torcs_train.rec",
17
        data_shape=(3, height, width),  # (channels, height, width)
18 19
        batch_size=batch_size,
        label_width=14,
20 21
        data_name=DATA_NAME,
        label_name=PREDICTIONS_LABEL
22 23
    )
    test_iter = mx.image.ImageIter(
24
        path_imgrec=data_dir + "torcs_test.rec",
25
        data_shape=(3, height, width),  # (channels, height, width)
26 27
        batch_size=batch_size,
        label_width=14,
28 29
        data_name=DATA_NAME,
        label_name=PREDICTIONS_LABEL
30 31
    )

32 33 34
    data_std = None
    # data_mean = np.asarray([[[a] * width] * height for a in TRAIN_MEAN])
    # data_std = np.asarray([[[a] * width] * height for a in TRAIN_STD])
35 36

    return train_iter, test_iter, data_mean, data_std