Skip to content
GitLab
Explore
Sign in
Primary navigation
Search or go to…
Project
Unconditional Diffusion
Manage
Activity
Members
Code
Repository
Branches
Commits
Tags
Repository graph
Compare revisions
Locked files
Deploy
Model registry
Analyze
Model experiments
Help
Help
Support
GitLab documentation
Compare GitLab plans
GitLab community forum
Contribute to GitLab
Provide feedback
Terms and privacy
Keyboard shortcuts
?
Snippets
Groups
Projects
Show more breadcrumbs
Diffusion Project
Unconditional Diffusion
Commits
f5cd01fe
Commit
f5cd01fe
authored
1 year ago
by
Srijeet Roy
Browse files
Options
Downloads
Patches
Plain Diff
delete redundant UNet models
parent
04d7a206
No related branches found
No related tags found
No related merge requests found
Changes
2
Show whitespace changes
Inline
Side-by-side
Showing
2 changed files
models/unet.py
+0
-138
0 additions, 138 deletions
models/unet.py
models/unet_unconditional_diffusion.py
+0
-446
0 additions, 446 deletions
models/unet_unconditional_diffusion.py
with
0 additions
and
584 deletions
models/unet.py
deleted
100644 → 0
+
0
−
138
View file @
04d7a206
# -*- coding: utf-8 -*-
"""
UNet.ipynb
Automatically generated by Colaboratory.
Original file is located at
https://colab.research.google.com/drive/1BdiIHZYyESTyt-NVRoJXUBMlKreOExkL
"""
'''
Implementation of U-Net architecture
Structure: Input -> Contracting Path -> Expansive Path -> Output
Contracting Path progressively downsamples the input
Expansive Path incrementally upsamples the output of contracting path
'''
import
torch
import
torch.nn
as
nn
import
torch.nn.functional
as
F
import
torchvision.transforms
as
transforms
# Two 3x3 conv layers
class
ConvBlock
(
nn
.
Module
):
def
__init__
(
self
,
channels_in
,
channels_out
):
super
().
__init__
()
# first conv layer with He initialization and batch normalization
self
.
conv1
=
nn
.
Conv2d
(
channels_in
,
channels_out
,
kernel_size
=
3
,
padding
=
1
,
bias
=
False
)
self
.
bn1
=
nn
.
BatchNorm2d
(
channels_out
)
nn
.
init
.
kaiming_uniform_
(
self
.
conv1
.
weight
,
nonlinearity
=
'
relu
'
)
# second conv layer with He initialization and batch normalization
self
.
conv2
=
nn
.
Conv2d
(
channels_out
,
channels_out
,
kernel_size
=
3
,
padding
=
1
,
bias
=
False
)
self
.
bn2
=
nn
.
BatchNorm2d
(
channels_out
)
nn
.
init
.
kaiming_uniform_
(
self
.
conv2
.
weight
,
nonlinearity
=
'
relu
'
)
def
forward
(
self
,
x
):
x
=
self
.
conv1
(
x
)
#x = self.bn1(x)
x
=
F
.
relu
(
x
)
x
=
self
.
conv2
(
x
)
#x = self.bn2(x)
x
=
F
.
relu
(
x
)
return
x
# Downsampling block - maxpool halves the resolution, followed by ConvBlock operation
class
DownsampleBlock
(
nn
.
Module
):
def
__init__
(
self
,
channels_in
,
channels_out
):
super
().
__init__
()
self
.
pool
=
nn
.
MaxPool2d
((
2
,
2
),
stride
=
2
)
self
.
convblock
=
ConvBlock
(
channels_in
,
channels_out
)
def
forward
(
self
,
x
):
x
=
self
.
pool
(
x
)
x
=
self
.
convblock
(
x
)
return
x
# Upsampling block - double the resolution with ConvTranspose, followed by ConvBlock operation
class
UpsampleBlock
(
nn
.
Module
):
def
__init__
(
self
,
channels_in
,
channels_out
):
super
().
__init__
()
self
.
upconv
=
nn
.
ConvTranspose2d
(
channels_in
,
channels_out
,
kernel_size
=
2
,
stride
=
2
)
self
.
convblock
=
ConvBlock
(
channels_in
,
channels_out
)
def
forward
(
self
,
x
,
down_x
):
x
=
self
.
upconv
(
x
)
# skip-connection - merge features from contracting path to its symmetric counterpart in expansive path
down_x
=
transforms
.
CenterCrop
(
size
=
(
x
.
shape
[
2
],
x
.
shape
[
3
]))(
down_x
)
x
=
torch
.
cat
([
x
,
down_x
],
dim
=
1
)
x
=
self
.
convblock
(
x
)
return
x
# U-Net model
class
UNet
(
nn
.
Module
):
def
__init__
(
self
,
channels_in
,
channels_out
,
n_channels
=
64
,
n_blocks
=
4
,
*
kwargs
,
**
args
):
super
().
__init__
()
# number of channels to be produced by the first conv block
self
.
n_channels
=
n_channels
# number of downsampling/upsampling blocks
self
.
n_blocks
=
n_blocks
# first conv block
self
.
first_conv
=
ConvBlock
(
channels_in
,
self
.
n_channels
)
# downsampling and upsampling blocks
down_channels
=
[]
up_channels
=
[]
for
i
in
(
2
**
p
for
p
in
range
(
self
.
n_blocks
)):
down_channels
.
append
((
self
.
n_channels
*
i
,
self
.
n_channels
*
i
*
2
))
up_channels
.
insert
(
0
,(
self
.
n_channels
*
i
*
2
,
self
.
n_channels
*
i
))
self
.
downsample
=
nn
.
ModuleList
([
DownsampleBlock
(
c_in
,
c_out
)
for
c_in
,
c_out
in
down_channels
])
self
.
upsample
=
nn
.
ModuleList
([
UpsampleBlock
(
c_in
,
c_out
)
for
c_in
,
c_out
in
up_channels
])
# final 1x1 conv
self
.
end_conv
=
nn
.
Conv2d
(
self
.
n_channels
,
channels_out
,
kernel_size
=
1
)
def
forward
(
self
,
x
,
t
):
# to store feature maps from contracting path
skip
=
[]
# first two conv layers
x
=
self
.
first_conv
(
x
)
skip
.
insert
(
0
,
x
)
# downsampling blocks
for
i
in
range
(
self
.
n_blocks
):
x
=
self
.
downsample
[
i
](
x
)
# store feature maps
if
i
<
self
.
n_blocks
-
1
:
skip
.
insert
(
0
,
x
)
# upsampling blocks (with skip-connections)
for
i
in
range
(
self
.
n_blocks
):
x
=
self
.
upsample
[
i
](
x
,
skip
[
i
])
# final 1x1 conv layer
x
=
self
.
end_conv
(
x
)
return
x
\ No newline at end of file
This diff is collapsed.
Click to expand it.
models/unet_unconditional_diffusion.py
deleted
100644 → 0
+
0
−
446
View file @
04d7a206
import
torch
import
torch.nn
as
nn
import
torch.nn.functional
as
F
import
torchvision.transforms
as
transforms
"""
TimeEmbedding - generates time embedding for a time step
input (0-d tensor) -> tensor of shape [time_channels, time_dim]
Arguments:
time_dim: int, default=64,
dimensionality of the time embedding (has to match batch size)
time_channels: int, default=256,
number of channels in the time embedding
"""
class
TimeEmbedding
(
nn
.
Module
):
def
__init__
(
self
,
time_dim
=
64
,
time_channels
=
256
):
super
().
__init__
()
self
.
time_dim
=
time_dim
self
.
time_channels
=
time_channels
# argument to sin/cos fn: t / 10000^(i / d) where i = 2k or 2k + 1 - https://kazemnejad.com/blog/transformer_architecture_positional_encoding/
self
.
factor
=
torch
.
exp
(
torch
.
arange
(
0
,
self
.
time_dim
,
2
)
*
(
-
torch
.
log
(
torch
.
tensor
(
10000.0
))
/
self
.
time_dim
))
def
forward
(
self
,
t
):
# if t = tensor.torch(int), t.shape = []
# change it so that t.shape = [1]
if
len
(
t
.
shape
)
==
0
:
t
=
t
.
unsqueeze
(
0
)
t
=
t
.
unsqueeze
(
1
)
*
self
.
factor
# shape of embedding [time_channels, dim]
emb
=
torch
.
zeros
(
self
.
time_channels
,
self
.
time_dim
)
emb
[:,
0
::
2
]
=
torch
.
sin
(
t
)
emb
[:,
1
::
2
]
=
torch
.
cos
(
t
)
return
emb
"""
ConvResBlock - building block of the U-Net architecture
input -> Conv1 -(+ time embedding) -> Conv2 -(+ residual) -> Multi-head attention
Arguments:
channels_in : int,
number of input channels fed into the block
channels_out: int,
number of output channels produced by the block
activation: {
'
relu
'
,
'
leakyrelu
'
,
'
selu
'
,
'
gelu
'
,
'
silu
'
/
'
swish
'
}, default=
'
relu
'
,
activation function in the neural network
weight_init: {
'
he
'
,
'
torch
'
}, default=
'
he
'
,
weight initializer for convolution layers; choose between He
initialization and PyTorch
'
s default initialization
time_channels: int, default=256,
number of channels for time embedding
num_groups: int, default=32,
number of groups used in Group Normalization; channels_in must be
divisible by num_groups
dropout: float, default=0.1,
drop-out to be applied
attention: boolean, default=False,
whether Multi-head attention (MHA) is applied or not
num_attention_heads: int, default=1,
number of attention heads in MHA
"""
class
ConvResBlock
(
nn
.
Module
):
def
__init__
(
self
,
channels_in
,
# number of input channels fed into the block
channels_out
,
# number of output channels produced by the block
activation
,
# activation function. Options: {'relu', 'leakyrelu', 'selu', 'gelu', 'silu'/'swish'}
weight_init
=
'
he
'
,
# weight initialization. Options: {'he', 'torch'}
time_channels
=
64
,
# number of channels for time embedding
num_groups
=
32
,
# number of groups used in Group Normalization; channels_in must be divisible by num_groups
dropout
=
0.1
,
# drop-out to be applied
attention
=
False
,
# boolean: whether Multi-head attention (MHA) is applied or not
num_attention_heads
=
1
# number of attention heads in MHA
):
super
().
__init__
()
self
.
activation
=
activation
# Convolution layer 1
self
.
conv1
=
nn
.
Conv2d
(
channels_in
,
channels_out
,
kernel_size
=
3
,
padding
=
1
,
bias
=
False
)
self
.
gn1
=
nn
.
GroupNorm
(
num_groups
,
channels_out
)
self
.
act1
=
self
.
activation
# Convolution layer 2
self
.
conv2
=
nn
.
Conv2d
(
channels_out
,
channels_out
,
kernel_size
=
3
,
padding
=
1
,
bias
=
False
)
self
.
gn2
=
nn
.
GroupNorm
(
num_groups
,
channels_out
)
self
.
act2
=
self
.
activation
if
weight_init
==
'
he
'
:
nn
.
init
.
kaiming_uniform_
(
self
.
conv1
.
weight
,
nonlinearity
=
'
relu
'
)
nn
.
init
.
kaiming_uniform_
(
self
.
conv2
.
weight
,
nonlinearity
=
'
relu
'
)
# Drop-out
self
.
dropout
=
nn
.
Dropout
(
dropout
)
# Residual connection
self
.
residual
=
nn
.
Identity
()
if
channels_in
!=
channels_out
:
self
.
residual
=
nn
.
Conv2d
(
channels_in
,
channels_out
,
kernel_size
=
1
)
if
weight_init
==
'
he
'
:
nn
.
init
.
kaiming_uniform_
(
self
.
residual
.
weight
,
nonlinearity
=
'
relu
'
)
self
.
residual_act
=
self
.
activation
# Time embedding - map time embedding to have the same number of channels as image activation
self
.
time_emb
=
nn
.
Linear
(
time_channels
,
channels_out
)
self
.
time_act
=
self
.
activation
# Multi-head attention
self
.
attention
=
attention
self
.
num_attention_heads
=
num_attention_heads
self
.
self_attention
=
nn
.
Identity
()
if
self
.
attention
:
self
.
self_attention
=
nn
.
MultiheadAttention
(
channels_out
,
num_heads
=
self
.
num_attention_heads
)
def
forward
(
self
,
x
,
t
):
# store input, to be used as residual
res
=
self
.
residual
(
x
)
if
isinstance
(
self
.
residual
,
nn
.
Conv2d
):
res
=
self
.
residual_act
(
res
)
# first convolution layer
x
=
self
.
act1
(
self
.
gn1
(
self
.
conv1
(
x
)))
# add temporal information with time embedding
t
=
self
.
time_act
(
self
.
time_emb
(
t
.
T
))
x
+=
t
[:,
:,
None
,
None
]
# Drop-out
x
=
self
.
dropout
(
x
)
# second convolution layer
x
=
self
.
act2
(
self
.
gn2
(
self
.
conv2
(
x
)))
# add residual
x
+=
res
# apply self-attention
if
self
.
attention
:
batch_size
=
x
.
shape
[
0
]
height
=
x
.
shape
[
2
]
width
=
x
.
shape
[
3
]
sequence_length
=
height
*
width
x
=
x
.
permute
(
2
,
3
,
0
,
1
).
reshape
(
sequence_length
,
batch_size
,
-
1
)
x
,
_
=
self
.
self_attention
(
x
,
x
,
x
)
x
=
x
.
reshape
(
batch_size
,
-
1
,
height
,
width
)
return
x
"""
UNet_Unconditional_Diffusion - the U-Net architecture
Arguments:
channels_in: int,
number of input channels to the U-Net; for RGB images, channels_in = 3
channels_out: int,
number of output channels
activation: {
'
relu
'
,
'
leakyrelu
'
,
'
selu
'
,
'
gelu
'
,
'
silu
'
/
'
swish
'
}, default=
'
relu
'
,
activation function in the neural network
weight_init: {
'
he
'
,
'
torch
'
}, default=
'
he
'
,
weight initializer for convolution layers; choose between He
initialization and PyTorch
'
s default initialization
projection_features: int, default=64,
number of image features after first convolution layer
time_dim: int, default=64,
dimensionality of the time embedding (has to match batch size)
time_channels: int, default=256,
number of time channels
num_stages: int, default=4,
number of stages in contracting/expansive path
attention_list: int list, default=None,
specify number of features produced by stages
num_blocks: int, default=2,
number of ConvResBlock in each contracting/expansive path
num_groupnorm_groups: int, default=32,
number of groups used in Group Normalization inside a ConvResBlock;
channels_in to a ConvResBlock must be divisible by num_groups
dropout: float, default=0.1,
drop-out to be applied
attention_list: boolean list, default=None,
specify MHA pattern across stages
num_attention_heads: int, default=1,
number of attention heads in MHA inside a ConvResBlock
"""
class
UNet_Unconditional_Diffusion
(
nn
.
Module
):
def
__init__
(
self
,
channels_in
,
# number of input channels to the U-Net; for RGB images, channels_in = 3
channels_out
,
# number of output channels
activation
=
'
relu
'
,
# activation function. Options: {'relu', 'leakyrelu', 'selu', 'gelu', 'silu'/'swish'}
weight_init
=
'
he
'
,
# weight initialization. Options: {'he', 'torch'}
projection_features
=
64
,
# number of image features after first convolution layer
time_dim
=
64
,
# dimensionality of the time embedding (has to match batch size)
time_channels
=
256
,
# number of time channels
num_stages
=
4
,
# number of stages in contracting/expansive path
stage_list
=
None
,
# specify number of features produced by stages
num_blocks
=
2
,
# number of ConvResBlock in each contracting/expansive path
num_groupnorm_groups
=
32
,
# number of groups used in Group Normalization inside a ConvResBlock
dropout
=
0.1
,
# drop-out to be applied inside a ConvResBlock
attention_list
=
None
,
# specify MHA pattern across stages
num_attention_heads
=
1
,
# number of attention heads in MHA inside a ConvResBlock
**
args
):
super
().
__init__
()
self
.
channels_in
=
channels_in
self
.
channels_out
=
channels_out
if
activation
==
'
relu
'
:
self
.
activation
=
nn
.
ReLU
()
elif
activation
==
'
leakyrelu
'
:
self
.
activation
=
nn
.
LeakyReLU
()
elif
activation
==
'
selu
'
:
self
.
activation
=
nn
.
SELU
()
elif
activation
==
'
gelu
'
:
self
.
activation
=
nn
.
GELU
()
elif
activation
==
'
swish
'
or
activation
==
'
silu
'
:
self
.
activation
=
nn
.
SiLU
()
# number of channels to be produced by the first conv block - image projection
self
.
projection_features
=
projection_features
self
.
first_conv
=
nn
.
Conv2d
(
channels_in
,
self
.
projection_features
,
kernel_size
=
3
,
padding
=
1
)
if
weight_init
==
'
he
'
:
nn
.
init
.
kaiming_uniform_
(
self
.
first_conv
.
weight
,
nonlinearity
=
'
relu
'
)
self
.
first_act
=
self
.
activation
# number of time channels
self
.
time_dim
=
time_dim
self
.
time_channels
=
time_channels
self
.
time_embedding
=
TimeEmbedding
(
time_dim
=
self
.
time_dim
,
time_channels
=
self
.
time_channels
)
# number of downsampling/upsampling stages
self
.
num_stages
=
num_stages
# number of ConvResBlocks in each downsampling/upsampling step
self
.
num_blocks
=
num_blocks
if
attention_list
is
None
:
# boolean list assigning attention blocks in the contracting and expansive path
# default - first half of contracting path has no attention, second half does;
# first half of expansive path has attention, second half doesn't
self
.
attention_list
=
[]
for
i
in
range
(
self
.
num_stages
):
if
i
<
self
.
num_stages
//
2
:
self
.
attention_list
.
append
(
False
)
else
:
self
.
attention_list
.
append
(
True
)
else
:
self
.
attention_list
=
attention_list
# [False, False, True, True] - paper implementation for similar 4 stage U-Net
# number of features produced by each stage
if
stage_list
is
None
:
# default - successive stages double the number of channels
self
.
stages
=
[
projection_features
*
2
**
i
for
i
in
range
(
1
,
self
.
num_stages
+
1
)]
else
:
self
.
stages
=
stage_list
# [64, 128, 256, 1024] - paper implementation for similar 4 stage U-Net
# contracting path
contracting_path
=
[]
# number of channels to go into the first ConvResBlock = number of output channels from first conv layer
c_in
=
c_out
=
projection_features
# there are num_stages number of stages
# each stage has num_blocks number of ConvRes+Attention blocks
# each stage (except for the last) ends with a downsampling layer - maxpool
for
i
in
range
(
self
.
num_stages
):
c_out
=
self
.
stages
[
i
]
for
_
in
range
(
num_blocks
):
contracting_path
.
append
(
ConvResBlock
(
channels_in
=
c_in
,
channels_out
=
c_out
,
activation
=
self
.
activation
,
weight_init
=
weight_init
,
time_channels
=
self
.
time_channels
,
num_groups
=
num_groupnorm_groups
,
dropout
=
dropout
,
attention
=
self
.
attention_list
[
i
],
num_attention_heads
=
num_attention_heads
))
c_in
=
c_out
# downsample, if it is not the last stage
if
i
<
self
.
num_stages
-
1
:
contracting_path
.
append
(
nn
.
MaxPool2d
((
2
,
2
),
stride
=
2
))
self
.
contracting_path
=
nn
.
ModuleList
(
contracting_path
)
# the bottleneck block
self
.
midblock1
=
ConvResBlock
(
channels_in
=
c_out
,
channels_out
=
c_out
,
activation
=
self
.
activation
,
weight_init
=
weight_init
,
time_channels
=
self
.
time_channels
,
num_groups
=
num_groupnorm_groups
,
dropout
=
dropout
,
attention
=
True
,
num_attention_heads
=
num_attention_heads
)
self
.
midblock2
=
ConvResBlock
(
channels_in
=
c_out
,
channels_out
=
c_out
,
activation
=
self
.
activation
,
weight_init
=
weight_init
,
time_channels
=
self
.
time_channels
,
num_groups
=
num_groupnorm_groups
,
dropout
=
dropout
,
attention
=
False
,
num_attention_heads
=
num_attention_heads
)
# expansive path
expansive_path
=
[]
# input to the expansive path = output of midblock = input to midblock = output of contracting path
c_in
=
c_out
=
self
.
stages
[
-
1
]
# there are num_stages number of stages
# each stage has num_blocks number of ConvRes+Attention blocks and then 1 more to halve the number of channels
# each stage (except for the last) ends with an upsampling layer - Transposed convolution
for
i
in
reversed
(
range
(
self
.
num_stages
)):
# channels_in = c_in + c_out to account for the incoming skip connections from contracting path
for
_
in
range
(
self
.
num_blocks
):
expansive_path
.
append
(
ConvResBlock
(
channels_in
=
c_in
+
c_out
,
channels_out
=
c_out
,
activation
=
self
.
activation
,
weight_init
=
weight_init
,
time_channels
=
self
.
time_channels
,
num_groups
=
num_groupnorm_groups
,
dropout
=
dropout
,
attention
=
self
.
attention_list
[
i
],
num_attention_heads
=
num_attention_heads
))
if
i
>
0
:
c_out
=
self
.
stages
[
i
-
1
]
else
:
c_out
=
self
.
projection_features
expansive_path
.
append
(
ConvResBlock
(
channels_in
=
c_in
+
c_out
,
channels_out
=
c_out
,
activation
=
self
.
activation
,
weight_init
=
weight_init
,
time_channels
=
self
.
time_channels
,
num_groups
=
num_groupnorm_groups
,
dropout
=
dropout
,
attention
=
self
.
attention_list
[
i
],
num_attention_heads
=
num_attention_heads
))
c_in
=
c_out
# upsample, if it is not the last stage
if
i
>
0
:
expansive_path
.
append
(
nn
.
ConvTranspose2d
(
c_in
,
c_in
,
kernel_size
=
4
,
stride
=
2
,
padding
=
1
))
self
.
expansive_path
=
nn
.
ModuleList
(
expansive_path
)
# final convolution layer
self
.
end_gn
=
nn
.
GroupNorm
(
8
,
c_in
)
self
.
end_conv
=
nn
.
Conv2d
(
c_in
,
self
.
channels_out
,
kernel_size
=
3
,
padding
=
1
)
if
weight_init
==
'
he
'
:
nn
.
init
.
kaiming_uniform_
(
self
.
end_conv
.
weight
,
nonlinearity
=
'
relu
'
)
self
.
end_act
=
self
.
activation
def
forward
(
self
,
x
,
t
):
t
=
torch
.
tensor
(
t
)
# to store feature maps from contracting path
skip
=
[]
# time embedding for time step t (int)
t
=
self
.
time_embedding
(
t
).
to
(
'
cuda
'
)
# first conv layer to project input image (3, *, *) into (projection_features=64, *, *)
x
=
self
.
first_act
(
self
.
first_conv
(
x
))
# store initial projection
skip
.
append
(
x
)
# contracting path
for
i
in
range
(
len
(
self
.
contracting_path
)):
if
isinstance
(
self
.
contracting_path
[
i
],
ConvResBlock
):
x
=
self
.
contracting_path
[
i
](
x
,
t
)
else
:
x
=
self
.
contracting_path
[
i
](
x
)
# store feature maps
skip
.
append
(
x
)
x
=
self
.
midblock1
(
x
,
t
)
x
=
self
.
midblock2
(
x
,
t
)
# expansive path
for
i
in
range
(
len
(
self
.
expansive_path
)):
# add channels coming from skip connections (doesn't apply for upsampling ConvTranspose2D layer)
if
isinstance
(
self
.
expansive_path
[
i
],
ConvResBlock
):
x
=
torch
.
cat
((
x
,
skip
.
pop
()),
dim
=
1
)
x
=
self
.
expansive_path
[
i
](
x
,
t
)
else
:
x
=
self
.
expansive_path
[
i
](
x
)
# final conv layer
x
=
self
.
end_gn
(
x
)
x
=
self
.
end_act
(
self
.
end_conv
(
x
))
return
x
This diff is collapsed.
Click to expand it.
Preview
0%
Loading
Try again
or
attach a new file
.
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Save comment
Cancel
Please
register
or
sign in
to comment