import torch from .base_class_models import BaseModel from torch import nn from e3nn import o3 from e3nn.o3 import Irreps from e3nn.nn import Gate from e3nn.o3 import FullyConnectedTensorProduct from e3nn.nn.models.v2106.gate_points_networks import SimpleNetwork # Scalar field = trivial irreducible representation class E3nnNet(BaseModel): def __init__(self, N: int = 128, in_channels: int = 2, out_channels: int = 1, style_parameters: int = 2, hidden_channels: int = 16, num_layers: int = 4, radius: float = 2.5, num_neighbors: int = 12, device: torch.device = torch.device('cpu')): """ E3nn-based model. Parameters: - N: Size of the input data: data will have the shape (B, C, N, N, N) with B the batch size, C the number of channels. - in_channels: Number of input channels (default is 2). - out_channels: Number of output channels (default is 1). - style_parameters: Number of style parameters (default is 2). - hidden_channels: Number of hidden channels (default is 16). - num_layers: Number of hidden layers (default is 4). - radius: Radius for the neighborhood search (default is 2.5). - num_neighbors: Number of neighbors to consider (default is 12). - device: Device to load the model onto (default is CPU). This model uses e3nn to handle the spherical harmonics and irreducible representations. The input is expected to be a scalar field, which is represented as a trivial irreducible representation (0e). The model consists of a simple network with fully connected layers and a gate mechanism. The input is reshaped to a 2D tensor where each voxel's position is concatenated with the style parameters. The output is reshaped back to the original 3D shape with a single output channel. """ super().__init__(N=N, in_channels=in_channels, out_channels=out_channels, style_parameters=style_parameters, device=device) irreps_input = Irreps(f"{in_channels+style_parameters}x0e") # input channels + style parameters irreps_hidden = Irreps(f"{hidden_channels}x0e") # hidden layers irreps_output = Irreps(f"{out_channels}x0e") # scalar output self.model = SimpleNetwork( irreps_in=irreps_input, irreps_out=irreps_output, layers=[irreps_hidden] * num_layers, radius=radius, num_neighbors=num_neighbors, ) def forward(self, x, style): # Reshape x: (B, C, N, N, N) -> (B*N^3, C) B, C, N, _, _ = x.shape x = x.permute(0, 2, 3, 4, 1).reshape(-1, C) pos = torch.stack(torch.meshgrid( torch.linspace(-1, 1, N), torch.linspace(-1, 1, N), torch.linspace(-1, 1, N), indexing='ij' ), dim=-1).reshape(-1, 3).repeat(B, 1, 1).reshape(-1, 3).to(x.device) # Expand style to each voxel style = style.unsqueeze(1).expand(-1, N**3, -1).reshape(-1, style.shape[-1]) x = torch.cat([x, style], dim=-1) # Simple concat of style params out = self.model(pos, x) out = out.reshape(B, N, N, N, 1).permute(0, 4, 1, 2, 3) return out