Compare commits
2 commits
2b9830211e
...
24c2d546db
Author | SHA1 | Date | |
---|---|---|---|
24c2d546db | |||
26af105195 |
5 changed files with 204 additions and 5 deletions
|
@ -7,6 +7,7 @@ from glob import glob
|
|||
import re
|
||||
|
||||
|
||||
|
||||
def read_cosmo_and_time_file(cosmo_and_time_file):
|
||||
with open(cosmo_and_time_file, 'r') as f:
|
||||
lines = f.readlines()
|
||||
|
@ -154,6 +155,10 @@ class GravPotDataset(Dataset):
|
|||
|
||||
def __getitem__(self, idx):
|
||||
from pysbmy.field import read_field_chunk_3D_periodic
|
||||
from io import BytesIO
|
||||
from sbmy_control.low_level import stdout_redirector, stderr_redirector
|
||||
f = BytesIO()
|
||||
|
||||
ID, t, ox, oy, oz = self.samples[idx]
|
||||
|
||||
# Filepaths
|
||||
|
@ -165,11 +170,12 @@ class GravPotDataset(Dataset):
|
|||
style_path = os.path.join(self.root_dir, self.STYLE_DIR, f'{self.style_files}_{ID}_nforce{t}.txt')
|
||||
|
||||
# Read 3D chunks
|
||||
input_arrays = [
|
||||
read_field_chunk_3D_periodic(file, self.N,self.N,self.N, ox,oy,oz, name=varname).array
|
||||
for file, varname in zip(input_paths, self.initial_conditions_variables)
|
||||
]
|
||||
target_array = read_field_chunk_3D_periodic(target_path, self.N, self.N, self.N, ox, oy, oz, name=self.target_variable).array
|
||||
with stdout_redirector(f):
|
||||
input_arrays = [
|
||||
read_field_chunk_3D_periodic(file, self.N,self.N,self.N, ox,oy,oz, name=varname).array
|
||||
for file, varname in zip(input_paths, self.initial_conditions_variables)
|
||||
]
|
||||
target_array = read_field_chunk_3D_periodic(target_path, self.N, self.N, self.N, ox, oy, oz, name=self.target_variable).array
|
||||
|
||||
# Stack the input arrays
|
||||
input_tensor = np.stack(input_arrays, axis=0)
|
||||
|
|
76
sCOCA_ML/models/UNet_models.py
Normal file
76
sCOCA_ML/models/UNet_models.py
Normal file
|
@ -0,0 +1,76 @@
|
|||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
from .base_class_models import BaseModel
|
||||
|
||||
class FiLM(nn.Module):
|
||||
def __init__(self, num_features, style_dim):
|
||||
super(FiLM, self).__init__()
|
||||
self.gamma = nn.Linear(style_dim, num_features)
|
||||
self.beta = nn.Linear(style_dim, num_features)
|
||||
|
||||
def forward(self, x, style):
|
||||
gamma = self.gamma(style).unsqueeze(-1).unsqueeze(-1).unsqueeze(-1)
|
||||
beta = self.beta(style).unsqueeze(-1).unsqueeze(-1).unsqueeze(-1)
|
||||
return gamma * x + beta
|
||||
|
||||
class UNetBlock(nn.Module):
|
||||
def __init__(self, in_channels, out_channels, style_dim=None):
|
||||
super(UNetBlock, self).__init__()
|
||||
self.conv1 = nn.Conv3d(in_channels, out_channels, kernel_size=3, padding=1)
|
||||
self.conv2 = nn.Conv3d(out_channels, out_channels, kernel_size=3, padding=1)
|
||||
self.norm = nn.BatchNorm3d(out_channels)
|
||||
self.relu = nn.ReLU(inplace=True)
|
||||
self.film = FiLM(out_channels, style_dim) if style_dim else None
|
||||
|
||||
def forward(self, x, style=None):
|
||||
x = self.relu(self.norm(self.conv1(x)))
|
||||
x = self.relu(self.norm(self.conv2(x)))
|
||||
if self.film:
|
||||
x = self.film(x, style)
|
||||
return x
|
||||
|
||||
class UNet3D(BaseModel):
|
||||
def __init__(self, N: int = 128,
|
||||
in_channels: int = 2,
|
||||
out_channels: int = 1,
|
||||
style_dim: int = 2,
|
||||
device: torch.device = torch.device('cpu')):
|
||||
"""
|
||||
3D U-Net model with optional FiLM layers for style conditioning.
|
||||
Parameters:
|
||||
- N: Size of the input data: data will have the shape (B, C, N, N, N) with B the batch size, C the number of channels.
|
||||
- in_channels: Number of input channels (default is 2).
|
||||
- out_channels: Number of output channels (default is 1).
|
||||
- style_dim: Dimension of the style vector (default is 2).
|
||||
- device: Device to load the model onto (default is CPU).
|
||||
This model implements a 3D U-Net architecture with downsampling and upsampling blocks.
|
||||
The model uses convolutional layers with ReLU activations and batch normalization.
|
||||
The FiLM layers are used to condition the feature maps on style parameters.
|
||||
"""
|
||||
|
||||
super().init(N=N,
|
||||
in_channels=in_channels,
|
||||
out_channels=out_channels,
|
||||
style_parameters=style_dim,
|
||||
device=device)
|
||||
|
||||
self.enc1 = UNetBlock(in_channels, 32, style_dim)
|
||||
self.pool1 = nn.MaxPool3d(2)
|
||||
self.enc2 = UNetBlock(32, 64, style_dim)
|
||||
self.pool2 = nn.MaxPool3d(2)
|
||||
self.bottleneck = UNetBlock(64, 128, style_dim)
|
||||
|
||||
self.up2 = nn.ConvTranspose3d(128, 64, kernel_size=2, stride=2)
|
||||
self.dec2 = UNetBlock(128, 64)
|
||||
self.up1 = nn.ConvTranspose3d(64, 32, kernel_size=2, stride=2)
|
||||
self.dec1 = UNetBlock(64, 32)
|
||||
self.final = nn.Conv3d(32, out_channels, kernel_size=1)
|
||||
|
||||
def forward(self, x, style):
|
||||
e1 = self.enc1(x, style)
|
||||
e2 = self.enc2(self.pool1(e1), style)
|
||||
b = self.bottleneck(self.pool2(e2), style)
|
||||
d2 = self.dec2(torch.cat([self.up2(b), e2], dim=1))
|
||||
d1 = self.dec1(torch.cat([self.up1(d2), e1], dim=1))
|
||||
return self.final(d1)
|
0
sCOCA_ML/models/__init__.py
Normal file
0
sCOCA_ML/models/__init__.py
Normal file
38
sCOCA_ML/models/base_class_models.py
Normal file
38
sCOCA_ML/models/base_class_models.py
Normal file
|
@ -0,0 +1,38 @@
|
|||
import torch
|
||||
|
||||
|
||||
class BaseModel(torch.nn.Module):
|
||||
def __init__(self,
|
||||
N:int=128,
|
||||
in_channels:int=2,
|
||||
out_channels:int=1,
|
||||
style_parameters:int=2,
|
||||
device: torch.device = torch.device('cpu')):
|
||||
"""
|
||||
Base class for all models.
|
||||
Parameters:
|
||||
- N: Size of the input data: data will have the shape (B, C, N,N,N) with B the batch size, C the number of channels.
|
||||
- in_channels: Number of input channels (default is 2).
|
||||
- out_channels: Number of output channels (default is 1).
|
||||
- style_parameters: Number of style parameters (default is 2).
|
||||
- device: Device to load the model onto (default is CPU).
|
||||
"""
|
||||
super().__init__()
|
||||
self.N = N
|
||||
self.in_channels = in_channels
|
||||
self.out_channels = out_channels
|
||||
self.style_parameters = style_parameters
|
||||
self.device = device
|
||||
self.to(self.device)
|
||||
|
||||
def forward(self, x, style):
|
||||
"""
|
||||
Forward pass of the model.
|
||||
Should be implemented in subclasses.
|
||||
Parameters:
|
||||
- x: Input tensor of shape (B, C, N, N, N) where B is the batch size, C is the number of channels.
|
||||
- style: Style parameters tensor of shape (B, S) where S is the number of style parameters.
|
||||
Returns:
|
||||
- Output tensor of shape (B, C_out, N, N, N) where C_out is the number of output channels.
|
||||
"""
|
||||
raise NotImplementedError("Forward method must be implemented in subclasses.")
|
79
sCOCA_ML/models/e3nn_models.py
Normal file
79
sCOCA_ML/models/e3nn_models.py
Normal file
|
@ -0,0 +1,79 @@
|
|||
import torch
|
||||
from .base_class_models import BaseModel
|
||||
from torch import nn
|
||||
from e3nn import o3
|
||||
from e3nn.o3 import Irreps
|
||||
from e3nn.nn import Gate
|
||||
from e3nn.o3 import FullyConnectedTensorProduct
|
||||
from e3nn.nn.models.v2106.gate_points_networks import SimpleNetwork
|
||||
|
||||
# Scalar field = trivial irreducible representation
|
||||
|
||||
class E3nnNet(BaseModel):
|
||||
|
||||
def __init__(self,
|
||||
N: int = 128,
|
||||
in_channels: int = 2,
|
||||
out_channels: int = 1,
|
||||
style_parameters: int = 2,
|
||||
hidden_channels: int = 16,
|
||||
num_layers: int = 4,
|
||||
radius: float = 2.5,
|
||||
num_neighbors: int = 12,
|
||||
device: torch.device = torch.device('cpu')):
|
||||
"""
|
||||
E3nn-based model.
|
||||
Parameters:
|
||||
- N: Size of the input data: data will have the shape (B, C, N, N, N) with B the batch size, C the number of channels.
|
||||
- in_channels: Number of input channels (default is 2).
|
||||
- out_channels: Number of output channels (default is 1).
|
||||
- style_parameters: Number of style parameters (default is 2).
|
||||
- hidden_channels: Number of hidden channels (default is 16).
|
||||
- num_layers: Number of hidden layers (default is 4).
|
||||
- radius: Radius for the neighborhood search (default is 2.5).
|
||||
- num_neighbors: Number of neighbors to consider (default is 12).
|
||||
- device: Device to load the model onto (default is CPU).
|
||||
This model uses e3nn to handle the spherical harmonics and irreducible representations.
|
||||
The input is expected to be a scalar field, which is represented as a trivial irreducible representation (0e).
|
||||
The model consists of a simple network with fully connected layers and a gate mechanism.
|
||||
The input is reshaped to a 2D tensor where each voxel's position is concatenated with the style parameters.
|
||||
The output is reshaped back to the original 3D shape with a single output channel.
|
||||
"""
|
||||
|
||||
super().__init__(N=N,
|
||||
in_channels=in_channels,
|
||||
out_channels=out_channels,
|
||||
style_parameters=style_parameters,
|
||||
device=device)
|
||||
|
||||
irreps_input = Irreps(f"{in_channels+style_parameters}x0e") # input channels + style parameters
|
||||
irreps_hidden = Irreps(f"{hidden_channels}x0e") # hidden layers
|
||||
irreps_output = Irreps(f"{out_channels}x0e") # scalar output
|
||||
|
||||
self.model = SimpleNetwork(
|
||||
irreps_in=irreps_input,
|
||||
irreps_out=irreps_output,
|
||||
layers=[irreps_hidden] * num_layers,
|
||||
radius=radius,
|
||||
num_neighbors=num_neighbors,
|
||||
)
|
||||
|
||||
|
||||
def forward(self, x, style):
|
||||
# Reshape x: (B, C, N, N, N) -> (B*N^3, C)
|
||||
B, C, N, _, _ = x.shape
|
||||
x = x.permute(0, 2, 3, 4, 1).reshape(-1, C)
|
||||
pos = torch.stack(torch.meshgrid(
|
||||
torch.linspace(-1, 1, N),
|
||||
torch.linspace(-1, 1, N),
|
||||
torch.linspace(-1, 1, N),
|
||||
indexing='ij'
|
||||
), dim=-1).reshape(-1, 3).repeat(B, 1, 1).reshape(-1, 3).to(x.device)
|
||||
|
||||
# Expand style to each voxel
|
||||
style = style.unsqueeze(1).expand(-1, N**3, -1).reshape(-1, style.shape[-1])
|
||||
x = torch.cat([x, style], dim=-1) # Simple concat of style params
|
||||
|
||||
out = self.model(pos, x)
|
||||
out = out.reshape(B, N, N, N, 1).permute(0, 4, 1, 2, 3)
|
||||
return out
|
Loading…
Add table
Add a link
Reference in a new issue