From c7cf9fe7d45ebedcf7cc8c7d420eea27a54f4807 Mon Sep 17 00:00:00 2001 From: Mayeul Aubin Date: Mon, 30 Jun 2025 14:29:33 +0200 Subject: [PATCH] minor changes to momenta dataset --- sCOCA_ML/dataset/momenta_dataset.py | 7 ++++++- 1 file changed, 6 insertions(+), 1 deletion(-) diff --git a/sCOCA_ML/dataset/momenta_dataset.py b/sCOCA_ML/dataset/momenta_dataset.py index ffa062f..5fdc6bb 100644 --- a/sCOCA_ML/dataset/momenta_dataset.py +++ b/sCOCA_ML/dataset/momenta_dataset.py @@ -36,9 +36,10 @@ class MomentaDataset(Dataset): N:int=128, N_full:int=768, match_str:str='train', + max_len:int|None=None, device='cpu', initial_conditions_variables:tuple|list=['DM_delta', 'DM_phi'], - target_variable:str='gravpot', + target_variable:str='momenta', style_files:str='cosmo_and_time_parameters', style_keys:list|None=["a"], max_time:int=100): @@ -57,6 +58,7 @@ class MomentaDataset(Dataset): self.style_files = style_files self.style_keys = style_keys self.max_time = max_time + self.max_len = max_len self.root_dir = root_dir self.N = N @@ -149,6 +151,9 @@ class MomentaDataset(Dataset): offset_y = (base_offset_y + j * self.N) % self.N_full offset_z = (base_offset_z + k * self.N) % self.N_full self.samples.append((ID, selected_time, offset_x, offset_y, offset_z)) + + if self.max_len is not None and len(self.samples) > self.max_len: + self.samples = random.sample(self.samples, int(self.max_len)) def __len__(self): return len(self.samples)