Skip to content

Commit

Permalink
remove unneeded sorts
Browse files Browse the repository at this point in the history
  • Loading branch information
ritchie46 committed Dec 21, 2024
1 parent 0fc4ca0 commit 65d7deb
Show file tree
Hide file tree
Showing 4 changed files with 45 additions and 16 deletions.
13 changes: 7 additions & 6 deletions crates/polars-plan/src/plans/optimizer/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -118,6 +118,13 @@ pub fn optimize(
members.collect(lp_top, lp_arena, expr_arena)
}

// Run before slice pushdown
if opt_state.contains(OptFlags::CHECK_ORDER_OBSERVE)
&& members.has_group_by | members.has_sort | members.has_distinct
{
set_order_flags(lp_top, lp_arena, expr_arena, scratch);
}

if simplify_expr {
#[cfg(feature = "fused")]
rules.push(Box::new(fused::FusedArithmetic {}));
Expand Down Expand Up @@ -210,12 +217,6 @@ pub fn optimize(
cache_states::set_cache_states(lp_top, lp_arena, expr_arena, scratch, expr_eval, verbose)?;
}

if opt_state.contains(OptFlags::CHECK_ORDER_OBSERVE)
&& members.has_group_by | members.has_sort | members.has_distinct
{
set_order_flags(lp_top, lp_arena, expr_arena, scratch);
}

// This one should run (nearly) last as this modifies the projections
#[cfg(feature = "cse")]
if comm_subexpr_elim && !members.has_ext_context {
Expand Down
36 changes: 27 additions & 9 deletions crates/polars-plan/src/plans/optimizer/set_order.rs
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@ fn is_order_dependent_top_level(ae: &AExpr, ctx: Context) -> bool {
IRAggExpr::First(_) => true,
IRAggExpr::Last(_) => true,
IRAggExpr::Mean(_) => false,
IRAggExpr::Implode(_) => false,
IRAggExpr::Implode(_) => true,
IRAggExpr::Quantile { .. } => false,
IRAggExpr::Sum(_) => false,
IRAggExpr::Count(_, _) => false,
Expand All @@ -31,8 +31,8 @@ fn is_order_dependent<'a>(mut ae: &'a AExpr, expr_arena: &'a Arena<AExpr>, ctx:
let mut stack = unitvec![];

loop {
if !is_order_dependent_top_level(ae, ctx) {
return false;
if is_order_dependent_top_level(ae, ctx) {
return true;
}

let Some(node) = stack.pop() else {
Expand All @@ -42,7 +42,7 @@ fn is_order_dependent<'a>(mut ae: &'a AExpr, expr_arena: &'a Arena<AExpr>, ctx:
ae = expr_arena.get(node);
}

true
false
}

// Can give false negatives.
Expand All @@ -54,11 +54,12 @@ pub(crate) fn all_order_independent<'a, N>(
where
Node: From<&'a N>,
{
nodes
!nodes
.iter()
.all(|n| !is_order_dependent(expr_arena.get(n.into()), expr_arena, ctx))
.any(|n| is_order_dependent(expr_arena.get(n.into()), expr_arena, ctx))
}

// Should run before slice pushdown.
pub(super) fn set_order_flags(
root: Node,
ir_arena: &mut Arena<IR>,
Expand All @@ -75,8 +76,23 @@ pub(super) fn set_order_flags(
ir.copy_inputs(scratch);

match ir {
IR::Sort { .. } => {
maintain_order_above = false;
IR::Sort {
input,
sort_options,
..
} => {
// This sort can be removed
if !maintain_order_above && sort_options.limit.is_none() {
scratch.pop();
scratch.push(node);
let input = *input;
ir_arena.swap(node, input);
continue;
}

if !sort_options.maintain_order {
maintain_order_above = false; // `maintain_order=True` is influenced by result of earlier sorts
}
},
IR::Distinct { options, .. } => {
if !maintain_order_above {
Expand Down Expand Up @@ -112,7 +128,7 @@ pub(super) fn set_order_flags(
continue;
}
if all_elementwise(keys, expr_arena)
&& !all_order_independent(aggs, expr_arena, Context::Aggregation)
&& all_order_independent(aggs, expr_arena, Context::Aggregation)
{
maintain_order_above = false;
continue;
Expand All @@ -127,6 +143,8 @@ pub(super) fn set_order_flags(
maintain_order_above = true;
},
_ => {
// If we don't know maintain order
// Known: slice
maintain_order_above = true;
},
}
Expand Down
10 changes: 10 additions & 0 deletions py-polars/tests/unit/lazyframe/optimizations.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,3 +40,13 @@ def test_fast_count_alias_18581() -> None:
df = pl.scan_csv(f).select(pl.len().alias("weird_name")).collect()

assert_frame_equal(pl.DataFrame({"weird_name": 2}), df)


def test_order_observability() -> None:
q = pl.LazyFrame({"a": [1, 2, 3], "b": [1, 2, 3]}).sort("a")

assert "SORT" not in q.group_by("a").sum().explain(_check_order=True)
assert "SORT" not in q.group_by("a").min().explain(_check_order=True)
assert "SORT" not in q.group_by("a").max().explain(_check_order=True)
assert "SORT" in q.group_by("a").last().explain(_check_order=True)
assert "SORT" in q.group_by("a").first().explain(_check_order=True)
2 changes: 1 addition & 1 deletion py-polars/tests/unit/operations/test_group_by.py
Original file line number Diff line number Diff line change
Expand Up @@ -402,7 +402,7 @@ def test_group_by_sorted_empty_dataframe_3680() -> None:
.sort("key")
.group_by("key")
.tail(1)
.collect()
.collect(_check_order=False)
)
assert df.rows() == []
assert df.shape == (0, 2)
Expand Down

0 comments on commit 65d7deb

Please sign in to comment.