From 6dcf0313bb89ac2391d40530ebec44a102f31d70 Mon Sep 17 00:00:00 2001 From: Mayeul Aubin Date: Thu, 26 Jun 2025 11:31:48 +0200 Subject: [PATCH] max len for dataset gravpot --- sCOCA_ML/dataset/gravpot_dataset.py | 12 +++++++++++- 1 file changed, 11 insertions(+), 1 deletion(-) diff --git a/sCOCA_ML/dataset/gravpot_dataset.py b/sCOCA_ML/dataset/gravpot_dataset.py index c380cda..2c637ec 100644 --- a/sCOCA_ML/dataset/gravpot_dataset.py +++ b/sCOCA_ML/dataset/gravpot_dataset.py @@ -36,6 +36,7 @@ class GravPotDataset(Dataset): N:int=128, N_full:int=768, match_str:str='train', + max_len:int|None=None, device='cpu', initial_conditions_variables:tuple|list=['DM_delta', 'DM_phi'], target_variable:str='gravpot', @@ -57,6 +58,7 @@ class GravPotDataset(Dataset): self.style_files = style_files self.style_keys = style_keys self.max_time = max_time + self.max_len = max_len self.root_dir = root_dir self.N = N @@ -129,7 +131,7 @@ class GravPotDataset(Dataset): - def _prepare_samples(self): + def _prepare_samples(self): self.samples.clear() for ID in self.ids: times = self._get_valid_times(ID) @@ -149,6 +151,9 @@ class GravPotDataset(Dataset): offset_y = (base_offset_y + j * self.N) % self.N_full offset_z = (base_offset_z + k * self.N) % self.N_full self.samples.append((ID, selected_time, offset_x, offset_y, offset_z)) + + if self.max_len is not None and len(self.samples) > self.max_len: + self.samples = random.sample(self.samples, int(self.max_len)) def __len__(self): return len(self.samples) @@ -288,6 +293,11 @@ def train_val_split(dataset: GravPotDataset, val_fraction: float = 0.2, seed: in train_ids, val_ids = train_test_split(dataset.ids, test_size=val_fraction, random_state=seed) train_dataset = SubDataset(dataset, train_ids) val_dataset = SubDataset(dataset, val_ids) + + if dataset.max_len is not None: + train_dataset.dataset.max_len = int(dataset.max_len * (1 - val_fraction)) + val_dataset.dataset.max_len = int(dataset.max_len * val_fraction) + train_dataset.dataset._prepare_samples() val_dataset.dataset._prepare_samples()