diff --git a/experiment_creator.ipynb b/experiment_creator.ipynb
index c877ff2dac828a17f2b46056fd3f482ca2dd9de2..202045fdad5b4e1e674b59c27fd12c4383265ea6 100644
--- a/experiment_creator.ipynb
+++ b/experiment_creator.ipynb
@@ -46,8 +46,8 @@
     "datapath = \"/work/lect0100/lhq_256\"\n",
     "\n",
     "# Experiment setup\n",
-    "run_name = 'big_diffusers_fin' # WANDB and experiment folder Name!\n",
-    "checkpoint = 'model_epoch_8.pth' # Name of checkpoint pth file or None \n",
+    "run_name = 'batch_timesteps' # WANDB and experiment folder Name!\n",
+    "checkpoint = None #'model_epoch_8.pth' # Name of checkpoint pth file or None \n",
     "experiment_path = '/work/lect0100/experiments_gonzalo/'+ run_name +'/'\n",
     "\n",
     "# Path to save generated experiment folder on local machine\n",
@@ -55,16 +55,15 @@
     "\n",
     "# Diffusion Model Settings\n",
     "diffusion_steps = 200\n",
-    "image_size = 128\n",
+    "image_size = 64\n",
     "channels = 3\n",
     "\n",
     "# Training\n",
-    "batchsize = 8\n",
+    "batchsize = 32\n",
     "epochs = 30\n",
-    "store_iter = 3\n",
+    "store_iter = 1\n",
     "eval_iter = 500\n",
     "learning_rate = 0.0001\n",
-    "lr_schedule = False\n",
     "optimizername = \"torch.optim.AdamW\"\n",
     "optimizer_params = None\n",
     "verbose = True\n",
@@ -133,11 +132,12 @@
     "                optimizer_params = optimizer_params,\n",
     "                #optimizer_params=dict(lr=learning_rate), # don't change! \n",
     "                learning_rate = learning_rate,\n",
-    "                lr_schedule = lr_schedule,\n",
     "                run_name=run_name,\n",
     "                checkpoint= checkpoint,\n",
     "                experiment_path = experiment_path,\n",
     "                verbose = verbose,\n",
+    "                T_max = 5*10000, # cosine lr param\n",
+    "                eta_min= 1e-5, # cosine lr param\n",
     "                )\n",
     "sampling_setting = dict( \n",
     "                checkpoint = checkpoint, \n",
@@ -164,13 +164,13 @@
       "create folder\n",
       "folder created \n",
       "stored json files in folder\n",
-      "{'modelname': 'UNet_Unconditional_Diffusion_Bottleneck_Variant', 'dataset': 'UnconditionalDataset', 'framework': 'DDPM', 'trainloop_function': 'ddpm_trainer', 'sampling_function': 'ddpm_sampler', 'evaluation_function': 'ddpm_evaluator', 'batchsize': 8}\n",
-      "{'fpath': '/work/lect0100/lhq_256', 'img_size': 128, 'frac': 0.8, 'skip_first_n': 0, 'ext': '.png', 'transform': True}\n",
-      "{'channels_in': 3, 'channels_out': 3, 'activation': 'relu', 'weight_init': 'he', 'projection_features': 64, 'time_dim': 8, 'time_channels': 200, 'num_stages': 4, 'stage_list': None, 'num_blocks': 1, 'num_groupnorm_groups': 32, 'dropout': 0.1, 'attention_list': None, 'num_attention_heads': 1}\n",
-      "{'diffusion_steps': 200, 'out_shape': (3, 128, 128), 'noise_schedule': 'linear', 'beta_1': 0.0001, 'beta_T': 0.02, 'alpha_bar_lower_bound': 0.9, 'var_schedule': 'same', 'kl_loss': 'simplified', 'recon_loss': 'nll'}\n",
-      "{'epochs': 30, 'store_iter': 3, 'eval_iter': 500, 'optimizer_class': 'torch.optim.AdamW', 'optimizer_params': None, 'learning_rate': 0.0001, 'lr_schedule': False, 'run_name': 'big_diffusers_fin', 'checkpoint': 'model_epoch_8.pth', 'experiment_path': '/work/lect0100/experiments_gonzalo/big_diffusers_fin/', 'verbose': True}\n",
-      "{'checkpoint': 'model_epoch_8.pth', 'experiment_path': '/work/lect0100/experiments_gonzalo/big_diffusers_fin/', 'batch_size': 10, 'intermediate': False}\n",
-      "{'checkpoint': 'model_epoch_8.pth', 'experiment_path': '/work/lect0100/experiments_gonzalo/big_diffusers_fin/'}\n"
+      "{'modelname': 'UNet_Unconditional_Diffusion_Bottleneck_Variant', 'dataset': 'UnconditionalDataset', 'framework': 'DDPM', 'trainloop_function': 'ddpm_trainer', 'sampling_function': 'ddpm_sampler', 'evaluation_function': 'ddpm_evaluator', 'batchsize': 32}\n",
+      "{'fpath': '/work/lect0100/lhq_256', 'img_size': 64, 'frac': 0.8, 'skip_first_n': 0, 'ext': '.png', 'transform': True}\n",
+      "{'channels_in': 3, 'channels_out': 3, 'activation': 'relu', 'weight_init': 'he', 'projection_features': 64, 'time_dim': 32, 'time_channels': 200, 'num_stages': 4, 'stage_list': None, 'num_blocks': 1, 'num_groupnorm_groups': 32, 'dropout': 0.1, 'attention_list': None, 'num_attention_heads': 1}\n",
+      "{'diffusion_steps': 200, 'out_shape': (3, 64, 64), 'noise_schedule': 'linear', 'beta_1': 0.0001, 'beta_T': 0.02, 'alpha_bar_lower_bound': 0.9, 'var_schedule': 'same', 'kl_loss': 'simplified', 'recon_loss': 'nll'}\n",
+      "{'epochs': 30, 'store_iter': 1, 'eval_iter': 500, 'optimizer_class': 'torch.optim.AdamW', 'optimizer_params': None, 'learning_rate': 0.0001, 'run_name': 'batch_timesteps', 'checkpoint': None, 'experiment_path': '/work/lect0100/experiments_gonzalo/batch_timesteps/', 'verbose': True}\n",
+      "{'checkpoint': None, 'experiment_path': '/work/lect0100/experiments_gonzalo/batch_timesteps/', 'batch_size': 10, 'intermediate': False}\n",
+      "{'checkpoint': None, 'experiment_path': '/work/lect0100/experiments_gonzalo/batch_timesteps/'}\n"
      ]
     }
    ],
diff --git a/models/Framework.py b/models/Framework.py
index a25056c46a0a90b45e141c4c1c8d2ece97725c93..32ffa6c53f671fd9ce994f4b24641a39bbb81fee 100644
--- a/models/Framework.py
+++ b/models/Framework.py
@@ -68,9 +68,9 @@ class DDPM(nn.Module):
         self.recon_loss = recon_loss 
         self.out_shape = out_shape
         # precomputed for efficiency reasons
-        self.noise_scaler = (1-alpha)/( self.sqrt_1_minus_alpha_bar)
-        self.mean_scaler = 1/torch.sqrt(self.alpha)
-        self.mse_weight = (self.beta**2)/(2*self.var*self.alpha*(1-self.alpha_bar))
+        self.noise_scaler = ((1-alpha)/( self.sqrt_1_minus_alpha_bar))
+        self.mean_scaler = (1/torch.sqrt(self.alpha))
+        self.mse_weight = ((self.beta**2)/(2*self.var*self.alpha*(1-self.alpha_bar)))
           
     @staticmethod
     def linear_schedule(diffusion_steps, beta_1, beta_T, device):
@@ -191,8 +191,8 @@ class DDPM(nn.Module):
         '''
         if t is None:
             t = self.diffusion_steps
-        elif t == 0:
-            return x_0, torch.zeros(x_0.shape, device = self.device)
+        elif torch.any(t == 0):
+            raise ValueError("The tensor 't' contains a timestep zero.")
         forward_noise = torch.randn(x_0.shape, device = self.device)
         x_T = self.noised_latent(forward_noise, x_0, t) 
         return x_T , forward_noise
@@ -228,8 +228,8 @@ class DDPM(nn.Module):
         mean (tensor): Batch of means for the complete noise distribution for each image in the batch x_0
         std  (tensor): Batch of std scalars for the complete noise distribution for each image in the batch x_0
         '''
-        mean = self.sqrt_alpha_bar[t-1]*x_0
-        std = self.sqrt_1_minus_alpha_bar[t-1]
+        mean = self.sqrt_alpha_bar[t-1].view(-1, 1, 1, 1)*x_0
+        std = self.sqrt_1_minus_alpha_bar[t-1].view(-1, 1, 1, 1)
         return mean, std
     
     @torch.no_grad()
@@ -245,8 +245,8 @@ class DDPM(nn.Module):
         mean (tensor): Batch of means for the individual noise distribution for each image in the batch x_t_1
         std  (tensor): Batch of std scalars for the individual noise distribution for each image in the batch x_t_1
         '''
-        mean = torch.sqrt(1-self.beta[t-1])*x_t_1
-        std = torch.sqrt(self.beta[t-1])
+        mean = torch.sqrt(1-self.beta[t-1]).view(-1, 1, 1, 1)*x_t_1
+        std = torch.sqrt(self.beta[t-1]).view(-1, 1, 1, 1)
         return mean, std
 
 
@@ -288,8 +288,8 @@ class DDPM(nn.Module):
         pred_noise  (tensor): Predicted noise for each image in the batch x_t
         '''
         pred_noise = self.net(x_t,t,return_dict=False)[0]
-        mean = self.mean_scaler[t-1]*(x_t - self.noise_scaler[t-1]*pred_noise)
-        std = self.std[t-1]
+        mean = self.mean_scaler[t-1].view(-1, 1, 1, 1)*(x_t - self.noise_scaler[t-1].view(-1, 1, 1, 1)*pred_noise)
+        std = self.std[t-1].view(-1, 1, 1, 1)
         return mean, std, pred_noise
     
 
@@ -370,10 +370,9 @@ class DDPM(nn.Module):
         Returns:
         x (tensor): Contains the self.diffusion_steps+1 denoised image tensors
         '''
-        # start with an image of pure noise and store it as part of the output 
+        # start with an image of pure noise (batch_size 1) and store it as part of the output 
         x_t_1 = torch.randn((1,) + tuple(self.out_shape), device=self.device)
-        #x_t_1 = torch.randn(self.out_shape, device=self.device)
-        x = torch.empty((self.diffusion_steps+1,) + tuple(self.out_shape), device=self.device)
+        x = torch.empty((self.diffusion_steps+1,1,) + tuple(self.out_shape), device=self.device)
         x[-1] = x_t_1
         # apply reverse trajectory
         for t in reversed(range(1, self.diffusion_steps+1)):
@@ -408,7 +407,7 @@ class DDPM(nn.Module):
         '''
         Returns the mathematically correct weighted version of the simplified loss.
         '''
-        return self.mse_weight[t-1]*F.mse_loss(forward_noise, pred_noise)
+        return self.mse_weight[t-1].view(-1, 1, 1, 1)*F.mse_loss(forward_noise, pred_noise)
     
     
     # If t=0 and self.recon_loss == 'nll'
@@ -418,3 +417,6 @@ class DDPM(nn.Module):
         denoising Gaussian distribution with mean mean_1 and standard deviation std_1.
         '''
         return -torch.distributions.Normal(mean_1, std_1).log_prob(x_0).mean()
+
+
+
diff --git a/trainer/train.py b/trainer/train.py
index b043512f41040a443b40c2fe96d3e91dc48b4d3d..13e919158b0d6b272773c107d2ca0a614fa08fc1 100644
--- a/trainer/train.py
+++ b/trainer/train.py
@@ -64,11 +64,12 @@ def ddpm_trainer(model,
                  optimizer_class=torch.optim.AdamW, 
                  optimizer_params=None,
                  learning_rate = 0.001,
-                 lr_schedule = False,
                  verbose = False,
                  run_name=None,
                  checkpoint= None,
                  experiment_path = None,
+                 T_max = 5*10000, # None,
+                 eta_min= 1e-5,
                  **args
                  ):
     '''
@@ -79,23 +80,25 @@ def ddpm_trainer(model,
     epochs:          Number of epochs we train the model further. 
     optimizer_class: PyTorch optimizer.
     optimizer_param: Parameters for the PyTorch optimizer.
-    learning_rate:   For optimizer initialization or if a checkpoint exists for our manual learning rate.
-    lr_schedule:     If True, manually sets the learning rate of the optimizer to the given one.
+    learning_rate:   For optimizer initialization when training from zero, i.e. no checkpoint
     verbose:         If True, prints the running losses for every epoch.
     run_name:        Run name for WandB. IF YOU TRAIN FROM CHECKPOINT MAKE SURE TO USE THE SAME
                      'run_name' FOR THE DATA TO BE LOGGED ON THE SAME WANDB RUN!
     trainloader:     Loads the train dataset
     testloader:      Loads the test dataset
     checkpoint:      Name of the saved pth. file containing the trained weights and biases
+    T_max:           CosineAnnealingLR scheduler argument (nr of steps in training for a full cycle) 
+    eta_min:         CosineAnnealingLR scheduler argument (scheduler oscillates between highest lr 'leraning_rate' and minimum lr 'eta_min')
     '''
 
     # set optimizer parameters and learning rate
     if optimizer_params is None:
         optimizer_params = dict(lr=learning_rate)
-    else:
-        optimizer_params['lr'] = learning_rate
     optimizer = optimizer_class(model.net.parameters(), **optimizer_params)
     
+    # set lr cosine schedule (comonly used in diffusion models)
+    scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=T_max, eta_min=eta_min) 
+
     # if checkpoint path is given, load the model from checkpoint
     last_epoch = -1
     if checkpoint:
@@ -112,10 +115,10 @@ def ddpm_trainer(model,
             # load optimizer state
             optimizer_state_dict = checkpoint['optimizer']
             optimizer.load_state_dict(optimizer_state_dict)
-            # If we want to decrease the learning rate manually
-            if lr_schedule:
-                for param_group in optimizer.param_groups:
-                    param_group['lr'] = learning_rate
+            # load learning rate schedule state
+            scheduler_state_dict = checkpoint['scheduler']
+            scheduler.load_state_dict(scheduler_state_dict)
+            scheduler.last_epoch = last_epoch
         except Exception as e:
             print("Error loading checkpoint. Exception: ", e)
             
@@ -129,7 +132,7 @@ def ddpm_trainer(model,
     low = 1
     if model.recon_loss == 'nll':
         low = 0
-                 
+             
     # Using W&B
     with wandb.init(project='test-project', name=run_name, entity='gonzalomartingarcia0', id=run_name, resume=True) as run:
         
@@ -137,41 +140,57 @@ def ddpm_trainer(model,
         run.config.learning_rate = learning_rate
         run.config.optimizer = optimizer.__class__.__name__
         run.watch(model.net)
-        # log the learning rate in each run s.t. for checkpoint training we can see at what times the learning rate has been 
-        # manually stepped
-        wandb.log({"learning_rate": learning_rate})
         
         # training loop
         # last model was stored at epoch last_epoch, we continue training from there, i.e. last_epoch+1 (else we start at epoch 0)
         for epoch in range(last_epoch+1, (last_epoch+1)+epochs): 
             running_trainloss = 0
             nr_train_batches = 0
-            
+
             # train
             model.net.train()
             for idx,(x_0, _) in enumerate(trainloader):
                 x_0 = x_0.to(device)
-                t = torch.randint(low=low, high=model.diffusion_steps, size=(1,)).item()
+                t = torch.randint(low=low, high=model.diffusion_steps, size=(x_0.shape[0],), device = device)
                 optimizer.zero_grad()
+                                    
+                # Define masks for zero and non-zero elements of t
+                mask_zero_t = (t == 0)
+                mask_non_zero_t = (t != 0)
                 
-                if t>0:
-                    x_t, forward_noise = model.forward_trajectory(x_0,t)
-                    _, _, pred_noise = model.forward(x_t,t)
-                    loss = loss_func(forward_noise,pred_noise,t)
+                t[mask_zero_t] = 1
+                x_t, forward_noise = model.forward_trajectory(x_0,t)
+                mean, std, pred_noise = model.forward(x_t,t)                
+
+                loss = 0
+                # Compute kl loss
+                if torch.any(mask_non_zero_t):
+                    loss = loss_func(forward_noise[mask_non_zero_t], pred_noise[mask_non_zero_t], t[mask_non_zero_t])
                     running_trainloss += loss.item()
                     nr_train_batches += 1
-                    run.log({'loss': loss.item(), 'epoch': epoch, 'batch': idx})
-                else: # reconstruction loss
-                    x_1, _ = model.forward_trajectory(x_0,1)
-                    mean_1, std_1, _ = model.forward(x_1,1)
-                    loss = model.loss_recon(x_0, mean_1, std_1)
-                    run.log({'recon_loss': loss.item(),'epoch': epoch, 'batch': idx})
-                    
-                loss.backward() 
+                    run.log({'loss': loss.item(), "learning_rate": scheduler.get_last_lr()[0], 'epoch': epoch, 'batch': idx})
+
+                # If reconstrcution loss was drawn      
+                if torch.any(mask_zero_t):
+                    recon_loss = model.loss_recon(x_0[mask_zero_t], mean[mask_zero_t], std[mask_zero_t])
+                    loss += recon_loss
+                    run.log({'recon_loss': recon_loss.item(), 'epoch': epoch, 'batch': idx})
+
+                loss.backward()
                 optimizer.step()
+                scheduler.step()
+                
             if verbose:
                 print(f"Loss in epoch {epoch}:{running_trainloss/nr_train_batches}")
-            run.log({'running_loss': running_trainloss/nr_train_batches})
+                run.log({'running_loss': running_trainloss/nr_train_batches})
+
+                # WORKING OLD VERSION
+                #x_t, forward_noise = model.forward_trajectory(x_0,t)
+                #_, _, pred_noise = model.forward(x_t,t)
+                #loss = loss_func(forward_noise,pred_noise,t)
+                #running_trainloss += loss.item()
+                #nr_train_batches += 1
+                #run.log({'loss': loss.item(), "learning_rate": scheduler.get_last_lr()[0], 'epoch': epoch, 'batch': idx})
                 
             # evaluation
             if ((epoch+1) % eval_iter == 0) or ((epoch+1) % store_iter == 0):
@@ -182,21 +201,30 @@ def ddpm_trainer(model,
                 with torch.no_grad():
                     for idx,(x_0,_) in enumerate(testloader):
                         x_0 = x_0.to(device)
-                        t = torch.randint(low=low,high=model.diffusion_steps, size=(1,)).item()
+                        t = torch.randint(low=low, high=model.diffusion_steps, size=(x_0.shape[0],), device = device)
                         
-                        if t>0:
-                            x_t, forward_noise = model.forward_trajectory(x_0,t)
-                            _, _, pred_noise = model.forward(x_t,t)
-                            loss = loss_func(forward_noise,pred_noise, t)
+                        # Define masks for zero and non-zero elements of t
+                        mask_zero_t = (t == 0)
+                        mask_non_zero_t = (t != 0)
+
+                        t[mask_zero_t] = 1
+                        x_t, forward_noise = model.forward_trajectory(x_0,t)
+                        mean, std, pred_noise = model.forward(x_t,t)
+
+                        loss = 0
+                        # Compute kl loss
+                        if torch.any(mask_non_zero_t):
+                            loss = loss_func(forward_noise[mask_non_zero_t], pred_noise[mask_non_zero_t], t[mask_non_zero_t])
                             running_testloss += loss.item()
                             nr_test_batches += 1
-                            run.log({'test_loss': loss.item(), 'epoch': epoch, 'batch': idx})
-                        else: # reconstruction loss
-                            x_1, _ = model.forward_trajectory(x_0,1)
-                            mean_1, std_1, _ = model.forward(x_1,1)
-                            loss = model.loss_recon(x_0, mean_1, std_1)
-                            run.log({'recon_test_loss': loss.item(), 'epoch': epoch, 'batch': idx})
-                    
+                            run.log({'test_loss': loss.item(), "learning_rate": scheduler.get_last_lr()[0], 'epoch': epoch, 'batch': idx})
+
+                        # If reconstrcution loss was drawn      
+                        if torch.any(mask_zero_t):
+                            recon_loss = model.loss_recon(x_0[mask_zero_t], mean[mask_zero_t], std[mask_zero_t])
+                            loss += recon_loss
+                            run.log({'recon_test_loss': recon_loss.item(), 'epoch': epoch, 'batch': idx})
+
                     if verbose:
                         print(f"Test loss in epoch {epoch}:{running_testloss/nr_test_batches}")
                     run.log({'running_test_loss': running_testloss/nr_test_batches})
@@ -209,11 +237,12 @@ def ddpm_trainer(model,
                         'epoch': epoch,
                         'model': model.net.state_dict(),
                         'optimizer': optimizer.state_dict(),
+                        'scheduler': scheduler.state_dict(),
                         'running_loss': running_trainloss/nr_train_batches,
                         'running_test_loss': running_testloss/nr_test_batches,
                     }, os.path.join(save_dir, f"model_epoch_{epoch}.pth"))
                     
-        # always store the last version of the model
+        # always store the last version of the model if we trained through all epochs
         final = (last_epoch+1)+epochs
         save_dir = os.path.join(experiment_path, 'trained_ddpm/')
         os.makedirs(save_dir, exist_ok=True)
@@ -221,6 +250,9 @@ def ddpm_trainer(model,
             'epoch': final,
             'model': model.net.state_dict(),
             'optimizer': optimizer.state_dict(),
+            'scheduler': scheduler.state_dict(),
             'running_loss': running_trainloss/nr_train_batches,
             'running_test_loss': running_testloss/nr_test_batches,
          }, os.path.join(save_dir, f"model_epoch_{final}.pth"))
+
+