Skip to content

Commit

Permalink
support cast signed numeric to decimal (apache#1044)
Browse files Browse the repository at this point in the history
* support cast signed numeric to decimal

* add test for i8,i16,i32,i64,f32,f64 casted to decimal

* change format of float64

* add none test; merge integer test together

Can drop this after rebase on commit 4b3d928, first released in 7.0.0
  • Loading branch information
liukun4515 authored and mcheshkov committed Aug 26, 2024
1 parent d257bf2 commit 311a519
Showing 1 changed file with 181 additions and 3 deletions.
184 changes: 181 additions & 3 deletions arrow/src/compute/kernels/cast.rs
Original file line number Diff line number Diff line change
Expand Up @@ -68,7 +68,9 @@ pub fn can_cast_types(from_type: &DataType, to_type: &DataType) -> bool {
}

match (from_type, to_type) {
(
// TODO now just support signed numeric to decimal, support decimal to numeric later
(Int8 | Int16 | Int32 | Int64 | Float32 | Float64, Decimal(_, _))
| (
Null,
Boolean
| Int8
Expand Down Expand Up @@ -870,6 +872,45 @@ pub fn cast(array: &ArrayRef, to_type: &DataType) -> Result<ArrayRef> {
cast_with_options(array, to_type, &DEFAULT_CAST_OPTIONS)
}

// cast the integer array to defined decimal data type array
macro_rules! cast_integer_to_decimal {
($ARRAY: expr, $ARRAY_TYPE: ident, $PRECISION : ident, $SCALE : ident) => {{
let mut decimal_builder = DecimalBuilder::new($ARRAY.len(), *$PRECISION, *$SCALE);
let array = $ARRAY.as_any().downcast_ref::<$ARRAY_TYPE>().unwrap();
let mul: i128 = 10_i128.pow(*$SCALE as u32);
for i in 0..array.len() {
if array.is_null(i) {
decimal_builder.append_null()?;
} else {
// convert i128 first
let v = array.value(i) as i128;
// if the input value is overflow, it will throw an error.
decimal_builder.append_value(mul * v)?;
}
}
Ok(Arc::new(decimal_builder.finish()))
}};
}

// cast the floating-point array to defined decimal data type array
macro_rules! cast_floating_point_to_decimal {
($ARRAY: expr, $ARRAY_TYPE: ident, $PRECISION : ident, $SCALE : ident) => {{
let mut decimal_builder = DecimalBuilder::new($ARRAY.len(), *$PRECISION, *$SCALE);
let array = $ARRAY.as_any().downcast_ref::<$ARRAY_TYPE>().unwrap();
let mul = 10_f64.powi(*$SCALE as i32);
for i in 0..array.len() {
if array.is_null(i) {
decimal_builder.append_null()?;
} else {
let v = ((array.value(i) as f64) * mul) as i128;
// if the input value is overflow, it will throw an error.
decimal_builder.append_value(v)?;
}
}
Ok(Arc::new(decimal_builder.finish()))
}};
}

/// Cast `array` to the provided data type and return a new Array with
/// type `to_type`, if possible. It accepts `CastOptions` to allow consumers
/// to configure cast behavior.
Expand Down Expand Up @@ -904,6 +945,34 @@ pub fn cast_with_options(
return Ok(array.clone());
}
match (from_type, to_type) {
(_, Decimal(precision, scale)) => {
// cast data to decimal
match from_type {
// TODO now just support signed numeric to decimal, support decimal to numeric later
Int8 => {
cast_integer_to_decimal!(array, Int8Array, precision, scale)
}
Int16 => {
cast_integer_to_decimal!(array, Int16Array, precision, scale)
}
Int32 => {
cast_integer_to_decimal!(array, Int32Array, precision, scale)
}
Int64 => {
cast_integer_to_decimal!(array, Int64Array, precision, scale)
}
Float32 => {
cast_floating_point_to_decimal!(array, Float32Array, precision, scale)
}
Float64 => {
cast_floating_point_to_decimal!(array, Float64Array, precision, scale)
}
_ => Err(ArrowError::CastError(format!(
"Casting from {:?} to {:?} not supported",
from_type, to_type
))),
}
}
(
Null,
Boolean
Expand Down Expand Up @@ -2074,7 +2143,7 @@ fn cast_string_to_date64<Offset: StringOffsetSizeTrait>(
if string_array.is_null(i) {
Ok(None)
} else {
let string = string_array
let string = string_array
.value(i);

let result = string
Expand Down Expand Up @@ -2291,7 +2360,7 @@ fn dictionary_cast<K: ArrowDictionaryKeyType>(
return Err(ArrowError::CastError(format!(
"Unsupported type {:?} for dictionary index",
to_index_type
)))
)));
}
};

Expand Down Expand Up @@ -2655,6 +2724,115 @@ where
mod tests {
use super::*;
use crate::{buffer::Buffer, util::display::array_value_to_string};
use num::traits::Pow;

#[test]
fn test_cast_numeric_to_decimal() {
// test cast type
let data_types = vec![
DataType::Int8,
DataType::Int16,
DataType::Int32,
DataType::Int64,
DataType::Float32,
DataType::Float64,
];
let decimal_type = DataType::Decimal(38, 6);
for data_type in data_types {
assert!(can_cast_types(&data_type, &decimal_type))
}
assert!(!can_cast_types(&DataType::UInt64, &decimal_type));

// test cast data
let input_datas = vec![
Arc::new(Int8Array::from(vec![
Some(1),
Some(2),
Some(3),
None,
Some(5),
])) as ArrayRef, // i8
Arc::new(Int16Array::from(vec![
Some(1),
Some(2),
Some(3),
None,
Some(5),
])) as ArrayRef, // i16
Arc::new(Int32Array::from(vec![
Some(1),
Some(2),
Some(3),
None,
Some(5),
])) as ArrayRef, // i32
Arc::new(Int64Array::from(vec![
Some(1),
Some(2),
Some(3),
None,
Some(5),
])) as ArrayRef, // i64
];

// i8, i16, i32, i64
for array in input_datas {
let casted_array = cast(&array, &decimal_type).unwrap();
let decimal_array = casted_array
.as_any()
.downcast_ref::<DecimalArray>()
.unwrap();
assert_eq!(&decimal_type, decimal_array.data_type());
for i in 0..array.len() {
if i == 3 {
assert!(decimal_array.is_null(i as usize));
} else {
assert_eq!(
10_i128.pow(6) * (i as i128 + 1),
decimal_array.value(i as usize)
);
}
}
}

// test i8 to decimal type with overflow the result type
// the 100 will be converted to 1000_i128, but it is out of range for max value in the precision 3.
let array = Int8Array::from(vec![1, 2, 3, 4, 100]);
let array = Arc::new(array) as ArrayRef;
let casted_array = cast(&array, &DataType::Decimal(3, 1));
assert!(casted_array.is_err());
assert_eq!("Invalid argument error: The value of 1000 i128 is not compatible with Decimal(3,1)", casted_array.unwrap_err().to_string());

// test f32 to decimal type
let f_data: Vec<f32> = vec![1.1, 2.2, 4.4, 1.123_456_8];
let array = Float32Array::from(f_data.clone());
let array = Arc::new(array) as ArrayRef;
let casted_array = cast(&array, &decimal_type).unwrap();
let decimal_array = casted_array
.as_any()
.downcast_ref::<DecimalArray>()
.unwrap();
assert_eq!(&decimal_type, decimal_array.data_type());
for (i, item) in f_data.iter().enumerate().take(array.len()) {
let left = (*item as f64) * 10_f64.pow(6);
assert_eq!(left as i128, decimal_array.value(i as usize));
}

// test f64 to decimal type
let f_data: Vec<f64> = vec![1.1, 2.2, 4.4, 1.123_456_789_123_4];
let array = Float64Array::from(f_data.clone());
let array = Arc::new(array) as ArrayRef;
let casted_array = cast(&array, &decimal_type).unwrap();
let decimal_array = casted_array
.as_any()
.downcast_ref::<DecimalArray>()
.unwrap();
assert_eq!(&decimal_type, decimal_array.data_type());
for (i, item) in f_data.iter().enumerate().take(array.len()) {
let left = (*item as f64) * 10_f64.pow(6);
assert_eq!(left as i128, decimal_array.value(i as usize));
}
}

#[test]
fn test_cast_i32_to_f64() {
Expand Down

0 comments on commit 311a519

Please sign in to comment.