Skip to content
GitLab
Explore
Sign in
Primary navigation
Search or go to…
Project
Unconditional Diffusion
Manage
Activity
Members
Code
Repository
Branches
Commits
Tags
Repository graph
Compare revisions
Locked files
Deploy
Model registry
Analyze
Model experiments
Help
Help
Support
GitLab documentation
Compare GitLab plans
GitLab community forum
Contribute to GitLab
Provide feedback
Terms and privacy
Keyboard shortcuts
?
Snippets
Groups
Projects
Show more breadcrumbs
Diffusion Project
Unconditional Diffusion
Commits
4dd87bb4
Commit
4dd87bb4
authored
1 year ago
by
Gonzalo Martin Garcia
Browse files
Options
Downloads
Plain Diff
unnecessary merge conflict all_unets.py
parents
934372c0
a729fb59
No related branches found
No related tags found
No related merge requests found
Changes
1
Show whitespace changes
Inline
Side-by-side
Showing
1 changed file
models/all_unets.py
+188
-5
188 additions, 5 deletions
models/all_unets.py
with
188 additions
and
5 deletions
models/all_unets.py
+
188
−
5
View file @
4dd87bb4
...
@@ -69,7 +69,6 @@ class UNet_Res(nn.Module):
...
@@ -69,7 +69,6 @@ class UNet_Res(nn.Module):
# first two conv layers
# first two conv layers
x
=
self
.
first_conv
(
input
)
+
t_emb0
[:,:,
None
,
None
]
x
=
self
.
first_conv
(
input
)
+
t_emb0
[:,:,
None
,
None
]
#timemb
#timemb
skip1
,
x
=
self
.
down1
(
x
,
t_emb1
)
skip1
,
x
=
self
.
down1
(
x
,
t_emb1
)
skip2
,
x
=
self
.
down2
(
x
,
t_emb2
)
skip2
,
x
=
self
.
down2
(
x
,
t_emb2
)
skip3
,
x
=
self
.
down3
(
x
,
t_emb3
)
skip3
,
x
=
self
.
down3
(
x
,
t_emb3
)
...
@@ -169,10 +168,10 @@ class ConvBlock_Res(nn.Module):
...
@@ -169,10 +168,10 @@ class ConvBlock_Res(nn.Module):
self
.
act3
=
nn
.
SiLU
()
self
.
act3
=
nn
.
SiLU
()
#Convolution skip
#Convolution skip
self
.
res_skip
=
nn
.
Identity
()
if
channels_in
!=
channels_out
:
if
channels_in
!=
channels_out
:
self
.
res_skip
=
nn
.
Conv2d
(
channels_in
,
channels_out
,
kernel_size
=
1
)
self
.
res_skip
=
nn
.
Conv2d
(
channels_in
,
channels_out
,
kernel_size
=
1
)
#self.res_skip = nn.Conv2d(channels_in,channels_out,kernel_size=1)
else
:
self
.
res_skip
=
nn
.
Identity
()
nn
.
init
.
xavier_normal_
(
self
.
conv1
.
weight
)
nn
.
init
.
xavier_normal_
(
self
.
conv1
.
weight
)
nn
.
init
.
xavier_normal_
(
self
.
conv2
.
weight
)
nn
.
init
.
xavier_normal_
(
self
.
conv2
.
weight
)
...
@@ -241,5 +240,189 @@ class MidBlock_Res(nn.Module):
...
@@ -241,5 +240,189 @@ class MidBlock_Res(nn.Module):
x
=
self
.
convblock1
(
x
,
t
)
x
=
self
.
convblock1
(
x
,
t
)
return
self
.
convblock2
(
x
,
t
)
return
self
.
convblock2
(
x
,
t
)
"""
UNet_Res_Bottleneck
"""
class
UNet_Res_Bottleneck
(
nn
.
Module
):
def
__init__
(
self
,
attention
,
channels_in
=
3
,
n_channels
=
64
,
fctr
=
[
1
,
2
,
4
,
4
,
8
],
time_dim
=
256
,
**
args
):
"""
attention : (Bool) wether to use attention layers or not
channels_in : (Int)
n_channels : (Int) Channel size after first convolution
fctr : (list) list of factors for further channel size wrt n_channels
time_dim : (Int) dimenison size for time embeding vector
"""
super
().
__init__
()
channels_out
=
channels_in
fctr
=
np
.
asarray
(
fctr
)
*
n_channels
# learned time embeddings
self
.
time_embedder
=
TimeEmbedding
(
time_dim
=
time_dim
)
self
.
time_embedder0
=
torch
.
nn
.
Sequential
(
nn
.
Linear
(
time_dim
,
fctr
[
0
]),
nn
.
SELU
(),
nn
.
Linear
(
fctr
[
0
],
fctr
[
0
]))
self
.
time_embedder1
=
torch
.
nn
.
Sequential
(
nn
.
Linear
(
time_dim
,
fctr
[
1
]),
nn
.
SELU
(),
nn
.
Linear
(
fctr
[
1
],
fctr
[
1
]))
self
.
time_embedder2
=
torch
.
nn
.
Sequential
(
nn
.
Linear
(
time_dim
,
fctr
[
2
]),
nn
.
SELU
(),
nn
.
Linear
(
fctr
[
2
],
fctr
[
2
]))
self
.
time_embedder3
=
torch
.
nn
.
Sequential
(
nn
.
Linear
(
time_dim
,
fctr
[
3
]),
nn
.
SELU
(),
nn
.
Linear
(
fctr
[
3
],
fctr
[
3
]))
self
.
time_embedder4
=
torch
.
nn
.
Sequential
(
nn
.
Linear
(
time_dim
,
fctr
[
4
]),
nn
.
SELU
(),
nn
.
Linear
(
fctr
[
4
],
fctr
[
4
]))
# first conv block
self
.
first_conv
=
nn
.
Conv2d
(
channels_in
,
fctr
[
0
],
kernel_size
=
3
,
padding
=
'
same
'
,
bias
=
True
)
#down blocks
self
.
down1
=
DownsampleBlock_Res_Bottleneck
(
fctr
[
0
],
fctr
[
1
],
time_dim
)
self
.
down2
=
DownsampleBlock_Res_Bottleneck
(
fctr
[
1
],
fctr
[
2
],
time_dim
)
self
.
down3
=
DownsampleBlock_Res_Bottleneck
(
fctr
[
2
],
fctr
[
3
],
time_dim
,
attention
=
attention
)
self
.
down4
=
DownsampleBlock_Res_Bottleneck
(
fctr
[
3
],
fctr
[
4
],
time_dim
,
attention
=
attention
)
#middle layer
self
.
mid1
=
MidBlock_Res_Bottleneck
(
fctr
[
4
],
time_dim
,
attention
=
attention
)
#up blocks
self
.
up1
=
UpsampleBlock_Res_Bottleneck
(
fctr
[
1
],
fctr
[
0
],
time_dim
)
self
.
up2
=
UpsampleBlock_Res_Bottleneck
(
fctr
[
2
],
fctr
[
1
],
time_dim
)
self
.
up3
=
UpsampleBlock_Res_Bottleneck
(
fctr
[
3
],
fctr
[
2
],
time_dim
,
attention
=
attention
)
self
.
up4
=
UpsampleBlock_Res_Bottleneck
(
fctr
[
4
],
fctr
[
3
],
time_dim
,
attention
=
attention
)
# final 1x1 conv
self
.
end_conv
=
nn
.
Conv2d
(
fctr
[
0
],
channels_out
,
kernel_size
=
1
,
bias
=
True
)
# Attention Layers
self
.
mha21
=
MHABlock
(
fctr
[
2
])
self
.
mha22
=
MHABlock
(
fctr
[
2
])
self
.
mha31
=
MHABlock
(
fctr
[
3
])
self
.
mha32
=
MHABlock
(
fctr
[
3
])
self
.
mha41
=
MHABlock
(
fctr
[
4
])
self
.
mha42
=
MHABlock
(
fctr
[
4
])
def
forward
(
self
,
input
,
t
):
t_emb
=
self
.
time_embedder
(
t
).
to
(
input
.
device
)
t_emb0
=
self
.
time_embedder0
(
t_emb
)
t_emb1
=
self
.
time_embedder1
(
t_emb
)
t_emb2
=
self
.
time_embedder2
(
t_emb
)
t_emb3
=
self
.
time_embedder3
(
t_emb
)
t_emb4
=
self
.
time_embedder4
(
t_emb
)
# first two conv layers
x
=
self
.
first_conv
(
input
)
+
t_emb0
[:,:,
None
,
None
]
#timemb
skip1
,
x
=
self
.
down1
(
x
,
t_emb1
)
skip2
,
x
=
self
.
down2
(
x
,
t_emb2
)
skip3
,
x
=
self
.
down3
(
x
,
t_emb3
)
skip4
,
x
=
self
.
down4
(
x
,
t_emb4
)
x
=
self
.
mid1
(
x
,
t_emb4
)
x
=
self
.
up4
(
x
,
skip4
,
t_emb3
)
x
=
self
.
up3
(
x
,
skip3
,
t_emb2
)
x
=
self
.
up2
(
x
,
skip2
,
t_emb1
)
x
=
self
.
up1
(
x
,
skip1
,
t_emb0
)
x
=
self
.
end_conv
(
x
)
return
x
# Residual Convolution Block
class
ConvBlock_Res_Bottleneck
(
nn
.
Module
):
def
__init__
(
self
,
channels_in
,
# number of input channels fed into the block
channels_out
,
# number of output channels produced by the block
time_dim
,
attention
,
num_groups
=
32
,
# number of groups used in Group Normalization; channels_in must be divisible by num_groups
):
super
().
__init__
()
self
.
attention
=
attention
if
self
.
attention
:
self
.
attlayer
=
MHABlock
(
channels_in
=
channels_out
)
# Convolution layer 1
self
.
conv1
=
nn
.
Conv2d
(
channels_in
,
channels_out
,
kernel_size
=
1
,
padding
=
'
same
'
,
bias
=
False
)
self
.
gn1
=
nn
.
GroupNorm
(
num_groups
,
channels_out
)
self
.
act1
=
nn
.
SiLU
()
# Convolution layer 2
self
.
conv2
=
nn
.
Conv2d
(
channels_out
,
channels_out
,
kernel_size
=
3
,
padding
=
'
same
'
,
bias
=
False
)
self
.
gn2
=
nn
.
GroupNorm
(
num_groups
,
channels_out
)
self
.
act2
=
nn
.
SiLU
()
# Convolution layer 3
self
.
conv3
=
nn
.
Conv2d
(
channels_out
,
channels_out
,
kernel_size
=
1
,
padding
=
'
same
'
,
bias
=
False
)
self
.
gn3
=
nn
.
GroupNorm
(
num_groups
,
channels_out
)
self
.
act3
=
nn
.
SiLU
()
#Convolution skip
self
.
res_skip
=
nn
.
Identity
()
if
channels_in
!=
channels_out
:
self
.
res_skip
=
nn
.
Conv2d
(
channels_in
,
channels_out
,
kernel_size
=
1
)
nn
.
init
.
xavier_normal_
(
self
.
conv1
.
weight
)
nn
.
init
.
xavier_normal_
(
self
.
conv2
.
weight
)
nn
.
init
.
xavier_normal_
(
self
.
conv3
.
weight
)
def
forward
(
self
,
x
,
t
):
res
=
self
.
res_skip
(
x
)
# second convolution layer
x
=
self
.
act1
(
self
.
gn1
(
self
.
conv1
(
x
)))
h
=
x
+
t
[:,:,
None
,
None
]
# third convolution layer
h
=
self
.
act2
(
self
.
gn2
(
self
.
conv2
(
h
)))
h
=
self
.
act3
(
self
.
gn3
(
self
.
conv3
(
h
)))
if
self
.
attention
:
h
=
self
.
attlayer
(
h
)
return
h
+
res
# Down Sample
class
DownsampleBlock_Res_Bottleneck
(
nn
.
Module
):
def
__init__
(
self
,
channels_in
,
channels_out
,
time_dim
,
attention
=
False
):
super
().
__init__
()
self
.
pool
=
nn
.
MaxPool2d
((
2
,
2
),
stride
=
2
)
self
.
convblock
=
ConvBlock_Res_Bottleneck
(
channels_in
,
channels_out
,
time_dim
,
attention
=
attention
)
def
forward
(
self
,
x
,
t
):
x
=
self
.
convblock
(
x
,
t
)
h
=
self
.
pool
(
x
)
return
x
,
h
# Upsample Block
class
UpsampleBlock_Res_Bottleneck
(
nn
.
Module
):
def
__init__
(
self
,
channels_in
,
channels_out
,
time_dim
,
attention
=
False
):
super
().
__init__
()
self
.
upconv
=
nn
.
ConvTranspose2d
(
channels_in
,
channels_in
,
kernel_size
=
2
,
stride
=
2
)
self
.
convblock
=
ConvBlock_Res_Bottleneck
(
channels_in
,
channels_out
,
time_dim
,
attention
=
attention
)
def
forward
(
self
,
x
,
skip_x
,
t
):
x
=
self
.
upconv
(
x
)
# skip-connection - merge features from contracting path to its symmetric counterpart in expansive path
out
=
x
+
skip_x
out
=
self
.
convblock
(
out
,
t
)
return
out
# Middle Block
class
MidBlock_Res_Bottleneck
(
nn
.
Module
):
def
__init__
(
self
,
channels
,
time_dim
,
attention
=
False
):
super
().
__init__
()
self
.
convblock1
=
ConvBlock_Res_Bottleneck
(
channels
,
channels
,
time_dim
,
attention
=
attention
)
self
.
convblock2
=
ConvBlock_Res_Bottleneck
(
channels
,
channels
,
time_dim
,
attention
=
False
)
def
forward
(
self
,
x
,
t
):
x
=
self
.
convblock1
(
x
,
t
)
return
self
.
convblock2
(
x
,
t
)
This diff is collapsed.
Click to expand it.
Preview
0%
Loading
Try again
or
attach a new file
.
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Save comment
Cancel
Please
register
or
sign in
to comment