diff --git a/main.py b/main.py
index 2717e977799a281f549c503bd3fa9862e32b93a7..c3e5c9d8aa399ccfc6abcd3a51e3d754081ce990 100644
--- a/main.py
+++ b/main.py
@@ -36,7 +36,7 @@ def train_func(f):
   #model = globals()[meta_setting["modelname"]](**model_setting).to(device)
   #net = torch.compile(model)
   net = UNet2DModel(
-    sample_size=128, 
+    sample_size=64, 
     in_channels=3,  
     out_channels=3,  
     layers_per_block=2,  
diff --git a/models/Framework.py b/models/Framework.py
index 32ffa6c53f671fd9ce994f4b24641a39bbb81fee..2b136d01484ebf1a5ce67ab340b80867f6a406e4 100644
--- a/models/Framework.py
+++ b/models/Framework.py
@@ -68,9 +68,9 @@ class DDPM(nn.Module):
         self.recon_loss = recon_loss 
         self.out_shape = out_shape
         # precomputed for efficiency reasons
-        self.noise_scaler = ((1-alpha)/( self.sqrt_1_minus_alpha_bar))
-        self.mean_scaler = (1/torch.sqrt(self.alpha))
-        self.mse_weight = ((self.beta**2)/(2*self.var*self.alpha*(1-self.alpha_bar)))
+        self.noise_scaler = (1-alpha)/( self.sqrt_1_minus_alpha_bar)
+        self.mean_scaler = 1/torch.sqrt(self.alpha)
+        self.mse_weight = (self.beta**2)/(2*self.var*self.alpha*(1-self.alpha_bar))
           
     @staticmethod
     def linear_schedule(diffusion_steps, beta_1, beta_T, device):
@@ -183,14 +183,14 @@ class DDPM(nn.Module):
         
         Parameters:
         x_0 (tensor): Batch of input images, with color channels assumed to be normalized between [-1,1]
-        t      (int): Timestep, by default goes through full forward trajectory
+        t   (tensor): Batch of timesteps, by default goes through full forward trajectory
     
         Returns:
         x_T           (tensor): Batch of noised images at timestep t
         forward_noise (tensor): Batch of noise parameters from the noise distribution reparametrization used to draw x_T
         '''
         if t is None:
-            t = self.diffusion_steps
+            t = torch.full((x_0.shape[0],), self.diffusion_steps, device = self.device)
         elif torch.any(t == 0):
             raise ValueError("The tensor 't' contains a timestep zero.")
         forward_noise = torch.randn(x_0.shape, device = self.device)
@@ -200,13 +200,13 @@ class DDPM(nn.Module):
     @torch.no_grad()
     def noised_latent(self, forward_noise, x_0, t):
         '''
-        Given a batch of noise parameters, this function recomputes the batch of noised images at timestep t.
+        Given a batch of noise parameters, this function recomputes the batch of noised images at their respective timesteps t.
         This allows us to avoid storing all the intermediate latents x_t along the forward trajectory.
         
         Parameters:
         forward_noise (tensor): Batch of noise parameters from the noise distribution reparametrization used to draw x_t
         x_0           (tensor): Batch of input images, with color channels assumed to be normalized between [-1,1]
-        t                (int): Timestep
+        t             (tensor): Batch of timesteps
     
         Returns:
         x_t           (tensor): Batch of noised images at timestep t
@@ -222,14 +222,14 @@ class DDPM(nn.Module):
         
         Parameters:
         x_0  (tensor): Batch of input images, with color channels assumed to be normalized between [-1,1]
-        t       (int): Timestep
+        t    (tensor): Batch of timesteps
     
         Returns:
         mean (tensor): Batch of means for the complete noise distribution for each image in the batch x_0
         std  (tensor): Batch of std scalars for the complete noise distribution for each image in the batch x_0
         '''
-        mean = self.sqrt_alpha_bar[t-1].view(-1, 1, 1, 1)*x_0
-        std = self.sqrt_1_minus_alpha_bar[t-1].view(-1, 1, 1, 1)
+        mean = self.sqrt_alpha_bar[t-1][:,None,None,None]*x_0
+        std = self.sqrt_1_minus_alpha_bar[t-1][:,None,None,None]
         return mean, std
     
     @torch.no_grad()
@@ -239,14 +239,14 @@ class DDPM(nn.Module):
 
         Parameters:
         x_t_1 (tensor): Batch of noised images at timestep t-1
-        t        (int): Timestep
+        t     (tensor): Batch of timesteps
         
         Returns:
         mean (tensor): Batch of means for the individual noise distribution for each image in the batch x_t_1
         std  (tensor): Batch of std scalars for the individual noise distribution for each image in the batch x_t_1
         '''
-        mean = torch.sqrt(1-self.beta[t-1]).view(-1, 1, 1, 1)*x_t_1
-        std = torch.sqrt(self.beta[t-1]).view(-1, 1, 1, 1)
+        mean = torch.sqrt(1-self.beta[t-1])[:,None,None,None]*x_t_1
+        std = torch.sqrt(self.beta[t-1])[:,None,None,None]
         return mean, std
 
 
@@ -254,12 +254,12 @@ class DDPM(nn.Module):
 
     def reverse_trajectory(self, x_t, t):
         '''
-        Draws a denoised image x_{t-1} by reparametrizing the denoising distribution at time t for the current noised
-        latent x_t.
+        Draws a denoised images x_{t-1} by reparametrizing the denoising distribution at times t for the current noised
+        latents x_t.
 
         Parameters:
         x_t (tensor): Batch of input images, with color channels assumed to be normalized between [-1,1]
-        t      (int): Timestep
+        t   (tensor): Batch of timestep
         
         Returns:
         x_t_1 (tensor): Batch of denoised images at timestep t-1
@@ -271,16 +271,16 @@ class DDPM(nn.Module):
     
     def forward(self, x_t, t):
         '''
-        Passes the current noised image x_t and timestep t through the U-Net in order to compute the 
-        predicted noise, which is later used to determine the current denoising distribution parameters in the 
-        reverse trajectory. 
+        Passes the current noised images x_t and timesteps t through the U-Net in order to compute the 
+        predicted noise, which is later used to determine the current denoising distribution parameters
+        (mean and std) in the reverse trajectory. 
         Since the DDPM class is inheriting from the nn.Module class, this function is required to share 
         the name 'forward'. This naming scheme does not refer to the forward trajectory, but the forward 
         pass of the model itself, which concerns to the reverse trajectory.
         
         Parameters:
         x_t (tensor): Batch of input images, with color channels assumed to be normalized between [-1,1]
-        t      (int): Timestep
+        t   (tensor): Batch of timesteps
         
         Returns:
         mean        (tensor): Batch of means for the complete noise dist. for each image in the batch x_t
@@ -288,8 +288,8 @@ class DDPM(nn.Module):
         pred_noise  (tensor): Predicted noise for each image in the batch x_t
         '''
         pred_noise = self.net(x_t,t,return_dict=False)[0]
-        mean = self.mean_scaler[t-1].view(-1, 1, 1, 1)*(x_t - self.noise_scaler[t-1].view(-1, 1, 1, 1)*pred_noise)
-        std = self.std[t-1].view(-1, 1, 1, 1)
+        mean = self.mean_scaler[t-1][:,None,None,None]*(x_t - self.noise_scaler[t-1][:,None,None,None]*pred_noise)
+        std = self.std[t-1][:,None,None,None]
         return mean, std, pred_noise
     
 
@@ -319,7 +319,7 @@ class DDPM(nn.Module):
             else:
                 noise = torch.zeros(x_0_recon.shape, device=self.device)
             # get denoising dist. param
-            mean, std, _ = self.forward(x_0_recon, t)
+            mean, std, _ = self.forward(x_0_recon, torch.full((x_0_recon.shape[0],), t ,device = self.device))
             # compute the drawn denoised latent at time t
             x_0_recon = mean + std * noise
         return x_0_recon 
@@ -355,7 +355,7 @@ class DDPM(nn.Module):
             else:
                 noise = torch.zeros(x_t_1.shape, device=self.device)
             # get denoising dist. param
-            mean, std, _ = self.forward(x_t_1, t)
+            mean, std, _ = self.forward(x_t_1, torch.full((x_t_1.shape[0],), t ,device = self.device))
             # compute the drawn densoined latent at time t
             x_t_1 = mean + std*noise
         return x_t_1
@@ -372,8 +372,8 @@ class DDPM(nn.Module):
         '''
         # start with an image of pure noise (batch_size 1) and store it as part of the output 
         x_t_1 = torch.randn((1,) + tuple(self.out_shape), device=self.device)
-        x = torch.empty((self.diffusion_steps+1,1,) + tuple(self.out_shape), device=self.device)
-        x[-1] = x_t_1
+        x = torch.empty((self.diffusion_steps+1,) + tuple(self.out_shape), device=self.device)
+        x[-1] = x_t_1.squeeze(0)
         # apply reverse trajectory
         for t in reversed(range(1, self.diffusion_steps+1)):
             # draw noise used in the denoising dist. reparametrization
@@ -382,14 +382,14 @@ class DDPM(nn.Module):
             else:
                 noise = torch.zeros(x_t_1.shape, device=self.device)
             # get denoising dist. param
-            mean, std, _ = self.forward(x_t_1, t)
+            mean, std, _ = self.forward(x_t_1, torch.full((x_t_1.shape[0],), t ,device = self.device))
             # compute the drawn densoined latent at time t
             x_t_1 = mean + std*noise
             # store noised image
-            x[t-1] = x_t_1
-        x_sq = x.squeeze(1)
-        return x_sq
-        #return x
+            x[t-1] = x_t_1.squeeze(0)
+        #x_sq = x.squeeze(1)
+        #return x_sq
+        return x
 
 
     # Loss functions
@@ -407,7 +407,7 @@ class DDPM(nn.Module):
         '''
         Returns the mathematically correct weighted version of the simplified loss.
         '''
-        return self.mse_weight[t-1].view(-1, 1, 1, 1)*F.mse_loss(forward_noise, pred_noise)
+        return self.mse_weight[t-1][:,None,None,None]*F.mse_loss(forward_noise, pred_noise)
     
     
     # If t=0 and self.recon_loss == 'nll'
diff --git a/trainer/train.py b/trainer/train.py
index 13e919158b0d6b272773c107d2ca0a614fa08fc1..80431b0c02342e1dfba41a4b47cb9afc13fa87ce 100644
--- a/trainer/train.py
+++ b/trainer/train.py
@@ -1,6 +1,6 @@
 
 import numpy as np
-
+import copy
 import torch 
 from torch import nn
 from torchvision import datasets,transforms
@@ -10,6 +10,7 @@ import numpy as np
 import torch.nn.functional as F
 import os
 import wandb 
+from copy import deepcopy
 
 device = 'cuda' if torch.cuda.is_available() else 'cpu'
 
@@ -53,6 +54,37 @@ def simple_trainer(model,device,epochs,trainloader,testloader,bs,lr,T,criterion
         print(f"Testloss in step {epoch} :{np.mean(running_testloss)}")
 
 
+# EMA class 
+# Important! This EMA class code is not ours and was taken from the Pytorch Image Models library called timm and performs exponential moving 
+# average on the trained weights for a given models neural net which was suggested by the paper "Improved Denoising Diffusion Probabilistic Models" 
+# by Nichol and Dhariwal to stabilize and improve the training and generalization process.
+# https://github.com/huggingface/pytorch-image-models/blob/main/timm/utils/model_ema.py
+class ModelEmaV2(nn.Module):
+    def __init__(self, model, decay=0.9999, device=None):
+        super(ModelEmaV2, self).__init__()
+        # make a copy of the model for accumulating moving average of weights
+        self.module = deepcopy(model)
+        self.module.eval()
+        self.decay = decay
+        self.device = device  # perform ema on different device from model if set
+        if self.device is not None:
+            self.module.to(device=device)
+
+    def _update(self, model, update_fn):
+        with torch.no_grad():
+            for ema_v, model_v in zip(self.module.state_dict().values(), model.state_dict().values()):
+                if self.device is not None:
+                    model_v = model_v.to(device=self.device)
+                ema_v.copy_(update_fn(ema_v, model_v))
+
+    def update(self, model):
+        self._update(model, update_fn=lambda e, m: self.decay * e + (1. - self.decay) * m)
+
+    def set(self, model):
+        self._update(model, update_fn=lambda e, m: m)
+
+
+
 # Training function for the unconditional diffusion model
 
 def ddpm_trainer(model, 
@@ -70,6 +102,8 @@ def ddpm_trainer(model,
                  experiment_path = None,
                  T_max = 5*10000, # None,
                  eta_min= 1e-5,
+                 ema_training = True,
+                 decay = 0.9999,
                  **args
                  ):
     '''
@@ -89,6 +123,8 @@ def ddpm_trainer(model,
     checkpoint:      Name of the saved pth. file containing the trained weights and biases
     T_max:           CosineAnnealingLR scheduler argument (nr of steps in training for a full cycle) 
     eta_min:         CosineAnnealingLR scheduler argument (scheduler oscillates between highest lr 'leraning_rate' and minimum lr 'eta_min')
+    decay:           EMA decay rate that is used to weight the effect of the ema model when computing the weighted avg between trained and
+                     ema weights for the networks weight update 
     '''
 
     # set optimizer parameters and learning rate
@@ -133,13 +169,17 @@ def ddpm_trainer(model,
     if model.recon_loss == 'nll':
         low = 0
              
+    # EMA
+    if ema_training:
+        ema = ModelEmaV2(model, decay=decay, device = model.device)
+
     # Using W&B
     with wandb.init(project='test-project', name=run_name, entity='gonzalomartingarcia0', id=run_name, resume=True) as run:
         
         # Log some info
         run.config.learning_rate = learning_rate
         run.config.optimizer = optimizer.__class__.__name__
-        run.watch(model.net)
+        #run.watch(model.net)
         
         # training loop
         # last model was stored at epoch last_epoch, we continue training from there, i.e. last_epoch+1 (else we start at epoch 0)
@@ -178,20 +218,14 @@ def ddpm_trainer(model,
 
                 loss.backward()
                 optimizer.step()
+                if ema_training:
+                    ema.update(model)
                 scheduler.step()
-                
+
             if verbose:
                 print(f"Loss in epoch {epoch}:{running_trainloss/nr_train_batches}")
-                run.log({'running_loss': running_trainloss/nr_train_batches})
+            run.log({'running_loss': running_trainloss/nr_train_batches})
 
-                # WORKING OLD VERSION
-                #x_t, forward_noise = model.forward_trajectory(x_0,t)
-                #_, _, pred_noise = model.forward(x_t,t)
-                #loss = loss_func(forward_noise,pred_noise,t)
-                #running_trainloss += loss.item()
-                #nr_train_batches += 1
-                #run.log({'loss': loss.item(), "learning_rate": scheduler.get_last_lr()[0], 'epoch': epoch, 'batch': idx})
-                
             # evaluation
             if ((epoch+1) % eval_iter == 0) or ((epoch+1) % store_iter == 0):
                 running_testloss = 0