diff --git a/main.py b/main.py index e24ad8eae82762438cd2718ea83c55176e5892fb..24c427753dfa7689ff4d69b0360e3a62fc25d91b 100644 --- a/main.py +++ b/main.py @@ -115,16 +115,6 @@ def evaluate_func(f): with open(f+"/dataset_setting.json","r") as fp: dataset_setting = json.load(fp) - # init dataloaders - batchsize = meta_setting["batchsize"] - test_dataset = globals()[meta_setting["dataset"]](train = False,**dataset_setting) - #test_dataloader = torch.utils.data.DataLoader(test_dataset,batch_size=len(test_dataset), shuffle=False) - test_dataloader = torch.utils.data.DataLoader(test_dataset,batch_size=batchsize, shuffle=False) - # init UNet - net = globals()[meta_setting["modelname"]](**model_setting).to(device) - net = net.to(device) - # init Diffusion Model - framework = globals()[meta_setting["framework"]](net = net,device=device, **framework_setting) print(f"META SETTINGS:\n\n {meta_setting}\n\n") print(f"DATASET SETTINGS:\n\n {dataset_setting}\n\n") @@ -133,7 +123,7 @@ def evaluate_func(f): print(f"EVALUATION SETTINGS:\n\n {evaluation_setting}\n\n") print("\n\nSTART EVALUATION\n\n") - globals()[meta_setting["evaluation_function"]](model=framework, device=device, dataloader = test_dataloader,safepath = f,**evaluation_setting,) + globals()[meta_setting["evaluation_function"]](**evaluation_setting) print("\n\nFINISHED EVALUATION\n\n")