changed FNO combination to add

This commit is contained in:
Mayeul Aubin 2025-06-25 17:00:11 +02:00
parent 1798195db4
commit 118455567b

View file

@ -112,18 +112,16 @@ class FNOBlock3D(nn.Module):
4. Repeats the steps for the second convolution.
C: Combination:
1. Concatenation: Conctenates the outputs from the Fourier and real space blocks (channel size is doubled).
2. Final Convolution: Applies a final convolution to reduce the channel size.
3. Activation: Applies a ReLU activation function.
4. FiLM: Applies Feature-wise Linear Modulation (FiLM) to condition the combined features on style parameters.
5. Dropout: Applies dropout to the output tensor.
1. Addition: adds the outputs from the Fourier and real space block.
2. Activation: Applies a ReLU activation function.
3. FiLM: Applies Feature-wise Linear Modulation (FiLM) to condition the combined features on style parameters.
4. Dropout: Applies dropout to the output tensor.
"""
def __init__(self, in_channels, out_channels, filtering=None, style_dim=None, dropout=0.05):
super(FNOBlock3D, self).__init__()
self.fourier_block = FourierSpaceBlock3D(in_channels, out_channels, filtering=filtering, style_dim=style_dim,)
self.real_block = RealSpaceBlock3D(in_channels, out_channels, style_dim=style_dim)
self.comb_conv = nn.Conv3d(2 * out_channels, out_channels, kernel_size=1, padding=0)
self.relu = nn.ReLU(inplace=True)
self.film = FiLM(out_channels, style_dim) if style_dim else None
self.dropout = nn.Dropout(dropout) if dropout > 0 else None
@ -131,9 +129,7 @@ class FNOBlock3D(nn.Module):
def forward(self, x, style=None):
fourier_out = self.fourier_block(x, style)
real_out = self.real_block(x, style)
combined = torch.cat([fourier_out, real_out], dim=1)
out = self.comb_conv(combined)
out = self.relu(out)
out = self.relu(real_out + fourier_out)
if self.film is not None:
out = self.film(out, style)
if self.dropout is not None: