From 049ef86222889d93587f5cea2c99cd12686ba920 Mon Sep 17 00:00:00 2001
From: gonzalomartingarcia0 <gonzalomartingarcia0@gmail.com>
Date: Thu, 29 Jun 2023 17:26:11 +0200
Subject: [PATCH] Implemented evaluation function, now computes FID, IS and KID
for a given sample of generated images w.r.t. testing data. Fixed device
porblems with cosine noise scheduler
---
dataloader/load.py | 7 +++-
evaluation/evaluate.py | 95 ++++++++++++++++++++++++++++++++++++++++--
main.py | 5 ++-
models/Framework.py | 11 ++---
4 files changed, 106 insertions(+), 12 deletions(-)
diff --git a/dataloader/load.py b/dataloader/load.py
index 4737e3e..8b9c66e 100644
--- a/dataloader/load.py
+++ b/dataloader/load.py
@@ -56,6 +56,7 @@ class UnconditionalDataset(Dataset):
self.transform = transforms.RandomChoice([transform_rotate,transform_randomcrop])
else :
self.transform = transforms.Compose([transforms.ToTensor(),
+ transforms.Lambda(lambda x: (x * 255).type(torch.uint8)),
transforms.Resize(img_size)])
def __len__(self):
@@ -64,8 +65,12 @@ class UnconditionalDataset(Dataset):
def __getitem__(self,idx):
path = self.df.iloc[idx].Filepath
img = Image.open(path)
+ if img.mode == 'RGBA':
+ background = Image.new('RGB', img.size, (255, 255, 255))
+ background.paste(img, mask=img.split()[3])
+ img = background
return self.transform(img),0
def tensor2PIL(self,img):
back2pil = transforms.Compose([transforms.Normalize(mean=(-1,-1,-1),std=(2,2,2)),transforms.ToPILImage()])
- return back2pil(img)
\ No newline at end of file
+ return back2pil(img)
diff --git a/evaluation/evaluate.py b/evaluation/evaluate.py
index bd48972..2d7f6e7 100644
--- a/evaluation/evaluate.py
+++ b/evaluation/evaluate.py
@@ -1,10 +1,19 @@
-from evaluation.sample import ddpm_sampler
+from torchmetrics.image.fid import FrechetInceptionDistance
+from torchmetrics.image.inception import InceptionScore
+from torchmetrics.image.kid import KernelInceptionDistance
+import re
+import os
+from PIL import Image
+from torchvision import transforms
+import torch
def ddpm_evaluator(model,
device,
dataloader,
checkpoint,
- experiment_path
+ experiment_path,
+ sample_idx=0,
+ **args,
):
'''
Takes a trained diffusion model from 'checkpoint' and evaluates its performance on the test
@@ -14,6 +23,84 @@ def ddpm_evaluator(model,
checkpoint: Name of the saved pth. file containing the trained weights and biases
experiment_path: Path to the experiment folder where the evaluation results will be stored
testloader: Loads the test dataset
- TODO ...
'''
- return None
+
+ checkpoint_path = f'{experiment_path}trained_ddpm/{checkpoint}'
+ # create evaluation directory for the complete experiment (if first time sampling images)
+ output_dir = f'{experiment_path}evaluations/'
+ os.makedirs(output_dir, exist_ok=True)
+
+ # create sample directory for the current version of the trained model
+ model_name = os.path.basename(checkpoint_path)
+ epoch = re.findall(r'\d+', model_name)
+ if epoch:
+ e = int(epoch[0])
+ else:
+ raise ValueError(f"No digit found in the filename: {filename}")
+ model_dir = os.path.join(output_dir,f'epoch_{e}')
+ os.makedirs(model_dir, exist_ok=True)
+
+ # create the sample directory for this sampling run for the current version of the model
+ eval_dir_list = [d for d in os.listdir(model_dir) if os.path.isdir(os.path.join(model_dir, d))]
+ indx_list = [int(d.split('_')[1]) for d in eval_dir_list if d.startswith('evaluation_')]
+ j = max(indx_list, default=-1) + 1
+ eval_dir = os.path.join(model_dir, f'evaluation_{j}')
+ os.makedirs(eval_dir, exist_ok=True)
+
+ # Compute FID SCORE
+ eval_path = os.path.join(eval_dir, 'eval.txt')
+
+ # get sampled images
+ transform = transforms.Compose([transforms.ToTensor(), transforms.Lambda(lambda x: (x * 255).type(torch.uint8))])
+ sample_path = os.path.join(f'{experiment_path}samples/',f'epoch_{e}',f'sample_{sample_idx}')
+ ignore_tensor = f'image_tensor{j}'
+ images = []
+ for samplename in os.listdir(sample_path):
+ if samplename == ignore_tensor:
+ continue
+ img = Image.open(os.path.join(sample_path, samplename))
+ img = transform(img)
+ images.append(img)
+ generated = torch.stack(images).to(device)
+ generated_batches = torch.split(generated, dataloader.batch_size)
+ nr_generated_batches = len(generated_batches)
+ nr_real_batches = len(dataloader)
+
+ # Init FID and IS
+ fid = FrechetInceptionDistance(normalize = False).to(device)
+ iscore = InceptionScore(normalize=False).to(device)
+ kid = KernelInceptionDistance(normalize=False, subset_size=32).to(device)
+
+ # Update FID score for full testing dataset and the sampled batch
+ for idx,(data, _) in enumerate(dataloader):
+ data = data.to(device)
+ fid.update(data, real=True)
+ kid.update(data, real=True)
+ if idx < nr_generated_batches:
+ gen = generated_batches[idx].to(device)
+ fid.update(gen, real=False)
+ kid.update(gen, real=False)
+ iscore.update(gen)
+
+ # If there are more generated images left, add them too
+ for idx in range(nr_real_batches, nr_generated_batches):
+ gen = generated_batches[idx].to(device)
+ fid.update(gen, real=False)
+ kid.update(gen, real=False)
+ iscore.update(gen)
+
+ # compute total FID and IS
+ fid_score = fid.compute()
+ i_score = iscore.compute()
+ kid_score = kid.compute()
+
+ # store in txt file
+ with open(str(eval_path), 'a') as txt:
+ result = f'FID_epoch_{e}_sample_{sample_idx}:'
+ txt.write(result + str(fid_score.item()) + '\n')
+ result = f'KID_epoch_{e}_sample_{sample_idx}:'
+ txt.write(result + str(kid_score) + '\n')
+ result = f'IS_epoch_{e}_sample_{sample_idx}:'
+ txt.write(result + str(i_score) + '\n')
+
+
diff --git a/main.py b/main.py
index dedee95..d60ec1c 100644
--- a/main.py
+++ b/main.py
@@ -117,7 +117,8 @@ def evaluate_func(f):
# load dataset
batchsize = meta_setting["batchsize"]
test_dataset = globals()[meta_setting["dataset"]](train = False,**dataset_setting)
- test_dataloader = torch.utils.data.DataLoader(test_dataset,batch_size=batchsize)
+ #test_dataloader = torch.utils.data.DataLoader(test_dataset,batch_size=len(test_dataset), shuffle=False)
+ test_dataloader = torch.utils.data.DataLoader(test_dataset,batch_size=batchsize, shuffle=False)
# init Unet
net = globals()[meta_setting["modelname"]](**model_setting).to(device)
@@ -134,7 +135,7 @@ def evaluate_func(f):
print(f"EVALUATION SETTINGS:\n\n {evaluation_setting}\n\n")
print("\n\nSTART EVALUATION\n\n")
- globals()[meta_setting["evaluation_function"]](model=framework, device=device, testloader = test_dataloader,safepath = f,**evaluation_setting,)
+ globals()[meta_setting["evaluation_function"]](model=framework, device=device, dataloader = test_dataloader,safepath = f,**evaluation_setting,)
print("\n\nFINISHED EVALUATION\n\n")
diff --git a/models/Framework.py b/models/Framework.py
index 84783d3..1d26682 100644
--- a/models/Framework.py
+++ b/models/Framework.py
@@ -1,6 +1,7 @@
import torch
from torch import nn
import torch.nn.functional as F
+import math
class DDPM(nn.Module):
@@ -107,14 +108,14 @@ class DDPM(nn.Module):
alpha_bar (tensor): Follows a sigmoid-like curve with a linear drop-off in the middle.
Length is diffusion_steps.
'''
- cosine_0 = DDPM.cosine(0, diffusion_steps= diffusion_steps)
- alpha_bar = [DDPM.cosine(t,diffusion_steps = diffusion_steps)/cosine_0
+ cosine_0 = DDPM.cosine(torch.tensor(0, device=device), diffusion_steps= diffusion_steps)
+ alpha_bar = [DDPM.cosine(torch.tensor(t, device=device),diffusion_steps = diffusion_steps)/cosine_0
for t in range(1, diffusion_steps+1)]
shift = [1] + alpha_bar[:-1]
beta = 1 - torch.div(torch.tensor(alpha_bar, device=device), torch.tensor(shift, device=device))
- beta = torch.clamp(beta, min =0, max = 0.999) #suggested by paper
+ beta = torch.clamp(beta, min =0, max = 0.999).to(device) #suggested by paper
alpha = 1 - beta
- alpha_bar = torch.tensor(alpha_bar)
+ alpha_bar = torch.tensor(alpha_bar,device=device)
return beta, alpha, alpha_bar
@staticmethod
@@ -165,7 +166,7 @@ class DDPM(nn.Module):
Returns:
(numpy.float64): Value of the cosine function at timestep t
'''
- return (np.cos((((t/diffusion_steps)+s)*np.pi)/((1+s)*2)))**2
+ return (torch.cos((((t/diffusion_steps)+s)*math.pi)/((1+s)*2)))**2
####
--
GitLab