Skip to content
Snippets Groups Projects
Commit 5b1ea326 authored by Srijeet Roy's avatar Srijeet Roy
Browse files

sampling function with image reconstruction

parent 4eac7622
Branches
No related tags found
No related merge requests found
import os
import torch
from PIL import Image
import pickle
from torchvision import transforms
from torch.utils.data import DataLoader
import re
def ddpm_sampler(model, checkpoint, experiment_path, device, intermediate=False, batch_size=15,sample_all=False,n_times=1):
def sample_reconstruction(model, path_to_training_data, batch_size=15, device='cpu'):
transform = transforms.Compose([transforms.ToTensor(),
transforms.Normalize(mean=(0.5,0.5,0.5),std=(0.5,0.5,0.5)),
transforms.Resize(128)])
training_images = os.listdir(path_to_training_data)
# store filenames
training_names = []
# store individual image tensors
training_images_list = []
for file in training_images:
if file.endswith('.png'):
filepath = os.path.join(path_to_training_data, file)
training_names.append(file)
im = Image.open(filepath)
im = transform(im)
training_images_list.append(im)
if len(training_images_list) == batch_size:
break
print(f'current sample size: {len(training_names)}')
train = torch.stack(training_images_list).to(device)
train_dataloader = DataLoader(train, batch_size=128, shuffle=False)
for batch in train_dataloader:
generated = model.complete_trajectory(batch)
return train, generated, training_names
def ddpm_sampler(model, checkpoint, experiment_path, device, intermediate=False, reconstruction=False, batch_size=15,sample_all=False,n_times=1):
'''
Samples a tensor of 'n_times'*'batch_size' images from a trained diffusion model with 'checkpoint'. The generated
Samples a tensor of 'batch_size' images from a trained diffusion model with 'checkpoint'. The generated
images are stored in the directory 'experiment_path/samples/epoch_{e}/sample_{j}. Where e is the epoch
w.r.t. the model which we are sampling form and j is an index separating the sampled batches for every call of this
sampling function.
w.r.t. the model which we are sampling form and j is an index separating images from each call of the
sampling function for the given mode.
model: Diffusion model
checkpoint: Name of the saved pth. file containing the trained weights and biases
experiment_path: Path to the experiment directory where the samples will saved under the diectory samples
batch_size: The number of images the model samples
batch_size: The number of images to sample
intermediate: Bool value. If False the sampling function will draw a batch of images, else it will just
sample a single image, but store all the intermediate noised latents along the reverse chain
sample_all: If True, samples a batch of images for the given model at every stored checkpoint at once
......@@ -33,7 +66,7 @@ def ddpm_sampler(model, checkpoint, experiment_path, device, intermediate=False,
# load model
try:
checkpoint_path = f'{experiment_path}trained_ddpm/{checkpoint}'
checkpoint = torch.load(checkpoint_path)
checkpoint = torch.load(checkpoint_path, map_location='cpu')
# load weights and biases of the U-Net
net_state_dict = checkpoint['model']
model.net.load_state_dict(net_state_dict)
......@@ -41,17 +74,19 @@ def ddpm_sampler(model, checkpoint, experiment_path, device, intermediate=False,
except Exception as e:
print("Error loading checkpoint. Exception:", e)
# create samples directory for the complete experiment
# create samples directory for the complete experiment (if first time sampling images)
output_dir = f'{experiment_path}samples/'
#output_dir = os.path.join(experiment_path,'/samples/')
os.makedirs(output_dir, exist_ok=True)
# create sample directory for the current checkpoint epoch of the trained model
# create sample directory for the current version of the trained model
model_name = os.path.basename(checkpoint_path)
epoch = re.findall(r'\d+', model_name)
if epoch:
e = int(epoch[0])
else:
raise ValueError(f"No digit found in the filename: {filename}")
#raise ValueError(f"No digit found in the filename: {filename}")
raise ValueError(f"No digit found in the filename: {model_name}")
model_dir = os.path.join(output_dir,f'epoch_{e}')
os.makedirs(model_dir, exist_ok=True)
......@@ -68,24 +103,48 @@ def ddpm_sampler(model, checkpoint, experiment_path, device, intermediate=False,
n = n_times
for k in range(n):
# generate batch_size images
'''
if intermediate:
generated = model.sample_intermediates_latents()
name = 'sample_intermediate'
else:
generated = model.sample(batch_size=batch_size)
name = 'sample'
'''
if intermediate:
generated = model.sample_intermediates_latents()
name = 'sample_intermediate'
elif reconstruction:
path_to_training_data = '/Users/roy/Desktop/Workspace/RWTH/SoSe 2023/Deep Learning Lab/DLL_vsc/data/lhq/train'
train, generated, training_names = sample_reconstruction(model, path_to_training_data, batch_size=batch_size, device=device)
with open('training_name_list', 'wb') as fp:
pickle.dump(training_names, fp)
name = 'reconstruction'
else:
generated = model.sample(batch_size=batch_size)
name = 'sample'
#store the raw generated tensor
#store the raw generated images within the tensor
torch.save(generated,os.path.join(sample_dir, f"image_tensor{j}"))
#normalize to (-1,1), not needed after 70 epochs, model learns to adhere to (-1,1)
#a = generated.min()
#b = generated.max()
#A,B=-1,1
#generated = (generated-a)/(b-a)*(B-A)+A
# transform tesnors to pil and save generated images
#normalize to (-1,1)
a = generated.min()
b = generated.max()
A,B=-1,1
generated = (generated-a)/(b-a)*(B-A)+A
# save training images
if 'train' in locals():
for i in range(train.size(0)):
index = i + k*train.size(0)
image = back2pil(train[i+k*n])
image_path = os.path.join(sample_dir, f'training_{j}_{i}_{training_names[i]}.png')
try:
image.save(image_path)
except Exception as e:
print("Error saving image. Exception:", e)
# save generated images
for i in range(generated.size(0)):
index = i + k*generated.size(0)
image = back2pil(generated[i+k*n])
image_path = os.path.join(sample_dir, f'{name}_{j}_{i}.png')
image_path = os.path.join(sample_dir, f'{name}_{j}_{i}_{training_names[i]}.png')
try:
image.save(image_path)
except Exception as e:
......
import os
import torch
from torchvision import transforms
import re
def ddpm_sampler(model, checkpoint, experiment_path, device, intermediate=False, batch_size=15,sample_all=False,n_times=1):
'''
Samples a tensor of 'n_times'*'batch_size' images from a trained diffusion model with 'checkpoint'. The generated
images are stored in the directory 'experiment_path/samples/epoch_{e}/sample_{j}. Where e is the epoch
w.r.t. the model which we are sampling form and j is an index separating the sampled batches for every call of this
sampling function.
model: Diffusion model
checkpoint: Name of the saved pth. file containing the trained weights and biases
experiment_path: Path to the experiment directory where the samples will saved under the diectory samples
batch_size: The number of images the model samples
intermediate: Bool value. If False the sampling function will draw a batch of images, else it will just
sample a single image, but store all the intermediate noised latents along the reverse chain
sample_all: If True, samples a batch of images for the given model at every stored checkpoint at once
n_times: Integer denoting how many times we draw a batch of 'batch_size'. If we want to draw 10k images
the GPU will draw batches of 512 images 20 times to reach this goal.
'''
# If we want to sample from every checkpoint of the current model, recursively call this function for all checkpoints
if sample_all:
f = f'{experiment_path}trained_ddpm/'
checkpoint_list = [checkpoint_i for checkpoint_i in os.listdir(f) if checkpoint_i.endswith(".pth")]
for checkpoint_i in os.listdir(f):
if checkpoint_i.endswith(".pth"):
ddpm_sampler(model, checkpoint_i, experiment_path, device, sample_all=False)
return 0
# load model
try:
checkpoint_path = f'{experiment_path}trained_ddpm/{checkpoint}'
checkpoint = torch.load(checkpoint_path)
# load weights and biases of the U-Net
net_state_dict = checkpoint['model']
model.net.load_state_dict(net_state_dict)
model = model.to(device)
except Exception as e:
print("Error loading checkpoint. Exception:", e)
# create samples directory for the complete experiment
output_dir = f'{experiment_path}samples/'
os.makedirs(output_dir, exist_ok=True)
# create sample directory for the current checkpoint epoch of the trained model
model_name = os.path.basename(checkpoint_path)
epoch = re.findall(r'\d+', model_name)
if epoch:
e = int(epoch[0])
else:
raise ValueError(f"No digit found in the filename: {filename}")
model_dir = os.path.join(output_dir,f'epoch_{e}')
os.makedirs(model_dir, exist_ok=True)
# create the sample directory for this sampling run for the current version of the model
sample_dir_list = [d for d in os.listdir(model_dir) if os.path.isdir(os.path.join(model_dir, d))]
indx_list = [int(d.split('_')[1]) for d in sample_dir_list if d.startswith('sample_')]
j = max(indx_list, default=-1) + 1
sample_dir = os.path.join(model_dir, f'sample_{j}')
os.makedirs(sample_dir, exist_ok=True)
# transform
back2pil = transforms.Compose([transforms.Normalize(mean=(-1,-1,-1),std=(2,2,2)),transforms.ToPILImage()])
n = n_times
for k in range(n):
# generate batch_size images
if intermediate:
generated = model.sample_intermediates_latents()
name = 'sample_intermediate'
else:
generated = model.sample(batch_size=batch_size)
name = 'sample'
#store the raw generated tensor
torch.save(generated,os.path.join(sample_dir, f"image_tensor{j}"))
#normalize to (-1,1), not needed after 70 epochs, model learns to adhere to (-1,1)
#a = generated.min()
#b = generated.max()
#A,B=-1,1
#generated = (generated-a)/(b-a)*(B-A)+A
# transform tesnors to pil and save generated images
for i in range(generated.size(0)):
index = i + k*generated.size(0)
image = back2pil(generated[i+k*n])
image_path = os.path.join(sample_dir, f'{name}_{j}_{i}.png')
try:
image.save(image_path)
except Exception as e:
print("Error saving image. Exception:", e)
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment