Skip to content

Commit

Permalink
Support casting of Float16 with other numeric types (#5139)
Browse files Browse the repository at this point in the history
* Support casting of Float16 with other numeric types

* Add Float16 test cases
  • Loading branch information
viirya authored Nov 28, 2023
1 parent 093a10e commit c161456
Showing 1 changed file with 155 additions and 4 deletions.
159 changes: 155 additions & 4 deletions arrow-cast/src/cast.rs
Original file line number Diff line number Diff line change
Expand Up @@ -200,8 +200,8 @@ pub fn can_cast_types(from_type: &DataType, to_type: &DataType) -> bool {

// start numeric casts
(
UInt8 | UInt16 | UInt32 | UInt64 | Int8 | Int16 | Int32 | Int64 | Float32 | Float64,
UInt8 | UInt16 | UInt32 | UInt64 | Int8 | Int16 | Int32 | Int64 | Float32 | Float64,
UInt8 | UInt16 | UInt32 | UInt64 | Int8 | Int16 | Int32 | Int64 | Float16 | Float32 | Float64,
UInt8 | UInt16 | UInt32 | UInt64 | Int8 | Int16 | Int32 | Int64 | Float16 | Float32 | Float64,
) => true,
// end numeric casts

Expand All @@ -220,8 +220,8 @@ pub fn can_cast_types(from_type: &DataType, to_type: &DataType) -> bool {
(Time64(_), Time32(to_unit)) => {
matches!(to_unit, Second | Millisecond)
}
(Timestamp(_, _), _) if to_type.is_numeric() && to_type != &Float16 => true,
(_, Timestamp(_, _)) if from_type.is_numeric() && from_type != &Float16 => true,
(Timestamp(_, _), _) if to_type.is_numeric() => true,
(_, Timestamp(_, _)) if from_type.is_numeric() => true,
(Date64, Timestamp(_, None)) => true,
(Date32, Timestamp(_, None)) => true,
(
Expand Down Expand Up @@ -1367,6 +1367,7 @@ pub fn cast_with_options(
(UInt8, Int16) => cast_numeric_arrays::<UInt8Type, Int16Type>(array, cast_options),
(UInt8, Int32) => cast_numeric_arrays::<UInt8Type, Int32Type>(array, cast_options),
(UInt8, Int64) => cast_numeric_arrays::<UInt8Type, Int64Type>(array, cast_options),
(UInt8, Float16) => cast_numeric_arrays::<UInt8Type, Float16Type>(array, cast_options),
(UInt8, Float32) => cast_numeric_arrays::<UInt8Type, Float32Type>(array, cast_options),
(UInt8, Float64) => cast_numeric_arrays::<UInt8Type, Float64Type>(array, cast_options),

Expand All @@ -1377,6 +1378,7 @@ pub fn cast_with_options(
(UInt16, Int16) => cast_numeric_arrays::<UInt16Type, Int16Type>(array, cast_options),
(UInt16, Int32) => cast_numeric_arrays::<UInt16Type, Int32Type>(array, cast_options),
(UInt16, Int64) => cast_numeric_arrays::<UInt16Type, Int64Type>(array, cast_options),
(UInt16, Float16) => cast_numeric_arrays::<UInt16Type, Float16Type>(array, cast_options),
(UInt16, Float32) => cast_numeric_arrays::<UInt16Type, Float32Type>(array, cast_options),
(UInt16, Float64) => cast_numeric_arrays::<UInt16Type, Float64Type>(array, cast_options),

Expand All @@ -1387,6 +1389,7 @@ pub fn cast_with_options(
(UInt32, Int16) => cast_numeric_arrays::<UInt32Type, Int16Type>(array, cast_options),
(UInt32, Int32) => cast_numeric_arrays::<UInt32Type, Int32Type>(array, cast_options),
(UInt32, Int64) => cast_numeric_arrays::<UInt32Type, Int64Type>(array, cast_options),
(UInt32, Float16) => cast_numeric_arrays::<UInt32Type, Float16Type>(array, cast_options),
(UInt32, Float32) => cast_numeric_arrays::<UInt32Type, Float32Type>(array, cast_options),
(UInt32, Float64) => cast_numeric_arrays::<UInt32Type, Float64Type>(array, cast_options),

Expand All @@ -1397,6 +1400,7 @@ pub fn cast_with_options(
(UInt64, Int16) => cast_numeric_arrays::<UInt64Type, Int16Type>(array, cast_options),
(UInt64, Int32) => cast_numeric_arrays::<UInt64Type, Int32Type>(array, cast_options),
(UInt64, Int64) => cast_numeric_arrays::<UInt64Type, Int64Type>(array, cast_options),
(UInt64, Float16) => cast_numeric_arrays::<UInt64Type, Float16Type>(array, cast_options),
(UInt64, Float32) => cast_numeric_arrays::<UInt64Type, Float32Type>(array, cast_options),
(UInt64, Float64) => cast_numeric_arrays::<UInt64Type, Float64Type>(array, cast_options),

Expand All @@ -1407,6 +1411,7 @@ pub fn cast_with_options(
(Int8, Int16) => cast_numeric_arrays::<Int8Type, Int16Type>(array, cast_options),
(Int8, Int32) => cast_numeric_arrays::<Int8Type, Int32Type>(array, cast_options),
(Int8, Int64) => cast_numeric_arrays::<Int8Type, Int64Type>(array, cast_options),
(Int8, Float16) => cast_numeric_arrays::<Int8Type, Float16Type>(array, cast_options),
(Int8, Float32) => cast_numeric_arrays::<Int8Type, Float32Type>(array, cast_options),
(Int8, Float64) => cast_numeric_arrays::<Int8Type, Float64Type>(array, cast_options),

Expand All @@ -1417,6 +1422,7 @@ pub fn cast_with_options(
(Int16, Int8) => cast_numeric_arrays::<Int16Type, Int8Type>(array, cast_options),
(Int16, Int32) => cast_numeric_arrays::<Int16Type, Int32Type>(array, cast_options),
(Int16, Int64) => cast_numeric_arrays::<Int16Type, Int64Type>(array, cast_options),
(Int16, Float16) => cast_numeric_arrays::<Int16Type, Float16Type>(array, cast_options),
(Int16, Float32) => cast_numeric_arrays::<Int16Type, Float32Type>(array, cast_options),
(Int16, Float64) => cast_numeric_arrays::<Int16Type, Float64Type>(array, cast_options),

Expand All @@ -1427,6 +1433,7 @@ pub fn cast_with_options(
(Int32, Int8) => cast_numeric_arrays::<Int32Type, Int8Type>(array, cast_options),
(Int32, Int16) => cast_numeric_arrays::<Int32Type, Int16Type>(array, cast_options),
(Int32, Int64) => cast_numeric_arrays::<Int32Type, Int64Type>(array, cast_options),
(Int32, Float16) => cast_numeric_arrays::<Int32Type, Float16Type>(array, cast_options),
(Int32, Float32) => cast_numeric_arrays::<Int32Type, Float32Type>(array, cast_options),
(Int32, Float64) => cast_numeric_arrays::<Int32Type, Float64Type>(array, cast_options),

Expand All @@ -1437,9 +1444,21 @@ pub fn cast_with_options(
(Int64, Int8) => cast_numeric_arrays::<Int64Type, Int8Type>(array, cast_options),
(Int64, Int16) => cast_numeric_arrays::<Int64Type, Int16Type>(array, cast_options),
(Int64, Int32) => cast_numeric_arrays::<Int64Type, Int32Type>(array, cast_options),
(Int64, Float16) => cast_numeric_arrays::<Int64Type, Float16Type>(array, cast_options),
(Int64, Float32) => cast_numeric_arrays::<Int64Type, Float32Type>(array, cast_options),
(Int64, Float64) => cast_numeric_arrays::<Int64Type, Float64Type>(array, cast_options),

(Float16, UInt8) => cast_numeric_arrays::<Float16Type, UInt8Type>(array, cast_options),
(Float16, UInt16) => cast_numeric_arrays::<Float16Type, UInt16Type>(array, cast_options),
(Float16, UInt32) => cast_numeric_arrays::<Float16Type, UInt32Type>(array, cast_options),
(Float16, UInt64) => cast_numeric_arrays::<Float16Type, UInt64Type>(array, cast_options),
(Float16, Int8) => cast_numeric_arrays::<Float16Type, Int8Type>(array, cast_options),
(Float16, Int16) => cast_numeric_arrays::<Float16Type, Int16Type>(array, cast_options),
(Float16, Int32) => cast_numeric_arrays::<Float16Type, Int32Type>(array, cast_options),
(Float16, Int64) => cast_numeric_arrays::<Float16Type, Int64Type>(array, cast_options),
(Float16, Float32) => cast_numeric_arrays::<Float16Type, Float32Type>(array, cast_options),
(Float16, Float64) => cast_numeric_arrays::<Float16Type, Float64Type>(array, cast_options),

(Float32, UInt8) => cast_numeric_arrays::<Float32Type, UInt8Type>(array, cast_options),
(Float32, UInt16) => cast_numeric_arrays::<Float32Type, UInt16Type>(array, cast_options),
(Float32, UInt32) => cast_numeric_arrays::<Float32Type, UInt32Type>(array, cast_options),
Expand All @@ -1448,6 +1467,7 @@ pub fn cast_with_options(
(Float32, Int16) => cast_numeric_arrays::<Float32Type, Int16Type>(array, cast_options),
(Float32, Int32) => cast_numeric_arrays::<Float32Type, Int32Type>(array, cast_options),
(Float32, Int64) => cast_numeric_arrays::<Float32Type, Int64Type>(array, cast_options),
(Float32, Float16) => cast_numeric_arrays::<Float32Type, Float16Type>(array, cast_options),
(Float32, Float64) => cast_numeric_arrays::<Float32Type, Float64Type>(array, cast_options),

(Float64, UInt8) => cast_numeric_arrays::<Float64Type, UInt8Type>(array, cast_options),
Expand All @@ -1458,6 +1478,7 @@ pub fn cast_with_options(
(Float64, Int16) => cast_numeric_arrays::<Float64Type, Int16Type>(array, cast_options),
(Float64, Int32) => cast_numeric_arrays::<Float64Type, Int32Type>(array, cast_options),
(Float64, Int64) => cast_numeric_arrays::<Float64Type, Int64Type>(array, cast_options),
(Float64, Float16) => cast_numeric_arrays::<Float64Type, Float16Type>(array, cast_options),
(Float64, Float32) => cast_numeric_arrays::<Float64Type, Float32Type>(array, cast_options),
// end numeric casts

Expand Down Expand Up @@ -3299,6 +3320,7 @@ fn cast_list<I: OffsetSizeTrait, O: OffsetSizeTrait>(
#[cfg(test)]
mod tests {
use arrow_buffer::{Buffer, NullBuffer};
use half::f16;

use super::*;

Expand Down Expand Up @@ -4665,6 +4687,15 @@ mod tests {
let array = Int64Array::from(vec![Some(2), Some(10), None]);
let expected = cast(&array, &DataType::Timestamp(TimeUnit::Microsecond, None)).unwrap();

let array = Float16Array::from(vec![
Some(f16::from_f32(2.0)),
Some(f16::from_f32(10.6)),
None,
]);
let actual = cast(&array, &DataType::Timestamp(TimeUnit::Microsecond, None)).unwrap();

assert_eq!(&actual, &expected);

let array = Float32Array::from(vec![Some(2.0), Some(10.6), None]);
let actual = cast(&array, &DataType::Timestamp(TimeUnit::Microsecond, None)).unwrap();

Expand All @@ -4682,6 +4713,9 @@ mod tests {
.with_timezone("UTC".to_string());
let expected = cast(&array, &DataType::Int64).unwrap();

let actual = cast(&cast(&array, &DataType::Float16).unwrap(), &DataType::Int64).unwrap();
assert_eq!(&actual, &expected);

let actual = cast(&cast(&array, &DataType::Float32).unwrap(), &DataType::Int64).unwrap();
assert_eq!(&actual, &expected);

Expand Down Expand Up @@ -6103,6 +6137,25 @@ mod tests {
.collect::<Vec<f32>>()
);

let f16_expected = vec![
f16::from_f64(-9223372000000000000.0),
f16::from_f64(-2147483600.0),
f16::from_f64(-32768.0),
f16::from_f64(-128.0),
f16::from_f64(0.0),
f16::from_f64(255.0),
f16::from_f64(65535.0),
f16::from_f64(4294967300.0),
f16::from_f64(18446744000000000000.0),
];
assert_eq!(
f16_expected,
get_cast_values::<Float16Type>(&f64_array, &DataType::Float16)
.iter()
.map(|i| i.parse::<f16>().unwrap())
.collect::<Vec<f16>>()
);

let i64_expected = vec![
"-9223372036854775808",
"-2147483648",
Expand Down Expand Up @@ -6247,6 +6300,14 @@ mod tests {
get_cast_values::<Float32Type>(&f32_array, &DataType::Float32)
);

let f16_expected = vec![
"-inf", "-inf", "-32768.0", "-128.0", "0.0", "255.0", "inf", "inf", "inf",
];
assert_eq!(
f16_expected,
get_cast_values::<Float16Type>(&f32_array, &DataType::Float16)
);

let i64_expected = vec![
"-2147483648",
"-2147483648",
Expand Down Expand Up @@ -6365,6 +6426,21 @@ mod tests {
.collect::<Vec<f32>>()
);

let f16_expected = vec![
f16::from_f64(0.0),
f16::from_f64(255.0),
f16::from_f64(65535.0),
f16::from_f64(4294967300.0),
f16::from_f64(18446744000000000000.0),
];
assert_eq!(
f16_expected,
get_cast_values::<Float16Type>(&u64_array, &DataType::Float16)
.iter()
.map(|i| i.parse::<f16>().unwrap())
.collect::<Vec<f16>>()
);

let i64_expected = vec!["0", "255", "65535", "4294967295", "null"];
assert_eq!(
i64_expected,
Expand Down Expand Up @@ -6431,6 +6507,12 @@ mod tests {
get_cast_values::<Float32Type>(&u32_array, &DataType::Float32)
);

let f16_expected = vec!["0.0", "255.0", "inf", "inf"];
assert_eq!(
f16_expected,
get_cast_values::<Float16Type>(&u32_array, &DataType::Float16)
);

let i64_expected = vec!["0", "255", "65535", "4294967295"];
assert_eq!(
i64_expected,
Expand Down Expand Up @@ -6497,6 +6579,12 @@ mod tests {
get_cast_values::<Float32Type>(&u16_array, &DataType::Float32)
);

let f16_expected = vec!["0.0", "255.0", "inf"];
assert_eq!(
f16_expected,
get_cast_values::<Float16Type>(&u16_array, &DataType::Float16)
);

let i64_expected = vec!["0", "255", "65535"];
assert_eq!(
i64_expected,
Expand Down Expand Up @@ -6563,6 +6651,12 @@ mod tests {
get_cast_values::<Float32Type>(&u8_array, &DataType::Float32)
);

let f16_expected = vec!["0.0", "255.0"];
assert_eq!(
f16_expected,
get_cast_values::<Float16Type>(&u8_array, &DataType::Float16)
);

let i64_expected = vec!["0", "255"];
assert_eq!(
i64_expected,
Expand Down Expand Up @@ -6665,6 +6759,25 @@ mod tests {
.collect::<Vec<f32>>()
);

let f16_expected = vec![
f16::from_f64(-9223372000000000000.0),
f16::from_f64(-2147483600.0),
f16::from_f64(-32768.0),
f16::from_f64(-128.0),
f16::from_f64(0.0),
f16::from_f64(127.0),
f16::from_f64(32767.0),
f16::from_f64(2147483600.0),
f16::from_f64(9223372000000000000.0),
];
assert_eq!(
f16_expected,
get_cast_values::<Float16Type>(&i64_array, &DataType::Float16)
.iter()
.map(|i| i.parse::<f16>().unwrap())
.collect::<Vec<f16>>()
);

let i64_expected = vec![
"-9223372036854775808",
"-2147483648",
Expand Down Expand Up @@ -6808,6 +6921,23 @@ mod tests {
get_cast_values::<Float32Type>(&i32_array, &DataType::Float32)
);

let f16_expected = vec![
f16::from_f64(-2147483600.0),
f16::from_f64(-32768.0),
f16::from_f64(-128.0),
f16::from_f64(0.0),
f16::from_f64(127.0),
f16::from_f64(32767.0),
f16::from_f64(2147483600.0),
];
assert_eq!(
f16_expected,
get_cast_values::<Float16Type>(&i32_array, &DataType::Float16)
.iter()
.map(|i| i.parse::<f16>().unwrap())
.collect::<Vec<f16>>()
);

let i16_expected = vec!["null", "-32768", "-128", "0", "127", "32767", "null"];
assert_eq!(
i16_expected,
Expand Down Expand Up @@ -6877,6 +7007,21 @@ mod tests {
get_cast_values::<Float32Type>(&i16_array, &DataType::Float32)
);

let f16_expected = vec![
f16::from_f64(-32768.0),
f16::from_f64(-128.0),
f16::from_f64(0.0),
f16::from_f64(127.0),
f16::from_f64(32767.0),
];
assert_eq!(
f16_expected,
get_cast_values::<Float16Type>(&i16_array, &DataType::Float16)
.iter()
.map(|i| i.parse::<f16>().unwrap())
.collect::<Vec<f16>>()
);

let i64_expected = vec!["-32768", "-128", "0", "127", "32767"];
assert_eq!(
i64_expected,
Expand Down Expand Up @@ -6971,6 +7116,12 @@ mod tests {
get_cast_values::<Float32Type>(&i8_array, &DataType::Float32)
);

let f16_expected = vec!["-128.0", "0.0", "127.0"];
assert_eq!(
f16_expected,
get_cast_values::<Float16Type>(&i8_array, &DataType::Float16)
);

let i64_expected = vec!["-128", "0", "127"];
assert_eq!(
i64_expected,
Expand Down

0 comments on commit c161456

Please sign in to comment.