Skip to content

Commit

Permalink
[built-in function] add greatest and least
Browse files Browse the repository at this point in the history
  • Loading branch information
jimexist committed Jun 4, 2023
1 parent 9d22054 commit 30b575a
Show file tree
Hide file tree
Showing 22 changed files with 449 additions and 16 deletions.
2 changes: 2 additions & 0 deletions Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -48,7 +48,9 @@ rust-version = "1.64"
arrow = { version = "40.0.0", features = ["prettyprint"] }
arrow-flight = { version = "40.0.0", features = ["flight-sql-experimental"] }
arrow-buffer = { version = "40.0.0", default-features = false }
arrow-ord = { version = "40.0.0", default-features = false }
arrow-schema = { version = "40.0.0", default-features = false }
arrow-select = { version = "40.0.0", default-features = false }
arrow-array = { version = "40.0.0", default-features = false, features = ["chrono-tz"] }
parquet = { version = "40.0.0", features = ["arrow", "async", "object_store"] }

Expand Down
2 changes: 2 additions & 0 deletions datafusion-cli/Cargo.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

35 changes: 35 additions & 0 deletions datafusion/core/tests/sql/expr.rs
Original file line number Diff line number Diff line change
Expand Up @@ -200,6 +200,41 @@ async fn binary_bitwise_shift() -> Result<()> {
Ok(())
}

#[tokio::test]
async fn test_comparison_func_expressions() -> Result<()> {
test_expression!("greatest(1,2,3)", "3");
test_expression!("least(1,2,3)", "1");

Ok(())
}

#[tokio::test]
async fn test_comparison_func_array_scalar_expression() -> Result<()> {
let ctx = SessionContext::new();
let schema = Arc::new(Schema::new(vec![Field::new("a", DataType::Int64, false)]));
let batch = RecordBatch::try_new(
schema.clone(),
vec![Arc::new(Int64Array::from(vec![1, 2, 3]))],
)?;
let table = MemTable::try_new(schema, vec![vec![batch]])?;
ctx.register_table("t1", Arc::new(table))?;
let sql = "SELECT greatest(a, 2), least(a, 2) from t1";
let actual = execute_to_batches(&ctx, sql).await;
assert_batches_eq!(
&[
"+-------------------------+----------------------+",
"| greatest(t1.a,Int64(2)) | least(t1.a,Int64(2)) |",
"+-------------------------+----------------------+",
"| 2 | 1 |",
"| 2 | 2 |",
"| 3 | 2 |",
"+-------------------------+----------------------+",
],
&actual
);
Ok(())
}

#[tokio::test]
async fn test_interval_expressions() -> Result<()> {
// day nano intervals
Expand Down
10 changes: 10 additions & 0 deletions datafusion/expr/src/built_in_function.rs
Original file line number Diff line number Diff line change
Expand Up @@ -205,6 +205,10 @@ pub enum BuiltinScalarFunction {
Struct,
/// arrow_typeof
ArrowTypeof,
/// greatest
Greatest,
/// least
Least,
}

lazy_static! {
Expand Down Expand Up @@ -328,6 +332,8 @@ impl BuiltinScalarFunction {
BuiltinScalarFunction::Struct => Volatility::Immutable,
BuiltinScalarFunction::FromUnixtime => Volatility::Immutable,
BuiltinScalarFunction::ArrowTypeof => Volatility::Immutable,
BuiltinScalarFunction::Greatest => Volatility::Immutable,
BuiltinScalarFunction::Least => Volatility::Immutable,

// Stable builtin functions
BuiltinScalarFunction::Now => Volatility::Stable,
Expand Down Expand Up @@ -414,6 +420,10 @@ fn aliases(func: &BuiltinScalarFunction) -> &'static [&'static str] {
BuiltinScalarFunction::Upper => &["upper"],
BuiltinScalarFunction::Uuid => &["uuid"],

// comparison functions
BuiltinScalarFunction::Greatest => &["greatest"],
BuiltinScalarFunction::Least => &["least"],

// regex functions
BuiltinScalarFunction::RegexpMatch => &["regexp_match"],
BuiltinScalarFunction::RegexpReplace => &["regexp_replace"],
Expand Down
35 changes: 35 additions & 0 deletions datafusion/expr/src/comparison_expressions.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,35 @@
// Licensed to the Apache Software Foundation (ASF) under one
// or more contributor license agreements. See the NOTICE file
// distributed with this work for additional information
// regarding copyright ownership. The ASF licenses this file
// to you under the Apache License, Version 2.0 (the
// "License"); you may not use this file except in compliance
// with the License. You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing,
// software distributed under the License is distributed on an
// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
// KIND, either express or implied. See the License for the
// specific language governing permissions and limitations
// under the License.

use arrow::datatypes::DataType;

/// Currently supported types by the comparison function.
pub static SUPPORTED_COMPARISON_TYPES: &[DataType] = &[
DataType::Boolean,
DataType::UInt8,
DataType::UInt16,
DataType::UInt32,
DataType::UInt64,
DataType::Int8,
DataType::Int16,
DataType::Int32,
DataType::Int64,
DataType::Float32,
DataType::Float64,
DataType::Utf8,
DataType::LargeUtf8,
];
24 changes: 23 additions & 1 deletion datafusion/expr/src/expr_fn.rs
Original file line number Diff line number Diff line change
Expand Up @@ -237,6 +237,22 @@ pub fn concat_ws(sep: Expr, values: Vec<Expr>) -> Expr {
))
}

/// Returns the greatest value of all arguments.
pub fn greatest(args: &[Expr]) -> Expr {
Expr::ScalarFunction(ScalarFunction::new(
BuiltinScalarFunction::Greatest,
args.to_vec(),
))
}

/// Returns the least value of all arguments.
pub fn least(args: &[Expr]) -> Expr {
Expr::ScalarFunction(ScalarFunction::new(
BuiltinScalarFunction::Least,
args.to_vec(),
))
}

/// Returns an approximate value of π
pub fn pi() -> Expr {
Expr::ScalarFunction(ScalarFunction::new(BuiltinScalarFunction::Pi, vec![]))
Expand Down Expand Up @@ -620,9 +636,15 @@ nary_scalar_expr!(Coalesce, coalesce, "returns `coalesce(args...)`, which evalua
nary_scalar_expr!(
ConcatWithSeparator,
concat_ws_expr,
"concatenates several strings, placing a seperator between each one"
"concatenates several strings, placing a separator between each one"
);
nary_scalar_expr!(Concat, concat_expr, "concatenates several strings");
nary_scalar_expr!(
Greatest,
greatest_expr,
"gets the largest value of the list"
);
nary_scalar_expr!(Least, least_expr, "gets the smallest value of the list");

// date functions
scalar_expr!(DatePart, date_part, part date, "extracts a subfield from the date");
Expand Down
15 changes: 13 additions & 2 deletions datafusion/expr/src/function.rs
Original file line number Diff line number Diff line change
Expand Up @@ -22,8 +22,8 @@ use crate::nullif::SUPPORTED_NULLIF_TYPES;
use crate::type_coercion::functions::data_types;
use crate::ColumnarValue;
use crate::{
array_expressions, conditional_expressions, struct_expressions, Accumulator,
BuiltinScalarFunction, Signature, TypeSignature,
array_expressions, comparison_expressions, conditional_expressions,
struct_expressions, Accumulator, BuiltinScalarFunction, Signature, TypeSignature,
};
use arrow::datatypes::{DataType, Field, Fields, IntervalUnit, TimeUnit};
use datafusion_common::{DataFusionError, Result};
Expand Down Expand Up @@ -168,6 +168,11 @@ pub fn return_type(
let coerced_types = data_types(input_expr_types, &signature(fun));
coerced_types.map(|typs| typs[0].clone())
}
BuiltinScalarFunction::Greatest | BuiltinScalarFunction::Least => {
// GREATEST and LEAST have multiple args and they might get coerced, get a preview of this
let coerced_types = data_types(input_expr_types, &signature(fun));
coerced_types.map(|typs| typs[0].clone())
}
BuiltinScalarFunction::OctetLength => {
utf8_to_int_type(&input_expr_types[0], "octet_length")
}
Expand Down Expand Up @@ -376,6 +381,12 @@ pub fn signature(fun: &BuiltinScalarFunction) -> Signature {
BuiltinScalarFunction::Chr | BuiltinScalarFunction::ToHex => {
Signature::uniform(1, vec![DataType::Int64], fun.volatility())
}
BuiltinScalarFunction::Greatest | BuiltinScalarFunction::Least => {
Signature::variadic_equal(
comparison_expressions::SUPPORTED_COMPARISON_TYPES.to_vec(),
fun.volatility(),
)
}
BuiltinScalarFunction::Lpad | BuiltinScalarFunction::Rpad => Signature::one_of(
vec![
TypeSignature::Exact(vec![DataType::Utf8, DataType::Int64]),
Expand Down
2 changes: 1 addition & 1 deletion datafusion/expr/src/function_err.rs
Original file line number Diff line number Diff line change
Expand Up @@ -53,7 +53,7 @@ impl TypeSignature {
.collect::<Vec<&str>>()
.join(", ")]
}
TypeSignature::VariadicEqual => vec!["T, .., T".to_string()],
TypeSignature::VariadicEqual(_) => vec!["T, .., T".to_string()],
TypeSignature::VariadicAny => vec!["Any, .., Any".to_string()],
TypeSignature::OneOf(sigs) => {
sigs.iter().flat_map(|s| s.to_string_repr()).collect()
Expand Down
1 change: 1 addition & 0 deletions datafusion/expr/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,7 @@ pub mod aggregate_function;
pub mod array_expressions;
mod built_in_function;
mod columnar_value;
pub mod comparison_expressions;
pub mod conditional_expressions;
pub mod expr;
pub mod expr_fn;
Expand Down
11 changes: 6 additions & 5 deletions datafusion/expr/src/signature.rs
Original file line number Diff line number Diff line change
Expand Up @@ -42,10 +42,10 @@ pub enum TypeSignature {
/// arbitrary number of arguments of an common type out of a list of valid types
// A function such as `concat` is `Variadic(vec![DataType::Utf8, DataType::LargeUtf8])`
Variadic(Vec<DataType>),
/// arbitrary number of arguments of an arbitrary but equal type
/// arbitrary number of arguments of an equal type
// A function such as `array` is `VariadicEqual`
// The first argument decides the type used for coercion
VariadicEqual,
VariadicEqual(Vec<DataType>),
/// arbitrary number of arguments with arbitrary types
VariadicAny,
/// fixed number of arguments of an arbitrary but equal type out of a list of valid types
Expand Down Expand Up @@ -85,10 +85,11 @@ impl Signature {
volatility,
}
}
/// variadic_equal - Creates a variadic signature that represents an arbitrary number of arguments of the same type.
pub fn variadic_equal(volatility: Volatility) -> Self {
/// variadic_equal - Creates a variadic signature that represents an arbitrary number of arguments of the same type in
/// the allowed_types.
pub fn variadic_equal(allowed_types: Vec<DataType>, volatility: Volatility) -> Self {
Self {
type_signature: TypeSignature::VariadicEqual,
type_signature: TypeSignature::VariadicEqual(allowed_types),
volatility,
}
}
Expand Down
10 changes: 7 additions & 3 deletions datafusion/expr/src/type_coercion/functions.rs
Original file line number Diff line number Diff line change
Expand Up @@ -71,11 +71,15 @@ fn get_valid_types(
.iter()
.map(|valid_type| (0..*number).map(|_| valid_type.clone()).collect())
.collect(),
TypeSignature::VariadicEqual => {
// one entry with the same len as current_types, whose type is `current_types[0]`.
TypeSignature::VariadicEqual(allowed_types) => {
if allowed_types.is_empty() {
return Err(DataFusionError::Plan(
"allowed types cannot be empty".to_string(),
));
}
vec![current_types
.iter()
.map(|_| current_types[0].clone())
.map(|_| allowed_types[0].clone())
.collect()]
}
TypeSignature::VariadicAny => {
Expand Down
2 changes: 2 additions & 0 deletions datafusion/physical-expr/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -46,7 +46,9 @@ ahash = { version = "0.8", default-features = false, features = ["runtime-rng"]
arrow = { workspace = true }
arrow-array = { workspace = true }
arrow-buffer = { workspace = true }
arrow-ord = { workspace = true }
arrow-schema = { workspace = true }
arrow-select = { workspace = true }
blake2 = { version = "^0.10.2", optional = true }
blake3 = { version = "1.0", optional = true }
chrono = { version = "0.4.23", default-features = false }
Expand Down
Loading

0 comments on commit 30b575a

Please sign in to comment.