diff --git a/dataloader/load.py b/dataloader/load.py index 2fb3593a714a6f92b5cf66dde4d3bec72f00a871..11c3e2d7ee0d640f27ecc2d19b3d3635dfc77d80 100644 --- a/dataloader/load.py +++ b/dataloader/load.py @@ -43,13 +43,18 @@ class UnconditionalDataset_LHQ(Dataset): intermediate_size = 150 theta = np.pi/4 -np.arccos(intermediate_size/(np.sqrt(2)*img_size)) #Check dataloading.ipynb in analysis-depot for more details - transform_rotate = transforms.Compose([transforms.ToTensor(),transforms.Normalize(mean=(0.5,0.5,0.5),std=(0.5,0.5,0.5)), - transforms.Resize(intermediate_size,antialias=True), - transforms.RandomRotation(theta/np.pi*180,interpolation=transforms.InterpolationMode.BILINEAR), - transforms.CenterCrop(img_size),transforms.RandomHorizontalFlip(p=0.5)]) + transform_rotate = transforms.Compose([transforms.ToTensor(), + transforms.Resize(intermediate_size,antialias=True), + transforms.RandomRotation(theta/np.pi*180,interpolation=transforms.InterpolationMode.BILINEAR), + transforms.CenterCrop(img_size), + transforms.RandomHorizontalFlip(p=0.5), + transforms.Normalize(mean=(0.5,0.5,0.5),std=(0.5,0.5,0.5))]) - transform_randomcrop = transforms.Compose([transforms.ToTensor(),transforms.Normalize(mean=(0.5,0.5,0.5),std=(0.5,0.5,0.5)), - transforms.Resize(intermediate_size),transforms.RandomCrop(img_size),transforms.RandomHorizontalFlip(p=0.5)]) + transform_randomcrop = transforms.Compose([transforms.ToTensor(), + transforms.Resize(intermediate_size), + transforms.RandomCrop(img_size), + transforms.RandomHorizontalFlip(p=0.5), + transforms.Normalize(mean=(0.5,0.5,0.5),std=(0.5,0.5,0.5))]) self.transform = transforms.RandomChoice([transform_rotate,transform_randomcrop]) else : @@ -100,16 +105,16 @@ class UnconditionalDataset_CelebAHQ(Dataset): theta = np.pi/4 -np.arccos(intermediate_size/(np.sqrt(2)*img_size)) #Check dataloading.ipynb in analysis-depot for more details transform_rotate_flip = transforms.Compose([transforms.ToTensor(), - transforms.Resize(intermediate_size,antialias=True), - transforms.RandomRotation((theta/np.pi*180),interpolation=transforms.InterpolationMode.BILINEAR), - transforms.CenterCrop(img_size), - transforms.RandomHorizontalFlip(p=0.5), - transforms.Normalize(mean=(0.5,0.5,0.5),std=(0.5,0.5,0.5))]) + transforms.Resize(intermediate_size,antialias=True), + transforms.RandomRotation((theta/np.pi*180),interpolation=transforms.InterpolationMode.BILINEAR), + transforms.CenterCrop(img_size), + transforms.RandomHorizontalFlip(p=0.5), + transforms.Normalize(mean=(0.5,0.5,0.5),std=(0.5,0.5,0.5))]) transform_flip = transforms.Compose([transforms.ToTensor(), - transforms.Resize(img_size, antialias=True), - transforms.RandomHorizontalFlip(p=0.5), - transforms.Normalize(mean=(0.5,0.5,0.5),std=(0.5,0.5,0.5))]) + transforms.Resize(img_size, antialias=True), + transforms.RandomHorizontalFlip(p=0.5), + transforms.Normalize(mean=(0.5,0.5,0.5),std=(0.5,0.5,0.5))]) self.transform = transforms.RandomChoice([transform_rotate_flip,transform_flip]) else : diff --git a/main.py b/main.py index e7d5ca4233a6e9805cfd3940f001486649bf1446..0618633fcbad881645933098121cb7a385edf615 100644 --- a/main.py +++ b/main.py @@ -6,12 +6,12 @@ from models.Framework import * from trainer.train import ddpm_trainer from evaluation.sample import ddpm_sampler from evaluation.evaluate import ddpm_evaluator -from models.all_unets import * +from models.unets import * import torch def train_func(f): - #load all settings + # Load Settings device = 'cuda' if torch.cuda.is_available() else 'cpu' print(f"device: {device}\n\n") print(f"folderpath: {f}\n\n") @@ -31,19 +31,17 @@ def train_func(f): training_setting = json.load(fp) training_setting["optimizer_class"] = eval(training_setting["optimizer_class"]) - + # init dataloaders batchsize = meta_setting["batchsize"] training_dataset = globals()[meta_setting["dataset"]](train = True,**dataset_setting) test_dataset = globals()[meta_setting["dataset"]](train = False,**dataset_setting) - training_dataloader = torch.utils.data.DataLoader( training_dataset,batch_size=batchsize) test_dataloader = torch.utils.data.DataLoader(test_dataset,batch_size=batchsize) - - + # init UNet net = globals()[meta_setting["modelname"]](**model_setting).to(device) #net = torch.compile(net) net = net.to(device) - + # init Diffusion Model framework = globals()[meta_setting["framework"]](net = net,device=device, **framework_setting) print(f"META SETTINGS:\n\n {meta_setting}\n\n") @@ -77,9 +75,8 @@ def sample_func(f): # init Unet net = globals()[meta_setting["modelname"]](**model_setting).to(device) - #net = torch.compile(net) net = net.to(device) - # init unconditional diffusion model + # init Diffusion Model framework = globals()[meta_setting["framework"]](net = net,device=device, **framework_setting) print(f"META SETTINGS:\n\n {meta_setting}\n\n") diff --git a/models/all_unets.py b/models/unets.py similarity index 100% rename from models/all_unets.py rename to models/unets.py diff --git a/trainer/train.py b/trainer/train.py index 5367dcb848f9c29c3d477012660e898f09535b1f..dd882fe8a27c1cec3f3df2f9f5005d57641ca2c2 100644 --- a/trainer/train.py +++ b/trainer/train.py @@ -85,7 +85,6 @@ class ModelEmaV2(nn.Module): # Training function for the unconditional diffusion model - def ddpm_trainer(model, device, trainloader, testloader, @@ -157,7 +156,7 @@ def ddpm_trainer(model, # load learning rate schedule state scheduler_state_dict = checkpoint['scheduler'] scheduler.load_state_dict(scheduler_state_dict) - scheduler.last_epoch = last_epoch + scheduler.last_epoch = (last_epoch+1)*len(trainloader) # load ema model state if ema_training: ema.module.load_state_dict(checkpoint['ema'])