-
Jammer, Tim authoredJammer, Tim authored
Code owners
Assign users and groups as approvers for specific file changes. Learn more.
visualization.py 3.00 KiB
import pandas as pd
import matplotlib.pyplot as plt
import os
import re
import argparse
# which parameter to use for caption and order of the subplots
caption_regex = p = re.compile("--learning_rate (\d+\.\d+)")
def main():
parser = argparse.ArgumentParser(
description='Visualizes the results of a parameterscan')
parser.add_argument("--input_dir", type=str, default="results",
required=True, action='store',
help="Input directory")
parser.add_argument("--num_rows", type=int, default=4,
required=False, action='store',
help="Plot Layout: Number of rows")
parser.add_argument("--max_loss", type=float, default=0,
required=False, action='store',
help="maximum loss to display (yscale of plot)")
parser.add_argument("--output", type=str, default="plot.pdf",
required=False, action='store',
help="output filename")
ARGS = parser.parse_args()
cases = []
for path, directories, files in os.walk(ARGS.input_dir):
for dir in directories:
if (os.path.isfile(os.path.join(path, dir, "train_hist.csv")) and
os.path.isfile(os.path.join(path, dir, "test_accuracy.txt")) and
os.path.isfile(os.path.join(path, dir, "input_parameter.params"))):
cases.append(os.path.join(path, dir))
pass
else:
print("%s has not the expected files" % dir)
num_results = len(cases)
if num_results == 0:
print("Nothing to visualize")
exit()
data = {}
for case in cases:
param_string = "unknown"
# read last line from file
with open(os.path.join(case, "input_parameter.params"), 'r') as file:
param_string = file.readline()
train_hist = pd.read_csv(os.path.join(case, "train_hist.csv"))
title = caption_regex.findall(param_string)[0]
data[title] = train_hist
num_cols = int(num_results / ARGS.num_rows)
if num_results % ARGS.num_rows != 0:
num_cols += 1
fig, axes = plt.subplots(ARGS.num_rows, num_cols, sharex=True, sharey=True)
if ARGS.max_loss != 0:
plt.ylim([0, ARGS.max_loss])
i = 0
for title in sorted(data):
train_hist = data[title]
row = int(i / num_cols)
col = i % num_cols
if num_cols == 1:
this_plot = axes[row]
elif ARGS.num_rows == 1:
this_plot = axes[col]
else:
this_plot = axes[row, col]
this_plot.set_title(title)
this_plot.plot(train_hist['epoch'], train_hist['loss'], 'bo-', label='Train loss')
this_plot.plot(train_hist['epoch'], train_hist['val_loss'], 'rd--', label='Val loss')
this_plot.set_xlabel('Epoch')
this_plot.set_ylabel('Loss')
this_plot.legend()
i += 1
plt.show()
plt.savefig(ARGS.output)
if __name__ == "__main__":
main()