Skip to content
Snippets Groups Projects
Commit d0cf545a authored by Gonzalo Martin Garcia's avatar Gonzalo Martin Garcia
Browse files

insert original dataloader for training on Landscapes (LHQ256).

parent 4dd87bb4
No related branches found
No related tags found
No related merge requests found
...@@ -20,69 +20,53 @@ class UnconditionalDataset(Dataset): ...@@ -20,69 +20,53 @@ class UnconditionalDataset(Dataset):
""" """
### Create DataFrame ### Create DataFrame
#file_list = [] file_list = []
#for root, dirs, files in os.walk(fpath, topdown=False): for root, dirs, files in os.walk(fpath, topdown=False):
# for name in sorted(files): for name in sorted(files):
# file_list.append(os.path.join(root, name)) file_list.append(os.path.join(root, name))
#
#df = pd.DataFrame({"Filepath":file_list},) df = pd.DataFrame({"Filepath":file_list},)
#self.df = df[df["Filepath"].str.endswith(ext)] self.df = df[df["Filepath"].str.endswith(ext)]
if skip_first_n: if skip_first_n:
self.df = self.df[skip_first_n:] self.df = self.df[skip_first_n:]
print(fpath)
if train: if train:
fpath = os.path.join(fpath, 'train') df_train = self.df.sample(frac=frac,random_state=2)
self.df = df_train
else: else:
fpath = os.path.join(fpath, 'valid') df_train = self.df.sample(frac=frac,random_state=2)
df_test = df.drop(df_train.index)
file_list =[] self.df = df_test
for root, dirs, files in os.walk(fpath, topdown=False):
for name in sorted(files):
file_list.append(os.path.join(root, name))
df = pd.DataFrame({"Filepath":file_list},)
self.df = df[df["Filepath"].str.endswith(ext)]
if transform: if transform:
# for training
intermediate_size = 150 intermediate_size = 150
theta = np.pi/4 -np.arccos(intermediate_size/(np.sqrt(2)*img_size)) #Check dataloading.ipynb in analysis-depot for more details theta = np.pi/4 -np.arccos(intermediate_size/(np.sqrt(2)*img_size)) #Check dataloading.ipynb in analysis-depot for more details
transform_rotate = transforms.Compose([transforms.ToTensor(), transform_rotate = transforms.Compose([transforms.ToTensor(),transforms.Normalize(mean=(0.5,0.5,0.5),std=(0.5,0.5,0.5)),
transforms.Resize(intermediate_size,antialias=True), transforms.Resize(intermediate_size,antialias=True),
transforms.RandomRotation(theta/np.pi*180,interpolation=transforms.InterpolationMode.BILINEAR), transforms.RandomRotation(theta/np.pi*180,interpolation=transforms.InterpolationMode.BILINEAR),
transforms.CenterCrop(img_size),transforms.RandomHorizontalFlip(p=0.5), transforms.CenterCrop(img_size),transforms.RandomHorizontalFlip(p=0.5)])
transforms.Normalize(mean=(0.5,0.5,0.5),std=(0.5,0.5,0.5))])
transform_randomcrop = transforms.Compose([transforms.ToTensor(), transform_randomcrop = transforms.Compose([transforms.ToTensor(),transforms.Normalize(mean=(0.5,0.5,0.5),std=(0.5,0.5,0.5)),
transforms.Resize(intermediate_size, antialias=True), transforms.Resize(intermediate_size),transforms.RandomCrop(img_size),transforms.RandomHorizontalFlip(p=0.5)])
transforms.RandomCrop(img_size),transforms.RandomHorizontalFlip(p=0.5),
transforms.Normalize(mean=(0.5,0.5,0.5),std=(0.5,0.5,0.5))])
self.transform = transforms.RandomChoice([transform_rotate,transform_randomcrop]) self.transform = transforms.RandomChoice([transform_rotate,transform_randomcrop])
else : else :
# for evaluation
self.transform = transforms.Compose([transforms.ToTensor(), self.transform = transforms.Compose([transforms.ToTensor(),
transforms.Lambda(lambda x: (x * 255).type(torch.uint8)),
transforms.Resize(img_size)]) transforms.Resize(img_size)])
if train==False:
# for testing
self.transform = transforms.Compose([transforms.ToTensor(),
transforms.Resize(intermediate_size, antialias=True),
transforms.Normalize(mean=(0.5,0.5,0.5),std=(0.5,0.5,0.5))])
def __len__(self): def __len__(self):
return len(self.df) return len(self.df)
def __getitem__(self,idx): def __getitem__(self,idx):
path = self.df.iloc[idx].Filepath path = self.df.iloc[idx].Filepaths
img = Image.open(path) img = Image.open(path)
return self.transform(img),0 return self.transform(img),0
def tensor2PIL(self,img): def tensor2PIL(self,img):
back2pil = transforms.Compose([transforms.Normalize(mean=(-1,-1,-1),std=(2,2,2)),transforms.ToPILImage()]) back2pil = transforms.Compose([transforms.Normalize(mean=(-1,-1,-1),std=(2,2,2)),transforms.ToPILImage()])
return back2pil(img) return back2pil(img)
\ No newline at end of file
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment