From ec835abb4da5e34556a08f72cf30a27ddeadbc95 Mon Sep 17 00:00:00 2001 From: Michael J Ward Date: Tue, 25 Jun 2024 19:20:39 -0500 Subject: [PATCH] UDAF `sum` workaround (#741) * provides workaround for half-migrated UDAF `sum` Ref #730 * provide compatibility for sqlparser::ast::NullTreatment This is now exposed as part of the API to `first_value` and `last_value` functions. If there's a more elegant way to achieve this, please let me know. --- Cargo.lock | 1 + Cargo.toml | 1 + examples/tpch/_tests.py | 1 - src/common.rs | 1 + src/common/data_type.rs | 30 ++++++++++++++++++++++++++++++ src/functions.rs | 33 ++++++++++++++++++++------------- 6 files changed, 53 insertions(+), 14 deletions(-) diff --git a/Cargo.lock b/Cargo.lock index 41742da47..f05c62e97 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -1051,6 +1051,7 @@ dependencies = [ "pyo3-build-config", "rand", "regex-syntax", + "sqlparser", "syn 2.0.67", "tokio", "url", diff --git a/Cargo.toml b/Cargo.toml index 4e3821127..e518449a8 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -56,6 +56,7 @@ parking_lot = "0.12" regex-syntax = "0.8.1" syn = "2.0.67" url = "2.2" +sqlparser = "0.47.0" [build-dependencies] pyo3-build-config = "0.21" diff --git a/examples/tpch/_tests.py b/examples/tpch/_tests.py index aa9491bfd..3f973d9f2 100644 --- a/examples/tpch/_tests.py +++ b/examples/tpch/_tests.py @@ -74,7 +74,6 @@ def check_q17(df): ("q10_returned_item_reporting", "q10"), pytest.param( "q11_important_stock_identification", "q11", - marks=pytest.mark.xfail # https://github.com/apache/datafusion-python/issues/730 ), ("q12_ship_mode_order_priority", "q12"), ("q13_customer_distribution", "q13"), diff --git a/src/common.rs b/src/common.rs index 44c557ce7..094e70c01 100644 --- a/src/common.rs +++ b/src/common.rs @@ -29,6 +29,7 @@ pub(crate) fn init_module(m: &Bound<'_, PyModule>) -> PyResult<()> { m.add_class::()?; m.add_class::()?; m.add_class::()?; + m.add_class::()?; m.add_class::()?; m.add_class::()?; m.add_class::()?; diff --git a/src/common/data_type.rs b/src/common/data_type.rs index cd4f864bc..313318fc9 100644 --- a/src/common/data_type.rs +++ b/src/common/data_type.rs @@ -757,3 +757,33 @@ pub enum SqlType { VARBINARY, VARCHAR, } + +/// Specifies Ignore / Respect NULL within window functions. +/// For example +/// `FIRST_VALUE(column2) IGNORE NULLS OVER (PARTITION BY column1)` +#[allow(non_camel_case_types)] +#[allow(clippy::upper_case_acronyms)] +#[derive(Debug, Clone, PartialEq, Eq, Hash, PartialOrd, Ord)] +#[pyclass(name = "PythonType", module = "datafusion.common")] +pub enum NullTreatment { + IGNORE_NULLS, + RESPECT_NULLS, +} + +impl From for sqlparser::ast::NullTreatment { + fn from(null_treatment: NullTreatment) -> sqlparser::ast::NullTreatment { + match null_treatment { + NullTreatment::IGNORE_NULLS => sqlparser::ast::NullTreatment::IgnoreNulls, + NullTreatment::RESPECT_NULLS => sqlparser::ast::NullTreatment::RespectNulls, + } + } +} + +impl From for NullTreatment { + fn from(null_treatment: sqlparser::ast::NullTreatment) -> NullTreatment { + match null_treatment { + sqlparser::ast::NullTreatment::IgnoreNulls => NullTreatment::IGNORE_NULLS, + sqlparser::ast::NullTreatment::RespectNulls => NullTreatment::RESPECT_NULLS, + } + } +} diff --git a/src/functions.rs b/src/functions.rs index 8e395ae4f..b39d98b35 100644 --- a/src/functions.rs +++ b/src/functions.rs @@ -17,6 +17,7 @@ use pyo3::{prelude::*, wrap_pyfunction}; +use crate::common::data_type::NullTreatment; use crate::context::PySessionContext; use crate::errors::DataFusionError; use crate::expr::conditional_expr::PyCaseBuilder; @@ -73,15 +74,15 @@ pub fn var(y: PyExpr) -> PyExpr { } #[pyfunction] -#[pyo3(signature = (*args, distinct = false, filter = None, order_by = None))] +#[pyo3(signature = (*args, distinct = false, filter = None, order_by = None, null_treatment = None))] pub fn first_value( args: Vec, distinct: bool, filter: Option, order_by: Option>, + null_treatment: Option, ) -> PyExpr { - // TODO: allow user to select null_treatment - let null_treatment = None; + let null_treatment = null_treatment.map(Into::into); let args = args.into_iter().map(|x| x.expr).collect::>(); let order_by = order_by.map(|x| x.into_iter().map(|x| x.expr).collect::>()); functions_aggregate::expr_fn::first_value( @@ -95,15 +96,15 @@ pub fn first_value( } #[pyfunction] -#[pyo3(signature = (*args, distinct = false, filter = None, order_by = None))] +#[pyo3(signature = (*args, distinct = false, filter = None, order_by = None, null_treatment = None))] pub fn last_value( args: Vec, distinct: bool, filter: Option, order_by: Option>, + null_treatment: Option, ) -> PyExpr { - // TODO: allow user to select null_treatment - let null_treatment = None; + let null_treatment = null_treatment.map(Into::into); let args = args.into_iter().map(|x| x.expr).collect::>(); let order_by = order_by.map(|x| x.into_iter().map(|x| x.expr).collect::>()); functions_aggregate::expr_fn::last_value( @@ -320,14 +321,20 @@ fn window( window_frame: Option, ctx: Option, ) -> PyResult { - let fun = find_df_window_func(name).or_else(|| { - ctx.and_then(|ctx| { - ctx.ctx - .udaf(name) - .map(WindowFunctionDefinition::AggregateUDF) - .ok() + // workaround for https://github.com/apache/datafusion-python/issues/730 + let fun = if name == "sum" { + let sum_udf = functions_aggregate::sum::sum_udaf(); + Some(WindowFunctionDefinition::AggregateUDF(sum_udf)) + } else { + find_df_window_func(name).or_else(|| { + ctx.and_then(|ctx| { + ctx.ctx + .udaf(name) + .map(WindowFunctionDefinition::AggregateUDF) + .ok() + }) }) - }); + }; if fun.is_none() { return Err(DataFusionError::Common("window function not found".to_string()).into()); }