diff --git a/models/unet.py b/models/unet.py
deleted file mode 100644
index de08e65fccb4e4ef9571c5b9d5d77e7ca2e1435f..0000000000000000000000000000000000000000
--- a/models/unet.py
+++ /dev/null
@@ -1,138 +0,0 @@
-# -*- coding: utf-8 -*-
-"""UNet.ipynb
-
-Automatically generated by Colaboratory.
-
-Original file is located at
- https://colab.research.google.com/drive/1BdiIHZYyESTyt-NVRoJXUBMlKreOExkL
-"""
-
-'''
-Implementation of U-Net architecture
-
-Structure: Input -> Contracting Path -> Expansive Path -> Output
-
-Contracting Path progressively downsamples the input
-
-Expansive Path incrementally upsamples the output of contracting path
-
-'''
-import torch
-import torch.nn as nn
-import torch.nn.functional as F
-import torchvision.transforms as transforms
-
-# Two 3x3 conv layers
-class ConvBlock(nn.Module):
-
- def __init__(self, channels_in, channels_out):
- super().__init__()
-
- # first conv layer with He initialization and batch normalization
- self.conv1 = nn.Conv2d(channels_in, channels_out, kernel_size=3, padding=1, bias=False)
- self.bn1 = nn.BatchNorm2d(channels_out)
- nn.init.kaiming_uniform_(self.conv1.weight, nonlinearity='relu')
-
- # second conv layer with He initialization and batch normalization
- self.conv2 = nn.Conv2d(channels_out, channels_out, kernel_size=3, padding=1, bias=False)
- self.bn2 = nn.BatchNorm2d(channels_out)
- nn.init.kaiming_uniform_(self.conv2.weight, nonlinearity='relu')
-
- def forward(self, x):
- x = self.conv1(x)
- #x = self.bn1(x)
- x = F.relu(x)
- x = self.conv2(x)
- #x = self.bn2(x)
- x = F.relu(x)
- return x
-
-
-# Downsampling block - maxpool halves the resolution, followed by ConvBlock operation
-class DownsampleBlock(nn.Module):
-
- def __init__(self, channels_in, channels_out):
- super().__init__()
-
- self.pool = nn.MaxPool2d((2,2), stride=2)
- self.convblock = ConvBlock(channels_in, channels_out)
-
- def forward(self, x):
- x = self.pool(x)
- x = self.convblock(x)
- return x
-
-
-# Upsampling block - double the resolution with ConvTranspose, followed by ConvBlock operation
-class UpsampleBlock(nn.Module):
-
- def __init__(self, channels_in, channels_out):
- super().__init__()
-
- self.upconv = nn.ConvTranspose2d(channels_in, channels_out, kernel_size=2, stride=2)
- self.convblock = ConvBlock(channels_in, channels_out)
-
- def forward(self, x, down_x):
- x = self.upconv(x)
-
- # skip-connection - merge features from contracting path to its symmetric counterpart in expansive path
- down_x = transforms.CenterCrop(size=(x.shape[2], x.shape[3]))(down_x)
- x = torch.cat([x, down_x], dim=1)
-
- x = self.convblock(x)
- return x
-
-
-# U-Net model
-class UNet(nn.Module):
-
- def __init__(self, channels_in, channels_out, n_channels=64, n_blocks=4,*kwargs,**args):
- super().__init__()
-
- # number of channels to be produced by the first conv block
- self.n_channels = n_channels
- # number of downsampling/upsampling blocks
- self.n_blocks = n_blocks
-
- # first conv block
- self.first_conv = ConvBlock(channels_in, self.n_channels)
-
- # downsampling and upsampling blocks
- down_channels = []
- up_channels = []
- for i in (2**p for p in range(self.n_blocks)):
- down_channels.append((self.n_channels * i, self.n_channels * i * 2))
- up_channels.insert(0,(self.n_channels * i * 2, self.n_channels * i))
-
-
- self.downsample = nn.ModuleList([DownsampleBlock(c_in, c_out) for c_in, c_out in down_channels])
-
- self.upsample = nn.ModuleList([UpsampleBlock(c_in, c_out) for c_in, c_out in up_channels])
-
- # final 1x1 conv
- self.end_conv = nn.Conv2d(self.n_channels, channels_out, kernel_size=1)
-
- def forward(self, x,t):
-
- # to store feature maps from contracting path
- skip = []
-
- # first two conv layers
- x = self.first_conv(x)
- skip.insert(0, x)
-
- # downsampling blocks
- for i in range(self.n_blocks):
- x = self.downsample[i](x)
- # store feature maps
- if i < self.n_blocks - 1:
- skip.insert(0, x)
-
- # upsampling blocks (with skip-connections)
- for i in range(self.n_blocks):
- x = self.upsample[i](x, skip[i])
-
- # final 1x1 conv layer
- x = self.end_conv(x)
-
- return x
\ No newline at end of file
diff --git a/models/unet_unconditional_diffusion.py b/models/unet_unconditional_diffusion.py
deleted file mode 100644
index 3f56aaaf1956a6073bf7fe65296c99ab80f5da16..0000000000000000000000000000000000000000
--- a/models/unet_unconditional_diffusion.py
+++ /dev/null
@@ -1,446 +0,0 @@
-import torch
-import torch.nn as nn
-import torch.nn.functional as F
-import torchvision.transforms as transforms
-
-
-"""
-TimeEmbedding - generates time embedding for a time step
-
-input (0-d tensor) -> tensor of shape [time_channels, time_dim]
-
-Arguments:
-
-time_dim: int, default=64,
- dimensionality of the time embedding (has to match batch size)
-
-time_channels: int, default=256,
- number of channels in the time embedding
-
-"""
-class TimeEmbedding(nn.Module):
-
- def __init__(self, time_dim=64, time_channels=256):
- super().__init__()
-
- self.time_dim = time_dim
- self.time_channels = time_channels
-
- # argument to sin/cos fn: t / 10000^(i / d) where i = 2k or 2k + 1 - https://kazemnejad.com/blog/transformer_architecture_positional_encoding/
- self.factor = torch.exp(torch.arange(0, self.time_dim, 2) * (- torch.log(torch.tensor(10000.0)) / self.time_dim))
-
-
- def forward(self, t):
-
- # if t = tensor.torch(int), t.shape = []
- # change it so that t.shape = [1]
- if len(t.shape) == 0:
- t = t.unsqueeze(0)
-
- t = t.unsqueeze(1) * self.factor
-
- # shape of embedding [time_channels, dim]
- emb = torch.zeros(self.time_channels, self.time_dim)
- emb[:, 0::2] = torch.sin(t)
- emb[:, 1::2] = torch.cos(t)
-
- return emb
-
-
-
-"""
-ConvResBlock - building block of the U-Net architecture
-
-input -> Conv1 -(+ time embedding) -> Conv2 -(+ residual) -> Multi-head attention
-
-Arguments:
-
-channels_in : int,
- number of input channels fed into the block
-
-channels_out: int,
- number of output channels produced by the block
-
-activation: {'relu', 'leakyrelu', 'selu', 'gelu', 'silu'/'swish'}, default='relu',
- activation function in the neural network
-
-weight_init: {'he', 'torch'}, default='he',
- weight initializer for convolution layers; choose between He
- initialization and PyTorch's default initialization
-
-time_channels: int, default=256,
- number of channels for time embedding
-
-num_groups: int, default=32,
- number of groups used in Group Normalization; channels_in must be
- divisible by num_groups
-
-dropout: float, default=0.1,
- drop-out to be applied
-
-attention: boolean, default=False,
- whether Multi-head attention (MHA) is applied or not
-
-num_attention_heads: int, default=1,
- number of attention heads in MHA
-
-"""
-class ConvResBlock(nn.Module):
-
- def __init__(self,
- channels_in, # number of input channels fed into the block
- channels_out, # number of output channels produced by the block
- activation, # activation function. Options: {'relu', 'leakyrelu', 'selu', 'gelu', 'silu'/'swish'}
- weight_init='he', # weight initialization. Options: {'he', 'torch'}
- time_channels=64, # number of channels for time embedding
- num_groups=32, # number of groups used in Group Normalization; channels_in must be divisible by num_groups
- dropout=0.1, # drop-out to be applied
- attention=False, # boolean: whether Multi-head attention (MHA) is applied or not
- num_attention_heads=1 # number of attention heads in MHA
- ):
- super().__init__()
-
- self.activation = activation
-
- # Convolution layer 1
- self.conv1 = nn.Conv2d(channels_in, channels_out, kernel_size=3, padding=1, bias=False)
- self.gn1 = nn.GroupNorm(num_groups, channels_out)
- self.act1 = self.activation
-
- # Convolution layer 2
- self.conv2 = nn.Conv2d(channels_out, channels_out, kernel_size=3, padding=1, bias=False)
- self.gn2 = nn.GroupNorm(num_groups, channels_out)
- self.act2 = self.activation
-
- if weight_init=='he':
- nn.init.kaiming_uniform_(self.conv1.weight, nonlinearity='relu')
- nn.init.kaiming_uniform_(self.conv2.weight, nonlinearity='relu')
-
- # Drop-out
- self.dropout = nn.Dropout(dropout)
-
- # Residual connection
- self.residual = nn.Identity()
- if channels_in != channels_out:
- self.residual = nn.Conv2d(channels_in, channels_out, kernel_size=1)
- if weight_init=='he':
- nn.init.kaiming_uniform_(self.residual.weight, nonlinearity='relu')
- self.residual_act = self.activation
-
- # Time embedding - map time embedding to have the same number of channels as image activation
- self.time_emb = nn.Linear(time_channels, channels_out)
- self.time_act = self.activation
-
- # Multi-head attention
- self.attention = attention
- self.num_attention_heads = num_attention_heads
- self.self_attention = nn.Identity()
- if self.attention:
- self.self_attention = nn.MultiheadAttention(channels_out, num_heads=self.num_attention_heads)
-
- def forward(self, x, t):
-
- # store input, to be used as residual
- res = self.residual(x)
-
- if isinstance(self.residual, nn.Conv2d):
- res = self.residual_act(res)
-
- # first convolution layer
- x = self.act1(self.gn1(self.conv1(x)))
-
- # add temporal information with time embedding
- t = self.time_act(self.time_emb(t.T))
- x += t[:, :, None, None]
-
- # Drop-out
- x = self.dropout(x)
-
- # second convolution layer
- x = self.act2(self.gn2(self.conv2(x)))
-
- # add residual
- x += res
-
- # apply self-attention
- if self.attention:
- batch_size = x.shape[0]
- height = x.shape[2]
- width = x.shape[3]
- sequence_length = height * width
- x = x.permute(2, 3, 0, 1).reshape(sequence_length, batch_size, -1)
- x, _ = self.self_attention(x, x, x)
- x = x.reshape(batch_size, -1, height, width)
- return x
-
-
-"""
-UNet_Unconditional_Diffusion - the U-Net architecture
-
-Arguments:
-
-channels_in: int,
- number of input channels to the U-Net; for RGB images, channels_in = 3
-
-channels_out: int,
- number of output channels
-
-activation: {'relu', 'leakyrelu', 'selu', 'gelu', 'silu'/'swish'}, default='relu',
- activation function in the neural network
-
-weight_init: {'he', 'torch'}, default='he',
- weight initializer for convolution layers; choose between He
- initialization and PyTorch's default initialization
-
-projection_features: int, default=64,
- number of image features after first convolution layer
-
-time_dim: int, default=64,
- dimensionality of the time embedding (has to match batch size)
-
-time_channels: int, default=256,
- number of time channels
-
-num_stages: int, default=4,
- number of stages in contracting/expansive path
-
-attention_list: int list, default=None,
- specify number of features produced by stages
-
-num_blocks: int, default=2,
- number of ConvResBlock in each contracting/expansive path
-
-num_groupnorm_groups: int, default=32,
- number of groups used in Group Normalization inside a ConvResBlock;
- channels_in to a ConvResBlock must be divisible by num_groups
-
-dropout: float, default=0.1,
- drop-out to be applied
-
-attention_list: boolean list, default=None,
- specify MHA pattern across stages
-
-num_attention_heads: int, default=1,
- number of attention heads in MHA inside a ConvResBlock
-
-"""
-class UNet_Unconditional_Diffusion(nn.Module):
-
- def __init__(self,
- channels_in, # number of input channels to the U-Net; for RGB images, channels_in = 3
- channels_out, # number of output channels
- activation='relu', # activation function. Options: {'relu', 'leakyrelu', 'selu', 'gelu', 'silu'/'swish'}
- weight_init='he', # weight initialization. Options: {'he', 'torch'}
- projection_features=64, # number of image features after first convolution layer
- time_dim=64, # dimensionality of the time embedding (has to match batch size)
- time_channels=256, # number of time channels
- num_stages=4, # number of stages in contracting/expansive path
- stage_list=None, # specify number of features produced by stages
- num_blocks=2, # number of ConvResBlock in each contracting/expansive path
- num_groupnorm_groups=32, # number of groups used in Group Normalization inside a ConvResBlock
- dropout=0.1, # drop-out to be applied inside a ConvResBlock
- attention_list=None, # specify MHA pattern across stages
- num_attention_heads=1, # number of attention heads in MHA inside a ConvResBlock
- **args
- ):
- super().__init__()
-
- self.channels_in = channels_in
- self.channels_out = channels_out
-
- if activation=='relu':
- self.activation = nn.ReLU()
- elif activation=='leakyrelu':
- self.activation = nn.LeakyReLU()
- elif activation=='selu':
- self.activation = nn.SELU()
- elif activation=='gelu':
- self.activation = nn.GELU()
- elif activation=='swish' or activation=='silu':
- self.activation = nn.SiLU()
-
- # number of channels to be produced by the first conv block - image projection
- self.projection_features = projection_features
- self.first_conv = nn.Conv2d(channels_in, self.projection_features, kernel_size=3, padding=1)
- if weight_init=='he':
- nn.init.kaiming_uniform_(self.first_conv.weight, nonlinearity='relu')
- self.first_act = self.activation
-
- # number of time channels
- self.time_dim = time_dim
- self.time_channels = time_channels
-
- self.time_embedding = TimeEmbedding(time_dim=self.time_dim, time_channels=self.time_channels)
-
- # number of downsampling/upsampling stages
- self.num_stages = num_stages
- # number of ConvResBlocks in each downsampling/upsampling step
- self.num_blocks = num_blocks
-
- if attention_list is None:
- # boolean list assigning attention blocks in the contracting and expansive path
- # default - first half of contracting path has no attention, second half does;
- # first half of expansive path has attention, second half doesn't
- self.attention_list=[]
- for i in range(self.num_stages):
- if i < self.num_stages//2:
- self.attention_list.append(False)
- else:
- self.attention_list.append(True)
- else:
- self.attention_list = attention_list # [False, False, True, True] - paper implementation for similar 4 stage U-Net
-
- # number of features produced by each stage
- if stage_list is None:
- # default - successive stages double the number of channels
- self.stages = [projection_features * 2**i for i in range(1, self.num_stages+1)]
- else:
- self.stages = stage_list # [64, 128, 256, 1024] - paper implementation for similar 4 stage U-Net
-
- # contracting path
- contracting_path = []
-
- # number of channels to go into the first ConvResBlock = number of output channels from first conv layer
- c_in = c_out = projection_features
-
- # there are num_stages number of stages
- # each stage has num_blocks number of ConvRes+Attention blocks
- # each stage (except for the last) ends with a downsampling layer - maxpool
- for i in range(self.num_stages):
-
- c_out = self.stages[i]
-
- for _ in range(num_blocks):
- contracting_path.append(ConvResBlock(channels_in=c_in,
- channels_out=c_out,
- activation=self.activation,
- weight_init=weight_init,
- time_channels=self.time_channels,
- num_groups=num_groupnorm_groups,
- dropout=dropout,
- attention=self.attention_list[i],
- num_attention_heads=num_attention_heads))
- c_in = c_out
-
-
- # downsample, if it is not the last stage
- if i < self.num_stages - 1:
- contracting_path.append(nn.MaxPool2d((2,2), stride=2))
-
- self.contracting_path = nn.ModuleList(contracting_path)
-
- # the bottleneck block
-
- self.midblock1 = ConvResBlock(channels_in=c_out,
- channels_out=c_out,
- activation=self.activation,
- weight_init=weight_init,
- time_channels=self.time_channels,
- num_groups=num_groupnorm_groups,
- dropout=dropout,
- attention=True,
- num_attention_heads=num_attention_heads)
- self.midblock2 = ConvResBlock(channels_in=c_out,
- channels_out=c_out,
- activation=self.activation,
- weight_init=weight_init,
- time_channels=self.time_channels,
- num_groups=num_groupnorm_groups,
- dropout=dropout,
- attention=False,
- num_attention_heads=num_attention_heads)
-
- # expansive path
- expansive_path = []
-
- # input to the expansive path = output of midblock = input to midblock = output of contracting path
- c_in = c_out = self.stages[-1]
-
- # there are num_stages number of stages
- # each stage has num_blocks number of ConvRes+Attention blocks and then 1 more to halve the number of channels
- # each stage (except for the last) ends with an upsampling layer - Transposed convolution
- for i in reversed(range(self.num_stages)):
-
- # channels_in = c_in + c_out to account for the incoming skip connections from contracting path
- for _ in range(self.num_blocks):
- expansive_path.append(ConvResBlock(channels_in=c_in + c_out,
- channels_out=c_out,
- activation=self.activation,
- weight_init=weight_init,
- time_channels=self.time_channels,
- num_groups=num_groupnorm_groups,
- dropout=dropout,
- attention=self.attention_list[i],
- num_attention_heads=num_attention_heads))
-
- if i > 0:
- c_out = self.stages[i-1]
- else:
- c_out = self.projection_features
- expansive_path.append(ConvResBlock(channels_in=c_in + c_out,
- channels_out=c_out,
- activation=self.activation,
- weight_init=weight_init,
- time_channels=self.time_channels,
- num_groups=num_groupnorm_groups,
- dropout=dropout,
- attention=self.attention_list[i],
- num_attention_heads=num_attention_heads))
- c_in = c_out
- # upsample, if it is not the last stage
- if i > 0:
- expansive_path.append(nn.ConvTranspose2d(c_in, c_in, kernel_size=4, stride=2, padding=1))
-
- self.expansive_path = nn.ModuleList(expansive_path)
-
- # final convolution layer
- self.end_gn = nn.GroupNorm(8, c_in)
- self.end_conv = nn.Conv2d(c_in, self.channels_out, kernel_size=3, padding=1)
- if weight_init=='he':
- nn.init.kaiming_uniform_(self.end_conv.weight, nonlinearity='relu')
- self.end_act = self.activation
-
-
-
- def forward(self, x, t):
- t = torch.tensor(t)
-
- # to store feature maps from contracting path
- skip = []
-
- # time embedding for time step t (int)
- t = self.time_embedding(t).to('cuda')
-
- # first conv layer to project input image (3, *, *) into (projection_features=64, *, *)
- x = self.first_act(self.first_conv(x))
-
- # store initial projection
- skip.append(x)
-
- # contracting path
- for i in range(len(self.contracting_path)):
- if isinstance(self.contracting_path[i], ConvResBlock):
- x = self.contracting_path[i](x, t)
- else:
- x = self.contracting_path[i](x)
-
- # store feature maps
- skip.append(x)
-
- x = self.midblock1(x, t)
- x = self.midblock2(x, t)
-
- # expansive path
- for i in range(len(self.expansive_path)):
-
- # add channels coming from skip connections (doesn't apply for upsampling ConvTranspose2D layer)
- if isinstance(self.expansive_path[i], ConvResBlock):
- x = torch.cat((x, skip.pop()), dim=1)
- x = self.expansive_path[i](x,t)
- else:
- x = self.expansive_path[i](x)
-
- # final conv layer
- x = self.end_gn(x)
- x = self.end_act(self.end_conv(x))
- return x