dataset improvement

This commit is contained in:
Mayeul Aubin 2025-06-05 17:30:38 +02:00
parent 26af105195
commit 24c2d546db

View file

@ -7,6 +7,7 @@ from glob import glob
import re import re
def read_cosmo_and_time_file(cosmo_and_time_file): def read_cosmo_and_time_file(cosmo_and_time_file):
with open(cosmo_and_time_file, 'r') as f: with open(cosmo_and_time_file, 'r') as f:
lines = f.readlines() lines = f.readlines()
@ -154,6 +155,10 @@ class GravPotDataset(Dataset):
def __getitem__(self, idx): def __getitem__(self, idx):
from pysbmy.field import read_field_chunk_3D_periodic from pysbmy.field import read_field_chunk_3D_periodic
from io import BytesIO
from sbmy_control.low_level import stdout_redirector, stderr_redirector
f = BytesIO()
ID, t, ox, oy, oz = self.samples[idx] ID, t, ox, oy, oz = self.samples[idx]
# Filepaths # Filepaths
@ -165,11 +170,12 @@ class GravPotDataset(Dataset):
style_path = os.path.join(self.root_dir, self.STYLE_DIR, f'{self.style_files}_{ID}_nforce{t}.txt') style_path = os.path.join(self.root_dir, self.STYLE_DIR, f'{self.style_files}_{ID}_nforce{t}.txt')
# Read 3D chunks # Read 3D chunks
input_arrays = [ with stdout_redirector(f):
read_field_chunk_3D_periodic(file, self.N,self.N,self.N, ox,oy,oz, name=varname).array input_arrays = [
for file, varname in zip(input_paths, self.initial_conditions_variables) read_field_chunk_3D_periodic(file, self.N,self.N,self.N, ox,oy,oz, name=varname).array
] for file, varname in zip(input_paths, self.initial_conditions_variables)
target_array = read_field_chunk_3D_periodic(target_path, self.N, self.N, self.N, ox, oy, oz, name=self.target_variable).array ]
target_array = read_field_chunk_3D_periodic(target_path, self.N, self.N, self.N, ox, oy, oz, name=self.target_variable).array
# Stack the input arrays # Stack the input arrays
input_tensor = np.stack(input_arrays, axis=0) input_tensor = np.stack(input_arrays, axis=0)