From 5dd1bc7be6ec1b24cc9084e15e40fe1d6589357f Mon Sep 17 00:00:00 2001 From: gonzalomartingarcia0 <gonzalomartingarcia0@gmail.com> Date: Sat, 22 Jul 2023 18:45:32 +0200 Subject: [PATCH] Combined both versions of the conditional diffusion model, now allows for classifier guided diffusion class conditional, and inpainting image generation. --- .gitignore | 3 +- dataloader/load.py | 108 +++++++++++++++++++++---- evaluation/sample.py | 120 ++++++++++++++++++++++++---- main.py | 44 +++++----- models/ConditionalDiffusionModel.py | 71 +++++++++------- models/conditional_unet.py | 38 ++++++--- trainer/train.py | 97 +++++++++++----------- 7 files changed, 335 insertions(+), 146 deletions(-) diff --git a/.gitignore b/.gitignore index 339aa3e..df3b0f2 100644 --- a/.gitignore +++ b/.gitignore @@ -1,7 +1,8 @@ .DS_Store */__pycache__ -*/trained_ddpm +*/trained_cdm root +.ipynb_checkpoints experiments trainer/__pycache__ wandb diff --git a/dataloader/load.py b/dataloader/load.py index a9ce93c..9db88f4 100644 --- a/dataloader/load.py +++ b/dataloader/load.py @@ -6,18 +6,16 @@ from PIL import Image import pandas as pd import numpy as np -class ConditionalDataset(Dataset): +class ConditionalDataset_AFHQ_Class(Dataset): def __init__(self,fpath,img_size,train,frac =0.8,skip_first_n = 0,ext = ".png",transform=True ): """ Args: - fpath (string): Path to the folder where images are stored - img_size (int): size of output image img_size=height=width - ext (string): type of images used(eg .png) + fpath (string): Path to the folder where images are stored + img_size (int): Size of output image img_size=height=width + ext (string): Type of images used(eg .png) transform (Bool): Image augmentation for diffusion model - skip_first_n: skips the first n values. Usefull for datasets that are sorted by increasing Likeliehood - train (Bool): Choose dataset to be either train set or test set. frac(float) required - frac (float): value within (0,1] (seeded)random shuffles dataset, then divides into train and test set. - """ + train (Bool): Choose dataset to be either train set or test set. frac(float) required + """ if train: fpath = os.path.join(fpath, 'train') else: @@ -32,23 +30,19 @@ class ConditionalDataset(Dataset): if name.endswith(ext): file_list.append(os.path.join(root, name)) class_list.append(self.class_to_idx[os.path.basename(root)]) - #df = pd.DataFrame({"Filepath":file_list},) - #self.df = df[df["Filepath"].str.endswith(ext)] self.df = pd.DataFrame({"Filepath": file_list}) self.class_list = class_list if transform: - # for training intermediate_size = 137 - theta = np.pi/4 -np.arccos(intermediate_size/(np.sqrt(2)*img_size)) #Check dataloading.ipynb in analysis-depot for more details - + theta = np.pi/4 -np.arccos(intermediate_size/(np.sqrt(2)*img_size)) + transform_rotate_flip = transforms.Compose([transforms.ToTensor(), transforms.Resize(intermediate_size,antialias=True), transforms.RandomRotation((theta/np.pi*180),interpolation=transforms.InterpolationMode.BILINEAR), transforms.CenterCrop(img_size), transforms.RandomHorizontalFlip(p=0.5), transforms.Normalize(mean=(0.5,0.5,0.5),std=(0.5,0.5,0.5))]) - transform_flip = transforms.Compose([transforms.ToTensor(), transforms.Resize(img_size, antialias=True), @@ -57,10 +51,9 @@ class ConditionalDataset(Dataset): self.transform = transforms.RandomChoice([transform_rotate_flip,transform_flip]) else : - # for evaluation self.transform = transforms.Compose([transforms.ToTensor(), - transforms.Lambda(lambda x: (x * 255).type(torch.uint8)), transforms.Resize(img_size)]) + def __len__(self): return len(self.df) @@ -74,3 +67,86 @@ class ConditionalDataset(Dataset): def tensor2PIL(self,img): back2pil = transforms.Compose([transforms.Normalize(mean=(-1,-1,-1),std=(2,2,2)),transforms.ToPILImage()]) return back2pil(img) + + +class ConditionalDataset_LHQ_Paint(Dataset): + def __init__(self,fpath,img_size,train,frac =0.8,skip_first_n = 0,ext = ".png",transform=True ): + """ + Args: + fpath (string): Path to the folder where images are stored + img_size (int): Size of output image img_size=height=width + ext (string): Type of images used(eg .png) + transform (Bool): Image augmentation for diffusion model + train (Bool): Choose dataset to be either train set or test set. frac(float) required + frac (float): value within (0,1] (seeded)random shuffles dataset, then divides into train and test set. + """ + + ### Create DataFrame + file_list = [] + for root, dirs, files in os.walk(fpath, topdown=False): + for name in sorted(files): + file_list.append(os.path.join(root, name)) + + df = pd.DataFrame({"Filepath":file_list},) + self.df = df[df["Filepath"].str.endswith(ext)] + + + if train: + df_train = self.df.sample(frac=frac,random_state=2) + self.df = df_train + else: + df_train = self.df.sample(frac=frac,random_state=2) + df_test = df.drop(df_train.index) + self.df = df_test + + if transform: + intermediate_size = 150 + theta = np.pi/4 -np.arccos(intermediate_size/(np.sqrt(2)*img_size)) + + transform_rotate = transforms.Compose([transforms.ToTensor(), + transforms.Resize(intermediate_size,antialias=True), + transforms.RandomRotation(theta/np.pi*180,interpolation=transforms.InterpolationMode.BILINEAR), + transforms.CenterCrop(img_size), + transforms.RandomHorizontalFlip(p=0.5), + transforms.Normalize(mean=(0.5,0.5,0.5),std=(0.5,0.5,0.5))]) + + transform_randomcrop = transforms.Compose([transforms.ToTensor(), + transforms.Resize(intermediate_size), + transforms.RandomCrop(img_size), + transforms.RandomHorizontalFlip(p=0.5), + transforms.Normalize(mean=(0.5,0.5,0.5),std=(0.5,0.5,0.5))]) + + self.transform = transforms.RandomChoice([transform_rotate,transform_randomcrop]) + else : + self.transform = transforms.Compose([transforms.ToTensor(), + transforms.Resize(img_size)]) + + + def __len__(self): + return len(self.df) + + def __getitem__(self,idx): + # get image + path = self.df.iloc[idx].Filepath + img = Image.open(path) + # apply transformation + img_tensor = self.transform(img) + # draw random rectangle + min_height = 30 + min_width = 30 + max_x = img_tensor.shape[1] - min_height + max_y = img_tensor.shape[2] - min_width + x = np.random.randint(0, max_x) + y = np.random.randint(0, max_y) + max_height = min(img_size, img_tensor.shape[1] - x) + max_width = min(img_size, img_tensor.shape[2] - y) + rect_height = torch.randint(min_height, max_height, (1,)).item() + rect_width = torch.randint(min_width, max_width, (1,)).item() + # create copy of image and add blacked out rectangle + masked_img = img_tensor.clone() + masked_img[:, x:x+rect_height, y:y+rect_width] = -1 + return img_tensor, masked_img + + def tensor2PIL(self,img): + back2pil = transforms.Compose([transforms.Normalize(mean=(-1,-1,-1),std=(2,2,2)),transforms.ToPILImage()]) + return back2pil(img) diff --git a/evaluation/sample.py b/evaluation/sample.py index 4d98deb..0d9af73 100644 --- a/evaluation/sample.py +++ b/evaluation/sample.py @@ -3,25 +3,26 @@ import torch from torchvision import transforms import re -def cdm_sampler(model, checkpoint, experiment_path, device, intermediate=False, batch_size=15,sample_all=False,n_times=1): +def cdm_sampler_afhq_class(model, checkpoint, experiment_path, dataloader, device, intermediate=False, batch_size=15,sample_all=False,n_times=1): ''' - Samples a tensor of 'batch_size' images from a trained diffusion model with 'checkpoint'. The generated + Samples a tensor of 'n_times'*'batch_size' images from a trained diffusion model with 'checkpoint'. The generated images are stored in the directory 'experiment_path/samples/epoch_{e}/sample_{j}. Where e is the epoch - w.r.t. the model which we are sampling form and j is an index separating images from each call of the - sampling function for the given mode. + w.r.t. the model which we are sampling form and j is an index separating the sampled batches for every call of this + sampling function. - model: Diffusion model - checkpoint: Name of the saved pth. file containing the trained weights and biases - conditioning_path: Path to the conditioning input tensors from whihc the samples will be generated from - experiment_path: Path to the experiment directory where the samples will saved under the diectory samples - batch_size: The number of images to sample - intermediate: Bool value. If False the sampling function will draw a batch of images, else it will just - sample a single image, but store all the intermediate noised latents along the reverse chain - sample_all: If True, samples a batch of images for the given model at every stored checkpoint at once - n_times: Integer denoting how many times we draw a batch of 'batch_size'. If we want to draw 10k images - the GPU will draw batches of 512 images 20 times to reach this goal. + model: Diffusion model + checkpoint: Name of the saved pth. file containing the trained weights and biases + experiment_path: Path to the experiment directory where the samples will saved under the diectory samples + batch_size: The number of images the model samples + dataloaer: To sample testing conditinoning batches + intermediate: Bool value. If False the sampling function will draw a batch of images, else it will just + sample a single image, but store all the intermediate noised latents along the reverse chain + sample_all: If True, samples a batch of images for the given model at every stored checkpoint at once + n_times: Integer denoting how many times we draw a batch of 'batch_size'. If we want to draw 10k images + the GPU will draw batches of 512 images 20 times to reach this goal. ''' + # If we want to sample from every checkpoint of the current model, recursively call this function for all checkpoints if sample_all: f = f'{experiment_path}trained_cdm/' @@ -44,7 +45,6 @@ def cdm_sampler(model, checkpoint, experiment_path, device, intermediate=False, # create samples directory for the complete experiment (if first time sampling images) output_dir = f'{experiment_path}samples/' - #output_dir = os.path.join(experiment_path,'/samples/') os.makedirs(output_dir, exist_ok=True) # create sample directory for the current version of the trained model @@ -53,8 +53,7 @@ def cdm_sampler(model, checkpoint, experiment_path, device, intermediate=False, if epoch: e = int(epoch[0]) else: - #raise ValueError(f"No digit found in the filename: {filename}") - e = 0 + raise ValueError(f"No digit found in the filename") model_dir = os.path.join(output_dir,f'epoch_{e}') os.makedirs(model_dir, exist_ok=True) @@ -86,3 +85,90 @@ def cdm_sampler(model, checkpoint, experiment_path, device, intermediate=False, except Exception as e: print("Error saving image. Exception:", e) run_indx += 1 + + +def cdm_sampler_lhq_paint(model, checkpoint, experiment_path, dataloader, device, intermediate=False, batch_size=15,sample_all=False,n_times=1): + ''' + Samples a tensor of 'n_times'*'batch_size' images from a trained diffusion model with 'checkpoint'. The generated + images are stored in the directory 'experiment_path/samples/epoch_{e}/sample_{j}. Where e is the epoch + w.r.t. the model which we are sampling form and j is an index separating the sampled batches for every call of this + sampling function. + + model: Diffusion model + checkpoint: Name of the saved pth. file containing the trained weights and biases + experiment_path: Path to the experiment directory where the samples will saved under the diectory samples + dataloader: Dataloader from which we draw testing images with black rectangles to inpaint on + batch_size: The number of images the model samples + dataloaer: To sample testing conditinoning batches + intermediate: Bool value. If False the sampling function will draw a batch of images, else it will just + sample a single image, but store all the intermediate noised latents along the reverse chain + sample_all: If True, samples a batch of images for the given model at every stored checkpoint at once + n_times: Integer denoting how many times we draw a batch of 'batch_size'. If we want to draw 10k images + the GPU will draw batches of 512 images 20 times to reach this goal. + ''' + + + # If we want to sample from every checkpoint of the current model, recursively call this function for all checkpoints + if sample_all: + f = f'{experiment_path}trained_cdm/' + checkpoint_list = [checkpoint_i for checkpoint_i in os.listdir(f) if checkpoint_i.endswith(".pth")] + for checkpoint_i in os.listdir(f): + if checkpoint_i.endswith(".pth"): + ldm_sampler(model, checkpoint_i, experiment_path, dataloader, device, batch_size=batch_size, intermediate=intermediate, sample_all=False) + return 0 + + # load model + try: + checkpoint_path = f'{experiment_path}trained_cdm/{checkpoint}' + checkpoint = torch.load(checkpoint_path) + # load weights and biases of the U-Net + net_state_dict = checkpoint['model'] + model.net.load_state_dict(net_state_dict) + model = model.to(device) + except Exception as e: + print("Error loading checkpoint. Exception:", e) + + # create samples directory for the complete experiment + output_dir = f'{experiment_path}samples/' + os.makedirs(output_dir, exist_ok=True) + + # create sample directory for the current version of the trained model + model_name = os.path.basename(checkpoint_path) + epoch = re.findall(r'\d+', model_name) + if epoch: + e = int(epoch[0]) + else: + raise ValueError(f"No digit found in the filename") + model_dir = os.path.join(output_dir,f'epoch_{e}') + os.makedirs(model_dir, exist_ok=True) + + # create the sample directory for this sampling run for the current version of the model + sample_dir_list = [d for d in os.listdir(model_dir) if os.path.isdir(os.path.join(model_dir, d))] + indx_list = [int(d.split('_')[1]) for d in sample_dir_list if d.startswith('sample_')] + j = max(indx_list, default=-1) + 1 + sample_dir = os.path.join(model_dir, f'sample_{j}') + os.makedirs(sample_dir, exist_ok=True) + + # transform + back2pil = transforms.Compose([transforms.Normalize(mean=(-1,-1,-1),std=(2,2,2)),transforms.ToPILImage()]) + + # sample upscaling latent encoding, decode and transform to PIL + run_indx=0 + for idx,(_, y) in enumerate(dataloader): + y = y.to(device) + generated = model.sample(y=y, batch_size=y.size(0)) + #torch.save(generated,os.path.join(sample_dir, f"image_tensor{idx}")) + # save images + for i in range(generated.size(0)): + image = back2pil(generated[i]) + image_path = os.path.join(sample_dir, f'sample_{j}_{run_indx}.png') + image_raw = back2pil(y[i]) + image_path_raw = os.path.join(sample_dir, f'raw_{j}_{run_indx}.png') + try: + image.save(image_path) + image_raw.save(image_path_raw) + except Exception as e: + print("Error saving image. Exception:", e) + run_indx += 1 + if run_indx>=n_times: + break diff --git a/main.py b/main.py index 701fa00..e24ad8e 100644 --- a/main.py +++ b/main.py @@ -3,13 +3,12 @@ import json import sys from dataloader.load import * from models.ConditionalDiffusionModel import * -from trainer.train import cdm_trainer -from evaluation.sample import cdm_sampler -from evaluation.evaluate import cdm_evaluator +from trainer.train import * +from evaluation.sample import * +from evaluation.evaluate import * from models.conditional_unet import * import torch - def train_func(f): #load all settings device = 'cuda' if torch.cuda.is_available() else 'cpu' @@ -31,19 +30,16 @@ def train_func(f): training_setting = json.load(fp) training_setting["optimizer_class"] = eval(training_setting["optimizer_class"]) - + # init dataloaders batchsize = meta_setting["batchsize"] training_dataset = globals()[meta_setting["dataset"]](train = True,**dataset_setting) test_dataset = globals()[meta_setting["dataset"]](train = False,**dataset_setting) - training_dataloader = torch.utils.data.DataLoader(training_dataset,batch_size=batchsize,shuffle=True) test_dataloader = torch.utils.data.DataLoader(test_dataset,batch_size=batchsize,shuffle=True) - - + # init UNet net = globals()[meta_setting["modelname"]](**model_setting).to(device) - #net = torch.compile(net) net = net.to(device) - + # init Diffusion Model framework = globals()[meta_setting["framework"]](net = net,device=device, **framework_setting) print(f"META SETTINGS:\n\n {meta_setting}\n\n") @@ -74,12 +70,18 @@ def sample_func(f): with open(f+"/sampling_setting.json","r") as fp: sampling_setting = json.load(fp) - - # init Unet + + with open(f+"/dataset_setting.json","r") as fp: + dataset_setting = json.load(fp) + + # init dataloader + batchsize = sampling_setting["batch_size"] + test_dataset = globals()[meta_setting["dataset"]](train = False,**dataset_setting) + test_dataloader = torch.utils.data.DataLoader(test_dataset,batch_size=batchsize,shuffle=True) + # init UNet net = globals()[meta_setting["modelname"]](**model_setting).to(device) - #net = torch.compile(net) net = net.to(device) - # init unconditional diffusion model + # init Diffusion model framework = globals()[meta_setting["framework"]](net = net,device=device, **framework_setting) print(f"META SETTINGS:\n\n {meta_setting}\n\n") @@ -88,7 +90,7 @@ def sample_func(f): print(f"SAMPLING SETTINGS:\n\n {sampling_setting}\n\n") print("\n\nSTART SAMPLING\n\n") - globals()[meta_setting["sampling_function"]](model=framework,device=device ,**sampling_setting,) + globals()[meta_setting["sampling_function"]](model=framework,device=device ,dataloader=test_dataloader, **sampling_setting,) print("\n\nFINISHED SAMPLING\n\n") @@ -113,18 +115,15 @@ def evaluate_func(f): with open(f+"/dataset_setting.json","r") as fp: dataset_setting = json.load(fp) - # load dataset + # init dataloaders batchsize = meta_setting["batchsize"] test_dataset = globals()[meta_setting["dataset"]](train = False,**dataset_setting) #test_dataloader = torch.utils.data.DataLoader(test_dataset,batch_size=len(test_dataset), shuffle=False) test_dataloader = torch.utils.data.DataLoader(test_dataset,batch_size=batchsize, shuffle=False) - - # init Unet + # init UNet net = globals()[meta_setting["modelname"]](**model_setting).to(device) - #net = torch.compile(net) net = net.to(device) - - # init unconditional diffusion model + # init Diffusion Model framework = globals()[meta_setting["framework"]](net = net,device=device, **framework_setting) print(f"META SETTINGS:\n\n {meta_setting}\n\n") @@ -140,9 +139,8 @@ def evaluate_func(f): - +# run training, sampling or evaluation if __name__ == '__main__': - print(sys.argv) functions = {'train': train_func,'sample': sample_func,'evaluate': evaluate_func} diff --git a/models/ConditionalDiffusionModel.py b/models/ConditionalDiffusionModel.py index 4dbf5cb..24c70b5 100644 --- a/models/ConditionalDiffusionModel.py +++ b/models/ConditionalDiffusionModel.py @@ -9,7 +9,6 @@ class CDM(nn.Module): net=None, diffusion_steps = 50, out_shape = (3,128,128), - conditional_shape = (3), noise_schedule = 'linear', beta_1 = 1e-4, beta_T = 0.02, @@ -17,19 +16,20 @@ class CDM(nn.Module): var_schedule = 'same', kl_loss = 'simplified', recon_loss = 'none', + class_free_guidence = False, guidance_score = 3, # for classifier-free guided diffusion device=None): ''' net: U-Net diffusion_steps: Length of the Markov chain out_shape: Shape of the models's in- and output images - conditional_shape: Shape of the low resolution image the DM is conditioned on for Super Resolution - noise_schedule: Methods of initialization for the noise dist. variances, 'linear', 'cosine' or bounded_cosine + noise_schedule: Methods of initialization for the noise dist. variances, 'linear', 'cosine' or 'bounded_cosine' beta_1, beta_T: Variances for the first and last noise dist. (only for the 'linear' noise schedule) alpha_bar_lower_bound: Upper bound for the varaince of the complete noise dist. (only for the 'cosine_bounded' noise schedule) var_schedule: Options to initialize or learn the denoising dist. variances, 'same', 'true' kl_loss: Choice between the mathematically correct 'weighted' or in practice most commonly used 'simplified' KL loss recon_loss: Is 'none' to ignore the reconstruction loss or 'nll' to compute the negative log likelihood + class_free_guidence: Boolean flag that indicates if classifier-free guided diffusion is used for training and sampling ''' super(CDM,self).__init__() self.device = device @@ -59,6 +59,7 @@ class CDM(nn.Module): self.net = net self.guidance_score = guidance_score + self.class_free_guidence = class_free_guidence self.diffusion_steps = diffusion_steps self.noise_schedule = noise_schedule self.var_schedule = var_schedule @@ -72,7 +73,6 @@ class CDM(nn.Module): self.kl_loss = kl_loss self.recon_loss = recon_loss self.out_shape = out_shape - self.conditional_shape = conditional_shape # precomputed for efficiency reasons self.noise_scaler = (1-alpha)/( self.sqrt_1_minus_alpha_bar) self.mean_scaler = 1/torch.sqrt(self.alpha) @@ -273,7 +273,7 @@ class CDM(nn.Module): ''' noise = torch.randn(x_t.shape, device=self.device) pred_noise = self.forward(x_t, t, y) - mean, std = self.reverse_dist_param(x_t, pred_noise, t) #self.forward(x_t, t, c) + mean, std = self.reverse_dist_param(x_t, pred_noise, t) x_t_1 = mean + std*noise return x_t_1 @@ -292,9 +292,9 @@ class CDM(nn.Module): y (tensor): Batch of conditional information for each input image Returns: - mean (tensor): Batch of means for the complete noise dist. for each image in the batch x_t - std (tensor): Batch of std scalars for the complete noise dist. for each image in the batch x_t - pred_noise (tensor): Predicted noise for each image in the batch x_t + mean (tensor): Batch of means for the complete noise dist. for each image in the batch x_t under y + std (tensor): Batch of std scalars for the complete noise dist. for each image in the batch x_t under y + pred_noise (tensor): Predicted noise for each image in the batch x_t under y ''' pred_noise = self.net(x_t,t, y) return pred_noise @@ -308,8 +308,8 @@ class CDM(nn.Module): t (tensor): Batch of timesteps Returns: - mean (tensor): Batch of means for the complete noise distribution for each image in the batch x_0 - std (tensor): Batch of std scalars for the complete noise distribution for each image in the batch x_0 + mean (tensor): Batch of means for the complete noise distribution for each image in the batch x_0 under y + std (tensor): Batch of std scalars for the complete noise distribution for each image in the batch x_0 under y ''' mean = self.mean_scaler[t-1][:,None,None,None]*(x_t - self.noise_scaler[t-1][:,None,None,None]*pred_noise) std = self.std[t-1][:,None,None,None] @@ -331,7 +331,7 @@ class CDM(nn.Module): y (tensor): Batch of conditional information for each input image Returns: - x_0_recon (tensor): Batch of images given by the model reconstruction of x_0 + x_0_recon (tensor): Batch of images given by the model reconstruction of x_0 under y ''' # apply forward trajectory x_0_recon, _ = self.forward_trajectory(x_0) @@ -345,7 +345,13 @@ class CDM(nn.Module): # set timestep batch with all entries as t t_batch = torch.full((x_0_recon.shape[0],), t ,device = self.device) # get denoising dist. param - pred_noise = self.forward(x_0_recon, t_batch, y) + if self.class_free_guidence: + # get classififer-free guided diffusion noise parameter + pred_noise_cond = self.forward(x_0_recon, t_batch, y) # param with conditioning + pred_noise_uncond = self.forward(x_0_recon, t_batch, y=None) # param without conditioning + pred_noise = torch.lerp(pred_noise_uncond, pred_noise_cond, self.guidance_score) # linear interpolation of the two + else: + pred_noise = self.forward(x_0_recon, t_batch, y) mean, std = self.reverse_dist_param(x_0_recon, pred_noise, t_batch) # compute the drawn denoised latent at time t x_0_recon = mean + std * noise @@ -365,12 +371,12 @@ class CDM(nn.Module): batch_size (int): Number of images to be sampled/generated from the diffusion model x_T (tensor): Input of the reverse trajectory. Batch of noised images usually drawn from an isotropic Gaussian, but can be set manually if desired. - y (tensor): Batch of conditional information for each input image + y (tensor): Batch of conditional information for each input image Returns: - x_t_1 (tensor): Batch of sampled/generated images + x_t_1 (tensor): Batch of sampled/generated images under y ''' - # start with a batch of isotropic noise images (or given arguemnt) + # start with a batch of isotropic noise images (or given argument x_T) if x_T: x_t_1 = x_T else: @@ -382,15 +388,17 @@ class CDM(nn.Module): noise = torch.randn(x_t_1.shape, device=self.device) else: noise = torch.zeros(x_t_1.shape, device=self.device) - # set timestep batch with all entries as t + # get denoising dist. param t_batch = torch.full((x_t_1.shape[0],), t ,device = self.device) - # get classififer-free guided diffusion noise parameter - pred_noise_cond = self.forward(x_t_1, t_batch, y) # param with conditioning - pred_noise_uncond = self.forward(x_t_1, t_batch, y=None) # param without conditioning - # linear interpolation of the two - pred_noise = torch.lerp(pred_noise_uncond, pred_noise_cond, self.guidance_score) + if self.class_free_guidence: + # get classififer-free guided diffusion noise parameter + pred_noise_cond = self.forward(x_t_1, t_batch, y) # param with conditioning + pred_noise_uncond = self.forward(x_t_1, t_batch, y=None) # param without conditioning + pred_noise = torch.lerp(pred_noise_uncond, pred_noise_cond, self.guidance_score) # linear interpolation of the two + else: + pred_noise = self.forward(x_t_1, t_batch, y) mean, std = self.reverse_dist_param(x_t_1, pred_noise, t_batch) - # compute the drawn densoined latent at time t + # compute the drawn denoised latent at time t x_t_1 = mean + std*noise return x_t_1 @@ -418,20 +426,21 @@ class CDM(nn.Module): noise = torch.randn(x_t_1.shape, device=self.device) else: noise = torch.zeros(x_t_1.shape, device=self.device) - # set timestep batch with all entries as t + # get denoising dist. param t_batch = torch.full((x_t_1.shape[0],), t ,device = self.device) - # get classififer-free guided diffusion noise parameter - pred_noise_cond = self.forward(x_t_1, t_batch, y) # param with conditioning - pred_noise_uncond = self.forward(x_t_1, t_batch, y=None) # param without conditioning - # linear interpolation of the two - pred_noise = torch.lerp(pred_noise_uncond, pred_noise_cond, self.guidance_score) + if self.class_free_guidence: + # get classififer-free guided diffusion noise parameter + pred_noise_cond = self.forward(x_t_1, t_batch, y) # param with conditioning + pred_noise_uncond = self.forward(x_t_1, t_batch, y=None) # param without conditioning + # linear interpolation of the two + pred_noise = torch.lerp(pred_noise_uncond, pred_noise_cond, self.guidance_score) + else: + pred_noise = self.forward(x_t_1, t_batch, y) mean, std = self.reverse_dist_param(x_t_1, pred_noise, t_batch) - # compute the drawn densoined latent at time t + # compute the drawn denoised latent at time t x_t_1 = mean + std*noise # store noised image x[t-1] = x_t_1.squeeze(0) - #x_sq = x.squeeze(1) - #return x_sq return x diff --git a/models/conditional_unet.py b/models/conditional_unet.py index 8f59e64..6539ec3 100644 --- a/models/conditional_unet.py +++ b/models/conditional_unet.py @@ -8,13 +8,17 @@ import numpy as np # U-Net model class Conditional_UNet_Res(nn.Module): - def __init__(self, attention,channels_in=3, nr_class=3,n_channels=64,fctr = [1,2,4,4,8],time_dim=256,**args): + def __init__(self, attention,channels_in=3, nr_class=3,n_channels=64,fctr = [1,2,4,4,8],time_dim=256,cond = 'class',**args): """ - attention : (Bool) wether to use attention layers or not - channels_in : (Int) - n_channels : (Int) Channel size after first convolution - fctr : (list) list of factors for further channel size wrt n_channels - time_dim : (Int) dimenison size for time end class embeding vector + attention: (Bool) Whether to use attention layers or not + channels_in: (Int) + n_channels: (Int) Channel size after first convolution + fctr: (list) List of factors for further channel size wrt n_channels + time_dim: (Int) Dimenison size for time end class embeding vector + cond: (str) What conditioning mechanism will be used, 'class' for adding + class embeddings and 'image' for concatenating the conditional image color + channel-wise with the input latent. + nr_class: (int) Number of possible classes the images may be conditioned on """ super().__init__() channels_out = channels_in @@ -33,6 +37,9 @@ class Conditional_UNet_Res(nn.Module): self.tc_embedder4 = torch.nn.Sequential(nn.Linear(time_dim,fctr[4]),nn.SELU(),nn.Linear(fctr[4],fctr[4])) # first conv block + if cond == 'image': + self.first_conv = nn.Conv2d(channels_in+channels_in,fctr[0],kernel_size=1, padding='same', bias=True) + elif cond == 'class': self.first_conv = nn.Conv2d(channels_in,fctr[0],kernel_size=3, padding='same', bias=True) #down blocks @@ -62,25 +69,32 @@ class Conditional_UNet_Res(nn.Module): self.mha42 = MHABlock(fctr[4]) def forward(self, input, t, y): - # compute time mebedding + # compute time embedding t_emb = self.time_embedder(t).to(input.device) - # compute class embedding if present - if y is not None: + + # compute class embedding if present and training on class conditioned data + if cond == 'class' and (y is not None): c_emb = self.class_embedder(y).to(input.device) else: c_emb = torch.zeros_like(t_emb).to(input.device) # combine both embeddings tc_emb = t_emb + c_emb - # learnable layers + # learnable embedding layers tc_emb0 = self.tc_embedder0(tc_emb) tc_emb1 = self.tc_embedder1(tc_emb) tc_emb2 = self.tc_embedder2(tc_emb) tc_emb3 = self.tc_embedder3(tc_emb) tc_emb4 = self.tc_embedder4(tc_emb) - # first two conv layers - x = self.first_conv(input) + tc_emb0[:,:,None,None] + # first conv layers + if cond == 'image': + # concat latent with masked image if training in image conditioned data + cat = torch.concat((input, y), dim=1) + elif cond == 'class': + cat = input + + x = self.first_conv(cat) + tc_emb0[:,:,None,None] #time and class mb skip1,x = self.down1(x,tc_emb1) diff --git a/trainer/train.py b/trainer/train.py index 9fe1525..8558d59 100644 --- a/trainer/train.py +++ b/trainer/train.py @@ -84,8 +84,7 @@ class ModelEmaV2(nn.Module): -# Training function for the unconditional diffusion model - +# Training function for the conditional diffusion model def cdm_trainer(model, device, trainloader, testloader, @@ -106,24 +105,24 @@ def cdm_trainer(model, **args ): ''' - model: Properly initialized DDPM model. - store_iter: Stores the trained DDPM every store_iter epochs. - experiment_path: Path to the models experiment folder, where the trained model will be stored every store_iter epochs - eval_iter: Evaluates the trained DDPM on testing data every eval_iter epochs. - epochs: Number of epochs we train the model further. - optimizer_class: PyTorch optimizer. - optimizer_param: Parameters for the PyTorch optimizer. - learning_rate: For optimizer initialization when training from zero, i.e. no checkpoint - verbose: If True, prints the running losses for every epoch. - run_name: Run name for WandB. IF YOU TRAIN FROM CHECKPOINT MAKE SURE TO USE THE SAME - 'run_name' FOR THE DATA TO BE LOGGED ON THE SAME WANDB RUN! - trainloader: Loads the train dataset - testloader: Loads the test dataset - checkpoint: Name of the saved pth. file containing the trained weights and biases - T_max: CosineAnnealingLR scheduler argument (nr of steps in training for a full cycle) - eta_min: CosineAnnealingLR scheduler argument (scheduler oscillates between highest lr 'leraning_rate' and minimum lr 'eta_min') - decay: EMA decay rate that is used to weight the effect of the ema model when computing the weighted avg between trained and - ema weights for the networks weight update + model: Properly initialized DDPM model. + store_iter: Stores the trained DDPM every store_iter epochs. + experiment_path: Path to the models experiment folder, where the trained model will be stored every store_iter epochs + eval_iter: Evaluates the trained DDPM on testing data every eval_iter epochs. + epochs: Number of epochs we train the model further. + optimizer_class: PyTorch optimizer. + optimizer_param: Parameters for the PyTorch optimizer. + learning_rate: For optimizer initialization when training from zero, i.e. no checkpoint + verbose: If True, prints the running losses for every epoch. + run_name: Run name for WandB. IF YOU TRAIN FROM CHECKPOINT MAKE SURE TO USE THE SAME + 'run_name' FOR THE DATA TO BE LOGGED ON THE SAME WANDB RUN! + trainloader: Loads the train dataset + testloader: Loads the test dataset + checkpoint: Name of the saved pth. file containing the trained weights and biases + T_max: CosineAnnealingLR scheduler argument (nr of steps in training for a full cycle) + eta_min: CosineAnnealingLR scheduler argument (scheduler oscillates between highest lr 'leraning_rate' and minimum lr 'eta_min') + decay: EMA decay rate that is used to weight the effect of the ema model when computing the weighted avg between trained and + ema weights for the networks weight update ''' # set optimizer parameters and learning rate @@ -155,9 +154,9 @@ def cdm_trainer(model, optimizer_state_dict = checkpoint['optimizer'] optimizer.load_state_dict(optimizer_state_dict) # load learning rate schedule state - #scheduler_state_dict = checkpoint['scheduler'] - #scheduler.load_state_dict(scheduler_state_dict) - #scheduler.last_epoch = last_epoch + scheduler_state_dict = checkpoint['scheduler'] + scheduler.load_state_dict(scheduler_state_dict) + scheduler.last_epoch = (last_epoch+1)*len(trainloader) # load ema model state if ema_training: ema.module.load_state_dict(checkpoint['ema']) @@ -206,17 +205,20 @@ def cdm_trainer(model, # apply noise x_t, forward_noise = model.forward_trajectory(x_0,t) - # compute denoising step at time t under CFGD - rand_prob = torch.rand(x_0.shape[0]).to(device) - mask_condition = (rand_prob <= 0.9) - pred_noise = torch.zeros_like(x_t).to(device) - # for every image with porb. of 90% we apply forward pass cond. on class, 10% prob. without class - if torch.any(mask_condition): - pred_noise[mask_condition] = model.forward(x_t[mask_condition], t[mask_condition], y = y[mask_condition]) - if torch.any(~mask_condition): - pred_noise[~mask_condition] = model.forward(x_t[~mask_condition], t[~mask_condition], y=None) - mean, std = model.reverse_dist_param(x_t, pred_noise, t) - + + if model.class_free_guidence: + # compute denoising step at time t under CFGD + rand_prob = torch.rand(x_0.shape[0]).to(device) + mask_condition = (rand_prob <= 0.9) + pred_noise = torch.zeros_like(x_t).to(device) + # for every image in the batch, with porb. of 90% we apply forward pass conditioned on class, 10% prob. without class + if torch.any(mask_condition): + pred_noise[mask_condition] = model.forward(x_t[mask_condition], t[mask_condition], y = y[mask_condition]) + if torch.any(~mask_condition): + pred_noise[~mask_condition] = model.forward(x_t[~mask_condition], t[~mask_condition], y=None) + else: + pred_noise = model.forward(x_t,t,y=y) + loss = 0 # compute kl loss if torch.any(mask_non_zero_t): @@ -227,6 +229,7 @@ def cdm_trainer(model, # if reconstrcution loss was drawn if torch.any(mask_zero_t): + mean, std = model.reverse_dist_param(x_t, pred_noise, t) recon_loss = model.loss_recon(x_0[mask_zero_t], mean[mask_zero_t], std[mask_zero_t]) loss += recon_loss run.log({'recon_loss': recon_loss.item(), 'epoch': epoch, 'batch': idx}) @@ -260,18 +263,19 @@ def cdm_trainer(model, # apply noise x_t, forward_noise = model.forward_trajectory(x_0,t) - - # compute denoising step at time t under CFGD - rand_prob = torch.rand(x_0.shape[0]) - mask_condition = (rand_prob <= 0.9) - pred_noise = torch.zeros_like(x_t) - # for every image with porb. of 90% we apply forward pass cond. on class, 10% prob. without class - if torch.any(mask_condition): - pred_noise[mask_condition] = model.forward(x_t[mask_condition], t[mask_condition], y = y[mask_condition]) - if torch.any(~mask_condition): - pred_noise[~mask_condition] = model.forward(x_t[~mask_condition], t[~mask_condition], y=None) - mean, std = model.reverse_dist_param(x_t, pred_noise, t) - + if model.class_free_guidence: + # compute denoising step at time t under CFGD + rand_prob = torch.rand(x_0.shape[0]) + mask_condition = (rand_prob <= 0.9) + pred_noise = torch.zeros_like(x_t) + # for every image in the batch, with porb. of 90% we apply forward pass conditioned on class, 10% prob. without class + if torch.any(mask_condition): + pred_noise[mask_condition] = model.forward(x_t[mask_condition],t[mask_condition],y = y[mask_condition]) + if torch.any(~mask_condition): + pred_noise[~mask_condition] = model.forward(x_t[~mask_condition], t[~mask_condition], y=None) + else: + pred_noise = model.forward(x_t,t,y=y) + loss = 0 # Compute kl loss if torch.any(mask_non_zero_t): @@ -282,6 +286,7 @@ def cdm_trainer(model, # If reconstrcution loss was drawn if torch.any(mask_zero_t): + mean, std = model.reverse_dist_param(x_t, pred_noise, t) recon_loss = model.loss_recon(x_0[mask_zero_t], mean[mask_zero_t], std[mask_zero_t]) loss += recon_loss run.log({'recon_test_loss': recon_loss.item(), 'epoch': epoch, 'batch': idx}) -- GitLab