diff --git a/models/all_unets.py b/models/all_unets.py index 9fe4f6bf77f5c11b20af250614a67d7f1e170e46..fabd628cbb06b993e30b3c8cade3903936fe8aa8 100644 --- a/models/all_unets.py +++ b/models/all_unets.py @@ -69,7 +69,7 @@ class UNet_Res(nn.Module): # first two conv layers x = self.first_conv(input) + t_emb0[:,:,None,None] #timemb - skip1 =x + skip1,x = self.down1(x,t_emb1) skip2,x = self.down2(x,t_emb2) skip3,x = self.down3(x,t_emb3) @@ -238,3 +238,5 @@ class MidBlock_Res(nn.Module): x = self.convblock1(x,t) return self.convblock2(x,t) + +