log_parser.py 1.83 KB
Newer Older
1
2
3
4
import matplotlib.pyplot as plt
import numpy as np
import re

5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
EPOCH_WISE = False
NETWORK= "DPNet"
LOGFILE = "../../dpnet_weights/normalized/train.log"
STEP = 50

rows = open(LOGFILE).read().strip()
train_mse = list()
validation_mse = list()
train_iterations = list()
validation_iterations = list()
speeds = list()

# grab the set of training epochs
epochs = set(re.findall(r'Epoch\[(\d+)\]', rows))
epochs = sorted([int(e) for e in epochs])
for e in epochs:
    train_mse_regexp = r'Epoch\[' + str(e) + '\].*(\s)mse=(.*)'
    mse = re.findall(train_mse_regexp, rows)
    mse = [float(a[1]) for a in mse]
    if EPOCH_WISE:
        train_mse.append(mse[-1])
    else:
        train_mse += mse

    speed_regexp = r'Epoch\[' + str(e) + '\].*(\s)Speed: (.*) samples'
    speed = re.findall(speed_regexp, rows)
    speed = [float(a[1]) for a in speed]
    speeds += speed

    validation_mse_regexp = r'Epoch\[' + str(e) + '\].*Validation-mse=(.*)'
    current_validation_mse = re.findall(validation_mse_regexp, rows)
    validation_mse.append(float(current_validation_mse[0]))

    last_iteration = train_iterations[-1] if len(train_iterations) > 0 else 0
    if EPOCH_WISE:
        train_iterations.append(e)
        validation_iterations.append(e)
    else:
        current_iterations = range(last_iteration+STEP, last_iteration+STEP * len(mse) + STEP, STEP)
        train_iterations += current_iterations
        validation_iterations.append(last_iteration+STEP * len(mse) + STEP)

print("Mean speed is " + str(np.mean(speeds)))
48
49

# plot the accuracies
50
# plt.style.use("ggplot")
51
plt.figure()
52
53
54
55
56
57
58
59
60
plt.plot(train_iterations, train_mse,
         label="train")
plt.plot(validation_iterations, validation_mse,
         label="validation")
if EPOCH_WISE:
    plt.xlabel("Epochs #")
else:
    plt.xlabel("Iterations")
plt.ylabel("MSE")
61
plt.legend(loc="upper right")
62
plt.grid()
63
plt.show()