diff --git a/evaluation/sample.py b/evaluation/sample.py index 82b65ff741a6c4b3e41c2736b34de4da53967a3a..46756227ed6c9979615d576102dfa5d75fedd478 100644 --- a/evaluation/sample.py +++ b/evaluation/sample.py @@ -3,7 +3,7 @@ import torch from torchvision import transforms import re -def ddpm_sampler(model, checkpoint, experiment_path, device, intermediate=False, batch_size=15,sample_all=False): +def ddpm_sampler(model, checkpoint, experiment_path, device, intermediate=False, batch_size=15,sample_all=False,n_times=1): ''' Samples a tensor of 'batch_size' images from a trained diffusion model with 'checkpoint'. The generated images are stored in the directory 'experiment_path/samples/epoch_{e}/sample_{j}. Where e is the epoch @@ -25,6 +25,7 @@ def ddpm_sampler(model, checkpoint, experiment_path, device, intermediate=False, for checkpoint_i in os.listdir(f): if checkpoint_i.endswith(".pth"): ddpm_sampler(model, checkpoint_i, experiment_path, device, sample_all=False) + return 0 # load model try: checkpoint_path = f'{experiment_path}trained_ddpm/{checkpoint}' @@ -61,26 +62,30 @@ def ddpm_sampler(model, checkpoint, experiment_path, device, intermediate=False, # transform back2pil = transforms.Compose([transforms.Normalize(mean=(-1,-1,-1),std=(2,2,2)),transforms.ToPILImage()]) - # generate batch_size images - if intermediate: - generated = model.sample_intermediates_latents() - name = 'sample_intermediate' - else: - generated = model.sample(batch_size=batch_size) - name = 'sample' + + n = n_times + for k in range(n): + # generate batch_size images + if intermediate: + generated = model.sample_intermediates_latents() + name = 'sample_intermediate' + else: + generated = model.sample(batch_size=batch_size) + name = 'sample' - #store the raw generated images within the tensor - torch.save(generated,os.path.join(sample_dir, f"image_tensor{j}")) - #normalize to (-1,1) - a = generated.min() - b = generated.max() - A,B=-1,1 - generated = (generated-a)/(b-a)*(B-A)+A - # save generated images - for i in range(generated.size(0)): - image = back2pil(generated[i]) - image_path = os.path.join(sample_dir, f'{name}_{j}_{i}.png') - try: - image.save(image_path) - except Exception as e: - print("Error saving image. Exception:", e) + #store the raw generated images within the tensor + torch.save(generated,os.path.join(sample_dir, f"image_tensor{j}")) + #normalize to (-1,1) + a = generated.min() + b = generated.max() + A,B=-1,1 + generated = (generated-a)/(b-a)*(B-A)+A + # save generated images + for i in range(generated.size(0)): + index = i + k*generated.size(0) + image = back2pil(generated[i+k*n]) + image_path = os.path.join(sample_dir, f'{name}_{j}_{i}.png') + try: + image.save(image_path) + except Exception as e: + print("Error saving image. Exception:", e)