improvements

This commit is contained in:
Mayeul Aubin 2025-06-24 09:26:26 +02:00
parent 6c526d7115
commit 58f9b27e6e
3 changed files with 153 additions and 7 deletions

View file

@ -186,16 +186,18 @@ class GravPotDataset(Dataset):
'style': style_path
}
def __getitem__(self, idx):
def get_data(self, ID, t, ox, oy, oz):
"""
Get the data for a specific ID, time, and offsets.
Returns a dictionary with input, target, and style tensors.
"""
from pysbmy.field import read_field_chunk_3D_periodic
from io import BytesIO
import torch
from sbmy_control.low_level import stdout_redirector, stderr_redirector
f = BytesIO()
ID, t, ox, oy, oz = self.samples[idx]
# Filepaths
input_paths = [
os.path.join(self.root_dir, self.INITIAL_CONDITIONS_DIR, f'ICs_{ID}_{var}.h5')
@ -236,6 +238,15 @@ class GravPotDataset(Dataset):
'time': t,
'offset': (ox, oy, oz)
}
def __getitem__(self, idx):
ID, t, ox, oy, oz = self.samples[idx]
return self.get_data(ID, t, ox, oy, oz)
def on_epoch_end(self):
"""Call this at the end of each epoch to regenerate offset + time choices."""