improvements
This commit is contained in:
parent
6c526d7115
commit
58f9b27e6e
3 changed files with 153 additions and 7 deletions
|
@ -186,16 +186,18 @@ class GravPotDataset(Dataset):
|
|||
'style': style_path
|
||||
}
|
||||
|
||||
|
||||
def __getitem__(self, idx):
|
||||
def get_data(self, ID, t, ox, oy, oz):
|
||||
"""
|
||||
Get the data for a specific ID, time, and offsets.
|
||||
Returns a dictionary with input, target, and style tensors.
|
||||
"""
|
||||
|
||||
from pysbmy.field import read_field_chunk_3D_periodic
|
||||
from io import BytesIO
|
||||
import torch
|
||||
from sbmy_control.low_level import stdout_redirector, stderr_redirector
|
||||
f = BytesIO()
|
||||
|
||||
ID, t, ox, oy, oz = self.samples[idx]
|
||||
|
||||
# Filepaths
|
||||
input_paths = [
|
||||
os.path.join(self.root_dir, self.INITIAL_CONDITIONS_DIR, f'ICs_{ID}_{var}.h5')
|
||||
|
@ -236,6 +238,15 @@ class GravPotDataset(Dataset):
|
|||
'time': t,
|
||||
'offset': (ox, oy, oz)
|
||||
}
|
||||
|
||||
|
||||
|
||||
def __getitem__(self, idx):
|
||||
|
||||
ID, t, ox, oy, oz = self.samples[idx]
|
||||
return self.get_data(ID, t, ox, oy, oz)
|
||||
|
||||
|
||||
|
||||
def on_epoch_end(self):
|
||||
"""Call this at the end of each epoch to regenerate offset + time choices."""
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue