import os import random import numpy as np import torch from torch.utils.data import Dataset, DataLoader from glob import glob import re def read_cosmo_and_time_file(cosmo_and_time_file): with open(cosmo_and_time_file, 'r') as f: lines = f.readlines() cosmo_and_time_params = {} for line in lines: if line.strip(): # Skip empty lines key, value = line.split(':') if key.strip() == 'ID': cosmo_and_time_params['ID'] = value.strip() elif key.strip() == 'nforce': cosmo_and_time_params['nforce'] = int(value.strip()) else: cosmo_and_time_params[key.strip()] = float(value.strip()) return cosmo_and_time_params class MomentaDataset(Dataset): INITIAL_CONDITIONS_DIR = 'initial_conditions' TARGET_DIR = 'momenta' STYLE_DIR = 'cosmo_and_time' def __init__(self, root_dir:str, ids:list|None=None, N:int=128, N_full:int=768, match_str:str='train', device='cpu', initial_conditions_variables:tuple|list=['DM_delta', 'DM_phi'], target_variable:str='gravpot', style_files:str='cosmo_and_time_parameters', style_keys:list|None=["D1", "D2"], max_time:int=100): """ Dataset for residual momenta data (COCA). Parameters: - root_dir: Directory containing the dataset. - ids: List of IDs to include in the dataset. If None, will discover IDs. - N: Size of the chunks to read (N x N x N). - N_full: Full size of the simulation box (N_full x N_full x N_full). - device: Device to load tensors onto (default is CPU).""" super().__init__() self.initial_conditions_variables = initial_conditions_variables self.target_variable = target_variable self.style_files = style_files self.style_keys = style_keys self.max_time = max_time self.root_dir = root_dir self.N = N self.N_full = N_full self.device = device # Compute how many chunks per dimension self.chunks_per_dim = N_full // N self.chunks_per_entry = self.chunks_per_dim ** 3 # This will hold (ID, time, offsets) self.samples = [] if ids is None: ids = self.discover_ids(match_str=match_str) self.ids = ids # Build indexable sample list self._prepare_samples() def discover_ids(self, match_str:str='train'): """ Discover IDs that contain match_str and have at least one valid time step with all required files present. """ pattern = os.path.join(self.root_dir, self.INITIAL_CONDITIONS_DIR, 'ICs_*') files = glob(pattern) valid_ids = [] for file in files: match = re.search(r'ICs_(.+?)_'+f"{self.initial_conditions_variables[0]}.h5", os.path.basename(file)) if not match: continue ID = match.group(1) if match_str not in ID: continue # Check if corresponding initial conditions file exists for var in self.initial_conditions_variables[1:]: ic_path = os.path.join(self.root_dir, self.INITIAL_CONDITIONS_DIR, f'ICs_{ID}_{var}.h5') if not os.path.exists(ic_path): continue # Check if at least one time has both target and style params found_valid_time = False for t in range(self.max_time): target_path = os.path.join(self.root_dir, self.TARGET_DIR, f'{self.target_variable}_{ID}_{t}.h5') style_path = os.path.join(self.root_dir, self.STYLE_DIR, f'{self.style_files}_{ID}_nforce{t}.txt') if os.path.exists(target_path) and os.path.exists(style_path): found_valid_time = True break if found_valid_time: valid_ids.append(ID) return sorted(valid_ids) def _get_valid_times(self, ID): """Returns valid time indices for which all required files exist.""" valid_times = [] for t in range(100): # arbitrary upper limit target_path = os.path.join(self.root_dir, self.TARGET_DIR, f'{self.target_variable}_{ID}_{t}.h5') style_path = os.path.join(self.root_dir, self.STYLE_DIR, f'{self.style_files}_{ID}_nforce{t}.txt') if all(os.path.exists(p) for p in [target_path, style_path]): valid_times.append(t) return valid_times def _prepare_samples(self): self.samples.clear() for ID in self.ids: times = self._get_valid_times(ID) if not times: continue # Random offset for epoch start base_offset_x = random.randint(0, self.N_full) base_offset_y = random.randint(0, self.N_full) base_offset_z = random.randint(0, self.N_full) for i in range(self.chunks_per_dim): for j in range(self.chunks_per_dim): for k in range(self.chunks_per_dim): selected_time = random.choice(times) offset_x = (base_offset_x + i * self.N) % self.N_full offset_y = (base_offset_y + j * self.N) % self.N_full offset_z = (base_offset_z + k * self.N) % self.N_full self.samples.append((ID, selected_time, offset_x, offset_y, offset_z)) def __len__(self): return len(self.samples) def files_from_samples(self, sample): """ Return the paths to the files for a given sample. """ ID, t, ox, oy, oz = sample input_paths = [ os.path.join(self.root_dir, self.INITIAL_CONDITIONS_DIR, f'ICs_{ID}_{var}.h5') for var in self.initial_conditions_variables ] target_path = os.path.join(self.root_dir, self.TARGET_DIR, f'{self.target_variable}_{ID}_{t}.h5') style_path = os.path.join(self.root_dir, self.STYLE_DIR, f'{self.style_files}_{ID}_nforce{t}.txt') return { 'input': input_paths, 'target': target_path, 'style': style_path } def files_from_ID_and_time(self, ID, t): """ Return the paths to the files for a given ID and time. """ input_paths = [ os.path.join(self.root_dir, self.INITIAL_CONDITIONS_DIR, f'ICs_{ID}_{var}.h5') for var in self.initial_conditions_variables ] target_path = os.path.join(self.root_dir, self.TARGET_DIR, f'{self.target_variable}_{ID}_{t}.h5') style_path = os.path.join(self.root_dir, self.STYLE_DIR, f'{self.style_files}_{ID}_nforce{t}.txt') return { 'input': input_paths, 'target': target_path, 'style': style_path } def get_data(self, ID, t, ox, oy, oz): """ Get the data for a specific ID, time, and offsets. Returns a dictionary with input, target, and style tensors. """ from pysbmy.field import read_field_chunk_periodic from io import BytesIO import torch from sbmy_control.low_level import stdout_redirector, stderr_redirector f = BytesIO() # Filepaths input_paths = [ os.path.join(self.root_dir, self.INITIAL_CONDITIONS_DIR, f'ICs_{ID}_{var}.h5') for var in self.initial_conditions_variables ] target_path = os.path.join(self.root_dir, self.TARGET_DIR, f'{self.target_variable}_{ID}_nforce{t}.h5') style_path = os.path.join(self.root_dir, self.STYLE_DIR, f'{self.style_files}_{ID}_nforce{t}.txt') # Read 3D chunks input_arrays = [ read_field_chunk_periodic(file, self.N,self.N,self.N, ox,oy,oz, name=varname).array for file, varname in zip(input_paths, self.initial_conditions_variables) ] target_array = read_field_chunk_periodic(target_path, self.N, self.N, self.N, ox, oy, oz, name=self.target_variable).array target_array = np.moveaxis(target_array, -1, 0) # Move channel dimension to the front # Stack the input arrays input_tensor = np.stack(input_arrays, axis=0) input_tensor = torch.tensor(input_tensor, dtype=torch.float32).to(self.device) # Target target_tensor = torch.tensor(target_array, dtype=torch.float32).unsqueeze(0).to(self.device) # Style parameters style_params = read_cosmo_and_time_file(style_path) # Select only the specified style keys if self.style_keys is not None: style_params = [style_params[key] for key in self.style_keys if key in style_params] else: style_params = list(style_params.values()) # Convert to tensor style_tensor = torch.tensor(style_params, dtype=torch.float32).to(self.device) return { 'input': input_tensor, 'target': target_tensor, 'style': style_tensor, 'ID': ID, 'time': t, 'offset': (ox, oy, oz) } def __getitem__(self, idx): ID, t, ox, oy, oz = self.samples[idx] return self.get_data(ID, t, ox, oy, oz) def on_epoch_end(self): """Call this at the end of each epoch to regenerate offset + time choices.""" self._prepare_samples() class SubDataset(Dataset): def __init__(self, dataset: MomentaDataset, ID_list: list): from copy import deepcopy self.dataset = deepcopy(dataset) self.ids = ID_list self.dataset.ids = ID_list def __len__(self): return len(self.dataset) def __getitem__(self, idx): return self.dataset[idx] def on_epoch_end(self): self.dataset.ids = self.ids self.dataset.on_epoch_end() def train_val_split(dataset: MomentaDataset, val_fraction: float = 0.2, seed: int = 42): """ Splits the dataset into training and validation sets. Parameters: - dataset: The GravPotDataset to split. - val_fraction: Fraction of the dataset to use for validation. Returns: - train_dataset: SubDataset for training. - val_dataset: SubDataset for validation. """ from sklearn.model_selection import train_test_split train_ids, val_ids = train_test_split(dataset.ids, test_size=val_fraction, random_state=seed) train_dataset = SubDataset(dataset, train_ids) val_dataset = SubDataset(dataset, val_ids) train_dataset.dataset._prepare_samples() val_dataset.dataset._prepare_samples() return train_dataset, val_dataset