FNO
This commit is contained in:
parent
aa3edb457c
commit
c0b1f656ce
1 changed files with 196 additions and 0 deletions
196
sCOCA_ML/models/FNO_models.py
Normal file
196
sCOCA_ML/models/FNO_models.py
Normal file
|
@ -0,0 +1,196 @@
|
|||
"""
|
||||
FNO: Fourier Neural Operator
|
||||
Implementation of the Fourier Neural Operator (FNO) architecture.
|
||||
"""
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
from .base_class_models import BaseModel
|
||||
from .FiLM import FiLM
|
||||
|
||||
class FourierSpaceBlock3D(nn.Module):
|
||||
"""
|
||||
The Fourier Space Block applies the following operations:
|
||||
|
||||
1. Fourier Transform: Computes the Fourier transform of the input tensor.
|
||||
2. Filtering: Removes the high-frequency components by slicing the Fourier field.
|
||||
3. Linear Transformation: Applies a linear transformation in the Fourier space.
|
||||
4. FiLM: Applies Feature-wise Linear Modulation (FiLM) to condition the Fourier features on style parameters.
|
||||
5. Inverse Fourier Transform: Computes the inverse Fourier transform to return to the real domain.
|
||||
"""
|
||||
|
||||
def __init__(self, in_channels, out_channels, filtering=None, style_dim=None):
|
||||
"""
|
||||
Initializes the FourierSpaceBlock with the given parameters.
|
||||
Parameters:
|
||||
- in_channels: Number of input channels.
|
||||
- out_channels: Number of output channels.
|
||||
- filtering: Optional parameter to specify the filtering modes for each dimension. filtering=(kx, ky, kz) where kx, ky, kz are the number of modes to keep in each dimension. If None, then all modes are kept and the linear transformation is the same for all of them.
|
||||
- style_dim: Dimension of the style vector for FiLM conditioning.
|
||||
"""
|
||||
super(FourierSpaceBlock3D, self).__init__()
|
||||
self.in_channels = in_channels
|
||||
self.out_channels = out_channels
|
||||
self.film = FiLM(out_channels, style_dim) if style_dim else None
|
||||
self.filtering = filtering if filtering is not None else (None, None, None)
|
||||
self.scale = (1 / (in_channels * out_channels))
|
||||
self.weights = nn.Parameter(
|
||||
self.scale * torch.randn(in_channels, out_channels, self.filtering[0] if self.filtering[0] else 1, self.filtering[1] if self.filtering[1] else 1, self.filtering[2] if self.filtering[2] else 1, 2)
|
||||
)
|
||||
|
||||
def compl_mul3d(self, input, weights):
|
||||
# input: (B, I, X, Y, Z), weights: (I, O, X, Y, Z)
|
||||
return torch.einsum("bixyz, ioxyz -> boxyz", input, weights)
|
||||
|
||||
|
||||
def forward(self, x, style=None):
|
||||
batchsize = x.shape[0]
|
||||
x_ft = torch.fft.rfftn(x, dim=[2, 3, 4])
|
||||
out_ft = torch.zeros(
|
||||
batchsize,
|
||||
self.out_channels,
|
||||
x.shape[2],
|
||||
x.shape[3],
|
||||
x.shape[4]//2 + 1,
|
||||
dtype=torch.cfloat,
|
||||
device=x.device
|
||||
)
|
||||
modes1 = self.filtering[0] if self.filtering[0] else x.shape[2]
|
||||
modes2 = self.filtering[1] if self.filtering[1] else x.shape[3]
|
||||
modes3 = self.filtering[2] if self.filtering[2] else x.shape[4] // 2 + 1
|
||||
x_ft_sub = x_ft[:, :, :modes1, :modes2, :modes3]
|
||||
w_complex = torch.view_as_complex(self.weights)
|
||||
out_ft[:, :, :modes1, :modes2, :modes3] = self.compl_mul3d(x_ft_sub, w_complex)
|
||||
if self.film is not None:
|
||||
out_ft = self.film(out_ft, style)
|
||||
x = torch.fft.irfftn(out_ft, s=(x.shape[2], x.shape[3], x.shape[4]), dim=[2, 3, 4])
|
||||
return x
|
||||
|
||||
|
||||
class RealSpaceBlock3D(nn.Module):
|
||||
"""
|
||||
The Real Space Block applies the following operations:
|
||||
|
||||
1. Convolution: Applies a 3D convolution to the input tensor.
|
||||
2. Activation: Applies a ReLU activation function.
|
||||
3. FiLM: Applies Feature-wise Linear Modulation (FiLM) to condition the real features on style parameters.
|
||||
4. Repeats the steps for the second convolution.
|
||||
"""
|
||||
|
||||
def __init__(self, in_channels, out_channels, style_dim=None):
|
||||
super(RealSpaceBlock3D, self).__init__()
|
||||
self.conv1 = nn.Conv3d(in_channels, out_channels, kernel_size=3, padding=1)
|
||||
self.conv2 = nn.Conv3d(out_channels, out_channels, kernel_size=3, padding=1)
|
||||
self.relu = nn.ReLU(inplace=True)
|
||||
self.film1 = FiLM(out_channels, style_dim) if style_dim else None
|
||||
self.film2 = FiLM(out_channels, style_dim) if style_dim else None
|
||||
|
||||
def forward(self, x, style=None):
|
||||
x = self.relu(self.conv1(x))
|
||||
if self.film1 is not None:
|
||||
x = self.film1(x, style)
|
||||
x = self.relu(self.conv2(x))
|
||||
if self.film2 is not None:
|
||||
x = self.film2(x, style)
|
||||
return x
|
||||
|
||||
|
||||
class FNOBlock3D(nn.Module):
|
||||
"""
|
||||
The Fourier Neural Operator (FNO) block applies the following operations:
|
||||
|
||||
A. Fourier Space Block:
|
||||
1. Fourier Transform: Computes the Fourier transform of the input tensor.
|
||||
2. Filtering: Removes the high-frequency components by slicing the Fourier field.
|
||||
3. Linear Transformation: Applies a linear transformation in the Fourier space.
|
||||
4. FiLM: Applies Feature-wise Linear Modulation (FiLM) to condition the Fourier features on style parameters.
|
||||
5. Inverse Fourier Transform: Computes the inverse Fourier transform to return to the real domain.
|
||||
|
||||
B. Real Space Block:
|
||||
1. Convolution: Applies a 3D convolution to the input tensor.
|
||||
2. Activation: Applies a ReLU activation function.
|
||||
3. FiLM: Applies Feature-wise Linear Modulation (FiLM) to condition the real features on style parameters.
|
||||
4. Repeats the steps for the second convolution.
|
||||
|
||||
C: Combination:
|
||||
1. Concatenation: Conctenates the outputs from the Fourier and real space blocks (channel size is doubled).
|
||||
2. Final Convolution: Applies a final convolution to reduce the channel size.
|
||||
3. Activation: Applies a ReLU activation function.
|
||||
4. FiLM: Applies Feature-wise Linear Modulation (FiLM) to condition the combined features on style parameters.
|
||||
5. Dropout: Applies dropout to the output tensor.
|
||||
"""
|
||||
|
||||
def __init__(self, in_channels, out_channels, filtering=None, style_dim=None, dropout=0.05):
|
||||
super(FNOBlock3D, self).__init__()
|
||||
self.fourier_block = FourierSpaceBlock3D(in_channels, out_channels, filtering=filtering, style_dim=style_dim,)
|
||||
self.real_block = RealSpaceBlock3D(in_channels, out_channels, style_dim=style_dim)
|
||||
self.comb_conv = nn.Conv3d(2 * out_channels, out_channels, kernel_size=1, padding=0)
|
||||
self.relu = nn.ReLU(inplace=True)
|
||||
self.film = FiLM(out_channels, style_dim) if style_dim else None
|
||||
self.dropout = nn.Dropout(dropout) if dropout > 0 else None
|
||||
|
||||
def forward(self, x, style=None):
|
||||
fourier_out = self.fourier_block(x, style)
|
||||
real_out = self.real_block(x, style)
|
||||
combined = torch.cat([fourier_out, real_out], dim=1)
|
||||
out = self.comb_conv(combined)
|
||||
out = self.relu(out)
|
||||
if self.film is not None:
|
||||
out = self.film(out, style)
|
||||
if self.dropout is not None:
|
||||
out = self.dropout(out)
|
||||
return out
|
||||
|
||||
|
||||
class FNO3D(BaseModel):
|
||||
"""
|
||||
Fourier Neural Operator (FNO) model for 3D data.
|
||||
Architecture:
|
||||
- Input: 3D tensor with shape (B, C, N, N, N) where B is the batch size, C is the number of channels, and N is the spatial dimension.
|
||||
- Embeddding: Applies a linear transformation to the input tensor to increase the number of channels.
|
||||
- FNO Blocks: Applies a series of FNO blocks to the embedded input tensor.
|
||||
- Output: Applies a final linear transformation to reduce the number of channels to the desired output channels.
|
||||
"""
|
||||
|
||||
def __init__(self,
|
||||
N: int = 128,
|
||||
in_channels: int = 2,
|
||||
out_channels: int = 1,
|
||||
latent_channels: int = 32,
|
||||
style_parameters: int = 2,
|
||||
device: torch.device = torch.device('cpu'),
|
||||
filtering: tuple|list|None = None,
|
||||
num_blocks: int = 4,
|
||||
dropout: float = 0.05):
|
||||
|
||||
"""
|
||||
Initializes the FNO3D model with the given parameters.
|
||||
"""
|
||||
|
||||
super(FNO3D, self).__init__(N=N, in_channels=in_channels, out_channels=out_channels, style_parameters=style_parameters, device=device)
|
||||
self.latent_channels = latent_channels
|
||||
self.filtering = filtering if filtering is not None else (self.N, self.N, self.N // 2 + 1) # We keep all modes by default
|
||||
self.num_blocks = num_blocks
|
||||
|
||||
self.embedding = nn.Conv3d(in_channels, latent_channels, kernel_size=1, padding=0)
|
||||
self.fno_blocks = nn.ModuleList([
|
||||
FNOBlock3D(latent_channels, latent_channels, filtering=self.filtering, style_dim=style_parameters, dropout=dropout)
|
||||
for _ in range(num_blocks)
|
||||
])
|
||||
self.final_conv = nn.Conv3d(latent_channels, out_channels, kernel_size=1, padding=0)
|
||||
|
||||
|
||||
def forward(self, x, style=None):
|
||||
"""
|
||||
Forward pass of the FNO3D model.
|
||||
Parameters:
|
||||
- x: Input tensor of shape (B, C, N, N, N) where B is the batch size, C is the number of channels.
|
||||
- style: Style parameters tensor of shape (B, S) where S is the number of style parameters.
|
||||
Returns:
|
||||
- Output tensor of shape (B, C_out, N, N, N) where C_out is the number of output channels.
|
||||
"""
|
||||
x = self.embedding(x)
|
||||
for block in self.fno_blocks:
|
||||
x = block(x, style)
|
||||
x = self.final_conv(x)
|
||||
return x
|
Loading…
Add table
Add a link
Reference in a new issue