diff --git a/Cargo.toml b/Cargo.toml index 12993eda0f72a..9eba7801a5889 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -48,7 +48,9 @@ rust-version = "1.64" arrow = { version = "40.0.0", features = ["prettyprint"] } arrow-flight = { version = "40.0.0", features = ["flight-sql-experimental"] } arrow-buffer = { version = "40.0.0", default-features = false } +arrow-ord = { version = "40.0.0", default-features = false } arrow-schema = { version = "40.0.0", default-features = false } +arrow-select = { version = "40.0.0", default-features = false } arrow-array = { version = "40.0.0", default-features = false, features = ["chrono-tz"] } parquet = { version = "40.0.0", features = ["arrow", "async", "object_store"] } diff --git a/datafusion-cli/Cargo.lock b/datafusion-cli/Cargo.lock index db1ead15e855a..c54231091f1c2 100644 --- a/datafusion-cli/Cargo.lock +++ b/datafusion-cli/Cargo.lock @@ -1087,7 +1087,9 @@ dependencies = [ "arrow", "arrow-array", "arrow-buffer", + "arrow-ord", "arrow-schema", + "arrow-select", "blake2", "blake3", "chrono", diff --git a/datafusion/core/tests/sql/expr.rs b/datafusion/core/tests/sql/expr.rs index 6783670545c3c..03bc11b1c05fb 100644 --- a/datafusion/core/tests/sql/expr.rs +++ b/datafusion/core/tests/sql/expr.rs @@ -200,6 +200,41 @@ async fn binary_bitwise_shift() -> Result<()> { Ok(()) } +#[tokio::test] +async fn test_comparison_func_expressions() -> Result<()> { + test_expression!("greatest(1,2,3)", "3"); + test_expression!("least(1,2,3)", "1"); + + Ok(()) +} + +#[tokio::test] +async fn test_comparison_func_array_scalar_expression() -> Result<()> { + let ctx = SessionContext::new(); + let schema = Arc::new(Schema::new(vec![Field::new("a", DataType::Int64, false)])); + let batch = RecordBatch::try_new( + schema.clone(), + vec![Arc::new(Int64Array::from(vec![1, 2, 3]))], + )?; + let table = MemTable::try_new(schema, vec![vec![batch]])?; + ctx.register_table("t1", Arc::new(table))?; + let sql = "SELECT greatest(a, 2), least(a, 2) from t1"; + let actual = execute_to_batches(&ctx, sql).await; + assert_batches_eq!( + &[ + "+-------------------------+----------------------+", + "| greatest(t1.a,Int64(2)) | least(t1.a,Int64(2)) |", + "+-------------------------+----------------------+", + "| 2 | 1 |", + "| 2 | 2 |", + "| 3 | 2 |", + "+-------------------------+----------------------+", + ], + &actual + ); + Ok(()) +} + #[tokio::test] async fn test_interval_expressions() -> Result<()> { // day nano intervals diff --git a/datafusion/expr/src/built_in_function.rs b/datafusion/expr/src/built_in_function.rs index 3911939b4ca6e..d4ca93ba24e5c 100644 --- a/datafusion/expr/src/built_in_function.rs +++ b/datafusion/expr/src/built_in_function.rs @@ -205,6 +205,10 @@ pub enum BuiltinScalarFunction { Struct, /// arrow_typeof ArrowTypeof, + /// greatest + Greatest, + /// least + Least, } lazy_static! { @@ -328,6 +332,8 @@ impl BuiltinScalarFunction { BuiltinScalarFunction::Struct => Volatility::Immutable, BuiltinScalarFunction::FromUnixtime => Volatility::Immutable, BuiltinScalarFunction::ArrowTypeof => Volatility::Immutable, + BuiltinScalarFunction::Greatest => Volatility::Immutable, + BuiltinScalarFunction::Least => Volatility::Immutable, // Stable builtin functions BuiltinScalarFunction::Now => Volatility::Stable, @@ -414,6 +420,10 @@ fn aliases(func: &BuiltinScalarFunction) -> &'static [&'static str] { BuiltinScalarFunction::Upper => &["upper"], BuiltinScalarFunction::Uuid => &["uuid"], + // comparison functions + BuiltinScalarFunction::Greatest => &["greatest"], + BuiltinScalarFunction::Least => &["least"], + // regex functions BuiltinScalarFunction::RegexpMatch => &["regexp_match"], BuiltinScalarFunction::RegexpReplace => &["regexp_replace"], diff --git a/datafusion/expr/src/comparison_expressions.rs b/datafusion/expr/src/comparison_expressions.rs new file mode 100644 index 0000000000000..c7f13f04f08dc --- /dev/null +++ b/datafusion/expr/src/comparison_expressions.rs @@ -0,0 +1,35 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +use arrow::datatypes::DataType; + +/// Currently supported types by the comparison function. +pub static SUPPORTED_COMPARISON_TYPES: &[DataType] = &[ + DataType::Boolean, + DataType::UInt8, + DataType::UInt16, + DataType::UInt32, + DataType::UInt64, + DataType::Int8, + DataType::Int16, + DataType::Int32, + DataType::Int64, + DataType::Float32, + DataType::Float64, + DataType::Utf8, + DataType::LargeUtf8, +]; diff --git a/datafusion/expr/src/expr_fn.rs b/datafusion/expr/src/expr_fn.rs index 6b0a09baf945e..4f236f1720589 100644 --- a/datafusion/expr/src/expr_fn.rs +++ b/datafusion/expr/src/expr_fn.rs @@ -237,6 +237,22 @@ pub fn concat_ws(sep: Expr, values: Vec) -> Expr { )) } +/// Returns the greatest value of all arguments. +pub fn greatest(args: &[Expr]) -> Expr { + Expr::ScalarFunction(ScalarFunction::new( + BuiltinScalarFunction::Greatest, + args.to_vec(), + )) +} + +/// Returns the least value of all arguments. +pub fn least(args: &[Expr]) -> Expr { + Expr::ScalarFunction(ScalarFunction::new( + BuiltinScalarFunction::Least, + args.to_vec(), + )) +} + /// Returns an approximate value of π pub fn pi() -> Expr { Expr::ScalarFunction(ScalarFunction::new(BuiltinScalarFunction::Pi, vec![])) @@ -620,9 +636,15 @@ nary_scalar_expr!(Coalesce, coalesce, "returns `coalesce(args...)`, which evalua nary_scalar_expr!( ConcatWithSeparator, concat_ws_expr, - "concatenates several strings, placing a seperator between each one" + "concatenates several strings, placing a separator between each one" ); nary_scalar_expr!(Concat, concat_expr, "concatenates several strings"); +nary_scalar_expr!( + Greatest, + greatest_expr, + "gets the largest value of the list" +); +nary_scalar_expr!(Least, least_expr, "gets the smallest value of the list"); // date functions scalar_expr!(DatePart, date_part, part date, "extracts a subfield from the date"); diff --git a/datafusion/expr/src/function.rs b/datafusion/expr/src/function.rs index 5ba6852248572..70fe0c0e55dfd 100644 --- a/datafusion/expr/src/function.rs +++ b/datafusion/expr/src/function.rs @@ -22,8 +22,8 @@ use crate::nullif::SUPPORTED_NULLIF_TYPES; use crate::type_coercion::functions::data_types; use crate::ColumnarValue; use crate::{ - array_expressions, conditional_expressions, struct_expressions, Accumulator, - BuiltinScalarFunction, Signature, TypeSignature, + array_expressions, comparison_expressions, conditional_expressions, + struct_expressions, Accumulator, BuiltinScalarFunction, Signature, TypeSignature, }; use arrow::datatypes::{DataType, Field, Fields, IntervalUnit, TimeUnit}; use datafusion_common::{DataFusionError, Result}; @@ -168,6 +168,11 @@ pub fn return_type( let coerced_types = data_types(input_expr_types, &signature(fun)); coerced_types.map(|typs| typs[0].clone()) } + BuiltinScalarFunction::Greatest | BuiltinScalarFunction::Least => { + // GREATEST and LEAST have multiple args and they might get coerced, get a preview of this + let coerced_types = data_types(input_expr_types, &signature(fun)); + coerced_types.map(|typs| typs[0].clone()) + } BuiltinScalarFunction::OctetLength => { utf8_to_int_type(&input_expr_types[0], "octet_length") } @@ -376,6 +381,12 @@ pub fn signature(fun: &BuiltinScalarFunction) -> Signature { BuiltinScalarFunction::Chr | BuiltinScalarFunction::ToHex => { Signature::uniform(1, vec![DataType::Int64], fun.volatility()) } + BuiltinScalarFunction::Greatest | BuiltinScalarFunction::Least => { + Signature::variadic_equal( + comparison_expressions::SUPPORTED_COMPARISON_TYPES.to_vec(), + fun.volatility(), + ) + } BuiltinScalarFunction::Lpad | BuiltinScalarFunction::Rpad => Signature::one_of( vec![ TypeSignature::Exact(vec![DataType::Utf8, DataType::Int64]), diff --git a/datafusion/expr/src/function_err.rs b/datafusion/expr/src/function_err.rs index 39ac4ef8039a7..2d804868ec662 100644 --- a/datafusion/expr/src/function_err.rs +++ b/datafusion/expr/src/function_err.rs @@ -53,7 +53,7 @@ impl TypeSignature { .collect::>() .join(", ")] } - TypeSignature::VariadicEqual => vec!["T, .., T".to_string()], + TypeSignature::VariadicEqual(_) => vec!["T, .., T".to_string()], TypeSignature::VariadicAny => vec!["Any, .., Any".to_string()], TypeSignature::OneOf(sigs) => { sigs.iter().flat_map(|s| s.to_string_repr()).collect() diff --git a/datafusion/expr/src/lib.rs b/datafusion/expr/src/lib.rs index 9f3841841b4c8..97ff84fcafac9 100644 --- a/datafusion/expr/src/lib.rs +++ b/datafusion/expr/src/lib.rs @@ -30,6 +30,7 @@ pub mod aggregate_function; pub mod array_expressions; mod built_in_function; mod columnar_value; +pub mod comparison_expressions; pub mod conditional_expressions; pub mod expr; pub mod expr_fn; diff --git a/datafusion/expr/src/signature.rs b/datafusion/expr/src/signature.rs index a2caba4fb8bbd..d0ca994456b71 100644 --- a/datafusion/expr/src/signature.rs +++ b/datafusion/expr/src/signature.rs @@ -42,10 +42,10 @@ pub enum TypeSignature { /// arbitrary number of arguments of an common type out of a list of valid types // A function such as `concat` is `Variadic(vec![DataType::Utf8, DataType::LargeUtf8])` Variadic(Vec), - /// arbitrary number of arguments of an arbitrary but equal type + /// arbitrary number of arguments of an equal type // A function such as `array` is `VariadicEqual` // The first argument decides the type used for coercion - VariadicEqual, + VariadicEqual(Vec), /// arbitrary number of arguments with arbitrary types VariadicAny, /// fixed number of arguments of an arbitrary but equal type out of a list of valid types @@ -85,10 +85,11 @@ impl Signature { volatility, } } - /// variadic_equal - Creates a variadic signature that represents an arbitrary number of arguments of the same type. - pub fn variadic_equal(volatility: Volatility) -> Self { + /// variadic_equal - Creates a variadic signature that represents an arbitrary number of arguments of the same type in + /// the allowed_types. + pub fn variadic_equal(allowed_types: Vec, volatility: Volatility) -> Self { Self { - type_signature: TypeSignature::VariadicEqual, + type_signature: TypeSignature::VariadicEqual(allowed_types), volatility, } } diff --git a/datafusion/expr/src/type_coercion/functions.rs b/datafusion/expr/src/type_coercion/functions.rs index d86914325fc98..f4480a1e754cc 100644 --- a/datafusion/expr/src/type_coercion/functions.rs +++ b/datafusion/expr/src/type_coercion/functions.rs @@ -71,11 +71,15 @@ fn get_valid_types( .iter() .map(|valid_type| (0..*number).map(|_| valid_type.clone()).collect()) .collect(), - TypeSignature::VariadicEqual => { - // one entry with the same len as current_types, whose type is `current_types[0]`. + TypeSignature::VariadicEqual(allowed_types) => { + if allowed_types.is_empty() { + return Err(DataFusionError::Plan( + "allowed types cannot be empty".to_string(), + )); + } vec![current_types .iter() - .map(|_| current_types[0].clone()) + .map(|_| allowed_types[0].clone()) .collect()] } TypeSignature::VariadicAny => { diff --git a/datafusion/physical-expr/Cargo.toml b/datafusion/physical-expr/Cargo.toml index bf0b0a9f19c74..b9d7b36bb373d 100644 --- a/datafusion/physical-expr/Cargo.toml +++ b/datafusion/physical-expr/Cargo.toml @@ -46,7 +46,9 @@ ahash = { version = "0.8", default-features = false, features = ["runtime-rng"] arrow = { workspace = true } arrow-array = { workspace = true } arrow-buffer = { workspace = true } +arrow-ord = { workspace = true } arrow-schema = { workspace = true } +arrow-select = { workspace = true } blake2 = { version = "^0.10.2", optional = true } blake3 = { version = "1.0", optional = true } chrono = { version = "0.4.23", default-features = false } diff --git a/datafusion/physical-expr/src/comparison_expressions.rs b/datafusion/physical-expr/src/comparison_expressions.rs new file mode 100644 index 0000000000000..895fa8016f1e4 --- /dev/null +++ b/datafusion/physical-expr/src/comparison_expressions.rs @@ -0,0 +1,268 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +// Some of these functions reference the Postgres documentation +// or implementation to ensure compatibility and are subject to +// the Postgres license. + +//! Comparison expressions + +use arrow::array::Array; +use arrow::datatypes::DataType; +use arrow_ord::comparison::{gt_dyn, lt_dyn}; +use arrow_select::zip::zip; +use datafusion_common::scalar::ScalarValue; +use datafusion_common::{DataFusionError, Result}; +use datafusion_expr::ColumnarValue; + +#[derive(Debug, Clone, PartialEq)] +enum ComparisonOperator { + Greatest, + Least, +} + +macro_rules! compare_scalar_typed { + ($op:expr, $args:expr, $data_type:ident) => {{ + let value = $args + .iter() + .filter_map(|scalar| match scalar { + ScalarValue::$data_type(v) => v.clone(), + _ => panic!("Impossibly got non-scalar values"), + }) + .reduce(|a, b| match $op { + ComparisonOperator::Greatest => a.max(b), + ComparisonOperator::Least => a.min(b), + }); + ScalarValue::$data_type(value) + }}; +} + +/// Evaluate a greatest or least function for the case when all arguments are scalars +fn compare_scalars( + data_type: DataType, + op: ComparisonOperator, + args: &[ScalarValue], +) -> ScalarValue { + match data_type { + DataType::Boolean => compare_scalar_typed!(op, args, Boolean), + DataType::Int8 => compare_scalar_typed!(op, args, Int8), + DataType::Int16 => compare_scalar_typed!(op, args, Int16), + DataType::Int32 => compare_scalar_typed!(op, args, Int32), + DataType::Int64 => compare_scalar_typed!(op, args, Int64), + DataType::UInt8 => compare_scalar_typed!(op, args, UInt8), + DataType::UInt16 => compare_scalar_typed!(op, args, UInt16), + DataType::UInt32 => compare_scalar_typed!(op, args, UInt32), + DataType::UInt64 => compare_scalar_typed!(op, args, UInt64), + DataType::Float32 => compare_scalar_typed!(op, args, Float32), + DataType::Float64 => compare_scalar_typed!(op, args, Float64), + DataType::Utf8 => compare_scalar_typed!(op, args, Utf8), + DataType::LargeUtf8 => compare_scalar_typed!(op, args, LargeUtf8), + _ => panic!("Unsupported data type for comparison: {:?}", data_type), + } +} + +/// Evaluate a greatest or least function +fn compare(op: ComparisonOperator, args: &[ColumnarValue]) -> Result { + if args.is_empty() { + return Err(DataFusionError::Internal(format!( + "{:?} expressions require at least one argument", + op + ))); + } else if args.len() == 1 { + return Ok(args[0].clone()); + } + + let args_types = args + .iter() + .map(|arg| match arg { + ColumnarValue::Array(array) => array.data_type().clone(), + ColumnarValue::Scalar(scalar) => scalar.get_datatype(), + }) + .collect::>(); + + if args_types.iter().any(|t| t != &args_types[0]) { + return Err(DataFusionError::Internal(format!( + "{:?} expressions require all arguments to be of the same type", + op + ))); + } + + let mut arg_lengths = args + .iter() + .filter_map(|arg| match arg { + ColumnarValue::Array(array) => Some(array.len()), + ColumnarValue::Scalar(_) => None, + }) + .collect::>(); + arg_lengths.dedup(); + + if arg_lengths.len() > 1 { + return Err(DataFusionError::Internal(format!( + "{:?} expressions require all arguments to be of the same length", + op + ))); + } + + // scalars have no lengths, so if there are no lengths, all arguments are scalars + let all_scalars = arg_lengths.is_empty(); + + if all_scalars { + let args: Vec<_> = args + .iter() + .map(|arg| match arg { + ColumnarValue::Array(_) => { + panic!("Internal error: all arguments should be scalars") + } + ColumnarValue::Scalar(scalar) => scalar.clone(), + }) + .collect(); + Ok(ColumnarValue::Scalar(compare_scalars( + args_types[0].clone(), + op, + &args, + ))) + } else { + let cmp = match op { + ComparisonOperator::Greatest => gt_dyn, + ComparisonOperator::Least => lt_dyn, + }; + let length = arg_lengths[0]; + let first_arg = match &args[0] { + ColumnarValue::Array(array) => array.clone(), + ColumnarValue::Scalar(scalar) => scalar.to_array_of_size(length), + }; + args[1..] + .iter() + .map(|arg| match arg { + ColumnarValue::Array(array) => array.clone(), + ColumnarValue::Scalar(scalar) => scalar.to_array_of_size(length), + }) + // we cannot use try_reduce as it is still nightly + .try_fold(first_arg, |a, b| { + // mask will be true if cmp holds for a to be otherwise false + let mask = cmp(&a, &b)?; + // then the zip can pluck values accordingly from a and b + let value = zip(&mask, &a, &b)?; + Ok(value) + }) + .map(ColumnarValue::Array) + } +} + +pub fn greatest(args: &[ColumnarValue]) -> Result { + compare(ComparisonOperator::Greatest, args) +} + +pub fn least(args: &[ColumnarValue]) -> Result { + compare(ComparisonOperator::Least, args) +} + +#[cfg(test)] +mod tests { + use super::*; + use arrow_array::Int32Array; + use datafusion_expr::ColumnarValue; + use std::sync::Arc; + + #[test] + fn test_compare_scalars() { + let args = vec![ + ScalarValue::Int32(Some(1)), + ScalarValue::Int32(Some(3)), + ScalarValue::Int32(Some(2)), + ]; + let result = + compare_scalars(DataType::Int32, ComparisonOperator::Greatest, &args); + assert_eq!(result, ScalarValue::Int32(Some(3))); + } + + #[test] + #[should_panic] + fn test_compare_scalars_unsupported_types() { + let args = vec![ + ScalarValue::Int32(Some(1)), + ScalarValue::Utf8(Some("foo".to_string())), + ]; + let _ = compare_scalars(DataType::Int32, ComparisonOperator::Greatest, &args); + } + + #[test] + fn test_compare_i32_arrays() { + let args = vec![ + Arc::new(Int32Array::from(vec![Some(1), Some(2), Some(3)])), + Arc::new(Int32Array::from(vec![Some(3), Some(2), Some(1)])), + Arc::new(Int32Array::from(vec![Some(2), Some(3), Some(1)])), + ]; + let args = args + .iter() + .map(|array| ColumnarValue::Array(array.clone())) + .collect::>(); + // compare to greatest + let result = compare(ComparisonOperator::Greatest, &args).unwrap(); + let array_value = match result { + ColumnarValue::Array(array) => array, + _ => panic!("Internal error: expected array"), + }; + let primitive_array = array_value.as_any().downcast_ref::().unwrap(); + let value_vec = primitive_array.values().to_vec(); + assert_eq!(value_vec, vec![3, 3, 3]); + // compare to least + let result = compare(ComparisonOperator::Least, &args).unwrap(); + let array_value = match result { + ColumnarValue::Array(array) => array, + _ => panic!("Internal error: expected array"), + }; + let primitive_array = array_value.as_any().downcast_ref::().unwrap(); + let value_vec = primitive_array.values().to_vec(); + assert_eq!(value_vec, vec![1, 2, 1]); + } + + #[test] + fn test_compare_i32_array_scalar() { + let args = vec![ + ColumnarValue::Array(Arc::new(Int32Array::from(vec![ + Some(1), + Some(2), + Some(3), + ]))), + ColumnarValue::Scalar(ScalarValue::Int32(Some(3))), + ColumnarValue::Array(Arc::new(Int32Array::from(vec![ + Some(2), + Some(3), + Some(1), + ]))), + ]; + // compare to greatest + let result = compare(ComparisonOperator::Greatest, &args).unwrap(); + let array_value = match result { + ColumnarValue::Array(array) => array, + _ => panic!("Internal error: expected array"), + }; + let primitive_array = array_value.as_any().downcast_ref::().unwrap(); + let value_vec = primitive_array.values().to_vec(); + assert_eq!(value_vec, vec![3, 3, 3]); + // compare to least + let result = compare(ComparisonOperator::Least, &args).unwrap(); + let array_value = match result { + ColumnarValue::Array(array) => array, + _ => panic!("Internal error: expected array"), + }; + let primitive_array = array_value.as_any().downcast_ref::().unwrap(); + let value_vec = primitive_array.values().to_vec(); + assert_eq!(value_vec, vec![1, 2, 1]); + } +} diff --git a/datafusion/physical-expr/src/functions.rs b/datafusion/physical-expr/src/functions.rs index 7020dda8b1225..438ebc26875fc 100644 --- a/datafusion/physical-expr/src/functions.rs +++ b/datafusion/physical-expr/src/functions.rs @@ -32,7 +32,8 @@ use crate::execution_props::ExecutionProps; use crate::{ - array_expressions, conditional_expressions, datetime_expressions, + array_expressions, comparison_expressions, conditional_expressions, + datetime_expressions, expressions::{cast_column, nullif_func}, math_expressions, string_expressions, struct_expressions, PhysicalExpr, ScalarFunctionExpr, @@ -452,6 +453,8 @@ pub fn create_physical_fun( BuiltinScalarFunction::ConcatWithSeparator => { Arc::new(|args| make_scalar_function(string_expressions::concat_ws)(args)) } + BuiltinScalarFunction::Greatest => Arc::new(comparison_expressions::greatest), + BuiltinScalarFunction::Least => Arc::new(comparison_expressions::least), BuiltinScalarFunction::DatePart => Arc::new(datetime_expressions::date_part), BuiltinScalarFunction::DateTrunc => Arc::new(datetime_expressions::date_trunc), BuiltinScalarFunction::DateBin => Arc::new(datetime_expressions::date_bin), diff --git a/datafusion/physical-expr/src/lib.rs b/datafusion/physical-expr/src/lib.rs index b54bcda601c74..06fa0d64d16f5 100644 --- a/datafusion/physical-expr/src/lib.rs +++ b/datafusion/physical-expr/src/lib.rs @@ -17,6 +17,7 @@ pub mod aggregate; pub mod array_expressions; +pub mod comparison_expressions; pub mod conditional_expressions; #[cfg(feature = "crypto_expressions")] pub mod crypto_expressions; diff --git a/datafusion/physical-expr/src/type_coercion.rs b/datafusion/physical-expr/src/type_coercion.rs index 399dcc0899000..bb84da69b0947 100644 --- a/datafusion/physical-expr/src/type_coercion.rs +++ b/datafusion/physical-expr/src/type_coercion.rs @@ -134,7 +134,7 @@ mod tests { // u32 -> f32 case( vec![DataType::Float32, DataType::UInt32], - Signature::variadic_equal(Volatility::Immutable), + Signature::variadic_equal(vec![DataType::Float32], Volatility::Immutable), vec![DataType::Float32, DataType::Float32], )?, // common type is u64 @@ -171,7 +171,7 @@ mod tests { // u32 and bool are not uniform case( vec![DataType::UInt32, DataType::Boolean], - Signature::variadic_equal(Volatility::Immutable), + Signature::variadic_equal(vec![DataType::UInt32], Volatility::Immutable), vec![], )?, // bool is not castable to u32 diff --git a/datafusion/proto/proto/datafusion.proto b/datafusion/proto/proto/datafusion.proto index c23d585e61d72..f98d2dc2846b4 100644 --- a/datafusion/proto/proto/datafusion.proto +++ b/datafusion/proto/proto/datafusion.proto @@ -549,6 +549,8 @@ enum ScalarFunction { Factorial = 83; Lcm = 84; Gcd = 85; + Greatest = 86; + Least = 87; } message ScalarFunctionNode { diff --git a/datafusion/proto/src/generated/pbjson.rs b/datafusion/proto/src/generated/pbjson.rs index 369cc0b24e711..fd5657f4cb72a 100644 --- a/datafusion/proto/src/generated/pbjson.rs +++ b/datafusion/proto/src/generated/pbjson.rs @@ -17851,6 +17851,8 @@ impl serde::Serialize for ScalarFunction { Self::Factorial => "Factorial", Self::Lcm => "Lcm", Self::Gcd => "Gcd", + Self::Greatest => "Greatest", + Self::Least => "Least", }; serializer.serialize_str(variant) } @@ -17948,6 +17950,8 @@ impl<'de> serde::Deserialize<'de> for ScalarFunction { "Factorial", "Lcm", "Gcd", + "Greatest", + "Least", ]; struct GeneratedVisitor; @@ -18076,6 +18080,8 @@ impl<'de> serde::Deserialize<'de> for ScalarFunction { "Factorial" => Ok(ScalarFunction::Factorial), "Lcm" => Ok(ScalarFunction::Lcm), "Gcd" => Ok(ScalarFunction::Gcd), + "Greatest" => Ok(ScalarFunction::Greatest), + "Least" => Ok(ScalarFunction::Least), _ => Err(serde::de::Error::unknown_variant(value, FIELDS)), } } diff --git a/datafusion/proto/src/generated/prost.rs b/datafusion/proto/src/generated/prost.rs index 4cf50d70bf0e6..acf85d5ab0e51 100644 --- a/datafusion/proto/src/generated/prost.rs +++ b/datafusion/proto/src/generated/prost.rs @@ -2208,6 +2208,8 @@ pub enum ScalarFunction { Factorial = 83, Lcm = 84, Gcd = 85, + Greatest = 86, + Least = 87, } impl ScalarFunction { /// String value of the enum field names used in the ProtoBuf definition. @@ -2302,6 +2304,8 @@ impl ScalarFunction { ScalarFunction::Factorial => "Factorial", ScalarFunction::Lcm => "Lcm", ScalarFunction::Gcd => "Gcd", + ScalarFunction::Greatest => "Greatest", + ScalarFunction::Least => "Least", } } /// Creates an enum from field names used in the ProtoBuf definition. @@ -2393,6 +2397,8 @@ impl ScalarFunction { "Factorial" => Some(Self::Factorial), "Lcm" => Some(Self::Lcm), "Gcd" => Some(Self::Gcd), + "Greatest" => Some(Self::Greatest), + "Least" => Some(Self::Least), _ => None, } } diff --git a/datafusion/proto/src/logical_plan/from_proto.rs b/datafusion/proto/src/logical_plan/from_proto.rs index 1150220bef4ad..1e28e890fdfde 100644 --- a/datafusion/proto/src/logical_plan/from_proto.rs +++ b/datafusion/proto/src/logical_plan/from_proto.rs @@ -40,7 +40,8 @@ use datafusion_expr::{ cbrt, ceil, character_length, chr, coalesce, concat_expr, concat_ws_expr, cos, cosh, date_bin, date_part, date_trunc, degrees, digest, exp, expr::{self, InList, Sort, WindowFunction}, - factorial, floor, from_unixtime, gcd, lcm, left, ln, log, log10, log2, + factorial, floor, from_unixtime, gcd, greatest, lcm, least, left, ln, log, log10, + log2, logical_plan::{PlanType, StringifiedPlan}, lower, lpad, ltrim, md5, now, nullif, octet_length, pi, power, radians, random, regexp_match, regexp_replace, repeat, replace, reverse, right, round, rpad, rtrim, @@ -494,6 +495,8 @@ impl From<&protobuf::ScalarFunction> for BuiltinScalarFunction { ScalarFunction::FromUnixtime => Self::FromUnixtime, ScalarFunction::Atan2 => Self::Atan2, ScalarFunction::ArrowTypeof => Self::ArrowTypeof, + ScalarFunction::Greatest => Self::Greatest, + ScalarFunction::Least => Self::Least, } } } @@ -1273,6 +1276,20 @@ pub fn parse_expr( .map(|expr| parse_expr(expr, registry)) .collect::, _>>()?, )), + ScalarFunction::Greatest => Ok(greatest( + &args + .to_owned() + .iter() + .map(|expr| parse_expr(expr, registry)) + .collect::, _>>()?, + )), + ScalarFunction::Least => Ok(least( + &args + .to_owned() + .iter() + .map(|expr| parse_expr(expr, registry)) + .collect::, _>>()?, + )), ScalarFunction::Lpad => Ok(lpad( args.to_owned() .iter() diff --git a/datafusion/proto/src/logical_plan/to_proto.rs b/datafusion/proto/src/logical_plan/to_proto.rs index 191c49194407f..6e27cbc4adec3 100644 --- a/datafusion/proto/src/logical_plan/to_proto.rs +++ b/datafusion/proto/src/logical_plan/to_proto.rs @@ -1386,6 +1386,8 @@ impl TryFrom<&BuiltinScalarFunction> for protobuf::ScalarFunction { BuiltinScalarFunction::FromUnixtime => Self::FromUnixtime, BuiltinScalarFunction::Atan2 => Self::Atan2, BuiltinScalarFunction::ArrowTypeof => Self::ArrowTypeof, + BuiltinScalarFunction::Greatest => Self::Greatest, + BuiltinScalarFunction::Least => Self::Least, }; Ok(scalar_function) diff --git a/docs/source/user-guide/sql/sql_status.md b/docs/source/user-guide/sql/sql_status.md index 6075a23330a8f..b7f4a0bc11bc2 100644 --- a/docs/source/user-guide/sql/sql_status.md +++ b/docs/source/user-guide/sql/sql_status.md @@ -76,6 +76,9 @@ - [x] nullif - [x] case - [x] coalesce +- Comparison functions + - [x] greatest + - [x] least - Approximation functions - [x] approx_distinct - [x] approx_median