From 47cfefd61f3c558ed5eb27095ccb2464985f9b55 Mon Sep 17 00:00:00 2001 From: Mayeul Aubin Date: Wed, 25 Jun 2025 15:49:05 +0200 Subject: [PATCH] bugfix UNetShrink --- sCOCA_ML/models/UNet_models.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/sCOCA_ML/models/UNet_models.py b/sCOCA_ML/models/UNet_models.py index 59a4398..8c99007 100644 --- a/sCOCA_ML/models/UNet_models.py +++ b/sCOCA_ML/models/UNet_models.py @@ -159,7 +159,7 @@ class UNet3D_Shrink(BaseModel): 2**(self.first_layer_channel_exponent + self.depth_enc), style_dim) self.dec = [] - for i in range(self.depth_dec - 1, -1, -1): + for i in range(self.depth_enc - 1, self.depth_dec - self.depth_enc -1, -1): in_ch = 2**(self.first_layer_channel_exponent + i + 1) out_ch = 2**(self.first_layer_channel_exponent + i) skip_conn_ch = out_ch if i >= self.depth_dec-self.depth_enc else 0 @@ -167,7 +167,7 @@ class UNet3D_Shrink(BaseModel): self.dec = nn.ModuleList(self.dec) - self.final = nn.Conv3d(2**(self.first_layer_channel_exponent), out_channels, kernel_size=1) + self.final = nn.Conv3d(2**(self.first_layer_channel_exponent+shrink_factor_exponent), out_channels, kernel_size=1) def forward(self, x, style):