minor changes to momenta dataset
This commit is contained in:
parent
b44c2344aa
commit
c7cf9fe7d4
1 changed files with 6 additions and 1 deletions
|
@ -36,9 +36,10 @@ class MomentaDataset(Dataset):
|
|||
N:int=128,
|
||||
N_full:int=768,
|
||||
match_str:str='train',
|
||||
max_len:int|None=None,
|
||||
device='cpu',
|
||||
initial_conditions_variables:tuple|list=['DM_delta', 'DM_phi'],
|
||||
target_variable:str='gravpot',
|
||||
target_variable:str='momenta',
|
||||
style_files:str='cosmo_and_time_parameters',
|
||||
style_keys:list|None=["a"],
|
||||
max_time:int=100):
|
||||
|
@ -57,6 +58,7 @@ class MomentaDataset(Dataset):
|
|||
self.style_files = style_files
|
||||
self.style_keys = style_keys
|
||||
self.max_time = max_time
|
||||
self.max_len = max_len
|
||||
|
||||
self.root_dir = root_dir
|
||||
self.N = N
|
||||
|
@ -149,6 +151,9 @@ class MomentaDataset(Dataset):
|
|||
offset_y = (base_offset_y + j * self.N) % self.N_full
|
||||
offset_z = (base_offset_z + k * self.N) % self.N_full
|
||||
self.samples.append((ID, selected_time, offset_x, offset_y, offset_z))
|
||||
|
||||
if self.max_len is not None and len(self.samples) > self.max_len:
|
||||
self.samples = random.sample(self.samples, int(self.max_len))
|
||||
|
||||
def __len__(self):
|
||||
return len(self.samples)
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue