diff --git a/evaluation/sample.py b/evaluation/sample.py
index 82b65ff741a6c4b3e41c2736b34de4da53967a3a..46756227ed6c9979615d576102dfa5d75fedd478 100644
--- a/evaluation/sample.py
+++ b/evaluation/sample.py
@@ -3,7 +3,7 @@ import torch
from torchvision import transforms
import re
-def ddpm_sampler(model, checkpoint, experiment_path, device, intermediate=False, batch_size=15,sample_all=False):
+def ddpm_sampler(model, checkpoint, experiment_path, device, intermediate=False, batch_size=15,sample_all=False,n_times=1):
'''
Samples a tensor of 'batch_size' images from a trained diffusion model with 'checkpoint'. The generated
images are stored in the directory 'experiment_path/samples/epoch_{e}/sample_{j}. Where e is the epoch
@@ -25,6 +25,7 @@ def ddpm_sampler(model, checkpoint, experiment_path, device, intermediate=False,
for checkpoint_i in os.listdir(f):
if checkpoint_i.endswith(".pth"):
ddpm_sampler(model, checkpoint_i, experiment_path, device, sample_all=False)
+ return 0
# load model
try:
checkpoint_path = f'{experiment_path}trained_ddpm/{checkpoint}'
@@ -61,26 +62,30 @@ def ddpm_sampler(model, checkpoint, experiment_path, device, intermediate=False,
# transform
back2pil = transforms.Compose([transforms.Normalize(mean=(-1,-1,-1),std=(2,2,2)),transforms.ToPILImage()])
- # generate batch_size images
- if intermediate:
- generated = model.sample_intermediates_latents()
- name = 'sample_intermediate'
- else:
- generated = model.sample(batch_size=batch_size)
- name = 'sample'
+
+ n = n_times
+ for k in range(n):
+ # generate batch_size images
+ if intermediate:
+ generated = model.sample_intermediates_latents()
+ name = 'sample_intermediate'
+ else:
+ generated = model.sample(batch_size=batch_size)
+ name = 'sample'
- #store the raw generated images within the tensor
- torch.save(generated,os.path.join(sample_dir, f"image_tensor{j}"))
- #normalize to (-1,1)
- a = generated.min()
- b = generated.max()
- A,B=-1,1
- generated = (generated-a)/(b-a)*(B-A)+A
- # save generated images
- for i in range(generated.size(0)):
- image = back2pil(generated[i])
- image_path = os.path.join(sample_dir, f'{name}_{j}_{i}.png')
- try:
- image.save(image_path)
- except Exception as e:
- print("Error saving image. Exception:", e)
+ #store the raw generated images within the tensor
+ torch.save(generated,os.path.join(sample_dir, f"image_tensor{j}"))
+ #normalize to (-1,1)
+ a = generated.min()
+ b = generated.max()
+ A,B=-1,1
+ generated = (generated-a)/(b-a)*(B-A)+A
+ # save generated images
+ for i in range(generated.size(0)):
+ index = i + k*generated.size(0)
+ image = back2pil(generated[i+k*n])
+ image_path = os.path.join(sample_dir, f'{name}_{j}_{i}.png')
+ try:
+ image.save(image_path)
+ except Exception as e:
+ print("Error saving image. Exception:", e)