Skip to content
Snippets Groups Projects
Commit fb2e6b86 authored by Gonzalo Martin Garcia's avatar Gonzalo Martin Garcia
Browse files

Completed Class Conditional Diffusion Model. This version is intended to be...

Completed Class Conditional Diffusion Model. This version is intended to be trained on a class labeled dataset. It makes use of classifier-free guided diffusion to boost sample quality of the model when generating images from each class. The conditioning mechanism in the UNet (implemented together with Roy) simply adds an embedding of the class to the time embedding before passing them through the learnable reshaping layers for each block. Successfully trained on 3 class dataset, dogs, cats and wildlife.
parent 80f54839
Branches
No related tags found
No related merge requests found
.DS_Store
*/__pycache__
*/trained_ddpm
root
experiments
trainer/__pycache__
wandb
# Conditional Diffusion
## Getting started
To make it easy for you to get started with GitLab, here's a list of recommended next steps.
Already a pro? Just edit this README.md and make it your own. Want to make it easy? [Use the template at the bottom](#editing-this-readme)!
## Add your files
- [ ] [Create](https://docs.gitlab.com/ee/user/project/repository/web_editor.html#create-a-file) or [upload](https://docs.gitlab.com/ee/user/project/repository/web_editor.html#upload-a-file) files
- [ ] [Add files using the command line](https://docs.gitlab.com/ee/gitlab-basics/add-file.html#add-a-file-using-the-command-line) or push an existing Git repository with the following command:
```
cd existing_repo
git remote add origin https://git.rwth-aachen.de/diffusion-project/conditional-diffusion.git
git branch -M main
git push -uf origin main
```
## Integrate with your tools
- [ ] [Set up project integrations](https://git.rwth-aachen.de/diffusion-project/conditional-diffusion/-/settings/integrations)
## Collaborate with your team
- [ ] [Invite team members and collaborators](https://docs.gitlab.com/ee/user/project/members/)
- [ ] [Create a new merge request](https://docs.gitlab.com/ee/user/project/merge_requests/creating_merge_requests.html)
- [ ] [Automatically close issues from merge requests](https://docs.gitlab.com/ee/user/project/issues/managing_issues.html#closing-issues-automatically)
- [ ] [Enable merge request approvals](https://docs.gitlab.com/ee/user/project/merge_requests/approvals/)
- [ ] [Automatically merge when pipeline succeeds](https://docs.gitlab.com/ee/user/project/merge_requests/merge_when_pipeline_succeeds.html)
## Test and Deploy
Use the built-in continuous integration in GitLab.
- [ ] [Get started with GitLab CI/CD](https://docs.gitlab.com/ee/ci/quick_start/index.html)
- [ ] [Analyze your code for known vulnerabilities with Static Application Security Testing(SAST)](https://docs.gitlab.com/ee/user/application_security/sast/)
- [ ] [Deploy to Kubernetes, Amazon EC2, or Amazon ECS using Auto Deploy](https://docs.gitlab.com/ee/topics/autodevops/requirements.html)
- [ ] [Use pull-based deployments for improved Kubernetes management](https://docs.gitlab.com/ee/user/clusters/agent/)
- [ ] [Set up protected environments](https://docs.gitlab.com/ee/ci/environments/protected_environments.html)
***
# Editing this README
When you're ready to make this README your own, just edit this file and use the handy template below (or feel free to structure it however you want - this is just a starting point!). Thank you to [makeareadme.com](https://www.makeareadme.com/) for this template.
## Suggestions for a good README
Every project is different, so consider which of these sections apply to yours. The sections used in the template are suggestions for most open source projects. Also keep in mind that while a README can be too long and detailed, too long is better than too short. If you think your README is too long, consider utilizing another form of documentation rather than cutting out information.
## Name
Choose a self-explaining name for your project.
## Description
Let people know what your project can do specifically. Provide context and add a link to any reference visitors might be unfamiliar with. A list of Features or a Background subsection can also be added here. If there are alternatives to your project, this is a good place to list differentiating factors.
## Badges
On some READMEs, you may see small images that convey metadata, such as whether or not all the tests are passing for the project. You can use Shields to add some to your README. Many services also have instructions for adding a badge.
## Visuals
Depending on what you are making, it can be a good idea to include screenshots or even a video (you'll frequently see GIFs rather than actual videos). Tools like ttygif can help, but check out Asciinema for a more sophisticated method.
## Installation
Within a particular ecosystem, there may be a common way of installing things, such as using Yarn, NuGet, or Homebrew. However, consider the possibility that whoever is reading your README is a novice and would like more guidance. Listing specific steps helps remove ambiguity and gets people to using your project as quickly as possible. If it only runs in a specific context like a particular programming language version or operating system or has dependencies that have to be installed manually, also add a Requirements subsection.
## Usage
Use examples liberally, and show the expected output if you can. It's helpful to have inline the smallest example of usage that you can demonstrate, while providing links to more sophisticated examples if they are too long to reasonably include in the README.
## Support
Tell people where they can go to for help. It can be any combination of an issue tracker, a chat room, an email address, etc.
## Roadmap
If you have ideas for releases in the future, it is a good idea to list them in the README.
## Contributing
State if you are open to contributions and what your requirements are for accepting them.
For people who want to make changes to your project, it's helpful to have some documentation on how to get started. Perhaps there is a script that they should run or some environment variables that they need to set. Make these steps explicit. These instructions could also be useful to your future self.
You can also document commands to lint the code or run tests. These steps help to ensure high code quality and reduce the likelihood that the changes inadvertently break something. Having instructions for running tests is especially helpful if it requires external setup, such as starting a Selenium server for testing in a browser.
## Authors and acknowledgment
Show your appreciation to those who have contributed to the project.
## License
For open source projects, say how it is licensed.
## Project status
If you have run out of energy or time for your project, put a note at the top of the README saying that development has slowed down or stopped completely. Someone may choose to fork your project or volunteer to step in as a maintainer or owner, allowing your project to keep going. You can also make an explicit request for maintainers.
import torch
from torch.utils.data import Dataset
from torchvision import transforms
import os
from PIL import Image
import pandas as pd
import numpy as np
class ConditionalDataset(Dataset):
def __init__(self,fpath,img_size,train,frac =0.8,skip_first_n = 0,ext = ".png",transform=True ):
"""
Args:
fpath (string): Path to the folder where images are stored
img_size (int): size of output image img_size=height=width
ext (string): type of images used(eg .png)
transform (Bool): Image augmentation for diffusion model
skip_first_n: skips the first n values. Usefull for datasets that are sorted by increasing Likeliehood
train (Bool): Choose dataset to be either train set or test set. frac(float) required
frac (float): value within (0,1] (seeded)random shuffles dataset, then divides into train and test set.
"""
if train:
fpath = os.path.join(fpath, 'train')
else:
fpath = os.path.join(fpath, 'valid')
self.class_to_idx = {'cat': 0, 'dog': 1, 'wild': 2}
file_list =[]
class_list = []
for root, dirs, files in os.walk(fpath, topdown=False):
for name in sorted(files):
if name.endswith(ext):
file_list.append(os.path.join(root, name))
class_list.append(self.class_to_idx[os.path.basename(root)])
#df = pd.DataFrame({"Filepath":file_list},)
#self.df = df[df["Filepath"].str.endswith(ext)]
self.df = pd.DataFrame({"Filepath": file_list})
self.class_list = class_list
if transform:
# for training
intermediate_size = 137
theta = np.pi/4 -np.arccos(intermediate_size/(np.sqrt(2)*img_size)) #Check dataloading.ipynb in analysis-depot for more details
transform_rotate_flip = transforms.Compose([transforms.ToTensor(),
transforms.Resize(intermediate_size,antialias=True),
transforms.RandomRotation((theta/np.pi*180),interpolation=transforms.InterpolationMode.BILINEAR),
transforms.CenterCrop(img_size),
transforms.RandomHorizontalFlip(p=0.5),
transforms.Normalize(mean=(0.5,0.5,0.5),std=(0.5,0.5,0.5))])
transform_flip = transforms.Compose([transforms.ToTensor(),
transforms.Resize(img_size, antialias=True),
transforms.RandomHorizontalFlip(p=0.5),
transforms.Normalize(mean=(0.5,0.5,0.5),std=(0.5,0.5,0.5))])
self.transform = transforms.RandomChoice([transform_rotate_flip,transform_flip])
else :
# for evaluation
self.transform = transforms.Compose([transforms.ToTensor(),
transforms.Lambda(lambda x: (x * 255).type(torch.uint8)),
transforms.Resize(img_size)])
def __len__(self):
return len(self.df)
def __getitem__(self,idx):
path = self.df.iloc[idx].Filepath
img = Image.open(path)
class_idx = self.class_list[idx]
return self.transform(img), class_idx
def tensor2PIL(self,img):
back2pil = transforms.Compose([transforms.Normalize(mean=(-1,-1,-1),std=(2,2,2)),transforms.ToPILImage()])
return back2pil(img)
from torchmetrics.image.fid import FrechetInceptionDistance
from torchmetrics.image.inception import InceptionScore
from torchmetrics.image.kid import KernelInceptionDistance
import re
import os
from PIL import Image
from torchvision import transforms
import torch
def cdm_evaluator(model,
device,
dataloader,
checkpoint,
experiment_path,
sample_idx=0,
**args,
):
'''
Takes a trained diffusion model from 'checkpoint' and evaluates its performance on the test
dataset 'dataloader' w.r.t. the three most important perfromance metrics; FID, IS, KID. We continue
the progress of our evaluation function for the LDM upscalaer and may update this function accordingly.
checkpoint: Name of the saved pth. file containing the trained weights and biases
experiment_path: Path to the experiment folder where the evaluation results will be stored
dataloader: Loads the test dataset for evaluation
sample_idx: Integer that denotes which sample directory sample_{sample_idx} from the checkpoint model shall be used for evaluation
'''
checkpoint_path = f'{experiment_path}trained_cdm/{checkpoint}'
# create evaluation directory for the complete experiment (if first time sampling images)
output_dir = f'{experiment_path}evaluations/'
os.makedirs(output_dir, exist_ok=True)
# create evaluation directory for the current version of the trained model
model_name = os.path.basename(checkpoint_path)
epoch = re.findall(r'\d+', model_name)
if epoch:
e = int(epoch[0])
else:
raise ValueError(f"No digit found in the filename: {filename}")
model_dir = os.path.join(output_dir,f'epoch_{e}')
os.makedirs(model_dir, exist_ok=True)
# create the evaluation directory for this evaluation run for the current version of the model
eval_dir_list = [d for d in os.listdir(model_dir) if os.path.isdir(os.path.join(model_dir, d))]
indx_list = [int(d.split('_')[1]) for d in eval_dir_list if d.startswith('evaluation_')]
j = max(indx_list, default=-1) + 1
eval_dir = os.path.join(model_dir, f'evaluation_{j}')
os.makedirs(eval_dir, exist_ok=True)
# Compute metrics
eval_path = os.path.join(eval_dir, 'eval.txt')
# get sampled images
transform = transforms.Compose([transforms.ToTensor(), transforms.Lambda(lambda x: (x * 255).type(torch.uint8))])
sample_path = os.path.join(f'{experiment_path}samples/',f'epoch_{e}',f'sample_{sample_idx}')
ignore_tensor = f'image_tensor{j}'
images = []
for samplename in os.listdir(sample_path):
if samplename == ignore_tensor:
continue
img = Image.open(os.path.join(sample_path, samplename))
img = transform(img)
images.append(img)
# split them into batches for GPU memory
generated = torch.stack(images).to(device)
generated_batches = torch.split(generated, dataloader.batch_size)
nr_generated_batches = len(generated_batches)
nr_real_batches = len(dataloader)
# Init FID, IS and KID scores
fid = FrechetInceptionDistance(normalize = False).to(device)
iscore = InceptionScore(normalize=False).to(device)
kid = KernelInceptionDistance(normalize=False, subset_size=32).to(device)
# Update scores for the full testing dataset w.r.t. the sampled batches
for idx,(data, _) in enumerate(dataloader):
data = data.to(device)
fid.update(data, real=True)
kid.update(data, real=True)
if idx < nr_generated_batches:
gen = generated_batches[idx].to(device)
fid.update(gen, real=False)
kid.update(gen, real=False)
iscore.update(gen)
# If there are sampled images left, add them too
for idx in range(nr_real_batches, nr_generated_batches):
gen = generated_batches[idx].to(device)
fid.update(gen, real=False)
kid.update(gen, real=False)
iscore.update(gen)
# compute total FID, IS and KID
fid_score = fid.compute()
i_score = iscore.compute()
kid_score = kid.compute()
# store results in txt file
with open(str(eval_path), 'a') as txt:
result = f'FID_epoch_{e}_sample_{sample_idx}:'
txt.write(result + str(fid_score.item()) + '\n')
result = f'KID_epoch_{e}_sample_{sample_idx}:'
txt.write(result + str(kid_score) + '\n')
result = f'IS_epoch_{e}_sample_{sample_idx}:'
txt.write(result + str(i_score) + '\n')
import os
import torch
from torchvision import transforms
import re
def cdm_sampler(model, checkpoint, experiment_path, device, intermediate=False, batch_size=15,sample_all=False,n_times=1):
'''
Samples a tensor of 'batch_size' images from a trained diffusion model with 'checkpoint'. The generated
images are stored in the directory 'experiment_path/samples/epoch_{e}/sample_{j}. Where e is the epoch
w.r.t. the model which we are sampling form and j is an index separating images from each call of the
sampling function for the given mode.
model: Diffusion model
checkpoint: Name of the saved pth. file containing the trained weights and biases
conditioning_path: Path to the conditioning input tensors from whihc the samples will be generated from
experiment_path: Path to the experiment directory where the samples will saved under the diectory samples
batch_size: The number of images to sample
intermediate: Bool value. If False the sampling function will draw a batch of images, else it will just
sample a single image, but store all the intermediate noised latents along the reverse chain
sample_all: If True, samples a batch of images for the given model at every stored checkpoint at once
n_times: Integer denoting how many times we draw a batch of 'batch_size'. If we want to draw 10k images
the GPU will draw batches of 512 images 20 times to reach this goal.
'''
# If we want to sample from every checkpoint of the current model, recursively call this function for all checkpoints
if sample_all:
f = f'{experiment_path}trained_cdm/'
checkpoint_list = [checkpoint_i for checkpoint_i in os.listdir(f) if checkpoint_i.endswith(".pth")]
for checkpoint_i in os.listdir(f):
if checkpoint_i.endswith(".pth"):
ldm_sampler(model, checkpoint_i, experiment_path, dataloader, device, batch_size=batch_size, intermediate=intermediate, sample_all=False)
return 0
# load model
try:
checkpoint_path = f'{experiment_path}trained_cdm/{checkpoint}'
checkpoint = torch.load(checkpoint_path)
# load weights and biases of the U-Net
net_state_dict = checkpoint['model']
model.net.load_state_dict(net_state_dict)
model = model.to(device)
except Exception as e:
print("Error loading checkpoint. Exception:", e)
# create samples directory for the complete experiment (if first time sampling images)
output_dir = f'{experiment_path}samples/'
#output_dir = os.path.join(experiment_path,'/samples/')
os.makedirs(output_dir, exist_ok=True)
# create sample directory for the current version of the trained model
model_name = os.path.basename(checkpoint_path)
epoch = re.findall(r'\d+', model_name)
if epoch:
e = int(epoch[0])
else:
#raise ValueError(f"No digit found in the filename: {filename}")
e = 0
model_dir = os.path.join(output_dir,f'epoch_{e}')
os.makedirs(model_dir, exist_ok=True)
# create the sample directory for this sampling run for the current version of the model
sample_dir_list = [d for d in os.listdir(model_dir) if os.path.isdir(os.path.join(model_dir, d))]
indx_list = [int(d.split('_')[1]) for d in sample_dir_list if d.startswith('sample_')]
j = max(indx_list, default=-1) + 1
sample_dir = os.path.join(model_dir, f'sample_{j}')
os.makedirs(sample_dir, exist_ok=True)
# transform
back2pil = transforms.Compose([transforms.Normalize(mean=(-1,-1,-1),std=(2,2,2)),transforms.ToPILImage()])
# sample upscaling latent encoding, decode and transform to PIL
run_indx = 0
for k in range(n_times):
# sample random classes
y = torch.randint(0, 3, (batch_size,)).to(device)
# generate images
generated = model.sample(y=y, batch_size=y.size(0))
# save images
for i in range(generated.size(0)):
image = back2pil(generated[i])
image_path = os.path.join(sample_dir, f'sample_{j}_{run_indx}_{y[i]}.png')
try:
image.save(image_path)
except Exception as e:
print("Error saving image. Exception:", e)
run_indx += 1
%% Cell type:code id: tags:
``` python
from trainer.train import *
from dataloader.load import *
from models.Framework import *
from models.all_unets import *
import torch
from torch import nn
```
%% Cell type:markdown id: tags:
# Prepare experiment
1. Choose Hyperparameter Settings
2. Run notebook on local maschine to generate experiment folder with the JSON files containing the settings
3. scp experiment folder to the HPC
4. Run Pipeline by adding following to batch file:
- Train Model: &emsp;&emsp;&emsp;&emsp;&emsp; `python main.py train "<absolute path of experiment folder in hpc>"`
- Sample Images: &emsp;&emsp;&emsp; `python main.py sample "<absolute path of experiment folder in hpc>"`
- Evaluate Model: &emsp;&emsp;&emsp; `python main.py evaluate "<absolute path of experiment folder in hpc>"`
%% Cell type:code id: tags:
``` python
import torch
####
# Settings
####
# Dataset path
datapath = "/work/lect0100/lhq_256"
# Experiment setup
run_name = 'main_test1' # WANDB and experiment folder Name!
checkpoint = None #'model_epoch_8.pth' # Name of checkpoint pth file or None
experiment_path = "/work/lect0100/main_experiment/" + run_name +'/'
# Path to save generated experiment folder on local machine
local_path ="experiments/" + run_name + '/settings'
# Diffusion Model Settings
diffusion_steps = 1000
image_size = 128
channels = 3
# Training
batchsize = 32
epochs = 100
store_iter = 10
eval_iter = 500
learning_rate = 0.0001
optimizername = "torch.optim.AdamW"
optimizer_params = None
verbose = False
# checkpoint = None #(If no checkpoint training, ie. random weights)
# Sampling
sample_size = 20
intermediate = False # True if you want to sample one image and all ist intermediate latents
sample_all=False
# Evaluating
...
###
# Advanced Settings Dictionaries
###
meta_setting = dict(modelname = "UNet_Res",
dataset = "UnconditionalDataset",
framework = "DDPM",
trainloop_function = "ddpm_trainer",
sampling_function = 'ddpm_sampler',
evaluation_function = 'ddpm_evaluator',
batchsize = batchsize
)
dataset_setting = dict(fpath = datapath,
img_size = image_size,
frac =0.8,
skip_first_n = 0,
ext = ".png",
transform=True
)
model_setting = dict( n_channels=64,
fctr = [1,2,4,4,8],
time_dim=256,
attention = True,
)
"""
outdated
model_setting = dict( channels_in=channels,
channels_out =channels ,
activation='relu', # activation function. Options: {'relu', 'leakyrelu', 'selu', 'gelu', 'silu'/'swish'}
weight_init='he', # weight initialization. Options: {'he', 'torch'}
projection_features=64, # number of image features after first convolution layer
time_dim=batchsize, #dont chnage!!!
time_channels=diffusion_steps, # number of time channels #TODO same as diffusion steps?
num_stages=4, # number of stages in contracting/expansive path
stage_list=None, # specify number of features produced by stages
num_blocks=1, # number of ConvResBlock in each contracting/expansive path
num_groupnorm_groups=32, # number of groups used in Group Normalization inside a ConvResBlock
dropout=0.1, # drop-out to be applied inside a ConvResBlock
attention_list=None, # specify MHA pattern across stages
num_attention_heads=1,
)
"""
framework_setting = dict(
diffusion_steps = diffusion_steps, # dont change!!
out_shape = (channels,image_size,image_size), # dont change!!
noise_schedule = 'linear',
beta_1 = 1e-4,
beta_T = 0.02,
alpha_bar_lower_bound = 0.9,
var_schedule = 'same',
kl_loss = 'simplified',
recon_loss = 'none',
)
training_setting = dict(
epochs = epochs,
store_iter = store_iter,
eval_iter = eval_iter,
optimizer_class=optimizername,
optimizer_params = optimizer_params,
#optimizer_params=dict(lr=learning_rate), # don't change!
learning_rate = learning_rate,
run_name=run_name,
checkpoint= checkpoint,
experiment_path = experiment_path,
verbose = verbose,
T_max = 0.8*90000/32*100, # cosine lr param len(train_ds)/batchsize * total epochs to 0
eta_min= 1e-10, # cosine lr param
)
sampling_setting = dict(
checkpoint = checkpoint,
experiment_path = experiment_path,
batch_size = sample_size,
intermediate = intermediate,
sample_all = sample_all
)
# TODO
evaluation_setting = dict(
checkpoint = checkpoint,
experiment_path = experiment_path,
)
```
%% Cell type:code id: tags:
``` python
import os
import json
f = local_path
if os.path.exists(f):
print("path already exists, pick a new name!")
print("break")
else:
print("create folder")
#os.mkdir(f)
os.makedirs(f, exist_ok=True)
print("folder created ")
with open(f+"/meta_setting.json","w+") as fp:
json.dump(meta_setting,fp)
with open(f+"/dataset_setting.json","w+") as fp:
json.dump(dataset_setting,fp)
with open(f+"/model_setting.json","w+") as fp:
json.dump(model_setting,fp)
with open(f+"/framework_setting.json","w+") as fp:
json.dump(framework_setting,fp)
with open(f+"/training_setting.json","w+") as fp:
json.dump(training_setting,fp)
with open(f+"/sampling_setting.json","w+") as fp:
json.dump(sampling_setting,fp)
with open(f+"/evaluation_setting.json","w+") as fp:
json.dump(evaluation_setting,fp)
print("stored json files in folder")
print(meta_setting)
print(dataset_setting)
print(model_setting)
print(framework_setting)
print(training_setting)
print(sampling_setting)
print(evaluation_setting)
```
%% Output
create folder
folder created
stored json files in folder
{'modelname': 'UNet_Res', 'dataset': 'UnconditionalDataset', 'framework': 'DDPM', 'trainloop_function': 'ddpm_trainer', 'sampling_function': 'ddpm_sampler', 'evaluation_function': 'ddpm_evaluator', 'batchsize': 32}
{'fpath': '/work/lect0100/lhq_256', 'img_size': 128, 'frac': 0.8, 'skip_first_n': 0, 'ext': '.png', 'transform': True}
{'n_channels': 64, 'fctr': [1, 2, 4, 4, 8], 'time_dim': 256, 'attention': True}
{'diffusion_steps': 1000, 'out_shape': (3, 128, 128), 'noise_schedule': 'linear', 'beta_1': 0.0001, 'beta_T': 0.02, 'alpha_bar_lower_bound': 0.9, 'var_schedule': 'same', 'kl_loss': 'simplified', 'recon_loss': 'none'}
{'epochs': 100, 'store_iter': 10, 'eval_iter': 500, 'optimizer_class': 'torch.optim.AdamW', 'optimizer_params': None, 'learning_rate': 0.0001, 'run_name': 'main_test1', 'checkpoint': None, 'experiment_path': '/work/lect0100/main_experiment/main_test1/', 'verbose': False, 'T_max': 225000.0, 'eta_min': 1e-10}
{'checkpoint': None, 'experiment_path': '/work/lect0100/main_experiment/main_test1/', 'batch_size': 20, 'intermediate': False}
{'checkpoint': None, 'experiment_path': '/work/lect0100/main_experiment/main_test1/'}
%% Cell type:code id: tags:
``` python
```
main.py 0 → 100644
import json
import sys
from dataloader.load import *
from models.ConditionalDiffusionModel import *
from trainer.train import cdm_trainer
from evaluation.sample import cdm_sampler
from evaluation.evaluate import cdm_evaluator
from models.conditional_unet import *
import torch
def train_func(f):
#load all settings
device = 'cuda' if torch.cuda.is_available() else 'cpu'
print(f"device: {device}\n\n")
print(f"folderpath: {f}\n\n")
with open(f+"/meta_setting.json","r") as fp:
meta_setting = json.load(fp)
with open(f+"/dataset_setting.json","r") as fp:
dataset_setting = json.load(fp)
with open(f+"/model_setting.json","r") as fp:
model_setting = json.load(fp)
with open(f+"/framework_setting.json","r") as fp:
framework_setting = json.load(fp)
with open(f+"/training_setting.json","r") as fp:
training_setting = json.load(fp)
training_setting["optimizer_class"] = eval(training_setting["optimizer_class"])
batchsize = meta_setting["batchsize"]
training_dataset = globals()[meta_setting["dataset"]](train = True,**dataset_setting)
test_dataset = globals()[meta_setting["dataset"]](train = False,**dataset_setting)
training_dataloader = torch.utils.data.DataLoader(training_dataset,batch_size=batchsize,shuffle=True)
test_dataloader = torch.utils.data.DataLoader(test_dataset,batch_size=batchsize,shuffle=True)
net = globals()[meta_setting["modelname"]](**model_setting).to(device)
#net = torch.compile(net)
net = net.to(device)
framework = globals()[meta_setting["framework"]](net = net,device=device, **framework_setting)
print(f"META SETTINGS:\n\n {meta_setting}\n\n")
print(f"DATASET SETTINGS:\n\n {dataset_setting}\n\n")
print(f"MODEL SETTINGS:\n\n {model_setting}\n\n")
print(f"FRAMEWORK SETTINGS:\n\n {framework_setting}\n\n")
print(f"TRAINING SETTINGS:\n\n {training_setting}\n\n")
print("\n\nSTART TRAINING\n\n")
globals()[meta_setting["trainloop_function"]](model=framework,device=device, trainloader = training_dataloader, testloader = test_dataloader,safepath = f,**training_setting,)
print("\n\nFINISHED TRAINING\n\n")
def sample_func(f):
device = 'cuda' if torch.cuda.is_available() else 'cpu'
print(f"device: {device}\n\n")
print(f"folderpath: {f}\n\n")
with open(f+"/meta_setting.json","r") as fp:
meta_setting = json.load(fp)
with open(f+"/model_setting.json","r") as fp:
model_setting = json.load(fp)
with open(f+"/framework_setting.json","r") as fp:
framework_setting = json.load(fp)
with open(f+"/sampling_setting.json","r") as fp:
sampling_setting = json.load(fp)
# init Unet
net = globals()[meta_setting["modelname"]](**model_setting).to(device)
#net = torch.compile(net)
net = net.to(device)
# init unconditional diffusion model
framework = globals()[meta_setting["framework"]](net = net,device=device, **framework_setting)
print(f"META SETTINGS:\n\n {meta_setting}\n\n")
print(f"MODEL SETTINGS:\n\n {model_setting}\n\n")
print(f"FRAMEWORK SETTINGS:\n\n {framework_setting}\n\n")
print(f"SAMPLING SETTINGS:\n\n {sampling_setting}\n\n")
print("\n\nSTART SAMPLING\n\n")
globals()[meta_setting["sampling_function"]](model=framework,device=device ,**sampling_setting,)
print("\n\nFINISHED SAMPLING\n\n")
def evaluate_func(f):
device = 'cuda' if torch.cuda.is_available() else 'cpu'
print(f"device: {device}\n\n")
print(f"folderpath: {f}\n\n")
with open(f+"/meta_setting.json","r") as fp:
meta_setting = json.load(fp)
with open(f+"/model_setting.json","r") as fp:
model_setting = json.load(fp)
with open(f+"/framework_setting.json","r") as fp:
framework_setting = json.load(fp)
with open(f+"/evaluation_setting.json","r") as fp:
evaluation_setting = json.load(fp)
with open(f+"/dataset_setting.json","r") as fp:
dataset_setting = json.load(fp)
# load dataset
batchsize = meta_setting["batchsize"]
test_dataset = globals()[meta_setting["dataset"]](train = False,**dataset_setting)
#test_dataloader = torch.utils.data.DataLoader(test_dataset,batch_size=len(test_dataset), shuffle=False)
test_dataloader = torch.utils.data.DataLoader(test_dataset,batch_size=batchsize, shuffle=False)
# init Unet
net = globals()[meta_setting["modelname"]](**model_setting).to(device)
#net = torch.compile(net)
net = net.to(device)
# init unconditional diffusion model
framework = globals()[meta_setting["framework"]](net = net,device=device, **framework_setting)
print(f"META SETTINGS:\n\n {meta_setting}\n\n")
print(f"DATASET SETTINGS:\n\n {dataset_setting}\n\n")
print(f"MODEL SETTINGS:\n\n {model_setting}\n\n")
print(f"FRAMEWORK SETTINGS:\n\n {framework_setting}\n\n")
print(f"EVALUATION SETTINGS:\n\n {evaluation_setting}\n\n")
print("\n\nSTART EVALUATION\n\n")
globals()[meta_setting["evaluation_function"]](model=framework, device=device, dataloader = test_dataloader,safepath = f,**evaluation_setting,)
print("\n\nFINISHED EVALUATION\n\n")
if __name__ == '__main__':
print(sys.argv)
functions = {'train': train_func,'sample': sample_func,'evaluate': evaluate_func}
functions[sys.argv[1]](sys.argv[2])
This diff is collapsed.
import torch
import torch.nn as nn
import torch.nn.functional as F
import torchvision.transforms as transforms
import einops
import numpy as np
# U-Net model
class Conditional_UNet_Res(nn.Module):
def __init__(self, attention,channels_in=3, nr_class=3,n_channels=64,fctr = [1,2,4,4,8],time_dim=256,**args):
"""
attention : (Bool) wether to use attention layers or not
channels_in : (Int)
n_channels : (Int) Channel size after first convolution
fctr : (list) list of factors for further channel size wrt n_channels
time_dim : (Int) dimenison size for time end class embeding vector
"""
super().__init__()
channels_out = channels_in
fctr = np.asarray(fctr)*n_channels
# learned time embedding
self.time_embedder = TimeEmbedding(time_dim = time_dim)
# learned class embedding
self.class_embedder = nn.Embedding(nr_class, time_dim) #akin to the OpenAI GLIDE Diffusion Model
# learnable embedding layers
self.tc_embedder0 = torch.nn.Sequential(nn.Linear(time_dim,fctr[0]),nn.SELU(),nn.Linear(fctr[0],fctr[0]))
self.tc_embedder1 = torch.nn.Sequential(nn.Linear(time_dim,fctr[1]),nn.SELU(),nn.Linear(fctr[1],fctr[1]))
self.tc_embedder2 = torch.nn.Sequential(nn.Linear(time_dim,fctr[2]),nn.SELU(),nn.Linear(fctr[2],fctr[2]))
self.tc_embedder3 = torch.nn.Sequential(nn.Linear(time_dim,fctr[3]),nn.SELU(),nn.Linear(fctr[3],fctr[3]))
self.tc_embedder4 = torch.nn.Sequential(nn.Linear(time_dim,fctr[4]),nn.SELU(),nn.Linear(fctr[4],fctr[4]))
# first conv block
self.first_conv = nn.Conv2d(channels_in,fctr[0],kernel_size=3, padding='same', bias=True)
#down blocks
self.down1 = DownsampleBlock_Res(fctr[0],fctr[1],time_dim)
self.down2 = DownsampleBlock_Res(fctr[1],fctr[2],time_dim)
self.down3 = DownsampleBlock_Res(fctr[2],fctr[3],time_dim,attention=attention)
self.down4 = DownsampleBlock_Res(fctr[3],fctr[4],time_dim,attention=attention)
#middle layer
self.mid1 = MidBlock_Res(fctr[4],time_dim,attention=attention)
#up blocks
self.up1 = UpsampleBlock_Res(fctr[1],fctr[0],time_dim)
self.up2 = UpsampleBlock_Res(fctr[2],fctr[1],time_dim)
self.up3 = UpsampleBlock_Res(fctr[3],fctr[2],time_dim,attention=attention)
self.up4 = UpsampleBlock_Res(fctr[4],fctr[3],time_dim,attention=attention)
# final 1x1 conv
self.end_conv = nn.Conv2d(fctr[0], channels_out, kernel_size=1,bias=True)
# Attention Layers
self.mha21 = MHABlock(fctr[2])
self.mha22 = MHABlock(fctr[2])
self.mha31 = MHABlock(fctr[3])
self.mha32 = MHABlock(fctr[3])
self.mha41 = MHABlock(fctr[4])
self.mha42 = MHABlock(fctr[4])
def forward(self, input, t, y):
# compute time mebedding
t_emb = self.time_embedder(t).to(input.device)
# compute class embedding if present
if y is not None:
c_emb = self.class_embedder(y).to(input.device)
else:
c_emb = torch.zeros_like(t_emb).to(input.device)
# combine both embeddings
tc_emb = t_emb + c_emb
# learnable layers
tc_emb0 = self.tc_embedder0(tc_emb)
tc_emb1 = self.tc_embedder1(tc_emb)
tc_emb2 = self.tc_embedder2(tc_emb)
tc_emb3 = self.tc_embedder3(tc_emb)
tc_emb4 = self.tc_embedder4(tc_emb)
# first two conv layers
x = self.first_conv(input) + tc_emb0[:,:,None,None]
#time and class mb
skip1,x = self.down1(x,tc_emb1)
skip2,x = self.down2(x,tc_emb2)
skip3,x = self.down3(x,tc_emb3)
skip4,x = self.down4(x,tc_emb4)
x = self.mid1(x,tc_emb4)
x = self.up4(x,skip4,tc_emb3)
x = self.up3(x,skip3,tc_emb2)
x = self.up2(x,skip2,tc_emb1)
x = self.up1(x,skip1,tc_emb0)
x = self.end_conv(x)
return x
#TimeEmbedding
class TimeEmbedding(nn.Module):
def __init__(self, time_dim=64):
super().__init__()
self.time_dim = time_dim
n = 10000
self.factor = torch.pow(n*torch.ones(size=(time_dim//2,)),(-2/time_dim*torch.arange(time_dim//2)))
def forward(self, t):
"""
input is t (B,)
factor dim (time_dim,)
output is (B,time_dim)
"""
self.factor = self.factor.to(t.device)
theta = torch.outer(t,self.factor)
# shape of embedding [time_channels, dim]
emb = torch.zeros(t.size(0), self.time_dim,device=t.device)
emb[:, 0::2] = torch.sin(theta)
emb[:, 1::2] = torch.cos(theta)
return emb
# Self Attention
class MHABlock(nn.Module):
def __init__(self,
channels_in,
num_attention_heads=1 # number of attention heads in MHA
):
super().__init__()
self.channels_in = channels_in
self.num_attention_heads = num_attention_heads
self.self_attention = nn.MultiheadAttention(channels_in, num_heads=self.num_attention_heads)
def forward(self, x):
skip = x
batch_size,_,height,width = x.size()
x = x.permute(2, 3, 0, 1).reshape(height * width, batch_size, -1)
attn_output, _ = self.self_attention(x, x, x)
attn_output = attn_output.reshape(batch_size, -1, height, width)
return attn_output+skip
# Residual Convolution Block
class ConvBlock_Res(nn.Module):
def __init__(self,
channels_in, # number of input channels fed into the block
channels_out, # number of output channels produced by the block
time_dim,
attention,
num_groups=32, # number of groups used in Group Normalization; channels_in must be divisible by num_groups
):
super().__init__()
self.attention = attention
if self.attention:
self.attlayer = MHABlock(channels_in=channels_out)
# Convolution layer 1
self.conv1 = nn.Conv2d(channels_in, channels_out, kernel_size=3, padding='same', bias=False)
self.gn1 = nn.GroupNorm(num_groups, channels_out)
self.act1 = nn.SiLU()
# Convolution layer 2
self.conv2 = nn.Conv2d(channels_out, channels_out, kernel_size=3, padding='same', bias=False)
self.gn2 = nn.GroupNorm(num_groups, channels_out)
self.act2 = nn.SiLU()
# Convolution layer 3
self.conv3 = nn.Conv2d(channels_out, channels_out, kernel_size=3, padding='same', bias=False)
self.gn3 = nn.GroupNorm(num_groups, channels_out)
self.act3 = nn.SiLU()
#Convolution skip
if channels_in!=channels_out:
self.res_skip = nn.Conv2d(channels_in,channels_out,kernel_size=1)
else:
self.res_skip = nn.Identity()
nn.init.xavier_normal_(self.conv1.weight)
nn.init.xavier_normal_(self.conv2.weight)
nn.init.xavier_normal_(self.conv3.weight)
def forward(self, x, t):
res = self.res_skip(x)
# second convolution layer
x = self.act1(self.gn1(self.conv1(x)))
h =x + t[:,:,None,None]
# third convolution layer
h = self.act2(self.gn2(self.conv2(h)))
h = self.act3(self.gn3(self.conv3(h)))
if self.attention:
h = self.attlayer(h)
return h +res
# Down Sample
class DownsampleBlock_Res(nn.Module):
def __init__(self, channels_in, channels_out,time_dim,attention=False):
super().__init__()
self.pool = nn.MaxPool2d((2,2), stride=2)
self.convblock = ConvBlock_Res(channels_in, channels_out,time_dim,attention=attention)
def forward(self, x, t):
x = self.convblock(x, t)
h = self.pool(x)
return x,h
# Upsample Block
class UpsampleBlock_Res(nn.Module):
def __init__(self, channels_in, channels_out,time_dim,attention=False):
super().__init__()
self.upconv = nn.ConvTranspose2d(channels_in, channels_in, kernel_size=2, stride=2)
self.convblock = ConvBlock_Res(channels_in, channels_out,time_dim,attention=attention)
def forward(self, x, skip_x, t):
x = self.upconv(x)
# skip-connection - merge features from contracting path to its symmetric counterpart in expansive path
out = x + skip_x
out = self.convblock(out, t)
return out
# Middle Block
class MidBlock_Res(nn.Module):
def __init__(self,channels,time_dim,attention=False):
super().__init__()
self.convblock1 = ConvBlock_Res(channels,channels,time_dim,attention=attention)
self.convblock2 = ConvBlock_Res(channels,channels,time_dim,attention=False)
def forward(self,x,t):
x = self.convblock1(x,t)
return self.convblock2(x,t)
import numpy as np
import copy
import torch
from torch import nn
from torchvision import datasets,transforms
from torch.utils.data import DataLoader
import matplotlib.pyplot as plt
import numpy as np
import torch.nn.functional as F
import os
import wandb
from copy import deepcopy
device = 'cuda' if torch.cuda.is_available() else 'cpu'
# Simple Training function for the unconditional diffusion model
def simple_trainer(model,device,epochs,trainloader,testloader,bs,lr,T,criterion = nn.MSELoss()):
criterion.to(device)
optimizer = torch.optim.AdamW(model.parameters(),lr=lr,)
for epoch in range(epochs):
model.train()
running_trainloss = []
running_testloss = []
for idx,(x,_) in enumerate(trainloader):
x = x.to(device) # has to go to device
t = torch.randint(low=0,high=T,size=(1,)).item() # doesn't have to go to device
x_t,forward_noise = model.forward_trajectory(x,t)
optimizer.zero_grad()
mean,std,pred_noise = model.forward(x_t,t) # changed to forward sinnce model is a NN module
loss = criterion(forward_noise,pred_noise)
loss.backward()
optimizer.step()
trainstep = epoch*bs+idx
running_trainloss.append(loss.cpu().item()) # MUST be on cou before appending to list
print(f"Trainloss in epoch {epoch}:{np.mean(running_trainloss)}")
model.eval()
with torch.no_grad():
for idx,(x,_) in enumerate(testloader):
x = x.to(device)
t = torch.randint(low=0,high=T,size=(1,)).item()
x_t,forward_noise = model.forward_trajectory(x,t)
optimizer.zero_grad()
mean,std,pred_noise = model.forward(x_t,t)
loss = criterion(forward_noise,pred_noise)
running_testloss.append(loss.cpu().item())
print(f"Testloss in step {epoch} :{np.mean(running_testloss)}")
# EMA class
# Important! This EMA class code is not ours and was taken from the Pytorch Image Models library called timm and performs exponential moving
# average on the trained weights for a given models neural net which was suggested by the paper "Improved Denoising Diffusion Probabilistic Models"
# by Nichol and Dhariwal to stabilize and improve the training and generalization process.
# https://github.com/huggingface/pytorch-image-models/blob/main/timm/utils/model_ema.py
class ModelEmaV2(nn.Module):
def __init__(self, model, decay=0.9999, device=None):
super(ModelEmaV2, self).__init__()
# make a copy of the model for accumulating moving average of weights
self.module = deepcopy(model)
self.module.eval()
self.decay = decay
self.device = device # perform ema on different device from model if set
if self.device is not None:
self.module.to(device=device)
def _update(self, model, update_fn):
with torch.no_grad():
for ema_v, model_v in zip(self.module.state_dict().values(), model.state_dict().values()):
if self.device is not None:
model_v = model_v.to(device=self.device)
ema_v.copy_(update_fn(ema_v, model_v))
def update(self, model):
self._update(model, update_fn=lambda e, m: self.decay * e + (1. - self.decay) * m)
def set(self, model):
self._update(model, update_fn=lambda e, m: m)
# Training function for the unconditional diffusion model
def cdm_trainer(model,
device,
trainloader, testloader,
store_iter = 10,
eval_iter = 10,
epochs = 50,
optimizer_class=torch.optim.AdamW,
optimizer_params=None,
learning_rate = 0.001,
verbose = False,
run_name=None,
checkpoint= None,
experiment_path = None,
T_max = 5*10000, # None,
eta_min= 1e-5,
ema_training = True,
decay = 0.9999,
**args
):
'''
model: Properly initialized DDPM model.
store_iter: Stores the trained DDPM every store_iter epochs.
experiment_path: Path to the models experiment folder, where the trained model will be stored every store_iter epochs
eval_iter: Evaluates the trained DDPM on testing data every eval_iter epochs.
epochs: Number of epochs we train the model further.
optimizer_class: PyTorch optimizer.
optimizer_param: Parameters for the PyTorch optimizer.
learning_rate: For optimizer initialization when training from zero, i.e. no checkpoint
verbose: If True, prints the running losses for every epoch.
run_name: Run name for WandB. IF YOU TRAIN FROM CHECKPOINT MAKE SURE TO USE THE SAME
'run_name' FOR THE DATA TO BE LOGGED ON THE SAME WANDB RUN!
trainloader: Loads the train dataset
testloader: Loads the test dataset
checkpoint: Name of the saved pth. file containing the trained weights and biases
T_max: CosineAnnealingLR scheduler argument (nr of steps in training for a full cycle)
eta_min: CosineAnnealingLR scheduler argument (scheduler oscillates between highest lr 'leraning_rate' and minimum lr 'eta_min')
decay: EMA decay rate that is used to weight the effect of the ema model when computing the weighted avg between trained and
ema weights for the networks weight update
'''
# set optimizer parameters and learning rate
if optimizer_params is None:
optimizer_params = dict(lr=learning_rate)
optimizer = optimizer_class(model.net.parameters(), **optimizer_params)
# set lr cosine schedule (comonly used in diffusion models)
scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=T_max, eta_min=eta_min)
# set ema model for training
if ema_training:
ema = ModelEmaV2(model, decay=decay, device = model.device)
# if checkpoint path is given, load the model from checkpoint
last_epoch = -1
if checkpoint:
try:
checkpoint_path = f'{experiment_path}trained_cdm/{checkpoint}'
# Load the checkpoint
checkpoint = torch.load(checkpoint_path)
# update last_epoch
last_epoch = checkpoint['epoch']
# load weights and biases of the U-Net
model_state_dict = checkpoint['model']
model.net.load_state_dict(model_state_dict)
model = model.to(device)
# load optimizer state
optimizer_state_dict = checkpoint['optimizer']
optimizer.load_state_dict(optimizer_state_dict)
# load learning rate schedule state
scheduler_state_dict = checkpoint['scheduler']
scheduler.load_state_dict(scheduler_state_dict)
scheduler.last_epoch = last_epoch
# load ema model state
if ema_training:
ema.module.load_state_dict(checkpoint['ema'])
except Exception as e:
print("Error loading checkpoint. Exception: ", e)
# pick kl loss function
if model.kl_loss == 'weighted':
loss_func = model.loss_weighted
else:
loss_func = model.loss_simplified
# pick lowest timestep
low = 1
if model.recon_loss == 'nll':
low = 0
# Using W&B
with wandb.init(project='Unconditional Landscapes', name=run_name, entity='deep-lab-', id=run_name, resume=True) as run:
# Log some info
run.config.learning_rate = learning_rate
#run.config.update({"learning_rate": learning_rate}, allow_val_change=True)
run.config.optimizer = optimizer.__class__.__name__
#run.watch(model.net)
# training loop
# last model was stored at epoch last_epoch, we continue training from there, i.e. last_epoch+1 (else we start at epoch 0)
for epoch in range(last_epoch+1, (last_epoch+1)+epochs):
running_trainloss = 0
nr_train_batches = 0
# train
model.net.train()
for idx,(x_0, y) in enumerate(trainloader):
x_0 = x_0.to(device)
y = y.to(device)
t = torch.randint(low=low, high=model.diffusion_steps, size=(x_0.shape[0],), device = device)
optimizer.zero_grad()
# define masks for zero and non-zero elements of t
mask_zero_t = (t == 0)
mask_non_zero_t = (t != 0)
t[mask_zero_t] = 1
# apply noise
x_t, forward_noise = model.forward_trajectory(x_0,t)
# compute denoising step at time t under CFGD
rand_prob = torch.rand(x_0.shape[0]).to(device)
mask_condition = (rand_prob <= 0.9)
pred_noise = torch.zeros_like(x_t).to(device)
# for every image with porb. of 90% we apply forward pass cond. on class, 10% prob. without class
if torch.any(mask_condition):
pred_noise[mask_condition] = model.forward(x_t[mask_condition], t[mask_condition], y = y[mask_condition])
if torch.any(~mask_condition):
pred_noise[~mask_condition] = model.forward(x_t[~mask_condition], t[~mask_condition], y=None)
mean, std = model.reverse_dist_param(x_t, pred_noise, t)
loss = 0
# compute kl loss
if torch.any(mask_non_zero_t):
loss = loss_func(forward_noise[mask_non_zero_t], pred_noise[mask_non_zero_t], t[mask_non_zero_t])
running_trainloss += loss.item()
nr_train_batches += 1
run.log({'loss': loss.item(), "learning_rate": scheduler.get_last_lr()[0], 'epoch': epoch, 'batch': idx})
# if reconstrcution loss was drawn
if torch.any(mask_zero_t):
recon_loss = model.loss_recon(x_0[mask_zero_t], mean[mask_zero_t], std[mask_zero_t])
loss += recon_loss
run.log({'recon_loss': recon_loss.item(), 'epoch': epoch, 'batch': idx})
loss.backward()
optimizer.step()
if ema_training:
ema.update(model)
scheduler.step()
if verbose:
print(f"Loss in epoch {epoch}:{running_trainloss/nr_train_batches}")
run.log({'running_loss': running_trainloss/nr_train_batches})
# evaluation
if ((epoch+1) % eval_iter == 0) or ((epoch+1) % store_iter == 0):
running_testloss = 0
nr_test_batches = 0
model.net.eval()
with torch.no_grad():
for idx,(x_0,y) in enumerate(testloader):
x_0 = x_0.to(device)
y=y.to(device)
t = torch.randint(low=low, high=model.diffusion_steps, size=(x_0.shape[0],), device = device)
# Define masks for zero and non-zero elements of t
mask_zero_t = (t == 0)
mask_non_zero_t = (t != 0)
t[mask_zero_t] = 1
# apply noise
x_t, forward_noise = model.forward_trajectory(x_0,t)
# compute denoising step at time t under CFGD
rand_prob = torch.rand(x_0.shape[0])
mask_condition = (rand_prob <= 0.9)
pred_noise = torch.zeros_like(x_t)
# for every image with porb. of 90% we apply forward pass cond. on class, 10% prob. without class
if torch.any(mask_condition):
pred_noise[mask_condition] = model.forward(x_t[mask_condition], t[mask_condition], y = y[mask_condition])
if torch.any(~mask_condition):
pred_noise[~mask_condition] = model.forward(x_t[~mask_condition], t[~mask_condition], y=None)
mean, std = model.reverse_dist_param(x_t, pred_noise, t)
loss = 0
# Compute kl loss
if torch.any(mask_non_zero_t):
loss = loss_func(forward_noise[mask_non_zero_t], pred_noise[mask_non_zero_t], t[mask_non_zero_t])
running_testloss += loss.item()
nr_test_batches += 1
run.log({'test_loss': loss.item(), "learning_rate": scheduler.get_last_lr()[0], 'epoch': epoch, 'batch': idx})
# If reconstrcution loss was drawn
if torch.any(mask_zero_t):
recon_loss = model.loss_recon(x_0[mask_zero_t], mean[mask_zero_t], std[mask_zero_t])
loss += recon_loss
run.log({'recon_test_loss': recon_loss.item(), 'epoch': epoch, 'batch': idx})
if verbose:
print(f"Test loss in epoch {epoch}:{running_testloss/nr_test_batches}")
run.log({'running_test_loss': running_testloss/nr_test_batches})
# store model
if ((epoch+1) % store_iter == 0):
save_dir = os.path.join(experiment_path, 'trained_cdm/')
os.makedirs(save_dir, exist_ok=True)
torch.save({
'epoch': epoch,
'model': model.net.state_dict(),
'optimizer': optimizer.state_dict(),
'scheduler': scheduler.state_dict(),
'ema' : ema.module.state_dict(),
'running_loss': running_trainloss/nr_train_batches,
'running_test_loss': running_testloss/nr_test_batches,
}, os.path.join(save_dir, f"model_epoch_{epoch}.pth"))
# always store the last version of the model if we trained through all epochs
final = (last_epoch+1)+epochs
save_dir = os.path.join(experiment_path, 'trained_cdm/')
os.makedirs(save_dir, exist_ok=True)
torch.save({
'epoch': final,
'model': model.net.state_dict(),
'optimizer': optimizer.state_dict(),
'scheduler': scheduler.state_dict(),
'ema' : ema.module.state_dict(),
'running_loss': running_trainloss/nr_train_batches,
'running_test_loss': running_testloss/nr_test_batches,
}, os.path.join(save_dir, f"model_epoch_{final}.pth"))
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment