diff --git a/dataloader/load.py b/dataloader/load.py
index 2fb3593a714a6f92b5cf66dde4d3bec72f00a871..11c3e2d7ee0d640f27ecc2d19b3d3635dfc77d80 100644
--- a/dataloader/load.py
+++ b/dataloader/load.py
@@ -43,13 +43,18 @@ class UnconditionalDataset_LHQ(Dataset):
             intermediate_size = 150
             theta = np.pi/4 -np.arccos(intermediate_size/(np.sqrt(2)*img_size)) #Check dataloading.ipynb in analysis-depot for more details
             
-            transform_rotate =  transforms.Compose([transforms.ToTensor(),transforms.Normalize(mean=(0.5,0.5,0.5),std=(0.5,0.5,0.5)),
-                                transforms.Resize(intermediate_size,antialias=True),
-                                transforms.RandomRotation(theta/np.pi*180,interpolation=transforms.InterpolationMode.BILINEAR),
-                                transforms.CenterCrop(img_size),transforms.RandomHorizontalFlip(p=0.5)])
+            transform_rotate =  transforms.Compose([transforms.ToTensor(),
+                                                    transforms.Resize(intermediate_size,antialias=True),
+                                                    transforms.RandomRotation(theta/np.pi*180,interpolation=transforms.InterpolationMode.BILINEAR),
+                                                    transforms.CenterCrop(img_size),
+                                                    transforms.RandomHorizontalFlip(p=0.5), 
+                                                    transforms.Normalize(mean=(0.5,0.5,0.5),std=(0.5,0.5,0.5))])
             
-            transform_randomcrop  =  transforms.Compose([transforms.ToTensor(),transforms.Normalize(mean=(0.5,0.5,0.5),std=(0.5,0.5,0.5)),
-                                transforms.Resize(intermediate_size),transforms.RandomCrop(img_size),transforms.RandomHorizontalFlip(p=0.5)])
+            transform_randomcrop  =  transforms.Compose([transforms.ToTensor(),
+                                                         transforms.Resize(intermediate_size),
+                                                         transforms.RandomCrop(img_size),
+                                                         transforms.RandomHorizontalFlip(p=0.5),
+                                                         transforms.Normalize(mean=(0.5,0.5,0.5),std=(0.5,0.5,0.5))])
 
             self.transform =  transforms.RandomChoice([transform_rotate,transform_randomcrop])
         else : 
@@ -100,16 +105,16 @@ class UnconditionalDataset_CelebAHQ(Dataset):
             theta = np.pi/4 -np.arccos(intermediate_size/(np.sqrt(2)*img_size)) #Check dataloading.ipynb in analysis-depot for more details
              
             transform_rotate_flip =  transforms.Compose([transforms.ToTensor(),
-                                     transforms.Resize(intermediate_size,antialias=True),
-                                     transforms.RandomRotation((theta/np.pi*180),interpolation=transforms.InterpolationMode.BILINEAR),
-                                     transforms.CenterCrop(img_size),
-                                     transforms.RandomHorizontalFlip(p=0.5),
-                                     transforms.Normalize(mean=(0.5,0.5,0.5),std=(0.5,0.5,0.5))])
+                                                         transforms.Resize(intermediate_size,antialias=True),
+                                                         transforms.RandomRotation((theta/np.pi*180),interpolation=transforms.InterpolationMode.BILINEAR),
+                                                         transforms.CenterCrop(img_size),
+                                                         transforms.RandomHorizontalFlip(p=0.5),
+                                                         transforms.Normalize(mean=(0.5,0.5,0.5),std=(0.5,0.5,0.5))])
 
             transform_flip =  transforms.Compose([transforms.ToTensor(),
-                              transforms.Resize(img_size, antialias=True),
-                              transforms.RandomHorizontalFlip(p=0.5),
-                              transforms.Normalize(mean=(0.5,0.5,0.5),std=(0.5,0.5,0.5))])
+                                                  transforms.Resize(img_size, antialias=True),
+                                                  transforms.RandomHorizontalFlip(p=0.5),
+                                                  transforms.Normalize(mean=(0.5,0.5,0.5),std=(0.5,0.5,0.5))])
 
             self.transform =  transforms.RandomChoice([transform_rotate_flip,transform_flip])                                                   
         else :
diff --git a/main.py b/main.py
index e7d5ca4233a6e9805cfd3940f001486649bf1446..0618633fcbad881645933098121cb7a385edf615 100644
--- a/main.py
+++ b/main.py
@@ -6,12 +6,12 @@ from models.Framework import *
 from trainer.train import ddpm_trainer
 from evaluation.sample import ddpm_sampler
 from evaluation.evaluate import ddpm_evaluator
-from models.all_unets import *
+from models.unets import *
 import torch 
 
 
 def train_func(f):
-  #load all settings 
+  # Load Settings 
   device = 'cuda' if torch.cuda.is_available() else 'cpu'
   print(f"device: {device}\n\n")
   print(f"folderpath: {f}\n\n")
@@ -31,19 +31,17 @@ def train_func(f):
       training_setting = json.load(fp)
       training_setting["optimizer_class"] = eval(training_setting["optimizer_class"])
 
-
+  # init dataloaders
   batchsize = meta_setting["batchsize"]
   training_dataset = globals()[meta_setting["dataset"]](train = True,**dataset_setting)
   test_dataset = globals()[meta_setting["dataset"]](train = False,**dataset_setting)
-
   training_dataloader = torch.utils.data.DataLoader( training_dataset,batch_size=batchsize)
   test_dataloader = torch.utils.data.DataLoader(test_dataset,batch_size=batchsize)
-  
-
+  # init UNet
   net = globals()[meta_setting["modelname"]](**model_setting).to(device)  
   #net = torch.compile(net)
   net = net.to(device)
-  
+  # init Diffusion Model
   framework = globals()[meta_setting["framework"]](net = net,device=device, **framework_setting)
   
   print(f"META SETTINGS:\n\n {meta_setting}\n\n")
@@ -77,9 +75,8 @@ def sample_func(f):
 
   # init Unet
   net = globals()[meta_setting["modelname"]](**model_setting).to(device)
-  #net = torch.compile(net)
   net = net.to(device)
-  # init unconditional diffusion model
+  # init Diffusion Model
   framework = globals()[meta_setting["framework"]](net = net,device=device, **framework_setting)
 
   print(f"META SETTINGS:\n\n {meta_setting}\n\n")
diff --git a/models/all_unets.py b/models/unets.py
similarity index 100%
rename from models/all_unets.py
rename to models/unets.py
diff --git a/trainer/train.py b/trainer/train.py
index 5367dcb848f9c29c3d477012660e898f09535b1f..dd882fe8a27c1cec3f3df2f9f5005d57641ca2c2 100644
--- a/trainer/train.py
+++ b/trainer/train.py
@@ -85,7 +85,6 @@ class ModelEmaV2(nn.Module):
 
 
 # Training function for the unconditional diffusion model
-
 def ddpm_trainer(model, 
                  device,
                  trainloader, testloader,
@@ -157,7 +156,7 @@ def ddpm_trainer(model,
             # load learning rate schedule state
             scheduler_state_dict = checkpoint['scheduler']
             scheduler.load_state_dict(scheduler_state_dict)
-            scheduler.last_epoch = last_epoch
+            scheduler.last_epoch = (last_epoch+1)*len(trainloader)
             # load ema model state
             if ema_training:
                 ema.module.load_state_dict(checkpoint['ema'])