test_dpnet.py 2.37 KB
Newer Older
1 2 3 4 5 6 7 8 9
import csv
import cv2
import numpy as np
import matplotlib.pyplot as plt
from matplotlib.ticker import FormatStrFormatter
import mxnet as mx
import os

EXAMPLES_PATH = "/media/sveta/4991e634-dd81-4cb9-bf46-2fa9c7159263/TORCS_examples/"
10
MODEL_PATH = "../../dpnet_weights/normalized/dpnet_newest"
11 12 13 14 15 16 17 18
RAW_PATH = "/media/sveta/4991e634-dd81-4cb9-bf46-2fa9c7159263/TORCS_raw/"


def main():
    # Load saved checkpoint
    sym, arg_params, aux_params = mx.model.load_checkpoint(MODEL_PATH, 0)
    mod = mx.mod.Module(symbol=sym,
                        context=mx.cpu(),
19
                        data_names=['data'],
20 21
                        label_names=['predictions_label'])
    mod.bind(for_training=False,
22
             data_shapes=[('data', (1, 3, 210, 280))],
23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41
             label_shapes=mod._label_shapes)

    mod.set_params(arg_params, aux_params, allow_missing=True)

    # Get Test Data
    files = [f for f in os.listdir(EXAMPLES_PATH) if f.endswith(".png")]
    for f in files:
        key = f[:-4]
        img = get_image(key)
        labels = get_labels(key)

        # Predict
        from collections import namedtuple
        Batch = namedtuple('Batch', ['data'])
        mod.forward(Batch([mx.nd.array(img)]))
        prob = mod.get_outputs()[0].asnumpy()
        prob = prob[0].tolist()

        # Plot ground truth against predicted
42 43
        plt.scatter(range(len(labels)), labels, marker='x', label='Ground truth')
        plt.scatter(range(len(prob)), prob, marker = 'x', label='Predicted')
44
        plt.legend()
45
        plt.grid()
46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71
        ax = plt.gca()
        ax.yaxis.set_major_formatter(FormatStrFormatter('%.2f'))
        plt.show()


def get_image(key):
    filename = EXAMPLES_PATH + key + ".png"
    img = cv2.imread(filename)  # read image in b,g,r order
    img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)   # change to r,g,b order
    # img = cv2.resize(img, (224, 224))  # resize to 224*224 to fit model
    img = np.swapaxes(img, 0, 2)
    img = np.swapaxes(img, 1, 2)  # change to (channel, height, width)
    img = img[np.newaxis, :]  # extend to (example, channel, heigth, width)
    return img


def get_labels(key):
    labels_file = EXAMPLES_PATH + key + "_labels.csv"
    with open(labels_file, 'rb') as csvfile:
        csv_reader = csv.reader(csvfile, delimiter=' ', quotechar='|')
        labels = [row for row in csv_reader][0]
        labels = [float(a) for a in labels]
        return labels

main()