diff --git a/trainer/train.py b/trainer/train.py index 0d78a6b90745c9ce287d2d4282547229c7725cd9..23eabd1746602049a0b84632a934430be8716317 100644 --- a/trainer/train.py +++ b/trainer/train.py @@ -135,6 +135,10 @@ def ddpm_trainer(model, # set lr cosine schedule (comonly used in diffusion models) scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=T_max, eta_min=eta_min) + # set ema model for training + if ema_training: + ema = ModelEmaV2(model, decay=decay, device = model.device) + # if checkpoint path is given, load the model from checkpoint last_epoch = -1 if checkpoint: @@ -155,6 +159,9 @@ def ddpm_trainer(model, scheduler_state_dict = checkpoint['scheduler'] scheduler.load_state_dict(scheduler_state_dict) scheduler.last_epoch = last_epoch + # load ema model state + if ema_training: + ema.module.load_state_dict(checkpoint['ema']) except Exception as e: print("Error loading checkpoint. Exception: ", e) @@ -169,12 +176,8 @@ def ddpm_trainer(model, if model.recon_loss == 'nll': low = 0 - # EMA - if ema_training: - ema = ModelEmaV2(model, decay=decay, device = model.device) - # Using W&B - with wandb.init(project='Unconditional Landscapes', name=run_name, entity='deep-lab-', id=run_name, resume=True) as run: + with wandb.init(project='Unconditional Landscapes', name=run_name, entity='deep-lab-', id=run_name, resume=True) as run: # Log some info run.config.learning_rate = learning_rate @@ -272,6 +275,7 @@ def ddpm_trainer(model, 'model': model.net.state_dict(), 'optimizer': optimizer.state_dict(), 'scheduler': scheduler.state_dict(), + 'ema' : ema.module.state_dict(), 'running_loss': running_trainloss/nr_train_batches, 'running_test_loss': running_testloss/nr_test_batches, }, os.path.join(save_dir, f"model_epoch_{epoch}.pth")) @@ -285,6 +289,7 @@ def ddpm_trainer(model, 'model': model.net.state_dict(), 'optimizer': optimizer.state_dict(), 'scheduler': scheduler.state_dict(), + 'ema' : ema.module.state_dict(), 'running_loss': running_trainloss/nr_train_batches, 'running_test_loss': running_testloss/nr_test_batches, }, os.path.join(save_dir, f"model_epoch_{final}.pth"))