Compare commits

...

2 commits

Author SHA1 Message Date
24c2d546db dataset improvement 2025-06-05 17:30:38 +02:00
26af105195 ChatGPT models 2025-06-05 17:30:26 +02:00
5 changed files with 204 additions and 5 deletions

View file

@ -7,6 +7,7 @@ from glob import glob
import re
def read_cosmo_and_time_file(cosmo_and_time_file):
with open(cosmo_and_time_file, 'r') as f:
lines = f.readlines()
@ -154,6 +155,10 @@ class GravPotDataset(Dataset):
def __getitem__(self, idx):
from pysbmy.field import read_field_chunk_3D_periodic
from io import BytesIO
from sbmy_control.low_level import stdout_redirector, stderr_redirector
f = BytesIO()
ID, t, ox, oy, oz = self.samples[idx]
# Filepaths
@ -165,11 +170,12 @@ class GravPotDataset(Dataset):
style_path = os.path.join(self.root_dir, self.STYLE_DIR, f'{self.style_files}_{ID}_nforce{t}.txt')
# Read 3D chunks
input_arrays = [
read_field_chunk_3D_periodic(file, self.N,self.N,self.N, ox,oy,oz, name=varname).array
for file, varname in zip(input_paths, self.initial_conditions_variables)
]
target_array = read_field_chunk_3D_periodic(target_path, self.N, self.N, self.N, ox, oy, oz, name=self.target_variable).array
with stdout_redirector(f):
input_arrays = [
read_field_chunk_3D_periodic(file, self.N,self.N,self.N, ox,oy,oz, name=varname).array
for file, varname in zip(input_paths, self.initial_conditions_variables)
]
target_array = read_field_chunk_3D_periodic(target_path, self.N, self.N, self.N, ox, oy, oz, name=self.target_variable).array
# Stack the input arrays
input_tensor = np.stack(input_arrays, axis=0)

View file

@ -0,0 +1,76 @@
import torch
import torch.nn as nn
import torch.nn.functional as F
from .base_class_models import BaseModel
class FiLM(nn.Module):
def __init__(self, num_features, style_dim):
super(FiLM, self).__init__()
self.gamma = nn.Linear(style_dim, num_features)
self.beta = nn.Linear(style_dim, num_features)
def forward(self, x, style):
gamma = self.gamma(style).unsqueeze(-1).unsqueeze(-1).unsqueeze(-1)
beta = self.beta(style).unsqueeze(-1).unsqueeze(-1).unsqueeze(-1)
return gamma * x + beta
class UNetBlock(nn.Module):
def __init__(self, in_channels, out_channels, style_dim=None):
super(UNetBlock, self).__init__()
self.conv1 = nn.Conv3d(in_channels, out_channels, kernel_size=3, padding=1)
self.conv2 = nn.Conv3d(out_channels, out_channels, kernel_size=3, padding=1)
self.norm = nn.BatchNorm3d(out_channels)
self.relu = nn.ReLU(inplace=True)
self.film = FiLM(out_channels, style_dim) if style_dim else None
def forward(self, x, style=None):
x = self.relu(self.norm(self.conv1(x)))
x = self.relu(self.norm(self.conv2(x)))
if self.film:
x = self.film(x, style)
return x
class UNet3D(BaseModel):
def __init__(self, N: int = 128,
in_channels: int = 2,
out_channels: int = 1,
style_dim: int = 2,
device: torch.device = torch.device('cpu')):
"""
3D U-Net model with optional FiLM layers for style conditioning.
Parameters:
- N: Size of the input data: data will have the shape (B, C, N, N, N) with B the batch size, C the number of channels.
- in_channels: Number of input channels (default is 2).
- out_channels: Number of output channels (default is 1).
- style_dim: Dimension of the style vector (default is 2).
- device: Device to load the model onto (default is CPU).
This model implements a 3D U-Net architecture with downsampling and upsampling blocks.
The model uses convolutional layers with ReLU activations and batch normalization.
The FiLM layers are used to condition the feature maps on style parameters.
"""
super().init(N=N,
in_channels=in_channels,
out_channels=out_channels,
style_parameters=style_dim,
device=device)
self.enc1 = UNetBlock(in_channels, 32, style_dim)
self.pool1 = nn.MaxPool3d(2)
self.enc2 = UNetBlock(32, 64, style_dim)
self.pool2 = nn.MaxPool3d(2)
self.bottleneck = UNetBlock(64, 128, style_dim)
self.up2 = nn.ConvTranspose3d(128, 64, kernel_size=2, stride=2)
self.dec2 = UNetBlock(128, 64)
self.up1 = nn.ConvTranspose3d(64, 32, kernel_size=2, stride=2)
self.dec1 = UNetBlock(64, 32)
self.final = nn.Conv3d(32, out_channels, kernel_size=1)
def forward(self, x, style):
e1 = self.enc1(x, style)
e2 = self.enc2(self.pool1(e1), style)
b = self.bottleneck(self.pool2(e2), style)
d2 = self.dec2(torch.cat([self.up2(b), e2], dim=1))
d1 = self.dec1(torch.cat([self.up1(d2), e1], dim=1))
return self.final(d1)

View file

View file

@ -0,0 +1,38 @@
import torch
class BaseModel(torch.nn.Module):
def __init__(self,
N:int=128,
in_channels:int=2,
out_channels:int=1,
style_parameters:int=2,
device: torch.device = torch.device('cpu')):
"""
Base class for all models.
Parameters:
- N: Size of the input data: data will have the shape (B, C, N,N,N) with B the batch size, C the number of channels.
- in_channels: Number of input channels (default is 2).
- out_channels: Number of output channels (default is 1).
- style_parameters: Number of style parameters (default is 2).
- device: Device to load the model onto (default is CPU).
"""
super().__init__()
self.N = N
self.in_channels = in_channels
self.out_channels = out_channels
self.style_parameters = style_parameters
self.device = device
self.to(self.device)
def forward(self, x, style):
"""
Forward pass of the model.
Should be implemented in subclasses.
Parameters:
- x: Input tensor of shape (B, C, N, N, N) where B is the batch size, C is the number of channels.
- style: Style parameters tensor of shape (B, S) where S is the number of style parameters.
Returns:
- Output tensor of shape (B, C_out, N, N, N) where C_out is the number of output channels.
"""
raise NotImplementedError("Forward method must be implemented in subclasses.")

View file

@ -0,0 +1,79 @@
import torch
from .base_class_models import BaseModel
from torch import nn
from e3nn import o3
from e3nn.o3 import Irreps
from e3nn.nn import Gate
from e3nn.o3 import FullyConnectedTensorProduct
from e3nn.nn.models.v2106.gate_points_networks import SimpleNetwork
# Scalar field = trivial irreducible representation
class E3nnNet(BaseModel):
def __init__(self,
N: int = 128,
in_channels: int = 2,
out_channels: int = 1,
style_parameters: int = 2,
hidden_channels: int = 16,
num_layers: int = 4,
radius: float = 2.5,
num_neighbors: int = 12,
device: torch.device = torch.device('cpu')):
"""
E3nn-based model.
Parameters:
- N: Size of the input data: data will have the shape (B, C, N, N, N) with B the batch size, C the number of channels.
- in_channels: Number of input channels (default is 2).
- out_channels: Number of output channels (default is 1).
- style_parameters: Number of style parameters (default is 2).
- hidden_channels: Number of hidden channels (default is 16).
- num_layers: Number of hidden layers (default is 4).
- radius: Radius for the neighborhood search (default is 2.5).
- num_neighbors: Number of neighbors to consider (default is 12).
- device: Device to load the model onto (default is CPU).
This model uses e3nn to handle the spherical harmonics and irreducible representations.
The input is expected to be a scalar field, which is represented as a trivial irreducible representation (0e).
The model consists of a simple network with fully connected layers and a gate mechanism.
The input is reshaped to a 2D tensor where each voxel's position is concatenated with the style parameters.
The output is reshaped back to the original 3D shape with a single output channel.
"""
super().__init__(N=N,
in_channels=in_channels,
out_channels=out_channels,
style_parameters=style_parameters,
device=device)
irreps_input = Irreps(f"{in_channels+style_parameters}x0e") # input channels + style parameters
irreps_hidden = Irreps(f"{hidden_channels}x0e") # hidden layers
irreps_output = Irreps(f"{out_channels}x0e") # scalar output
self.model = SimpleNetwork(
irreps_in=irreps_input,
irreps_out=irreps_output,
layers=[irreps_hidden] * num_layers,
radius=radius,
num_neighbors=num_neighbors,
)
def forward(self, x, style):
# Reshape x: (B, C, N, N, N) -> (B*N^3, C)
B, C, N, _, _ = x.shape
x = x.permute(0, 2, 3, 4, 1).reshape(-1, C)
pos = torch.stack(torch.meshgrid(
torch.linspace(-1, 1, N),
torch.linspace(-1, 1, N),
torch.linspace(-1, 1, N),
indexing='ij'
), dim=-1).reshape(-1, 3).repeat(B, 1, 1).reshape(-1, 3).to(x.device)
# Expand style to each voxel
style = style.unsqueeze(1).expand(-1, N**3, -1).reshape(-1, style.shape[-1])
x = torch.cat([x, style], dim=-1) # Simple concat of style params
out = self.model(pos, x)
out = out.reshape(B, N, N, N, 1).permute(0, 4, 1, 2, 3)
return out