Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Group split kfold #484

Merged
merged 9 commits into from
Jan 14, 2025
Merged

Group split kfold #484

merged 9 commits into from
Jan 14, 2025

Conversation

jkeupp
Copy link
Contributor

@jkeupp jkeupp commented Dec 16, 2024

This pull request introduces the GroupShuffleSplit functionality to the cross_validate method in bofire/surrogates/trainable.py. It adds a new parameter group_split_column to ensure that the splits are made such that the same group is not present in both training and testing sets.

Additionally, corresponding tests have been added in tests/bofire/surrogates/test_cross_validate.py to validate the new functionality.

For now this only works for stand-alone cross_validation calls of the surrogate. Next step, which is to use it during hyperparameter optimization will be part of a future PR.

Copy link
Contributor

@jduerholt jduerholt left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Looks good. Thanks, mostly remarks with respect to testing.

assert model._check_valid_nfolds(20, 10) == 10

# Test invalid folds
with pytest.raises(ValueError):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can you also test that the correct errors are thrown here? you can do this via the match keyword.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Also for the cases below.

if group_split_column is not None:
# check if the group split column is present in the experiments
if group_split_column not in experiments.columns:
raise ValueError(
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can you also test these errors, so that they are throw correctly?

@jduerholt
Copy link
Contributor

@jkeupp: Do you keep this in mind, so that we can integrate it soon? Best, Johannes

Copy link
Contributor

@jduerholt jduerholt left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Looks good to me. Thanks!

@jduerholt jduerholt merged commit 4b6d7bf into main Jan 14, 2025
6 of 9 checks passed
@jduerholt jduerholt deleted the group_split_kfold branch January 14, 2025 07:28
dlinzner-bcs pushed a commit that referenced this pull request Jan 20, 2025
* add group kfold option in cross_validate of any traainable surrogate

* changed to GroupShuffleSplit, added test case

* improve docstring & add some inline comments in test

* refactor cross_validate & add tests

* imrpve tests, remove unnecessary case while checking group split col

* add push

* formatting

---------

Co-authored-by: Jim Boelrijk Valcon <[email protected]>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

2 participants