changed FNO combination to add

This commit is contained in:
Mayeul Aubin 2025-06-25 17:00:11 +02:00
parent 1798195db4
commit 118455567b

View file

@ -112,18 +112,16 @@ class FNOBlock3D(nn.Module):
4. Repeats the steps for the second convolution. 4. Repeats the steps for the second convolution.
C: Combination: C: Combination:
1. Concatenation: Conctenates the outputs from the Fourier and real space blocks (channel size is doubled). 1. Addition: adds the outputs from the Fourier and real space block.
2. Final Convolution: Applies a final convolution to reduce the channel size. 2. Activation: Applies a ReLU activation function.
3. Activation: Applies a ReLU activation function. 3. FiLM: Applies Feature-wise Linear Modulation (FiLM) to condition the combined features on style parameters.
4. FiLM: Applies Feature-wise Linear Modulation (FiLM) to condition the combined features on style parameters. 4. Dropout: Applies dropout to the output tensor.
5. Dropout: Applies dropout to the output tensor.
""" """
def __init__(self, in_channels, out_channels, filtering=None, style_dim=None, dropout=0.05): def __init__(self, in_channels, out_channels, filtering=None, style_dim=None, dropout=0.05):
super(FNOBlock3D, self).__init__() super(FNOBlock3D, self).__init__()
self.fourier_block = FourierSpaceBlock3D(in_channels, out_channels, filtering=filtering, style_dim=style_dim,) self.fourier_block = FourierSpaceBlock3D(in_channels, out_channels, filtering=filtering, style_dim=style_dim,)
self.real_block = RealSpaceBlock3D(in_channels, out_channels, style_dim=style_dim) self.real_block = RealSpaceBlock3D(in_channels, out_channels, style_dim=style_dim)
self.comb_conv = nn.Conv3d(2 * out_channels, out_channels, kernel_size=1, padding=0)
self.relu = nn.ReLU(inplace=True) self.relu = nn.ReLU(inplace=True)
self.film = FiLM(out_channels, style_dim) if style_dim else None self.film = FiLM(out_channels, style_dim) if style_dim else None
self.dropout = nn.Dropout(dropout) if dropout > 0 else None self.dropout = nn.Dropout(dropout) if dropout > 0 else None
@ -131,9 +129,7 @@ class FNOBlock3D(nn.Module):
def forward(self, x, style=None): def forward(self, x, style=None):
fourier_out = self.fourier_block(x, style) fourier_out = self.fourier_block(x, style)
real_out = self.real_block(x, style) real_out = self.real_block(x, style)
combined = torch.cat([fourier_out, real_out], dim=1) out = self.relu(real_out + fourier_out)
out = self.comb_conv(combined)
out = self.relu(out)
if self.film is not None: if self.film is not None:
out = self.film(out, style) out = self.film(out, style)
if self.dropout is not None: if self.dropout is not None: