From 066a5b395824bf3eba0f98cd85518c486c3cd153 Mon Sep 17 00:00:00 2001 From: Mayeul Aubin Date: Wed, 25 Jun 2025 17:12:01 +0200 Subject: [PATCH] momenta dataset --- sCOCA_ML/dataset/momenta_dataset.py | 296 ++++++++++++++++++++++++++++ 1 file changed, 296 insertions(+) create mode 100644 sCOCA_ML/dataset/momenta_dataset.py diff --git a/sCOCA_ML/dataset/momenta_dataset.py b/sCOCA_ML/dataset/momenta_dataset.py new file mode 100644 index 0000000..d8ecf1f --- /dev/null +++ b/sCOCA_ML/dataset/momenta_dataset.py @@ -0,0 +1,296 @@ +import os +import random +import numpy as np +import torch +from torch.utils.data import Dataset, DataLoader +from glob import glob +import re + + + +def read_cosmo_and_time_file(cosmo_and_time_file): + with open(cosmo_and_time_file, 'r') as f: + lines = f.readlines() + cosmo_and_time_params = {} + for line in lines: + if line.strip(): # Skip empty lines + key, value = line.split(':') + if key.strip() == 'ID': + cosmo_and_time_params['ID'] = value.strip() + elif key.strip() == 'nforce': + cosmo_and_time_params['nforce'] = int(value.strip()) + else: + cosmo_and_time_params[key.strip()] = float(value.strip()) + return cosmo_and_time_params + + +class MomentaDataset(Dataset): + + INITIAL_CONDITIONS_DIR = 'initial_conditions' + TARGET_DIR = 'momenta' + STYLE_DIR = 'cosmo_and_time' + + def __init__(self, + root_dir:str, + ids:list|None=None, + N:int=128, + N_full:int=768, + match_str:str='train', + device='cpu', + initial_conditions_variables:tuple|list=['DM_delta', 'DM_phi'], + target_variable:str='gravpot', + style_files:str='cosmo_and_time_parameters', + style_keys:list|None=["D1", "D2"], + max_time:int=100): + """ + Dataset for residual momenta data (COCA). + Parameters: + - root_dir: Directory containing the dataset. + - ids: List of IDs to include in the dataset. If None, will discover IDs. + - N: Size of the chunks to read (N x N x N). + - N_full: Full size of the simulation box (N_full x N_full x N_full). + - device: Device to load tensors onto (default is CPU).""" + super().__init__() + + self.initial_conditions_variables = initial_conditions_variables + self.target_variable = target_variable + self.style_files = style_files + self.style_keys = style_keys + self.max_time = max_time + + self.root_dir = root_dir + self.N = N + self.N_full = N_full + self.device = device + + # Compute how many chunks per dimension + self.chunks_per_dim = N_full // N + self.chunks_per_entry = self.chunks_per_dim ** 3 + + # This will hold (ID, time, offsets) + self.samples = [] + + if ids is None: + ids = self.discover_ids(match_str=match_str) + self.ids = ids + + # Build indexable sample list + self._prepare_samples() + + + def discover_ids(self, match_str:str='train'): + """ + Discover IDs that contain match_str and have at least one valid time step + with all required files present. + """ + pattern = os.path.join(self.root_dir, self.INITIAL_CONDITIONS_DIR, 'ICs_*') + files = glob(pattern) + + valid_ids = [] + for file in files: + match = re.search(r'ICs_(.+?)_'+f"{self.initial_conditions_variables[0]}.h5", os.path.basename(file)) + if not match: + continue + + ID = match.group(1) + if match_str not in ID: + continue + + # Check if corresponding initial conditions file exists + for var in self.initial_conditions_variables[1:]: + ic_path = os.path.join(self.root_dir, self.INITIAL_CONDITIONS_DIR, f'ICs_{ID}_{var}.h5') + if not os.path.exists(ic_path): + continue + + # Check if at least one time has both target and style params + found_valid_time = False + for t in range(self.max_time): + target_path = os.path.join(self.root_dir, self.TARGET_DIR, f'{self.target_variable}_{ID}_{t}.h5') + style_path = os.path.join(self.root_dir, self.STYLE_DIR, f'{self.style_files}_{ID}_nforce{t}.txt') + if os.path.exists(target_path) and os.path.exists(style_path): + found_valid_time = True + break + + if found_valid_time: + valid_ids.append(ID) + + return sorted(valid_ids) + + + def _get_valid_times(self, ID): + """Returns valid time indices for which all required files exist.""" + valid_times = [] + for t in range(100): # arbitrary upper limit + target_path = os.path.join(self.root_dir, self.TARGET_DIR, f'{self.target_variable}_{ID}_{t}.h5') + style_path = os.path.join(self.root_dir, self.STYLE_DIR, f'{self.style_files}_{ID}_nforce{t}.txt') + if all(os.path.exists(p) for p in [target_path, style_path]): + valid_times.append(t) + return valid_times + + + + def _prepare_samples(self): + self.samples.clear() + for ID in self.ids: + times = self._get_valid_times(ID) + if not times: + continue + + # Random offset for epoch start + base_offset_x = random.randint(0, self.N_full) + base_offset_y = random.randint(0, self.N_full) + base_offset_z = random.randint(0, self.N_full) + + for i in range(self.chunks_per_dim): + for j in range(self.chunks_per_dim): + for k in range(self.chunks_per_dim): + selected_time = random.choice(times) + offset_x = (base_offset_x + i * self.N) % self.N_full + offset_y = (base_offset_y + j * self.N) % self.N_full + offset_z = (base_offset_z + k * self.N) % self.N_full + self.samples.append((ID, selected_time, offset_x, offset_y, offset_z)) + + def __len__(self): + return len(self.samples) + + def files_from_samples(self, sample): + """ + Return the paths to the files for a given sample. + """ + ID, t, ox, oy, oz = sample + input_paths = [ + os.path.join(self.root_dir, self.INITIAL_CONDITIONS_DIR, f'ICs_{ID}_{var}.h5') + for var in self.initial_conditions_variables + ] + target_path = os.path.join(self.root_dir, self.TARGET_DIR, f'{self.target_variable}_{ID}_{t}.h5') + style_path = os.path.join(self.root_dir, self.STYLE_DIR, f'{self.style_files}_{ID}_nforce{t}.txt') + return { + 'input': input_paths, + 'target': target_path, + 'style': style_path + } + + def files_from_ID_and_time(self, ID, t): + """ + Return the paths to the files for a given ID and time. + """ + input_paths = [ + os.path.join(self.root_dir, self.INITIAL_CONDITIONS_DIR, f'ICs_{ID}_{var}.h5') + for var in self.initial_conditions_variables + ] + target_path = os.path.join(self.root_dir, self.TARGET_DIR, f'{self.target_variable}_{ID}_{t}.h5') + style_path = os.path.join(self.root_dir, self.STYLE_DIR, f'{self.style_files}_{ID}_nforce{t}.txt') + return { + 'input': input_paths, + 'target': target_path, + 'style': style_path + } + + def get_data(self, ID, t, ox, oy, oz): + """ + Get the data for a specific ID, time, and offsets. + Returns a dictionary with input, target, and style tensors. + """ + + from pysbmy.field import read_field_chunk_periodic + from io import BytesIO + import torch + from sbmy_control.low_level import stdout_redirector, stderr_redirector + f = BytesIO() + + # Filepaths + input_paths = [ + os.path.join(self.root_dir, self.INITIAL_CONDITIONS_DIR, f'ICs_{ID}_{var}.h5') + for var in self.initial_conditions_variables + ] + target_path = os.path.join(self.root_dir, self.TARGET_DIR, f'{self.target_variable}_{ID}_nforce{t}.h5') + style_path = os.path.join(self.root_dir, self.STYLE_DIR, f'{self.style_files}_{ID}_nforce{t}.txt') + + # Read 3D chunks + input_arrays = [ + read_field_chunk_periodic(file, self.N,self.N,self.N, ox,oy,oz, name=varname).array + for file, varname in zip(input_paths, self.initial_conditions_variables) + ] + target_array = read_field_chunk_periodic(target_path, self.N, self.N, self.N, ox, oy, oz, name=self.target_variable).array + target_array = np.moveaxis(target_array, -1, 0) # Move channel dimension to the front + + # Stack the input arrays + input_tensor = np.stack(input_arrays, axis=0) + input_tensor = torch.tensor(input_tensor, dtype=torch.float32).to(self.device) + + # Target + target_tensor = torch.tensor(target_array, dtype=torch.float32).unsqueeze(0).to(self.device) + + # Style parameters + style_params = read_cosmo_and_time_file(style_path) + # Select only the specified style keys + if self.style_keys is not None: + style_params = [style_params[key] for key in self.style_keys if key in style_params] + else: + style_params = list(style_params.values()) + # Convert to tensor + style_tensor = torch.tensor(style_params, dtype=torch.float32).to(self.device) + + return { + 'input': input_tensor, + 'target': target_tensor, + 'style': style_tensor, + 'ID': ID, + 'time': t, + 'offset': (ox, oy, oz) + } + + + + def __getitem__(self, idx): + + ID, t, ox, oy, oz = self.samples[idx] + return self.get_data(ID, t, ox, oy, oz) + + + + def on_epoch_end(self): + """Call this at the end of each epoch to regenerate offset + time choices.""" + self._prepare_samples() + +class SubDataset(Dataset): + def __init__(self, dataset: MomentaDataset, ID_list: list): + from copy import deepcopy + self.dataset = deepcopy(dataset) + self.ids = ID_list + self.dataset.ids = ID_list + + def __len__(self): + return len(self.dataset) + + def __getitem__(self, idx): + return self.dataset[idx] + + def on_epoch_end(self): + self.dataset.ids = self.ids + self.dataset.on_epoch_end() + + + + +def train_val_split(dataset: MomentaDataset, val_fraction: float = 0.2, seed: int = 42): + """ + Splits the dataset into training and validation sets. + + Parameters: + - dataset: The GravPotDataset to split. + - val_fraction: Fraction of the dataset to use for validation. + + Returns: + - train_dataset: SubDataset for training. + - val_dataset: SubDataset for validation. + """ + from sklearn.model_selection import train_test_split + train_ids, val_ids = train_test_split(dataset.ids, test_size=val_fraction, random_state=seed) + train_dataset = SubDataset(dataset, train_ids) + val_dataset = SubDataset(dataset, val_ids) + train_dataset.dataset._prepare_samples() + val_dataset.dataset._prepare_samples() + + return train_dataset, val_dataset +