This commit is contained in:
Mayeul Aubin 2025-06-25 17:11:50 +02:00
parent 36f38ef256
commit 2456915038

View file

@ -285,7 +285,7 @@ def train_val_split(dataset: GravPotDataset, val_fraction: float = 0.2, seed: in
- val_dataset: SubDataset for validation.
"""
from sklearn.model_selection import train_test_split
train_ids, val_ids = train_test_split(dataset.ids, test_size=0.2, random_state=seed)
train_ids, val_ids = train_test_split(dataset.ids, test_size=val_fraction, random_state=seed)
train_dataset = SubDataset(dataset, train_ids)
val_dataset = SubDataset(dataset, val_ids)
train_dataset.dataset._prepare_samples()