Skip to content
Snippets Groups Projects
Select Git revision
  • Topic/1453-userInvitation
  • master default protected
  • gitkeep
  • dev protected
  • Hotfix/2521-fixEmail
  • Issue/2431-fixesInNotificationFooter
  • Issue/2309-docs
  • Hotfix/2354-emailNotificationPID
  • Issue/2263-changeMailingDomain
  • Issue/2158-emailServicedesk
  • Hotfix/2087-efNet6
  • Issue/1910-MigrationtoNET6.0
  • Sprint/2022-01
  • Sprint/2021-23
  • Issue/1746-ApplicationProfileStoringMethod
  • Sprint/2021-08
  • Product/202-userInvitation
  • Sprint/2021-05
  • Hotfix/1370-swaggerDescription
  • Sprint/2021-04
  • v2.6.3
  • v2.6.2
  • v2.6.1
  • v2.6.0
  • v2.5.2
  • v2.5.1
  • v2.5.0
  • v2.4.2
  • v2.4.1
  • v2.4.0
  • v2.3.0
  • v2.2.0
  • v2.1.1
  • v2.1.0
  • v2.0.1
  • v2.0.0
  • v1.3.2
  • v1.3.1
  • v1.3.0
  • v1.2.1
40 results

GitVersion.yml

Blame
  • Code owners
    Assign users and groups as approvers for specific file changes. Learn more.
    param_generator.py 2.71 KiB
    import argparse
    import random
    import os
    
    ## definition of search area
    min_batch_size = 16
    max_batch_size = 1024
    min_lr = 0.0001
    max_lr = 0.5
    min_dense_layers = 1
    max_dense_layers = 5
    min_conv_layers = 1
    max_conv_layers = 3
    min_num_filters = 2
    max_num_filters = 32
    min_num_units = 11
    max_num_units = 128
    
    ## settings for all parameter sets
    use_augmentation = True
    num_epochs = 10
    
    
    def write_parameter_file(
            job_number=1,
            num_epochs=20,
            batch_size=128,
            lr=0.01,
            num_units=30,
            dense_layers=1,
            num_filters=16,
            conv_layers=1,
            augment=False,
            path="./parameter/",
    ):
        if augment:
            augment_str = "--augment_data"
        else:
            augment_str = ""
    
        parameters = (
                "--num_epochs %d --batch_size %d --learning_rate %f"
                " --num_units %d --dense_layers %d --conv_layers %d"
                " --num_filters %d %s\n"
                % (
                    # " --early_stopping --trainfile_suffix %d --num_filters %d %s\n" % (
                    num_epochs,
                    batch_size,
                    lr,
                    num_units,
                    dense_layers,
                    conv_layers,
                    num_filters,
                    augment_str,
                )
        )
    
        with open(path + '/' + str(job_number) + ".params", "w") as file:
            file.write(parameters)
    
    
    def main():
        parser = argparse.ArgumentParser(
            description='generates parameters for a random parameterscan. The range of parameters to generate can be setup in the script')
        parser.add_argument("--output_dir", type=str, default="./parameter",
                            required=False, action='store',
                            help="The directroy, where the output should be stored")
        parser.add_argument("-n", type=int, default=5,
                            required=False, action='store',
                            help="Number of params to generate")
        ARGS = parser.parse_args()
        os.makedirs(ARGS.output_dir, exist_ok=True)
        # use range +1 will lead to less errors when setting the slurm array boundary
        for i in range(ARGS.n + 1):
            # check random parameters
            write_parameter_file(
                job_number=i,
                num_epochs=num_epochs,
                batch_size=random.randint(min_batch_size, max_batch_size),
                lr=random.uniform(min_lr, max_lr),
                num_units=random.randint(min_num_units, max_num_units),
                dense_layers=random.randint(min_dense_layers, max_dense_layers),
                num_filters=random.randint(min_num_filters, max_num_filters),
                conv_layers=random.randint(min_conv_layers, max_conv_layers),
                augment=use_augmentation,
                path=ARGS.output_dir
            )
    
    
    if __name__ == "__main__":
        main()