Skip to content

Commit

Permalink
UDAF sum workaround (#741)
Browse files Browse the repository at this point in the history
* provides workaround for half-migrated UDAF `sum`

Ref #730

* provide compatibility for sqlparser::ast::NullTreatment

This is now exposed as part of the API to `first_value` and `last_value` functions.

If there's a more elegant way to achieve this, please let me know.
  • Loading branch information
Michael-J-Ward authored Jun 26, 2024
1 parent 32d6975 commit ec835ab
Show file tree
Hide file tree
Showing 6 changed files with 53 additions and 14 deletions.
1 change: 1 addition & 0 deletions Cargo.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

1 change: 1 addition & 0 deletions Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -56,6 +56,7 @@ parking_lot = "0.12"
regex-syntax = "0.8.1"
syn = "2.0.67"
url = "2.2"
sqlparser = "0.47.0"

[build-dependencies]
pyo3-build-config = "0.21"
Expand Down
1 change: 0 additions & 1 deletion examples/tpch/_tests.py
Original file line number Diff line number Diff line change
Expand Up @@ -74,7 +74,6 @@ def check_q17(df):
("q10_returned_item_reporting", "q10"),
pytest.param(
"q11_important_stock_identification", "q11",
marks=pytest.mark.xfail # https://github.com/apache/datafusion-python/issues/730
),
("q12_ship_mode_order_priority", "q12"),
("q13_customer_distribution", "q13"),
Expand Down
1 change: 1 addition & 0 deletions src/common.rs
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@ pub(crate) fn init_module(m: &Bound<'_, PyModule>) -> PyResult<()> {
m.add_class::<data_type::DataTypeMap>()?;
m.add_class::<data_type::PythonType>()?;
m.add_class::<data_type::SqlType>()?;
m.add_class::<data_type::NullTreatment>()?;
m.add_class::<schema::SqlTable>()?;
m.add_class::<schema::SqlSchema>()?;
m.add_class::<schema::SqlView>()?;
Expand Down
30 changes: 30 additions & 0 deletions src/common/data_type.rs
Original file line number Diff line number Diff line change
Expand Up @@ -757,3 +757,33 @@ pub enum SqlType {
VARBINARY,
VARCHAR,
}

/// Specifies Ignore / Respect NULL within window functions.
/// For example
/// `FIRST_VALUE(column2) IGNORE NULLS OVER (PARTITION BY column1)`
#[allow(non_camel_case_types)]
#[allow(clippy::upper_case_acronyms)]
#[derive(Debug, Clone, PartialEq, Eq, Hash, PartialOrd, Ord)]
#[pyclass(name = "PythonType", module = "datafusion.common")]
pub enum NullTreatment {
IGNORE_NULLS,
RESPECT_NULLS,
}

impl From<NullTreatment> for sqlparser::ast::NullTreatment {
fn from(null_treatment: NullTreatment) -> sqlparser::ast::NullTreatment {
match null_treatment {
NullTreatment::IGNORE_NULLS => sqlparser::ast::NullTreatment::IgnoreNulls,
NullTreatment::RESPECT_NULLS => sqlparser::ast::NullTreatment::RespectNulls,
}
}
}

impl From<sqlparser::ast::NullTreatment> for NullTreatment {
fn from(null_treatment: sqlparser::ast::NullTreatment) -> NullTreatment {
match null_treatment {
sqlparser::ast::NullTreatment::IgnoreNulls => NullTreatment::IGNORE_NULLS,
sqlparser::ast::NullTreatment::RespectNulls => NullTreatment::RESPECT_NULLS,
}
}
}
33 changes: 20 additions & 13 deletions src/functions.rs
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@

use pyo3::{prelude::*, wrap_pyfunction};

use crate::common::data_type::NullTreatment;
use crate::context::PySessionContext;
use crate::errors::DataFusionError;
use crate::expr::conditional_expr::PyCaseBuilder;
Expand Down Expand Up @@ -73,15 +74,15 @@ pub fn var(y: PyExpr) -> PyExpr {
}

#[pyfunction]
#[pyo3(signature = (*args, distinct = false, filter = None, order_by = None))]
#[pyo3(signature = (*args, distinct = false, filter = None, order_by = None, null_treatment = None))]
pub fn first_value(
args: Vec<PyExpr>,
distinct: bool,
filter: Option<PyExpr>,
order_by: Option<Vec<PyExpr>>,
null_treatment: Option<NullTreatment>,
) -> PyExpr {
// TODO: allow user to select null_treatment
let null_treatment = None;
let null_treatment = null_treatment.map(Into::into);
let args = args.into_iter().map(|x| x.expr).collect::<Vec<_>>();
let order_by = order_by.map(|x| x.into_iter().map(|x| x.expr).collect::<Vec<_>>());
functions_aggregate::expr_fn::first_value(
Expand All @@ -95,15 +96,15 @@ pub fn first_value(
}

#[pyfunction]
#[pyo3(signature = (*args, distinct = false, filter = None, order_by = None))]
#[pyo3(signature = (*args, distinct = false, filter = None, order_by = None, null_treatment = None))]
pub fn last_value(
args: Vec<PyExpr>,
distinct: bool,
filter: Option<PyExpr>,
order_by: Option<Vec<PyExpr>>,
null_treatment: Option<NullTreatment>,
) -> PyExpr {
// TODO: allow user to select null_treatment
let null_treatment = None;
let null_treatment = null_treatment.map(Into::into);
let args = args.into_iter().map(|x| x.expr).collect::<Vec<_>>();
let order_by = order_by.map(|x| x.into_iter().map(|x| x.expr).collect::<Vec<_>>());
functions_aggregate::expr_fn::last_value(
Expand Down Expand Up @@ -320,14 +321,20 @@ fn window(
window_frame: Option<PyWindowFrame>,
ctx: Option<PySessionContext>,
) -> PyResult<PyExpr> {
let fun = find_df_window_func(name).or_else(|| {
ctx.and_then(|ctx| {
ctx.ctx
.udaf(name)
.map(WindowFunctionDefinition::AggregateUDF)
.ok()
// workaround for https://github.com/apache/datafusion-python/issues/730
let fun = if name == "sum" {
let sum_udf = functions_aggregate::sum::sum_udaf();
Some(WindowFunctionDefinition::AggregateUDF(sum_udf))
} else {
find_df_window_func(name).or_else(|| {
ctx.and_then(|ctx| {
ctx.ctx
.udaf(name)
.map(WindowFunctionDefinition::AggregateUDF)
.ok()
})
})
});
};
if fun.is_none() {
return Err(DataFusionError::Common("window function not found".to_string()).into());
}
Expand Down

0 comments on commit ec835ab

Please sign in to comment.