Skip to content
Snippets Groups Projects
Commit 049ef862 authored by Gonzalo Martin Garcia's avatar Gonzalo Martin Garcia
Browse files

Implemented evaluation function, now computes FID, IS and KID for a given...

Implemented evaluation function, now computes FID, IS and KID for a given sample of generated images w.r.t. testing data. Fixed device porblems with cosine noise scheduler
parent 81a33135
No related branches found
No related tags found
No related merge requests found
......@@ -56,6 +56,7 @@ class UnconditionalDataset(Dataset):
self.transform = transforms.RandomChoice([transform_rotate,transform_randomcrop])
else :
self.transform = transforms.Compose([transforms.ToTensor(),
transforms.Lambda(lambda x: (x * 255).type(torch.uint8)),
transforms.Resize(img_size)])
def __len__(self):
......@@ -64,6 +65,10 @@ class UnconditionalDataset(Dataset):
def __getitem__(self,idx):
path = self.df.iloc[idx].Filepath
img = Image.open(path)
if img.mode == 'RGBA':
background = Image.new('RGB', img.size, (255, 255, 255))
background.paste(img, mask=img.split()[3])
img = background
return self.transform(img),0
def tensor2PIL(self,img):
......
from evaluation.sample import ddpm_sampler
from torchmetrics.image.fid import FrechetInceptionDistance
from torchmetrics.image.inception import InceptionScore
from torchmetrics.image.kid import KernelInceptionDistance
import re
import os
from PIL import Image
from torchvision import transforms
import torch
def ddpm_evaluator(model,
device,
dataloader,
checkpoint,
experiment_path
experiment_path,
sample_idx=0,
**args,
):
'''
Takes a trained diffusion model from 'checkpoint' and evaluates its performance on the test
......@@ -14,6 +23,84 @@ def ddpm_evaluator(model,
checkpoint: Name of the saved pth. file containing the trained weights and biases
experiment_path: Path to the experiment folder where the evaluation results will be stored
testloader: Loads the test dataset
TODO ...
'''
return None
checkpoint_path = f'{experiment_path}trained_ddpm/{checkpoint}'
# create evaluation directory for the complete experiment (if first time sampling images)
output_dir = f'{experiment_path}evaluations/'
os.makedirs(output_dir, exist_ok=True)
# create sample directory for the current version of the trained model
model_name = os.path.basename(checkpoint_path)
epoch = re.findall(r'\d+', model_name)
if epoch:
e = int(epoch[0])
else:
raise ValueError(f"No digit found in the filename: {filename}")
model_dir = os.path.join(output_dir,f'epoch_{e}')
os.makedirs(model_dir, exist_ok=True)
# create the sample directory for this sampling run for the current version of the model
eval_dir_list = [d for d in os.listdir(model_dir) if os.path.isdir(os.path.join(model_dir, d))]
indx_list = [int(d.split('_')[1]) for d in eval_dir_list if d.startswith('evaluation_')]
j = max(indx_list, default=-1) + 1
eval_dir = os.path.join(model_dir, f'evaluation_{j}')
os.makedirs(eval_dir, exist_ok=True)
# Compute FID SCORE
eval_path = os.path.join(eval_dir, 'eval.txt')
# get sampled images
transform = transforms.Compose([transforms.ToTensor(), transforms.Lambda(lambda x: (x * 255).type(torch.uint8))])
sample_path = os.path.join(f'{experiment_path}samples/',f'epoch_{e}',f'sample_{sample_idx}')
ignore_tensor = f'image_tensor{j}'
images = []
for samplename in os.listdir(sample_path):
if samplename == ignore_tensor:
continue
img = Image.open(os.path.join(sample_path, samplename))
img = transform(img)
images.append(img)
generated = torch.stack(images).to(device)
generated_batches = torch.split(generated, dataloader.batch_size)
nr_generated_batches = len(generated_batches)
nr_real_batches = len(dataloader)
# Init FID and IS
fid = FrechetInceptionDistance(normalize = False).to(device)
iscore = InceptionScore(normalize=False).to(device)
kid = KernelInceptionDistance(normalize=False, subset_size=32).to(device)
# Update FID score for full testing dataset and the sampled batch
for idx,(data, _) in enumerate(dataloader):
data = data.to(device)
fid.update(data, real=True)
kid.update(data, real=True)
if idx < nr_generated_batches:
gen = generated_batches[idx].to(device)
fid.update(gen, real=False)
kid.update(gen, real=False)
iscore.update(gen)
# If there are more generated images left, add them too
for idx in range(nr_real_batches, nr_generated_batches):
gen = generated_batches[idx].to(device)
fid.update(gen, real=False)
kid.update(gen, real=False)
iscore.update(gen)
# compute total FID and IS
fid_score = fid.compute()
i_score = iscore.compute()
kid_score = kid.compute()
# store in txt file
with open(str(eval_path), 'a') as txt:
result = f'FID_epoch_{e}_sample_{sample_idx}:'
txt.write(result + str(fid_score.item()) + '\n')
result = f'KID_epoch_{e}_sample_{sample_idx}:'
txt.write(result + str(kid_score) + '\n')
result = f'IS_epoch_{e}_sample_{sample_idx}:'
txt.write(result + str(i_score) + '\n')
......@@ -117,7 +117,8 @@ def evaluate_func(f):
# load dataset
batchsize = meta_setting["batchsize"]
test_dataset = globals()[meta_setting["dataset"]](train = False,**dataset_setting)
test_dataloader = torch.utils.data.DataLoader(test_dataset,batch_size=batchsize)
#test_dataloader = torch.utils.data.DataLoader(test_dataset,batch_size=len(test_dataset), shuffle=False)
test_dataloader = torch.utils.data.DataLoader(test_dataset,batch_size=batchsize, shuffle=False)
# init Unet
net = globals()[meta_setting["modelname"]](**model_setting).to(device)
......@@ -134,7 +135,7 @@ def evaluate_func(f):
print(f"EVALUATION SETTINGS:\n\n {evaluation_setting}\n\n")
print("\n\nSTART EVALUATION\n\n")
globals()[meta_setting["evaluation_function"]](model=framework, device=device, testloader = test_dataloader,safepath = f,**evaluation_setting,)
globals()[meta_setting["evaluation_function"]](model=framework, device=device, dataloader = test_dataloader,safepath = f,**evaluation_setting,)
print("\n\nFINISHED EVALUATION\n\n")
......
import torch
from torch import nn
import torch.nn.functional as F
import math
class DDPM(nn.Module):
......@@ -107,14 +108,14 @@ class DDPM(nn.Module):
alpha_bar (tensor): Follows a sigmoid-like curve with a linear drop-off in the middle.
Length is diffusion_steps.
'''
cosine_0 = DDPM.cosine(0, diffusion_steps= diffusion_steps)
alpha_bar = [DDPM.cosine(t,diffusion_steps = diffusion_steps)/cosine_0
cosine_0 = DDPM.cosine(torch.tensor(0, device=device), diffusion_steps= diffusion_steps)
alpha_bar = [DDPM.cosine(torch.tensor(t, device=device),diffusion_steps = diffusion_steps)/cosine_0
for t in range(1, diffusion_steps+1)]
shift = [1] + alpha_bar[:-1]
beta = 1 - torch.div(torch.tensor(alpha_bar, device=device), torch.tensor(shift, device=device))
beta = torch.clamp(beta, min =0, max = 0.999) #suggested by paper
beta = torch.clamp(beta, min =0, max = 0.999).to(device) #suggested by paper
alpha = 1 - beta
alpha_bar = torch.tensor(alpha_bar)
alpha_bar = torch.tensor(alpha_bar,device=device)
return beta, alpha, alpha_bar
@staticmethod
......@@ -165,7 +166,7 @@ class DDPM(nn.Module):
Returns:
(numpy.float64): Value of the cosine function at timestep t
'''
return (np.cos((((t/diffusion_steps)+s)*np.pi)/((1+s)*2)))**2
return (torch.cos((((t/diffusion_steps)+s)*math.pi)/((1+s)*2)))**2
####
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment