Skip to content

Commit

Permalink
fix: Incorrect aggregation of empty groups after slice (#20127)
Browse files Browse the repository at this point in the history
  • Loading branch information
nameexhaustion authored Dec 3, 2024
1 parent d86e44b commit c5a8efa
Show file tree
Hide file tree
Showing 3 changed files with 25 additions and 1 deletion.
3 changes: 2 additions & 1 deletion crates/polars-core/src/frame/column/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -657,7 +657,8 @@ impl Column {
let mut s = scalar_col.take_materialized_series().rechunk();
// SAFETY: We perform a compute_len afterwards.
let chunks = unsafe { s.chunks_mut() };
chunks[0].with_validity(Some(validity));
let arr = &mut chunks[0];
*arr = arr.with_validity(Some(validity));
s.compute_len();

s.into_column()
Expand Down
5 changes: 5 additions & 0 deletions crates/polars-expr/src/expressions/slice.rs
Original file line number Diff line number Diff line change
Expand Up @@ -110,6 +110,11 @@ impl PhysicalExpr for SliceExpr {
.collect::<PolarsResult<Vec<_>>>()
})?;
let mut ac = results.pop().unwrap();

if let AggState::AggregatedScalar(_) = ac.agg_state() {
polars_bail!(InvalidOperation: "cannot slice() an aggregated scalar value")
}

let mut ac_length = results.pop().unwrap();
let mut ac_offset = results.pop().unwrap();

Expand Down
18 changes: 18 additions & 0 deletions py-polars/tests/unit/operations/aggregation/test_aggregations.py
Original file line number Diff line number Diff line change
Expand Up @@ -742,3 +742,21 @@ def test_sort_by_over_multiple_nulls_last() -> None:
}
)
assert_frame_equal(out, expected)


def test_slice_after_agg_raises() -> None:
with pytest.raises(
InvalidOperationError, match=r"cannot slice\(\) an aggregated scalar value"
):
pl.select(a=1, b=1).group_by("a").agg(pl.col("b").first().slice(99, 0))


def test_agg_scalar_empty_groups_20115() -> None:
assert_frame_equal(
(
pl.DataFrame({"key": [123], "value": [456]})
.group_by("key")
.agg(pl.col("value").slice(1, 1).first())
),
pl.select(key=pl.lit(123, pl.Int64), value=pl.lit(None, pl.Int64)),
)

0 comments on commit c5a8efa

Please sign in to comment.