From 813420db5092a7ad4a4ce338b18cf56b1277fbba Mon Sep 17 00:00:00 2001 From: Jacob Baldwin <51560848+baldwinj30@users.noreply.github.com> Date: Fri, 3 Jan 2025 16:11:35 -0500 Subject: [PATCH] bugfix/1835: Keep nulls in polars when dropping invalid rows and nullable=True (#1890) * don't drop null values when dropping invalid rows in polars and nullable=True Signed-off-by: Jacob Baldwin * move the fix for keeping nulls in polars when nullable=True to the check backend Signed-off-by: Jacob Baldwin --------- Signed-off-by: Jacob Baldwin --- pandera/backends/polars/checks.py | 4 +++ tests/polars/test_polars_check.py | 21 +++++++++--- tests/polars/test_polars_container.py | 47 ++++++++++++++++++++++++--- 3 files changed, 64 insertions(+), 8 deletions(-) diff --git a/pandera/backends/polars/checks.py b/pandera/backends/polars/checks.py index 3a02d97af..1993bea02 100644 --- a/pandera/backends/polars/checks.py +++ b/pandera/backends/polars/checks.py @@ -94,6 +94,10 @@ def postprocess_lazyframe_output( ) -> CheckResult: """Postprocesses the result of applying the check function.""" results = pl.LazyFrame(check_output.collect()) + if self.check.ignore_na: + results = results.with_columns( + pl.col(CHECK_OUTPUT_KEY) | pl.col(CHECK_OUTPUT_KEY).is_null() + ) passed = results.select([pl.col(CHECK_OUTPUT_KEY).all()]) failure_cases = pl.concat( [check_obj.lazyframe, results], how="horizontal" diff --git a/tests/polars/test_polars_check.py b/tests/polars/test_polars_check.py index a3929ef4a..665c5f3dd 100644 --- a/tests/polars/test_polars_check.py +++ b/tests/polars/test_polars_check.py @@ -32,10 +32,22 @@ def _column_check_fn_scalar_out(data: pa.PolarsData) -> pl.LazyFrame: @pytest.mark.parametrize( - "check_fn, invalid_data, expected_output", + "check_fn, invalid_data, expected_output, ignore_na", [ - [_column_check_fn_df_out, [-1, 2, 3, -2], [False, True, True, False]], - [_column_check_fn_scalar_out, [-1, 2, 3, -2], [False]], + [ + _column_check_fn_df_out, + [-1, 2, 3, -2], + [False, True, True, False], + False, + ], + [_column_check_fn_scalar_out, [-1, 2, 3, -2], [False], False], + [ + _column_check_fn_df_out, + [-1, 2, 3, None], + [False, True, True, True], + True, + ], + [_column_check_fn_scalar_out, [-1, 2, 3, None], [False], True], ], ) def test_polars_column_check( @@ -43,8 +55,9 @@ def test_polars_column_check( check_fn, invalid_data, expected_output, + ignore_na, ): - check = pa.Check(check_fn) + check = pa.Check(check_fn, ignore_na=ignore_na) check_result = check(column_lf, column="col") assert check_result.check_passed.collect().item() diff --git a/tests/polars/test_polars_container.py b/tests/polars/test_polars_container.py index c44a9448a..edcaf10d7 100644 --- a/tests/polars/test_polars_container.py +++ b/tests/polars/test_polars_container.py @@ -313,16 +313,20 @@ def custom_check(data: PolarsData): "column_mod,filter_expr", [ ({"int_col": pl.Series([-1, 1, 1])}, pl.col("int_col").ge(0)), - ({"string_col": pl.Series([*"013"])}, pl.col("string_col").ne("d")), + ({"string_col": pl.Series([*"013"])}, pl.col("string_col").ne("3")), ( { "int_col": pl.Series([-1, 1, 1]), "string_col": pl.Series([*"013"]), }, - pl.col("int_col").ge(0) & pl.col("string_col").ne("d"), + pl.col("int_col").ge(0) & pl.col("string_col").ne("3"), ), ({"int_col": pl.lit(-1)}, pl.col("int_col").ge(0)), - ({"int_col": pl.lit("d")}, pl.col("string_col").ne("d")), + ({"string_col": pl.lit("d")}, pl.col("string_col").ne("d")), + ( + {"int_col": pl.Series([None, 1, 1])}, + pl.col("int_col").is_not_null(), + ), ], ) @pytest.mark.parametrize("lazy", [False, True]) @@ -334,7 +338,7 @@ def test_drop_invalid_rows( ldf_schema_with_check, ): ldf_schema_with_check.drop_invalid_rows = True - modified_data = ldf_basic.with_columns(column_mod) + modified_data = ldf_basic.with_columns(**column_mod) if lazy: validated_data = modified_data.pipe( ldf_schema_with_check.validate, @@ -350,6 +354,41 @@ def test_drop_invalid_rows( ) +@pytest.mark.parametrize( + "schema_column_updates,column_mod,filter_expr", + [ + ( + {"int_col": {"nullable": True}}, + {"int_col": pl.Series([None, -1, 1])}, + pl.col("int_col").is_null() | (pl.col("int_col") > 0), + ), + ( + {"int_col": {"nullable": True, "checks": [pa.Check.isin([1])]}}, + {"int_col": pl.Series([None, 1, 2])}, + pl.col("int_col").is_null() | (pl.col("int_col") == 1), + ), + ], +) +def test_drop_invalid_rows_nullable( + schema_column_updates, + column_mod, + filter_expr, + ldf_basic, + ldf_schema_with_check, +): + ldf_schema_with_check.drop_invalid_rows = True + nullable_schema = ldf_schema_with_check.update_columns( + schema_column_updates + ) + modified_data = ldf_basic.with_columns(**column_mod) + validated_data = modified_data.pipe( + nullable_schema.validate, + lazy=True, + ) + expected_valid_data = modified_data.filter(filter_expr) + assert validated_data.collect().equals(expected_valid_data.collect()) + + def test_set_defaults(ldf_basic, ldf_schema_basic): ldf_schema_basic.columns["int_col"].default = 1 ldf_schema_basic.columns["string_col"].default = "a"