changed FNO combination to add
This commit is contained in:
parent
1798195db4
commit
118455567b
1 changed files with 5 additions and 9 deletions
|
@ -112,18 +112,16 @@ class FNOBlock3D(nn.Module):
|
||||||
4. Repeats the steps for the second convolution.
|
4. Repeats the steps for the second convolution.
|
||||||
|
|
||||||
C: Combination:
|
C: Combination:
|
||||||
1. Concatenation: Conctenates the outputs from the Fourier and real space blocks (channel size is doubled).
|
1. Addition: adds the outputs from the Fourier and real space block.
|
||||||
2. Final Convolution: Applies a final convolution to reduce the channel size.
|
2. Activation: Applies a ReLU activation function.
|
||||||
3. Activation: Applies a ReLU activation function.
|
3. FiLM: Applies Feature-wise Linear Modulation (FiLM) to condition the combined features on style parameters.
|
||||||
4. FiLM: Applies Feature-wise Linear Modulation (FiLM) to condition the combined features on style parameters.
|
4. Dropout: Applies dropout to the output tensor.
|
||||||
5. Dropout: Applies dropout to the output tensor.
|
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(self, in_channels, out_channels, filtering=None, style_dim=None, dropout=0.05):
|
def __init__(self, in_channels, out_channels, filtering=None, style_dim=None, dropout=0.05):
|
||||||
super(FNOBlock3D, self).__init__()
|
super(FNOBlock3D, self).__init__()
|
||||||
self.fourier_block = FourierSpaceBlock3D(in_channels, out_channels, filtering=filtering, style_dim=style_dim,)
|
self.fourier_block = FourierSpaceBlock3D(in_channels, out_channels, filtering=filtering, style_dim=style_dim,)
|
||||||
self.real_block = RealSpaceBlock3D(in_channels, out_channels, style_dim=style_dim)
|
self.real_block = RealSpaceBlock3D(in_channels, out_channels, style_dim=style_dim)
|
||||||
self.comb_conv = nn.Conv3d(2 * out_channels, out_channels, kernel_size=1, padding=0)
|
|
||||||
self.relu = nn.ReLU(inplace=True)
|
self.relu = nn.ReLU(inplace=True)
|
||||||
self.film = FiLM(out_channels, style_dim) if style_dim else None
|
self.film = FiLM(out_channels, style_dim) if style_dim else None
|
||||||
self.dropout = nn.Dropout(dropout) if dropout > 0 else None
|
self.dropout = nn.Dropout(dropout) if dropout > 0 else None
|
||||||
|
@ -131,9 +129,7 @@ class FNOBlock3D(nn.Module):
|
||||||
def forward(self, x, style=None):
|
def forward(self, x, style=None):
|
||||||
fourier_out = self.fourier_block(x, style)
|
fourier_out = self.fourier_block(x, style)
|
||||||
real_out = self.real_block(x, style)
|
real_out = self.real_block(x, style)
|
||||||
combined = torch.cat([fourier_out, real_out], dim=1)
|
out = self.relu(real_out + fourier_out)
|
||||||
out = self.comb_conv(combined)
|
|
||||||
out = self.relu(out)
|
|
||||||
if self.film is not None:
|
if self.film is not None:
|
||||||
out = self.film(out, style)
|
out = self.film(out, style)
|
||||||
if self.dropout is not None:
|
if self.dropout is not None:
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue