Skip to content
Snippets Groups Projects
Commit 984a5450 authored by Srijeet Roy's avatar Srijeet Roy
Browse files

update flexibility with diff input shape, output directories, readme

parent f82fe3e1
No related branches found
No related tags found
No related merge requests found
...@@ -10,6 +10,7 @@ from torchvision.models import resnet50 ...@@ -10,6 +10,7 @@ from torchvision.models import resnet50
from kNN import * from kNN import *
from metrics import * from metrics import *
if __name__ == '__main__': if __name__ == '__main__':
#device = "mps" if torch.backends.mps.is_available() else "cpu" #device = "mps" if torch.backends.mps.is_available() else "cpu"
...@@ -23,8 +24,10 @@ if __name__ == '__main__': ...@@ -23,8 +24,10 @@ if __name__ == '__main__':
help='path to generated images', type=str) help='path to generated images', type=str)
parser.add_argument('-d', '--data', nargs='?', const='lhq', default='lhq', parser.add_argument('-d', '--data', nargs='?', const='lhq', default='lhq',
help='choose between "lhq" and "face" dataset', type=str) help='choose between "lhq" and "face" dataset', type=str)
parser.add_argument('-a', '--arch', nargs='?', const='cnn', default='cnn', parser.add_argument('--size', nargs='?', const=128, default=128,
help='choose between "clip" and "cnn", default "cnn"', type=str) help='resolution of image the model was trained on, default 128 (int)', type=int)
parser.add_argument('-a', '--arch', nargs='?', const='clip', default='clip',
help='choose between "clip" and "cnn", default "clip"', type=str)
parser.add_argument('-m', '--mode', nargs='?', const='kNN', default='kNN', parser.add_argument('-m', '--mode', nargs='?', const='kNN', default='kNN',
help='choose between "kNN" and "pairs" for closest_pairs, default "kNN"', type=str) help='choose between "kNN" and "pairs" for closest_pairs, default "kNN"', type=str)
parser.add_argument('-k', '--k', nargs='?', const=3, default=3, parser.add_argument('-k', '--k', nargs='?', const=3, default=3,
...@@ -46,22 +49,25 @@ if __name__ == '__main__': ...@@ -46,22 +49,25 @@ if __name__ == '__main__':
sample = args['sample'] sample = args['sample']
name_appendix = args['name'] name_appendix = args['name']
fid_bool = args['fid'] fid_bool = args['fid']
size = args['size']
print('Start') print('Start')
output_path = Path(os.path.join(os.getcwd(),'output'))
if not output_path.is_dir():
os.mkdir(output_path)
txt_filename = 'output/evaluation_' + dataset + '_' + arch + '_' + mode + '-' + name_appendix + '.txt' txt_filename = 'output/evaluation_' + dataset + '_' + arch + '_' + mode + '-' + name_appendix + '.txt'
with open(txt_filename, 'w') as f: with open(txt_filename, 'w') as f:
f.write(f'Path to real images: {path_to_real_images}\n') f.write(f'Path to real images: {path_to_real_images}\n')
f.write(f'Path to generated images: {path_to_generated_images}\n') f.write(f'Path to generated images: {path_to_generated_images}\n')
f.write(f'Experiment on {dataset} dataset\n') f.write(f'Experiment on {dataset} dataset with images of resolution {size}x{size}\n')
f.write(f'Using {arch} model to extract features\n') f.write(f'Using {arch} model to extract features\n')
f.write(f'Plot of {mode} on {sample} samples\n') f.write(f'Plot of {mode} on {sample} samples\n')
f.write(f'Quantitative metrics computed: {fid_bool}\n')
# load data # load data
path_to_training_images = os.path.join(path_to_real_images, 'train') path_to_training_images = os.path.join(path_to_real_images, 'train')
path_to_test_images = os.path.join(path_to_real_images, 'test') path_to_test_images = os.path.join(path_to_real_images, 'test')
if fid_bool == 'yes': if fid_bool == 'yes':
# load data
#path_to_training_images = os.path.join(path_to_real_images, 'train')
#path_to_test_images = os.path.join(path_to_real_images, 'test')
# metrics eval # metrics eval
eval_images = image_to_tensor(path_to_test_images) eval_images = image_to_tensor(path_to_test_images)
...@@ -92,23 +98,19 @@ if __name__ == '__main__': ...@@ -92,23 +98,19 @@ if __name__ == '__main__':
# kNN-based eval # kNN-based eval
if dataset == 'lhq': if dataset == 'lhq':
print('Dataset ', dataset) print('Dataset ', dataset)
#pth = '/Users/roy/Desktop/Workspace/RWTH/SoSe 2023/Deep Learning Lab/DLL_vsc/data/features/lhq_features'
pth = '/home/wn455752/repo/evaluation/features/lhq' pth = '/home/wn455752/repo/evaluation/features/lhq'
# load pretrained ResNet50 # load pretrained ResNet50
if arch == 'cnn': if arch == 'cnn':
#path_to_pretrained_weights = '/Users/roy/Desktop/Workspace/RWTH/SoSe 2023/Deep Learning Lab/DLL_vsc/pretrained/resnet50_places365_pretrained/resnet50_places365_weights.pth' print('Loading pretrained ResNet50...')
print('loading model...')
path_to_pretrained_weights = '/home/wn455752/repo/evaluation/pretrained/resnet50_places365_pretrained/resnet50_places365_weights.pth' path_to_pretrained_weights = '/home/wn455752/repo/evaluation/pretrained/resnet50_places365_pretrained/resnet50_places365_weights.pth'
print('loading weights...')
weights = torch.load(path_to_pretrained_weights) weights = torch.load(path_to_pretrained_weights)
model = resnet50().to(device) model = resnet50().to(device)
print('initializing model with pretrained weights')
model.load_state_dict(weights) model.load_state_dict(weights)
transform = transforms.Compose([transforms.ToTensor(), transform = transforms.Compose([transforms.ToTensor(), # transform PIL.Image to torch.Tensor
transforms.Lambda(lambda x: x * 255)]) transforms.Lambda(lambda x: x * 255)]) # scale values to VGG input range
with torch.no_grad(): with torch.no_grad():
model.eval() model.eval()
print('checking for saved dataset features') print('Checking for existing training dataset features...')
# check for saved dataset features # check for saved dataset features
name_pth = Path(os.path.join(pth, 'resnet_features/real_name_list')) name_pth = Path(os.path.join(pth, 'resnet_features/real_name_list'))
if name_pth.is_file(): if name_pth.is_file():
...@@ -116,22 +118,23 @@ if __name__ == '__main__': ...@@ -116,22 +118,23 @@ if __name__ == '__main__':
real_names = pickle.load(fp) real_names = pickle.load(fp)
feature_pth = Path(os.path.join(pth, 'resnet_features/real_image_features.pt')) feature_pth = Path(os.path.join(pth, 'resnet_features/real_image_features.pt'))
if name_pth.is_file(): if name_pth.is_file():
print('Loading ResNet features of real images...') print('Loading existing training dataset features...')
real_features = torch.load(feature_pth, map_location="cpu") real_features = torch.load(feature_pth, map_location="cpu")
real_features = real_features.to(device) real_features = real_features.to(device)
feature_flag = True feature_flag = True
# load CLIP # load CLIP
elif arch == 'clip': elif arch == 'clip':
print('loading model...') print('Loading pretrained CLIP...')
model, transform = clip.load("ViT-B/32", device=device) model, transform = clip.load("ViT-B/32", device=device)
# check for saved dataset features # check for saved dataset features
print('Checking for existing training dataset features...')
name_pth = Path(os.path.join(pth, 'clip_features/real_name_list')) name_pth = Path(os.path.join(pth, 'clip_features/real_name_list'))
if name_pth.is_file(): if name_pth.is_file():
with open(name_pth, 'rb') as fp: with open(name_pth, 'rb') as fp:
real_names = pickle.load(fp) real_names = pickle.load(fp)
feature_pth = Path(os.path.join(pth, 'clip_features/real_image_features.pt')) feature_pth = Path(os.path.join(pth, 'clip_features/real_image_features.pt'))
if name_pth.is_file(): if name_pth.is_file():
print('Loading CLIP features of real images...') print('Loading existing training dataset features...')
real_features = torch.load(feature_pth, map_location="cpu") real_features = torch.load(feature_pth, map_location="cpu")
real_features = real_features.to(device) real_features = real_features.to(device)
feature_flag = True feature_flag = True
...@@ -140,45 +143,45 @@ if __name__ == '__main__': ...@@ -140,45 +143,45 @@ if __name__ == '__main__':
elif dataset == 'faces': elif dataset == 'faces':
print('Dataset ', dataset) print('Dataset ', dataset)
#pth = '/Users/roy/Desktop/Workspace/RWTH/SoSe 2023/Deep Learning Lab/DLL_vsc/data/features/face_features'
pth = '/home/wn455752/repo/evaluation/features/faces' pth = '/home/wn455752/repo/evaluation/features/faces'
# load pretrained VGGFace # load pretrained VGGFace
if arch == 'cnn': if arch == 'cnn':
print('loading model...') print('Loading pretrained VGGFace...')
#path_to_pretrained_weights = '/Users/roy/Desktop/Workspace/RWTH/SoSe 2023/Deep Learning Lab/DLL_vsc/pretrained/vggface_pretrained/VGG_FACE.t7'
path_to_pretrained_weights = '/home/wn455752/repo/evaluation/pretrained/vggface_pretrained/VGG_FACE.t7' path_to_pretrained_weights = '/home/wn455752/repo/evaluation/pretrained/vggface_pretrained/VGG_FACE.t7'
model = VGG_16().to(device) model = VGG_16().to(device)
model.load_weights(path=path_to_pretrained_weights) model.load_weights(path=path_to_pretrained_weights)
transform = transforms.Compose([transforms.ToTensor(), transform = transforms.Compose([transforms.ToTensor(), # transform PIL.Image to torch.Tensor
transforms.Resize((224,224)), transforms.Resize((224,224)), # resize to VGG input shape
transforms.Lambda(lambda x: x * 255)]) transforms.Lambda(lambda x: x * 255)]) # scale values to VGG input range
with torch.no_grad(): with torch.no_grad():
model.eval() model.eval()
# check for saved dataset features # check for saved dataset features
print('Checking for existing training dataset features...')
name_pth = Path(os.path.join(pth, 'vggface_features/real_name_list')) name_pth = Path(os.path.join(pth, 'vggface_features/real_name_list'))
if name_pth.is_file(): if name_pth.is_file():
with open(name_pth, 'rb') as fp: with open(name_pth, 'rb') as fp:
real_names = pickle.load(fp) real_names = pickle.load(fp)
feature_pth = Path(os.path.join(pth, 'vggface_features/real_image_features.pt')) feature_pth = Path(os.path.join(pth, 'vggface_features/real_image_features.pt'))
if name_pth.is_file(): if name_pth.is_file():
print('Loading VGGFace features of real images...') print('Loading existing training dataset features...')
real_features = torch.load(feature_pth, map_location="cpu") real_features = torch.load(feature_pth, map_location="cpu")
real_features = real_features.to(device) real_features = real_features.to(device)
feature_flag = True feature_flag = True
# load CLIP # load CLIP
elif arch == 'clip': elif arch == 'clip':
print('loading model...') print('Loading pretrained CLIP...')
model, transform = clip.load("ViT-B/32", device=device) model, transform = clip.load("ViT-B/32", device=device)
# check for saved dataset features # check for saved dataset features
print('Checking for existing training dataset features...')
name_pth = Path(os.path.join(pth, 'clip_features/real_name_list')) name_pth = Path(os.path.join(pth, 'clip_features/real_name_list'))
if name_pth.is_file(): if name_pth.is_file():
with open(name_pth, 'rb') as fp: with open(name_pth, 'rb') as fp:
real_names = pickle.load(fp) real_names = pickle.load(fp)
feature_pth = Path(os.path.join(pth, 'clip_features/real_image_features.pt')) feature_pth = Path(os.path.join(pth, 'clip_features/real_image_features.pt'))
if name_pth.is_file(): if name_pth.is_file():
print('Loading CLIP features of real images...') print('Loading existing training dataset features...')
real_features = torch.load(feature_pth, map_location="cpu") real_features = torch.load(feature_pth, map_location="cpu")
real_features = real_features.to(device) real_features = real_features.to(device)
feature_flag = True feature_flag = True
...@@ -186,7 +189,7 @@ if __name__ == '__main__': ...@@ -186,7 +189,7 @@ if __name__ == '__main__':
knn = kNN() knn = kNN()
# get images # get images
if not feature_flag: if not feature_flag:
print('Collecting real images...') print('Collecting training images...')
real_names, real_tensor = knn.get_images(path_to_training_images, transform) real_names, real_tensor = knn.get_images(path_to_training_images, transform)
with open(name_pth, 'wb') as fp: with open(name_pth, 'wb') as fp:
pickle.dump(real_names, fp) pickle.dump(real_names, fp)
...@@ -195,7 +198,7 @@ if __name__ == '__main__': ...@@ -195,7 +198,7 @@ if __name__ == '__main__':
# extract features # extract features
if not feature_flag: if not feature_flag:
print('Extracting features from real images...') print('Extracting features from training images...')
real_features = knn.feature_extractor(real_tensor, model, device) real_features = knn.feature_extractor(real_tensor, model, device)
torch.save(real_features, feature_pth) torch.save(real_features, feature_pth)
print('Extracting features from generated images...') print('Extracting features from generated images...')
...@@ -206,7 +209,6 @@ if __name__ == '__main__': ...@@ -206,7 +209,6 @@ if __name__ == '__main__':
else: else:
sample_size = int(sample) sample_size = int(sample)
if mode == 'kNN': if mode == 'kNN':
print('Finding kNNs...') print('Finding kNNs...')
knn.kNN(real_names, generated_names, knn.kNN(real_names, generated_names,
...@@ -214,11 +216,14 @@ if __name__ == '__main__': ...@@ -214,11 +216,14 @@ if __name__ == '__main__':
path_to_training_images, path_to_generated_images, path_to_training_images, path_to_generated_images,
k=k_kNN, k=k_kNN,
sample=sample_size, sample=sample_size,
size=size,
name_appendix=name_appendix) name_appendix=name_appendix)
elif mode == 'pairs': elif mode == 'pairs':
print('Finding closest pairs...')
knn.nearest_neighbor(real_names, generated_names, knn.nearest_neighbor(real_names, generated_names,
real_features, generated_features, real_features, generated_features,
path_to_training_images, path_to_generated_images, path_to_training_images, path_to_generated_images,
sample=sample_size, sample=sample_size,
size=size,
name_appendix=name_appendix) name_appendix=name_appendix)
print('Finish!') print('Finish!')
...@@ -3,6 +3,7 @@ ...@@ -3,6 +3,7 @@
We conduct two types of evaluation - qualitative and quantitative. We conduct two types of evaluation - qualitative and quantitative.
### Quantitative evaluations - ### Quantitative evaluations -
<pre>
Quantitative evaluations are carried out to compare different backbone architectures of our unconditional diffusion model. Quantitative evaluations are carried out to compare different backbone architectures of our unconditional diffusion model.
A set of 10,000 generated samples from each model variant is compared with the test set of the real dataset. A set of 10,000 generated samples from each model variant is compared with the test set of the real dataset.
These evaluations include - These evaluations include -
...@@ -10,36 +11,42 @@ These evaluations include - ...@@ -10,36 +11,42 @@ These evaluations include -
2. Inception score 2. Inception score
3. Clean FID score (with CLIP) 3. Clean FID score (with CLIP)
4. FID infinity and IS infinity scores 4. FID infinity and IS infinity scores
</pre>
### Qualitative evaluations - ### Qualitative evaluations -
The aim of this set of evaluations is to qualitatively inspect whether our model has overfit to the training images. For this, <pre>
the entire set of 10,000 generated samples from the best performing model from quanititative evaluation is compared with the The aim of this set of evaluations is to qualitatively inspect whether our model has overfit to the training images.
training set of the real dataset. Additionally, the quality check is also done on a hand-selected subset of best generations. For this, the entire set of 10,000 generated samples from the best performing model from quanititative evaluation is
compared with the training set of the real dataset.
Additionally, the quality check is also done on a hand-selected subset of best generations.
The comparison is implemented as MSE values between features of the generated and training samples. The features are extracted The comparison is implemented as MSE values between features of the generated and training samples. The features are
by using a pretrained model (ResNet50-Places365/VGGFace or CLIP). Based on the MSE scores we compute - extracted by using a pretrained model (ResNet50-Places365/VGGFace or CLIP). Based on the MSE scores we compute -
1. kNN - plot the k nearest neighbors of the generated samples 1. kNN - plot the k nearest neighbors of the generated samples
2. Closest pairs - plot the top pairs with smallest MSE value 2. Closest pairs - plot the top pairs with smallest MSE value
</pre>
### Argumnets -
<pre>
Execution starts with evaluate_full.py file. Input arguments are - Execution starts with evaluate_full.py file. Input arguments are -
</pre>
* -rp, --realpath : Path to real images (string) * <pre>-rp, --realpath : Path to real images (string) </pre>
* -gp, --genpath : Path to generated images (string) * <pre>-gp, --genpath : Path to generated images (string) </pre>
* -d, --data : Choose between 'lhq' (for LHQ landscape dataset) and 'faces' (for CelebAHQ faces dataset). * <pre>-d, --data : Choose between 'lhq' (for LHQ landscape dataset) and 'faces' (for CelebAHQ faces dataset).
Default = 'lhq' (string) Default = 'lhq' (string)</pre>
* -a, --arch : Choose between 'cnn' and 'clip'. Chosen pretrained model is used to extract features from the images. * <pre>--size : Resolution of images the model was trained on, default 128 (int) </pre>
* <pre>-a, --arch : Choose between 'cnn' and 'clip'. Chosen pretrained model is used to extract features from the images. </pre>
If 'cnn' is selected, for LHQ dataset the model is a ResNet50 pretrained on Places365 dataset and for If 'cnn' is selected, for LHQ dataset the model is a ResNet50 pretrained on Places365 dataset and for
CelebAHQ dataset the model is a pretrained VGGFace. Default = 'cnn' (string) CelebAHQ dataset the model is a pretrained VGGFace. Not relevant in computing FID, IS scores. Default = 'clip' (string) </pre>
* -m, --mode : Choose between 'kNN' and 'pairs' (for closest pairs), default = 'kNN' (string) * <pre>-m, --mode : Choose between 'kNN' and 'pairs' (for closest pairs), default = 'kNN' (string) </pre>
* -k, --k : k value for kNN, default = 3 (int) * <pre>-k, --k : k value for kNN, default = 3 (int) </pre>
* -s, --sample : Choose between an int and 'all'. If mode is 'kNN', plot kNN for this many samples (first s samples * <pre>-s, --sample : Choose between an int and 'all'. If mode is 'kNN', plot kNN for this many samples (first s samples
in the directory of generated images). If mode is 'pairs', plot the top s closest pairs from entire in the directory of generated images). If mode is 'pairs', plot the top s closest pairs from entire
directory of generated images. Default 10 (int or 'all') directory of generated images. Default 10 (int or 'all') </pre>
* -n, --name : Name appendix (string) * <pre>-n, --name : Name appendix (string) </pre>
* --fid : Choose between 'yes' and 'no'. Compute FID, Inception score and upgraded FID scores. Default 'no' (string) * <pre>--fid : Choose between 'yes' and 'no'. Compute FID, Inception score and upgraded FID scores. Default 'no' (string) </pre>
<pre>
Path to real images leads to a directory with two sub-directories - train and test. Path to real images leads to a directory with two sub-directories - train and test.
data data
...@@ -50,8 +57,9 @@ data ...@@ -50,8 +57,9 @@ data
| |_ train | |_ train
| |_ test | |_ test
CLIP and CNN (ResNet50 or VGGFace) features of training images are saved after the first execution. This alleviates the need CLIP and CNN (ResNet50 or VGGFace) features of training images are saved after the first execution. This alleviates the need \
to recompute features of real images for different sets of generated samples. to recompute features of real images for different sets of generated samples.
</pre>
### Links ### Links
1. ResNet50 pretrained on Places365 - https://github.com/CSAILVision/places365 1. ResNet50 pretrained on Places365 - https://github.com/CSAILVision/places365
......
import os import os
from pathlib import Path
import torch import torch
import torchvision.transforms as transforms import torchvision.transforms as transforms
from torch.utils.data import DataLoader from torch.utils.data import DataLoader
...@@ -12,7 +13,7 @@ class kNN(): ...@@ -12,7 +13,7 @@ class kNN():
def __init__(self): def __init__(self):
pass pass
def get_images(self, path, transform, *args, **kwargs): def get_images(self, path, transform, size=128, *args, **kwargs):
''' '''
returns returns
names: list of filenames names: list of filenames
...@@ -30,8 +31,8 @@ class kNN(): ...@@ -30,8 +31,8 @@ class kNN():
filepath = os.path.join(path, file) filepath = os.path.join(path, file)
names.append(file) names.append(file)
im = Image.open(filepath) im = Image.open(filepath)
if im.size[0] != 128: if im.size[0] != size:
im = im.resize((128,128)) # DDPM was trained on 128x128 images im = im.resize((size,size)) # DDPM was trained on 128x128 images
im = transform(im) im = transform(im)
images_list.append(im) images_list.append(im)
...@@ -68,12 +69,14 @@ class kNN(): ...@@ -68,12 +69,14 @@ class kNN():
real_features, generated_features, real_features, generated_features,
path_to_real_images, path_to_generated_images, path_to_real_images, path_to_generated_images,
k=3, k=3,
sample=10, sample=10, size=128,
name_appendix='', name_appendix='',
*args, **kwargs): *args, **kwargs):
''' '''
creates a plot with (generated image: k nearest real images) pairs creates a plot with (generated image: k nearest real images) pairs
''' '''
if sample > 50:
print('Cannot plot for more than 50 samples! sample <= 50')
fig, ax = plt.subplots(sample, k+1, figsize=((k+1)*3,sample*2)) fig, ax = plt.subplots(sample, k+1, figsize=((k+1)*3,sample*2))
for i in range(len(generated_features)): for i in range(len(generated_features)):
...@@ -94,6 +97,8 @@ class kNN(): ...@@ -94,6 +97,8 @@ class kNN():
# draw the k real images # draw the k real images
for idx in knn.indices: for idx in knn.indices:
im = Image.open(os.path.join(path_to_real_images, real_names[idx.item()])) im = Image.open(os.path.join(path_to_real_images, real_names[idx.item()]))
if im.size[0] != size:
im = im.resize((size,size))
ax[i, j].imshow(im) ax[i, j].imshow(im)
ax[i, j].set_xticks([]) ax[i, j].set_xticks([])
ax[i, j].set_yticks([]) ax[i, j].set_yticks([])
...@@ -103,27 +108,32 @@ class kNN(): ...@@ -103,27 +108,32 @@ class kNN():
break break
# savefig # savefig
output_path = Path(os.path.join(os.getcwd(),'output'))
if not output_path.is_dir():
os.mkdir(output_path)
plot_name = f'{k}NN_{sample}_samples' plot_name = f'{k}NN_{sample}_samples'
if name_appendix != '': if name_appendix != '':
plot_name = plot_name + name_appendix plot_name = plot_name + '_' + name_appendix + '.png'
fig.savefig('output/' + plot_name + '.png') fig.savefig(os.path.join(output_path, plot_name))
def nearest_neighbor(self, real_names, generated_names, def nearest_neighbor(self, real_names, generated_names,
real_features, generated_features, real_features, generated_features,
path_to_real_images, path_to_generated_images, path_to_real_images, path_to_generated_images,
sample=10, sample=10, size=128,
name_appendix='', name_appendix='',
*args, **kwargs): *args, **kwargs):
print('Computing nearest neighbors...') print('Computing nearest neighbors...')
if sample > 50:
print('Cannot plot for more than 50 samples! sample <= 50')
fig, ax = plt.subplots(sample, 2, figsize=(2*3,sample*2)) fig, ax = plt.subplots(sample, 2, figsize=(2*3,sample*2))
nn_dict = OrderedDict() nn_dict = OrderedDict()
for i in range(len(generated_features)): for i in range(len(generated_features)):
# l2 norm of one generated feature and all real features # l2 norm of one generated feature and all real features
#dist = torch.linalg.vector_norm(real_features - generated_features[i], ord=2, dim=1) #dist = torch.linalg.vector_norm(real_features - generated_features[i], ord=2, dim=1) # no mps support
dist = torch.norm(real_features - generated_features[i], dim=1, p=2) dist = torch.norm(real_features - generated_features[i], dim=1, p=2) # soon to be deprecated
# nearest neighbor of the generated image # nearest neighbor of the generated image
knn = dist.topk(1, largest=False) knn = dist.topk(1, largest=False)
# insert to the dict: generated_image: (distance, index of the nearest neighbor) # insert to the dict: generated_image: (distance, index of the nearest neighbor)
...@@ -145,13 +155,19 @@ class kNN(): ...@@ -145,13 +155,19 @@ class kNN():
# draw the real image # draw the real image
knn_score, real_img_idx = nn_dict_sorted[gen_names[i]] knn_score, real_img_idx = nn_dict_sorted[gen_names[i]]
im = Image.open(os.path.join(path_to_real_images, real_names[real_img_idx])) im = Image.open(os.path.join(path_to_real_images, real_names[real_img_idx]))
if im.size[0] != size:
im = im.resize((size,size))
ax[i, 1].imshow(im) ax[i, 1].imshow(im)
ax[i, 1].set_xticks([]) ax[i, 1].set_xticks([])
ax[i, 1].set_yticks([]) ax[i, 1].set_yticks([])
ax[i, 1].set_title(f'{real_names[real_img_idx][:-4]}, {knn_score:.2f}', fontsize=8) ax[i, 1].set_title(f'{real_names[real_img_idx][:-4]}, {knn_score:.2f}', fontsize=8)
#savefig #savefig
output_path = Path(os.path.join(os.getcwd(),'output'))
if not output_path.is_dir():
os.mkdir(output_path)
plot_name = f'closest_pairs_top_{sample}' plot_name = f'closest_pairs_top_{sample}'
if name_appendix != '': if name_appendix != '':
plot_name = plot_name + name_appendix plot_name = plot_name + '_' + name_appendix + '.png'
fig.savefig('output/' + plot_name + '.png') fig.savefig(os.path.join(output_path, plot_name))
\ No newline at end of file
\ No newline at end of file
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment