Select Git revision
-
Paolo Bonzini authored
All files under GPLv2 will get GPLv2+ changes starting tomorrow. event_notifier.c and exec-obsolete.h were only ever touched by Red Hat employees and can be relicensed now. Signed-off-by:
Paolo Bonzini <pbonzini@redhat.com> Signed-off-by:
Anthony Liguori <aliguori@us.ibm.com>
Paolo Bonzini authoredAll files under GPLv2 will get GPLv2+ changes starting tomorrow. event_notifier.c and exec-obsolete.h were only ever touched by Red Hat employees and can be relicensed now. Signed-off-by:
Paolo Bonzini <pbonzini@redhat.com> Signed-off-by:
Anthony Liguori <aliguori@us.ibm.com>
Code owners
Assign users and groups as approvers for specific file changes. Learn more.
sample.py 4.53 KiB
import os
import torch
from torchvision import transforms
import re
def ldm_sampler(model, checkpoint, experiment_path, dataloader, device, intermediate=False, batch_size=15,sample_all=False,n_times=1):
'''
Samples a tensor of 'batch_size' images from a trained diffusion model with 'checkpoint'. The generated
images are stored in the directory 'experiment_path/samples/epoch_{e}/sample_{j}. Where e is the epoch
w.r.t. the model which we are sampling form and j is an index separating images from each call of the
sampling function for the given mode.
model: Diffusion model
checkpoint: Name of the saved pth. file containing the trained weights and biases
conditioning_path: Path to the conditioning input tensors from whihc the samples will be generated from
experiment_path: Path to the experiment directory where the samples will saved under the diectory samples
batch_size: The number of images to sample
intermediate: Bool value. If False the sampling function will draw a batch of images, else it will just
sample a single image, but store all the intermediate noised latents along the reverse chain
sample_all: If True, samples a batch of images for the given model at every stored checkpoint at once
n_times: Integer denoting how many times we draw a batch of 'batch_size'. If we want to draw 10k images
the GPU will draw batches of 512 images 20 times to reach this goal.
'''
# If we want to sample from every checkpoint of the current model, recursively call this function for all checkpoints
if sample_all:
f = f'{experiment_path}trained_ldm/'
checkpoint_list = [checkpoint_i for checkpoint_i in os.listdir(f) if checkpoint_i.endswith(".pth")]
for checkpoint_i in os.listdir(f):
if checkpoint_i.endswith(".pth"):
ldm_sampler(model, checkpoint_i, experiment_path, dataloader, device, batch_size=batch_size, intermediate=intermediate, sample_all=False)
return 0
# load model
try:
checkpoint_path = f'{experiment_path}trained_ldm/{checkpoint}'
checkpoint = torch.load(checkpoint_path)
# load weights and biases of the U-Net
net_state_dict = checkpoint['model']
model.net.load_state_dict(net_state_dict)
model = model.to(device)
except Exception as e:
print("Error loading checkpoint. Exception:", e)
model.vq_model.to("cuda")
# create samples directory for the complete experiment (if first time sampling images)
output_dir = f'{experiment_path}samples/'
#output_dir = os.path.join(experiment_path,'/samples/')
os.makedirs(output_dir, exist_ok=True)
# create sample directory for the current version of the trained model
model_name = os.path.basename(checkpoint_path)
epoch = re.findall(r'\d+', model_name)
if epoch:
e = int(epoch[0])
else:
#raise ValueError(f"No digit found in the filename: {filename}")
e = 0
model_dir = os.path.join(output_dir,f'epoch_{e}')
os.makedirs(model_dir, exist_ok=True)
# create the sample directory for this sampling run for the current version of the model
sample_dir_list = [d for d in os.listdir(model_dir) if os.path.isdir(os.path.join(model_dir, d))]
indx_list = [int(d.split('_')[1]) for d in sample_dir_list if d.startswith('sample_')]
j = max(indx_list, default=-1) + 1
sample_dir = os.path.join(model_dir, f'sample_{j}')
os.makedirs(sample_dir, exist_ok=True)
# transform
back2pil = transforms.Compose([transforms.Lambda(lambda x: (x-x.min())/(x.max()-x.min())),transforms.ToPILImage()])
# sample upscaling latent encoding, decode and transform to PIL
run_indx=0
for idx,(_, y) in enumerate(dataloader):
y = y.to(device)
generated = model.sample(y=y, batch_size=y.size(0))
#torch.save(generated,os.path.join(sample_dir, f"image_tensor{idx}"))
# save images
for i in range(generated.size(0)):
image = back2pil(generated[i])
image_path = os.path.join(sample_dir, f'sample_{j}_{run_indx}.png')
image_raw = back2pil(y[i])
image_path_raw = os.path.join(sample_dir, f'raw_{j}_{run_indx}.png')
try:
image.save(image_path)
image_raw.save(image_path_raw)
except Exception as e:
print("Error saving image. Exception:", e)
run_indx = run_indx+1
if run_indx >9:
break