gravpot dataset

This commit is contained in:
Mayeul Aubin 2025-06-05 16:32:50 +02:00
parent b09b866bed
commit 2b9830211e

View file

@ -0,0 +1,202 @@
import os
import random
import numpy as np
import torch
from torch.utils.data import Dataset, DataLoader
from glob import glob
import re
def read_cosmo_and_time_file(cosmo_and_time_file):
with open(cosmo_and_time_file, 'r') as f:
lines = f.readlines()
cosmo_and_time_params = {}
for line in lines:
if line.strip(): # Skip empty lines
key, value = line.split(':')
if key.strip() == 'ID':
cosmo_and_time_params['ID'] = value.strip()
elif key.strip() == 'nforce':
cosmo_and_time_params['nforce'] = int(value.strip())
else:
cosmo_and_time_params[key.strip()] = float(value.strip())
return cosmo_and_time_params
class GravPotDataset(Dataset):
INITIAL_CONDITIONS_DIR = 'initial_conditions'
TARGET_DIR = 'gravitational_potential'
STYLE_DIR = 'cosmo_and_time'
def __init__(self,
root_dir:str,
ids:list|None=None,
N:int=128,
N_full:int=768,
match_str:str='train',
device=torch.device('cpu'),
initial_conditions_variables:tuple|list=['DM_delta', 'DM_phi'],
target_variable:str='gravpot',
style_files:str='cosmo_and_time_parameters',
style_keys:list|None=["D1", "D2"],
max_time:int=100):
"""
Dataset for gravitational potential data.
Parameters:
- root_dir: Directory containing the dataset.
- ids: List of IDs to include in the dataset. If None, will discover IDs.
- N: Size of the chunks to read (N x N x N).
- N_full: Full size of the simulation box (N_full x N_full x N_full).
- device: Device to load tensors onto (default is CPU)."""
self.initial_conditions_variables = initial_conditions_variables
self.target_variable = target_variable
self.style_files = style_files
self.style_keys = style_keys
self.max_time = max_time
self.root_dir = root_dir
self.N = N
self.N_full = N_full
self.device = device
# Compute how many chunks per dimension
self.chunks_per_dim = N_full // N
self.chunks_per_entry = self.chunks_per_dim ** 3
# This will hold (ID, time, offsets)
self.samples = []
if ids is None:
ids = self.discover_ids(match_str=match_str)
self.ids = ids
# Build indexable sample list
self._prepare_samples()
def discover_ids(self, match_str:str='train'):
"""
Discover IDs that contain match_str and have at least one valid time step
with all required files present.
"""
pattern = os.path.join(self.root_dir, self.INITIAL_CONDITIONS_DIR, 'ICs_*')
files = glob(pattern)
valid_ids = []
for file in files:
match = re.search(r'ICs_(.+?)_'+f"{self.initial_conditions_variables[0]}.h5", os.path.basename(file))
if not match:
continue
ID = match.group(1)
if match_str not in ID:
continue
# Check if corresponding initial conditions file exists
for var in self.initial_conditions_variables[1:]:
ic_path = os.path.join(self.root_dir, self.INITIAL_CONDITIONS_DIR, f'ICs_{ID}_{var}.h5')
if not os.path.exists(ic_path):
continue
# Check if at least one time has both target and style params
found_valid_time = False
for t in range(self.max_time):
target_path = os.path.join(self.root_dir, self.TARGET_DIR, f'{self.target_variable}_{ID}_nforce{t}.h5')
style_path = os.path.join(self.root_dir, self.STYLE_DIR, f'{self.style_files}_{ID}_nforce{t}.txt')
if os.path.exists(target_path) and os.path.exists(style_path):
found_valid_time = True
break
if found_valid_time:
valid_ids.append(ID)
return sorted(valid_ids)
def _get_valid_times(self, ID):
"""Returns valid time indices for which all required files exist."""
valid_times = []
for t in range(100): # arbitrary upper limit
target_path = os.path.join(self.root_dir, self.TARGET_DIR, f'{self.target_variable}_{ID}_nforce{t}.h5')
style_path = os.path.join(self.root_dir, self.STYLE_DIR, f'{self.style_files}_{ID}_nforce{t}.txt')
if all(os.path.exists(p) for p in [target_path, style_path]):
valid_times.append(t)
return valid_times
def _prepare_samples(self):
self.samples.clear()
for ID in self.ids:
times = self._get_valid_times(ID)
if not times:
continue
# Random offset for epoch start
base_offset_x = random.randint(0, self.N_full)
base_offset_y = random.randint(0, self.N_full)
base_offset_z = random.randint(0, self.N_full)
for i in range(self.chunks_per_dim):
for j in range(self.chunks_per_dim):
for k in range(self.chunks_per_dim):
selected_time = random.choice(times)
offset_x = (base_offset_x + i * self.N) % self.N_full
offset_y = (base_offset_y + j * self.N) % self.N_full
offset_z = (base_offset_z + k * self.N) % self.N_full
self.samples.append((ID, selected_time, offset_x, offset_y, offset_z))
def __len__(self):
return len(self.samples)
def __getitem__(self, idx):
from pysbmy.field import read_field_chunk_3D_periodic
ID, t, ox, oy, oz = self.samples[idx]
# Filepaths
input_paths = [
os.path.join(self.root_dir, self.INITIAL_CONDITIONS_DIR, f'ICs_{ID}_{var}.h5')
for var in self.initial_conditions_variables
]
target_path = os.path.join(self.root_dir, self.TARGET_DIR, f'{self.target_variable}_{ID}_nforce{t}.h5')
style_path = os.path.join(self.root_dir, self.STYLE_DIR, f'{self.style_files}_{ID}_nforce{t}.txt')
# Read 3D chunks
input_arrays = [
read_field_chunk_3D_periodic(file, self.N,self.N,self.N, ox,oy,oz, name=varname).array
for file, varname in zip(input_paths, self.initial_conditions_variables)
]
target_array = read_field_chunk_3D_periodic(target_path, self.N, self.N, self.N, ox, oy, oz, name=self.target_variable).array
# Stack the input arrays
input_tensor = np.stack(input_arrays, axis=0)
input_tensor = torch.tensor(input_tensor, dtype=torch.float32).to(self.device)
# Target
target_tensor = torch.tensor(target_array, dtype=torch.float32).unsqueeze(0).to(self.device)
# Style parameters
style_params = read_cosmo_and_time_file(style_path)
# Select only the specified style keys
if self.style_keys is not None:
style_params = [style_params[key] for key in self.style_keys if key in style_params]
else:
style_params = list(style_params.values())
# Convert to tensor
style_tensor = torch.tensor(style_params, dtype=torch.float32).to(self.device)
return {
'input': input_tensor,
'target': target_tensor,
'style': style_tensor,
'ID': ID,
'time': t,
'offset': (ox, oy, oz)
}
def on_epoch_end(self):
"""Call this at the end of each epoch to regenerate offset + time choices."""
self._prepare_samples()