Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Use LogicalType for TypeSignature Numeric and String, Coercible #13240

Merged
merged 14 commits into from
Nov 6, 2024
6 changes: 6 additions & 0 deletions datafusion/common/src/types/logical.rs
Original file line number Diff line number Diff line change
Expand Up @@ -98,6 +98,12 @@ impl fmt::Debug for dyn LogicalType {
}
}

impl std::fmt::Display for dyn LogicalType {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
write!(f, "{self:?}")
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Maybe we can make a distinction between Debug and Display. Currently, we print something like

        let int: Box<dyn LogicalType> = Box::new(Int32);
        println!(format!("{}", int));

------- output ------
LogicalType(Native(Int32), Int32))

I imagine we can print Int32 toInt32, print JSON to JSON ... 🤔
However, I think it may not be the point of this PR. We can do it in another PR.

}
}

impl PartialEq for dyn LogicalType {
fn eq(&self, other: &Self) -> bool {
self.signature().eq(&other.signature())
Expand Down
42 changes: 39 additions & 3 deletions datafusion/common/src/types/native.rs
Original file line number Diff line number Diff line change
Expand Up @@ -348,6 +348,12 @@ impl LogicalType for NativeType {
// mapping solutions to provide backwards compatibility while transitioning from
// the purely physical system to a logical / physical system.

impl From<&DataType> for NativeType {
fn from(value: &DataType) -> Self {
value.clone().into()
}
}

impl From<DataType> for NativeType {
fn from(value: DataType) -> Self {
use NativeType::*;
Expand Down Expand Up @@ -392,8 +398,38 @@ impl From<DataType> for NativeType {
}
}

impl From<&DataType> for NativeType {
fn from(value: &DataType) -> Self {
value.clone().into()
impl NativeType {
#[inline]
pub fn is_numeric(&self) -> bool {
use NativeType::*;
matches!(
self,
UInt8
| UInt16
| UInt32
| UInt64
| Int8
| Int16
| Int32
| Int64
| Float16
| Float32
| Float64
| Decimal(_, _)
)
}

/// This function is the NativeType version of `can_cast_types`.
/// It handles general coercion rules that are widely applicable.
/// Avoid adding specific coercion cases here.
/// Aim to keep this logic as SIMPLE as possible!
pub fn can_cast_to(&self, target_type: &Self) -> bool {
// In Postgres, most functions coerce numeric strings to numeric inputs,
// but they do not accept numeric inputs as strings.
if self.is_numeric() && target_type == &NativeType::String {
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This talks about "can cast" and "can coerce" without making clear distinction between them.
Can we make it less ambigous and clarify whether we may "can cast" (ie does CAST(type_from AS type_to) exist?) or "can coerce" (does cast exist and is applied implicitly in various contexts?)

return false;
}

true
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Default for "can cast" and "can coerce" should be false.
Maps cannot coerce to numbers or lists.

Let'e have

Suggested change
true
false

here and let's define what can be converted in rules above. This will make the code simpler to reason about

}
}
10 changes: 7 additions & 3 deletions datafusion/expr-common/src/signature.rs
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@
//! and return types of functions in DataFusion.

use arrow::datatypes::DataType;
use datafusion_common::types::LogicalTypeRef;

/// Constant that is used as a placeholder for any valid timezone.
/// This is used where a function can accept a timestamp type with any
Expand Down Expand Up @@ -109,7 +110,7 @@ pub enum TypeSignature {
/// For example, `Coercible(vec![DataType::Float64])` accepts
/// arguments like `vec![DataType::Int32]` or `vec![DataType::Float32]`
/// since i32 and f32 can be casted to f64
goldmedal marked this conversation as resolved.
Show resolved Hide resolved
Coercible(Vec<DataType>),
Coercible(Vec<LogicalTypeRef>),
/// Fixed number of arguments of arbitrary types
/// If a function takes 0 argument, its `TypeSignature` should be `Any(0)`
Any(usize),
Expand Down Expand Up @@ -201,7 +202,10 @@ impl TypeSignature {
TypeSignature::Numeric(num) => {
vec![format!("Numeric({num})")]
}
TypeSignature::Exact(types) | TypeSignature::Coercible(types) => {
TypeSignature::Coercible(types) => {
vec![Self::join_types(types, ", ")]
}
TypeSignature::Exact(types) => {
vec![Self::join_types(types, ", ")]
}
TypeSignature::Any(arg_count) => {
Expand Down Expand Up @@ -322,7 +326,7 @@ impl Signature {
}
}
/// Target coerce types in order
pub fn coercible(target_types: Vec<DataType>, volatility: Volatility) -> Self {
pub fn coercible(target_types: Vec<LogicalTypeRef>, volatility: Volatility) -> Self {
Self {
type_signature: TypeSignature::Coercible(target_types),
volatility,
Expand Down
80 changes: 62 additions & 18 deletions datafusion/expr/src/type_coercion/functions.rs
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@ use arrow::{
};
use datafusion_common::{
exec_err, internal_datafusion_err, internal_err, plan_err,
types::NativeType,
utils::{coerced_fixed_size_list_to_list, list_ndims},
Result,
};
Expand Down Expand Up @@ -401,6 +402,10 @@ fn get_valid_types(
.map(|valid_type| current_types.iter().map(|_| valid_type.clone()).collect())
.collect(),
TypeSignature::String(number) => {
// TODO: we can switch to coercible after all the string functions support utf8view since it is choosen as the default string type.
//
// let data_types = get_valid_types(&TypeSignature::Coercible(vec![logical_string(); *number]), current_types)?.swap_remove(0);

if *number < 1 {
return plan_err!(
"The signature expected at least one argument but received {}",
Expand All @@ -415,20 +420,38 @@ fn get_valid_types(
);
}

fn coercion_rule(
let mut new_types = Vec::with_capacity(current_types.len());
for data_type in current_types.iter() {
let logical_data_type: NativeType = data_type.into();

match logical_data_type {
NativeType::String => {
new_types.push(data_type.to_owned());
}
NativeType::Null => {
new_types.push(DataType::Utf8);
}
_ => {
return plan_err!(
"The signature expected NativeType::String but received {data_type}"
);
}
}
}

let data_types = new_types;

// Find the common string type for the given types
fn find_common_type(
lhs_type: &DataType,
rhs_type: &DataType,
) -> Result<DataType> {
match (lhs_type, rhs_type) {
(DataType::Null, DataType::Null) => Ok(DataType::Utf8),
(DataType::Null, data_type) | (data_type, DataType::Null) => {
coercion_rule(data_type, &DataType::Utf8)
}
(DataType::Dictionary(_, lhs), DataType::Dictionary(_, rhs)) => {
coercion_rule(lhs, rhs)
find_common_type(lhs, rhs)
}
(DataType::Dictionary(_, v), other)
| (other, DataType::Dictionary(_, v)) => coercion_rule(v, other),
| (other, DataType::Dictionary(_, v)) => find_common_type(v, other),
_ => {
if let Some(coerced_type) = string_coercion(lhs_type, rhs_type) {
Ok(coerced_type)
Expand All @@ -444,15 +467,13 @@ fn get_valid_types(
}

// Length checked above, safe to unwrap
let mut coerced_type = current_types.first().unwrap().to_owned();
for t in current_types.iter().skip(1) {
coerced_type = coercion_rule(&coerced_type, t)?;
let mut coerced_type = data_types.first().unwrap().to_owned();
for t in data_types.iter().skip(1) {
coerced_type = find_common_type(&coerced_type, t)?;
}

fn base_type_or_default_type(data_type: &DataType) -> DataType {
if data_type.is_null() {
DataType::Utf8
} else if let DataType::Dictionary(_, v) = data_type {
if let DataType::Dictionary(_, v) = data_type {
base_type_or_default_type(v)
} else {
data_type.to_owned()
Expand All @@ -476,8 +497,20 @@ fn get_valid_types(
);
}

let mut valid_type = current_types.first().unwrap().clone();
for t in current_types.iter().skip(1) {
let mut new_types = Vec::with_capacity(current_types.len());
for data_type in current_types.iter() {
let logical_data_type: NativeType = data_type.into();
if logical_data_type.is_numeric() {
new_types.push(data_type.to_owned());
} else {
return plan_err!(
"The signature expected NativeType::Numeric but received {data_type}"
);
}
}

let mut valid_type = new_types.first().unwrap().clone();
for t in new_types.iter().skip(1) {
if let Some(coerced_type) = binary_numeric_coercion(&valid_type, t) {
valid_type = coerced_type;
} else {
Expand Down Expand Up @@ -506,14 +539,25 @@ fn get_valid_types(
);
}

let mut new_types = Vec::with_capacity(current_types.len());
for (data_type, target_type) in current_types.iter().zip(target_types.iter())
{
if !can_cast_types(data_type, target_type) {
return plan_err!("{data_type} is not coercible to {target_type}");
let logical_data_type: NativeType = data_type.into();
if logical_data_type == *target_type.native() {
new_types.push(data_type.to_owned());
} else if logical_data_type.can_cast_to(target_type.native()) {
let casted_type = target_type.default_cast_for(data_type)?;
new_types.push(casted_type);
} else {
return plan_err!(
"The signature expected {:?} but received {:?}",
target_type.native(),
logical_data_type
);
}
}

vec![target_types.to_owned()]
vec![new_types]
}
TypeSignature::Uniform(number, valid_types) => valid_types
.iter()
Expand Down
24 changes: 4 additions & 20 deletions datafusion/functions-aggregate/src/first_last.rs
Original file line number Diff line number Diff line change
Expand Up @@ -33,8 +33,8 @@ use datafusion_expr::aggregate_doc_sections::DOC_SECTION_GENERAL;
use datafusion_expr::function::{AccumulatorArgs, StateFieldsArgs};
use datafusion_expr::utils::{format_state_name, AggregateOrderSensitivity};
use datafusion_expr::{
Accumulator, AggregateUDFImpl, ArrayFunctionSignature, Documentation, Expr,
ExprFunctionExt, Signature, SortExpr, TypeSignature, Volatility,
Accumulator, AggregateUDFImpl, Documentation, Expr, ExprFunctionExt, Signature,
SortExpr, Volatility,
};
use datafusion_functions_aggregate_common::utils::get_sort_options;
use datafusion_physical_expr_common::sort_expr::{LexOrdering, LexOrderingRef};
Expand Down Expand Up @@ -79,15 +79,7 @@ impl Default for FirstValue {
impl FirstValue {
pub fn new() -> Self {
Self {
signature: Signature::one_of(
vec![
// TODO: we can introduce more strict signature that only numeric of array types are allowed
TypeSignature::ArraySignature(ArrayFunctionSignature::Array),
TypeSignature::Numeric(1),
TypeSignature::Uniform(1, vec![DataType::Utf8]),
],
Volatility::Immutable,
),
signature: Signature::any(1, Volatility::Immutable),
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ideally first/last value get the first/last item in the column, so it should be the type of the column.

I forgot why we don't use Any.
Without the change, we need to add another type signature for boolean, so I change it to Any in this PR

requirement_satisfied: false,
}
}
Expand Down Expand Up @@ -406,15 +398,7 @@ impl Default for LastValue {
impl LastValue {
pub fn new() -> Self {
Self {
signature: Signature::one_of(
vec![
// TODO: we can introduce more strict signature that only numeric of array types are allowed
TypeSignature::ArraySignature(ArrayFunctionSignature::Array),
TypeSignature::Numeric(1),
TypeSignature::Uniform(1, vec![DataType::Utf8]),
],
Volatility::Immutable,
),
signature: Signature::any(1, Volatility::Immutable),
requirement_satisfied: false,
}
}
Expand Down
3 changes: 2 additions & 1 deletion datafusion/functions-aggregate/src/stddev.rs
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@ use std::sync::{Arc, OnceLock};
use arrow::array::Float64Array;
use arrow::{array::ArrayRef, datatypes::DataType, datatypes::Field};

use datafusion_common::types::logical_float64;
use datafusion_common::{internal_err, not_impl_err, Result};
use datafusion_common::{plan_err, ScalarValue};
use datafusion_expr::aggregate_doc_sections::DOC_SECTION_STATISTICAL;
Expand Down Expand Up @@ -72,7 +73,7 @@ impl Stddev {
pub fn new() -> Self {
Self {
signature: Signature::coercible(
vec![DataType::Float64],
vec![logical_float64()],
Volatility::Immutable,
),
alias: vec!["stddev_samp".to_string()],
Expand Down
5 changes: 3 additions & 2 deletions datafusion/functions-aggregate/src/variance.rs
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,8 @@ use std::sync::OnceLock;
use std::{fmt::Debug, sync::Arc};

use datafusion_common::{
downcast_value, not_impl_err, plan_err, DataFusionError, Result, ScalarValue,
downcast_value, not_impl_err, plan_err, types::logical_float64, DataFusionError,
Result, ScalarValue,
};
use datafusion_expr::aggregate_doc_sections::DOC_SECTION_GENERAL;
use datafusion_expr::{
Expand Down Expand Up @@ -83,7 +84,7 @@ impl VarianceSample {
Self {
aliases: vec![String::from("var_sample"), String::from("var_samp")],
signature: Signature::coercible(
vec![DataType::Float64],
vec![logical_float64()],
Volatility::Immutable,
),
}
Expand Down
2 changes: 1 addition & 1 deletion datafusion/functions/src/string/bit_length.rs
Original file line number Diff line number Diff line change
Expand Up @@ -79,7 +79,7 @@ impl ScalarUDFImpl for BitLengthFunc {
ScalarValue::LargeUtf8(v) => Ok(ColumnarValue::Scalar(
ScalarValue::Int64(v.as_ref().map(|x| (x.len() * 8) as i64)),
)),
_ => unreachable!(),
_ => unreachable!("bit length"),
},
}
}
Expand Down
2 changes: 1 addition & 1 deletion datafusion/functions/src/string/concat.rs
Original file line number Diff line number Diff line change
Expand Up @@ -186,7 +186,7 @@ impl ScalarUDFImpl for ConcatFunc {
}
};
}
_ => unreachable!(),
_ => unreachable!("concat"),
}
}

Expand Down
2 changes: 1 addition & 1 deletion datafusion/functions/src/string/concat_ws.rs
Original file line number Diff line number Diff line change
Expand Up @@ -164,7 +164,7 @@ impl ScalarUDFImpl for ConcatWsFunc {
ColumnarValueRef::NonNullableArray(string_array)
}
}
_ => unreachable!(),
_ => unreachable!("concat ws"),
};

let mut columns = Vec::with_capacity(args.len() - 1);
Expand Down
2 changes: 1 addition & 1 deletion datafusion/functions/src/string/lower.rs
Original file line number Diff line number Diff line change
Expand Up @@ -107,7 +107,7 @@ mod tests {
let args = vec![ColumnarValue::Array(input)];
let result = match func.invoke(&args)? {
ColumnarValue::Array(result) => result,
_ => unreachable!(),
_ => unreachable!("lower"),
};
assert_eq!(&expected, &result);
Ok(())
Expand Down
2 changes: 1 addition & 1 deletion datafusion/functions/src/string/octet_length.rs
Original file line number Diff line number Diff line change
Expand Up @@ -82,7 +82,7 @@ impl ScalarUDFImpl for OctetLengthFunc {
ScalarValue::Utf8View(v) => Ok(ColumnarValue::Scalar(
ScalarValue::Int32(v.as_ref().map(|x| x.len() as i32)),
)),
_ => unreachable!(),
_ => unreachable!("OctetLengthFunc"),
},
}
}
Expand Down
2 changes: 1 addition & 1 deletion datafusion/functions/src/string/upper.rs
Original file line number Diff line number Diff line change
Expand Up @@ -107,7 +107,7 @@ mod tests {
let args = vec![ColumnarValue::Array(input)];
let result = match func.invoke(&args)? {
ColumnarValue::Array(result) => result,
_ => unreachable!(),
_ => unreachable!("upper"),
};
assert_eq!(&expected, &result);
Ok(())
Expand Down
2 changes: 1 addition & 1 deletion datafusion/functions/src/unicode/character_length.rs
Original file line number Diff line number Diff line change
Expand Up @@ -128,7 +128,7 @@ fn character_length(args: &[ArrayRef]) -> Result<ArrayRef> {
let string_array = args[0].as_string_view();
character_length_general::<Int32Type, _>(string_array)
}
_ => unreachable!(),
_ => unreachable!("CharacterLengthFunc"),
}
}

Expand Down
2 changes: 1 addition & 1 deletion datafusion/functions/src/unicode/lpad.rs
Original file line number Diff line number Diff line change
Expand Up @@ -162,7 +162,7 @@ pub fn lpad<T: OffsetSizeTrait>(args: &[ArrayRef]) -> Result<ArrayRef> {
length_array,
&args[2],
),
(_, _) => unreachable!(),
(_, _) => unreachable!("lpad"),
}
}

Expand Down
Loading