Skip to content
Snippets Groups Projects
Commit a729fb59 authored by Tobias Seibel's avatar Tobias Seibel
Browse files

corrected model

parent 910ae792
No related branches found
No related tags found
No related merge requests found
......@@ -44,7 +44,7 @@ class UNet_Res(nn.Module):
self.up1 = UpsampleBlock_Res(fctr[1],fctr[0],time_dim)
self.up2 = UpsampleBlock_Res(fctr[2],fctr[1],time_dim)
self.up3 = UpsampleBlock_Res(fctr[3],fctr[2],time_dim,attention=attention)
self.up4 = UpsampleBlock_Res(fctr[4],fctr[3],time_dim)
self.up4 = UpsampleBlock_Res(fctr[4],fctr[3],time_dim,attention=attention)
# final 1x1 conv
self.end_conv = nn.Conv2d(fctr[0], channels_out, kernel_size=1,bias=True)
......@@ -69,7 +69,6 @@ class UNet_Res(nn.Module):
# first two conv layers
x = self.first_conv(input) + t_emb0[:,:,None,None]
#timemb
skip1,x = self.down1(x,t_emb1)
skip2,x = self.down2(x,t_emb2)
skip3,x = self.down3(x,t_emb3)
......@@ -154,22 +153,25 @@ class ConvBlock_Res(nn.Module):
self.attlayer = MHABlock(channels_in=channels_out)
# Convolution layer 1
self.conv1 = nn.Conv2d(channels_in, channels_out, kernel_size=3, padding='same', bias=True)
self.conv1 = nn.Conv2d(channels_in, channels_out, kernel_size=3, padding='same', bias=False)
self.gn1 = nn.GroupNorm(num_groups, channels_out)
self.act1 = nn.SiLU()
# Convolution layer 2
self.conv2 = nn.Conv2d(channels_out, channels_out, kernel_size=3, padding='same', bias=True)
self.conv2 = nn.Conv2d(channels_out, channels_out, kernel_size=3, padding='same', bias=False)
self.gn2 = nn.GroupNorm(num_groups, channels_out)
self.act2 = nn.SiLU()
# Convolution layer 3
self.conv3 = nn.Conv2d(channels_out, channels_out, kernel_size=3, padding='same', bias=True)
self.conv3 = nn.Conv2d(channels_out, channels_out, kernel_size=3, padding='same', bias=False)
self.gn3 = nn.GroupNorm(num_groups, channels_out)
self.act3 = nn.SiLU()
#Convolution skip
if channels_in!=channels_out:
self.res_skip = nn.Conv2d(channels_in,channels_out,kernel_size=1)
else:
self.res_skip = nn.Identity()
nn.init.xavier_normal_(self.conv1.weight)
nn.init.xavier_normal_(self.conv2.weight)
......@@ -238,5 +240,189 @@ class MidBlock_Res(nn.Module):
x = self.convblock1(x,t)
return self.convblock2(x,t)
"""
UNet_Res_Bottleneck
"""
class UNet_Res_Bottleneck(nn.Module):
def __init__(self, attention,channels_in=3, n_channels=64,fctr = [1,2,4,4,8],time_dim=256,**args):
"""
attention : (Bool) wether to use attention layers or not
channels_in : (Int)
n_channels : (Int) Channel size after first convolution
fctr : (list) list of factors for further channel size wrt n_channels
time_dim : (Int) dimenison size for time embeding vector
"""
super().__init__()
channels_out = channels_in
fctr = np.asarray(fctr)*n_channels
# learned time embeddings
self.time_embedder = TimeEmbedding(time_dim = time_dim)
self.time_embedder0 = torch.nn.Sequential(nn.Linear(time_dim,fctr[0]),nn.SELU(),nn.Linear(fctr[0],fctr[0]))
self.time_embedder1 = torch.nn.Sequential(nn.Linear(time_dim,fctr[1]),nn.SELU(),nn.Linear(fctr[1],fctr[1]))
self.time_embedder2 = torch.nn.Sequential(nn.Linear(time_dim,fctr[2]),nn.SELU(),nn.Linear(fctr[2],fctr[2]))
self.time_embedder3 = torch.nn.Sequential(nn.Linear(time_dim,fctr[3]),nn.SELU(),nn.Linear(fctr[3],fctr[3]))
self.time_embedder4 = torch.nn.Sequential(nn.Linear(time_dim,fctr[4]),nn.SELU(),nn.Linear(fctr[4],fctr[4]))
# first conv block
self.first_conv = nn.Conv2d(channels_in,fctr[0],kernel_size=3, padding='same', bias=True)
#down blocks
self.down1 = DownsampleBlock_Res_Bottleneck(fctr[0],fctr[1],time_dim)
self.down2 = DownsampleBlock_Res_Bottleneck(fctr[1],fctr[2],time_dim)
self.down3 = DownsampleBlock_Res_Bottleneck(fctr[2],fctr[3],time_dim,attention=attention)
self.down4 = DownsampleBlock_Res_Bottleneck(fctr[3],fctr[4],time_dim,attention=attention)
#middle layer
self.mid1 = MidBlock_Res_Bottleneck(fctr[4],time_dim,attention=attention)
#up blocks
self.up1 = UpsampleBlock_Res_Bottleneck(fctr[1],fctr[0],time_dim)
self.up2 = UpsampleBlock_Res_Bottleneck(fctr[2],fctr[1],time_dim)
self.up3 = UpsampleBlock_Res_Bottleneck(fctr[3],fctr[2],time_dim,attention=attention)
self.up4 = UpsampleBlock_Res_Bottleneck(fctr[4],fctr[3],time_dim,attention=attention)
# final 1x1 conv
self.end_conv = nn.Conv2d(fctr[0], channels_out, kernel_size=1,bias=True)
# Attention Layers
self.mha21 = MHABlock(fctr[2])
self.mha22 = MHABlock(fctr[2])
self.mha31 = MHABlock(fctr[3])
self.mha32 = MHABlock(fctr[3])
self.mha41 = MHABlock(fctr[4])
self.mha42 = MHABlock(fctr[4])
def forward(self, input, t):
t_emb = self.time_embedder(t).to(input.device)
t_emb0 = self.time_embedder0(t_emb)
t_emb1 = self.time_embedder1(t_emb)
t_emb2 = self.time_embedder2(t_emb)
t_emb3 = self.time_embedder3(t_emb)
t_emb4 = self.time_embedder4(t_emb)
# first two conv layers
x = self.first_conv(input) + t_emb0[:,:,None,None]
#timemb
skip1,x = self.down1(x,t_emb1)
skip2,x = self.down2(x,t_emb2)
skip3,x = self.down3(x,t_emb3)
skip4,x = self.down4(x,t_emb4)
x = self.mid1(x,t_emb4)
x = self.up4(x,skip4,t_emb3)
x = self.up3(x,skip3,t_emb2)
x = self.up2(x,skip2,t_emb1)
x = self.up1(x,skip1,t_emb0)
x = self.end_conv(x)
return x
# Residual Convolution Block
class ConvBlock_Res_Bottleneck(nn.Module):
def __init__(self,
channels_in, # number of input channels fed into the block
channels_out, # number of output channels produced by the block
time_dim,
attention,
num_groups=32, # number of groups used in Group Normalization; channels_in must be divisible by num_groups
):
super().__init__()
self.attention = attention
if self.attention:
self.attlayer = MHABlock(channels_in=channels_out)
# Convolution layer 1
self.conv1 = nn.Conv2d(channels_in, channels_out, kernel_size=1, padding='same', bias=False)
self.gn1 = nn.GroupNorm(num_groups, channels_out)
self.act1 = nn.SiLU()
# Convolution layer 2
self.conv2 = nn.Conv2d(channels_out, channels_out, kernel_size=3, padding='same', bias=False)
self.gn2 = nn.GroupNorm(num_groups, channels_out)
self.act2 = nn.SiLU()
# Convolution layer 3
self.conv3 = nn.Conv2d(channels_out, channels_out, kernel_size=1, padding='same', bias=False)
self.gn3 = nn.GroupNorm(num_groups, channels_out)
self.act3 = nn.SiLU()
#Convolution skip
self.res_skip = nn.Identity()
if channels_in != channels_out:
self.res_skip = nn.Conv2d(channels_in,channels_out,kernel_size=1)
nn.init.xavier_normal_(self.conv1.weight)
nn.init.xavier_normal_(self.conv2.weight)
nn.init.xavier_normal_(self.conv3.weight)
def forward(self, x, t):
res = self.res_skip(x)
# second convolution layer
x = self.act1(self.gn1(self.conv1(x)))
h =x + t[:,:,None,None]
# third convolution layer
h = self.act2(self.gn2(self.conv2(h)))
h = self.act3(self.gn3(self.conv3(h)))
if self.attention:
h = self.attlayer(h)
return h +res
# Down Sample
class DownsampleBlock_Res_Bottleneck(nn.Module):
def __init__(self, channels_in, channels_out,time_dim,attention=False):
super().__init__()
self.pool = nn.MaxPool2d((2,2), stride=2)
self.convblock = ConvBlock_Res_Bottleneck(channels_in, channels_out,time_dim,attention=attention)
def forward(self, x, t):
x = self.convblock(x, t)
h = self.pool(x)
return x,h
# Upsample Block
class UpsampleBlock_Res_Bottleneck(nn.Module):
def __init__(self, channels_in, channels_out,time_dim,attention=False):
super().__init__()
self.upconv = nn.ConvTranspose2d(channels_in, channels_in, kernel_size=2, stride=2)
self.convblock = ConvBlock_Res_Bottleneck(channels_in, channels_out,time_dim,attention=attention)
def forward(self, x, skip_x, t):
x = self.upconv(x)
# skip-connection - merge features from contracting path to its symmetric counterpart in expansive path
out = x + skip_x
out = self.convblock(out, t)
return out
# Middle Block
class MidBlock_Res_Bottleneck(nn.Module):
def __init__(self,channels,time_dim,attention=False):
super().__init__()
self.convblock1 = ConvBlock_Res_Bottleneck(channels,channels,time_dim,attention=attention)
self.convblock2 = ConvBlock_Res_Bottleneck(channels,channels,time_dim,attention=False)
def forward(self,x,t):
x = self.convblock1(x,t)
return self.convblock2(x,t)
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment