changed FNO combination to add
This commit is contained in:
parent
1798195db4
commit
118455567b
1 changed files with 5 additions and 9 deletions
|
@ -112,18 +112,16 @@ class FNOBlock3D(nn.Module):
|
|||
4. Repeats the steps for the second convolution.
|
||||
|
||||
C: Combination:
|
||||
1. Concatenation: Conctenates the outputs from the Fourier and real space blocks (channel size is doubled).
|
||||
2. Final Convolution: Applies a final convolution to reduce the channel size.
|
||||
3. Activation: Applies a ReLU activation function.
|
||||
4. FiLM: Applies Feature-wise Linear Modulation (FiLM) to condition the combined features on style parameters.
|
||||
5. Dropout: Applies dropout to the output tensor.
|
||||
1. Addition: adds the outputs from the Fourier and real space block.
|
||||
2. Activation: Applies a ReLU activation function.
|
||||
3. FiLM: Applies Feature-wise Linear Modulation (FiLM) to condition the combined features on style parameters.
|
||||
4. Dropout: Applies dropout to the output tensor.
|
||||
"""
|
||||
|
||||
def __init__(self, in_channels, out_channels, filtering=None, style_dim=None, dropout=0.05):
|
||||
super(FNOBlock3D, self).__init__()
|
||||
self.fourier_block = FourierSpaceBlock3D(in_channels, out_channels, filtering=filtering, style_dim=style_dim,)
|
||||
self.real_block = RealSpaceBlock3D(in_channels, out_channels, style_dim=style_dim)
|
||||
self.comb_conv = nn.Conv3d(2 * out_channels, out_channels, kernel_size=1, padding=0)
|
||||
self.relu = nn.ReLU(inplace=True)
|
||||
self.film = FiLM(out_channels, style_dim) if style_dim else None
|
||||
self.dropout = nn.Dropout(dropout) if dropout > 0 else None
|
||||
|
@ -131,9 +129,7 @@ class FNOBlock3D(nn.Module):
|
|||
def forward(self, x, style=None):
|
||||
fourier_out = self.fourier_block(x, style)
|
||||
real_out = self.real_block(x, style)
|
||||
combined = torch.cat([fourier_out, real_out], dim=1)
|
||||
out = self.comb_conv(combined)
|
||||
out = self.relu(out)
|
||||
out = self.relu(real_out + fourier_out)
|
||||
if self.film is not None:
|
||||
out = self.film(out, style)
|
||||
if self.dropout is not None:
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue