Select Git revision
GitVersion.yml
-
Marcel Nellesen authoredMarcel Nellesen authored
Code owners
Assign users and groups as approvers for specific file changes. Learn more.
param_generator.py 2.71 KiB
import argparse
import random
import os
## definition of search area
min_batch_size = 16
max_batch_size = 1024
min_lr = 0.0001
max_lr = 0.5
min_dense_layers = 1
max_dense_layers = 5
min_conv_layers = 1
max_conv_layers = 3
min_num_filters = 2
max_num_filters = 32
min_num_units = 11
max_num_units = 128
## settings for all parameter sets
use_augmentation = True
num_epochs = 10
def write_parameter_file(
job_number=1,
num_epochs=20,
batch_size=128,
lr=0.01,
num_units=30,
dense_layers=1,
num_filters=16,
conv_layers=1,
augment=False,
path="./parameter/",
):
if augment:
augment_str = "--augment_data"
else:
augment_str = ""
parameters = (
"--num_epochs %d --batch_size %d --learning_rate %f"
" --num_units %d --dense_layers %d --conv_layers %d"
" --num_filters %d %s\n"
% (
# " --early_stopping --trainfile_suffix %d --num_filters %d %s\n" % (
num_epochs,
batch_size,
lr,
num_units,
dense_layers,
conv_layers,
num_filters,
augment_str,
)
)
with open(path + '/' + str(job_number) + ".params", "w") as file:
file.write(parameters)
def main():
parser = argparse.ArgumentParser(
description='generates parameters for a random parameterscan. The range of parameters to generate can be setup in the script')
parser.add_argument("--output_dir", type=str, default="./parameter",
required=False, action='store',
help="The directroy, where the output should be stored")
parser.add_argument("-n", type=int, default=5,
required=False, action='store',
help="Number of params to generate")
ARGS = parser.parse_args()
os.makedirs(ARGS.output_dir, exist_ok=True)
# use range +1 will lead to less errors when setting the slurm array boundary
for i in range(ARGS.n + 1):
# check random parameters
write_parameter_file(
job_number=i,
num_epochs=num_epochs,
batch_size=random.randint(min_batch_size, max_batch_size),
lr=random.uniform(min_lr, max_lr),
num_units=random.randint(min_num_units, max_num_units),
dense_layers=random.randint(min_dense_layers, max_dense_layers),
num_filters=random.randint(min_num_filters, max_num_filters),
conv_layers=random.randint(min_conv_layers, max_conv_layers),
augment=use_augmentation,
path=ARGS.output_dir
)
if __name__ == "__main__":
main()