Skip to content
Snippets Groups Projects
Commit 056a9793 authored by Srijeet Roy's avatar Srijeet Roy
Browse files

eval update: eval directory, main, experiment_creator, readme

parent 5b1ea326
No related branches found
No related tags found
No related merge requests found
# Unconditional Diffusion model
# Diffusion_Project
This repo presents our paper implementation of the papers *Denoising Diffusion Probabilistic Models* by Ho et al and *Improved Denoising Diffusion Probabilistic Models* by Nichol and Dhariwal. The pipeline contains training, sampling and evaluation of a Unconditional Diffusion Model on the HPC.
A partial reimplemnetation of the paper *papername*. The pipeline contains the training and sampling for an image-to-image conditioned LDM model.
We show our results by training unconditional diffusion models on the landscapes (LHQ) and celebrity A (CelebAHQ) datasets, generating realistic images with a resolution of 128x128px.
## Motivation
## Background
Unconditional diffusion models are ...
......@@ -30,18 +28,23 @@ To train the model, follow the steps:
### Model Sampling
1. Make sure that the checkpoint file is within the **ddpm_trained** folder within the experiment folder. Alternatively one can create this folder manually and add the checkpoint file.
1. Make sure that the checkpoint file is within the **ldm_trained** folder within the experiment folder. Alternatively one can create this folder manually and add the checkpoint file.
2. Also make sure that the correct checkpoint name is given in the json file ```settings/sample_samplesettings.json```
otherwise the sampling will be done with randomly initialized weights.
3. within the repository folder, run ```python main.py sample <path to experimentfolder>/settings```
### Model Evaluation
...
To evaluate the performance of the unconditional diffusion model:
1. Ensure the settings are set to desired values in ```settings/evaluation_settings.json```
2. within the repository folder, run ```python main.py evaluate <path to experimentfolder>/settings```
3. For a detailed overview on evaluation metrics, refer to [evaluation_readme](evaluation/evaluation_readme.md).
## Comprehensive Description
## Description
This repository houses our comprehensive pipeline, designed to conveniently train, sample from, and evaluate our unconditional diffusion model.
The pipeline is initiated via the experiment_creator.ipynb notebook, which may be separately run our local machine. This notebook allows for the configuration of every aspect of the diffusion model, including all hyperparameters. These configurations extend to the underlying neural backbone UNet, as well as the training parameters, such as training from checkpoint, Weights & Biases run name for resumption, optimizer selection, adjustment of the CosineAnnealingLR learning rate schedule parameters, and more. Moreover, it includes parameters for evaluating a and sampling images via a trained diffusion models.
The pipeline is initiated via the experiment_creator.ipynb notebook, which is separately run our local machine. This notebook allows for the configuration of every aspect of the diffusion model, including all hyperparameters. These configurations extend to the underlying neural backbone UNet, as well as the training parameters, such as training from checkpoint, Weights & Biases run name for resumption, optimizer selection, adjustment of the CosineAnnealingLR learning rate schedule parameters, and more. Moreover, it includes parameters for evaluating a and sampling images via a trained diffusion models.
Upon execution, the notebook generates individual JSON files, encapsulating all the hyperparameter information. When running the model on the HPC, we can choose between the operations 'train', 'sample', and 'evaluate'. These operations automatically extract the necessary hyperparameters from the JSON files and perform their respective tasks. This process is managed by the main.py file. The remaining files contain all the necessary functions optimized for HPC to perform the aforementioned tasks.
......
import os
import torch
import argparse
from PIL import Image
from torch.utils.data import DataLoader
import torchvision.transforms as transforms
from itertools import cycle
from torchmetrics.image import StructuralSimilarityIndexMeasure, PeakSignalNoiseRatio
def image_to_tensor(path, sample=10, device='cpu'):
transform_resize = transforms.Compose([transforms.ToTensor(), transforms.Resize(128), transforms.Lambda(lambda x: (x * 255)) ])
transform = transforms.Compose([transforms.ToTensor(), transforms.Lambda(lambda x: (x * 255)) ])
filelist = os.listdir(path)
if sample == 'all':
sample_size = -1
else:
sample_size = sample
image_names = []
image_list = []
for file in filelist:
if file.endswith('.png'):
filepath = os.path.join(path, file)
image_names.append(file)
im = Image.open(filepath)
if im.size[0] != 128:
im = transform_resize(im)
else:
im = transform(im)
image_list.append(im)
if len(image_list) == sample_size:
break
print(f'current sample size: {len(image_names)}')
# convert list of tensors to tensor
image_tensor = torch.stack(image_list).to(device)
return image_tensor
if __name__ == '__main__':
device = 'mps' if torch.backends.mps.is_available() else 'cpu'
#device = 'cuda' if torch.cuda.is_available() else 'cpu'
parser = argparse.ArgumentParser()
parser.add_argument('-rp', '--realpath', required=True, help='path to real images', type=str)
parser.add_argument('-gp', '--genpath', required=True, help='path to generated images', type=str)
parser.add_argument('-s', '--sample', nargs='?', const=10, default=10,
help='how many generated samples to compare, default 10, int or "all"')
parser.add_argument('-n', '--name', nargs='?', const='', default='',
help='name appendix')
args = vars(parser.parse_args())
path_to_real_images = args['realpath']
path_to_generated_images = args['genpath']
sample = args['sample']
name_appendix = args['name']
real_image_tensor = image_to_tensor(path_to_real_images, sample, device)
generated_image_tensor = image_to_tensor(path_to_generated_images, sample, device)
real_dataloader = DataLoader(real_image_tensor, batch_size=128, shuffle=False)
generated_dataloader = DataLoader(generated_image_tensor, batch_size=128, shuffle=False)
ssim = StructuralSimilarityIndexMeasure(data_range=1.0).to(device)
psnr = PeakSignalNoiseRatio().to(device)
for r, g in zip(real_dataloader, cycle(generated_dataloader)):
r = r.to(device)
g = g.to(device)
ssim.update(preds=g, target=r)
psnr.update(preds=g, target=r)
ssim_score = ssim.compute()
psnr_score = psnr.compute()
print(f'SSIM: {ssim_score:0.3f}')
print(f'PSNR: {psnr_score:0.3f}')
txtfile = 'content_invariant_metrics.txt'
if name_appendix != '':
txtfile = 'content_invariant_metrics_' + name_appendix + '.txt'
with open(os.path.join(os.getcwd(),txtfile), 'w') as fp:
fp.write(f'SSIM: {ssim_score:0.3f}\n')
fp.write(f'PSNR: {psnr_score:0.3f}\n')
\ No newline at end of file
import os
import argparse
import pickle
from pathlib import Path
import clip
from vggface import *
from torchvision.models import resnet50
from kNN import *
from metrics import *
if __name__ == '__main__':
#device = "mps" if torch.backends.mps.is_available() else "cpu"
device = "cuda" if torch.cuda.is_available() else "cpu"
print('device:', device)
print('Parsing arguments...')
parser = argparse.ArgumentParser()
parser.add_argument('-rp', '--realpath', required=True,
help='path to real images', type=str)
parser.add_argument('-gp', '--genpath', required=True,
help='path to generated images', type=str)
parser.add_argument('-d', '--data', nargs='?', const='lhq', default='lhq',
help='choose between "lhq" and "face" dataset', type=str)
parser.add_argument('--size', nargs='?', const=128, default=128,
help='resolution of image the model was trained on, default 128 (int)', type=int)
parser.add_argument('-a', '--arch', nargs='?', const='clip', default='clip',
help='choose between "clip" and "cnn", default "clip"', type=str)
parser.add_argument('-m', '--mode', nargs='?', const='kNN', default='kNN',
help='choose between "kNN" and "pairs" for closest_pairs, default "kNN"', type=str)
parser.add_argument('-k', '--k', nargs='?', const=3, default=3,
help='k for kNN, default 3', type=int)
parser.add_argument('-s', '--sample', nargs='?', const=10, default=10,
help='how many generated samples to compare, default 10, int or "all"')
parser.add_argument('-n', '--name', nargs='?', const='', default='',
help='name appendix for the plot file', type=str)
parser.add_argument('--fid', nargs='?', const='no', default='no',
help='compute FID and inception score, choose between "yes" and "no"', type=str)
args = vars(parser.parse_args())
path_to_real_images = args['realpath']
path_to_generated_images = args['genpath']
dataset = args['data']
arch = args['arch']
mode = args['mode']
k_kNN = args['k']
sample = args['sample']
name_appendix = args['name']
fid_bool = args['fid']
size = args['size']
print('Start')
output_path = Path(os.path.join(os.getcwd(),'output'))
if not output_path.is_dir():
os.mkdir(output_path)
txt_filename = 'output/evaluation_' + dataset + '_' + arch + '_' + mode + '-' + name_appendix + '.txt'
with open(txt_filename, 'w') as f:
f.write(f'Path to real images: {path_to_real_images}\n')
f.write(f'Path to generated images: {path_to_generated_images}\n')
f.write(f'Experiment on {dataset} dataset with images of resolution {size}x{size}\n')
f.write(f'Using {arch} model to extract features\n')
f.write(f'Plot of {mode} on {sample} samples\n')
f.write(f'Quantitative metrics computed: {fid_bool}\n')
# load data
path_to_training_images = os.path.join(path_to_real_images, 'train')
path_to_test_images = os.path.join(path_to_real_images, 'test')
if fid_bool == 'yes':
# metrics eval
eval_images = image_to_tensor(path_to_test_images)
generated = image_to_tensor(path_to_generated_images)
print('Computing FID, KID & inception scores...')
fid_score, kid_score, is_score = compute_scores(eval_images, generated, device)
with open(txt_filename, 'a') as f:
f.write(f'FID score: {fid_score}\n')
f.write(f'KID score: {kid_score}\n')
f.write(f'Inception score: {is_score}\n')
print('Computing Clean FID scores...')
clean_fid_score, clip_clean_fid_score = clean_fid(path_to_test_images, path_to_generated_images)
with open(txt_filename, 'a') as f:
f.write(f'Clean FID score: {clean_fid_score}\n')
f.write(f'Clean FID score with CLIP features: {clip_clean_fid_score}\n')
print('Computing FID infinity and IS infinity scores...')
fid_infinity, is_infinity = fid_inf_is_inf(path_to_test_images, path_to_generated_images)
with open(txt_filename, 'a') as f:
f.write(f'FID Infinity score: {fid_infinity}\n')
f.write(f'IS Infinity score: {is_infinity}\n')
print('Loading model...', arch)
feature_flag = False
# kNN-based eval
if dataset == 'lhq':
print('Dataset ', dataset)
pth = '/home/wn455752/repo/evaluation/features/lhq'
# load pretrained ResNet50
if arch == 'cnn':
print('Loading pretrained ResNet50...')
path_to_pretrained_weights = '/home/wn455752/repo/evaluation/pretrained/resnet50_places365_pretrained/resnet50_places365_weights.pth'
weights = torch.load(path_to_pretrained_weights)
model = resnet50().to(device)
model.load_state_dict(weights)
transform = transforms.Compose([transforms.ToTensor(), # transform PIL.Image to torch.Tensor
transforms.Lambda(lambda x: x * 255)]) # scale values to VGG input range
with torch.no_grad():
model.eval()
print('Checking for existing training dataset features...')
# check for saved dataset features
name_pth = Path(os.path.join(pth, 'resnet_features/real_name_list'))
if name_pth.is_file():
with open(name_pth, 'rb') as fp:
real_names = pickle.load(fp)
feature_pth = Path(os.path.join(pth, 'resnet_features/real_image_features.pt'))
if name_pth.is_file():
print('Loading existing training dataset features...')
real_features = torch.load(feature_pth, map_location="cpu")
real_features = real_features.to(device)
feature_flag = True
# load CLIP
elif arch == 'clip':
print('Loading pretrained CLIP...')
model, transform = clip.load("ViT-B/32", device=device)
# check for saved dataset features
print('Checking for existing training dataset features...')
name_pth = Path(os.path.join(pth, 'clip_features/real_name_list'))
if name_pth.is_file():
with open(name_pth, 'rb') as fp:
real_names = pickle.load(fp)
feature_pth = Path(os.path.join(pth, 'clip_features/real_image_features.pt'))
if name_pth.is_file():
print('Loading existing training dataset features...')
real_features = torch.load(feature_pth, map_location="cpu")
real_features = real_features.to(device)
feature_flag = True
elif dataset == 'faces':
print('Dataset ', dataset)
pth = '/home/wn455752/repo/evaluation/features/faces'
# load pretrained VGGFace
if arch == 'cnn':
print('Loading pretrained VGGFace...')
path_to_pretrained_weights = '/home/wn455752/repo/evaluation/pretrained/vggface_pretrained/VGG_FACE.t7'
model = VGG_16().to(device)
model.load_weights(path=path_to_pretrained_weights)
transform = transforms.Compose([transforms.ToTensor(), # transform PIL.Image to torch.Tensor
transforms.Resize((224,224)), # resize to VGG input shape
transforms.Lambda(lambda x: x * 255)]) # scale values to VGG input range
with torch.no_grad():
model.eval()
# check for saved dataset features
print('Checking for existing training dataset features...')
name_pth = Path(os.path.join(pth, 'vggface_features/real_name_list'))
if name_pth.is_file():
with open(name_pth, 'rb') as fp:
real_names = pickle.load(fp)
feature_pth = Path(os.path.join(pth, 'vggface_features/real_image_features.pt'))
if name_pth.is_file():
print('Loading existing training dataset features...')
real_features = torch.load(feature_pth, map_location="cpu")
real_features = real_features.to(device)
feature_flag = True
# load CLIP
elif arch == 'clip':
print('Loading pretrained CLIP...')
model, transform = clip.load("ViT-B/32", device=device)
# check for saved dataset features
print('Checking for existing training dataset features...')
name_pth = Path(os.path.join(pth, 'clip_features/real_name_list'))
if name_pth.is_file():
with open(name_pth, 'rb') as fp:
real_names = pickle.load(fp)
feature_pth = Path(os.path.join(pth, 'clip_features/real_image_features.pt'))
if name_pth.is_file():
print('Loading existing training dataset features...')
real_features = torch.load(feature_pth, map_location="cpu")
real_features = real_features.to(device)
feature_flag = True
knn = kNN()
# get images
if not feature_flag:
print('Collecting training images...')
real_names, real_tensor = knn.get_images(path_to_training_images, transform)
with open(name_pth, 'wb') as fp:
pickle.dump(real_names, fp)
print('Collecting generating images...')
generated_names, generated_tensor = knn.get_images(path_to_generated_images, transform)
# extract features
if not feature_flag:
print('Extracting features from training images...')
real_features = knn.feature_extractor(real_tensor, model, device)
torch.save(real_features, feature_pth)
print('Extracting features from generated images...')
generated_features = knn.feature_extractor(generated_tensor, model, device)
if sample == 'all':
sample_size = len(generated_names)
else:
sample_size = int(sample)
if mode == 'kNN':
print('Finding kNNs...')
knn.kNN(real_names, generated_names,
real_features, generated_features,
path_to_training_images, path_to_generated_images,
k=k_kNN,
sample=sample_size,
size=size,
name_appendix=name_appendix)
elif mode == 'pairs':
print('Finding closest pairs...')
knn.nearest_neighbor(real_names, generated_names,
real_features, generated_features,
path_to_training_images, path_to_generated_images,
sample=sample_size,
size=size,
name_appendix=name_appendix)
print('Finish!')
from torchmetrics.image.fid import FrechetInceptionDistance
from torchmetrics.image.inception import InceptionScore
from torchmetrics.image.kid import KernelInceptionDistance
import re
import os
from PIL import Image
from torchvision import transforms
import torch
def ddpm_evaluator(model,
device,
dataloader,
checkpoint,
experiment_path,
sample_idx=0,
**args,
):
'''
Takes a trained diffusion model from 'checkpoint' and evaluates its performance on the test
dataset 'dataloader' w.r.t. the three most important perfromance metrics; FID, IS, KID. We continue
the progress of our evaluation function for the LDM upscalaer and may update this function accordingly.
checkpoint: Name of the saved pth. file containing the trained weights and biases
experiment_path: Path to the experiment folder where the evaluation results will be stored
dataloader: Loads the test dataset for evaluation
sample_idx: Integer that denotes which sample directory sample_{sample_idx} from the checkpoint model shall be used for evaluation
'''
import argparse
import pickle
from pathlib import Path
import clip
from evaluation.helpers.vggface import *
from torchvision.models import resnet50
from evaluation.helpers.kNN import *
from evaluation.helpers.metrics import *
def ddpm_evaluator(experiment_path, realpath, genpath, data='lhq', size=128, arch='clip', mode='kNN', k=3, sample=10, name_appendix='', fid='no'):
#device = "mps" if torch.backends.mps.is_available() else "cpu"
device = "cuda" if torch.cuda.is_available() else "cpu"
print('device:', device)
path_to_real_images = realpath
path_to_generated_images = genpath
dataset = data
arch = arch
mode = mode
k_kNN = k
sample = sample
name_appendix = name_appendix
fid_bool = fid
size = size
print('Start')
output_path = Path(os.path.join(experiment_path,'eval_output'))
if not output_path.is_dir():
os.mkdir(output_path)
# output path
os.chdir(output_path)
txt_filename = 'evaluation_' + dataset + '_' + arch + '_' + mode + '-' + name_appendix + '.txt'
with open(txt_filename, 'w') as f:
f.write(f'Path to real images: {path_to_real_images}\n')
f.write(f'Path to generated images: {path_to_generated_images}\n')
f.write(f'Experiment on {dataset} dataset with images of resolution {size}x{size}\n')
f.write(f'Using {arch} model to extract features\n')
f.write(f'Plot of {mode} on {sample} samples\n')
f.write(f'Quantitative metrics computed: {fid_bool}\n')
# load data
path_to_training_images = os.path.join(path_to_real_images, 'train')
path_to_test_images = os.path.join(path_to_real_images, 'test')
if fid_bool == 'yes':
# metrics eval
eval_images = image_to_tensor(path_to_test_images, device=device)
generated = image_to_tensor(path_to_generated_images, device=device)
print('Computing FID, KID & inception scores...')
fid_score, kid_score, is_score = compute_scores(eval_images, generated, device)
with open(txt_filename, 'a') as f:
f.write(f'FID score: {fid_score}\n')
f.write(f'KID score: {kid_score}\n')
f.write(f'Inception score: {is_score}\n')
print('Computing Clean FID scores...')
clean_fid_score, clip_clean_fid_score = clean_fid(path_to_test_images, path_to_generated_images)
with open(txt_filename, 'a') as f:
f.write(f'Clean FID score: {clean_fid_score}\n')
f.write(f'Clean FID score with CLIP features: {clip_clean_fid_score}\n')
print('Computing FID infinity and IS infinity scores...')
fid_infinity, is_infinity = fid_inf_is_inf(path_to_test_images, path_to_generated_images)
with open(txt_filename, 'a') as f:
f.write(f'FID Infinity score: {fid_infinity}\n')
f.write(f'IS Infinity score: {is_infinity}\n')
#train_images = image_to_tensor(path_to_training_images, device=device)
checkpoint_path = f'{experiment_path}trained_ddpm/{checkpoint}'
# create evaluation directory for the complete experiment (if first time sampling images)
output_dir = f'{experiment_path}evaluations/'
os.makedirs(output_dir, exist_ok=True)
#print('Computing PSNR and SSIM scores...')
#psnr_score, ssim_score = compute_ssim_psnr_scores(train_images, generated, device)
#with open(txt_filename, 'a') as f:
# f.write(f'PSNR score: {psnr_score}\n')
# f.write(f'SSIM score: {ssim_score}\n')
# create evaluation directory for the current version of the trained model
model_name = os.path.basename(checkpoint_path)
epoch = re.findall(r'\d+', model_name)
if epoch:
e = int(epoch[0])
print('Loading model...', arch)
feature_flag = False
# kNN-based eval
if dataset == 'lhq':
print('Dataset ', dataset)
pth = '/home/wn455752/repo/evaluation/features/lhq'
# load pretrained ResNet50
if arch == 'cnn':
print('Loading pretrained ResNet50...')
path_to_pretrained_weights = '/home/wn455752/repo/evaluation/pretrained/resnet50_places365_pretrained/resnet50_places365_weights.pth'
weights = torch.load(path_to_pretrained_weights)
model = resnet50().to(device)
model.load_state_dict(weights)
transform = transforms.Compose([transforms.ToTensor(), # transform PIL.Image to torch.Tensor
transforms.Lambda(lambda x: x * 255)]) # scale values to VGG input range
with torch.no_grad():
model.eval()
print('Checking for existing training dataset features...')
# check for saved dataset features
name_pth = Path(os.path.join(pth, 'resnet_features/real_name_list'))
if name_pth.is_file():
with open(name_pth, 'rb') as fp:
real_names = pickle.load(fp)
feature_pth = Path(os.path.join(pth, 'resnet_features/real_image_features.pt'))
if name_pth.is_file():
print('Loading existing training dataset features...')
real_features = torch.load(feature_pth, map_location="cpu")
real_features = real_features.to(device)
feature_flag = True
# load CLIP
elif arch == 'clip':
print('Loading pretrained CLIP...')
model, transform = clip.load("ViT-B/32", device=device)
# check for saved dataset features
print('Checking for existing training dataset features...')
name_pth = Path(os.path.join(pth, 'clip_features/real_name_list'))
if name_pth.is_file():
with open(name_pth, 'rb') as fp:
real_names = pickle.load(fp)
feature_pth = Path(os.path.join(pth, 'clip_features/real_image_features.pt'))
if name_pth.is_file():
print('Loading existing training dataset features...')
real_features = torch.load(feature_pth, map_location="cpu")
real_features = real_features.to(device)
feature_flag = True
elif dataset == 'faces':
print('Dataset ', dataset)
presaved_feature_pth = '/home/wn455752/repo/evaluation/features/faces'
# load pretrained VGGFace
if arch == 'cnn':
print('Loading pretrained VGGFace...')
path_to_pretrained_weights = '/home/wn455752/repo/evaluation/pretrained/vggface_pretrained/VGG_FACE.t7'
model = VGG_16().to(device)
model.load_weights(path=path_to_pretrained_weights)
transform = transforms.Compose([transforms.ToTensor(), # transform PIL.Image to torch.Tensor
transforms.Resize((224,224)), # resize to VGG input shape
transforms.Lambda(lambda x: x * 255)]) # scale values to VGG input range
with torch.no_grad():
model.eval()
# check for saved dataset features
print('Checking for existing training dataset features...')
name_pth = Path(os.path.join(presaved_feature_pth, 'vggface_features/real_name_list'))
if name_pth.is_file():
with open(name_pth, 'rb') as fp:
real_names = pickle.load(fp)
feature_pth = Path(os.path.join(presaved_feature_pth, 'vggface_features/real_image_features.pt'))
if name_pth.is_file():
print('Loading existing training dataset features...')
real_features = torch.load(feature_pth, map_location="cpu")
real_features = real_features.to(device)
feature_flag = True
# load CLIP
elif arch == 'clip':
print('Loading pretrained CLIP...')
model, transform = clip.load("ViT-B/32", device=device)
# check for saved dataset features
print('Checking for existing training dataset features...')
name_pth = Path(os.path.join(presaved_feature_pth, 'clip_features/real_name_list'))
if name_pth.is_file():
with open(name_pth, 'rb') as fp:
real_names = pickle.load(fp)
feature_pth = Path(os.path.join(presaved_feature_pth, 'clip_features/real_image_features.pt'))
if name_pth.is_file():
print('Loading existing training dataset features...')
real_features = torch.load(feature_pth, map_location="cpu")
real_features = real_features.to(device)
feature_flag = True
knn = kNN()
# get images
if not feature_flag:
print('Collecting training images...')
real_names, real_tensor = knn.get_images(path_to_training_images, transform)
with open(name_pth, 'wb') as fp:
pickle.dump(real_names, fp)
print('Collecting generating images...')
generated_names, generated_tensor = knn.get_images(path_to_generated_images, transform)
# extract features
if not feature_flag:
print('Extracting features from training images...')
real_features = knn.feature_extractor(real_tensor, model, device)
torch.save(real_features, feature_pth)
print('Extracting features from generated images...')
generated_features = knn.feature_extractor(generated_tensor, model, device)
if sample == 'all':
sample_size = len(generated_names)
else:
raise ValueError(f"No digit found in the filename: {filename}")
model_dir = os.path.join(output_dir,f'epoch_{e}')
os.makedirs(model_dir, exist_ok=True)
# create the evaluation directory for this evaluation run for the current version of the model
eval_dir_list = [d for d in os.listdir(model_dir) if os.path.isdir(os.path.join(model_dir, d))]
indx_list = [int(d.split('_')[1]) for d in eval_dir_list if d.startswith('evaluation_')]
j = max(indx_list, default=-1) + 1
eval_dir = os.path.join(model_dir, f'evaluation_{j}')
os.makedirs(eval_dir, exist_ok=True)
# Compute metrics
eval_path = os.path.join(eval_dir, 'eval.txt')
# get sampled images
transform = transforms.Compose([transforms.ToTensor(), transforms.Lambda(lambda x: (x * 255).type(torch.uint8))])
sample_path = os.path.join(f'{experiment_path}samples/',f'epoch_{e}',f'sample_{sample_idx}')
ignore_tensor = f'image_tensor{j}'
images = []
for samplename in os.listdir(sample_path):
if samplename == ignore_tensor:
continue
img = Image.open(os.path.join(sample_path, samplename))
img = transform(img)
images.append(img)
# split them into batches for GPU memory
generated = torch.stack(images).to(device)
generated_batches = torch.split(generated, dataloader.batch_size)
nr_generated_batches = len(generated_batches)
nr_real_batches = len(dataloader)
# Init FID, IS and KID scores
fid = FrechetInceptionDistance(normalize = False).to(device)
iscore = InceptionScore(normalize=False).to(device)
kid = KernelInceptionDistance(normalize=False, subset_size=32).to(device)
# Update scores for the full testing dataset w.r.t. the sampled batches
for idx,(data, _) in enumerate(dataloader):
data = data.to(device)
fid.update(data, real=True)
kid.update(data, real=True)
if idx < nr_generated_batches:
gen = generated_batches[idx].to(device)
fid.update(gen, real=False)
kid.update(gen, real=False)
iscore.update(gen)
# If there are sampled images left, add them too
for idx in range(nr_real_batches, nr_generated_batches):
gen = generated_batches[idx].to(device)
fid.update(gen, real=False)
kid.update(gen, real=False)
iscore.update(gen)
# compute total FID, IS and KID
fid_score = fid.compute()
i_score = iscore.compute()
kid_score = kid.compute()
# store results in txt file
with open(str(eval_path), 'a') as txt:
result = f'FID_epoch_{e}_sample_{sample_idx}:'
txt.write(result + str(fid_score.item()) + '\n')
result = f'KID_epoch_{e}_sample_{sample_idx}:'
txt.write(result + str(kid_score) + '\n')
result = f'IS_epoch_{e}_sample_{sample_idx}:'
txt.write(result + str(i_score) + '\n')
sample_size = int(sample)
if mode == 'kNN':
print('Finding kNNs...')
knn.kNN(output_path,
real_names, generated_names,
real_features, generated_features,
path_to_training_images, path_to_generated_images,
k=k_kNN,
sample=sample_size,
size=size,
name_appendix=name_appendix)
elif mode == 'pairs':
print('Finding closest pairs...')
knn.nearest_neighbor(output_path,
real_names, generated_names,
real_features, generated_features,
path_to_training_images, path_to_generated_images,
sample=sample_size,
size=size,
name_appendix=name_appendix)
print('Finish!')
'''
if __name__ == '__main__':
#device = "mps" if torch.backends.mps.is_available() else "cpu"
device = "cuda" if torch.cuda.is_available() else "cpu"
print('device:', device)
print('Parsing arguments...')
parser = argparse.ArgumentParser()
parser.add_argument('-exp', '--exptpath', required=True,
help='path to experiment', type=str)
parser.add_argument('-rp', '--realpath', required=True,
help='path to real images', type=str)
parser.add_argument('-gp', '--genpath', required=True,
help='path to generated images', type=str)
parser.add_argument('-d', '--data', nargs='?', const='lhq', default='lhq',
help='choose between "lhq" and "face" dataset', type=str)
parser.add_argument('--size', nargs='?', const=128, default=128,
help='resolution of image the model was trained on, default 128 (int)', type=int)
parser.add_argument('-a', '--arch', nargs='?', const='clip', default='clip',
help='choose between "clip" and "cnn", default "clip"', type=str)
parser.add_argument('-m', '--mode', nargs='?', const='kNN', default='kNN',
help='choose between "kNN" and "pairs" for closest_pairs, default "kNN"', type=str)
parser.add_argument('-k', '--k', nargs='?', const=3, default=3,
help='k for kNN, default 3', type=int)
parser.add_argument('-s', '--sample', nargs='?', const=10, default=10,
help='how many generated samples to compare, default 10, int or "all"')
parser.add_argument('-n', '--name', nargs='?', const='', default='',
help='name appendix for the plot file', type=str)
parser.add_argument('--fid', nargs='?', const='no', default='no',
help='compute FID and inception score, choose between "yes" and "no"', type=str)
args = vars(parser.parse_args())
path_to_experiment_folder = args['exptpath']
path_to_real_images = args['realpath']
path_to_generated_images = args['genpath']
dataset = args['data']
arch = args['arch']
mode = args['mode']
k_kNN = args['k']
sample = args['sample']
name_appendix = args['name']
fid_bool = args['fid']
size = args['size']
ddpm_evaluator(experiment_path=path_to_experiment_folder,
realpath=path_to_real_images,
genpath=path_to_generated_images,
data=dataset,
arch=arch,
size=size,
mode=mode,
k=k_kNN,
sample=sample,
name_appendix=name_appendix,
fid=fid_bool)
'''
\ No newline at end of file
File moved
......@@ -4,14 +4,24 @@ We conduct two types of evaluation - qualitative and quantitative.
### Quantitative evaluations -
Quantitative evaluations are carried out to compare different backbone architectures of our unconditional diffusion model.
A set of 10,000 generated samples from each model variant is compared with the test set of the real dataset.
Quantitative metrics can be further categorised into two groups - content variant and content invariant metrics.
Content variant metrics are useful when the model can generate different samples from a noise vector. \
These evaluations are carried out to compare different backbone architectures of our unconditional diffusion model. \
A set of 10,000 generated samples from each model variant is compared with the test set of the real dataset. \
These evaluations include -
1. FID score
2. Inception score
3. Clean FID score (with CLIP)
4. FID infinity and IS infinity scores
Content invariant metrics are useful when the model output can be compared w.r.t a ground truth. \
For example, our model can output the reconstructed version of an input training image (following the entire forward \
and reverse trajectories). \
These evaluation include -
1. SSIM (Structural Similarity Index Metric)
2. PSNR
### Qualitative evaluations -
......
File moved
......@@ -65,7 +65,7 @@ class kNN():
return features
def kNN(self, real_names, generated_names,
def kNN(self, output_path, real_names, generated_names,
real_features, generated_features,
path_to_real_images, path_to_generated_images,
k=3,
......@@ -108,7 +108,6 @@ class kNN():
break
# savefig
output_path = Path(os.path.join(os.getcwd(),'output'))
if not output_path.is_dir():
os.mkdir(output_path)
plot_name = f'{k}NN_{sample}_samples'
......@@ -116,7 +115,7 @@ class kNN():
plot_name = plot_name + '_' + name_appendix + '.png'
fig.savefig(os.path.join(output_path, plot_name))
def nearest_neighbor(self, real_names, generated_names,
def nearest_neighbor(self, output_path, real_names, generated_names,
real_features, generated_features,
path_to_real_images, path_to_generated_images,
sample=10, size=128,
......@@ -163,7 +162,6 @@ class kNN():
ax[i, 1].set_title(f'{real_names[real_img_idx][:-4]}, {knn_score:.2f}', fontsize=8)
#savefig
output_path = Path(os.path.join(os.getcwd(),'output'))
if not output_path.is_dir():
os.mkdir(output_path)
plot_name = f'closest_pairs_top_{sample}'
......
......@@ -8,26 +8,35 @@ from itertools import cycle
from torchmetrics.image.fid import FrechetInceptionDistance
from torchmetrics.image.inception import InceptionScore
from torchmetrics.image.kid import KernelInceptionDistance
from torchmetrics.image import StructuralSimilarityIndexMeasure, PeakSignalNoiseRatio
from cleanfid import fid
from score_infinity import calculate_FID_infinity_path, calculate_IS_infinity_path
def image_to_tensor(path):
transform = transforms.Compose([transforms.ToTensor(),
transforms.ConvertImageDtype(dtype=torch.float),
transforms.Lambda(lambda x: (x * 255).type(torch.uint8))])
file_list = os.listdir(path)
img_list = []
#for file in tqdm(file_list, desc='Loading images'):
for file in file_list:
from evaluation.helpers.score_infinity import calculate_FID_infinity_path, calculate_IS_infinity_path
def image_to_tensor(path, sample='all', device='cpu'):
transform_resize = transforms.Compose([transforms.ToTensor(), transforms.Resize(128), transforms.Lambda(lambda x: (x * 255).type(torch.uint8))])
transform = transforms.Compose([transforms.ToTensor(), transforms.Lambda(lambda x: (x * 255).type(torch.uint8)) ])
filelist = os.listdir(path)
if sample == 'all':
sample_size = -1
else:
sample_size = sample
image_list = []
for file in filelist:
if file.endswith('.png'):
img = Image.open(os.path.join(path, file))
if img.size[0] != 128:
img = img.resize((128,128))
img = transform(img)
img_list.append(img)
return torch.stack(img_list)
filepath = os.path.join(path, file)
im = Image.open(filepath)
if im.size[0] != 128:
im = transform_resize(im)
else:
im = transform(im)
image_list.append(im)
if len(image_list) == sample_size:
break
print(f'current sample size: {len(image_list)}')
# convert list of tensors to tensor
image_tensor = torch.stack(image_list).to(device)
return image_tensor
def compute_scores(real, generated, device):
......@@ -68,3 +77,19 @@ def fid_inf_is_inf(path_to_real_images, path_to_generated_images, batchsize=128)
is_infinity = calculate_IS_infinity_path(path_to_generated_images, batch_size=batchsize)
return fid_infinity, is_infinity
def compute_ssim_psnr_scores(real, generated, device):
real_dataloader = DataLoader(real, batch_size=128, shuffle=False)
generated_dataloader = DataLoader(generated, batch_size=128, shuffle=False)
ssim = StructuralSimilarityIndexMeasure(data_range=1.0).to(device)
psnr = PeakSignalNoiseRatio().to(device)
for r, g in zip(real_dataloader, cycle(generated_dataloader)):
r = r.to(device)
g = g.to(device)
ssim.update(preds=g, target=r)
psnr.update(preds=g, target=r)
ssim_score = ssim.compute()
psnr_score = psnr.compute()
return psnr_score, ssim_score
\ No newline at end of file
......@@ -15,7 +15,7 @@ import glob
from tqdm import tqdm
from PIL import Image
from scipy import linalg
from inception import *
from evaluation.helpers.inception import *
class randn_sampler():
"""
......
File moved
import os
import torch
from torchvision import transforms
import re
def ddpm_sampler(model, checkpoint, experiment_path, device, intermediate=False, batch_size=15,sample_all=False,n_times=1):
'''
Samples a tensor of 'n_times'*'batch_size' images from a trained diffusion model with 'checkpoint'. The generated
images are stored in the directory 'experiment_path/samples/epoch_{e}/sample_{j}. Where e is the epoch
w.r.t. the model which we are sampling form and j is an index separating the sampled batches for every call of this
sampling function.
model: Diffusion model
checkpoint: Name of the saved pth. file containing the trained weights and biases
experiment_path: Path to the experiment directory where the samples will saved under the diectory samples
batch_size: The number of images the model samples
intermediate: Bool value. If False the sampling function will draw a batch of images, else it will just
sample a single image, but store all the intermediate noised latents along the reverse chain
sample_all: If True, samples a batch of images for the given model at every stored checkpoint at once
n_times: Integer denoting how many times we draw a batch of 'batch_size'. If we want to draw 10k images
the GPU will draw batches of 512 images 20 times to reach this goal.
'''
# If we want to sample from every checkpoint of the current model, recursively call this function for all checkpoints
if sample_all:
f = f'{experiment_path}trained_ddpm/'
checkpoint_list = [checkpoint_i for checkpoint_i in os.listdir(f) if checkpoint_i.endswith(".pth")]
for checkpoint_i in os.listdir(f):
if checkpoint_i.endswith(".pth"):
ddpm_sampler(model, checkpoint_i, experiment_path, device, sample_all=False)
return 0
# load model
try:
checkpoint_path = f'{experiment_path}trained_ddpm/{checkpoint}'
checkpoint = torch.load(checkpoint_path)
# load weights and biases of the U-Net
net_state_dict = checkpoint['model']
model.net.load_state_dict(net_state_dict)
model = model.to(device)
except Exception as e:
print("Error loading checkpoint. Exception:", e)
# create samples directory for the complete experiment
output_dir = f'{experiment_path}samples/'
os.makedirs(output_dir, exist_ok=True)
# create sample directory for the current checkpoint epoch of the trained model
model_name = os.path.basename(checkpoint_path)
epoch = re.findall(r'\d+', model_name)
if epoch:
e = int(epoch[0])
else:
raise ValueError(f"No digit found in the filename: {filename}")
model_dir = os.path.join(output_dir,f'epoch_{e}')
os.makedirs(model_dir, exist_ok=True)
# create the sample directory for this sampling run for the current version of the model
sample_dir_list = [d for d in os.listdir(model_dir) if os.path.isdir(os.path.join(model_dir, d))]
indx_list = [int(d.split('_')[1]) for d in sample_dir_list if d.startswith('sample_')]
j = max(indx_list, default=-1) + 1
sample_dir = os.path.join(model_dir, f'sample_{j}')
os.makedirs(sample_dir, exist_ok=True)
# transform
back2pil = transforms.Compose([transforms.Normalize(mean=(-1,-1,-1),std=(2,2,2)),transforms.ToPILImage()])
n = n_times
for k in range(n):
# generate batch_size images
if intermediate:
generated = model.sample_intermediates_latents()
name = 'sample_intermediate'
else:
generated = model.sample(batch_size=batch_size)
name = 'sample'
#store the raw generated tensor
torch.save(generated,os.path.join(sample_dir, f"image_tensor{j}"))
#normalize to (-1,1), not needed after 70 epochs, model learns to adhere to (-1,1)
#a = generated.min()
#b = generated.max()
#A,B=-1,1
#generated = (generated-a)/(b-a)*(B-A)+A
# transform tesnors to pil and save generated images
for i in range(generated.size(0)):
index = i + k*generated.size(0)
image = back2pil(generated[i+k*n])
image_path = os.path.join(sample_dir, f'{name}_{j}_{i}.png')
try:
image.save(image_path)
except Exception as e:
print("Error saving image. Exception:", e)
%% Cell type:code id: tags:
``` python
from trainer.train import *
from dataloader.load import *
from models.Framework import *
from models.all_unets import *
import torch
from torch import nn
import os
```
%% Cell type:markdown id: tags:
# Prepare experiment
1. Choose Hyperparameter Settings
2. Run notebook on local maschine to generate experiment folder with the JSON files containing the settings
3. scp experiment folder to the HPC
4. Run Pipeline by adding following to batch file:
- Train Model: &emsp;&emsp;&emsp;&emsp;&emsp; `python main.py train "<absolute path of experiment folder in hpc>"`
- Sample Images: &emsp;&emsp;&emsp; `python main.py sample "<absolute path of experiment folder in hpc>"`
- Evaluate Model: &emsp;&emsp;&emsp; `python main.py evaluate "<absolute path of experiment folder in hpc>"`
%% Cell type:code id: tags:
``` python
import torch
####
# Settings
####
# Dataset path
datapath = "/work/lect0100/lhq_256"
# Experiment setup
run_name = 'main_test1' # WANDB and experiment folder Name!
checkpoint = None #'model_epoch_8.pth' # Name of checkpoint pth file or None
experiment_path = "/work/lect0100/main_experiment/" + run_name +'/'
#experiment_path = "/work/lect0100/main_experiment/" + run_name +'/'
# local
experiment_path = '/Users/roy/Desktop/Workspace/RWTH/SoSe 2023/Deep Learning Lab/DLL_vsc/diffusion_project/experiments/' + run_name + '/'
# Path to save generated experiment folder on local machine
local_path ="experiments/" + run_name + '/settings'
# Diffusion Model Settings
diffusion_steps = 1000
image_size = 128
channels = 3
# Training
batchsize = 32
epochs = 100
store_iter = 10
eval_iter = 500
learning_rate = 0.0001
optimizername = "torch.optim.AdamW"
optimizer_params = None
verbose = False
# checkpoint = None #(If no checkpoint training, ie. random weights)
# Sampling
sample_size = 20
intermediate = False # True if you want to sample one image and all ist intermediate latents
reconstruction = False # True if you want to sample reconstructed versions of training images
sample_all=False
# Evaluating
...
eval_realpath = '/Users/roy/Desktop/Workspace/RWTH/SoSe 2023/Deep Learning Lab/DLL_vsc/data/lhq' # path to real images (assumes the dir has two subdirs - train and test)
eval_genpath = '/Users/roy/Desktop/Workspace/RWTH/SoSe 2023/Deep Learning Lab/DLL_vsc/samples/lhq_samples/rf200_good_ones' # path to sampled images
eval_data='lhq' # choose between 'face' and 'lhq'
eval_size=image_size # resolution of training images
eval_arch='clip' # choose between 'clip' and 'cnn'
eval_mode='kNN' # choose between 'kNN' and 'pairs' (for closest pairs)
eval_k_kNN=3 # choose k for kNN
eval_sample=10 # in case of kNN, find kNN of first 'sample' number of generated samples
# in case of pairs, find top 'sample' number of closest pairs of
# real-generated images from the entire set of generated samples
eval_name_appendix='' # name appendix
eval_fid='yes' # whether to compute FID, IS scores
###
# Advanced Settings Dictionaries
###
meta_setting = dict(modelname = "UNet_Res",
dataset = "UnconditionalDataset",
framework = "DDPM",
trainloop_function = "ddpm_trainer",
sampling_function = 'ddpm_sampler',
evaluation_function = 'ddpm_evaluator',
batchsize = batchsize
)
dataset_setting = dict(fpath = datapath,
img_size = image_size,
frac =0.8,
skip_first_n = 0,
ext = ".png",
transform=True
)
model_setting = dict( n_channels=64,
fctr = [1,2,4,4,8],
time_dim=256,
attention = True,
)
"""
outdated
model_setting = dict( channels_in=channels,
channels_out =channels ,
activation='relu', # activation function. Options: {'relu', 'leakyrelu', 'selu', 'gelu', 'silu'/'swish'}
weight_init='he', # weight initialization. Options: {'he', 'torch'}
projection_features=64, # number of image features after first convolution layer
time_dim=batchsize, #dont chnage!!!
time_channels=diffusion_steps, # number of time channels #TODO same as diffusion steps?
num_stages=4, # number of stages in contracting/expansive path
stage_list=None, # specify number of features produced by stages
num_blocks=1, # number of ConvResBlock in each contracting/expansive path
num_groupnorm_groups=32, # number of groups used in Group Normalization inside a ConvResBlock
dropout=0.1, # drop-out to be applied inside a ConvResBlock
attention_list=None, # specify MHA pattern across stages
num_attention_heads=1,
)
"""
framework_setting = dict(
diffusion_steps = diffusion_steps, # dont change!!
out_shape = (channels,image_size,image_size), # dont change!!
noise_schedule = 'linear',
beta_1 = 1e-4,
beta_T = 0.02,
alpha_bar_lower_bound = 0.9,
var_schedule = 'same',
kl_loss = 'simplified',
recon_loss = 'none',
)
training_setting = dict(
epochs = epochs,
store_iter = store_iter,
eval_iter = eval_iter,
optimizer_class=optimizername,
optimizer_params = optimizer_params,
#optimizer_params=dict(lr=learning_rate), # don't change!
learning_rate = learning_rate,
run_name=run_name,
checkpoint= checkpoint,
experiment_path = experiment_path,
verbose = verbose,
T_max = 0.8*90000/32*100, # cosine lr param len(train_ds)/batchsize * total epochs to 0
eta_min= 1e-10, # cosine lr param
)
sampling_setting = dict(
checkpoint = checkpoint,
experiment_path = experiment_path,
batch_size = sample_size,
intermediate = intermediate,
reconstruction = reconstruction,
sample_all = sample_all
)
# TODO
evaluation_setting = dict(
checkpoint = checkpoint,
experiment_path = experiment_path,
experiment_path=experiment_path,
realpath=eval_realpath,
genpath=eval_genpath,
data=eval_data,
size=eval_size,
arch=eval_arch,
mode=eval_mode,
k=eval_k_kNN,
sample=eval_sample,
name_appendix=eval_name_appendix,
fid=eval_fid
)
```
%% Cell type:code id: tags:
``` python
import os
import json
f = local_path
if os.path.exists(f):
print("path already exists, pick a new name!")
print("break")
else:
print("create folder")
#os.mkdir(f)
os.makedirs(f, exist_ok=True)
print("folder created ")
with open(f+"/meta_setting.json","w+") as fp:
json.dump(meta_setting,fp)
with open(f+"/dataset_setting.json","w+") as fp:
json.dump(dataset_setting,fp)
with open(f+"/model_setting.json","w+") as fp:
json.dump(model_setting,fp)
with open(f+"/framework_setting.json","w+") as fp:
json.dump(framework_setting,fp)
with open(f+"/training_setting.json","w+") as fp:
json.dump(training_setting,fp)
with open(f+"/sampling_setting.json","w+") as fp:
json.dump(sampling_setting,fp)
with open(f+"/evaluation_setting.json","w+") as fp:
json.dump(evaluation_setting,fp)
print("stored json files in folder")
print(meta_setting)
print(dataset_setting)
print(model_setting)
print(framework_setting)
print(training_setting)
print(sampling_setting)
print(evaluation_setting)
```
%% Output
create folder
folder created
stored json files in folder
{'modelname': 'UNet_Res', 'dataset': 'UnconditionalDataset', 'framework': 'DDPM', 'trainloop_function': 'ddpm_trainer', 'sampling_function': 'ddpm_sampler', 'evaluation_function': 'ddpm_evaluator', 'batchsize': 32}
{'fpath': '/work/lect0100/lhq_256', 'img_size': 128, 'frac': 0.8, 'skip_first_n': 0, 'ext': '.png', 'transform': True}
{'n_channels': 64, 'fctr': [1, 2, 4, 4, 8], 'time_dim': 256, 'attention': True}
{'diffusion_steps': 1000, 'out_shape': (3, 128, 128), 'noise_schedule': 'linear', 'beta_1': 0.0001, 'beta_T': 0.02, 'alpha_bar_lower_bound': 0.9, 'var_schedule': 'same', 'kl_loss': 'simplified', 'recon_loss': 'none'}
{'epochs': 100, 'store_iter': 10, 'eval_iter': 500, 'optimizer_class': 'torch.optim.AdamW', 'optimizer_params': None, 'learning_rate': 0.0001, 'run_name': 'main_test1', 'checkpoint': None, 'experiment_path': '/work/lect0100/main_experiment/main_test1/', 'verbose': False, 'T_max': 225000.0, 'eta_min': 1e-10}
{'checkpoint': None, 'experiment_path': '/work/lect0100/main_experiment/main_test1/', 'batch_size': 20, 'intermediate': False}
{'checkpoint': None, 'experiment_path': '/work/lect0100/main_experiment/main_test1/'}
{'epochs': 100, 'store_iter': 10, 'eval_iter': 500, 'optimizer_class': 'torch.optim.AdamW', 'optimizer_params': None, 'learning_rate': 0.0001, 'run_name': 'main_test1', 'checkpoint': None, 'experiment_path': '/Users/roy/Desktop/Workspace/RWTH/SoSe 2023/Deep Learning Lab/DLL_vsc/diffusion_project/experiments/main_test1/', 'verbose': False, 'T_max': 225000.0, 'eta_min': 1e-10}
{'checkpoint': None, 'experiment_path': '/Users/roy/Desktop/Workspace/RWTH/SoSe 2023/Deep Learning Lab/DLL_vsc/diffusion_project/experiments/main_test1/', 'batch_size': 20, 'intermediate': False, 'reconstruction': False, 'sample_all': False}
{'experiment_path': '/Users/roy/Desktop/Workspace/RWTH/SoSe 2023/Deep Learning Lab/DLL_vsc/diffusion_project/experiments/main_test1/', 'realpath': '/Users/roy/Desktop/Workspace/RWTH/SoSe 2023/Deep Learning Lab/DLL_vsc/data/lhq', 'genpath': '/Users/roy/Desktop/Workspace/RWTH/SoSe 2023/Deep Learning Lab/DLL_vsc/samples/lhq_samples/rf200_good_ones', 'data': 'lhq', 'size': 128, 'arch': 'clip', 'mode': 'kNN', 'k': 3, 'sample': 10, 'name_appendix': '', 'fid': 'no'}
%% Cell type:code id: tags:
``` python
```
......
......@@ -2,7 +2,7 @@
import json
import sys
from dataloader.load import *
from models.UnconditionalDiffusionModel import *
from models.Framework import *
from trainer.train import ddpm_trainer
from evaluation.sample import ddpm_sampler
from evaluation.evaluate import ddpm_evaluator
......@@ -31,16 +31,19 @@ def train_func(f):
training_setting = json.load(fp)
training_setting["optimizer_class"] = eval(training_setting["optimizer_class"])
# init Dataloaders
batchsize = meta_setting["batchsize"]
training_dataset = globals()[meta_setting["dataset"]](train = True,**dataset_setting)
test_dataset = globals()[meta_setting["dataset"]](train = False,**dataset_setting)
training_dataloader = torch.utils.data.DataLoader( training_dataset,batch_size=batchsize)
test_dataloader = torch.utils.data.DataLoader(test_dataset,batch_size=batchsize)
# init UNet
net = globals()[meta_setting["modelname"]](**model_setting).to(device)
#net = torch.compile(net)
net = net.to(device)
# init Diffusion Model
framework = globals()[meta_setting["framework"]](net = net,device=device, **framework_setting)
print(f"META SETTINGS:\n\n {meta_setting}\n\n")
......@@ -59,7 +62,7 @@ def sample_func(f):
device = 'cuda' if torch.cuda.is_available() else 'cpu'
print(f"device: {device}\n\n")
print(f"folderpath: {f}\n\n")
# load all settings
with open(f+"/meta_setting.json","r") as fp:
meta_setting = json.load(fp)
......@@ -72,14 +75,12 @@ def sample_func(f):
with open(f+"/sampling_setting.json","r") as fp:
sampling_setting = json.load(fp)
# set sampling batch_size
batchsize = sampling_setting["batch_size"]
# init UNet
# init Unet
batchsize = meta_setting["batchsize"]
net = globals()[meta_setting["modelname"]](**model_setting).to(device)
#net = torch.compile(net)
net = net.to(device)
# init Diffusion Model
# init unconditional diffusion model
framework = globals()[meta_setting["framework"]](net = net,device=device, **framework_setting)
print(f"META SETTINGS:\n\n {meta_setting}\n\n")
......@@ -98,7 +99,6 @@ def evaluate_func(f):
print(f"device: {device}\n\n")
print(f"folderpath: {f}\n\n")
#load all settings
with open(f+"/meta_setting.json","r") as fp:
meta_setting = json.load(fp)
......@@ -114,19 +114,6 @@ def evaluate_func(f):
with open(f+"/dataset_setting.json","r") as fp:
dataset_setting = json.load(fp)
# init Dataloader
batchsize = meta_setting["batchsize"]
test_dataset = globals()[meta_setting["dataset"]](train = False,**dataset_setting)
test_dataloader = torch.utils.data.DataLoader(test_dataset,batch_size=batchsize, shuffle=False)
# init UNet
net = globals()[meta_setting["modelname"]](**model_setting).to(device)
#net = torch.compile(net)
net = net.to(device)
# init Diffusion Model
framework = globals()[meta_setting["framework"]](net = net,device=device, **framework_setting)
print(f"META SETTINGS:\n\n {meta_setting}\n\n")
print(f"DATASET SETTINGS:\n\n {dataset_setting}\n\n")
print(f"MODEL SETTINGS:\n\n {model_setting}\n\n")
......@@ -134,14 +121,19 @@ def evaluate_func(f):
print(f"EVALUATION SETTINGS:\n\n {evaluation_setting}\n\n")
print("\n\nSTART EVALUATION\n\n")
globals()[meta_setting["evaluation_function"]](model=framework, device=device, dataloader = test_dataloader,safepath = f,**evaluation_setting,)
#globals()[meta_setting["evaluation_function"]](model=framework, device=device, dataloader = test_dataloader,safepath = f,**evaluation_setting,)
globals()[meta_setting["evaluation_function"]](**evaluation_setting)
print("\n\nFINISHED EVALUATION\n\n")
# run training, sampling or evaluation
if __name__ == '__main__':
print(sys.argv)
functions = {'train': train_func,'sample': sample_func,'evaluate': evaluate_func}
functions[sys.argv[1]](sys.argv[2])
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment