From 5dd1bc7be6ec1b24cc9084e15e40fe1d6589357f Mon Sep 17 00:00:00 2001
From: gonzalomartingarcia0 <gonzalomartingarcia0@gmail.com>
Date: Sat, 22 Jul 2023 18:45:32 +0200
Subject: [PATCH] Combined both versions of the conditional diffusion model,
 now allows for classifier guided diffusion class conditional, and inpainting
 image generation.

---
 .gitignore                          |   3 +-
 dataloader/load.py                  | 108 +++++++++++++++++++++----
 evaluation/sample.py                | 120 ++++++++++++++++++++++++----
 main.py                             |  44 +++++-----
 models/ConditionalDiffusionModel.py |  71 +++++++++-------
 models/conditional_unet.py          |  38 ++++++---
 trainer/train.py                    |  97 +++++++++++-----------
 7 files changed, 335 insertions(+), 146 deletions(-)

diff --git a/.gitignore b/.gitignore
index 339aa3e..df3b0f2 100644
--- a/.gitignore
+++ b/.gitignore
@@ -1,7 +1,8 @@
 .DS_Store
 */__pycache__
-*/trained_ddpm
+*/trained_cdm
 root
+.ipynb_checkpoints
 experiments
 trainer/__pycache__
 wandb
diff --git a/dataloader/load.py b/dataloader/load.py
index a9ce93c..9db88f4 100644
--- a/dataloader/load.py
+++ b/dataloader/load.py
@@ -6,18 +6,16 @@ from PIL import Image
 import pandas as pd
 import numpy as np
 
-class ConditionalDataset(Dataset):
+class ConditionalDataset_AFHQ_Class(Dataset):
     def __init__(self,fpath,img_size,train,frac =0.8,skip_first_n = 0,ext = ".png",transform=True ):
         """
         Args:
-            fpath (string): Path to the folder where images are stored
-            img_size (int): size of output image img_size=height=width
-            ext (string): type of images used(eg .png)
+            fpath (string):   Path to the folder where images are stored
+            img_size (int):   Size of output image img_size=height=width
+            ext (string):     Type of images used(eg .png)
             transform (Bool): Image augmentation for diffusion model
-            skip_first_n: skips the first n values. Usefull for datasets that are sorted by increasing Likeliehood 
-            train (Bool): Choose dataset to be either train set or test set. frac(float) required 
-            frac (float): value within (0,1] (seeded)random shuffles dataset, then divides into train and test set. 
-                            """        
+            train (Bool):     Choose dataset to be either train set or test set. frac(float) required 
+        """        
         if train:
             fpath = os.path.join(fpath, 'train') 
         else:
@@ -32,23 +30,19 @@ class ConditionalDataset(Dataset):
                 if name.endswith(ext):
                     file_list.append(os.path.join(root, name))
                     class_list.append(self.class_to_idx[os.path.basename(root)])
-        #df = pd.DataFrame({"Filepath":file_list},)
-        #self.df = df[df["Filepath"].str.endswith(ext)]
         self.df = pd.DataFrame({"Filepath": file_list})
         self.class_list = class_list     
         
         if transform: 
-            # for training
             intermediate_size = 137
-            theta = np.pi/4 -np.arccos(intermediate_size/(np.sqrt(2)*img_size)) #Check dataloading.ipynb in analysis-depot for more details
-             
+            theta = np.pi/4 -np.arccos(intermediate_size/(np.sqrt(2)*img_size))
+            
             transform_rotate_flip =  transforms.Compose([transforms.ToTensor(),
                                      transforms.Resize(intermediate_size,antialias=True),
                                      transforms.RandomRotation((theta/np.pi*180),interpolation=transforms.InterpolationMode.BILINEAR),
                                      transforms.CenterCrop(img_size),
                                      transforms.RandomHorizontalFlip(p=0.5),
                                      transforms.Normalize(mean=(0.5,0.5,0.5),std=(0.5,0.5,0.5))])
-
             
             transform_flip =  transforms.Compose([transforms.ToTensor(),
                               transforms.Resize(img_size, antialias=True),
@@ -57,10 +51,9 @@ class ConditionalDataset(Dataset):
 
             self.transform =  transforms.RandomChoice([transform_rotate_flip,transform_flip])                                                   
         else :
-            # for evaluation 
             self.transform =  transforms.Compose([transforms.ToTensor(),
-                                transforms.Lambda(lambda x: (x * 255).type(torch.uint8)),
                                 transforms.Resize(img_size)])
+
  
     def __len__(self):
         return len(self.df)
@@ -74,3 +67,86 @@ class ConditionalDataset(Dataset):
     def tensor2PIL(self,img):
         back2pil = transforms.Compose([transforms.Normalize(mean=(-1,-1,-1),std=(2,2,2)),transforms.ToPILImage()])
         return back2pil(img)
+    
+    
+class ConditionalDataset_LHQ_Paint(Dataset):
+    def __init__(self,fpath,img_size,train,frac =0.8,skip_first_n = 0,ext = ".png",transform=True ):
+        """
+        Args:
+            fpath (string):   Path to the folder where images are stored
+            img_size (int):   Size of output image img_size=height=width
+            ext (string):     Type of images used(eg .png)
+            transform (Bool): Image augmentation for diffusion model
+            train (Bool):     Choose dataset to be either train set or test set. frac(float) required 
+            frac (float):     value within (0,1] (seeded)random shuffles dataset, then divides into train and test set. 
+        """        
+     
+        ### Create DataFrame
+        file_list = []
+        for root, dirs, files in os.walk(fpath, topdown=False):
+            for name in sorted(files):
+                file_list.append(os.path.join(root, name))
+
+        df = pd.DataFrame({"Filepath":file_list},)
+        self.df = df[df["Filepath"].str.endswith(ext)] 
+        
+        
+        if train: 
+            df_train = self.df.sample(frac=frac,random_state=2)
+            self.df = df_train
+        else:
+            df_train = self.df.sample(frac=frac,random_state=2)
+            df_test = df.drop(df_train.index)
+            self.df = df_test
+
+        if transform: 
+            intermediate_size = 150
+            theta = np.pi/4 -np.arccos(intermediate_size/(np.sqrt(2)*img_size))
+            
+            transform_rotate =  transforms.Compose([transforms.ToTensor(),
+                                                    transforms.Resize(intermediate_size,antialias=True),
+                                                    transforms.RandomRotation(theta/np.pi*180,interpolation=transforms.InterpolationMode.BILINEAR),
+                                                    transforms.CenterCrop(img_size),
+                                                    transforms.RandomHorizontalFlip(p=0.5),
+                                                    transforms.Normalize(mean=(0.5,0.5,0.5),std=(0.5,0.5,0.5))])
+            
+            transform_randomcrop  =  transforms.Compose([transforms.ToTensor(),
+                                                         transforms.Resize(intermediate_size),
+                                                         transforms.RandomCrop(img_size),
+                                                         transforms.RandomHorizontalFlip(p=0.5), 
+                                                         transforms.Normalize(mean=(0.5,0.5,0.5),std=(0.5,0.5,0.5))])
+
+            self.transform =  transforms.RandomChoice([transform_rotate,transform_randomcrop])                                                  
+        else :
+            self.transform =  transforms.Compose([transforms.ToTensor(),
+                                transforms.Resize(img_size)])
+
+ 
+    def __len__(self):
+        return len(self.df)
+    
+    def __getitem__(self,idx):
+        # get image
+        path =  self.df.iloc[idx].Filepath
+        img = Image.open(path)
+        # apply transformation
+        img_tensor = self.transform(img)
+        # draw random rectangle
+        min_height = 30
+        min_width = 30
+        max_x = img_tensor.shape[1] - min_height
+        max_y = img_tensor.shape[2] - min_width
+        x = np.random.randint(0, max_x)  
+        y = np.random.randint(0, max_y)
+        max_height = min(img_size, img_tensor.shape[1] - x)
+        max_width  = min(img_size, img_tensor.shape[2] - y) 
+        rect_height = torch.randint(min_height, max_height, (1,)).item()
+        rect_width  = torch.randint(min_width, max_width, (1,)).item() 
+        # create copy of image and add blacked out rectangle
+        masked_img = img_tensor.clone()
+        masked_img[:, x:x+rect_height, y:y+rect_width] = -1 
+        return img_tensor, masked_img
+        
+    def tensor2PIL(self,img):
+        back2pil = transforms.Compose([transforms.Normalize(mean=(-1,-1,-1),std=(2,2,2)),transforms.ToPILImage()])
+        return back2pil(img)
diff --git a/evaluation/sample.py b/evaluation/sample.py
index 4d98deb..0d9af73 100644
--- a/evaluation/sample.py
+++ b/evaluation/sample.py
@@ -3,25 +3,26 @@ import torch
 from torchvision import transforms
 import re
 
-def cdm_sampler(model, checkpoint, experiment_path, device, intermediate=False, batch_size=15,sample_all=False,n_times=1):
+def cdm_sampler_afhq_class(model, checkpoint, experiment_path, dataloader, device, intermediate=False, batch_size=15,sample_all=False,n_times=1):
     '''
-    Samples a tensor of 'batch_size' images from a trained diffusion model with 'checkpoint'. The generated 
+    Samples a tensor of 'n_times'*'batch_size' images from a trained diffusion model with 'checkpoint'. The generated 
     images are stored in the directory 'experiment_path/samples/epoch_{e}/sample_{j}. Where e is the epoch 
-    w.r.t. the model which we are sampling form and j is an index separating images from each call of the 
-    sampling function for the given mode.
+    w.r.t. the model which we are sampling form and j is an index separating the sampled batches for every call of this 
+    sampling function.
   
-    model:             Diffusion model
-    checkpoint:        Name of the saved pth. file containing the trained weights and biases
-    conditioning_path: Path to the conditioning input tensors from whihc the samples will be generated from 
-    experiment_path:   Path to the experiment directory where the samples will saved under the diectory samples
-    batch_size:        The number of images to sample
-    intermediate:      Bool value. If False the sampling function will draw a batch of images, else it will just 
-                       sample a single image, but store all the intermediate noised latents along the reverse chain
-    sample_all:        If True, samples a batch of images for the given model at every stored checkpoint at once
-    n_times:           Integer denoting how many times we draw a batch of 'batch_size'. If we want to draw 10k images
-                       the GPU will draw batches of 512 images 20 times to reach this goal.  
+    model:           Diffusion model
+    checkpoint:      Name of the saved pth. file containing the trained weights and biases
+    experiment_path: Path to the experiment directory where the samples will saved under the diectory samples
+    batch_size:      The number of images the model samples
+    dataloaer:       To sample testing conditinoning batches 
+    intermediate:    Bool value. If False the sampling function will draw a batch of images, else it will just 
+                     sample a single image, but store all the intermediate noised latents along the reverse chain
+    sample_all:      If True, samples a batch of images for the given model at every stored checkpoint at once
+    n_times:         Integer denoting how many times we draw a batch of 'batch_size'. If we want to draw 10k images
+                     the GPU will draw batches of 512 images 20 times to reach this goal.  
     '''
 
+
     # If we want to sample from every checkpoint of the current model, recursively call this function for all checkpoints
     if sample_all:
         f = f'{experiment_path}trained_cdm/'
@@ -44,7 +45,6 @@ def cdm_sampler(model, checkpoint, experiment_path, device, intermediate=False,
 
     # create samples directory for the complete experiment (if first time sampling images)
     output_dir = f'{experiment_path}samples/'
-    #output_dir = os.path.join(experiment_path,'/samples/')
     os.makedirs(output_dir, exist_ok=True)
 
     # create sample directory for the current version of the trained model
@@ -53,8 +53,7 @@ def cdm_sampler(model, checkpoint, experiment_path, device, intermediate=False,
     if epoch:
         e = int(epoch[0])
     else:
-        #raise ValueError(f"No digit found in the filename: {filename}")
-        e = 0
+        raise ValueError(f"No digit found in the filename")
     model_dir = os.path.join(output_dir,f'epoch_{e}')
     os.makedirs(model_dir, exist_ok=True)
 
@@ -86,3 +85,90 @@ def cdm_sampler(model, checkpoint, experiment_path, device, intermediate=False,
             except Exception as e:
                 print("Error saving image. Exception:", e)
             run_indx += 1
+            
+            
+def cdm_sampler_lhq_paint(model, checkpoint, experiment_path, dataloader, device, intermediate=False, batch_size=15,sample_all=False,n_times=1):
+    '''
+    Samples a tensor of 'n_times'*'batch_size' images from a trained diffusion model with 'checkpoint'. The generated 
+    images are stored in the directory 'experiment_path/samples/epoch_{e}/sample_{j}. Where e is the epoch 
+    w.r.t. the model which we are sampling form and j is an index separating the sampled batches for every call of this 
+    sampling function.
+  
+    model:           Diffusion model
+    checkpoint:      Name of the saved pth. file containing the trained weights and biases
+    experiment_path: Path to the experiment directory where the samples will saved under the diectory samples
+    dataloader:      Dataloader from which we draw testing images with black rectangles to inpaint on
+    batch_size:      The number of images the model samples
+    dataloaer:       To sample testing conditinoning batches
+    intermediate:    Bool value. If False the sampling function will draw a batch of images, else it will just 
+                     sample a single image, but store all the intermediate noised latents along the reverse chain
+    sample_all:      If True, samples a batch of images for the given model at every stored checkpoint at once
+    n_times:         Integer denoting how many times we draw a batch of 'batch_size'. If we want to draw 10k images
+                     the GPU will draw batches of 512 images 20 times to reach this goal.  
+    '''
+
+
+    # If we want to sample from every checkpoint of the current model, recursively call this function for all checkpoints
+    if sample_all:
+        f = f'{experiment_path}trained_cdm/'
+        checkpoint_list = [checkpoint_i for checkpoint_i in os.listdir(f) if checkpoint_i.endswith(".pth")]
+        for checkpoint_i in os.listdir(f):
+            if checkpoint_i.endswith(".pth"):
+                ldm_sampler(model, checkpoint_i, experiment_path, dataloader, device, batch_size=batch_size, intermediate=intermediate, sample_all=False)
+        return 0
+
+    # load model
+    try:
+        checkpoint_path = f'{experiment_path}trained_cdm/{checkpoint}'
+        checkpoint = torch.load(checkpoint_path)
+        # load weights and biases of the U-Net
+        net_state_dict = checkpoint['model']
+        model.net.load_state_dict(net_state_dict)
+        model = model.to(device)
+    except Exception as e:
+        print("Error loading checkpoint. Exception:", e)
+
+    # create samples directory for the complete experiment
+    output_dir = f'{experiment_path}samples/'
+    os.makedirs(output_dir, exist_ok=True)
+
+    # create sample directory for the current version of the trained model
+    model_name = os.path.basename(checkpoint_path)
+    epoch = re.findall(r'\d+', model_name)
+    if epoch:
+        e = int(epoch[0])
+    else:
+        raise ValueError(f"No digit found in the filename")
+    model_dir = os.path.join(output_dir,f'epoch_{e}')
+    os.makedirs(model_dir, exist_ok=True)
+
+    # create the sample directory for this sampling run for the current version of the model    
+    sample_dir_list = [d for d in os.listdir(model_dir) if os.path.isdir(os.path.join(model_dir, d))]
+    indx_list = [int(d.split('_')[1]) for d in sample_dir_list if d.startswith('sample_')]
+    j = max(indx_list, default=-1) + 1
+    sample_dir = os.path.join(model_dir, f'sample_{j}')
+    os.makedirs(sample_dir, exist_ok=True)
+
+    # transform
+    back2pil = transforms.Compose([transforms.Normalize(mean=(-1,-1,-1),std=(2,2,2)),transforms.ToPILImage()])
+
+    # sample upscaling latent encoding, decode and transform to PIL
+    run_indx=0
+    for idx,(_, y) in enumerate(dataloader):
+        y = y.to(device)
+        generated = model.sample(y=y, batch_size=y.size(0))
+        #torch.save(generated,os.path.join(sample_dir, f"image_tensor{idx}"))
+        # save images
+        for i in range(generated.size(0)):
+            image = back2pil(generated[i])
+            image_path = os.path.join(sample_dir, f'sample_{j}_{run_indx}.png')
+            image_raw = back2pil(y[i])
+            image_path_raw = os.path.join(sample_dir, f'raw_{j}_{run_indx}.png')
+            try:
+                image.save(image_path)
+                image_raw.save(image_path_raw)
+            except Exception as e:
+                print("Error saving image. Exception:", e)
+            run_indx += 1
+        if run_indx>=n_times:
+            break
diff --git a/main.py b/main.py
index 701fa00..e24ad8e 100644
--- a/main.py
+++ b/main.py
@@ -3,13 +3,12 @@ import json
 import sys
 from dataloader.load import  *
 from models.ConditionalDiffusionModel import *
-from trainer.train import cdm_trainer
-from evaluation.sample import cdm_sampler
-from evaluation.evaluate import cdm_evaluator
+from trainer.train import *
+from evaluation.sample import *
+from evaluation.evaluate import *
 from models.conditional_unet import *
 import torch 
 
-
 def train_func(f):
   #load all settings 
   device = 'cuda' if torch.cuda.is_available() else 'cpu'
@@ -31,19 +30,16 @@ def train_func(f):
       training_setting = json.load(fp)
       training_setting["optimizer_class"] = eval(training_setting["optimizer_class"])
 
-
+  # init dataloaders
   batchsize = meta_setting["batchsize"]
   training_dataset = globals()[meta_setting["dataset"]](train = True,**dataset_setting)
   test_dataset = globals()[meta_setting["dataset"]](train = False,**dataset_setting)
-
   training_dataloader = torch.utils.data.DataLoader(training_dataset,batch_size=batchsize,shuffle=True)
   test_dataloader = torch.utils.data.DataLoader(test_dataset,batch_size=batchsize,shuffle=True)
-  
-
+  # init UNet
   net = globals()[meta_setting["modelname"]](**model_setting).to(device)  
-  #net = torch.compile(net)
   net = net.to(device)
-  
+  # init Diffusion Model 
   framework = globals()[meta_setting["framework"]](net = net,device=device, **framework_setting)
   
   print(f"META SETTINGS:\n\n {meta_setting}\n\n")
@@ -74,12 +70,18 @@ def sample_func(f):
 
   with open(f+"/sampling_setting.json","r") as fp:
       sampling_setting = json.load(fp)
-
-  # init Unet
+        
+  with open(f+"/dataset_setting.json","r") as fp:
+      dataset_setting = json.load(fp)    
+  
+  # init dataloader
+  batchsize = sampling_setting["batch_size"]
+  test_dataset = globals()[meta_setting["dataset"]](train = False,**dataset_setting)
+  test_dataloader = torch.utils.data.DataLoader(test_dataset,batch_size=batchsize,shuffle=True)
+  # init UNet
   net = globals()[meta_setting["modelname"]](**model_setting).to(device)
-  #net = torch.compile(net)
   net = net.to(device)
-  # init unconditional diffusion model
+  # init Diffusion model
   framework = globals()[meta_setting["framework"]](net = net,device=device, **framework_setting)
 
   print(f"META SETTINGS:\n\n {meta_setting}\n\n")
@@ -88,7 +90,7 @@ def sample_func(f):
   print(f"SAMPLING SETTINGS:\n\n {sampling_setting}\n\n")
 
   print("\n\nSTART SAMPLING\n\n")
-  globals()[meta_setting["sampling_function"]](model=framework,device=device ,**sampling_setting,)
+  globals()[meta_setting["sampling_function"]](model=framework,device=device ,dataloader=test_dataloader, **sampling_setting,)
   print("\n\nFINISHED SAMPLING\n\n")
 
 
@@ -113,18 +115,15 @@ def evaluate_func(f):
   with open(f+"/dataset_setting.json","r") as fp:
       dataset_setting = json.load(fp)
 
-  # load dataset
+  # init dataloaders
   batchsize = meta_setting["batchsize"]
   test_dataset = globals()[meta_setting["dataset"]](train = False,**dataset_setting)
   #test_dataloader = torch.utils.data.DataLoader(test_dataset,batch_size=len(test_dataset),  shuffle=False)
   test_dataloader = torch.utils.data.DataLoader(test_dataset,batch_size=batchsize,  shuffle=False)
-
-  # init Unet
+  # init UNet
   net = globals()[meta_setting["modelname"]](**model_setting).to(device)
-  #net = torch.compile(net)
   net = net.to(device)
-  
-  # init unconditional diffusion model
+  # init Diffusion Model
   framework = globals()[meta_setting["framework"]](net = net,device=device, **framework_setting)
 
   print(f"META SETTINGS:\n\n {meta_setting}\n\n")
@@ -140,9 +139,8 @@ def evaluate_func(f):
 
 
 
-  
+# run training, sampling or evaluation
 if __name__ == '__main__':
-    
   
   print(sys.argv)
   functions = {'train': train_func,'sample': sample_func,'evaluate': evaluate_func}
diff --git a/models/ConditionalDiffusionModel.py b/models/ConditionalDiffusionModel.py
index 4dbf5cb..24c70b5 100644
--- a/models/ConditionalDiffusionModel.py
+++ b/models/ConditionalDiffusionModel.py
@@ -9,7 +9,6 @@ class CDM(nn.Module):
                  net=None,
                  diffusion_steps = 50, 
                  out_shape = (3,128,128),
-                 conditional_shape = (3),
                  noise_schedule = 'linear', 
                  beta_1 = 1e-4, 
                  beta_T = 0.02,
@@ -17,19 +16,20 @@ class CDM(nn.Module):
                  var_schedule = 'same', 
                  kl_loss = 'simplified', 
                  recon_loss = 'none',
+                 class_free_guidence = False,
                  guidance_score = 3, # for classifier-free guided diffusion
                  device=None):
         '''
         net:                   U-Net
         diffusion_steps:       Length of the Markov chain
         out_shape:             Shape of the models's in- and output images
-        conditional_shape:     Shape of the low resolution image the DM is conditioned on for Super Resolution
-        noise_schedule:        Methods of initialization for the noise dist. variances, 'linear', 'cosine' or bounded_cosine
+        noise_schedule:        Methods of initialization for the noise dist. variances, 'linear', 'cosine' or 'bounded_cosine'
         beta_1, beta_T:        Variances for the first and last noise dist. (only for the 'linear' noise schedule)
         alpha_bar_lower_bound: Upper bound for the varaince of the complete noise dist. (only for the 'cosine_bounded' noise schedule)
         var_schedule:          Options to initialize or learn the denoising dist. variances, 'same', 'true'
         kl_loss:               Choice between the mathematically correct 'weighted' or in practice most commonly used 'simplified' KL loss
         recon_loss:            Is 'none' to ignore the reconstruction loss or 'nll' to compute the negative log likelihood
+        class_free_guidence:   Boolean flag that indicates if classifier-free guided diffusion is used for training and sampling
         '''
         super(CDM,self).__init__()
         self.device = device
@@ -59,6 +59,7 @@ class CDM(nn.Module):
         
         self.net = net
         self.guidance_score = guidance_score
+        self.class_free_guidence = class_free_guidence
         self.diffusion_steps = diffusion_steps
         self.noise_schedule = noise_schedule
         self.var_schedule = var_schedule
@@ -72,7 +73,6 @@ class CDM(nn.Module):
         self.kl_loss = kl_loss 
         self.recon_loss = recon_loss 
         self.out_shape = out_shape
-        self.conditional_shape = conditional_shape
         # precomputed for efficiency reasons
         self.noise_scaler = (1-alpha)/( self.sqrt_1_minus_alpha_bar)
         self.mean_scaler = 1/torch.sqrt(self.alpha)
@@ -273,7 +273,7 @@ class CDM(nn.Module):
         '''
         noise = torch.randn(x_t.shape, device=self.device)
         pred_noise = self.forward(x_t, t, y)
-        mean, std = self.reverse_dist_param(x_t, pred_noise, t) #self.forward(x_t, t, c)        
+        mean, std = self.reverse_dist_param(x_t, pred_noise, t)        
         x_t_1 = mean + std*noise
         return x_t_1
     
@@ -292,9 +292,9 @@ class CDM(nn.Module):
         y (tensor):   Batch of conditional information for each input image
         
         Returns:
-        mean        (tensor): Batch of means for the complete noise dist. for each image in the batch x_t
-        std         (tensor): Batch of std scalars for the complete noise dist. for each image in the batch x_t
-        pred_noise  (tensor): Predicted noise for each image in the batch x_t
+        mean        (tensor): Batch of means for the complete noise dist. for each image in the batch x_t under y
+        std         (tensor): Batch of std scalars for the complete noise dist. for each image in the batch x_t under y
+        pred_noise  (tensor): Predicted noise for each image in the batch x_t under y
         '''
         pred_noise = self.net(x_t,t, y)
         return pred_noise
@@ -308,8 +308,8 @@ class CDM(nn.Module):
         t    (tensor): Batch of timesteps
     
         Returns:
-        mean (tensor): Batch of means for the complete noise distribution for each image in the batch x_0
-        std  (tensor): Batch of std scalars for the complete noise distribution for each image in the batch x_0
+        mean (tensor): Batch of means for the complete noise distribution for each image in the batch x_0 under y
+        std  (tensor): Batch of std scalars for the complete noise distribution for each image in the batch x_0 under y
         '''
         mean = self.mean_scaler[t-1][:,None,None,None]*(x_t - self.noise_scaler[t-1][:,None,None,None]*pred_noise)
         std = self.std[t-1][:,None,None,None]
@@ -331,7 +331,7 @@ class CDM(nn.Module):
         y (tensor):   Batch of conditional information for each input image
 
         Returns:
-        x_0_recon (tensor): Batch of images given by the model reconstruction of x_0
+        x_0_recon (tensor): Batch of images given by the model reconstruction of x_0 under y
         '''
         # apply forward trajectory
         x_0_recon, _ = self.forward_trajectory(x_0)
@@ -345,7 +345,13 @@ class CDM(nn.Module):
             # set timestep batch with all entries as t    
             t_batch = torch.full((x_0_recon.shape[0],), t ,device = self.device)
             # get denoising dist. param
-            pred_noise = self.forward(x_0_recon, t_batch, y)
+            if self.class_free_guidence:
+                # get classififer-free guided diffusion noise parameter
+                pred_noise_cond = self.forward(x_0_recon, t_batch, y) # param with conditioning
+                pred_noise_uncond = self.forward(x_0_recon, t_batch, y=None) # param without conditioning
+                pred_noise = torch.lerp(pred_noise_uncond, pred_noise_cond, self.guidance_score) # linear interpolation of the two
+            else:
+                pred_noise = self.forward(x_0_recon, t_batch, y)
             mean, std = self.reverse_dist_param(x_0_recon, pred_noise, t_batch)
             # compute the drawn denoised latent at time t
             x_0_recon = mean + std * noise
@@ -365,12 +371,12 @@ class CDM(nn.Module):
         batch_size (int): Number of images to be sampled/generated from the diffusion model 
         x_T     (tensor): Input of the reverse trajectory. Batch of noised images usually drawn 
                           from an isotropic Gaussian, but can be set manually if desired.
-        y (tensor):   Batch of conditional information for each input image
+        y (tensor):       Batch of conditional information for each input image
 
         Returns:
-        x_t_1 (tensor): Batch of sampled/generated images
+        x_t_1 (tensor): Batch of sampled/generated images under y
         '''
-        # start with a batch of isotropic noise images (or given arguemnt)
+        # start with a batch of isotropic noise images (or given argument x_T)
         if x_T:
             x_t_1 = x_T
         else:
@@ -382,15 +388,17 @@ class CDM(nn.Module):
                 noise = torch.randn(x_t_1.shape, device=self.device)
             else:
                 noise = torch.zeros(x_t_1.shape, device=self.device)
-            # set timestep batch with all entries as t
+            # get denoising dist. param
             t_batch = torch.full((x_t_1.shape[0],), t ,device = self.device)
-            # get classififer-free guided diffusion noise parameter
-            pred_noise_cond = self.forward(x_t_1, t_batch, y) # param with conditioning
-            pred_noise_uncond = self.forward(x_t_1, t_batch, y=None) # param without conditioning
-            # linear interpolation of the two
-            pred_noise = torch.lerp(pred_noise_uncond, pred_noise_cond, self.guidance_score) 
+            if self.class_free_guidence:
+                # get classififer-free guided diffusion noise parameter
+                pred_noise_cond = self.forward(x_t_1, t_batch, y) # param with conditioning
+                pred_noise_uncond = self.forward(x_t_1, t_batch, y=None) # param without conditioning
+                pred_noise = torch.lerp(pred_noise_uncond, pred_noise_cond, self.guidance_score) # linear interpolation of the two
+            else:
+                pred_noise = self.forward(x_t_1, t_batch, y)
             mean, std = self.reverse_dist_param(x_t_1, pred_noise, t_batch)
-            # compute the drawn densoined latent at time t
+            # compute the drawn denoised latent at time t
             x_t_1 = mean + std*noise
         return x_t_1
 
@@ -418,20 +426,21 @@ class CDM(nn.Module):
                 noise = torch.randn(x_t_1.shape, device=self.device)
             else:
                 noise = torch.zeros(x_t_1.shape, device=self.device)
-            # set timestep batch with all entries as t
+            # get denoising dist. param
             t_batch = torch.full((x_t_1.shape[0],), t ,device = self.device)
-            # get classififer-free guided diffusion noise parameter
-            pred_noise_cond = self.forward(x_t_1, t_batch, y) # param with conditioning
-            pred_noise_uncond = self.forward(x_t_1, t_batch, y=None) # param without conditioning
-            # linear interpolation of the two
-            pred_noise = torch.lerp(pred_noise_uncond, pred_noise_cond, self.guidance_score)
+            if self.class_free_guidence: 
+                # get classififer-free guided diffusion noise parameter
+                pred_noise_cond = self.forward(x_t_1, t_batch, y) # param with conditioning
+                pred_noise_uncond = self.forward(x_t_1, t_batch, y=None) # param without conditioning
+                # linear interpolation of the two
+                pred_noise = torch.lerp(pred_noise_uncond, pred_noise_cond, self.guidance_score)
+            else:
+                pred_noise = self.forward(x_t_1, t_batch, y)
             mean, std = self.reverse_dist_param(x_t_1, pred_noise, t_batch)
-            # compute the drawn densoined latent at time t
+            # compute the drawn denoised latent at time t
             x_t_1 = mean + std*noise
             # store noised image
             x[t-1] = x_t_1.squeeze(0)
-        #x_sq = x.squeeze(1)
-        #return x_sq
         return x
 
 
diff --git a/models/conditional_unet.py b/models/conditional_unet.py
index 8f59e64..6539ec3 100644
--- a/models/conditional_unet.py
+++ b/models/conditional_unet.py
@@ -8,13 +8,17 @@ import numpy as np
 # U-Net model
 class Conditional_UNet_Res(nn.Module):
 
-  def __init__(self, attention,channels_in=3, nr_class=3,n_channels=64,fctr = [1,2,4,4,8],time_dim=256,**args):
+  def __init__(self, attention,channels_in=3, nr_class=3,n_channels=64,fctr = [1,2,4,4,8],time_dim=256,cond = 'class',**args):
     """
-    attention : (Bool) wether to use attention layers or not
-    channels_in : (Int) 
-    n_channels : (Int) Channel size after first convolution
-    fctr : (list) list of factors for further channel size wrt n_channels
-    time_dim : (Int) dimenison size for time end class embeding vector  
+    attention: (Bool)  Whether to use attention layers or not
+    channels_in: (Int) 
+    n_channels:  (Int) Channel size after first convolution
+    fctr: (list)       List of factors for further channel size wrt n_channels
+    time_dim: (Int)    Dimenison size for time end class embeding vector
+    cond: (str)        What conditioning mechanism will be used, 'class' for adding 
+                       class embeddings and 'image' for concatenating the conditional image color 
+                       channel-wise with the input latent.
+    nr_class: (int)    Number of possible classes the images may be conditioned on
     """
     super().__init__()
     channels_out = channels_in
@@ -33,6 +37,9 @@ class Conditional_UNet_Res(nn.Module):
     self.tc_embedder4 = torch.nn.Sequential(nn.Linear(time_dim,fctr[4]),nn.SELU(),nn.Linear(fctr[4],fctr[4]))
     
     # first conv block
+    if cond == 'image':
+        self.first_conv =  nn.Conv2d(channels_in+channels_in,fctr[0],kernel_size=1, padding='same', bias=True)
+    elif cond == 'class': 
     self.first_conv =  nn.Conv2d(channels_in,fctr[0],kernel_size=3, padding='same', bias=True)
 
     #down blocks
@@ -62,25 +69,32 @@ class Conditional_UNet_Res(nn.Module):
     self.mha42 = MHABlock(fctr[4])
 
   def forward(self, input, t, y):
-    # compute time mebedding
+    # compute time embedding
     t_emb  = self.time_embedder(t).to(input.device)
-    # compute class embedding if present
-    if y is not None:
+    
+    # compute class embedding if present and training on class conditioned data
+    if cond == 'class' and (y is not None):
         c_emb  = self.class_embedder(y).to(input.device)
     else:
         c_emb = torch.zeros_like(t_emb).to(input.device)
     # combine both embeddings
     tc_emb = t_emb + c_emb
 
-    # learnable layers
+    # learnable embedding layers
     tc_emb0 = self.tc_embedder0(tc_emb)
     tc_emb1 = self.tc_embedder1(tc_emb)
     tc_emb2 = self.tc_embedder2(tc_emb)
     tc_emb3 = self.tc_embedder3(tc_emb)
     tc_emb4 = self.tc_embedder4(tc_emb)
 
-    # first two conv layers
-    x = self.first_conv(input) + tc_emb0[:,:,None,None]
+    # first conv layers
+    if cond == 'image':
+        # concat latent with masked image if training in image conditioned data
+        cat = torch.concat((input, y), dim=1)  
+    elif cond == 'class':
+        cat = input
+        
+    x = self.first_conv(cat) + tc_emb0[:,:,None,None]
 
     #time and class mb
     skip1,x = self.down1(x,tc_emb1)
diff --git a/trainer/train.py b/trainer/train.py
index 9fe1525..8558d59 100644
--- a/trainer/train.py
+++ b/trainer/train.py
@@ -84,8 +84,7 @@ class ModelEmaV2(nn.Module):
 
 
 
-# Training function for the unconditional diffusion model
-
+# Training function for the conditional diffusion model
 def cdm_trainer(model, 
                  device,
                  trainloader, testloader,
@@ -106,24 +105,24 @@ def cdm_trainer(model,
                  **args
                  ):
     '''
-    model:           Properly initialized DDPM model.
-    store_iter:      Stores the trained DDPM every store_iter epochs.
-    experiment_path: Path to the models experiment folder, where the trained model will be stored every store_iter epochs
-    eval_iter:       Evaluates the trained DDPM on testing data every eval_iter epochs.
-    epochs:          Number of epochs we train the model further. 
-    optimizer_class: PyTorch optimizer.
-    optimizer_param: Parameters for the PyTorch optimizer.
-    learning_rate:   For optimizer initialization when training from zero, i.e. no checkpoint
-    verbose:         If True, prints the running losses for every epoch.
-    run_name:        Run name for WandB. IF YOU TRAIN FROM CHECKPOINT MAKE SURE TO USE THE SAME
-                     'run_name' FOR THE DATA TO BE LOGGED ON THE SAME WANDB RUN!
-    trainloader:     Loads the train dataset
-    testloader:      Loads the test dataset
-    checkpoint:      Name of the saved pth. file containing the trained weights and biases
-    T_max:           CosineAnnealingLR scheduler argument (nr of steps in training for a full cycle) 
-    eta_min:         CosineAnnealingLR scheduler argument (scheduler oscillates between highest lr 'leraning_rate' and minimum lr 'eta_min')
-    decay:           EMA decay rate that is used to weight the effect of the ema model when computing the weighted avg between trained and
-                     ema weights for the networks weight update 
+    model:               Properly initialized DDPM model.
+    store_iter:          Stores the trained DDPM every store_iter epochs.
+    experiment_path:     Path to the models experiment folder, where the trained model will be stored every store_iter epochs
+    eval_iter:           Evaluates the trained DDPM on testing data every eval_iter epochs.
+    epochs:              Number of epochs we train the model further. 
+    optimizer_class:     PyTorch optimizer.
+    optimizer_param:     Parameters for the PyTorch optimizer.
+    learning_rate:       For optimizer initialization when training from zero, i.e. no checkpoint
+    verbose:             If True, prints the running losses for every epoch.
+    run_name:            Run name for WandB. IF YOU TRAIN FROM CHECKPOINT MAKE SURE TO USE THE SAME
+                         'run_name' FOR THE DATA TO BE LOGGED ON THE SAME WANDB RUN!
+    trainloader:         Loads the train dataset
+    testloader:          Loads the test dataset
+    checkpoint:          Name of the saved pth. file containing the trained weights and biases
+    T_max:               CosineAnnealingLR scheduler argument (nr of steps in training for a full cycle) 
+    eta_min:             CosineAnnealingLR scheduler argument (scheduler oscillates between highest lr 'leraning_rate' and minimum lr 'eta_min')
+    decay:               EMA decay rate that is used to weight the effect of the ema model when computing the weighted avg between trained and
+                         ema weights for the networks weight update 
     '''
 
     # set optimizer parameters and learning rate
@@ -155,9 +154,9 @@ def cdm_trainer(model,
             optimizer_state_dict = checkpoint['optimizer']
             optimizer.load_state_dict(optimizer_state_dict)
             # load learning rate schedule state
-            #scheduler_state_dict = checkpoint['scheduler']
-            #scheduler.load_state_dict(scheduler_state_dict)
-            #scheduler.last_epoch = last_epoch
+            scheduler_state_dict = checkpoint['scheduler']
+            scheduler.load_state_dict(scheduler_state_dict)
+            scheduler.last_epoch = (last_epoch+1)*len(trainloader)
             # load ema model state
             if ema_training:
                 ema.module.load_state_dict(checkpoint['ema'])     
@@ -206,17 +205,20 @@ def cdm_trainer(model,
                 # apply noise
                 x_t, forward_noise = model.forward_trajectory(x_0,t)                
                 
-                # compute denoising step at time t under CFGD
-                rand_prob = torch.rand(x_0.shape[0]).to(device)
-                mask_condition = (rand_prob <= 0.9)
-                pred_noise = torch.zeros_like(x_t).to(device)
-                # for every image with porb. of 90% we apply forward pass cond. on class, 10% prob. without class
-                if torch.any(mask_condition):                
-                    pred_noise[mask_condition] = model.forward(x_t[mask_condition], t[mask_condition], y = y[mask_condition])
-                if torch.any(~mask_condition):
-                    pred_noise[~mask_condition] = model.forward(x_t[~mask_condition], t[~mask_condition], y=None)
-                mean, std = model.reverse_dist_param(x_t, pred_noise, t)
-
+                
+                if model.class_free_guidence:
+                    # compute denoising step at time t under CFGD
+                    rand_prob = torch.rand(x_0.shape[0]).to(device)
+                    mask_condition = (rand_prob <= 0.9)
+                    pred_noise = torch.zeros_like(x_t).to(device)
+                    # for every image in the batch, with porb. of 90% we apply forward pass conditioned on class, 10% prob. without class
+                    if torch.any(mask_condition):                
+                        pred_noise[mask_condition] = model.forward(x_t[mask_condition], t[mask_condition], y = y[mask_condition])
+                    if torch.any(~mask_condition):
+                        pred_noise[~mask_condition] = model.forward(x_t[~mask_condition], t[~mask_condition], y=None)
+                else:
+                    pred_noise = model.forward(x_t,t,y=y)
+                
                 loss = 0
                 # compute kl loss
                 if torch.any(mask_non_zero_t):
@@ -227,6 +229,7 @@ def cdm_trainer(model,
 
                 # if reconstrcution loss was drawn      
                 if torch.any(mask_zero_t):
+                    mean, std = model.reverse_dist_param(x_t, pred_noise, t)
                     recon_loss = model.loss_recon(x_0[mask_zero_t], mean[mask_zero_t], std[mask_zero_t])
                     loss += recon_loss
                     run.log({'recon_loss': recon_loss.item(), 'epoch': epoch, 'batch': idx})
@@ -260,18 +263,19 @@ def cdm_trainer(model,
 
                         # apply noise
                         x_t, forward_noise = model.forward_trajectory(x_0,t)
-                        
-                        # compute denoising step at time t under CFGD
-                        rand_prob = torch.rand(x_0.shape[0])
-                        mask_condition = (rand_prob <= 0.9)
-                        pred_noise = torch.zeros_like(x_t)
-                        # for every image with porb. of 90% we apply forward pass cond. on class, 10% prob. without class
-                        if torch.any(mask_condition):
-                            pred_noise[mask_condition] = model.forward(x_t[mask_condition], t[mask_condition], y = y[mask_condition])
-                        if torch.any(~mask_condition):
-                            pred_noise[~mask_condition] = model.forward(x_t[~mask_condition], t[~mask_condition], y=None)
-                        mean, std = model.reverse_dist_param(x_t, pred_noise, t)
- 
+                        if model.class_free_guidence:
+                            # compute denoising step at time t under CFGD
+                            rand_prob = torch.rand(x_0.shape[0])
+                            mask_condition = (rand_prob <= 0.9)
+                            pred_noise = torch.zeros_like(x_t)
+                            # for every image in the batch, with porb. of 90% we apply forward pass conditioned on class, 10% prob. without class
+                            if torch.any(mask_condition):
+                                pred_noise[mask_condition] = model.forward(x_t[mask_condition],t[mask_condition],y = y[mask_condition])
+                            if torch.any(~mask_condition):
+                                pred_noise[~mask_condition] = model.forward(x_t[~mask_condition], t[~mask_condition], y=None)
+                        else:
+                            pred_noise = model.forward(x_t,t,y=y)
+                            
                         loss = 0
                         # Compute kl loss
                         if torch.any(mask_non_zero_t):
@@ -282,6 +286,7 @@ def cdm_trainer(model,
 
                         # If reconstrcution loss was drawn      
                         if torch.any(mask_zero_t):
+                            mean, std = model.reverse_dist_param(x_t, pred_noise, t)
                             recon_loss = model.loss_recon(x_0[mask_zero_t], mean[mask_zero_t], std[mask_zero_t])
                             loss += recon_loss
                             run.log({'recon_test_loss': recon_loss.item(), 'epoch': epoch, 'batch': idx})
-- 
GitLab