From 0d01612488a0bca8c941467239af6ebf083da709 Mon Sep 17 00:00:00 2001 From: jackwener Date: Sun, 3 Sep 2023 23:27:34 +0800 Subject: [PATCH] refactor: simplify cast_string_to_interval --- arrow-cast/src/cast.rs | 93 +++++++++++++++--------------------------- 1 file changed, 33 insertions(+), 60 deletions(-) diff --git a/arrow-cast/src/cast.rs b/arrow-cast/src/cast.rs index 7f8bd19e9291..a75354cf9b35 100644 --- a/arrow-cast/src/cast.rs +++ b/arrow-cast/src/cast.rs @@ -2444,10 +2444,16 @@ fn cast_string_to_timestamp_impl( +fn cast_string_to_interval( array: &dyn Array, cast_options: &CastOptions, -) -> Result { + parse_function: F, +) -> Result +where + Offset: OffsetSizeTrait, + ArrowType: ArrowPrimitiveType, + F: Fn(&str) -> Result + Copy, +{ let string_array = array .as_any() .downcast_ref::>() @@ -2455,92 +2461,59 @@ fn cast_string_to_year_month_interval( let interval_array = if cast_options.safe { let iter = string_array .iter() - .map(|v| v.and_then(|v| parse_interval_year_month(v).ok())); + .map(|v| v.and_then(|v| parse_function(v).ok())); // Benefit: // 20% performance improvement // Soundness: // The iterator is trustedLen because it comes from an `StringArray`. - unsafe { IntervalYearMonthArray::from_trusted_len_iter(iter) } + unsafe { PrimitiveArray::::from_trusted_len_iter(iter) } } else { let vec = string_array .iter() - .map(|v| v.map(parse_interval_year_month).transpose()) + .map(|v| v.map(parse_function).transpose()) .collect::, ArrowError>>()?; // Benefit: // 20% performance improvement // Soundness: // The iterator is trustedLen because it comes from an `StringArray`. - unsafe { IntervalYearMonthArray::from_trusted_len_iter(vec) } + unsafe { PrimitiveArray::::from_trusted_len_iter(vec) } }; Ok(Arc::new(interval_array) as ArrayRef) } -fn cast_string_to_day_time_interval( +fn cast_string_to_year_month_interval( array: &dyn Array, cast_options: &CastOptions, ) -> Result { - let string_array = array - .as_any() - .downcast_ref::>() - .unwrap(); - let interval_array = if cast_options.safe { - let iter = string_array - .iter() - .map(|v| v.and_then(|v| parse_interval_day_time(v).ok())); - - // Benefit: - // 20% performance improvement - // Soundness: - // The iterator is trustedLen because it comes from an `StringArray`. - unsafe { IntervalDayTimeArray::from_trusted_len_iter(iter) } - } else { - let vec = string_array - .iter() - .map(|v| v.map(parse_interval_day_time).transpose()) - .collect::, ArrowError>>()?; + cast_string_to_interval::( + array, + cast_options, + parse_interval_year_month, + ) +} - // Benefit: - // 20% performance improvement - // Soundness: - // The iterator is trustedLen because it comes from an `StringArray`. - unsafe { IntervalDayTimeArray::from_trusted_len_iter(vec) } - }; - Ok(Arc::new(interval_array) as ArrayRef) +fn cast_string_to_day_time_interval( + array: &dyn Array, + cast_options: &CastOptions, +) -> Result { + cast_string_to_interval::( + array, + cast_options, + parse_interval_day_time, + ) } fn cast_string_to_month_day_nano_interval( array: &dyn Array, cast_options: &CastOptions, ) -> Result { - let string_array = array - .as_any() - .downcast_ref::>() - .unwrap(); - let interval_array = if cast_options.safe { - let iter = string_array - .iter() - .map(|v| v.and_then(|v| parse_interval_month_day_nano(v).ok())); - - // Benefit: - // 20% performance improvement - // Soundness: - // The iterator is trustedLen because it comes from an `StringArray`. - unsafe { IntervalMonthDayNanoArray::from_trusted_len_iter(iter) } - } else { - let vec = string_array - .iter() - .map(|v| v.map(parse_interval_month_day_nano).transpose()) - .collect::, ArrowError>>()?; - - // Benefit: - // 20% performance improvement - // Soundness: - // The iterator is trustedLen because it comes from an `StringArray`. - unsafe { IntervalMonthDayNanoArray::from_trusted_len_iter(vec) } - }; - Ok(Arc::new(interval_array) as ArrayRef) + cast_string_to_interval::( + array, + cast_options, + parse_interval_month_day_nano, + ) } fn adjust_timestamp_to_timezone(