improvements

This commit is contained in:
Mayeul Aubin 2025-06-24 09:26:26 +02:00
parent 6c526d7115
commit 58f9b27e6e
3 changed files with 153 additions and 7 deletions

View file

@ -56,7 +56,9 @@ class UNet3D(BaseModel):
in_channels: int = 2,
out_channels: int = 1,
style_dim: int = 2,
device: torch.device = torch.device('cpu')):
device: torch.device = torch.device('cpu'),
first_layer_channel_exponent: int = 3,
):
"""
3D U-Net model with optional FiLM layers for style conditioning.
Parameters:
@ -78,7 +80,7 @@ class UNet3D(BaseModel):
import numpy as np
self.depth = np.floor(np.log2(N)).astype(int) - 1 # Depth of the U-Net based on input size N
self.first_layer_channel_exponent = 3
self.first_layer_channel_exponent = first_layer_channel_exponent
self.enc=[]