Skip to content
Snippets Groups Projects
Commit 4dd87bb4 authored by Gonzalo Martin Garcia's avatar Gonzalo Martin Garcia
Browse files

unnecessary merge conflict all_unets.py

parents 934372c0 a729fb59
No related branches found
No related tags found
No related merge requests found
...@@ -69,7 +69,6 @@ class UNet_Res(nn.Module): ...@@ -69,7 +69,6 @@ class UNet_Res(nn.Module):
# first two conv layers # first two conv layers
x = self.first_conv(input) + t_emb0[:,:,None,None] x = self.first_conv(input) + t_emb0[:,:,None,None]
#timemb #timemb
skip1,x = self.down1(x,t_emb1) skip1,x = self.down1(x,t_emb1)
skip2,x = self.down2(x,t_emb2) skip2,x = self.down2(x,t_emb2)
skip3,x = self.down3(x,t_emb3) skip3,x = self.down3(x,t_emb3)
...@@ -169,10 +168,10 @@ class ConvBlock_Res(nn.Module): ...@@ -169,10 +168,10 @@ class ConvBlock_Res(nn.Module):
self.act3 = nn.SiLU() self.act3 = nn.SiLU()
#Convolution skip #Convolution skip
self.res_skip = nn.Identity()
if channels_in!=channels_out: if channels_in!=channels_out:
self.res_skip = nn.Conv2d(channels_in,channels_out,kernel_size=1) self.res_skip = nn.Conv2d(channels_in,channels_out,kernel_size=1)
#self.res_skip = nn.Conv2d(channels_in,channels_out,kernel_size=1) else:
self.res_skip = nn.Identity()
nn.init.xavier_normal_(self.conv1.weight) nn.init.xavier_normal_(self.conv1.weight)
nn.init.xavier_normal_(self.conv2.weight) nn.init.xavier_normal_(self.conv2.weight)
...@@ -241,5 +240,189 @@ class MidBlock_Res(nn.Module): ...@@ -241,5 +240,189 @@ class MidBlock_Res(nn.Module):
x = self.convblock1(x,t) x = self.convblock1(x,t)
return self.convblock2(x,t) return self.convblock2(x,t)
"""
UNet_Res_Bottleneck
"""
class UNet_Res_Bottleneck(nn.Module):
def __init__(self, attention,channels_in=3, n_channels=64,fctr = [1,2,4,4,8],time_dim=256,**args):
"""
attention : (Bool) wether to use attention layers or not
channels_in : (Int)
n_channels : (Int) Channel size after first convolution
fctr : (list) list of factors for further channel size wrt n_channels
time_dim : (Int) dimenison size for time embeding vector
"""
super().__init__()
channels_out = channels_in
fctr = np.asarray(fctr)*n_channels
# learned time embeddings
self.time_embedder = TimeEmbedding(time_dim = time_dim)
self.time_embedder0 = torch.nn.Sequential(nn.Linear(time_dim,fctr[0]),nn.SELU(),nn.Linear(fctr[0],fctr[0]))
self.time_embedder1 = torch.nn.Sequential(nn.Linear(time_dim,fctr[1]),nn.SELU(),nn.Linear(fctr[1],fctr[1]))
self.time_embedder2 = torch.nn.Sequential(nn.Linear(time_dim,fctr[2]),nn.SELU(),nn.Linear(fctr[2],fctr[2]))
self.time_embedder3 = torch.nn.Sequential(nn.Linear(time_dim,fctr[3]),nn.SELU(),nn.Linear(fctr[3],fctr[3]))
self.time_embedder4 = torch.nn.Sequential(nn.Linear(time_dim,fctr[4]),nn.SELU(),nn.Linear(fctr[4],fctr[4]))
# first conv block
self.first_conv = nn.Conv2d(channels_in,fctr[0],kernel_size=3, padding='same', bias=True)
#down blocks
self.down1 = DownsampleBlock_Res_Bottleneck(fctr[0],fctr[1],time_dim)
self.down2 = DownsampleBlock_Res_Bottleneck(fctr[1],fctr[2],time_dim)
self.down3 = DownsampleBlock_Res_Bottleneck(fctr[2],fctr[3],time_dim,attention=attention)
self.down4 = DownsampleBlock_Res_Bottleneck(fctr[3],fctr[4],time_dim,attention=attention)
#middle layer
self.mid1 = MidBlock_Res_Bottleneck(fctr[4],time_dim,attention=attention)
#up blocks
self.up1 = UpsampleBlock_Res_Bottleneck(fctr[1],fctr[0],time_dim)
self.up2 = UpsampleBlock_Res_Bottleneck(fctr[2],fctr[1],time_dim)
self.up3 = UpsampleBlock_Res_Bottleneck(fctr[3],fctr[2],time_dim,attention=attention)
self.up4 = UpsampleBlock_Res_Bottleneck(fctr[4],fctr[3],time_dim,attention=attention)
# final 1x1 conv
self.end_conv = nn.Conv2d(fctr[0], channels_out, kernel_size=1,bias=True)
# Attention Layers
self.mha21 = MHABlock(fctr[2])
self.mha22 = MHABlock(fctr[2])
self.mha31 = MHABlock(fctr[3])
self.mha32 = MHABlock(fctr[3])
self.mha41 = MHABlock(fctr[4])
self.mha42 = MHABlock(fctr[4])
def forward(self, input, t):
t_emb = self.time_embedder(t).to(input.device)
t_emb0 = self.time_embedder0(t_emb)
t_emb1 = self.time_embedder1(t_emb)
t_emb2 = self.time_embedder2(t_emb)
t_emb3 = self.time_embedder3(t_emb)
t_emb4 = self.time_embedder4(t_emb)
# first two conv layers
x = self.first_conv(input) + t_emb0[:,:,None,None]
#timemb
skip1,x = self.down1(x,t_emb1)
skip2,x = self.down2(x,t_emb2)
skip3,x = self.down3(x,t_emb3)
skip4,x = self.down4(x,t_emb4)
x = self.mid1(x,t_emb4)
x = self.up4(x,skip4,t_emb3)
x = self.up3(x,skip3,t_emb2)
x = self.up2(x,skip2,t_emb1)
x = self.up1(x,skip1,t_emb0)
x = self.end_conv(x)
return x
# Residual Convolution Block
class ConvBlock_Res_Bottleneck(nn.Module):
def __init__(self,
channels_in, # number of input channels fed into the block
channels_out, # number of output channels produced by the block
time_dim,
attention,
num_groups=32, # number of groups used in Group Normalization; channels_in must be divisible by num_groups
):
super().__init__()
self.attention = attention
if self.attention:
self.attlayer = MHABlock(channels_in=channels_out)
# Convolution layer 1
self.conv1 = nn.Conv2d(channels_in, channels_out, kernel_size=1, padding='same', bias=False)
self.gn1 = nn.GroupNorm(num_groups, channels_out)
self.act1 = nn.SiLU()
# Convolution layer 2
self.conv2 = nn.Conv2d(channels_out, channels_out, kernel_size=3, padding='same', bias=False)
self.gn2 = nn.GroupNorm(num_groups, channels_out)
self.act2 = nn.SiLU()
# Convolution layer 3
self.conv3 = nn.Conv2d(channels_out, channels_out, kernel_size=1, padding='same', bias=False)
self.gn3 = nn.GroupNorm(num_groups, channels_out)
self.act3 = nn.SiLU()
#Convolution skip
self.res_skip = nn.Identity()
if channels_in != channels_out:
self.res_skip = nn.Conv2d(channels_in,channels_out,kernel_size=1)
nn.init.xavier_normal_(self.conv1.weight)
nn.init.xavier_normal_(self.conv2.weight)
nn.init.xavier_normal_(self.conv3.weight)
def forward(self, x, t):
res = self.res_skip(x)
# second convolution layer
x = self.act1(self.gn1(self.conv1(x)))
h =x + t[:,:,None,None]
# third convolution layer
h = self.act2(self.gn2(self.conv2(h)))
h = self.act3(self.gn3(self.conv3(h)))
if self.attention:
h = self.attlayer(h)
return h +res
# Down Sample
class DownsampleBlock_Res_Bottleneck(nn.Module):
def __init__(self, channels_in, channels_out,time_dim,attention=False):
super().__init__()
self.pool = nn.MaxPool2d((2,2), stride=2)
self.convblock = ConvBlock_Res_Bottleneck(channels_in, channels_out,time_dim,attention=attention)
def forward(self, x, t):
x = self.convblock(x, t)
h = self.pool(x)
return x,h
# Upsample Block
class UpsampleBlock_Res_Bottleneck(nn.Module):
def __init__(self, channels_in, channels_out,time_dim,attention=False):
super().__init__()
self.upconv = nn.ConvTranspose2d(channels_in, channels_in, kernel_size=2, stride=2)
self.convblock = ConvBlock_Res_Bottleneck(channels_in, channels_out,time_dim,attention=attention)
def forward(self, x, skip_x, t):
x = self.upconv(x)
# skip-connection - merge features from contracting path to its symmetric counterpart in expansive path
out = x + skip_x
out = self.convblock(out, t)
return out
# Middle Block
class MidBlock_Res_Bottleneck(nn.Module):
def __init__(self,channels,time_dim,attention=False):
super().__init__()
self.convblock1 = ConvBlock_Res_Bottleneck(channels,channels,time_dim,attention=attention)
self.convblock2 = ConvBlock_Res_Bottleneck(channels,channels,time_dim,attention=False)
def forward(self,x,t):
x = self.convblock1(x,t)
return self.convblock2(x,t)
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment