dataset that expands during training
This commit is contained in:
parent
6dcf0313bb
commit
b44c2344aa
1 changed files with 63 additions and 0 deletions
|
@ -55,6 +55,7 @@ class GravPotDataset(Dataset):
|
|||
|
||||
self.initial_conditions_variables = initial_conditions_variables
|
||||
self.target_variable = target_variable
|
||||
self.match_str = match_str
|
||||
self.style_files = style_files
|
||||
self.style_keys = style_keys
|
||||
self.max_time = max_time
|
||||
|
@ -303,3 +304,65 @@ def train_val_split(dataset: GravPotDataset, val_fraction: float = 0.2, seed: in
|
|||
|
||||
return train_dataset, val_dataset
|
||||
|
||||
|
||||
|
||||
class ExpandingGravPotDataset(GravPotDataset):
|
||||
"""
|
||||
A dataset that slowly increases the number of samples to help with training stability.
|
||||
"""
|
||||
|
||||
def __init__(self, batch_size: int = 32, epoch_time_scale: int = 100, epoch_reduction: float = 2, **kwargs):
|
||||
super().__init__(**kwargs)
|
||||
self.batch_size = batch_size
|
||||
self.epoch_time_scale = epoch_time_scale
|
||||
self.epoch_reduction = epoch_reduction
|
||||
self.current_epoch = 0
|
||||
self.last_epoch_expansion = 0
|
||||
self.n_epochs_before_expanding = epoch_time_scale
|
||||
self.global_max_len = kwargs.get('max_len', int(1e6))
|
||||
self.max_len = 1
|
||||
|
||||
def on_epoch_end(self):
|
||||
"""
|
||||
Call this at the end of each epoch to regenerate offset + time choices.
|
||||
Also expands the dataset size based on the current epoch.
|
||||
"""
|
||||
self.current_epoch += 1
|
||||
|
||||
# Expand dataset size every n_epochs_before_expanding epochs
|
||||
if (self.current_epoch - self.last_epoch_expansion) >= self.n_epochs_before_expanding:
|
||||
self.max_len = min(self.global_max_len, self.max_len)
|
||||
self.n_epochs_before_expanding = int(self.n_epochs_before_expanding / self.epoch_reduction)
|
||||
self.last_epoch_expansion = self.current_epoch
|
||||
self.max_len = int(self.max_len * self.batch_size)
|
||||
print(f"Expanding dataset at epoch {self.current_epoch}, new max_len: {self.max_len}")
|
||||
self._prepare_samples()
|
||||
|
||||
|
||||
|
||||
def convert_GravPotDataset_to_ExpandingGravPotDataset(dataset: GravPotDataset, **kwargs):
|
||||
"""
|
||||
Converts a GravPotDataset to an ExpandingGravPotDataset.
|
||||
"""
|
||||
if not isinstance(dataset, GravPotDataset):
|
||||
raise TypeError("Input dataset must be an instance of GravPotDataset.")
|
||||
|
||||
dataset = ExpandingGravPotDataset(
|
||||
root_dir=dataset.root_dir,
|
||||
ids=dataset.ids,
|
||||
N=dataset.N,
|
||||
N_full=dataset.N_full,
|
||||
match_str=dataset.match_str,
|
||||
max_len=dataset.max_len,
|
||||
device=dataset.device,
|
||||
initial_conditions_variables=dataset.initial_conditions_variables,
|
||||
target_variable=dataset.target_variable,
|
||||
style_files=dataset.style_files,
|
||||
style_keys=dataset.style_keys,
|
||||
max_time=dataset.max_time,
|
||||
batch_size=kwargs.get('batch_size', 32),
|
||||
epoch_time_scale=kwargs.get('epoch_time_scale', 100),
|
||||
epoch_reduction=kwargs.get('epoch_reduction', 2)
|
||||
)
|
||||
dataset._prepare_samples() # Prepare initial samples
|
||||
return dataset
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue