Select Git revision
-
Gonzalo Martin Garcia authoredGonzalo Martin Garcia authored
Code owners
Assign users and groups as approvers for specific file changes. Learn more.
train.py 13.24 KiB
import numpy as np
import copy
import torch
from torch import nn
from torchvision import datasets,transforms
from torch.utils.data import DataLoader
import matplotlib.pyplot as plt
import numpy as np
import torch.nn.functional as F
import os
import wandb
from copy import deepcopy
device = 'cuda' if torch.cuda.is_available() else 'cpu'
# Simple Training function for the unconditional diffusion model
def simple_trainer(model,device,epochs,trainloader,testloader,bs,lr,T,criterion = nn.MSELoss()):
criterion.to(device)
optimizer = torch.optim.AdamW(model.parameters(),lr=lr,)
for epoch in range(epochs):
model.train()
running_trainloss = []
running_testloss = []
for idx,(x,_) in enumerate(trainloader):
x = x.to(device) # has to go to device
t = torch.randint(low=0,high=T,size=(1,)).item() # doesn't have to go to device
x_t,forward_noise = model.forward_trajectory(x,t)
optimizer.zero_grad()
mean,std,pred_noise = model.forward(x_t,t) # changed to forward sinnce model is a NN module
loss = criterion(forward_noise,pred_noise)
loss.backward()
optimizer.step()
trainstep = epoch*bs+idx
running_trainloss.append(loss.cpu().item()) # MUST be on cou before appending to list
print(f"Trainloss in epoch {epoch}:{np.mean(running_trainloss)}")
model.eval()
with torch.no_grad():
for idx,(x,_) in enumerate(testloader):
x = x.to(device)
t = torch.randint(low=0,high=T,size=(1,)).item()
x_t,forward_noise = model.forward_trajectory(x,t)
optimizer.zero_grad()
mean,std,pred_noise = model.forward(x_t,t)
loss = criterion(forward_noise,pred_noise)
running_testloss.append(loss.cpu().item())
print(f"Testloss in step {epoch} :{np.mean(running_testloss)}")
# EMA class
# Important! This EMA class code is not ours and was taken from the Pytorch Image Models library called timm and performs exponential moving
# average on the trained weights for a given models neural net which was suggested by the paper "Improved Denoising Diffusion Probabilistic Models"
# by Nichol and Dhariwal to stabilize and improve the training and generalization process.
# https://github.com/huggingface/pytorch-image-models/blob/main/timm/utils/model_ema.py
class ModelEmaV2(nn.Module):
def __init__(self, model, decay=0.9999, device=None):
super(ModelEmaV2, self).__init__()
# make a copy of the model for accumulating moving average of weights
self.module = deepcopy(model)
self.module.eval()
self.decay = decay
self.device = device # perform ema on different device from model if set
if self.device is not None:
self.module.to(device=device)
def _update(self, model, update_fn):
with torch.no_grad():
for ema_v, model_v in zip(self.module.state_dict().values(), model.state_dict().values()):
if self.device is not None:
model_v = model_v.to(device=self.device)
ema_v.copy_(update_fn(ema_v, model_v))
def update(self, model):
self._update(model, update_fn=lambda e, m: self.decay * e + (1. - self.decay) * m)
def set(self, model):
self._update(model, update_fn=lambda e, m: m)
# Training function for the unconditional diffusion model
def ddpm_trainer(model,
device,
trainloader, testloader,
store_iter = 10,
eval_iter = 10,
epochs = 50,
optimizer_class=torch.optim.AdamW,
optimizer_params=None,
learning_rate = 0.001,
verbose = False,
run_name=None,
checkpoint= None,
experiment_path = None,
T_max = 5*10000, # None,
eta_min= 1e-5,
ema_training = True,
decay = 0.9999,
**args
):
'''
model: Properly initialized DDPM model.
store_iter: Stores the trained DDPM every store_iter epochs.
experiment_path: Path to the models experiment folder, where the trained model will be stored every store_iter epochs
eval_iter: Evaluates the trained DDPM on testing data every eval_iter epochs.
epochs: Number of epochs we train the model further.
optimizer_class: PyTorch optimizer.
optimizer_param: Parameters for the PyTorch optimizer.
learning_rate: For optimizer initialization when training from zero, i.e. no checkpoint
verbose: If True, prints the running losses for every epoch.
run_name: Run name for WandB. IF YOU TRAIN FROM CHECKPOINT MAKE SURE TO USE THE SAME
'run_name' FOR THE DATA TO BE LOGGED ON THE SAME WANDB RUN!
trainloader: Loads the train dataset
testloader: Loads the test dataset
checkpoint: Name of the saved pth. file containing the trained weights and biases
T_max: CosineAnnealingLR scheduler argument (nr of steps in training for a full cycle)
eta_min: CosineAnnealingLR scheduler argument (scheduler oscillates between highest lr 'leraning_rate' and minimum lr 'eta_min')
decay: EMA decay rate that is used to weight the effect of the ema model when computing the weighted avg between trained and
ema weights for the networks weight update
'''
# set optimizer parameters and learning rate
if optimizer_params is None:
optimizer_params = dict(lr=learning_rate)
optimizer = optimizer_class(model.net.parameters(), **optimizer_params)
# set lr cosine schedule (comonly used in diffusion models)
scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=T_max, eta_min=eta_min)
# set ema model for training
if ema_training:
ema = ModelEmaV2(model, decay=decay, device = model.device)
# if checkpoint path is given, load the model from checkpoint
last_epoch = -1
if checkpoint:
try:
checkpoint_path = f'{experiment_path}trained_ddpm/{checkpoint}'
# Load the checkpoint
checkpoint = torch.load(checkpoint_path)
# update last_epoch
last_epoch = checkpoint['epoch']
# load weights and biases of the U-Net
model_state_dict = checkpoint['model']
model.net.load_state_dict(model_state_dict)
model = model.to(device)
# load optimizer state
optimizer_state_dict = checkpoint['optimizer']
optimizer.load_state_dict(optimizer_state_dict)
# load learning rate schedule state
scheduler_state_dict = checkpoint['scheduler']
scheduler.load_state_dict(scheduler_state_dict)
scheduler.last_epoch = last_epoch
# load ema model state
if ema_training:
ema.module.load_state_dict(checkpoint['ema'])
except Exception as e:
print("Error loading checkpoint. Exception: ", e)
# pick kl loss function
if model.kl_loss == 'weighted':
loss_func = model.loss_weighted
else:
loss_func = model.loss_simplified
# pick lowest timestep
low = 1
if model.recon_loss == 'nll':
low = 0
# Using W&B
with wandb.init(project='Unconditional Landscapes', name=run_name, entity='deep-lab-', id=run_name, resume=True) as run:
# Log some info
run.config.learning_rate = learning_rate
run.config.optimizer = optimizer.__class__.__name__
#run.watch(model.net)
# training loop
# last model was stored at epoch last_epoch, we continue training from there, i.e. last_epoch+1 (else we start at epoch 0)
for epoch in range(last_epoch+1, (last_epoch+1)+epochs):
running_trainloss = 0
nr_train_batches = 0
# train
model.net.train()
for idx,(x_0, _) in enumerate(trainloader):
x_0 = x_0.to(device)
t = torch.randint(low=low, high=model.diffusion_steps, size=(x_0.shape[0],), device = device)
optimizer.zero_grad()
# Define masks for zero and non-zero elements of t
mask_zero_t = (t == 0)
mask_non_zero_t = (t != 0)
t[mask_zero_t] = 1
x_t, forward_noise = model.forward_trajectory(x_0,t)
mean, std, pred_noise = model.forward(x_t,t)
loss = 0
# Compute kl loss
if torch.any(mask_non_zero_t):
loss = loss_func(forward_noise[mask_non_zero_t], pred_noise[mask_non_zero_t], t[mask_non_zero_t])
running_trainloss += loss.item()
nr_train_batches += 1
run.log({'loss': loss.item(), "learning_rate": scheduler.get_last_lr()[0], 'epoch': epoch, 'batch': idx})
# If reconstrcution loss was drawn
if torch.any(mask_zero_t):
recon_loss = model.loss_recon(x_0[mask_zero_t], mean[mask_zero_t], std[mask_zero_t])
loss += recon_loss
run.log({'recon_loss': recon_loss.item(), 'epoch': epoch, 'batch': idx})
loss.backward()
optimizer.step()
if ema_training:
ema.update(model)
scheduler.step()
if verbose:
print(f"Loss in epoch {epoch}:{running_trainloss/nr_train_batches}")
run.log({'running_loss': running_trainloss/nr_train_batches})
# evaluation
if ((epoch+1) % eval_iter == 0) or ((epoch+1) % store_iter == 0):
running_testloss = 0
nr_test_batches = 0
model.net.eval()
with torch.no_grad():
for idx,(x_0,_) in enumerate(testloader):
x_0 = x_0.to(device)
t = torch.randint(low=low, high=model.diffusion_steps, size=(x_0.shape[0],), device = device)
# Define masks for zero and non-zero elements of t
mask_zero_t = (t == 0)
mask_non_zero_t = (t != 0)
t[mask_zero_t] = 1
x_t, forward_noise = model.forward_trajectory(x_0,t)
mean, std, pred_noise = model.forward(x_t,t)
loss = 0
# Compute kl loss
if torch.any(mask_non_zero_t):
loss = loss_func(forward_noise[mask_non_zero_t], pred_noise[mask_non_zero_t], t[mask_non_zero_t])
running_testloss += loss.item()
nr_test_batches += 1
run.log({'test_loss': loss.item(), "learning_rate": scheduler.get_last_lr()[0], 'epoch': epoch, 'batch': idx})
# If reconstrcution loss was drawn
if torch.any(mask_zero_t):
recon_loss = model.loss_recon(x_0[mask_zero_t], mean[mask_zero_t], std[mask_zero_t])
loss += recon_loss
run.log({'recon_test_loss': recon_loss.item(), 'epoch': epoch, 'batch': idx})
if verbose:
print(f"Test loss in epoch {epoch}:{running_testloss/nr_test_batches}")
run.log({'running_test_loss': running_testloss/nr_test_batches})
# store model
if ((epoch+1) % store_iter == 0):
save_dir = os.path.join(experiment_path, 'trained_ddpm/')
os.makedirs(save_dir, exist_ok=True)
torch.save({
'epoch': epoch,
'model': model.net.state_dict(),
'optimizer': optimizer.state_dict(),
'scheduler': scheduler.state_dict(),
'ema' : ema.module.state_dict(),
'running_loss': running_trainloss/nr_train_batches,
'running_test_loss': running_testloss/nr_test_batches,
}, os.path.join(save_dir, f"model_epoch_{epoch}.pth"))
# always store the last version of the model if we trained through all epochs
final = (last_epoch+1)+epochs
save_dir = os.path.join(experiment_path, 'trained_ddpm/')
os.makedirs(save_dir, exist_ok=True)
torch.save({
'epoch': final,
'model': model.net.state_dict(),
'optimizer': optimizer.state_dict(),
'scheduler': scheduler.state_dict(),
'ema' : ema.module.state_dict(),
'running_loss': running_trainloss/nr_train_batches,
'running_test_loss': running_testloss/nr_test_batches,
}, os.path.join(save_dir, f"model_epoch_{final}.pth"))