dataset that expands during training

This commit is contained in:
Mayeul Aubin 2025-06-30 14:29:07 +02:00
parent 6dcf0313bb
commit b44c2344aa

View file

@ -55,6 +55,7 @@ class GravPotDataset(Dataset):
self.initial_conditions_variables = initial_conditions_variables
self.target_variable = target_variable
self.match_str = match_str
self.style_files = style_files
self.style_keys = style_keys
self.max_time = max_time
@ -303,3 +304,65 @@ def train_val_split(dataset: GravPotDataset, val_fraction: float = 0.2, seed: in
return train_dataset, val_dataset
class ExpandingGravPotDataset(GravPotDataset):
"""
A dataset that slowly increases the number of samples to help with training stability.
"""
def __init__(self, batch_size: int = 32, epoch_time_scale: int = 100, epoch_reduction: float = 2, **kwargs):
super().__init__(**kwargs)
self.batch_size = batch_size
self.epoch_time_scale = epoch_time_scale
self.epoch_reduction = epoch_reduction
self.current_epoch = 0
self.last_epoch_expansion = 0
self.n_epochs_before_expanding = epoch_time_scale
self.global_max_len = kwargs.get('max_len', int(1e6))
self.max_len = 1
def on_epoch_end(self):
"""
Call this at the end of each epoch to regenerate offset + time choices.
Also expands the dataset size based on the current epoch.
"""
self.current_epoch += 1
# Expand dataset size every n_epochs_before_expanding epochs
if (self.current_epoch - self.last_epoch_expansion) >= self.n_epochs_before_expanding:
self.max_len = min(self.global_max_len, self.max_len)
self.n_epochs_before_expanding = int(self.n_epochs_before_expanding / self.epoch_reduction)
self.last_epoch_expansion = self.current_epoch
self.max_len = int(self.max_len * self.batch_size)
print(f"Expanding dataset at epoch {self.current_epoch}, new max_len: {self.max_len}")
self._prepare_samples()
def convert_GravPotDataset_to_ExpandingGravPotDataset(dataset: GravPotDataset, **kwargs):
"""
Converts a GravPotDataset to an ExpandingGravPotDataset.
"""
if not isinstance(dataset, GravPotDataset):
raise TypeError("Input dataset must be an instance of GravPotDataset.")
dataset = ExpandingGravPotDataset(
root_dir=dataset.root_dir,
ids=dataset.ids,
N=dataset.N,
N_full=dataset.N_full,
match_str=dataset.match_str,
max_len=dataset.max_len,
device=dataset.device,
initial_conditions_variables=dataset.initial_conditions_variables,
target_variable=dataset.target_variable,
style_files=dataset.style_files,
style_keys=dataset.style_keys,
max_time=dataset.max_time,
batch_size=kwargs.get('batch_size', 32),
epoch_time_scale=kwargs.get('epoch_time_scale', 100),
epoch_reduction=kwargs.get('epoch_reduction', 2)
)
dataset._prepare_samples() # Prepare initial samples
return dataset