diff --git a/docs/source/libraries/sklearn.rst b/docs/source/libraries/sklearn.rst index cff4b42a..d05be37a 100644 --- a/docs/source/libraries/sklearn.rst +++ b/docs/source/libraries/sklearn.rst @@ -260,8 +260,14 @@ Currently the following transformers are supported out of the box: * nested FeatureUnions and Pipelines; * SelectorMixin-based transformers: SelectPercentile_, SelectKBest_, GenericUnivariateSelect_, VarianceThreshold_, - RFE_, RFECV_, SelectFromModel_, RandomizedLogisticRegression_. - + RFE_, RFECV_, SelectFromModel_, RandomizedLogisticRegression_; +* scalers from sklearn.preprocessing: MinMaxScaler_, StandardScaler_, + MaxAbsScaler_, RobustScaler_. + +.. _MinMaxScaler: http://scikit-learn.org/stable/modules/generated/sklearn.preprocessing.MinMaxScaler.html +.. _StandardScaler: http://scikit-learn.org/stable/modules/generated/sklearn.preprocessing.StandardScaler.html#sklearn.preprocessing.StandardScaler +.. _MaxAbsScaler: http://scikit-learn.org/stable/modules/generated/sklearn.preprocessing.MaxAbsScaler.html#sklearn.preprocessing.MaxAbsScaler +.. _RobustScaler: http://scikit-learn.org/stable/modules/generated/sklearn.preprocessing.RobustScaler.html#sklearn.preprocessing.RobustScaler .. _GenericUnivariateSelect: http://scikit-learn.org/stable/modules/generated/sklearn.feature_selection.GenericUnivariateSelect.html .. _SelectPercentile: http://scikit-learn.org/stable/modules/generated/sklearn.feature_selection.SelectPercentile.html .. _SelectKBest: http://scikit-learn.org/stable/modules/generated/sklearn.feature_selection.SelectKBest.html @@ -273,6 +279,7 @@ Currently the following transformers are supported out of the box: .. _Pipeline: http://scikit-learn.org/stable/modules/generated/sklearn.pipeline.Pipeline.html#sklearn.pipeline.Pipeline .. _singledispatch: https://pypi.python.org/pypi/singledispatch + .. _sklearn-unhashing: Reversing hashing trick diff --git a/eli5/sklearn/transform.py b/eli5/sklearn/transform.py index 8afd7c72..d753f2ab 100644 --- a/eli5/sklearn/transform.py +++ b/eli5/sklearn/transform.py @@ -8,6 +8,12 @@ RandomizedLogisticRegression, RandomizedLasso, ) +from sklearn.preprocessing import ( # type: ignore + MinMaxScaler, + StandardScaler, + MaxAbsScaler, + RobustScaler, +) from eli5.transform import transform_feature_names from eli5.sklearn.utils import get_feature_names as _get_feature_names @@ -25,6 +31,19 @@ def _select_names(est, in_names=None): return [in_names[i] for i in np.flatnonzero(mask)] +# Scaling + +@transform_feature_names.register(MinMaxScaler) +@transform_feature_names.register(StandardScaler) +@transform_feature_names.register(MaxAbsScaler) +@transform_feature_names.register(RobustScaler) +def _transform_scaling(est, in_names=None): + if in_names is None: + in_names = _get_feature_names(est, feature_names=in_names, + num_features=est.scale_.shape[0]) + return [name for name in in_names] + + # Pipelines @transform_feature_names.register(Pipeline) diff --git a/tests/test_sklearn_transform.py b/tests/test_sklearn_transform.py index 3decd235..4cceeeb7 100644 --- a/tests/test_sklearn_transform.py +++ b/tests/test_sklearn_transform.py @@ -20,7 +20,14 @@ RandomizedLogisticRegression, RandomizedLasso, # TODO: add tests and document ) -from sklearn.pipeline import FeatureUnion +from sklearn.preprocessing import ( + MinMaxScaler, + StandardScaler, + MaxAbsScaler, + RobustScaler, +) +from sklearn.pipeline import FeatureUnion, make_pipeline + from eli5 import transform_feature_names @@ -41,6 +48,20 @@ def selection_score_func(X, y): @pytest.mark.parametrize('transformer,expected', [ (MyFeatureExtractor(), ['f1', 'f2', 'f3']), + + (make_pipeline(StandardScaler(), MyFeatureExtractor()), + ['f1', 'f2', 'f3']), + (make_pipeline(MinMaxScaler(), MyFeatureExtractor()), + ['f1', 'f2', 'f3']), + (make_pipeline(MaxAbsScaler(), MyFeatureExtractor()), + ['f1', 'f2', 'f3']), + (make_pipeline(RobustScaler(), MyFeatureExtractor()), + ['f1', 'f2', 'f3']), + (StandardScaler(), ['', '', '', '']), + (MinMaxScaler(), ['', '', '', '']), + (MaxAbsScaler(), ['', '', '', '']), + (RobustScaler(), ['', '', '', '']), + (SelectKBest(selection_score_func, k=1), ['']), (SelectKBest(selection_score_func, k=2),