Skip to content
Snippets Groups Projects
Commit d0cf545a authored by Gonzalo Martin Garcia's avatar Gonzalo Martin Garcia
Browse files

insert original dataloader for training on Landscapes (LHQ256).

parent 4dd87bb4
No related branches found
No related tags found
No related merge requests found
......@@ -20,69 +20,53 @@ class UnconditionalDataset(Dataset):
"""
### Create DataFrame
#file_list = []
#for root, dirs, files in os.walk(fpath, topdown=False):
# for name in sorted(files):
# file_list.append(os.path.join(root, name))
#
#df = pd.DataFrame({"Filepath":file_list},)
#self.df = df[df["Filepath"].str.endswith(ext)]
file_list = []
for root, dirs, files in os.walk(fpath, topdown=False):
for name in sorted(files):
file_list.append(os.path.join(root, name))
df = pd.DataFrame({"Filepath":file_list},)
self.df = df[df["Filepath"].str.endswith(ext)]
if skip_first_n:
self.df = self.df[skip_first_n:]
print(fpath)
if train:
fpath = os.path.join(fpath, 'train')
df_train = self.df.sample(frac=frac,random_state=2)
self.df = df_train
else:
fpath = os.path.join(fpath, 'valid')
file_list =[]
for root, dirs, files in os.walk(fpath, topdown=False):
for name in sorted(files):
file_list.append(os.path.join(root, name))
df = pd.DataFrame({"Filepath":file_list},)
self.df = df[df["Filepath"].str.endswith(ext)]
df_train = self.df.sample(frac=frac,random_state=2)
df_test = df.drop(df_train.index)
self.df = df_test
if transform:
# for training
intermediate_size = 150
theta = np.pi/4 -np.arccos(intermediate_size/(np.sqrt(2)*img_size)) #Check dataloading.ipynb in analysis-depot for more details
transform_rotate = transforms.Compose([transforms.ToTensor(),
transform_rotate = transforms.Compose([transforms.ToTensor(),transforms.Normalize(mean=(0.5,0.5,0.5),std=(0.5,0.5,0.5)),
transforms.Resize(intermediate_size,antialias=True),
transforms.RandomRotation(theta/np.pi*180,interpolation=transforms.InterpolationMode.BILINEAR),
transforms.CenterCrop(img_size),transforms.RandomHorizontalFlip(p=0.5),
transforms.Normalize(mean=(0.5,0.5,0.5),std=(0.5,0.5,0.5))])
transforms.CenterCrop(img_size),transforms.RandomHorizontalFlip(p=0.5)])
transform_randomcrop = transforms.Compose([transforms.ToTensor(),
transforms.Resize(intermediate_size, antialias=True),
transforms.RandomCrop(img_size),transforms.RandomHorizontalFlip(p=0.5),
transforms.Normalize(mean=(0.5,0.5,0.5),std=(0.5,0.5,0.5))])
transform_randomcrop = transforms.Compose([transforms.ToTensor(),transforms.Normalize(mean=(0.5,0.5,0.5),std=(0.5,0.5,0.5)),
transforms.Resize(intermediate_size),transforms.RandomCrop(img_size),transforms.RandomHorizontalFlip(p=0.5)])
self.transform = transforms.RandomChoice([transform_rotate,transform_randomcrop])
else :
# for evaluation
self.transform = transforms.Compose([transforms.ToTensor(),
transforms.Lambda(lambda x: (x * 255).type(torch.uint8)),
transforms.Resize(img_size)])
if train==False:
# for testing
self.transform = transforms.Compose([transforms.ToTensor(),
transforms.Resize(intermediate_size, antialias=True),
transforms.Normalize(mean=(0.5,0.5,0.5),std=(0.5,0.5,0.5))])
def __len__(self):
return len(self.df)
def __getitem__(self,idx):
path = self.df.iloc[idx].Filepath
path = self.df.iloc[idx].Filepaths
img = Image.open(path)
return self.transform(img),0
def tensor2PIL(self,img):
back2pil = transforms.Compose([transforms.Normalize(mean=(-1,-1,-1),std=(2,2,2)),transforms.ToPILImage()])
return back2pil(img)
\ No newline at end of file
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment