Skip to content

Commit

Permalink
fix: Incorrect comparison in some cases with filtered list/array colu…
Browse files Browse the repository at this point in the history
…mns (#20243)
  • Loading branch information
nameexhaustion authored Dec 10, 2024
1 parent 296b4d4 commit 8c39bb9
Show file tree
Hide file tree
Showing 2 changed files with 30 additions and 5 deletions.
20 changes: 16 additions & 4 deletions crates/polars-core/src/chunked_array/comparison/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -652,7 +652,10 @@ where
{
match (lhs.len(), rhs.len()) {
(_, 1) => {
let right = rhs.chunks()[0]
let right = rhs
.downcast_iter()
.find(|x| !x.is_empty())
.unwrap()
.as_any()
.downcast_ref::<ListArray<i64>>()
.unwrap();
Expand Down Expand Up @@ -681,7 +684,10 @@ where
}
},
(1, _) => {
let left = lhs.chunks()[0]
let left = lhs
.downcast_iter()
.find(|x| !x.is_empty())
.unwrap()
.as_any()
.downcast_ref::<ListArray<i64>>()
.unwrap();
Expand Down Expand Up @@ -898,7 +904,10 @@ where
{
match (lhs.len(), rhs.len()) {
(_, 1) => {
let right = rhs.chunks()[0]
let right = rhs
.downcast_iter()
.find(|x| !x.is_empty())
.unwrap()
.as_any()
.downcast_ref::<FixedSizeListArray>()
.unwrap();
Expand All @@ -922,7 +931,10 @@ where
}
},
(1, _) => {
let left = lhs.chunks()[0]
let left = lhs
.downcast_iter()
.find(|x| !x.is_empty())
.unwrap()
.as_any()
.downcast_ref::<FixedSizeListArray>()
.unwrap();
Expand Down
15 changes: 14 additions & 1 deletion py-polars/tests/unit/operations/test_comparison.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@

import polars as pl
from polars.exceptions import ComputeError
from polars.testing import assert_frame_equal
from polars.testing import assert_frame_equal, assert_series_equal

if TYPE_CHECKING:
from contextlib import AbstractContextManager as ContextManager
Expand Down Expand Up @@ -444,3 +444,16 @@ def test_struct_broadcasting_comparison() -> None:
assert df.select(eq=pl.col.foo == pl.col.foo.last()).to_dict(as_series=False) == {
"eq": [True, False, True]
}


@pytest.mark.parametrize("dtype", [pl.List(pl.Int64), pl.Array(pl.Int64, 1)])
def test_compare_list_broadcast_empty_first_chunk_20165(dtype: pl.DataType) -> None:
s = pl.concat(2 * [pl.Series([[1]], dtype=dtype)]).filter([False, True])

assert s.len() == 1
assert s.n_chunks() == 2

assert_series_equal(
pl.select(pl.lit(pl.Series([[1], [2]]), dtype=dtype) == pl.lit(s)).to_series(),
pl.Series([True, False]),
)

0 comments on commit 8c39bb9

Please sign in to comment.