diff --git a/CHANGES.rst b/CHANGES.rst index 94045576e..5cd655994 100644 --- a/CHANGES.rst +++ b/CHANGES.rst @@ -15,6 +15,17 @@ It is currently undergoing fast development and backward compatibility is not en New features ------------ +Changes +------- + +Bug fixes +--------- + +Maintenance +----------- + +Release 0.4.1 +============= Changes ------- @@ -22,16 +33,17 @@ Changes :pr:`1190` by :user: `Mojdeh Rastgoo`. * A new parameter `verbose` has been added to the :class:`TableReport` to toggle on or off the +* A new parameter ``verbose`` has been added to the :class:`TableReport` to toggle on or off the printing of progress information when a report is being generated. :pr:`1182` by :user:`Priscilla Baah`. -* A parameter `verbose` has been added to the :func:`patch_display` to toggle on or off the +* A parameter ``verbose`` has been added to the :func:`patch_display` to toggle on or off the printing of progress information when a table report is being generated. :pr:`1188` by :user:`Priscilla Baah`. * :func:`tabular_learner` accepts the alias ``"regression"`` for the option - ``"regressor"`` and ``"classification"`` for ``"classifier"``. - :pr:`1180` by :user:`Mojdeh Rastgoo `. + ``"regressor"`` and ``"classification"`` for ``"classifier"``. + :pr:`1180` by :user:`Mojdeh Rastgoo `. Bug fixes --------- @@ -39,12 +51,19 @@ Bug fixes configuration which could cause plots not to display inline in jupyter notebooks any more. This has been fixed in skrub in :pr:`1172` by :user:`Jérôme Dockès ` and the matplotlib issue can be tracked - [here](https://github.com/matplotlib/matplotlib/issues/25041). + `here `_. + +* The labels on bar plots in the ``TableReport`` for columns of object dtypes + that have a repr spanning multiple lines could be unreadable. This has been + fixed in :pr:`1196` by :user:`Jérôme Dockès `. + +* Improve the performance of :func:`deduplicate` by removing some unnecessary + computations. :pr:`1193` by :user:`Jérôme Dockès `. Maintenance ----------- -* Make `skrub` compatible with scikit-learn 1.6. - :pr:`1135` by :user:`Guillaume Lemaitre `. +* Make ``skrub`` compatible with scikit-learn 1.6. + :pr:`1169` by :user:`Guillaume Lemaitre `. Release 0.4.0 ============= diff --git a/CONTRIBUTING.rst b/CONTRIBUTING.rst index cc74f86a9..e47cd89ea 100644 --- a/CONTRIBUTING.rst +++ b/CONTRIBUTING.rst @@ -124,15 +124,52 @@ See the relevant sections above on how to do this. Setting up the environment ^^^^^^^^^^^^^^^^^^^^^^^^^^ -Follow the steps in the :ref:`installation_instructions` > "From Source" section -to set up your environment, install the required development dependencies, and -run the tests. +To contribute, you will first have to run through some steps: + +- Set up your environment by forking the repository (`Github doc on + forking and + cloning `__). +- Create and activate a new virtual environment: + + - With `venv `__, create + the env with ``python -m venv env_skrub`` and then activate it with + ``source env_skrub/bin/activate``. + - With + `conda `__, + create the env with ``conda new -n env_skrub`` and activate it with + ``conda activate env_skrub``. + - While at the root of your local copy of skrub and within the new + env, install the required development dependencies by running + ``pip install --editable ".[dev, lint, test, doc]"``. + +- Run ``pre-commit install`` to activate some checks that will run every + time you do a ``git commit`` (mostly, formatting checks). + +If you want to make sure that everything runs properly, you can run all +the tests with the command ``pytest -s skrub/tests``; note that this may +take a long time. Some tests may raise warnings such as: -When starting to work on a new issue, it's recommended to create a new branch: +.. code:: sh + + UserWarning: Only pandas and polars DataFrames are supported, but input is a Numpy array. Please convert Numpy arrays to DataFrames before passing them to skrub transformers. Converting to pandas DataFrame with columns ['0', '1', …]. + warnings.warn( + +This is expected, and you may proceed with the next steps without worrying about them. However, no tests should fail at this point: if they do fail, then let us know. -.. code:: console +Now that the development environment is ready, you may start working on +the new issue by creating a new branch: + +.. code:: sh - git switch -c branch_name + git checkout -b my-branch-name-eg-fix-issue-123 + # make some changes + git add ./the/file-i-changed + git commit -m "my message" + git push --set-upstream origin my-branch-name-eg-fix-issue-123 + +At this point, if you visit again the `pull requests +page `__ github should show a +banner asking if you want to open a pull request from your new branch. .. _implementation guidelines: @@ -183,7 +220,8 @@ Additionally, you might have updated the internal dataframe API in ``skrub/_dataframe/tests/test_common.py`` to add a test for the ``amazing_function``. -Run each updated test file using ``pytest``: +Run each updated test file using ``pytest`` +([pytest docs](https://docs.pytest.org/en/stable/)): .. code:: sh @@ -193,10 +231,20 @@ Run each updated test file using ``pytest``: The ``-vsl`` flag provides more information when running the tests. +It is also possible to run a specific test, or set of tests using the +commands ``pytest the_file.py::the_test``, or +``pytest the_file.py -k 'test_name_pattern'``. This is helpful to avoid +having to run all the tests. + +If you work on Windows, you might have some issues with the working +directory if you use ``pytest``, while ``python -m pytest ...`` should +be more robust. + Once you are satisfied with your changes, you can run all the tests to make sure that your change did not break code elsewhere: .. code:: sh + pytest -s skrub/tests Finally, sync your changes with the remote repository and wait for CI to run. @@ -229,6 +277,21 @@ the docstrings. Check for possible problems by running pytest skrub/path/to/file + +Formatting and pre-commit checks +~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + +Formatting the code well helps with code development and maintenance, +which why is skrub requires that all commits follow a specific set of +formatting rules to ensure code quality. + +Luckily, these checks are performed automatically by the ``pre-commit`` +tool (`pre-commit docs `__) before any commit +can be pushed. Something worth noting is that if the ``pre-commit`` +hooks format some files, the commit will be canceled: you will have to +stage the changes made by ``pre-commit`` and commit again. + + Submitting your code ^^^^^^^^^^^^^^^^^^^^ @@ -237,17 +300,10 @@ a PR by clicking the "Compare & pull request" button on GitHub, targeting the skrub repository. -Integration -^^^^^^^^^^^ - -Community consensus is key in the integration process. Expect a minimum -of 1 to 3 reviews depending on the size of the change before we consider -merging the PR. - -Please be mindful that maintainers are volunteers, so review times may vary. - Continuous Integration (CI) ~~~~~~~~~~~~~~~~~~~~~~~~~~~ +After creating your PR, CI tools will run proceed to run all the tests on all +configurations supported by skrub. - **Github Actions**: Used for testing skrub across various platforms (Linux, macOS, Windows) @@ -273,18 +329,30 @@ actions are taken. Note that by default the documentation is built, but only the examples that are directly modified by the pull request are executed. -- If the remote repository was changed, you might need to run - ``pre-commit run --all-files`` to make sure that the formatting is - correct. -- If a specific test environment fails, it is possible to run the tests - in the environment that is failing by using pixi. For example if the - env is ``ci-py309-min-optional-deps``, it is possible to replicate it - using the following command: +CI is testing all possible configurations supported by skrub, so tests may fail +with configurations different from what you are developing with. If this is the +case, it is possible to run the tests in the environment that is failing by +using pixi. For example if the env is ``ci-py309-min-optional-deps``, it is +possible to replicate it using the following command: .. code:: sh pixi run -e ci-py309-min-optional-deps pytest skrub/tests/path/to/test +This command downloads the specific environment on the machine, so you can test +it locally and apply fixes, or have a clearer idea of where the code is failing +to discuss with the maintainers. + +Finally, if the remote repository was changed, you might need to run + ``pre-commit run --all-files`` to make sure that the formatting is + correct. + +Integration +^^^^^^^^^^^ + +Community consensus is key in the integration process. Expect a minimum +of 1 to 3 reviews depending on the size of the change before we consider +merging the PR. Building the documentation diff --git a/README.rst b/README.rst index c4a4b6fb1..e86152c67 100644 --- a/README.rst +++ b/README.rst @@ -32,7 +32,8 @@ The goal of skrub is to bridge the gap between tabular data sources and machine- skrub provides high-level tools for joining dataframes (``Joiner``, ``AggJoiner``, ...), encoding columns (``MinHashEncoder``, ``ToCategorical``, ...), building a pipeline -(``TableVectorizer``, ``tabular_learner``, ...), and more. +(``TableVectorizer``, ``tabular_learner``, ...), and exploring interactively your data (``TableReport``). + >>> from skrub.datasets import fetch_employee_salaries >>> dataset = fetch_employee_salaries() @@ -69,5 +70,8 @@ The best way to support the development of skrub is to spread the word! Also, if you already are a skrub user, we would love to hear about your use cases and challenges in the `Discussions `_ section. To report a bug or suggest enhancements, please -`open an issue `_ and/or -`submit a pull request `_. +`open an issue `_. + +If you want to contribute directly to the library, then check the +`how to contribute `_ page on +the website for more information. diff --git a/benchmarks/bench_minhash_batch_number.py b/benchmarks/bench_minhash_batch_number.py index 6fda60f5b..e3554b67f 100644 --- a/benchmarks/bench_minhash_batch_number.py +++ b/benchmarks/bench_minhash_batch_number.py @@ -15,11 +15,9 @@ import numpy as np import pandas as pd import seaborn as sns -import sklearn from joblib import Parallel, delayed, effective_n_jobs from sklearn.base import BaseEstimator, TransformerMixin from sklearn.utils import gen_even_slices, murmurhash3_32 -from sklearn.utils.fixes import parse_version from utils import default_parser, find_result, monitor from skrub._fast_hash import ngram_min_hash @@ -34,11 +32,6 @@ # flake8: noqa: E501 -sklearn_below_1_6 = parse_version( - parse_version(sklearn.__version__).base_version -) < parse_version("1.6") - - class MinHashEncoder(BaseEstimator, TransformerMixin): """ Encode string categorical features as a numeric array, minhash method @@ -133,20 +126,16 @@ def __init__( self.batch_per_job = batch_per_job self.n_jobs = n_jobs - if sklearn_below_1_6: - - def _more_tags(self): - """ - Used internally by sklearn to ease the estimator checks. - """ - return {"X_types": ["categorical"]} - - else: + def _more_tags(self): + """ + Used internally by sklearn to ease the estimator checks. + """ + return {"X_types": ["categorical"]} - def __sklearn_tags__(self): - tags = super().__sklearn_tags__() - tags.input_tags.categorical = True - return tags + def __sklearn_tags__(self): + tags = super().__sklearn_tags__() + tags.input_tags.categorical = True + return tags def _get_murmur_hash(self, string): """ diff --git a/doc/version.json b/doc/version.json index d038b7f0c..e5470d869 100644 --- a/doc/version.json +++ b/doc/version.json @@ -5,8 +5,8 @@ "url": "https://skrub-data.org/dev/" }, { - "name": "0.4.0 (stable)", - "version": "0.4.0", + "name": "0.4.1 (stable)", + "version": "0.4.1", "url": "https://skrub-data.org/stable/", "preferred": true } diff --git a/skrub/_dataframe/tests/test_common.py b/skrub/_dataframe/tests/test_common.py index b95998767..a9830ea05 100644 --- a/skrub/_dataframe/tests/test_common.py +++ b/skrub/_dataframe/tests/test_common.py @@ -500,13 +500,11 @@ def test_to_datetime(df_module): s = df_module.make_column("", ["01/02/2020", "02/01/2021", "bad"]) with pytest.raises(ValueError): ns.to_datetime(s, "%m/%d/%Y", True) - df_module.assert_column_equal( - ns.to_datetime(s, "%m/%d/%Y", False), - df_module.make_column("", [datetime(2020, 1, 2), datetime(2021, 2, 1), None]), + assert ns.to_list(ns.to_datetime(s, "%m/%d/%Y", False)) == ns.to_list( + df_module.make_column("", [datetime(2020, 1, 2), datetime(2021, 2, 1), None]) ) - df_module.assert_column_equal( - ns.to_datetime(s, "%d/%m/%Y", False), - df_module.make_column("", [datetime(2020, 2, 1), datetime(2021, 1, 2), None]), + assert ns.to_list(ns.to_datetime(s, "%d/%m/%Y", False)) == ns.to_list( + df_module.make_column("", [datetime(2020, 2, 1), datetime(2021, 1, 2), None]) ) dt_col = ns.col(df_module.example_dataframe, "datetime-col") assert ns.to_datetime(dt_col, None) is dt_col diff --git a/skrub/_datetime_encoder.py b/skrub/_datetime_encoder.py index e62033932..fea47ded0 100644 --- a/skrub/_datetime_encoder.py +++ b/skrub/_datetime_encoder.py @@ -1,8 +1,6 @@ from datetime import datetime, timezone import pandas as pd -import sklearn -from sklearn.utils.fixes import parse_version from sklearn.utils.validation import check_is_fitted try: @@ -13,6 +11,7 @@ from . import _dataframe as sbd from ._dispatch import dispatch from ._on_each_column import RejectColumn, SingleColumnTransformer +from ._sklearn_compat import TransformerTags __all__ = ["DatetimeEncoder"] @@ -28,11 +27,6 @@ ] -sklearn_below_1_6 = parse_version( - parse_version(sklearn.__version__).base_version -) < parse_version("1.6") - - @dispatch def _is_date(col): raise NotImplementedError() @@ -134,7 +128,7 @@ class DatetimeEncoder(SingleColumnTransformer): 0 2024-05-13 12:05:36 1 NaT 2 2024-05-15 13:46:02 - Name: login, dtype: datetime64[ns] + Name: login, dtype: datetime64[...] >>> from skrub import DatetimeEncoder >>> DatetimeEncoder().fit_transform(login) @@ -237,7 +231,7 @@ class DatetimeEncoder(SingleColumnTransformer): 0 2024-05-13 07:05:36-03:00 1 NaT 2 2024-05-15 08:46:02-03:00 - Name: login, dtype: datetime64[ns, America/Sao_Paulo] + Name: login, dtype: datetime64[..., America/Sao_Paulo] >>> encoder.transform(login_sp)['login_hour'] 0 7.0 1 NaN @@ -331,16 +325,10 @@ def _check_params(self): f"'resolution' options are {allowed}, got {self.resolution!r}." ) - if sklearn_below_1_6: - - def _more_tags(self): - return {"preserves_dtype": []} - - else: - - def __sklearn_tags__(self): - tags = super().__sklearn_tags__() - from sklearn.utils import TransformerTags + def _more_tags(self): + return {"preserves_dtype": []} - tags.transformer_tags = TransformerTags() - return tags + def __sklearn_tags__(self): + tags = super().__sklearn_tags__() + tags.transformer_tags = TransformerTags(preserves_dtype=[]) + return tags diff --git a/skrub/_deduplicate.py b/skrub/_deduplicate.py index ce6426bd4..3aeb5dc7b 100644 --- a/skrub/_deduplicate.py +++ b/skrub/_deduplicate.py @@ -72,7 +72,7 @@ def _guess_clusters(Z, distance_mat, n_jobs=None): int number of clusters that maximize the silhouette score. """ - max_clusters = distance_mat.shape[0] + max_clusters = Z.shape[0] n_clusters = np.arange(2, max_clusters) # silhouette score needs a redundant distance matrix redundant_dist = squareform(distance_mat) diff --git a/skrub/_fixes.py b/skrub/_fixes.py deleted file mode 100644 index 49ee8a219..000000000 --- a/skrub/_fixes.py +++ /dev/null @@ -1,19 +0,0 @@ -import sklearn -from sklearn.utils.fixes import parse_version - -sklearn_version = parse_version(parse_version(sklearn.__version__).base_version) - - -if sklearn_version < parse_version("1.6"): - from sklearn.utils._tags import _safe_tags as get_tags # noqa -else: - from sklearn.utils import get_tags # noqa - - -def _check_n_features(estimator, X, *, reset): - if hasattr(estimator, "_check_n_features"): - estimator._check_n_features(X, reset=reset) - else: - from sklearn.utils.validation import _check_n_features - - _check_n_features(estimator, X, reset=reset) diff --git a/skrub/_interpolation_joiner.py b/skrub/_interpolation_joiner.py index 14cbef4a1..d5cdd1cfe 100644 --- a/skrub/_interpolation_joiner.py +++ b/skrub/_interpolation_joiner.py @@ -1,5 +1,4 @@ import warnings -from dataclasses import is_dataclass import joblib import numpy as np @@ -12,8 +11,8 @@ from . import _dataframe as sbd from . import _join_utils, _utils from . import _selectors as s -from ._fixes import get_tags from ._minhash_encoder import MinHashEncoder +from ._sklearn_compat import get_tags from ._table_vectorizer import TableVectorizer DEFAULT_REGRESSOR = HistGradientBoostingRegressor() @@ -404,14 +403,7 @@ def _get_assignments_for_estimator(table, estimator): def _handles_multioutput(estimator): - tags = get_tags(estimator) - if isinstance(tags, dict): - # scikit-learn < 1.6 - return tags.get("multioutput", False) - elif is_dataclass(tags): - # scikit-learn >= 1.6 - return tags.target_tags.multi_output - return False + return get_tags(estimator).target_tags.multi_output def _fit(key_values, target_table, estimator, propagate_exceptions): diff --git a/skrub/_on_each_column.py b/skrub/_on_each_column.py index 68705d6c2..fd736d6e6 100644 --- a/skrub/_on_each_column.py +++ b/skrub/_on_each_column.py @@ -38,7 +38,7 @@ class RejectColumn(ValueError): >>> df = pd.DataFrame(dict(a=['2020-02-02'], b=[12.5])) >>> ToDatetime().fit_transform(df['a']) 0 2020-02-02 - Name: a, dtype: datetime64[ns] + Name: a, dtype: datetime64[...] >>> ToDatetime().fit_transform(df['b']) Traceback (most recent call last): ... @@ -340,7 +340,7 @@ class OnEachColumn(TransformerMixin, BaseEstimator): dtype: object >>> ToDatetime().fit_transform(df["birthday"]) 0 2024-01-29 - Name: birthday, dtype: datetime64[ns] + Name: birthday, dtype: datetime64[...] >>> ToDatetime().fit_transform(df["city"]) Traceback (most recent call last): ... @@ -373,7 +373,7 @@ class OnEachColumn(TransformerMixin, BaseEstimator): datetime column. >>> transformed.dtypes - birthday datetime64[ns] + birthday datetime64[...] city object dtype: object >>> to_datetime.transformers_ diff --git a/skrub/_reporting/_utils.py b/skrub/_reporting/_utils.py index b5f0e802b..c2962fc72 100644 --- a/skrub/_reporting/_utils.py +++ b/skrub/_reporting/_utils.py @@ -43,8 +43,8 @@ def quantiles(column): def ellide_string(s, max_len=30): """Shorten a string so it can be used as a plot axis title or label.""" - if not isinstance(s, str): - return s + s = str(s) + # normalize whitespace s = re.sub(r"\s+", " ", s) if len(s) <= max_len: diff --git a/skrub/_reporting/tests/test_utils.py b/skrub/_reporting/tests/test_utils.py index 10f7c3e82..06c4759e4 100644 --- a/skrub/_reporting/tests/test_utils.py +++ b/skrub/_reporting/tests/test_utils.py @@ -11,7 +11,7 @@ @pytest.mark.parametrize( "s_in, s_out", [ - (1, 1), + (1, "1"), ("aa", "aa"), ("a\na", "a a"), ("a" * 70, "a" * 30 + "…\u200e"), @@ -55,6 +55,16 @@ def test_ellide_string_empty(): assert _utils.ellide_string(" a", 1) == "…" +def test_ellide_non_string(): + # non-regression for #1195: objects in columns must be converted to strings + # before elliding and plotting + class A: + def __repr__(self): + return "one\ntwo\nthree" + + assert _utils.ellide_string(A()) == "one two three" + + @pytest.mark.parametrize( "n_in, n_out", [ diff --git a/skrub/_selectors/_selectors.py b/skrub/_selectors/_selectors.py index 08fd58a94..a3a9a8a80 100644 --- a/skrub/_selectors/_selectors.py +++ b/skrub/_selectors/_selectors.py @@ -312,8 +312,8 @@ def any_date(): 0 2020-03-02 10:30:00 2020-03-02 10:30:00+00:00 2020-03-02 10:30:00 >>> df.dtypes - dt datetime64[ns] - tzdt datetime64[ns, UTC] + dt datetime64[...] + tzdt datetime64[..., UTC] str_ object dtype: object diff --git a/skrub/_similarity_encoder.py b/skrub/_similarity_encoder.py index c29b53637..d67b831aa 100644 --- a/skrub/_similarity_encoder.py +++ b/skrub/_similarity_encoder.py @@ -3,6 +3,7 @@ which encodes similarity instead of equality of values. """ + import numpy as np import pandas as pd import sklearn @@ -13,18 +14,13 @@ from sklearn.utils.fixes import parse_version from sklearn.utils.validation import check_is_fitted -from ._fixes import _check_n_features +from ._sklearn_compat import _check_n_features from ._string_distances import get_ngram_count, preprocess # Ignore lines too long, first docstring lines can't be cut # flake8: noqa: E501 -sklearn_below_1_6 = parse_version( - parse_version(sklearn.__version__).base_version -) < parse_version("1.6") - - def _ngram_similarity_one_sample_inplace( x_count_vector, vocabulary_count_matrix, @@ -555,26 +551,22 @@ def _ngram_similarity_fast( return np.nan_to_num(out, copy=False) - if sklearn_below_1_6: - - def _more_tags(self): - return { - "X_types": ["2darray", "categorical", "string"], - "preserves_dtype": [], - "allow_nan": True, - "_xfail_checks": { - "check_estimator_sparse_data": ( - "Cannot create sparse matrix with strings." - ), - "check_estimators_dtypes": "We only support string dtypes.", - }, - } - - else: - - def __sklearn_tags__(self): - tags = super().__sklearn_tags__() - tags.input_tags.categorical = True - tags.input_tags.string = True - tags.transformer_tags.preserves_dtype = [] - return tags + def _more_tags(self): + return { + "X_types": ["2darray", "categorical", "string"], + "preserves_dtype": [], + "allow_nan": True, + "_xfail_checks": { + "check_estimator_sparse_data": ( + "Cannot create sparse matrix with strings." + ), + "check_estimators_dtypes": "We only support string dtypes.", + }, + } + + def __sklearn_tags__(self): + tags = super().__sklearn_tags__() + tags.input_tags.categorical = True + tags.input_tags.string = True + tags.transformer_tags.preserves_dtype = [] + return tags diff --git a/skrub/_sklearn_compat.py b/skrub/_sklearn_compat.py new file mode 100644 index 000000000..676521cb7 --- /dev/null +++ b/skrub/_sklearn_compat.py @@ -0,0 +1,572 @@ +"""Ease developer experience to support multiple versions of scikit-learn. + +This file is intended to be vendored in your project if you do not want to depend on +`sklearn-compat` as a package. Then, you can import directly from this file. + +Be aware that depending on `sklearn-compat` does not add any additional dependencies: +we are only depending on `scikit-learn`. + +Version: 0.1.0 +""" + +from __future__ import annotations + +import platform +import sys +from dataclasses import dataclass, field + +import sklearn +from sklearn.utils._param_validation import validate_parameter_constraints +from sklearn.utils.fixes import parse_version + +sklearn_version = parse_version(parse_version(sklearn.__version__).base_version) + + +######################################################################################## +# The following code does not depend on the sklearn version +######################################################################################## + + +# parameters validation +class ParamsValidationMixin: + """Mixin class to validate parameters.""" + + def _validate_params(self): + """Validate types and values of constructor parameters. + + The expected type and values must be defined in the `_parameter_constraints` + class attribute, which is a dictionary `param_name: list of constraints`. See + the docstring of `validate_parameter_constraints` for a description of the + accepted constraints. + """ + if hasattr(self, "_parameter_constraints"): + validate_parameter_constraints( + self._parameter_constraints, + self.get_params(deep=False), + caller_name=self.__class__.__name__, + ) + + +# tags infrastructure +def _dataclass_args(): + if sys.version_info < (3, 10): + return {} + return {"slots": True} + + +def get_tags(estimator): + """Get estimator tags in a consistent format across different sklearn versions. + + This function provides compatibility between sklearn versions before and after 1.6. + It returns either a Tags object (sklearn >= 1.6) or a converted Tags object from + the dictionary format (sklearn < 1.6) containing metadata about the estimator's + requirements and capabilities. + + Parameters + ---------- + estimator : estimator object + A scikit-learn estimator instance. + + Returns + ------- + tags : Tags + An object containing metadata about the estimator's requirements and + capabilities (e.g., input types, fitting requirements, classifier/regressor + specific tags). + """ + try: + from sklearn.utils._tags import get_tags + + return get_tags(estimator) + except ImportError: + from sklearn.utils._tags import _safe_tags + + return _to_new_tags(_safe_tags(estimator), estimator) + + +def _to_new_tags(old_tags, estimator=None): + """Utility function convert old tags (dictionary) to new tags (dataclass).""" + input_tags = InputTags( + one_d_array="1darray" in old_tags["X_types"], + two_d_array="2darray" in old_tags["X_types"], + three_d_array="3darray" in old_tags["X_types"], + sparse="sparse" in old_tags["X_types"], + categorical="categorical" in old_tags["X_types"], + string="string" in old_tags["X_types"], + dict="dict" in old_tags["X_types"], + positive_only=old_tags["requires_positive_X"], + allow_nan=old_tags["allow_nan"], + pairwise=old_tags["pairwise"], + ) + target_tags = TargetTags( + required=old_tags["requires_y"], + one_d_labels="1dlabels" in old_tags["X_types"], + two_d_labels="2dlabels" in old_tags["X_types"], + positive_only=old_tags["requires_positive_y"], + multi_output=old_tags["multioutput"] or old_tags["multioutput_only"], + single_output=not old_tags["multioutput_only"], + ) + if estimator is not None and ( + hasattr(estimator, "transform") or hasattr(estimator, "fit_transform") + ): + transformer_tags = TransformerTags( + preserves_dtype=old_tags["preserves_dtype"], + ) + else: + transformer_tags = None + estimator_type = getattr(estimator, "_estimator_type", None) + if estimator_type == "classifier": + classifier_tags = ClassifierTags( + poor_score=old_tags["poor_score"], + multi_class=not old_tags["binary_only"], + multi_label=old_tags["multilabel"], + ) + else: + classifier_tags = None + if estimator_type == "regressor": + regressor_tags = RegressorTags( + poor_score=old_tags["poor_score"], + multi_label=old_tags["multilabel"], + ) + else: + regressor_tags = None + return Tags( + estimator_type=estimator_type, + target_tags=target_tags, + transformer_tags=transformer_tags, + classifier_tags=classifier_tags, + regressor_tags=regressor_tags, + input_tags=input_tags, + array_api_support=old_tags.get("array_api_support", False), + no_validation=old_tags["no_validation"], + non_deterministic=old_tags["non_deterministic"], + requires_fit=old_tags["requires_fit"], + _skip_test=old_tags["_skip_test"], + ) + + +######################################################################################## +# Upgrading for scikit-learn 1.4 +######################################################################################## + + +if sklearn_version < parse_version("1.4"): + + def _is_fitted(estimator, attributes=None, all_or_any=all): + """Determine if an estimator is fitted + + Parameters + ---------- + estimator : estimator instance + Estimator instance for which the check is performed. + + attributes : str, list or tuple of str, default=None + Attribute name(s) given as string or a list/tuple of strings + Eg.: ``["coef_", "estimator_", ...], "coef_"`` + + If `None`, `estimator` is considered fitted if there exist an + attribute that ends with a underscore and does not start with double + underscore. + + all_or_any : callable, {all, any}, default=all + Specify whether all or any of the given attributes must exist. + + Returns + ------- + fitted : bool + Whether the estimator is fitted. + """ + if attributes is not None: + if not isinstance(attributes, (list, tuple)): + attributes = [attributes] + return all_or_any([hasattr(estimator, attr) for attr in attributes]) + + if hasattr(estimator, "__sklearn_is_fitted__"): + return estimator.__sklearn_is_fitted__() + + fitted_attrs = [ + v for v in vars(estimator) if v.endswith("_") and not v.startswith("__") + ] + return len(fitted_attrs) > 0 + +else: + from sklearn.utils.validation import _is_fitted # noqa: F401 + + +######################################################################################## +# Upgrading for scikit-learn 1.5 +######################################################################################## + + +if sklearn_version < parse_version("1.5"): + # chunking + # extmath + # fixes + from sklearn.utils import ( + _IS_32BIT, # noqa: F401 + _approximate_mode, # noqa: F401 + _in_unstable_openblas_configuration, # noqa: F401 + gen_batches, # noqa: F401 + gen_even_slices, # noqa: F401 + get_chunk_n_rows, # noqa: F401 + safe_sqr, # noqa: F401 + ) + from sklearn.utils import _chunk_generator as chunk_generator # noqa: F401 + + _IS_WASM = platform.machine() in ["wasm32", "wasm64"] + # indexing + # mask + # missing + # optional dependencies + # user interface + # validation + from sklearn.utils import ( + _determine_key_type, # noqa: F401 + _get_column_indices, # noqa: F401 + _print_elapsed_time, # noqa: F401 + _safe_assign, # noqa: F401 + _safe_indexing, # noqa: F401 + _to_object_array, # noqa: F401 + axis0_safe_slice, # noqa: F401 + check_matplotlib_support, # noqa: F401 + check_pandas_support, # noqa: F401 + indices_to_mask, # noqa: F401 + is_scalar_nan, # noqa: F401 + resample, # noqa: F401 + safe_mask, # noqa: F401 + shuffle, # noqa: F401 + ) + from sklearn.utils import _is_pandas_na as is_pandas_na # noqa: F401 +else: + # chunking + from sklearn.utils._chunking import ( + chunk_generator, # noqa: F401 + gen_batches, # noqa: F401 + gen_even_slices, # noqa: F401 + get_chunk_n_rows, # noqa: F401 + ) + + # indexing + from sklearn.utils._indexing import ( + _determine_key_type, # noqa: F401 + _get_column_indices, # noqa: F401 + _safe_assign, # noqa: F401 + _safe_indexing, # noqa: F401 + resample, # noqa: F401 + shuffle, # noqa: F401 + ) + + # mask + from sklearn.utils._mask import ( + axis0_safe_slice, # noqa: F401 + indices_to_mask, # noqa: F401 + safe_mask, # noqa: F401 + ) + + # missing + from sklearn.utils._missing import ( + is_pandas_na, # noqa: F401 + is_scalar_nan, # noqa: F401 + ) + + # optional dependencies + from sklearn.utils._optional_dependencies import ( # noqa: F401 + check_matplotlib_support, + check_pandas_support, # noqa: F401 + ) + + # user interface + from sklearn.utils._user_interface import _print_elapsed_time # noqa: F401 + + # extmath + from sklearn.utils.extmath import ( + _approximate_mode, # noqa: F401 + safe_sqr, # noqa: F401 + ) + + # fixes + from sklearn.utils.fixes import ( + _IS_32BIT, # noqa: F401 + _IS_WASM, # noqa: F401 + _in_unstable_openblas_configuration, # noqa: F401 + ) + + # validation + from sklearn.utils.validation import _to_object_array # noqa: F401 + +######################################################################################## +# Upgrading for scikit-learn 1.6 +######################################################################################## + + +if sklearn_version < parse_version("1.6"): + # test_common + from sklearn.utils.estimator_checks import _construct_instance + + def _construct_instances(Estimator): + yield _construct_instance(Estimator) + + # validation + def validate_data(_estimator, /, **kwargs): + if "ensure_all_finite" in kwargs: + force_all_finite = kwargs.pop("ensure_all_finite") + else: + force_all_finite = True + return _estimator._validate_data(**kwargs, force_all_finite=force_all_finite) + + def _check_n_features(estimator, X, *, reset): + return estimator._check_n_features(X, reset=reset) + + def _check_feature_names(estimator, X, *, reset): + return estimator._check_feature_names(X, reset=reset) + + # tags infrastructure + @dataclass(**_dataclass_args()) + class InputTags: + """Tags for the input data. + + Parameters + ---------- + one_d_array : bool, default=False + Whether the input can be a 1D array. + + two_d_array : bool, default=True + Whether the input can be a 2D array. Note that most common + tests currently run only if this flag is set to ``True``. + + three_d_array : bool, default=False + Whether the input can be a 3D array. + + sparse : bool, default=False + Whether the input can be a sparse matrix. + + categorical : bool, default=False + Whether the input can be categorical. + + string : bool, default=False + Whether the input can be an array-like of strings. + + dict : bool, default=False + Whether the input can be a dictionary. + + positive_only : bool, default=False + Whether the estimator requires positive X. + + allow_nan : bool, default=False + Whether the estimator supports data with missing values encoded as `np.nan`. + + pairwise : bool, default=False + This boolean attribute indicates whether the data (`X`), + :term:`fit` and similar methods consists of pairwise measures + over samples rather than a feature representation for each + sample. It is usually `True` where an estimator has a + `metric` or `affinity` or `kernel` parameter with value + 'precomputed'. Its primary purpose is to support a + :term:`meta-estimator` or a cross validation procedure that + extracts a sub-sample of data intended for a pairwise + estimator, where the data needs to be indexed on both axes. + Specifically, this tag is used by + `sklearn.utils.metaestimators._safe_split` to slice rows and + columns. + """ + + one_d_array: bool = False + two_d_array: bool = True + three_d_array: bool = False + sparse: bool = False + categorical: bool = False + string: bool = False + dict: bool = False + positive_only: bool = False + allow_nan: bool = False + pairwise: bool = False + + @dataclass(**_dataclass_args()) + class TargetTags: + """Tags for the target data. + + Parameters + ---------- + required : bool + Whether the estimator requires y to be passed to `fit`, + `fit_predict` or `fit_transform` methods. The tag is ``True`` + for estimators inheriting from `~sklearn.base.RegressorMixin` + and `~sklearn.base.ClassifierMixin`. + + one_d_labels : bool, default=False + Whether the input is a 1D labels (y). + + two_d_labels : bool, default=False + Whether the input is a 2D labels (y). + + positive_only : bool, default=False + Whether the estimator requires a positive y (only applicable + for regression). + + multi_output : bool, default=False + Whether a regressor supports multi-target outputs or a classifier supports + multi-class multi-output. + + single_output : bool, default=True + Whether the target can be single-output. This can be ``False`` if the + estimator supports only multi-output cases. + """ + + required: bool + one_d_labels: bool = False + two_d_labels: bool = False + positive_only: bool = False + multi_output: bool = False + single_output: bool = True + + @dataclass(**_dataclass_args()) + class TransformerTags: + """Tags for the transformer. + + Parameters + ---------- + preserves_dtype : list[str], default=["float64"] + Applies only on transformers. It corresponds to the data types + which will be preserved such that `X_trans.dtype` is the same + as `X.dtype` after calling `transformer.transform(X)`. If this + list is empty, then the transformer is not expected to + preserve the data type. The first value in the list is + considered as the default data type, corresponding to the data + type of the output when the input data type is not going to be + preserved. + """ + + preserves_dtype: list[str] = field(default_factory=lambda: ["float64"]) + + @dataclass(**_dataclass_args()) + class ClassifierTags: + """Tags for the classifier. + + Parameters + ---------- + poor_score : bool, default=False + Whether the estimator fails to provide a "reasonable" test-set + score, which currently for classification is an accuracy of + 0.83 on ``make_blobs(n_samples=300, random_state=0)``. The + datasets and values are based on current estimators in scikit-learn + and might be replaced by something more systematic. + + multi_class : bool, default=True + Whether the classifier can handle multi-class + classification. Note that all classifiers support binary + classification. Therefore this flag indicates whether the + classifier is a binary-classifier-only or not. + + multi_label : bool, default=False + Whether the classifier supports multi-label output. + """ + + poor_score: bool = False + multi_class: bool = True + multi_label: bool = False + + @dataclass(**_dataclass_args()) + class RegressorTags: + """Tags for the regressor. + + Parameters + ---------- + poor_score : bool, default=False + Whether the estimator fails to provide a "reasonable" test-set + score, which currently for regression is an R2 of 0.5 on + ``make_regression(n_samples=200, n_features=10, + n_informative=1, bias=5.0, noise=20, random_state=42)``. The + dataset and values are based on current estimators in scikit-learn + and might be replaced by something more systematic. + + multi_label : bool, default=False + Whether the regressor supports multilabel output. + """ + + poor_score: bool = False + multi_label: bool = False + + @dataclass(**_dataclass_args()) + class Tags: + """Tags for the estimator. + + See :ref:`estimator_tags` for more information. + + Parameters + ---------- + estimator_type : str or None + The type of the estimator. Can be one of: + - "classifier" + - "regressor" + - "transformer" + - "clusterer" + - "outlier_detector" + - "density_estimator" + + target_tags : :class:`TargetTags` + The target(y) tags. + + transformer_tags : :class:`TransformerTags` or None + The transformer tags. + + classifier_tags : :class:`ClassifierTags` or None + The classifier tags. + + regressor_tags : :class:`RegressorTags` or None + The regressor tags. + + array_api_support : bool, default=False + Whether the estimator supports Array API compatible inputs. + + no_validation : bool, default=False + Whether the estimator skips input-validation. This is only meant for + stateless and dummy transformers! + + non_deterministic : bool, default=False + Whether the estimator is not deterministic given a fixed ``random_state``. + + requires_fit : bool, default=True + Whether the estimator requires to be fitted before calling one of + `transform`, `predict`, `predict_proba`, or `decision_function`. + + _skip_test : bool, default=False + Whether to skip common tests entirely. Don't use this unless + you have a *very good* reason. + + input_tags : :class:`InputTags` + The input data(X) tags. + """ + + estimator_type: str | None + target_tags: TargetTags + transformer_tags: TransformerTags | None = None + classifier_tags: ClassifierTags | None = None + regressor_tags: RegressorTags | None = None + array_api_support: bool = False + no_validation: bool = False + non_deterministic: bool = False + requires_fit: bool = True + _skip_test: bool = False + input_tags: InputTags = field(default_factory=InputTags) + +else: + # test_common + # tags infrastructure + from sklearn.utils import ( + ClassifierTags, + InputTags, + RegressorTags, + Tags, + TargetTags, + TransformerTags, + ) + from sklearn.utils._test_common.instance_generator import ( + _construct_instances, # noqa: F401 + ) + + # validation + from sklearn.utils.validation import ( + _check_feature_names, # noqa: F401 + _check_n_features, # noqa: F401 + validate_data, # noqa: F401 + ) diff --git a/skrub/_table_vectorizer.py b/skrub/_table_vectorizer.py index 78463f997..4e58cc554 100644 --- a/skrub/_table_vectorizer.py +++ b/skrub/_table_vectorizer.py @@ -3,12 +3,10 @@ from typing import Iterable import numpy as np -import sklearn from sklearn.base import BaseEstimator, TransformerMixin, clone from sklearn.pipeline import make_pipeline from sklearn.preprocessing import OneHotEncoder from sklearn.utils._estimator_html_repr import _VisualBlock -from sklearn.utils.fixes import parse_version from sklearn.utils.validation import check_is_fitted from . import _dataframe as sbd @@ -30,11 +28,6 @@ __all__ = ["TableVectorizer"] -sklearn_below_1_6 = parse_version( - parse_version(sklearn.__version__).base_version -) < parse_version("1.6") - - class PassThrough(SingleColumnTransformer): def fit_transform(self, column, y=None): return column @@ -665,27 +658,23 @@ def _sk_visual_block_(self): # scikit-learn compatibility - if sklearn_below_1_6: - - def _more_tags(self): - """ - Used internally by sklearn to ease the estimator checks. - """ - return { - "X_types": ["2darray", "string"], - "allow_nan": [True], - "_xfail_checks": { - "check_complex_data": "Passthrough complex columns as-is.", - }, - } - - else: + def _more_tags(self): + """ + Used internally by sklearn to ease the estimator checks. + """ + return { + "X_types": ["2darray", "string"], + "allow_nan": [True], + "_xfail_checks": { + "check_complex_data": "Passthrough complex columns as-is.", + }, + } - def __sklearn_tags__(self): - tags = super().__sklearn_tags__() - tags.input_tags.string = True - tags.input_tags.allow_nan = True - return tags + def __sklearn_tags__(self): + tags = super().__sklearn_tags__() + tags.input_tags.string = True + tags.input_tags.allow_nan = True + return tags def get_feature_names_out(self): """Return the column names of the output of ``transform`` as a list of strings. diff --git a/skrub/_tabular_learner.py b/skrub/_tabular_learner.py index 607406f15..2342cfa7a 100644 --- a/skrub/_tabular_learner.py +++ b/skrub/_tabular_learner.py @@ -1,5 +1,3 @@ -from dataclasses import is_dataclass - import sklearn from sklearn import ensemble from sklearn.base import BaseEstimator @@ -8,8 +6,8 @@ from sklearn.preprocessing import OrdinalEncoder, StandardScaler from sklearn.utils.fixes import parse_version -from ._fixes import get_tags from ._minhash_encoder import MinHashEncoder +from ._sklearn_compat import get_tags from ._table_vectorizer import TableVectorizer from ._to_categorical import ToCategorical @@ -273,15 +271,7 @@ def tabular_learner(estimator, *, n_jobs=None): high_cardinality=MinHashEncoder(), ) steps = [vectorizer] - try: - tags = get_tags(estimator) - if is_dataclass(tags): - allow_nan = tags.input_tags.allow_nan - else: - allow_nan = tags.get("allow_nan", False) - except TypeError: - allow_nan = False - if not allow_nan: + if not get_tags(estimator).input_tags.allow_nan: steps.append(SimpleImputer(add_indicator=True)) if not isinstance(estimator, _TREE_ENSEMBLE_CLASSES): steps.append(StandardScaler()) diff --git a/skrub/_to_datetime.py b/skrub/_to_datetime.py index 943bf4906..0a21a28f0 100644 --- a/skrub/_to_datetime.py +++ b/skrub/_to_datetime.py @@ -145,7 +145,7 @@ class ToDatetime(SingleColumnTransformer): 0 2024-05-05 13:17:52 1 NaT 2 2024-05-07 13:17:52 - Name: when, dtype: datetime64[ns] + Name: when, dtype: datetime64[...] The attributes ``format_``, ``output_dtype_``, ``output_time_zone_`` record information about the conversion result. @@ -153,7 +153,7 @@ class ToDatetime(SingleColumnTransformer): >>> to_dt.format_ '%Y-%m-%dT%H:%M:%S' >>> to_dt.output_dtype_ - dtype('>> to_dt.output_time_zone_ is None True @@ -164,7 +164,7 @@ class ToDatetime(SingleColumnTransformer): 0 2024-05-05 13:17:52 1 NaT 2 2024-05-07 13:17:52 - Name: when, dtype: datetime64[ns] + Name: when, dtype: datetime64[...] >>> ToDatetime(format="%d/%m/%Y").fit_transform(s) Traceback (most recent call last): @@ -179,7 +179,7 @@ class ToDatetime(SingleColumnTransformer): 0 2024-05-05 13:17:52+02:00 1 NaT 2 2024-05-07 13:17:52+02:00 - Name: when, dtype: datetime64[ns, Europe/Paris] + Name: when, dtype: datetime64[..., Europe/Paris] >>> to_dt.fit_transform(s) is s True @@ -188,7 +188,7 @@ class ToDatetime(SingleColumnTransformer): >>> to_dt.format_ is None True >>> to_dt.output_dtype_ - datetime64[ns, Europe/Paris] + datetime64[..., Europe/Paris] >>> to_dt.output_time_zone_ 'Europe/Paris' @@ -220,13 +220,13 @@ class ToDatetime(SingleColumnTransformer): 0 2024-05-05 13:17:52 1 NaT 2 2024-05-07 13:17:52 - Name: when, dtype: datetime64[ns] + Name: when, dtype: datetime64[...] >>> s = pd.Series(["05/05/2024", None, "07/05/2024"], name="when") >>> to_dt.transform(s) 0 NaT 1 NaT 2 NaT - Name: when, dtype: datetime64[ns] + Name: when, dtype: datetime64[...] **Time zones** @@ -237,7 +237,7 @@ class ToDatetime(SingleColumnTransformer): >>> to_dt.fit_transform(s) 0 2020-01-01 02:00:00+00:00 1 2020-01-01 01:00:00+00:00 - dtype: datetime64[ns, UTC] + dtype: datetime64[..., UTC] >>> to_dt.format_ '%Y-%m-%dT%H:%M:%S%z' >>> to_dt.output_time_zone_ @@ -249,7 +249,7 @@ class ToDatetime(SingleColumnTransformer): >>> to_dt.fit_transform(s) 0 2020-01-01 04:00:00 1 2020-01-01 04:00:00 - dtype: datetime64[ns] + dtype: datetime64[...] >>> to_dt.output_time_zone_ is None True @@ -262,10 +262,10 @@ class ToDatetime(SingleColumnTransformer): >>> s_paris 0 2024-05-07 14:24:49+02:00 1 2024-05-06 14:24:49+02:00 - dtype: datetime64[ns, Europe/Paris] + dtype: datetime64[..., Europe/Paris] >>> to_dt = ToDatetime().fit(s_paris) >>> to_dt.output_dtype_ - datetime64[ns, Europe/Paris] + datetime64[..., Europe/Paris] Here our converter is set to output datetimes with nanosecond resolution, localized in "Europe/Paris". @@ -276,7 +276,7 @@ class ToDatetime(SingleColumnTransformer): >>> s_london 0 2024-05-07 13:24:49+01:00 1 2024-05-06 13:24:49+01:00 - dtype: datetime64[ns, Europe/London] + dtype: datetime64[..., Europe/London] Here the timezone is "Europe/London" and the times are offset by 1 hour. During ``transform`` datetimes will be converted to the original dtype and the @@ -285,7 +285,7 @@ class ToDatetime(SingleColumnTransformer): >>> to_dt.transform(s_london) 0 2024-05-07 14:24:49+02:00 1 2024-05-06 14:24:49+02:00 - dtype: datetime64[ns, Europe/Paris] + dtype: datetime64[..., Europe/Paris] Moreover, we may have to transform a timezone-naive column whereas the transformer was fitted on a timezone-aware column. Note that is somewhat a @@ -296,7 +296,7 @@ class ToDatetime(SingleColumnTransformer): >>> s_naive 0 2024-05-07 12:24:49 1 2024-05-06 12:24:49 - dtype: datetime64[ns] + dtype: datetime64[...] In this case, we make the arbitrary choice to assume that the timezone-naive datetimes are in UTC. @@ -304,7 +304,7 @@ class ToDatetime(SingleColumnTransformer): >>> to_dt.transform(s_naive) 0 2024-05-07 14:24:49+02:00 1 2024-05-06 14:24:49+02:00 - dtype: datetime64[ns, Europe/Paris] + dtype: datetime64[..., Europe/Paris] Conversely, a transformer fitted on a timezone-naive column can convert timezone-aware columns. Here also, we assume the naive datetimes were in UTC. @@ -313,7 +313,7 @@ class ToDatetime(SingleColumnTransformer): >>> to_dt.transform(s_london) 0 2024-05-07 12:24:49 1 2024-05-06 12:24:49 - dtype: datetime64[ns] + dtype: datetime64[...] **``%d/%m/%Y`` vs ``%m/%d/%Y``** @@ -324,7 +324,7 @@ class ToDatetime(SingleColumnTransformer): >>> s = pd.Series(["05/23/2024"]) >>> to_dt.fit_transform(s) 0 2024-05-23 - dtype: datetime64[ns] + dtype: datetime64[...] >>> to_dt.format_ '%m/%d/%Y' @@ -334,7 +334,7 @@ class ToDatetime(SingleColumnTransformer): >>> s = pd.Series(["23/05/2024"]) >>> to_dt.fit_transform(s) 0 2024-05-23 - dtype: datetime64[ns] + dtype: datetime64[...] >>> to_dt.format_ '%d/%m/%Y' @@ -343,7 +343,7 @@ class ToDatetime(SingleColumnTransformer): >>> s = pd.Series(["03/05/2024"]) >>> to_dt.fit_transform(s) 0 2024-03-05 - dtype: datetime64[ns] + dtype: datetime64[...] >>> to_dt.format_ '%m/%d/%Y' diff --git a/skrub/_to_float32.py b/skrub/_to_float32.py index 26e8303ba..4cf1185a2 100644 --- a/skrub/_to_float32.py +++ b/skrub/_to_float32.py @@ -158,7 +158,7 @@ class ToFloat32(SingleColumnTransformer): >>> to_float.fit_transform(pd.to_datetime(pd.Series(['2024-05-13'], name='s'))) Traceback (most recent call last): ... - skrub._on_each_column.RejectColumn: Refusing to cast column 's' with dtype 'datetime64[ns]' to numbers. + skrub._on_each_column.RejectColumn: Refusing to cast column 's' with dtype 'datetime64[...]' to numbers. float32 columns are passed through: diff --git a/skrub/_to_str.py b/skrub/_to_str.py index 5cf7de2e3..e8ff19cac 100644 --- a/skrub/_to_str.py +++ b/skrub/_to_str.py @@ -100,7 +100,7 @@ class ToStr(SingleColumnTransformer): >>> to_str.fit_transform(pd.to_datetime(pd.Series(['2020-02-02']))) Traceback (most recent call last): ... - skrub._on_each_column.RejectColumn: Refusing to convert None with dtype 'datetime64[ns]' to strings. + skrub._on_each_column.RejectColumn: Refusing to convert None with dtype 'datetime64[...]' to strings. However, once a column has been accepted, the output of ``transform`` will always be strings: diff --git a/skrub/tests/test_sklearn.py b/skrub/tests/test_sklearn.py index f50e3f89a..00ee1e48c 100644 --- a/skrub/tests/test_sklearn.py +++ b/skrub/tests/test_sklearn.py @@ -12,7 +12,7 @@ SimilarityEncoder, TableVectorizer, ) -from skrub._fixes import get_tags +from skrub._sklearn_compat import get_tags def _enforce_estimator_tags_X_monkey_patch( @@ -22,120 +22,59 @@ def _enforce_estimator_tags_X_monkey_patch( having only strings with some encoders. """ tags = get_tags(estimator) - if isinstance(tags, dict): - # Estimators with `1darray` in `X_types` tag only accept - # X of shape (`n_samples`,) - if "1darray" in tags["X_types"]: - X = X[:, 0] + if tags.input_tags.one_d_array: + X = X[:, 0] + if X_test is not None: + X_test = X_test[:, 0] # pragma: no cover + # Estimators with a `requires_positive_X` tag only accept + # strictly positive data + if tags.input_tags.positive_only: + X = X - X.min() + if X_test is not None: + X_test = X_test - X_test.min() # pragma: no cover + if tags.input_tags.categorical: + X = np.round((X - X.min())) + if X_test is not None: + X_test = np.round((X_test - X_test.min())) # pragma: no cover + if tags.input_tags.string: + X = X.astype(object) + for i in range(X.shape[0]): + for j in range(X.shape[1]): + X[i, j] = str(X[i, j]) if X_test is not None: - X_test = X_test[:, 0] # pragma: no cover - # Estimators with a `requires_positive_X` tag only accept - # strictly positive data - if tags["requires_positive_X"]: - X = X - X.min() + X_test = X_test.astype(object) + for i in range(X_test.shape[0]): + for j in range(X_test.shape[1]): + X_test[i, j] = str(X_test[i, j]) + elif tags.input_tags.allow_nan: + X = X.astype(np.float64) if X_test is not None: - X_test = X_test - X_test.min() # pragma: no cover - if "categorical" in tags["X_types"]: - X = np.round((X - X.min())) + X_test = X_test.astype(np.float64) # pragma: no cover + else: + X = X.astype(np.int32) if X_test is not None: - X_test = np.round((X_test - X_test.min())) # pragma: no cover - if "string" in tags["X_types"]: - # Note: this part is the monkey patch - X = X.astype(object) - for i in range(X.shape[0]): - for j in range(X.shape[1]): - X[i, j] = str(X[i, j]) - if X_test is not None: - X_test = X_test.astype(object) - for i in range(X_test.shape[0]): - for j in range(X_test.shape[1]): - X_test[i, j] = str(X_test[i, j]) - elif tags["allow_nan"]: - X = X.astype(np.float64) - if X_test is not None: - X_test = X_test.astype(np.float64) # pragma: no cover - else: - X = X.astype(np.int32) - if X_test is not None: - X_test = X_test.astype(np.int32) # pragma: no cover + X_test = X_test.astype(np.int32) # pragma: no cover - if estimator.__class__.__name__ == "SkewedChi2Sampler": - # SkewedChi2Sampler requires X > -skewdness in transform - X = X - X.min() - if X_test is not None: - X_test = X_test - X_test.min() # pragma: no cover - - X_res = X - - # Pairwise estimators only accept - # X of shape (`n_samples`, `n_samples`) - if _is_pairwise_metric(estimator): - X_res = pairwise_distances(X, metric="euclidean") - if X_test is not None: - X_test = pairwise_distances( - X_test, X, metric="euclidean" - ) # pragma: no cover - elif tags["pairwise"]: - X_res = kernel(X, X) - if X_test is not None: - X_test = kernel(X_test, X) # pragma: no cover - else: - # scikit-learn >= 1.6 - # Estimators with `1darray` in `X_types` tag only accept - # X of shape (`n_samples`,) - if tags.input_tags.one_d_array: - X = X[:, 0] - if X_test is not None: - X_test = X_test[:, 0] # pragma: no cover - # Estimators with a `requires_positive_X` tag only accept - # strictly positive data - if tags.input_tags.positive_only: - X = X - X.min() - if X_test is not None: - X_test = X_test - X_test.min() # pragma: no cover - if tags.input_tags.categorical: - X = np.round((X - X.min())) - if X_test is not None: - X_test = np.round((X_test - X_test.min())) # pragma: no cover - if tags.input_tags.string: - X = X.astype(object) - for i in range(X.shape[0]): - for j in range(X.shape[1]): - X[i, j] = str(X[i, j]) - if X_test is not None: - X_test = X_test.astype(object) - for i in range(X_test.shape[0]): - for j in range(X_test.shape[1]): - X_test[i, j] = str(X_test[i, j]) - elif tags.input_tags.allow_nan: - X = X.astype(np.float64) - if X_test is not None: - X_test = X_test.astype(np.float64) # pragma: no cover - else: - X = X.astype(np.int32) - if X_test is not None: - X_test = X_test.astype(np.int32) # pragma: no cover - - if estimator.__class__.__name__ == "SkewedChi2Sampler": - # SkewedChi2Sampler requires X > -skewdness in transform - X = X - X.min() - if X_test is not None: - X_test = X_test - X_test.min() # pragma: no cover + if estimator.__class__.__name__ == "SkewedChi2Sampler": + # SkewedChi2Sampler requires X > -skewdness in transform + X = X - X.min() + if X_test is not None: + X_test = X_test - X_test.min() # pragma: no cover - X_res = X + X_res = X - # Pairwise estimators only accept - # X of shape (`n_samples`, `n_samples`) - if _is_pairwise_metric(estimator): - X_res = pairwise_distances(X, metric="euclidean") - if X_test is not None: - X_test = pairwise_distances( - X_test, X, metric="euclidean" - ) # pragma: no cover - elif tags.input_tags.pairwise: - X_res = kernel(X, X) - if X_test is not None: - X_test = kernel(X_test, X) # pragma: no cover + # Pairwise estimators only accept + # X of shape (`n_samples`, `n_samples`) + if _is_pairwise_metric(estimator): + X_res = pairwise_distances(X, metric="euclidean") + if X_test is not None: + X_test = pairwise_distances( + X_test, X, metric="euclidean" + ) # pragma: no cover + elif tags.input_tags.pairwise: + X_res = kernel(X, X) + if X_test is not None: + X_test = kernel(X_test, X) # pragma: no cover if X_test is not None: return X_res, X_test return X_res diff --git a/skrub/tests/test_table_vectorizer.py b/skrub/tests/test_table_vectorizer.py index a16cf17f1..ec1fc023b 100644 --- a/skrub/tests/test_table_vectorizer.py +++ b/skrub/tests/test_table_vectorizer.py @@ -235,11 +235,11 @@ def test_duplicate_column_names(): ( X, { - "pd_datetime": "datetime64[ns]", - "np_datetime": "datetime64[ns]", - "dmy-": "datetime64[ns]", - "ymd/": "datetime64[ns]", - "ymd/_hms:": "datetime64[ns]", + "pd_datetime": "datetime", + "np_datetime": "datetime", + "dmy-": "datetime", + "ymd/": "datetime", + "ymd/_hms:": "datetime", }, ), # Test other types detection @@ -285,7 +285,10 @@ def test_auto_cast(X, dict_expected_types): vectorizer = passthrough_vectorizer() X_trans = vectorizer.fit_transform(X) for col in X_trans.columns: - assert dict_expected_types[col] == X_trans[col].dtype + if dict_expected_types[col] == "datetime": + assert sbd.is_any_date(X_trans[col]) + else: + assert dict_expected_types[col] == X_trans[col].dtype def test_auto_cast_missing_categories():