Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[ENH] Enable Grid-Search for TableVectorizer #814

Merged

Conversation

Vincent-Maladiere
Copy link
Member

@Vincent-Maladiere Vincent-Maladiere commented Nov 2, 2023

What does this PR fix/address?

Apply Gaël's suggestions and the outputs of discussion #796 to make grid-search possible.

from skrub import TableVectorizer
from skrub.datasets import fetch_employee_salaries

from sklearn.model_selection import GridSearchCV
from sklearn.pipeline import make_pipeline
from sklearn.ensemble import HistGradientBoostingRegressor

dataset = fetch_employee_salaries()
X, y = dataset.X.head(500), dataset.y.head(500)

pipe = make_pipeline(
    TableVectorizer(),
    HistGradientBoostingRegressor()
)

param_grid = {"tablevectorizer__high_cardinality_transformer__n_components": [10, 20]}

cv = GridSearchCV(pipe, param_grid)
cv.fit(X, y)

What does it change?

  • Replace the default None of the transformers with their global default
  • When the user actively sets transformers to None, they are turned into "passthrough" during fit (e.g. high_cardinality_encoder = None will result in high_cardinality_encoder_ = passthrough
  • Some revamp to make cloning during init and during fit more readable

@Vincent-Maladiere Vincent-Maladiere changed the title [ENH] Apply global sub-estimator default parameter [ENH] Enable Grid-Search for TableVectorizer Nov 2, 2023
@Vincent-Maladiere
Copy link
Member Author

I need to address a small docstring error

skrub/_table_vectorizer.py Outdated Show resolved Hide resolved
skrub/_table_vectorizer.py Show resolved Hide resolved
skrub/_table_vectorizer.py Outdated Show resolved Hide resolved
Copy link
Member

@jeromedockes jeromedockes left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM, thanks!

from sklearn.utils.validation import check_is_fitted

from skrub import DatetimeEncoder, GapEncoder
from skrub._utils import parse_astype_error_message

HIGH_CARDINALITY_TRANSFORMER = GapEncoder(n_components=30)
LOW_CARDINALITY_TRANSFORMER = OneHotEncoder(
sparse_output=False,
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

we expose the ColumnTransformer's sparse_threshold parameter but with our default transformers the default will always be dense (even if toe onehot encoder yields many zeros)

we could consider

  • pointing out in the doc that users need to change the transformers if they want sparse output
  • not exposing the sparse_threshold and always returning dense data
  • making the onehot encoder sparse by default

(not in this PR)

("numeric", self.numerical_transformer_, numeric_columns),
("datetime", self.datetime_transformer_, datetime_columns),
("low_card_cat", self.low_card_cat_transformer_, low_card_cat_columns),
("high_card_cat", self.high_card_cat_transformer_, high_card_cat_columns),
("low_card_cat", self.low_cardinality_transformer_, low_card_cat_columns),
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

not super important but you could propagate the change 'card_cat' -> 'cardinality' to local variables

along that line it would be nice if we picked one of "numeric" or "numerical" and used it all the time :)

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

numeric, since it's shorter?

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

sounds good! it's also the choice made by polars.selectors.numeric. pandas select_dtypes uses "number", "category"

@jeromedockes
Copy link
Member

I guess this one is ready to merge?

@Vincent-Maladiere
Copy link
Member Author

I think so!

@jeromedockes jeromedockes merged commit 4b11e62 into skrub-data:main Nov 9, 2023
21 checks passed
@Vincent-Maladiere Vincent-Maladiere deleted the make_tv_grid_searchable branch November 9, 2023 16:34
@GaelVaroquaux
Copy link
Member

Very nice. Congratulations!

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants