Skip to content
Snippets Groups Projects
Commit 7e935796 authored by Gonzalo Martin Garcia's avatar Gonzalo Martin Garcia
Browse files

forgot to update no transform for the celebAHQ dataloader

parent 25d26f13
Branches
No related tags found
No related merge requests found
......@@ -75,12 +75,12 @@ class UnconditionalDataset_CelebAHQ(Dataset):
"""
Args:
fpath (string): Path to the folder where images are stored
img_size (int): size of output image img_size=height=width
ext (string): type of images used(eg .png)
img_size (int): Size of output image img_size=height=width
ext (string): Type of images used(eg .png)
transform (Bool): Image augmentation for diffusion model
skip_first_n: skips the first n values. Usefull for datasets that are sorted by increasing Likeliehood
skip_first_n: Skips the first n values. Usefull for datasets that are sorted by increasing Likeliehood
train (Bool): Choose dataset to be either train set or test set. frac(float) required
frac (float): value within (0,1] (seeded)random shuffles dataset, then divides into train and test set.
frac (float): Value within (0,1] (seeded)random shuffles dataset, then divides into train and test set.
"""
# they provide a fixed train and validation split
if train:
......@@ -96,7 +96,6 @@ class UnconditionalDataset_CelebAHQ(Dataset):
self.df = df[df["Filepath"].str.endswith(ext)]
if transform:
# for training
intermediate_size = 137
theta = np.pi/4 -np.arccos(intermediate_size/(np.sqrt(2)*img_size)) #Check dataloading.ipynb in analysis-depot for more details
......@@ -107,7 +106,6 @@ class UnconditionalDataset_CelebAHQ(Dataset):
transforms.RandomHorizontalFlip(p=0.5),
transforms.Normalize(mean=(0.5,0.5,0.5),std=(0.5,0.5,0.5))])
transform_flip = transforms.Compose([transforms.ToTensor(),
transforms.Resize(img_size, antialias=True),
transforms.RandomHorizontalFlip(p=0.5),
......@@ -115,9 +113,7 @@ class UnconditionalDataset_CelebAHQ(Dataset):
self.transform = transforms.RandomChoice([transform_rotate_flip,transform_flip])
else :
# for evaluation
self.transform = transforms.Compose([transforms.ToTensor(),
transforms.Lambda(lambda x: (x * 255).type(torch.uint8)),
transforms.Resize(img_size)])
def __len__(self):
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment