Skip to content
Snippets Groups Projects
Select Git revision
  • 910ae79222394141930cd1bf6eafb783b4bd397e
  • main default protected
  • celebAHQ
  • ddpm-diffusers
4 results

train.py

Blame
  • Code owners
    Assign users and groups as approvers for specific file changes. Learn more.
    train.py 13.24 KiB
    import numpy as np
    import copy
    import torch 
    from torch import nn
    from torchvision import datasets,transforms
    from torch.utils.data import DataLoader
    import matplotlib.pyplot as plt
    import numpy as np
    import torch.nn.functional as F
    import os
    import wandb 
    from copy import deepcopy
    
    device = 'cuda' if torch.cuda.is_available() else 'cpu'
    
    # Simple Training function for the unconditional diffusion model
    def simple_trainer(model,device,epochs,trainloader,testloader,bs,lr,T,criterion = nn.MSELoss()):
        criterion.to(device)
        optimizer = torch.optim.AdamW(model.parameters(),lr=lr,)
    
        for epoch in range(epochs):
            model.train()
            running_trainloss = []
            running_testloss = []
            for idx,(x,_) in enumerate(trainloader):
                x = x.to(device)  # has to go to device 
                t = torch.randint(low=0,high=T,size=(1,)).item() # doesn't have to go to device 
                x_t,forward_noise = model.forward_trajectory(x,t)
                optimizer.zero_grad()
                mean,std,pred_noise = model.forward(x_t,t) # changed to forward sinnce model is a  NN module 
                
                loss = criterion(forward_noise,pred_noise)
                loss.backward() 
                optimizer.step()
                trainstep = epoch*bs+idx
                running_trainloss.append(loss.cpu().item())  # MUST be on cou before appending to list 
                
                
            print(f"Trainloss in epoch {epoch}:{np.mean(running_trainloss)}")
                        
            
            model.eval()
            with torch.no_grad():
                for idx,(x,_) in enumerate(testloader):
                    x = x.to(device)
                    t = torch.randint(low=0,high=T,size=(1,)).item() 
                    x_t,forward_noise = model.forward_trajectory(x,t)
                    optimizer.zero_grad()
                    mean,std,pred_noise = model.forward(x_t,t) 
                    loss = criterion(forward_noise,pred_noise)
                    running_testloss.append(loss.cpu().item())
                    
            print(f"Testloss in step {epoch} :{np.mean(running_testloss)}")
    
    
    # EMA class 
    # Important! This EMA class code is not ours and was taken from the Pytorch Image Models library called timm and performs exponential moving 
    # average on the trained weights for a given models neural net which was suggested by the paper "Improved Denoising Diffusion Probabilistic Models" 
    # by Nichol and Dhariwal to stabilize and improve the training and generalization process.
    # https://github.com/huggingface/pytorch-image-models/blob/main/timm/utils/model_ema.py
    class ModelEmaV2(nn.Module):
        def __init__(self, model, decay=0.9999, device=None):
            super(ModelEmaV2, self).__init__()
            # make a copy of the model for accumulating moving average of weights
            self.module = deepcopy(model)
            self.module.eval()
            self.decay = decay
            self.device = device  # perform ema on different device from model if set
            if self.device is not None:
                self.module.to(device=device)
    
        def _update(self, model, update_fn):
            with torch.no_grad():
                for ema_v, model_v in zip(self.module.state_dict().values(), model.state_dict().values()):
                    if self.device is not None:
                        model_v = model_v.to(device=self.device)
                    ema_v.copy_(update_fn(ema_v, model_v))
    
        def update(self, model):
            self._update(model, update_fn=lambda e, m: self.decay * e + (1. - self.decay) * m)
    
        def set(self, model):
            self._update(model, update_fn=lambda e, m: m)
    
    
    
    # Training function for the unconditional diffusion model
    
    def ddpm_trainer(model, 
                     device,
                     trainloader, testloader,
                     store_iter = 10,
                     eval_iter = 10,    
                     epochs = 50,    
                     optimizer_class=torch.optim.AdamW, 
                     optimizer_params=None,
                     learning_rate = 0.001,
                     verbose = False,
                     run_name=None,
                     checkpoint= None,
                     experiment_path = None,
                     T_max = 5*10000, # None,
                     eta_min= 1e-5,
                     ema_training = True,
                     decay = 0.9999,
                     **args
                     ):
        '''
        model:           Properly initialized DDPM model.
        store_iter:      Stores the trained DDPM every store_iter epochs.
        experiment_path: Path to the models experiment folder, where the trained model will be stored every store_iter epochs
        eval_iter:       Evaluates the trained DDPM on testing data every eval_iter epochs.
        epochs:          Number of epochs we train the model further. 
        optimizer_class: PyTorch optimizer.
        optimizer_param: Parameters for the PyTorch optimizer.
        learning_rate:   For optimizer initialization when training from zero, i.e. no checkpoint
        verbose:         If True, prints the running losses for every epoch.
        run_name:        Run name for WandB. IF YOU TRAIN FROM CHECKPOINT MAKE SURE TO USE THE SAME
                         'run_name' FOR THE DATA TO BE LOGGED ON THE SAME WANDB RUN!
        trainloader:     Loads the train dataset
        testloader:      Loads the test dataset
        checkpoint:      Name of the saved pth. file containing the trained weights and biases
        T_max:           CosineAnnealingLR scheduler argument (nr of steps in training for a full cycle) 
        eta_min:         CosineAnnealingLR scheduler argument (scheduler oscillates between highest lr 'leraning_rate' and minimum lr 'eta_min')
        decay:           EMA decay rate that is used to weight the effect of the ema model when computing the weighted avg between trained and
                         ema weights for the networks weight update 
        '''
    
        # set optimizer parameters and learning rate
        if optimizer_params is None:
            optimizer_params = dict(lr=learning_rate)
        optimizer = optimizer_class(model.net.parameters(), **optimizer_params)
        
        # set lr cosine schedule (comonly used in diffusion models)
        scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=T_max, eta_min=eta_min) 
    
        # set ema model for training
        if ema_training:
            ema = ModelEmaV2(model, decay=decay, device = model.device)
    
        # if checkpoint path is given, load the model from checkpoint
        last_epoch = -1
        if checkpoint:
            try:
                checkpoint_path = f'{experiment_path}trained_ddpm/{checkpoint}'
                # Load the checkpoint 
                checkpoint = torch.load(checkpoint_path)
                # update last_epoch
                last_epoch = checkpoint['epoch']
                # load weights and biases of the U-Net
                model_state_dict = checkpoint['model']
                model.net.load_state_dict(model_state_dict)
                model = model.to(device)
                # load optimizer state
                optimizer_state_dict = checkpoint['optimizer']
                optimizer.load_state_dict(optimizer_state_dict)
                # load learning rate schedule state
                scheduler_state_dict = checkpoint['scheduler']
                scheduler.load_state_dict(scheduler_state_dict)
                scheduler.last_epoch = last_epoch
                # load ema model state
                if ema_training:
                    ema.module.load_state_dict(checkpoint['ema'])     
            except Exception as e:
                print("Error loading checkpoint. Exception: ", e)
                
        # pick kl loss function
        if model.kl_loss == 'weighted':
            loss_func = model.loss_weighted
        else:
            loss_func = model.loss_simplified
            
        # pick lowest timestep
        low = 1
        if model.recon_loss == 'nll':
            low = 0
                 
        # Using W&B
        with wandb.init(project='Unconditional Landscapes', name=run_name, entity='deep-lab-', id=run_name, resume=True) as run: 
            
            # Log some info
            run.config.learning_rate = learning_rate
            run.config.optimizer = optimizer.__class__.__name__
            #run.watch(model.net)
            
            # training loop
            # last model was stored at epoch last_epoch, we continue training from there, i.e. last_epoch+1 (else we start at epoch 0)
            for epoch in range(last_epoch+1, (last_epoch+1)+epochs): 
                running_trainloss = 0
                nr_train_batches = 0
    
                # train
                model.net.train()
                for idx,(x_0, _) in enumerate(trainloader):
                    x_0 = x_0.to(device)
                    t = torch.randint(low=low, high=model.diffusion_steps, size=(x_0.shape[0],), device = device)
                    optimizer.zero_grad()
                                        
                    # Define masks for zero and non-zero elements of t
                    mask_zero_t = (t == 0)
                    mask_non_zero_t = (t != 0)
                    
                    t[mask_zero_t] = 1
                    x_t, forward_noise = model.forward_trajectory(x_0,t)
                    mean, std, pred_noise = model.forward(x_t,t)                
    
                    loss = 0
                    # Compute kl loss
                    if torch.any(mask_non_zero_t):
                        loss = loss_func(forward_noise[mask_non_zero_t], pred_noise[mask_non_zero_t], t[mask_non_zero_t])
                        running_trainloss += loss.item()
                        nr_train_batches += 1
                        run.log({'loss': loss.item(), "learning_rate": scheduler.get_last_lr()[0], 'epoch': epoch, 'batch': idx})
    
                    # If reconstrcution loss was drawn      
                    if torch.any(mask_zero_t):
                        recon_loss = model.loss_recon(x_0[mask_zero_t], mean[mask_zero_t], std[mask_zero_t])
                        loss += recon_loss
                        run.log({'recon_loss': recon_loss.item(), 'epoch': epoch, 'batch': idx})
    
                    loss.backward()
                    optimizer.step()
                    if ema_training:
                        ema.update(model)
                    scheduler.step()
    
                if verbose:
                    print(f"Loss in epoch {epoch}:{running_trainloss/nr_train_batches}")
                run.log({'running_loss': running_trainloss/nr_train_batches})
    
                # evaluation
                if ((epoch+1) % eval_iter == 0) or ((epoch+1) % store_iter == 0):
                    running_testloss = 0
                    nr_test_batches = 0
                    
                    model.net.eval()
                    with torch.no_grad():
                        for idx,(x_0,_) in enumerate(testloader):
                            x_0 = x_0.to(device)
                            t = torch.randint(low=low, high=model.diffusion_steps, size=(x_0.shape[0],), device = device)
                            
                            # Define masks for zero and non-zero elements of t
                            mask_zero_t = (t == 0)
                            mask_non_zero_t = (t != 0)
    
                            t[mask_zero_t] = 1
                            x_t, forward_noise = model.forward_trajectory(x_0,t)
                            mean, std, pred_noise = model.forward(x_t,t)
    
                            loss = 0
                            # Compute kl loss
                            if torch.any(mask_non_zero_t):
                                loss = loss_func(forward_noise[mask_non_zero_t], pred_noise[mask_non_zero_t], t[mask_non_zero_t])
                                running_testloss += loss.item()
                                nr_test_batches += 1
                                run.log({'test_loss': loss.item(), "learning_rate": scheduler.get_last_lr()[0], 'epoch': epoch, 'batch': idx})
    
                            # If reconstrcution loss was drawn      
                            if torch.any(mask_zero_t):
                                recon_loss = model.loss_recon(x_0[mask_zero_t], mean[mask_zero_t], std[mask_zero_t])
                                loss += recon_loss
                                run.log({'recon_test_loss': recon_loss.item(), 'epoch': epoch, 'batch': idx})
    
                        if verbose:
                            print(f"Test loss in epoch {epoch}:{running_testloss/nr_test_batches}")
                        run.log({'running_test_loss': running_testloss/nr_test_batches})
                        
                    # store model
                    if ((epoch+1) % store_iter == 0):
                        save_dir = os.path.join(experiment_path, 'trained_ddpm/')
                        os.makedirs(save_dir, exist_ok=True)
                        torch.save({
                            'epoch': epoch,
                            'model': model.net.state_dict(),
                            'optimizer': optimizer.state_dict(),
                            'scheduler': scheduler.state_dict(),
                            'ema' :  ema.module.state_dict(),
                            'running_loss': running_trainloss/nr_train_batches,
                            'running_test_loss': running_testloss/nr_test_batches,
                        }, os.path.join(save_dir, f"model_epoch_{epoch}.pth"))
                        
            # always store the last version of the model if we trained through all epochs
            final = (last_epoch+1)+epochs
            save_dir = os.path.join(experiment_path, 'trained_ddpm/')
            os.makedirs(save_dir, exist_ok=True)
            torch.save({
                'epoch': final,
                'model': model.net.state_dict(),
                'optimizer': optimizer.state_dict(),
                'scheduler': scheduler.state_dict(),
                'ema' :  ema.module.state_dict(),
                'running_loss': running_trainloss/nr_train_batches,
                'running_test_loss': running_testloss/nr_test_batches,
             }, os.path.join(save_dir, f"model_epoch_{final}.pth"))