improvements
This commit is contained in:
parent
6c526d7115
commit
58f9b27e6e
3 changed files with 153 additions and 7 deletions
|
@ -56,7 +56,9 @@ class UNet3D(BaseModel):
|
|||
in_channels: int = 2,
|
||||
out_channels: int = 1,
|
||||
style_dim: int = 2,
|
||||
device: torch.device = torch.device('cpu')):
|
||||
device: torch.device = torch.device('cpu'),
|
||||
first_layer_channel_exponent: int = 3,
|
||||
):
|
||||
"""
|
||||
3D U-Net model with optional FiLM layers for style conditioning.
|
||||
Parameters:
|
||||
|
@ -78,7 +80,7 @@ class UNet3D(BaseModel):
|
|||
import numpy as np
|
||||
|
||||
self.depth = np.floor(np.log2(N)).astype(int) - 1 # Depth of the U-Net based on input size N
|
||||
self.first_layer_channel_exponent = 3
|
||||
self.first_layer_channel_exponent = first_layer_channel_exponent
|
||||
|
||||
self.enc=[]
|
||||
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue