From d2e4453958ba3c28802d5d2f1009fc45d198439b Mon Sep 17 00:00:00 2001 From: Mayeul Aubin Date: Tue, 1 Jul 2025 15:14:54 +0200 Subject: [PATCH] expanding dataset can now start with size larger than 1 --- sCOCA_ML/dataset/gravpot_dataset.py | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/sCOCA_ML/dataset/gravpot_dataset.py b/sCOCA_ML/dataset/gravpot_dataset.py index 7a26da8..9339dfe 100644 --- a/sCOCA_ML/dataset/gravpot_dataset.py +++ b/sCOCA_ML/dataset/gravpot_dataset.py @@ -311,7 +311,7 @@ class ExpandingGravPotDataset(GravPotDataset): A dataset that slowly increases the number of samples to help with training stability. """ - def __init__(self, batch_size: int = 32, epoch_time_scale: int = 100, epoch_reduction: float = 2, **kwargs): + def __init__(self, batch_size: int = 32, epoch_time_scale: int = 100, epoch_reduction: float = 2, init_len:int = 1, **kwargs): super().__init__(**kwargs) self.batch_size = batch_size self.epoch_time_scale = epoch_time_scale @@ -320,7 +320,8 @@ class ExpandingGravPotDataset(GravPotDataset): self.last_epoch_expansion = 0 self.n_epochs_before_expanding = epoch_time_scale self.global_max_len = kwargs.get('max_len', int(1e6)) - self.max_len = 1 + self.init_len = init_len + self.max_len = init_len def on_epoch_end(self): """