From 6fe38f9b458411e40066deb0b3a7501e87d10593 Mon Sep 17 00:00:00 2001
From: Srijeet Roy <srijeet.11@gmail.com>
Date: Mon, 31 Jul 2023 14:06:51 +0200
Subject: [PATCH] update evaluation pipeline

---
 evaluation/evaluate.py | 276 +++++++++++++++++++++++++----------------
 1 file changed, 172 insertions(+), 104 deletions(-)

diff --git a/evaluation/evaluate.py b/evaluation/evaluate.py
index 316953c..2df0f2d 100644
--- a/evaluation/evaluate.py
+++ b/evaluation/evaluate.py
@@ -1,108 +1,176 @@
-from torchmetrics.image.fid import FrechetInceptionDistance
-from torchmetrics.image.inception import InceptionScore
-from torchmetrics.image.kid import KernelInceptionDistance
-import re
 import os
-from PIL import Image
-from torchvision import transforms
+import argparse
+import pickle
+from pathlib import Path
 import torch
+from torchvision import transforms
+import clip
+from evaluation.helpers.vggface import *
+from torchvision.models import resnet50
+
+from evaluation.helpers.kNN import *
+from evaluation.helpers.metrics import *
+
+
+def cdm_evaluator(experiment_path, realpath, genpath, size=128, arch='clip', mode='both', k=3, sample=10, name_appendix='', fid='no'):
+    
+    device = "cuda" if torch.cuda.is_available() else "cpu"
 
-def cdm_evaluator(model, 
-                  device,
-                  dataloader,
-                  checkpoint,
-                  experiment_path,
-                  sample_idx=0,
-                  **args,
-                  ):
-  '''
-  Takes a trained diffusion model from 'checkpoint' and evaluates its performance on the test 
-  dataset 'dataloader' w.r.t. the three most important perfromance metrics; FID, IS, KID. We continue
-  the progress of our evaluation function for the LDM upscalaer and may update this function accordingly.
+    print('device:', device)
+
+    path_to_real_images = realpath                          # path to real images (assumes that there are 2 subdirectories - train and test)
+    path_to_generated_images = genpath                      # path to generated samples
+    size = size                                             # image resolution
+    arch = arch                                             # architecture to extract features - choose between 'cnn' and 'clip'
+    mode = mode                                             # qualitative eval mode - 'kNN' or 'pairs' (closest pairs) or 'both'
+    k_kNN = k                                               # value of k if mode=='kNN'
+    sample = sample                                         # for kNN, find kNNs of first 'sample' samples in the dir; for pairs, find top 'sample' closest ones
+    name_appendix = name_appendix                           # name appendix for evaluation files
+    fid_bool = fid                                          # whether to compute FID, IS scores - choose between 'yes' or 'no'
+    
+    
+    print('Start')
+    # change working directory to output folder (experiment_path/eval_output)
+    output_path = Path(os.path.join(experiment_path,'eval_output'))
+    if not output_path.is_dir():
+        os.mkdir(output_path)
+
+    # output path
+    os.chdir(output_path)
+    # create output text file, store evaluation metadata
+    txt_filename = 'evaluation_' + '_' + arch + '_' + mode + '-' + name_appendix + '.txt'
+    with open(txt_filename, 'w') as f:
+        f.write(f'Path to real images: {path_to_real_images}\n')
+        f.write(f'Path to generated images: {path_to_generated_images}\n')
+        f.write(f'Experiment on AFHQ dataset with images of resolution {size}x{size}\n')
+        f.write(f'Using {arch} model to extract features\n')
+        f.write(f'Plot of {mode} on {sample} samples\n')
+        f.write(f'Quantitative metrics computed: {fid_bool}\n')
+        
     
-  checkpoint:      Name of the saved pth. file containing the trained weights and biases  
-  experiment_path: Path to the experiment folder where the evaluation results will be stored  
-  dataloader:      Loads the test dataset for evaluation
-  sample_idx:      Integer that denotes which sample directory sample_{sample_idx} from the checkpoint model shall be used for evaluation
-  '''
-
-  checkpoint_path = f'{experiment_path}trained_cdm/{checkpoint}'
-  # create evaluation directory for the complete experiment (if first time sampling images)
-  output_dir = f'{experiment_path}evaluations/'
-  os.makedirs(output_dir, exist_ok=True)
-
-  # create evaluation directory for the current version of the trained model
-  model_name = os.path.basename(checkpoint_path)
-  epoch = re.findall(r'\d+', model_name)
-  if epoch:
-      e = int(epoch[0])
-  else:
-      raise ValueError(f"No digit found in the filename")
-  model_dir = os.path.join(output_dir,f'epoch_{e}')
-  os.makedirs(model_dir, exist_ok=True)
-
-  # create the evaluation directory for this evaluation run for the current version of the model    
-  eval_dir_list = [d for d in os.listdir(model_dir) if os.path.isdir(os.path.join(model_dir, d))]
-  indx_list = [int(d.split('_')[1]) for d in eval_dir_list if d.startswith('evaluation_')]
-  j = max(indx_list, default=-1) + 1
-  eval_dir = os.path.join(model_dir, f'evaluation_{j}')
-  os.makedirs(eval_dir, exist_ok=True)
-  
-  # Compute metrics  
-  eval_path = os.path.join(eval_dir, 'eval.txt')
- 
-  # get sampled images
-  transform = transforms.Compose([transforms.ToTensor(), transforms.Lambda(lambda x: (x * 255).type(torch.uint8))])
-  sample_path =  os.path.join(f'{experiment_path}samples/',f'epoch_{e}',f'sample_{sample_idx}')
-  ignore_tensor = f'image_tensor{j}'
-  images = []
-  for samplename in os.listdir(sample_path):
-        if samplename == ignore_tensor:
-            continue
-        img = Image.open(os.path.join(sample_path, samplename))
-        img = transform(img)
-        images.append(img)
-  # split them into batches for GPU memory
-  generated = torch.stack(images).to(device)
-  generated_batches = torch.split(generated, dataloader.batch_size)
-  nr_generated_batches = len(generated_batches)
-  nr_real_batches = len(dataloader)
-
-  # Init FID, IS and KID scores
-  fid = FrechetInceptionDistance(normalize = False).to(device)
-  iscore = InceptionScore(normalize=False).to(device)
-  kid = KernelInceptionDistance(normalize=False, subset_size=32).to(device)
-
-  # Update scores for the full testing dataset w.r.t. the sampled batches
-  for idx,(data, _) in enumerate(dataloader):
-    data = data.to(device)
-    fid.update(data, real=True)
-    kid.update(data, real=True)
-    if idx < nr_generated_batches:
-        gen = generated_batches[idx].to(device)
-        fid.update(gen, real=False)
-        kid.update(gen, real=False)
-        iscore.update(gen)
-
-  # If there are sampled images left, add them too
-  for idx in range(nr_real_batches, nr_generated_batches):
-    gen = generated_batches[idx].to(device)
-    fid.update(gen, real=False)
-    kid.update(gen, real=False)
-    iscore.update(gen)
- 
-  # compute total FID, IS and KID 
-  fid_score = fid.compute()
-  i_score = iscore.compute()
-  kid_score = kid.compute()
-
-  # store results in txt file
-  with open(str(eval_path), 'a') as txt:
-    result = f'FID_epoch_{e}_sample_{sample_idx}:'
-    txt.write(result + str(fid_score.item()) + '\n')
-    result = f'KID_epoch_{e}_sample_{sample_idx}:'
-    txt.write(result + str(kid_score) + '\n')
-    result =  f'IS_epoch_{e}_sample_{sample_idx}:'
-    txt.write(result + str(i_score) + '\n')
-
- 
+    # datapaths
+    path_to_training_images = os.path.join(path_to_real_images, 'train')
+    path_to_test_images = os.path.join(path_to_real_images, 'test')
+
+    # compute quantitative metrics (FID, IS and variants)
+    if fid_bool == 'yes':
+
+        # convert images to tensors
+        eval_images = image_to_tensor(path_to_test_images, device=device)
+        generated = image_to_tensor(path_to_generated_images, device=device)
+        # Content-variant metrics
+        print('Computing FID, KID & inception scores...')
+        fid_score, kid_score, is_score = compute_scores(eval_images, generated, device)
+        with open(txt_filename, 'a') as f:
+            f.write(f'FID score: {fid_score}\n')
+            f.write(f'KID score: {kid_score}\n')
+            f.write(f'Inception score: {is_score}\n')
+
+        print('Computing Clean FID scores...')
+        clean_fid_score, clip_clean_fid_score = clean_fid(path_to_test_images, path_to_generated_images)
+        with open(txt_filename, 'a') as f:
+            f.write(f'Clean FID score: {clean_fid_score}\n')
+            f.write(f'Clean FID score with CLIP features: {clip_clean_fid_score}\n')
+        
+        print('Computing FID infinity and IS infinity scores...')
+        fid_infinity, is_infinity = fid_inf_is_inf(path_to_test_images, path_to_generated_images) 
+        with open(txt_filename, 'a') as f:
+            f.write(f'FID Infinity score: {fid_infinity}\n')
+            f.write(f'IS Infinity score: {is_infinity}\n')
+
+        
+    
+    print(f'Loading model {arch}...')   
+    feature_flag = False
+    
+    # Quantitative Evaluations
+    
+    # load pre-trained models
+    # path to pre-saved training image features
+    pth = Path('/home/wn455752/repo/evaluation/features/afhq/clip_features')
+    # load CLIP
+    print('Loading pretrained CLIP...')
+    model, transform = clip.load("ViT-B/32", device=device)
+    # check for saved dataset features
+    print('Checking for existing training dataset features...')
+    if pth.is_dir():
+        name_pth = Path(os.path.join(str(pth), 'real_name_list'))
+        if name_pth.is_file():
+            with open(name_pth, 'rb') as fp:
+                real_names = pickle.load(fp)
+        feature_pth = Path(os.path.join(str(pth), 'real_image_features.pt'))
+        if name_pth.is_file():
+            print('Loading existing training dataset features...')
+            real_features = torch.load(feature_pth, map_location="cpu")
+            real_features = real_features.to(device)
+            feature_flag = True
+    else:
+        os.mkdir(pth)
+        os.mkdir(os.path.join(str(pth), 'clip_features')) 
+        name_pth = Path(os.path.join(str(pth), 'clip_features/real_name_list'))
+        feature_pth = Path(os.path.join(str(pth), 'clip_features/real_image_features.pt'))
+
+
+    # Qualitative Evaluations
+
+    knn = kNN()
+    # collect images from directories and store in a tensor
+    if not feature_flag:
+        print('Collecting training images...')
+        real_names, real_tensor = knn.get_images(path_to_training_images, transform)
+        with open(name_pth, 'wb') as fp:
+            pickle.dump(real_names, fp)
+    print('Collecting generated images...')
+    generated_names, generated_tensor = knn.get_images(path_to_generated_images, transform)
+
+    # extract features from image tensors
+    if not feature_flag:
+        print('Extracting features from training images...')
+        real_features = knn.feature_extractor(real_tensor, model, device)
+        torch.save(real_features, feature_pth)
+    print('Extracting features from generated images...')
+    generated_features = knn.feature_extractor(generated_tensor, model, device)
+
+    if sample == 'all':
+        sample_size = len(generated_names)
+    else:
+        sample_size = int(sample)
+
+    if mode == 'kNN':
+        print('Finding kNNs...')
+        knn.kNN(output_path,
+            real_names, generated_names, 
+            real_features, generated_features, 
+            path_to_training_images, path_to_generated_images, 
+            k=k_kNN, 
+            sample=sample_size, 
+            size=size,
+            name_appendix=name_appendix)
+    elif mode == 'pairs':
+        print('Finding closest pairs...')
+        knn.nearest_neighbor(output_path,
+                        real_names, generated_names, 
+                        real_features, generated_features, 
+                        path_to_training_images, path_to_generated_images, 
+                        sample=sample_size, 
+                        size=size,
+                        name_appendix=name_appendix)
+    elif mode == 'both':
+        print('Finding kNNs...')
+        knn.kNN(output_path,
+            real_names, generated_names, 
+            real_features, generated_features, 
+            path_to_training_images, path_to_generated_images, 
+            k=k_kNN, 
+            sample=sample_size, 
+            size=size,
+            name_appendix=name_appendix)
+        print('Finding closest pairs...')
+        knn.nearest_neighbor(output_path,
+                        real_names, generated_names, 
+                        real_features, generated_features, 
+                        path_to_training_images, path_to_generated_images, 
+                        sample=sample_size, 
+                        size=size,
+                        name_appendix=name_appendix)
+    print('Finish!')
-- 
GitLab