Select Git revision
ResourceController.cs
Code owners
Assign users and groups as approvers for specific file changes. Learn more.
sample.py 4.53 KiB
import os
import torch
from torchvision import transforms
import re
def ldm_sampler(model, checkpoint, experiment_path, dataloader, device, intermediate=False, batch_size=15,sample_all=False,n_times=1):
'''
Samples a tensor of 'batch_size' images from a trained diffusion model with 'checkpoint'. The generated
images are stored in the directory 'experiment_path/samples/epoch_{e}/sample_{j}. Where e is the epoch
w.r.t. the model which we are sampling form and j is an index separating images from each call of the
sampling function for the given mode.
model: Diffusion model
checkpoint: Name of the saved pth. file containing the trained weights and biases
conditioning_path: Path to the conditioning input tensors from whihc the samples will be generated from
experiment_path: Path to the experiment directory where the samples will saved under the diectory samples
batch_size: The number of images to sample
intermediate: Bool value. If False the sampling function will draw a batch of images, else it will just
sample a single image, but store all the intermediate noised latents along the reverse chain
sample_all: If True, samples a batch of images for the given model at every stored checkpoint at once
n_times: Integer denoting how many times we draw a batch of 'batch_size'. If we want to draw 10k images
the GPU will draw batches of 512 images 20 times to reach this goal.
'''
# If we want to sample from every checkpoint of the current model, recursively call this function for all checkpoints
if sample_all:
f = f'{experiment_path}trained_ldm/'
checkpoint_list = [checkpoint_i for checkpoint_i in os.listdir(f) if checkpoint_i.endswith(".pth")]
for checkpoint_i in os.listdir(f):
if checkpoint_i.endswith(".pth"):
ldm_sampler(model, checkpoint_i, experiment_path, dataloader, device, batch_size=batch_size, intermediate=intermediate, sample_all=False)
return 0
# load model
try:
checkpoint_path = f'{experiment_path}trained_ldm/{checkpoint}'
checkpoint = torch.load(checkpoint_path)
# load weights and biases of the U-Net
net_state_dict = checkpoint['model']
model.net.load_state_dict(net_state_dict)
model = model.to(device)
except Exception as e:
print("Error loading checkpoint. Exception:", e)
model.vq_model.to("cuda")
# create samples directory for the complete experiment (if first time sampling images)
output_dir = f'{experiment_path}samples/'
#output_dir = os.path.join(experiment_path,'/samples/')
os.makedirs(output_dir, exist_ok=True)
# create sample directory for the current version of the trained model
model_name = os.path.basename(checkpoint_path)
epoch = re.findall(r'\d+', model_name)
if epoch:
e = int(epoch[0])
else:
#raise ValueError(f"No digit found in the filename: {filename}")
e = 0
model_dir = os.path.join(output_dir,f'epoch_{e}')
os.makedirs(model_dir, exist_ok=True)
# create the sample directory for this sampling run for the current version of the model
sample_dir_list = [d for d in os.listdir(model_dir) if os.path.isdir(os.path.join(model_dir, d))]
indx_list = [int(d.split('_')[1]) for d in sample_dir_list if d.startswith('sample_')]
j = max(indx_list, default=-1) + 1
sample_dir = os.path.join(model_dir, f'sample_{j}')
os.makedirs(sample_dir, exist_ok=True)
# transform
back2pil = transforms.Compose([transforms.Lambda(lambda x: (x-x.min())/(x.max()-x.min())),transforms.ToPILImage()])
# sample upscaling latent encoding, decode and transform to PIL
run_indx=0
for idx,(_, y) in enumerate(dataloader):
y = y.to(device)
generated = model.sample(y=y, batch_size=y.size(0))
#torch.save(generated,os.path.join(sample_dir, f"image_tensor{idx}"))
# save images
for i in range(generated.size(0)):
image = back2pil(generated[i])
image_path = os.path.join(sample_dir, f'sample_{j}_{run_indx}.png')
image_raw = back2pil(y[i])
image_path_raw = os.path.join(sample_dir, f'raw_{j}_{run_indx}.png')
try:
image.save(image_path)
image_raw.save(image_path_raw)
except Exception as e:
print("Error saving image. Exception:", e)
run_indx = run_indx+1
if run_indx >9:
break