Skip to content
Snippets Groups Projects
Commit 049ef862 authored by Gonzalo Martin Garcia's avatar Gonzalo Martin Garcia
Browse files

Implemented evaluation function, now computes FID, IS and KID for a given...

Implemented evaluation function, now computes FID, IS and KID for a given sample of generated images w.r.t. testing data. Fixed device porblems with cosine noise scheduler
parent 81a33135
Branches
No related tags found
No related merge requests found
...@@ -56,6 +56,7 @@ class UnconditionalDataset(Dataset): ...@@ -56,6 +56,7 @@ class UnconditionalDataset(Dataset):
self.transform = transforms.RandomChoice([transform_rotate,transform_randomcrop]) self.transform = transforms.RandomChoice([transform_rotate,transform_randomcrop])
else : else :
self.transform = transforms.Compose([transforms.ToTensor(), self.transform = transforms.Compose([transforms.ToTensor(),
transforms.Lambda(lambda x: (x * 255).type(torch.uint8)),
transforms.Resize(img_size)]) transforms.Resize(img_size)])
def __len__(self): def __len__(self):
...@@ -64,6 +65,10 @@ class UnconditionalDataset(Dataset): ...@@ -64,6 +65,10 @@ class UnconditionalDataset(Dataset):
def __getitem__(self,idx): def __getitem__(self,idx):
path = self.df.iloc[idx].Filepath path = self.df.iloc[idx].Filepath
img = Image.open(path) img = Image.open(path)
if img.mode == 'RGBA':
background = Image.new('RGB', img.size, (255, 255, 255))
background.paste(img, mask=img.split()[3])
img = background
return self.transform(img),0 return self.transform(img),0
def tensor2PIL(self,img): def tensor2PIL(self,img):
......
from evaluation.sample import ddpm_sampler from torchmetrics.image.fid import FrechetInceptionDistance
from torchmetrics.image.inception import InceptionScore
from torchmetrics.image.kid import KernelInceptionDistance
import re
import os
from PIL import Image
from torchvision import transforms
import torch
def ddpm_evaluator(model, def ddpm_evaluator(model,
device, device,
dataloader, dataloader,
checkpoint, checkpoint,
experiment_path experiment_path,
sample_idx=0,
**args,
): ):
''' '''
Takes a trained diffusion model from 'checkpoint' and evaluates its performance on the test Takes a trained diffusion model from 'checkpoint' and evaluates its performance on the test
...@@ -14,6 +23,84 @@ def ddpm_evaluator(model, ...@@ -14,6 +23,84 @@ def ddpm_evaluator(model,
checkpoint: Name of the saved pth. file containing the trained weights and biases checkpoint: Name of the saved pth. file containing the trained weights and biases
experiment_path: Path to the experiment folder where the evaluation results will be stored experiment_path: Path to the experiment folder where the evaluation results will be stored
testloader: Loads the test dataset testloader: Loads the test dataset
TODO ...
''' '''
return None
checkpoint_path = f'{experiment_path}trained_ddpm/{checkpoint}'
# create evaluation directory for the complete experiment (if first time sampling images)
output_dir = f'{experiment_path}evaluations/'
os.makedirs(output_dir, exist_ok=True)
# create sample directory for the current version of the trained model
model_name = os.path.basename(checkpoint_path)
epoch = re.findall(r'\d+', model_name)
if epoch:
e = int(epoch[0])
else:
raise ValueError(f"No digit found in the filename: {filename}")
model_dir = os.path.join(output_dir,f'epoch_{e}')
os.makedirs(model_dir, exist_ok=True)
# create the sample directory for this sampling run for the current version of the model
eval_dir_list = [d for d in os.listdir(model_dir) if os.path.isdir(os.path.join(model_dir, d))]
indx_list = [int(d.split('_')[1]) for d in eval_dir_list if d.startswith('evaluation_')]
j = max(indx_list, default=-1) + 1
eval_dir = os.path.join(model_dir, f'evaluation_{j}')
os.makedirs(eval_dir, exist_ok=True)
# Compute FID SCORE
eval_path = os.path.join(eval_dir, 'eval.txt')
# get sampled images
transform = transforms.Compose([transforms.ToTensor(), transforms.Lambda(lambda x: (x * 255).type(torch.uint8))])
sample_path = os.path.join(f'{experiment_path}samples/',f'epoch_{e}',f'sample_{sample_idx}')
ignore_tensor = f'image_tensor{j}'
images = []
for samplename in os.listdir(sample_path):
if samplename == ignore_tensor:
continue
img = Image.open(os.path.join(sample_path, samplename))
img = transform(img)
images.append(img)
generated = torch.stack(images).to(device)
generated_batches = torch.split(generated, dataloader.batch_size)
nr_generated_batches = len(generated_batches)
nr_real_batches = len(dataloader)
# Init FID and IS
fid = FrechetInceptionDistance(normalize = False).to(device)
iscore = InceptionScore(normalize=False).to(device)
kid = KernelInceptionDistance(normalize=False, subset_size=32).to(device)
# Update FID score for full testing dataset and the sampled batch
for idx,(data, _) in enumerate(dataloader):
data = data.to(device)
fid.update(data, real=True)
kid.update(data, real=True)
if idx < nr_generated_batches:
gen = generated_batches[idx].to(device)
fid.update(gen, real=False)
kid.update(gen, real=False)
iscore.update(gen)
# If there are more generated images left, add them too
for idx in range(nr_real_batches, nr_generated_batches):
gen = generated_batches[idx].to(device)
fid.update(gen, real=False)
kid.update(gen, real=False)
iscore.update(gen)
# compute total FID and IS
fid_score = fid.compute()
i_score = iscore.compute()
kid_score = kid.compute()
# store in txt file
with open(str(eval_path), 'a') as txt:
result = f'FID_epoch_{e}_sample_{sample_idx}:'
txt.write(result + str(fid_score.item()) + '\n')
result = f'KID_epoch_{e}_sample_{sample_idx}:'
txt.write(result + str(kid_score) + '\n')
result = f'IS_epoch_{e}_sample_{sample_idx}:'
txt.write(result + str(i_score) + '\n')
...@@ -117,7 +117,8 @@ def evaluate_func(f): ...@@ -117,7 +117,8 @@ def evaluate_func(f):
# load dataset # load dataset
batchsize = meta_setting["batchsize"] batchsize = meta_setting["batchsize"]
test_dataset = globals()[meta_setting["dataset"]](train = False,**dataset_setting) test_dataset = globals()[meta_setting["dataset"]](train = False,**dataset_setting)
test_dataloader = torch.utils.data.DataLoader(test_dataset,batch_size=batchsize) #test_dataloader = torch.utils.data.DataLoader(test_dataset,batch_size=len(test_dataset), shuffle=False)
test_dataloader = torch.utils.data.DataLoader(test_dataset,batch_size=batchsize, shuffle=False)
# init Unet # init Unet
net = globals()[meta_setting["modelname"]](**model_setting).to(device) net = globals()[meta_setting["modelname"]](**model_setting).to(device)
...@@ -134,7 +135,7 @@ def evaluate_func(f): ...@@ -134,7 +135,7 @@ def evaluate_func(f):
print(f"EVALUATION SETTINGS:\n\n {evaluation_setting}\n\n") print(f"EVALUATION SETTINGS:\n\n {evaluation_setting}\n\n")
print("\n\nSTART EVALUATION\n\n") print("\n\nSTART EVALUATION\n\n")
globals()[meta_setting["evaluation_function"]](model=framework, device=device, testloader = test_dataloader,safepath = f,**evaluation_setting,) globals()[meta_setting["evaluation_function"]](model=framework, device=device, dataloader = test_dataloader,safepath = f,**evaluation_setting,)
print("\n\nFINISHED EVALUATION\n\n") print("\n\nFINISHED EVALUATION\n\n")
......
import torch import torch
from torch import nn from torch import nn
import torch.nn.functional as F import torch.nn.functional as F
import math
class DDPM(nn.Module): class DDPM(nn.Module):
...@@ -107,14 +108,14 @@ class DDPM(nn.Module): ...@@ -107,14 +108,14 @@ class DDPM(nn.Module):
alpha_bar (tensor): Follows a sigmoid-like curve with a linear drop-off in the middle. alpha_bar (tensor): Follows a sigmoid-like curve with a linear drop-off in the middle.
Length is diffusion_steps. Length is diffusion_steps.
''' '''
cosine_0 = DDPM.cosine(0, diffusion_steps= diffusion_steps) cosine_0 = DDPM.cosine(torch.tensor(0, device=device), diffusion_steps= diffusion_steps)
alpha_bar = [DDPM.cosine(t,diffusion_steps = diffusion_steps)/cosine_0 alpha_bar = [DDPM.cosine(torch.tensor(t, device=device),diffusion_steps = diffusion_steps)/cosine_0
for t in range(1, diffusion_steps+1)] for t in range(1, diffusion_steps+1)]
shift = [1] + alpha_bar[:-1] shift = [1] + alpha_bar[:-1]
beta = 1 - torch.div(torch.tensor(alpha_bar, device=device), torch.tensor(shift, device=device)) beta = 1 - torch.div(torch.tensor(alpha_bar, device=device), torch.tensor(shift, device=device))
beta = torch.clamp(beta, min =0, max = 0.999) #suggested by paper beta = torch.clamp(beta, min =0, max = 0.999).to(device) #suggested by paper
alpha = 1 - beta alpha = 1 - beta
alpha_bar = torch.tensor(alpha_bar) alpha_bar = torch.tensor(alpha_bar,device=device)
return beta, alpha, alpha_bar return beta, alpha, alpha_bar
@staticmethod @staticmethod
...@@ -165,7 +166,7 @@ class DDPM(nn.Module): ...@@ -165,7 +166,7 @@ class DDPM(nn.Module):
Returns: Returns:
(numpy.float64): Value of the cosine function at timestep t (numpy.float64): Value of the cosine function at timestep t
''' '''
return (np.cos((((t/diffusion_steps)+s)*np.pi)/((1+s)*2)))**2 return (torch.cos((((t/diffusion_steps)+s)*math.pi)/((1+s)*2)))**2
#### ####
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment