diff --git a/models/all_unets.py b/models/all_unets.py index dd0e97f9982ff7d0e524eedf4b8b380a9dd431a6..a7dff8c3427bdd80b2c2b5f38b38d59ef74ed293 100644 --- a/models/all_unets.py +++ b/models/all_unets.py @@ -69,7 +69,6 @@ class UNet_Res(nn.Module): # first two conv layers x = self.first_conv(input) + t_emb0[:,:,None,None] #timemb - skip1,x = self.down1(x,t_emb1) skip2,x = self.down2(x,t_emb2) skip3,x = self.down3(x,t_emb3) @@ -169,10 +168,10 @@ class ConvBlock_Res(nn.Module): self.act3 = nn.SiLU() #Convolution skip - self.res_skip = nn.Identity() - if channels_in != channels_out: - self.res_skip = nn.Conv2d(channels_in, channels_out, kernel_size=1) - #self.res_skip = nn.Conv2d(channels_in,channels_out,kernel_size=1) + if channels_in!=channels_out: + self.res_skip = nn.Conv2d(channels_in,channels_out,kernel_size=1) + else: + self.res_skip = nn.Identity() nn.init.xavier_normal_(self.conv1.weight) nn.init.xavier_normal_(self.conv2.weight) @@ -241,5 +240,189 @@ class MidBlock_Res(nn.Module): x = self.convblock1(x,t) return self.convblock2(x,t) +""" +UNet_Res_Bottleneck +""" + +class UNet_Res_Bottleneck(nn.Module): + + def __init__(self, attention,channels_in=3, n_channels=64,fctr = [1,2,4,4,8],time_dim=256,**args): + """ + attention : (Bool) wether to use attention layers or not + channels_in : (Int) + n_channels : (Int) Channel size after first convolution + fctr : (list) list of factors for further channel size wrt n_channels + time_dim : (Int) dimenison size for time embeding vector + """ + super().__init__() + channels_out = channels_in + fctr = np.asarray(fctr)*n_channels + # learned time embeddings + self.time_embedder = TimeEmbedding(time_dim = time_dim) + self.time_embedder0 = torch.nn.Sequential(nn.Linear(time_dim,fctr[0]),nn.SELU(),nn.Linear(fctr[0],fctr[0])) + self.time_embedder1 = torch.nn.Sequential(nn.Linear(time_dim,fctr[1]),nn.SELU(),nn.Linear(fctr[1],fctr[1])) + self.time_embedder2 = torch.nn.Sequential(nn.Linear(time_dim,fctr[2]),nn.SELU(),nn.Linear(fctr[2],fctr[2])) + self.time_embedder3 = torch.nn.Sequential(nn.Linear(time_dim,fctr[3]),nn.SELU(),nn.Linear(fctr[3],fctr[3])) + self.time_embedder4 = torch.nn.Sequential(nn.Linear(time_dim,fctr[4]),nn.SELU(),nn.Linear(fctr[4],fctr[4])) + + # first conv block + self.first_conv = nn.Conv2d(channels_in,fctr[0],kernel_size=3, padding='same', bias=True) + + #down blocks + self.down1 = DownsampleBlock_Res_Bottleneck(fctr[0],fctr[1],time_dim) + self.down2 = DownsampleBlock_Res_Bottleneck(fctr[1],fctr[2],time_dim) + self.down3 = DownsampleBlock_Res_Bottleneck(fctr[2],fctr[3],time_dim,attention=attention) + self.down4 = DownsampleBlock_Res_Bottleneck(fctr[3],fctr[4],time_dim,attention=attention) + + #middle layer + self.mid1 = MidBlock_Res_Bottleneck(fctr[4],time_dim,attention=attention) + + + #up blocks + self.up1 = UpsampleBlock_Res_Bottleneck(fctr[1],fctr[0],time_dim) + self.up2 = UpsampleBlock_Res_Bottleneck(fctr[2],fctr[1],time_dim) + self.up3 = UpsampleBlock_Res_Bottleneck(fctr[3],fctr[2],time_dim,attention=attention) + self.up4 = UpsampleBlock_Res_Bottleneck(fctr[4],fctr[3],time_dim,attention=attention) + + # final 1x1 conv + self.end_conv = nn.Conv2d(fctr[0], channels_out, kernel_size=1,bias=True) + + # Attention Layers + self.mha21 = MHABlock(fctr[2]) + self.mha22 = MHABlock(fctr[2]) + self.mha31 = MHABlock(fctr[3]) + self.mha32 = MHABlock(fctr[3]) + self.mha41 = MHABlock(fctr[4]) + self.mha42 = MHABlock(fctr[4]) + + def forward(self, input, t): + t_emb = self.time_embedder(t).to(input.device) + + t_emb0 = self.time_embedder0(t_emb) + t_emb1 = self.time_embedder1(t_emb) + t_emb2 = self.time_embedder2(t_emb) + t_emb3 = self.time_embedder3(t_emb) + t_emb4 = self.time_embedder4(t_emb) + + # first two conv layers + x = self.first_conv(input) + t_emb0[:,:,None,None] + #timemb + skip1,x = self.down1(x,t_emb1) + skip2,x = self.down2(x,t_emb2) + skip3,x = self.down3(x,t_emb3) + skip4,x = self.down4(x,t_emb4) + + x = self.mid1(x,t_emb4) + + x = self.up4(x,skip4,t_emb3) + x = self.up3(x,skip3,t_emb2) + x = self.up2(x,skip2,t_emb1) + x = self.up1(x,skip1,t_emb0) + x = self.end_conv(x) + + return x + + + +# Residual Convolution Block +class ConvBlock_Res_Bottleneck(nn.Module): + + def __init__(self, + channels_in, # number of input channels fed into the block + channels_out, # number of output channels produced by the block + time_dim, + attention, + num_groups=32, # number of groups used in Group Normalization; channels_in must be divisible by num_groups + ): + super().__init__() + + self.attention = attention + if self.attention: + self.attlayer = MHABlock(channels_in=channels_out) + + # Convolution layer 1 + self.conv1 = nn.Conv2d(channels_in, channels_out, kernel_size=1, padding='same', bias=False) + self.gn1 = nn.GroupNorm(num_groups, channels_out) + self.act1 = nn.SiLU() + + # Convolution layer 2 + self.conv2 = nn.Conv2d(channels_out, channels_out, kernel_size=3, padding='same', bias=False) + self.gn2 = nn.GroupNorm(num_groups, channels_out) + self.act2 = nn.SiLU() + + # Convolution layer 3 + self.conv3 = nn.Conv2d(channels_out, channels_out, kernel_size=1, padding='same', bias=False) + self.gn3 = nn.GroupNorm(num_groups, channels_out) + self.act3 = nn.SiLU() + + #Convolution skip + self.res_skip = nn.Identity() + if channels_in != channels_out: + self.res_skip = nn.Conv2d(channels_in,channels_out,kernel_size=1) + + nn.init.xavier_normal_(self.conv1.weight) + nn.init.xavier_normal_(self.conv2.weight) + nn.init.xavier_normal_(self.conv3.weight) + + def forward(self, x, t): + res = self.res_skip(x) + # second convolution layer + x = self.act1(self.gn1(self.conv1(x))) + + + h =x + t[:,:,None,None] + + + # third convolution layer + h = self.act2(self.gn2(self.conv2(h))) + + h = self.act3(self.gn3(self.conv3(h))) + + if self.attention: + h = self.attlayer(h) + + return h +res + +# Down Sample +class DownsampleBlock_Res_Bottleneck(nn.Module): + + def __init__(self, channels_in, channels_out,time_dim,attention=False): + super().__init__() + + + self.pool = nn.MaxPool2d((2,2), stride=2) + self.convblock = ConvBlock_Res_Bottleneck(channels_in, channels_out,time_dim,attention=attention) + + def forward(self, x, t): + + x = self.convblock(x, t) + h = self.pool(x) + return x,h + +# Upsample Block +class UpsampleBlock_Res_Bottleneck(nn.Module): + + def __init__(self, channels_in, channels_out,time_dim,attention=False): + super().__init__() + + self.upconv = nn.ConvTranspose2d(channels_in, channels_in, kernel_size=2, stride=2) + self.convblock = ConvBlock_Res_Bottleneck(channels_in, channels_out,time_dim,attention=attention) + + def forward(self, x, skip_x, t): + x = self.upconv(x) + # skip-connection - merge features from contracting path to its symmetric counterpart in expansive path + out = x + skip_x + out = self.convblock(out, t) + return out + +# Middle Block +class MidBlock_Res_Bottleneck(nn.Module): + def __init__(self,channels,time_dim,attention=False): + super().__init__() + self.convblock1 = ConvBlock_Res_Bottleneck(channels,channels,time_dim,attention=attention) + self.convblock2 = ConvBlock_Res_Bottleneck(channels,channels,time_dim,attention=False) + def forward(self,x,t): + x = self.convblock1(x,t) + return self.convblock2(x,t)