Skip to content

Commit

Permalink
Move abs to datafusion_functions (apache#9313)
Browse files Browse the repository at this point in the history
* feat: move abs to datafusion_functions

* fix proto

* fix proto

* fix CI vendored code

* Fix proto

* add support type

* fix signature

* fix typo

* fix test cases

* disable a test case

* remove old code from math_expressions

* feat: add test

* fix clippy

* use unknown for proto

* fix unknown proto
  • Loading branch information
yyy1000 authored Feb 27, 2024
1 parent b55d0ed commit 85f7a8e
Show file tree
Hide file tree
Showing 12 changed files with 198 additions and 123 deletions.
7 changes: 0 additions & 7 deletions datafusion/expr/src/built_in_function.rs
Original file line number Diff line number Diff line change
Expand Up @@ -42,8 +42,6 @@ use strum_macros::EnumIter;
#[derive(Debug, Clone, PartialEq, Eq, Hash, EnumIter, Copy)]
pub enum BuiltinScalarFunction {
// math functions
/// abs
Abs,
/// acos
Acos,
/// asin
Expand Down Expand Up @@ -364,7 +362,6 @@ impl BuiltinScalarFunction {
pub fn volatility(&self) -> Volatility {
match self {
// Immutable scalar builtins
BuiltinScalarFunction::Abs => Volatility::Immutable,
BuiltinScalarFunction::Acos => Volatility::Immutable,
BuiltinScalarFunction::Asin => Volatility::Immutable,
BuiltinScalarFunction::Atan => Volatility::Immutable,
Expand Down Expand Up @@ -868,8 +865,6 @@ impl BuiltinScalarFunction {

BuiltinScalarFunction::ArrowTypeof => Ok(Utf8),

BuiltinScalarFunction::Abs => Ok(input_expr_types[0].clone()),

BuiltinScalarFunction::OverLay => {
utf8_to_str_type(&input_expr_types[0], "overlay")
}
Expand Down Expand Up @@ -1338,7 +1333,6 @@ impl BuiltinScalarFunction {
Signature::uniform(2, vec![Int64], self.volatility())
}
BuiltinScalarFunction::ArrowTypeof => Signature::any(1, self.volatility()),
BuiltinScalarFunction::Abs => Signature::any(1, self.volatility()),
BuiltinScalarFunction::OverLay => Signature::one_of(
vec![
Exact(vec![Utf8, Utf8, Int64, Int64]),
Expand Down Expand Up @@ -1444,7 +1438,6 @@ impl BuiltinScalarFunction {
/// Returns all names that can be used to call this function
pub fn aliases(&self) -> &'static [&'static str] {
match self {
BuiltinScalarFunction::Abs => &["abs"],
BuiltinScalarFunction::Acos => &["acos"],
BuiltinScalarFunction::Acosh => &["acosh"],
BuiltinScalarFunction::Asin => &["asin"],
Expand Down
5 changes: 0 additions & 5 deletions datafusion/expr/src/expr.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2033,11 +2033,6 @@ mod test {
.is_volatile()
.unwrap()
);
assert!(
!ScalarFunctionDefinition::BuiltIn(BuiltinScalarFunction::Abs)
.is_volatile()
.unwrap()
);

// UDF
#[derive(Debug)]
Expand Down
2 changes: 0 additions & 2 deletions datafusion/expr/src/expr_fn.rs
Original file line number Diff line number Diff line change
Expand Up @@ -557,7 +557,6 @@ nary_scalar_expr!(
trunc,
"truncate toward zero, with optional precision"
);
scalar_expr!(Abs, abs, num, "absolute value");
scalar_expr!(Signum, signum, num, "sign of the argument (-1, 0, +1) ");
scalar_expr!(Exp, exp, num, "exponential");
scalar_expr!(Gcd, gcd, arg_1 arg_2, "greatest common divisor");
Expand Down Expand Up @@ -1354,7 +1353,6 @@ mod test {
test_nary_scalar_expr!(Round, round, input, decimal_places);
test_nary_scalar_expr!(Trunc, trunc, num);
test_nary_scalar_expr!(Trunc, trunc, num, precision);
test_unary_scalar_expr!(Abs, abs);
test_unary_scalar_expr!(Signum, signum);
test_unary_scalar_expr!(Exp, exp);
test_unary_scalar_expr!(Log2, log2);
Expand Down
177 changes: 177 additions & 0 deletions datafusion/functions/src/math/abs.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,177 @@
// Licensed to the Apache Software Foundation (ASF) under one
// or more contributor license agreements. See the NOTICE file
// distributed with this work for additional information
// regarding copyright ownership. The ASF licenses this file
// to you under the Apache License, Version 2.0 (the
// "License"); you may not use this file except in compliance
// with the License. You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing,
// software distributed under the License is distributed on an
// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
// KIND, either express or implied. See the License for the
// specific language governing permissions and limitations
// under the License.

//! math expressions
use arrow::array::Decimal128Array;
use arrow::array::Decimal256Array;
use arrow::array::Int16Array;
use arrow::array::Int32Array;
use arrow::array::Int64Array;
use arrow::array::Int8Array;
use arrow::datatypes::DataType;
use datafusion_common::not_impl_err;
use datafusion_common::plan_datafusion_err;
use datafusion_common::{internal_err, Result, DataFusionError};
use datafusion_expr::utils;
use datafusion_expr::ColumnarValue;

use datafusion_expr::{ScalarUDFImpl, Signature, Volatility};
use std::any::Any;
use std::sync::Arc;
use arrow::array::{ArrayRef, Float32Array, Float64Array};
use arrow::error::ArrowError;

type MathArrayFunction = fn(&Vec<ArrayRef>) -> Result<ArrayRef>;

macro_rules! make_abs_function {
($ARRAY_TYPE:ident) => {{
|args: &Vec<ArrayRef>| {
let array = downcast_arg!(&args[0], "abs arg", $ARRAY_TYPE);
let res: $ARRAY_TYPE = array.unary(|x| x.abs());
Ok(Arc::new(res) as ArrayRef)
}
}};
}

macro_rules! make_try_abs_function {
($ARRAY_TYPE:ident) => {{
|args: &Vec<ArrayRef>| {
let array = downcast_arg!(&args[0], "abs arg", $ARRAY_TYPE);
let res: $ARRAY_TYPE = array.try_unary(|x| {
x.checked_abs().ok_or_else(|| {
ArrowError::ComputeError(format!(
"{} overflow on abs({})",
stringify!($ARRAY_TYPE),
x
))
})
})?;
Ok(Arc::new(res) as ArrayRef)
}
}};
}

macro_rules! make_decimal_abs_function {
($ARRAY_TYPE:ident) => {{
|args: &Vec<ArrayRef>| {
let array = downcast_arg!(&args[0], "abs arg", $ARRAY_TYPE);
let res: $ARRAY_TYPE = array
.unary(|x| x.wrapping_abs())
.with_data_type(args[0].data_type().clone());
Ok(Arc::new(res) as ArrayRef)
}
}};
}

/// Abs SQL function
/// Return different implementations based on input datatype to reduce branches during execution
fn create_abs_function(
input_data_type: &DataType,
) -> Result<MathArrayFunction> {
match input_data_type {
DataType::Float32 => Ok(make_abs_function!(Float32Array)),
DataType::Float64 => Ok(make_abs_function!(Float64Array)),

// Types that may overflow, such as abs(-128_i8).
DataType::Int8 => Ok(make_try_abs_function!(Int8Array)),
DataType::Int16 => Ok(make_try_abs_function!(Int16Array)),
DataType::Int32 => Ok(make_try_abs_function!(Int32Array)),
DataType::Int64 => Ok(make_try_abs_function!(Int64Array)),

// Types of results are the same as the input.
DataType::Null
| DataType::UInt8
| DataType::UInt16
| DataType::UInt32
| DataType::UInt64 => Ok(|args: &Vec<ArrayRef>| Ok(args[0].clone())),

// Decimal types
DataType::Decimal128(_, _) => Ok(make_decimal_abs_function!(Decimal128Array)),
DataType::Decimal256(_, _) => Ok(make_decimal_abs_function!(Decimal256Array)),

other => not_impl_err!("Unsupported data type {other:?} for function abs"),
}
}
#[derive(Debug)]
pub(super) struct AbsFunc {
signature: Signature,
}

impl AbsFunc {
pub fn new() -> Self {
Self {
signature: Signature::any(1, Volatility::Immutable)
}
}
}

impl ScalarUDFImpl for AbsFunc {
fn as_any(&self) -> &dyn Any {
self
}
fn name(&self) -> &str {
"abs"
}

fn signature(&self) -> &Signature {
&self.signature
}

fn return_type(&self, arg_types: &[DataType]) -> Result<DataType> {
if arg_types.len() != 1 {
return Err(plan_datafusion_err!(
"{}",
utils::generate_signature_error_msg(
self.name(),
self.signature().clone(),
arg_types,
)
));
}
match arg_types[0] {
DataType::Float32 => Ok(DataType::Float32),
DataType::Float64 => Ok(DataType::Float64),
DataType::Int8 => Ok(DataType::Int8),
DataType::Int16 => Ok(DataType::Int16),
DataType::Int32 => Ok(DataType::Int32),
DataType::Int64 => Ok(DataType::Int64),
DataType::Null => Ok(DataType::Null),
DataType::UInt8 => Ok(DataType::UInt8),
DataType::UInt16 => Ok(DataType::UInt16),
DataType::UInt32 => Ok(DataType::UInt32),
DataType::UInt64 => Ok(DataType::UInt64),
DataType::Decimal128(precision, scale) => Ok(DataType::Decimal128(precision, scale)),
DataType::Decimal256(precision, scale) => Ok(DataType::Decimal256(precision, scale)),
_ => not_impl_err!("Unsupported data type {} for function abs", arg_types[0].to_string()),
}
}

fn invoke(&self, args: &[ColumnarValue]) -> Result<ColumnarValue> {
let args = ColumnarValue::values_to_arrays(args)?;

if args.len() != 1 {
return internal_err!("abs function requires 1 argument, got {}", args.len());
}

let input_data_type = args[0].data_type();
let abs_fun = create_abs_function(input_data_type)?;

let arr = abs_fun(&args)?;
Ok(ColumnarValue::Array(arr))
}
}
8 changes: 5 additions & 3 deletions datafusion/functions/src/math/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -18,12 +18,14 @@
//! "math" DataFusion functions
mod nans;
mod abs;

// create UDFs
make_udf_function!(nans::IsNanFunc, ISNAN, isnan);
make_udf_function!(abs::AbsFunc, ABS, abs);

// Export the functions out of this package, both as expr_fn as well as a list of functions
export_functions!(
(isnan, num, "returns true if a given number is +NaN or -NaN otherwise returns false")
);

(isnan, num, "returns true if a given number is +NaN or -NaN otherwise returns false"),
(abs, num, "returns the absolute value of a given number")
);
4 changes: 0 additions & 4 deletions datafusion/physical-expr/src/functions.rs
Original file line number Diff line number Diff line change
Expand Up @@ -260,9 +260,6 @@ pub fn create_physical_fun(
) -> Result<ScalarFunctionImplementation> {
Ok(match fun {
// math functions
BuiltinScalarFunction::Abs => Arc::new(|args| {
make_scalar_function_inner(math_expressions::abs_invoke)(args)
}),
BuiltinScalarFunction::Acos => Arc::new(math_expressions::acos),
BuiltinScalarFunction::Asin => Arc::new(math_expressions::asin),
BuiltinScalarFunction::Atan => Arc::new(math_expressions::atan),
Expand Down Expand Up @@ -3075,7 +3072,6 @@ mod tests {
let funs = [
BuiltinScalarFunction::Concat,
BuiltinScalarFunction::ToTimestamp,
BuiltinScalarFunction::Abs,
BuiltinScalarFunction::Repeat,
];

Expand Down
Loading

0 comments on commit 85f7a8e

Please sign in to comment.