diff --git a/sCOCA_ML/dataset/gravpot_dataset.py b/sCOCA_ML/dataset/gravpot_dataset.py index 7315f57..c380cda 100644 --- a/sCOCA_ML/dataset/gravpot_dataset.py +++ b/sCOCA_ML/dataset/gravpot_dataset.py @@ -285,7 +285,7 @@ def train_val_split(dataset: GravPotDataset, val_fraction: float = 0.2, seed: in - val_dataset: SubDataset for validation. """ from sklearn.model_selection import train_test_split - train_ids, val_ids = train_test_split(dataset.ids, test_size=0.2, random_state=seed) + train_ids, val_ids = train_test_split(dataset.ids, test_size=val_fraction, random_state=seed) train_dataset = SubDataset(dataset, train_ids) val_dataset = SubDataset(dataset, val_ids) train_dataset.dataset._prepare_samples()