bugfix UNetShrink
This commit is contained in:
parent
6a97fd27ec
commit
47cfefd61f
1 changed files with 2 additions and 2 deletions
|
@ -159,7 +159,7 @@ class UNet3D_Shrink(BaseModel):
|
|||
2**(self.first_layer_channel_exponent + self.depth_enc), style_dim)
|
||||
|
||||
self.dec = []
|
||||
for i in range(self.depth_dec - 1, -1, -1):
|
||||
for i in range(self.depth_enc - 1, self.depth_dec - self.depth_enc -1, -1):
|
||||
in_ch = 2**(self.first_layer_channel_exponent + i + 1)
|
||||
out_ch = 2**(self.first_layer_channel_exponent + i)
|
||||
skip_conn_ch = out_ch if i >= self.depth_dec-self.depth_enc else 0
|
||||
|
@ -167,7 +167,7 @@ class UNet3D_Shrink(BaseModel):
|
|||
|
||||
self.dec = nn.ModuleList(self.dec)
|
||||
|
||||
self.final = nn.Conv3d(2**(self.first_layer_channel_exponent), out_channels, kernel_size=1)
|
||||
self.final = nn.Conv3d(2**(self.first_layer_channel_exponent+shrink_factor_exponent), out_channels, kernel_size=1)
|
||||
|
||||
|
||||
def forward(self, x, style):
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue