expansion

This commit is contained in:
Mayeul Aubin 2025-06-06 13:52:17 +02:00
parent 24c2d546db
commit c07ec8f8cf
6 changed files with 138 additions and 1 deletions

View file

@ -206,3 +206,21 @@ class GravPotDataset(Dataset):
def on_epoch_end(self):
"""Call this at the end of each epoch to regenerate offset + time choices."""
self._prepare_samples()
class SubDataset(Dataset):
def __init__(self, dataset: GravPotDataset, indices: list):
self.dataset = dataset
self.indices = indices
def __len__(self):
return len(self.indices)
def __getitem__(self, idx):
return self.dataset[self.indices[idx]]
def on_epoch_end(self):
self.dataset.on_epoch_end()