many improvements
This commit is contained in:
parent
c07ec8f8cf
commit
6c526d7115
4 changed files with 219 additions and 53 deletions
|
@ -36,7 +36,7 @@ class GravPotDataset(Dataset):
|
|||
N:int=128,
|
||||
N_full:int=768,
|
||||
match_str:str='train',
|
||||
device=torch.device('cpu'),
|
||||
device='cpu',
|
||||
initial_conditions_variables:tuple|list=['DM_delta', 'DM_phi'],
|
||||
target_variable:str='gravpot',
|
||||
style_files:str='cosmo_and_time_parameters',
|
||||
|
@ -50,6 +50,7 @@ class GravPotDataset(Dataset):
|
|||
- N: Size of the chunks to read (N x N x N).
|
||||
- N_full: Full size of the simulation box (N_full x N_full x N_full).
|
||||
- device: Device to load tensors onto (default is CPU)."""
|
||||
super().__init__()
|
||||
|
||||
self.initial_conditions_variables = initial_conditions_variables
|
||||
self.target_variable = target_variable
|
||||
|
@ -152,10 +153,44 @@ class GravPotDataset(Dataset):
|
|||
def __len__(self):
|
||||
return len(self.samples)
|
||||
|
||||
def files_from_samples(self, sample):
|
||||
"""
|
||||
Return the paths to the files for a given sample.
|
||||
"""
|
||||
ID, t, ox, oy, oz = sample
|
||||
input_paths = [
|
||||
os.path.join(self.root_dir, self.INITIAL_CONDITIONS_DIR, f'ICs_{ID}_{var}.h5')
|
||||
for var in self.initial_conditions_variables
|
||||
]
|
||||
target_path = os.path.join(self.root_dir, self.TARGET_DIR, f'{self.target_variable}_{ID}_nforce{t}.h5')
|
||||
style_path = os.path.join(self.root_dir, self.STYLE_DIR, f'{self.style_files}_{ID}_nforce{t}.txt')
|
||||
return {
|
||||
'input': input_paths,
|
||||
'target': target_path,
|
||||
'style': style_path
|
||||
}
|
||||
|
||||
def files_from_ID_and_time(self, ID, t):
|
||||
"""
|
||||
Return the paths to the files for a given ID and time.
|
||||
"""
|
||||
input_paths = [
|
||||
os.path.join(self.root_dir, self.INITIAL_CONDITIONS_DIR, f'ICs_{ID}_{var}.h5')
|
||||
for var in self.initial_conditions_variables
|
||||
]
|
||||
target_path = os.path.join(self.root_dir, self.TARGET_DIR, f'{self.target_variable}_{ID}_nforce{t}.h5')
|
||||
style_path = os.path.join(self.root_dir, self.STYLE_DIR, f'{self.style_files}_{ID}_nforce{t}.txt')
|
||||
return {
|
||||
'input': input_paths,
|
||||
'target': target_path,
|
||||
'style': style_path
|
||||
}
|
||||
|
||||
|
||||
def __getitem__(self, idx):
|
||||
from pysbmy.field import read_field_chunk_3D_periodic
|
||||
from io import BytesIO
|
||||
import torch
|
||||
from sbmy_control.low_level import stdout_redirector, stderr_redirector
|
||||
f = BytesIO()
|
||||
|
||||
|
@ -170,12 +205,11 @@ class GravPotDataset(Dataset):
|
|||
style_path = os.path.join(self.root_dir, self.STYLE_DIR, f'{self.style_files}_{ID}_nforce{t}.txt')
|
||||
|
||||
# Read 3D chunks
|
||||
with stdout_redirector(f):
|
||||
input_arrays = [
|
||||
read_field_chunk_3D_periodic(file, self.N,self.N,self.N, ox,oy,oz, name=varname).array
|
||||
for file, varname in zip(input_paths, self.initial_conditions_variables)
|
||||
]
|
||||
target_array = read_field_chunk_3D_periodic(target_path, self.N, self.N, self.N, ox, oy, oz, name=self.target_variable).array
|
||||
input_arrays = [
|
||||
read_field_chunk_3D_periodic(file, self.N,self.N,self.N, ox,oy,oz, name=varname).array
|
||||
for file, varname in zip(input_paths, self.initial_conditions_variables)
|
||||
]
|
||||
target_array = read_field_chunk_3D_periodic(target_path, self.N, self.N, self.N, ox, oy, oz, name=self.target_variable).array
|
||||
|
||||
# Stack the input arrays
|
||||
input_tensor = np.stack(input_arrays, axis=0)
|
||||
|
@ -206,21 +240,45 @@ class GravPotDataset(Dataset):
|
|||
def on_epoch_end(self):
|
||||
"""Call this at the end of each epoch to regenerate offset + time choices."""
|
||||
self._prepare_samples()
|
||||
|
||||
|
||||
class SubDataset(Dataset):
|
||||
def __init__(self, dataset: GravPotDataset, indices: list):
|
||||
self.dataset = dataset
|
||||
self.indices = indices
|
||||
def __init__(self, dataset: GravPotDataset, ID_list: list):
|
||||
from copy import deepcopy
|
||||
self.dataset = deepcopy(dataset)
|
||||
self.ids = ID_list
|
||||
self.dataset.ids = ID_list
|
||||
|
||||
def __len__(self):
|
||||
return len(self.indices)
|
||||
return len(self.dataset)
|
||||
|
||||
def __getitem__(self, idx):
|
||||
return self.dataset[self.indices[idx]]
|
||||
return self.dataset[idx]
|
||||
|
||||
def on_epoch_end(self):
|
||||
self.dataset.ids = self.ids
|
||||
self.dataset.on_epoch_end()
|
||||
|
||||
|
||||
|
||||
|
||||
def train_val_split(dataset: GravPotDataset, val_fraction: float = 0.2, seed: int = 42):
|
||||
"""
|
||||
Splits the dataset into training and validation sets.
|
||||
|
||||
Parameters:
|
||||
- dataset: The GravPotDataset to split.
|
||||
- val_fraction: Fraction of the dataset to use for validation.
|
||||
|
||||
Returns:
|
||||
- train_dataset: SubDataset for training.
|
||||
- val_dataset: SubDataset for validation.
|
||||
"""
|
||||
from sklearn.model_selection import train_test_split
|
||||
train_ids, val_ids = train_test_split(dataset.ids, test_size=0.2, random_state=seed)
|
||||
train_dataset = SubDataset(dataset, train_ids)
|
||||
val_dataset = SubDataset(dataset, val_ids)
|
||||
train_dataset.dataset._prepare_samples()
|
||||
val_dataset.dataset._prepare_samples()
|
||||
|
||||
return train_dataset, val_dataset
|
||||
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue