Skip to content
Snippets Groups Projects
Code owners
Assign users and groups as approvers for specific file changes. Learn more.
visualization.py 3.00 KiB
import pandas as pd
import matplotlib.pyplot as plt
import os
import re
import argparse

# which parameter to use for caption and order of the subplots
caption_regex = p = re.compile("--learning_rate (\d+\.\d+)")


def main():
    parser = argparse.ArgumentParser(
        description='Visualizes the results of a parameterscan')
    parser.add_argument("--input_dir", type=str, default="results",
                        required=True, action='store',
                        help="Input directory")
    parser.add_argument("--num_rows", type=int, default=4,
                        required=False, action='store',
                        help="Plot Layout: Number of rows")
    parser.add_argument("--max_loss", type=float, default=0,
                        required=False, action='store',
                        help="maximum loss to display (yscale of plot)")
    parser.add_argument("--output", type=str, default="plot.pdf",
                        required=False, action='store',
                        help="output filename")

    ARGS = parser.parse_args()

    cases = []

    for path, directories, files in os.walk(ARGS.input_dir):
        for dir in directories:
            if (os.path.isfile(os.path.join(path, dir, "train_hist.csv")) and
                    os.path.isfile(os.path.join(path, dir, "test_accuracy.txt")) and
                    os.path.isfile(os.path.join(path, dir, "input_parameter.params"))):
                cases.append(os.path.join(path, dir))
                pass
            else:
                print("%s has not the expected files" % dir)

    num_results = len(cases)
    if num_results == 0:
        print("Nothing to visualize")
        exit()

    data = {}
    for case in cases:
        param_string = "unknown"
        # read last line from file
        with open(os.path.join(case, "input_parameter.params"), 'r') as file:
            param_string = file.readline()

        train_hist = pd.read_csv(os.path.join(case, "train_hist.csv"))

        title = caption_regex.findall(param_string)[0]
        data[title] = train_hist

    num_cols = int(num_results / ARGS.num_rows)
    if num_results % ARGS.num_rows != 0:
        num_cols += 1
    fig, axes = plt.subplots(ARGS.num_rows, num_cols, sharex=True, sharey=True)
    if ARGS.max_loss != 0:
        plt.ylim([0, ARGS.max_loss])

    i = 0
    for title in sorted(data):
        train_hist = data[title]

        row = int(i / num_cols)
        col = i % num_cols
        if num_cols == 1:
            this_plot = axes[row]
        elif ARGS.num_rows == 1:
            this_plot = axes[col]
        else:
            this_plot = axes[row, col]

        this_plot.set_title(title)
        this_plot.plot(train_hist['epoch'], train_hist['loss'], 'bo-', label='Train loss')
        this_plot.plot(train_hist['epoch'], train_hist['val_loss'], 'rd--', label='Val loss')
        this_plot.set_xlabel('Epoch')
        this_plot.set_ylabel('Loss')
        this_plot.legend()

        i += 1

    plt.show()
    plt.savefig(ARGS.output)


if __name__ == "__main__":
    main()