Skip to content
Snippets Groups Projects
Commit a186748b authored by Gonzalo Martin Garcia's avatar Gonzalo Martin Garcia
Browse files

missing dataloader in main

parent c71c8a35
No related branches found
No related tags found
No related merge requests found
......@@ -3,7 +3,7 @@ import torch
from torchvision import transforms
import re
def cdm_sampler(model, checkpoint, experiment_path, device, intermediate=False, batch_size=15,sample_all=False,n_times=1):
def cdm_sampler(model, checkpoint, experiment_path, dataloader, device, intermediate=False, batch_size=15,sample_all=False,n_times=1):
'''
Samples a tensor of 'batch_size' images from a trained diffusion model with 'checkpoint'. The generated
images are stored in the directory 'experiment_path/samples/epoch_{e}/sample_{j}. Where e is the epoch
......
......@@ -75,6 +75,13 @@ def sample_func(f):
with open(f+"/sampling_setting.json","r") as fp:
sampling_setting = json.load(fp)
with open(f+"/dataset_setting.json","r") as fp:
dataset_setting = json.load(fp)
batchsize = sampling_setting["batch_size"]
test_dataset = globals()[meta_setting["dataset"]](train = False,**dataset_setting)
test_dataloader = torch.utils.data.DataLoader(test_dataset,batch_size=batchsize,shuffle=True)
# init Unet
net = globals()[meta_setting["modelname"]](**model_setting).to(device)
#net = torch.compile(net)
......@@ -88,7 +95,7 @@ def sample_func(f):
print(f"SAMPLING SETTINGS:\n\n {sampling_setting}\n\n")
print("\n\nSTART SAMPLING\n\n")
globals()[meta_setting["sampling_function"]](model=framework,device=device ,**sampling_setting,)
globals()[meta_setting["sampling_function"]](model=framework,device=device, dataloader = test_dataloader, **sampling_setting,)
print("\n\nFINISHED SAMPLING\n\n")
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment