From b44c2344aafff7f8c08f72b36147c8c2ed56bd10 Mon Sep 17 00:00:00 2001 From: Mayeul Aubin Date: Mon, 30 Jun 2025 14:29:07 +0200 Subject: [PATCH] dataset that expands during training --- sCOCA_ML/dataset/gravpot_dataset.py | 63 +++++++++++++++++++++++++++++ 1 file changed, 63 insertions(+) diff --git a/sCOCA_ML/dataset/gravpot_dataset.py b/sCOCA_ML/dataset/gravpot_dataset.py index 2c637ec..7a26da8 100644 --- a/sCOCA_ML/dataset/gravpot_dataset.py +++ b/sCOCA_ML/dataset/gravpot_dataset.py @@ -55,6 +55,7 @@ class GravPotDataset(Dataset): self.initial_conditions_variables = initial_conditions_variables self.target_variable = target_variable + self.match_str = match_str self.style_files = style_files self.style_keys = style_keys self.max_time = max_time @@ -303,3 +304,65 @@ def train_val_split(dataset: GravPotDataset, val_fraction: float = 0.2, seed: in return train_dataset, val_dataset + + +class ExpandingGravPotDataset(GravPotDataset): + """ + A dataset that slowly increases the number of samples to help with training stability. + """ + + def __init__(self, batch_size: int = 32, epoch_time_scale: int = 100, epoch_reduction: float = 2, **kwargs): + super().__init__(**kwargs) + self.batch_size = batch_size + self.epoch_time_scale = epoch_time_scale + self.epoch_reduction = epoch_reduction + self.current_epoch = 0 + self.last_epoch_expansion = 0 + self.n_epochs_before_expanding = epoch_time_scale + self.global_max_len = kwargs.get('max_len', int(1e6)) + self.max_len = 1 + + def on_epoch_end(self): + """ + Call this at the end of each epoch to regenerate offset + time choices. + Also expands the dataset size based on the current epoch. + """ + self.current_epoch += 1 + + # Expand dataset size every n_epochs_before_expanding epochs + if (self.current_epoch - self.last_epoch_expansion) >= self.n_epochs_before_expanding: + self.max_len = min(self.global_max_len, self.max_len) + self.n_epochs_before_expanding = int(self.n_epochs_before_expanding / self.epoch_reduction) + self.last_epoch_expansion = self.current_epoch + self.max_len = int(self.max_len * self.batch_size) + print(f"Expanding dataset at epoch {self.current_epoch}, new max_len: {self.max_len}") + self._prepare_samples() + + + +def convert_GravPotDataset_to_ExpandingGravPotDataset(dataset: GravPotDataset, **kwargs): + """ + Converts a GravPotDataset to an ExpandingGravPotDataset. + """ + if not isinstance(dataset, GravPotDataset): + raise TypeError("Input dataset must be an instance of GravPotDataset.") + + dataset = ExpandingGravPotDataset( + root_dir=dataset.root_dir, + ids=dataset.ids, + N=dataset.N, + N_full=dataset.N_full, + match_str=dataset.match_str, + max_len=dataset.max_len, + device=dataset.device, + initial_conditions_variables=dataset.initial_conditions_variables, + target_variable=dataset.target_variable, + style_files=dataset.style_files, + style_keys=dataset.style_keys, + max_time=dataset.max_time, + batch_size=kwargs.get('batch_size', 32), + epoch_time_scale=kwargs.get('epoch_time_scale', 100), + epoch_reduction=kwargs.get('epoch_reduction', 2) + ) + dataset._prepare_samples() # Prepare initial samples + return dataset