Skip to content
Snippets Groups Projects
Commit 3d557f77 authored by Gonzalo Martin Garcia's avatar Gonzalo Martin Garcia
Browse files

solved minor bugs, changes to readme, updated sampling functions for readability

parent 5dd1bc7b
No related branches found
No related tags found
No related merge requests found
Showing
with 87 additions and 59 deletions
......@@ -2,18 +2,41 @@
This repository contains the pipeline for training, sampling, and evaluation of our Conditional Diffusion Model on the HPC. It builds upon our code from the unconditional-diffusion repository.
We demonstrate our results by training conditional diffusion models to solve the tasks of class conditional image generation and inpainting. For the class labeled dataset, we chose to use the Animal Face (AFHQ) dataset containing three classes; dog, cat, and wildlife, each of them with a representation of 5000 images. For the inpainting dataset, we train the model on the same Landscape dataset (LHQ) as with the unconditional diffusion model, and generate our own labels by randomly drawing black rectangle masks which the model learns to inpaint.
We demonstrate our results by training conditional diffusion models to solve the tasks of class conditional image generation and inpainting. For the class labeled dataset, we chose to use the Animal Face (AFHQ) dataset containing three classes; dog, cat, and wildlife, each of them with a representation of 5000 training images. For the inpainting dataset, we train the model on the same Landscape dataset (LHQ) as with the unconditional diffusion model, and generate our own labels by randomly drawing black rectangle masks which the model learns to inpaint.
For this purpose, we implement class and image conditioning mechanisms into our UNet. For class conditioning, we also make use of our implementation of classifier-free guided diffusion to achieve better sampling quality results.
For this purpose, we implement class and image conditioning mechanisms into our UNet. For class conditioning, we also make use of our implementation of classifier-free guided diffusion to achieve better sample quality results.
These techniques are presented in the papers **GLIDE: Towards Photorealistic Image Generation and Editing with Text-Guided Diffusion Models** by Nichol et al. and **Classifier-Free Diffusion Guidance** by Ho and Salamatin.
## Motivation
Conditional diffusion models are ...
## Sample Examples
...
Unconditional diffusion models are great at learning the image distributions present in the data. However, apart from the choice of the dataset, the user has no control over the individual images the diffusion model generates. By training on a labeled dataset, we can now inject some input with desired specifications which the conditional diffusion models conform to. For class conditional image generation, we are able to choose a class from which the image is drawn from, while for image inpainting we can provide any partially obscured image to be reimagined.
## Samples
### Class Conditional Image Generation:
<table>
<tr>
<td><img src="imgs/afhq/sample_2_41_1.png" alt="Dog"></td>
<td><img src="imgs/afhq/sample_2_83_1.png" alt="Dog"></td>
<td><img src="imgs/afhq/sample_2_1092_1.png" alt="Dog"></td>
<td><img src="imgs/afhq/sample_2_1034_1.png" alt="Dog"></td>
</tr>
<tr>
<td><img src="imgs/afhq/sample_2_1268_0.png" alt="Cat"></td>
<td><img src="imgs/afhq/sample_2_920_0.png" alt="Cat"></td>
<td><img src="imgs/afhq/sample_2_2484_0.png" alt="Cat"></td>
<td><img src="imgs/afhq/sample_2_1626_0.png" alt="Cat"></td>
</tr>
<tr>
<td><img src="imgs/afhq/sample_2_917_2.png" alt="Wild"></td>
<td><img src="imgs/afhq/sample_2_1321_2.png" alt="Wild"></td>
<td><img src="imgs/afhq/sample_2_3309_2.png" alt="Wild"></td>
<td><img src="imgs/afhq/sample_2_2579_2.png" alt="Wild"></td>
</tr>
</table>
### Image Inpainting:
show grid examples in imgs/paint
## Recreating Results
......@@ -40,10 +63,12 @@ otherwise, the sampling will be done with randomly initialized weights.
3. Within the repository folder, run ```python main.py sample <path to experimentfolder>/settings```
### Model Evaluation
1. Make sure that the checkpoint file is within the **trained_cdm** folder within the experiment folder. Alternatively, one can create this folder manually and add the checkpoint file.
2. Also make sure that the correct checkpoint name is given in the JSON file ```settings/evaluation_settings.json```
otherwise, the sampling will be done with randomly initialized weights.
3. Within the repository folder, run ```python main.py evaluate <path to experimentfolder>/settings```
...
## Description
## Pipeline Description
This repository houses our comprehensive pipeline, designed to conveniently train, sample from, and evaluate our Conditional diffusion model.
The pipeline is initiated via the experiment_creator.ipynb notebook, which may be separately run on our local machine. This notebook allows for the configuration of every aspect of the diffusion model, including all hyperparameters. These configurations extend to the underlying neural backbone UNet, as well as the training parameters, such as training from checkpoint, Weights & Biases run name for resumption, optimizer selection, adjustment of the CosineAnnealingLR learning rate schedule parameters, and more. Moreover, it includes parameters for evaluating and sampling images via a trained diffusion model.
......
......@@ -81,6 +81,8 @@ class ConditionalDataset_LHQ_Paint(Dataset):
frac (float): value within (0,1] (seeded)random shuffles dataset, then divides into train and test set.
"""
self.img_size = img_size
### Create DataFrame
file_list = []
for root, dirs, files in os.walk(fpath, topdown=False):
......@@ -138,8 +140,8 @@ class ConditionalDataset_LHQ_Paint(Dataset):
max_y = img_tensor.shape[2] - min_width
x = np.random.randint(0, max_x)
y = np.random.randint(0, max_y)
max_height = min(img_size, img_tensor.shape[1] - x)
max_width = min(img_size, img_tensor.shape[2] - y)
max_height = min(self.img_size, img_tensor.shape[1] - x)
max_width = min(self.img_size, img_tensor.shape[2] - y)
rect_height = torch.randint(min_height, max_height, (1,)).item()
rect_width = torch.randint(min_width, max_width, (1,)).item()
# create copy of image and add blacked out rectangle
......
......@@ -37,7 +37,7 @@ def cdm_evaluator(model,
if epoch:
e = int(epoch[0])
else:
raise ValueError(f"No digit found in the filename: {filename}")
raise ValueError(f"No digit found in the filename")
model_dir = os.path.join(output_dir,f'epoch_{e}')
os.makedirs(model_dir, exist_ok=True)
......
......@@ -29,7 +29,7 @@ def cdm_sampler_afhq_class(model, checkpoint, experiment_path, dataloader, devic
checkpoint_list = [checkpoint_i for checkpoint_i in os.listdir(f) if checkpoint_i.endswith(".pth")]
for checkpoint_i in os.listdir(f):
if checkpoint_i.endswith(".pth"):
ldm_sampler(model, checkpoint_i, experiment_path, dataloader, device, batch_size=batch_size, intermediate=intermediate, sample_all=False)
cdm_sampler_afhq_class(model, checkpoint_i, experiment_path, dataloader, device, batch_size=batch_size, intermediate=intermediate, n_times=n_times, sample_all=False)
return 0
# load model
......@@ -42,6 +42,8 @@ def cdm_sampler_afhq_class(model, checkpoint, experiment_path, dataloader, devic
model = model.to(device)
except Exception as e:
print("Error loading checkpoint. Exception:", e)
raise
# create samples directory for the complete experiment (if first time sampling images)
output_dir = f'{experiment_path}samples/'
......@@ -84,6 +86,7 @@ def cdm_sampler_afhq_class(model, checkpoint, experiment_path, dataloader, devic
image.save(image_path)
except Exception as e:
print("Error saving image. Exception:", e)
raise
run_indx += 1
......@@ -114,7 +117,7 @@ def cdm_sampler_lhq_paint(model, checkpoint, experiment_path, dataloader, device
checkpoint_list = [checkpoint_i for checkpoint_i in os.listdir(f) if checkpoint_i.endswith(".pth")]
for checkpoint_i in os.listdir(f):
if checkpoint_i.endswith(".pth"):
ldm_sampler(model, checkpoint_i, experiment_path, dataloader, device, batch_size=batch_size, intermediate=intermediate, sample_all=False)
cdm_sampler_lhq_paint(model, checkpoint_i, experiment_path, dataloader, device, batch_size=batch_size, intermediate=intermediate, n_times=n_times, sample_all=False)
return 0
# load model
......@@ -127,6 +130,7 @@ def cdm_sampler_lhq_paint(model, checkpoint, experiment_path, dataloader, device
model = model.to(device)
except Exception as e:
print("Error loading checkpoint. Exception:", e)
raise
# create samples directory for the complete experiment
output_dir = f'{experiment_path}samples/'
......@@ -169,6 +173,7 @@ def cdm_sampler_lhq_paint(model, checkpoint, experiment_path, dataloader, device
image_raw.save(image_path_raw)
except Exception as e:
print("Error saving image. Exception:", e)
raise
run_indx += 1
if run_indx>=n_times:
break
imgs/afhq/sample_2_1034_1.png

28.6 KiB

imgs/afhq/sample_2_1092_1.png

28.7 KiB

imgs/afhq/sample_2_1268_0.png

29.3 KiB

imgs/afhq/sample_2_1321_2.png

32.7 KiB

imgs/afhq/sample_2_1626_0.png

28.5 KiB

imgs/afhq/sample_2_2484_0.png

32.5 KiB

imgs/afhq/sample_2_2579_2.png

28.8 KiB

imgs/afhq/sample_2_3309_2.png

33.2 KiB

imgs/afhq/sample_2_41_1.png

29 KiB

imgs/afhq/sample_2_479_2.png

30.4 KiB

imgs/afhq/sample_2_742_1.png

29.1 KiB

imgs/afhq/sample_2_776_0.png

27 KiB

imgs/afhq/sample_2_83_1.png

26.4 KiB

imgs/afhq/sample_2_917_2.png

29.5 KiB

imgs/afhq/sample_2_920_0.png

28.2 KiB

......@@ -30,6 +30,7 @@ class CDM(nn.Module):
kl_loss: Choice between the mathematically correct 'weighted' or in practice most commonly used 'simplified' KL loss
recon_loss: Is 'none' to ignore the reconstruction loss or 'nll' to compute the negative log likelihood
class_free_guidence: Boolean flag that indicates if classifier-free guided diffusion is used for training and sampling
guidance_score: Factor used in the interpolation between UNet's conditioned and unconditioned output when sampling with cfg
'''
super(CDM,self).__init__()
self.device = device
......@@ -342,19 +343,7 @@ class CDM(nn.Module):
noise = torch.randn(x_0_recon.shape, device=self.device)
else:
noise = torch.zeros(x_0_recon.shape, device=self.device)
# set timestep batch with all entries as t
t_batch = torch.full((x_0_recon.shape[0],), t ,device = self.device)
# get denoising dist. param
if self.class_free_guidence:
# get classififer-free guided diffusion noise parameter
pred_noise_cond = self.forward(x_0_recon, t_batch, y) # param with conditioning
pred_noise_uncond = self.forward(x_0_recon, t_batch, y=None) # param without conditioning
pred_noise = torch.lerp(pred_noise_uncond, pred_noise_cond, self.guidance_score) # linear interpolation of the two
else:
pred_noise = self.forward(x_0_recon, t_batch, y)
mean, std = self.reverse_dist_param(x_0_recon, pred_noise, t_batch)
# compute the drawn denoised latent at time t
x_0_recon = mean + std * noise
x_0_recon = self.denoised_latent(noise, x_0_recon, t, y)
return x_0_recon
......@@ -388,18 +377,7 @@ class CDM(nn.Module):
noise = torch.randn(x_t_1.shape, device=self.device)
else:
noise = torch.zeros(x_t_1.shape, device=self.device)
# get denoising dist. param
t_batch = torch.full((x_t_1.shape[0],), t ,device = self.device)
if self.class_free_guidence:
# get classififer-free guided diffusion noise parameter
pred_noise_cond = self.forward(x_t_1, t_batch, y) # param with conditioning
pred_noise_uncond = self.forward(x_t_1, t_batch, y=None) # param without conditioning
pred_noise = torch.lerp(pred_noise_uncond, pred_noise_cond, self.guidance_score) # linear interpolation of the two
else:
pred_noise = self.forward(x_t_1, t_batch, y)
mean, std = self.reverse_dist_param(x_t_1, pred_noise, t_batch)
# compute the drawn denoised latent at time t
x_t_1 = mean + std*noise
x_t_1 = self.denoised_latent(noise, x_t_1, t, y)
return x_t_1
@torch.no_grad()
......@@ -426,22 +404,40 @@ class CDM(nn.Module):
noise = torch.randn(x_t_1.shape, device=self.device)
else:
noise = torch.zeros(x_t_1.shape, device=self.device)
# compute the drawn denoised latent at time t
x_t_1 = self.denoised_latent(noise, x_t_1, t, y)
# store noised image
x[t-1] = x_t_1.squeeze(0)
return x
@torch.no_grad()
def denoised_latent(self, noise, x_t_1, t,y):
'''
Computes the Gaussian reparameterization for the denoising dist. at times t given an isotopic noise parameter.
Parameters:
forward_noise (tensor): Batch of noise parameters from the noise distribution reparametrization used to draw x_t
x_t_1 (tensor): Batch of input images, with color channels assumed to be normalized between [-1,1]
t (tensor): Batch of timesteps
y (tensor): Batch of conditional information for each input image
Returns:
x_t_1 (tensor): Batch of noised images at timestep t
'''
# get denoising dist. param
t_batch = torch.full((x_t_1.shape[0],), t ,device = self.device)
if self.class_free_guidence:
# get classififer-free guided diffusion noise parameter
pred_noise_cond = self.forward(x_t_1, t_batch, y) # param with conditioning
pred_noise_uncond = self.forward(x_t_1, t_batch, y=None) # param without conditioning
# linear interpolation of the two
pred_noise = torch.lerp(pred_noise_uncond, pred_noise_cond, self.guidance_score)
pred_noise = torch.lerp(pred_noise_uncond, pred_noise_cond, self.guidance_score) # linear interpolation of the two
else:
pred_noise = self.forward(x_t_1, t_batch, y)
mean, std = self.reverse_dist_param(x_t_1, pred_noise, t_batch)
# compute the drawn denoised latent at time t
# compute the denoised latent reparametrization at time t
x_t_1 = mean + std*noise
# store noised image
x[t-1] = x_t_1.squeeze(0)
return x
return x_t_1
# Loss functions
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment