Skip to content

Commit

Permalink
provides workaround for half-migrated UDAF sum
Browse files Browse the repository at this point in the history
  • Loading branch information
Michael-J-Ward committed Jun 25, 2024
1 parent faa26b2 commit 50cac4e
Show file tree
Hide file tree
Showing 2 changed files with 13 additions and 8 deletions.
1 change: 0 additions & 1 deletion examples/tpch/_tests.py
Original file line number Diff line number Diff line change
Expand Up @@ -74,7 +74,6 @@ def check_q17(df):
("q10_returned_item_reporting", "q10"),
pytest.param(
"q11_important_stock_identification", "q11",
marks=pytest.mark.xfail # https://github.com/apache/datafusion-python/issues/730
),
("q12_ship_mode_order_priority", "q12"),
("q13_customer_distribution", "q13"),
Expand Down
20 changes: 13 additions & 7 deletions src/functions.rs
Original file line number Diff line number Diff line change
Expand Up @@ -320,14 +320,20 @@ fn window(
window_frame: Option<PyWindowFrame>,
ctx: Option<PySessionContext>,
) -> PyResult<PyExpr> {
let fun = find_df_window_func(name).or_else(|| {
ctx.and_then(|ctx| {
ctx.ctx
.udaf(name)
.map(WindowFunctionDefinition::AggregateUDF)
.ok()
// workaround for https://github.com/apache/datafusion-python/issues/730
let fun = if name == "sum" {
let sum_udf = functions_aggregate::sum::sum_udaf();
Some(WindowFunctionDefinition::AggregateUDF(sum_udf))
} else {
find_df_window_func(name).or_else(|| {
ctx.and_then(|ctx| {
ctx.ctx
.udaf(name)
.map(WindowFunctionDefinition::AggregateUDF)
.ok()
})
})
});
};
if fun.is_none() {
return Err(DataFusionError::Common("window function not found".to_string()).into());
}
Expand Down

0 comments on commit 50cac4e

Please sign in to comment.