From 2d303346c7d44a78d317e7b7b959cbd94edfb891 Mon Sep 17 00:00:00 2001 From: Landon Gingerich Date: Wed, 12 Feb 2025 16:28:48 -0600 Subject: [PATCH] Use ` take_function_args` in more places (#14525) * refactor: apply take_function_args() in functions crate * fix: handle plural vs. singular grammar for "argument(s)" * fix: run cargo clippy and fix errors * style: apply cargo fmt * refactor: move func to datafusion_common and update imports * refactor: apply take_function_args * fix: update test output language * fix: simplify doc test for take_function_args --------- Co-authored-by: Andrew Lamb --- datafusion/common/src/utils/mod.rs | 41 +++++++++++++++++- datafusion/expr/src/test/function_stub.rs | 8 ++-- datafusion/functions-aggregate/src/average.rs | 10 ++--- datafusion/functions-aggregate/src/sum.rs | 10 ++--- .../functions-nested/src/cardinality.rs | 14 +++--- datafusion/functions-nested/src/concat.rs | 16 +++---- datafusion/functions-nested/src/dimension.rs | 22 ++++------ datafusion/functions-nested/src/distance.rs | 14 +++--- datafusion/functions-nested/src/empty.rs | 12 +++--- datafusion/functions-nested/src/except.rs | 10 ++--- datafusion/functions-nested/src/extract.rs | 42 ++++++++---------- datafusion/functions-nested/src/flatten.rs | 17 +++----- datafusion/functions-nested/src/map.rs | 30 +++++-------- .../functions-nested/src/map_extract.rs | 31 ++++++------- datafusion/functions-nested/src/map_keys.rs | 14 +++--- datafusion/functions-nested/src/map_values.rs | 14 +++--- datafusion/functions-nested/src/position.rs | 14 +++--- datafusion/functions-nested/src/range.rs | 36 +++++++--------- datafusion/functions-nested/src/remove.rs | 26 +++++------ datafusion/functions-nested/src/repeat.rs | 15 ++----- datafusion/functions-nested/src/replace.rs | 35 ++++++--------- datafusion/functions-nested/src/reverse.rs | 14 +++--- datafusion/functions-nested/src/set_ops.rs | 31 ++++--------- datafusion/functions/src/core/arrow_cast.rs | 14 +++--- datafusion/functions/src/core/arrowtypeof.rs | 3 +- datafusion/functions/src/core/getfield.rs | 29 +++++-------- datafusion/functions/src/core/nullif.rs | 4 +- datafusion/functions/src/core/nvl.rs | 3 +- datafusion/functions/src/core/nvl2.rs | 3 +- datafusion/functions/src/core/version.rs | 3 +- datafusion/functions/src/crypto/basic.rs | 6 +-- .../functions/src/datetime/date_part.rs | 10 ++--- .../functions/src/datetime/make_date.rs | 3 +- datafusion/functions/src/datetime/to_char.rs | 3 +- .../functions/src/datetime/to_local_time.rs | 34 +++++---------- datafusion/functions/src/encoding/inner.rs | 37 ++++++---------- datafusion/functions/src/math/abs.rs | 5 ++- datafusion/functions/src/string/bit_length.rs | 11 ++--- .../functions/src/string/levenshtein.rs | 23 ++++------ .../functions/src/string/octet_length.rs | 11 ++--- .../functions/src/unicode/find_in_set.rs | 16 +++---- .../functions/src/unicode/substrindex.rs | 29 ++++++------- datafusion/functions/src/utils.rs | 43 +------------------ datafusion/sqllogictest/test_files/map.slt | 2 +- 44 files changed, 298 insertions(+), 470 deletions(-) diff --git a/datafusion/common/src/utils/mod.rs b/datafusion/common/src/utils/mod.rs index 0bf7c03a0a19..cb77cc8e79b1 100644 --- a/datafusion/common/src/utils/mod.rs +++ b/datafusion/common/src/utils/mod.rs @@ -22,7 +22,7 @@ pub mod memory; pub mod proxy; pub mod string_utils; -use crate::error::{_internal_datafusion_err, _internal_err}; +use crate::error::{_exec_datafusion_err, _internal_datafusion_err, _internal_err}; use crate::{DataFusionError, Result, ScalarValue}; use arrow::array::{ cast::AsArray, Array, ArrayRef, FixedSizeListArray, LargeListArray, ListArray, @@ -905,6 +905,45 @@ pub fn get_available_parallelism() -> usize { .get() } +/// Converts a collection of function arguments into an fixed-size array of length N +/// producing a reasonable error message in case of unexpected number of arguments. +/// +/// # Example +/// ``` +/// # use datafusion_common::Result; +/// # use datafusion_common::utils::take_function_args; +/// # use datafusion_common::ScalarValue; +/// fn my_function(args: &[ScalarValue]) -> Result<()> { +/// // function expects 2 args, so create a 2-element array +/// let [arg1, arg2] = take_function_args("my_function", args)?; +/// // ... do stuff.. +/// Ok(()) +/// } +/// +/// // Calling the function with 1 argument produces an error: +/// let args = vec![ScalarValue::Int32(Some(10))]; +/// let err = my_function(&args).unwrap_err(); +/// assert_eq!(err.to_string(), "Execution error: my_function function requires 2 arguments, got 1"); +/// // Calling the function with 2 arguments works great +/// let args = vec![ScalarValue::Int32(Some(10)), ScalarValue::Int32(Some(20))]; +/// my_function(&args).unwrap(); +/// ``` +pub fn take_function_args( + function_name: &str, + args: impl IntoIterator, +) -> Result<[T; N]> { + let args = args.into_iter().collect::>(); + args.try_into().map_err(|v: Vec| { + _exec_datafusion_err!( + "{} function requires {} {}, got {}", + function_name, + N, + if N == 1 { "argument" } else { "arguments" }, + v.len() + ) + }) +} + #[cfg(test)] mod tests { use super::*; diff --git a/datafusion/expr/src/test/function_stub.rs b/datafusion/expr/src/test/function_stub.rs index 71ab1ad6ef9b..a753f4c376c6 100644 --- a/datafusion/expr/src/test/function_stub.rs +++ b/datafusion/expr/src/test/function_stub.rs @@ -25,7 +25,7 @@ use arrow::datatypes::{ DataType, Field, DECIMAL128_MAX_PRECISION, DECIMAL256_MAX_PRECISION, }; -use datafusion_common::{exec_err, not_impl_err, Result}; +use datafusion_common::{exec_err, not_impl_err, utils::take_function_args, Result}; use crate::type_coercion::aggregates::{avg_return_type, coerce_avg_type, NUMERICS}; use crate::Volatility::Immutable; @@ -125,9 +125,7 @@ impl AggregateUDFImpl for Sum { } fn coerce_types(&self, arg_types: &[DataType]) -> Result> { - if arg_types.len() != 1 { - return exec_err!("SUM expects exactly one argument"); - } + let [array] = take_function_args(self.name(), arg_types)?; // Refer to https://www.postgresql.org/docs/8.2/functions-aggregate.html doc // smallint, int, bigint, real, double precision, decimal, or interval. @@ -147,7 +145,7 @@ impl AggregateUDFImpl for Sum { } } - Ok(vec![coerced_type(&arg_types[0])?]) + Ok(vec![coerced_type(array)?]) } fn return_type(&self, arg_types: &[DataType]) -> Result { diff --git a/datafusion/functions-aggregate/src/average.rs b/datafusion/functions-aggregate/src/average.rs index 18874f831e9d..141771b0412f 100644 --- a/datafusion/functions-aggregate/src/average.rs +++ b/datafusion/functions-aggregate/src/average.rs @@ -27,7 +27,9 @@ use arrow::datatypes::{ i256, ArrowNativeType, DataType, Decimal128Type, Decimal256Type, DecimalType, Field, Float64Type, UInt64Type, }; -use datafusion_common::{exec_err, not_impl_err, Result, ScalarValue}; +use datafusion_common::{ + exec_err, not_impl_err, utils::take_function_args, Result, ScalarValue, +}; use datafusion_expr::function::{AccumulatorArgs, StateFieldsArgs}; use datafusion_expr::type_coercion::aggregates::{avg_return_type, coerce_avg_type}; use datafusion_expr::utils::format_state_name; @@ -247,10 +249,8 @@ impl AggregateUDFImpl for Avg { } fn coerce_types(&self, arg_types: &[DataType]) -> Result> { - if arg_types.len() != 1 { - return exec_err!("{} expects exactly one argument.", self.name()); - } - coerce_avg_type(self.name(), arg_types) + let [args] = take_function_args(self.name(), arg_types)?; + coerce_avg_type(self.name(), std::slice::from_ref(args)) } fn documentation(&self) -> Option<&Documentation> { diff --git a/datafusion/functions-aggregate/src/sum.rs b/datafusion/functions-aggregate/src/sum.rs index 9615ca33a5f3..76a1315c2d88 100644 --- a/datafusion/functions-aggregate/src/sum.rs +++ b/datafusion/functions-aggregate/src/sum.rs @@ -33,7 +33,9 @@ use arrow::datatypes::{ DECIMAL128_MAX_PRECISION, DECIMAL256_MAX_PRECISION, }; use arrow::{array::ArrayRef, datatypes::Field}; -use datafusion_common::{exec_err, not_impl_err, Result, ScalarValue}; +use datafusion_common::{ + exec_err, not_impl_err, utils::take_function_args, Result, ScalarValue, +}; use datafusion_expr::function::AccumulatorArgs; use datafusion_expr::function::StateFieldsArgs; use datafusion_expr::utils::format_state_name; @@ -125,9 +127,7 @@ impl AggregateUDFImpl for Sum { } fn coerce_types(&self, arg_types: &[DataType]) -> Result> { - if arg_types.len() != 1 { - return exec_err!("SUM expects exactly one argument"); - } + let [args] = take_function_args(self.name(), arg_types)?; // Refer to https://www.postgresql.org/docs/8.2/functions-aggregate.html doc // smallint, int, bigint, real, double precision, decimal, or interval. @@ -147,7 +147,7 @@ impl AggregateUDFImpl for Sum { } } - Ok(vec![coerced_type(&arg_types[0])?]) + Ok(vec![coerced_type(args)?]) } fn return_type(&self, arg_types: &[DataType]) -> Result { diff --git a/datafusion/functions-nested/src/cardinality.rs b/datafusion/functions-nested/src/cardinality.rs index f38a2ab5b90a..ad30c0b540af 100644 --- a/datafusion/functions-nested/src/cardinality.rs +++ b/datafusion/functions-nested/src/cardinality.rs @@ -26,6 +26,7 @@ use arrow::datatypes::{ DataType::{FixedSizeList, LargeList, List, Map, UInt64}, }; use datafusion_common::cast::{as_large_list_array, as_list_array, as_map_array}; +use datafusion_common::utils::take_function_args; use datafusion_common::Result; use datafusion_common::{exec_err, plan_err}; use datafusion_expr::{ @@ -127,21 +128,18 @@ impl ScalarUDFImpl for Cardinality { /// Cardinality SQL function pub fn cardinality_inner(args: &[ArrayRef]) -> Result { - if args.len() != 1 { - return exec_err!("cardinality expects one argument"); - } - - match &args[0].data_type() { + let [array] = take_function_args("cardinality", args)?; + match &array.data_type() { List(_) => { - let list_array = as_list_array(&args[0])?; + let list_array = as_list_array(&array)?; generic_list_cardinality::(list_array) } LargeList(_) => { - let list_array = as_large_list_array(&args[0])?; + let list_array = as_large_list_array(&array)?; generic_list_cardinality::(list_array) } Map(_, _) => { - let map_array = as_map_array(&args[0])?; + let map_array = as_map_array(&array)?; generic_map_cardinality(map_array) } other => { diff --git a/datafusion/functions-nested/src/concat.rs b/datafusion/functions-nested/src/concat.rs index 2f1d9e9938a8..14d4b958867f 100644 --- a/datafusion/functions-nested/src/concat.rs +++ b/datafusion/functions-nested/src/concat.rs @@ -28,7 +28,9 @@ use arrow::buffer::OffsetBuffer; use arrow::datatypes::{DataType, Field}; use datafusion_common::Result; use datafusion_common::{ - cast::as_generic_list_array, exec_err, not_impl_err, plan_err, utils::list_ndims, + cast::as_generic_list_array, + exec_err, not_impl_err, plan_err, + utils::{list_ndims, take_function_args}, }; use datafusion_expr::{ ColumnarValue, Documentation, ScalarUDFImpl, Signature, Volatility, @@ -415,11 +417,9 @@ fn concat_internal(args: &[ArrayRef]) -> Result { /// Array_append SQL function pub(crate) fn array_append_inner(args: &[ArrayRef]) -> Result { - if args.len() != 2 { - return exec_err!("array_append expects two arguments"); - } + let [array, _] = take_function_args("array_append", args)?; - match args[0].data_type() { + match array.data_type() { DataType::LargeList(_) => general_append_and_prepend::(args, true), _ => general_append_and_prepend::(args, true), } @@ -427,11 +427,9 @@ pub(crate) fn array_append_inner(args: &[ArrayRef]) -> Result { /// Array_prepend SQL function pub(crate) fn array_prepend_inner(args: &[ArrayRef]) -> Result { - if args.len() != 2 { - return exec_err!("array_prepend expects two arguments"); - } + let [_, array] = take_function_args("array_prepend", args)?; - match args[1].data_type() { + match array.data_type() { DataType::LargeList(_) => general_append_and_prepend::(args, false), _ => general_append_and_prepend::(args, false), } diff --git a/datafusion/functions-nested/src/dimension.rs b/datafusion/functions-nested/src/dimension.rs index 30b2650bff38..dc1547b7b437 100644 --- a/datafusion/functions-nested/src/dimension.rs +++ b/datafusion/functions-nested/src/dimension.rs @@ -28,7 +28,7 @@ use arrow::datatypes::{ use std::any::Any; use datafusion_common::cast::{as_large_list_array, as_list_array}; -use datafusion_common::{exec_err, plan_err, Result}; +use datafusion_common::{exec_err, plan_err, utils::take_function_args, Result}; use crate::utils::{compute_array_dims, make_scalar_function}; use datafusion_expr::{ @@ -203,20 +203,18 @@ impl ScalarUDFImpl for ArrayNdims { /// Array_dims SQL function pub fn array_dims_inner(args: &[ArrayRef]) -> Result { - if args.len() != 1 { - return exec_err!("array_dims needs one argument"); - } + let [array] = take_function_args("array_dims", args)?; - let data = match args[0].data_type() { + let data = match array.data_type() { List(_) => { - let array = as_list_array(&args[0])?; + let array = as_list_array(&array)?; array .iter() .map(compute_array_dims) .collect::>>()? } LargeList(_) => { - let array = as_large_list_array(&args[0])?; + let array = as_large_list_array(&array)?; array .iter() .map(compute_array_dims) @@ -234,9 +232,7 @@ pub fn array_dims_inner(args: &[ArrayRef]) -> Result { /// Array_ndims SQL function pub fn array_ndims_inner(args: &[ArrayRef]) -> Result { - if args.len() != 1 { - return exec_err!("array_ndims needs one argument"); - } + let [array_dim] = take_function_args("array_ndims", args)?; fn general_list_ndims( array: &GenericListArray, @@ -254,13 +250,13 @@ pub fn array_ndims_inner(args: &[ArrayRef]) -> Result { Ok(Arc::new(UInt64Array::from(data)) as ArrayRef) } - match args[0].data_type() { + match array_dim.data_type() { List(_) => { - let array = as_list_array(&args[0])?; + let array = as_list_array(&array_dim)?; general_list_ndims::(array) } LargeList(_) => { - let array = as_large_list_array(&args[0])?; + let array = as_large_list_array(&array_dim)?; general_list_ndims::(array) } array_type => exec_err!("array_ndims does not support type {array_type:?}"), diff --git a/datafusion/functions-nested/src/distance.rs b/datafusion/functions-nested/src/distance.rs index 805cfd69c01f..fc33828078c0 100644 --- a/datafusion/functions-nested/src/distance.rs +++ b/datafusion/functions-nested/src/distance.rs @@ -30,7 +30,9 @@ use datafusion_common::cast::{ as_int64_array, }; use datafusion_common::utils::coerced_fixed_size_list_to_list; -use datafusion_common::{exec_err, internal_datafusion_err, Result}; +use datafusion_common::{ + exec_err, internal_datafusion_err, utils::take_function_args, Result, +}; use datafusion_expr::{ ColumnarValue, Documentation, ScalarUDFImpl, Signature, Volatility, }; @@ -110,9 +112,7 @@ impl ScalarUDFImpl for ArrayDistance { } fn coerce_types(&self, arg_types: &[DataType]) -> Result> { - if arg_types.len() != 2 { - return exec_err!("array_distance expects exactly two arguments"); - } + let [_, _] = take_function_args(self.name(), arg_types)?; let mut result = Vec::new(); for arg_type in arg_types { match arg_type { @@ -142,11 +142,9 @@ impl ScalarUDFImpl for ArrayDistance { } pub fn array_distance_inner(args: &[ArrayRef]) -> Result { - if args.len() != 2 { - return exec_err!("array_distance expects exactly two arguments"); - } + let [array1, array2] = take_function_args("array_distance", args)?; - match (&args[0].data_type(), &args[1].data_type()) { + match (&array1.data_type(), &array2.data_type()) { (List(_), List(_)) => general_array_distance::(args), (LargeList(_), LargeList(_)) => general_array_distance::(args), (array_type1, array_type2) => { diff --git a/datafusion/functions-nested/src/empty.rs b/datafusion/functions-nested/src/empty.rs index eab773819bf7..07e5d41b8023 100644 --- a/datafusion/functions-nested/src/empty.rs +++ b/datafusion/functions-nested/src/empty.rs @@ -24,7 +24,7 @@ use arrow::datatypes::{ DataType::{Boolean, FixedSizeList, LargeList, List}, }; use datafusion_common::cast::as_generic_list_array; -use datafusion_common::{exec_err, plan_err, Result}; +use datafusion_common::{exec_err, plan_err, utils::take_function_args, Result}; use datafusion_expr::{ ColumnarValue, Documentation, ScalarUDFImpl, Signature, Volatility, }; @@ -117,14 +117,12 @@ impl ScalarUDFImpl for ArrayEmpty { /// Array_empty SQL function pub fn array_empty_inner(args: &[ArrayRef]) -> Result { - if args.len() != 1 { - return exec_err!("array_empty expects one argument"); - } + let [array] = take_function_args("array_empty", args)?; - let array_type = args[0].data_type(); + let array_type = array.data_type(); match array_type { - List(_) => general_array_empty::(&args[0]), - LargeList(_) => general_array_empty::(&args[0]), + List(_) => general_array_empty::(array), + LargeList(_) => general_array_empty::(array), _ => exec_err!("array_empty does not support type '{array_type:?}'."), } } diff --git a/datafusion/functions-nested/src/except.rs b/datafusion/functions-nested/src/except.rs index fda76894507b..f7958caa6379 100644 --- a/datafusion/functions-nested/src/except.rs +++ b/datafusion/functions-nested/src/except.rs @@ -22,7 +22,8 @@ use arrow::array::{cast::AsArray, Array, ArrayRef, GenericListArray, OffsetSizeT use arrow::buffer::OffsetBuffer; use arrow::datatypes::{DataType, FieldRef}; use arrow::row::{RowConverter, SortField}; -use datafusion_common::{exec_err, internal_err, HashSet, Result}; +use datafusion_common::utils::take_function_args; +use datafusion_common::{internal_err, HashSet, Result}; use datafusion_expr::{ ColumnarValue, Documentation, ScalarUDFImpl, Signature, Volatility, }; @@ -124,12 +125,7 @@ impl ScalarUDFImpl for ArrayExcept { /// Array_except SQL function pub fn array_except_inner(args: &[ArrayRef]) -> Result { - if args.len() != 2 { - return exec_err!("array_except needs two arguments"); - } - - let array1 = &args[0]; - let array2 = &args[1]; + let [array1, array2] = take_function_args("array_except", args)?; match (array1.data_type(), array2.data_type()) { (DataType::Null, _) | (_, DataType::Null) => Ok(array1.to_owned()), diff --git a/datafusion/functions-nested/src/extract.rs b/datafusion/functions-nested/src/extract.rs index 93c2b45dde9b..697c868fdea1 100644 --- a/datafusion/functions-nested/src/extract.rs +++ b/datafusion/functions-nested/src/extract.rs @@ -31,7 +31,8 @@ use datafusion_common::cast::as_int64_array; use datafusion_common::cast::as_large_list_array; use datafusion_common::cast::as_list_array; use datafusion_common::{ - exec_err, internal_datafusion_err, plan_err, DataFusionError, Result, + exec_err, internal_datafusion_err, plan_err, utils::take_function_args, + DataFusionError, Result, }; use datafusion_expr::{ArrayFunctionSignature, Expr, TypeSignature}; use datafusion_expr::{ @@ -194,24 +195,22 @@ impl ScalarUDFImpl for ArrayElement { /// For example: /// > array_element(\[1, 2, 3], 2) -> 2 fn array_element_inner(args: &[ArrayRef]) -> Result { - if args.len() != 2 { - return exec_err!("array_element needs two arguments"); - } + let [array, indexes] = take_function_args("array_element", args)?; - match &args[0].data_type() { + match &array.data_type() { List(_) => { - let array = as_list_array(&args[0])?; - let indexes = as_int64_array(&args[1])?; + let array = as_list_array(&array)?; + let indexes = as_int64_array(&indexes)?; general_array_element::(array, indexes) } LargeList(_) => { - let array = as_large_list_array(&args[0])?; - let indexes = as_int64_array(&args[1])?; + let array = as_large_list_array(&array)?; + let indexes = as_int64_array(&indexes)?; general_array_element::(array, indexes) } _ => exec_err!( "array_element does not support type: {:?}", - args[0].data_type() + array.data_type() ), } } @@ -807,23 +806,20 @@ impl ScalarUDFImpl for ArrayPopBack { /// array_pop_back SQL function fn array_pop_back_inner(args: &[ArrayRef]) -> Result { - if args.len() != 1 { - return exec_err!("array_pop_back needs one argument"); - } + let [array] = take_function_args("array_pop_back", args)?; - let array_data_type = args[0].data_type(); - match array_data_type { + match array.data_type() { List(_) => { - let array = as_list_array(&args[0])?; + let array = as_list_array(&array)?; general_pop_back_list::(array) } LargeList(_) => { - let array = as_large_list_array(&args[0])?; + let array = as_large_list_array(&array)?; general_pop_back_list::(array) } _ => exec_err!( "array_pop_back does not support type: {:?}", - array_data_type + array.data_type() ), } } @@ -914,17 +910,15 @@ impl ScalarUDFImpl for ArrayAnyValue { } fn array_any_value_inner(args: &[ArrayRef]) -> Result { - if args.len() != 1 { - return exec_err!("array_any_value expects one argument"); - } + let [array] = take_function_args("array_any_value", args)?; - match &args[0].data_type() { + match &array.data_type() { List(_) => { - let array = as_list_array(&args[0])?; + let array = as_list_array(&array)?; general_array_any_value::(array) } LargeList(_) => { - let array = as_large_list_array(&args[0])?; + let array = as_large_list_array(&array)?; general_array_any_value::(array) } data_type => exec_err!("array_any_value does not support type: {:?}", data_type), diff --git a/datafusion/functions-nested/src/flatten.rs b/datafusion/functions-nested/src/flatten.rs index be5d80ecf4ad..0003db38e0e4 100644 --- a/datafusion/functions-nested/src/flatten.rs +++ b/datafusion/functions-nested/src/flatten.rs @@ -27,7 +27,7 @@ use arrow::datatypes::{ use datafusion_common::cast::{ as_generic_list_array, as_large_list_array, as_list_array, }; -use datafusion_common::{exec_err, Result}; +use datafusion_common::{exec_err, utils::take_function_args, Result}; use datafusion_expr::{ ArrayFunctionSignature, ColumnarValue, Documentation, ScalarUDFImpl, Signature, TypeSignature, Volatility, @@ -143,25 +143,22 @@ impl ScalarUDFImpl for Flatten { /// Flatten SQL function pub fn flatten_inner(args: &[ArrayRef]) -> Result { - if args.len() != 1 { - return exec_err!("flatten expects one argument"); - } + let [array] = take_function_args("flatten", args)?; - let array_type = args[0].data_type(); - match array_type { + match array.data_type() { List(_) => { - let list_arr = as_list_array(&args[0])?; + let list_arr = as_list_array(&array)?; let flattened_array = flatten_internal::(list_arr.clone(), None)?; Ok(Arc::new(flattened_array) as ArrayRef) } LargeList(_) => { - let list_arr = as_large_list_array(&args[0])?; + let list_arr = as_large_list_array(&array)?; let flattened_array = flatten_internal::(list_arr.clone(), None)?; Ok(Arc::new(flattened_array) as ArrayRef) } - Null => Ok(Arc::clone(&args[0])), + Null => Ok(Arc::clone(array)), _ => { - exec_err!("flatten does not support type '{array_type:?}'") + exec_err!("flatten does not support type '{:?}'", array.data_type()) } } } diff --git a/datafusion/functions-nested/src/map.rs b/datafusion/functions-nested/src/map.rs index 26e7733d581b..67ff9182517e 100644 --- a/datafusion/functions-nested/src/map.rs +++ b/datafusion/functions-nested/src/map.rs @@ -24,7 +24,9 @@ use arrow::buffer::Buffer; use arrow::datatypes::{DataType, Field, SchemaBuilder, ToByteSlice}; use datafusion_common::utils::{fixed_size_list_to_arrays, list_to_arrays}; -use datafusion_common::{exec_err, HashSet, Result, ScalarValue}; +use datafusion_common::{ + exec_err, utils::take_function_args, HashSet, Result, ScalarValue, +}; use datafusion_expr::expr::ScalarFunction; use datafusion_expr::{ ColumnarValue, Documentation, Expr, ScalarUDFImpl, Signature, Volatility, @@ -55,23 +57,18 @@ fn can_evaluate_to_const(args: &[ColumnarValue]) -> bool { } fn make_map_batch(args: &[ColumnarValue]) -> Result { - if args.len() != 2 { - return exec_err!( - "make_map requires exactly 2 arguments, got {} instead", - args.len() - ); - } + let [keys_arg, values_arg] = take_function_args("make_map", args)?; let can_evaluate_to_const = can_evaluate_to_const(args); // check the keys array is unique - let keys = get_first_array_ref(&args[0])?; + let keys = get_first_array_ref(keys_arg)?; if keys.null_count() > 0 { return exec_err!("map key cannot be null"); } let key_array = keys.as_ref(); - match &args[0] { + match keys_arg { ColumnarValue::Array(_) => { let row_keys = match key_array.data_type() { DataType::List(_) => list_to_arrays::(&keys), @@ -94,8 +91,8 @@ fn make_map_batch(args: &[ColumnarValue]) -> Result { } } - let values = get_first_array_ref(&args[1])?; - make_map_batch_internal(keys, values, can_evaluate_to_const, args[0].data_type()) + let values = get_first_array_ref(values_arg)?; + make_map_batch_internal(keys, values, can_evaluate_to_const, keys_arg.data_type()) } fn check_unique_keys(array: &dyn Array) -> Result<()> { @@ -257,21 +254,16 @@ impl ScalarUDFImpl for MapFunc { } fn return_type(&self, arg_types: &[DataType]) -> Result { - if arg_types.len() != 2 { - return exec_err!( - "map requires exactly 2 arguments, got {} instead", - arg_types.len() - ); - } + let [keys_arg, values_arg] = take_function_args(self.name(), arg_types)?; let mut builder = SchemaBuilder::new(); builder.push(Field::new( "key", - get_element_type(&arg_types[0])?.clone(), + get_element_type(keys_arg)?.clone(), false, )); builder.push(Field::new( "value", - get_element_type(&arg_types[1])?.clone(), + get_element_type(values_arg)?.clone(), true, )); let fields = builder.finish().fields; diff --git a/datafusion/functions-nested/src/map_extract.rs b/datafusion/functions-nested/src/map_extract.rs index 98fb8440427b..ddc12482e380 100644 --- a/datafusion/functions-nested/src/map_extract.rs +++ b/datafusion/functions-nested/src/map_extract.rs @@ -17,11 +17,13 @@ //! [`ScalarUDFImpl`] definitions for map_extract functions. +use crate::utils::{get_map_entry_field, make_scalar_function}; use arrow::array::{ make_array, Array, ArrayRef, Capacities, ListArray, MapArray, MutableArrayData, }; use arrow::buffer::OffsetBuffer; use arrow::datatypes::{DataType, Field}; +use datafusion_common::utils::take_function_args; use datafusion_common::{cast::as_map_array, exec_err, Result}; use datafusion_expr::{ ColumnarValue, Documentation, ScalarUDFImpl, Signature, Volatility, @@ -31,8 +33,6 @@ use std::any::Any; use std::sync::Arc; use std::vec; -use crate::utils::{get_map_entry_field, make_scalar_function}; - // Create static instances of ScalarUDFs for each function make_udf_expr_and_func!( MapExtract, @@ -102,10 +102,7 @@ impl ScalarUDFImpl for MapExtract { } fn return_type(&self, arg_types: &[DataType]) -> Result { - if arg_types.len() != 2 { - return exec_err!("map_extract expects two arguments"); - } - let map_type = &arg_types[0]; + let [map_type, _] = take_function_args(self.name(), arg_types)?; let map_fields = get_map_entry_field(map_type)?; Ok(DataType::List(Arc::new(Field::new_list_field( map_fields.last().unwrap().data_type().clone(), @@ -126,13 +123,11 @@ impl ScalarUDFImpl for MapExtract { } fn coerce_types(&self, arg_types: &[DataType]) -> Result> { - if arg_types.len() != 2 { - return exec_err!("map_extract expects two arguments"); - } + let [map_type, _] = take_function_args(self.name(), arg_types)?; - let field = get_map_entry_field(&arg_types[0])?; + let field = get_map_entry_field(map_type)?; Ok(vec![ - arg_types[0].clone(), + map_type.clone(), field.first().unwrap().data_type().clone(), ]) } @@ -188,24 +183,22 @@ fn general_map_extract_inner( } fn map_extract_inner(args: &[ArrayRef]) -> Result { - if args.len() != 2 { - return exec_err!("map_extract expects two arguments"); - } + let [map_arg, key_arg] = take_function_args("map_extract", args)?; - let map_array = match args[0].data_type() { - DataType::Map(_, _) => as_map_array(&args[0])?, + let map_array = match map_arg.data_type() { + DataType::Map(_, _) => as_map_array(&map_arg)?, _ => return exec_err!("The first argument in map_extract must be a map"), }; let key_type = map_array.key_type(); - if key_type != args[1].data_type() { + if key_type != key_arg.data_type() { return exec_err!( "The key type {} does not match the map key type {}", - args[1].data_type(), + key_arg.data_type(), key_type ); } - general_map_extract_inner(map_array, &args[1]) + general_map_extract_inner(map_array, key_arg) } diff --git a/datafusion/functions-nested/src/map_keys.rs b/datafusion/functions-nested/src/map_keys.rs index 40a936208770..c58624e12c60 100644 --- a/datafusion/functions-nested/src/map_keys.rs +++ b/datafusion/functions-nested/src/map_keys.rs @@ -20,6 +20,7 @@ use crate::utils::{get_map_entry_field, make_scalar_function}; use arrow::array::{Array, ArrayRef, ListArray}; use arrow::datatypes::{DataType, Field}; +use datafusion_common::utils::take_function_args; use datafusion_common::{cast::as_map_array, exec_err, Result}; use datafusion_expr::{ ArrayFunctionSignature, ColumnarValue, Documentation, ScalarUDFImpl, Signature, @@ -91,10 +92,7 @@ impl ScalarUDFImpl for MapKeysFunc { } fn return_type(&self, arg_types: &[DataType]) -> Result { - if arg_types.len() != 1 { - return exec_err!("map_keys expects single argument"); - } - let map_type = &arg_types[0]; + let [map_type] = take_function_args(self.name(), arg_types)?; let map_fields = get_map_entry_field(map_type)?; Ok(DataType::List(Arc::new(Field::new_list_field( map_fields.first().unwrap().data_type().clone(), @@ -116,12 +114,10 @@ impl ScalarUDFImpl for MapKeysFunc { } fn map_keys_inner(args: &[ArrayRef]) -> Result { - if args.len() != 1 { - return exec_err!("map_keys expects single argument"); - } + let [map_arg] = take_function_args("map_keys", args)?; - let map_array = match args[0].data_type() { - DataType::Map(_, _) => as_map_array(&args[0])?, + let map_array = match map_arg.data_type() { + DataType::Map(_, _) => as_map_array(&map_arg)?, _ => return exec_err!("Argument for map_keys should be a map"), }; diff --git a/datafusion/functions-nested/src/map_values.rs b/datafusion/functions-nested/src/map_values.rs index 48da2cfb68a4..d4a67b7f67a7 100644 --- a/datafusion/functions-nested/src/map_values.rs +++ b/datafusion/functions-nested/src/map_values.rs @@ -20,6 +20,7 @@ use crate::utils::{get_map_entry_field, make_scalar_function}; use arrow::array::{Array, ArrayRef, ListArray}; use arrow::datatypes::{DataType, Field}; +use datafusion_common::utils::take_function_args; use datafusion_common::{cast::as_map_array, exec_err, Result}; use datafusion_expr::{ ArrayFunctionSignature, ColumnarValue, Documentation, ScalarUDFImpl, Signature, @@ -91,10 +92,7 @@ impl ScalarUDFImpl for MapValuesFunc { } fn return_type(&self, arg_types: &[DataType]) -> Result { - if arg_types.len() != 1 { - return exec_err!("map_values expects single argument"); - } - let map_type = &arg_types[0]; + let [map_type] = take_function_args(self.name(), arg_types)?; let map_fields = get_map_entry_field(map_type)?; Ok(DataType::List(Arc::new(Field::new_list_field( map_fields.last().unwrap().data_type().clone(), @@ -116,12 +114,10 @@ impl ScalarUDFImpl for MapValuesFunc { } fn map_values_inner(args: &[ArrayRef]) -> Result { - if args.len() != 1 { - return exec_err!("map_values expects single argument"); - } + let [map_arg] = take_function_args("map_values", args)?; - let map_array = match args[0].data_type() { - DataType::Map(_, _) => as_map_array(&args[0])?, + let map_array = match map_arg.data_type() { + DataType::Map(_, _) => as_map_array(&map_arg)?, _ => return exec_err!("Argument for map_values should be a map"), }; diff --git a/datafusion/functions-nested/src/position.rs b/datafusion/functions-nested/src/position.rs index 95c5fdf9a59d..9adb174c4f2f 100644 --- a/datafusion/functions-nested/src/position.rs +++ b/datafusion/functions-nested/src/position.rs @@ -37,7 +37,7 @@ use arrow::array::{ use datafusion_common::cast::{ as_generic_list_array, as_int64_array, as_large_list_array, as_list_array, }; -use datafusion_common::{exec_err, internal_err, Result}; +use datafusion_common::{exec_err, internal_err, utils::take_function_args, Result}; use itertools::Itertools; use crate::utils::{compare_element_to_list, make_scalar_function}; @@ -293,20 +293,16 @@ impl ScalarUDFImpl for ArrayPositions { /// Array_positions SQL function pub fn array_positions_inner(args: &[ArrayRef]) -> Result { - if args.len() != 2 { - return exec_err!("array_positions expects two arguments"); - } - - let element = &args[1]; + let [array, element] = take_function_args("array_positions", args)?; - match &args[0].data_type() { + match &array.data_type() { List(_) => { - let arr = as_list_array(&args[0])?; + let arr = as_list_array(&array)?; crate::utils::check_datatypes("array_positions", &[arr.values(), element])?; general_positions::(arr, element) } LargeList(_) => { - let arr = as_large_list_array(&args[0])?; + let arr = as_large_list_array(&array)?; crate::utils::check_datatypes("array_positions", &[arr.values(), element])?; general_positions::(arr, element) } diff --git a/datafusion/functions-nested/src/range.rs b/datafusion/functions-nested/src/range.rs index 3636babe7d5c..dcf5f33ea2c2 100644 --- a/datafusion/functions-nested/src/range.rs +++ b/datafusion/functions-nested/src/range.rs @@ -34,7 +34,8 @@ use datafusion_common::cast::{ as_date32_array, as_int64_array, as_interval_mdn_array, as_timestamp_nanosecond_array, }; use datafusion_common::{ - exec_datafusion_err, exec_err, internal_err, not_impl_datafusion_err, Result, + exec_datafusion_err, exec_err, internal_err, not_impl_datafusion_err, + utils::take_function_args, Result, }; use datafusion_expr::{ ColumnarValue, Documentation, ScalarUDFImpl, Signature, Volatility, @@ -434,13 +435,12 @@ fn gen_range_iter( } fn gen_range_date(args: &[ArrayRef], include_upper_bound: bool) -> Result { - if args.len() != 3 { - return exec_err!("arguments length does not match"); - } + let [start, stop, step] = take_function_args("range", args)?; + let (start_array, stop_array, step_array) = ( - Some(as_date32_array(&args[0])?), - as_date32_array(&args[1])?, - Some(as_interval_mdn_array(&args[2])?), + Some(as_date32_array(start)?), + as_date32_array(stop)?, + Some(as_interval_mdn_array(step)?), ); // values are date32s @@ -507,21 +507,17 @@ fn gen_range_date(args: &[ArrayRef], include_upper_bound: bool) -> Result Result { - if args.len() != 3 { - return exec_err!( - "Arguments length must be 3 for {}", - if include_upper_bound { - "GENERATE_SERIES" - } else { - "RANGE" - } - ); - } + let func_name = if include_upper_bound { + "GENERATE_SERIES" + } else { + "RANGE" + }; + let [start, stop, step] = take_function_args(func_name, args)?; // coerce_types fn should coerce all types to Timestamp(Nanosecond, tz) - let (start_arr, start_tz_opt) = cast_timestamp_arg(&args[0], include_upper_bound)?; - let (stop_arr, stop_tz_opt) = cast_timestamp_arg(&args[1], include_upper_bound)?; - let step_arr = as_interval_mdn_array(&args[2])?; + let (start_arr, start_tz_opt) = cast_timestamp_arg(start, include_upper_bound)?; + let (stop_arr, stop_tz_opt) = cast_timestamp_arg(stop, include_upper_bound)?; + let step_arr = as_interval_mdn_array(step)?; let start_tz = parse_tz(start_tz_opt)?; let stop_tz = parse_tz(stop_tz_opt)?; diff --git a/datafusion/functions-nested/src/remove.rs b/datafusion/functions-nested/src/remove.rs index e196e244646f..f9539dbc1621 100644 --- a/datafusion/functions-nested/src/remove.rs +++ b/datafusion/functions-nested/src/remove.rs @@ -26,7 +26,7 @@ use arrow::array::{ use arrow::buffer::OffsetBuffer; use arrow::datatypes::{DataType, Field}; use datafusion_common::cast::as_int64_array; -use datafusion_common::{exec_err, Result}; +use datafusion_common::{exec_err, utils::take_function_args, Result}; use datafusion_expr::{ ColumnarValue, Documentation, ScalarUDFImpl, Signature, Volatility, }; @@ -277,32 +277,26 @@ impl ScalarUDFImpl for ArrayRemoveAll { /// Array_remove SQL function pub fn array_remove_inner(args: &[ArrayRef]) -> Result { - if args.len() != 2 { - return exec_err!("array_remove expects two arguments"); - } + let [array, element] = take_function_args("array_remove", args)?; - let arr_n = vec![1; args[0].len()]; - array_remove_internal(&args[0], &args[1], arr_n) + let arr_n = vec![1; array.len()]; + array_remove_internal(array, element, arr_n) } /// Array_remove_n SQL function pub fn array_remove_n_inner(args: &[ArrayRef]) -> Result { - if args.len() != 3 { - return exec_err!("array_remove_n expects three arguments"); - } + let [array, element, max] = take_function_args("array_remove_n", args)?; - let arr_n = as_int64_array(&args[2])?.values().to_vec(); - array_remove_internal(&args[0], &args[1], arr_n) + let arr_n = as_int64_array(max)?.values().to_vec(); + array_remove_internal(array, element, arr_n) } /// Array_remove_all SQL function pub fn array_remove_all_inner(args: &[ArrayRef]) -> Result { - if args.len() != 2 { - return exec_err!("array_remove_all expects two arguments"); - } + let [array, element] = take_function_args("array_remove_all", args)?; - let arr_n = vec![i64::MAX; args[0].len()]; - array_remove_internal(&args[0], &args[1], arr_n) + let arr_n = vec![i64::MAX; array.len()]; + array_remove_internal(array, element, arr_n) } fn array_remove_internal( diff --git a/datafusion/functions-nested/src/repeat.rs b/datafusion/functions-nested/src/repeat.rs index e26c79ab45c6..16d7c1912f6d 100644 --- a/datafusion/functions-nested/src/repeat.rs +++ b/datafusion/functions-nested/src/repeat.rs @@ -31,7 +31,7 @@ use arrow::datatypes::{ Field, }; use datafusion_common::cast::{as_large_list_array, as_list_array, as_uint64_array}; -use datafusion_common::{exec_err, Result}; +use datafusion_common::{exec_err, utils::take_function_args, Result}; use datafusion_expr::{ ColumnarValue, Documentation, ScalarUDFImpl, Signature, Volatility, }; @@ -128,17 +128,10 @@ impl ScalarUDFImpl for ArrayRepeat { } fn coerce_types(&self, arg_types: &[DataType]) -> Result> { - if arg_types.len() != 2 { - return exec_err!("array_repeat expects two arguments"); - } - - let element_type = &arg_types[0]; - let first = element_type.clone(); - - let count_type = &arg_types[1]; + let [first_type, second_type] = take_function_args(self.name(), arg_types)?; // Coerce the second argument to Int64/UInt64 if it's a numeric type - let second = match count_type { + let second = match second_type { DataType::Int8 | DataType::Int16 | DataType::Int32 | DataType::Int64 => { DataType::Int64 } @@ -148,7 +141,7 @@ impl ScalarUDFImpl for ArrayRepeat { _ => return exec_err!("count must be an integer type"), }; - Ok(vec![first, second]) + Ok(vec![first_type.clone(), second]) } fn documentation(&self) -> Option<&Documentation> { diff --git a/datafusion/functions-nested/src/replace.rs b/datafusion/functions-nested/src/replace.rs index 8213685f71b6..53f43de4108d 100644 --- a/datafusion/functions-nested/src/replace.rs +++ b/datafusion/functions-nested/src/replace.rs @@ -25,7 +25,7 @@ use arrow::datatypes::{DataType, Field}; use arrow::buffer::OffsetBuffer; use datafusion_common::cast::as_int64_array; -use datafusion_common::{exec_err, Result}; +use datafusion_common::{exec_err, utils::take_function_args, Result}; use datafusion_expr::{ ColumnarValue, Documentation, ScalarUDFImpl, Signature, Volatility, }; @@ -380,42 +380,36 @@ fn general_replace( } pub(crate) fn array_replace_inner(args: &[ArrayRef]) -> Result { - if args.len() != 3 { - return exec_err!("array_replace expects three arguments"); - } + let [array, from, to] = take_function_args("array_replace", args)?; // replace at most one occurrence for each element - let arr_n = vec![1; args[0].len()]; - let array = &args[0]; + let arr_n = vec![1; array.len()]; match array.data_type() { DataType::List(_) => { let list_array = array.as_list::(); - general_replace::(list_array, &args[1], &args[2], arr_n) + general_replace::(list_array, from, to, arr_n) } DataType::LargeList(_) => { let list_array = array.as_list::(); - general_replace::(list_array, &args[1], &args[2], arr_n) + general_replace::(list_array, from, to, arr_n) } array_type => exec_err!("array_replace does not support type '{array_type:?}'."), } } pub(crate) fn array_replace_n_inner(args: &[ArrayRef]) -> Result { - if args.len() != 4 { - return exec_err!("array_replace_n expects four arguments"); - } + let [array, from, to, max] = take_function_args("array_replace_n", args)?; // replace the specified number of occurrences - let arr_n = as_int64_array(&args[3])?.values().to_vec(); - let array = &args[0]; + let arr_n = as_int64_array(max)?.values().to_vec(); match array.data_type() { DataType::List(_) => { let list_array = array.as_list::(); - general_replace::(list_array, &args[1], &args[2], arr_n) + general_replace::(list_array, from, to, arr_n) } DataType::LargeList(_) => { let list_array = array.as_list::(); - general_replace::(list_array, &args[1], &args[2], arr_n) + general_replace::(list_array, from, to, arr_n) } array_type => { exec_err!("array_replace_n does not support type '{array_type:?}'.") @@ -424,21 +418,18 @@ pub(crate) fn array_replace_n_inner(args: &[ArrayRef]) -> Result { } pub(crate) fn array_replace_all_inner(args: &[ArrayRef]) -> Result { - if args.len() != 3 { - return exec_err!("array_replace_all expects three arguments"); - } + let [array, from, to] = take_function_args("array_replace_all", args)?; // replace all occurrences (up to "i64::MAX") - let arr_n = vec![i64::MAX; args[0].len()]; - let array = &args[0]; + let arr_n = vec![i64::MAX; array.len()]; match array.data_type() { DataType::List(_) => { let list_array = array.as_list::(); - general_replace::(list_array, &args[1], &args[2], arr_n) + general_replace::(list_array, from, to, arr_n) } DataType::LargeList(_) => { let list_array = array.as_list::(); - general_replace::(list_array, &args[1], &args[2], arr_n) + general_replace::(list_array, from, to, arr_n) } array_type => { exec_err!("array_replace_all does not support type '{array_type:?}'.") diff --git a/datafusion/functions-nested/src/reverse.rs b/datafusion/functions-nested/src/reverse.rs index d2969cff74ce..b7c4274ca436 100644 --- a/datafusion/functions-nested/src/reverse.rs +++ b/datafusion/functions-nested/src/reverse.rs @@ -25,7 +25,7 @@ use arrow::buffer::OffsetBuffer; use arrow::datatypes::DataType::{LargeList, List, Null}; use arrow::datatypes::{DataType, FieldRef}; use datafusion_common::cast::{as_large_list_array, as_list_array}; -use datafusion_common::{exec_err, Result}; +use datafusion_common::{exec_err, utils::take_function_args, Result}; use datafusion_expr::{ ColumnarValue, Documentation, ScalarUDFImpl, Signature, Volatility, }; @@ -115,20 +115,18 @@ impl ScalarUDFImpl for ArrayReverse { /// array_reverse SQL function pub fn array_reverse_inner(arg: &[ArrayRef]) -> Result { - if arg.len() != 1 { - return exec_err!("array_reverse needs one argument"); - } + let [input_array] = take_function_args("array_reverse", arg)?; - match &arg[0].data_type() { + match &input_array.data_type() { List(field) => { - let array = as_list_array(&arg[0])?; + let array = as_list_array(input_array)?; general_array_reverse::(array, field) } LargeList(field) => { - let array = as_large_list_array(&arg[0])?; + let array = as_large_list_array(input_array)?; general_array_reverse::(array, field) } - Null => Ok(Arc::clone(&arg[0])), + Null => Ok(Arc::clone(input_array)), array_type => exec_err!("array_reverse does not support type '{array_type:?}'."), } } diff --git a/datafusion/functions-nested/src/set_ops.rs b/datafusion/functions-nested/src/set_ops.rs index afb877d4cf9a..97ccc035a046 100644 --- a/datafusion/functions-nested/src/set_ops.rs +++ b/datafusion/functions-nested/src/set_ops.rs @@ -26,7 +26,7 @@ use arrow::datatypes::DataType::{FixedSizeList, LargeList, List, Null}; use arrow::datatypes::{DataType, Field, FieldRef}; use arrow::row::{RowConverter, SortField}; use datafusion_common::cast::{as_large_list_array, as_list_array}; -use datafusion_common::{exec_err, internal_err, Result}; +use datafusion_common::{exec_err, internal_err, utils::take_function_args, Result}; use datafusion_expr::{ ColumnarValue, Documentation, ScalarUDFImpl, Signature, Volatility, }; @@ -308,23 +308,21 @@ impl ScalarUDFImpl for ArrayDistinct { /// array_distinct SQL function /// example: from list [1, 3, 2, 3, 1, 2, 4] to [1, 2, 3, 4] fn array_distinct_inner(args: &[ArrayRef]) -> Result { - if args.len() != 1 { - return exec_err!("array_distinct needs one argument"); - } + let [input_array] = take_function_args("array_distinct", args)?; // handle null - if args[0].data_type() == &Null { - return Ok(Arc::clone(&args[0])); + if input_array.data_type() == &Null { + return Ok(Arc::clone(input_array)); } // handle for list & largelist - match args[0].data_type() { + match input_array.data_type() { List(field) => { - let array = as_list_array(&args[0])?; + let array = as_list_array(&input_array)?; general_array_distinct(array, field) } LargeList(field) => { - let array = as_large_list_array(&args[0])?; + let array = as_large_list_array(&input_array)?; general_array_distinct(array, field) } array_type => exec_err!("array_distinct does not support type '{array_type:?}'"), @@ -488,24 +486,13 @@ fn general_set_op( /// Array_union SQL function fn array_union_inner(args: &[ArrayRef]) -> Result { - if args.len() != 2 { - return exec_err!("array_union needs two arguments"); - } - let array1 = &args[0]; - let array2 = &args[1]; - + let [array1, array2] = take_function_args("array_union", args)?; general_set_op(array1, array2, SetOp::Union) } /// array_intersect SQL function fn array_intersect_inner(args: &[ArrayRef]) -> Result { - if args.len() != 2 { - return exec_err!("array_intersect needs two arguments"); - } - - let array1 = &args[0]; - let array2 = &args[1]; - + let [array1, array2] = take_function_args("array_intersect", args)?; general_set_op(array1, array2, SetOp::Intersect) } diff --git a/datafusion/functions/src/core/arrow_cast.rs b/datafusion/functions/src/core/arrow_cast.rs index 0f9f11b4eff0..1ba5197fe2fb 100644 --- a/datafusion/functions/src/core/arrow_cast.rs +++ b/datafusion/functions/src/core/arrow_cast.rs @@ -22,10 +22,11 @@ use arrow::error::ArrowError; use datafusion_common::{ arrow_datafusion_err, exec_err, internal_err, Result, ScalarValue, }; -use datafusion_common::{exec_datafusion_err, DataFusionError}; +use datafusion_common::{ + exec_datafusion_err, utils::take_function_args, DataFusionError, +}; use std::any::Any; -use crate::utils::take_function_args; use datafusion_expr::simplify::{ExprSimplifyResult, SimplifyInfo}; use datafusion_expr::{ ColumnarValue, Documentation, Expr, ReturnInfo, ReturnTypeArgs, ScalarUDFImpl, @@ -178,13 +179,12 @@ impl ScalarUDFImpl for ArrowCastFunc { /// Returns the requested type from the arguments fn data_type_from_args(args: &[Expr]) -> Result { - if args.len() != 2 { - return exec_err!("arrow_cast needs 2 arguments, {} provided", args.len()); - } - let Expr::Literal(ScalarValue::Utf8(Some(val))) = &args[1] else { + let [_, type_arg] = take_function_args("arrow_cast", args)?; + + let Expr::Literal(ScalarValue::Utf8(Some(val))) = type_arg else { return exec_err!( "arrow_cast requires its second argument to be a constant string, got {:?}", - &args[1] + type_arg ); }; diff --git a/datafusion/functions/src/core/arrowtypeof.rs b/datafusion/functions/src/core/arrowtypeof.rs index 3c672384ffa1..653ca6569896 100644 --- a/datafusion/functions/src/core/arrowtypeof.rs +++ b/datafusion/functions/src/core/arrowtypeof.rs @@ -15,9 +15,8 @@ // specific language governing permissions and limitations // under the License. -use crate::utils::take_function_args; use arrow::datatypes::DataType; -use datafusion_common::{Result, ScalarValue}; +use datafusion_common::{utils::take_function_args, Result, ScalarValue}; use datafusion_expr::{ColumnarValue, Documentation}; use datafusion_expr::{ScalarUDFImpl, Signature, Volatility}; use datafusion_macros::user_doc; diff --git a/datafusion/functions/src/core/getfield.rs b/datafusion/functions/src/core/getfield.rs index 8533b3123d51..d971001dbf78 100644 --- a/datafusion/functions/src/core/getfield.rs +++ b/datafusion/functions/src/core/getfield.rs @@ -15,14 +15,14 @@ // specific language governing permissions and limitations // under the License. -use crate::utils::take_function_args; use arrow::array::{ make_array, Array, Capacities, MutableArrayData, Scalar, StringArray, }; use arrow::datatypes::DataType; use datafusion_common::cast::{as_map_array, as_struct_array}; use datafusion_common::{ - exec_err, internal_err, plan_datafusion_err, Result, ScalarValue, + exec_err, internal_err, plan_datafusion_err, utils::take_function_args, Result, + ScalarValue, }; use datafusion_expr::{ColumnarValue, Documentation, Expr, ReturnInfo, ReturnTypeArgs}; use datafusion_expr::{ScalarUDFImpl, Signature, Volatility}; @@ -115,14 +115,9 @@ impl ScalarUDFImpl for GetFieldFunc { } fn schema_name(&self, args: &[Expr]) -> Result { - if args.len() != 2 { - return exec_err!( - "get_field function requires 2 arguments, got {}", - args.len() - ); - } + let [base, field_name] = take_function_args(self.name(), args)?; - let name = match &args[1] { + let name = match field_name { Expr::Literal(name) => name, _ => { return exec_err!( @@ -131,7 +126,7 @@ impl ScalarUDFImpl for GetFieldFunc { } }; - Ok(format!("{}[{}]", args[0].schema_name(), name)) + Ok(format!("{}[{}]", base.schema_name(), name)) } fn signature(&self) -> &Signature { @@ -180,21 +175,17 @@ impl ScalarUDFImpl for GetFieldFunc { args: &[ColumnarValue], _number_rows: usize, ) -> Result { - if args.len() != 2 { - return exec_err!( - "get_field function requires 2 arguments, got {}", - args.len() - ); - } + let [base, field_name] = take_function_args(self.name(), args)?; - if args[0].data_type().is_null() { + if base.data_type().is_null() { return Ok(ColumnarValue::Scalar(ScalarValue::Null)); } - let arrays = ColumnarValue::values_to_arrays(args)?; + let arrays = + ColumnarValue::values_to_arrays(&[base.clone(), field_name.clone()])?; let array = Arc::clone(&arrays[0]); - let name = match &args[1] { + let name = match field_name { ColumnarValue::Scalar(name) => name, _ => { return exec_err!( diff --git a/datafusion/functions/src/core/nullif.rs b/datafusion/functions/src/core/nullif.rs index a0f3c8b8a452..14366767523f 100644 --- a/datafusion/functions/src/core/nullif.rs +++ b/datafusion/functions/src/core/nullif.rs @@ -16,13 +16,11 @@ // under the License. use arrow::datatypes::DataType; -use datafusion_common::Result; use datafusion_expr::{ColumnarValue, Documentation}; -use crate::utils::take_function_args; use arrow::compute::kernels::cmp::eq; use arrow::compute::kernels::nullif::nullif; -use datafusion_common::ScalarValue; +use datafusion_common::{utils::take_function_args, Result, ScalarValue}; use datafusion_expr::{ScalarUDFImpl, Signature, Volatility}; use datafusion_macros::user_doc; use std::any::Any; diff --git a/datafusion/functions/src/core/nvl.rs b/datafusion/functions/src/core/nvl.rs index 5b306c8093cb..1e261a9bc055 100644 --- a/datafusion/functions/src/core/nvl.rs +++ b/datafusion/functions/src/core/nvl.rs @@ -15,12 +15,11 @@ // specific language governing permissions and limitations // under the License. -use crate::utils::take_function_args; use arrow::array::Array; use arrow::compute::is_not_null; use arrow::compute::kernels::zip::zip; use arrow::datatypes::DataType; -use datafusion_common::Result; +use datafusion_common::{utils::take_function_args, Result}; use datafusion_expr::{ ColumnarValue, Documentation, ScalarUDFImpl, Signature, Volatility, }; diff --git a/datafusion/functions/src/core/nvl2.rs b/datafusion/functions/src/core/nvl2.rs index b1f8e4e5c213..71188441043a 100644 --- a/datafusion/functions/src/core/nvl2.rs +++ b/datafusion/functions/src/core/nvl2.rs @@ -15,12 +15,11 @@ // specific language governing permissions and limitations // under the License. -use crate::utils::take_function_args; use arrow::array::Array; use arrow::compute::is_not_null; use arrow::compute::kernels::zip::zip; use arrow::datatypes::DataType; -use datafusion_common::{internal_err, Result}; +use datafusion_common::{internal_err, utils::take_function_args, Result}; use datafusion_expr::{ type_coercion::binary::comparison_coercion, ColumnarValue, Documentation, ScalarUDFImpl, Signature, Volatility, diff --git a/datafusion/functions/src/core/version.rs b/datafusion/functions/src/core/version.rs index 139763af7b38..5fa8347c8787 100644 --- a/datafusion/functions/src/core/version.rs +++ b/datafusion/functions/src/core/version.rs @@ -17,9 +17,8 @@ //! [`VersionFunc`]: Implementation of the `version` function. -use crate::utils::take_function_args; use arrow::datatypes::DataType; -use datafusion_common::{Result, ScalarValue}; +use datafusion_common::{utils::take_function_args, Result, ScalarValue}; use datafusion_expr::{ ColumnarValue, Documentation, ScalarUDFImpl, Signature, Volatility, }; diff --git a/datafusion/functions/src/crypto/basic.rs b/datafusion/functions/src/crypto/basic.rs index a15b9b57cff6..191154b8f8ff 100644 --- a/datafusion/functions/src/crypto/basic.rs +++ b/datafusion/functions/src/crypto/basic.rs @@ -24,12 +24,10 @@ use blake2::{Blake2b512, Blake2s256, Digest}; use blake3::Hasher as Blake3; use datafusion_common::cast::as_binary_array; -use crate::utils::take_function_args; use arrow::compute::StringArrayType; -use datafusion_common::plan_err; use datafusion_common::{ - cast::as_generic_binary_array, exec_err, internal_err, DataFusionError, Result, - ScalarValue, + cast::as_generic_binary_array, exec_err, internal_err, plan_err, + utils::take_function_args, DataFusionError, Result, ScalarValue, }; use datafusion_expr::ColumnarValue; use md5::Md5; diff --git a/datafusion/functions/src/datetime/date_part.rs b/datafusion/functions/src/datetime/date_part.rs index c7dbf089e530..9df91da67f39 100644 --- a/datafusion/functions/src/datetime/date_part.rs +++ b/datafusion/functions/src/datetime/date_part.rs @@ -28,8 +28,6 @@ use arrow::datatypes::DataType::{ use arrow::datatypes::TimeUnit::{Microsecond, Millisecond, Nanosecond, Second}; use arrow::datatypes::{DataType, TimeUnit}; -use crate::utils::take_function_args; -use datafusion_common::not_impl_err; use datafusion_common::{ cast::{ as_date32_array, as_date64_array, as_int32_array, as_time32_millisecond_array, @@ -37,8 +35,9 @@ use datafusion_common::{ as_timestamp_microsecond_array, as_timestamp_millisecond_array, as_timestamp_nanosecond_array, as_timestamp_second_array, }, - exec_err, internal_err, + exec_err, internal_err, not_impl_err, types::logical_string, + utils::take_function_args, Result, ScalarValue, }; use datafusion_expr::{ @@ -167,10 +166,7 @@ impl ScalarUDFImpl for DatePartFunc { args: &[ColumnarValue], _number_rows: usize, ) -> Result { - if args.len() != 2 { - return exec_err!("Expected two arguments in DATE_PART"); - } - let (part, array) = (&args[0], &args[1]); + let [part, array] = take_function_args(self.name(), args)?; let part = if let ColumnarValue::Scalar(ScalarValue::Utf8(Some(v))) = part { v diff --git a/datafusion/functions/src/datetime/make_date.rs b/datafusion/functions/src/datetime/make_date.rs index 2d4db56cc788..f081dfd11ecf 100644 --- a/datafusion/functions/src/datetime/make_date.rs +++ b/datafusion/functions/src/datetime/make_date.rs @@ -26,8 +26,7 @@ use arrow::datatypes::DataType; use arrow::datatypes::DataType::{Date32, Int32, Int64, UInt32, UInt64, Utf8, Utf8View}; use chrono::prelude::*; -use crate::utils::take_function_args; -use datafusion_common::{exec_err, Result, ScalarValue}; +use datafusion_common::{exec_err, utils::take_function_args, Result, ScalarValue}; use datafusion_expr::{ ColumnarValue, Documentation, ScalarUDFImpl, Signature, Volatility, }; diff --git a/datafusion/functions/src/datetime/to_char.rs b/datafusion/functions/src/datetime/to_char.rs index 485fdc7a3384..b049ca01ac97 100644 --- a/datafusion/functions/src/datetime/to_char.rs +++ b/datafusion/functions/src/datetime/to_char.rs @@ -28,8 +28,7 @@ use arrow::datatypes::TimeUnit::{Microsecond, Millisecond, Nanosecond, Second}; use arrow::error::ArrowError; use arrow::util::display::{ArrayFormatter, DurationFormat, FormatOptions}; -use crate::utils::take_function_args; -use datafusion_common::{exec_err, Result, ScalarValue}; +use datafusion_common::{exec_err, utils::take_function_args, Result, ScalarValue}; use datafusion_expr::TypeSignature::Exact; use datafusion_expr::{ ColumnarValue, Documentation, ScalarUDFImpl, Signature, Volatility, TIMEZONE_WILDCARD, diff --git a/datafusion/functions/src/datetime/to_local_time.rs b/datafusion/functions/src/datetime/to_local_time.rs index b350819a55ec..0e235735e29f 100644 --- a/datafusion/functions/src/datetime/to_local_time.rs +++ b/datafusion/functions/src/datetime/to_local_time.rs @@ -27,10 +27,12 @@ use arrow::datatypes::{ ArrowTimestampType, DataType, TimestampMicrosecondType, TimestampMillisecondType, TimestampNanosecondType, TimestampSecondType, }; - use chrono::{DateTime, MappedLocalTime, Offset, TimeDelta, TimeZone, Utc}; + use datafusion_common::cast::as_primitive_array; -use datafusion_common::{exec_err, plan_err, DataFusionError, Result, ScalarValue}; +use datafusion_common::{ + exec_err, plan_err, utils::take_function_args, DataFusionError, Result, ScalarValue, +}; use datafusion_expr::{ ColumnarValue, Documentation, ScalarUDFImpl, Signature, Volatility, }; @@ -113,14 +115,8 @@ impl ToLocalTimeFunc { } fn to_local_time(&self, args: &[ColumnarValue]) -> Result { - if args.len() != 1 { - return exec_err!( - "to_local_time function requires 1 argument, got {}", - args.len() - ); - } + let [time_value] = take_function_args(self.name(), args)?; - let time_value = &args[0]; let arg_type = time_value.data_type(); match arg_type { Timestamp(_, None) => { @@ -360,17 +356,12 @@ impl ScalarUDFImpl for ToLocalTimeFunc { } fn return_type(&self, arg_types: &[DataType]) -> Result { - if arg_types.len() != 1 { - return exec_err!( - "to_local_time function requires 1 argument, got {:?}", - arg_types.len() - ); - } + let [time_value] = take_function_args(self.name(), arg_types)?; - match &arg_types[0] { + match time_value { Timestamp(timeunit, _) => Ok(Timestamp(*timeunit, None)), _ => exec_err!( - "The to_local_time function can only accept timestamp as the arg, got {:?}", arg_types[0] + "The to_local_time function can only accept timestamp as the arg, got {:?}", time_value ) } } @@ -380,14 +371,9 @@ impl ScalarUDFImpl for ToLocalTimeFunc { args: &[ColumnarValue], _number_rows: usize, ) -> Result { - if args.len() != 1 { - return exec_err!( - "to_local_time function requires 1 argument, got {:?}", - args.len() - ); - } + let [time_value] = take_function_args(self.name(), args)?; - self.to_local_time(args) + self.to_local_time(&[time_value.clone()]) } fn coerce_types(&self, arg_types: &[DataType]) -> Result> { diff --git a/datafusion/functions/src/encoding/inner.rs b/datafusion/functions/src/encoding/inner.rs index a5338ff76592..68a6d1006052 100644 --- a/datafusion/functions/src/encoding/inner.rs +++ b/datafusion/functions/src/encoding/inner.rs @@ -28,6 +28,7 @@ use base64::{engine::general_purpose, Engine as _}; use datafusion_common::{ cast::{as_generic_binary_array, as_generic_string_array}, not_impl_err, plan_err, + utils::take_function_args, }; use datafusion_common::{exec_err, ScalarValue}; use datafusion_common::{DataFusionError, Result}; @@ -111,19 +112,13 @@ impl ScalarUDFImpl for EncodeFunc { } fn coerce_types(&self, arg_types: &[DataType]) -> Result> { - if arg_types.len() != 2 { - return plan_err!( - "{} expects to get 2 arguments, but got {}", - self.name(), - arg_types.len() - ); - } + let [expression, format] = take_function_args(self.name(), arg_types)?; - if arg_types[1] != DataType::Utf8 { + if format != &DataType::Utf8 { return Err(DataFusionError::Plan("2nd argument should be Utf8".into())); } - match arg_types[0] { + match expression { DataType::Utf8 | DataType::Utf8View | DataType::Null => { Ok(vec![DataType::Utf8; 2]) } @@ -539,13 +534,9 @@ impl FromStr for Encoding { /// Second argument is the encoding to use. /// Standard encodings are base64 and hex. fn encode(args: &[ColumnarValue]) -> Result { - if args.len() != 2 { - return exec_err!( - "{:?} args were supplied but encode takes exactly two arguments", - args.len() - ); - } - let encoding = match &args[1] { + let [expression, format] = take_function_args("encode", args)?; + + let encoding = match format { ColumnarValue::Scalar(scalar) => match scalar.try_as_str() { Some(Some(method)) => method.parse::(), _ => not_impl_err!( @@ -556,20 +547,16 @@ fn encode(args: &[ColumnarValue]) -> Result { "Second argument to encode must be a constant: Encode using dynamically decided method is not yet supported" ), }?; - encode_process(&args[0], encoding) + encode_process(expression, encoding) } /// Decodes the given data, accepts Binary, LargeBinary, Utf8, Utf8View or LargeUtf8 and returns a [`ColumnarValue`]. /// Second argument is the encoding to use. /// Standard encodings are base64 and hex. fn decode(args: &[ColumnarValue]) -> Result { - if args.len() != 2 { - return exec_err!( - "{:?} args were supplied but decode takes exactly two arguments", - args.len() - ); - } - let encoding = match &args[1] { + let [expression, format] = take_function_args("decode", args)?; + + let encoding = match format { ColumnarValue::Scalar(scalar) => match scalar.try_as_str() { Some(Some(method))=> method.parse::(), _ => not_impl_err!( @@ -580,5 +567,5 @@ fn decode(args: &[ColumnarValue]) -> Result { "Second argument to decode must be a utf8 constant: Decode using dynamically decided method is not yet supported" ), }?; - decode_process(&args[0], encoding) + decode_process(expression, encoding) } diff --git a/datafusion/functions/src/math/abs.rs b/datafusion/functions/src/math/abs.rs index a375af2ad29e..ff6a82113262 100644 --- a/datafusion/functions/src/math/abs.rs +++ b/datafusion/functions/src/math/abs.rs @@ -20,14 +20,15 @@ use std::any::Any; use std::sync::Arc; -use crate::utils::take_function_args; use arrow::array::{ ArrayRef, Decimal128Array, Decimal256Array, Float32Array, Float64Array, Int16Array, Int32Array, Int64Array, Int8Array, }; use arrow::datatypes::DataType; use arrow::error::ArrowError; -use datafusion_common::{internal_datafusion_err, not_impl_err, Result}; +use datafusion_common::{ + internal_datafusion_err, not_impl_err, utils::take_function_args, Result, +}; use datafusion_expr::interval_arithmetic::Interval; use datafusion_expr::sort_properties::{ExprProperties, SortProperties}; use datafusion_expr::{ diff --git a/datafusion/functions/src/string/bit_length.rs b/datafusion/functions/src/string/bit_length.rs index 623fb2ba03f0..f7e9fce960fe 100644 --- a/datafusion/functions/src/string/bit_length.rs +++ b/datafusion/functions/src/string/bit_length.rs @@ -20,7 +20,7 @@ use arrow::datatypes::DataType; use std::any::Any; use crate::utils::utf8_to_int_type; -use datafusion_common::{exec_err, Result, ScalarValue}; +use datafusion_common::{utils::take_function_args, Result, ScalarValue}; use datafusion_expr::{ColumnarValue, Documentation, Volatility}; use datafusion_expr::{ScalarUDFImpl, Signature}; use datafusion_macros::user_doc; @@ -82,14 +82,9 @@ impl ScalarUDFImpl for BitLengthFunc { args: &[ColumnarValue], _number_rows: usize, ) -> Result { - if args.len() != 1 { - return exec_err!( - "bit_length function requires 1 argument, got {}", - args.len() - ); - } + let [array] = take_function_args(self.name(), args)?; - match &args[0] { + match array { ColumnarValue::Array(v) => Ok(ColumnarValue::Array(bit_length(v.as_ref())?)), ColumnarValue::Scalar(v) => match v { ScalarValue::Utf8(v) => Ok(ColumnarValue::Scalar(ScalarValue::Int32( diff --git a/datafusion/functions/src/string/levenshtein.rs b/datafusion/functions/src/string/levenshtein.rs index 57392c114d79..c2e5dc52f82f 100644 --- a/datafusion/functions/src/string/levenshtein.rs +++ b/datafusion/functions/src/string/levenshtein.rs @@ -24,7 +24,7 @@ use arrow::datatypes::DataType; use crate::utils::{make_scalar_function, utf8_to_int_type}; use datafusion_common::cast::{as_generic_string_array, as_string_view_array}; use datafusion_common::utils::datafusion_strsim; -use datafusion_common::{exec_err, Result}; +use datafusion_common::{exec_err, utils::take_function_args, Result}; use datafusion_expr::{ColumnarValue, Documentation}; use datafusion_expr::{ScalarUDFImpl, Signature, Volatility}; use datafusion_macros::user_doc; @@ -110,17 +110,12 @@ impl ScalarUDFImpl for LevenshteinFunc { ///Returns the Levenshtein distance between the two given strings. /// LEVENSHTEIN('kitten', 'sitting') = 3 pub fn levenshtein(args: &[ArrayRef]) -> Result { - if args.len() != 2 { - return exec_err!( - "levenshtein function requires two arguments, got {}", - args.len() - ); - } + let [str1, str2] = take_function_args("levenshtein", args)?; - match args[0].data_type() { + match str1.data_type() { DataType::Utf8View => { - let str1_array = as_string_view_array(&args[0])?; - let str2_array = as_string_view_array(&args[1])?; + let str1_array = as_string_view_array(&str1)?; + let str2_array = as_string_view_array(&str2)?; let result = str1_array .iter() .zip(str2_array.iter()) @@ -134,8 +129,8 @@ pub fn levenshtein(args: &[ArrayRef]) -> Result { Ok(Arc::new(result) as ArrayRef) } DataType::Utf8 => { - let str1_array = as_generic_string_array::(&args[0])?; - let str2_array = as_generic_string_array::(&args[1])?; + let str1_array = as_generic_string_array::(&str1)?; + let str2_array = as_generic_string_array::(&str2)?; let result = str1_array .iter() .zip(str2_array.iter()) @@ -149,8 +144,8 @@ pub fn levenshtein(args: &[ArrayRef]) -> Result { Ok(Arc::new(result) as ArrayRef) } DataType::LargeUtf8 => { - let str1_array = as_generic_string_array::(&args[0])?; - let str2_array = as_generic_string_array::(&args[1])?; + let str1_array = as_generic_string_array::(&str1)?; + let str2_array = as_generic_string_array::(&str2)?; let result = str1_array .iter() .zip(str2_array.iter()) diff --git a/datafusion/functions/src/string/octet_length.rs b/datafusion/functions/src/string/octet_length.rs index f443571112e7..7e0187c0b1be 100644 --- a/datafusion/functions/src/string/octet_length.rs +++ b/datafusion/functions/src/string/octet_length.rs @@ -20,7 +20,7 @@ use arrow::datatypes::DataType; use std::any::Any; use crate::utils::utf8_to_int_type; -use datafusion_common::{exec_err, Result, ScalarValue}; +use datafusion_common::{utils::take_function_args, Result, ScalarValue}; use datafusion_expr::{ColumnarValue, Documentation, Volatility}; use datafusion_expr::{ScalarUDFImpl, Signature}; use datafusion_macros::user_doc; @@ -82,14 +82,9 @@ impl ScalarUDFImpl for OctetLengthFunc { args: &[ColumnarValue], _number_rows: usize, ) -> Result { - if args.len() != 1 { - return exec_err!( - "octet_length function requires 1 argument, got {}", - args.len() - ); - } + let [array] = take_function_args(self.name(), args)?; - match &args[0] { + match array { ColumnarValue::Array(v) => Ok(ColumnarValue::Array(length(v.as_ref())?)), ColumnarValue::Scalar(v) => match v { ScalarValue::Utf8(v) => Ok(ColumnarValue::Scalar(ScalarValue::Int32( diff --git a/datafusion/functions/src/unicode/find_in_set.rs b/datafusion/functions/src/unicode/find_in_set.rs index 12f213a827cf..c4a9f067e9f4 100644 --- a/datafusion/functions/src/unicode/find_in_set.rs +++ b/datafusion/functions/src/unicode/find_in_set.rs @@ -25,7 +25,9 @@ use arrow::array::{ use arrow::datatypes::{ArrowNativeType, DataType, Int32Type, Int64Type}; use crate::utils::utf8_to_int_type; -use datafusion_common::{exec_err, internal_err, Result, ScalarValue}; +use datafusion_common::{ + exec_err, internal_err, utils::take_function_args, Result, ScalarValue, +}; use datafusion_expr::TypeSignature::Exact; use datafusion_expr::{ ColumnarValue, Documentation, ScalarFunctionArgs, ScalarUDFImpl, Signature, @@ -96,17 +98,9 @@ impl ScalarUDFImpl for FindInSetFunc { } fn invoke_with_args(&self, args: ScalarFunctionArgs) -> Result { - let ScalarFunctionArgs { mut args, .. } = args; - - if args.len() != 2 { - return exec_err!( - "find_in_set was called with {} arguments. It requires 2.", - args.len() - ); - } + let ScalarFunctionArgs { args, .. } = args; - let str_list = args.pop().unwrap(); - let string = args.pop().unwrap(); + let [string, str_list] = take_function_args(self.name(), args)?; match (string, str_list) { // both inputs are scalars diff --git a/datafusion/functions/src/unicode/substrindex.rs b/datafusion/functions/src/unicode/substrindex.rs index 60ccd2204788..20ad33b3cfe3 100644 --- a/datafusion/functions/src/unicode/substrindex.rs +++ b/datafusion/functions/src/unicode/substrindex.rs @@ -25,7 +25,7 @@ use arrow::array::{ use arrow::datatypes::{DataType, Int32Type, Int64Type}; use crate::utils::{make_scalar_function, utf8_to_str_type}; -use datafusion_common::{exec_err, Result}; +use datafusion_common::{exec_err, utils::take_function_args, Result}; use datafusion_expr::TypeSignature::Exact; use datafusion_expr::{ ColumnarValue, Documentation, ScalarUDFImpl, Signature, Volatility, @@ -131,18 +131,13 @@ impl ScalarUDFImpl for SubstrIndexFunc { /// SUBSTRING_INDEX('www.apache.org', '.', -2) = apache.org /// SUBSTRING_INDEX('www.apache.org', '.', -1) = org fn substr_index(args: &[ArrayRef]) -> Result { - if args.len() != 3 { - return exec_err!( - "substr_index was called with {} arguments. It requires 3.", - args.len() - ); - } + let [str, delim, count] = take_function_args("substr_index", args)?; - match args[0].data_type() { + match str.data_type() { DataType::Utf8 => { - let string_array = args[0].as_string::(); - let delimiter_array = args[1].as_string::(); - let count_array: &PrimitiveArray = args[2].as_primitive(); + let string_array = str.as_string::(); + let delimiter_array = delim.as_string::(); + let count_array: &PrimitiveArray = count.as_primitive(); substr_index_general::( string_array, delimiter_array, @@ -150,9 +145,9 @@ fn substr_index(args: &[ArrayRef]) -> Result { ) } DataType::LargeUtf8 => { - let string_array = args[0].as_string::(); - let delimiter_array = args[1].as_string::(); - let count_array: &PrimitiveArray = args[2].as_primitive(); + let string_array = str.as_string::(); + let delimiter_array = delim.as_string::(); + let count_array: &PrimitiveArray = count.as_primitive(); substr_index_general::( string_array, delimiter_array, @@ -160,9 +155,9 @@ fn substr_index(args: &[ArrayRef]) -> Result { ) } DataType::Utf8View => { - let string_array = args[0].as_string_view(); - let delimiter_array = args[1].as_string_view(); - let count_array: &PrimitiveArray = args[2].as_primitive(); + let string_array = str.as_string_view(); + let delimiter_array = delim.as_string_view(); + let count_array: &PrimitiveArray = count.as_primitive(); substr_index_general::( string_array, delimiter_array, diff --git a/datafusion/functions/src/utils.rs b/datafusion/functions/src/utils.rs index 966fd8209a04..39d8aeeda460 100644 --- a/datafusion/functions/src/utils.rs +++ b/datafusion/functions/src/utils.rs @@ -18,51 +18,10 @@ use arrow::array::ArrayRef; use arrow::datatypes::DataType; -use datafusion_common::{exec_datafusion_err, Result, ScalarValue}; +use datafusion_common::{Result, ScalarValue}; use datafusion_expr::function::Hint; use datafusion_expr::ColumnarValue; -/// Converts a collection of function arguments into an fixed-size array of length N -/// producing a reasonable error message in case of unexpected number of arguments. -/// -/// # Example -/// ``` -/// # use datafusion_common::ScalarValue; -/// # use datafusion_common::Result; -/// # use datafusion_expr_common::columnar_value::ColumnarValue; -/// # use datafusion_functions::utils::take_function_args; -/// fn my_function(args: &[ColumnarValue]) -> Result<()> { -/// // function expects 2 args, so create a 2-element array -/// let [arg1, arg2] = take_function_args("my_function", args)?; -/// // ... do stuff.. -/// Ok(()) -/// } -/// -/// // Calling the function with 1 argument produces an error: -/// let ten = ColumnarValue::from(ScalarValue::from(10i32)); -/// let twenty = ColumnarValue::from(ScalarValue::from(20i32)); -/// let args = vec![ten.clone()]; -/// let err = my_function(&args).unwrap_err(); -/// assert_eq!(err.to_string(), "Execution error: my_function function requires 2 arguments, got 1"); -/// // Calling the function with 2 arguments works great -/// let args = vec![ten, twenty]; -/// my_function(&args).unwrap(); -/// ``` -pub fn take_function_args( - function_name: &str, - args: impl IntoIterator, -) -> Result<[T; N]> { - let args = args.into_iter().collect::>(); - args.try_into().map_err(|v: Vec| { - exec_datafusion_err!( - "{} function requires {} arguments, got {}", - function_name, - N, - v.len() - ) - }) -} - /// Creates a function to identify the optimal return type of a string function given /// the type of its first argument. /// diff --git a/datafusion/sqllogictest/test_files/map.slt b/datafusion/sqllogictest/test_files/map.slt index 29ef506aa070..71296b6f6474 100644 --- a/datafusion/sqllogictest/test_files/map.slt +++ b/datafusion/sqllogictest/test_files/map.slt @@ -188,7 +188,7 @@ SELECT MAP([[1,2], [3,4]], ['a', 'b']); query error SELECT MAP() -query error DataFusion error: Execution error: map requires exactly 2 arguments, got 1 instead +query error DataFusion error: Execution error: map function requires 2 arguments, got 1 SELECT MAP(['POST', 'HEAD']) query error DataFusion error: Execution error: Expected list, large_list or fixed_size_list, got Null