diff --git a/experiment_creator.ipynb b/experiment_creator.ipynb index 202045fdad5b4e1e674b59c27fd12c4383265ea6..17963ce4ccd8bc3c2e41dc95cea6aca9c9c783ea 100644 --- a/experiment_creator.ipynb +++ b/experiment_creator.ipynb @@ -11,12 +11,13 @@ "from trainer.train import *\n", "from dataloader.load import *\n", "from models.Framework import *\n", - "from models.unet_unconditional_diffusion import *\n", + "from models.all_unets import *\n", "import torch \n", "from torch import nn " ] }, { + "attachments": {}, "cell_type": "markdown", "metadata": {}, "source": [ @@ -83,7 +84,7 @@ "# Advanced Settings Dictionaries\n", "###\n", "\n", - "meta_setting = dict(modelname = \"UNet_Unconditional_Diffusion_Bottleneck_Variant\",\n", + "meta_setting = dict(modelname = \"UNet_Res\",\n", " dataset = \"UnconditionalDataset\",\n", " framework = \"DDPM\",\n", " trainloop_function = \"ddpm_trainer\",\n", @@ -98,6 +99,13 @@ " ext = \".png\",\n", " transform=True\n", " )\n", + "\n", + "model_setting = dict( n_channels=64,\n", + " fctr = [1,2,4,4,8],\n", + " time_dim=256,\n", + " )\n", + "\"\"\"\n", + "outdated\n", "model_setting = dict( channels_in=channels, \n", " channels_out =channels , \n", " activation='relu', # activation function. Options: {'relu', 'leakyrelu', 'selu', 'gelu', 'silu'/'swish'}\n", @@ -113,6 +121,7 @@ " attention_list=None, # specify MHA pattern across stages\n", " num_attention_heads=1,\n", " )\n", + "\"\"\"\n", "framework_setting = dict(\n", " diffusion_steps = diffusion_steps, # dont change!!\n", " out_shape = (channels,image_size,image_size), # dont change!!\n", @@ -235,9 +244,9 @@ ], "metadata": { "kernelspec": { - "display_name": "Python 3.9 (pytorch)", + "display_name": "env", "language": "python", - "name": "pytorch" + "name": "env" }, "language_info": { "codemirror_mode": { @@ -249,7 +258,7 @@ "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", - "version": "3.9.16" + "version": "3.10.6" } }, "nbformat": 4, diff --git a/main.py b/main.py index c3e5c9d8aa399ccfc6abcd3a51e3d754081ce990..dedee957f94dd6cd4a20341633a8e30ee494726c 100644 --- a/main.py +++ b/main.py @@ -1,7 +1,14 @@ -from diffusers import UNet2DModel + +import json +import sys +from dataloader.load import * +from models.Framework import * from trainer.train import ddpm_trainer from evaluation.sample import ddpm_sampler from evaluation.evaluate import ddpm_evaluator +from models.all_unets import * +import torch + def train_func(f): #load all settings @@ -33,32 +40,10 @@ def train_func(f): test_dataloader = torch.utils.data.DataLoader(test_dataset,batch_size=batchsize) - #model = globals()[meta_setting["modelname"]](**model_setting).to(device) - #net = torch.compile(model) - net = UNet2DModel( - sample_size=64, - in_channels=3, - out_channels=3, - layers_per_block=2, - block_out_channels=(128, 128, 256, 256, 512, 512), - down_block_types=( - "DownBlock2D", - "DownBlock2D", - "DownBlock2D", - "DownBlock2D", - "AttnDownBlock2D", - "DownBlock2D", - ), - up_block_types=( - "UpBlock2D", - "AttnUpBlock2D", - "UpBlock2D", - "UpBlock2D", - "UpBlock2D", - "UpBlock2D", - ), - ) + net = globals()[meta_setting["modelname"]](**model_setting).to(device) + #net = torch.compile(net) net = net.to(device) + framework = globals()[meta_setting["framework"]](net = net,device=device, **framework_setting) print(f"META SETTINGS:\n\n {meta_setting}\n\n") @@ -92,31 +77,8 @@ def sample_func(f): # init Unet batchsize = meta_setting["batchsize"] - #model = globals()[meta_setting["modelname"]](**model_setting).to(device) - #net = torch.compile(model) - net = UNet2DModel( - sample_size=64, - in_channels=3, - out_channels=3, - layers_per_block=2, - block_out_channels=(128, 128, 256, 256, 512, 512), - down_block_types=( - "DownBlock2D", - "DownBlock2D", - "DownBlock2D", - "DownBlock2D", - "AttnDownBlock2D", - "DownBlock2D", - ), - up_block_types=( - "UpBlock2D", - "AttnUpBlock2D", - "UpBlock2D", - "UpBlock2D", - "UpBlock2D", - "UpBlock2D", - ), - ) + net = globals()[meta_setting["modelname"]](**model_setting).to(device) + #net = torch.compile(net) net = net.to(device) # init unconditional diffusion model framework = globals()[meta_setting["framework"]](net = net,device=device, **framework_setting) @@ -158,31 +120,8 @@ def evaluate_func(f): test_dataloader = torch.utils.data.DataLoader(test_dataset,batch_size=batchsize) # init Unet - #model = globals()[meta_setting["modelname"]](**model_setting).to(device) - #net = torch.compile(model) - net = UNet2DModel( - sample_size=64, - in_channels=3, - out_channels=3, - layers_per_block=2, - block_out_channels=(128, 128, 256, 256, 512, 512), - down_block_types=( - "DownBlock2D", - "DownBlock2D", - "DownBlock2D", - "DownBlock2D", - "AttnDownBlock2D", - "DownBlock2D", - ), - up_block_types=( - "UpBlock2D", - "AttnUpBlock2D", - "UpBlock2D", - "UpBlock2D", - "UpBlock2D", - "UpBlock2D", - ), - ) + net = globals()[meta_setting["modelname"]](**model_setting).to(device) + #net = torch.compile(net) net = net.to(device) # init unconditional diffusion model @@ -201,30 +140,12 @@ def evaluate_func(f): - -def pipeline_func(f): - # TODO - #train_func(f) - generate_func(f) - #evaluate_func(f) - -def hello(name): - print(f'Hello {name}!') if __name__ == '__main__': - import json - import sys - from trainer.train import * - from dataloader.load import * - from models.Framework import * - from models.unet_unconditional_diffusion import * - from models.unet import UNet - import torch - from torch import nn print(sys.argv) - functions = {'train': train_func,'sample': sample_func,'evaluate': evaluate_func,"hello":hello} + functions = {'train': train_func,'sample': sample_func,'evaluate': evaluate_func} functions[sys.argv[1]](sys.argv[2]) diff --git a/models/Framework.py b/models/Framework.py index 2b136d01484ebf1a5ce67ab340b80867f6a406e4..84783d3cf4d12350f2d1de54947dcd7ffb49c766 100644 --- a/models/Framework.py +++ b/models/Framework.py @@ -287,7 +287,7 @@ class DDPM(nn.Module): std (tensor): Batch of std scalars for the complete noise dist. for each image in the batch x_t pred_noise (tensor): Predicted noise for each image in the batch x_t ''' - pred_noise = self.net(x_t,t,return_dict=False)[0] + pred_noise = self.net(x_t,t) mean = self.mean_scaler[t-1][:,None,None,None]*(x_t - self.noise_scaler[t-1][:,None,None,None]*pred_noise) std = self.std[t-1][:,None,None,None] return mean, std, pred_noise diff --git a/models/all_unets.py b/models/all_unets.py new file mode 100644 index 0000000000000000000000000000000000000000..9fe4f6bf77f5c11b20af250614a67d7f1e170e46 --- /dev/null +++ b/models/all_unets.py @@ -0,0 +1,240 @@ +import torch +import torch.nn as nn +import torch.nn.functional as F +import torchvision.transforms as transforms +import einops +import numpy as np + +# U-Net model +class UNet_Res(nn.Module): + + def __init__(self, attention,channels_in=3, n_channels=64,fctr = [1,2,4,4,8],time_dim=256,**args): + """ + attention : (Bool) wether to use attention layers or not + channels_in : (Int) + n_channels : (Int) Channel size after first convolution + fctr : (list) list of factors for further channel size wrt n_channels + time_dim : (Int) dimenison size for time embeding vector + """ + super().__init__() + channels_out = channels_in + fctr = np.asarray(fctr)*n_channels + # learned time embeddings + self.time_embedder = TimeEmbedding(time_dim = time_dim) + self.time_embedder0 = torch.nn.Sequential(nn.Linear(time_dim,fctr[0]),nn.SELU(),nn.Linear(fctr[0],fctr[0])) + self.time_embedder1 = torch.nn.Sequential(nn.Linear(time_dim,fctr[1]),nn.SELU(),nn.Linear(fctr[1],fctr[1])) + self.time_embedder2 = torch.nn.Sequential(nn.Linear(time_dim,fctr[2]),nn.SELU(),nn.Linear(fctr[2],fctr[2])) + self.time_embedder3 = torch.nn.Sequential(nn.Linear(time_dim,fctr[3]),nn.SELU(),nn.Linear(fctr[3],fctr[3])) + self.time_embedder4 = torch.nn.Sequential(nn.Linear(time_dim,fctr[4]),nn.SELU(),nn.Linear(fctr[4],fctr[4])) + + # first conv block + self.first_conv = nn.Conv2d(channels_in,fctr[0],kernel_size=3, padding='same', bias=True) + + #down blocks + self.down1 = DownsampleBlock_Res(fctr[0],fctr[1],time_dim) + self.down2 = DownsampleBlock_Res(fctr[1],fctr[2],time_dim) + self.down3 = DownsampleBlock_Res(fctr[2],fctr[3],time_dim,attention=attention) + self.down4 = DownsampleBlock_Res(fctr[3],fctr[4],time_dim,attention=attention) + + #middle layer + self.mid1 = MidBlock_Res(fctr[4],time_dim,attention=attention) + + + #up blocks + self.up1 = UpsampleBlock_Res(fctr[1],fctr[0],time_dim) + self.up2 = UpsampleBlock_Res(fctr[2],fctr[1],time_dim) + self.up3 = UpsampleBlock_Res(fctr[3],fctr[2],time_dim,attention=attention) + self.up4 = UpsampleBlock_Res(fctr[4],fctr[3],time_dim) + + # final 1x1 conv + self.end_conv = nn.Conv2d(fctr[0], channels_out, kernel_size=1,bias=True) + + # Attention Layers + self.mha21 = MHABlock(fctr[2]) + self.mha22 = MHABlock(fctr[2]) + self.mha31 = MHABlock(fctr[3]) + self.mha32 = MHABlock(fctr[3]) + self.mha41 = MHABlock(fctr[4]) + self.mha42 = MHABlock(fctr[4]) + + def forward(self, input, t): + t_emb = self.time_embedder(t).to(input.device) + + t_emb0 = self.time_embedder0(t_emb) + t_emb1 = self.time_embedder1(t_emb) + t_emb2 = self.time_embedder2(t_emb) + t_emb3 = self.time_embedder3(t_emb) + t_emb4 = self.time_embedder4(t_emb) + + # first two conv layers + x = self.first_conv(input) + t_emb0[:,:,None,None] + #timemb + skip1 =x + skip1,x = self.down1(x,t_emb1) + skip2,x = self.down2(x,t_emb2) + skip3,x = self.down3(x,t_emb3) + skip4,x = self.down4(x,t_emb4) + + x = self.mid1(x,t_emb4) + + x = self.up4(x,skip4,t_emb3) + x = self.up3(x,skip3,t_emb2) + x = self.up2(x,skip2,t_emb1) + x = self.up1(x,skip1,t_emb0) + x = self.end_conv(x) + + return x + + + +#TimeEmbedding +class TimeEmbedding(nn.Module): + + def __init__(self, time_dim=64): + super().__init__() + + self.time_dim = time_dim + n = 10000 + self.factor = torch.pow(n*torch.ones(size=(time_dim//2,)),(-2/time_dim*torch.arange(time_dim//2))) + + def forward(self, t): + """ + input is t (B,) + factor dim (time_dim,) + output is (B,time_dim) + """ + self.factor = self.factor.to(t.device) + theta = torch.outer(t,self.factor) + + + # shape of embedding [time_channels, dim] + emb = torch.zeros(t.size(0), self.time_dim,device=t.device) + emb[:, 0::2] = torch.sin(theta) + emb[:, 1::2] = torch.cos(theta) + + return emb + +# Self Attention +class MHABlock(nn.Module): + + def __init__(self, + channels_in, + num_attention_heads=1 # number of attention heads in MHA + ): + super().__init__() + + self.channels_in = channels_in + self.num_attention_heads = num_attention_heads + self.self_attention = nn.MultiheadAttention(channels_in, num_heads=self.num_attention_heads) + + def forward(self, x): + skip = x + batch_size,_,height,width = x.size() + + x = x.permute(2, 3, 0, 1).reshape(height * width, batch_size, -1) + attn_output, _ = self.self_attention(x, x, x) + attn_output = attn_output.reshape(batch_size, -1, height, width) + + return attn_output+skip + +# Residual Convolution Block +class ConvBlock_Res(nn.Module): + + def __init__(self, + channels_in, # number of input channels fed into the block + channels_out, # number of output channels produced by the block + time_dim, + attention, + num_groups=32, # number of groups used in Group Normalization; channels_in must be divisible by num_groups + ): + super().__init__() + + self.attention = attention + if self.attention: + self.attlayer = MHABlock(channels_in=channels_out) + + # Convolution layer 1 + self.conv1 = nn.Conv2d(channels_in, channels_out, kernel_size=3, padding='same', bias=True) + self.gn1 = nn.GroupNorm(num_groups, channels_out) + self.act1 = nn.SiLU() + + # Convolution layer 2 + self.conv2 = nn.Conv2d(channels_out, channels_out, kernel_size=3, padding='same', bias=True) + self.gn2 = nn.GroupNorm(num_groups, channels_out) + self.act2 = nn.SiLU() + + # Convolution layer 3 + self.conv3 = nn.Conv2d(channels_out, channels_out, kernel_size=3, padding='same', bias=True) + self.gn3 = nn.GroupNorm(num_groups, channels_out) + self.act3 = nn.SiLU() + + #Convolution skip + self.res_skip = nn.Conv2d(channels_in,channels_out,kernel_size=1) + + nn.init.xavier_normal_(self.conv1.weight) + nn.init.xavier_normal_(self.conv2.weight) + nn.init.xavier_normal_(self.conv3.weight) + + def forward(self, x, t): + res = self.res_skip(x) + # second convolution layer + x = self.act1(self.gn1(self.conv1(x))) + + + h =x + t[:,:,None,None] + + + # third convolution layer + h = self.act2(self.gn2(self.conv2(h))) + + h = self.act3(self.gn3(self.conv3(h))) + + if self.attention: + h = self.attlayer(h) + + return h +res + +# Down Sample +class DownsampleBlock_Res(nn.Module): + + def __init__(self, channels_in, channels_out,time_dim,attention=False): + super().__init__() + + + self.pool = nn.MaxPool2d((2,2), stride=2) + self.convblock = ConvBlock_Res(channels_in, channels_out,time_dim,attention=attention) + + def forward(self, x, t): + + x = self.convblock(x, t) + h = self.pool(x) + return x,h + +# Upsample Block +class UpsampleBlock_Res(nn.Module): + + def __init__(self, channels_in, channels_out,time_dim,attention=False): + super().__init__() + + self.upconv = nn.ConvTranspose2d(channels_in, channels_in, kernel_size=2, stride=2) + self.convblock = ConvBlock_Res(channels_in, channels_out,time_dim,attention=attention) + + def forward(self, x, skip_x, t): + x = self.upconv(x) + + # skip-connection - merge features from contracting path to its symmetric counterpart in expansive path + out = x + skip_x + + out = self.convblock(out, t) + return out + +# Middle Block +class MidBlock_Res(nn.Module): + def __init__(self,channels,time_dim,attention=False): + super().__init__() + self.convblock1 = ConvBlock_Res(channels,channels,time_dim,attention=attention) + self.convblock2 = ConvBlock_Res(channels,channels,time_dim,attention=False) + def forward(self,x,t): + x = self.convblock1(x,t) + return self.convblock2(x,t) + diff --git a/requirements.txt b/requirements.txt index c21bf8dcc4f54368c489ecfddf8da777cd049a7a..18a2715a942079fda50b4e5a746d299ce4c0a4d9 100644 --- a/requirements.txt +++ b/requirements.txt @@ -8,3 +8,4 @@ wandb torch torchvision torchaudio +einops \ No newline at end of file