Skip to content
Snippets Groups Projects
Commit c4410482 authored by Srijeet Roy's avatar Srijeet Roy
Browse files

add full evaluation pipeline

parent d0cf545a
No related branches found
No related tags found
No related merge requests found
import os
import argparse
import pickle
from pathlib import Path
import clip
from vggface import *
from torchvision.models import resnet50
from kNN import *
from metrics import *
if __name__ == '__main__':
#device = "mps" if torch.backends.mps.is_available() else "cpu"
device = "cuda" if torch.cuda.is_available() else "cpu"
print('device:', device)
print('Parsing arguments...')
parser = argparse.ArgumentParser()
parser.add_argument('-rp', '--realpath', required=True,
help='path to real images', type=str)
parser.add_argument('-gp', '--genpath', required=True,
help='path to generated images', type=str)
parser.add_argument('-d', '--data', nargs='?', const='lhq', default='lhq',
help='choose between "lhq" and "face" dataset', type=str)
parser.add_argument('-a', '--arch', nargs='?', const='cnn', default='cnn',
help='choose between "clip" and "cnn", default "cnn"', type=str)
parser.add_argument('-m', '--mode', nargs='?', const='kNN', default='kNN',
help='choose between "kNN" and "pairs" for closest_pairs, default "kNN"', type=str)
parser.add_argument('-k', '--k', nargs='?', const=3, default=3,
help='k for kNN, default 3', type=int)
parser.add_argument('-s', '--sample', nargs='?', const=10, default=10,
help='how many generated samples to compare, default 10, int or "all"')
parser.add_argument('-n', '--name', nargs='?', const='', default='',
help='name appendix for the plot file', type=str)
parser.add_argument('--fid', nargs='?', const='no', default='no',
help='compute FID and inception score, choose between "yes" and "no"', type=str)
args = vars(parser.parse_args())
path_to_real_images = args['realpath']
path_to_generated_images = args['genpath']
dataset = args['data']
arch = args['arch']
mode = args['mode']
k_kNN = args['k']
sample = args['sample']
name_appendix = args['name']
fid_bool = args['fid']
print('Start')
txt_filename = 'output/evaluation_' + dataset + '_' + arch + '_' + mode + '-' + name_appendix + '.txt'
with open(txt_filename, 'w') as f:
f.write(f'Path to real images: {path_to_real_images}\n')
f.write(f'Path to generated images: {path_to_generated_images}\n')
f.write(f'Experiment on {dataset} dataset\n')
f.write(f'Using {arch} model to extract features\n')
f.write(f'Plot of {mode} on {sample} samples\n')
# load data
path_to_training_images = os.path.join(path_to_real_images, 'train')
path_to_test_images = os.path.join(path_to_real_images, 'test')
if fid_bool == 'yes':
# load data
#path_to_training_images = os.path.join(path_to_real_images, 'train')
#path_to_test_images = os.path.join(path_to_real_images, 'test')
# metrics eval
eval_images = image_to_tensor(path_to_test_images)
generated = image_to_tensor(path_to_generated_images)
print('Computing FID, KID & inception scores...')
fid_score, kid_score, is_score = compute_scores(eval_images, generated, device)
with open(txt_filename, 'a') as f:
f.write(f'FID score: {fid_score}\n')
f.write(f'KID score: {kid_score}\n')
f.write(f'Inception score: {is_score}\n')
print('Computing Clean FID scores...')
clean_fid_score, clip_clean_fid_score = clean_fid(path_to_test_images, path_to_generated_images)
with open(txt_filename, 'a') as f:
f.write(f'Clean FID score: {clean_fid_score}\n')
f.write(f'Clean FID score with CLIP features: {clip_clean_fid_score}\n')
print('Computing FID infinity and IS infinity scores...')
fid_infinity, is_infinity = fid_inf_is_inf(path_to_test_images, path_to_generated_images)
with open(txt_filename, 'a') as f:
f.write(f'FID Infinity score: {fid_infinity}\n')
f.write(f'IS Infinity score: {is_infinity}\n')
print('Loading model...', arch)
feature_flag = False
# kNN-based eval
if dataset == 'lhq':
print('Dataset ', dataset)
#pth = '/Users/roy/Desktop/Workspace/RWTH/SoSe 2023/Deep Learning Lab/DLL_vsc/data/features/lhq_features'
pth = '/home/wn455752/repo/evaluation/features/lhq'
# load pretrained ResNet50
if arch == 'cnn':
#path_to_pretrained_weights = '/Users/roy/Desktop/Workspace/RWTH/SoSe 2023/Deep Learning Lab/DLL_vsc/pretrained/resnet50_places365_pretrained/resnet50_places365_weights.pth'
print('loading model...')
path_to_pretrained_weights = '/home/wn455752/repo/evaluation/pretrained/resnet50_places365_pretrained/resnet50_places365_weights.pth'
print('loading weights...')
weights = torch.load(path_to_pretrained_weights)
model = resnet50().to(device)
print('initializing model with pretrained weights')
model.load_state_dict(weights)
transform = transforms.Compose([transforms.ToTensor(),
transforms.Lambda(lambda x: x * 255)])
with torch.no_grad():
model.eval()
print('checking for saved dataset features')
# check for saved dataset features
name_pth = Path(os.path.join(pth, 'resnet_features/real_name_list'))
if name_pth.is_file():
with open(name_pth, 'rb') as fp:
real_names = pickle.load(fp)
feature_pth = Path(os.path.join(pth, 'resnet_features/real_image_features.pt'))
if name_pth.is_file():
print('Loading ResNet features of real images...')
real_features = torch.load(feature_pth, map_location="cpu")
real_features = real_features.to(device)
feature_flag = True
# load CLIP
elif arch == 'clip':
print('loading model...')
model, transform = clip.load("ViT-B/32", device=device)
# check for saved dataset features
name_pth = Path(os.path.join(pth, 'clip_features/real_name_list'))
if name_pth.is_file():
with open(name_pth, 'rb') as fp:
real_names = pickle.load(fp)
feature_pth = Path(os.path.join(pth, 'clip_features/real_image_features.pt'))
if name_pth.is_file():
print('Loading CLIP features of real images...')
real_features = torch.load(feature_pth, map_location="cpu")
real_features = real_features.to(device)
feature_flag = True
elif dataset == 'faces':
print('Dataset ', dataset)
#pth = '/Users/roy/Desktop/Workspace/RWTH/SoSe 2023/Deep Learning Lab/DLL_vsc/data/features/face_features'
pth = '/home/wn455752/repo/evaluation/features/faces'
# load pretrained VGGFace
if arch == 'cnn':
print('loading model...')
#path_to_pretrained_weights = '/Users/roy/Desktop/Workspace/RWTH/SoSe 2023/Deep Learning Lab/DLL_vsc/pretrained/vggface_pretrained/VGG_FACE.t7'
path_to_pretrained_weights = '/home/wn455752/repo/evaluation/pretrained/vggface_pretrained/VGG_FACE.t7'
model = VGG_16().to(device)
model.load_weights(path=path_to_pretrained_weights)
transform = transforms.Compose([transforms.ToTensor(),
transforms.Resize((224,224)),
transforms.Lambda(lambda x: x * 255)])
with torch.no_grad():
model.eval()
# check for saved dataset features
name_pth = Path(os.path.join(pth, 'vggface_features/real_name_list'))
if name_pth.is_file():
with open(name_pth, 'rb') as fp:
real_names = pickle.load(fp)
feature_pth = Path(os.path.join(pth, 'vggface_features/real_image_features.pt'))
if name_pth.is_file():
print('Loading VGGFace features of real images...')
real_features = torch.load(feature_pth, map_location="cpu")
real_features = real_features.to(device)
feature_flag = True
# load CLIP
elif arch == 'clip':
print('loading model...')
model, transform = clip.load("ViT-B/32", device=device)
# check for saved dataset features
name_pth = Path(os.path.join(pth, 'clip_features/real_name_list'))
if name_pth.is_file():
with open(name_pth, 'rb') as fp:
real_names = pickle.load(fp)
feature_pth = Path(os.path.join(pth, 'clip_features/real_image_features.pt'))
if name_pth.is_file():
print('Loading CLIP features of real images...')
real_features = torch.load(feature_pth, map_location="cpu")
real_features = real_features.to(device)
feature_flag = True
knn = kNN()
# get images
if not feature_flag:
print('Collecting real images...')
real_names, real_tensor = knn.get_images(path_to_training_images, transform)
with open(name_pth, 'wb') as fp:
pickle.dump(real_names, fp)
print('Collecting generating images...')
generated_names, generated_tensor = knn.get_images(path_to_generated_images, transform)
# extract features
if not feature_flag:
print('Extracting features from real images...')
real_features = knn.feature_extractor(real_tensor, model, device)
torch.save(real_features, feature_pth)
print('Extracting features from generated images...')
generated_features = knn.feature_extractor(generated_tensor, model, device)
if sample == 'all':
sample_size = len(generated_names)
else:
sample_size = int(sample)
if mode == 'kNN':
print('Finding kNNs...')
knn.kNN(real_names, generated_names,
real_features, generated_features,
path_to_training_images, path_to_generated_images,
k=k_kNN,
sample=sample_size,
name_appendix=name_appendix)
elif mode == 'pairs':
knn.nearest_neighbor(real_names, generated_names,
real_features, generated_features,
path_to_training_images, path_to_generated_images,
sample=sample_size,
name_appendix=name_appendix)
print('Finish!')
# Evaluation Pipeline
We conduct two types of evaluation - qualitative and quantitative.
### Quantitative evaluations -
Quantitative evaluations are carried out to compare different backbone architectures of our unconditional diffusion model.
A set of 10,000 generated samples from each model variant is compared with the test set of the real dataset.
These evaluations include -
1. FID score
2. Inception score
3. Clean FID score (with CLIP)
4. FID infinity and IS infinity scores
### Qualitative evaluations -
The aim of this set of evaluations is to qualitatively inspect whether our model has overfit to the training images. For this,
the entire set of 10,000 generated samples from the best performing model from quanititative evaluation is compared with the
training set of the real dataset. Additionally, the quality check is also done on a hand-selected subset of best generations.
The comparison is implemented as MSE values between features of the generated and training samples. The features are extracted
by using a pretrained model (ResNet50-Places365/VGGFace or CLIP). Based on the MSE scores we compute -
1. kNN - plot the k nearest neighbors of the generated samples
2. Closest pairs - plot the top pairs with smallest MSE value
Execution starts with evaluate_full.py file. Input arguments are -
* -rp, --realpath : Path to real images (string)
* -gp, --genpath : Path to generated images (string)
* -d, --data : Choose between 'lhq' (for LHQ landscape dataset) and 'faces' (for CelebAHQ faces dataset).
Default = 'lhq' (string)
* -a, --arch : Choose between 'cnn' and 'clip'. Chosen pretrained model is used to extract features from the images.
If 'cnn' is selected, for LHQ dataset the model is a ResNet50 pretrained on Places365 dataset and for
CelebAHQ dataset the model is a pretrained VGGFace. Default = 'cnn' (string)
* -m, --mode : Choose between 'kNN' and 'pairs' (for closest pairs), default = 'kNN' (string)
* -k, --k : k value for kNN, default = 3 (int)
* -s, --sample : Choose between an int and 'all'. If mode is 'kNN', plot kNN for this many samples (first s samples
in the directory of generated images). If mode is 'pairs', plot the top s closest pairs from entire
directory of generated images. Default 10 (int or 'all')
* -n, --name : Name appendix (string)
* --fid : Choose between 'yes' and 'no'. Compute FID, Inception score and upgraded FID scores. Default 'no' (string)
Path to real images leads to a directory with two sub-directories - train and test.
data
|_ lhq
| |_ train
| |_ test
|_ celebahq256_imgs
| |_ train
| |_ test
CLIP and CNN (ResNet50 or VGGFace) features of training images are saved after the first execution. This alleviates the need
to recompute features of real images for different sets of generated samples.
### Links
1. ResNet50 pretrained on Places365 - https://github.com/CSAILVision/places365
2. Pretrained VGGFace - https://www.robots.ox.ac.uk/~vgg/software/vgg_face/
3. Clean FID - https://github.com/GaParmar/clean-fid/tree/main
4. FID infinity, IS infinity - https://github.com/mchong6/FID_IS_infinity/tree/master
# Source - https://github.com/mchong6/FID_IS_infinity
import torch
import torch.nn as nn
from torch.nn import Parameter as P
from torchvision.models.inception import inception_v3
import torch.nn.functional as F
# Module that wraps the inception network to enable use with dataparallel and
# returning pool features and logits.
class WrapInception(nn.Module):
def __init__(self, net):
super(WrapInception,self).__init__()
self.net = net
self.mean = P(torch.tensor([0.485, 0.456, 0.406]).view(1, -1, 1, 1),
requires_grad=False)
self.std = P(torch.tensor([0.229, 0.224, 0.225]).view(1, -1, 1, 1),
requires_grad=False)
def forward(self, x):
x = (x - self.mean) / self.std
# Upsample if necessary
if x.shape[2] != 299 or x.shape[3] != 299:
x = F.interpolate(x, size=(299, 299), mode='bilinear', align_corners=True)
# 299 x 299 x 3
x = self.net.Conv2d_1a_3x3(x)
# 149 x 149 x 32
x = self.net.Conv2d_2a_3x3(x)
# 147 x 147 x 32
x = self.net.Conv2d_2b_3x3(x)
# 147 x 147 x 64
x = F.max_pool2d(x, kernel_size=3, stride=2)
# 73 x 73 x 64
x = self.net.Conv2d_3b_1x1(x)
# 73 x 73 x 80
x = self.net.Conv2d_4a_3x3(x)
# 71 x 71 x 192
x = F.max_pool2d(x, kernel_size=3, stride=2)
# 35 x 35 x 192
x = self.net.Mixed_5b(x)
# 35 x 35 x 256
x = self.net.Mixed_5c(x)
# 35 x 35 x 288
x = self.net.Mixed_5d(x)
# 35 x 35 x 288
x = self.net.Mixed_6a(x)
# 17 x 17 x 768
x = self.net.Mixed_6b(x)
# 17 x 17 x 768
x = self.net.Mixed_6c(x)
# 17 x 17 x 768
x = self.net.Mixed_6d(x)
# 17 x 17 x 768
x = self.net.Mixed_6e(x)
# 17 x 17 x 768
# 17 x 17 x 768
x = self.net.Mixed_7a(x)
# 8 x 8 x 1280
x = self.net.Mixed_7b(x)
# 8 x 8 x 2048
x = self.net.Mixed_7c(x)
# 8 x 8 x 2048
pool = torch.mean(x.view(x.size(0), x.size(1), -1), 2)
# 1 x 1 x 2048
logits = self.net.fc(F.dropout(pool, training=False).view(pool.size(0), -1))
# 1000 (num_classes)
return pool, logits
# Load and wrap the Inception model
def load_inception_net(parallel=False):
inception_model = inception_v3(pretrained=True, transform_input=False)
inception_model = WrapInception(inception_model.eval()).cuda()
if parallel:
inception_model = nn.DataParallel(inception_model)
return inception_model
import os
import torch
import torchvision.transforms as transforms
from torch.utils.data import DataLoader
from PIL import Image
import matplotlib.pyplot as plt
from collections import OrderedDict
class kNN():
def __init__(self):
pass
def get_images(self, path, transform, *args, **kwargs):
'''
returns
names: list of filenames
image_tensor: tensor with all images
'''
# path to real image files
image_files = os.listdir(path)
# list to store filenames
names = []
# list to store images (transformed to tensors)
images_list = []
for file in image_files:
if file.endswith('.png'):
filepath = os.path.join(path, file)
names.append(file)
im = Image.open(filepath)
if im.size[0] != 128:
im = im.resize((128,128)) # DDPM was trained on 128x128 images
im = transform(im)
images_list.append(im)
# tensor with all real image tensors
image_tensor = torch.stack(images_list)
return names, image_tensor
def feature_extractor(self, images, model, device='cpu', bs=128, *args, **kwargs):
'''
returns
real_features: VGGFace features for real images
generated_features: VGGFace features for generated images
'''
# extract features for real and generated images
dataloader = DataLoader(images, batch_size=bs, shuffle=False)
features_list = []
if model._get_name() == 'CLIP':
with torch.no_grad():
for item in dataloader:
features = model.encode_image(item.to(device))
features_list.append(features)
else:
with torch.no_grad():
for item in dataloader:
features = model(item.to(device))
features_list.append(features)
features = torch.cat(features_list, dim=0)
return features
def kNN(self, real_names, generated_names,
real_features, generated_features,
path_to_real_images, path_to_generated_images,
k=3,
sample=10,
name_appendix='',
*args, **kwargs):
'''
creates a plot with (generated image: k nearest real images) pairs
'''
fig, ax = plt.subplots(sample, k+1, figsize=((k+1)*3,sample*2))
for i in range(len(generated_features)):
# l2 norm of one generated feature and all real features
dist = torch.linalg.vector_norm(real_features - generated_features[i], ord=2, dim=1)
# draw the generated image
im = Image.open(os.path.join(path_to_generated_images, generated_names[i]))
ax[i, 0].imshow(im)
ax[i, 0].set_xticks([])
ax[i, 0].set_yticks([])
ax[i, 0].set_title(f'Generated: {generated_names[i].split("_")[2][:-4]}', fontsize=8)
# kNN of the generated image
knn = dist.topk(k, largest=False)
j=1
# draw the k real images
for idx in knn.indices:
im = Image.open(os.path.join(path_to_real_images, real_names[idx.item()]))
ax[i, j].imshow(im)
ax[i, j].set_xticks([])
ax[i, j].set_yticks([])
ax[i, j].set_title(f'{real_names[idx.item()][:-4]}, {knn.values[j-1].item():.2f}', fontsize=8)
j += 1
if i == sample-1:
break
# savefig
plot_name = f'{k}NN_{sample}_samples'
if name_appendix != '':
plot_name = plot_name + name_appendix
fig.savefig('output/' + plot_name + '.png')
def nearest_neighbor(self, real_names, generated_names,
real_features, generated_features,
path_to_real_images, path_to_generated_images,
sample=10,
name_appendix='',
*args, **kwargs):
print('Computing nearest neighbors...')
fig, ax = plt.subplots(sample, 2, figsize=(2*3,sample*2))
nn_dict = OrderedDict()
for i in range(len(generated_features)):
# l2 norm of one generated feature and all real features
#dist = torch.linalg.vector_norm(real_features - generated_features[i], ord=2, dim=1)
dist = torch.norm(real_features - generated_features[i], dim=1, p=2)
# nearest neighbor of the generated image
knn = dist.topk(1, largest=False)
# insert to the dict: generated_image: (distance, index of the nearest neighbor)
nn_dict[generated_names[i]] = (knn.values.item(), knn.indices.item())
print('Finding closest pairs...')
# sort to get the generated-real pairs that were the closest
nn_dict_sorted = OrderedDict(sorted(nn_dict.items(), key=lambda item: item[1][0]))
# names of the generated images that look closest to the real images
gen_names = list(nn_dict_sorted.keys())
print('Drawing the plot...')
for i in range(sample):
# draw the generated image
im = Image.open(os.path.join(path_to_generated_images, gen_names[i]))
ax[i, 0].imshow(im)
ax[i, 0].set_xticks([])
ax[i, 0].set_yticks([])
ax[i, 0].set_title(f'Generated: {generated_names[i].split("_")[2][:-4]}', fontsize=8)
# draw the real image
knn_score, real_img_idx = nn_dict_sorted[gen_names[i]]
im = Image.open(os.path.join(path_to_real_images, real_names[real_img_idx]))
ax[i, 1].imshow(im)
ax[i, 1].set_xticks([])
ax[i, 1].set_yticks([])
ax[i, 1].set_title(f'{real_names[real_img_idx][:-4]}, {knn_score:.2f}', fontsize=8)
#savefig
plot_name = f'closest_pairs_top_{sample}'
if name_appendix != '':
plot_name = plot_name + name_appendix
fig.savefig('output/' + plot_name + '.png')
\ No newline at end of file
import os
import torch
from tqdm import tqdm
from PIL import Image
from torchvision import transforms
from torch.utils.data import DataLoader
from itertools import cycle
from torchmetrics.image.fid import FrechetInceptionDistance
from torchmetrics.image.inception import InceptionScore
from torchmetrics.image.kid import KernelInceptionDistance
from cleanfid import fid
from score_infinity import calculate_FID_infinity_path, calculate_IS_infinity_path
def image_to_tensor(path):
transform = transforms.Compose([transforms.ToTensor(),
transforms.ConvertImageDtype(dtype=torch.float),
transforms.Lambda(lambda x: (x * 255).type(torch.uint8))])
file_list = os.listdir(path)
img_list = []
#for file in tqdm(file_list, desc='Loading images'):
for file in file_list:
if file.endswith('.png'):
img = Image.open(os.path.join(path, file))
if img.size[0] != 128:
img = img.resize((128,128))
img = transform(img)
img_list.append(img)
return torch.stack(img_list)
def compute_scores(real, generated, device):
real_dataloader = DataLoader(real, batch_size=128, shuffle=True)
generated_dataloader = DataLoader(generated, batch_size=128, shuffle=True)
fid = FrechetInceptionDistance().to(device)
kid = KernelInceptionDistance().to(device)
inception = InceptionScore().to(device)
for r, g in zip(real_dataloader, cycle(generated_dataloader)):
r = r.to(device)
g = g.to(device)
fid.update(r, real=True)
fid.update(g, real=False)
kid.update(r, real=True)
kid.update(g, real=False)
inception.update(g)
fid_score = fid.compute()
kid_score = kid.compute()
is_score = inception.compute()
return fid_score, kid_score, is_score
def clean_fid(path_to_real_images, path_to_generated_images):
clean_fid_score = fid.compute_fid(path_to_real_images, path_to_generated_images, mode="clean", num_workers=0)
clip_clean_fid_score = fid.compute_fid(path_to_real_images, path_to_generated_images, mode="clean", model_name="clip_vit_b_32")
return clean_fid_score, clip_clean_fid_score
def fid_inf_is_inf(path_to_real_images, path_to_generated_images, batchsize=128):
fid_infinity = calculate_FID_infinity_path(path_to_real_images, path_to_generated_images, batch_size=batchsize)
is_infinity = calculate_IS_infinity_path(path_to_generated_images, batch_size=batchsize)
return fid_infinity, is_infinity
# Source - https://github.com/mchong6/FID_IS_infinity
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import Dataset
import torchvision.transforms as transforms
from botorch.sampling.qmc import NormalQMCEngine
import numpy as np
import math
from sklearn.linear_model import LinearRegression
import math
import os
import glob
from tqdm import tqdm
from PIL import Image
from scipy import linalg
from inception import *
class randn_sampler():
"""
Generates z~N(0,1) using random sampling or scrambled Sobol sequences.
Args:
ndim: (int)
The dimension of z.
use_sobol: (bool)
If True, sample z from scrambled Sobol sequence. Else, sample
from standard normal distribution.
Default: False
use_inv: (bool)
If True, use inverse CDF to transform z from U[0,1] to N(0,1).
Else, use Box-Muller transformation.
Default: True
cache: (bool)
If True, we cache some amount of Sobol points and reorder them.
This is mainly used for training GANs when we use two separate
Sobol generators which helps stabilize the training.
Default: False
Examples::
>>> sampler = randn_sampler(128, True)
>>> z = sampler.draw(10) # Generates [10, 128] vector
"""
def __init__(self, ndim, use_sobol=False, use_inv=True, cache=False):
self.ndim = ndim
self.cache = cache
if use_sobol:
self.sampler = NormalQMCEngine(d=ndim, inv_transform=use_inv)
self.cached_points = torch.tensor([])
else:
self.sampler = None
def draw(self, batch_size):
if self.sampler is None:
return torch.randn([batch_size, self.ndim])
else:
if self.cache:
if len(self.cached_points) < batch_size:
# sample from sampler and reorder the points
self.cached_points = self.sampler.draw(int(1e6))[torch.randperm(int(1e6))]
# Sample without replacement from cached points
samples = self.cached_points[:batch_size]
self.cached_points = self.cached_points[batch_size:]
return samples
else:
return self.sampler.draw(batch_size)
def calculate_FID_infinity(gen_model, ndim, batch_size, gt_path, num_im=50000, num_points=15):
"""
Calculates effectively unbiased FID_inf using extrapolation
Args:
gen_model: (nn.Module)
The trained generator. Generator takes in z~N(0,1) and outputs
an image of [-1, 1].
ndim: (int)
The dimension of z.
batch_size: (int)
The batch size of generator
gt_path: (str)
Path to saved FID statistics of true data.
num_im: (int)
Number of images we are generating to evaluate FID_inf.
Default: 50000
num_points: (int)
Number of FID_N we evaluate to fit a line.
Default: 15
"""
# load pretrained inception model
inception_model = load_inception_net()
# define a sobol_inv sampler
z_sampler = randn_sampler(ndim, True)
# get all activations of generated images
activations, _ = accumulate_activations(gen_model, inception_model, num_im, z_sampler, batch_size)
fids = []
# Choose the number of images to evaluate FID_N at regular intervals over N
fid_batches = np.linspace(5000, num_im, num_points).astype('int32')
# Evaluate FID_N
for fid_batch_size in fid_batches:
# sample with replacement
np.random.shuffle(activations)
fid_activations = activations[:fid_batch_size]
fids.append(calculate_FID(inception_model, fid_activations, gt_path))
fids = np.array(fids).reshape(-1, 1)
# Fit linear regression
reg = LinearRegression().fit(1/fid_batches.reshape(-1, 1), fids)
fid_infinity = reg.predict(np.array([[0]]))[0,0]
return fid_infinity
def calculate_FID_infinity_path(real_path, fake_path, batch_size=50, min_fake=5000, num_points=15):
"""
Calculates effectively unbiased FID_inf using extrapolation given
paths to real and fake data
Args:
real_path: (str)
Path to real dataset or precomputed .npz statistics.
fake_path: (str)
Path to fake dataset.
batch_size: (int)
The batch size for dataloader.
Default: 50
min_fake: (int)
Minimum number of images to evaluate FID on.
Default: 5000
num_points: (int)
Number of FID_N we evaluate to fit a line.
Default: 15
"""
# load pretrained inception model
inception_model = load_inception_net()
# get all activations of generated images
if real_path.endswith('.npz'):
real_m, real_s = load_path_statistics(real_path)
else:
real_act, _ = compute_path_statistics(real_path, batch_size, model=inception_model)
real_m, real_s = np.mean(real_act, axis=0), np.cov(real_act, rowvar=False)
fake_act, _ = compute_path_statistics(fake_path, batch_size, model=inception_model)
num_fake = len(fake_act)
assert num_fake > min_fake, \
'number of fake data must be greater than the minimum point for extrapolation'
fids = []
# Choose the number of images to evaluate FID_N at regular intervals over N
fid_batches = np.linspace(min_fake, num_fake, num_points).astype('int32')
# Evaluate FID_N
for fid_batch_size in fid_batches:
# sample with replacement
np.random.shuffle(fake_act)
fid_activations = fake_act[:fid_batch_size]
m, s = np.mean(fid_activations, axis=0), np.cov(fid_activations, rowvar=False)
FID = numpy_calculate_frechet_distance(m, s, real_m, real_s)
fids.append(FID)
fids = np.array(fids).reshape(-1, 1)
# Fit linear regression
reg = LinearRegression().fit(1/fid_batches.reshape(-1, 1), fids)
fid_infinity = reg.predict(np.array([[0]]))[0,0]
return fid_infinity
def calculate_IS_infinity(gen_model, ndim, batch_size, num_im=50000, num_points=15):
"""
Calculates effectively unbiased IS_inf using extrapolation
Args:
gen_model: (nn.Module)
The trained generator. Generator takes in z~N(0,1) and outputs
an image of [-1, 1].
ndim: (int)
The dimension of z.
batch_size: (int)
The batch size of generator
num_im: (int)
Number of images we are generating to evaluate IS_inf.
Default: 50000
num_points: (int)
Number of IS_N we evaluate to fit a line.
Default: 15
"""
# load pretrained inception model
inception_model = load_inception_net()
# define a sobol_inv sampler
z_sampler = randn_sampler(ndim, True)
# get all activations of generated images
_, logits = accumulate_activations(gen_model, inception_model, num_im, z_sampler, batch_size)
IS = []
# Choose the number of images to evaluate IS_N at regular intervals over N
IS_batches = np.linspace(5000, num_im, num_points).astype('int32')
# Evaluate IS_N
for IS_batch_size in IS_batches:
# sample with replacement
np.random.shuffle(logits)
IS_logits = logits[:IS_batch_size]
IS.append(calculate_inception_score(IS_logits)[0])
IS = np.array(IS).reshape(-1, 1)
# Fit linear regression
reg = LinearRegression().fit(1/IS_batches.reshape(-1, 1), IS)
IS_infinity = reg.predict(np.array([[0]]))[0,0]
return IS_infinity
def calculate_IS_infinity_path(path, batch_size=50, min_fake=5000, num_points=15):
"""
Calculates effectively unbiased IS_inf using extrapolation given
paths to real and fake data
Args:
path: (str)
Path to fake dataset.
batch_size: (int)
The batch size for dataloader.
Default: 50
min_fake: (int)
Minimum number of images to evaluate IS on.
Default: 5000
num_points: (int)
Number of IS_N we evaluate to fit a line.
Default: 15
"""
# load pretrained inception model
inception_model = load_inception_net()
# get all activations of generated images
_, logits = compute_path_statistics(path, batch_size, model=inception_model)
num_fake = len(logits)
assert num_fake > min_fake, \
'number of fake data must be greater than the minimum point for extrapolation'
IS = []
# Choose the number of images to evaluate FID_N at regular intervals over N
IS_batches = np.linspace(min_fake, num_fake, num_points).astype('int32')
# Evaluate IS_N
for IS_batch_size in IS_batches:
# sample with replacement
np.random.shuffle(logits)
IS_logits = logits[:IS_batch_size]
IS.append(calculate_inception_score(IS_logits)[0])
IS = np.array(IS).reshape(-1, 1)
# Fit linear regression
reg = LinearRegression().fit(1/IS_batches.reshape(-1, 1), IS)
IS_infinity = reg.predict(np.array([[0]]))[0,0]
return IS_infinity
################# Functions for calculating and saving dataset inception statistics ##################
class im_dataset(Dataset):
def __init__(self, data_dir):
self.data_dir = data_dir
self.imgpaths = self.get_imgpaths()
self.transform = transforms.Compose([
transforms.Resize(64),
transforms.CenterCrop(64),
transforms.ToTensor()])
def get_imgpaths(self):
paths = glob.glob(os.path.join(self.data_dir, "**/*.jpg"), recursive=True) +\
glob.glob(os.path.join(self.data_dir, "**/*.png"), recursive=True)
return paths
def __getitem__(self, idx):
img_name = self.imgpaths[idx]
image = self.transform(Image.open(img_name))
return image
def __len__(self):
return len(self.imgpaths)
def load_path_statistics(path):
"""
Given path to dataset npz file, load and return mu and sigma
"""
if path.endswith('.npz'):
f = np.load(path)
m, s = f['mu'][:], f['sigma'][:]
f.close()
return m, s
else:
raise RuntimeError('Invalid path: %s' % path)
def compute_path_statistics(path, batch_size, model=None):
"""
Given path to a dataset, load and compute mu and sigma.
"""
if not os.path.exists(path):
raise RuntimeError('Invalid path: %s' % path)
if model is None:
model = load_inception_net()
dataset = im_dataset(path)
dataloader = torch.utils.data.DataLoader(dataset, batch_size=batch_size, drop_last=False)
return get_activations(dataloader, model)
def get_activations(dataloader, model):
"""
Get inception activations from dataset
"""
pool = []
logits = []
for images in tqdm(dataloader):
images = images.cuda()
with torch.no_grad():
pool_val, logits_val = model(images)
pool += [pool_val]
logits += [F.softmax(logits_val, 1)]
return torch.cat(pool, 0).cpu().numpy(), torch.cat(logits, 0).cpu().numpy()
def accumulate_activations(gen_model, inception_model, num_im, z_sampler, batch_size):
"""
Generate images and compute their Inception activations.
"""
pool, logits = [], []
for i in range(math.ceil(num_im/batch_size)):
with torch.no_grad():
z = z_sampler.draw(batch_size).cuda()
fake_img = to_img(gen_model(z))
pool_val, logits_val = inception_model(fake_img)
pool += [pool_val]
logits += [F.softmax(logits_val, 1)]
pool = torch.cat(pool, 0)[:num_im]
logits = torch.cat(logits, 0)[:num_im]
return pool.cpu().numpy(), logits.cpu().numpy()
def to_img(x):
"""
Normalizes an image from [-1, 1] to [0, 1]
"""
x = 0.5 * (x + 1)
x = x.clamp(0, 1)
return x
####################### Functions to help calculate FID and IS #######################
def calculate_FID(model, act, gt_npz):
"""
calculate score given activations and path to npz
"""
data_m, data_s = load_path_statistics(gt_npz)
gen_m, gen_s = np.mean(act, axis=0), np.cov(act, rowvar=False)
FID = numpy_calculate_frechet_distance(gen_m, gen_s, data_m, data_s)
return FID
def calculate_inception_score(pred, num_splits=1):
scores = []
for index in range(num_splits):
pred_chunk = pred[index * (pred.shape[0] // num_splits): (index + 1) * (pred.shape[0] // num_splits), :]
kl_inception = pred_chunk * (np.log(pred_chunk) - np.log(np.expand_dims(np.mean(pred_chunk, 0), 0)))
kl_inception = np.mean(np.sum(kl_inception, 1))
scores.append(np.exp(kl_inception))
return np.mean(scores), np.std(scores)
def numpy_calculate_frechet_distance(mu1, sigma1, mu2, sigma2, eps=1e-6):
"""Numpy implementation of the Frechet Distance.
The Frechet distance between two multivariate Gaussians X_1 ~ N(mu_1, C_1)
and X_2 ~ N(mu_2, C_2) is
d^2 = ||mu_1 - mu_2||^2 + Tr(C_1 + C_2 - 2*sqrt(C_1*C_2)).
Stable version by Dougal J. Sutherland.
Params:
-- mu1 : Numpy array containing the activations of a layer of the
inception net (like returned by the function 'get_predictions')
for generated samples.
-- mu2 : The sample mean over activations, precalculated on an
representative data set.
-- sigma1: The covariance matrix over activations for generated samples.
-- sigma2: The covariance matrix over activations, precalculated on an
representative data set.
Returns:
-- : The Frechet Distance.
"""
mu1 = np.atleast_1d(mu1)
mu2 = np.atleast_1d(mu2)
sigma1 = np.atleast_2d(sigma1)
sigma2 = np.atleast_2d(sigma2)
assert mu1.shape == mu2.shape, \
'Training and test mean vectors have different lengths'
assert sigma1.shape == sigma2.shape, \
'Training and test covariances have different dimensions'
diff = mu1 - mu2
# Product might be almost singular
covmean, _ = linalg.sqrtm(sigma1.dot(sigma2), disp=False)
if not np.isfinite(covmean).all():
msg = ('fid calculation produces singular product; '
'adding %s to diagonal of cov estimates') % eps
print(msg)
offset = np.eye(sigma1.shape[0]) * eps
covmean = linalg.sqrtm((sigma1 + offset).dot(sigma2 + offset))
# Numerical error might give slight imaginary component
if np.iscomplexobj(covmean):
if not np.allclose(np.diagonal(covmean).imag, 0, atol=1e-3):
m = np.max(np.abs(covmean.imag))
raise ValueError('Imaginary component {}'.format(m))
covmean = covmean.real
tr_covmean = np.trace(covmean)
return (diff.dot(diff) + np.trace(sigma1) +
np.trace(sigma2) - 2 * tr_covmean)
# VGG16 model from https://github.com/prlz77/vgg-face.pytorch
import torch
import torch.nn as nn
import torch.nn.functional as F
import torchfile
class VGG_16(nn.Module):
"""
Main Class
"""
def __init__(self):
"""
Constructor
"""
super().__init__()
self.block_size = [2, 2, 3, 3, 3]
self.conv_1_1 = nn.Conv2d(3, 64, 3, stride=1, padding=1)
self.conv_1_2 = nn.Conv2d(64, 64, 3, stride=1, padding=1)
self.conv_2_1 = nn.Conv2d(64, 128, 3, stride=1, padding=1)
self.conv_2_2 = nn.Conv2d(128, 128, 3, stride=1, padding=1)
self.conv_3_1 = nn.Conv2d(128, 256, 3, stride=1, padding=1)
self.conv_3_2 = nn.Conv2d(256, 256, 3, stride=1, padding=1)
self.conv_3_3 = nn.Conv2d(256, 256, 3, stride=1, padding=1)
self.conv_4_1 = nn.Conv2d(256, 512, 3, stride=1, padding=1)
self.conv_4_2 = nn.Conv2d(512, 512, 3, stride=1, padding=1)
self.conv_4_3 = nn.Conv2d(512, 512, 3, stride=1, padding=1)
self.conv_5_1 = nn.Conv2d(512, 512, 3, stride=1, padding=1)
self.conv_5_2 = nn.Conv2d(512, 512, 3, stride=1, padding=1)
self.conv_5_3 = nn.Conv2d(512, 512, 3, stride=1, padding=1)
self.fc6 = nn.Linear(512 * 7 * 7, 4096)
self.fc7 = nn.Linear(4096, 4096)
self.fc8 = nn.Linear(4096, 2622)
def load_weights(self, path):
""" Function to load luatorch pretrained
Args:
path: path for the luatorch pretrained
"""
model = torchfile.load(path)
counter = 1
block = 1
for i, layer in enumerate(model.modules):
if layer.weight is not None:
if block <= 5:
self_layer = getattr(self, "conv_%d_%d" % (block, counter))
counter += 1
if counter > self.block_size[block - 1]:
counter = 1
block += 1
self_layer.weight.data[...] = torch.tensor(layer.weight).view_as(self_layer.weight)[...]
self_layer.bias.data[...] = torch.tensor(layer.bias).view_as(self_layer.bias)[...]
else:
self_layer = getattr(self, "fc%d" % (block))
block += 1
self_layer.weight.data[...] = torch.tensor(layer.weight).view_as(self_layer.weight)[...]
self_layer.bias.data[...] = torch.tensor(layer.bias).view_as(self_layer.bias)[...]
def forward(self, x):
""" Pytorch forward
Args:
x: input image (224x224)
Returns: class logits
"""
x = F.relu(self.conv_1_1(x))
x = F.relu(self.conv_1_2(x))
x = F.max_pool2d(x, 2, 2)
x = F.relu(self.conv_2_1(x))
x = F.relu(self.conv_2_2(x))
x = F.max_pool2d(x, 2, 2)
x = F.relu(self.conv_3_1(x))
x = F.relu(self.conv_3_2(x))
x = F.relu(self.conv_3_3(x))
x = F.max_pool2d(x, 2, 2)
x = F.relu(self.conv_4_1(x))
x = F.relu(self.conv_4_2(x))
x = F.relu(self.conv_4_3(x))
x = F.max_pool2d(x, 2, 2)
x = F.relu(self.conv_5_1(x))
x = F.relu(self.conv_5_2(x))
x = F.relu(self.conv_5_3(x))
x = F.max_pool2d(x, 2, 2)
x = x.view(x.size(0), -1)
x = F.relu(self.fc6(x))
x = F.dropout(x, 0.5, self.training)
x = F.relu(self.fc7(x))
x = F.dropout(x, 0.5, self.training)
return self.fc8(x)
\ No newline at end of file
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment