From 8c39bb9f4fd1dd0baba596a60fbdaa7c109432fb Mon Sep 17 00:00:00 2001 From: nameexhaustion Date: Tue, 10 Dec 2024 20:23:33 +1100 Subject: [PATCH] fix: Incorrect comparison in some cases with filtered list/array columns (#20243) --- .../src/chunked_array/comparison/mod.rs | 20 +++++++++++++++---- .../tests/unit/operations/test_comparison.py | 15 +++++++++++++- 2 files changed, 30 insertions(+), 5 deletions(-) diff --git a/crates/polars-core/src/chunked_array/comparison/mod.rs b/crates/polars-core/src/chunked_array/comparison/mod.rs index e9b9efd2cae0..5d5845133986 100644 --- a/crates/polars-core/src/chunked_array/comparison/mod.rs +++ b/crates/polars-core/src/chunked_array/comparison/mod.rs @@ -652,7 +652,10 @@ where { match (lhs.len(), rhs.len()) { (_, 1) => { - let right = rhs.chunks()[0] + let right = rhs + .downcast_iter() + .find(|x| !x.is_empty()) + .unwrap() .as_any() .downcast_ref::>() .unwrap(); @@ -681,7 +684,10 @@ where } }, (1, _) => { - let left = lhs.chunks()[0] + let left = lhs + .downcast_iter() + .find(|x| !x.is_empty()) + .unwrap() .as_any() .downcast_ref::>() .unwrap(); @@ -898,7 +904,10 @@ where { match (lhs.len(), rhs.len()) { (_, 1) => { - let right = rhs.chunks()[0] + let right = rhs + .downcast_iter() + .find(|x| !x.is_empty()) + .unwrap() .as_any() .downcast_ref::() .unwrap(); @@ -922,7 +931,10 @@ where } }, (1, _) => { - let left = lhs.chunks()[0] + let left = lhs + .downcast_iter() + .find(|x| !x.is_empty()) + .unwrap() .as_any() .downcast_ref::() .unwrap(); diff --git a/py-polars/tests/unit/operations/test_comparison.py b/py-polars/tests/unit/operations/test_comparison.py index 523f2050ca36..a58f718e8ee9 100644 --- a/py-polars/tests/unit/operations/test_comparison.py +++ b/py-polars/tests/unit/operations/test_comparison.py @@ -8,7 +8,7 @@ import polars as pl from polars.exceptions import ComputeError -from polars.testing import assert_frame_equal +from polars.testing import assert_frame_equal, assert_series_equal if TYPE_CHECKING: from contextlib import AbstractContextManager as ContextManager @@ -444,3 +444,16 @@ def test_struct_broadcasting_comparison() -> None: assert df.select(eq=pl.col.foo == pl.col.foo.last()).to_dict(as_series=False) == { "eq": [True, False, True] } + + +@pytest.mark.parametrize("dtype", [pl.List(pl.Int64), pl.Array(pl.Int64, 1)]) +def test_compare_list_broadcast_empty_first_chunk_20165(dtype: pl.DataType) -> None: + s = pl.concat(2 * [pl.Series([[1]], dtype=dtype)]).filter([False, True]) + + assert s.len() == 1 + assert s.n_chunks() == 2 + + assert_series_equal( + pl.select(pl.lit(pl.Series([[1], [2]]), dtype=dtype) == pl.lit(s)).to_series(), + pl.Series([True, False]), + )