Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
11 changes: 9 additions & 2 deletions docs/source/libraries/sklearn.rst
Original file line number Diff line number Diff line change
Expand Up @@ -260,8 +260,14 @@ Currently the following transformers are supported out of the box:
* nested FeatureUnions and Pipelines;
* SelectorMixin-based transformers: SelectPercentile_,
SelectKBest_, GenericUnivariateSelect_, VarianceThreshold_,
RFE_, RFECV_, SelectFromModel_, RandomizedLogisticRegression_.

RFE_, RFECV_, SelectFromModel_, RandomizedLogisticRegression_;
* scalers from sklearn.preprocessing: MinMaxScaler_, StandardScaler_,
MaxAbsScaler_, RobustScaler_.

.. _MinMaxScaler: http://scikit-learn.org/stable/modules/generated/sklearn.preprocessing.MinMaxScaler.html
.. _StandardScaler: http://scikit-learn.org/stable/modules/generated/sklearn.preprocessing.StandardScaler.html#sklearn.preprocessing.StandardScaler
.. _MaxAbsScaler: http://scikit-learn.org/stable/modules/generated/sklearn.preprocessing.MaxAbsScaler.html#sklearn.preprocessing.MaxAbsScaler
.. _RobustScaler: http://scikit-learn.org/stable/modules/generated/sklearn.preprocessing.RobustScaler.html#sklearn.preprocessing.RobustScaler
.. _GenericUnivariateSelect: http://scikit-learn.org/stable/modules/generated/sklearn.feature_selection.GenericUnivariateSelect.html
.. _SelectPercentile: http://scikit-learn.org/stable/modules/generated/sklearn.feature_selection.SelectPercentile.html
.. _SelectKBest: http://scikit-learn.org/stable/modules/generated/sklearn.feature_selection.SelectKBest.html
Expand All @@ -273,6 +279,7 @@ Currently the following transformers are supported out of the box:
.. _Pipeline: http://scikit-learn.org/stable/modules/generated/sklearn.pipeline.Pipeline.html#sklearn.pipeline.Pipeline
.. _singledispatch: https://pypi.python.org/pypi/singledispatch


.. _sklearn-unhashing:

Reversing hashing trick
Expand Down
19 changes: 19 additions & 0 deletions eli5/sklearn/transform.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,12 @@
RandomizedLogisticRegression,
RandomizedLasso,
)
from sklearn.preprocessing import ( # type: ignore
MinMaxScaler,
StandardScaler,
MaxAbsScaler,
RobustScaler,
)

from eli5.transform import transform_feature_names
from eli5.sklearn.utils import get_feature_names as _get_feature_names
Expand All @@ -25,6 +31,19 @@ def _select_names(est, in_names=None):
return [in_names[i] for i in np.flatnonzero(mask)]


# Scaling

@transform_feature_names.register(MinMaxScaler)
@transform_feature_names.register(StandardScaler)
@transform_feature_names.register(MaxAbsScaler)
@transform_feature_names.register(RobustScaler)
def _transform_scaling(est, in_names=None):
if in_names is None:
in_names = _get_feature_names(est, feature_names=in_names,
num_features=est.scale_.shape[0])
return [name for name in in_names]


# Pipelines

@transform_feature_names.register(Pipeline)
Expand Down
23 changes: 22 additions & 1 deletion tests/test_sklearn_transform.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,14 @@
RandomizedLogisticRegression,
RandomizedLasso, # TODO: add tests and document
)
from sklearn.pipeline import FeatureUnion
from sklearn.preprocessing import (
MinMaxScaler,
StandardScaler,
MaxAbsScaler,
RobustScaler,
)
from sklearn.pipeline import FeatureUnion, make_pipeline

from eli5 import transform_feature_names


Expand All @@ -41,6 +48,20 @@ def selection_score_func(X, y):

@pytest.mark.parametrize('transformer,expected', [
(MyFeatureExtractor(), ['f1', 'f2', 'f3']),

(make_pipeline(StandardScaler(), MyFeatureExtractor()),
['f1', 'f2', 'f3']),
(make_pipeline(MinMaxScaler(), MyFeatureExtractor()),
['f1', 'f2', 'f3']),
(make_pipeline(MaxAbsScaler(), MyFeatureExtractor()),
['f1', 'f2', 'f3']),
(make_pipeline(RobustScaler(), MyFeatureExtractor()),
['f1', 'f2', 'f3']),
(StandardScaler(), ['<NAME0>', '<NAME1>', '<NAME2>', '<NAME3>']),
(MinMaxScaler(), ['<NAME0>', '<NAME1>', '<NAME2>', '<NAME3>']),
(MaxAbsScaler(), ['<NAME0>', '<NAME1>', '<NAME2>', '<NAME3>']),
(RobustScaler(), ['<NAME0>', '<NAME1>', '<NAME2>', '<NAME3>']),

(SelectKBest(selection_score_func, k=1),
['<NAME3>']),
(SelectKBest(selection_score_func, k=2),
Expand Down