diff --git a/evaluation/sample.py b/evaluation/sample.py index 9805ff2b622cb72788a82b739f55b62de195bdfe..58ee3be53ae6858803f0261026ec251c94511b71 100644 --- a/evaluation/sample.py +++ b/evaluation/sample.py @@ -3,7 +3,7 @@ import torch from torchvision import transforms import re -def cdm_sampler(model, checkpoint, experiment_path, device, intermediate=False, batch_size=15,sample_all=False,n_times=1): +def cdm_sampler(model, checkpoint, experiment_path, dataloader, device, intermediate=False, batch_size=15,sample_all=False,n_times=1): ''' Samples a tensor of 'batch_size' images from a trained diffusion model with 'checkpoint'. The generated images are stored in the directory 'experiment_path/samples/epoch_{e}/sample_{j}. Where e is the epoch diff --git a/main.py b/main.py index 701fa00701092b6fa65c4752d31525a06eab9382..467552345d0823737d2693a0349b5153d536137b 100644 --- a/main.py +++ b/main.py @@ -75,6 +75,13 @@ def sample_func(f): with open(f+"/sampling_setting.json","r") as fp: sampling_setting = json.load(fp) + with open(f+"/dataset_setting.json","r") as fp: + dataset_setting = json.load(fp) + + batchsize = sampling_setting["batch_size"] + test_dataset = globals()[meta_setting["dataset"]](train = False,**dataset_setting) + test_dataloader = torch.utils.data.DataLoader(test_dataset,batch_size=batchsize,shuffle=True) + # init Unet net = globals()[meta_setting["modelname"]](**model_setting).to(device) #net = torch.compile(net) @@ -88,7 +95,7 @@ def sample_func(f): print(f"SAMPLING SETTINGS:\n\n {sampling_setting}\n\n") print("\n\nSTART SAMPLING\n\n") - globals()[meta_setting["sampling_function"]](model=framework,device=device ,**sampling_setting,) + globals()[meta_setting["sampling_function"]](model=framework,device=device, dataloader = test_dataloader, **sampling_setting,) print("\n\nFINISHED SAMPLING\n\n")