From c4410482ec88746cdc926fdf72769cbe1cf0d00b Mon Sep 17 00:00:00 2001
From: Srijeet Roy <srijeet.11@gmail.com>
Date: Wed, 19 Jul 2023 14:44:17 +0200
Subject: [PATCH] add full evaluation pipeline

---
 evaluation/eval_full/__init__.py          |   0
 evaluation/eval_full/evaluate_full.py     | 224 +++++++++++
 evaluation/eval_full/evaluation_readme.md |  60 +++
 evaluation/eval_full/inception.py         |  74 ++++
 evaluation/eval_full/kNN.py               | 157 ++++++++
 evaluation/eval_full/metrics.py           |  70 ++++
 evaluation/eval_full/score_infinity.py    | 433 ++++++++++++++++++++++
 evaluation/eval_full/vggface.py           |  93 +++++
 8 files changed, 1111 insertions(+)
 create mode 100644 evaluation/eval_full/__init__.py
 create mode 100644 evaluation/eval_full/evaluate_full.py
 create mode 100644 evaluation/eval_full/evaluation_readme.md
 create mode 100644 evaluation/eval_full/inception.py
 create mode 100644 evaluation/eval_full/kNN.py
 create mode 100644 evaluation/eval_full/metrics.py
 create mode 100644 evaluation/eval_full/score_infinity.py
 create mode 100644 evaluation/eval_full/vggface.py

diff --git a/evaluation/eval_full/__init__.py b/evaluation/eval_full/__init__.py
new file mode 100644
index 0000000..e69de29
diff --git a/evaluation/eval_full/evaluate_full.py b/evaluation/eval_full/evaluate_full.py
new file mode 100644
index 0000000..8d1c724
--- /dev/null
+++ b/evaluation/eval_full/evaluate_full.py
@@ -0,0 +1,224 @@
+import os
+import argparse
+import pickle
+from pathlib import Path
+
+import clip
+from vggface import *
+from torchvision.models import resnet50
+
+from kNN import *
+from metrics import *
+
+if __name__ == '__main__':
+
+    #device = "mps" if torch.backends.mps.is_available() else "cpu"
+    device = "cuda" if torch.cuda.is_available() else "cpu"
+    print('device:', device)
+    print('Parsing arguments...')
+    parser = argparse.ArgumentParser()
+    parser.add_argument('-rp', '--realpath', required=True, 
+                        help='path to real images', type=str)
+    parser.add_argument('-gp', '--genpath', required=True,
+                        help='path to generated images', type=str)
+    parser.add_argument('-d', '--data', nargs='?', const='lhq', default='lhq',
+                        help='choose between "lhq" and "face" dataset', type=str)
+    parser.add_argument('-a', '--arch', nargs='?', const='cnn', default='cnn', 
+                        help='choose between "clip" and "cnn", default "cnn"', type=str)
+    parser.add_argument('-m', '--mode', nargs='?', const='kNN', default='kNN', 
+                        help='choose between "kNN" and "pairs" for closest_pairs, default "kNN"', type=str)
+    parser.add_argument('-k', '--k', nargs='?', const=3, default=3,
+                        help='k for kNN, default 3', type=int)
+    parser.add_argument('-s', '--sample', nargs='?', const=10, default=10,
+                        help='how many generated samples to compare, default 10, int or "all"')
+    parser.add_argument('-n', '--name', nargs='?', const='', default='', 
+                        help='name appendix for the plot file', type=str)
+    parser.add_argument('--fid', nargs='?', const='no', default='no',
+                        help='compute FID and inception score, choose between "yes" and "no"', type=str)
+    args = vars(parser.parse_args())
+
+    path_to_real_images = args['realpath']
+    path_to_generated_images = args['genpath']
+    dataset = args['data']
+    arch = args['arch']
+    mode = args['mode']
+    k_kNN = args['k']
+    sample = args['sample']
+    name_appendix = args['name']
+    fid_bool = args['fid']
+    print('Start')
+    txt_filename = 'output/evaluation_' + dataset + '_' + arch + '_' + mode + '-' + name_appendix + '.txt'
+    with open(txt_filename, 'w') as f:
+        f.write(f'Path to real images: {path_to_real_images}\n')
+        f.write(f'Path to generated images: {path_to_generated_images}\n')
+        f.write(f'Experiment on {dataset} dataset\n')
+        f.write(f'Using {arch} model to extract features\n')
+        f.write(f'Plot of {mode} on {sample} samples\n')
+    
+    # load data
+    path_to_training_images = os.path.join(path_to_real_images, 'train')
+    path_to_test_images = os.path.join(path_to_real_images, 'test')
+    if fid_bool == 'yes':
+        # load data
+        #path_to_training_images = os.path.join(path_to_real_images, 'train')
+        #path_to_test_images = os.path.join(path_to_real_images, 'test')
+
+        # metrics eval
+        eval_images = image_to_tensor(path_to_test_images)
+        generated = image_to_tensor(path_to_generated_images)
+    
+        print('Computing FID, KID & inception scores...')
+        fid_score, kid_score, is_score = compute_scores(eval_images, generated, device)
+        with open(txt_filename, 'a') as f:
+            f.write(f'FID score: {fid_score}\n')
+            f.write(f'KID score: {kid_score}\n')
+            f.write(f'Inception score: {is_score}\n')
+
+        print('Computing Clean FID scores...')
+        clean_fid_score, clip_clean_fid_score = clean_fid(path_to_test_images, path_to_generated_images)
+        with open(txt_filename, 'a') as f:
+            f.write(f'Clean FID score: {clean_fid_score}\n')
+            f.write(f'Clean FID score with CLIP features: {clip_clean_fid_score}\n')
+        
+        print('Computing FID infinity and IS infinity scores...')
+        fid_infinity, is_infinity = fid_inf_is_inf(path_to_test_images, path_to_generated_images) 
+        with open(txt_filename, 'a') as f:
+            f.write(f'FID Infinity score: {fid_infinity}\n')
+            f.write(f'IS Infinity score: {is_infinity}\n')
+    
+    print('Loading model...', arch)   
+    feature_flag = False
+    
+    # kNN-based eval
+    if dataset == 'lhq':
+        print('Dataset ', dataset)
+        #pth = '/Users/roy/Desktop/Workspace/RWTH/SoSe 2023/Deep Learning Lab/DLL_vsc/data/features/lhq_features'
+        pth = '/home/wn455752/repo/evaluation/features/lhq'
+        # load pretrained ResNet50 
+        if arch == 'cnn':
+            #path_to_pretrained_weights = '/Users/roy/Desktop/Workspace/RWTH/SoSe 2023/Deep Learning Lab/DLL_vsc/pretrained/resnet50_places365_pretrained/resnet50_places365_weights.pth'
+            print('loading model...')
+            path_to_pretrained_weights = '/home/wn455752/repo/evaluation/pretrained/resnet50_places365_pretrained/resnet50_places365_weights.pth'
+            print('loading weights...')
+            weights = torch.load(path_to_pretrained_weights)
+            model = resnet50().to(device)
+            print('initializing model with pretrained weights')
+            model.load_state_dict(weights)
+            transform = transforms.Compose([transforms.ToTensor(), 
+                                            transforms.Lambda(lambda x: x * 255)])
+            with torch.no_grad():
+                model.eval()
+            print('checking for saved dataset features')
+            # check for saved dataset features
+            name_pth = Path(os.path.join(pth, 'resnet_features/real_name_list'))
+            if name_pth.is_file():
+                with open(name_pth, 'rb') as fp:
+                    real_names = pickle.load(fp)
+            feature_pth = Path(os.path.join(pth, 'resnet_features/real_image_features.pt'))
+            if name_pth.is_file():
+                print('Loading ResNet features of real images...')
+                real_features = torch.load(feature_pth, map_location="cpu")
+                real_features = real_features.to(device)
+                feature_flag = True
+        # load CLIP
+        elif arch == 'clip':
+            print('loading model...')
+            model, transform = clip.load("ViT-B/32", device=device)
+            # check for saved dataset features
+            name_pth = Path(os.path.join(pth, 'clip_features/real_name_list'))
+            if name_pth.is_file():
+                with open(name_pth, 'rb') as fp:
+                    real_names = pickle.load(fp)
+            feature_pth = Path(os.path.join(pth, 'clip_features/real_image_features.pt'))
+            if name_pth.is_file():
+                print('Loading CLIP features of real images...')
+                real_features = torch.load(feature_pth, map_location="cpu")
+                real_features = real_features.to(device)
+                feature_flag = True
+
+
+
+    elif dataset == 'faces':
+        print('Dataset ', dataset)
+        #pth = '/Users/roy/Desktop/Workspace/RWTH/SoSe 2023/Deep Learning Lab/DLL_vsc/data/features/face_features'
+        pth = '/home/wn455752/repo/evaluation/features/faces' 
+        # load pretrained VGGFace
+        if arch == 'cnn':
+            print('loading model...')
+            #path_to_pretrained_weights = '/Users/roy/Desktop/Workspace/RWTH/SoSe 2023/Deep Learning Lab/DLL_vsc/pretrained/vggface_pretrained/VGG_FACE.t7'
+            path_to_pretrained_weights = '/home/wn455752/repo/evaluation/pretrained/vggface_pretrained/VGG_FACE.t7'
+            model = VGG_16().to(device)
+            model.load_weights(path=path_to_pretrained_weights)
+            transform = transforms.Compose([transforms.ToTensor(),
+                                        transforms.Resize((224,224)),
+                                        transforms.Lambda(lambda x: x * 255)])
+            with torch.no_grad():
+                model.eval()
+            
+            # check for saved dataset features
+            name_pth = Path(os.path.join(pth, 'vggface_features/real_name_list'))
+            if name_pth.is_file():
+                with open(name_pth, 'rb') as fp:
+                    real_names = pickle.load(fp)
+            feature_pth = Path(os.path.join(pth, 'vggface_features/real_image_features.pt'))
+            if name_pth.is_file():
+                print('Loading VGGFace features of real images...')
+                real_features = torch.load(feature_pth, map_location="cpu")
+                real_features = real_features.to(device)
+                feature_flag = True
+
+        # load CLIP
+        elif arch == 'clip':
+            print('loading model...')
+            model, transform = clip.load("ViT-B/32", device=device)
+            # check for saved dataset features
+            name_pth = Path(os.path.join(pth, 'clip_features/real_name_list'))
+            if name_pth.is_file():
+                with open(name_pth, 'rb') as fp:
+                    real_names = pickle.load(fp)
+            feature_pth = Path(os.path.join(pth, 'clip_features/real_image_features.pt'))
+            if name_pth.is_file():
+                print('Loading CLIP features of real images...')
+                real_features = torch.load(feature_pth, map_location="cpu")
+                real_features = real_features.to(device)
+                feature_flag = True
+
+    knn = kNN()
+    # get images
+    if not feature_flag:
+        print('Collecting real images...')
+        real_names, real_tensor = knn.get_images(path_to_training_images, transform)
+        with open(name_pth, 'wb') as fp:
+            pickle.dump(real_names, fp)
+    print('Collecting generating images...')
+    generated_names, generated_tensor = knn.get_images(path_to_generated_images, transform)
+
+    # extract features
+    if not feature_flag:
+        print('Extracting features from real images...')
+        real_features = knn.feature_extractor(real_tensor, model, device)
+        torch.save(real_features, feature_pth)
+    print('Extracting features from generated images...')
+    generated_features = knn.feature_extractor(generated_tensor, model, device)
+
+    if sample == 'all':
+        sample_size = len(generated_names)
+    else:
+        sample_size = int(sample)
+
+
+    if mode == 'kNN':
+        print('Finding kNNs...')
+        knn.kNN(real_names, generated_names, 
+            real_features, generated_features, 
+            path_to_training_images, path_to_generated_images, 
+            k=k_kNN, 
+            sample=sample_size, 
+            name_appendix=name_appendix)
+    elif mode == 'pairs':
+        knn.nearest_neighbor(real_names, generated_names, 
+                         real_features, generated_features, 
+                         path_to_training_images, path_to_generated_images, 
+                         sample=sample_size, 
+                         name_appendix=name_appendix)
+    print('Finish!')
diff --git a/evaluation/eval_full/evaluation_readme.md b/evaluation/eval_full/evaluation_readme.md
new file mode 100644
index 0000000..f28b92b
--- /dev/null
+++ b/evaluation/eval_full/evaluation_readme.md
@@ -0,0 +1,60 @@
+# Evaluation Pipeline
+
+We conduct two types of evaluation - qualitative and quantitative.
+
+### Quantitative evaluations -
+Quantitative evaluations are carried out to compare different backbone architectures of our unconditional diffusion model.
+A set of 10,000 generated samples from each model variant is compared with the test set of the real dataset.
+These evaluations include - 
+    1. FID score
+    2. Inception score
+    3. Clean FID score (with CLIP)
+    4. FID infinity and IS infinity scores
+
+### Qualitative evaluations -
+The aim of this set of evaluations is to qualitatively inspect whether our model has overfit to the training images. For this,
+the entire set of 10,000 generated samples from the best performing model from quanititative evaluation is compared with the
+training set of the real dataset. Additionally, the quality check is also done on a hand-selected subset of best generations.
+
+The comparison is implemented as MSE values between features of the generated and training samples. The features are extracted
+by using a pretrained model (ResNet50-Places365/VGGFace or CLIP). Based on the MSE scores we compute -
+    1. kNN - plot the k nearest neighbors of the generated samples
+    2. Closest pairs - plot the top pairs with smallest MSE value
+
+
+Execution starts with evaluate_full.py file. Input arguments are -
+
+* -rp, --realpath : Path to real images (string)
+* -gp, --genpath  : Path to generated images (string)
+* -d, --data      : Choose between 'lhq' (for LHQ landscape dataset) and 'faces' (for CelebAHQ faces dataset). 
+                    Default = 'lhq' (string)
+* -a, --arch      : Choose between 'cnn' and 'clip'. Chosen pretrained model is used to extract features from the images.
+                    If 'cnn' is selected, for LHQ dataset the model is a ResNet50 pretrained on Places365 dataset and for
+                    CelebAHQ dataset the model is a pretrained VGGFace. Default = 'cnn' (string)
+* -m, --mode      : Choose between 'kNN' and 'pairs' (for closest pairs), default = 'kNN' (string)
+* -k, --k         : k value for kNN, default = 3 (int)
+* -s, --sample    : Choose between an int and 'all'. If mode is 'kNN', plot kNN for this many samples (first s samples 
+                    in the directory of generated images). If mode is 'pairs', plot the top s closest pairs from entire 
+                    directory of generated images. Default 10 (int or 'all')
+* -n, --name      : Name appendix (string)
+* --fid           : Choose between 'yes' and 'no'. Compute FID, Inception score and upgraded FID scores. Default 'no' (string)   
+
+
+Path to real images leads to a directory with two sub-directories - train and test.
+
+data
+|_ lhq
+|    |_ train
+|    |_ test
+|_ celebahq256_imgs
+|    |_ train
+|    |_ test
+
+CLIP and CNN (ResNet50 or VGGFace) features of training images are saved after the first execution. This alleviates the need
+to recompute features of real images for different sets of generated samples.
+
+### Links
+1. ResNet50 pretrained on Places365 - https://github.com/CSAILVision/places365
+2. Pretrained VGGFace - https://www.robots.ox.ac.uk/~vgg/software/vgg_face/
+3. Clean FID - https://github.com/GaParmar/clean-fid/tree/main
+4. FID infinity, IS infinity - https://github.com/mchong6/FID_IS_infinity/tree/master
diff --git a/evaluation/eval_full/inception.py b/evaluation/eval_full/inception.py
new file mode 100644
index 0000000..246e3bb
--- /dev/null
+++ b/evaluation/eval_full/inception.py
@@ -0,0 +1,74 @@
+# Source - https://github.com/mchong6/FID_IS_infinity
+
+import torch
+import torch.nn as nn
+from torch.nn import Parameter as P
+from torchvision.models.inception import inception_v3
+import torch.nn.functional as F
+
+# Module that wraps the inception network to enable use with dataparallel and
+# returning pool features and logits.
+class WrapInception(nn.Module):
+    def __init__(self, net):
+        super(WrapInception,self).__init__()
+        self.net = net
+        self.mean = P(torch.tensor([0.485, 0.456, 0.406]).view(1, -1, 1, 1),
+                      requires_grad=False)
+        self.std = P(torch.tensor([0.229, 0.224, 0.225]).view(1, -1, 1, 1),
+                     requires_grad=False)
+    def forward(self, x):
+        x = (x - self.mean) / self.std
+        # Upsample if necessary
+        if x.shape[2] != 299 or x.shape[3] != 299:
+            x = F.interpolate(x, size=(299, 299), mode='bilinear', align_corners=True)
+        # 299 x 299 x 3
+        x = self.net.Conv2d_1a_3x3(x)
+        # 149 x 149 x 32
+        x = self.net.Conv2d_2a_3x3(x)
+        # 147 x 147 x 32
+        x = self.net.Conv2d_2b_3x3(x)
+        # 147 x 147 x 64
+        x = F.max_pool2d(x, kernel_size=3, stride=2)
+        # 73 x 73 x 64
+        x = self.net.Conv2d_3b_1x1(x)
+        # 73 x 73 x 80
+        x = self.net.Conv2d_4a_3x3(x)
+        # 71 x 71 x 192
+        x = F.max_pool2d(x, kernel_size=3, stride=2)
+        # 35 x 35 x 192
+        x = self.net.Mixed_5b(x)
+        # 35 x 35 x 256
+        x = self.net.Mixed_5c(x)
+        # 35 x 35 x 288
+        x = self.net.Mixed_5d(x)
+        # 35 x 35 x 288
+        x = self.net.Mixed_6a(x)
+        # 17 x 17 x 768
+        x = self.net.Mixed_6b(x)
+        # 17 x 17 x 768
+        x = self.net.Mixed_6c(x)
+        # 17 x 17 x 768
+        x = self.net.Mixed_6d(x)
+        # 17 x 17 x 768
+        x = self.net.Mixed_6e(x)
+        # 17 x 17 x 768
+        # 17 x 17 x 768
+        x = self.net.Mixed_7a(x)
+        # 8 x 8 x 1280
+        x = self.net.Mixed_7b(x)
+        # 8 x 8 x 2048
+        x = self.net.Mixed_7c(x)
+        # 8 x 8 x 2048
+        pool = torch.mean(x.view(x.size(0), x.size(1), -1), 2)
+        # 1 x 1 x 2048
+        logits = self.net.fc(F.dropout(pool, training=False).view(pool.size(0), -1))
+        # 1000 (num_classes)
+        return pool, logits
+
+# Load and wrap the Inception model
+def load_inception_net(parallel=False):
+    inception_model = inception_v3(pretrained=True, transform_input=False)
+    inception_model = WrapInception(inception_model.eval()).cuda()
+    if parallel:
+        inception_model = nn.DataParallel(inception_model)
+    return inception_model
diff --git a/evaluation/eval_full/kNN.py b/evaluation/eval_full/kNN.py
new file mode 100644
index 0000000..185de21
--- /dev/null
+++ b/evaluation/eval_full/kNN.py
@@ -0,0 +1,157 @@
+import os
+import torch
+import torchvision.transforms as transforms
+from torch.utils.data import DataLoader
+from PIL import Image
+import matplotlib.pyplot as plt
+from collections import OrderedDict
+
+
+class kNN():
+
+    def __init__(self):
+        pass
+
+    def get_images(self, path, transform, *args, **kwargs):
+        '''
+        returns 
+        names: list of filenames
+        image_tensor: tensor with all images
+        '''
+        # path to real image files
+        image_files = os.listdir(path)
+        # list to store filenames
+        names = []
+        # list to store images (transformed to tensors)
+        images_list = []
+        
+        for file in image_files:
+            if file.endswith('.png'):
+                filepath = os.path.join(path, file)
+                names.append(file)
+                im = Image.open(filepath)
+                if im.size[0] != 128:
+                    im = im.resize((128,128))              # DDPM was trained on 128x128 images
+                im = transform(im)  
+                images_list.append(im)
+        
+        # tensor with all real image tensors
+        image_tensor = torch.stack(images_list)
+
+        return names, image_tensor
+
+    def feature_extractor(self, images, model, device='cpu', bs=128, *args, **kwargs):
+        '''
+        returns
+        real_features: VGGFace features for real images
+        generated_features: VGGFace features for generated images
+        '''
+        # extract features for real and generated images
+        dataloader = DataLoader(images, batch_size=bs, shuffle=False)
+        features_list = []
+        if model._get_name() == 'CLIP':
+            with torch.no_grad():
+                for item in dataloader:
+                    features = model.encode_image(item.to(device))
+                    features_list.append(features)
+        else:
+            with torch.no_grad():
+                for item in dataloader:
+                    features = model(item.to(device))
+                    features_list.append(features)
+
+        features = torch.cat(features_list, dim=0)
+        return features
+
+
+    def kNN(self, real_names, generated_names, 
+            real_features, generated_features, 
+            path_to_real_images, path_to_generated_images, 
+            k=3, 
+            sample=10, 
+            name_appendix='',
+            *args, **kwargs):
+        '''
+        creates a plot with (generated image: k nearest real images) pairs
+        '''
+        fig, ax = plt.subplots(sample, k+1, figsize=((k+1)*3,sample*2))
+
+        for i in range(len(generated_features)):
+            # l2 norm of one generated feature and all real features
+            dist = torch.linalg.vector_norm(real_features - generated_features[i], ord=2, dim=1)
+            
+            # draw the generated image
+            im = Image.open(os.path.join(path_to_generated_images, generated_names[i]))
+            ax[i, 0].imshow(im)
+            ax[i, 0].set_xticks([])
+            ax[i, 0].set_yticks([])
+            ax[i, 0].set_title(f'Generated: {generated_names[i].split("_")[2][:-4]}', fontsize=8)
+            
+            # kNN of the generated image
+            knn = dist.topk(k, largest=False)
+            j=1
+
+            # draw the k real images
+            for idx in knn.indices:
+                im = Image.open(os.path.join(path_to_real_images, real_names[idx.item()]))
+                ax[i, j].imshow(im)
+                ax[i, j].set_xticks([])
+                ax[i, j].set_yticks([])
+                ax[i, j].set_title(f'{real_names[idx.item()][:-4]}, {knn.values[j-1].item():.2f}', fontsize=8)
+                j += 1
+            if i == sample-1:
+                break
+        
+        # savefig
+        
+        plot_name = f'{k}NN_{sample}_samples'
+        if name_appendix != '':
+            plot_name = plot_name + name_appendix 
+        fig.savefig('output/' + plot_name + '.png')
+
+    def nearest_neighbor(self, real_names, generated_names, 
+                    real_features, generated_features, 
+                    path_to_real_images, path_to_generated_images, 
+                    sample=10, 
+                    name_appendix='',
+                    *args, **kwargs):
+        
+        print('Computing nearest neighbors...')
+        fig, ax = plt.subplots(sample, 2, figsize=(2*3,sample*2))
+        nn_dict = OrderedDict()
+        
+        for i in range(len(generated_features)):
+            # l2 norm of one generated feature and all real features
+            #dist = torch.linalg.vector_norm(real_features - generated_features[i], ord=2, dim=1)
+            dist = torch.norm(real_features - generated_features[i], dim=1, p=2)
+            # nearest neighbor of the generated image
+            knn = dist.topk(1, largest=False)
+            # insert to the dict: generated_image: (distance, index of the nearest neighbor)
+            nn_dict[generated_names[i]] = (knn.values.item(), knn.indices.item())
+        print('Finding closest pairs...')
+        # sort to get the generated-real pairs that were the closest
+        nn_dict_sorted = OrderedDict(sorted(nn_dict.items(), key=lambda item: item[1][0]))
+        # names of the generated images that look closest to the real images
+        gen_names = list(nn_dict_sorted.keys())
+        print('Drawing the plot...')
+        for i in range(sample):
+            # draw the generated image
+            im = Image.open(os.path.join(path_to_generated_images, gen_names[i]))
+            ax[i, 0].imshow(im)
+            ax[i, 0].set_xticks([])
+            ax[i, 0].set_yticks([])
+            ax[i, 0].set_title(f'Generated: {generated_names[i].split("_")[2][:-4]}', fontsize=8)
+            
+            # draw the real image
+            knn_score, real_img_idx = nn_dict_sorted[gen_names[i]]
+            im = Image.open(os.path.join(path_to_real_images, real_names[real_img_idx]))
+            ax[i, 1].imshow(im)
+            ax[i, 1].set_xticks([])
+            ax[i, 1].set_yticks([])
+            ax[i, 1].set_title(f'{real_names[real_img_idx][:-4]}, {knn_score:.2f}', fontsize=8)
+                
+        #savefig
+        plot_name = f'closest_pairs_top_{sample}'
+        if name_appendix != '':
+            plot_name = plot_name + name_appendix
+        fig.savefig('output/' + plot_name + '.png')
\ No newline at end of file
diff --git a/evaluation/eval_full/metrics.py b/evaluation/eval_full/metrics.py
new file mode 100644
index 0000000..fc0d75c
--- /dev/null
+++ b/evaluation/eval_full/metrics.py
@@ -0,0 +1,70 @@
+import os
+import torch
+from tqdm import tqdm
+from PIL import Image
+from torchvision import transforms
+from torch.utils.data import DataLoader
+from itertools import cycle
+from torchmetrics.image.fid import FrechetInceptionDistance
+from torchmetrics.image.inception import InceptionScore
+from torchmetrics.image.kid import KernelInceptionDistance
+from cleanfid import fid
+from score_infinity import calculate_FID_infinity_path, calculate_IS_infinity_path
+
+def image_to_tensor(path):
+
+    transform = transforms.Compose([transforms.ToTensor(),
+                                transforms.ConvertImageDtype(dtype=torch.float),
+                                transforms.Lambda(lambda x: (x * 255).type(torch.uint8))])
+    file_list = os.listdir(path)
+    img_list = []
+    #for file in tqdm(file_list, desc='Loading images'):
+    for file in file_list:
+        if file.endswith('.png'):
+            img = Image.open(os.path.join(path, file))
+            if img.size[0] != 128:
+                img = img.resize((128,128))
+            img = transform(img)
+            img_list.append(img)
+
+    return torch.stack(img_list)
+
+
+def compute_scores(real, generated, device):
+
+    real_dataloader = DataLoader(real, batch_size=128, shuffle=True)
+    generated_dataloader = DataLoader(generated, batch_size=128, shuffle=True)
+
+    fid = FrechetInceptionDistance().to(device)
+    kid = KernelInceptionDistance().to(device)
+    inception = InceptionScore().to(device)
+
+    for r, g in zip(real_dataloader, cycle(generated_dataloader)):
+        r = r.to(device)
+        g = g.to(device)
+        fid.update(r, real=True)
+        fid.update(g, real=False)
+        kid.update(r, real=True)
+        kid.update(g, real=False)
+        inception.update(g)
+    
+    fid_score = fid.compute()
+    kid_score = kid.compute()
+    is_score = inception.compute()
+    return fid_score, kid_score, is_score
+
+
+def clean_fid(path_to_real_images, path_to_generated_images):
+
+    clean_fid_score = fid.compute_fid(path_to_real_images, path_to_generated_images, mode="clean", num_workers=0)
+    clip_clean_fid_score = fid.compute_fid(path_to_real_images, path_to_generated_images, mode="clean", model_name="clip_vit_b_32")
+
+    return clean_fid_score, clip_clean_fid_score
+
+
+def fid_inf_is_inf(path_to_real_images, path_to_generated_images, batchsize=128):
+
+    fid_infinity = calculate_FID_infinity_path(path_to_real_images, path_to_generated_images, batch_size=batchsize)
+    is_infinity = calculate_IS_infinity_path(path_to_generated_images, batch_size=batchsize)
+
+    return fid_infinity, is_infinity
diff --git a/evaluation/eval_full/score_infinity.py b/evaluation/eval_full/score_infinity.py
new file mode 100644
index 0000000..7081d40
--- /dev/null
+++ b/evaluation/eval_full/score_infinity.py
@@ -0,0 +1,433 @@
+# Source - https://github.com/mchong6/FID_IS_infinity
+
+import torch
+import torch.nn as nn
+import torch.nn.functional as F
+from torch.utils.data import Dataset
+import torchvision.transforms as transforms
+from botorch.sampling.qmc import NormalQMCEngine
+import numpy as np
+import math
+from sklearn.linear_model import LinearRegression
+import math
+import os
+import glob
+from tqdm import tqdm
+from PIL import Image
+from scipy import linalg 
+from inception import *
+
+class randn_sampler():
+    """
+    Generates z~N(0,1) using random sampling or scrambled Sobol sequences.
+    Args:
+        ndim: (int)
+            The dimension of z.
+        use_sobol: (bool)
+            If True, sample z from scrambled Sobol sequence. Else, sample 
+            from standard normal distribution.
+            Default: False
+        use_inv: (bool)
+            If True, use inverse CDF to transform z from U[0,1] to N(0,1).
+            Else, use Box-Muller transformation.
+            Default: True
+        cache: (bool)
+            If True, we cache some amount of Sobol points and reorder them.
+            This is mainly used for training GANs when we use two separate
+            Sobol generators which helps stabilize the training.
+            Default: False
+            
+    Examples::
+
+        >>> sampler = randn_sampler(128, True)
+        >>> z = sampler.draw(10) # Generates [10, 128] vector
+    """
+
+    def __init__(self, ndim, use_sobol=False, use_inv=True, cache=False):
+        self.ndim = ndim
+        self.cache = cache
+        if use_sobol:
+            self.sampler = NormalQMCEngine(d=ndim, inv_transform=use_inv)
+            self.cached_points = torch.tensor([])
+        else:
+            self.sampler = None
+
+    def draw(self, batch_size):
+        if self.sampler is None:
+            return torch.randn([batch_size, self.ndim])
+        else:
+            if self.cache:
+                if len(self.cached_points) < batch_size:
+                    # sample from sampler and reorder the points
+                    self.cached_points = self.sampler.draw(int(1e6))[torch.randperm(int(1e6))]
+
+                # Sample without replacement from cached points
+                samples = self.cached_points[:batch_size]
+                self.cached_points = self.cached_points[batch_size:]
+                return samples
+            else:
+                return self.sampler.draw(batch_size)
+
+def calculate_FID_infinity(gen_model, ndim, batch_size, gt_path, num_im=50000, num_points=15):
+    """
+    Calculates effectively unbiased FID_inf using extrapolation
+    Args:
+        gen_model: (nn.Module)
+            The trained generator. Generator takes in z~N(0,1) and outputs
+            an image of [-1, 1].
+        ndim: (int)
+            The dimension of z.
+        batch_size: (int)
+            The batch size of generator
+        gt_path: (str)
+            Path to saved FID statistics of true data.
+        num_im: (int)
+            Number of images we are generating to evaluate FID_inf.
+            Default: 50000
+        num_points: (int)
+            Number of FID_N we evaluate to fit a line.
+            Default: 15
+    """
+    # load pretrained inception model 
+    inception_model = load_inception_net()
+
+    # define a sobol_inv sampler
+    z_sampler = randn_sampler(ndim, True)
+
+    # get all activations of generated images
+    activations, _ =  accumulate_activations(gen_model, inception_model, num_im, z_sampler, batch_size)
+
+    fids = []
+
+    # Choose the number of images to evaluate FID_N at regular intervals over N
+    fid_batches = np.linspace(5000, num_im, num_points).astype('int32')
+
+    # Evaluate FID_N
+    for fid_batch_size in fid_batches:
+        # sample with replacement
+        np.random.shuffle(activations)
+        fid_activations = activations[:fid_batch_size]
+        fids.append(calculate_FID(inception_model, fid_activations, gt_path))
+    fids = np.array(fids).reshape(-1, 1)
+    
+    # Fit linear regression
+    reg = LinearRegression().fit(1/fid_batches.reshape(-1, 1), fids)
+    fid_infinity = reg.predict(np.array([[0]]))[0,0]
+
+    return fid_infinity
+
+def calculate_FID_infinity_path(real_path, fake_path, batch_size=50, min_fake=5000, num_points=15):
+    """
+    Calculates effectively unbiased FID_inf using extrapolation given 
+    paths to real and fake data
+    Args:
+        real_path: (str)
+            Path to real dataset or precomputed .npz statistics.
+        fake_path: (str)
+            Path to fake dataset.
+        batch_size: (int)
+            The batch size for dataloader.
+            Default: 50
+        min_fake: (int)
+            Minimum number of images to evaluate FID on.
+            Default: 5000
+        num_points: (int)
+            Number of FID_N we evaluate to fit a line.
+            Default: 15
+    """
+    # load pretrained inception model 
+    inception_model = load_inception_net()
+
+    # get all activations of generated images
+    if real_path.endswith('.npz'):
+        real_m, real_s = load_path_statistics(real_path)
+    else:
+        real_act, _ = compute_path_statistics(real_path, batch_size, model=inception_model)
+        real_m, real_s = np.mean(real_act, axis=0), np.cov(real_act, rowvar=False)
+
+    fake_act, _ = compute_path_statistics(fake_path, batch_size, model=inception_model)
+
+    num_fake = len(fake_act)
+    assert num_fake > min_fake, \
+        'number of fake data must be greater than the minimum point for extrapolation'
+
+    fids = []
+
+    # Choose the number of images to evaluate FID_N at regular intervals over N
+    fid_batches = np.linspace(min_fake, num_fake, num_points).astype('int32')
+
+    # Evaluate FID_N
+    for fid_batch_size in fid_batches:
+        # sample with replacement
+        np.random.shuffle(fake_act)
+        fid_activations = fake_act[:fid_batch_size]
+        m, s = np.mean(fid_activations, axis=0), np.cov(fid_activations, rowvar=False)
+        FID = numpy_calculate_frechet_distance(m, s, real_m, real_s)
+        fids.append(FID)
+    fids = np.array(fids).reshape(-1, 1)
+    
+    # Fit linear regression
+    reg = LinearRegression().fit(1/fid_batches.reshape(-1, 1), fids)
+    fid_infinity = reg.predict(np.array([[0]]))[0,0]
+
+    return fid_infinity
+
+def calculate_IS_infinity(gen_model, ndim, batch_size, num_im=50000, num_points=15):
+    """
+    Calculates effectively unbiased IS_inf using extrapolation
+    Args:
+        gen_model: (nn.Module)
+            The trained generator. Generator takes in z~N(0,1) and outputs
+            an image of [-1, 1].
+        ndim: (int)
+            The dimension of z.
+        batch_size: (int)
+            The batch size of generator
+        num_im: (int)
+            Number of images we are generating to evaluate IS_inf.
+            Default: 50000
+        num_points: (int)
+            Number of IS_N we evaluate to fit a line.
+            Default: 15
+    """
+    # load pretrained inception model 
+    inception_model = load_inception_net()
+
+    # define a sobol_inv sampler
+    z_sampler = randn_sampler(ndim, True)
+
+    # get all activations of generated images
+    _, logits =  accumulate_activations(gen_model, inception_model, num_im, z_sampler, batch_size)
+
+    IS = []
+
+    # Choose the number of images to evaluate IS_N at regular intervals over N
+    IS_batches = np.linspace(5000, num_im, num_points).astype('int32')
+
+    # Evaluate IS_N
+    for IS_batch_size in IS_batches:
+        # sample with replacement
+        np.random.shuffle(logits)
+        IS_logits = logits[:IS_batch_size]
+        IS.append(calculate_inception_score(IS_logits)[0])
+    IS = np.array(IS).reshape(-1, 1)
+    
+    # Fit linear regression
+    reg = LinearRegression().fit(1/IS_batches.reshape(-1, 1), IS)
+    IS_infinity = reg.predict(np.array([[0]]))[0,0]
+
+    return IS_infinity
+
+def calculate_IS_infinity_path(path, batch_size=50, min_fake=5000, num_points=15):
+    """
+    Calculates effectively unbiased IS_inf using extrapolation given 
+    paths to real and fake data
+    Args:
+        path: (str)
+            Path to fake dataset.
+        batch_size: (int)
+            The batch size for dataloader.
+            Default: 50
+        min_fake: (int)
+            Minimum number of images to evaluate IS on.
+            Default: 5000
+        num_points: (int)
+            Number of IS_N we evaluate to fit a line.
+            Default: 15
+    """
+    # load pretrained inception model 
+    inception_model = load_inception_net()
+
+    # get all activations of generated images
+    _, logits = compute_path_statistics(path, batch_size, model=inception_model)
+
+    num_fake = len(logits)
+    assert num_fake > min_fake, \
+        'number of fake data must be greater than the minimum point for extrapolation'
+
+    IS = []
+
+    # Choose the number of images to evaluate FID_N at regular intervals over N
+    IS_batches = np.linspace(min_fake, num_fake, num_points).astype('int32')
+
+    # Evaluate IS_N
+    for IS_batch_size in IS_batches:
+        # sample with replacement
+        np.random.shuffle(logits)
+        IS_logits = logits[:IS_batch_size]
+        IS.append(calculate_inception_score(IS_logits)[0])
+    IS = np.array(IS).reshape(-1, 1)
+    
+    # Fit linear regression
+    reg = LinearRegression().fit(1/IS_batches.reshape(-1, 1), IS)
+    IS_infinity = reg.predict(np.array([[0]]))[0,0]
+
+    return IS_infinity
+
+################# Functions for calculating and saving dataset inception statistics ##################
+class im_dataset(Dataset):
+    def __init__(self, data_dir):
+        self.data_dir = data_dir
+        self.imgpaths = self.get_imgpaths()
+        
+        self.transform = transforms.Compose([
+                       transforms.Resize(64),
+                       transforms.CenterCrop(64),
+                       transforms.ToTensor()])
+
+    def get_imgpaths(self):
+        paths = glob.glob(os.path.join(self.data_dir, "**/*.jpg"), recursive=True) +\
+            glob.glob(os.path.join(self.data_dir, "**/*.png"), recursive=True)
+        return paths
+    
+    def __getitem__(self, idx):
+        img_name = self.imgpaths[idx]
+        image = self.transform(Image.open(img_name))
+        return image
+
+    def __len__(self):
+        return len(self.imgpaths)
+
+def load_path_statistics(path):
+    """
+    Given path to dataset npz file, load and return mu and sigma
+    """
+    if path.endswith('.npz'):
+        f = np.load(path)
+        m, s = f['mu'][:], f['sigma'][:]
+        f.close()
+        return m, s
+    else:
+        raise RuntimeError('Invalid path: %s' % path)
+        
+def compute_path_statistics(path, batch_size, model=None):
+    """
+    Given path to a dataset, load and compute mu and sigma.
+    """
+    if not os.path.exists(path):
+        raise RuntimeError('Invalid path: %s' % path)
+        
+    if model is None:
+        model = load_inception_net()
+    dataset = im_dataset(path)
+    dataloader = torch.utils.data.DataLoader(dataset, batch_size=batch_size, drop_last=False)
+    return get_activations(dataloader, model)
+
+def get_activations(dataloader, model):
+    """
+    Get inception activations from dataset
+    """
+    pool = []
+    logits = []
+
+    for images in tqdm(dataloader):
+        images = images.cuda()
+        with torch.no_grad():
+            pool_val, logits_val = model(images)
+            pool += [pool_val]
+            logits += [F.softmax(logits_val, 1)]
+
+    return torch.cat(pool, 0).cpu().numpy(), torch.cat(logits, 0).cpu().numpy()
+
+def accumulate_activations(gen_model, inception_model, num_im, z_sampler, batch_size):
+    """
+    Generate images and compute their Inception activations.
+    """
+    pool, logits = [], []
+    for i in range(math.ceil(num_im/batch_size)):
+        with torch.no_grad():
+            z = z_sampler.draw(batch_size).cuda()
+            fake_img = to_img(gen_model(z))
+
+            pool_val, logits_val = inception_model(fake_img)
+            pool += [pool_val]
+            logits += [F.softmax(logits_val, 1)]
+
+    pool =  torch.cat(pool, 0)[:num_im]
+    logits = torch.cat(logits, 0)[:num_im]
+
+    return pool.cpu().numpy(), logits.cpu().numpy()
+
+def to_img(x):
+    """
+    Normalizes an image from [-1, 1] to [0, 1]
+    """
+    x = 0.5 * (x + 1)
+    x = x.clamp(0, 1)
+    return x
+
+
+
+####################### Functions to help calculate FID and IS #######################
+def calculate_FID(model, act, gt_npz):
+    """
+    calculate score given activations and path to npz
+    """
+    data_m, data_s = load_path_statistics(gt_npz)
+    gen_m, gen_s = np.mean(act, axis=0), np.cov(act, rowvar=False)
+    FID = numpy_calculate_frechet_distance(gen_m, gen_s, data_m, data_s)
+
+    return FID
+
+def calculate_inception_score(pred, num_splits=1):
+    scores = []
+    for index in range(num_splits):
+        pred_chunk = pred[index * (pred.shape[0] // num_splits): (index + 1) * (pred.shape[0] // num_splits), :]
+        kl_inception = pred_chunk * (np.log(pred_chunk) - np.log(np.expand_dims(np.mean(pred_chunk, 0), 0)))
+        kl_inception = np.mean(np.sum(kl_inception, 1))
+        scores.append(np.exp(kl_inception))
+    return np.mean(scores), np.std(scores)
+
+
+def numpy_calculate_frechet_distance(mu1, sigma1, mu2, sigma2, eps=1e-6):
+    """Numpy implementation of the Frechet Distance.
+    The Frechet distance between two multivariate Gaussians X_1 ~ N(mu_1, C_1)
+    and X_2 ~ N(mu_2, C_2) is
+            d^2 = ||mu_1 - mu_2||^2 + Tr(C_1 + C_2 - 2*sqrt(C_1*C_2)).
+    Stable version by Dougal J. Sutherland.
+    Params:
+    -- mu1   : Numpy array containing the activations of a layer of the
+               inception net (like returned by the function 'get_predictions')
+               for generated samples.
+    -- mu2   : The sample mean over activations, precalculated on an
+               representative data set.
+    -- sigma1: The covariance matrix over activations for generated samples.
+    -- sigma2: The covariance matrix over activations, precalculated on an
+               representative data set.
+    Returns:
+    --   : The Frechet Distance.
+    """
+
+    mu1 = np.atleast_1d(mu1)
+    mu2 = np.atleast_1d(mu2)
+
+    sigma1 = np.atleast_2d(sigma1)
+    sigma2 = np.atleast_2d(sigma2)
+
+    assert mu1.shape == mu2.shape, \
+        'Training and test mean vectors have different lengths'
+    assert sigma1.shape == sigma2.shape, \
+        'Training and test covariances have different dimensions'
+
+    diff = mu1 - mu2
+
+    # Product might be almost singular
+    covmean, _ = linalg.sqrtm(sigma1.dot(sigma2), disp=False)
+    if not np.isfinite(covmean).all():
+        msg = ('fid calculation produces singular product; '
+               'adding %s to diagonal of cov estimates') % eps
+        print(msg)
+        offset = np.eye(sigma1.shape[0]) * eps
+        covmean = linalg.sqrtm((sigma1 + offset).dot(sigma2 + offset))
+
+    # Numerical error might give slight imaginary component
+    if np.iscomplexobj(covmean):
+        if not np.allclose(np.diagonal(covmean).imag, 0, atol=1e-3):
+            m = np.max(np.abs(covmean.imag))
+            raise ValueError('Imaginary component {}'.format(m))
+        covmean = covmean.real
+
+    tr_covmean = np.trace(covmean)
+
+    return (diff.dot(diff) + np.trace(sigma1) +
+            np.trace(sigma2) - 2 * tr_covmean)
diff --git a/evaluation/eval_full/vggface.py b/evaluation/eval_full/vggface.py
new file mode 100644
index 0000000..2e4086e
--- /dev/null
+++ b/evaluation/eval_full/vggface.py
@@ -0,0 +1,93 @@
+# VGG16 model from https://github.com/prlz77/vgg-face.pytorch
+
+import torch
+import torch.nn as nn
+import torch.nn.functional as F
+import torchfile
+
+class VGG_16(nn.Module):
+    """
+    Main Class
+    """
+
+    def __init__(self):
+        """
+        Constructor
+        """
+        super().__init__()
+        self.block_size = [2, 2, 3, 3, 3]
+        self.conv_1_1 = nn.Conv2d(3, 64, 3, stride=1, padding=1)
+        self.conv_1_2 = nn.Conv2d(64, 64, 3, stride=1, padding=1)
+        self.conv_2_1 = nn.Conv2d(64, 128, 3, stride=1, padding=1)
+        self.conv_2_2 = nn.Conv2d(128, 128, 3, stride=1, padding=1)
+        self.conv_3_1 = nn.Conv2d(128, 256, 3, stride=1, padding=1)
+        self.conv_3_2 = nn.Conv2d(256, 256, 3, stride=1, padding=1)
+        self.conv_3_3 = nn.Conv2d(256, 256, 3, stride=1, padding=1)
+        self.conv_4_1 = nn.Conv2d(256, 512, 3, stride=1, padding=1)
+        self.conv_4_2 = nn.Conv2d(512, 512, 3, stride=1, padding=1)
+        self.conv_4_3 = nn.Conv2d(512, 512, 3, stride=1, padding=1)
+        self.conv_5_1 = nn.Conv2d(512, 512, 3, stride=1, padding=1)
+        self.conv_5_2 = nn.Conv2d(512, 512, 3, stride=1, padding=1)
+        self.conv_5_3 = nn.Conv2d(512, 512, 3, stride=1, padding=1)
+        self.fc6 = nn.Linear(512 * 7 * 7, 4096)
+        self.fc7 = nn.Linear(4096, 4096)
+        self.fc8 = nn.Linear(4096, 2622)
+
+    def load_weights(self, path):
+        """ Function to load luatorch pretrained
+
+        Args:
+            path: path for the luatorch pretrained
+        """
+        model = torchfile.load(path)
+        counter = 1
+        block = 1
+        for i, layer in enumerate(model.modules):
+            if layer.weight is not None:
+                if block <= 5:
+                    self_layer = getattr(self, "conv_%d_%d" % (block, counter))
+                    counter += 1
+                    if counter > self.block_size[block - 1]:
+                        counter = 1
+                        block += 1
+                    self_layer.weight.data[...] = torch.tensor(layer.weight).view_as(self_layer.weight)[...]
+                    self_layer.bias.data[...] = torch.tensor(layer.bias).view_as(self_layer.bias)[...]
+                else:
+                    self_layer = getattr(self, "fc%d" % (block))
+                    block += 1
+                    self_layer.weight.data[...] = torch.tensor(layer.weight).view_as(self_layer.weight)[...]
+                    self_layer.bias.data[...] = torch.tensor(layer.bias).view_as(self_layer.bias)[...]
+
+    def forward(self, x):
+        """ Pytorch forward
+
+        Args:
+            x: input image (224x224)
+
+        Returns: class logits
+
+        """
+        x = F.relu(self.conv_1_1(x))
+        x = F.relu(self.conv_1_2(x))
+        x = F.max_pool2d(x, 2, 2)
+        x = F.relu(self.conv_2_1(x))
+        x = F.relu(self.conv_2_2(x))
+        x = F.max_pool2d(x, 2, 2)
+        x = F.relu(self.conv_3_1(x))
+        x = F.relu(self.conv_3_2(x))
+        x = F.relu(self.conv_3_3(x))
+        x = F.max_pool2d(x, 2, 2)
+        x = F.relu(self.conv_4_1(x))
+        x = F.relu(self.conv_4_2(x))
+        x = F.relu(self.conv_4_3(x))
+        x = F.max_pool2d(x, 2, 2)
+        x = F.relu(self.conv_5_1(x))
+        x = F.relu(self.conv_5_2(x))
+        x = F.relu(self.conv_5_3(x))
+        x = F.max_pool2d(x, 2, 2)
+        x = x.view(x.size(0), -1)
+        x = F.relu(self.fc6(x))
+        x = F.dropout(x, 0.5, self.training)
+        x = F.relu(self.fc7(x))
+        x = F.dropout(x, 0.5, self.training)
+        return self.fc8(x)
\ No newline at end of file
-- 
GitLab