ML_GravPotBCs/sCOCA_ML/models/e3nn_models.py
2025-06-05 17:30:26 +02:00

79 lines
3.4 KiB
Python

import torch
from .base_class_models import BaseModel
from torch import nn
from e3nn import o3
from e3nn.o3 import Irreps
from e3nn.nn import Gate
from e3nn.o3 import FullyConnectedTensorProduct
from e3nn.nn.models.v2106.gate_points_networks import SimpleNetwork
# Scalar field = trivial irreducible representation
class E3nnNet(BaseModel):
def __init__(self,
N: int = 128,
in_channels: int = 2,
out_channels: int = 1,
style_parameters: int = 2,
hidden_channels: int = 16,
num_layers: int = 4,
radius: float = 2.5,
num_neighbors: int = 12,
device: torch.device = torch.device('cpu')):
"""
E3nn-based model.
Parameters:
- N: Size of the input data: data will have the shape (B, C, N, N, N) with B the batch size, C the number of channels.
- in_channels: Number of input channels (default is 2).
- out_channels: Number of output channels (default is 1).
- style_parameters: Number of style parameters (default is 2).
- hidden_channels: Number of hidden channels (default is 16).
- num_layers: Number of hidden layers (default is 4).
- radius: Radius for the neighborhood search (default is 2.5).
- num_neighbors: Number of neighbors to consider (default is 12).
- device: Device to load the model onto (default is CPU).
This model uses e3nn to handle the spherical harmonics and irreducible representations.
The input is expected to be a scalar field, which is represented as a trivial irreducible representation (0e).
The model consists of a simple network with fully connected layers and a gate mechanism.
The input is reshaped to a 2D tensor where each voxel's position is concatenated with the style parameters.
The output is reshaped back to the original 3D shape with a single output channel.
"""
super().__init__(N=N,
in_channels=in_channels,
out_channels=out_channels,
style_parameters=style_parameters,
device=device)
irreps_input = Irreps(f"{in_channels+style_parameters}x0e") # input channels + style parameters
irreps_hidden = Irreps(f"{hidden_channels}x0e") # hidden layers
irreps_output = Irreps(f"{out_channels}x0e") # scalar output
self.model = SimpleNetwork(
irreps_in=irreps_input,
irreps_out=irreps_output,
layers=[irreps_hidden] * num_layers,
radius=radius,
num_neighbors=num_neighbors,
)
def forward(self, x, style):
# Reshape x: (B, C, N, N, N) -> (B*N^3, C)
B, C, N, _, _ = x.shape
x = x.permute(0, 2, 3, 4, 1).reshape(-1, C)
pos = torch.stack(torch.meshgrid(
torch.linspace(-1, 1, N),
torch.linspace(-1, 1, N),
torch.linspace(-1, 1, N),
indexing='ij'
), dim=-1).reshape(-1, 3).repeat(B, 1, 1).reshape(-1, 3).to(x.device)
# Expand style to each voxel
style = style.unsqueeze(1).expand(-1, N**3, -1).reshape(-1, style.shape[-1])
x = torch.cat([x, style], dim=-1) # Simple concat of style params
out = self.model(pos, x)
out = out.reshape(B, N, N, N, 1).permute(0, 4, 1, 2, 3)
return out