expanding dataset can now start with size larger than 1

This commit is contained in:
Mayeul Aubin 2025-07-01 15:14:54 +02:00
parent 70fcad44f5
commit d2e4453958

View file

@ -311,7 +311,7 @@ class ExpandingGravPotDataset(GravPotDataset):
A dataset that slowly increases the number of samples to help with training stability. A dataset that slowly increases the number of samples to help with training stability.
""" """
def __init__(self, batch_size: int = 32, epoch_time_scale: int = 100, epoch_reduction: float = 2, **kwargs): def __init__(self, batch_size: int = 32, epoch_time_scale: int = 100, epoch_reduction: float = 2, init_len:int = 1, **kwargs):
super().__init__(**kwargs) super().__init__(**kwargs)
self.batch_size = batch_size self.batch_size = batch_size
self.epoch_time_scale = epoch_time_scale self.epoch_time_scale = epoch_time_scale
@ -320,7 +320,8 @@ class ExpandingGravPotDataset(GravPotDataset):
self.last_epoch_expansion = 0 self.last_epoch_expansion = 0
self.n_epochs_before_expanding = epoch_time_scale self.n_epochs_before_expanding = epoch_time_scale
self.global_max_len = kwargs.get('max_len', int(1e6)) self.global_max_len = kwargs.get('max_len', int(1e6))
self.max_len = 1 self.init_len = init_len
self.max_len = init_len
def on_epoch_end(self): def on_epoch_end(self):
""" """