From c161456158b122345788f86e9302fb4b5340a31e Mon Sep 17 00:00:00 2001 From: Liang-Chi Hsieh Date: Tue, 28 Nov 2023 12:51:30 -0800 Subject: [PATCH] Support casting of Float16 with other numeric types (#5139) * Support casting of Float16 with other numeric types * Add Float16 test cases --- arrow-cast/src/cast.rs | 159 +++++++++++++++++++++++++++++++++++++++-- 1 file changed, 155 insertions(+), 4 deletions(-) diff --git a/arrow-cast/src/cast.rs b/arrow-cast/src/cast.rs index 8facb4f161f4..51acd36c3fe4 100644 --- a/arrow-cast/src/cast.rs +++ b/arrow-cast/src/cast.rs @@ -200,8 +200,8 @@ pub fn can_cast_types(from_type: &DataType, to_type: &DataType) -> bool { // start numeric casts ( - UInt8 | UInt16 | UInt32 | UInt64 | Int8 | Int16 | Int32 | Int64 | Float32 | Float64, - UInt8 | UInt16 | UInt32 | UInt64 | Int8 | Int16 | Int32 | Int64 | Float32 | Float64, + UInt8 | UInt16 | UInt32 | UInt64 | Int8 | Int16 | Int32 | Int64 | Float16 | Float32 | Float64, + UInt8 | UInt16 | UInt32 | UInt64 | Int8 | Int16 | Int32 | Int64 | Float16 | Float32 | Float64, ) => true, // end numeric casts @@ -220,8 +220,8 @@ pub fn can_cast_types(from_type: &DataType, to_type: &DataType) -> bool { (Time64(_), Time32(to_unit)) => { matches!(to_unit, Second | Millisecond) } - (Timestamp(_, _), _) if to_type.is_numeric() && to_type != &Float16 => true, - (_, Timestamp(_, _)) if from_type.is_numeric() && from_type != &Float16 => true, + (Timestamp(_, _), _) if to_type.is_numeric() => true, + (_, Timestamp(_, _)) if from_type.is_numeric() => true, (Date64, Timestamp(_, None)) => true, (Date32, Timestamp(_, None)) => true, ( @@ -1367,6 +1367,7 @@ pub fn cast_with_options( (UInt8, Int16) => cast_numeric_arrays::(array, cast_options), (UInt8, Int32) => cast_numeric_arrays::(array, cast_options), (UInt8, Int64) => cast_numeric_arrays::(array, cast_options), + (UInt8, Float16) => cast_numeric_arrays::(array, cast_options), (UInt8, Float32) => cast_numeric_arrays::(array, cast_options), (UInt8, Float64) => cast_numeric_arrays::(array, cast_options), @@ -1377,6 +1378,7 @@ pub fn cast_with_options( (UInt16, Int16) => cast_numeric_arrays::(array, cast_options), (UInt16, Int32) => cast_numeric_arrays::(array, cast_options), (UInt16, Int64) => cast_numeric_arrays::(array, cast_options), + (UInt16, Float16) => cast_numeric_arrays::(array, cast_options), (UInt16, Float32) => cast_numeric_arrays::(array, cast_options), (UInt16, Float64) => cast_numeric_arrays::(array, cast_options), @@ -1387,6 +1389,7 @@ pub fn cast_with_options( (UInt32, Int16) => cast_numeric_arrays::(array, cast_options), (UInt32, Int32) => cast_numeric_arrays::(array, cast_options), (UInt32, Int64) => cast_numeric_arrays::(array, cast_options), + (UInt32, Float16) => cast_numeric_arrays::(array, cast_options), (UInt32, Float32) => cast_numeric_arrays::(array, cast_options), (UInt32, Float64) => cast_numeric_arrays::(array, cast_options), @@ -1397,6 +1400,7 @@ pub fn cast_with_options( (UInt64, Int16) => cast_numeric_arrays::(array, cast_options), (UInt64, Int32) => cast_numeric_arrays::(array, cast_options), (UInt64, Int64) => cast_numeric_arrays::(array, cast_options), + (UInt64, Float16) => cast_numeric_arrays::(array, cast_options), (UInt64, Float32) => cast_numeric_arrays::(array, cast_options), (UInt64, Float64) => cast_numeric_arrays::(array, cast_options), @@ -1407,6 +1411,7 @@ pub fn cast_with_options( (Int8, Int16) => cast_numeric_arrays::(array, cast_options), (Int8, Int32) => cast_numeric_arrays::(array, cast_options), (Int8, Int64) => cast_numeric_arrays::(array, cast_options), + (Int8, Float16) => cast_numeric_arrays::(array, cast_options), (Int8, Float32) => cast_numeric_arrays::(array, cast_options), (Int8, Float64) => cast_numeric_arrays::(array, cast_options), @@ -1417,6 +1422,7 @@ pub fn cast_with_options( (Int16, Int8) => cast_numeric_arrays::(array, cast_options), (Int16, Int32) => cast_numeric_arrays::(array, cast_options), (Int16, Int64) => cast_numeric_arrays::(array, cast_options), + (Int16, Float16) => cast_numeric_arrays::(array, cast_options), (Int16, Float32) => cast_numeric_arrays::(array, cast_options), (Int16, Float64) => cast_numeric_arrays::(array, cast_options), @@ -1427,6 +1433,7 @@ pub fn cast_with_options( (Int32, Int8) => cast_numeric_arrays::(array, cast_options), (Int32, Int16) => cast_numeric_arrays::(array, cast_options), (Int32, Int64) => cast_numeric_arrays::(array, cast_options), + (Int32, Float16) => cast_numeric_arrays::(array, cast_options), (Int32, Float32) => cast_numeric_arrays::(array, cast_options), (Int32, Float64) => cast_numeric_arrays::(array, cast_options), @@ -1437,9 +1444,21 @@ pub fn cast_with_options( (Int64, Int8) => cast_numeric_arrays::(array, cast_options), (Int64, Int16) => cast_numeric_arrays::(array, cast_options), (Int64, Int32) => cast_numeric_arrays::(array, cast_options), + (Int64, Float16) => cast_numeric_arrays::(array, cast_options), (Int64, Float32) => cast_numeric_arrays::(array, cast_options), (Int64, Float64) => cast_numeric_arrays::(array, cast_options), + (Float16, UInt8) => cast_numeric_arrays::(array, cast_options), + (Float16, UInt16) => cast_numeric_arrays::(array, cast_options), + (Float16, UInt32) => cast_numeric_arrays::(array, cast_options), + (Float16, UInt64) => cast_numeric_arrays::(array, cast_options), + (Float16, Int8) => cast_numeric_arrays::(array, cast_options), + (Float16, Int16) => cast_numeric_arrays::(array, cast_options), + (Float16, Int32) => cast_numeric_arrays::(array, cast_options), + (Float16, Int64) => cast_numeric_arrays::(array, cast_options), + (Float16, Float32) => cast_numeric_arrays::(array, cast_options), + (Float16, Float64) => cast_numeric_arrays::(array, cast_options), + (Float32, UInt8) => cast_numeric_arrays::(array, cast_options), (Float32, UInt16) => cast_numeric_arrays::(array, cast_options), (Float32, UInt32) => cast_numeric_arrays::(array, cast_options), @@ -1448,6 +1467,7 @@ pub fn cast_with_options( (Float32, Int16) => cast_numeric_arrays::(array, cast_options), (Float32, Int32) => cast_numeric_arrays::(array, cast_options), (Float32, Int64) => cast_numeric_arrays::(array, cast_options), + (Float32, Float16) => cast_numeric_arrays::(array, cast_options), (Float32, Float64) => cast_numeric_arrays::(array, cast_options), (Float64, UInt8) => cast_numeric_arrays::(array, cast_options), @@ -1458,6 +1478,7 @@ pub fn cast_with_options( (Float64, Int16) => cast_numeric_arrays::(array, cast_options), (Float64, Int32) => cast_numeric_arrays::(array, cast_options), (Float64, Int64) => cast_numeric_arrays::(array, cast_options), + (Float64, Float16) => cast_numeric_arrays::(array, cast_options), (Float64, Float32) => cast_numeric_arrays::(array, cast_options), // end numeric casts @@ -3299,6 +3320,7 @@ fn cast_list( #[cfg(test)] mod tests { use arrow_buffer::{Buffer, NullBuffer}; + use half::f16; use super::*; @@ -4665,6 +4687,15 @@ mod tests { let array = Int64Array::from(vec![Some(2), Some(10), None]); let expected = cast(&array, &DataType::Timestamp(TimeUnit::Microsecond, None)).unwrap(); + let array = Float16Array::from(vec![ + Some(f16::from_f32(2.0)), + Some(f16::from_f32(10.6)), + None, + ]); + let actual = cast(&array, &DataType::Timestamp(TimeUnit::Microsecond, None)).unwrap(); + + assert_eq!(&actual, &expected); + let array = Float32Array::from(vec![Some(2.0), Some(10.6), None]); let actual = cast(&array, &DataType::Timestamp(TimeUnit::Microsecond, None)).unwrap(); @@ -4682,6 +4713,9 @@ mod tests { .with_timezone("UTC".to_string()); let expected = cast(&array, &DataType::Int64).unwrap(); + let actual = cast(&cast(&array, &DataType::Float16).unwrap(), &DataType::Int64).unwrap(); + assert_eq!(&actual, &expected); + let actual = cast(&cast(&array, &DataType::Float32).unwrap(), &DataType::Int64).unwrap(); assert_eq!(&actual, &expected); @@ -6103,6 +6137,25 @@ mod tests { .collect::>() ); + let f16_expected = vec![ + f16::from_f64(-9223372000000000000.0), + f16::from_f64(-2147483600.0), + f16::from_f64(-32768.0), + f16::from_f64(-128.0), + f16::from_f64(0.0), + f16::from_f64(255.0), + f16::from_f64(65535.0), + f16::from_f64(4294967300.0), + f16::from_f64(18446744000000000000.0), + ]; + assert_eq!( + f16_expected, + get_cast_values::(&f64_array, &DataType::Float16) + .iter() + .map(|i| i.parse::().unwrap()) + .collect::>() + ); + let i64_expected = vec![ "-9223372036854775808", "-2147483648", @@ -6247,6 +6300,14 @@ mod tests { get_cast_values::(&f32_array, &DataType::Float32) ); + let f16_expected = vec![ + "-inf", "-inf", "-32768.0", "-128.0", "0.0", "255.0", "inf", "inf", "inf", + ]; + assert_eq!( + f16_expected, + get_cast_values::(&f32_array, &DataType::Float16) + ); + let i64_expected = vec![ "-2147483648", "-2147483648", @@ -6365,6 +6426,21 @@ mod tests { .collect::>() ); + let f16_expected = vec![ + f16::from_f64(0.0), + f16::from_f64(255.0), + f16::from_f64(65535.0), + f16::from_f64(4294967300.0), + f16::from_f64(18446744000000000000.0), + ]; + assert_eq!( + f16_expected, + get_cast_values::(&u64_array, &DataType::Float16) + .iter() + .map(|i| i.parse::().unwrap()) + .collect::>() + ); + let i64_expected = vec!["0", "255", "65535", "4294967295", "null"]; assert_eq!( i64_expected, @@ -6431,6 +6507,12 @@ mod tests { get_cast_values::(&u32_array, &DataType::Float32) ); + let f16_expected = vec!["0.0", "255.0", "inf", "inf"]; + assert_eq!( + f16_expected, + get_cast_values::(&u32_array, &DataType::Float16) + ); + let i64_expected = vec!["0", "255", "65535", "4294967295"]; assert_eq!( i64_expected, @@ -6497,6 +6579,12 @@ mod tests { get_cast_values::(&u16_array, &DataType::Float32) ); + let f16_expected = vec!["0.0", "255.0", "inf"]; + assert_eq!( + f16_expected, + get_cast_values::(&u16_array, &DataType::Float16) + ); + let i64_expected = vec!["0", "255", "65535"]; assert_eq!( i64_expected, @@ -6563,6 +6651,12 @@ mod tests { get_cast_values::(&u8_array, &DataType::Float32) ); + let f16_expected = vec!["0.0", "255.0"]; + assert_eq!( + f16_expected, + get_cast_values::(&u8_array, &DataType::Float16) + ); + let i64_expected = vec!["0", "255"]; assert_eq!( i64_expected, @@ -6665,6 +6759,25 @@ mod tests { .collect::>() ); + let f16_expected = vec![ + f16::from_f64(-9223372000000000000.0), + f16::from_f64(-2147483600.0), + f16::from_f64(-32768.0), + f16::from_f64(-128.0), + f16::from_f64(0.0), + f16::from_f64(127.0), + f16::from_f64(32767.0), + f16::from_f64(2147483600.0), + f16::from_f64(9223372000000000000.0), + ]; + assert_eq!( + f16_expected, + get_cast_values::(&i64_array, &DataType::Float16) + .iter() + .map(|i| i.parse::().unwrap()) + .collect::>() + ); + let i64_expected = vec![ "-9223372036854775808", "-2147483648", @@ -6808,6 +6921,23 @@ mod tests { get_cast_values::(&i32_array, &DataType::Float32) ); + let f16_expected = vec![ + f16::from_f64(-2147483600.0), + f16::from_f64(-32768.0), + f16::from_f64(-128.0), + f16::from_f64(0.0), + f16::from_f64(127.0), + f16::from_f64(32767.0), + f16::from_f64(2147483600.0), + ]; + assert_eq!( + f16_expected, + get_cast_values::(&i32_array, &DataType::Float16) + .iter() + .map(|i| i.parse::().unwrap()) + .collect::>() + ); + let i16_expected = vec!["null", "-32768", "-128", "0", "127", "32767", "null"]; assert_eq!( i16_expected, @@ -6877,6 +7007,21 @@ mod tests { get_cast_values::(&i16_array, &DataType::Float32) ); + let f16_expected = vec![ + f16::from_f64(-32768.0), + f16::from_f64(-128.0), + f16::from_f64(0.0), + f16::from_f64(127.0), + f16::from_f64(32767.0), + ]; + assert_eq!( + f16_expected, + get_cast_values::(&i16_array, &DataType::Float16) + .iter() + .map(|i| i.parse::().unwrap()) + .collect::>() + ); + let i64_expected = vec!["-32768", "-128", "0", "127", "32767"]; assert_eq!( i64_expected, @@ -6971,6 +7116,12 @@ mod tests { get_cast_values::(&i8_array, &DataType::Float32) ); + let f16_expected = vec!["-128.0", "0.0", "127.0"]; + assert_eq!( + f16_expected, + get_cast_values::(&i8_array, &DataType::Float16) + ); + let i64_expected = vec!["-128", "0", "127"]; assert_eq!( i64_expected,