bugfix
This commit is contained in:
parent
36f38ef256
commit
2456915038
1 changed files with 1 additions and 1 deletions
|
@ -285,7 +285,7 @@ def train_val_split(dataset: GravPotDataset, val_fraction: float = 0.2, seed: in
|
|||
- val_dataset: SubDataset for validation.
|
||||
"""
|
||||
from sklearn.model_selection import train_test_split
|
||||
train_ids, val_ids = train_test_split(dataset.ids, test_size=0.2, random_state=seed)
|
||||
train_ids, val_ids = train_test_split(dataset.ids, test_size=val_fraction, random_state=seed)
|
||||
train_dataset = SubDataset(dataset, train_ids)
|
||||
val_dataset = SubDataset(dataset, val_ids)
|
||||
train_dataset.dataset._prepare_samples()
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue