diff --git a/sCOCA_ML/dataset/gravpot_dataset.py b/sCOCA_ML/dataset/gravpot_dataset.py index 7a26da8..9339dfe 100644 --- a/sCOCA_ML/dataset/gravpot_dataset.py +++ b/sCOCA_ML/dataset/gravpot_dataset.py @@ -311,7 +311,7 @@ class ExpandingGravPotDataset(GravPotDataset): A dataset that slowly increases the number of samples to help with training stability. """ - def __init__(self, batch_size: int = 32, epoch_time_scale: int = 100, epoch_reduction: float = 2, **kwargs): + def __init__(self, batch_size: int = 32, epoch_time_scale: int = 100, epoch_reduction: float = 2, init_len:int = 1, **kwargs): super().__init__(**kwargs) self.batch_size = batch_size self.epoch_time_scale = epoch_time_scale @@ -320,7 +320,8 @@ class ExpandingGravPotDataset(GravPotDataset): self.last_epoch_expansion = 0 self.n_epochs_before_expanding = epoch_time_scale self.global_max_len = kwargs.get('max_len', int(1e6)) - self.max_len = 1 + self.init_len = init_len + self.max_len = init_len def on_epoch_end(self): """