max len for dataset gravpot

This commit is contained in:
Mayeul Aubin 2025-06-26 11:31:48 +02:00
parent e4308c003b
commit 6dcf0313bb

View file

@ -36,6 +36,7 @@ class GravPotDataset(Dataset):
N:int=128,
N_full:int=768,
match_str:str='train',
max_len:int|None=None,
device='cpu',
initial_conditions_variables:tuple|list=['DM_delta', 'DM_phi'],
target_variable:str='gravpot',
@ -57,6 +58,7 @@ class GravPotDataset(Dataset):
self.style_files = style_files
self.style_keys = style_keys
self.max_time = max_time
self.max_len = max_len
self.root_dir = root_dir
self.N = N
@ -129,7 +131,7 @@ class GravPotDataset(Dataset):
def _prepare_samples(self):
def _prepare_samples(self):
self.samples.clear()
for ID in self.ids:
times = self._get_valid_times(ID)
@ -149,6 +151,9 @@ class GravPotDataset(Dataset):
offset_y = (base_offset_y + j * self.N) % self.N_full
offset_z = (base_offset_z + k * self.N) % self.N_full
self.samples.append((ID, selected_time, offset_x, offset_y, offset_z))
if self.max_len is not None and len(self.samples) > self.max_len:
self.samples = random.sample(self.samples, int(self.max_len))
def __len__(self):
return len(self.samples)
@ -288,6 +293,11 @@ def train_val_split(dataset: GravPotDataset, val_fraction: float = 0.2, seed: in
train_ids, val_ids = train_test_split(dataset.ids, test_size=val_fraction, random_state=seed)
train_dataset = SubDataset(dataset, train_ids)
val_dataset = SubDataset(dataset, val_ids)
if dataset.max_len is not None:
train_dataset.dataset.max_len = int(dataset.max_len * (1 - val_fraction))
val_dataset.dataset.max_len = int(dataset.max_len * val_fraction)
train_dataset.dataset._prepare_samples()
val_dataset.dataset._prepare_samples()