Skip to content

Commit

Permalink
Merge pull request #543 from MLecardonnel/feature/scikit_learn_fix
Browse files Browse the repository at this point in the history
Scikit learn fix for versions above 1.4.0
  • Loading branch information
guillaume-vignal authored Mar 28, 2024
2 parents 9110e95 + 4902b1e commit 641405b
Show file tree
Hide file tree
Showing 5 changed files with 18 additions and 15 deletions.
2 changes: 1 addition & 1 deletion requirements.dev.txt
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@ nbsphinx==0.8.8
sphinx_material==0.0.35
pytest>=6.2.5
pytest-cov>=2.8.1
scikit-learn>=1.0.1,<1.4
scikit-learn>=1.4.0
xgboost>=1.0.0
nbformat>4.2.0
numba>=0.53.1
Expand Down
2 changes: 1 addition & 1 deletion setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,7 @@
"dash-table>=5.0.0",
"nbformat>4.2.0",
"numba>=0.53.1",
"scikit-learn>=1.0.1,<1.4",
"scikit-learn>=1.4.0",
"category_encoders>=2.6.0",
"scipy>=0.19.1",
]
Expand Down
16 changes: 8 additions & 8 deletions shapash/utils/columntransformer_backend.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@

import numpy as np
import pandas as pd
from sklearn.preprocessing import FunctionTransformer

from shapash.utils.category_encoder_backend import (
category_encoder_binary,
Expand Down Expand Up @@ -91,7 +92,7 @@ def inv_transform_ct(x_in, encoding):

# columns not encode
elif name_encoding == "remainder":
if ct_encoding == "passthrough":
if isinstance(ct_encoding, FunctionTransformer):
nb_col = len(col_encoding)
frame = x_in.iloc[:, init : init + nb_col]
else:
Expand Down Expand Up @@ -249,7 +250,7 @@ def calc_inv_contrib_ct(x_contrib, encoding, agg_columns):
init += nb_col

elif name_encoding == "remainder":
if ct_encoding == "passthrough":
if isinstance(ct_encoding, FunctionTransformer):
nb_col = len(col_encoding)
frame = x_contrib.iloc[:, init : init + nb_col]
rst = pd.concat([rst, frame], axis=1)
Expand Down Expand Up @@ -366,7 +367,9 @@ def get_feature_names(column_transformer):
List of returned features names when ColumnTransformer is applied.
"""
feature_names = []
l_transformers = list(column_transformer._iter(fitted=True))
l_transformers = list(
column_transformer._iter(fitted=True, column_as_labels=False, skip_drop=True, skip_empty_columns=True)
)

for name, trans, column, _ in l_transformers:
feature_names.extend(get_names(name, trans, column, column_transformer))
Expand Down Expand Up @@ -463,11 +466,8 @@ def get_col_mapping_ct(encoder, x_encoded):
else:
raise NotImplementedError(f"Estimator not supported : {estimator}")

elif estimator == "passthrough":
try:
features_out = encoder.feature_names_in_[features]
except Exception:
features_out = encoder._feature_names_in[features] # for oldest sklearn version
elif isinstance(estimator, FunctionTransformer):
features_out = encoder.feature_names_in_[features]
for f_name in features_out:
dict_col_mapping[f_name] = [x_encoded.columns.to_list()[idx_encoded]]
idx_encoded += 1
Expand Down
4 changes: 3 additions & 1 deletion shapash/utils/transform.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,12 @@
"""
Transform Module
"""

import re

import numpy as np
import pandas as pd
from sklearn.preprocessing import FunctionTransformer

from shapash.utils.category_encoder_backend import (
get_col_mapping_ce,
Expand Down Expand Up @@ -185,7 +187,7 @@ def check_transformers(list_encoding):
if (str(type(ct_encoding)) not in supported_sklearn) and (
str(type(ct_encoding)) not in supported_category_encoder
):
if str(type(ct_encoding)) != "<class 'str'>":
if not isinstance(ct_encoding, str) and not isinstance(ct_encoding, FunctionTransformer):
raise ValueError("One of the encoders used in ColumnTransformers isn't supported.")

elif str(type(enc)) in supported_category_encoder:
Expand Down
9 changes: 5 additions & 4 deletions tests/unit_tests/utils/test_columntransformer_backend.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
"""
Unit test of Inverse Transform
"""

import unittest

import catboost as cb
Expand Down Expand Up @@ -959,25 +960,25 @@ def test_get_names_1(self):
enc_4.fit(train)

feature_names_1 = []
l_transformers = list(enc_1._iter(fitted=True))
l_transformers = list(enc_1._iter(fitted=True, column_as_labels=False, skip_drop=True, skip_empty_columns=True))

for name, trans, column, _ in l_transformers:
feature_names_1.extend(get_names(name, trans, column, enc_1))

feature_names_2 = []
l_transformers = list(enc_2._iter(fitted=True))
l_transformers = list(enc_2._iter(fitted=True, column_as_labels=False, skip_drop=True, skip_empty_columns=True))

for name, trans, column, _ in l_transformers:
feature_names_2.extend(get_names(name, trans, column, enc_2))

feature_names_3 = []
l_transformers = list(enc_3._iter(fitted=True))
l_transformers = list(enc_3._iter(fitted=True, column_as_labels=False, skip_drop=True, skip_empty_columns=True))

for name, trans, column, _ in l_transformers:
feature_names_3.extend(get_names(name, trans, column, enc_3))

feature_names_4 = []
l_transformers = list(enc_4._iter(fitted=True))
l_transformers = list(enc_4._iter(fitted=True, column_as_labels=False, skip_drop=True, skip_empty_columns=True))

for name, trans, column, _ in l_transformers:
feature_names_4.extend(get_names(name, trans, column, enc_4))
Expand Down

0 comments on commit 641405b

Please sign in to comment.