Skip to content
Snippets Groups Projects
Select Git revision
  • 81a33135c1cf281719a890cccaee36297c9c0c28
  • main default protected
  • celebAHQ
  • ddpm-diffusers
4 results

sample.py

Blame
  • Code owners
    Assign users and groups as approvers for specific file changes. Learn more.
    Framework.py 19.01 KiB
    import torch 
    from torch import nn 
    import torch.nn.functional as F
    
    class DDPM(nn.Module):
        
        def __init__(self,
                     net=None,
                     diffusion_steps = 50, 
                     out_shape = (3,32,32), 
                     noise_schedule = 'linear', 
                     beta_1 = 1e-4, 
                     beta_T = 0.02,
                     alpha_bar_lower_bound = 0.9,
                     var_schedule = 'same', 
                     kl_loss = 'simplified', 
                     recon_loss = 'none',
                     device=None):
            '''
            net:                   U-Net
            diffusion_steps:       Length of the Markov chain
            out_shape:             Shape of the models's in- and output images
            noise_schedule:        Methods of initialization for the noise dist. variances, 'linear', 'cosine' or bounded_cosine
            beta_1, beta_T:        Variances for the first and last noise dist. (only for the 'linear' noise schedule)
            alpha_bar_lower_bound: Upper bound for the varaince of the complete noise dist. (only for the 'cosine_bounded' noise schedule)
            var_schedule:          Options to initialize or learn the denoising dist. variances, 'same', 'true'
            kl_loss:               Choice between the mathematically correct 'weighted' or in practice most commonly used 'simplified' KL loss
            recon_loss:            Is 'none' to ignore the reconstruction loss or 'nll' to compute the negative log likelihood
            '''
            super(DDPM,self).__init__()
            self.device = device
            
            # initialize the beta's, alpha's and alpha_bar's for the given noise schedule
            if noise_schedule == 'linear':
                beta, alpha, alpha_bar = self.linear_schedule(diffusion_steps, beta_1, beta_T, device=self.device)
            elif noise_schedule == 'cosine':
                beta, alpha, alpha_bar = self.cosine_schedule(diffusion_steps, device=self.device)
            elif noise_schedule == 'cosine_bounded':
                beta, alpha, alpha_bar = self.bounded_cosine_schedule(diffusion_steps, alpha_bar_lower_bound, device=self.device)
            else:
                raise ValueError('Unimplemented noise scheduler')
                
            # initialize the denoising varainces for the given varaince schedule 
            if var_schedule == 'same':
                var = beta
            elif var_schedule == 'true':
                var = [beta[0]] + [((1-alpha_bar[t-1])/(1-alpha_bar[t]))*beta[t] for t in range (1,diffusion_steps)]
                var = torch.tensor(var, device=self.device)
            else:
                raise ValueError('Unimplemented variance scheduler')
            
            # check for invalid kl_loss argument
            if (kl_loss != 'simplified') & (kl_loss != 'weighted'): 
                raise ValueError("Unimplemented loss function")
            
            self.net = net
            self.diffusion_steps = diffusion_steps
            self.noise_schedule = noise_schedule
            self.var_schedule = var_schedule
            self.beta = beta
            self.alpha = alpha
            self.alpha_bar = alpha_bar
            self.sqrt_1_minus_alpha_bar = torch.sqrt(1-alpha_bar) # for forward std
            self.sqrt_alpha_bar = torch.sqrt(alpha_bar) # for forward mean
            self.var = var
            self.std = torch.sqrt(self.var)
            self.kl_loss = kl_loss 
            self.recon_loss = recon_loss 
            self.out_shape = out_shape
            # precomputed for efficiency reasons
            self.noise_scaler = (1-alpha)/( self.sqrt_1_minus_alpha_bar)
            self.mean_scaler = 1/torch.sqrt(self.alpha)
            self.mse_weight = (self.beta**2)/(2*self.var*self.alpha*(1-self.alpha_bar))
              
        @staticmethod
        def linear_schedule(diffusion_steps, beta_1, beta_T, device):
            ''''
            Function that returns the noise distribution hyperparameters for the linear schedule.  
    
            Parameters:
            diffusion_steps (int): Length of the Markov chain.
            beta_1        (float): Variance of the first noise distribution.
            beta_T        (float): Variance of the last noise distribution.
            
            Returns:
            beta      (tensor): Linearly scaled from beta[0] = beta_1 to beta[-1] = beta_T, length is diffusion_steps.
            alpha     (tensor): Length is diffusion_steps.
            alpha_bar (tensor): Length is diffusion_steps.
            '''
            beta = torch.linspace(beta_1, beta_T, diffusion_steps,device=device)
            alpha = 1 - beta
            alpha_bar = torch.cumprod(alpha, dim=0)
            return beta, alpha, alpha_bar
        
        @staticmethod    
        def cosine_schedule(diffusion_steps, device):
            '''
            Function that returns the noise distribution hyperparameters for the cosine schedule.
            From "Improved Denoising Diffusion Probabilistic Models" by Nichol and Dhariwal.
    
            Parameters:
            diffusion_steps (int): Length of the Markov chain.
            
            Returns:
            beta      (tensor): Length is diffusion_steps.
            alpha     (tensor): Length is diffusion_steps.
            alpha_bar (tensor): Follows a sigmoid-like curve with a linear drop-off in the middle. 
                                Length is diffusion_steps.
            '''
            cosine_0 = DDPM.cosine(0, diffusion_steps= diffusion_steps)
            alpha_bar = [DDPM.cosine(t,diffusion_steps = diffusion_steps)/cosine_0 
                         for t in range(1, diffusion_steps+1)]
            shift = [1] + alpha_bar[:-1]
            beta = 1 - torch.div(torch.tensor(alpha_bar, device=device), torch.tensor(shift, device=device))
            beta = torch.clamp(beta, min =0, max = 0.999) #suggested by paper
            alpha = 1 - beta
            alpha_bar = torch.tensor(alpha_bar)
            return beta, alpha, alpha_bar
        
        @staticmethod    
        def bounded_cosine_schedule(diffusion_steps, alpha_bar_lower_bound, device):
            '''
            Function that returns the noise distribution hyperparameters for our experimental version of a 
            bounded cosine schedule. Benefits are still unproven. It still has a linear drop-off in alpha_bar, 
            but it's not sigmoidal and the betas are no longer smooth.
    
            Parameters:
            diffusion_steps (int): Length of the Markov chain
            
            Returns:
            beta      (tensor): Length is diffusion_steps
            alpha     (tensor): Length is diffusion_steps
            alpha_bar (tensor): Bounded between (alpha_bar_lower_bound, 1) with a linear drop-off in the middle. 
                                Length is diffusion_steps
            '''
            # get cosine alpha_bar (that range from 1 to 0)
            _, _, alpha_bar = DDPM.cosine_schedule(diffusion_steps, device)
            # apply min max normalization on alpha_bar (range from lower_bound to 0.999)
            min_val = torch.min(alpha_bar)
            max_val = torch.max(alpha_bar)
            alpha_bar = (alpha_bar - min_val) / (max_val - min_val)
            alpha_bar = alpha_bar * (0.9999 - alpha_bar_lower_bound) + alpha_bar_lower_bound # for 0.9999=>beta_1 = 1e-4
            # recompute beta, alpha and alpha_bar
            alpha_bar = alpha_bar.tolist()
            shift = [1] + alpha_bar[:-1]
            beta = 1 - torch.div(torch.tensor(alpha_bar, device = device), torch.tensor(shift, device=device))
            beta = torch.clamp(beta, min=0, max = 0.999)
            beta = torch.tensor(sorted(beta), device = device)
            alpha = 1 - beta
            alpha_bar = torch.cumprod(alpha, dim=0)
            return beta, alpha, alpha_bar
    
        @staticmethod
        def cosine(t, diffusion_steps, s = 0.008):
            '''
            Helper function that computes the cosine function from "Improved Denoising Diffusion Probabilistic Models" 
            by Nichol and Dhariwal, used for the cosine noise schedules.
    
            Parameters:
            t               (int): Current timestep
            diffusion_steps (int): Length of the Markov chain
            s             (float): Offset value suggested by the paper. Should be chosen such that sqrt(beta[0]) ~ 1/127.5 
                                   (for small T=50, this is not possible)
                                    
            Returns:
            (numpy.float64): Value of the cosine function at timestep t
            '''
            return (np.cos((((t/diffusion_steps)+s)*np.pi)/((1+s)*2)))**2
        
        
        ####
        # Important to note: Timesteps are adjusted to the range t in [1, diffusion_steps] akin to the paper 
        # equations, where x_0 denotes the input image and x_t the noised latent after adding noise t times.
        # Both trajectories work on batches assuming shape=(batch_size, channels, height, width).
        ####
        
        # Forward Trajectory Functions:
        
        @torch.no_grad()
        def forward_trajectory(self, x_0, t = None):
            '''
            Applies noise t times to each input image in the batch x_0.
            
            Parameters:
            x_0 (tensor): Batch of input images, with color channels assumed to be normalized between [-1,1]
            t   (tensor): Batch of timesteps, by default goes through full forward trajectory
        
            Returns:
            x_T           (tensor): Batch of noised images at timestep t
            forward_noise (tensor): Batch of noise parameters from the noise distribution reparametrization used to draw x_T
            '''
            if t is None:
                t = torch.full((x_0.shape[0],), self.diffusion_steps, device = self.device)
            elif torch.any(t == 0):
                raise ValueError("The tensor 't' contains a timestep zero.")
            forward_noise = torch.randn(x_0.shape, device = self.device)
            x_T = self.noised_latent(forward_noise, x_0, t) 
            return x_T , forward_noise
        
        @torch.no_grad()
        def noised_latent(self, forward_noise, x_0, t):
            '''
            Given a batch of noise parameters, this function recomputes the batch of noised images at their respective timesteps t.
            This allows us to avoid storing all the intermediate latents x_t along the forward trajectory.
            
            Parameters:
            forward_noise (tensor): Batch of noise parameters from the noise distribution reparametrization used to draw x_t
            x_0           (tensor): Batch of input images, with color channels assumed to be normalized between [-1,1]
            t             (tensor): Batch of timesteps
        
            Returns:
            x_t           (tensor): Batch of noised images at timestep t
            '''
            mean, std = self.forward_dist_param(x_0, t)
            x_t = mean + std*forward_noise
            return x_t
        
        @torch.no_grad()
        def forward_dist_param(self, x_0, t):
            '''
            Computes the parameters of the complete noise distribution.
            
            Parameters:
            x_0  (tensor): Batch of input images, with color channels assumed to be normalized between [-1,1]
            t    (tensor): Batch of timesteps
        
            Returns:
            mean (tensor): Batch of means for the complete noise distribution for each image in the batch x_0
            std  (tensor): Batch of std scalars for the complete noise distribution for each image in the batch x_0
            '''
            mean = self.sqrt_alpha_bar[t-1][:,None,None,None]*x_0
            std = self.sqrt_1_minus_alpha_bar[t-1][:,None,None,None]
            return mean, std
        
        @torch.no_grad()
        def single_forward_dist_param(self, x_t_1, t):
            '''
            Computes the parameters of the individual noise distribution.
    
            Parameters:
            x_t_1 (tensor): Batch of noised images at timestep t-1
            t     (tensor): Batch of timesteps
            
            Returns:
            mean (tensor): Batch of means for the individual noise distribution for each image in the batch x_t_1
            std  (tensor): Batch of std scalars for the individual noise distribution for each image in the batch x_t_1
            '''
            mean = torch.sqrt(1-self.beta[t-1])[:,None,None,None]*x_t_1
            std = torch.sqrt(self.beta[t-1])[:,None,None,None]
            return mean, std
    
    
        # Reverse Trajectory Functions:
    
        def reverse_trajectory(self, x_t, t):
            '''
            Draws a denoised images x_{t-1} by reparametrizing the denoising distribution at times t for the current noised
            latents x_t.
    
            Parameters:
            x_t (tensor): Batch of input images, with color channels assumed to be normalized between [-1,1]
            t   (tensor): Batch of timestep
            
            Returns:
            x_t_1 (tensor): Batch of denoised images at timestep t-1
            '''
            noise = torch.randn(x_t.shape, device=self.device)
            mean, std , _ = self.forward(x_t, t)
            x_t_1 = mean + std*noise
            return x_t_1
        
        def forward(self, x_t, t):
            '''
            Passes the current noised images x_t and timesteps t through the U-Net in order to compute the 
            predicted noise, which is later used to determine the current denoising distribution parameters
            (mean and std) in the reverse trajectory. 
            Since the DDPM class is inheriting from the nn.Module class, this function is required to share 
            the name 'forward'. This naming scheme does not refer to the forward trajectory, but the forward 
            pass of the model itself, which concerns to the reverse trajectory.
            
            Parameters:
            x_t (tensor): Batch of input images, with color channels assumed to be normalized between [-1,1]
            t   (tensor): Batch of timesteps
            
            Returns:
            mean        (tensor): Batch of means for the complete noise dist. for each image in the batch x_t
            std         (tensor): Batch of std scalars for the complete noise dist. for each image in the batch x_t
            pred_noise  (tensor): Predicted noise for each image in the batch x_t
            '''
            pred_noise = self.net(x_t,t,return_dict=False)[0]
            mean = self.mean_scaler[t-1][:,None,None,None]*(x_t - self.noise_scaler[t-1][:,None,None,None]*pred_noise)
            std = self.std[t-1][:,None,None,None]
            return mean, std, pred_noise
        
    
        # Forward and Reverse Trajectory:
    
        @torch.no_grad()
        def complete_trajectory(self, x_0):
            '''
            Takes a batch of images and applies both trajectories sequentially, i.e. first adds noise to all 
            images along the forward chain and later removes the noise with the reverse chain.
            This function will be used in the evaluation pipeline as a means to evaluate its performance on 
            how well it is able to reconstruct/recover the training images after applying the forward trajectory.
    
            Parameters:
            x_0 (tensor): Batch of input images, with color channels assumed to be normalized between [-1,1]
    
            Returns:
            x_0_recon (tensor): Batch of images given by the model reconstruction of x_0
            '''
            # apply forward trajectory
            x_0_recon, _ = self.forward_trajectory(x_0)
            # apply reverse trajectory
            for t in reversed(range(1, self.diffusion_steps + 1)):
                # draw noise used in the denoising dist. reparametrization
                if t > 1:
                    noise = torch.randn(x_0_recon.shape, device=self.device)
                else:
                    noise = torch.zeros(x_0_recon.shape, device=self.device)
                # get denoising dist. param
                mean, std, _ = self.forward(x_0_recon, torch.full((x_0_recon.shape[0],), t ,device = self.device))
                # compute the drawn denoised latent at time t
                x_0_recon = mean + std * noise
            return x_0_recon 
          
    
        # Sampling Functions:
    
        @torch.no_grad()
        def sample(self, batch_size = 10,  x_T=None):
            '''
            Samples batch_size images by passing a batch of randomly drawn noise parameters through the complete 
            reverse trajectory. The last denoising step is deterministic as suggested by the paper 
            "Denoising Diffusion Probabilistic Models" by Ho et al.
    
            Parameters:
            batch_size (int): Number of images to be sampled/generated from the diffusion model 
            x_T     (tensor): Input of the reverse trajectory. Batch of noised images usually drawn 
                              from an isotropic Gaussian, but can be set manually if desired.
    
            Returns:
            x_t_1 (tensor): Batch of sampled/generated images
            '''
            # start with a batch of isotropic noise images (or given arguemnt)
            if x_T:
                x_t_1 = x_T
            else:
                x_t_1 = torch.randn((batch_size,)+tuple(self.out_shape), device=self.device)
            # apply reverse trajectory
            for t in reversed(range(1, self.diffusion_steps+1)):
                # draw noise used in the denoising dist. reparametrization
                if t>1:
                    noise = torch.randn(x_t_1.shape, device=self.device)
                else:
                    noise = torch.zeros(x_t_1.shape, device=self.device)
                # get denoising dist. param
                mean, std, _ = self.forward(x_t_1, torch.full((x_t_1.shape[0],), t ,device = self.device))
                # compute the drawn densoined latent at time t
                x_t_1 = mean + std*noise
            return x_t_1
    
        @torch.no_grad()
        def sample_intermediates_latents(self):
            '''
            Samples a single image and provides all intermediate denoised images that were drawn along the reverse 
            trajectory. The last denoising step is deterministic as suggested by the paper "Denoising Diffusion 
            Probabilistic Models" by Ho et al.
            
            Returns:
            x (tensor): Contains the self.diffusion_steps+1 denoised image tensors
            '''
            # start with an image of pure noise (batch_size 1) and store it as part of the output 
            x_t_1 = torch.randn((1,) + tuple(self.out_shape), device=self.device)
            x = torch.empty((self.diffusion_steps+1,) + tuple(self.out_shape), device=self.device)
            x[-1] = x_t_1.squeeze(0)
            # apply reverse trajectory
            for t in reversed(range(1, self.diffusion_steps+1)):
                # draw noise used in the denoising dist. reparametrization
                if t>1:
                    noise = torch.randn(x_t_1.shape, device=self.device)
                else:
                    noise = torch.zeros(x_t_1.shape, device=self.device)
                # get denoising dist. param
                mean, std, _ = self.forward(x_t_1, torch.full((x_t_1.shape[0],), t ,device = self.device))
                # compute the drawn densoined latent at time t
                x_t_1 = mean + std*noise
                # store noised image
                x[t-1] = x_t_1.squeeze(0)
            #x_sq = x.squeeze(1)
            #return x_sq
            return x
    
    
        # Loss functions
        
        def loss_simplified(self, forward_noise, pred_noise, t=None):
            '''
            Returns the Mean Squared Error (MSE) between the forward_noise used to compute the noised images x_t 
            along the forward trajectory and the predicted noise computed by the U-Net with the noised images 
            x_t and timestep t.
            '''
            return F.mse_loss(forward_noise, pred_noise)
        
        
        def loss_weighted(self, forward_noise, pred_noise, t):
            '''
            Returns the mathematically correct weighted version of the simplified loss.
            '''
            return self.mse_weight[t-1][:,None,None,None]*F.mse_loss(forward_noise, pred_noise)
        
        
        # If t=0 and self.recon_loss == 'nll'
        def loss_recon(self, x_0, mean_1, std_1):
            '''
            Returns the reconstruction loss given by the mean negative log-likelihood of x_0 under the last 
            denoising Gaussian distribution with mean mean_1 and standard deviation std_1.
            '''
            return -torch.distributions.Normal(mean_1, std_1).log_prob(x_0).mean()