max len for dataset gravpot
This commit is contained in:
parent
e4308c003b
commit
6dcf0313bb
1 changed files with 11 additions and 1 deletions
|
@ -36,6 +36,7 @@ class GravPotDataset(Dataset):
|
|||
N:int=128,
|
||||
N_full:int=768,
|
||||
match_str:str='train',
|
||||
max_len:int|None=None,
|
||||
device='cpu',
|
||||
initial_conditions_variables:tuple|list=['DM_delta', 'DM_phi'],
|
||||
target_variable:str='gravpot',
|
||||
|
@ -57,6 +58,7 @@ class GravPotDataset(Dataset):
|
|||
self.style_files = style_files
|
||||
self.style_keys = style_keys
|
||||
self.max_time = max_time
|
||||
self.max_len = max_len
|
||||
|
||||
self.root_dir = root_dir
|
||||
self.N = N
|
||||
|
@ -129,7 +131,7 @@ class GravPotDataset(Dataset):
|
|||
|
||||
|
||||
|
||||
def _prepare_samples(self):
|
||||
def _prepare_samples(self):
|
||||
self.samples.clear()
|
||||
for ID in self.ids:
|
||||
times = self._get_valid_times(ID)
|
||||
|
@ -149,6 +151,9 @@ class GravPotDataset(Dataset):
|
|||
offset_y = (base_offset_y + j * self.N) % self.N_full
|
||||
offset_z = (base_offset_z + k * self.N) % self.N_full
|
||||
self.samples.append((ID, selected_time, offset_x, offset_y, offset_z))
|
||||
|
||||
if self.max_len is not None and len(self.samples) > self.max_len:
|
||||
self.samples = random.sample(self.samples, int(self.max_len))
|
||||
|
||||
def __len__(self):
|
||||
return len(self.samples)
|
||||
|
@ -288,6 +293,11 @@ def train_val_split(dataset: GravPotDataset, val_fraction: float = 0.2, seed: in
|
|||
train_ids, val_ids = train_test_split(dataset.ids, test_size=val_fraction, random_state=seed)
|
||||
train_dataset = SubDataset(dataset, train_ids)
|
||||
val_dataset = SubDataset(dataset, val_ids)
|
||||
|
||||
if dataset.max_len is not None:
|
||||
train_dataset.dataset.max_len = int(dataset.max_len * (1 - val_fraction))
|
||||
val_dataset.dataset.max_len = int(dataset.max_len * val_fraction)
|
||||
|
||||
train_dataset.dataset._prepare_samples()
|
||||
val_dataset.dataset._prepare_samples()
|
||||
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue