Skip to content

Commit

Permalink
Revert "MAINT adapt for scikit-learn 1.6 (#1135)"
Browse files Browse the repository at this point in the history
This reverts commit 18af508.
  • Loading branch information
jeromedockes authored Dec 10, 2024
1 parent 18af508 commit 7e5e190
Show file tree
Hide file tree
Showing 11 changed files with 69 additions and 272 deletions.
4 changes: 0 additions & 4 deletions CHANGES.rst
Original file line number Diff line number Diff line change
Expand Up @@ -38,10 +38,6 @@ Bug fixes
:user:`Jérôme Dockès <jeromedockes>` and the matplotlib issue can be tracked
[here](https://github.com/matplotlib/matplotlib/issues/25041).

Maintenance
-----------
* Make `skrub` compatible with scikit-learn 1.6.
:pr:`1135` by :user:`Guillaume Lemaitre <glemaitre>`.

Release 0.4.0
=============
Expand Down
26 changes: 5 additions & 21 deletions benchmarks/bench_minhash_batch_number.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,11 +15,9 @@
import numpy as np
import pandas as pd
import seaborn as sns
import sklearn
from joblib import Parallel, delayed, effective_n_jobs
from sklearn.base import BaseEstimator, TransformerMixin
from sklearn.utils import gen_even_slices, murmurhash3_32
from sklearn.utils.fixes import parse_version
from utils import default_parser, find_result, monitor

from skrub._fast_hash import ngram_min_hash
Expand All @@ -34,11 +32,6 @@
# flake8: noqa: E501


sklearn_below_1_6 = parse_version(
parse_version(sklearn.__version__).base_version
) < parse_version("1.6")


class MinHashEncoder(BaseEstimator, TransformerMixin):
"""
Encode string categorical features as a numeric array, minhash method
Expand Down Expand Up @@ -133,20 +126,11 @@ def __init__(
self.batch_per_job = batch_per_job
self.n_jobs = n_jobs

if sklearn_below_1_6:

def _more_tags(self):
"""
Used internally by sklearn to ease the estimator checks.
"""
return {"X_types": ["categorical"]}

else:

def __sklearn_tags__(self):
tags = super().__sklearn_tags__()
tags.input_tags.categorical = True
return tags
def _more_tags(self):
"""
Used internally by sklearn to ease the estimator checks.
"""
return {"X_types": ["categorical"]}

def _get_murmur_hash(self, string):
"""
Expand Down
5 changes: 1 addition & 4 deletions skrub/_dataframe/_common.py
Original file line number Diff line number Diff line change
Expand Up @@ -106,8 +106,6 @@
"total_seconds",
]

pandas_version = parse_version(parse_version(pd.__version__).base_version)

#
# Inspecting containers' type and module
# ======================================
Expand Down Expand Up @@ -332,8 +330,7 @@ def _concat_horizontal_pandas(*dataframes):
init_index = dataframes[0].index
dataframes = [df.reset_index(drop=True) for df in dataframes]
dataframes = _join_utils.make_column_names_unique(*dataframes)
kwargs = {"copy": False} if pandas_version < parse_version("3.0") else {}
result = pd.concat(dataframes, axis=1, **kwargs)
result = pd.concat(dataframes, axis=1, copy=False)
result.index = init_index
return result

Expand Down
21 changes: 0 additions & 21 deletions skrub/_datetime_encoder.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,6 @@
from datetime import datetime, timezone

import pandas as pd
import sklearn
from sklearn.utils.fixes import parse_version
from sklearn.utils.validation import check_is_fitted

try:
Expand All @@ -28,11 +26,6 @@
]


sklearn_below_1_6 = parse_version(
parse_version(sklearn.__version__).base_version
) < parse_version("1.6")


@dispatch
def _is_date(col):
raise NotImplementedError()
Expand Down Expand Up @@ -330,17 +323,3 @@ def _check_params(self):
raise ValueError(
f"'resolution' options are {allowed}, got {self.resolution!r}."
)

if sklearn_below_1_6:

def _more_tags(self):
return {"preserves_dtype": []}

else:

def __sklearn_tags__(self):
tags = super().__sklearn_tags__()
from sklearn.utils import TransformerTags

tags.transformer_tags = TransformerTags()
return tags
19 changes: 0 additions & 19 deletions skrub/_fixes.py

This file was deleted.

12 changes: 2 additions & 10 deletions skrub/_interpolation_joiner.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,4 @@
import warnings
from dataclasses import is_dataclass

import joblib
import numpy as np
Expand All @@ -8,11 +7,11 @@
HistGradientBoostingClassifier,
HistGradientBoostingRegressor,
)
from sklearn.utils._tags import _safe_tags

from . import _dataframe as sbd
from . import _join_utils, _utils
from . import _selectors as s
from ._fixes import get_tags
from ._minhash_encoder import MinHashEncoder
from ._table_vectorizer import TableVectorizer

Expand Down Expand Up @@ -404,14 +403,7 @@ def _get_assignments_for_estimator(table, estimator):


def _handles_multioutput(estimator):
tags = get_tags(estimator)
if isinstance(tags, dict):
# scikit-learn < 1.6
return tags.get("multioutput", False)
elif is_dataclass(tags):
# scikit-learn >= 1.6
return tags.target_tags.multi_output
return False
return _safe_tags(estimator).get("multioutput", False)


def _fit(key_values, target_table, estimator, propagate_exceptions):
Expand Down
46 changes: 15 additions & 31 deletions skrub/_similarity_encoder.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
which encodes similarity instead of equality of values.
"""


import numpy as np
import pandas as pd
import sklearn
Expand All @@ -13,18 +14,12 @@
from sklearn.utils.fixes import parse_version
from sklearn.utils.validation import check_is_fitted

from ._fixes import _check_n_features
from ._string_distances import get_ngram_count, preprocess

# Ignore lines too long, first docstring lines can't be cut
# flake8: noqa: E501


sklearn_below_1_6 = parse_version(
parse_version(sklearn.__version__).base_version
) < parse_version("1.6")


def _ngram_similarity_one_sample_inplace(
x_count_vector,
vocabulary_count_matrix,
Expand Down Expand Up @@ -339,7 +334,7 @@ def fit(self, X, y=None):
X[mask] = self.handle_missing

Xlist, n_samples, n_features = self._check_X(X)
_check_n_features(self, X, reset=True)
self._check_n_features(X, reset=True)

if self.handle_unknown not in ["error", "ignore"]:
raise ValueError(
Expand Down Expand Up @@ -458,7 +453,7 @@ def transform(self, X, fast=True):
X[mask] = self.handle_missing

Xlist, n_samples, n_features = self._check_X(X)
_check_n_features(self, X, reset=False)
self._check_n_features(X, reset=False)

for i in range(n_features):
Xi = Xlist[i]
Expand Down Expand Up @@ -555,26 +550,15 @@ def _ngram_similarity_fast(

return np.nan_to_num(out, copy=False)

if sklearn_below_1_6:

def _more_tags(self):
return {
"X_types": ["2darray", "categorical", "string"],
"preserves_dtype": [],
"allow_nan": True,
"_xfail_checks": {
"check_estimator_sparse_data": (
"Cannot create sparse matrix with strings."
),
"check_estimators_dtypes": "We only support string dtypes.",
},
}

else:

def __sklearn_tags__(self):
tags = super().__sklearn_tags__()
tags.input_tags.categorical = True
tags.input_tags.string = True
tags.transformer_tags.preserves_dtype = []
return tags
def _more_tags(self):
return {
"X_types": ["2darray", "categorical", "string"],
"preserves_dtype": [],
"allow_nan": True,
"_xfail_checks": {
"check_estimator_sparse_data": (
"Cannot create sparse matrix with strings."
),
"check_estimators_dtypes": "We only support string dtypes.",
},
}
39 changes: 11 additions & 28 deletions skrub/_table_vectorizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,12 +3,10 @@
from typing import Iterable

import numpy as np
import sklearn
from sklearn.base import BaseEstimator, TransformerMixin, clone
from sklearn.pipeline import make_pipeline
from sklearn.preprocessing import OneHotEncoder
from sklearn.utils._estimator_html_repr import _VisualBlock
from sklearn.utils.fixes import parse_version
from sklearn.utils.validation import check_is_fitted

from . import _dataframe as sbd
Expand All @@ -30,11 +28,6 @@
__all__ = ["TableVectorizer"]


sklearn_below_1_6 = parse_version(
parse_version(sklearn.__version__).base_version
) < parse_version("1.6")


class PassThrough(SingleColumnTransformer):
def fit_transform(self, column, y=None):
return column
Expand Down Expand Up @@ -665,27 +658,17 @@ def _sk_visual_block_(self):

# scikit-learn compatibility

if sklearn_below_1_6:

def _more_tags(self):
"""
Used internally by sklearn to ease the estimator checks.
"""
return {
"X_types": ["2darray", "string"],
"allow_nan": [True],
"_xfail_checks": {
"check_complex_data": "Passthrough complex columns as-is.",
},
}

else:

def __sklearn_tags__(self):
tags = super().__sklearn_tags__()
tags.input_tags.string = True
tags.input_tags.allow_nan = True
return tags
def _more_tags(self):
"""
Used internally by sklearn to ease the estimator checks.
"""
return {
"X_types": ["2darray", "string"],
"allow_nan": [True],
"_xfail_checks": {
"check_complex_data": "Passthrough complex columns as-is.",
},
}

def get_feature_names_out(self):
"""Return the column names of the output of ``transform`` as a list of strings.
Expand Down
15 changes: 3 additions & 12 deletions skrub/_tabular_learner.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,3 @@
from dataclasses import is_dataclass

import sklearn
from sklearn import ensemble
from sklearn.base import BaseEstimator
Expand All @@ -8,7 +6,6 @@
from sklearn.preprocessing import OrdinalEncoder, StandardScaler
from sklearn.utils.fixes import parse_version

from ._fixes import get_tags
from ._minhash_encoder import MinHashEncoder
from ._table_vectorizer import TableVectorizer
from ._to_categorical import ToCategorical
Expand Down Expand Up @@ -273,15 +270,9 @@ def tabular_learner(estimator, *, n_jobs=None):
high_cardinality=MinHashEncoder(),
)
steps = [vectorizer]
try:
tags = get_tags(estimator)
if is_dataclass(tags):
allow_nan = tags.input_tags.allow_nan
else:
allow_nan = tags.get("allow_nan", False)
except TypeError:
allow_nan = False
if not allow_nan:
if not hasattr(estimator, "_get_tags") or not estimator._get_tags().get(
"allow_nan", False
):
steps.append(SimpleImputer(add_indicator=True))
if not isinstance(estimator, _TREE_ENSEMBLE_CLASSES):
steps.append(StandardScaler())
Expand Down
2 changes: 0 additions & 2 deletions skrub/_to_datetime.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,8 +28,6 @@ def _get_time_zone_pandas(col):
return None
if hasattr(tz, "zone"):
return tz.zone
if hasattr(tz, "key"):
return tz.key
return tz.tzname(None)


Expand Down
Loading

0 comments on commit 7e5e190

Please sign in to comment.