Skip to content
Snippets Groups Projects
Commit 38147ff1 authored by Gonzalo Martin Garcia's avatar Gonzalo Martin Garcia
Browse files

EMA training added. Fixed some minor buggs.

parent cf57c965
No related branches found
No related tags found
No related merge requests found
......@@ -36,7 +36,7 @@ def train_func(f):
#model = globals()[meta_setting["modelname"]](**model_setting).to(device)
#net = torch.compile(model)
net = UNet2DModel(
sample_size=128,
sample_size=64,
in_channels=3,
out_channels=3,
layers_per_block=2,
......
......@@ -68,9 +68,9 @@ class DDPM(nn.Module):
self.recon_loss = recon_loss
self.out_shape = out_shape
# precomputed for efficiency reasons
self.noise_scaler = ((1-alpha)/( self.sqrt_1_minus_alpha_bar))
self.mean_scaler = (1/torch.sqrt(self.alpha))
self.mse_weight = ((self.beta**2)/(2*self.var*self.alpha*(1-self.alpha_bar)))
self.noise_scaler = (1-alpha)/( self.sqrt_1_minus_alpha_bar)
self.mean_scaler = 1/torch.sqrt(self.alpha)
self.mse_weight = (self.beta**2)/(2*self.var*self.alpha*(1-self.alpha_bar))
@staticmethod
def linear_schedule(diffusion_steps, beta_1, beta_T, device):
......@@ -183,14 +183,14 @@ class DDPM(nn.Module):
Parameters:
x_0 (tensor): Batch of input images, with color channels assumed to be normalized between [-1,1]
t (int): Timestep, by default goes through full forward trajectory
t (tensor): Batch of timesteps, by default goes through full forward trajectory
Returns:
x_T (tensor): Batch of noised images at timestep t
forward_noise (tensor): Batch of noise parameters from the noise distribution reparametrization used to draw x_T
'''
if t is None:
t = self.diffusion_steps
t = torch.full((x_0.shape[0],), self.diffusion_steps, device = self.device)
elif torch.any(t == 0):
raise ValueError("The tensor 't' contains a timestep zero.")
forward_noise = torch.randn(x_0.shape, device = self.device)
......@@ -200,13 +200,13 @@ class DDPM(nn.Module):
@torch.no_grad()
def noised_latent(self, forward_noise, x_0, t):
'''
Given a batch of noise parameters, this function recomputes the batch of noised images at timestep t.
Given a batch of noise parameters, this function recomputes the batch of noised images at their respective timesteps t.
This allows us to avoid storing all the intermediate latents x_t along the forward trajectory.
Parameters:
forward_noise (tensor): Batch of noise parameters from the noise distribution reparametrization used to draw x_t
x_0 (tensor): Batch of input images, with color channels assumed to be normalized between [-1,1]
t (int): Timestep
t (tensor): Batch of timesteps
Returns:
x_t (tensor): Batch of noised images at timestep t
......@@ -222,14 +222,14 @@ class DDPM(nn.Module):
Parameters:
x_0 (tensor): Batch of input images, with color channels assumed to be normalized between [-1,1]
t (int): Timestep
t (tensor): Batch of timesteps
Returns:
mean (tensor): Batch of means for the complete noise distribution for each image in the batch x_0
std (tensor): Batch of std scalars for the complete noise distribution for each image in the batch x_0
'''
mean = self.sqrt_alpha_bar[t-1].view(-1, 1, 1, 1)*x_0
std = self.sqrt_1_minus_alpha_bar[t-1].view(-1, 1, 1, 1)
mean = self.sqrt_alpha_bar[t-1][:,None,None,None]*x_0
std = self.sqrt_1_minus_alpha_bar[t-1][:,None,None,None]
return mean, std
@torch.no_grad()
......@@ -239,14 +239,14 @@ class DDPM(nn.Module):
Parameters:
x_t_1 (tensor): Batch of noised images at timestep t-1
t (int): Timestep
t (tensor): Batch of timesteps
Returns:
mean (tensor): Batch of means for the individual noise distribution for each image in the batch x_t_1
std (tensor): Batch of std scalars for the individual noise distribution for each image in the batch x_t_1
'''
mean = torch.sqrt(1-self.beta[t-1]).view(-1, 1, 1, 1)*x_t_1
std = torch.sqrt(self.beta[t-1]).view(-1, 1, 1, 1)
mean = torch.sqrt(1-self.beta[t-1])[:,None,None,None]*x_t_1
std = torch.sqrt(self.beta[t-1])[:,None,None,None]
return mean, std
......@@ -254,12 +254,12 @@ class DDPM(nn.Module):
def reverse_trajectory(self, x_t, t):
'''
Draws a denoised image x_{t-1} by reparametrizing the denoising distribution at time t for the current noised
latent x_t.
Draws a denoised images x_{t-1} by reparametrizing the denoising distribution at times t for the current noised
latents x_t.
Parameters:
x_t (tensor): Batch of input images, with color channels assumed to be normalized between [-1,1]
t (int): Timestep
t (tensor): Batch of timestep
Returns:
x_t_1 (tensor): Batch of denoised images at timestep t-1
......@@ -271,16 +271,16 @@ class DDPM(nn.Module):
def forward(self, x_t, t):
'''
Passes the current noised image x_t and timestep t through the U-Net in order to compute the
predicted noise, which is later used to determine the current denoising distribution parameters in the
reverse trajectory.
Passes the current noised images x_t and timesteps t through the U-Net in order to compute the
predicted noise, which is later used to determine the current denoising distribution parameters
(mean and std) in the reverse trajectory.
Since the DDPM class is inheriting from the nn.Module class, this function is required to share
the name 'forward'. This naming scheme does not refer to the forward trajectory, but the forward
pass of the model itself, which concerns to the reverse trajectory.
Parameters:
x_t (tensor): Batch of input images, with color channels assumed to be normalized between [-1,1]
t (int): Timestep
t (tensor): Batch of timesteps
Returns:
mean (tensor): Batch of means for the complete noise dist. for each image in the batch x_t
......@@ -288,8 +288,8 @@ class DDPM(nn.Module):
pred_noise (tensor): Predicted noise for each image in the batch x_t
'''
pred_noise = self.net(x_t,t,return_dict=False)[0]
mean = self.mean_scaler[t-1].view(-1, 1, 1, 1)*(x_t - self.noise_scaler[t-1].view(-1, 1, 1, 1)*pred_noise)
std = self.std[t-1].view(-1, 1, 1, 1)
mean = self.mean_scaler[t-1][:,None,None,None]*(x_t - self.noise_scaler[t-1][:,None,None,None]*pred_noise)
std = self.std[t-1][:,None,None,None]
return mean, std, pred_noise
......@@ -319,7 +319,7 @@ class DDPM(nn.Module):
else:
noise = torch.zeros(x_0_recon.shape, device=self.device)
# get denoising dist. param
mean, std, _ = self.forward(x_0_recon, t)
mean, std, _ = self.forward(x_0_recon, torch.full((x_0_recon.shape[0],), t ,device = self.device))
# compute the drawn denoised latent at time t
x_0_recon = mean + std * noise
return x_0_recon
......@@ -355,7 +355,7 @@ class DDPM(nn.Module):
else:
noise = torch.zeros(x_t_1.shape, device=self.device)
# get denoising dist. param
mean, std, _ = self.forward(x_t_1, t)
mean, std, _ = self.forward(x_t_1, torch.full((x_t_1.shape[0],), t ,device = self.device))
# compute the drawn densoined latent at time t
x_t_1 = mean + std*noise
return x_t_1
......@@ -372,8 +372,8 @@ class DDPM(nn.Module):
'''
# start with an image of pure noise (batch_size 1) and store it as part of the output
x_t_1 = torch.randn((1,) + tuple(self.out_shape), device=self.device)
x = torch.empty((self.diffusion_steps+1,1,) + tuple(self.out_shape), device=self.device)
x[-1] = x_t_1
x = torch.empty((self.diffusion_steps+1,) + tuple(self.out_shape), device=self.device)
x[-1] = x_t_1.squeeze(0)
# apply reverse trajectory
for t in reversed(range(1, self.diffusion_steps+1)):
# draw noise used in the denoising dist. reparametrization
......@@ -382,14 +382,14 @@ class DDPM(nn.Module):
else:
noise = torch.zeros(x_t_1.shape, device=self.device)
# get denoising dist. param
mean, std, _ = self.forward(x_t_1, t)
mean, std, _ = self.forward(x_t_1, torch.full((x_t_1.shape[0],), t ,device = self.device))
# compute the drawn densoined latent at time t
x_t_1 = mean + std*noise
# store noised image
x[t-1] = x_t_1
x_sq = x.squeeze(1)
return x_sq
#return x
x[t-1] = x_t_1.squeeze(0)
#x_sq = x.squeeze(1)
#return x_sq
return x
# Loss functions
......@@ -407,7 +407,7 @@ class DDPM(nn.Module):
'''
Returns the mathematically correct weighted version of the simplified loss.
'''
return self.mse_weight[t-1].view(-1, 1, 1, 1)*F.mse_loss(forward_noise, pred_noise)
return self.mse_weight[t-1][:,None,None,None]*F.mse_loss(forward_noise, pred_noise)
# If t=0 and self.recon_loss == 'nll'
......
import numpy as np
import copy
import torch
from torch import nn
from torchvision import datasets,transforms
......@@ -10,6 +10,7 @@ import numpy as np
import torch.nn.functional as F
import os
import wandb
from copy import deepcopy
device = 'cuda' if torch.cuda.is_available() else 'cpu'
......@@ -53,6 +54,37 @@ def simple_trainer(model,device,epochs,trainloader,testloader,bs,lr,T,criterion
print(f"Testloss in step {epoch} :{np.mean(running_testloss)}")
# EMA class
# Important! This EMA class code is not ours and was taken from the Pytorch Image Models library called timm and performs exponential moving
# average on the trained weights for a given models neural net which was suggested by the paper "Improved Denoising Diffusion Probabilistic Models"
# by Nichol and Dhariwal to stabilize and improve the training and generalization process.
# https://github.com/huggingface/pytorch-image-models/blob/main/timm/utils/model_ema.py
class ModelEmaV2(nn.Module):
def __init__(self, model, decay=0.9999, device=None):
super(ModelEmaV2, self).__init__()
# make a copy of the model for accumulating moving average of weights
self.module = deepcopy(model)
self.module.eval()
self.decay = decay
self.device = device # perform ema on different device from model if set
if self.device is not None:
self.module.to(device=device)
def _update(self, model, update_fn):
with torch.no_grad():
for ema_v, model_v in zip(self.module.state_dict().values(), model.state_dict().values()):
if self.device is not None:
model_v = model_v.to(device=self.device)
ema_v.copy_(update_fn(ema_v, model_v))
def update(self, model):
self._update(model, update_fn=lambda e, m: self.decay * e + (1. - self.decay) * m)
def set(self, model):
self._update(model, update_fn=lambda e, m: m)
# Training function for the unconditional diffusion model
def ddpm_trainer(model,
......@@ -70,6 +102,8 @@ def ddpm_trainer(model,
experiment_path = None,
T_max = 5*10000, # None,
eta_min= 1e-5,
ema_training = True,
decay = 0.9999,
**args
):
'''
......@@ -89,6 +123,8 @@ def ddpm_trainer(model,
checkpoint: Name of the saved pth. file containing the trained weights and biases
T_max: CosineAnnealingLR scheduler argument (nr of steps in training for a full cycle)
eta_min: CosineAnnealingLR scheduler argument (scheduler oscillates between highest lr 'leraning_rate' and minimum lr 'eta_min')
decay: EMA decay rate that is used to weight the effect of the ema model when computing the weighted avg between trained and
ema weights for the networks weight update
'''
# set optimizer parameters and learning rate
......@@ -133,13 +169,17 @@ def ddpm_trainer(model,
if model.recon_loss == 'nll':
low = 0
# EMA
if ema_training:
ema = ModelEmaV2(model, decay=decay, device = model.device)
# Using W&B
with wandb.init(project='test-project', name=run_name, entity='gonzalomartingarcia0', id=run_name, resume=True) as run:
# Log some info
run.config.learning_rate = learning_rate
run.config.optimizer = optimizer.__class__.__name__
run.watch(model.net)
#run.watch(model.net)
# training loop
# last model was stored at epoch last_epoch, we continue training from there, i.e. last_epoch+1 (else we start at epoch 0)
......@@ -178,20 +218,14 @@ def ddpm_trainer(model,
loss.backward()
optimizer.step()
if ema_training:
ema.update(model)
scheduler.step()
if verbose:
print(f"Loss in epoch {epoch}:{running_trainloss/nr_train_batches}")
run.log({'running_loss': running_trainloss/nr_train_batches})
# WORKING OLD VERSION
#x_t, forward_noise = model.forward_trajectory(x_0,t)
#_, _, pred_noise = model.forward(x_t,t)
#loss = loss_func(forward_noise,pred_noise,t)
#running_trainloss += loss.item()
#nr_train_batches += 1
#run.log({'loss': loss.item(), "learning_rate": scheduler.get_last_lr()[0], 'epoch': epoch, 'batch': idx})
# evaluation
if ((epoch+1) % eval_iter == 0) or ((epoch+1) % store_iter == 0):
running_testloss = 0
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment