bugfix UNetShrink

This commit is contained in:
Mayeul Aubin 2025-06-25 15:49:05 +02:00
parent 6a97fd27ec
commit 47cfefd61f

View file

@ -159,7 +159,7 @@ class UNet3D_Shrink(BaseModel):
2**(self.first_layer_channel_exponent + self.depth_enc), style_dim)
self.dec = []
for i in range(self.depth_dec - 1, -1, -1):
for i in range(self.depth_enc - 1, self.depth_dec - self.depth_enc -1, -1):
in_ch = 2**(self.first_layer_channel_exponent + i + 1)
out_ch = 2**(self.first_layer_channel_exponent + i)
skip_conn_ch = out_ch if i >= self.depth_dec-self.depth_enc else 0
@ -167,7 +167,7 @@ class UNet3D_Shrink(BaseModel):
self.dec = nn.ModuleList(self.dec)
self.final = nn.Conv3d(2**(self.first_layer_channel_exponent), out_channels, kernel_size=1)
self.final = nn.Conv3d(2**(self.first_layer_channel_exponent+shrink_factor_exponent), out_channels, kernel_size=1)
def forward(self, x, style):