diff --git a/sCOCA_ML/models/UNet_models.py b/sCOCA_ML/models/UNet_models.py index 59a4398..8c99007 100644 --- a/sCOCA_ML/models/UNet_models.py +++ b/sCOCA_ML/models/UNet_models.py @@ -159,7 +159,7 @@ class UNet3D_Shrink(BaseModel): 2**(self.first_layer_channel_exponent + self.depth_enc), style_dim) self.dec = [] - for i in range(self.depth_dec - 1, -1, -1): + for i in range(self.depth_enc - 1, self.depth_dec - self.depth_enc -1, -1): in_ch = 2**(self.first_layer_channel_exponent + i + 1) out_ch = 2**(self.first_layer_channel_exponent + i) skip_conn_ch = out_ch if i >= self.depth_dec-self.depth_enc else 0 @@ -167,7 +167,7 @@ class UNet3D_Shrink(BaseModel): self.dec = nn.ModuleList(self.dec) - self.final = nn.Conv3d(2**(self.first_layer_channel_exponent), out_channels, kernel_size=1) + self.final = nn.Conv3d(2**(self.first_layer_channel_exponent+shrink_factor_exponent), out_channels, kernel_size=1) def forward(self, x, style):