diff --git a/sCOCA_ML/models/FNO_models.py b/sCOCA_ML/models/FNO_models.py index f8234af..2f14198 100644 --- a/sCOCA_ML/models/FNO_models.py +++ b/sCOCA_ML/models/FNO_models.py @@ -112,18 +112,16 @@ class FNOBlock3D(nn.Module): 4. Repeats the steps for the second convolution. C: Combination: - 1. Concatenation: Conctenates the outputs from the Fourier and real space blocks (channel size is doubled). - 2. Final Convolution: Applies a final convolution to reduce the channel size. - 3. Activation: Applies a ReLU activation function. - 4. FiLM: Applies Feature-wise Linear Modulation (FiLM) to condition the combined features on style parameters. - 5. Dropout: Applies dropout to the output tensor. + 1. Addition: adds the outputs from the Fourier and real space block. + 2. Activation: Applies a ReLU activation function. + 3. FiLM: Applies Feature-wise Linear Modulation (FiLM) to condition the combined features on style parameters. + 4. Dropout: Applies dropout to the output tensor. """ def __init__(self, in_channels, out_channels, filtering=None, style_dim=None, dropout=0.05): super(FNOBlock3D, self).__init__() self.fourier_block = FourierSpaceBlock3D(in_channels, out_channels, filtering=filtering, style_dim=style_dim,) self.real_block = RealSpaceBlock3D(in_channels, out_channels, style_dim=style_dim) - self.comb_conv = nn.Conv3d(2 * out_channels, out_channels, kernel_size=1, padding=0) self.relu = nn.ReLU(inplace=True) self.film = FiLM(out_channels, style_dim) if style_dim else None self.dropout = nn.Dropout(dropout) if dropout > 0 else None @@ -131,9 +129,7 @@ class FNOBlock3D(nn.Module): def forward(self, x, style=None): fourier_out = self.fourier_block(x, style) real_out = self.real_block(x, style) - combined = torch.cat([fourier_out, real_out], dim=1) - out = self.comb_conv(combined) - out = self.relu(out) + out = self.relu(real_out + fourier_out) if self.film is not None: out = self.film(out, style) if self.dropout is not None: