Skip to content
Snippets Groups Projects
Commit 04d7a206 authored by Gonzalo Martin Garcia's avatar Gonzalo Martin Garcia
Browse files

solved minor bugs, added images to README

parent 056a9793
No related branches found
No related tags found
No related merge requests found
Showing
with 68 additions and 29 deletions
# Diffusion_Project
# Unconditional Diffusion model
A partial reimplemnetation of the paper *papername*. The pipeline contains the training and sampling for an image-to-image conditioned LDM model.
This repo presents our paper implementation of the unconditional diffusion model ,with its popular optimizations, from *Denoising Diffusion Probabilistic Models* by Ho et al and *Improved Denoising Diffusion Probabilistic Models* by Nichol and Dhariwal. The pipeline contains training, sampling and evaluation functions meant to be run on the HPC.
We show correctness of our pipeline by training unconditional diffusion models on the landscapes (LHQ) and celebrity A (CelebAHQ) datasets, generating realistic images with a resolution of 128x128px.
## Background
Unconditional diffusion models are ...
Diffusion models are a class of generative models that offer a unique approach to modeling complex data distributions by simulating a stochastic process, known as a diffusion process, that gradually transforms data from a simple initial distribution into a complex data distribution. More specifically, the simple distribution is given by Gaussian Noise which is iteratively denoised into coherent images through modeling the entire data distribution present in the training set.
## Sample Examples
...
### Unconditional Landscape Generation:
### Unconditional Celebrity Face Generation:
<table>
<tr>
<td><img src="imgs/celebAHQ/sample_0_97.png" alt="celeb"></td>
<td><img src="imgs/celebAHQ/sample_0_2.png" alt="celeb"></td>
<td><img src="imgs/celebAHQ/sample_4_3852.png" alt="celeb"></td>
<td><img src="imgs/celebAHQ/sample_0_53.png" alt="celeb"></td>
<td><img src="imgs/celebAHQ/sample_0_142.png" alt="celeb"></td>
</tr>
<tr>
<td><img src="imgs/celebAHQ/sample_0_117.png" alt="celeb"></td>
<td><img src="imgs/celebAHQ/sample_0_84.png" alt="celeb"></td>
<td><img src="imgs/celebAHQ/sample_0_41.png" alt="celeb"></td>
<td><img src="imgs/celebAHQ/sample_0_20.png" alt="celeb"></td>
<td><img src="imgs/celebAHQ/sample_0_48.png" alt="celeb"></td>
</tr>
<tr>
<td><img src="imgs/celebAHQ/sample_0_22.png" alt="celeb"></td>
<td><img src="imgs/celebAHQ/sample_0_36.png" alt="celeb"></td>
<td><img src="imgs/celebAHQ/sample_5_3160.png" alt="celeb"></td>
<td><img src="imgs/celebAHQ/sample_5_3484.png" alt="celeb"></td>
<td><img src="imgs/celebAHQ/sample_0_111.png" alt="celeb"></td>
</tr>
</table>
## Recreating Results
We used the following modules:
......@@ -40,8 +63,7 @@ To evaluate the performance of the unconditional diffusion model:
3. For a detailed overview on evaluation metrics, refer to [evaluation_readme](evaluation/evaluation_readme.md).
## Comprehensive Description
## Pipeline Description
This repository houses our comprehensive pipeline, designed to conveniently train, sample from, and evaluate our unconditional diffusion model.
The pipeline is initiated via the experiment_creator.ipynb notebook, which is separately run our local machine. This notebook allows for the configuration of every aspect of the diffusion model, including all hyperparameters. These configurations extend to the underlying neural backbone UNet, as well as the training parameters, such as training from checkpoint, Weights & Biases run name for resumption, optimizer selection, adjustment of the CosineAnnealingLR learning rate schedule parameters, and more. Moreover, it includes parameters for evaluating a and sampling images via a trained diffusion models.
......
......@@ -60,7 +60,7 @@ def ddpm_sampler(model, checkpoint, experiment_path, device, intermediate=False,
checkpoint_list = [checkpoint_i for checkpoint_i in os.listdir(f) if checkpoint_i.endswith(".pth")]
for checkpoint_i in os.listdir(f):
if checkpoint_i.endswith(".pth"):
ddpm_sampler(model, checkpoint_i, experiment_path, device, sample_all=False)
ddpm_sampler(model, checkpoint_i, experiment_path, device, intermediate=intermediate, n_times=n_times, reconstruction=reconstruction, batch_size=batch_size, sample_all=False)
return 0
# load model
......@@ -73,6 +73,7 @@ def ddpm_sampler(model, checkpoint, experiment_path, device, intermediate=False,
model = model.to(device)
except Exception as e:
print("Error loading checkpoint. Exception:", e)
raise
# create samples directory for the complete experiment (if first time sampling images)
output_dir = f'{experiment_path}samples/'
......@@ -85,7 +86,6 @@ def ddpm_sampler(model, checkpoint, experiment_path, device, intermediate=False,
if epoch:
e = int(epoch[0])
else:
#raise ValueError(f"No digit found in the filename: {filename}")
raise ValueError(f"No digit found in the filename: {model_name}")
model_dir = os.path.join(output_dir,f'epoch_{e}')
os.makedirs(model_dir, exist_ok=True)
......@@ -125,11 +125,14 @@ def ddpm_sampler(model, checkpoint, experiment_path, device, intermediate=False,
name = 'sample'
#store the raw generated images within the tensor
torch.save(generated,os.path.join(sample_dir, f"image_tensor{j}"))
# If additional normalization is desired
#normalize to (-1,1)
a = generated.min()
b = generated.max()
A,B=-1,1
generated = (generated-a)/(b-a)*(B-A)+A
#a = generated.min()
#b = generated.max()
#A,B=-1,1
#generated = (generated-a)/(b-a)*(B-A)+A
#clip to (-1,1)
#generated = generated.clamp(-1, 1)
# save training images
if 'train' in locals():
for i in range(train.size(0)):
......@@ -149,3 +152,4 @@ def ddpm_sampler(model, checkpoint, experiment_path, device, intermediate=False,
image.save(image_path)
except Exception as e:
print("Error saving image. Exception:", e)
raise
imgs/celebAHQ/sample_0_111.png

25.4 KiB

imgs/celebAHQ/sample_0_117.png

23.7 KiB

imgs/celebAHQ/sample_0_142.png

28.7 KiB

imgs/celebAHQ/sample_0_2.png

26.9 KiB

imgs/celebAHQ/sample_0_20.png

25.4 KiB

imgs/celebAHQ/sample_0_22.png

25.7 KiB

imgs/celebAHQ/sample_0_36.png

28.4 KiB

imgs/celebAHQ/sample_0_41.png

27.4 KiB

imgs/celebAHQ/sample_0_48.png

24.2 KiB

imgs/celebAHQ/sample_0_53.png

25.5 KiB

imgs/celebAHQ/sample_0_84.png

27.1 KiB

imgs/celebAHQ/sample_0_97.png

26 KiB

imgs/celebAHQ/sample_4_3852.png

30.8 KiB

imgs/celebAHQ/sample_5_3160.png

27.7 KiB

imgs/celebAHQ/sample_5_3484.png

26.3 KiB

......@@ -76,7 +76,6 @@ def sample_func(f):
sampling_setting = json.load(fp)
# init Unet
batchsize = meta_setting["batchsize"]
net = globals()[meta_setting["modelname"]](**model_setting).to(device)
#net = torch.compile(net)
net = net.to(device)
......
......@@ -319,10 +319,8 @@ class DDPM(nn.Module):
noise = torch.randn(x_0_recon.shape, device=self.device)
else:
noise = torch.zeros(x_0_recon.shape, device=self.device)
# get denoising dist. param
mean, std, _ = self.forward(x_0_recon, torch.full((x_0_recon.shape[0],), t ,device = self.device))
# compute the drawn denoised latent at time t
x_0_recon = mean + std * noise
x_0_recon = self.denoised_latent(noise, x_0_recon, t)
return x_0_recon
......@@ -355,10 +353,8 @@ class DDPM(nn.Module):
noise = torch.randn(x_t_1.shape, device=self.device)
else:
noise = torch.zeros(x_t_1.shape, device=self.device)
# get denoising dist. param
mean, std, _ = self.forward(x_t_1, torch.full((x_t_1.shape[0],), t ,device = self.device))
# compute the drawn densoined latent at time t
x_t_1 = mean + std*noise
x_t_1 = self.denoised_latent(noise, x_t_1, t)
return x_t_1
@torch.no_grad()
......@@ -382,15 +378,33 @@ class DDPM(nn.Module):
noise = torch.randn(x_t_1.shape, device=self.device)
else:
noise = torch.zeros(x_t_1.shape, device=self.device)
# get denoising dist. param
mean, std, _ = self.forward(x_t_1, torch.full((x_t_1.shape[0],), t ,device = self.device))
# compute the drawn densoined latent at time t
x_t_1 = mean + std*noise
x_t_1 = self.denoised_latent(noise, x_t_1, t)
# store noised image
x[t-1] = x_t_1.squeeze(0)
return x
@torch.no_grad()
def denoised_latent(self, noise, x_t_1, t):
'''
Computes the Gaussian reparameterization for the denoising dist. at times t given an isotopic noise parameter.
Parameters:
forward_noise (tensor): Batch of noise parameters from the noise distribution reparametrization used to draw x_t
x_t_1 (tensor): Batch of input images, with color channels assumed to be normalized between [-1,1]
t (tensor): Batch of timesteps
Returns:
x_t_1 (tensor): Batch of noised images at timestep t-1
'''
# get denoising dist. param
mean, std, _ = self.forward(x_t_1, torch.full((x_t_1.shape[0],), t ,device = self.device))
# compute the drawn densoined latent at time t
x_t_1 = mean + std*noise
return x_t_1
# Loss functions
def loss_simplified(self, forward_noise, pred_noise, t=None):
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment