diff --git a/experiment_creator.ipynb b/experiment_creator.ipynb index c877ff2dac828a17f2b46056fd3f482ca2dd9de2..202045fdad5b4e1e674b59c27fd12c4383265ea6 100644 --- a/experiment_creator.ipynb +++ b/experiment_creator.ipynb @@ -46,8 +46,8 @@ "datapath = \"/work/lect0100/lhq_256\"\n", "\n", "# Experiment setup\n", - "run_name = 'big_diffusers_fin' # WANDB and experiment folder Name!\n", - "checkpoint = 'model_epoch_8.pth' # Name of checkpoint pth file or None \n", + "run_name = 'batch_timesteps' # WANDB and experiment folder Name!\n", + "checkpoint = None #'model_epoch_8.pth' # Name of checkpoint pth file or None \n", "experiment_path = '/work/lect0100/experiments_gonzalo/'+ run_name +'/'\n", "\n", "# Path to save generated experiment folder on local machine\n", @@ -55,16 +55,15 @@ "\n", "# Diffusion Model Settings\n", "diffusion_steps = 200\n", - "image_size = 128\n", + "image_size = 64\n", "channels = 3\n", "\n", "# Training\n", - "batchsize = 8\n", + "batchsize = 32\n", "epochs = 30\n", - "store_iter = 3\n", + "store_iter = 1\n", "eval_iter = 500\n", "learning_rate = 0.0001\n", - "lr_schedule = False\n", "optimizername = \"torch.optim.AdamW\"\n", "optimizer_params = None\n", "verbose = True\n", @@ -133,11 +132,12 @@ " optimizer_params = optimizer_params,\n", " #optimizer_params=dict(lr=learning_rate), # don't change! \n", " learning_rate = learning_rate,\n", - " lr_schedule = lr_schedule,\n", " run_name=run_name,\n", " checkpoint= checkpoint,\n", " experiment_path = experiment_path,\n", " verbose = verbose,\n", + " T_max = 5*10000, # cosine lr param\n", + " eta_min= 1e-5, # cosine lr param\n", " )\n", "sampling_setting = dict( \n", " checkpoint = checkpoint, \n", @@ -164,13 +164,13 @@ "create folder\n", "folder created \n", "stored json files in folder\n", - "{'modelname': 'UNet_Unconditional_Diffusion_Bottleneck_Variant', 'dataset': 'UnconditionalDataset', 'framework': 'DDPM', 'trainloop_function': 'ddpm_trainer', 'sampling_function': 'ddpm_sampler', 'evaluation_function': 'ddpm_evaluator', 'batchsize': 8}\n", - "{'fpath': '/work/lect0100/lhq_256', 'img_size': 128, 'frac': 0.8, 'skip_first_n': 0, 'ext': '.png', 'transform': True}\n", - "{'channels_in': 3, 'channels_out': 3, 'activation': 'relu', 'weight_init': 'he', 'projection_features': 64, 'time_dim': 8, 'time_channels': 200, 'num_stages': 4, 'stage_list': None, 'num_blocks': 1, 'num_groupnorm_groups': 32, 'dropout': 0.1, 'attention_list': None, 'num_attention_heads': 1}\n", - "{'diffusion_steps': 200, 'out_shape': (3, 128, 128), 'noise_schedule': 'linear', 'beta_1': 0.0001, 'beta_T': 0.02, 'alpha_bar_lower_bound': 0.9, 'var_schedule': 'same', 'kl_loss': 'simplified', 'recon_loss': 'nll'}\n", - "{'epochs': 30, 'store_iter': 3, 'eval_iter': 500, 'optimizer_class': 'torch.optim.AdamW', 'optimizer_params': None, 'learning_rate': 0.0001, 'lr_schedule': False, 'run_name': 'big_diffusers_fin', 'checkpoint': 'model_epoch_8.pth', 'experiment_path': '/work/lect0100/experiments_gonzalo/big_diffusers_fin/', 'verbose': True}\n", - "{'checkpoint': 'model_epoch_8.pth', 'experiment_path': '/work/lect0100/experiments_gonzalo/big_diffusers_fin/', 'batch_size': 10, 'intermediate': False}\n", - "{'checkpoint': 'model_epoch_8.pth', 'experiment_path': '/work/lect0100/experiments_gonzalo/big_diffusers_fin/'}\n" + "{'modelname': 'UNet_Unconditional_Diffusion_Bottleneck_Variant', 'dataset': 'UnconditionalDataset', 'framework': 'DDPM', 'trainloop_function': 'ddpm_trainer', 'sampling_function': 'ddpm_sampler', 'evaluation_function': 'ddpm_evaluator', 'batchsize': 32}\n", + "{'fpath': '/work/lect0100/lhq_256', 'img_size': 64, 'frac': 0.8, 'skip_first_n': 0, 'ext': '.png', 'transform': True}\n", + "{'channels_in': 3, 'channels_out': 3, 'activation': 'relu', 'weight_init': 'he', 'projection_features': 64, 'time_dim': 32, 'time_channels': 200, 'num_stages': 4, 'stage_list': None, 'num_blocks': 1, 'num_groupnorm_groups': 32, 'dropout': 0.1, 'attention_list': None, 'num_attention_heads': 1}\n", + "{'diffusion_steps': 200, 'out_shape': (3, 64, 64), 'noise_schedule': 'linear', 'beta_1': 0.0001, 'beta_T': 0.02, 'alpha_bar_lower_bound': 0.9, 'var_schedule': 'same', 'kl_loss': 'simplified', 'recon_loss': 'nll'}\n", + "{'epochs': 30, 'store_iter': 1, 'eval_iter': 500, 'optimizer_class': 'torch.optim.AdamW', 'optimizer_params': None, 'learning_rate': 0.0001, 'run_name': 'batch_timesteps', 'checkpoint': None, 'experiment_path': '/work/lect0100/experiments_gonzalo/batch_timesteps/', 'verbose': True}\n", + "{'checkpoint': None, 'experiment_path': '/work/lect0100/experiments_gonzalo/batch_timesteps/', 'batch_size': 10, 'intermediate': False}\n", + "{'checkpoint': None, 'experiment_path': '/work/lect0100/experiments_gonzalo/batch_timesteps/'}\n" ] } ], diff --git a/models/Framework.py b/models/Framework.py index a25056c46a0a90b45e141c4c1c8d2ece97725c93..32ffa6c53f671fd9ce994f4b24641a39bbb81fee 100644 --- a/models/Framework.py +++ b/models/Framework.py @@ -68,9 +68,9 @@ class DDPM(nn.Module): self.recon_loss = recon_loss self.out_shape = out_shape # precomputed for efficiency reasons - self.noise_scaler = (1-alpha)/( self.sqrt_1_minus_alpha_bar) - self.mean_scaler = 1/torch.sqrt(self.alpha) - self.mse_weight = (self.beta**2)/(2*self.var*self.alpha*(1-self.alpha_bar)) + self.noise_scaler = ((1-alpha)/( self.sqrt_1_minus_alpha_bar)) + self.mean_scaler = (1/torch.sqrt(self.alpha)) + self.mse_weight = ((self.beta**2)/(2*self.var*self.alpha*(1-self.alpha_bar))) @staticmethod def linear_schedule(diffusion_steps, beta_1, beta_T, device): @@ -191,8 +191,8 @@ class DDPM(nn.Module): ''' if t is None: t = self.diffusion_steps - elif t == 0: - return x_0, torch.zeros(x_0.shape, device = self.device) + elif torch.any(t == 0): + raise ValueError("The tensor 't' contains a timestep zero.") forward_noise = torch.randn(x_0.shape, device = self.device) x_T = self.noised_latent(forward_noise, x_0, t) return x_T , forward_noise @@ -228,8 +228,8 @@ class DDPM(nn.Module): mean (tensor): Batch of means for the complete noise distribution for each image in the batch x_0 std (tensor): Batch of std scalars for the complete noise distribution for each image in the batch x_0 ''' - mean = self.sqrt_alpha_bar[t-1]*x_0 - std = self.sqrt_1_minus_alpha_bar[t-1] + mean = self.sqrt_alpha_bar[t-1].view(-1, 1, 1, 1)*x_0 + std = self.sqrt_1_minus_alpha_bar[t-1].view(-1, 1, 1, 1) return mean, std @torch.no_grad() @@ -245,8 +245,8 @@ class DDPM(nn.Module): mean (tensor): Batch of means for the individual noise distribution for each image in the batch x_t_1 std (tensor): Batch of std scalars for the individual noise distribution for each image in the batch x_t_1 ''' - mean = torch.sqrt(1-self.beta[t-1])*x_t_1 - std = torch.sqrt(self.beta[t-1]) + mean = torch.sqrt(1-self.beta[t-1]).view(-1, 1, 1, 1)*x_t_1 + std = torch.sqrt(self.beta[t-1]).view(-1, 1, 1, 1) return mean, std @@ -288,8 +288,8 @@ class DDPM(nn.Module): pred_noise (tensor): Predicted noise for each image in the batch x_t ''' pred_noise = self.net(x_t,t,return_dict=False)[0] - mean = self.mean_scaler[t-1]*(x_t - self.noise_scaler[t-1]*pred_noise) - std = self.std[t-1] + mean = self.mean_scaler[t-1].view(-1, 1, 1, 1)*(x_t - self.noise_scaler[t-1].view(-1, 1, 1, 1)*pred_noise) + std = self.std[t-1].view(-1, 1, 1, 1) return mean, std, pred_noise @@ -370,10 +370,9 @@ class DDPM(nn.Module): Returns: x (tensor): Contains the self.diffusion_steps+1 denoised image tensors ''' - # start with an image of pure noise and store it as part of the output + # start with an image of pure noise (batch_size 1) and store it as part of the output x_t_1 = torch.randn((1,) + tuple(self.out_shape), device=self.device) - #x_t_1 = torch.randn(self.out_shape, device=self.device) - x = torch.empty((self.diffusion_steps+1,) + tuple(self.out_shape), device=self.device) + x = torch.empty((self.diffusion_steps+1,1,) + tuple(self.out_shape), device=self.device) x[-1] = x_t_1 # apply reverse trajectory for t in reversed(range(1, self.diffusion_steps+1)): @@ -408,7 +407,7 @@ class DDPM(nn.Module): ''' Returns the mathematically correct weighted version of the simplified loss. ''' - return self.mse_weight[t-1]*F.mse_loss(forward_noise, pred_noise) + return self.mse_weight[t-1].view(-1, 1, 1, 1)*F.mse_loss(forward_noise, pred_noise) # If t=0 and self.recon_loss == 'nll' @@ -418,3 +417,6 @@ class DDPM(nn.Module): denoising Gaussian distribution with mean mean_1 and standard deviation std_1. ''' return -torch.distributions.Normal(mean_1, std_1).log_prob(x_0).mean() + + + diff --git a/trainer/train.py b/trainer/train.py index b043512f41040a443b40c2fe96d3e91dc48b4d3d..13e919158b0d6b272773c107d2ca0a614fa08fc1 100644 --- a/trainer/train.py +++ b/trainer/train.py @@ -64,11 +64,12 @@ def ddpm_trainer(model, optimizer_class=torch.optim.AdamW, optimizer_params=None, learning_rate = 0.001, - lr_schedule = False, verbose = False, run_name=None, checkpoint= None, experiment_path = None, + T_max = 5*10000, # None, + eta_min= 1e-5, **args ): ''' @@ -79,23 +80,25 @@ def ddpm_trainer(model, epochs: Number of epochs we train the model further. optimizer_class: PyTorch optimizer. optimizer_param: Parameters for the PyTorch optimizer. - learning_rate: For optimizer initialization or if a checkpoint exists for our manual learning rate. - lr_schedule: If True, manually sets the learning rate of the optimizer to the given one. + learning_rate: For optimizer initialization when training from zero, i.e. no checkpoint verbose: If True, prints the running losses for every epoch. run_name: Run name for WandB. IF YOU TRAIN FROM CHECKPOINT MAKE SURE TO USE THE SAME 'run_name' FOR THE DATA TO BE LOGGED ON THE SAME WANDB RUN! trainloader: Loads the train dataset testloader: Loads the test dataset checkpoint: Name of the saved pth. file containing the trained weights and biases + T_max: CosineAnnealingLR scheduler argument (nr of steps in training for a full cycle) + eta_min: CosineAnnealingLR scheduler argument (scheduler oscillates between highest lr 'leraning_rate' and minimum lr 'eta_min') ''' # set optimizer parameters and learning rate if optimizer_params is None: optimizer_params = dict(lr=learning_rate) - else: - optimizer_params['lr'] = learning_rate optimizer = optimizer_class(model.net.parameters(), **optimizer_params) + # set lr cosine schedule (comonly used in diffusion models) + scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=T_max, eta_min=eta_min) + # if checkpoint path is given, load the model from checkpoint last_epoch = -1 if checkpoint: @@ -112,10 +115,10 @@ def ddpm_trainer(model, # load optimizer state optimizer_state_dict = checkpoint['optimizer'] optimizer.load_state_dict(optimizer_state_dict) - # If we want to decrease the learning rate manually - if lr_schedule: - for param_group in optimizer.param_groups: - param_group['lr'] = learning_rate + # load learning rate schedule state + scheduler_state_dict = checkpoint['scheduler'] + scheduler.load_state_dict(scheduler_state_dict) + scheduler.last_epoch = last_epoch except Exception as e: print("Error loading checkpoint. Exception: ", e) @@ -129,7 +132,7 @@ def ddpm_trainer(model, low = 1 if model.recon_loss == 'nll': low = 0 - + # Using W&B with wandb.init(project='test-project', name=run_name, entity='gonzalomartingarcia0', id=run_name, resume=True) as run: @@ -137,41 +140,57 @@ def ddpm_trainer(model, run.config.learning_rate = learning_rate run.config.optimizer = optimizer.__class__.__name__ run.watch(model.net) - # log the learning rate in each run s.t. for checkpoint training we can see at what times the learning rate has been - # manually stepped - wandb.log({"learning_rate": learning_rate}) # training loop # last model was stored at epoch last_epoch, we continue training from there, i.e. last_epoch+1 (else we start at epoch 0) for epoch in range(last_epoch+1, (last_epoch+1)+epochs): running_trainloss = 0 nr_train_batches = 0 - + # train model.net.train() for idx,(x_0, _) in enumerate(trainloader): x_0 = x_0.to(device) - t = torch.randint(low=low, high=model.diffusion_steps, size=(1,)).item() + t = torch.randint(low=low, high=model.diffusion_steps, size=(x_0.shape[0],), device = device) optimizer.zero_grad() + + # Define masks for zero and non-zero elements of t + mask_zero_t = (t == 0) + mask_non_zero_t = (t != 0) - if t>0: - x_t, forward_noise = model.forward_trajectory(x_0,t) - _, _, pred_noise = model.forward(x_t,t) - loss = loss_func(forward_noise,pred_noise,t) + t[mask_zero_t] = 1 + x_t, forward_noise = model.forward_trajectory(x_0,t) + mean, std, pred_noise = model.forward(x_t,t) + + loss = 0 + # Compute kl loss + if torch.any(mask_non_zero_t): + loss = loss_func(forward_noise[mask_non_zero_t], pred_noise[mask_non_zero_t], t[mask_non_zero_t]) running_trainloss += loss.item() nr_train_batches += 1 - run.log({'loss': loss.item(), 'epoch': epoch, 'batch': idx}) - else: # reconstruction loss - x_1, _ = model.forward_trajectory(x_0,1) - mean_1, std_1, _ = model.forward(x_1,1) - loss = model.loss_recon(x_0, mean_1, std_1) - run.log({'recon_loss': loss.item(),'epoch': epoch, 'batch': idx}) - - loss.backward() + run.log({'loss': loss.item(), "learning_rate": scheduler.get_last_lr()[0], 'epoch': epoch, 'batch': idx}) + + # If reconstrcution loss was drawn + if torch.any(mask_zero_t): + recon_loss = model.loss_recon(x_0[mask_zero_t], mean[mask_zero_t], std[mask_zero_t]) + loss += recon_loss + run.log({'recon_loss': recon_loss.item(), 'epoch': epoch, 'batch': idx}) + + loss.backward() optimizer.step() + scheduler.step() + if verbose: print(f"Loss in epoch {epoch}:{running_trainloss/nr_train_batches}") - run.log({'running_loss': running_trainloss/nr_train_batches}) + run.log({'running_loss': running_trainloss/nr_train_batches}) + + # WORKING OLD VERSION + #x_t, forward_noise = model.forward_trajectory(x_0,t) + #_, _, pred_noise = model.forward(x_t,t) + #loss = loss_func(forward_noise,pred_noise,t) + #running_trainloss += loss.item() + #nr_train_batches += 1 + #run.log({'loss': loss.item(), "learning_rate": scheduler.get_last_lr()[0], 'epoch': epoch, 'batch': idx}) # evaluation if ((epoch+1) % eval_iter == 0) or ((epoch+1) % store_iter == 0): @@ -182,21 +201,30 @@ def ddpm_trainer(model, with torch.no_grad(): for idx,(x_0,_) in enumerate(testloader): x_0 = x_0.to(device) - t = torch.randint(low=low,high=model.diffusion_steps, size=(1,)).item() + t = torch.randint(low=low, high=model.diffusion_steps, size=(x_0.shape[0],), device = device) - if t>0: - x_t, forward_noise = model.forward_trajectory(x_0,t) - _, _, pred_noise = model.forward(x_t,t) - loss = loss_func(forward_noise,pred_noise, t) + # Define masks for zero and non-zero elements of t + mask_zero_t = (t == 0) + mask_non_zero_t = (t != 0) + + t[mask_zero_t] = 1 + x_t, forward_noise = model.forward_trajectory(x_0,t) + mean, std, pred_noise = model.forward(x_t,t) + + loss = 0 + # Compute kl loss + if torch.any(mask_non_zero_t): + loss = loss_func(forward_noise[mask_non_zero_t], pred_noise[mask_non_zero_t], t[mask_non_zero_t]) running_testloss += loss.item() nr_test_batches += 1 - run.log({'test_loss': loss.item(), 'epoch': epoch, 'batch': idx}) - else: # reconstruction loss - x_1, _ = model.forward_trajectory(x_0,1) - mean_1, std_1, _ = model.forward(x_1,1) - loss = model.loss_recon(x_0, mean_1, std_1) - run.log({'recon_test_loss': loss.item(), 'epoch': epoch, 'batch': idx}) - + run.log({'test_loss': loss.item(), "learning_rate": scheduler.get_last_lr()[0], 'epoch': epoch, 'batch': idx}) + + # If reconstrcution loss was drawn + if torch.any(mask_zero_t): + recon_loss = model.loss_recon(x_0[mask_zero_t], mean[mask_zero_t], std[mask_zero_t]) + loss += recon_loss + run.log({'recon_test_loss': recon_loss.item(), 'epoch': epoch, 'batch': idx}) + if verbose: print(f"Test loss in epoch {epoch}:{running_testloss/nr_test_batches}") run.log({'running_test_loss': running_testloss/nr_test_batches}) @@ -209,11 +237,12 @@ def ddpm_trainer(model, 'epoch': epoch, 'model': model.net.state_dict(), 'optimizer': optimizer.state_dict(), + 'scheduler': scheduler.state_dict(), 'running_loss': running_trainloss/nr_train_batches, 'running_test_loss': running_testloss/nr_test_batches, }, os.path.join(save_dir, f"model_epoch_{epoch}.pth")) - # always store the last version of the model + # always store the last version of the model if we trained through all epochs final = (last_epoch+1)+epochs save_dir = os.path.join(experiment_path, 'trained_ddpm/') os.makedirs(save_dir, exist_ok=True) @@ -221,6 +250,9 @@ def ddpm_trainer(model, 'epoch': final, 'model': model.net.state_dict(), 'optimizer': optimizer.state_dict(), + 'scheduler': scheduler.state_dict(), 'running_loss': running_trainloss/nr_train_batches, 'running_test_loss': running_testloss/nr_test_batches, }, os.path.join(save_dir, f"model_epoch_{final}.pth")) + +