Skip to content

Commit

Permalink
linter
Browse files Browse the repository at this point in the history
  • Loading branch information
mrastgoo committed Dec 6, 2024
1 parent 4ac84e4 commit 3c364be
Show file tree
Hide file tree
Showing 2 changed files with 19 additions and 13 deletions.
21 changes: 11 additions & 10 deletions skrub/_tabular_learner.py
Original file line number Diff line number Diff line change
Expand Up @@ -61,14 +61,14 @@ def tabular_learner(estimator, *, n_jobs=None):
Parameters
----------
estimator : {"regressor", "classifier"} or scikit-learn estimator
estimator : {"regressor", "regression", "classifier", "classification"} or scikit-learn estimator
The estimator to use as the final step in the pipeline. Based on the type of
estimator, the previous preprocessing steps and their respective parameters are
chosen. The possible values are:
- ``'regressor'``: a :obj:`~sklearn.ensemble.HistGradientBoostingRegressor`
- ``'regressor'`` or ``'regression'``: a :obj:`~sklearn.ensemble.HistGradientBoostingRegressor`
is used as the final step;
- ``'classifier'``: a :obj:`~sklearn.ensemble.HistGradientBoostingClassifier`
- ``'classifier'`` or ``'classification'``: a :obj:`~sklearn.ensemble.HistGradientBoostingClassifier`
is used as the final step;
- a scikit-learn estimator: the provided estimator is used as the final step.
Expand Down Expand Up @@ -106,24 +106,24 @@ def tabular_learner(estimator, *, n_jobs=None):
We can easily get a default pipeline for regression or classification:
>>> tabular_learner('regressor') # doctest: +SKIP
>>> tabular_learner('regression') # doctest: +SKIP
Pipeline(steps=[('tablevectorizer',
TableVectorizer(high_cardinality=MinHashEncoder(),
low_cardinality=ToCategorical())),
('histgradientboostingregressor',
HistGradientBoostingRegressor(categorical_features='from_dtype'))])
When requesting a ``'regressor'``, the last step of the pipeline is set to a
When requesting a ``'regression'``, the last step of the pipeline is set to a
:obj:`~sklearn.ensemble.HistGradientBoostingRegressor`.
>>> tabular_learner('classifier') # doctest: +SKIP
>>> tabular_learner('classification') # doctest: +SKIP
Pipeline(steps=[('tablevectorizer',
TableVectorizer(high_cardinality=MinHashEncoder(),
low_cardinality=ToCategorical())),
('histgradientboostingclassifier',
HistGradientBoostingClassifier(categorical_features='from_dtype'))])
When requesting a ``'classifier'``, the last step of the pipeline is set to a
When requesting a ``'classification'``, the last step of the pipeline is set to a
:obj:`~sklearn.ensemble.HistGradientBoostingClassifier`.
This pipeline can be applied to rich tabular data:
Expand Down Expand Up @@ -227,18 +227,19 @@ def tabular_learner(estimator, *, n_jobs=None):
cat_feat_kwargs = {"categorical_features": "from_dtype"}

if isinstance(estimator, str):
if estimator == "classifier":
if estimator in ("classifier", "classification"):
return tabular_learner(
ensemble.HistGradientBoostingClassifier(**cat_feat_kwargs),
n_jobs=n_jobs,
)
if estimator == "regressor":
if estimator in ("regressor", "regression"):
return tabular_learner(
ensemble.HistGradientBoostingRegressor(**cat_feat_kwargs),
n_jobs=n_jobs,
)
raise ValueError(
"If ``estimator`` is a string it should be 'regressor' or 'classifier'."
"If ``estimator`` is a string it should be 'regressor', 'regression',"
" 'classifier' or 'classification'."
)
if isinstance(estimator, type) and issubclass(estimator, BaseEstimator):
raise TypeError(
Expand Down
11 changes: 8 additions & 3 deletions skrub/tests/test_tabular_learner.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,9 @@
)


@pytest.mark.parametrize("learner_kind", ["regressor", "classifier"])
@pytest.mark.parametrize(
"learner_kind", ["regressor", "regression", "classifier", "classification"]
)
def test_default_pipeline(learner_kind):
p = tabular_learner(learner_kind)
tv, learner = [e for _, e in p.steps]
Expand All @@ -26,14 +28,17 @@ def test_default_pipeline(learner_kind):
else:
assert isinstance(tv.low_cardinality, ToCategorical)
assert learner.categorical_features == "from_dtype"
if learner_kind == "regressor":
if learner_kind in ("regressor", "regression"):
assert isinstance(learner, ensemble.HistGradientBoostingRegressor)
else:
assert isinstance(learner, ensemble.HistGradientBoostingClassifier)


def test_bad_learner():
with pytest.raises(ValueError, match=".*should be 'regressor' or 'classifier'"):
with pytest.raises(
ValueError,
match=".*should be 'regressor', 'regression', 'classifier' or 'classification'",
):
tabular_learner("bad")
with pytest.raises(
TypeError, match=".*Pass an instance of HistGradientBoostingRegressor"
Expand Down

0 comments on commit 3c364be

Please sign in to comment.