Skip to content
Snippets Groups Projects
Commit 4eac7622 authored by Srijeet Roy's avatar Srijeet Roy
Browse files

update kNN, add psnr, ssim metric

parent 7e935796
No related branches found
No related tags found
No related merge requests found
import os
import torch
import argparse
from PIL import Image
from torch.utils.data import DataLoader
import torchvision.transforms as transforms
from itertools import cycle
from torchmetrics.image import StructuralSimilarityIndexMeasure, PeakSignalNoiseRatio
def image_to_tensor(path, sample=10, device='cpu'):
transform_resize = transforms.Compose([transforms.ToTensor(), transforms.Resize(128), transforms.Lambda(lambda x: (x * 255)) ])
transform = transforms.Compose([transforms.ToTensor(), transforms.Lambda(lambda x: (x * 255)) ])
filelist = os.listdir(path)
if sample == 'all':
sample_size = -1
else:
sample_size = sample
image_names = []
image_list = []
for file in filelist:
if file.endswith('.png'):
filepath = os.path.join(path, file)
image_names.append(file)
im = Image.open(filepath)
if im.size[0] != 128:
im = transform_resize(im)
else:
im = transform(im)
image_list.append(im)
if len(image_list) == sample_size:
break
print(f'current sample size: {len(image_names)}')
# convert list of tensors to tensor
image_tensor = torch.stack(image_list).to(device)
return image_tensor
if __name__ == '__main__':
device = 'mps' if torch.backends.mps.is_available() else 'cpu'
#device = 'cuda' if torch.cuda.is_available() else 'cpu'
parser = argparse.ArgumentParser()
parser.add_argument('-rp', '--realpath', required=True, help='path to real images', type=str)
parser.add_argument('-gp', '--genpath', required=True, help='path to generated images', type=str)
parser.add_argument('-s', '--sample', nargs='?', const=10, default=10,
help='how many generated samples to compare, default 10, int or "all"')
parser.add_argument('-n', '--name', nargs='?', const='', default='',
help='name appendix')
args = vars(parser.parse_args())
path_to_real_images = args['realpath']
path_to_generated_images = args['genpath']
sample = args['sample']
name_appendix = args['name']
real_image_tensor = image_to_tensor(path_to_real_images, sample, device)
generated_image_tensor = image_to_tensor(path_to_generated_images, sample, device)
real_dataloader = DataLoader(real_image_tensor, batch_size=128, shuffle=False)
generated_dataloader = DataLoader(generated_image_tensor, batch_size=128, shuffle=False)
ssim = StructuralSimilarityIndexMeasure(data_range=1.0).to(device)
psnr = PeakSignalNoiseRatio().to(device)
for r, g in zip(real_dataloader, cycle(generated_dataloader)):
r = r.to(device)
g = g.to(device)
ssim.update(preds=g, target=r)
psnr.update(preds=g, target=r)
ssim_score = ssim.compute()
psnr_score = psnr.compute()
print(f'SSIM: {ssim_score:0.3f}')
print(f'PSNR: {psnr_score:0.3f}')
txtfile = 'content_invariant_metrics.txt'
if name_appendix != '':
txtfile = 'content_invariant_metrics_' + name_appendix + '.txt'
with open(os.path.join(os.getcwd(),txtfile), 'w') as fp:
fp.write(f'SSIM: {ssim_score:0.3f}\n')
fp.write(f'PSNR: {psnr_score:0.3f}\n')
\ No newline at end of file
......@@ -3,7 +3,7 @@
We conduct two types of evaluation - qualitative and quantitative.
### Quantitative evaluations -
<pre>
Quantitative evaluations are carried out to compare different backbone architectures of our unconditional diffusion model.
A set of 10,000 generated samples from each model variant is compared with the test set of the real dataset.
These evaluations include -
......@@ -11,10 +11,10 @@ These evaluations include -
2. Inception score
3. Clean FID score (with CLIP)
4. FID infinity and IS infinity scores
</pre>
### Qualitative evaluations -
<pre>
The aim of this set of evaluations is to qualitatively inspect whether our model has overfit to the training images.
For this, the entire set of 10,000 generated samples from the best performing model from quanititative evaluation is
compared with the training set of the real dataset.
......@@ -24,18 +24,18 @@ The comparison is implemented as MSE values between features of the generated an
extracted by using a pretrained model (ResNet50-Places365/VGGFace or CLIP). Based on the MSE scores we compute -
1. kNN - plot the k nearest neighbors of the generated samples
2. Closest pairs - plot the top pairs with smallest MSE value
</pre>
### Argumnets -
<pre>
Execution starts with evaluate_full.py file. Input arguments are -
</pre>
* <pre>-rp, --realpath : Path to real images (string) </pre>
* <pre>-gp, --genpath : Path to generated images (string) </pre>
* <pre>-d, --data : Choose between 'lhq' (for LHQ landscape dataset) and 'faces' (for CelebAHQ faces dataset).
Default = 'lhq' (string)</pre>
* <pre>--size : Resolution of images the model was trained on, default 128 (int) </pre>
* <pre>-a, --arch : Choose between 'cnn' and 'clip'. Chosen pretrained model is used to extract features from the images. </pre>
* <pre>-a, --arch : Choose between 'cnn' and 'clip'. Chosen pretrained model is used to extract features from the images.
If 'cnn' is selected, for LHQ dataset the model is a ResNet50 pretrained on Places365 dataset and for
CelebAHQ dataset the model is a pretrained VGGFace. Not relevant in computing FID, IS scores. Default = 'clip' (string) </pre>
* <pre>-m, --mode : Choose between 'kNN' and 'pairs' (for closest pairs), default = 'kNN' (string) </pre>
......@@ -46,9 +46,10 @@ Execution starts with evaluate_full.py file. Input arguments are -
* <pre>-n, --name : Name appendix (string) </pre>
* <pre>--fid : Choose between 'yes' and 'no'. Compute FID, Inception score and upgraded FID scores. Default 'no' (string) </pre>
<pre>
Path to real images leads to a directory with two sub-directories - train and test.
<pre>
data
|_ lhq
| |_ train
......@@ -56,10 +57,10 @@ data
|_ celebahq256_imgs
| |_ train
| |_ test
</pre>
CLIP and CNN (ResNet50 or VGGFace) features of training images are saved after the first execution. This alleviates the need \
to recompute features of real images for different sets of generated samples.
</pre>
### Links
1. ResNet50 pretrained on Places365 - https://github.com/CSAILVision/places365
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment