Skip to content

Commit

Permalink
Test polars support (skrub-data#826)
Browse files Browse the repository at this point in the history
* Rename polars_missing_msg in maj

* Test polars inputs in test_joiner

* Test polars inputs in test_deduplicate

* Test polars inputs in test_fuzzy_join

* Test polars inputs in test_minhash_encoder

* Rename list of tested modules to MODULES

* Test polars inputs in test_gap_encoder. Add dict of possible NULL options for pd or pl

* Test polars inputs in test_datetime_encore. Lots of tests don't pass

* Store comparison utils in list of tuples instead of dictionaries

* Adapt test_interpolation_join for polars. All tests xfail because df.drop() got an unexpected argument 'axis'

* Remove NULL dict in test_gap_encoder

* Xfail set_output to polars in test_similarity_encoder

* Create dfs with pandas, then convert them in px.df in test_interpolation_join

* Format

* Remove pl testing in test_deduplicate as it isn't dependent on it

* Format

* Create pd.DataFrames first in test_datetime_encoder

* Fix error when polars isn't available

* Create function to test if the polars module is available. Use it to xfail specific tests

* Move functions to test modules to _utils.py

* Move functions to test modules into a new _test_utils.py in _dataframe

* Rename is_namespace into is_module
  • Loading branch information
TheooJ authored Nov 22, 2023
1 parent 9eb21f8 commit b787db9
Show file tree
Hide file tree
Showing 11 changed files with 477 additions and 205 deletions.
6 changes: 6 additions & 0 deletions skrub/_dataframe/_test_utils.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,6 @@
def is_module_pandas(px):
return px.__name__ == "pandas"


def is_module_polars(px):
return px.__name__ == "polars"
4 changes: 2 additions & 2 deletions skrub/_dataframe/tests/test_polars.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,8 +22,8 @@
}
)
else:
polars_missing_msg = "Polars is not available"
pytest.skip(reason=polars_missing_msg, allow_module_level=True)
POLARS_MISSING_MSG = "Polars is not available"
pytest.skip(reason=POLARS_MISSING_MSG, allow_module_level=True)


def test_join():
Expand Down
6 changes: 3 additions & 3 deletions skrub/tests/test_agg_joiner.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,15 +21,15 @@
)


assert_tuples = [(main, pd, assert_frame_equal)]
ASSERT_TUPLES = [(main, pd, assert_frame_equal)]
if POLARS_SETUP:
assert_tuples.append((pl.DataFrame(main), pl, assert_frame_equal_pl))
ASSERT_TUPLES.append((pl.DataFrame(main), pl, assert_frame_equal_pl))


@pytest.mark.parametrize("use_X_placeholder", [False, True])
@pytest.mark.parametrize(
"X, px, assert_frame_equal_",
assert_tuples,
ASSERT_TUPLES,
)
def test_simple_fit_transform(use_X_placeholder, X, px, assert_frame_equal_):
aux = X if not use_X_placeholder else "X"
Expand Down
77 changes: 62 additions & 15 deletions skrub/tests/test_datetime_encoder.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,13 +8,25 @@
from pandas.api.types import is_datetime64_any_dtype
from pandas.testing import assert_frame_equal

from skrub._dataframe._polars import POLARS_SETUP
from skrub._dataframe._test_utils import is_module_polars
from skrub._datetime_encoder import (
TIME_LEVELS,
DatetimeEncoder,
_is_pandas_format_mixed_available,
to_datetime,
)

MODULES = [pd]
ASSERT_TUPLES = [(pd, assert_frame_equal)]

if POLARS_SETUP:
import polars as pl
from polars.testing import assert_frame_equal as assert_frame_equal_pl

MODULES.append(pl)
ASSERT_TUPLES.append((pl, assert_frame_equal_pl))

NANOSECONDS_FORMAT = (
"%Y-%m-%d %H:%M:%S.%f" if _is_pandas_format_mixed_available() else None
)
Expand Down Expand Up @@ -120,6 +132,7 @@ def get_mixed_datetime_format(as_array=False):
return df


@pytest.mark.parametrize("px", MODULES)
@pytest.mark.parametrize("as_array", [True, False])
@pytest.mark.parametrize(
"get_data_func, features, format",
Expand All @@ -136,6 +149,7 @@ def get_mixed_datetime_format(as_array=False):
)
@pytest.mark.parametrize("resolution", TIME_LEVELS)
def test_fit(
px,
as_array,
get_data_func,
features,
Expand Down Expand Up @@ -176,6 +190,7 @@ def test_fit(
assert enc.get_feature_names_out() == expected_feature_names


@pytest.mark.parametrize("px", MODULES)
@pytest.mark.parametrize(
"get_data_func, expected_datetime_columns",
[
Expand All @@ -185,16 +200,20 @@ def test_fit(
(get_mixed_type_dataframe, ["a", "e"]),
],
)
def test_to_datetime(get_data_func, expected_datetime_columns):
def test_to_datetime(px, get_data_func, expected_datetime_columns):
if is_module_polars(px):
pytest.xfail(reason="AssertionError is raised when using Polars.")
X = get_data_func()
X = to_datetime(X)
X = pd.DataFrame(X)
X = px.DataFrame(X)
datetime_columns = [col for col in X.columns if is_datetime64_any_dtype(X[col])]
assert_array_equal(datetime_columns, expected_datetime_columns)


def test_format_nan():
@pytest.mark.parametrize("px", MODULES)
def test_format_nan(px):
X = get_nan_datetime()
X = px.DataFrame(X)
enc = DatetimeEncoder().fit(X)
expected_index_to_format = {
0: "%Y-%m-%d %H:%M:%S",
Expand All @@ -204,14 +223,18 @@ def test_format_nan():
assert enc.index_to_format_ == expected_index_to_format


def test_format_nz():
@pytest.mark.parametrize("px", MODULES)
def test_format_nz(px):
X = get_tz_datetime()
X = px.DataFrame(X)
enc = DatetimeEncoder().fit(X)
assert enc.index_to_format_ == {0: "%Y-%m-%d %H:%M:%S%z"}


def test_resolution_none():
@pytest.mark.parametrize("px", MODULES)
def test_resolution_none(px):
X = get_datetime()
px.DataFrame(X)
enc = DatetimeEncoder(
resolution=None,
add_total_seconds=False,
Expand All @@ -223,8 +246,10 @@ def test_resolution_none():
assert enc.get_feature_names_out() == []


def test_transform_date():
@pytest.mark.parametrize("px", MODULES)
def test_transform_date(px):
X = get_date()
X = px.DataFrame(X)
enc = DatetimeEncoder(
add_total_seconds=False,
)
Expand All @@ -242,8 +267,10 @@ def test_transform_date():
assert_array_equal(X_trans, expected_result)


def test_transform_datetime():
@pytest.mark.parametrize("px", MODULES)
def test_transform_datetime(px):
X = get_datetime()
X = px.DataFrame(X)
enc = DatetimeEncoder(
resolution="second",
add_total_seconds=False,
Expand All @@ -260,8 +287,10 @@ def test_transform_datetime():
assert_array_equal(X_trans, expected_X_trans)


def test_transform_tz():
@pytest.mark.parametrize("px", MODULES)
def test_transform_tz(px):
X = get_tz_datetime()
X = px.DataFrame(X)
enc = DatetimeEncoder(
add_total_seconds=True,
)
Expand All @@ -277,8 +306,10 @@ def test_transform_tz():
assert_allclose(X_trans, expected_X_trans)


def test_transform_nan():
@pytest.mark.parametrize("px", MODULES)
def test_transform_nan(px):
X = get_nan_datetime()
X = px.DataFrame(X)
enc = DatetimeEncoder(
add_total_seconds=True,
)
Expand Down Expand Up @@ -341,8 +372,17 @@ def test_transform_nan():
assert_allclose(X_trans, expected_X_trans)


def test_mixed_type_dataframe():
@pytest.mark.parametrize("px", MODULES)
def test_mixed_type_dataframe(px):
if is_module_polars(px):
pytest.xfail(
reason=(
"to_datetime(X) raises polars.exceptions.ComputeError: cannot cast"
" 'Object' type"
)
)
X = get_mixed_type_dataframe()
X = px.DataFrame(X)
enc = DatetimeEncoder().fit(X)
assert enc.index_to_format_ == {0: "%Y-%m-%d", 4: "%d/%m/%Y"}

Expand All @@ -361,19 +401,23 @@ def test_mixed_type_dataframe():
assert X_dt.dtype == np.object_


def test_indempotency():
@pytest.mark.parametrize("px, assert_frame_equal_", ASSERT_TUPLES)
def test_indempotency(px, assert_frame_equal_):
df = get_mixed_datetime_format()
df = px.DataFrame(df)
df_dt = to_datetime(df)
df_dt_2 = to_datetime(df_dt)
assert_frame_equal(df_dt, df_dt_2)
assert_frame_equal_(df_dt, df_dt_2)

X_trans = DatetimeEncoder().fit_transform(df)
X_trans_2 = DatetimeEncoder().fit_transform(df_dt)
assert_array_equal(X_trans, X_trans_2)


def test_datetime_encoder_invalid_params():
@pytest.mark.parametrize("px", MODULES)
def test_datetime_encoder_invalid_params(px):
X = get_datetime()
X = px.DataFrame(X)

with pytest.raises(ValueError, match=r"(?=.*'resolution' options)"):
DatetimeEncoder(resolution="hello").fit(X)
Expand Down Expand Up @@ -437,8 +481,10 @@ def test_to_datetime_format_param():
assert_array_equal(out, expected_out)


def test_mixed_datetime_format():
@pytest.mark.parametrize("px, assert_frame_equal_", ASSERT_TUPLES)
def test_mixed_datetime_format(px, assert_frame_equal_):
df = get_mixed_datetime_format()
df = px.DataFrame(df)

df_dt = to_datetime(df)
expected_df_dt = pd.DataFrame(
Expand All @@ -451,7 +497,8 @@ def test_mixed_datetime_format():
]
)
)
assert_frame_equal(df_dt, expected_df_dt)
expected_df_dt = px.DataFrame(expected_df_dt)
assert_frame_equal_(df_dt, expected_df_dt)

series_dt = to_datetime(df["a"])
expected_series_dt = expected_df_dt["a"]
Expand Down
14 changes: 7 additions & 7 deletions skrub/tests/test_deduplicate.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,14 +18,14 @@


@pytest.mark.parametrize(
["entries_per_category", "prob_mistake_per_letter"],
[[[500, 100, 1500], 0.05], [[100, 100], 0.02], [[200, 50, 30, 200, 800], 0.01]],
"entries_per_category, prob_mistake_per_letter",
[([500, 100, 1500], 0.05), ([100, 100], 0.02), ([200, 50, 30, 200, 800], 0.01)],
)
def test_deduplicate(
entries_per_category: list[int],
prob_mistake_per_letter: float,
seed: int = 123,
) -> None:
):
rng = np.random.RandomState(seed)

# hard coded to fix ground truth string similarities
Expand Down Expand Up @@ -60,7 +60,7 @@ def test_deduplicate(
assert np.isin(unique_other_analyzer, recovered_categories).all()


def test_compute_ngram_distance() -> None:
def test_compute_ngram_distance():
words = np.array(["aac", "aaa", "aaab", "aaa", "aaab", "aaa", "aaab", "aaa"])
distance = compute_ngram_distance(words)
distance = squareform(distance)
Expand All @@ -70,15 +70,15 @@ def test_compute_ngram_distance() -> None:
assert np.allclose(distance[words == un_word][:, words == un_word], 0)


def test__guess_clusters() -> None:
def test__guess_clusters():
words = np.array(["aac", "aaa", "aaab", "aaa", "aaab", "aaa", "aaab", "aaa"])
distance = compute_ngram_distance(words)
Z = linkage(distance, method="average")
n_clusters = _guess_clusters(Z, distance)
assert n_clusters == len(np.unique(words))


def test__create_spelling_correction(seed: int = 123) -> None:
def test__create_spelling_correction(seed: int = 123):
rng = np.random.RandomState(seed)
n_clusters = 3
samples_per_cluster = 10
Expand Down Expand Up @@ -116,7 +116,7 @@ def default_deduplicate(n: int = 500, random_state=0):
return X, y


def test_parallelism() -> None:
def test_parallelism():
"""Tests that parallelism works with different backends and n_jobs."""

X, y = default_deduplicate(n=200)
Expand Down
Loading

0 comments on commit b787db9

Please sign in to comment.