Skip to content

Commit

Permalink
MAINT compatibility sklearn 1.6.0
Browse files Browse the repository at this point in the history
  • Loading branch information
glemaitre committed Nov 29, 2024
1 parent f100059 commit cba5486
Show file tree
Hide file tree
Showing 10 changed files with 653 additions and 22 deletions.
5 changes: 5 additions & 0 deletions benchmarks/bench_minhash_batch_number.py
Original file line number Diff line number Diff line change
Expand Up @@ -132,6 +132,11 @@ def _more_tags(self):
"""
return {"X_types": ["categorical"]}

def __sklearn_tags__(self):
tags = super().__sklearn_tags__()
tags.input_tags.categorical = True
return tags

def _get_murmur_hash(self, string):
"""
Encode a string using murmur hashing function.
Expand Down
5 changes: 4 additions & 1 deletion skrub/_dataframe/_common.py
Original file line number Diff line number Diff line change
Expand Up @@ -106,6 +106,8 @@
"total_seconds",
]

pandas_version = parse_version(parse_version(pd.__version__).base_version)

#
# Inspecting containers' type and module
# ======================================
Expand Down Expand Up @@ -330,7 +332,8 @@ def _concat_horizontal_pandas(*dataframes):
init_index = dataframes[0].index
dataframes = [df.reset_index(drop=True) for df in dataframes]
dataframes = _join_utils.make_column_names_unique(*dataframes)
result = pd.concat(dataframes, axis=1, copy=False)
kwargs = {"copy": False} if pandas_version < parse_version("3.0") else {}
result = pd.concat(dataframes, axis=1, **kwargs)
result.index = init_index
return result

Expand Down
9 changes: 9 additions & 0 deletions skrub/_datetime_encoder.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
from . import _dataframe as sbd
from ._dispatch import dispatch
from ._on_each_column import RejectColumn, SingleColumnTransformer
from ._sklearn_compat import TransformerTags

__all__ = ["DatetimeEncoder"]

Expand Down Expand Up @@ -323,3 +324,11 @@ def _check_params(self):
raise ValueError(
f"'resolution' options are {allowed}, got {self.resolution!r}."
)

def _more_tags(self):
return {"preserves_dtype": []}

def __sklearn_tags__(self):
tags = super().__sklearn_tags__()
tags.transformer_tags = TransformerTags(preserves_dtype=[])
return tags
4 changes: 2 additions & 2 deletions skrub/_interpolation_joiner.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,12 +7,12 @@
HistGradientBoostingClassifier,
HistGradientBoostingRegressor,
)
from sklearn.utils._tags import _safe_tags

from . import _dataframe as sbd
from . import _join_utils, _utils
from . import _selectors as s
from ._minhash_encoder import MinHashEncoder
from ._sklearn_compat import get_tags
from ._table_vectorizer import TableVectorizer

DEFAULT_REGRESSOR = HistGradientBoostingRegressor()
Expand Down Expand Up @@ -403,7 +403,7 @@ def _get_assignments_for_estimator(table, estimator):


def _handles_multioutput(estimator):
return _safe_tags(estimator).get("multioutput", False)
return get_tags(estimator).target_tags.multi_output


def _fit(key_values, target_table, estimator, propagate_exceptions):
Expand Down
12 changes: 10 additions & 2 deletions skrub/_similarity_encoder.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
from sklearn.utils.fixes import parse_version
from sklearn.utils.validation import check_is_fitted

from ._sklearn_compat import _check_n_features
from ._string_distances import get_ngram_count, preprocess

# Ignore lines too long, first docstring lines can't be cut
Expand Down Expand Up @@ -334,7 +335,7 @@ def fit(self, X, y=None):
X[mask] = self.handle_missing

Xlist, n_samples, n_features = self._check_X(X)
self._check_n_features(X, reset=True)
_check_n_features(self, X, reset=True)

if self.handle_unknown not in ["error", "ignore"]:
raise ValueError(
Expand Down Expand Up @@ -453,7 +454,7 @@ def transform(self, X, fast=True):
X[mask] = self.handle_missing

Xlist, n_samples, n_features = self._check_X(X)
self._check_n_features(X, reset=False)
_check_n_features(self, X, reset=False)

for i in range(n_features):
Xi = Xlist[i]
Expand Down Expand Up @@ -562,3 +563,10 @@ def _more_tags(self):
"check_estimators_dtypes": "We only support string dtypes.",
},
}

def __sklearn_tags__(self):
tags = super().__sklearn_tags__()
tags.input_tags.categorical = True
tags.input_tags.string = True
tags.transformer_tags.preserves_dtype = []
return tags
Loading

0 comments on commit cba5486

Please sign in to comment.