Skip to content

Commit

Permalink
refactor: simplify cast_string_to_interval
Browse files Browse the repository at this point in the history
  • Loading branch information
jackwener committed Dec 9, 2023
1 parent 9630aaf commit 0d01612
Showing 1 changed file with 33 additions and 60 deletions.
93 changes: 33 additions & 60 deletions arrow-cast/src/cast.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2444,103 +2444,76 @@ fn cast_string_to_timestamp_impl<O: OffsetSizeTrait, T: ArrowTimestampType, Tz:
}
}

fn cast_string_to_year_month_interval<Offset: OffsetSizeTrait>(
fn cast_string_to_interval<Offset, F, ArrowType>(
array: &dyn Array,
cast_options: &CastOptions,
) -> Result<ArrayRef, ArrowError> {
parse_function: F,
) -> Result<ArrayRef, ArrowError>
where
Offset: OffsetSizeTrait,
ArrowType: ArrowPrimitiveType,
F: Fn(&str) -> Result<ArrowType::Native, ArrowError> + Copy,
{
let string_array = array
.as_any()
.downcast_ref::<GenericStringArray<Offset>>()
.unwrap();
let interval_array = if cast_options.safe {
let iter = string_array
.iter()
.map(|v| v.and_then(|v| parse_interval_year_month(v).ok()));
.map(|v| v.and_then(|v| parse_function(v).ok()));

// Benefit:
// 20% performance improvement
// Soundness:
// The iterator is trustedLen because it comes from an `StringArray`.
unsafe { IntervalYearMonthArray::from_trusted_len_iter(iter) }
unsafe { PrimitiveArray::<ArrowType>::from_trusted_len_iter(iter) }
} else {
let vec = string_array
.iter()
.map(|v| v.map(parse_interval_year_month).transpose())
.map(|v| v.map(parse_function).transpose())
.collect::<Result<Vec<_>, ArrowError>>()?;

// Benefit:
// 20% performance improvement
// Soundness:
// The iterator is trustedLen because it comes from an `StringArray`.
unsafe { IntervalYearMonthArray::from_trusted_len_iter(vec) }
unsafe { PrimitiveArray::<ArrowType>::from_trusted_len_iter(vec) }
};
Ok(Arc::new(interval_array) as ArrayRef)
}

fn cast_string_to_day_time_interval<Offset: OffsetSizeTrait>(
fn cast_string_to_year_month_interval<Offset: OffsetSizeTrait>(
array: &dyn Array,
cast_options: &CastOptions,
) -> Result<ArrayRef, ArrowError> {
let string_array = array
.as_any()
.downcast_ref::<GenericStringArray<Offset>>()
.unwrap();
let interval_array = if cast_options.safe {
let iter = string_array
.iter()
.map(|v| v.and_then(|v| parse_interval_day_time(v).ok()));

// Benefit:
// 20% performance improvement
// Soundness:
// The iterator is trustedLen because it comes from an `StringArray`.
unsafe { IntervalDayTimeArray::from_trusted_len_iter(iter) }
} else {
let vec = string_array
.iter()
.map(|v| v.map(parse_interval_day_time).transpose())
.collect::<Result<Vec<_>, ArrowError>>()?;
cast_string_to_interval::<Offset, _, IntervalYearMonthType>(
array,
cast_options,
parse_interval_year_month,
)
}

// Benefit:
// 20% performance improvement
// Soundness:
// The iterator is trustedLen because it comes from an `StringArray`.
unsafe { IntervalDayTimeArray::from_trusted_len_iter(vec) }
};
Ok(Arc::new(interval_array) as ArrayRef)
fn cast_string_to_day_time_interval<Offset: OffsetSizeTrait>(
array: &dyn Array,
cast_options: &CastOptions,
) -> Result<ArrayRef, ArrowError> {
cast_string_to_interval::<Offset, _, IntervalDayTimeType>(
array,
cast_options,
parse_interval_day_time,
)
}

fn cast_string_to_month_day_nano_interval<Offset: OffsetSizeTrait>(
array: &dyn Array,
cast_options: &CastOptions,
) -> Result<ArrayRef, ArrowError> {
let string_array = array
.as_any()
.downcast_ref::<GenericStringArray<Offset>>()
.unwrap();
let interval_array = if cast_options.safe {
let iter = string_array
.iter()
.map(|v| v.and_then(|v| parse_interval_month_day_nano(v).ok()));

// Benefit:
// 20% performance improvement
// Soundness:
// The iterator is trustedLen because it comes from an `StringArray`.
unsafe { IntervalMonthDayNanoArray::from_trusted_len_iter(iter) }
} else {
let vec = string_array
.iter()
.map(|v| v.map(parse_interval_month_day_nano).transpose())
.collect::<Result<Vec<_>, ArrowError>>()?;

// Benefit:
// 20% performance improvement
// Soundness:
// The iterator is trustedLen because it comes from an `StringArray`.
unsafe { IntervalMonthDayNanoArray::from_trusted_len_iter(vec) }
};
Ok(Arc::new(interval_array) as ArrayRef)
cast_string_to_interval::<Offset, _, IntervalMonthDayNanoType>(
array,
cast_options,
parse_interval_month_day_nano,
)
}

fn adjust_timestamp_to_timezone<T: ArrowTimestampType>(
Expand Down

0 comments on commit 0d01612

Please sign in to comment.