diff --git a/arrow-string/src/regexp.rs b/arrow-string/src/regexp.rs index 25c712d20f08..05f3f3d60b39 100644 --- a/arrow-string/src/regexp.rs +++ b/arrow-string/src/regexp.rs @@ -152,28 +152,7 @@ pub fn regexp_is_match_utf8_scalar( Ok(BooleanArray::from(data)) } -/// Extract all groups matched by a regular expression for a given String array. -/// -/// Modelled after the Postgres [regexp_match]. -/// -/// Returns a ListArray of [`GenericStringArray`] with each element containing the leftmost-first -/// match of the corresponding index in `regex_array` to string in `array` -/// -/// If there is no match, the list element is NULL. -/// -/// If a match is found, and the pattern contains no capturing parenthesized subexpressions, -/// then the list element is a single-element [`GenericStringArray`] containing the substring -/// matching the whole pattern. -/// -/// If a match is found, and the pattern contains capturing parenthesized subexpressions, then the -/// list element is a [`GenericStringArray`] whose n'th element is the substring matching -/// the n'th capturing parenthesized subexpression of the pattern. -/// -/// The flags parameter is an optional text string containing zero or more single-letter flags -/// that change the function's behavior. -/// -/// [regexp_match]: https://www.postgresql.org/docs/current/functions-matching.html#FUNCTIONS-POSIX-REGEXP -pub fn regexp_match( +fn regexp_array_match( array: &GenericStringArray, regex_array: &GenericStringArray, flags_array: Option<&GenericStringArray>, @@ -248,6 +227,214 @@ pub fn regexp_match( Ok(Arc::new(list_builder.finish())) } +fn get_scalar_pattern_flag<'a, OffsetSize: OffsetSizeTrait>( + regex_array: &'a dyn Array, + flag_array: Option<&'a dyn Array>, +) -> (&'a str, Option<&'a str>) { + let regex = regex_array + .as_any() + .downcast_ref::>() + .expect("Unable to downcast to StringArray/LargeStringArray"); + let regex = regex.value(0); + + if flag_array.is_some() { + let flag = flag_array + .unwrap() + .as_any() + .downcast_ref::>() + .expect("Unable to downcast to StringArray/LargeStringArray"); + + if flag.is_valid(0) { + let flag = flag.value(0); + (regex, Some(flag)) + } else { + (regex, None) + } + } else { + (regex, None) + } +} + +fn regexp_scalar_match( + array: &dyn Array, + regex: Option<&Regex>, +) -> std::result::Result { + if regex.is_none() {} + + let builder: GenericStringBuilder = GenericStringBuilder::with_capacity(0, 0); + let mut list_builder = ListBuilder::new(builder); + + let array = array + .as_any() + .downcast_ref::>() + .expect("Unable to downcast to StringArray/LargeStringArray"); + + let regex = regex.unwrap(); + + array + .iter() + .map(|value| { + match value { + // Required for Postgres compatibility: + // SELECT regexp_match('foobarbequebaz', ''); = {""} + Some(_) if regex.as_str() == "" => { + list_builder.values().append_value(""); + list_builder.append(true); + } + Some(value) => match regex.captures(value) { + Some(caps) => { + let mut iter = caps.iter(); + if caps.len() > 1 { + iter.next(); + } + for m in iter.flatten() { + list_builder.values().append_value(m.as_str()); + } + + list_builder.append(true); + } + None => list_builder.append(false), + }, + _ => list_builder.append(false), + } + Ok(()) + }) + .collect::, ArrowError>>()?; + + Ok(Arc::new(list_builder.finish())) +} + +/// Extract all groups matched by a regular expression for a given String array. +/// +/// Modelled after the Postgres [regexp_match]. +/// +/// Returns a ListArray of [`GenericStringArray`] with each element containing the leftmost-first +/// match of the corresponding index in `regex_array` to string in `array` +/// +/// If there is no match, the list element is NULL. +/// +/// If a match is found, and the pattern contains no capturing parenthesized subexpressions, +/// then the list element is a single-element [`GenericStringArray`] containing the substring +/// matching the whole pattern. +/// +/// If a match is found, and the pattern contains capturing parenthesized subexpressions, then the +/// list element is a [`GenericStringArray`] whose n'th element is the substring matching +/// the n'th capturing parenthesized subexpression of the pattern. +/// +/// The flags parameter is an optional text string containing zero or more single-letter flags +/// that change the function's behavior. +/// +/// [regexp_match]: https://www.postgresql.org/docs/current/functions-matching.html#FUNCTIONS-POSIX-REGEXP +pub fn regexp_match( + array: &dyn Datum, + regex_array: &dyn Datum, + flags_array: Option<&dyn Datum>, +) -> std::result::Result { + let (lhs, is_lhs_scalar) = array.get(); + let (rhs, is_rhs_scalar) = regex_array.get(); + + let (flags, is_flags_scalar) = match flags_array { + Some(flags) => { + let (flags, is_flags_scalar) = flags.get(); + (Some(flags), Some(is_flags_scalar)) + } + None => (None, None), + }; + + if is_lhs_scalar { + return Err(ArrowError::ComputeError(format!( + "regexp_match() requires array to be either Utf8 or LargeUtf8 array instead of scalar" + ))); + } + + if is_flags_scalar.is_some() && is_rhs_scalar != is_flags_scalar.unwrap() { + return Err(ArrowError::ComputeError(format!( + "regexp_match() requires both pattern and flags to be either scalar or array" + ))); + } + + if flags_array.is_some() && rhs.data_type() != flags.unwrap().data_type() { + return Err(ArrowError::ComputeError(format!( + "regexp_match() requires both pattern and flags to be either string or largestring" + ))); + } + + if is_rhs_scalar { + // Regex and flag is scalars + let (regex, flag) = match rhs.data_type() { + DataType::Utf8 => get_scalar_pattern_flag::(rhs, flags), + DataType::LargeUtf8 => get_scalar_pattern_flag::(rhs, flags), + _ => { + return Err(ArrowError::ComputeError(format!( + "regexp_match() requires pattern to be either Utf8 or LargeUtf8" + ))); + } + }; + + let pattern = if let Some(flag) = flag { + format!("(?{regex}){flag}") + } else { + regex.to_string() + }; + + let re = Regex::new(pattern.as_str()).map_err(|e| { + ArrowError::ComputeError(format!("Regular expression did not compile: {e:?}")) + })?; + + match lhs.data_type() { + DataType::Utf8 => regexp_scalar_match::(lhs, Some(&re)), + DataType::LargeUtf8 => regexp_scalar_match::(lhs, Some(&re)), + _ => { + return Err(ArrowError::ComputeError(format!( + "regexp_match() requires array to be either Utf8 or LargeUtf8" + ))); + } + } + } else { + match rhs.data_type() { + DataType::Utf8 => { + let array = lhs + .as_any() + .downcast_ref::>() + .expect("Unable to downcast to StringArray/LargeStringArray"); + let regex_array = rhs + .as_any() + .downcast_ref::>() + .expect("Unable to downcast to StringArray/LargeStringArray"); + let flags_array = flags.map(|flags| { + flags + .as_any() + .downcast_ref::>() + .expect("Unable to downcast to StringArray/LargeStringArray") + }); + regexp_array_match(array, regex_array, flags_array) + } + DataType::LargeUtf8 => { + let array = lhs + .as_any() + .downcast_ref::>() + .expect("Unable to downcast to StringArray/LargeStringArray"); + let regex_array = rhs + .as_any() + .downcast_ref::>() + .expect("Unable to downcast to StringArray/LargeStringArray"); + let flags_array = flags.map(|flags| { + flags + .as_any() + .downcast_ref::>() + .expect("Unable to downcast to StringArray/LargeStringArray") + }); + regexp_array_match(array, regex_array, flags_array) + } + _ => { + return Err(ArrowError::ComputeError(format!( + "regexp_match() requires pattern to be either Utf8 or LargeUtf8" + ))); + } + } + } +} + #[cfg(test)] mod tests { use super::*;