diff --git a/dataloader/load.py b/dataloader/load.py index 67f0fce00e15f4195d2909511843dd1fca861f3e..7fba61ce564989371d89384ba0c3b2e63faf2121 100644 --- a/dataloader/load.py +++ b/dataloader/load.py @@ -71,17 +71,17 @@ class UnconditionalDataset_LHQ(Dataset): # Dataset used when training on CelebAHQ. class UnconditionalDataset_CelebAHQ(Dataset): - def __init__(self,fpath,img_size,train,frac =0.8,skip_first_n = 0,ext = ".png",transform=True ): + def __init__(self,fpath,img_size,train,frac =0.8,skip_first_n = 0,ext = ".png",transform=True): """ Args: - fpath (string): Path to the folder where images are stored - img_size (int): size of output image img_size=height=width - ext (string): type of images used(eg .png) + fpath (string): Path to the folder where images are stored + img_size (int): Size of output image img_size=height=width + ext (string): Type of images used(eg .png) transform (Bool): Image augmentation for diffusion model - skip_first_n: skips the first n values. Usefull for datasets that are sorted by increasing Likeliehood - train (Bool): Choose dataset to be either train set or test set. frac(float) required - frac (float): value within (0,1] (seeded)random shuffles dataset, then divides into train and test set. - """ + skip_first_n: Skips the first n values. Usefull for datasets that are sorted by increasing Likeliehood + train (Bool): Choose dataset to be either train set or test set. frac(float) required + frac (float): Value within (0,1] (seeded)random shuffles dataset, then divides into train and test set. + """ # they provide a fixed train and validation split if train: fpath = os.path.join(fpath, 'train') @@ -96,7 +96,6 @@ class UnconditionalDataset_CelebAHQ(Dataset): self.df = df[df["Filepath"].str.endswith(ext)] if transform: - # for training intermediate_size = 137 theta = np.pi/4 -np.arccos(intermediate_size/(np.sqrt(2)*img_size)) #Check dataloading.ipynb in analysis-depot for more details @@ -107,7 +106,6 @@ class UnconditionalDataset_CelebAHQ(Dataset): transforms.RandomHorizontalFlip(p=0.5), transforms.Normalize(mean=(0.5,0.5,0.5),std=(0.5,0.5,0.5))]) - transform_flip = transforms.Compose([transforms.ToTensor(), transforms.Resize(img_size, antialias=True), transforms.RandomHorizontalFlip(p=0.5), @@ -115,9 +113,7 @@ class UnconditionalDataset_CelebAHQ(Dataset): self.transform = transforms.RandomChoice([transform_rotate_flip,transform_flip]) else : - # for evaluation self.transform = transforms.Compose([transforms.ToTensor(), - transforms.Lambda(lambda x: (x * 255).type(torch.uint8)), transforms.Resize(img_size)]) def __len__(self):