diff --git a/README.md b/README.md index 380723a671e008f7e1b4bf87c484e14d1f66a14e..4af2c69b0162ee122f87943c492376498f884a14 100644 --- a/README.md +++ b/README.md @@ -1,15 +1,38 @@ -# Diffusion_Project - -A partial reimplemnetation of the paper *papername*. The pipeline contains the training and sampling for an image-to-image conditioned LDM model. - - -## Background -Unconditional diffusion models are ... +# Unconditional Diffusion model + +This repo presents our paper implementation of the unconditional diffusion model ,with its popular optimizations, from *Denoising Diffusion Probabilistic Models* by Ho et al and *Improved Denoising Diffusion Probabilistic Models* by Nichol and Dhariwal. The pipeline contains training, sampling and evaluation functions meant to be run on the HPC. + +We show correctness of our pipeline by training unconditional diffusion models on the landscapes (LHQ) and celebrity A (CelebAHQ) datasets, generating realistic images with a resolution of 128x128px. +Diffusion models are a class of generative models that offer a unique approach to modeling complex data distributions by simulating a stochastic process, known as a diffusion process, that gradually transforms data from a simple initial distribution into a complex data distribution. More specifically, the simple distribution is given by Gaussian Noise which is iteratively denoised into coherent images through modeling the entire data distribution present in the training set. ## Sample Examples -... - +### Unconditional Landscape Generation: + +### Unconditional Celebrity Face Generation: +<table> + <tr> + <td><img src="imgs/celebAHQ/sample_0_97.png" alt="celeb"></td> + <td><img src="imgs/celebAHQ/sample_0_2.png" alt="celeb"></td> + <td><img src="imgs/celebAHQ/sample_4_3852.png" alt="celeb"></td> + <td><img src="imgs/celebAHQ/sample_0_53.png" alt="celeb"></td> + <td><img src="imgs/celebAHQ/sample_0_142.png" alt="celeb"></td> + </tr> + <tr> + <td><img src="imgs/celebAHQ/sample_0_117.png" alt="celeb"></td> + <td><img src="imgs/celebAHQ/sample_0_84.png" alt="celeb"></td> + <td><img src="imgs/celebAHQ/sample_0_41.png" alt="celeb"></td> + <td><img src="imgs/celebAHQ/sample_0_20.png" alt="celeb"></td> + <td><img src="imgs/celebAHQ/sample_0_48.png" alt="celeb"></td> + </tr> + <tr> + <td><img src="imgs/celebAHQ/sample_0_22.png" alt="celeb"></td> + <td><img src="imgs/celebAHQ/sample_0_36.png" alt="celeb"></td> + <td><img src="imgs/celebAHQ/sample_5_3160.png" alt="celeb"></td> + <td><img src="imgs/celebAHQ/sample_5_3484.png" alt="celeb"></td> + <td><img src="imgs/celebAHQ/sample_0_111.png" alt="celeb"></td> + </tr> +</table> ## Recreating Results We used the following modules: @@ -40,8 +63,7 @@ To evaluate the performance of the unconditional diffusion model: 3. For a detailed overview on evaluation metrics, refer to [evaluation_readme](evaluation/evaluation_readme.md). -## Comprehensive Description - +## Pipeline Description This repository houses our comprehensive pipeline, designed to conveniently train, sample from, and evaluate our unconditional diffusion model. The pipeline is initiated via the experiment_creator.ipynb notebook, which is separately run our local machine. This notebook allows for the configuration of every aspect of the diffusion model, including all hyperparameters. These configurations extend to the underlying neural backbone UNet, as well as the training parameters, such as training from checkpoint, Weights & Biases run name for resumption, optimizer selection, adjustment of the CosineAnnealingLR learning rate schedule parameters, and more. Moreover, it includes parameters for evaluating a and sampling images via a trained diffusion models. diff --git a/dataloader/load.py b/dataloader/load.py index 7fba61ce564989371d89384ba0c3b2e63faf2121..2fb3593a714a6f92b5cf66dde4d3bec72f00a871 100644 --- a/dataloader/load.py +++ b/dataloader/load.py @@ -54,7 +54,7 @@ class UnconditionalDataset_LHQ(Dataset): self.transform = transforms.RandomChoice([transform_rotate,transform_randomcrop]) else : self.transform = transforms.Compose([transforms.ToTensor(), - transforms.Resize(img_size)]) + transforms.Resize(img_size)]) def __len__(self): return len(self.df) @@ -114,7 +114,7 @@ class UnconditionalDataset_CelebAHQ(Dataset): self.transform = transforms.RandomChoice([transform_rotate_flip,transform_flip]) else : self.transform = transforms.Compose([transforms.ToTensor(), - transforms.Resize(img_size)]) + transforms.Resize(img_size)]) def __len__(self): return len(self.df) diff --git a/evaluation/sample.py b/evaluation/sample.py index a193373cc8ead870c5dd3b3a863f32f9a0349cde..bffa00c49baac0020bd9a026d0fd809ca885064f 100644 --- a/evaluation/sample.py +++ b/evaluation/sample.py @@ -60,7 +60,7 @@ def ddpm_sampler(model, checkpoint, experiment_path, device, intermediate=False, checkpoint_list = [checkpoint_i for checkpoint_i in os.listdir(f) if checkpoint_i.endswith(".pth")] for checkpoint_i in os.listdir(f): if checkpoint_i.endswith(".pth"): - ddpm_sampler(model, checkpoint_i, experiment_path, device, sample_all=False) + ddpm_sampler(model, checkpoint_i, experiment_path, device, intermediate=intermediate, n_times=n_times, reconstruction=reconstruction, batch_size=batch_size, sample_all=False) return 0 # load model @@ -73,6 +73,7 @@ def ddpm_sampler(model, checkpoint, experiment_path, device, intermediate=False, model = model.to(device) except Exception as e: print("Error loading checkpoint. Exception:", e) + raise # create samples directory for the complete experiment (if first time sampling images) output_dir = f'{experiment_path}samples/' @@ -85,7 +86,6 @@ def ddpm_sampler(model, checkpoint, experiment_path, device, intermediate=False, if epoch: e = int(epoch[0]) else: - #raise ValueError(f"No digit found in the filename: {filename}") raise ValueError(f"No digit found in the filename: {model_name}") model_dir = os.path.join(output_dir,f'epoch_{e}') os.makedirs(model_dir, exist_ok=True) @@ -125,11 +125,14 @@ def ddpm_sampler(model, checkpoint, experiment_path, device, intermediate=False, name = 'sample' #store the raw generated images within the tensor torch.save(generated,os.path.join(sample_dir, f"image_tensor{j}")) + # If additional normalization is desired #normalize to (-1,1) - a = generated.min() - b = generated.max() - A,B=-1,1 - generated = (generated-a)/(b-a)*(B-A)+A + #a = generated.min() + #b = generated.max() + #A,B=-1,1 + #generated = (generated-a)/(b-a)*(B-A)+A + #clip to (-1,1) + #generated = generated.clamp(-1, 1) # save training images if 'train' in locals(): for i in range(train.size(0)): @@ -149,3 +152,4 @@ def ddpm_sampler(model, checkpoint, experiment_path, device, intermediate=False, image.save(image_path) except Exception as e: print("Error saving image. Exception:", e) + raise diff --git a/imgs/celebAHQ/sample_0_111.png b/imgs/celebAHQ/sample_0_111.png new file mode 100644 index 0000000000000000000000000000000000000000..14b8b76855cfff7cc4455068fb5e2cf462beb767 Binary files /dev/null and b/imgs/celebAHQ/sample_0_111.png differ diff --git a/imgs/celebAHQ/sample_0_117.png b/imgs/celebAHQ/sample_0_117.png new file mode 100644 index 0000000000000000000000000000000000000000..a83ed632785fd7c6343e0c1664b51039c4bf5b03 Binary files /dev/null and b/imgs/celebAHQ/sample_0_117.png differ diff --git a/imgs/celebAHQ/sample_0_142.png b/imgs/celebAHQ/sample_0_142.png new file mode 100644 index 0000000000000000000000000000000000000000..7e73153d5ba2ef7dd250a6f4911e71c0ac04e808 Binary files /dev/null and b/imgs/celebAHQ/sample_0_142.png differ diff --git a/imgs/celebAHQ/sample_0_2.png b/imgs/celebAHQ/sample_0_2.png new file mode 100644 index 0000000000000000000000000000000000000000..cedc94bb7a383c145a99d3620a3207e6d9f5243d Binary files /dev/null and b/imgs/celebAHQ/sample_0_2.png differ diff --git a/imgs/celebAHQ/sample_0_20.png b/imgs/celebAHQ/sample_0_20.png new file mode 100644 index 0000000000000000000000000000000000000000..79a0ce0ec8f4650e2e85db11298ea0befcabeb44 Binary files /dev/null and b/imgs/celebAHQ/sample_0_20.png differ diff --git a/imgs/celebAHQ/sample_0_22.png b/imgs/celebAHQ/sample_0_22.png new file mode 100644 index 0000000000000000000000000000000000000000..1b23c965b3d06856dcc9207d3ca1b43f5d34b469 Binary files /dev/null and b/imgs/celebAHQ/sample_0_22.png differ diff --git a/imgs/celebAHQ/sample_0_36.png b/imgs/celebAHQ/sample_0_36.png new file mode 100644 index 0000000000000000000000000000000000000000..b379f3f2899b23cc00d3398b166a3e21f2ec13fc Binary files /dev/null and b/imgs/celebAHQ/sample_0_36.png differ diff --git a/imgs/celebAHQ/sample_0_41.png b/imgs/celebAHQ/sample_0_41.png new file mode 100644 index 0000000000000000000000000000000000000000..cfc0589321ff59ba61d03b5fb98041b767a03c0a Binary files /dev/null and b/imgs/celebAHQ/sample_0_41.png differ diff --git a/imgs/celebAHQ/sample_0_48.png b/imgs/celebAHQ/sample_0_48.png new file mode 100644 index 0000000000000000000000000000000000000000..325a0f7730d0042b6bb2f3e78236236f6d5e7d24 Binary files /dev/null and b/imgs/celebAHQ/sample_0_48.png differ diff --git a/imgs/celebAHQ/sample_0_53.png b/imgs/celebAHQ/sample_0_53.png new file mode 100644 index 0000000000000000000000000000000000000000..063cdc693254e3b5801a02ca2681381204b6f57b Binary files /dev/null and b/imgs/celebAHQ/sample_0_53.png differ diff --git a/imgs/celebAHQ/sample_0_84.png b/imgs/celebAHQ/sample_0_84.png new file mode 100644 index 0000000000000000000000000000000000000000..b3bf9b4f546c7d265caf844af13e1a3f37ef204e Binary files /dev/null and b/imgs/celebAHQ/sample_0_84.png differ diff --git a/imgs/celebAHQ/sample_0_97.png b/imgs/celebAHQ/sample_0_97.png new file mode 100644 index 0000000000000000000000000000000000000000..91c3de771b0b68b371a0cd27f010fcd24926ee12 Binary files /dev/null and b/imgs/celebAHQ/sample_0_97.png differ diff --git a/imgs/celebAHQ/sample_4_3852.png b/imgs/celebAHQ/sample_4_3852.png new file mode 100644 index 0000000000000000000000000000000000000000..3ee30c1cb35057783309e45dd189c74832813b6e Binary files /dev/null and b/imgs/celebAHQ/sample_4_3852.png differ diff --git a/imgs/celebAHQ/sample_5_3160.png b/imgs/celebAHQ/sample_5_3160.png new file mode 100644 index 0000000000000000000000000000000000000000..c97fa2816f147ebb760f5cbde13f102d5d5a531a Binary files /dev/null and b/imgs/celebAHQ/sample_5_3160.png differ diff --git a/imgs/celebAHQ/sample_5_3484.png b/imgs/celebAHQ/sample_5_3484.png new file mode 100644 index 0000000000000000000000000000000000000000..0281d5f503c8774017d1169c1d23d0dc094dc9da Binary files /dev/null and b/imgs/celebAHQ/sample_5_3484.png differ diff --git a/main.py b/main.py index 3c35493f6c5e1f10a13c67b2903f333e0018aa52..e7d5ca4233a6e9805cfd3940f001486649bf1446 100644 --- a/main.py +++ b/main.py @@ -76,7 +76,6 @@ def sample_func(f): sampling_setting = json.load(fp) # init Unet - batchsize = meta_setting["batchsize"] net = globals()[meta_setting["modelname"]](**model_setting).to(device) #net = torch.compile(net) net = net.to(device) diff --git a/models/UnconditionalDiffusionModel.py b/models/UnconditionalDiffusionModel.py index 675e0fad9ab16531bf73f423c082e008f0c1130b..5de0db7edd7f6d47dece4cc7c7d3e940770e735b 100644 --- a/models/UnconditionalDiffusionModel.py +++ b/models/UnconditionalDiffusionModel.py @@ -319,10 +319,8 @@ class DDPM(nn.Module): noise = torch.randn(x_0_recon.shape, device=self.device) else: noise = torch.zeros(x_0_recon.shape, device=self.device) - # get denoising dist. param - mean, std, _ = self.forward(x_0_recon, torch.full((x_0_recon.shape[0],), t ,device = self.device)) # compute the drawn denoised latent at time t - x_0_recon = mean + std * noise + x_0_recon = self.denoised_latent(noise, x_0_recon, t) return x_0_recon @@ -355,10 +353,8 @@ class DDPM(nn.Module): noise = torch.randn(x_t_1.shape, device=self.device) else: noise = torch.zeros(x_t_1.shape, device=self.device) - # get denoising dist. param - mean, std, _ = self.forward(x_t_1, torch.full((x_t_1.shape[0],), t ,device = self.device)) # compute the drawn densoined latent at time t - x_t_1 = mean + std*noise + x_t_1 = self.denoised_latent(noise, x_t_1, t) return x_t_1 @torch.no_grad() @@ -382,15 +378,33 @@ class DDPM(nn.Module): noise = torch.randn(x_t_1.shape, device=self.device) else: noise = torch.zeros(x_t_1.shape, device=self.device) - # get denoising dist. param - mean, std, _ = self.forward(x_t_1, torch.full((x_t_1.shape[0],), t ,device = self.device)) # compute the drawn densoined latent at time t - x_t_1 = mean + std*noise + x_t_1 = self.denoised_latent(noise, x_t_1, t) # store noised image x[t-1] = x_t_1.squeeze(0) return x + @torch.no_grad() + def denoised_latent(self, noise, x_t_1, t): + ''' + Computes the Gaussian reparameterization for the denoising dist. at times t given an isotopic noise parameter. + + Parameters: + forward_noise (tensor): Batch of noise parameters from the noise distribution reparametrization used to draw x_t + x_t_1 (tensor): Batch of input images, with color channels assumed to be normalized between [-1,1] + t (tensor): Batch of timesteps + + Returns: + x_t_1 (tensor): Batch of noised images at timestep t-1 + ''' + # get denoising dist. param + mean, std, _ = self.forward(x_t_1, torch.full((x_t_1.shape[0],), t ,device = self.device)) + # compute the drawn densoined latent at time t + x_t_1 = mean + std*noise + return x_t_1 + + # Loss functions def loss_simplified(self, forward_noise, pred_noise, t=None): diff --git a/trainer/train.py b/trainer/train.py index efacc50eb81f330b4247310ec12ebcbf178c4769..5367dcb848f9c29c3d477012660e898f09535b1f 100644 --- a/trainer/train.py +++ b/trainer/train.py @@ -163,6 +163,7 @@ def ddpm_trainer(model, ema.module.load_state_dict(checkpoint['ema']) except Exception as e: print("Error loading checkpoint. Exception: ", e) + raise # pick kl loss function if model.kl_loss == 'weighted':