From 24c2d546dbe24931c0921ee263ae28f6a90cdd8c Mon Sep 17 00:00:00 2001 From: Mayeul Aubin Date: Thu, 5 Jun 2025 17:30:38 +0200 Subject: [PATCH] dataset improvement --- sCOCA_ML/dataset/gravpot_dataset.py | 16 +++++++++++----- 1 file changed, 11 insertions(+), 5 deletions(-) diff --git a/sCOCA_ML/dataset/gravpot_dataset.py b/sCOCA_ML/dataset/gravpot_dataset.py index 37b75ae..d369801 100644 --- a/sCOCA_ML/dataset/gravpot_dataset.py +++ b/sCOCA_ML/dataset/gravpot_dataset.py @@ -7,6 +7,7 @@ from glob import glob import re + def read_cosmo_and_time_file(cosmo_and_time_file): with open(cosmo_and_time_file, 'r') as f: lines = f.readlines() @@ -154,6 +155,10 @@ class GravPotDataset(Dataset): def __getitem__(self, idx): from pysbmy.field import read_field_chunk_3D_periodic + from io import BytesIO + from sbmy_control.low_level import stdout_redirector, stderr_redirector + f = BytesIO() + ID, t, ox, oy, oz = self.samples[idx] # Filepaths @@ -165,11 +170,12 @@ class GravPotDataset(Dataset): style_path = os.path.join(self.root_dir, self.STYLE_DIR, f'{self.style_files}_{ID}_nforce{t}.txt') # Read 3D chunks - input_arrays = [ - read_field_chunk_3D_periodic(file, self.N,self.N,self.N, ox,oy,oz, name=varname).array - for file, varname in zip(input_paths, self.initial_conditions_variables) - ] - target_array = read_field_chunk_3D_periodic(target_path, self.N, self.N, self.N, ox, oy, oz, name=self.target_variable).array + with stdout_redirector(f): + input_arrays = [ + read_field_chunk_3D_periodic(file, self.N,self.N,self.N, ox,oy,oz, name=varname).array + for file, varname in zip(input_paths, self.initial_conditions_variables) + ] + target_array = read_field_chunk_3D_periodic(target_path, self.N, self.N, self.N, ox, oy, oz, name=self.target_variable).array # Stack the input arrays input_tensor = np.stack(input_arrays, axis=0)