From fb2e6b8611d23fbf60be10761a2ff11a3afe8159 Mon Sep 17 00:00:00 2001 From: gonzalomartingarcia0 <gonzalomartingarcia0@gmail.com> Date: Thu, 20 Jul 2023 14:12:39 +0200 Subject: [PATCH] Completed Class Conditional Diffusion Model. This version is intended to be trained on a class labeled dataset. It makes use of classifier-free guided diffusion to boost sample quality of the model when generating images from each class. The conditioning mechanism in the UNet (implemented together with Roy) simply adds an embedding of the class to the time embedding before passing them through the learnable reshaping layers for each block. Successfully trained on 3 class dataset, dogs, cats and wildlife. --- .gitignore | 7 + README.md | 91 ------ __init__.py | 0 dataloader/__init__.py | 0 dataloader/load.py | 76 +++++ evaluation/__init__.py | 0 evaluation/evaluate.py | 108 +++++++ evaluation/sample.py | 86 ++++++ experiment_creator.ipynb | 261 ++++++++++++++++ main.py | 153 +++++++++ models/ConditionalDiffusionModel.py | 463 ++++++++++++++++++++++++++++ models/__init__.py | 0 models/conditional_unet.py | 256 +++++++++++++++ trainer/__init__.py | 0 trainer/train.py | 321 +++++++++++++++++++ 15 files changed, 1731 insertions(+), 91 deletions(-) create mode 100644 .gitignore create mode 100644 __init__.py create mode 100644 dataloader/__init__.py create mode 100644 dataloader/load.py create mode 100644 evaluation/__init__.py create mode 100644 evaluation/evaluate.py create mode 100644 evaluation/sample.py create mode 100644 experiment_creator.ipynb create mode 100644 main.py create mode 100644 models/ConditionalDiffusionModel.py create mode 100644 models/__init__.py create mode 100644 models/conditional_unet.py create mode 100644 trainer/__init__.py create mode 100644 trainer/train.py diff --git a/.gitignore b/.gitignore new file mode 100644 index 0000000..339aa3e --- /dev/null +++ b/.gitignore @@ -0,0 +1,7 @@ +.DS_Store +*/__pycache__ +*/trained_ddpm +root +experiments +trainer/__pycache__ +wandb diff --git a/README.md b/README.md index 6cbd4e3..373dfc9 100644 --- a/README.md +++ b/README.md @@ -1,92 +1 @@ # Conditional Diffusion - - - -## Getting started - -To make it easy for you to get started with GitLab, here's a list of recommended next steps. - -Already a pro? Just edit this README.md and make it your own. Want to make it easy? [Use the template at the bottom](#editing-this-readme)! - -## Add your files - -- [ ] [Create](https://docs.gitlab.com/ee/user/project/repository/web_editor.html#create-a-file) or [upload](https://docs.gitlab.com/ee/user/project/repository/web_editor.html#upload-a-file) files -- [ ] [Add files using the command line](https://docs.gitlab.com/ee/gitlab-basics/add-file.html#add-a-file-using-the-command-line) or push an existing Git repository with the following command: - -``` -cd existing_repo -git remote add origin https://git.rwth-aachen.de/diffusion-project/conditional-diffusion.git -git branch -M main -git push -uf origin main -``` - -## Integrate with your tools - -- [ ] [Set up project integrations](https://git.rwth-aachen.de/diffusion-project/conditional-diffusion/-/settings/integrations) - -## Collaborate with your team - -- [ ] [Invite team members and collaborators](https://docs.gitlab.com/ee/user/project/members/) -- [ ] [Create a new merge request](https://docs.gitlab.com/ee/user/project/merge_requests/creating_merge_requests.html) -- [ ] [Automatically close issues from merge requests](https://docs.gitlab.com/ee/user/project/issues/managing_issues.html#closing-issues-automatically) -- [ ] [Enable merge request approvals](https://docs.gitlab.com/ee/user/project/merge_requests/approvals/) -- [ ] [Automatically merge when pipeline succeeds](https://docs.gitlab.com/ee/user/project/merge_requests/merge_when_pipeline_succeeds.html) - -## Test and Deploy - -Use the built-in continuous integration in GitLab. - -- [ ] [Get started with GitLab CI/CD](https://docs.gitlab.com/ee/ci/quick_start/index.html) -- [ ] [Analyze your code for known vulnerabilities with Static Application Security Testing(SAST)](https://docs.gitlab.com/ee/user/application_security/sast/) -- [ ] [Deploy to Kubernetes, Amazon EC2, or Amazon ECS using Auto Deploy](https://docs.gitlab.com/ee/topics/autodevops/requirements.html) -- [ ] [Use pull-based deployments for improved Kubernetes management](https://docs.gitlab.com/ee/user/clusters/agent/) -- [ ] [Set up protected environments](https://docs.gitlab.com/ee/ci/environments/protected_environments.html) - -*** - -# Editing this README - -When you're ready to make this README your own, just edit this file and use the handy template below (or feel free to structure it however you want - this is just a starting point!). Thank you to [makeareadme.com](https://www.makeareadme.com/) for this template. - -## Suggestions for a good README -Every project is different, so consider which of these sections apply to yours. The sections used in the template are suggestions for most open source projects. Also keep in mind that while a README can be too long and detailed, too long is better than too short. If you think your README is too long, consider utilizing another form of documentation rather than cutting out information. - -## Name -Choose a self-explaining name for your project. - -## Description -Let people know what your project can do specifically. Provide context and add a link to any reference visitors might be unfamiliar with. A list of Features or a Background subsection can also be added here. If there are alternatives to your project, this is a good place to list differentiating factors. - -## Badges -On some READMEs, you may see small images that convey metadata, such as whether or not all the tests are passing for the project. You can use Shields to add some to your README. Many services also have instructions for adding a badge. - -## Visuals -Depending on what you are making, it can be a good idea to include screenshots or even a video (you'll frequently see GIFs rather than actual videos). Tools like ttygif can help, but check out Asciinema for a more sophisticated method. - -## Installation -Within a particular ecosystem, there may be a common way of installing things, such as using Yarn, NuGet, or Homebrew. However, consider the possibility that whoever is reading your README is a novice and would like more guidance. Listing specific steps helps remove ambiguity and gets people to using your project as quickly as possible. If it only runs in a specific context like a particular programming language version or operating system or has dependencies that have to be installed manually, also add a Requirements subsection. - -## Usage -Use examples liberally, and show the expected output if you can. It's helpful to have inline the smallest example of usage that you can demonstrate, while providing links to more sophisticated examples if they are too long to reasonably include in the README. - -## Support -Tell people where they can go to for help. It can be any combination of an issue tracker, a chat room, an email address, etc. - -## Roadmap -If you have ideas for releases in the future, it is a good idea to list them in the README. - -## Contributing -State if you are open to contributions and what your requirements are for accepting them. - -For people who want to make changes to your project, it's helpful to have some documentation on how to get started. Perhaps there is a script that they should run or some environment variables that they need to set. Make these steps explicit. These instructions could also be useful to your future self. - -You can also document commands to lint the code or run tests. These steps help to ensure high code quality and reduce the likelihood that the changes inadvertently break something. Having instructions for running tests is especially helpful if it requires external setup, such as starting a Selenium server for testing in a browser. - -## Authors and acknowledgment -Show your appreciation to those who have contributed to the project. - -## License -For open source projects, say how it is licensed. - -## Project status -If you have run out of energy or time for your project, put a note at the top of the README saying that development has slowed down or stopped completely. Someone may choose to fork your project or volunteer to step in as a maintainer or owner, allowing your project to keep going. You can also make an explicit request for maintainers. diff --git a/__init__.py b/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/dataloader/__init__.py b/dataloader/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/dataloader/load.py b/dataloader/load.py new file mode 100644 index 0000000..a9ce93c --- /dev/null +++ b/dataloader/load.py @@ -0,0 +1,76 @@ +import torch +from torch.utils.data import Dataset +from torchvision import transforms +import os +from PIL import Image +import pandas as pd +import numpy as np + +class ConditionalDataset(Dataset): + def __init__(self,fpath,img_size,train,frac =0.8,skip_first_n = 0,ext = ".png",transform=True ): + """ + Args: + fpath (string): Path to the folder where images are stored + img_size (int): size of output image img_size=height=width + ext (string): type of images used(eg .png) + transform (Bool): Image augmentation for diffusion model + skip_first_n: skips the first n values. Usefull for datasets that are sorted by increasing Likeliehood + train (Bool): Choose dataset to be either train set or test set. frac(float) required + frac (float): value within (0,1] (seeded)random shuffles dataset, then divides into train and test set. + """ + if train: + fpath = os.path.join(fpath, 'train') + else: + fpath = os.path.join(fpath, 'valid') + + self.class_to_idx = {'cat': 0, 'dog': 1, 'wild': 2} + + file_list =[] + class_list = [] + for root, dirs, files in os.walk(fpath, topdown=False): + for name in sorted(files): + if name.endswith(ext): + file_list.append(os.path.join(root, name)) + class_list.append(self.class_to_idx[os.path.basename(root)]) + #df = pd.DataFrame({"Filepath":file_list},) + #self.df = df[df["Filepath"].str.endswith(ext)] + self.df = pd.DataFrame({"Filepath": file_list}) + self.class_list = class_list + + if transform: + # for training + intermediate_size = 137 + theta = np.pi/4 -np.arccos(intermediate_size/(np.sqrt(2)*img_size)) #Check dataloading.ipynb in analysis-depot for more details + + transform_rotate_flip = transforms.Compose([transforms.ToTensor(), + transforms.Resize(intermediate_size,antialias=True), + transforms.RandomRotation((theta/np.pi*180),interpolation=transforms.InterpolationMode.BILINEAR), + transforms.CenterCrop(img_size), + transforms.RandomHorizontalFlip(p=0.5), + transforms.Normalize(mean=(0.5,0.5,0.5),std=(0.5,0.5,0.5))]) + + + transform_flip = transforms.Compose([transforms.ToTensor(), + transforms.Resize(img_size, antialias=True), + transforms.RandomHorizontalFlip(p=0.5), + transforms.Normalize(mean=(0.5,0.5,0.5),std=(0.5,0.5,0.5))]) + + self.transform = transforms.RandomChoice([transform_rotate_flip,transform_flip]) + else : + # for evaluation + self.transform = transforms.Compose([transforms.ToTensor(), + transforms.Lambda(lambda x: (x * 255).type(torch.uint8)), + transforms.Resize(img_size)]) + + def __len__(self): + return len(self.df) + + def __getitem__(self,idx): + path = self.df.iloc[idx].Filepath + img = Image.open(path) + class_idx = self.class_list[idx] + return self.transform(img), class_idx + + def tensor2PIL(self,img): + back2pil = transforms.Compose([transforms.Normalize(mean=(-1,-1,-1),std=(2,2,2)),transforms.ToPILImage()]) + return back2pil(img) diff --git a/evaluation/__init__.py b/evaluation/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/evaluation/evaluate.py b/evaluation/evaluate.py new file mode 100644 index 0000000..2c9c3a8 --- /dev/null +++ b/evaluation/evaluate.py @@ -0,0 +1,108 @@ +from torchmetrics.image.fid import FrechetInceptionDistance +from torchmetrics.image.inception import InceptionScore +from torchmetrics.image.kid import KernelInceptionDistance +import re +import os +from PIL import Image +from torchvision import transforms +import torch + +def cdm_evaluator(model, + device, + dataloader, + checkpoint, + experiment_path, + sample_idx=0, + **args, + ): + ''' + Takes a trained diffusion model from 'checkpoint' and evaluates its performance on the test + dataset 'dataloader' w.r.t. the three most important perfromance metrics; FID, IS, KID. We continue + the progress of our evaluation function for the LDM upscalaer and may update this function accordingly. + + checkpoint: Name of the saved pth. file containing the trained weights and biases + experiment_path: Path to the experiment folder where the evaluation results will be stored + dataloader: Loads the test dataset for evaluation + sample_idx: Integer that denotes which sample directory sample_{sample_idx} from the checkpoint model shall be used for evaluation + ''' + + checkpoint_path = f'{experiment_path}trained_cdm/{checkpoint}' + # create evaluation directory for the complete experiment (if first time sampling images) + output_dir = f'{experiment_path}evaluations/' + os.makedirs(output_dir, exist_ok=True) + + # create evaluation directory for the current version of the trained model + model_name = os.path.basename(checkpoint_path) + epoch = re.findall(r'\d+', model_name) + if epoch: + e = int(epoch[0]) + else: + raise ValueError(f"No digit found in the filename: {filename}") + model_dir = os.path.join(output_dir,f'epoch_{e}') + os.makedirs(model_dir, exist_ok=True) + + # create the evaluation directory for this evaluation run for the current version of the model + eval_dir_list = [d for d in os.listdir(model_dir) if os.path.isdir(os.path.join(model_dir, d))] + indx_list = [int(d.split('_')[1]) for d in eval_dir_list if d.startswith('evaluation_')] + j = max(indx_list, default=-1) + 1 + eval_dir = os.path.join(model_dir, f'evaluation_{j}') + os.makedirs(eval_dir, exist_ok=True) + + # Compute metrics + eval_path = os.path.join(eval_dir, 'eval.txt') + + # get sampled images + transform = transforms.Compose([transforms.ToTensor(), transforms.Lambda(lambda x: (x * 255).type(torch.uint8))]) + sample_path = os.path.join(f'{experiment_path}samples/',f'epoch_{e}',f'sample_{sample_idx}') + ignore_tensor = f'image_tensor{j}' + images = [] + for samplename in os.listdir(sample_path): + if samplename == ignore_tensor: + continue + img = Image.open(os.path.join(sample_path, samplename)) + img = transform(img) + images.append(img) + # split them into batches for GPU memory + generated = torch.stack(images).to(device) + generated_batches = torch.split(generated, dataloader.batch_size) + nr_generated_batches = len(generated_batches) + nr_real_batches = len(dataloader) + + # Init FID, IS and KID scores + fid = FrechetInceptionDistance(normalize = False).to(device) + iscore = InceptionScore(normalize=False).to(device) + kid = KernelInceptionDistance(normalize=False, subset_size=32).to(device) + + # Update scores for the full testing dataset w.r.t. the sampled batches + for idx,(data, _) in enumerate(dataloader): + data = data.to(device) + fid.update(data, real=True) + kid.update(data, real=True) + if idx < nr_generated_batches: + gen = generated_batches[idx].to(device) + fid.update(gen, real=False) + kid.update(gen, real=False) + iscore.update(gen) + + # If there are sampled images left, add them too + for idx in range(nr_real_batches, nr_generated_batches): + gen = generated_batches[idx].to(device) + fid.update(gen, real=False) + kid.update(gen, real=False) + iscore.update(gen) + + # compute total FID, IS and KID + fid_score = fid.compute() + i_score = iscore.compute() + kid_score = kid.compute() + + # store results in txt file + with open(str(eval_path), 'a') as txt: + result = f'FID_epoch_{e}_sample_{sample_idx}:' + txt.write(result + str(fid_score.item()) + '\n') + result = f'KID_epoch_{e}_sample_{sample_idx}:' + txt.write(result + str(kid_score) + '\n') + result = f'IS_epoch_{e}_sample_{sample_idx}:' + txt.write(result + str(i_score) + '\n') + + diff --git a/evaluation/sample.py b/evaluation/sample.py new file mode 100644 index 0000000..a7a6e3b --- /dev/null +++ b/evaluation/sample.py @@ -0,0 +1,86 @@ +import os +import torch +from torchvision import transforms +import re + +def cdm_sampler(model, checkpoint, experiment_path, device, intermediate=False, batch_size=15,sample_all=False,n_times=1): + ''' + Samples a tensor of 'batch_size' images from a trained diffusion model with 'checkpoint'. The generated + images are stored in the directory 'experiment_path/samples/epoch_{e}/sample_{j}. Where e is the epoch + w.r.t. the model which we are sampling form and j is an index separating images from each call of the + sampling function for the given mode. + + model: Diffusion model + checkpoint: Name of the saved pth. file containing the trained weights and biases + conditioning_path: Path to the conditioning input tensors from whihc the samples will be generated from + experiment_path: Path to the experiment directory where the samples will saved under the diectory samples + batch_size: The number of images to sample + intermediate: Bool value. If False the sampling function will draw a batch of images, else it will just + sample a single image, but store all the intermediate noised latents along the reverse chain + sample_all: If True, samples a batch of images for the given model at every stored checkpoint at once + n_times: Integer denoting how many times we draw a batch of 'batch_size'. If we want to draw 10k images + the GPU will draw batches of 512 images 20 times to reach this goal. + ''' + + # If we want to sample from every checkpoint of the current model, recursively call this function for all checkpoints + if sample_all: + f = f'{experiment_path}trained_cdm/' + checkpoint_list = [checkpoint_i for checkpoint_i in os.listdir(f) if checkpoint_i.endswith(".pth")] + for checkpoint_i in os.listdir(f): + if checkpoint_i.endswith(".pth"): + ldm_sampler(model, checkpoint_i, experiment_path, dataloader, device, batch_size=batch_size, intermediate=intermediate, sample_all=False) + return 0 + + # load model + try: + checkpoint_path = f'{experiment_path}trained_cdm/{checkpoint}' + checkpoint = torch.load(checkpoint_path) + # load weights and biases of the U-Net + net_state_dict = checkpoint['model'] + model.net.load_state_dict(net_state_dict) + model = model.to(device) + except Exception as e: + print("Error loading checkpoint. Exception:", e) + + # create samples directory for the complete experiment (if first time sampling images) + output_dir = f'{experiment_path}samples/' + #output_dir = os.path.join(experiment_path,'/samples/') + os.makedirs(output_dir, exist_ok=True) + + # create sample directory for the current version of the trained model + model_name = os.path.basename(checkpoint_path) + epoch = re.findall(r'\d+', model_name) + if epoch: + e = int(epoch[0]) + else: + #raise ValueError(f"No digit found in the filename: {filename}") + e = 0 + model_dir = os.path.join(output_dir,f'epoch_{e}') + os.makedirs(model_dir, exist_ok=True) + + # create the sample directory for this sampling run for the current version of the model + sample_dir_list = [d for d in os.listdir(model_dir) if os.path.isdir(os.path.join(model_dir, d))] + indx_list = [int(d.split('_')[1]) for d in sample_dir_list if d.startswith('sample_')] + j = max(indx_list, default=-1) + 1 + sample_dir = os.path.join(model_dir, f'sample_{j}') + os.makedirs(sample_dir, exist_ok=True) + + # transform + back2pil = transforms.Compose([transforms.Normalize(mean=(-1,-1,-1),std=(2,2,2)),transforms.ToPILImage()]) + + # sample upscaling latent encoding, decode and transform to PIL + run_indx = 0 + for k in range(n_times): + # sample random classes + y = torch.randint(0, 3, (batch_size,)).to(device) + # generate images + generated = model.sample(y=y, batch_size=y.size(0)) + # save images + for i in range(generated.size(0)): + image = back2pil(generated[i]) + image_path = os.path.join(sample_dir, f'sample_{j}_{run_indx}_{y[i]}.png') + try: + image.save(image_path) + except Exception as e: + print("Error saving image. Exception:", e) + run_indx += 1 diff --git a/experiment_creator.ipynb b/experiment_creator.ipynb new file mode 100644 index 0000000..bb4e02d --- /dev/null +++ b/experiment_creator.ipynb @@ -0,0 +1,261 @@ +{ + "cells": [ + { + "cell_type": "code", + "execution_count": 14, + "metadata": { + "scrolled": true + }, + "outputs": [], + "source": [ + "from trainer.train import *\n", + "from dataloader.load import *\n", + "from models.Framework import *\n", + "from models.all_unets import *\n", + "import torch \n", + "from torch import nn " + ] + }, + { + "attachments": {}, + "cell_type": "markdown", + "metadata": {}, + "source": [ + "# Prepare experiment\n", + "1. Choose Hyperparameter Settings\n", + "2. Run notebook on local maschine to generate experiment folder with the JSON files containing the settings\n", + "3. scp experiment folder to the HPC\n", + "4. Run Pipeline by adding following to batch file:\n", + "- Train Model:       `python main.py train \"<absolute path of experiment folder in hpc>\"`\n", + "- Sample Images:     `python main.py sample \"<absolute path of experiment folder in hpc>\"`\n", + "- Evaluate Model:     `python main.py evaluate \"<absolute path of experiment folder in hpc>\"`" + ] + }, + { + "cell_type": "code", + "execution_count": 17, + "metadata": {}, + "outputs": [], + "source": [ + "import torch \n", + "\n", + "####\n", + "# Settings\n", + "####\n", + "\n", + "# Dataset path\n", + "datapath = \"/work/lect0100/lhq_256\"\n", + "\n", + "# Experiment setup\n", + "run_name = 'main_test1' # WANDB and experiment folder Name!\n", + "checkpoint = None #'model_epoch_8.pth' # Name of checkpoint pth file or None \n", + "experiment_path = \"/work/lect0100/main_experiment/\" + run_name +'/'\n", + "\n", + "# Path to save generated experiment folder on local machine\n", + "local_path =\"experiments/\" + run_name + '/settings'\n", + "\n", + "# Diffusion Model Settings\n", + "diffusion_steps = 1000\n", + "image_size = 128\n", + "channels = 3\n", + "\n", + "# Training\n", + "batchsize = 32\n", + "epochs = 100\n", + "store_iter = 10\n", + "eval_iter = 500\n", + "learning_rate = 0.0001\n", + "optimizername = \"torch.optim.AdamW\"\n", + "optimizer_params = None\n", + "verbose = False\n", + "# checkpoint = None #(If no checkpoint training, ie. random weights)\n", + "\n", + "# Sampling \n", + "sample_size = 20\n", + "intermediate = False # True if you want to sample one image and all ist intermediate latents\n", + "sample_all=False\n", + "\n", + "# Evaluating\n", + "...\n", + "\n", + "\n", + "\n", + "###\n", + "# Advanced Settings Dictionaries\n", + "###\n", + "\n", + "meta_setting = dict(modelname = \"UNet_Res\",\n", + " dataset = \"UnconditionalDataset\",\n", + " framework = \"DDPM\",\n", + " trainloop_function = \"ddpm_trainer\",\n", + " sampling_function = 'ddpm_sampler',\n", + " evaluation_function = 'ddpm_evaluator',\n", + " batchsize = batchsize\n", + " )\n", + "dataset_setting = dict(fpath = datapath,\n", + " img_size = image_size,\n", + " frac =0.8,\n", + " skip_first_n = 0,\n", + " ext = \".png\",\n", + " transform=True\n", + " )\n", + "\n", + "model_setting = dict( n_channels=64,\n", + " fctr = [1,2,4,4,8],\n", + " time_dim=256,\n", + " attention = True,\n", + " )\n", + "\"\"\"\n", + "outdated\n", + "model_setting = dict( channels_in=channels, \n", + " channels_out =channels , \n", + " activation='relu', # activation function. Options: {'relu', 'leakyrelu', 'selu', 'gelu', 'silu'/'swish'}\n", + " weight_init='he', # weight initialization. Options: {'he', 'torch'}\n", + " projection_features=64, # number of image features after first convolution layer\n", + " time_dim=batchsize, #dont chnage!!!\n", + " time_channels=diffusion_steps, # number of time channels #TODO same as diffusion steps? \n", + " num_stages=4, # number of stages in contracting/expansive path\n", + " stage_list=None, # specify number of features produced by stages\n", + " num_blocks=1, # number of ConvResBlock in each contracting/expansive path\n", + " num_groupnorm_groups=32, # number of groups used in Group Normalization inside a ConvResBlock\n", + " dropout=0.1, # drop-out to be applied inside a ConvResBlock\n", + " attention_list=None, # specify MHA pattern across stages\n", + " num_attention_heads=1,\n", + " )\n", + "\"\"\"\n", + "framework_setting = dict(\n", + " diffusion_steps = diffusion_steps, # dont change!!\n", + " out_shape = (channels,image_size,image_size), # dont change!!\n", + " noise_schedule = 'linear', \n", + " beta_1 = 1e-4, \n", + " beta_T = 0.02,\n", + " alpha_bar_lower_bound = 0.9,\n", + " var_schedule = 'same', \n", + " kl_loss = 'simplified', \n", + " recon_loss = 'none',\n", + " )\n", + "training_setting = dict(\n", + " epochs = epochs,\n", + " store_iter = store_iter,\n", + " eval_iter = eval_iter,\n", + " optimizer_class=optimizername, \n", + " optimizer_params = optimizer_params,\n", + " #optimizer_params=dict(lr=learning_rate), # don't change! \n", + " learning_rate = learning_rate,\n", + " run_name=run_name,\n", + " checkpoint= checkpoint,\n", + " experiment_path = experiment_path,\n", + " verbose = verbose,\n", + " T_max = 0.8*90000/32*100, # cosine lr param len(train_ds)/batchsize * total epochs to 0 \n", + " eta_min= 1e-10, # cosine lr param\n", + " )\n", + "sampling_setting = dict( \n", + " checkpoint = checkpoint, \n", + " experiment_path = experiment_path, \n", + " batch_size = sample_size,\n", + " intermediate = intermediate,\n", + " sample_all = sample_all\n", + " )\n", + "# TODO\n", + "evaluation_setting = dict(\n", + " checkpoint = checkpoint,\n", + " experiment_path = experiment_path,\n", + " ) " + ] + }, + { + "cell_type": "code", + "execution_count": 18, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "create folder\n", + "folder created \n", + "stored json files in folder\n", + "{'modelname': 'UNet_Res', 'dataset': 'UnconditionalDataset', 'framework': 'DDPM', 'trainloop_function': 'ddpm_trainer', 'sampling_function': 'ddpm_sampler', 'evaluation_function': 'ddpm_evaluator', 'batchsize': 32}\n", + "{'fpath': '/work/lect0100/lhq_256', 'img_size': 128, 'frac': 0.8, 'skip_first_n': 0, 'ext': '.png', 'transform': True}\n", + "{'n_channels': 64, 'fctr': [1, 2, 4, 4, 8], 'time_dim': 256, 'attention': True}\n", + "{'diffusion_steps': 1000, 'out_shape': (3, 128, 128), 'noise_schedule': 'linear', 'beta_1': 0.0001, 'beta_T': 0.02, 'alpha_bar_lower_bound': 0.9, 'var_schedule': 'same', 'kl_loss': 'simplified', 'recon_loss': 'none'}\n", + "{'epochs': 100, 'store_iter': 10, 'eval_iter': 500, 'optimizer_class': 'torch.optim.AdamW', 'optimizer_params': None, 'learning_rate': 0.0001, 'run_name': 'main_test1', 'checkpoint': None, 'experiment_path': '/work/lect0100/main_experiment/main_test1/', 'verbose': False, 'T_max': 225000.0, 'eta_min': 1e-10}\n", + "{'checkpoint': None, 'experiment_path': '/work/lect0100/main_experiment/main_test1/', 'batch_size': 20, 'intermediate': False}\n", + "{'checkpoint': None, 'experiment_path': '/work/lect0100/main_experiment/main_test1/'}\n" + ] + } + ], + "source": [ + "import os\n", + "import json\n", + "f = local_path\n", + "if os.path.exists(f):\n", + " print(\"path already exists, pick a new name!\")\n", + " print(\"break\")\n", + "else:\n", + " print(\"create folder\")\n", + " #os.mkdir(f)\n", + " os.makedirs(f, exist_ok=True)\n", + " print(\"folder created \")\n", + " with open(f+\"/meta_setting.json\",\"w+\") as fp:\n", + " json.dump(meta_setting,fp)\n", + "\n", + " with open(f+\"/dataset_setting.json\",\"w+\") as fp:\n", + " json.dump(dataset_setting,fp)\n", + " \n", + " with open(f+\"/model_setting.json\",\"w+\") as fp:\n", + " json.dump(model_setting,fp)\n", + " \n", + " with open(f+\"/framework_setting.json\",\"w+\") as fp:\n", + " json.dump(framework_setting,fp)\n", + "\n", + " with open(f+\"/training_setting.json\",\"w+\") as fp:\n", + " json.dump(training_setting,fp)\n", + " \n", + " with open(f+\"/sampling_setting.json\",\"w+\") as fp:\n", + " json.dump(sampling_setting,fp)\n", + " \n", + " with open(f+\"/evaluation_setting.json\",\"w+\") as fp:\n", + " json.dump(evaluation_setting,fp)\n", + " \n", + " print(\"stored json files in folder\")\n", + " print(meta_setting)\n", + " print(dataset_setting)\n", + " print(model_setting)\n", + " print(framework_setting)\n", + " print(training_setting)\n", + " print(sampling_setting)\n", + " print(evaluation_setting)\n", + " " + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [] + } + ], + "metadata": { + "kernelspec": { + "display_name": "env", + "language": "python", + "name": "env" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.10.6" + } + }, + "nbformat": 4, + "nbformat_minor": 2 +} diff --git a/main.py b/main.py new file mode 100644 index 0000000..701fa00 --- /dev/null +++ b/main.py @@ -0,0 +1,153 @@ + +import json +import sys +from dataloader.load import * +from models.ConditionalDiffusionModel import * +from trainer.train import cdm_trainer +from evaluation.sample import cdm_sampler +from evaluation.evaluate import cdm_evaluator +from models.conditional_unet import * +import torch + + +def train_func(f): + #load all settings + device = 'cuda' if torch.cuda.is_available() else 'cpu' + print(f"device: {device}\n\n") + print(f"folderpath: {f}\n\n") + with open(f+"/meta_setting.json","r") as fp: + meta_setting = json.load(fp) + + with open(f+"/dataset_setting.json","r") as fp: + dataset_setting = json.load(fp) + + with open(f+"/model_setting.json","r") as fp: + model_setting = json.load(fp) + + with open(f+"/framework_setting.json","r") as fp: + framework_setting = json.load(fp) + + with open(f+"/training_setting.json","r") as fp: + training_setting = json.load(fp) + training_setting["optimizer_class"] = eval(training_setting["optimizer_class"]) + + + batchsize = meta_setting["batchsize"] + training_dataset = globals()[meta_setting["dataset"]](train = True,**dataset_setting) + test_dataset = globals()[meta_setting["dataset"]](train = False,**dataset_setting) + + training_dataloader = torch.utils.data.DataLoader(training_dataset,batch_size=batchsize,shuffle=True) + test_dataloader = torch.utils.data.DataLoader(test_dataset,batch_size=batchsize,shuffle=True) + + + net = globals()[meta_setting["modelname"]](**model_setting).to(device) + #net = torch.compile(net) + net = net.to(device) + + framework = globals()[meta_setting["framework"]](net = net,device=device, **framework_setting) + + print(f"META SETTINGS:\n\n {meta_setting}\n\n") + print(f"DATASET SETTINGS:\n\n {dataset_setting}\n\n") + print(f"MODEL SETTINGS:\n\n {model_setting}\n\n") + print(f"FRAMEWORK SETTINGS:\n\n {framework_setting}\n\n") + print(f"TRAINING SETTINGS:\n\n {training_setting}\n\n") + + print("\n\nSTART TRAINING\n\n") + globals()[meta_setting["trainloop_function"]](model=framework,device=device, trainloader = training_dataloader, testloader = test_dataloader,safepath = f,**training_setting,) + print("\n\nFINISHED TRAINING\n\n") + + + +def sample_func(f): + device = 'cuda' if torch.cuda.is_available() else 'cpu' + print(f"device: {device}\n\n") + print(f"folderpath: {f}\n\n") + + with open(f+"/meta_setting.json","r") as fp: + meta_setting = json.load(fp) + + with open(f+"/model_setting.json","r") as fp: + model_setting = json.load(fp) + + with open(f+"/framework_setting.json","r") as fp: + framework_setting = json.load(fp) + + with open(f+"/sampling_setting.json","r") as fp: + sampling_setting = json.load(fp) + + # init Unet + net = globals()[meta_setting["modelname"]](**model_setting).to(device) + #net = torch.compile(net) + net = net.to(device) + # init unconditional diffusion model + framework = globals()[meta_setting["framework"]](net = net,device=device, **framework_setting) + + print(f"META SETTINGS:\n\n {meta_setting}\n\n") + print(f"MODEL SETTINGS:\n\n {model_setting}\n\n") + print(f"FRAMEWORK SETTINGS:\n\n {framework_setting}\n\n") + print(f"SAMPLING SETTINGS:\n\n {sampling_setting}\n\n") + + print("\n\nSTART SAMPLING\n\n") + globals()[meta_setting["sampling_function"]](model=framework,device=device ,**sampling_setting,) + print("\n\nFINISHED SAMPLING\n\n") + + + +def evaluate_func(f): + device = 'cuda' if torch.cuda.is_available() else 'cpu' + print(f"device: {device}\n\n") + print(f"folderpath: {f}\n\n") + + with open(f+"/meta_setting.json","r") as fp: + meta_setting = json.load(fp) + + with open(f+"/model_setting.json","r") as fp: + model_setting = json.load(fp) + + with open(f+"/framework_setting.json","r") as fp: + framework_setting = json.load(fp) + + with open(f+"/evaluation_setting.json","r") as fp: + evaluation_setting = json.load(fp) + + with open(f+"/dataset_setting.json","r") as fp: + dataset_setting = json.load(fp) + + # load dataset + batchsize = meta_setting["batchsize"] + test_dataset = globals()[meta_setting["dataset"]](train = False,**dataset_setting) + #test_dataloader = torch.utils.data.DataLoader(test_dataset,batch_size=len(test_dataset), shuffle=False) + test_dataloader = torch.utils.data.DataLoader(test_dataset,batch_size=batchsize, shuffle=False) + + # init Unet + net = globals()[meta_setting["modelname"]](**model_setting).to(device) + #net = torch.compile(net) + net = net.to(device) + + # init unconditional diffusion model + framework = globals()[meta_setting["framework"]](net = net,device=device, **framework_setting) + + print(f"META SETTINGS:\n\n {meta_setting}\n\n") + print(f"DATASET SETTINGS:\n\n {dataset_setting}\n\n") + print(f"MODEL SETTINGS:\n\n {model_setting}\n\n") + print(f"FRAMEWORK SETTINGS:\n\n {framework_setting}\n\n") + print(f"EVALUATION SETTINGS:\n\n {evaluation_setting}\n\n") + + print("\n\nSTART EVALUATION\n\n") + globals()[meta_setting["evaluation_function"]](model=framework, device=device, dataloader = test_dataloader,safepath = f,**evaluation_setting,) + print("\n\nFINISHED EVALUATION\n\n") + + + + + +if __name__ == '__main__': + + + print(sys.argv) + functions = {'train': train_func,'sample': sample_func,'evaluate': evaluate_func} + functions[sys.argv[1]](sys.argv[2]) + + + + diff --git a/models/ConditionalDiffusionModel.py b/models/ConditionalDiffusionModel.py new file mode 100644 index 0000000..4dbf5cb --- /dev/null +++ b/models/ConditionalDiffusionModel.py @@ -0,0 +1,463 @@ +import torch +from torch import nn +import torch.nn.functional as F +import math + +class CDM(nn.Module): + + def __init__(self, + net=None, + diffusion_steps = 50, + out_shape = (3,128,128), + conditional_shape = (3), + noise_schedule = 'linear', + beta_1 = 1e-4, + beta_T = 0.02, + alpha_bar_lower_bound = 0.9, + var_schedule = 'same', + kl_loss = 'simplified', + recon_loss = 'none', + guidance_score = 3, # for classifier-free guided diffusion + device=None): + ''' + net: U-Net + diffusion_steps: Length of the Markov chain + out_shape: Shape of the models's in- and output images + conditional_shape: Shape of the low resolution image the DM is conditioned on for Super Resolution + noise_schedule: Methods of initialization for the noise dist. variances, 'linear', 'cosine' or bounded_cosine + beta_1, beta_T: Variances for the first and last noise dist. (only for the 'linear' noise schedule) + alpha_bar_lower_bound: Upper bound for the varaince of the complete noise dist. (only for the 'cosine_bounded' noise schedule) + var_schedule: Options to initialize or learn the denoising dist. variances, 'same', 'true' + kl_loss: Choice between the mathematically correct 'weighted' or in practice most commonly used 'simplified' KL loss + recon_loss: Is 'none' to ignore the reconstruction loss or 'nll' to compute the negative log likelihood + ''' + super(CDM,self).__init__() + self.device = device + + # initialize the beta's, alpha's and alpha_bar's for the given noise schedule + if noise_schedule == 'linear': + beta, alpha, alpha_bar = self.linear_schedule(diffusion_steps, beta_1, beta_T, device=self.device) + elif noise_schedule == 'cosine': + beta, alpha, alpha_bar = self.cosine_schedule(diffusion_steps, device=self.device) + elif noise_schedule == 'cosine_bounded': + beta, alpha, alpha_bar = self.bounded_cosine_schedule(diffusion_steps, alpha_bar_lower_bound, device=self.device) + else: + raise ValueError('Unimplemented noise scheduler') + + # initialize the denoising varainces for the given varaince schedule + if var_schedule == 'same': + var = beta + elif var_schedule == 'true': + var = [beta[0]] + [((1-alpha_bar[t-1])/(1-alpha_bar[t]))*beta[t] for t in range (1,diffusion_steps)] + var = torch.tensor(var, device=self.device) + else: + raise ValueError('Unimplemented variance scheduler') + + # check for invalid kl_loss argument + if (kl_loss != 'simplified') & (kl_loss != 'weighted'): + raise ValueError("Unimplemented loss function") + + self.net = net + self.guidance_score = guidance_score + self.diffusion_steps = diffusion_steps + self.noise_schedule = noise_schedule + self.var_schedule = var_schedule + self.beta = beta + self.alpha = alpha + self.alpha_bar = alpha_bar + self.sqrt_1_minus_alpha_bar = torch.sqrt(1-alpha_bar) # for forward std + self.sqrt_alpha_bar = torch.sqrt(alpha_bar) # for forward mean + self.var = var + self.std = torch.sqrt(self.var) + self.kl_loss = kl_loss + self.recon_loss = recon_loss + self.out_shape = out_shape + self.conditional_shape = conditional_shape + # precomputed for efficiency reasons + self.noise_scaler = (1-alpha)/( self.sqrt_1_minus_alpha_bar) + self.mean_scaler = 1/torch.sqrt(self.alpha) + self.mse_weight = (self.beta**2)/(2*self.var*self.alpha*(1-self.alpha_bar)) + + @staticmethod + def linear_schedule(diffusion_steps, beta_1, beta_T, device): + '''' + Function that returns the noise distribution hyperparameters for the linear schedule. + + Parameters: + diffusion_steps (int): Length of the Markov chain. + beta_1 (float): Variance of the first noise distribution. + beta_T (float): Variance of the last noise distribution. + + Returns: + beta (tensor): Linearly scaled from beta[0] = beta_1 to beta[-1] = beta_T, length is diffusion_steps. + alpha (tensor): Length is diffusion_steps. + alpha_bar (tensor): Length is diffusion_steps. + ''' + beta = torch.linspace(beta_1, beta_T, diffusion_steps,device=device) + alpha = 1 - beta + alpha_bar = torch.cumprod(alpha, dim=0) + return beta, alpha, alpha_bar + + @staticmethod + def cosine_schedule(diffusion_steps, device): + ''' + Function that returns the noise distribution hyperparameters for the cosine schedule. + From "Improved Denoising Diffusion Probabilistic Models" by Nichol and Dhariwal. + + Parameters: + diffusion_steps (int): Length of the Markov chain. + + Returns: + beta (tensor): Length is diffusion_steps. + alpha (tensor): Length is diffusion_steps. + alpha_bar (tensor): Follows a sigmoid-like curve with a linear drop-off in the middle. + Length is diffusion_steps. + ''' + cosine_0 = CDM.cosine(torch.tensor(0, device=device), diffusion_steps= diffusion_steps) + alpha_bar = [CDM.cosine(torch.tensor(t, device=device),diffusion_steps = diffusion_steps)/cosine_0 + for t in range(1, diffusion_steps+1)] + shift = [1] + alpha_bar[:-1] + beta = 1 - torch.div(torch.tensor(alpha_bar, device=device), torch.tensor(shift, device=device)) + beta = torch.clamp(beta, min =0, max = 0.999).to(device) #suggested by paper + alpha = 1 - beta + alpha_bar = torch.tensor(alpha_bar,device=device) + return beta, alpha, alpha_bar + + @staticmethod + def bounded_cosine_schedule(diffusion_steps, alpha_bar_lower_bound, device): + ''' + Function that returns the noise distribution hyperparameters for our experimental version of a + bounded cosine schedule. Benefits are still unproven. It still has a linear drop-off in alpha_bar, + but it's not sigmoidal and the betas are no longer smooth. + + Parameters: + diffusion_steps (int): Length of the Markov chain + + Returns: + beta (tensor): Length is diffusion_steps + alpha (tensor): Length is diffusion_steps + alpha_bar (tensor): Bounded between (alpha_bar_lower_bound, 1) with a linear drop-off in the middle. + Length is diffusion_steps + ''' + # get cosine alpha_bar (that range from 1 to 0) + _, _, alpha_bar = CDM.cosine_schedule(diffusion_steps, device) + # apply min max normalization on alpha_bar (range from lower_bound to 0.999) + min_val = torch.min(alpha_bar) + max_val = torch.max(alpha_bar) + alpha_bar = (alpha_bar - min_val) / (max_val - min_val) + alpha_bar = alpha_bar * (0.9999 - alpha_bar_lower_bound) + alpha_bar_lower_bound # for 0.9999=>beta_1 = 1e-4 + # recompute beta, alpha and alpha_bar + alpha_bar = alpha_bar.tolist() + shift = [1] + alpha_bar[:-1] + beta = 1 - torch.div(torch.tensor(alpha_bar, device = device), torch.tensor(shift, device=device)) + beta = torch.clamp(beta, min=0, max = 0.999) + beta = torch.tensor(sorted(beta), device = device) + alpha = 1 - beta + alpha_bar = torch.cumprod(alpha, dim=0) + return beta, alpha, alpha_bar + + @staticmethod + def cosine(t, diffusion_steps, s = 0.008): + ''' + Helper function that computes the cosine function from "Improved Denoising Diffusion Probabilistic Models" + by Nichol and Dhariwal, used for the cosine noise schedules. + + Parameters: + t (int): Current timestep + diffusion_steps (int): Length of the Markov chain + s (float): Offset value suggested by the paper. Should be chosen such that sqrt(beta[0]) ~ 1/127.5 + (for small T=50, this is not possible) + + Returns: + (numpy.float64): Value of the cosine function at timestep t + ''' + return (torch.cos((((t/diffusion_steps)+s)*math.pi)/((1+s)*2)))**2 + + + #### + # Important to note: Timesteps are adjusted to the range t in [1, diffusion_steps] akin to the paper + # equations, where x_0 denotes the input image and x_t the noised latent after adding noise t times. + # Both trajectories work on batches assuming shape=(batch_size, channels, height, width). + #### + + # Forward Trajectory Functions: + + @torch.no_grad() + def forward_trajectory(self, x_0, t = None): + ''' + Applies noise t times to each input image in the batch x_0. + + Parameters: + x_0 (tensor): Batch of input images, with color channels assumed to be normalized between [-1,1] + t (tensor): Batch of timesteps, by default goes through full forward trajectory + + Returns: + x_T (tensor): Batch of noised images at timestep t + forward_noise (tensor): Batch of noise parameters from the noise distribution reparametrization used to draw x_T + ''' + if t is None: + t = torch.full((x_0.shape[0],), self.diffusion_steps, device = self.device) + elif torch.any(t == 0): + raise ValueError("The tensor 't' contains a timestep zero.") + forward_noise = torch.randn(x_0.shape, device = self.device) + x_T = self.noised_latent(forward_noise, x_0, t) + return x_T , forward_noise + + @torch.no_grad() + def noised_latent(self, forward_noise, x_0, t): + ''' + Given a batch of noise parameters, this function recomputes the batch of noised images at their respective timesteps t. + This allows us to avoid storing all the intermediate latents x_t along the forward trajectory. + + Parameters: + forward_noise (tensor): Batch of noise parameters from the noise distribution reparametrization used to draw x_t + x_0 (tensor): Batch of input images, with color channels assumed to be normalized between [-1,1] + t (tensor): Batch of timesteps + + Returns: + x_t (tensor): Batch of noised images at timestep t + ''' + mean, std = self.forward_dist_param(x_0, t) + x_t = mean + std*forward_noise + return x_t + + @torch.no_grad() + def forward_dist_param(self, x_0, t): + ''' + Computes the parameters of the complete noise distribution. + + Parameters: + x_0 (tensor): Batch of input images, with color channels assumed to be normalized between [-1,1] + t (tensor): Batch of timesteps + + Returns: + mean (tensor): Batch of means for the complete noise distribution for each image in the batch x_0 + std (tensor): Batch of std scalars for the complete noise distribution for each image in the batch x_0 + ''' + mean = self.sqrt_alpha_bar[t-1][:,None,None,None]*x_0 + std = self.sqrt_1_minus_alpha_bar[t-1][:,None,None,None] + return mean, std + + @torch.no_grad() + def single_forward_dist_param(self, x_t_1, t): + ''' + Computes the parameters of the individual noise distribution. + + Parameters: + x_t_1 (tensor): Batch of noised images at timestep t-1 + t (tensor): Batch of timesteps + + Returns: + mean (tensor): Batch of means for the individual noise distribution for each image in the batch x_t_1 + std (tensor): Batch of std scalars for the individual noise distribution for each image in the batch x_t_1 + ''' + mean = torch.sqrt(1-self.beta[t-1])[:,None,None,None]*x_t_1 + std = torch.sqrt(self.beta[t-1])[:,None,None,None] + return mean, std + + + # Reverse Trajectory Functions: + + def reverse_trajectory(self, x_t, t, y=None): + ''' + Draws a denoised images x_{t-1} by reparametrizing the denoising distribution at times t for the current noised + latents x_t. + + Parameters: + x_t (tensor): Batch of input images, with color channels assumed to be normalized between [-1,1] + t (tensor): Batch of timestep + y (tensor): Batch of conditional information for each input image + + Returns: + x_t_1 (tensor): Batch of denoised images at timestep t-1 + ''' + noise = torch.randn(x_t.shape, device=self.device) + pred_noise = self.forward(x_t, t, y) + mean, std = self.reverse_dist_param(x_t, pred_noise, t) #self.forward(x_t, t, c) + x_t_1 = mean + std*noise + return x_t_1 + + def forward(self, x_t, t, y=None): + ''' + Passes the current noised images x_t and timesteps t through the U-Net in order to compute the + predicted noise, which is later used to determine the current denoising distribution parameters + (mean and std) in the reverse trajectory. + Since the CDM class is inheriting from the nn.Module class, this function is required to share + the name 'forward'. This naming scheme does not refer to the forward trajectory, but the forward + pass of the model itself, which concerns to the reverse trajectory. + + Parameters: + x_t (tensor): Batch of input images, with color channels assumed to be normalized between [-1,1] + t (tensor): Batch of timesteps + y (tensor): Batch of conditional information for each input image + + Returns: + mean (tensor): Batch of means for the complete noise dist. for each image in the batch x_t + std (tensor): Batch of std scalars for the complete noise dist. for each image in the batch x_t + pred_noise (tensor): Predicted noise for each image in the batch x_t + ''' + pred_noise = self.net(x_t,t, y) + return pred_noise + + def reverse_dist_param(self, x_t, pred_noise, t): + ''' + Computes the parameters of the reverse denoising distribution at times t for a given predicted noise batch. + + Parameters: + pred_noise (tensor): Batch of input images, with color channels assumed to be normalized between [-1,1] + t (tensor): Batch of timesteps + + Returns: + mean (tensor): Batch of means for the complete noise distribution for each image in the batch x_0 + std (tensor): Batch of std scalars for the complete noise distribution for each image in the batch x_0 + ''' + mean = self.mean_scaler[t-1][:,None,None,None]*(x_t - self.noise_scaler[t-1][:,None,None,None]*pred_noise) + std = self.std[t-1][:,None,None,None] + return mean, std + + + # Forward and Reverse Trajectory: + + @torch.no_grad() + def complete_trajectory(self, x_0, y=None): + ''' + Takes a batch of images and applies both trajectories sequentially, i.e. first adds noise to all + images along the forward chain and later removes the noise with the reverse chain. + This function will be used in the evaluation pipeline as a means to evaluate its performance on + how well it is able to reconstruct/recover the training images after applying the forward trajectory. + + Parameters: + x_0 (tensor): Batch of input images, with color channels assumed to be normalized between [-1,1] + y (tensor): Batch of conditional information for each input image + + Returns: + x_0_recon (tensor): Batch of images given by the model reconstruction of x_0 + ''' + # apply forward trajectory + x_0_recon, _ = self.forward_trajectory(x_0) + # apply reverse trajectory + for t in reversed(range(1, self.diffusion_steps + 1)): + # draw noise used in the denoising dist. reparametrization + if t > 1: + noise = torch.randn(x_0_recon.shape, device=self.device) + else: + noise = torch.zeros(x_0_recon.shape, device=self.device) + # set timestep batch with all entries as t + t_batch = torch.full((x_0_recon.shape[0],), t ,device = self.device) + # get denoising dist. param + pred_noise = self.forward(x_0_recon, t_batch, y) + mean, std = self.reverse_dist_param(x_0_recon, pred_noise, t_batch) + # compute the drawn denoised latent at time t + x_0_recon = mean + std * noise + return x_0_recon + + + # Sampling Functions: + + @torch.no_grad() + def sample(self, batch_size = 10, x_T=None, y = None): + ''' + Samples batch_size images by passing a batch of randomly drawn noise parameters through the complete + reverse trajectory. The last denoising step is deterministic as suggested by the paper + "Denoising Diffusion Probabilistic Models" by Ho et al. + + Parameters: + batch_size (int): Number of images to be sampled/generated from the diffusion model + x_T (tensor): Input of the reverse trajectory. Batch of noised images usually drawn + from an isotropic Gaussian, but can be set manually if desired. + y (tensor): Batch of conditional information for each input image + + Returns: + x_t_1 (tensor): Batch of sampled/generated images + ''' + # start with a batch of isotropic noise images (or given arguemnt) + if x_T: + x_t_1 = x_T + else: + x_t_1 = torch.randn((batch_size,)+tuple(self.out_shape), device=self.device) + # apply reverse trajectory + for t in reversed(range(1, self.diffusion_steps+1)): + # draw noise used in the denoising dist. reparametrization + if t>1: + noise = torch.randn(x_t_1.shape, device=self.device) + else: + noise = torch.zeros(x_t_1.shape, device=self.device) + # set timestep batch with all entries as t + t_batch = torch.full((x_t_1.shape[0],), t ,device = self.device) + # get classififer-free guided diffusion noise parameter + pred_noise_cond = self.forward(x_t_1, t_batch, y) # param with conditioning + pred_noise_uncond = self.forward(x_t_1, t_batch, y=None) # param without conditioning + # linear interpolation of the two + pred_noise = torch.lerp(pred_noise_uncond, pred_noise_cond, self.guidance_score) + mean, std = self.reverse_dist_param(x_t_1, pred_noise, t_batch) + # compute the drawn densoined latent at time t + x_t_1 = mean + std*noise + return x_t_1 + + @torch.no_grad() + def sample_intermediates_latents(self, y): + ''' + Samples a single image and provides all intermediate denoised images that were drawn along the reverse + trajectory. The last denoising step is deterministic as suggested by the paper "Denoising Diffusion + Probabilistic Models" by Ho et al. + + Parameters: + y (tensor): Batch of conditional information for each input image + + Returns: + x (tensor): Contains the self.diffusion_steps+1 denoised image tensors + ''' + # start with an image of pure noise (batch_size 1) and store it as part of the output + x_t_1 = torch.randn((1,) + tuple(self.out_shape), device=self.device) + x = torch.empty((self.diffusion_steps+1,) + tuple(self.out_shape), device=self.device) + x[-1] = x_t_1.squeeze(0) + # apply reverse trajectory + for t in reversed(range(1, self.diffusion_steps+1)): + # draw noise used in the denoising dist. reparametrization + if t>1: + noise = torch.randn(x_t_1.shape, device=self.device) + else: + noise = torch.zeros(x_t_1.shape, device=self.device) + # set timestep batch with all entries as t + t_batch = torch.full((x_t_1.shape[0],), t ,device = self.device) + # get classififer-free guided diffusion noise parameter + pred_noise_cond = self.forward(x_t_1, t_batch, y) # param with conditioning + pred_noise_uncond = self.forward(x_t_1, t_batch, y=None) # param without conditioning + # linear interpolation of the two + pred_noise = torch.lerp(pred_noise_uncond, pred_noise_cond, self.guidance_score) + mean, std = self.reverse_dist_param(x_t_1, pred_noise, t_batch) + # compute the drawn densoined latent at time t + x_t_1 = mean + std*noise + # store noised image + x[t-1] = x_t_1.squeeze(0) + #x_sq = x.squeeze(1) + #return x_sq + return x + + + # Loss functions + + def loss_simplified(self, forward_noise, pred_noise, t=None): + ''' + Returns the Mean Squared Error (MSE) between the forward_noise used to compute the noised images x_t + along the forward trajectory and the predicted noise computed by the U-Net with the noised images + x_t and timestep t. + ''' + return F.mse_loss(forward_noise, pred_noise) + + + def loss_weighted(self, forward_noise, pred_noise, t): + ''' + Returns the mathematically correct weighted version of the simplified loss. + ''' + return self.mse_weight[t-1][:,None,None,None]*F.mse_loss(forward_noise, pred_noise) + + + # If t=0 and self.recon_loss == 'nll' + def loss_recon(self, x_0, mean_1, std_1): + ''' + Returns the reconstruction loss given by the mean negative log-likelihood of x_0 under the last + denoising Gaussian distribution with mean mean_1 and standard deviation std_1. + ''' + return -torch.distributions.Normal(mean_1, std_1).log_prob(x_0).mean() + diff --git a/models/__init__.py b/models/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/models/conditional_unet.py b/models/conditional_unet.py new file mode 100644 index 0000000..8f59e64 --- /dev/null +++ b/models/conditional_unet.py @@ -0,0 +1,256 @@ +import torch +import torch.nn as nn +import torch.nn.functional as F +import torchvision.transforms as transforms +import einops +import numpy as np + +# U-Net model +class Conditional_UNet_Res(nn.Module): + + def __init__(self, attention,channels_in=3, nr_class=3,n_channels=64,fctr = [1,2,4,4,8],time_dim=256,**args): + """ + attention : (Bool) wether to use attention layers or not + channels_in : (Int) + n_channels : (Int) Channel size after first convolution + fctr : (list) list of factors for further channel size wrt n_channels + time_dim : (Int) dimenison size for time end class embeding vector + """ + super().__init__() + channels_out = channels_in + fctr = np.asarray(fctr)*n_channels + + # learned time embedding + self.time_embedder = TimeEmbedding(time_dim = time_dim) + # learned class embedding + self.class_embedder = nn.Embedding(nr_class, time_dim) #akin to the OpenAI GLIDE Diffusion Model + + # learnable embedding layers + self.tc_embedder0 = torch.nn.Sequential(nn.Linear(time_dim,fctr[0]),nn.SELU(),nn.Linear(fctr[0],fctr[0])) + self.tc_embedder1 = torch.nn.Sequential(nn.Linear(time_dim,fctr[1]),nn.SELU(),nn.Linear(fctr[1],fctr[1])) + self.tc_embedder2 = torch.nn.Sequential(nn.Linear(time_dim,fctr[2]),nn.SELU(),nn.Linear(fctr[2],fctr[2])) + self.tc_embedder3 = torch.nn.Sequential(nn.Linear(time_dim,fctr[3]),nn.SELU(),nn.Linear(fctr[3],fctr[3])) + self.tc_embedder4 = torch.nn.Sequential(nn.Linear(time_dim,fctr[4]),nn.SELU(),nn.Linear(fctr[4],fctr[4])) + + # first conv block + self.first_conv = nn.Conv2d(channels_in,fctr[0],kernel_size=3, padding='same', bias=True) + + #down blocks + self.down1 = DownsampleBlock_Res(fctr[0],fctr[1],time_dim) + self.down2 = DownsampleBlock_Res(fctr[1],fctr[2],time_dim) + self.down3 = DownsampleBlock_Res(fctr[2],fctr[3],time_dim,attention=attention) + self.down4 = DownsampleBlock_Res(fctr[3],fctr[4],time_dim,attention=attention) + + #middle layer + self.mid1 = MidBlock_Res(fctr[4],time_dim,attention=attention) + + #up blocks + self.up1 = UpsampleBlock_Res(fctr[1],fctr[0],time_dim) + self.up2 = UpsampleBlock_Res(fctr[2],fctr[1],time_dim) + self.up3 = UpsampleBlock_Res(fctr[3],fctr[2],time_dim,attention=attention) + self.up4 = UpsampleBlock_Res(fctr[4],fctr[3],time_dim,attention=attention) + + # final 1x1 conv + self.end_conv = nn.Conv2d(fctr[0], channels_out, kernel_size=1,bias=True) + + # Attention Layers + self.mha21 = MHABlock(fctr[2]) + self.mha22 = MHABlock(fctr[2]) + self.mha31 = MHABlock(fctr[3]) + self.mha32 = MHABlock(fctr[3]) + self.mha41 = MHABlock(fctr[4]) + self.mha42 = MHABlock(fctr[4]) + + def forward(self, input, t, y): + # compute time mebedding + t_emb = self.time_embedder(t).to(input.device) + # compute class embedding if present + if y is not None: + c_emb = self.class_embedder(y).to(input.device) + else: + c_emb = torch.zeros_like(t_emb).to(input.device) + # combine both embeddings + tc_emb = t_emb + c_emb + + # learnable layers + tc_emb0 = self.tc_embedder0(tc_emb) + tc_emb1 = self.tc_embedder1(tc_emb) + tc_emb2 = self.tc_embedder2(tc_emb) + tc_emb3 = self.tc_embedder3(tc_emb) + tc_emb4 = self.tc_embedder4(tc_emb) + + # first two conv layers + x = self.first_conv(input) + tc_emb0[:,:,None,None] + + #time and class mb + skip1,x = self.down1(x,tc_emb1) + skip2,x = self.down2(x,tc_emb2) + skip3,x = self.down3(x,tc_emb3) + skip4,x = self.down4(x,tc_emb4) + + x = self.mid1(x,tc_emb4) + + x = self.up4(x,skip4,tc_emb3) + x = self.up3(x,skip3,tc_emb2) + x = self.up2(x,skip2,tc_emb1) + x = self.up1(x,skip1,tc_emb0) + x = self.end_conv(x) + + return x + + + +#TimeEmbedding +class TimeEmbedding(nn.Module): + + def __init__(self, time_dim=64): + super().__init__() + + self.time_dim = time_dim + n = 10000 + self.factor = torch.pow(n*torch.ones(size=(time_dim//2,)),(-2/time_dim*torch.arange(time_dim//2))) + + def forward(self, t): + """ + input is t (B,) + factor dim (time_dim,) + output is (B,time_dim) + """ + self.factor = self.factor.to(t.device) + theta = torch.outer(t,self.factor) + + + # shape of embedding [time_channels, dim] + emb = torch.zeros(t.size(0), self.time_dim,device=t.device) + emb[:, 0::2] = torch.sin(theta) + emb[:, 1::2] = torch.cos(theta) + + return emb + +# Self Attention +class MHABlock(nn.Module): + + def __init__(self, + channels_in, + num_attention_heads=1 # number of attention heads in MHA + ): + super().__init__() + + self.channels_in = channels_in + self.num_attention_heads = num_attention_heads + self.self_attention = nn.MultiheadAttention(channels_in, num_heads=self.num_attention_heads) + + def forward(self, x): + skip = x + batch_size,_,height,width = x.size() + + x = x.permute(2, 3, 0, 1).reshape(height * width, batch_size, -1) + attn_output, _ = self.self_attention(x, x, x) + attn_output = attn_output.reshape(batch_size, -1, height, width) + + return attn_output+skip + +# Residual Convolution Block +class ConvBlock_Res(nn.Module): + + def __init__(self, + channels_in, # number of input channels fed into the block + channels_out, # number of output channels produced by the block + time_dim, + attention, + num_groups=32, # number of groups used in Group Normalization; channels_in must be divisible by num_groups + ): + super().__init__() + + self.attention = attention + if self.attention: + self.attlayer = MHABlock(channels_in=channels_out) + + # Convolution layer 1 + self.conv1 = nn.Conv2d(channels_in, channels_out, kernel_size=3, padding='same', bias=False) + self.gn1 = nn.GroupNorm(num_groups, channels_out) + self.act1 = nn.SiLU() + + # Convolution layer 2 + self.conv2 = nn.Conv2d(channels_out, channels_out, kernel_size=3, padding='same', bias=False) + self.gn2 = nn.GroupNorm(num_groups, channels_out) + self.act2 = nn.SiLU() + + # Convolution layer 3 + self.conv3 = nn.Conv2d(channels_out, channels_out, kernel_size=3, padding='same', bias=False) + self.gn3 = nn.GroupNorm(num_groups, channels_out) + self.act3 = nn.SiLU() + + #Convolution skip + if channels_in!=channels_out: + self.res_skip = nn.Conv2d(channels_in,channels_out,kernel_size=1) + else: + self.res_skip = nn.Identity() + + nn.init.xavier_normal_(self.conv1.weight) + nn.init.xavier_normal_(self.conv2.weight) + nn.init.xavier_normal_(self.conv3.weight) + + def forward(self, x, t): + res = self.res_skip(x) + # second convolution layer + x = self.act1(self.gn1(self.conv1(x))) + + + h =x + t[:,:,None,None] + + + # third convolution layer + h = self.act2(self.gn2(self.conv2(h))) + + h = self.act3(self.gn3(self.conv3(h))) + + if self.attention: + h = self.attlayer(h) + + return h +res + +# Down Sample +class DownsampleBlock_Res(nn.Module): + + def __init__(self, channels_in, channels_out,time_dim,attention=False): + super().__init__() + + + self.pool = nn.MaxPool2d((2,2), stride=2) + self.convblock = ConvBlock_Res(channels_in, channels_out,time_dim,attention=attention) + + def forward(self, x, t): + + x = self.convblock(x, t) + h = self.pool(x) + return x,h + +# Upsample Block +class UpsampleBlock_Res(nn.Module): + + def __init__(self, channels_in, channels_out,time_dim,attention=False): + super().__init__() + + self.upconv = nn.ConvTranspose2d(channels_in, channels_in, kernel_size=2, stride=2) + self.convblock = ConvBlock_Res(channels_in, channels_out,time_dim,attention=attention) + + def forward(self, x, skip_x, t): + x = self.upconv(x) + + # skip-connection - merge features from contracting path to its symmetric counterpart in expansive path + out = x + skip_x + + out = self.convblock(out, t) + return out + +# Middle Block +class MidBlock_Res(nn.Module): + def __init__(self,channels,time_dim,attention=False): + super().__init__() + self.convblock1 = ConvBlock_Res(channels,channels,time_dim,attention=attention) + self.convblock2 = ConvBlock_Res(channels,channels,time_dim,attention=False) + def forward(self,x,t): + x = self.convblock1(x,t) + return self.convblock2(x,t) + diff --git a/trainer/__init__.py b/trainer/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/trainer/train.py b/trainer/train.py new file mode 100644 index 0000000..ac5dce8 --- /dev/null +++ b/trainer/train.py @@ -0,0 +1,321 @@ +import numpy as np +import copy +import torch +from torch import nn +from torchvision import datasets,transforms +from torch.utils.data import DataLoader +import matplotlib.pyplot as plt +import numpy as np +import torch.nn.functional as F +import os +import wandb +from copy import deepcopy + +device = 'cuda' if torch.cuda.is_available() else 'cpu' + +# Simple Training function for the unconditional diffusion model +def simple_trainer(model,device,epochs,trainloader,testloader,bs,lr,T,criterion = nn.MSELoss()): + criterion.to(device) + optimizer = torch.optim.AdamW(model.parameters(),lr=lr,) + + for epoch in range(epochs): + model.train() + running_trainloss = [] + running_testloss = [] + for idx,(x,_) in enumerate(trainloader): + x = x.to(device) # has to go to device + t = torch.randint(low=0,high=T,size=(1,)).item() # doesn't have to go to device + x_t,forward_noise = model.forward_trajectory(x,t) + optimizer.zero_grad() + mean,std,pred_noise = model.forward(x_t,t) # changed to forward sinnce model is a NN module + + loss = criterion(forward_noise,pred_noise) + loss.backward() + optimizer.step() + trainstep = epoch*bs+idx + running_trainloss.append(loss.cpu().item()) # MUST be on cou before appending to list + + + print(f"Trainloss in epoch {epoch}:{np.mean(running_trainloss)}") + + + model.eval() + with torch.no_grad(): + for idx,(x,_) in enumerate(testloader): + x = x.to(device) + t = torch.randint(low=0,high=T,size=(1,)).item() + x_t,forward_noise = model.forward_trajectory(x,t) + optimizer.zero_grad() + mean,std,pred_noise = model.forward(x_t,t) + loss = criterion(forward_noise,pred_noise) + running_testloss.append(loss.cpu().item()) + + print(f"Testloss in step {epoch} :{np.mean(running_testloss)}") + + +# EMA class +# Important! This EMA class code is not ours and was taken from the Pytorch Image Models library called timm and performs exponential moving +# average on the trained weights for a given models neural net which was suggested by the paper "Improved Denoising Diffusion Probabilistic Models" +# by Nichol and Dhariwal to stabilize and improve the training and generalization process. +# https://github.com/huggingface/pytorch-image-models/blob/main/timm/utils/model_ema.py +class ModelEmaV2(nn.Module): + def __init__(self, model, decay=0.9999, device=None): + super(ModelEmaV2, self).__init__() + # make a copy of the model for accumulating moving average of weights + self.module = deepcopy(model) + self.module.eval() + self.decay = decay + self.device = device # perform ema on different device from model if set + if self.device is not None: + self.module.to(device=device) + + def _update(self, model, update_fn): + with torch.no_grad(): + for ema_v, model_v in zip(self.module.state_dict().values(), model.state_dict().values()): + if self.device is not None: + model_v = model_v.to(device=self.device) + ema_v.copy_(update_fn(ema_v, model_v)) + + def update(self, model): + self._update(model, update_fn=lambda e, m: self.decay * e + (1. - self.decay) * m) + + def set(self, model): + self._update(model, update_fn=lambda e, m: m) + + + +# Training function for the unconditional diffusion model + +def cdm_trainer(model, + device, + trainloader, testloader, + store_iter = 10, + eval_iter = 10, + epochs = 50, + optimizer_class=torch.optim.AdamW, + optimizer_params=None, + learning_rate = 0.001, + verbose = False, + run_name=None, + checkpoint= None, + experiment_path = None, + T_max = 5*10000, # None, + eta_min= 1e-5, + ema_training = True, + decay = 0.9999, + **args + ): + ''' + model: Properly initialized DDPM model. + store_iter: Stores the trained DDPM every store_iter epochs. + experiment_path: Path to the models experiment folder, where the trained model will be stored every store_iter epochs + eval_iter: Evaluates the trained DDPM on testing data every eval_iter epochs. + epochs: Number of epochs we train the model further. + optimizer_class: PyTorch optimizer. + optimizer_param: Parameters for the PyTorch optimizer. + learning_rate: For optimizer initialization when training from zero, i.e. no checkpoint + verbose: If True, prints the running losses for every epoch. + run_name: Run name for WandB. IF YOU TRAIN FROM CHECKPOINT MAKE SURE TO USE THE SAME + 'run_name' FOR THE DATA TO BE LOGGED ON THE SAME WANDB RUN! + trainloader: Loads the train dataset + testloader: Loads the test dataset + checkpoint: Name of the saved pth. file containing the trained weights and biases + T_max: CosineAnnealingLR scheduler argument (nr of steps in training for a full cycle) + eta_min: CosineAnnealingLR scheduler argument (scheduler oscillates between highest lr 'leraning_rate' and minimum lr 'eta_min') + decay: EMA decay rate that is used to weight the effect of the ema model when computing the weighted avg between trained and + ema weights for the networks weight update + ''' + + # set optimizer parameters and learning rate + if optimizer_params is None: + optimizer_params = dict(lr=learning_rate) + optimizer = optimizer_class(model.net.parameters(), **optimizer_params) + + # set lr cosine schedule (comonly used in diffusion models) + scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=T_max, eta_min=eta_min) + + # set ema model for training + if ema_training: + ema = ModelEmaV2(model, decay=decay, device = model.device) + + # if checkpoint path is given, load the model from checkpoint + last_epoch = -1 + if checkpoint: + try: + checkpoint_path = f'{experiment_path}trained_cdm/{checkpoint}' + # Load the checkpoint + checkpoint = torch.load(checkpoint_path) + # update last_epoch + last_epoch = checkpoint['epoch'] + # load weights and biases of the U-Net + model_state_dict = checkpoint['model'] + model.net.load_state_dict(model_state_dict) + model = model.to(device) + # load optimizer state + optimizer_state_dict = checkpoint['optimizer'] + optimizer.load_state_dict(optimizer_state_dict) + # load learning rate schedule state + scheduler_state_dict = checkpoint['scheduler'] + scheduler.load_state_dict(scheduler_state_dict) + scheduler.last_epoch = last_epoch + # load ema model state + if ema_training: + ema.module.load_state_dict(checkpoint['ema']) + except Exception as e: + print("Error loading checkpoint. Exception: ", e) + + # pick kl loss function + if model.kl_loss == 'weighted': + loss_func = model.loss_weighted + else: + loss_func = model.loss_simplified + + # pick lowest timestep + low = 1 + if model.recon_loss == 'nll': + low = 0 + + # Using W&B + with wandb.init(project='Unconditional Landscapes', name=run_name, entity='deep-lab-', id=run_name, resume=True) as run: + + # Log some info + run.config.learning_rate = learning_rate + #run.config.update({"learning_rate": learning_rate}, allow_val_change=True) + run.config.optimizer = optimizer.__class__.__name__ + #run.watch(model.net) + + # training loop + # last model was stored at epoch last_epoch, we continue training from there, i.e. last_epoch+1 (else we start at epoch 0) + for epoch in range(last_epoch+1, (last_epoch+1)+epochs): + running_trainloss = 0 + nr_train_batches = 0 + + # train + model.net.train() + for idx,(x_0, y) in enumerate(trainloader): + x_0 = x_0.to(device) + y = y.to(device) + t = torch.randint(low=low, high=model.diffusion_steps, size=(x_0.shape[0],), device = device) + optimizer.zero_grad() + + # define masks for zero and non-zero elements of t + mask_zero_t = (t == 0) + mask_non_zero_t = (t != 0) + t[mask_zero_t] = 1 + + # apply noise + x_t, forward_noise = model.forward_trajectory(x_0,t) + + # compute denoising step at time t under CFGD + rand_prob = torch.rand(x_0.shape[0]).to(device) + mask_condition = (rand_prob <= 0.9) + pred_noise = torch.zeros_like(x_t).to(device) + # for every image with porb. of 90% we apply forward pass cond. on class, 10% prob. without class + if torch.any(mask_condition): + pred_noise[mask_condition] = model.forward(x_t[mask_condition], t[mask_condition], y = y[mask_condition]) + if torch.any(~mask_condition): + pred_noise[~mask_condition] = model.forward(x_t[~mask_condition], t[~mask_condition], y=None) + mean, std = model.reverse_dist_param(x_t, pred_noise, t) + + loss = 0 + # compute kl loss + if torch.any(mask_non_zero_t): + loss = loss_func(forward_noise[mask_non_zero_t], pred_noise[mask_non_zero_t], t[mask_non_zero_t]) + running_trainloss += loss.item() + nr_train_batches += 1 + run.log({'loss': loss.item(), "learning_rate": scheduler.get_last_lr()[0], 'epoch': epoch, 'batch': idx}) + + # if reconstrcution loss was drawn + if torch.any(mask_zero_t): + recon_loss = model.loss_recon(x_0[mask_zero_t], mean[mask_zero_t], std[mask_zero_t]) + loss += recon_loss + run.log({'recon_loss': recon_loss.item(), 'epoch': epoch, 'batch': idx}) + + loss.backward() + optimizer.step() + if ema_training: + ema.update(model) + scheduler.step() + + if verbose: + print(f"Loss in epoch {epoch}:{running_trainloss/nr_train_batches}") + run.log({'running_loss': running_trainloss/nr_train_batches}) + + # evaluation + if ((epoch+1) % eval_iter == 0) or ((epoch+1) % store_iter == 0): + running_testloss = 0 + nr_test_batches = 0 + + model.net.eval() + with torch.no_grad(): + for idx,(x_0,y) in enumerate(testloader): + x_0 = x_0.to(device) + y=y.to(device) + t = torch.randint(low=low, high=model.diffusion_steps, size=(x_0.shape[0],), device = device) + + # Define masks for zero and non-zero elements of t + mask_zero_t = (t == 0) + mask_non_zero_t = (t != 0) + t[mask_zero_t] = 1 + + # apply noise + x_t, forward_noise = model.forward_trajectory(x_0,t) + + # compute denoising step at time t under CFGD + rand_prob = torch.rand(x_0.shape[0]) + mask_condition = (rand_prob <= 0.9) + pred_noise = torch.zeros_like(x_t) + # for every image with porb. of 90% we apply forward pass cond. on class, 10% prob. without class + if torch.any(mask_condition): + pred_noise[mask_condition] = model.forward(x_t[mask_condition], t[mask_condition], y = y[mask_condition]) + if torch.any(~mask_condition): + pred_noise[~mask_condition] = model.forward(x_t[~mask_condition], t[~mask_condition], y=None) + mean, std = model.reverse_dist_param(x_t, pred_noise, t) + + loss = 0 + # Compute kl loss + if torch.any(mask_non_zero_t): + loss = loss_func(forward_noise[mask_non_zero_t], pred_noise[mask_non_zero_t], t[mask_non_zero_t]) + running_testloss += loss.item() + nr_test_batches += 1 + run.log({'test_loss': loss.item(), "learning_rate": scheduler.get_last_lr()[0], 'epoch': epoch, 'batch': idx}) + + # If reconstrcution loss was drawn + if torch.any(mask_zero_t): + recon_loss = model.loss_recon(x_0[mask_zero_t], mean[mask_zero_t], std[mask_zero_t]) + loss += recon_loss + run.log({'recon_test_loss': recon_loss.item(), 'epoch': epoch, 'batch': idx}) + + if verbose: + print(f"Test loss in epoch {epoch}:{running_testloss/nr_test_batches}") + run.log({'running_test_loss': running_testloss/nr_test_batches}) + + # store model + if ((epoch+1) % store_iter == 0): + save_dir = os.path.join(experiment_path, 'trained_cdm/') + os.makedirs(save_dir, exist_ok=True) + torch.save({ + 'epoch': epoch, + 'model': model.net.state_dict(), + 'optimizer': optimizer.state_dict(), + 'scheduler': scheduler.state_dict(), + 'ema' : ema.module.state_dict(), + 'running_loss': running_trainloss/nr_train_batches, + 'running_test_loss': running_testloss/nr_test_batches, + }, os.path.join(save_dir, f"model_epoch_{epoch}.pth")) + + # always store the last version of the model if we trained through all epochs + final = (last_epoch+1)+epochs + save_dir = os.path.join(experiment_path, 'trained_cdm/') + os.makedirs(save_dir, exist_ok=True) + torch.save({ + 'epoch': final, + 'model': model.net.state_dict(), + 'optimizer': optimizer.state_dict(), + 'scheduler': scheduler.state_dict(), + 'ema' : ema.module.state_dict(), + 'running_loss': running_trainloss/nr_train_batches, + 'running_test_loss': running_testloss/nr_test_batches, + }, os.path.join(save_dir, f"model_epoch_{final}.pth")) + + -- GitLab