Skip to content
GitLab
Explore
Sign in
Primary navigation
Search or go to…
Project
Unconditional Diffusion
Manage
Activity
Members
Code
Repository
Branches
Commits
Tags
Repository graph
Compare revisions
Locked files
Deploy
Model registry
Analyze
Model experiments
Help
Help
Support
GitLab documentation
Compare GitLab plans
GitLab community forum
Contribute to GitLab
Provide feedback
Terms and privacy
Keyboard shortcuts
?
Snippets
Groups
Projects
Show more breadcrumbs
Diffusion Project
Unconditional Diffusion
Commits
38147ff1
Commit
38147ff1
authored
2 years ago
by
Gonzalo Martin Garcia
Browse files
Options
Downloads
Patches
Plain Diff
EMA training added. Fixed some minor buggs.
parent
cf57c965
No related branches found
No related tags found
No related merge requests found
Changes
3
Show whitespace changes
Inline
Side-by-side
Showing
3 changed files
main.py
+1
-1
1 addition, 1 deletion
main.py
models/Framework.py
+32
-32
32 additions, 32 deletions
models/Framework.py
trainer/train.py
+46
-12
46 additions, 12 deletions
trainer/train.py
with
79 additions
and
45 deletions
main.py
+
1
−
1
View file @
38147ff1
...
@@ -36,7 +36,7 @@ def train_func(f):
...
@@ -36,7 +36,7 @@ def train_func(f):
#model = globals()[meta_setting["modelname"]](**model_setting).to(device)
#model = globals()[meta_setting["modelname"]](**model_setting).to(device)
#net = torch.compile(model)
#net = torch.compile(model)
net
=
UNet2DModel
(
net
=
UNet2DModel
(
sample_size
=
128
,
sample_size
=
64
,
in_channels
=
3
,
in_channels
=
3
,
out_channels
=
3
,
out_channels
=
3
,
layers_per_block
=
2
,
layers_per_block
=
2
,
...
...
This diff is collapsed.
Click to expand it.
models/Framework.py
+
32
−
32
View file @
38147ff1
...
@@ -68,9 +68,9 @@ class DDPM(nn.Module):
...
@@ -68,9 +68,9 @@ class DDPM(nn.Module):
self
.
recon_loss
=
recon_loss
self
.
recon_loss
=
recon_loss
self
.
out_shape
=
out_shape
self
.
out_shape
=
out_shape
# precomputed for efficiency reasons
# precomputed for efficiency reasons
self
.
noise_scaler
=
(
(
1
-
alpha
)
/
(
self
.
sqrt_1_minus_alpha_bar
)
)
self
.
noise_scaler
=
(
1
-
alpha
)
/
(
self
.
sqrt_1_minus_alpha_bar
)
self
.
mean_scaler
=
(
1
/
torch
.
sqrt
(
self
.
alpha
)
)
self
.
mean_scaler
=
1
/
torch
.
sqrt
(
self
.
alpha
)
self
.
mse_weight
=
(
(
self
.
beta
**
2
)
/
(
2
*
self
.
var
*
self
.
alpha
*
(
1
-
self
.
alpha_bar
))
)
self
.
mse_weight
=
(
self
.
beta
**
2
)
/
(
2
*
self
.
var
*
self
.
alpha
*
(
1
-
self
.
alpha_bar
))
@staticmethod
@staticmethod
def
linear_schedule
(
diffusion_steps
,
beta_1
,
beta_T
,
device
):
def
linear_schedule
(
diffusion_steps
,
beta_1
,
beta_T
,
device
):
...
@@ -183,14 +183,14 @@ class DDPM(nn.Module):
...
@@ -183,14 +183,14 @@ class DDPM(nn.Module):
Parameters:
Parameters:
x_0 (tensor): Batch of input images, with color channels assumed to be normalized between [-1,1]
x_0 (tensor): Batch of input images, with color channels assumed to be normalized between [-1,1]
t
(int): T
imestep, by default goes through full forward trajectory
t
(tensor): Batch of t
imestep
s
, by default goes through full forward trajectory
Returns:
Returns:
x_T (tensor): Batch of noised images at timestep t
x_T (tensor): Batch of noised images at timestep t
forward_noise (tensor): Batch of noise parameters from the noise distribution reparametrization used to draw x_T
forward_noise (tensor): Batch of noise parameters from the noise distribution reparametrization used to draw x_T
'''
'''
if
t
is
None
:
if
t
is
None
:
t
=
self
.
diffusion_steps
t
=
torch
.
full
((
x_0
.
shape
[
0
],),
self
.
diffusion_steps
,
device
=
self
.
device
)
elif
torch
.
any
(
t
==
0
):
elif
torch
.
any
(
t
==
0
):
raise
ValueError
(
"
The tensor
'
t
'
contains a timestep zero.
"
)
raise
ValueError
(
"
The tensor
'
t
'
contains a timestep zero.
"
)
forward_noise
=
torch
.
randn
(
x_0
.
shape
,
device
=
self
.
device
)
forward_noise
=
torch
.
randn
(
x_0
.
shape
,
device
=
self
.
device
)
...
@@ -200,13 +200,13 @@ class DDPM(nn.Module):
...
@@ -200,13 +200,13 @@ class DDPM(nn.Module):
@torch.no_grad
()
@torch.no_grad
()
def
noised_latent
(
self
,
forward_noise
,
x_0
,
t
):
def
noised_latent
(
self
,
forward_noise
,
x_0
,
t
):
'''
'''
Given a batch of noise parameters, this function recomputes the batch of noised images at timestep t.
Given a batch of noise parameters, this function recomputes the batch of noised images at
their respective
timestep
s
t.
This allows us to avoid storing all the intermediate latents x_t along the forward trajectory.
This allows us to avoid storing all the intermediate latents x_t along the forward trajectory.
Parameters:
Parameters:
forward_noise (tensor): Batch of noise parameters from the noise distribution reparametrization used to draw x_t
forward_noise (tensor): Batch of noise parameters from the noise distribution reparametrization used to draw x_t
x_0 (tensor): Batch of input images, with color channels assumed to be normalized between [-1,1]
x_0 (tensor): Batch of input images, with color channels assumed to be normalized between [-1,1]
t
(int): T
imestep
t
(tensor): Batch of t
imestep
s
Returns:
Returns:
x_t (tensor): Batch of noised images at timestep t
x_t (tensor): Batch of noised images at timestep t
...
@@ -222,14 +222,14 @@ class DDPM(nn.Module):
...
@@ -222,14 +222,14 @@ class DDPM(nn.Module):
Parameters:
Parameters:
x_0 (tensor): Batch of input images, with color channels assumed to be normalized between [-1,1]
x_0 (tensor): Batch of input images, with color channels assumed to be normalized between [-1,1]
t
(int): T
imestep
t
(tensor): Batch of t
imestep
s
Returns:
Returns:
mean (tensor): Batch of means for the complete noise distribution for each image in the batch x_0
mean (tensor): Batch of means for the complete noise distribution for each image in the batch x_0
std (tensor): Batch of std scalars for the complete noise distribution for each image in the batch x_0
std (tensor): Batch of std scalars for the complete noise distribution for each image in the batch x_0
'''
'''
mean
=
self
.
sqrt_alpha_bar
[
t
-
1
]
.
view
(
-
1
,
1
,
1
,
1
)
*
x_0
mean
=
self
.
sqrt_alpha_bar
[
t
-
1
]
[:,
None
,
None
,
None
]
*
x_0
std
=
self
.
sqrt_1_minus_alpha_bar
[
t
-
1
]
.
view
(
-
1
,
1
,
1
,
1
)
std
=
self
.
sqrt_1_minus_alpha_bar
[
t
-
1
]
[:,
None
,
None
,
None
]
return
mean
,
std
return
mean
,
std
@torch.no_grad
()
@torch.no_grad
()
...
@@ -239,14 +239,14 @@ class DDPM(nn.Module):
...
@@ -239,14 +239,14 @@ class DDPM(nn.Module):
Parameters:
Parameters:
x_t_1 (tensor): Batch of noised images at timestep t-1
x_t_1 (tensor): Batch of noised images at timestep t-1
t
(int): T
imestep
t
(tensor): Batch of t
imestep
s
Returns:
Returns:
mean (tensor): Batch of means for the individual noise distribution for each image in the batch x_t_1
mean (tensor): Batch of means for the individual noise distribution for each image in the batch x_t_1
std (tensor): Batch of std scalars for the individual noise distribution for each image in the batch x_t_1
std (tensor): Batch of std scalars for the individual noise distribution for each image in the batch x_t_1
'''
'''
mean
=
torch
.
sqrt
(
1
-
self
.
beta
[
t
-
1
])
.
view
(
-
1
,
1
,
1
,
1
)
*
x_t_1
mean
=
torch
.
sqrt
(
1
-
self
.
beta
[
t
-
1
])
[:,
None
,
None
,
None
]
*
x_t_1
std
=
torch
.
sqrt
(
self
.
beta
[
t
-
1
])
.
view
(
-
1
,
1
,
1
,
1
)
std
=
torch
.
sqrt
(
self
.
beta
[
t
-
1
])
[:,
None
,
None
,
None
]
return
mean
,
std
return
mean
,
std
...
@@ -254,12 +254,12 @@ class DDPM(nn.Module):
...
@@ -254,12 +254,12 @@ class DDPM(nn.Module):
def
reverse_trajectory
(
self
,
x_t
,
t
):
def
reverse_trajectory
(
self
,
x_t
,
t
):
'''
'''
Draws a denoised image x_{t-1} by reparametrizing the denoising distribution at time t for the current noised
Draws a denoised image
s
x_{t-1} by reparametrizing the denoising distribution at time
s
t for the current noised
latent x_t.
latent
s
x_t.
Parameters:
Parameters:
x_t (tensor): Batch of input images, with color channels assumed to be normalized between [-1,1]
x_t (tensor): Batch of input images, with color channels assumed to be normalized between [-1,1]
t
(int): T
imestep
t
(tensor): Batch of t
imestep
Returns:
Returns:
x_t_1 (tensor): Batch of denoised images at timestep t-1
x_t_1 (tensor): Batch of denoised images at timestep t-1
...
@@ -271,16 +271,16 @@ class DDPM(nn.Module):
...
@@ -271,16 +271,16 @@ class DDPM(nn.Module):
def
forward
(
self
,
x_t
,
t
):
def
forward
(
self
,
x_t
,
t
):
'''
'''
Passes the current noised image x_t and timestep t through the U-Net in order to compute the
Passes the current noised image
s
x_t and timestep
s
t through the U-Net in order to compute the
predicted noise, which is later used to determine the current denoising distribution parameters
in the
predicted noise, which is later used to determine the current denoising distribution parameters
reverse trajectory.
(mean and std) in the
reverse trajectory.
Since the DDPM class is inheriting from the nn.Module class, this function is required to share
Since the DDPM class is inheriting from the nn.Module class, this function is required to share
the name
'
forward
'
. This naming scheme does not refer to the forward trajectory, but the forward
the name
'
forward
'
. This naming scheme does not refer to the forward trajectory, but the forward
pass of the model itself, which concerns to the reverse trajectory.
pass of the model itself, which concerns to the reverse trajectory.
Parameters:
Parameters:
x_t (tensor): Batch of input images, with color channels assumed to be normalized between [-1,1]
x_t (tensor): Batch of input images, with color channels assumed to be normalized between [-1,1]
t
(int): T
imestep
t
(tensor): Batch of t
imestep
s
Returns:
Returns:
mean (tensor): Batch of means for the complete noise dist. for each image in the batch x_t
mean (tensor): Batch of means for the complete noise dist. for each image in the batch x_t
...
@@ -288,8 +288,8 @@ class DDPM(nn.Module):
...
@@ -288,8 +288,8 @@ class DDPM(nn.Module):
pred_noise (tensor): Predicted noise for each image in the batch x_t
pred_noise (tensor): Predicted noise for each image in the batch x_t
'''
'''
pred_noise
=
self
.
net
(
x_t
,
t
,
return_dict
=
False
)[
0
]
pred_noise
=
self
.
net
(
x_t
,
t
,
return_dict
=
False
)[
0
]
mean
=
self
.
mean_scaler
[
t
-
1
]
.
view
(
-
1
,
1
,
1
,
1
)
*
(
x_t
-
self
.
noise_scaler
[
t
-
1
]
.
view
(
-
1
,
1
,
1
,
1
)
*
pred_noise
)
mean
=
self
.
mean_scaler
[
t
-
1
]
[:,
None
,
None
,
None
]
*
(
x_t
-
self
.
noise_scaler
[
t
-
1
]
[:,
None
,
None
,
None
]
*
pred_noise
)
std
=
self
.
std
[
t
-
1
]
.
view
(
-
1
,
1
,
1
,
1
)
std
=
self
.
std
[
t
-
1
]
[:,
None
,
None
,
None
]
return
mean
,
std
,
pred_noise
return
mean
,
std
,
pred_noise
...
@@ -319,7 +319,7 @@ class DDPM(nn.Module):
...
@@ -319,7 +319,7 @@ class DDPM(nn.Module):
else
:
else
:
noise
=
torch
.
zeros
(
x_0_recon
.
shape
,
device
=
self
.
device
)
noise
=
torch
.
zeros
(
x_0_recon
.
shape
,
device
=
self
.
device
)
# get denoising dist. param
# get denoising dist. param
mean
,
std
,
_
=
self
.
forward
(
x_0_recon
,
t
)
mean
,
std
,
_
=
self
.
forward
(
x_0_recon
,
t
orch
.
full
((
x_0_recon
.
shape
[
0
],),
t
,
device
=
self
.
device
)
)
# compute the drawn denoised latent at time t
# compute the drawn denoised latent at time t
x_0_recon
=
mean
+
std
*
noise
x_0_recon
=
mean
+
std
*
noise
return
x_0_recon
return
x_0_recon
...
@@ -355,7 +355,7 @@ class DDPM(nn.Module):
...
@@ -355,7 +355,7 @@ class DDPM(nn.Module):
else
:
else
:
noise
=
torch
.
zeros
(
x_t_1
.
shape
,
device
=
self
.
device
)
noise
=
torch
.
zeros
(
x_t_1
.
shape
,
device
=
self
.
device
)
# get denoising dist. param
# get denoising dist. param
mean
,
std
,
_
=
self
.
forward
(
x_t_1
,
t
)
mean
,
std
,
_
=
self
.
forward
(
x_t_1
,
t
orch
.
full
((
x_t_1
.
shape
[
0
],),
t
,
device
=
self
.
device
)
)
# compute the drawn densoined latent at time t
# compute the drawn densoined latent at time t
x_t_1
=
mean
+
std
*
noise
x_t_1
=
mean
+
std
*
noise
return
x_t_1
return
x_t_1
...
@@ -372,8 +372,8 @@ class DDPM(nn.Module):
...
@@ -372,8 +372,8 @@ class DDPM(nn.Module):
'''
'''
# start with an image of pure noise (batch_size 1) and store it as part of the output
# start with an image of pure noise (batch_size 1) and store it as part of the output
x_t_1
=
torch
.
randn
((
1
,)
+
tuple
(
self
.
out_shape
),
device
=
self
.
device
)
x_t_1
=
torch
.
randn
((
1
,)
+
tuple
(
self
.
out_shape
),
device
=
self
.
device
)
x
=
torch
.
empty
((
self
.
diffusion_steps
+
1
,
1
,
)
+
tuple
(
self
.
out_shape
),
device
=
self
.
device
)
x
=
torch
.
empty
((
self
.
diffusion_steps
+
1
,)
+
tuple
(
self
.
out_shape
),
device
=
self
.
device
)
x
[
-
1
]
=
x_t_1
x
[
-
1
]
=
x_t_1
.
squeeze
(
0
)
# apply reverse trajectory
# apply reverse trajectory
for
t
in
reversed
(
range
(
1
,
self
.
diffusion_steps
+
1
)):
for
t
in
reversed
(
range
(
1
,
self
.
diffusion_steps
+
1
)):
# draw noise used in the denoising dist. reparametrization
# draw noise used in the denoising dist. reparametrization
...
@@ -382,14 +382,14 @@ class DDPM(nn.Module):
...
@@ -382,14 +382,14 @@ class DDPM(nn.Module):
else
:
else
:
noise
=
torch
.
zeros
(
x_t_1
.
shape
,
device
=
self
.
device
)
noise
=
torch
.
zeros
(
x_t_1
.
shape
,
device
=
self
.
device
)
# get denoising dist. param
# get denoising dist. param
mean
,
std
,
_
=
self
.
forward
(
x_t_1
,
t
)
mean
,
std
,
_
=
self
.
forward
(
x_t_1
,
t
orch
.
full
((
x_t_1
.
shape
[
0
],),
t
,
device
=
self
.
device
)
)
# compute the drawn densoined latent at time t
# compute the drawn densoined latent at time t
x_t_1
=
mean
+
std
*
noise
x_t_1
=
mean
+
std
*
noise
# store noised image
# store noised image
x
[
t
-
1
]
=
x_t_1
x
[
t
-
1
]
=
x_t_1
.
squeeze
(
0
)
x_sq
=
x
.
squeeze
(
1
)
#
x_sq = x.squeeze(1)
return
x_sq
#
return x_sq
#
return x
return
x
# Loss functions
# Loss functions
...
@@ -407,7 +407,7 @@ class DDPM(nn.Module):
...
@@ -407,7 +407,7 @@ class DDPM(nn.Module):
'''
'''
Returns the mathematically correct weighted version of the simplified loss.
Returns the mathematically correct weighted version of the simplified loss.
'''
'''
return
self
.
mse_weight
[
t
-
1
]
.
view
(
-
1
,
1
,
1
,
1
)
*
F
.
mse_loss
(
forward_noise
,
pred_noise
)
return
self
.
mse_weight
[
t
-
1
]
[:,
None
,
None
,
None
]
*
F
.
mse_loss
(
forward_noise
,
pred_noise
)
# If t=0 and self.recon_loss == 'nll'
# If t=0 and self.recon_loss == 'nll'
...
...
This diff is collapsed.
Click to expand it.
trainer/train.py
+
46
−
12
View file @
38147ff1
import
numpy
as
np
import
numpy
as
np
import
copy
import
torch
import
torch
from
torch
import
nn
from
torch
import
nn
from
torchvision
import
datasets
,
transforms
from
torchvision
import
datasets
,
transforms
...
@@ -10,6 +10,7 @@ import numpy as np
...
@@ -10,6 +10,7 @@ import numpy as np
import
torch.nn.functional
as
F
import
torch.nn.functional
as
F
import
os
import
os
import
wandb
import
wandb
from
copy
import
deepcopy
device
=
'
cuda
'
if
torch
.
cuda
.
is_available
()
else
'
cpu
'
device
=
'
cuda
'
if
torch
.
cuda
.
is_available
()
else
'
cpu
'
...
@@ -53,6 +54,37 @@ def simple_trainer(model,device,epochs,trainloader,testloader,bs,lr,T,criterion
...
@@ -53,6 +54,37 @@ def simple_trainer(model,device,epochs,trainloader,testloader,bs,lr,T,criterion
print
(
f
"
Testloss in step
{
epoch
}
:
{
np
.
mean
(
running_testloss
)
}
"
)
print
(
f
"
Testloss in step
{
epoch
}
:
{
np
.
mean
(
running_testloss
)
}
"
)
# EMA class
# Important! This EMA class code is not ours and was taken from the Pytorch Image Models library called timm and performs exponential moving
# average on the trained weights for a given models neural net which was suggested by the paper "Improved Denoising Diffusion Probabilistic Models"
# by Nichol and Dhariwal to stabilize and improve the training and generalization process.
# https://github.com/huggingface/pytorch-image-models/blob/main/timm/utils/model_ema.py
class
ModelEmaV2
(
nn
.
Module
):
def
__init__
(
self
,
model
,
decay
=
0.9999
,
device
=
None
):
super
(
ModelEmaV2
,
self
).
__init__
()
# make a copy of the model for accumulating moving average of weights
self
.
module
=
deepcopy
(
model
)
self
.
module
.
eval
()
self
.
decay
=
decay
self
.
device
=
device
# perform ema on different device from model if set
if
self
.
device
is
not
None
:
self
.
module
.
to
(
device
=
device
)
def
_update
(
self
,
model
,
update_fn
):
with
torch
.
no_grad
():
for
ema_v
,
model_v
in
zip
(
self
.
module
.
state_dict
().
values
(),
model
.
state_dict
().
values
()):
if
self
.
device
is
not
None
:
model_v
=
model_v
.
to
(
device
=
self
.
device
)
ema_v
.
copy_
(
update_fn
(
ema_v
,
model_v
))
def
update
(
self
,
model
):
self
.
_update
(
model
,
update_fn
=
lambda
e
,
m
:
self
.
decay
*
e
+
(
1.
-
self
.
decay
)
*
m
)
def
set
(
self
,
model
):
self
.
_update
(
model
,
update_fn
=
lambda
e
,
m
:
m
)
# Training function for the unconditional diffusion model
# Training function for the unconditional diffusion model
def
ddpm_trainer
(
model
,
def
ddpm_trainer
(
model
,
...
@@ -70,6 +102,8 @@ def ddpm_trainer(model,
...
@@ -70,6 +102,8 @@ def ddpm_trainer(model,
experiment_path
=
None
,
experiment_path
=
None
,
T_max
=
5
*
10000
,
# None,
T_max
=
5
*
10000
,
# None,
eta_min
=
1e-5
,
eta_min
=
1e-5
,
ema_training
=
True
,
decay
=
0.9999
,
**
args
**
args
):
):
'''
'''
...
@@ -89,6 +123,8 @@ def ddpm_trainer(model,
...
@@ -89,6 +123,8 @@ def ddpm_trainer(model,
checkpoint: Name of the saved pth. file containing the trained weights and biases
checkpoint: Name of the saved pth. file containing the trained weights and biases
T_max: CosineAnnealingLR scheduler argument (nr of steps in training for a full cycle)
T_max: CosineAnnealingLR scheduler argument (nr of steps in training for a full cycle)
eta_min: CosineAnnealingLR scheduler argument (scheduler oscillates between highest lr
'
leraning_rate
'
and minimum lr
'
eta_min
'
)
eta_min: CosineAnnealingLR scheduler argument (scheduler oscillates between highest lr
'
leraning_rate
'
and minimum lr
'
eta_min
'
)
decay: EMA decay rate that is used to weight the effect of the ema model when computing the weighted avg between trained and
ema weights for the networks weight update
'''
'''
# set optimizer parameters and learning rate
# set optimizer parameters and learning rate
...
@@ -133,13 +169,17 @@ def ddpm_trainer(model,
...
@@ -133,13 +169,17 @@ def ddpm_trainer(model,
if
model
.
recon_loss
==
'
nll
'
:
if
model
.
recon_loss
==
'
nll
'
:
low
=
0
low
=
0
# EMA
if
ema_training
:
ema
=
ModelEmaV2
(
model
,
decay
=
decay
,
device
=
model
.
device
)
# Using W&B
# Using W&B
with
wandb
.
init
(
project
=
'
test-project
'
,
name
=
run_name
,
entity
=
'
gonzalomartingarcia0
'
,
id
=
run_name
,
resume
=
True
)
as
run
:
with
wandb
.
init
(
project
=
'
test-project
'
,
name
=
run_name
,
entity
=
'
gonzalomartingarcia0
'
,
id
=
run_name
,
resume
=
True
)
as
run
:
# Log some info
# Log some info
run
.
config
.
learning_rate
=
learning_rate
run
.
config
.
learning_rate
=
learning_rate
run
.
config
.
optimizer
=
optimizer
.
__class__
.
__name__
run
.
config
.
optimizer
=
optimizer
.
__class__
.
__name__
run
.
watch
(
model
.
net
)
#
run.watch(model.net)
# training loop
# training loop
# last model was stored at epoch last_epoch, we continue training from there, i.e. last_epoch+1 (else we start at epoch 0)
# last model was stored at epoch last_epoch, we continue training from there, i.e. last_epoch+1 (else we start at epoch 0)
...
@@ -178,20 +218,14 @@ def ddpm_trainer(model,
...
@@ -178,20 +218,14 @@ def ddpm_trainer(model,
loss
.
backward
()
loss
.
backward
()
optimizer
.
step
()
optimizer
.
step
()
if
ema_training
:
ema
.
update
(
model
)
scheduler
.
step
()
scheduler
.
step
()
if
verbose
:
if
verbose
:
print
(
f
"
Loss in epoch
{
epoch
}
:
{
running_trainloss
/
nr_train_batches
}
"
)
print
(
f
"
Loss in epoch
{
epoch
}
:
{
running_trainloss
/
nr_train_batches
}
"
)
run
.
log
({
'
running_loss
'
:
running_trainloss
/
nr_train_batches
})
run
.
log
({
'
running_loss
'
:
running_trainloss
/
nr_train_batches
})
# WORKING OLD VERSION
#x_t, forward_noise = model.forward_trajectory(x_0,t)
#_, _, pred_noise = model.forward(x_t,t)
#loss = loss_func(forward_noise,pred_noise,t)
#running_trainloss += loss.item()
#nr_train_batches += 1
#run.log({'loss': loss.item(), "learning_rate": scheduler.get_last_lr()[0], 'epoch': epoch, 'batch': idx})
# evaluation
# evaluation
if
((
epoch
+
1
)
%
eval_iter
==
0
)
or
((
epoch
+
1
)
%
store_iter
==
0
):
if
((
epoch
+
1
)
%
eval_iter
==
0
)
or
((
epoch
+
1
)
%
store_iter
==
0
):
running_testloss
=
0
running_testloss
=
0
...
...
This diff is collapsed.
Click to expand it.
Preview
0%
Loading
Try again
or
attach a new file
.
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Save comment
Cancel
Please
register
or
sign in
to comment