368 lines
14 KiB
Python
368 lines
14 KiB
Python
import os
|
|
import random
|
|
import numpy as np
|
|
import torch
|
|
from torch.utils.data import Dataset, DataLoader
|
|
from glob import glob
|
|
import re
|
|
|
|
|
|
|
|
def read_cosmo_and_time_file(cosmo_and_time_file):
|
|
with open(cosmo_and_time_file, 'r') as f:
|
|
lines = f.readlines()
|
|
cosmo_and_time_params = {}
|
|
for line in lines:
|
|
if line.strip(): # Skip empty lines
|
|
key, value = line.split(':')
|
|
if key.strip() == 'ID':
|
|
cosmo_and_time_params['ID'] = value.strip()
|
|
elif key.strip() == 'nforce':
|
|
cosmo_and_time_params['nforce'] = int(value.strip())
|
|
else:
|
|
cosmo_and_time_params[key.strip()] = float(value.strip())
|
|
return cosmo_and_time_params
|
|
|
|
|
|
class GravPotDataset(Dataset):
|
|
|
|
INITIAL_CONDITIONS_DIR = 'initial_conditions'
|
|
TARGET_DIR = 'gravitational_potential'
|
|
STYLE_DIR = 'cosmo_and_time'
|
|
|
|
def __init__(self,
|
|
root_dir:str,
|
|
ids:list|None=None,
|
|
N:int=128,
|
|
N_full:int=768,
|
|
match_str:str='train',
|
|
max_len:int|None=None,
|
|
device='cpu',
|
|
initial_conditions_variables:tuple|list=['DM_delta', 'DM_phi'],
|
|
target_variable:str='gravpot',
|
|
style_files:str='cosmo_and_time_parameters',
|
|
style_keys:list|None=["D1", "D2"],
|
|
max_time:int=100):
|
|
"""
|
|
Dataset for gravitational potential data.
|
|
Parameters:
|
|
- root_dir: Directory containing the dataset.
|
|
- ids: List of IDs to include in the dataset. If None, will discover IDs.
|
|
- N: Size of the chunks to read (N x N x N).
|
|
- N_full: Full size of the simulation box (N_full x N_full x N_full).
|
|
- device: Device to load tensors onto (default is CPU)."""
|
|
super().__init__()
|
|
|
|
self.initial_conditions_variables = initial_conditions_variables
|
|
self.target_variable = target_variable
|
|
self.match_str = match_str
|
|
self.style_files = style_files
|
|
self.style_keys = style_keys
|
|
self.max_time = max_time
|
|
self.max_len = max_len
|
|
|
|
self.root_dir = root_dir
|
|
self.N = N
|
|
self.N_full = N_full
|
|
self.device = device
|
|
|
|
# Compute how many chunks per dimension
|
|
self.chunks_per_dim = N_full // N
|
|
self.chunks_per_entry = self.chunks_per_dim ** 3
|
|
|
|
# This will hold (ID, time, offsets)
|
|
self.samples = []
|
|
|
|
if ids is None:
|
|
ids = self.discover_ids(match_str=match_str)
|
|
self.ids = ids
|
|
|
|
# Build indexable sample list
|
|
self._prepare_samples()
|
|
|
|
|
|
def discover_ids(self, match_str:str='train'):
|
|
"""
|
|
Discover IDs that contain match_str and have at least one valid time step
|
|
with all required files present.
|
|
"""
|
|
pattern = os.path.join(self.root_dir, self.INITIAL_CONDITIONS_DIR, 'ICs_*')
|
|
files = glob(pattern)
|
|
|
|
valid_ids = []
|
|
for file in files:
|
|
match = re.search(r'ICs_(.+?)_'+f"{self.initial_conditions_variables[0]}.h5", os.path.basename(file))
|
|
if not match:
|
|
continue
|
|
|
|
ID = match.group(1)
|
|
if match_str not in ID:
|
|
continue
|
|
|
|
# Check if corresponding initial conditions file exists
|
|
for var in self.initial_conditions_variables[1:]:
|
|
ic_path = os.path.join(self.root_dir, self.INITIAL_CONDITIONS_DIR, f'ICs_{ID}_{var}.h5')
|
|
if not os.path.exists(ic_path):
|
|
continue
|
|
|
|
# Check if at least one time has both target and style params
|
|
found_valid_time = False
|
|
for t in range(self.max_time):
|
|
target_path = os.path.join(self.root_dir, self.TARGET_DIR, f'{self.target_variable}_{ID}_nforce{t}.h5')
|
|
style_path = os.path.join(self.root_dir, self.STYLE_DIR, f'{self.style_files}_{ID}_nforce{t}.txt')
|
|
if os.path.exists(target_path) and os.path.exists(style_path):
|
|
found_valid_time = True
|
|
break
|
|
|
|
if found_valid_time:
|
|
valid_ids.append(ID)
|
|
|
|
return sorted(valid_ids)
|
|
|
|
|
|
def _get_valid_times(self, ID):
|
|
"""Returns valid time indices for which all required files exist."""
|
|
valid_times = []
|
|
for t in range(100): # arbitrary upper limit
|
|
target_path = os.path.join(self.root_dir, self.TARGET_DIR, f'{self.target_variable}_{ID}_nforce{t}.h5')
|
|
style_path = os.path.join(self.root_dir, self.STYLE_DIR, f'{self.style_files}_{ID}_nforce{t}.txt')
|
|
if all(os.path.exists(p) for p in [target_path, style_path]):
|
|
valid_times.append(t)
|
|
return valid_times
|
|
|
|
|
|
|
|
def _prepare_samples(self):
|
|
self.samples.clear()
|
|
for ID in self.ids:
|
|
times = self._get_valid_times(ID)
|
|
if not times:
|
|
continue
|
|
|
|
# Random offset for epoch start
|
|
base_offset_x = random.randint(0, self.N_full)
|
|
base_offset_y = random.randint(0, self.N_full)
|
|
base_offset_z = random.randint(0, self.N_full)
|
|
|
|
for i in range(self.chunks_per_dim):
|
|
for j in range(self.chunks_per_dim):
|
|
for k in range(self.chunks_per_dim):
|
|
selected_time = random.choice(times)
|
|
offset_x = (base_offset_x + i * self.N) % self.N_full
|
|
offset_y = (base_offset_y + j * self.N) % self.N_full
|
|
offset_z = (base_offset_z + k * self.N) % self.N_full
|
|
self.samples.append((ID, selected_time, offset_x, offset_y, offset_z))
|
|
|
|
if self.max_len is not None and len(self.samples) > self.max_len:
|
|
self.samples = random.sample(self.samples, int(self.max_len))
|
|
|
|
def __len__(self):
|
|
return len(self.samples)
|
|
|
|
def files_from_samples(self, sample):
|
|
"""
|
|
Return the paths to the files for a given sample.
|
|
"""
|
|
ID, t, ox, oy, oz = sample
|
|
input_paths = [
|
|
os.path.join(self.root_dir, self.INITIAL_CONDITIONS_DIR, f'ICs_{ID}_{var}.h5')
|
|
for var in self.initial_conditions_variables
|
|
]
|
|
target_path = os.path.join(self.root_dir, self.TARGET_DIR, f'{self.target_variable}_{ID}_nforce{t}.h5')
|
|
style_path = os.path.join(self.root_dir, self.STYLE_DIR, f'{self.style_files}_{ID}_nforce{t}.txt')
|
|
return {
|
|
'input': input_paths,
|
|
'target': target_path,
|
|
'style': style_path
|
|
}
|
|
|
|
def files_from_ID_and_time(self, ID, t):
|
|
"""
|
|
Return the paths to the files for a given ID and time.
|
|
"""
|
|
input_paths = [
|
|
os.path.join(self.root_dir, self.INITIAL_CONDITIONS_DIR, f'ICs_{ID}_{var}.h5')
|
|
for var in self.initial_conditions_variables
|
|
]
|
|
target_path = os.path.join(self.root_dir, self.TARGET_DIR, f'{self.target_variable}_{ID}_nforce{t}.h5')
|
|
style_path = os.path.join(self.root_dir, self.STYLE_DIR, f'{self.style_files}_{ID}_nforce{t}.txt')
|
|
return {
|
|
'input': input_paths,
|
|
'target': target_path,
|
|
'style': style_path
|
|
}
|
|
|
|
def get_data(self, ID, t, ox, oy, oz):
|
|
"""
|
|
Get the data for a specific ID, time, and offsets.
|
|
Returns a dictionary with input, target, and style tensors.
|
|
"""
|
|
|
|
from pysbmy.field import read_field_chunk_3D_periodic
|
|
from io import BytesIO
|
|
import torch
|
|
from sbmy_control.low_level import stdout_redirector, stderr_redirector
|
|
f = BytesIO()
|
|
|
|
# Filepaths
|
|
input_paths = [
|
|
os.path.join(self.root_dir, self.INITIAL_CONDITIONS_DIR, f'ICs_{ID}_{var}.h5')
|
|
for var in self.initial_conditions_variables
|
|
]
|
|
target_path = os.path.join(self.root_dir, self.TARGET_DIR, f'{self.target_variable}_{ID}_nforce{t}.h5')
|
|
style_path = os.path.join(self.root_dir, self.STYLE_DIR, f'{self.style_files}_{ID}_nforce{t}.txt')
|
|
|
|
# Read 3D chunks
|
|
input_arrays = [
|
|
read_field_chunk_3D_periodic(file, self.N,self.N,self.N, ox,oy,oz, name=varname).array
|
|
for file, varname in zip(input_paths, self.initial_conditions_variables)
|
|
]
|
|
target_array = read_field_chunk_3D_periodic(target_path, self.N, self.N, self.N, ox, oy, oz, name=self.target_variable).array
|
|
|
|
# Stack the input arrays
|
|
input_tensor = np.stack(input_arrays, axis=0)
|
|
input_tensor = torch.tensor(input_tensor, dtype=torch.float32).to(self.device)
|
|
|
|
# Target
|
|
target_tensor = torch.tensor(target_array, dtype=torch.float32).unsqueeze(0).to(self.device)
|
|
|
|
# Style parameters
|
|
style_params = read_cosmo_and_time_file(style_path)
|
|
# Select only the specified style keys
|
|
if self.style_keys is not None:
|
|
style_params = [style_params[key] for key in self.style_keys if key in style_params]
|
|
else:
|
|
style_params = list(style_params.values())
|
|
# Convert to tensor
|
|
style_tensor = torch.tensor(style_params, dtype=torch.float32).to(self.device)
|
|
|
|
return {
|
|
'input': input_tensor,
|
|
'target': target_tensor,
|
|
'style': style_tensor,
|
|
'ID': ID,
|
|
'time': t,
|
|
'offset': (ox, oy, oz)
|
|
}
|
|
|
|
|
|
|
|
def __getitem__(self, idx):
|
|
|
|
ID, t, ox, oy, oz = self.samples[idx]
|
|
return self.get_data(ID, t, ox, oy, oz)
|
|
|
|
|
|
|
|
def on_epoch_end(self):
|
|
"""Call this at the end of each epoch to regenerate offset + time choices."""
|
|
self._prepare_samples()
|
|
|
|
class SubDataset(Dataset):
|
|
def __init__(self, dataset: GravPotDataset, ID_list: list):
|
|
from copy import deepcopy
|
|
self.dataset = deepcopy(dataset)
|
|
self.ids = ID_list
|
|
self.dataset.ids = ID_list
|
|
|
|
def __len__(self):
|
|
return len(self.dataset)
|
|
|
|
def __getitem__(self, idx):
|
|
return self.dataset[idx]
|
|
|
|
def on_epoch_end(self):
|
|
self.dataset.ids = self.ids
|
|
self.dataset.on_epoch_end()
|
|
|
|
|
|
|
|
|
|
def train_val_split(dataset: GravPotDataset, val_fraction: float = 0.2, seed: int = 42):
|
|
"""
|
|
Splits the dataset into training and validation sets.
|
|
|
|
Parameters:
|
|
- dataset: The GravPotDataset to split.
|
|
- val_fraction: Fraction of the dataset to use for validation.
|
|
|
|
Returns:
|
|
- train_dataset: SubDataset for training.
|
|
- val_dataset: SubDataset for validation.
|
|
"""
|
|
from sklearn.model_selection import train_test_split
|
|
train_ids, val_ids = train_test_split(dataset.ids, test_size=val_fraction, random_state=seed)
|
|
train_dataset = SubDataset(dataset, train_ids)
|
|
val_dataset = SubDataset(dataset, val_ids)
|
|
|
|
if dataset.max_len is not None:
|
|
train_dataset.dataset.max_len = int(dataset.max_len * (1 - val_fraction))
|
|
val_dataset.dataset.max_len = int(dataset.max_len * val_fraction)
|
|
|
|
train_dataset.dataset._prepare_samples()
|
|
val_dataset.dataset._prepare_samples()
|
|
|
|
return train_dataset, val_dataset
|
|
|
|
|
|
|
|
class ExpandingGravPotDataset(GravPotDataset):
|
|
"""
|
|
A dataset that slowly increases the number of samples to help with training stability.
|
|
"""
|
|
|
|
def __init__(self, batch_size: int = 32, epoch_time_scale: int = 100, epoch_reduction: float = 2, **kwargs):
|
|
super().__init__(**kwargs)
|
|
self.batch_size = batch_size
|
|
self.epoch_time_scale = epoch_time_scale
|
|
self.epoch_reduction = epoch_reduction
|
|
self.current_epoch = 0
|
|
self.last_epoch_expansion = 0
|
|
self.n_epochs_before_expanding = epoch_time_scale
|
|
self.global_max_len = kwargs.get('max_len', int(1e6))
|
|
self.max_len = 1
|
|
|
|
def on_epoch_end(self):
|
|
"""
|
|
Call this at the end of each epoch to regenerate offset + time choices.
|
|
Also expands the dataset size based on the current epoch.
|
|
"""
|
|
self.current_epoch += 1
|
|
|
|
# Expand dataset size every n_epochs_before_expanding epochs
|
|
if (self.current_epoch - self.last_epoch_expansion) >= self.n_epochs_before_expanding:
|
|
self.max_len = min(self.global_max_len, self.max_len)
|
|
self.n_epochs_before_expanding = int(self.n_epochs_before_expanding / self.epoch_reduction)
|
|
self.last_epoch_expansion = self.current_epoch
|
|
self.max_len = int(self.max_len * self.batch_size)
|
|
print(f"Expanding dataset at epoch {self.current_epoch}, new max_len: {self.max_len}")
|
|
self._prepare_samples()
|
|
|
|
|
|
|
|
def convert_GravPotDataset_to_ExpandingGravPotDataset(dataset: GravPotDataset, **kwargs):
|
|
"""
|
|
Converts a GravPotDataset to an ExpandingGravPotDataset.
|
|
"""
|
|
if not isinstance(dataset, GravPotDataset):
|
|
raise TypeError("Input dataset must be an instance of GravPotDataset.")
|
|
|
|
dataset = ExpandingGravPotDataset(
|
|
root_dir=dataset.root_dir,
|
|
ids=dataset.ids,
|
|
N=dataset.N,
|
|
N_full=dataset.N_full,
|
|
match_str=dataset.match_str,
|
|
max_len=dataset.max_len,
|
|
device=dataset.device,
|
|
initial_conditions_variables=dataset.initial_conditions_variables,
|
|
target_variable=dataset.target_variable,
|
|
style_files=dataset.style_files,
|
|
style_keys=dataset.style_keys,
|
|
max_time=dataset.max_time,
|
|
batch_size=kwargs.get('batch_size', 32),
|
|
epoch_time_scale=kwargs.get('epoch_time_scale', 100),
|
|
epoch_reduction=kwargs.get('epoch_reduction', 2)
|
|
)
|
|
dataset._prepare_samples() # Prepare initial samples
|
|
return dataset
|