expanding dataset can now start with size larger than 1
This commit is contained in:
parent
70fcad44f5
commit
d2e4453958
1 changed files with 3 additions and 2 deletions
|
@ -311,7 +311,7 @@ class ExpandingGravPotDataset(GravPotDataset):
|
||||||
A dataset that slowly increases the number of samples to help with training stability.
|
A dataset that slowly increases the number of samples to help with training stability.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(self, batch_size: int = 32, epoch_time_scale: int = 100, epoch_reduction: float = 2, **kwargs):
|
def __init__(self, batch_size: int = 32, epoch_time_scale: int = 100, epoch_reduction: float = 2, init_len:int = 1, **kwargs):
|
||||||
super().__init__(**kwargs)
|
super().__init__(**kwargs)
|
||||||
self.batch_size = batch_size
|
self.batch_size = batch_size
|
||||||
self.epoch_time_scale = epoch_time_scale
|
self.epoch_time_scale = epoch_time_scale
|
||||||
|
@ -320,7 +320,8 @@ class ExpandingGravPotDataset(GravPotDataset):
|
||||||
self.last_epoch_expansion = 0
|
self.last_epoch_expansion = 0
|
||||||
self.n_epochs_before_expanding = epoch_time_scale
|
self.n_epochs_before_expanding = epoch_time_scale
|
||||||
self.global_max_len = kwargs.get('max_len', int(1e6))
|
self.global_max_len = kwargs.get('max_len', int(1e6))
|
||||||
self.max_len = 1
|
self.init_len = init_len
|
||||||
|
self.max_len = init_len
|
||||||
|
|
||||||
def on_epoch_end(self):
|
def on_epoch_end(self):
|
||||||
"""
|
"""
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue