Skip to content
Snippets Groups Projects
Commit 1b7f595b authored by Gonzalo Martin Garcia's avatar Gonzalo Martin Garcia
Browse files

Added save states for the EMA model when training from checkpoint

parent 578a3833
No related branches found
No related tags found
No related merge requests found
...@@ -135,6 +135,10 @@ def ddpm_trainer(model, ...@@ -135,6 +135,10 @@ def ddpm_trainer(model,
# set lr cosine schedule (comonly used in diffusion models) # set lr cosine schedule (comonly used in diffusion models)
scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=T_max, eta_min=eta_min) scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=T_max, eta_min=eta_min)
# set ema model for training
if ema_training:
ema = ModelEmaV2(model, decay=decay, device = model.device)
# if checkpoint path is given, load the model from checkpoint # if checkpoint path is given, load the model from checkpoint
last_epoch = -1 last_epoch = -1
if checkpoint: if checkpoint:
...@@ -155,6 +159,9 @@ def ddpm_trainer(model, ...@@ -155,6 +159,9 @@ def ddpm_trainer(model,
scheduler_state_dict = checkpoint['scheduler'] scheduler_state_dict = checkpoint['scheduler']
scheduler.load_state_dict(scheduler_state_dict) scheduler.load_state_dict(scheduler_state_dict)
scheduler.last_epoch = last_epoch scheduler.last_epoch = last_epoch
# load ema model state
if ema_training:
ema.module.load_state_dict(checkpoint['ema'])
except Exception as e: except Exception as e:
print("Error loading checkpoint. Exception: ", e) print("Error loading checkpoint. Exception: ", e)
...@@ -169,10 +176,6 @@ def ddpm_trainer(model, ...@@ -169,10 +176,6 @@ def ddpm_trainer(model,
if model.recon_loss == 'nll': if model.recon_loss == 'nll':
low = 0 low = 0
# EMA
if ema_training:
ema = ModelEmaV2(model, decay=decay, device = model.device)
# Using W&B # Using W&B
with wandb.init(project='Unconditional Landscapes', name=run_name, entity='deep-lab-', id=run_name, resume=True) as run: with wandb.init(project='Unconditional Landscapes', name=run_name, entity='deep-lab-', id=run_name, resume=True) as run:
...@@ -272,6 +275,7 @@ def ddpm_trainer(model, ...@@ -272,6 +275,7 @@ def ddpm_trainer(model,
'model': model.net.state_dict(), 'model': model.net.state_dict(),
'optimizer': optimizer.state_dict(), 'optimizer': optimizer.state_dict(),
'scheduler': scheduler.state_dict(), 'scheduler': scheduler.state_dict(),
'ema' : ema.module.state_dict(),
'running_loss': running_trainloss/nr_train_batches, 'running_loss': running_trainloss/nr_train_batches,
'running_test_loss': running_testloss/nr_test_batches, 'running_test_loss': running_testloss/nr_test_batches,
}, os.path.join(save_dir, f"model_epoch_{epoch}.pth")) }, os.path.join(save_dir, f"model_epoch_{epoch}.pth"))
...@@ -285,6 +289,7 @@ def ddpm_trainer(model, ...@@ -285,6 +289,7 @@ def ddpm_trainer(model,
'model': model.net.state_dict(), 'model': model.net.state_dict(),
'optimizer': optimizer.state_dict(), 'optimizer': optimizer.state_dict(),
'scheduler': scheduler.state_dict(), 'scheduler': scheduler.state_dict(),
'ema' : ema.module.state_dict(),
'running_loss': running_trainloss/nr_train_batches, 'running_loss': running_trainloss/nr_train_batches,
'running_test_loss': running_testloss/nr_test_batches, 'running_test_loss': running_testloss/nr_test_batches,
}, os.path.join(save_dir, f"model_epoch_{final}.pth")) }, os.path.join(save_dir, f"model_epoch_{final}.pth"))
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment