dataset improvement
This commit is contained in:
parent
26af105195
commit
24c2d546db
1 changed files with 11 additions and 5 deletions
|
@ -7,6 +7,7 @@ from glob import glob
|
||||||
import re
|
import re
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
def read_cosmo_and_time_file(cosmo_and_time_file):
|
def read_cosmo_and_time_file(cosmo_and_time_file):
|
||||||
with open(cosmo_and_time_file, 'r') as f:
|
with open(cosmo_and_time_file, 'r') as f:
|
||||||
lines = f.readlines()
|
lines = f.readlines()
|
||||||
|
@ -154,6 +155,10 @@ class GravPotDataset(Dataset):
|
||||||
|
|
||||||
def __getitem__(self, idx):
|
def __getitem__(self, idx):
|
||||||
from pysbmy.field import read_field_chunk_3D_periodic
|
from pysbmy.field import read_field_chunk_3D_periodic
|
||||||
|
from io import BytesIO
|
||||||
|
from sbmy_control.low_level import stdout_redirector, stderr_redirector
|
||||||
|
f = BytesIO()
|
||||||
|
|
||||||
ID, t, ox, oy, oz = self.samples[idx]
|
ID, t, ox, oy, oz = self.samples[idx]
|
||||||
|
|
||||||
# Filepaths
|
# Filepaths
|
||||||
|
@ -165,11 +170,12 @@ class GravPotDataset(Dataset):
|
||||||
style_path = os.path.join(self.root_dir, self.STYLE_DIR, f'{self.style_files}_{ID}_nforce{t}.txt')
|
style_path = os.path.join(self.root_dir, self.STYLE_DIR, f'{self.style_files}_{ID}_nforce{t}.txt')
|
||||||
|
|
||||||
# Read 3D chunks
|
# Read 3D chunks
|
||||||
input_arrays = [
|
with stdout_redirector(f):
|
||||||
read_field_chunk_3D_periodic(file, self.N,self.N,self.N, ox,oy,oz, name=varname).array
|
input_arrays = [
|
||||||
for file, varname in zip(input_paths, self.initial_conditions_variables)
|
read_field_chunk_3D_periodic(file, self.N,self.N,self.N, ox,oy,oz, name=varname).array
|
||||||
]
|
for file, varname in zip(input_paths, self.initial_conditions_variables)
|
||||||
target_array = read_field_chunk_3D_periodic(target_path, self.N, self.N, self.N, ox, oy, oz, name=self.target_variable).array
|
]
|
||||||
|
target_array = read_field_chunk_3D_periodic(target_path, self.N, self.N, self.N, ox, oy, oz, name=self.target_variable).array
|
||||||
|
|
||||||
# Stack the input arrays
|
# Stack the input arrays
|
||||||
input_tensor = np.stack(input_arrays, axis=0)
|
input_tensor = np.stack(input_arrays, axis=0)
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue