Select Git revision
-
Tobias Seibel authoredTobias Seibel authored
Code owners
Assign users and groups as approvers for specific file changes. Learn more.
Framework.py 19.01 KiB
import torch
from torch import nn
import torch.nn.functional as F
class DDPM(nn.Module):
def __init__(self,
net=None,
diffusion_steps = 50,
out_shape = (3,32,32),
noise_schedule = 'linear',
beta_1 = 1e-4,
beta_T = 0.02,
alpha_bar_lower_bound = 0.9,
var_schedule = 'same',
kl_loss = 'simplified',
recon_loss = 'none',
device=None):
'''
net: U-Net
diffusion_steps: Length of the Markov chain
out_shape: Shape of the models's in- and output images
noise_schedule: Methods of initialization for the noise dist. variances, 'linear', 'cosine' or bounded_cosine
beta_1, beta_T: Variances for the first and last noise dist. (only for the 'linear' noise schedule)
alpha_bar_lower_bound: Upper bound for the varaince of the complete noise dist. (only for the 'cosine_bounded' noise schedule)
var_schedule: Options to initialize or learn the denoising dist. variances, 'same', 'true'
kl_loss: Choice between the mathematically correct 'weighted' or in practice most commonly used 'simplified' KL loss
recon_loss: Is 'none' to ignore the reconstruction loss or 'nll' to compute the negative log likelihood
'''
super(DDPM,self).__init__()
self.device = device
# initialize the beta's, alpha's and alpha_bar's for the given noise schedule
if noise_schedule == 'linear':
beta, alpha, alpha_bar = self.linear_schedule(diffusion_steps, beta_1, beta_T, device=self.device)
elif noise_schedule == 'cosine':
beta, alpha, alpha_bar = self.cosine_schedule(diffusion_steps, device=self.device)
elif noise_schedule == 'cosine_bounded':
beta, alpha, alpha_bar = self.bounded_cosine_schedule(diffusion_steps, alpha_bar_lower_bound, device=self.device)
else:
raise ValueError('Unimplemented noise scheduler')
# initialize the denoising varainces for the given varaince schedule
if var_schedule == 'same':
var = beta
elif var_schedule == 'true':
var = [beta[0]] + [((1-alpha_bar[t-1])/(1-alpha_bar[t]))*beta[t] for t in range (1,diffusion_steps)]
var = torch.tensor(var, device=self.device)
else:
raise ValueError('Unimplemented variance scheduler')
# check for invalid kl_loss argument
if (kl_loss != 'simplified') & (kl_loss != 'weighted'):
raise ValueError("Unimplemented loss function")
self.net = net
self.diffusion_steps = diffusion_steps
self.noise_schedule = noise_schedule
self.var_schedule = var_schedule
self.beta = beta
self.alpha = alpha
self.alpha_bar = alpha_bar
self.sqrt_1_minus_alpha_bar = torch.sqrt(1-alpha_bar) # for forward std
self.sqrt_alpha_bar = torch.sqrt(alpha_bar) # for forward mean
self.var = var
self.std = torch.sqrt(self.var)
self.kl_loss = kl_loss
self.recon_loss = recon_loss
self.out_shape = out_shape
# precomputed for efficiency reasons
self.noise_scaler = (1-alpha)/( self.sqrt_1_minus_alpha_bar)
self.mean_scaler = 1/torch.sqrt(self.alpha)
self.mse_weight = (self.beta**2)/(2*self.var*self.alpha*(1-self.alpha_bar))
@staticmethod
def linear_schedule(diffusion_steps, beta_1, beta_T, device):
''''
Function that returns the noise distribution hyperparameters for the linear schedule.
Parameters:
diffusion_steps (int): Length of the Markov chain.
beta_1 (float): Variance of the first noise distribution.
beta_T (float): Variance of the last noise distribution.
Returns:
beta (tensor): Linearly scaled from beta[0] = beta_1 to beta[-1] = beta_T, length is diffusion_steps.
alpha (tensor): Length is diffusion_steps.
alpha_bar (tensor): Length is diffusion_steps.
'''
beta = torch.linspace(beta_1, beta_T, diffusion_steps,device=device)
alpha = 1 - beta
alpha_bar = torch.cumprod(alpha, dim=0)
return beta, alpha, alpha_bar
@staticmethod
def cosine_schedule(diffusion_steps, device):
'''
Function that returns the noise distribution hyperparameters for the cosine schedule.
From "Improved Denoising Diffusion Probabilistic Models" by Nichol and Dhariwal.
Parameters:
diffusion_steps (int): Length of the Markov chain.
Returns:
beta (tensor): Length is diffusion_steps.
alpha (tensor): Length is diffusion_steps.
alpha_bar (tensor): Follows a sigmoid-like curve with a linear drop-off in the middle.
Length is diffusion_steps.
'''
cosine_0 = DDPM.cosine(0, diffusion_steps= diffusion_steps)
alpha_bar = [DDPM.cosine(t,diffusion_steps = diffusion_steps)/cosine_0
for t in range(1, diffusion_steps+1)]
shift = [1] + alpha_bar[:-1]
beta = 1 - torch.div(torch.tensor(alpha_bar, device=device), torch.tensor(shift, device=device))
beta = torch.clamp(beta, min =0, max = 0.999) #suggested by paper
alpha = 1 - beta
alpha_bar = torch.tensor(alpha_bar)
return beta, alpha, alpha_bar
@staticmethod
def bounded_cosine_schedule(diffusion_steps, alpha_bar_lower_bound, device):
'''
Function that returns the noise distribution hyperparameters for our experimental version of a
bounded cosine schedule. Benefits are still unproven. It still has a linear drop-off in alpha_bar,
but it's not sigmoidal and the betas are no longer smooth.
Parameters:
diffusion_steps (int): Length of the Markov chain
Returns:
beta (tensor): Length is diffusion_steps
alpha (tensor): Length is diffusion_steps
alpha_bar (tensor): Bounded between (alpha_bar_lower_bound, 1) with a linear drop-off in the middle.
Length is diffusion_steps
'''
# get cosine alpha_bar (that range from 1 to 0)
_, _, alpha_bar = DDPM.cosine_schedule(diffusion_steps, device)
# apply min max normalization on alpha_bar (range from lower_bound to 0.999)
min_val = torch.min(alpha_bar)
max_val = torch.max(alpha_bar)
alpha_bar = (alpha_bar - min_val) / (max_val - min_val)
alpha_bar = alpha_bar * (0.9999 - alpha_bar_lower_bound) + alpha_bar_lower_bound # for 0.9999=>beta_1 = 1e-4
# recompute beta, alpha and alpha_bar
alpha_bar = alpha_bar.tolist()
shift = [1] + alpha_bar[:-1]
beta = 1 - torch.div(torch.tensor(alpha_bar, device = device), torch.tensor(shift, device=device))
beta = torch.clamp(beta, min=0, max = 0.999)
beta = torch.tensor(sorted(beta), device = device)
alpha = 1 - beta
alpha_bar = torch.cumprod(alpha, dim=0)
return beta, alpha, alpha_bar
@staticmethod
def cosine(t, diffusion_steps, s = 0.008):
'''
Helper function that computes the cosine function from "Improved Denoising Diffusion Probabilistic Models"
by Nichol and Dhariwal, used for the cosine noise schedules.
Parameters:
t (int): Current timestep
diffusion_steps (int): Length of the Markov chain
s (float): Offset value suggested by the paper. Should be chosen such that sqrt(beta[0]) ~ 1/127.5
(for small T=50, this is not possible)
Returns:
(numpy.float64): Value of the cosine function at timestep t
'''
return (np.cos((((t/diffusion_steps)+s)*np.pi)/((1+s)*2)))**2
####
# Important to note: Timesteps are adjusted to the range t in [1, diffusion_steps] akin to the paper
# equations, where x_0 denotes the input image and x_t the noised latent after adding noise t times.
# Both trajectories work on batches assuming shape=(batch_size, channels, height, width).
####
# Forward Trajectory Functions:
@torch.no_grad()
def forward_trajectory(self, x_0, t = None):
'''
Applies noise t times to each input image in the batch x_0.
Parameters:
x_0 (tensor): Batch of input images, with color channels assumed to be normalized between [-1,1]
t (tensor): Batch of timesteps, by default goes through full forward trajectory
Returns:
x_T (tensor): Batch of noised images at timestep t
forward_noise (tensor): Batch of noise parameters from the noise distribution reparametrization used to draw x_T
'''
if t is None:
t = torch.full((x_0.shape[0],), self.diffusion_steps, device = self.device)
elif torch.any(t == 0):
raise ValueError("The tensor 't' contains a timestep zero.")
forward_noise = torch.randn(x_0.shape, device = self.device)
x_T = self.noised_latent(forward_noise, x_0, t)
return x_T , forward_noise
@torch.no_grad()
def noised_latent(self, forward_noise, x_0, t):
'''
Given a batch of noise parameters, this function recomputes the batch of noised images at their respective timesteps t.
This allows us to avoid storing all the intermediate latents x_t along the forward trajectory.
Parameters:
forward_noise (tensor): Batch of noise parameters from the noise distribution reparametrization used to draw x_t
x_0 (tensor): Batch of input images, with color channels assumed to be normalized between [-1,1]
t (tensor): Batch of timesteps
Returns:
x_t (tensor): Batch of noised images at timestep t
'''
mean, std = self.forward_dist_param(x_0, t)
x_t = mean + std*forward_noise
return x_t
@torch.no_grad()
def forward_dist_param(self, x_0, t):
'''
Computes the parameters of the complete noise distribution.
Parameters:
x_0 (tensor): Batch of input images, with color channels assumed to be normalized between [-1,1]
t (tensor): Batch of timesteps
Returns:
mean (tensor): Batch of means for the complete noise distribution for each image in the batch x_0
std (tensor): Batch of std scalars for the complete noise distribution for each image in the batch x_0
'''
mean = self.sqrt_alpha_bar[t-1][:,None,None,None]*x_0
std = self.sqrt_1_minus_alpha_bar[t-1][:,None,None,None]
return mean, std
@torch.no_grad()
def single_forward_dist_param(self, x_t_1, t):
'''
Computes the parameters of the individual noise distribution.
Parameters:
x_t_1 (tensor): Batch of noised images at timestep t-1
t (tensor): Batch of timesteps
Returns:
mean (tensor): Batch of means for the individual noise distribution for each image in the batch x_t_1
std (tensor): Batch of std scalars for the individual noise distribution for each image in the batch x_t_1
'''
mean = torch.sqrt(1-self.beta[t-1])[:,None,None,None]*x_t_1
std = torch.sqrt(self.beta[t-1])[:,None,None,None]
return mean, std
# Reverse Trajectory Functions:
def reverse_trajectory(self, x_t, t):
'''
Draws a denoised images x_{t-1} by reparametrizing the denoising distribution at times t for the current noised
latents x_t.
Parameters:
x_t (tensor): Batch of input images, with color channels assumed to be normalized between [-1,1]
t (tensor): Batch of timestep
Returns:
x_t_1 (tensor): Batch of denoised images at timestep t-1
'''
noise = torch.randn(x_t.shape, device=self.device)
mean, std , _ = self.forward(x_t, t)
x_t_1 = mean + std*noise
return x_t_1
def forward(self, x_t, t):
'''
Passes the current noised images x_t and timesteps t through the U-Net in order to compute the
predicted noise, which is later used to determine the current denoising distribution parameters
(mean and std) in the reverse trajectory.
Since the DDPM class is inheriting from the nn.Module class, this function is required to share
the name 'forward'. This naming scheme does not refer to the forward trajectory, but the forward
pass of the model itself, which concerns to the reverse trajectory.
Parameters:
x_t (tensor): Batch of input images, with color channels assumed to be normalized between [-1,1]
t (tensor): Batch of timesteps
Returns:
mean (tensor): Batch of means for the complete noise dist. for each image in the batch x_t
std (tensor): Batch of std scalars for the complete noise dist. for each image in the batch x_t
pred_noise (tensor): Predicted noise for each image in the batch x_t
'''
pred_noise = self.net(x_t,t,return_dict=False)[0]
mean = self.mean_scaler[t-1][:,None,None,None]*(x_t - self.noise_scaler[t-1][:,None,None,None]*pred_noise)
std = self.std[t-1][:,None,None,None]
return mean, std, pred_noise
# Forward and Reverse Trajectory:
@torch.no_grad()
def complete_trajectory(self, x_0):
'''
Takes a batch of images and applies both trajectories sequentially, i.e. first adds noise to all
images along the forward chain and later removes the noise with the reverse chain.
This function will be used in the evaluation pipeline as a means to evaluate its performance on
how well it is able to reconstruct/recover the training images after applying the forward trajectory.
Parameters:
x_0 (tensor): Batch of input images, with color channels assumed to be normalized between [-1,1]
Returns:
x_0_recon (tensor): Batch of images given by the model reconstruction of x_0
'''
# apply forward trajectory
x_0_recon, _ = self.forward_trajectory(x_0)
# apply reverse trajectory
for t in reversed(range(1, self.diffusion_steps + 1)):
# draw noise used in the denoising dist. reparametrization
if t > 1:
noise = torch.randn(x_0_recon.shape, device=self.device)
else:
noise = torch.zeros(x_0_recon.shape, device=self.device)
# get denoising dist. param
mean, std, _ = self.forward(x_0_recon, torch.full((x_0_recon.shape[0],), t ,device = self.device))
# compute the drawn denoised latent at time t
x_0_recon = mean + std * noise
return x_0_recon
# Sampling Functions:
@torch.no_grad()
def sample(self, batch_size = 10, x_T=None):
'''
Samples batch_size images by passing a batch of randomly drawn noise parameters through the complete
reverse trajectory. The last denoising step is deterministic as suggested by the paper
"Denoising Diffusion Probabilistic Models" by Ho et al.
Parameters:
batch_size (int): Number of images to be sampled/generated from the diffusion model
x_T (tensor): Input of the reverse trajectory. Batch of noised images usually drawn
from an isotropic Gaussian, but can be set manually if desired.
Returns:
x_t_1 (tensor): Batch of sampled/generated images
'''
# start with a batch of isotropic noise images (or given arguemnt)
if x_T:
x_t_1 = x_T
else:
x_t_1 = torch.randn((batch_size,)+tuple(self.out_shape), device=self.device)
# apply reverse trajectory
for t in reversed(range(1, self.diffusion_steps+1)):
# draw noise used in the denoising dist. reparametrization
if t>1:
noise = torch.randn(x_t_1.shape, device=self.device)
else:
noise = torch.zeros(x_t_1.shape, device=self.device)
# get denoising dist. param
mean, std, _ = self.forward(x_t_1, torch.full((x_t_1.shape[0],), t ,device = self.device))
# compute the drawn densoined latent at time t
x_t_1 = mean + std*noise
return x_t_1
@torch.no_grad()
def sample_intermediates_latents(self):
'''
Samples a single image and provides all intermediate denoised images that were drawn along the reverse
trajectory. The last denoising step is deterministic as suggested by the paper "Denoising Diffusion
Probabilistic Models" by Ho et al.
Returns:
x (tensor): Contains the self.diffusion_steps+1 denoised image tensors
'''
# start with an image of pure noise (batch_size 1) and store it as part of the output
x_t_1 = torch.randn((1,) + tuple(self.out_shape), device=self.device)
x = torch.empty((self.diffusion_steps+1,) + tuple(self.out_shape), device=self.device)
x[-1] = x_t_1.squeeze(0)
# apply reverse trajectory
for t in reversed(range(1, self.diffusion_steps+1)):
# draw noise used in the denoising dist. reparametrization
if t>1:
noise = torch.randn(x_t_1.shape, device=self.device)
else:
noise = torch.zeros(x_t_1.shape, device=self.device)
# get denoising dist. param
mean, std, _ = self.forward(x_t_1, torch.full((x_t_1.shape[0],), t ,device = self.device))
# compute the drawn densoined latent at time t
x_t_1 = mean + std*noise
# store noised image
x[t-1] = x_t_1.squeeze(0)
#x_sq = x.squeeze(1)
#return x_sq
return x
# Loss functions
def loss_simplified(self, forward_noise, pred_noise, t=None):
'''
Returns the Mean Squared Error (MSE) between the forward_noise used to compute the noised images x_t
along the forward trajectory and the predicted noise computed by the U-Net with the noised images
x_t and timestep t.
'''
return F.mse_loss(forward_noise, pred_noise)
def loss_weighted(self, forward_noise, pred_noise, t):
'''
Returns the mathematically correct weighted version of the simplified loss.
'''
return self.mse_weight[t-1][:,None,None,None]*F.mse_loss(forward_noise, pred_noise)
# If t=0 and self.recon_loss == 'nll'
def loss_recon(self, x_0, mean_1, std_1):
'''
Returns the reconstruction loss given by the mean negative log-likelihood of x_0 under the last
denoising Gaussian distribution with mean mean_1 and standard deviation std_1.
'''
return -torch.distributions.Normal(mean_1, std_1).log_prob(x_0).mean()