Skip to content

Commit

Permalink
feat(query): add decimal round/truncate/ceil/floor function (databend…
Browse files Browse the repository at this point in the history
…labs#14040)

* feat(query): add decimal round/truncate function

* chore(query): add tests

* chore(query): add tests

* chore(query): add tests

* chore(query): add tests

* fix(function): fix ceil function

* chore(query): fix tests
  • Loading branch information
sundy-li authored Dec 18, 2023
1 parent 188426e commit e19d4c1
Show file tree
Hide file tree
Showing 9 changed files with 810 additions and 495 deletions.
29 changes: 29 additions & 0 deletions src/query/expression/src/type_check.rs
Original file line number Diff line number Diff line change
Expand Up @@ -28,9 +28,12 @@ use crate::function::FunctionSignature;
use crate::types::decimal::DecimalSize;
use crate::types::decimal::MAX_DECIMAL128_PRECISION;
use crate::types::decimal::MAX_DECIMAL256_PRECISION;
use crate::types::ArgType;
use crate::types::DataType;
use crate::types::DecimalDataType;
use crate::types::Int64Type;
use crate::types::Number;
use crate::types::NumberScalar;
use crate::AutoCastRules;
use crate::ColumnIndex;
use crate::ConstantFolder;
Expand Down Expand Up @@ -124,6 +127,32 @@ pub fn check<Index: ColumnIndex>(
}
}

// inject the params
if ["round", "truncate"].contains(&name.as_str()) && params.is_empty() {
let mut scale = 0;
let mut new_args = args_expr.clone();

if args_expr.len() == 2 {
let scalar_expr = &args_expr[1];
scale = check_number::<_, i64>(
scalar_expr.span(),
&FunctionContext::default(),
scalar_expr,
fn_registry,
)?;
} else {
new_args.push(Expr::Constant {
span: None,
scalar: Scalar::Number(NumberScalar::Int64(scale)),
data_type: Int64Type::data_type(),
})
}
scale = scale.clamp(-76, 76);
let add_on_scale = (scale + 76) as usize;
let params = vec![add_on_scale];
return check_function(*span, name, &params, &args_expr, fn_registry);
}

check_function(*span, name, params, &args_expr, fn_registry)
}
RawExpr::LambdaFunctionCall {
Expand Down
270 changes: 270 additions & 0 deletions src/query/functions/src/scalars/decimal.rs
Original file line number Diff line number Diff line change
Expand Up @@ -1639,3 +1639,273 @@ fn decimal_to_int<T: Number>(
Value::Column(result)
}
}

pub fn register_decimal_math(registry: &mut FunctionRegistry) {
let factory = |params: &[usize], args_type: &[DataType], round_mode: RoundMode| {
if args_type.is_empty() {
return None;
}

let from_type = args_type[0].remove_nullable();
if !matches!(from_type, DataType::Decimal(_)) {
return None;
}

let from_decimal_type = from_type.as_decimal().unwrap();

let scale = if params.is_empty() {
0
} else {
params[0] as i64 - 76
};

let decimal_size = DecimalSize {
precision: from_decimal_type.precision(),
scale: scale.clamp(0, from_decimal_type.scale() as i64) as u8,
};

let dest_decimal_type = DecimalDataType::from_size(decimal_size).ok()?;
let name = format!("{:?}", round_mode).to_lowercase();

let mut sig_args_type = args_type.to_owned();
sig_args_type[0] = from_type.clone();
let f = Function {
signature: FunctionSignature {
name,
args_type: sig_args_type,
return_type: DataType::Decimal(dest_decimal_type),
},
eval: FunctionEval::Scalar {
calc_domain: Box::new(move |_ctx, _d| FunctionDomain::Full),
eval: Box::new(move |args, _ctx| {
decimal_round_truncate(
&args[0],
from_type.clone(),
dest_decimal_type,
scale,
round_mode,
)
}),
},
};

if args_type[0].is_nullable() {
Some(f.passthrough_nullable())
} else {
Some(f)
}
};

for m in [
RoundMode::Round,
RoundMode::Truncate,
RoundMode::Ceil,
RoundMode::Floor,
] {
let name = format!("{:?}", m).to_lowercase();
registry.register_function_factory(&name, move |params, args_type| {
Some(Arc::new(factory(params, args_type, m)?))
});
}
}

#[derive(Copy, Clone, Debug)]
enum RoundMode {
Round,
Truncate,
Floor,
Ceil,
}

fn decimal_round_positive<T>(values: &[T], source_scale: i64, target_scale: i64) -> Vec<T>
where T: Decimal + From<i8> + DivAssign + Div<Output = T> + Add<Output = T> + Sub<Output = T> {
let power_of_ten = T::e((source_scale - target_scale) as u32);
let addition = power_of_ten / T::from(2);

values
.iter()
.map(|input| {
let input = if input < &T::zero() {
*input - addition
} else {
*input + addition
};
input / power_of_ten
})
.collect()
}

fn decimal_round_negative<T>(values: &[T], source_scale: i64, target_scale: i64) -> Vec<T>
where T: Decimal
+ From<i8>
+ DivAssign
+ Div<Output = T>
+ Add<Output = T>
+ Sub<Output = T>
+ Mul<Output = T> {
let divide_power_of_ten = T::e((source_scale - target_scale) as u32);
let addition = divide_power_of_ten / T::from(2);
let multiply_power_of_ten = T::e((-target_scale) as u32);

values
.iter()
.map(|input| {
let input = if input < &T::zero() {
*input - addition
} else {
*input + addition
};
input / divide_power_of_ten * multiply_power_of_ten
})
.collect()
}

// if round mode is ceil, truncate should add one value
fn decimal_truncate_positive<T>(values: &[T], source_scale: i64, target_scale: i64) -> Vec<T>
where T: Decimal + From<i8> + DivAssign + Div<Output = T> + Add<Output = T> + Sub<Output = T> {
let power_of_ten = T::e((source_scale - target_scale) as u32);

values.iter().map(|input| *input / power_of_ten).collect()
}

fn decimal_truncate_negative<T>(values: &[T], source_scale: i64, target_scale: i64) -> Vec<T>
where T: Decimal
+ From<i8>
+ DivAssign
+ Div<Output = T>
+ Add<Output = T>
+ Sub<Output = T>
+ Mul<Output = T> {
let divide_power_of_ten = T::e((source_scale - target_scale) as u32);
let multiply_power_of_ten = T::e((-target_scale) as u32);

values
.iter()
.map(|input| *input / divide_power_of_ten * multiply_power_of_ten)
.collect()
}

fn decimal_floor<T>(values: &[T], source_scale: i64) -> Vec<T>
where T: Decimal
+ From<i8>
+ DivAssign
+ Div<Output = T>
+ Add<Output = T>
+ Sub<Output = T>
+ Mul<Output = T> {
let power_of_ten = T::e(source_scale as u32);

values
.iter()
.map(|input| {
if input < &T::zero() {
// below 0 we ceil the number (e.g. -10.5 -> -11)
((*input + T::one()) / power_of_ten) - T::one()
} else {
*input / power_of_ten
}
})
.collect()
}

fn decimal_ceil<T>(values: &[T], source_scale: i64) -> Vec<T>
where T: Decimal
+ From<i8>
+ DivAssign
+ Div<Output = T>
+ Add<Output = T>
+ Sub<Output = T>
+ Mul<Output = T> {
let power_of_ten = T::e(source_scale as u32);

values
.iter()
.map(|input| {
if input <= &T::zero() {
*input / power_of_ten
} else {
((*input - T::one()) / power_of_ten) + T::one()
}
})
.collect()
}

fn decimal_round_truncate(
arg: &ValueRef<AnyType>,
from_type: DataType,
dest_type: DecimalDataType,
target_scale: i64,
mode: RoundMode,
) -> Value<AnyType> {
let from_decimal_type = from_type.as_decimal().unwrap();
let source_scale = from_decimal_type.scale() as i64;

if source_scale < target_scale {
return arg.clone().to_owned();
}

let mut is_scalar = false;
let column = match arg {
ValueRef::Column(column) => column.clone(),
ValueRef::Scalar(s) => {
is_scalar = true;
let builder = ColumnBuilder::repeat(s, 1, &from_type);
builder.build()
}
};

let none_negative = target_scale >= 0;

let result = match from_decimal_type {
DecimalDataType::Decimal128(_) => {
let (buffer, _) = i128::try_downcast_column(&column).unwrap();

let result = match (none_negative, mode) {
(true, RoundMode::Round) => {
decimal_round_positive::<_>(&buffer, source_scale, target_scale)
}
(true, RoundMode::Truncate) => {
decimal_truncate_positive::<_>(&buffer, source_scale, target_scale)
}
(false, RoundMode::Round) => {
decimal_round_negative::<_>(&buffer, source_scale, target_scale)
}
(false, RoundMode::Truncate) => {
decimal_truncate_negative::<_>(&buffer, source_scale, target_scale)
}
(_, RoundMode::Floor) => decimal_floor::<_>(&buffer, source_scale),
(_, RoundMode::Ceil) => decimal_ceil::<_>(&buffer, source_scale),
};
i128::to_column(result, dest_type.size())
}

DecimalDataType::Decimal256(_) => {
let (buffer, _) = i256::try_downcast_column(&column).unwrap();
let result = match (none_negative, mode) {
(true, RoundMode::Round) => {
decimal_round_positive::<_>(&buffer, source_scale, target_scale)
}
(true, RoundMode::Truncate) => {
decimal_truncate_positive::<_>(&buffer, source_scale, target_scale)
}
(false, RoundMode::Round) => {
decimal_round_negative::<_>(&buffer, source_scale, target_scale)
}
(false, RoundMode::Truncate) => {
decimal_truncate_negative::<_>(&buffer, source_scale, target_scale)
}
(_, RoundMode::Floor) => decimal_floor::<_>(&buffer, source_scale),
(_, RoundMode::Ceil) => decimal_ceil::<_>(&buffer, source_scale),
};
i256::to_column(result, dest_type.size())
}
};

let result = Column::Decimal(result);
if is_scalar {
let scalar = result.index(0).unwrap();
Value::Scalar(scalar.to_owned())
} else {
Value::Column(result)
}
}
4 changes: 4 additions & 0 deletions src/query/functions/src/scalars/math.rs
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,11 @@ use num_traits::Float;
use num_traits::Pow;
use ordered_float::OrderedFloat;

use crate::scalars::decimal::register_decimal_math;

pub fn register(registry: &mut FunctionRegistry) {
register_decimal_math(registry);

registry.register_1_arg::<NumberType<F64>, NumberType<F64>, _, _>(
"sin",
|_, _| {
Expand Down
Loading

0 comments on commit e19d4c1

Please sign in to comment.