Skip to content
Snippets Groups Projects
Commit 1b7f595b authored by Gonzalo Martin Garcia's avatar Gonzalo Martin Garcia
Browse files

Added save states for the EMA model when training from checkpoint

parent 578a3833
Branches
No related tags found
No related merge requests found
......@@ -135,6 +135,10 @@ def ddpm_trainer(model,
# set lr cosine schedule (comonly used in diffusion models)
scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=T_max, eta_min=eta_min)
# set ema model for training
if ema_training:
ema = ModelEmaV2(model, decay=decay, device = model.device)
# if checkpoint path is given, load the model from checkpoint
last_epoch = -1
if checkpoint:
......@@ -155,6 +159,9 @@ def ddpm_trainer(model,
scheduler_state_dict = checkpoint['scheduler']
scheduler.load_state_dict(scheduler_state_dict)
scheduler.last_epoch = last_epoch
# load ema model state
if ema_training:
ema.module.load_state_dict(checkpoint['ema'])
except Exception as e:
print("Error loading checkpoint. Exception: ", e)
......@@ -169,10 +176,6 @@ def ddpm_trainer(model,
if model.recon_loss == 'nll':
low = 0
# EMA
if ema_training:
ema = ModelEmaV2(model, decay=decay, device = model.device)
# Using W&B
with wandb.init(project='Unconditional Landscapes', name=run_name, entity='deep-lab-', id=run_name, resume=True) as run:
......@@ -272,6 +275,7 @@ def ddpm_trainer(model,
'model': model.net.state_dict(),
'optimizer': optimizer.state_dict(),
'scheduler': scheduler.state_dict(),
'ema' : ema.module.state_dict(),
'running_loss': running_trainloss/nr_train_batches,
'running_test_loss': running_testloss/nr_test_batches,
}, os.path.join(save_dir, f"model_epoch_{epoch}.pth"))
......@@ -285,6 +289,7 @@ def ddpm_trainer(model,
'model': model.net.state_dict(),
'optimizer': optimizer.state_dict(),
'scheduler': scheduler.state_dict(),
'ema' : ema.module.state_dict(),
'running_loss': running_trainloss/nr_train_batches,
'running_test_loss': running_testloss/nr_test_batches,
}, os.path.join(save_dir, f"model_epoch_{final}.pth"))
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment