diff --git a/dataloader/load.py b/dataloader/load.py index 4737e3e72c0da46629aa524c921d70415c232566..8b9c66eff83aa6c3f9aca9d01da813db3e16498c 100644 --- a/dataloader/load.py +++ b/dataloader/load.py @@ -56,6 +56,7 @@ class UnconditionalDataset(Dataset): self.transform = transforms.RandomChoice([transform_rotate,transform_randomcrop]) else : self.transform = transforms.Compose([transforms.ToTensor(), + transforms.Lambda(lambda x: (x * 255).type(torch.uint8)), transforms.Resize(img_size)]) def __len__(self): @@ -64,8 +65,12 @@ class UnconditionalDataset(Dataset): def __getitem__(self,idx): path = self.df.iloc[idx].Filepath img = Image.open(path) + if img.mode == 'RGBA': + background = Image.new('RGB', img.size, (255, 255, 255)) + background.paste(img, mask=img.split()[3]) + img = background return self.transform(img),0 def tensor2PIL(self,img): back2pil = transforms.Compose([transforms.Normalize(mean=(-1,-1,-1),std=(2,2,2)),transforms.ToPILImage()]) - return back2pil(img) \ No newline at end of file + return back2pil(img) diff --git a/evaluation/evaluate.py b/evaluation/evaluate.py index bd489723d8e7009d6832e43c9aeee20e16461628..2d7f6e7929f5d96e6ee1591dd3a4eaa8f8dccda3 100644 --- a/evaluation/evaluate.py +++ b/evaluation/evaluate.py @@ -1,10 +1,19 @@ -from evaluation.sample import ddpm_sampler +from torchmetrics.image.fid import FrechetInceptionDistance +from torchmetrics.image.inception import InceptionScore +from torchmetrics.image.kid import KernelInceptionDistance +import re +import os +from PIL import Image +from torchvision import transforms +import torch def ddpm_evaluator(model, device, dataloader, checkpoint, - experiment_path + experiment_path, + sample_idx=0, + **args, ): ''' Takes a trained diffusion model from 'checkpoint' and evaluates its performance on the test @@ -14,6 +23,84 @@ def ddpm_evaluator(model, checkpoint: Name of the saved pth. file containing the trained weights and biases experiment_path: Path to the experiment folder where the evaluation results will be stored testloader: Loads the test dataset - TODO ... ''' - return None + + checkpoint_path = f'{experiment_path}trained_ddpm/{checkpoint}' + # create evaluation directory for the complete experiment (if first time sampling images) + output_dir = f'{experiment_path}evaluations/' + os.makedirs(output_dir, exist_ok=True) + + # create sample directory for the current version of the trained model + model_name = os.path.basename(checkpoint_path) + epoch = re.findall(r'\d+', model_name) + if epoch: + e = int(epoch[0]) + else: + raise ValueError(f"No digit found in the filename: {filename}") + model_dir = os.path.join(output_dir,f'epoch_{e}') + os.makedirs(model_dir, exist_ok=True) + + # create the sample directory for this sampling run for the current version of the model + eval_dir_list = [d for d in os.listdir(model_dir) if os.path.isdir(os.path.join(model_dir, d))] + indx_list = [int(d.split('_')[1]) for d in eval_dir_list if d.startswith('evaluation_')] + j = max(indx_list, default=-1) + 1 + eval_dir = os.path.join(model_dir, f'evaluation_{j}') + os.makedirs(eval_dir, exist_ok=True) + + # Compute FID SCORE + eval_path = os.path.join(eval_dir, 'eval.txt') + + # get sampled images + transform = transforms.Compose([transforms.ToTensor(), transforms.Lambda(lambda x: (x * 255).type(torch.uint8))]) + sample_path = os.path.join(f'{experiment_path}samples/',f'epoch_{e}',f'sample_{sample_idx}') + ignore_tensor = f'image_tensor{j}' + images = [] + for samplename in os.listdir(sample_path): + if samplename == ignore_tensor: + continue + img = Image.open(os.path.join(sample_path, samplename)) + img = transform(img) + images.append(img) + generated = torch.stack(images).to(device) + generated_batches = torch.split(generated, dataloader.batch_size) + nr_generated_batches = len(generated_batches) + nr_real_batches = len(dataloader) + + # Init FID and IS + fid = FrechetInceptionDistance(normalize = False).to(device) + iscore = InceptionScore(normalize=False).to(device) + kid = KernelInceptionDistance(normalize=False, subset_size=32).to(device) + + # Update FID score for full testing dataset and the sampled batch + for idx,(data, _) in enumerate(dataloader): + data = data.to(device) + fid.update(data, real=True) + kid.update(data, real=True) + if idx < nr_generated_batches: + gen = generated_batches[idx].to(device) + fid.update(gen, real=False) + kid.update(gen, real=False) + iscore.update(gen) + + # If there are more generated images left, add them too + for idx in range(nr_real_batches, nr_generated_batches): + gen = generated_batches[idx].to(device) + fid.update(gen, real=False) + kid.update(gen, real=False) + iscore.update(gen) + + # compute total FID and IS + fid_score = fid.compute() + i_score = iscore.compute() + kid_score = kid.compute() + + # store in txt file + with open(str(eval_path), 'a') as txt: + result = f'FID_epoch_{e}_sample_{sample_idx}:' + txt.write(result + str(fid_score.item()) + '\n') + result = f'KID_epoch_{e}_sample_{sample_idx}:' + txt.write(result + str(kid_score) + '\n') + result = f'IS_epoch_{e}_sample_{sample_idx}:' + txt.write(result + str(i_score) + '\n') + + diff --git a/main.py b/main.py index dedee957f94dd6cd4a20341633a8e30ee494726c..d60ec1cfb1511a01699361529fc134af37b13771 100644 --- a/main.py +++ b/main.py @@ -117,7 +117,8 @@ def evaluate_func(f): # load dataset batchsize = meta_setting["batchsize"] test_dataset = globals()[meta_setting["dataset"]](train = False,**dataset_setting) - test_dataloader = torch.utils.data.DataLoader(test_dataset,batch_size=batchsize) + #test_dataloader = torch.utils.data.DataLoader(test_dataset,batch_size=len(test_dataset), shuffle=False) + test_dataloader = torch.utils.data.DataLoader(test_dataset,batch_size=batchsize, shuffle=False) # init Unet net = globals()[meta_setting["modelname"]](**model_setting).to(device) @@ -134,7 +135,7 @@ def evaluate_func(f): print(f"EVALUATION SETTINGS:\n\n {evaluation_setting}\n\n") print("\n\nSTART EVALUATION\n\n") - globals()[meta_setting["evaluation_function"]](model=framework, device=device, testloader = test_dataloader,safepath = f,**evaluation_setting,) + globals()[meta_setting["evaluation_function"]](model=framework, device=device, dataloader = test_dataloader,safepath = f,**evaluation_setting,) print("\n\nFINISHED EVALUATION\n\n") diff --git a/models/Framework.py b/models/Framework.py index 84783d3cf4d12350f2d1de54947dcd7ffb49c766..1d26682d05ed1d5553b1e184297d3495a0951a03 100644 --- a/models/Framework.py +++ b/models/Framework.py @@ -1,6 +1,7 @@ import torch from torch import nn import torch.nn.functional as F +import math class DDPM(nn.Module): @@ -107,14 +108,14 @@ class DDPM(nn.Module): alpha_bar (tensor): Follows a sigmoid-like curve with a linear drop-off in the middle. Length is diffusion_steps. ''' - cosine_0 = DDPM.cosine(0, diffusion_steps= diffusion_steps) - alpha_bar = [DDPM.cosine(t,diffusion_steps = diffusion_steps)/cosine_0 + cosine_0 = DDPM.cosine(torch.tensor(0, device=device), diffusion_steps= diffusion_steps) + alpha_bar = [DDPM.cosine(torch.tensor(t, device=device),diffusion_steps = diffusion_steps)/cosine_0 for t in range(1, diffusion_steps+1)] shift = [1] + alpha_bar[:-1] beta = 1 - torch.div(torch.tensor(alpha_bar, device=device), torch.tensor(shift, device=device)) - beta = torch.clamp(beta, min =0, max = 0.999) #suggested by paper + beta = torch.clamp(beta, min =0, max = 0.999).to(device) #suggested by paper alpha = 1 - beta - alpha_bar = torch.tensor(alpha_bar) + alpha_bar = torch.tensor(alpha_bar,device=device) return beta, alpha, alpha_bar @staticmethod @@ -165,7 +166,7 @@ class DDPM(nn.Module): Returns: (numpy.float64): Value of the cosine function at timestep t ''' - return (np.cos((((t/diffusion_steps)+s)*np.pi)/((1+s)*2)))**2 + return (torch.cos((((t/diffusion_steps)+s)*math.pi)/((1+s)*2)))**2 ####