Skip to content

Commit c826009

Browse files
authored
Refactor log() signature to use coercion API + fixes (#18519)
## Which issue does this PR close? <!-- We generally require a GitHub issue to be filed for all bug fixes and enhancements and this helps us generate change logs for our releases. You can link an issue to this PR using the GitHub syntax. For example `Closes #123` indicates that this PR will close issue #123. --> Part of #14763 and #14760 ## Rationale for this change <!-- Why are you proposing this change? If this is already explained clearly in the issue then this section is not needed. Explaining clearly why changes are proposed helps reviewers understand your changes and offer better suggestions for fixes. --> Current `log()` signature has some drawbacks: https://github.com/apache/datafusion/blob/a5eb9121ccf802dda547897155403b08a4fbf774/datafusion/functions/src/math/log.rs#L78-L105 - A bit nasty to look at: mixes numeric with exact float/int with exact decimal (of exact precision and scale) - Can't accommodate arbitrary decimals of any precision/scale (this is true for other functions too) Aim of this PR is to refactor it to use the coercion API, uplifting the API where necessary to make this possible. This simplifies the signature in code, whilst not losing flexibility. Also other minor refactors are included to log. ## What changes are included in this PR? <!-- There is no need to duplicate the description in the issue here but it is sometimes worth providing a summary of the individual changes in this PR. --> New `TypeSignatureClass` variants: Float, Decimal & Numeric Refactor `log()` signature to be more in line with it's supported implementations. Fix issue in `log()` where `ColumnarValue::Scalar`s were being lost as `ColumnarValue::Array`s for the base. Support null propagation in `simplify()` for `log()`. ~~Fix issue with `calculate_binary_math` where it wasn't casting scalars.~~ ## Are these changes tested? <!-- We typically require tests for all PRs in order to: 1. Prevent the code from being accidentally broken by subsequent changes 2. Serve as another way to document the expected behavior of the code If tests are not included in your PR, please explain why (for example, are they covered by existing tests)? --> Added new tests. - Tests for float16, decimal32, decimal64, decimals with different scales/precisions - Test for null propagation (ensure use array input to avoid function inlining) ## Are there any user-facing changes? <!-- If there are user-facing changes then we may require documentation to be updated before approving the PR. --> No. <!-- If there are any breaking changes to public APIs, please add the `api change` label. -->
1 parent 2233796 commit c826009

File tree

7 files changed

+216
-122
lines changed

7 files changed

+216
-122
lines changed

datafusion/common/src/scalar/mod.rs

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1734,7 +1734,7 @@ impl ScalarValue {
17341734
) {
17351735
return _internal_err!("Invalid precision and scale {err}");
17361736
}
1737-
if *scale <= 0 {
1737+
if *scale < 0 {
17381738
return _internal_err!("Negative scale is not supported");
17391739
}
17401740
match 10_i32.checked_pow((*scale + 1) as u32) {
@@ -1750,7 +1750,7 @@ impl ScalarValue {
17501750
) {
17511751
return _internal_err!("Invalid precision and scale {err}");
17521752
}
1753-
if *scale <= 0 {
1753+
if *scale < 0 {
17541754
return _internal_err!("Negative scale is not supported");
17551755
}
17561756
match i64::from(10).checked_pow((*scale + 1) as u32) {
@@ -4407,6 +4407,7 @@ macro_rules! impl_scalar {
44074407

44084408
impl_scalar!(f64, Float64);
44094409
impl_scalar!(f32, Float32);
4410+
impl_scalar!(f16, Float16);
44104411
impl_scalar!(i8, Int8);
44114412
impl_scalar!(i16, Int16);
44124413
impl_scalar!(i32, Int32);
@@ -4563,6 +4564,7 @@ impl_try_from!(UInt8, u8);
45634564
impl_try_from!(UInt16, u16);
45644565
impl_try_from!(UInt32, u32);
45654566
impl_try_from!(UInt64, u64);
4567+
impl_try_from!(Float16, f16);
45664568
impl_try_from!(Float32, f32);
45674569
impl_try_from!(Float64, f64);
45684570
impl_try_from!(Boolean, bool);

datafusion/common/src/types/native.rs

Lines changed: 11 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -430,22 +430,7 @@ impl From<DataType> for NativeType {
430430
impl NativeType {
431431
#[inline]
432432
pub fn is_numeric(&self) -> bool {
433-
use NativeType::*;
434-
matches!(
435-
self,
436-
UInt8
437-
| UInt16
438-
| UInt32
439-
| UInt64
440-
| Int8
441-
| Int16
442-
| Int32
443-
| Int64
444-
| Float16
445-
| Float32
446-
| Float64
447-
| Decimal(_, _)
448-
)
433+
self.is_integer() || self.is_float() || self.is_decimal()
449434
}
450435

451436
#[inline]
@@ -491,4 +476,14 @@ impl NativeType {
491476
pub fn is_null(&self) -> bool {
492477
matches!(self, NativeType::Null)
493478
}
479+
480+
#[inline]
481+
pub fn is_decimal(&self) -> bool {
482+
matches!(self, Self::Decimal(_, _))
483+
}
484+
485+
#[inline]
486+
pub fn is_float(&self) -> bool {
487+
matches!(self, Self::Float16 | Self::Float32 | Self::Float64)
488+
}
494489
}

datafusion/expr-common/src/signature.rs

Lines changed: 23 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -21,7 +21,7 @@ use std::fmt::Display;
2121
use std::hash::Hash;
2222

2323
use crate::type_coercion::aggregates::NUMERICS;
24-
use arrow::datatypes::{DataType, IntervalUnit, TimeUnit};
24+
use arrow::datatypes::{DataType, Decimal128Type, DecimalType, IntervalUnit, TimeUnit};
2525
use datafusion_common::types::{LogicalType, LogicalTypeRef, NativeType};
2626
use datafusion_common::utils::ListCoercion;
2727
use datafusion_common::{internal_err, plan_err, Result};
@@ -333,9 +333,10 @@ pub enum TypeSignatureClass {
333333
Interval,
334334
Duration,
335335
Native(LogicalTypeRef),
336-
// TODO:
337-
// Numeric
338336
Integer,
337+
Float,
338+
Decimal,
339+
Numeric,
339340
/// Encompasses both the native Binary as well as arbitrarily sized FixedSizeBinary types
340341
Binary,
341342
}
@@ -378,6 +379,13 @@ impl TypeSignatureClass {
378379
TypeSignatureClass::Binary => {
379380
vec![DataType::Binary]
380381
}
382+
TypeSignatureClass::Decimal => vec![Decimal128Type::DEFAULT_TYPE],
383+
TypeSignatureClass::Float => vec![DataType::Float64],
384+
TypeSignatureClass::Numeric => vec![
385+
DataType::Float64,
386+
DataType::Int64,
387+
Decimal128Type::DEFAULT_TYPE,
388+
],
381389
}
382390
}
383391

@@ -395,6 +403,9 @@ impl TypeSignatureClass {
395403
TypeSignatureClass::Duration if logical_type.is_duration() => true,
396404
TypeSignatureClass::Integer if logical_type.is_integer() => true,
397405
TypeSignatureClass::Binary if logical_type.is_binary() => true,
406+
TypeSignatureClass::Decimal if logical_type.is_decimal() => true,
407+
TypeSignatureClass::Float if logical_type.is_float() => true,
408+
TypeSignatureClass::Numeric if logical_type.is_numeric() => true,
398409
_ => false,
399410
}
400411
}
@@ -428,6 +439,15 @@ impl TypeSignatureClass {
428439
TypeSignatureClass::Binary if native_type.is_binary() => {
429440
Ok(origin_type.to_owned())
430441
}
442+
TypeSignatureClass::Decimal if native_type.is_decimal() => {
443+
Ok(origin_type.to_owned())
444+
}
445+
TypeSignatureClass::Float if native_type.is_float() => {
446+
Ok(origin_type.to_owned())
447+
}
448+
TypeSignatureClass::Numeric if native_type.is_numeric() => {
449+
Ok(origin_type.to_owned())
450+
}
431451
_ if native_type.is_null() => Ok(origin_type.to_owned()),
432452
_ => internal_err!("May miss the matching logic in `matches_native_type`"),
433453
}

0 commit comments

Comments
 (0)