diff --git a/trainer/train.py b/trainer/train.py
index 0d78a6b90745c9ce287d2d4282547229c7725cd9..23eabd1746602049a0b84632a934430be8716317 100644
--- a/trainer/train.py
+++ b/trainer/train.py
@@ -135,6 +135,10 @@ def ddpm_trainer(model,
     # set lr cosine schedule (comonly used in diffusion models)
     scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=T_max, eta_min=eta_min) 
 
+    # set ema model for training
+    if ema_training:
+        ema = ModelEmaV2(model, decay=decay, device = model.device)
+
     # if checkpoint path is given, load the model from checkpoint
     last_epoch = -1
     if checkpoint:
@@ -155,6 +159,9 @@ def ddpm_trainer(model,
             scheduler_state_dict = checkpoint['scheduler']
             scheduler.load_state_dict(scheduler_state_dict)
             scheduler.last_epoch = last_epoch
+            # load ema model state
+            if ema_training:
+                ema.module.load_state_dict(checkpoint['ema'])     
         except Exception as e:
             print("Error loading checkpoint. Exception: ", e)
             
@@ -169,12 +176,8 @@ def ddpm_trainer(model,
     if model.recon_loss == 'nll':
         low = 0
              
-    # EMA
-    if ema_training:
-        ema = ModelEmaV2(model, decay=decay, device = model.device)
-
     # Using W&B
-    with wandb.init(project='Unconditional Landscapes', name=run_name, entity='deep-lab-', id=run_name, resume=True) as run:
+    with wandb.init(project='Unconditional Landscapes', name=run_name, entity='deep-lab-', id=run_name, resume=True) as run: 
         
         # Log some info
         run.config.learning_rate = learning_rate
@@ -272,6 +275,7 @@ def ddpm_trainer(model,
                         'model': model.net.state_dict(),
                         'optimizer': optimizer.state_dict(),
                         'scheduler': scheduler.state_dict(),
+                        'ema' :  ema.module.state_dict(),
                         'running_loss': running_trainloss/nr_train_batches,
                         'running_test_loss': running_testloss/nr_test_batches,
                     }, os.path.join(save_dir, f"model_epoch_{epoch}.pth"))
@@ -285,6 +289,7 @@ def ddpm_trainer(model,
             'model': model.net.state_dict(),
             'optimizer': optimizer.state_dict(),
             'scheduler': scheduler.state_dict(),
+            'ema' :  ema.module.state_dict(),
             'running_loss': running_trainloss/nr_train_batches,
             'running_test_loss': running_testloss/nr_test_batches,
          }, os.path.join(save_dir, f"model_epoch_{final}.pth"))