-
Notifications
You must be signed in to change notification settings - Fork 27
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Group split kfold #484
Group split kfold #484
Conversation
…into group_split_kfold
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Looks good. Thanks, mostly remarks with respect to testing.
assert model._check_valid_nfolds(20, 10) == 10 | ||
|
||
# Test invalid folds | ||
with pytest.raises(ValueError): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Can you also test that the correct errors are thrown here? you can do this via the match
keyword.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Also for the cases below.
if group_split_column is not None: | ||
# check if the group split column is present in the experiments | ||
if group_split_column not in experiments.columns: | ||
raise ValueError( |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Can you also test these errors, so that they are throw correctly?
@jkeupp: Do you keep this in mind, so that we can integrate it soon? Best, Johannes |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Looks good to me. Thanks!
* add group kfold option in cross_validate of any traainable surrogate * changed to GroupShuffleSplit, added test case * improve docstring & add some inline comments in test * refactor cross_validate & add tests * imrpve tests, remove unnecessary case while checking group split col * add push * formatting --------- Co-authored-by: Jim Boelrijk Valcon <[email protected]>
This pull request introduces the GroupShuffleSplit functionality to the cross_validate method in bofire/surrogates/trainable.py. It adds a new parameter group_split_column to ensure that the splits are made such that the same group is not present in both training and testing sets.
Additionally, corresponding tests have been added in tests/bofire/surrogates/test_cross_validate.py to validate the new functionality.
For now this only works for stand-alone cross_validation calls of the surrogate. Next step, which is to use it during hyperparameter optimization will be part of a future PR.