Skip to content
Snippets Groups Projects
Commit c71c8a35 authored by Gonzalo Martin Garcia's avatar Gonzalo Martin Garcia
Browse files

fixed UNet error

parent 66710703
No related branches found
No related tags found
No related merge requests found
......@@ -61,14 +61,14 @@ class Conditional_UNet_Res(nn.Module):
def forward(self, input, t, y):
# compute time mebedding
t_emb = self.time_embedder(t).to(input.device)
time_emb = self.time_embedder(t).to(input.device)
# time embedding learnable layers
t_emb0 = self.time_embedder0(t_emb)
t_emb1 = self.time_embedder1(t_emb)
t_emb2 = self.time_embedder2(t_emb)
t_emb3 = self.time_embedder3(t_emb)
t_emb4 = self.time_embedder4(t_emb)
time_emb0 = self.time_embedder0(time_emb)
time_emb1 = self.time_embedder1(time_emb)
time_emb2 = self.time_embedder2(time_emb)
time_emb3 = self.time_embedder3(time_emb)
time_emb4 = self.time_embedder4(time_emb)
# concat latent with masked image
cat = torch.concat((input, y), dim=1)
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment