Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Refactor] Made CrossValTypes, HoldoutValTypes to have split functions directly #108

Open
wants to merge 12 commits into
base: master
Choose a base branch
from
Open
Prev Previous commit
Next Next commit
[fix] Fix pytest errors
  • Loading branch information
nabenabe0928 committed May 19, 2021
commit bef43231a592a5527f123c9da37394fb16c3844f
4 changes: 3 additions & 1 deletion autoPyTorch/datasets/resampling_strategy.py
Original file line number Diff line number Diff line change
@@ -32,9 +32,11 @@ def holdout_validation(
labels_to_stratify: Optional[Union[Tuple[np.ndarray, np.ndarray], Dataset]] = None
) -> List[Tuple[np.ndarray, np.ndarray]]:

""" SKLearn requires shuffle=True for stratify """
train, val = train_test_split(
indices, test_size=val_share,
shuffle=shuffle, random_state=random_state,
shuffle=shuffle if labels_to_stratify is None else True,
random_state=random_state,
stratify=labels_to_stratify
)
return [(train, val)]