minor changes to momenta dataset

This commit is contained in:
Mayeul Aubin 2025-06-30 14:29:33 +02:00
parent b44c2344aa
commit c7cf9fe7d4

View file

@ -36,9 +36,10 @@ class MomentaDataset(Dataset):
N:int=128,
N_full:int=768,
match_str:str='train',
max_len:int|None=None,
device='cpu',
initial_conditions_variables:tuple|list=['DM_delta', 'DM_phi'],
target_variable:str='gravpot',
target_variable:str='momenta',
style_files:str='cosmo_and_time_parameters',
style_keys:list|None=["a"],
max_time:int=100):
@ -57,6 +58,7 @@ class MomentaDataset(Dataset):
self.style_files = style_files
self.style_keys = style_keys
self.max_time = max_time
self.max_len = max_len
self.root_dir = root_dir
self.N = N
@ -149,6 +151,9 @@ class MomentaDataset(Dataset):
offset_y = (base_offset_y + j * self.N) % self.N_full
offset_z = (base_offset_z + k * self.N) % self.N_full
self.samples.append((ID, selected_time, offset_x, offset_y, offset_z))
if self.max_len is not None and len(self.samples) > self.max_len:
self.samples = random.sample(self.samples, int(self.max_len))
def __len__(self):
return len(self.samples)