diff --git a/README.md b/README.md index e9249ae828134d327941e75bd7cdb41fa80d25a2..03068ef903724f077eae63f35b0bd81d0f42d268 100644 --- a/README.md +++ b/README.md @@ -5,7 +5,7 @@ This repository contains the pipeline for training, sampling, and evaluation of We demonstrate our results by training conditional diffusion models to solve the tasks of class conditional image generation and inpainting. For the class labeled dataset, we chose to use the Animal Face (AFHQ) dataset containing three classes; dog, cat, and wildlife, each of them with a representation of 5000 training images. For the inpainting dataset, we train the model on the same Landscape dataset (LHQ) as with the unconditional diffusion model, and generate our own labels by randomly drawing black rectangle masks which the model learns to inpaint. For this purpose, we implement class and image conditioning mechanisms into our UNet. For class conditioning, we also make use of our implementation of classifier-free guided diffusion to achieve better sample quality results. -These techniques are presented in the papers **GLIDE: Towards Photorealistic Image Generation and Editing with Text-Guided Diffusion Models** by Nichol et al. and **Classifier-Free Diffusion Guidance** by Ho and Salamatin. +These techniques are presented in the papers [*GLIDE: Towards Photorealistic Image Generation and Editing with Text-Guided Diffusion Models*](https://arxiv.org/abs/2112.10741) by Nichol et al. and [*Classifier-Free Diffusion Guidance*](https://arxiv.org/abs/2207.12598) by Ho and Salamatin. ## Motivation @@ -36,7 +36,7 @@ Unconditional diffusion models are great at learning the image distributions pre </table> ### Image Inpainting: -show grid examples in imgs/paint +<img src="imgs/lhq/lhq.png" alt="inpaint" height="600px"> ## Recreating Results diff --git a/dataloader/load.py b/dataloader/load.py index 9bf0a294501a3c6a0c93c1bfed1ad68241e15fcb..770243a195982c9b1071ed12616999749d3844d8 100644 --- a/dataloader/load.py +++ b/dataloader/load.py @@ -38,21 +38,21 @@ class ConditionalDataset_AFHQ_Class(Dataset): theta = np.pi/4 -np.arccos(intermediate_size/(np.sqrt(2)*img_size)) transform_rotate_flip = transforms.Compose([transforms.ToTensor(), - transforms.Resize(intermediate_size,antialias=True), - transforms.RandomRotation((theta/np.pi*180),interpolation=transforms.InterpolationMode.BILINEAR), - transforms.CenterCrop(img_size), - transforms.RandomHorizontalFlip(p=0.5), - transforms.Normalize(mean=(0.5,0.5,0.5),std=(0.5,0.5,0.5))]) + transforms.Resize(intermediate_size,antialias=True), + transforms.RandomRotation((theta/np.pi*180),interpolation=transforms.InterpolationMode.BILINEAR), + transforms.CenterCrop(img_size), + transforms.RandomHorizontalFlip(p=0.5), + transforms.Normalize(mean=(0.5,0.5,0.5),std=(0.5,0.5,0.5))]) transform_flip = transforms.Compose([transforms.ToTensor(), - transforms.Resize(img_size, antialias=True), - transforms.RandomHorizontalFlip(p=0.5), - transforms.Normalize(mean=(0.5,0.5,0.5),std=(0.5,0.5,0.5))]) + transforms.Resize(img_size, antialias=True), + transforms.RandomHorizontalFlip(p=0.5), + transforms.Normalize(mean=(0.5,0.5,0.5),std=(0.5,0.5,0.5))]) self.transform = transforms.RandomChoice([transform_rotate_flip,transform_flip]) else : self.transform = transforms.Compose([transforms.ToTensor(), - transforms.Resize(img_size)]) + transforms.Resize(img_size)]) def __len__(self): diff --git a/evaluation/evaluate.py b/evaluation/evaluate.py index 727985e722a42481f6ba80d316d1a4608b54218d..764605e8e5ab3936d7a30fd570bc4687f0168557 100644 --- a/evaluation/evaluate.py +++ b/evaluation/evaluate.py @@ -329,4 +329,103 @@ def cdm_evaluator_lhq_paint(experiment_path, realpath, genpath, size=128, arch=' sample=sample_size, size=size, name_appendix=name_appendix) - print('Finish!') \ No newline at end of file + print('Finish!') + + + def simple_evaluator(model, + device, + dataloader, + checkpoint, + experiment_path, + sample_idx=0, + **args, + ): + ''' + Takes a trained diffusion model from 'checkpoint' and evaluates its performance on the test + dataset 'dataloader' w.r.t. the three most important perfromance metrics; FID, IS, KID. We continue + the progress of our evaluation function for the LDM upscalaer and may update this function accordingly. + + checkpoint: Name of the saved pth. file containing the trained weights and biases + experiment_path: Path to the experiment folder where the evaluation results will be stored + dataloader: Loads the test dataset for evaluation + sample_idx: Integer that denotes which sample directory sample_{sample_idx} from the checkpoint model shall be used for evaluation + ''' + + checkpoint_path = f'{experiment_path}trained_ddpm/{checkpoint}' + # create evaluation directory for the complete experiment (if first time sampling images) + output_dir = f'{experiment_path}evaluations/' + os.makedirs(output_dir, exist_ok=True) + + # create evaluation directory for the current version of the trained model + model_name = os.path.basename(checkpoint_path) + epoch = re.findall(r'\d+', model_name) + if epoch: + e = int(epoch[0]) + else: + raise ValueError(f"No digit found in the filename: {filename}") + model_dir = os.path.join(output_dir,f'epoch_{e}') + os.makedirs(model_dir, exist_ok=True) + + # create the evaluation directory for this evaluation run for the current version of the model + eval_dir_list = [d for d in os.listdir(model_dir) if os.path.isdir(os.path.join(model_dir, d))] + indx_list = [int(d.split('_')[1]) for d in eval_dir_list if d.startswith('evaluation_')] + j = max(indx_list, default=-1) + 1 + eval_dir = os.path.join(model_dir, f'evaluation_{j}') + os.makedirs(eval_dir, exist_ok=True) + + # Compute Metrics + eval_path = os.path.join(eval_dir, 'eval.txt') + + # get sampled images + transform = transforms.Compose([transforms.ToTensor(), transforms.Lambda(lambda x: (x * 255).type(torch.uint8))]) + sample_path = os.path.join(f'{experiment_path}samples/',f'epoch_{e}',f'sample_{sample_idx}') + ignore_tensor = f'image_tensor{j}' + images = [] + for samplename in os.listdir(sample_path): + if samplename == ignore_tensor: + continue + img = Image.open(os.path.join(sample_path, samplename)) + img = transform(img) + images.append(img) + # split them into batches for GPU memory + generated = torch.stack(images).to(device) + generated_batches = torch.split(generated, dataloader.batch_size) + nr_generated_batches = len(generated_batches) + nr_real_batches = len(dataloader) + + # Init FID, IS and KID scores + fid = FrechetInceptionDistance(normalize = False).to(device) + iscore = InceptionScore(normalize=False).to(device) + kid = KernelInceptionDistance(normalize=False, subset_size=32).to(device) + + # Update scores for the full testing dataset w.r.t. the sampled batches + for idx,(data, _) in enumerate(dataloader): + data = data.to(device) + fid.update(data, real=True) + kid.update(data, real=True) + if idx < nr_generated_batches: + gen = generated_batches[idx].to(device) + fid.update(gen, real=False) + kid.update(gen, real=False) + iscore.update(gen) + + # If there are sampled images left, add them too + for idx in range(nr_real_batches, nr_generated_batches): + gen = generated_batches[idx].to(device) + fid.update(gen, real=False) + kid.update(gen, real=False) + iscore.update(gen) + + # compute total FID, IS and KID + fid_score = fid.compute() + i_score = iscore.compute() + kid_score = kid.compute() + + # store results in txt file + with open(str(eval_path), 'a') as txt: + result = f'FID_epoch_{e}_sample_{sample_idx}:' + txt.write(result + str(fid_score.item()) + '\n') + result = f'KID_epoch_{e}_sample_{sample_idx}:' + txt.write(result + str(kid_score) + '\n') + result = f'IS_epoch_{e}_sample_{sample_idx}:' + txt.write(result + str(i_score) + '\n') diff --git "a/imgs/lhq/\342\200\216lhq.png" "b/imgs/lhq/\342\200\216lhq.png" new file mode 100644 index 0000000000000000000000000000000000000000..61d6956e9b6d27a368dfce9bae3186b419da1e87 Binary files /dev/null and "b/imgs/lhq/\342\200\216lhq.png" differ diff --git a/main.py b/main.py index 24c427753dfa7689ff4d69b0360e3a62fc25d91b..c6a99284bb3eca8356816aa73c3f43dfb101df19 100644 --- a/main.py +++ b/main.py @@ -30,7 +30,7 @@ def train_func(f): training_setting = json.load(fp) training_setting["optimizer_class"] = eval(training_setting["optimizer_class"]) - # init dataloaders + # init Dataloaders batchsize = meta_setting["batchsize"] training_dataset = globals()[meta_setting["dataset"]](train = True,**dataset_setting) test_dataset = globals()[meta_setting["dataset"]](train = False,**dataset_setting) @@ -74,7 +74,7 @@ def sample_func(f): with open(f+"/dataset_setting.json","r") as fp: dataset_setting = json.load(fp) - # init dataloader + # init Dataloader batchsize = sampling_setting["batch_size"] test_dataset = globals()[meta_setting["dataset"]](train = False,**dataset_setting) test_dataloader = torch.utils.data.DataLoader(test_dataset,batch_size=batchsize,shuffle=True) diff --git a/models/conditional_unet.py b/models/conditional_unet.py index 6539ec3e3ccd916c0eefde51e586af3525e98998..7e0843df1580efb5f520abe727bd05c8f81aef47 100644 --- a/models/conditional_unet.py +++ b/models/conditional_unet.py @@ -24,6 +24,8 @@ class Conditional_UNet_Res(nn.Module): channels_out = channels_in fctr = np.asarray(fctr)*n_channels + self.cond = cond + # learned time embedding self.time_embedder = TimeEmbedding(time_dim = time_dim) # learned class embedding @@ -73,7 +75,7 @@ class Conditional_UNet_Res(nn.Module): t_emb = self.time_embedder(t).to(input.device) # compute class embedding if present and training on class conditioned data - if cond == 'class' and (y is not None): + if self.cond == 'class' and (y is not None): c_emb = self.class_embedder(y).to(input.device) else: c_emb = torch.zeros_like(t_emb).to(input.device) @@ -88,10 +90,10 @@ class Conditional_UNet_Res(nn.Module): tc_emb4 = self.tc_embedder4(tc_emb) # first conv layers - if cond == 'image': + if self.cond == 'image': # concat latent with masked image if training in image conditioned data cat = torch.concat((input, y), dim=1) - elif cond == 'class': + elif self.cond == 'class': cat = input x = self.first_conv(cat) + tc_emb0[:,:,None,None]