Skip to content
Snippets Groups Projects
Commit adcd00d3 authored by Gonzalo Martin Garcia's avatar Gonzalo Martin Garcia
Browse files

minor changes to possible bugs in dataloader and training function

parent 968a1b01
No related branches found
No related tags found
No related merge requests found
......@@ -43,13 +43,18 @@ class UnconditionalDataset_LHQ(Dataset):
intermediate_size = 150
theta = np.pi/4 -np.arccos(intermediate_size/(np.sqrt(2)*img_size)) #Check dataloading.ipynb in analysis-depot for more details
transform_rotate = transforms.Compose([transforms.ToTensor(),transforms.Normalize(mean=(0.5,0.5,0.5),std=(0.5,0.5,0.5)),
transform_rotate = transforms.Compose([transforms.ToTensor(),
transforms.Resize(intermediate_size,antialias=True),
transforms.RandomRotation(theta/np.pi*180,interpolation=transforms.InterpolationMode.BILINEAR),
transforms.CenterCrop(img_size),transforms.RandomHorizontalFlip(p=0.5)])
transforms.CenterCrop(img_size),
transforms.RandomHorizontalFlip(p=0.5),
transforms.Normalize(mean=(0.5,0.5,0.5),std=(0.5,0.5,0.5))])
transform_randomcrop = transforms.Compose([transforms.ToTensor(),transforms.Normalize(mean=(0.5,0.5,0.5),std=(0.5,0.5,0.5)),
transforms.Resize(intermediate_size),transforms.RandomCrop(img_size),transforms.RandomHorizontalFlip(p=0.5)])
transform_randomcrop = transforms.Compose([transforms.ToTensor(),
transforms.Resize(intermediate_size),
transforms.RandomCrop(img_size),
transforms.RandomHorizontalFlip(p=0.5),
transforms.Normalize(mean=(0.5,0.5,0.5),std=(0.5,0.5,0.5))])
self.transform = transforms.RandomChoice([transform_rotate,transform_randomcrop])
else :
......
......@@ -6,12 +6,12 @@ from models.Framework import *
from trainer.train import ddpm_trainer
from evaluation.sample import ddpm_sampler
from evaluation.evaluate import ddpm_evaluator
from models.all_unets import *
from models.unets import *
import torch
def train_func(f):
#load all settings
# Load Settings
device = 'cuda' if torch.cuda.is_available() else 'cpu'
print(f"device: {device}\n\n")
print(f"folderpath: {f}\n\n")
......@@ -31,19 +31,17 @@ def train_func(f):
training_setting = json.load(fp)
training_setting["optimizer_class"] = eval(training_setting["optimizer_class"])
# init dataloaders
batchsize = meta_setting["batchsize"]
training_dataset = globals()[meta_setting["dataset"]](train = True,**dataset_setting)
test_dataset = globals()[meta_setting["dataset"]](train = False,**dataset_setting)
training_dataloader = torch.utils.data.DataLoader( training_dataset,batch_size=batchsize)
test_dataloader = torch.utils.data.DataLoader(test_dataset,batch_size=batchsize)
# init UNet
net = globals()[meta_setting["modelname"]](**model_setting).to(device)
#net = torch.compile(net)
net = net.to(device)
# init Diffusion Model
framework = globals()[meta_setting["framework"]](net = net,device=device, **framework_setting)
print(f"META SETTINGS:\n\n {meta_setting}\n\n")
......@@ -77,9 +75,8 @@ def sample_func(f):
# init Unet
net = globals()[meta_setting["modelname"]](**model_setting).to(device)
#net = torch.compile(net)
net = net.to(device)
# init unconditional diffusion model
# init Diffusion Model
framework = globals()[meta_setting["framework"]](net = net,device=device, **framework_setting)
print(f"META SETTINGS:\n\n {meta_setting}\n\n")
......
File moved
......@@ -85,7 +85,6 @@ class ModelEmaV2(nn.Module):
# Training function for the unconditional diffusion model
def ddpm_trainer(model,
device,
trainloader, testloader,
......@@ -157,7 +156,7 @@ def ddpm_trainer(model,
# load learning rate schedule state
scheduler_state_dict = checkpoint['scheduler']
scheduler.load_state_dict(scheduler_state_dict)
scheduler.last_epoch = last_epoch
scheduler.last_epoch = (last_epoch+1)*len(trainloader)
# load ema model state
if ema_training:
ema.module.load_state_dict(checkpoint['ema'])
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment