From 1f5e24469d7d0a97646619f108b3e4add158d5ff Mon Sep 17 00:00:00 2001 From: Tobias Seibel <55710042+SeibelT@users.noreply.github.com> Date: Fri, 30 Jun 2023 08:37:23 +0200 Subject: [PATCH] added n times sampling possibility --- evaluation/sample.py | 51 ++++++++++++++++++++++++-------------------- 1 file changed, 28 insertions(+), 23 deletions(-) diff --git a/evaluation/sample.py b/evaluation/sample.py index 82b65ff..4675622 100644 --- a/evaluation/sample.py +++ b/evaluation/sample.py @@ -3,7 +3,7 @@ import torch from torchvision import transforms import re -def ddpm_sampler(model, checkpoint, experiment_path, device, intermediate=False, batch_size=15,sample_all=False): +def ddpm_sampler(model, checkpoint, experiment_path, device, intermediate=False, batch_size=15,sample_all=False,n_times=1): ''' Samples a tensor of 'batch_size' images from a trained diffusion model with 'checkpoint'. The generated images are stored in the directory 'experiment_path/samples/epoch_{e}/sample_{j}. Where e is the epoch @@ -25,6 +25,7 @@ def ddpm_sampler(model, checkpoint, experiment_path, device, intermediate=False, for checkpoint_i in os.listdir(f): if checkpoint_i.endswith(".pth"): ddpm_sampler(model, checkpoint_i, experiment_path, device, sample_all=False) + return 0 # load model try: checkpoint_path = f'{experiment_path}trained_ddpm/{checkpoint}' @@ -61,26 +62,30 @@ def ddpm_sampler(model, checkpoint, experiment_path, device, intermediate=False, # transform back2pil = transforms.Compose([transforms.Normalize(mean=(-1,-1,-1),std=(2,2,2)),transforms.ToPILImage()]) - # generate batch_size images - if intermediate: - generated = model.sample_intermediates_latents() - name = 'sample_intermediate' - else: - generated = model.sample(batch_size=batch_size) - name = 'sample' + + n = n_times + for k in range(n): + # generate batch_size images + if intermediate: + generated = model.sample_intermediates_latents() + name = 'sample_intermediate' + else: + generated = model.sample(batch_size=batch_size) + name = 'sample' - #store the raw generated images within the tensor - torch.save(generated,os.path.join(sample_dir, f"image_tensor{j}")) - #normalize to (-1,1) - a = generated.min() - b = generated.max() - A,B=-1,1 - generated = (generated-a)/(b-a)*(B-A)+A - # save generated images - for i in range(generated.size(0)): - image = back2pil(generated[i]) - image_path = os.path.join(sample_dir, f'{name}_{j}_{i}.png') - try: - image.save(image_path) - except Exception as e: - print("Error saving image. Exception:", e) + #store the raw generated images within the tensor + torch.save(generated,os.path.join(sample_dir, f"image_tensor{j}")) + #normalize to (-1,1) + a = generated.min() + b = generated.max() + A,B=-1,1 + generated = (generated-a)/(b-a)*(B-A)+A + # save generated images + for i in range(generated.size(0)): + index = i + k*generated.size(0) + image = back2pil(generated[i+k*n]) + image_path = os.path.join(sample_dir, f'{name}_{j}_{i}.png') + try: + image.save(image_path) + except Exception as e: + print("Error saving image. Exception:", e) -- GitLab