Skip to content
Snippets Groups Projects
Commit 268080df authored by Gonzalo Martin Garcia's avatar Gonzalo Martin Garcia
Browse files

Added CosineAnnealingLR Schedule to the training function. Adapted the...

Added CosineAnnealingLR Schedule to the training function. Adapted the diffusion model class to handle batches of timesteps and updated the training loop to draw a batch t instad of a single integer per batch.
parent 89776614
Branches
No related tags found
No related merge requests found
%% Cell type:code id: tags:
``` python
from trainer.train import *
from dataloader.load import *
from models.Framework import *
from models.unet_unconditional_diffusion import *
import torch
from torch import nn
```
%% Cell type:markdown id: tags:
# Prepare experiment
1. Choose Hyperparameter Settings
2. Run notebook on local maschine to generate experiment folder with the JSON files containing the settings
3. scp experiment folder to the HPC
4. Run Pipeline by adding following to batch file:
- Train Model: &emsp;&emsp;&emsp;&emsp;&emsp; `python main.py train "<absolute path of experiment folder in hpc>"`
- Sample Images: &emsp;&emsp;&emsp; `python main.py sample "<absolute path of experiment folder in hpc>"`
- Evaluate Model: &emsp;&emsp;&emsp; `python main.py evaluate "<absolute path of experiment folder in hpc>"`
%% Cell type:code id: tags:
``` python
import torch
####
# Settings
####
# Dataset path
datapath = "/work/lect0100/lhq_256"
# Experiment setup
run_name = 'big_diffusers_fin' # WANDB and experiment folder Name!
checkpoint = 'model_epoch_8.pth' # Name of checkpoint pth file or None
run_name = 'batch_timesteps' # WANDB and experiment folder Name!
checkpoint = None #'model_epoch_8.pth' # Name of checkpoint pth file or None
experiment_path = '/work/lect0100/experiments_gonzalo/'+ run_name +'/'
# Path to save generated experiment folder on local machine
local_path ="/Users/gonzalo/Desktop/" + run_name + '/settings'
# Diffusion Model Settings
diffusion_steps = 200
image_size = 128
image_size = 64
channels = 3
# Training
batchsize = 8
batchsize = 32
epochs = 30
store_iter = 3
store_iter = 1
eval_iter = 500
learning_rate = 0.0001
lr_schedule = False
optimizername = "torch.optim.AdamW"
optimizer_params = None
verbose = True
# checkpoint = None #(If no checkpoint training, ie. random weights)
# Sampling
sample_size = 10
intermediate = False # True if you want to sample one image and all ist intermediate latents
# Evaluating
...
###
# Advanced Settings Dictionaries
###
meta_setting = dict(modelname = "UNet_Unconditional_Diffusion_Bottleneck_Variant",
dataset = "UnconditionalDataset",
framework = "DDPM",
trainloop_function = "ddpm_trainer",
sampling_function = 'ddpm_sampler',
evaluation_function = 'ddpm_evaluator',
batchsize = batchsize
)
dataset_setting = dict(fpath = datapath,
img_size = image_size,
frac =0.8,
skip_first_n = 0,
ext = ".png",
transform=True
)
model_setting = dict( channels_in=channels,
channels_out =channels ,
activation='relu', # activation function. Options: {'relu', 'leakyrelu', 'selu', 'gelu', 'silu'/'swish'}
weight_init='he', # weight initialization. Options: {'he', 'torch'}
projection_features=64, # number of image features after first convolution layer
time_dim=batchsize, #dont chnage!!!
time_channels=diffusion_steps, # number of time channels #TODO same as diffusion steps?
num_stages=4, # number of stages in contracting/expansive path
stage_list=None, # specify number of features produced by stages
num_blocks=1, # number of ConvResBlock in each contracting/expansive path
num_groupnorm_groups=32, # number of groups used in Group Normalization inside a ConvResBlock
dropout=0.1, # drop-out to be applied inside a ConvResBlock
attention_list=None, # specify MHA pattern across stages
num_attention_heads=1,
)
framework_setting = dict(
diffusion_steps = diffusion_steps, # dont change!!
out_shape = (channels,image_size,image_size), # dont change!!
noise_schedule = 'linear',
beta_1 = 1e-4,
beta_T = 0.02,
alpha_bar_lower_bound = 0.9,
var_schedule = 'same',
kl_loss = 'simplified',
recon_loss = 'nll',
)
training_setting = dict(
epochs = epochs,
store_iter = store_iter,
eval_iter = eval_iter,
optimizer_class=optimizername,
optimizer_params = optimizer_params,
#optimizer_params=dict(lr=learning_rate), # don't change!
learning_rate = learning_rate,
lr_schedule = lr_schedule,
run_name=run_name,
checkpoint= checkpoint,
experiment_path = experiment_path,
verbose = verbose,
T_max = 5*10000, # cosine lr param
eta_min= 1e-5, # cosine lr param
)
sampling_setting = dict(
checkpoint = checkpoint,
experiment_path = experiment_path,
batch_size = sample_size,
intermediate = intermediate
)
# TODO
evaluation_setting = dict(
checkpoint = checkpoint,
experiment_path = experiment_path,
)
```
%% Cell type:code id: tags:
``` python
import os
import json
f = local_path
if os.path.exists(f):
print("path already exists, pick a new name!")
print("break")
else:
print("create folder")
#os.mkdir(f)
os.makedirs(f, exist_ok=True)
print("folder created ")
with open(f+"/meta_setting.json","w+") as fp:
json.dump(meta_setting,fp)
with open(f+"/dataset_setting.json","w+") as fp:
json.dump(dataset_setting,fp)
with open(f+"/model_setting.json","w+") as fp:
json.dump(model_setting,fp)
with open(f+"/framework_setting.json","w+") as fp:
json.dump(framework_setting,fp)
with open(f+"/training_setting.json","w+") as fp:
json.dump(training_setting,fp)
with open(f+"/sampling_setting.json","w+") as fp:
json.dump(sampling_setting,fp)
with open(f+"/evaluation_setting.json","w+") as fp:
json.dump(evaluation_setting,fp)
print("stored json files in folder")
print(meta_setting)
print(dataset_setting)
print(model_setting)
print(framework_setting)
print(training_setting)
print(sampling_setting)
print(evaluation_setting)
```
%% Output
create folder
folder created
stored json files in folder
{'modelname': 'UNet_Unconditional_Diffusion_Bottleneck_Variant', 'dataset': 'UnconditionalDataset', 'framework': 'DDPM', 'trainloop_function': 'ddpm_trainer', 'sampling_function': 'ddpm_sampler', 'evaluation_function': 'ddpm_evaluator', 'batchsize': 8}
{'fpath': '/work/lect0100/lhq_256', 'img_size': 128, 'frac': 0.8, 'skip_first_n': 0, 'ext': '.png', 'transform': True}
{'channels_in': 3, 'channels_out': 3, 'activation': 'relu', 'weight_init': 'he', 'projection_features': 64, 'time_dim': 8, 'time_channels': 200, 'num_stages': 4, 'stage_list': None, 'num_blocks': 1, 'num_groupnorm_groups': 32, 'dropout': 0.1, 'attention_list': None, 'num_attention_heads': 1}
{'diffusion_steps': 200, 'out_shape': (3, 128, 128), 'noise_schedule': 'linear', 'beta_1': 0.0001, 'beta_T': 0.02, 'alpha_bar_lower_bound': 0.9, 'var_schedule': 'same', 'kl_loss': 'simplified', 'recon_loss': 'nll'}
{'epochs': 30, 'store_iter': 3, 'eval_iter': 500, 'optimizer_class': 'torch.optim.AdamW', 'optimizer_params': None, 'learning_rate': 0.0001, 'lr_schedule': False, 'run_name': 'big_diffusers_fin', 'checkpoint': 'model_epoch_8.pth', 'experiment_path': '/work/lect0100/experiments_gonzalo/big_diffusers_fin/', 'verbose': True}
{'checkpoint': 'model_epoch_8.pth', 'experiment_path': '/work/lect0100/experiments_gonzalo/big_diffusers_fin/', 'batch_size': 10, 'intermediate': False}
{'checkpoint': 'model_epoch_8.pth', 'experiment_path': '/work/lect0100/experiments_gonzalo/big_diffusers_fin/'}
{'modelname': 'UNet_Unconditional_Diffusion_Bottleneck_Variant', 'dataset': 'UnconditionalDataset', 'framework': 'DDPM', 'trainloop_function': 'ddpm_trainer', 'sampling_function': 'ddpm_sampler', 'evaluation_function': 'ddpm_evaluator', 'batchsize': 32}
{'fpath': '/work/lect0100/lhq_256', 'img_size': 64, 'frac': 0.8, 'skip_first_n': 0, 'ext': '.png', 'transform': True}
{'channels_in': 3, 'channels_out': 3, 'activation': 'relu', 'weight_init': 'he', 'projection_features': 64, 'time_dim': 32, 'time_channels': 200, 'num_stages': 4, 'stage_list': None, 'num_blocks': 1, 'num_groupnorm_groups': 32, 'dropout': 0.1, 'attention_list': None, 'num_attention_heads': 1}
{'diffusion_steps': 200, 'out_shape': (3, 64, 64), 'noise_schedule': 'linear', 'beta_1': 0.0001, 'beta_T': 0.02, 'alpha_bar_lower_bound': 0.9, 'var_schedule': 'same', 'kl_loss': 'simplified', 'recon_loss': 'nll'}
{'epochs': 30, 'store_iter': 1, 'eval_iter': 500, 'optimizer_class': 'torch.optim.AdamW', 'optimizer_params': None, 'learning_rate': 0.0001, 'run_name': 'batch_timesteps', 'checkpoint': None, 'experiment_path': '/work/lect0100/experiments_gonzalo/batch_timesteps/', 'verbose': True}
{'checkpoint': None, 'experiment_path': '/work/lect0100/experiments_gonzalo/batch_timesteps/', 'batch_size': 10, 'intermediate': False}
{'checkpoint': None, 'experiment_path': '/work/lect0100/experiments_gonzalo/batch_timesteps/'}
%% Cell type:code id: tags:
``` python
```
%% Cell type:code id: tags:
``` python
```
......
......@@ -68,9 +68,9 @@ class DDPM(nn.Module):
self.recon_loss = recon_loss
self.out_shape = out_shape
# precomputed for efficiency reasons
self.noise_scaler = (1-alpha)/( self.sqrt_1_minus_alpha_bar)
self.mean_scaler = 1/torch.sqrt(self.alpha)
self.mse_weight = (self.beta**2)/(2*self.var*self.alpha*(1-self.alpha_bar))
self.noise_scaler = ((1-alpha)/( self.sqrt_1_minus_alpha_bar))
self.mean_scaler = (1/torch.sqrt(self.alpha))
self.mse_weight = ((self.beta**2)/(2*self.var*self.alpha*(1-self.alpha_bar)))
@staticmethod
def linear_schedule(diffusion_steps, beta_1, beta_T, device):
......@@ -191,8 +191,8 @@ class DDPM(nn.Module):
'''
if t is None:
t = self.diffusion_steps
elif t == 0:
return x_0, torch.zeros(x_0.shape, device = self.device)
elif torch.any(t == 0):
raise ValueError("The tensor 't' contains a timestep zero.")
forward_noise = torch.randn(x_0.shape, device = self.device)
x_T = self.noised_latent(forward_noise, x_0, t)
return x_T , forward_noise
......@@ -228,8 +228,8 @@ class DDPM(nn.Module):
mean (tensor): Batch of means for the complete noise distribution for each image in the batch x_0
std (tensor): Batch of std scalars for the complete noise distribution for each image in the batch x_0
'''
mean = self.sqrt_alpha_bar[t-1]*x_0
std = self.sqrt_1_minus_alpha_bar[t-1]
mean = self.sqrt_alpha_bar[t-1].view(-1, 1, 1, 1)*x_0
std = self.sqrt_1_minus_alpha_bar[t-1].view(-1, 1, 1, 1)
return mean, std
@torch.no_grad()
......@@ -245,8 +245,8 @@ class DDPM(nn.Module):
mean (tensor): Batch of means for the individual noise distribution for each image in the batch x_t_1
std (tensor): Batch of std scalars for the individual noise distribution for each image in the batch x_t_1
'''
mean = torch.sqrt(1-self.beta[t-1])*x_t_1
std = torch.sqrt(self.beta[t-1])
mean = torch.sqrt(1-self.beta[t-1]).view(-1, 1, 1, 1)*x_t_1
std = torch.sqrt(self.beta[t-1]).view(-1, 1, 1, 1)
return mean, std
......@@ -288,8 +288,8 @@ class DDPM(nn.Module):
pred_noise (tensor): Predicted noise for each image in the batch x_t
'''
pred_noise = self.net(x_t,t,return_dict=False)[0]
mean = self.mean_scaler[t-1]*(x_t - self.noise_scaler[t-1]*pred_noise)
std = self.std[t-1]
mean = self.mean_scaler[t-1].view(-1, 1, 1, 1)*(x_t - self.noise_scaler[t-1].view(-1, 1, 1, 1)*pred_noise)
std = self.std[t-1].view(-1, 1, 1, 1)
return mean, std, pred_noise
......@@ -370,10 +370,9 @@ class DDPM(nn.Module):
Returns:
x (tensor): Contains the self.diffusion_steps+1 denoised image tensors
'''
# start with an image of pure noise and store it as part of the output
# start with an image of pure noise (batch_size 1) and store it as part of the output
x_t_1 = torch.randn((1,) + tuple(self.out_shape), device=self.device)
#x_t_1 = torch.randn(self.out_shape, device=self.device)
x = torch.empty((self.diffusion_steps+1,) + tuple(self.out_shape), device=self.device)
x = torch.empty((self.diffusion_steps+1,1,) + tuple(self.out_shape), device=self.device)
x[-1] = x_t_1
# apply reverse trajectory
for t in reversed(range(1, self.diffusion_steps+1)):
......@@ -408,7 +407,7 @@ class DDPM(nn.Module):
'''
Returns the mathematically correct weighted version of the simplified loss.
'''
return self.mse_weight[t-1]*F.mse_loss(forward_noise, pred_noise)
return self.mse_weight[t-1].view(-1, 1, 1, 1)*F.mse_loss(forward_noise, pred_noise)
# If t=0 and self.recon_loss == 'nll'
......@@ -418,3 +417,6 @@ class DDPM(nn.Module):
denoising Gaussian distribution with mean mean_1 and standard deviation std_1.
'''
return -torch.distributions.Normal(mean_1, std_1).log_prob(x_0).mean()
......@@ -64,11 +64,12 @@ def ddpm_trainer(model,
optimizer_class=torch.optim.AdamW,
optimizer_params=None,
learning_rate = 0.001,
lr_schedule = False,
verbose = False,
run_name=None,
checkpoint= None,
experiment_path = None,
T_max = 5*10000, # None,
eta_min= 1e-5,
**args
):
'''
......@@ -79,23 +80,25 @@ def ddpm_trainer(model,
epochs: Number of epochs we train the model further.
optimizer_class: PyTorch optimizer.
optimizer_param: Parameters for the PyTorch optimizer.
learning_rate: For optimizer initialization or if a checkpoint exists for our manual learning rate.
lr_schedule: If True, manually sets the learning rate of the optimizer to the given one.
learning_rate: For optimizer initialization when training from zero, i.e. no checkpoint
verbose: If True, prints the running losses for every epoch.
run_name: Run name for WandB. IF YOU TRAIN FROM CHECKPOINT MAKE SURE TO USE THE SAME
'run_name' FOR THE DATA TO BE LOGGED ON THE SAME WANDB RUN!
trainloader: Loads the train dataset
testloader: Loads the test dataset
checkpoint: Name of the saved pth. file containing the trained weights and biases
T_max: CosineAnnealingLR scheduler argument (nr of steps in training for a full cycle)
eta_min: CosineAnnealingLR scheduler argument (scheduler oscillates between highest lr 'leraning_rate' and minimum lr 'eta_min')
'''
# set optimizer parameters and learning rate
if optimizer_params is None:
optimizer_params = dict(lr=learning_rate)
else:
optimizer_params['lr'] = learning_rate
optimizer = optimizer_class(model.net.parameters(), **optimizer_params)
# set lr cosine schedule (comonly used in diffusion models)
scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=T_max, eta_min=eta_min)
# if checkpoint path is given, load the model from checkpoint
last_epoch = -1
if checkpoint:
......@@ -112,10 +115,10 @@ def ddpm_trainer(model,
# load optimizer state
optimizer_state_dict = checkpoint['optimizer']
optimizer.load_state_dict(optimizer_state_dict)
# If we want to decrease the learning rate manually
if lr_schedule:
for param_group in optimizer.param_groups:
param_group['lr'] = learning_rate
# load learning rate schedule state
scheduler_state_dict = checkpoint['scheduler']
scheduler.load_state_dict(scheduler_state_dict)
scheduler.last_epoch = last_epoch
except Exception as e:
print("Error loading checkpoint. Exception: ", e)
......@@ -137,9 +140,6 @@ def ddpm_trainer(model,
run.config.learning_rate = learning_rate
run.config.optimizer = optimizer.__class__.__name__
run.watch(model.net)
# log the learning rate in each run s.t. for checkpoint training we can see at what times the learning rate has been
# manually stepped
wandb.log({"learning_rate": learning_rate})
# training loop
# last model was stored at epoch last_epoch, we continue training from there, i.e. last_epoch+1 (else we start at epoch 0)
......@@ -151,28 +151,47 @@ def ddpm_trainer(model,
model.net.train()
for idx,(x_0, _) in enumerate(trainloader):
x_0 = x_0.to(device)
t = torch.randint(low=low, high=model.diffusion_steps, size=(1,)).item()
t = torch.randint(low=low, high=model.diffusion_steps, size=(x_0.shape[0],), device = device)
optimizer.zero_grad()
if t>0:
# Define masks for zero and non-zero elements of t
mask_zero_t = (t == 0)
mask_non_zero_t = (t != 0)
t[mask_zero_t] = 1
x_t, forward_noise = model.forward_trajectory(x_0,t)
_, _, pred_noise = model.forward(x_t,t)
loss = loss_func(forward_noise,pred_noise,t)
mean, std, pred_noise = model.forward(x_t,t)
loss = 0
# Compute kl loss
if torch.any(mask_non_zero_t):
loss = loss_func(forward_noise[mask_non_zero_t], pred_noise[mask_non_zero_t], t[mask_non_zero_t])
running_trainloss += loss.item()
nr_train_batches += 1
run.log({'loss': loss.item(), 'epoch': epoch, 'batch': idx})
else: # reconstruction loss
x_1, _ = model.forward_trajectory(x_0,1)
mean_1, std_1, _ = model.forward(x_1,1)
loss = model.loss_recon(x_0, mean_1, std_1)
run.log({'recon_loss': loss.item(),'epoch': epoch, 'batch': idx})
run.log({'loss': loss.item(), "learning_rate": scheduler.get_last_lr()[0], 'epoch': epoch, 'batch': idx})
# If reconstrcution loss was drawn
if torch.any(mask_zero_t):
recon_loss = model.loss_recon(x_0[mask_zero_t], mean[mask_zero_t], std[mask_zero_t])
loss += recon_loss
run.log({'recon_loss': recon_loss.item(), 'epoch': epoch, 'batch': idx})
loss.backward()
optimizer.step()
scheduler.step()
if verbose:
print(f"Loss in epoch {epoch}:{running_trainloss/nr_train_batches}")
run.log({'running_loss': running_trainloss/nr_train_batches})
# WORKING OLD VERSION
#x_t, forward_noise = model.forward_trajectory(x_0,t)
#_, _, pred_noise = model.forward(x_t,t)
#loss = loss_func(forward_noise,pred_noise,t)
#running_trainloss += loss.item()
#nr_train_batches += 1
#run.log({'loss': loss.item(), "learning_rate": scheduler.get_last_lr()[0], 'epoch': epoch, 'batch': idx})
# evaluation
if ((epoch+1) % eval_iter == 0) or ((epoch+1) % store_iter == 0):
running_testloss = 0
......@@ -182,20 +201,29 @@ def ddpm_trainer(model,
with torch.no_grad():
for idx,(x_0,_) in enumerate(testloader):
x_0 = x_0.to(device)
t = torch.randint(low=low,high=model.diffusion_steps, size=(1,)).item()
t = torch.randint(low=low, high=model.diffusion_steps, size=(x_0.shape[0],), device = device)
# Define masks for zero and non-zero elements of t
mask_zero_t = (t == 0)
mask_non_zero_t = (t != 0)
if t>0:
t[mask_zero_t] = 1
x_t, forward_noise = model.forward_trajectory(x_0,t)
_, _, pred_noise = model.forward(x_t,t)
loss = loss_func(forward_noise,pred_noise, t)
mean, std, pred_noise = model.forward(x_t,t)
loss = 0
# Compute kl loss
if torch.any(mask_non_zero_t):
loss = loss_func(forward_noise[mask_non_zero_t], pred_noise[mask_non_zero_t], t[mask_non_zero_t])
running_testloss += loss.item()
nr_test_batches += 1
run.log({'test_loss': loss.item(), 'epoch': epoch, 'batch': idx})
else: # reconstruction loss
x_1, _ = model.forward_trajectory(x_0,1)
mean_1, std_1, _ = model.forward(x_1,1)
loss = model.loss_recon(x_0, mean_1, std_1)
run.log({'recon_test_loss': loss.item(), 'epoch': epoch, 'batch': idx})
run.log({'test_loss': loss.item(), "learning_rate": scheduler.get_last_lr()[0], 'epoch': epoch, 'batch': idx})
# If reconstrcution loss was drawn
if torch.any(mask_zero_t):
recon_loss = model.loss_recon(x_0[mask_zero_t], mean[mask_zero_t], std[mask_zero_t])
loss += recon_loss
run.log({'recon_test_loss': recon_loss.item(), 'epoch': epoch, 'batch': idx})
if verbose:
print(f"Test loss in epoch {epoch}:{running_testloss/nr_test_batches}")
......@@ -209,11 +237,12 @@ def ddpm_trainer(model,
'epoch': epoch,
'model': model.net.state_dict(),
'optimizer': optimizer.state_dict(),
'scheduler': scheduler.state_dict(),
'running_loss': running_trainloss/nr_train_batches,
'running_test_loss': running_testloss/nr_test_batches,
}, os.path.join(save_dir, f"model_epoch_{epoch}.pth"))
# always store the last version of the model
# always store the last version of the model if we trained through all epochs
final = (last_epoch+1)+epochs
save_dir = os.path.join(experiment_path, 'trained_ddpm/')
os.makedirs(save_dir, exist_ok=True)
......@@ -221,6 +250,9 @@ def ddpm_trainer(model,
'epoch': final,
'model': model.net.state_dict(),
'optimizer': optimizer.state_dict(),
'scheduler': scheduler.state_dict(),
'running_loss': running_trainloss/nr_train_batches,
'running_test_loss': running_testloss/nr_test_batches,
}, os.path.join(save_dir, f"model_epoch_{final}.pth"))
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment