Skip to content

Commit

Permalink
bugfix/1835: Keep nulls in polars when dropping invalid rows and null…
Browse files Browse the repository at this point in the history
…able=True (#1890)

* don't drop null values when dropping invalid rows in polars and nullable=True

Signed-off-by: Jacob Baldwin <[email protected]>

* move the fix for keeping nulls in polars when nullable=True to the check backend

Signed-off-by: Jacob Baldwin <[email protected]>

---------

Signed-off-by: Jacob Baldwin <[email protected]>
  • Loading branch information
baldwinj30 authored Jan 3, 2025
1 parent 1b8e925 commit 813420d
Show file tree
Hide file tree
Showing 3 changed files with 64 additions and 8 deletions.
4 changes: 4 additions & 0 deletions pandera/backends/polars/checks.py
Original file line number Diff line number Diff line change
Expand Up @@ -94,6 +94,10 @@ def postprocess_lazyframe_output(
) -> CheckResult:
"""Postprocesses the result of applying the check function."""
results = pl.LazyFrame(check_output.collect())
if self.check.ignore_na:
results = results.with_columns(
pl.col(CHECK_OUTPUT_KEY) | pl.col(CHECK_OUTPUT_KEY).is_null()
)
passed = results.select([pl.col(CHECK_OUTPUT_KEY).all()])
failure_cases = pl.concat(
[check_obj.lazyframe, results], how="horizontal"
Expand Down
21 changes: 17 additions & 4 deletions tests/polars/test_polars_check.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,19 +32,32 @@ def _column_check_fn_scalar_out(data: pa.PolarsData) -> pl.LazyFrame:


@pytest.mark.parametrize(
"check_fn, invalid_data, expected_output",
"check_fn, invalid_data, expected_output, ignore_na",
[
[_column_check_fn_df_out, [-1, 2, 3, -2], [False, True, True, False]],
[_column_check_fn_scalar_out, [-1, 2, 3, -2], [False]],
[
_column_check_fn_df_out,
[-1, 2, 3, -2],
[False, True, True, False],
False,
],
[_column_check_fn_scalar_out, [-1, 2, 3, -2], [False], False],
[
_column_check_fn_df_out,
[-1, 2, 3, None],
[False, True, True, True],
True,
],
[_column_check_fn_scalar_out, [-1, 2, 3, None], [False], True],
],
)
def test_polars_column_check(
column_lf,
check_fn,
invalid_data,
expected_output,
ignore_na,
):
check = pa.Check(check_fn)
check = pa.Check(check_fn, ignore_na=ignore_na)
check_result = check(column_lf, column="col")
assert check_result.check_passed.collect().item()

Expand Down
47 changes: 43 additions & 4 deletions tests/polars/test_polars_container.py
Original file line number Diff line number Diff line change
Expand Up @@ -313,16 +313,20 @@ def custom_check(data: PolarsData):
"column_mod,filter_expr",
[
({"int_col": pl.Series([-1, 1, 1])}, pl.col("int_col").ge(0)),
({"string_col": pl.Series([*"013"])}, pl.col("string_col").ne("d")),
({"string_col": pl.Series([*"013"])}, pl.col("string_col").ne("3")),
(
{
"int_col": pl.Series([-1, 1, 1]),
"string_col": pl.Series([*"013"]),
},
pl.col("int_col").ge(0) & pl.col("string_col").ne("d"),
pl.col("int_col").ge(0) & pl.col("string_col").ne("3"),
),
({"int_col": pl.lit(-1)}, pl.col("int_col").ge(0)),
({"int_col": pl.lit("d")}, pl.col("string_col").ne("d")),
({"string_col": pl.lit("d")}, pl.col("string_col").ne("d")),
(
{"int_col": pl.Series([None, 1, 1])},
pl.col("int_col").is_not_null(),
),
],
)
@pytest.mark.parametrize("lazy", [False, True])
Expand All @@ -334,7 +338,7 @@ def test_drop_invalid_rows(
ldf_schema_with_check,
):
ldf_schema_with_check.drop_invalid_rows = True
modified_data = ldf_basic.with_columns(column_mod)
modified_data = ldf_basic.with_columns(**column_mod)
if lazy:
validated_data = modified_data.pipe(
ldf_schema_with_check.validate,
Expand All @@ -350,6 +354,41 @@ def test_drop_invalid_rows(
)


@pytest.mark.parametrize(
"schema_column_updates,column_mod,filter_expr",
[
(
{"int_col": {"nullable": True}},
{"int_col": pl.Series([None, -1, 1])},
pl.col("int_col").is_null() | (pl.col("int_col") > 0),
),
(
{"int_col": {"nullable": True, "checks": [pa.Check.isin([1])]}},
{"int_col": pl.Series([None, 1, 2])},
pl.col("int_col").is_null() | (pl.col("int_col") == 1),
),
],
)
def test_drop_invalid_rows_nullable(
schema_column_updates,
column_mod,
filter_expr,
ldf_basic,
ldf_schema_with_check,
):
ldf_schema_with_check.drop_invalid_rows = True
nullable_schema = ldf_schema_with_check.update_columns(
schema_column_updates
)
modified_data = ldf_basic.with_columns(**column_mod)
validated_data = modified_data.pipe(
nullable_schema.validate,
lazy=True,
)
expected_valid_data = modified_data.filter(filter_expr)
assert validated_data.collect().equals(expected_valid_data.collect())


def test_set_defaults(ldf_basic, ldf_schema_basic):
ldf_schema_basic.columns["int_col"].default = 1
ldf_schema_basic.columns["string_col"].default = "a"
Expand Down

0 comments on commit 813420d

Please sign in to comment.