diff --git a/arrow-cast/src/cast/mod.rs b/arrow-cast/src/cast/mod.rs index b751c81ee440..e3fad3da19f8 100644 --- a/arrow-cast/src/cast/mod.rs +++ b/arrow-cast/src/cast/mod.rs @@ -271,8 +271,9 @@ pub fn can_cast_types(from_type: &DataType, to_type: &DataType) -> bool { | Time64(Microsecond) | Time64(Nanosecond), ) => true, - (Int64, Duration(_)) => true, - (Duration(_), Int64) => true, + (_, Duration(_)) if from_type.is_numeric() => true, + (Duration(_), _) if to_type.is_numeric() => true, + (Duration(_), Duration(_)) => true, (Interval(from_type), Int64) => { match from_type { YearMonth => true, @@ -518,6 +519,15 @@ fn make_timestamp_array( } } +fn make_duration_array(array: &PrimitiveArray, unit: TimeUnit) -> ArrayRef { + match unit { + TimeUnit::Second => Arc::new(array.reinterpret_cast::()), + TimeUnit::Millisecond => Arc::new(array.reinterpret_cast::()), + TimeUnit::Microsecond => Arc::new(array.reinterpret_cast::()), + TimeUnit::Nanosecond => Arc::new(array.reinterpret_cast::()), + } +} + fn as_time_res_with_timezone( v: i64, tz: Option, @@ -2074,31 +2084,53 @@ pub fn cast_with_options( .as_primitive::() .unary::<_, TimestampNanosecondType>(|x| (x as i64) * NANOSECONDS_IN_DAY), )), - (Int64, Duration(TimeUnit::Second)) => { - cast_reinterpret_arrays::(array) - } - (Int64, Duration(TimeUnit::Millisecond)) => { - cast_reinterpret_arrays::(array) - } - (Int64, Duration(TimeUnit::Microsecond)) => { - cast_reinterpret_arrays::(array) + + (_, Duration(unit)) if from_type.is_numeric() => { + let array = cast_with_options(array, &Int64, cast_options)?; + Ok(make_duration_array(array.as_primitive(), *unit)) } - (Int64, Duration(TimeUnit::Nanosecond)) => { - cast_reinterpret_arrays::(array) + (Duration(TimeUnit::Second), _) if to_type.is_numeric() => { + let array = cast_reinterpret_arrays::(array)?; + cast_with_options(&array, to_type, cast_options) } - - (Duration(TimeUnit::Second), Int64) => { - cast_reinterpret_arrays::(array) + (Duration(TimeUnit::Millisecond), _) if to_type.is_numeric() => { + let array = cast_reinterpret_arrays::(array)?; + cast_with_options(&array, to_type, cast_options) } - (Duration(TimeUnit::Millisecond), Int64) => { - cast_reinterpret_arrays::(array) + (Duration(TimeUnit::Microsecond), _) if to_type.is_numeric() => { + let array = cast_reinterpret_arrays::(array)?; + cast_with_options(&array, to_type, cast_options) } - (Duration(TimeUnit::Microsecond), Int64) => { - cast_reinterpret_arrays::(array) + (Duration(TimeUnit::Nanosecond), _) if to_type.is_numeric() => { + let array = cast_reinterpret_arrays::(array)?; + cast_with_options(&array, to_type, cast_options) } - (Duration(TimeUnit::Nanosecond), Int64) => { - cast_reinterpret_arrays::(array) + + (Duration(from_unit), Duration(to_unit)) => { + let array = cast_with_options(array, &Int64, cast_options)?; + let time_array = array.as_primitive::(); + let from_size = time_unit_multiple(from_unit); + let to_size = time_unit_multiple(to_unit); + // we either divide or multiply, depending on size of each unit + // units are never the same when the types are the same + let converted = match from_size.cmp(&to_size) { + Ordering::Greater => { + let divisor = from_size / to_size; + time_array.unary::<_, Int64Type>(|o| o / divisor) + } + Ordering::Equal => time_array.clone(), + Ordering::Less => { + let mul = to_size / from_size; + if cast_options.safe { + time_array.unary_opt::<_, Int64Type>(|o| o.checked_mul(mul)) + } else { + time_array.try_unary::<_, Int64Type, _>(|o| o.mul_checked(mul))? + } + } + }; + Ok(make_duration_array(&converted, *to_unit)) } + (Duration(TimeUnit::Second), Interval(IntervalUnit::MonthDayNano)) => { cast_duration_to_interval::(array, cast_options) } @@ -5254,6 +5286,106 @@ mod tests { } } + #[test] + fn test_cast_between_durations_and_numerics() { + fn test_cast_between_durations() + where + FromType: ArrowPrimitiveType, + ToType: ArrowPrimitiveType, + PrimitiveArray: From>>, + { + let from_unit = match FromType::DATA_TYPE { + DataType::Duration(unit) => unit, + _ => panic!("Expected a duration type"), + }; + let to_unit = match ToType::DATA_TYPE { + DataType::Duration(unit) => unit, + _ => panic!("Expected a duration type"), + }; + let from_size = time_unit_multiple(&from_unit); + let to_size = time_unit_multiple(&to_unit); + + let (v1_before, v2_before) = (8640003005, 1696002001); + let (v1_after, v2_after) = if from_size >= to_size { + ( + v1_before / (from_size / to_size), + v2_before / (from_size / to_size), + ) + } else { + ( + v1_before * (to_size / from_size), + v2_before * (to_size / from_size), + ) + }; + + let array = + PrimitiveArray::::from(vec![Some(v1_before), Some(v2_before), None]); + let b = cast(&array, &ToType::DATA_TYPE).unwrap(); + let c = b.as_primitive::(); + assert_eq!(v1_after, c.value(0)); + assert_eq!(v2_after, c.value(1)); + assert!(c.is_null(2)); + } + + // between each individual duration type + test_cast_between_durations::(); + test_cast_between_durations::(); + test_cast_between_durations::(); + test_cast_between_durations::(); + test_cast_between_durations::(); + test_cast_between_durations::(); + test_cast_between_durations::(); + test_cast_between_durations::(); + test_cast_between_durations::(); + test_cast_between_durations::(); + test_cast_between_durations::(); + test_cast_between_durations::(); + + // cast failed + let array = DurationSecondArray::from(vec![ + Some(i64::MAX), + Some(8640203410378005), + Some(10241096), + None, + ]); + let b = cast(&array, &DataType::Duration(TimeUnit::Nanosecond)).unwrap(); + let c = b.as_primitive::(); + assert!(c.is_null(0)); + assert!(c.is_null(1)); + assert_eq!(10241096000000000, c.value(2)); + assert!(c.is_null(3)); + + // durations to numerics + let array = DurationSecondArray::from(vec![ + Some(i64::MAX), + Some(8640203410378005), + Some(10241096), + None, + ]); + let b = cast(&array, &DataType::Int64).unwrap(); + let c = b.as_primitive::(); + assert_eq!(i64::MAX, c.value(0)); + assert_eq!(8640203410378005, c.value(1)); + assert_eq!(10241096, c.value(2)); + assert!(c.is_null(3)); + + let b = cast(&array, &DataType::Int32).unwrap(); + let c = b.as_primitive::(); + assert_eq!(0, c.value(0)); + assert_eq!(0, c.value(1)); + assert_eq!(10241096, c.value(2)); + assert!(c.is_null(3)); + + // numerics to durations + let array = Int32Array::from(vec![Some(i32::MAX), Some(802034103), Some(10241096), None]); + let b = cast(&array, &DataType::Duration(TimeUnit::Second)).unwrap(); + let c = b.as_any().downcast_ref::().unwrap(); + assert_eq!(i32::MAX as i64, c.value(0)); + assert_eq!(802034103, c.value(1)); + assert_eq!(10241096, c.value(2)); + assert!(c.is_null(3)); + } + #[test] fn test_cast_to_strings() { let a = Int32Array::from(vec![1, 2, 3]);