From 45b3af9b8aefa4ef4cbc98c7cb75d577842e6daa Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Th=C3=A9o=20Jolivet?= Date: Tue, 21 Nov 2023 17:51:19 +0100 Subject: [PATCH] Move functions to test modules to _utils.py --- skrub/_dataframe/_namespace.py | 13 ------------- skrub/_utils.py | 8 ++++++++ skrub/tests/test_datetime_encoder.py | 2 +- skrub/tests/test_fuzzy_join.py | 2 +- skrub/tests/test_gap_encoder.py | 2 +- skrub/tests/test_interpolation_join.py | 2 +- skrub/tests/test_joiner.py | 2 +- skrub/tests/test_similarity_encoder.py | 2 +- 8 files changed, 14 insertions(+), 19 deletions(-) diff --git a/skrub/_dataframe/_namespace.py b/skrub/_dataframe/_namespace.py index d8773cdda..06c65a2ea 100644 --- a/skrub/_dataframe/_namespace.py +++ b/skrub/_dataframe/_namespace.py @@ -43,19 +43,6 @@ def is_polars(dataframe): return isinstance(dataframe, (pl.DataFrame, pl.LazyFrame)) -def is_namespace_pandas(px): - return px is pd - - -def is_namespace_polars(px): - if "polars" not in sys.modules: - return False - - import polars as pl - - return px is pl - - def get_df_namespace(*dfs): """Get the namespaces of dataframes. diff --git a/skrub/_utils.py b/skrub/_utils.py index 72b89e295..cde7be6b9 100644 --- a/skrub/_utils.py +++ b/skrub/_utils.py @@ -162,3 +162,11 @@ def atleast_2d_or_none(x): def clone_if_default(estimator, default_estimator): return clone(estimator) if estimator is default_estimator else estimator + + +def is_namespace_pandas(px): + return px.__name__ == "pandas" + + +def is_namespace_polars(px): + return px.__name__ == "polars" diff --git a/skrub/tests/test_datetime_encoder.py b/skrub/tests/test_datetime_encoder.py index 9838995f9..9bcfb85dc 100644 --- a/skrub/tests/test_datetime_encoder.py +++ b/skrub/tests/test_datetime_encoder.py @@ -7,7 +7,6 @@ from numpy.testing import assert_allclose, assert_array_equal from pandas.testing import assert_frame_equal -from skrub._dataframe._namespace import is_namespace_polars from skrub._dataframe._polars import POLARS_SETUP from skrub._datetime_encoder import ( TIME_LEVELS, @@ -15,6 +14,7 @@ _is_pandas_format_mixed_available, to_datetime, ) +from skrub._utils import is_namespace_polars MODULES = [pd] ASSERT_TUPLES = [(pd, assert_frame_equal)] diff --git a/skrub/tests/test_fuzzy_join.py b/skrub/tests/test_fuzzy_join.py index dc57a440a..e4fad1486 100644 --- a/skrub/tests/test_fuzzy_join.py +++ b/skrub/tests/test_fuzzy_join.py @@ -9,8 +9,8 @@ from sklearn.feature_extraction.text import HashingVectorizer from skrub import fuzzy_join -from skrub._dataframe._namespace import is_namespace_polars from skrub._dataframe._polars import POLARS_SETUP +from skrub._utils import is_namespace_polars MODULES = [pd] ASSERT_TUPLES = [(pd, assert_frame_equal)] diff --git a/skrub/tests/test_gap_encoder.py b/skrub/tests/test_gap_encoder.py index fa2540a26..7369aef90 100644 --- a/skrub/tests/test_gap_encoder.py +++ b/skrub/tests/test_gap_encoder.py @@ -5,8 +5,8 @@ from sklearn.model_selection import train_test_split from skrub import GapEncoder -from skrub._dataframe._namespace import is_namespace_polars from skrub._dataframe._polars import POLARS_SETUP +from skrub._utils import is_namespace_polars from skrub.datasets import fetch_midwest_survey from skrub.tests.utils import generate_data diff --git a/skrub/tests/test_interpolation_join.py b/skrub/tests/test_interpolation_join.py index c3c4cbd5a..58c74956c 100644 --- a/skrub/tests/test_interpolation_join.py +++ b/skrub/tests/test_interpolation_join.py @@ -6,8 +6,8 @@ from sklearn.neighbors import KNeighborsClassifier, KNeighborsRegressor from skrub import InterpolationJoiner -from skrub._dataframe._namespace import is_namespace_polars from skrub._dataframe._polars import POLARS_SETUP +from skrub._utils import is_namespace_polars MODULES = [pd] diff --git a/skrub/tests/test_joiner.py b/skrub/tests/test_joiner.py index f73817764..478574d95 100644 --- a/skrub/tests/test_joiner.py +++ b/skrub/tests/test_joiner.py @@ -4,8 +4,8 @@ from pandas.testing import assert_frame_equal from skrub import Joiner -from skrub._dataframe._namespace import is_namespace_polars from skrub._dataframe._polars import POLARS_SETUP +from skrub._utils import is_namespace_polars MODULES = [pd] ASSERT_TUPLES = [(pd, assert_frame_equal)] diff --git a/skrub/tests/test_similarity_encoder.py b/skrub/tests/test_similarity_encoder.py index 97a729585..ba655a5dc 100644 --- a/skrub/tests/test_similarity_encoder.py +++ b/skrub/tests/test_similarity_encoder.py @@ -7,10 +7,10 @@ from sklearn.exceptions import NotFittedError from skrub import SimilarityEncoder -from skrub._dataframe._namespace import is_namespace_polars from skrub._dataframe._polars import POLARS_SETUP from skrub._similarity_encoder import ngram_similarity_matrix from skrub._string_distances import ngram_similarity +from skrub._utils import is_namespace_polars MODULES = [pd] INPUT_TYPES = ["list", "numpy", "pandas"]