Skip to content
GitLab
Explore
Sign in
Primary navigation
Search or go to…
Project
Unconditional Diffusion
Manage
Activity
Members
Code
Repository
Branches
Commits
Tags
Repository graph
Compare revisions
Locked files
Deploy
Model registry
Analyze
Model experiments
Help
Help
Support
GitLab documentation
Compare GitLab plans
GitLab community forum
Contribute to GitLab
Provide feedback
Terms and privacy
Keyboard shortcuts
?
Snippets
Groups
Projects
Show more breadcrumbs
Diffusion Project
Unconditional Diffusion
Commits
984a5450
Commit
984a5450
authored
1 year ago
by
Srijeet Roy
Browse files
Options
Downloads
Patches
Plain Diff
update flexibility with diff input shape, output directories, readme
parent
f82fe3e1
No related branches found
No related tags found
No related merge requests found
Changes
3
Show whitespace changes
Inline
Side-by-side
Showing
3 changed files
evaluation/eval_full/evaluate_full.py
+34
-29
34 additions, 29 deletions
evaluation/eval_full/evaluate_full.py
evaluation/eval_full/evaluation_readme.md
+48
-40
48 additions, 40 deletions
evaluation/eval_full/evaluation_readme.md
evaluation/eval_full/kNN.py
+28
-12
28 additions, 12 deletions
evaluation/eval_full/kNN.py
with
110 additions
and
81 deletions
evaluation/eval_full/evaluate_full.py
+
34
−
29
View file @
984a5450
...
@@ -10,6 +10,7 @@ from torchvision.models import resnet50
...
@@ -10,6 +10,7 @@ from torchvision.models import resnet50
from
kNN
import
*
from
kNN
import
*
from
metrics
import
*
from
metrics
import
*
if
__name__
==
'
__main__
'
:
if
__name__
==
'
__main__
'
:
#device = "mps" if torch.backends.mps.is_available() else "cpu"
#device = "mps" if torch.backends.mps.is_available() else "cpu"
...
@@ -23,8 +24,10 @@ if __name__ == '__main__':
...
@@ -23,8 +24,10 @@ if __name__ == '__main__':
help
=
'
path to generated images
'
,
type
=
str
)
help
=
'
path to generated images
'
,
type
=
str
)
parser
.
add_argument
(
'
-d
'
,
'
--data
'
,
nargs
=
'
?
'
,
const
=
'
lhq
'
,
default
=
'
lhq
'
,
parser
.
add_argument
(
'
-d
'
,
'
--data
'
,
nargs
=
'
?
'
,
const
=
'
lhq
'
,
default
=
'
lhq
'
,
help
=
'
choose between
"
lhq
"
and
"
face
"
dataset
'
,
type
=
str
)
help
=
'
choose between
"
lhq
"
and
"
face
"
dataset
'
,
type
=
str
)
parser
.
add_argument
(
'
-a
'
,
'
--arch
'
,
nargs
=
'
?
'
,
const
=
'
cnn
'
,
default
=
'
cnn
'
,
parser
.
add_argument
(
'
--size
'
,
nargs
=
'
?
'
,
const
=
128
,
default
=
128
,
help
=
'
choose between
"
clip
"
and
"
cnn
"
, default
"
cnn
"'
,
type
=
str
)
help
=
'
resolution of image the model was trained on, default 128 (int)
'
,
type
=
int
)
parser
.
add_argument
(
'
-a
'
,
'
--arch
'
,
nargs
=
'
?
'
,
const
=
'
clip
'
,
default
=
'
clip
'
,
help
=
'
choose between
"
clip
"
and
"
cnn
"
, default
"
clip
"'
,
type
=
str
)
parser
.
add_argument
(
'
-m
'
,
'
--mode
'
,
nargs
=
'
?
'
,
const
=
'
kNN
'
,
default
=
'
kNN
'
,
parser
.
add_argument
(
'
-m
'
,
'
--mode
'
,
nargs
=
'
?
'
,
const
=
'
kNN
'
,
default
=
'
kNN
'
,
help
=
'
choose between
"
kNN
"
and
"
pairs
"
for closest_pairs, default
"
kNN
"'
,
type
=
str
)
help
=
'
choose between
"
kNN
"
and
"
pairs
"
for closest_pairs, default
"
kNN
"'
,
type
=
str
)
parser
.
add_argument
(
'
-k
'
,
'
--k
'
,
nargs
=
'
?
'
,
const
=
3
,
default
=
3
,
parser
.
add_argument
(
'
-k
'
,
'
--k
'
,
nargs
=
'
?
'
,
const
=
3
,
default
=
3
,
...
@@ -46,22 +49,25 @@ if __name__ == '__main__':
...
@@ -46,22 +49,25 @@ if __name__ == '__main__':
sample
=
args
[
'
sample
'
]
sample
=
args
[
'
sample
'
]
name_appendix
=
args
[
'
name
'
]
name_appendix
=
args
[
'
name
'
]
fid_bool
=
args
[
'
fid
'
]
fid_bool
=
args
[
'
fid
'
]
size
=
args
[
'
size
'
]
print
(
'
Start
'
)
print
(
'
Start
'
)
output_path
=
Path
(
os
.
path
.
join
(
os
.
getcwd
(),
'
output
'
))
if
not
output_path
.
is_dir
():
os
.
mkdir
(
output_path
)
txt_filename
=
'
output/evaluation_
'
+
dataset
+
'
_
'
+
arch
+
'
_
'
+
mode
+
'
-
'
+
name_appendix
+
'
.txt
'
txt_filename
=
'
output/evaluation_
'
+
dataset
+
'
_
'
+
arch
+
'
_
'
+
mode
+
'
-
'
+
name_appendix
+
'
.txt
'
with
open
(
txt_filename
,
'
w
'
)
as
f
:
with
open
(
txt_filename
,
'
w
'
)
as
f
:
f
.
write
(
f
'
Path to real images:
{
path_to_real_images
}
\n
'
)
f
.
write
(
f
'
Path to real images:
{
path_to_real_images
}
\n
'
)
f
.
write
(
f
'
Path to generated images:
{
path_to_generated_images
}
\n
'
)
f
.
write
(
f
'
Path to generated images:
{
path_to_generated_images
}
\n
'
)
f
.
write
(
f
'
Experiment on
{
dataset
}
dataset
\n
'
)
f
.
write
(
f
'
Experiment on
{
dataset
}
dataset
with images of resolution
{
size
}
x
{
size
}
\n
'
)
f
.
write
(
f
'
Using
{
arch
}
model to extract features
\n
'
)
f
.
write
(
f
'
Using
{
arch
}
model to extract features
\n
'
)
f
.
write
(
f
'
Plot of
{
mode
}
on
{
sample
}
samples
\n
'
)
f
.
write
(
f
'
Plot of
{
mode
}
on
{
sample
}
samples
\n
'
)
f
.
write
(
f
'
Quantitative metrics computed:
{
fid_bool
}
\n
'
)
# load data
# load data
path_to_training_images
=
os
.
path
.
join
(
path_to_real_images
,
'
train
'
)
path_to_training_images
=
os
.
path
.
join
(
path_to_real_images
,
'
train
'
)
path_to_test_images
=
os
.
path
.
join
(
path_to_real_images
,
'
test
'
)
path_to_test_images
=
os
.
path
.
join
(
path_to_real_images
,
'
test
'
)
if
fid_bool
==
'
yes
'
:
if
fid_bool
==
'
yes
'
:
# load data
#path_to_training_images = os.path.join(path_to_real_images, 'train')
#path_to_test_images = os.path.join(path_to_real_images, 'test')
# metrics eval
# metrics eval
eval_images
=
image_to_tensor
(
path_to_test_images
)
eval_images
=
image_to_tensor
(
path_to_test_images
)
...
@@ -92,23 +98,19 @@ if __name__ == '__main__':
...
@@ -92,23 +98,19 @@ if __name__ == '__main__':
# kNN-based eval
# kNN-based eval
if
dataset
==
'
lhq
'
:
if
dataset
==
'
lhq
'
:
print
(
'
Dataset
'
,
dataset
)
print
(
'
Dataset
'
,
dataset
)
#pth = '/Users/roy/Desktop/Workspace/RWTH/SoSe 2023/Deep Learning Lab/DLL_vsc/data/features/lhq_features'
pth
=
'
/home/wn455752/repo/evaluation/features/lhq
'
pth
=
'
/home/wn455752/repo/evaluation/features/lhq
'
# load pretrained ResNet50
# load pretrained ResNet50
if
arch
==
'
cnn
'
:
if
arch
==
'
cnn
'
:
#path_to_pretrained_weights = '/Users/roy/Desktop/Workspace/RWTH/SoSe 2023/Deep Learning Lab/DLL_vsc/pretrained/resnet50_places365_pretrained/resnet50_places365_weights.pth'
print
(
'
Loading pretrained ResNet50...
'
)
print
(
'
loading model...
'
)
path_to_pretrained_weights
=
'
/home/wn455752/repo/evaluation/pretrained/resnet50_places365_pretrained/resnet50_places365_weights.pth
'
path_to_pretrained_weights
=
'
/home/wn455752/repo/evaluation/pretrained/resnet50_places365_pretrained/resnet50_places365_weights.pth
'
print
(
'
loading weights...
'
)
weights
=
torch
.
load
(
path_to_pretrained_weights
)
weights
=
torch
.
load
(
path_to_pretrained_weights
)
model
=
resnet50
().
to
(
device
)
model
=
resnet50
().
to
(
device
)
print
(
'
initializing model with pretrained weights
'
)
model
.
load_state_dict
(
weights
)
model
.
load_state_dict
(
weights
)
transform
=
transforms
.
Compose
([
transforms
.
ToTensor
(),
transform
=
transforms
.
Compose
([
transforms
.
ToTensor
(),
# transform PIL.Image to torch.Tensor
transforms
.
Lambda
(
lambda
x
:
x
*
255
)])
transforms
.
Lambda
(
lambda
x
:
x
*
255
)])
# scale values to VGG input range
with
torch
.
no_grad
():
with
torch
.
no_grad
():
model
.
eval
()
model
.
eval
()
print
(
'
c
hecking for
saved
dataset features
'
)
print
(
'
C
hecking for
existing training
dataset features
...
'
)
# check for saved dataset features
# check for saved dataset features
name_pth
=
Path
(
os
.
path
.
join
(
pth
,
'
resnet_features/real_name_list
'
))
name_pth
=
Path
(
os
.
path
.
join
(
pth
,
'
resnet_features/real_name_list
'
))
if
name_pth
.
is_file
():
if
name_pth
.
is_file
():
...
@@ -116,22 +118,23 @@ if __name__ == '__main__':
...
@@ -116,22 +118,23 @@ if __name__ == '__main__':
real_names
=
pickle
.
load
(
fp
)
real_names
=
pickle
.
load
(
fp
)
feature_pth
=
Path
(
os
.
path
.
join
(
pth
,
'
resnet_features/real_image_features.pt
'
))
feature_pth
=
Path
(
os
.
path
.
join
(
pth
,
'
resnet_features/real_image_features.pt
'
))
if
name_pth
.
is_file
():
if
name_pth
.
is_file
():
print
(
'
Loading
ResNet features of real imag
es...
'
)
print
(
'
Loading
existing training dataset featur
es...
'
)
real_features
=
torch
.
load
(
feature_pth
,
map_location
=
"
cpu
"
)
real_features
=
torch
.
load
(
feature_pth
,
map_location
=
"
cpu
"
)
real_features
=
real_features
.
to
(
device
)
real_features
=
real_features
.
to
(
device
)
feature_flag
=
True
feature_flag
=
True
# load CLIP
# load CLIP
elif
arch
==
'
clip
'
:
elif
arch
==
'
clip
'
:
print
(
'
l
oading
model
...
'
)
print
(
'
L
oading
pretrained CLIP
...
'
)
model
,
transform
=
clip
.
load
(
"
ViT-B/32
"
,
device
=
device
)
model
,
transform
=
clip
.
load
(
"
ViT-B/32
"
,
device
=
device
)
# check for saved dataset features
# check for saved dataset features
print
(
'
Checking for existing training dataset features...
'
)
name_pth
=
Path
(
os
.
path
.
join
(
pth
,
'
clip_features/real_name_list
'
))
name_pth
=
Path
(
os
.
path
.
join
(
pth
,
'
clip_features/real_name_list
'
))
if
name_pth
.
is_file
():
if
name_pth
.
is_file
():
with
open
(
name_pth
,
'
rb
'
)
as
fp
:
with
open
(
name_pth
,
'
rb
'
)
as
fp
:
real_names
=
pickle
.
load
(
fp
)
real_names
=
pickle
.
load
(
fp
)
feature_pth
=
Path
(
os
.
path
.
join
(
pth
,
'
clip_features/real_image_features.pt
'
))
feature_pth
=
Path
(
os
.
path
.
join
(
pth
,
'
clip_features/real_image_features.pt
'
))
if
name_pth
.
is_file
():
if
name_pth
.
is_file
():
print
(
'
Loading
CLIP features of real imag
es...
'
)
print
(
'
Loading
existing training dataset featur
es...
'
)
real_features
=
torch
.
load
(
feature_pth
,
map_location
=
"
cpu
"
)
real_features
=
torch
.
load
(
feature_pth
,
map_location
=
"
cpu
"
)
real_features
=
real_features
.
to
(
device
)
real_features
=
real_features
.
to
(
device
)
feature_flag
=
True
feature_flag
=
True
...
@@ -140,45 +143,45 @@ if __name__ == '__main__':
...
@@ -140,45 +143,45 @@ if __name__ == '__main__':
elif
dataset
==
'
faces
'
:
elif
dataset
==
'
faces
'
:
print
(
'
Dataset
'
,
dataset
)
print
(
'
Dataset
'
,
dataset
)
#pth = '/Users/roy/Desktop/Workspace/RWTH/SoSe 2023/Deep Learning Lab/DLL_vsc/data/features/face_features'
pth
=
'
/home/wn455752/repo/evaluation/features/faces
'
pth
=
'
/home/wn455752/repo/evaluation/features/faces
'
# load pretrained VGGFace
# load pretrained VGGFace
if
arch
==
'
cnn
'
:
if
arch
==
'
cnn
'
:
print
(
'
loading model...
'
)
print
(
'
Loading pretrained VGGFace...
'
)
#path_to_pretrained_weights = '/Users/roy/Desktop/Workspace/RWTH/SoSe 2023/Deep Learning Lab/DLL_vsc/pretrained/vggface_pretrained/VGG_FACE.t7'
path_to_pretrained_weights
=
'
/home/wn455752/repo/evaluation/pretrained/vggface_pretrained/VGG_FACE.t7
'
path_to_pretrained_weights
=
'
/home/wn455752/repo/evaluation/pretrained/vggface_pretrained/VGG_FACE.t7
'
model
=
VGG_16
().
to
(
device
)
model
=
VGG_16
().
to
(
device
)
model
.
load_weights
(
path
=
path_to_pretrained_weights
)
model
.
load_weights
(
path
=
path_to_pretrained_weights
)
transform
=
transforms
.
Compose
([
transforms
.
ToTensor
(),
transform
=
transforms
.
Compose
([
transforms
.
ToTensor
(),
# transform PIL.Image to torch.Tensor
transforms
.
Resize
((
224
,
224
)),
transforms
.
Resize
((
224
,
224
)),
# resize to VGG input shape
transforms
.
Lambda
(
lambda
x
:
x
*
255
)])
transforms
.
Lambda
(
lambda
x
:
x
*
255
)])
# scale values to VGG input range
with
torch
.
no_grad
():
with
torch
.
no_grad
():
model
.
eval
()
model
.
eval
()
# check for saved dataset features
# check for saved dataset features
print
(
'
Checking for existing training dataset features...
'
)
name_pth
=
Path
(
os
.
path
.
join
(
pth
,
'
vggface_features/real_name_list
'
))
name_pth
=
Path
(
os
.
path
.
join
(
pth
,
'
vggface_features/real_name_list
'
))
if
name_pth
.
is_file
():
if
name_pth
.
is_file
():
with
open
(
name_pth
,
'
rb
'
)
as
fp
:
with
open
(
name_pth
,
'
rb
'
)
as
fp
:
real_names
=
pickle
.
load
(
fp
)
real_names
=
pickle
.
load
(
fp
)
feature_pth
=
Path
(
os
.
path
.
join
(
pth
,
'
vggface_features/real_image_features.pt
'
))
feature_pth
=
Path
(
os
.
path
.
join
(
pth
,
'
vggface_features/real_image_features.pt
'
))
if
name_pth
.
is_file
():
if
name_pth
.
is_file
():
print
(
'
Loading
VGGFace features of real imag
es...
'
)
print
(
'
Loading
existing training dataset featur
es...
'
)
real_features
=
torch
.
load
(
feature_pth
,
map_location
=
"
cpu
"
)
real_features
=
torch
.
load
(
feature_pth
,
map_location
=
"
cpu
"
)
real_features
=
real_features
.
to
(
device
)
real_features
=
real_features
.
to
(
device
)
feature_flag
=
True
feature_flag
=
True
# load CLIP
# load CLIP
elif
arch
==
'
clip
'
:
elif
arch
==
'
clip
'
:
print
(
'
l
oading
model
...
'
)
print
(
'
L
oading
pretrained CLIP
...
'
)
model
,
transform
=
clip
.
load
(
"
ViT-B/32
"
,
device
=
device
)
model
,
transform
=
clip
.
load
(
"
ViT-B/32
"
,
device
=
device
)
# check for saved dataset features
# check for saved dataset features
print
(
'
Checking for existing training dataset features...
'
)
name_pth
=
Path
(
os
.
path
.
join
(
pth
,
'
clip_features/real_name_list
'
))
name_pth
=
Path
(
os
.
path
.
join
(
pth
,
'
clip_features/real_name_list
'
))
if
name_pth
.
is_file
():
if
name_pth
.
is_file
():
with
open
(
name_pth
,
'
rb
'
)
as
fp
:
with
open
(
name_pth
,
'
rb
'
)
as
fp
:
real_names
=
pickle
.
load
(
fp
)
real_names
=
pickle
.
load
(
fp
)
feature_pth
=
Path
(
os
.
path
.
join
(
pth
,
'
clip_features/real_image_features.pt
'
))
feature_pth
=
Path
(
os
.
path
.
join
(
pth
,
'
clip_features/real_image_features.pt
'
))
if
name_pth
.
is_file
():
if
name_pth
.
is_file
():
print
(
'
Loading
CLIP features of real imag
es...
'
)
print
(
'
Loading
existing training dataset featur
es...
'
)
real_features
=
torch
.
load
(
feature_pth
,
map_location
=
"
cpu
"
)
real_features
=
torch
.
load
(
feature_pth
,
map_location
=
"
cpu
"
)
real_features
=
real_features
.
to
(
device
)
real_features
=
real_features
.
to
(
device
)
feature_flag
=
True
feature_flag
=
True
...
@@ -186,7 +189,7 @@ if __name__ == '__main__':
...
@@ -186,7 +189,7 @@ if __name__ == '__main__':
knn
=
kNN
()
knn
=
kNN
()
# get images
# get images
if
not
feature_flag
:
if
not
feature_flag
:
print
(
'
Collecting
real
images...
'
)
print
(
'
Collecting
training
images...
'
)
real_names
,
real_tensor
=
knn
.
get_images
(
path_to_training_images
,
transform
)
real_names
,
real_tensor
=
knn
.
get_images
(
path_to_training_images
,
transform
)
with
open
(
name_pth
,
'
wb
'
)
as
fp
:
with
open
(
name_pth
,
'
wb
'
)
as
fp
:
pickle
.
dump
(
real_names
,
fp
)
pickle
.
dump
(
real_names
,
fp
)
...
@@ -195,7 +198,7 @@ if __name__ == '__main__':
...
@@ -195,7 +198,7 @@ if __name__ == '__main__':
# extract features
# extract features
if
not
feature_flag
:
if
not
feature_flag
:
print
(
'
Extracting features from
real
images...
'
)
print
(
'
Extracting features from
training
images...
'
)
real_features
=
knn
.
feature_extractor
(
real_tensor
,
model
,
device
)
real_features
=
knn
.
feature_extractor
(
real_tensor
,
model
,
device
)
torch
.
save
(
real_features
,
feature_pth
)
torch
.
save
(
real_features
,
feature_pth
)
print
(
'
Extracting features from generated images...
'
)
print
(
'
Extracting features from generated images...
'
)
...
@@ -206,7 +209,6 @@ if __name__ == '__main__':
...
@@ -206,7 +209,6 @@ if __name__ == '__main__':
else
:
else
:
sample_size
=
int
(
sample
)
sample_size
=
int
(
sample
)
if
mode
==
'
kNN
'
:
if
mode
==
'
kNN
'
:
print
(
'
Finding kNNs...
'
)
print
(
'
Finding kNNs...
'
)
knn
.
kNN
(
real_names
,
generated_names
,
knn
.
kNN
(
real_names
,
generated_names
,
...
@@ -214,11 +216,14 @@ if __name__ == '__main__':
...
@@ -214,11 +216,14 @@ if __name__ == '__main__':
path_to_training_images
,
path_to_generated_images
,
path_to_training_images
,
path_to_generated_images
,
k
=
k_kNN
,
k
=
k_kNN
,
sample
=
sample_size
,
sample
=
sample_size
,
size
=
size
,
name_appendix
=
name_appendix
)
name_appendix
=
name_appendix
)
elif
mode
==
'
pairs
'
:
elif
mode
==
'
pairs
'
:
print
(
'
Finding closest pairs...
'
)
knn
.
nearest_neighbor
(
real_names
,
generated_names
,
knn
.
nearest_neighbor
(
real_names
,
generated_names
,
real_features
,
generated_features
,
real_features
,
generated_features
,
path_to_training_images
,
path_to_generated_images
,
path_to_training_images
,
path_to_generated_images
,
sample
=
sample_size
,
sample
=
sample_size
,
size
=
size
,
name_appendix
=
name_appendix
)
name_appendix
=
name_appendix
)
print
(
'
Finish!
'
)
print
(
'
Finish!
'
)
This diff is collapsed.
Click to expand it.
evaluation/eval_full/evaluation_readme.md
+
48
−
40
View file @
984a5450
...
@@ -3,6 +3,7 @@
...
@@ -3,6 +3,7 @@
We conduct two types of evaluation - qualitative and quantitative.
We conduct two types of evaluation - qualitative and quantitative.
### Quantitative evaluations -
### Quantitative evaluations -
<pre>
Quantitative evaluations are carried out to compare different backbone architectures of our unconditional diffusion model.
Quantitative evaluations are carried out to compare different backbone architectures of our unconditional diffusion model.
A set of 10,000 generated samples from each model variant is compared with the test set of the real dataset.
A set of 10,000 generated samples from each model variant is compared with the test set of the real dataset.
These evaluations include -
These evaluations include -
...
@@ -10,36 +11,42 @@ These evaluations include -
...
@@ -10,36 +11,42 @@ These evaluations include -
2.
Inception score
2.
Inception score
3.
Clean FID score (with CLIP)
3.
Clean FID score (with CLIP)
4.
FID infinity and IS infinity scores
4.
FID infinity and IS infinity scores
</pre>
### Qualitative evaluations -
### Qualitative evaluations -
The aim of this set of evaluations is to qualitatively inspect whether our model has overfit to the training images. For this,
<pre>
the entire set of 10,000 generated samples from the best performing model from quanititative evaluation is compared with the
The aim of this set of evaluations is to qualitatively inspect whether our model has overfit to the training images.
training set of the real dataset. Additionally, the quality check is also done on a hand-selected subset of best generations.
For this, the entire set of 10,000 generated samples from the best performing model from quanititative evaluation is
compared with the training set of the real dataset.
Additionally, the quality check is also done on a hand-selected subset of best generations.
The comparison is implemented as MSE values between features of the generated and training samples. The features are
extracted
The comparison is implemented as MSE values between features of the generated and training samples. The features are
by using a pretrained model (ResNet50-Places365/VGGFace or CLIP). Based on the MSE scores we compute -
extracted
by using a pretrained model (ResNet50-Places365/VGGFace or CLIP). Based on the MSE scores we compute -
1.
kNN - plot the k nearest neighbors of the generated samples
1.
kNN - plot the k nearest neighbors of the generated samples
2.
Closest pairs - plot the top pairs with smallest MSE value
2.
Closest pairs - plot the top pairs with smallest MSE value
</pre>
### Argumnets -
<pre>
Execution starts with evaluate_full.py file. Input arguments are -
Execution starts with evaluate_full.py file. Input arguments are -
</pre>
*
-rp, --realpath : Path to real images (string)
*
<pre>
-rp, --realpath : Path to real images (string)
</pre>
*
-gp, --genpath : Path to generated images (string)
*
<pre>
-gp, --genpath : Path to generated images (string)
</pre>
*
-d, --data : Choose between 'lhq' (for LHQ landscape dataset) and 'faces' (for CelebAHQ faces dataset).
*
<pre>
-d, --data : Choose between 'lhq' (for LHQ landscape dataset) and 'faces' (for CelebAHQ faces dataset).
Default = 'lhq' (string)
Default = 'lhq' (string)
</pre>
*
-a, --arch : Choose between 'cnn' and 'clip'. Chosen pretrained model is used to extract features from the images.
*
<pre>
--size : Resolution of images the model was trained on, default 128 (int)
</pre>
*
<pre>
-a, --arch : Choose between 'cnn' and 'clip'. Chosen pretrained model is used to extract features from the images.
</pre>
If 'cnn' is selected, for LHQ dataset the model is a ResNet50 pretrained on Places365 dataset and for
If 'cnn' is selected, for LHQ dataset the model is a ResNet50 pretrained on Places365 dataset and for
CelebAHQ dataset the model is a pretrained VGGFace. Default = 'c
nn
' (string)
CelebAHQ dataset the model is a pretrained VGGFace.
Not relevant in computing FID, IS scores.
Default = 'c
lip
' (string)
</pre>
*
-m, --mode : Choose between 'kNN' and 'pairs' (for closest pairs), default = 'kNN' (string)
*
<pre>
-m, --mode : Choose between 'kNN' and 'pairs' (for closest pairs), default = 'kNN' (string)
</pre>
*
-k, --k : k value for kNN, default = 3 (int)
*
<pre>
-k, --k : k value for kNN, default = 3 (int)
</pre>
*
-s, --sample : Choose between an int and 'all'. If mode is 'kNN', plot kNN for this many samples (first s samples
*
<pre>
-s, --sample : Choose between an int and 'all'. If mode is 'kNN', plot kNN for this many samples (first s samples
in the directory of generated images). If mode is 'pairs', plot the top s closest pairs from entire
in the directory of generated images). If mode is 'pairs', plot the top s closest pairs from entire
directory of generated images. Default 10 (int or 'all')
directory of generated images. Default 10 (int or 'all')
</pre>
*
-n, --name : Name appendix (string)
*
<pre>
-n, --name : Name appendix (string)
</pre>
*
--fid : Choose between 'yes' and 'no'. Compute FID, Inception score and upgraded FID scores. Default 'no' (string)
*
<pre>
--fid : Choose between 'yes' and 'no'. Compute FID, Inception score and upgraded FID scores. Default 'no' (string)
</pre>
<pre>
Path to real images leads to a directory with two sub-directories - train and test.
Path to real images leads to a directory with two sub-directories - train and test.
data
data
...
@@ -50,8 +57,9 @@ data
...
@@ -50,8 +57,9 @@ data
| |_ train
| |_ train
| |_ test
| |_ test
CLIP and CNN (ResNet50 or VGGFace) features of training images are saved after the first execution. This alleviates the need
CLIP and CNN (ResNet50 or VGGFace) features of training images are saved after the first execution. This alleviates the need
\
to recompute features of real images for different sets of generated samples.
to recompute features of real images for different sets of generated samples.
</pre>
### Links
### Links
1.
ResNet50 pretrained on Places365 - https://github.com/CSAILVision/places365
1.
ResNet50 pretrained on Places365 - https://github.com/CSAILVision/places365
...
...
This diff is collapsed.
Click to expand it.
evaluation/eval_full/kNN.py
+
28
−
12
View file @
984a5450
import
os
import
os
from
pathlib
import
Path
import
torch
import
torch
import
torchvision.transforms
as
transforms
import
torchvision.transforms
as
transforms
from
torch.utils.data
import
DataLoader
from
torch.utils.data
import
DataLoader
...
@@ -12,7 +13,7 @@ class kNN():
...
@@ -12,7 +13,7 @@ class kNN():
def
__init__
(
self
):
def
__init__
(
self
):
pass
pass
def
get_images
(
self
,
path
,
transform
,
*
args
,
**
kwargs
):
def
get_images
(
self
,
path
,
transform
,
size
=
128
,
*
args
,
**
kwargs
):
'''
'''
returns
returns
names: list of filenames
names: list of filenames
...
@@ -30,8 +31,8 @@ class kNN():
...
@@ -30,8 +31,8 @@ class kNN():
filepath
=
os
.
path
.
join
(
path
,
file
)
filepath
=
os
.
path
.
join
(
path
,
file
)
names
.
append
(
file
)
names
.
append
(
file
)
im
=
Image
.
open
(
filepath
)
im
=
Image
.
open
(
filepath
)
if
im
.
size
[
0
]
!=
128
:
if
im
.
size
[
0
]
!=
size
:
im
=
im
.
resize
((
128
,
128
))
# DDPM was trained on 128x128 images
im
=
im
.
resize
((
size
,
size
))
# DDPM was trained on 128x128 images
im
=
transform
(
im
)
im
=
transform
(
im
)
images_list
.
append
(
im
)
images_list
.
append
(
im
)
...
@@ -68,12 +69,14 @@ class kNN():
...
@@ -68,12 +69,14 @@ class kNN():
real_features
,
generated_features
,
real_features
,
generated_features
,
path_to_real_images
,
path_to_generated_images
,
path_to_real_images
,
path_to_generated_images
,
k
=
3
,
k
=
3
,
sample
=
10
,
sample
=
10
,
size
=
128
,
name_appendix
=
''
,
name_appendix
=
''
,
*
args
,
**
kwargs
):
*
args
,
**
kwargs
):
'''
'''
creates a plot with (generated image: k nearest real images) pairs
creates a plot with (generated image: k nearest real images) pairs
'''
'''
if
sample
>
50
:
print
(
'
Cannot plot for more than 50 samples! sample <= 50
'
)
fig
,
ax
=
plt
.
subplots
(
sample
,
k
+
1
,
figsize
=
((
k
+
1
)
*
3
,
sample
*
2
))
fig
,
ax
=
plt
.
subplots
(
sample
,
k
+
1
,
figsize
=
((
k
+
1
)
*
3
,
sample
*
2
))
for
i
in
range
(
len
(
generated_features
)):
for
i
in
range
(
len
(
generated_features
)):
...
@@ -94,6 +97,8 @@ class kNN():
...
@@ -94,6 +97,8 @@ class kNN():
# draw the k real images
# draw the k real images
for
idx
in
knn
.
indices
:
for
idx
in
knn
.
indices
:
im
=
Image
.
open
(
os
.
path
.
join
(
path_to_real_images
,
real_names
[
idx
.
item
()]))
im
=
Image
.
open
(
os
.
path
.
join
(
path_to_real_images
,
real_names
[
idx
.
item
()]))
if
im
.
size
[
0
]
!=
size
:
im
=
im
.
resize
((
size
,
size
))
ax
[
i
,
j
].
imshow
(
im
)
ax
[
i
,
j
].
imshow
(
im
)
ax
[
i
,
j
].
set_xticks
([])
ax
[
i
,
j
].
set_xticks
([])
ax
[
i
,
j
].
set_yticks
([])
ax
[
i
,
j
].
set_yticks
([])
...
@@ -103,27 +108,32 @@ class kNN():
...
@@ -103,27 +108,32 @@ class kNN():
break
break
# savefig
# savefig
output_path
=
Path
(
os
.
path
.
join
(
os
.
getcwd
(),
'
output
'
))
if
not
output_path
.
is_dir
():
os
.
mkdir
(
output_path
)
plot_name
=
f
'
{
k
}
NN_
{
sample
}
_samples
'
plot_name
=
f
'
{
k
}
NN_
{
sample
}
_samples
'
if
name_appendix
!=
''
:
if
name_appendix
!=
''
:
plot_name
=
plot_name
+
name_appendix
plot_name
=
plot_name
+
'
_
'
+
name_appendix
+
'
.png
'
fig
.
savefig
(
'
output/
'
+
plot_name
+
'
.png
'
)
fig
.
savefig
(
os
.
path
.
join
(
output_path
,
plot_name
)
)
def
nearest_neighbor
(
self
,
real_names
,
generated_names
,
def
nearest_neighbor
(
self
,
real_names
,
generated_names
,
real_features
,
generated_features
,
real_features
,
generated_features
,
path_to_real_images
,
path_to_generated_images
,
path_to_real_images
,
path_to_generated_images
,
sample
=
10
,
sample
=
10
,
size
=
128
,
name_appendix
=
''
,
name_appendix
=
''
,
*
args
,
**
kwargs
):
*
args
,
**
kwargs
):
print
(
'
Computing nearest neighbors...
'
)
print
(
'
Computing nearest neighbors...
'
)
if
sample
>
50
:
print
(
'
Cannot plot for more than 50 samples! sample <= 50
'
)
fig
,
ax
=
plt
.
subplots
(
sample
,
2
,
figsize
=
(
2
*
3
,
sample
*
2
))
fig
,
ax
=
plt
.
subplots
(
sample
,
2
,
figsize
=
(
2
*
3
,
sample
*
2
))
nn_dict
=
OrderedDict
()
nn_dict
=
OrderedDict
()
for
i
in
range
(
len
(
generated_features
)):
for
i
in
range
(
len
(
generated_features
)):
# l2 norm of one generated feature and all real features
# l2 norm of one generated feature and all real features
#dist = torch.linalg.vector_norm(real_features - generated_features[i], ord=2, dim=1)
#dist = torch.linalg.vector_norm(real_features - generated_features[i], ord=2, dim=1) # no mps support
dist
=
torch
.
norm
(
real_features
-
generated_features
[
i
],
dim
=
1
,
p
=
2
)
dist
=
torch
.
norm
(
real_features
-
generated_features
[
i
],
dim
=
1
,
p
=
2
)
# soon to be deprecated
# nearest neighbor of the generated image
# nearest neighbor of the generated image
knn
=
dist
.
topk
(
1
,
largest
=
False
)
knn
=
dist
.
topk
(
1
,
largest
=
False
)
# insert to the dict: generated_image: (distance, index of the nearest neighbor)
# insert to the dict: generated_image: (distance, index of the nearest neighbor)
...
@@ -145,13 +155,19 @@ class kNN():
...
@@ -145,13 +155,19 @@ class kNN():
# draw the real image
# draw the real image
knn_score
,
real_img_idx
=
nn_dict_sorted
[
gen_names
[
i
]]
knn_score
,
real_img_idx
=
nn_dict_sorted
[
gen_names
[
i
]]
im
=
Image
.
open
(
os
.
path
.
join
(
path_to_real_images
,
real_names
[
real_img_idx
]))
im
=
Image
.
open
(
os
.
path
.
join
(
path_to_real_images
,
real_names
[
real_img_idx
]))
if
im
.
size
[
0
]
!=
size
:
im
=
im
.
resize
((
size
,
size
))
ax
[
i
,
1
].
imshow
(
im
)
ax
[
i
,
1
].
imshow
(
im
)
ax
[
i
,
1
].
set_xticks
([])
ax
[
i
,
1
].
set_xticks
([])
ax
[
i
,
1
].
set_yticks
([])
ax
[
i
,
1
].
set_yticks
([])
ax
[
i
,
1
].
set_title
(
f
'
{
real_names
[
real_img_idx
][
:
-
4
]
}
,
{
knn_score
:
.
2
f
}
'
,
fontsize
=
8
)
ax
[
i
,
1
].
set_title
(
f
'
{
real_names
[
real_img_idx
][
:
-
4
]
}
,
{
knn_score
:
.
2
f
}
'
,
fontsize
=
8
)
#savefig
#savefig
output_path
=
Path
(
os
.
path
.
join
(
os
.
getcwd
(),
'
output
'
))
if
not
output_path
.
is_dir
():
os
.
mkdir
(
output_path
)
plot_name
=
f
'
closest_pairs_top_
{
sample
}
'
plot_name
=
f
'
closest_pairs_top_
{
sample
}
'
if
name_appendix
!=
''
:
if
name_appendix
!=
''
:
plot_name
=
plot_name
+
name_appendix
plot_name
=
plot_name
+
'
_
'
+
name_appendix
+
'
.png
'
fig
.
savefig
(
'
output/
'
+
plot_name
+
'
.png
'
)
fig
.
savefig
(
os
.
path
.
join
(
output_path
,
plot_name
))
\ No newline at end of file
\ No newline at end of file
This diff is collapsed.
Click to expand it.
Preview
0%
Loading
Try again
or
attach a new file
.
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Save comment
Cancel
Please
register
or
sign in
to comment