diff --git a/models/unet.py b/models/unet.py deleted file mode 100644 index de08e65fccb4e4ef9571c5b9d5d77e7ca2e1435f..0000000000000000000000000000000000000000 --- a/models/unet.py +++ /dev/null @@ -1,138 +0,0 @@ -# -*- coding: utf-8 -*- -"""UNet.ipynb - -Automatically generated by Colaboratory. - -Original file is located at - https://colab.research.google.com/drive/1BdiIHZYyESTyt-NVRoJXUBMlKreOExkL -""" - -''' -Implementation of U-Net architecture - -Structure: Input -> Contracting Path -> Expansive Path -> Output - -Contracting Path progressively downsamples the input - -Expansive Path incrementally upsamples the output of contracting path - -''' -import torch -import torch.nn as nn -import torch.nn.functional as F -import torchvision.transforms as transforms - -# Two 3x3 conv layers -class ConvBlock(nn.Module): - - def __init__(self, channels_in, channels_out): - super().__init__() - - # first conv layer with He initialization and batch normalization - self.conv1 = nn.Conv2d(channels_in, channels_out, kernel_size=3, padding=1, bias=False) - self.bn1 = nn.BatchNorm2d(channels_out) - nn.init.kaiming_uniform_(self.conv1.weight, nonlinearity='relu') - - # second conv layer with He initialization and batch normalization - self.conv2 = nn.Conv2d(channels_out, channels_out, kernel_size=3, padding=1, bias=False) - self.bn2 = nn.BatchNorm2d(channels_out) - nn.init.kaiming_uniform_(self.conv2.weight, nonlinearity='relu') - - def forward(self, x): - x = self.conv1(x) - #x = self.bn1(x) - x = F.relu(x) - x = self.conv2(x) - #x = self.bn2(x) - x = F.relu(x) - return x - - -# Downsampling block - maxpool halves the resolution, followed by ConvBlock operation -class DownsampleBlock(nn.Module): - - def __init__(self, channels_in, channels_out): - super().__init__() - - self.pool = nn.MaxPool2d((2,2), stride=2) - self.convblock = ConvBlock(channels_in, channels_out) - - def forward(self, x): - x = self.pool(x) - x = self.convblock(x) - return x - - -# Upsampling block - double the resolution with ConvTranspose, followed by ConvBlock operation -class UpsampleBlock(nn.Module): - - def __init__(self, channels_in, channels_out): - super().__init__() - - self.upconv = nn.ConvTranspose2d(channels_in, channels_out, kernel_size=2, stride=2) - self.convblock = ConvBlock(channels_in, channels_out) - - def forward(self, x, down_x): - x = self.upconv(x) - - # skip-connection - merge features from contracting path to its symmetric counterpart in expansive path - down_x = transforms.CenterCrop(size=(x.shape[2], x.shape[3]))(down_x) - x = torch.cat([x, down_x], dim=1) - - x = self.convblock(x) - return x - - -# U-Net model -class UNet(nn.Module): - - def __init__(self, channels_in, channels_out, n_channels=64, n_blocks=4,*kwargs,**args): - super().__init__() - - # number of channels to be produced by the first conv block - self.n_channels = n_channels - # number of downsampling/upsampling blocks - self.n_blocks = n_blocks - - # first conv block - self.first_conv = ConvBlock(channels_in, self.n_channels) - - # downsampling and upsampling blocks - down_channels = [] - up_channels = [] - for i in (2**p for p in range(self.n_blocks)): - down_channels.append((self.n_channels * i, self.n_channels * i * 2)) - up_channels.insert(0,(self.n_channels * i * 2, self.n_channels * i)) - - - self.downsample = nn.ModuleList([DownsampleBlock(c_in, c_out) for c_in, c_out in down_channels]) - - self.upsample = nn.ModuleList([UpsampleBlock(c_in, c_out) for c_in, c_out in up_channels]) - - # final 1x1 conv - self.end_conv = nn.Conv2d(self.n_channels, channels_out, kernel_size=1) - - def forward(self, x,t): - - # to store feature maps from contracting path - skip = [] - - # first two conv layers - x = self.first_conv(x) - skip.insert(0, x) - - # downsampling blocks - for i in range(self.n_blocks): - x = self.downsample[i](x) - # store feature maps - if i < self.n_blocks - 1: - skip.insert(0, x) - - # upsampling blocks (with skip-connections) - for i in range(self.n_blocks): - x = self.upsample[i](x, skip[i]) - - # final 1x1 conv layer - x = self.end_conv(x) - - return x \ No newline at end of file diff --git a/models/unet_unconditional_diffusion.py b/models/unet_unconditional_diffusion.py deleted file mode 100644 index 3f56aaaf1956a6073bf7fe65296c99ab80f5da16..0000000000000000000000000000000000000000 --- a/models/unet_unconditional_diffusion.py +++ /dev/null @@ -1,446 +0,0 @@ -import torch -import torch.nn as nn -import torch.nn.functional as F -import torchvision.transforms as transforms - - -""" -TimeEmbedding - generates time embedding for a time step - -input (0-d tensor) -> tensor of shape [time_channels, time_dim] - -Arguments: - -time_dim: int, default=64, - dimensionality of the time embedding (has to match batch size) - -time_channels: int, default=256, - number of channels in the time embedding - -""" -class TimeEmbedding(nn.Module): - - def __init__(self, time_dim=64, time_channels=256): - super().__init__() - - self.time_dim = time_dim - self.time_channels = time_channels - - # argument to sin/cos fn: t / 10000^(i / d) where i = 2k or 2k + 1 - https://kazemnejad.com/blog/transformer_architecture_positional_encoding/ - self.factor = torch.exp(torch.arange(0, self.time_dim, 2) * (- torch.log(torch.tensor(10000.0)) / self.time_dim)) - - - def forward(self, t): - - # if t = tensor.torch(int), t.shape = [] - # change it so that t.shape = [1] - if len(t.shape) == 0: - t = t.unsqueeze(0) - - t = t.unsqueeze(1) * self.factor - - # shape of embedding [time_channels, dim] - emb = torch.zeros(self.time_channels, self.time_dim) - emb[:, 0::2] = torch.sin(t) - emb[:, 1::2] = torch.cos(t) - - return emb - - - -""" -ConvResBlock - building block of the U-Net architecture - -input -> Conv1 -(+ time embedding) -> Conv2 -(+ residual) -> Multi-head attention - -Arguments: - -channels_in : int, - number of input channels fed into the block - -channels_out: int, - number of output channels produced by the block - -activation: {'relu', 'leakyrelu', 'selu', 'gelu', 'silu'/'swish'}, default='relu', - activation function in the neural network - -weight_init: {'he', 'torch'}, default='he', - weight initializer for convolution layers; choose between He - initialization and PyTorch's default initialization - -time_channels: int, default=256, - number of channels for time embedding - -num_groups: int, default=32, - number of groups used in Group Normalization; channels_in must be - divisible by num_groups - -dropout: float, default=0.1, - drop-out to be applied - -attention: boolean, default=False, - whether Multi-head attention (MHA) is applied or not - -num_attention_heads: int, default=1, - number of attention heads in MHA - -""" -class ConvResBlock(nn.Module): - - def __init__(self, - channels_in, # number of input channels fed into the block - channels_out, # number of output channels produced by the block - activation, # activation function. Options: {'relu', 'leakyrelu', 'selu', 'gelu', 'silu'/'swish'} - weight_init='he', # weight initialization. Options: {'he', 'torch'} - time_channels=64, # number of channels for time embedding - num_groups=32, # number of groups used in Group Normalization; channels_in must be divisible by num_groups - dropout=0.1, # drop-out to be applied - attention=False, # boolean: whether Multi-head attention (MHA) is applied or not - num_attention_heads=1 # number of attention heads in MHA - ): - super().__init__() - - self.activation = activation - - # Convolution layer 1 - self.conv1 = nn.Conv2d(channels_in, channels_out, kernel_size=3, padding=1, bias=False) - self.gn1 = nn.GroupNorm(num_groups, channels_out) - self.act1 = self.activation - - # Convolution layer 2 - self.conv2 = nn.Conv2d(channels_out, channels_out, kernel_size=3, padding=1, bias=False) - self.gn2 = nn.GroupNorm(num_groups, channels_out) - self.act2 = self.activation - - if weight_init=='he': - nn.init.kaiming_uniform_(self.conv1.weight, nonlinearity='relu') - nn.init.kaiming_uniform_(self.conv2.weight, nonlinearity='relu') - - # Drop-out - self.dropout = nn.Dropout(dropout) - - # Residual connection - self.residual = nn.Identity() - if channels_in != channels_out: - self.residual = nn.Conv2d(channels_in, channels_out, kernel_size=1) - if weight_init=='he': - nn.init.kaiming_uniform_(self.residual.weight, nonlinearity='relu') - self.residual_act = self.activation - - # Time embedding - map time embedding to have the same number of channels as image activation - self.time_emb = nn.Linear(time_channels, channels_out) - self.time_act = self.activation - - # Multi-head attention - self.attention = attention - self.num_attention_heads = num_attention_heads - self.self_attention = nn.Identity() - if self.attention: - self.self_attention = nn.MultiheadAttention(channels_out, num_heads=self.num_attention_heads) - - def forward(self, x, t): - - # store input, to be used as residual - res = self.residual(x) - - if isinstance(self.residual, nn.Conv2d): - res = self.residual_act(res) - - # first convolution layer - x = self.act1(self.gn1(self.conv1(x))) - - # add temporal information with time embedding - t = self.time_act(self.time_emb(t.T)) - x += t[:, :, None, None] - - # Drop-out - x = self.dropout(x) - - # second convolution layer - x = self.act2(self.gn2(self.conv2(x))) - - # add residual - x += res - - # apply self-attention - if self.attention: - batch_size = x.shape[0] - height = x.shape[2] - width = x.shape[3] - sequence_length = height * width - x = x.permute(2, 3, 0, 1).reshape(sequence_length, batch_size, -1) - x, _ = self.self_attention(x, x, x) - x = x.reshape(batch_size, -1, height, width) - return x - - -""" -UNet_Unconditional_Diffusion - the U-Net architecture - -Arguments: - -channels_in: int, - number of input channels to the U-Net; for RGB images, channels_in = 3 - -channels_out: int, - number of output channels - -activation: {'relu', 'leakyrelu', 'selu', 'gelu', 'silu'/'swish'}, default='relu', - activation function in the neural network - -weight_init: {'he', 'torch'}, default='he', - weight initializer for convolution layers; choose between He - initialization and PyTorch's default initialization - -projection_features: int, default=64, - number of image features after first convolution layer - -time_dim: int, default=64, - dimensionality of the time embedding (has to match batch size) - -time_channels: int, default=256, - number of time channels - -num_stages: int, default=4, - number of stages in contracting/expansive path - -attention_list: int list, default=None, - specify number of features produced by stages - -num_blocks: int, default=2, - number of ConvResBlock in each contracting/expansive path - -num_groupnorm_groups: int, default=32, - number of groups used in Group Normalization inside a ConvResBlock; - channels_in to a ConvResBlock must be divisible by num_groups - -dropout: float, default=0.1, - drop-out to be applied - -attention_list: boolean list, default=None, - specify MHA pattern across stages - -num_attention_heads: int, default=1, - number of attention heads in MHA inside a ConvResBlock - -""" -class UNet_Unconditional_Diffusion(nn.Module): - - def __init__(self, - channels_in, # number of input channels to the U-Net; for RGB images, channels_in = 3 - channels_out, # number of output channels - activation='relu', # activation function. Options: {'relu', 'leakyrelu', 'selu', 'gelu', 'silu'/'swish'} - weight_init='he', # weight initialization. Options: {'he', 'torch'} - projection_features=64, # number of image features after first convolution layer - time_dim=64, # dimensionality of the time embedding (has to match batch size) - time_channels=256, # number of time channels - num_stages=4, # number of stages in contracting/expansive path - stage_list=None, # specify number of features produced by stages - num_blocks=2, # number of ConvResBlock in each contracting/expansive path - num_groupnorm_groups=32, # number of groups used in Group Normalization inside a ConvResBlock - dropout=0.1, # drop-out to be applied inside a ConvResBlock - attention_list=None, # specify MHA pattern across stages - num_attention_heads=1, # number of attention heads in MHA inside a ConvResBlock - **args - ): - super().__init__() - - self.channels_in = channels_in - self.channels_out = channels_out - - if activation=='relu': - self.activation = nn.ReLU() - elif activation=='leakyrelu': - self.activation = nn.LeakyReLU() - elif activation=='selu': - self.activation = nn.SELU() - elif activation=='gelu': - self.activation = nn.GELU() - elif activation=='swish' or activation=='silu': - self.activation = nn.SiLU() - - # number of channels to be produced by the first conv block - image projection - self.projection_features = projection_features - self.first_conv = nn.Conv2d(channels_in, self.projection_features, kernel_size=3, padding=1) - if weight_init=='he': - nn.init.kaiming_uniform_(self.first_conv.weight, nonlinearity='relu') - self.first_act = self.activation - - # number of time channels - self.time_dim = time_dim - self.time_channels = time_channels - - self.time_embedding = TimeEmbedding(time_dim=self.time_dim, time_channels=self.time_channels) - - # number of downsampling/upsampling stages - self.num_stages = num_stages - # number of ConvResBlocks in each downsampling/upsampling step - self.num_blocks = num_blocks - - if attention_list is None: - # boolean list assigning attention blocks in the contracting and expansive path - # default - first half of contracting path has no attention, second half does; - # first half of expansive path has attention, second half doesn't - self.attention_list=[] - for i in range(self.num_stages): - if i < self.num_stages//2: - self.attention_list.append(False) - else: - self.attention_list.append(True) - else: - self.attention_list = attention_list # [False, False, True, True] - paper implementation for similar 4 stage U-Net - - # number of features produced by each stage - if stage_list is None: - # default - successive stages double the number of channels - self.stages = [projection_features * 2**i for i in range(1, self.num_stages+1)] - else: - self.stages = stage_list # [64, 128, 256, 1024] - paper implementation for similar 4 stage U-Net - - # contracting path - contracting_path = [] - - # number of channels to go into the first ConvResBlock = number of output channels from first conv layer - c_in = c_out = projection_features - - # there are num_stages number of stages - # each stage has num_blocks number of ConvRes+Attention blocks - # each stage (except for the last) ends with a downsampling layer - maxpool - for i in range(self.num_stages): - - c_out = self.stages[i] - - for _ in range(num_blocks): - contracting_path.append(ConvResBlock(channels_in=c_in, - channels_out=c_out, - activation=self.activation, - weight_init=weight_init, - time_channels=self.time_channels, - num_groups=num_groupnorm_groups, - dropout=dropout, - attention=self.attention_list[i], - num_attention_heads=num_attention_heads)) - c_in = c_out - - - # downsample, if it is not the last stage - if i < self.num_stages - 1: - contracting_path.append(nn.MaxPool2d((2,2), stride=2)) - - self.contracting_path = nn.ModuleList(contracting_path) - - # the bottleneck block - - self.midblock1 = ConvResBlock(channels_in=c_out, - channels_out=c_out, - activation=self.activation, - weight_init=weight_init, - time_channels=self.time_channels, - num_groups=num_groupnorm_groups, - dropout=dropout, - attention=True, - num_attention_heads=num_attention_heads) - self.midblock2 = ConvResBlock(channels_in=c_out, - channels_out=c_out, - activation=self.activation, - weight_init=weight_init, - time_channels=self.time_channels, - num_groups=num_groupnorm_groups, - dropout=dropout, - attention=False, - num_attention_heads=num_attention_heads) - - # expansive path - expansive_path = [] - - # input to the expansive path = output of midblock = input to midblock = output of contracting path - c_in = c_out = self.stages[-1] - - # there are num_stages number of stages - # each stage has num_blocks number of ConvRes+Attention blocks and then 1 more to halve the number of channels - # each stage (except for the last) ends with an upsampling layer - Transposed convolution - for i in reversed(range(self.num_stages)): - - # channels_in = c_in + c_out to account for the incoming skip connections from contracting path - for _ in range(self.num_blocks): - expansive_path.append(ConvResBlock(channels_in=c_in + c_out, - channels_out=c_out, - activation=self.activation, - weight_init=weight_init, - time_channels=self.time_channels, - num_groups=num_groupnorm_groups, - dropout=dropout, - attention=self.attention_list[i], - num_attention_heads=num_attention_heads)) - - if i > 0: - c_out = self.stages[i-1] - else: - c_out = self.projection_features - expansive_path.append(ConvResBlock(channels_in=c_in + c_out, - channels_out=c_out, - activation=self.activation, - weight_init=weight_init, - time_channels=self.time_channels, - num_groups=num_groupnorm_groups, - dropout=dropout, - attention=self.attention_list[i], - num_attention_heads=num_attention_heads)) - c_in = c_out - # upsample, if it is not the last stage - if i > 0: - expansive_path.append(nn.ConvTranspose2d(c_in, c_in, kernel_size=4, stride=2, padding=1)) - - self.expansive_path = nn.ModuleList(expansive_path) - - # final convolution layer - self.end_gn = nn.GroupNorm(8, c_in) - self.end_conv = nn.Conv2d(c_in, self.channels_out, kernel_size=3, padding=1) - if weight_init=='he': - nn.init.kaiming_uniform_(self.end_conv.weight, nonlinearity='relu') - self.end_act = self.activation - - - - def forward(self, x, t): - t = torch.tensor(t) - - # to store feature maps from contracting path - skip = [] - - # time embedding for time step t (int) - t = self.time_embedding(t).to('cuda') - - # first conv layer to project input image (3, *, *) into (projection_features=64, *, *) - x = self.first_act(self.first_conv(x)) - - # store initial projection - skip.append(x) - - # contracting path - for i in range(len(self.contracting_path)): - if isinstance(self.contracting_path[i], ConvResBlock): - x = self.contracting_path[i](x, t) - else: - x = self.contracting_path[i](x) - - # store feature maps - skip.append(x) - - x = self.midblock1(x, t) - x = self.midblock2(x, t) - - # expansive path - for i in range(len(self.expansive_path)): - - # add channels coming from skip connections (doesn't apply for upsampling ConvTranspose2D layer) - if isinstance(self.expansive_path[i], ConvResBlock): - x = torch.cat((x, skip.pop()), dim=1) - x = self.expansive_path[i](x,t) - else: - x = self.expansive_path[i](x) - - # final conv layer - x = self.end_gn(x) - x = self.end_act(self.end_conv(x)) - return x