Skip to content
Snippets Groups Projects
Commit 1f5e2446 authored by Tobias Seibel's avatar Tobias Seibel
Browse files

added n times sampling possibility

parent cf7fe26e
No related branches found
No related tags found
No related merge requests found
......@@ -3,7 +3,7 @@ import torch
from torchvision import transforms
import re
def ddpm_sampler(model, checkpoint, experiment_path, device, intermediate=False, batch_size=15,sample_all=False):
def ddpm_sampler(model, checkpoint, experiment_path, device, intermediate=False, batch_size=15,sample_all=False,n_times=1):
'''
Samples a tensor of 'batch_size' images from a trained diffusion model with 'checkpoint'. The generated
images are stored in the directory 'experiment_path/samples/epoch_{e}/sample_{j}. Where e is the epoch
......@@ -25,6 +25,7 @@ def ddpm_sampler(model, checkpoint, experiment_path, device, intermediate=False,
for checkpoint_i in os.listdir(f):
if checkpoint_i.endswith(".pth"):
ddpm_sampler(model, checkpoint_i, experiment_path, device, sample_all=False)
return 0
# load model
try:
checkpoint_path = f'{experiment_path}trained_ddpm/{checkpoint}'
......@@ -61,6 +62,9 @@ def ddpm_sampler(model, checkpoint, experiment_path, device, intermediate=False,
# transform
back2pil = transforms.Compose([transforms.Normalize(mean=(-1,-1,-1),std=(2,2,2)),transforms.ToPILImage()])
n = n_times
for k in range(n):
# generate batch_size images
if intermediate:
generated = model.sample_intermediates_latents()
......@@ -78,7 +82,8 @@ def ddpm_sampler(model, checkpoint, experiment_path, device, intermediate=False,
generated = (generated-a)/(b-a)*(B-A)+A
# save generated images
for i in range(generated.size(0)):
image = back2pil(generated[i])
index = i + k*generated.size(0)
image = back2pil(generated[i+k*n])
image_path = os.path.join(sample_dir, f'{name}_{j}_{i}.png')
try:
image.save(image_path)
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment