import csv import cv2 import numpy as np import matplotlib.pyplot as plt from matplotlib.ticker import FormatStrFormatter import mxnet as mx import os EXAMPLES_PATH = "/media/sveta/4991e634-dd81-4cb9-bf46-2fa9c7159263/TORCS_examples/" MODEL_PATH = "../../dpnet_weights/normalized/dpnet_newest" RAW_PATH = "/media/sveta/4991e634-dd81-4cb9-bf46-2fa9c7159263/TORCS_raw/" def main(): # Load saved checkpoint sym, arg_params, aux_params = mx.model.load_checkpoint(MODEL_PATH, 0) mod = mx.mod.Module(symbol=sym, context=mx.cpu(), data_names=['data'], label_names=['predictions_label']) mod.bind(for_training=False, data_shapes=[('data', (1, 3, 210, 280))], label_shapes=mod._label_shapes) mod.set_params(arg_params, aux_params, allow_missing=True) # Get Test Data files = [f for f in os.listdir(EXAMPLES_PATH) if f.endswith(".png")] for f in files: key = f[:-4] img = get_image(key) labels = get_labels(key) # Predict from collections import namedtuple Batch = namedtuple('Batch', ['data']) mod.forward(Batch([mx.nd.array(img)])) prob = mod.get_outputs()[0].asnumpy() prob = prob[0].tolist() # Plot ground truth against predicted plt.scatter(range(len(labels)), labels, marker='x', label='Ground truth') plt.scatter(range(len(prob)), prob, marker = 'x', label='Predicted') plt.legend() plt.grid() ax = plt.gca() ax.yaxis.set_major_formatter(FormatStrFormatter('%.2f')) plt.show() def get_image(key): filename = EXAMPLES_PATH + key + ".png" img = cv2.imread(filename) # read image in b,g,r order img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB) # change to r,g,b order # img = cv2.resize(img, (224, 224)) # resize to 224*224 to fit model img = np.swapaxes(img, 0, 2) img = np.swapaxes(img, 1, 2) # change to (channel, height, width) img = img[np.newaxis, :] # extend to (example, channel, heigth, width) return img def get_labels(key): labels_file = EXAMPLES_PATH + key + "_labels.csv" with open(labels_file, 'rb') as csvfile: csv_reader = csv.reader(csvfile, delimiter=' ', quotechar='|') labels = [row for row in csv_reader][0] labels = [float(a) for a in labels] return labels main()