diff --git a/arrow-string/src/regexp.rs b/arrow-string/src/regexp.rs index d14662be728..12a95aa78f1 100644 --- a/arrow-string/src/regexp.rs +++ b/arrow-string/src/regexp.rs @@ -20,7 +20,9 @@ use crate::like::StringArrayType; -use arrow_array::builder::{BooleanBufferBuilder, GenericStringBuilder, ListBuilder}; +use arrow_array::builder::{ + BooleanBufferBuilder, GenericStringBuilder, ListBuilder, StringViewBuilder, +}; use arrow_array::cast::AsArray; use arrow_array::*; use arrow_buffer::NullBuffer; @@ -243,78 +245,96 @@ where Ok(BooleanArray::from(data)) } -fn regexp_array_match( - array: &GenericStringArray, - regex_array: &GenericStringArray, - flags_array: Option<&GenericStringArray>, -) -> Result { - let mut patterns: HashMap = HashMap::new(); - let builder: GenericStringBuilder = GenericStringBuilder::with_capacity(0, 0); - let mut list_builder = ListBuilder::new(builder); +macro_rules! process_regexp_array_match { + ($array:expr, $regex_array:expr, $flags_array:expr, $list_builder:expr) => { + let mut patterns: HashMap = HashMap::new(); - let complete_pattern = match flags_array { - Some(flags) => Box::new( - regex_array - .iter() - .zip(flags.iter()) - .map(|(pattern, flags)| { + let complete_pattern = match $flags_array { + Some(flags) => Box::new($regex_array.iter().zip(flags.iter()).map( + |(pattern, flags)| { pattern.map(|pattern| match flags { Some(value) => format!("(?{value}){pattern}"), None => pattern.to_string(), }) - }), - ) as Box>>, - None => Box::new( - regex_array - .iter() - .map(|pattern| pattern.map(|pattern| pattern.to_string())), - ), - }; + }, + )) as Box>>, + None => Box::new( + $regex_array + .iter() + .map(|pattern| pattern.map(|pattern| pattern.to_string())), + ), + }; - array - .iter() - .zip(complete_pattern) - .map(|(value, pattern)| { - match (value, pattern) { - // Required for Postgres compatibility: - // SELECT regexp_match('foobarbequebaz', ''); = {""} - (Some(_), Some(pattern)) if pattern == *"" => { - list_builder.values().append_value(""); - list_builder.append(true); - } - (Some(value), Some(pattern)) => { - let existing_pattern = patterns.get(&pattern); - let re = match existing_pattern { - Some(re) => re, - None => { - let re = Regex::new(pattern.as_str()).map_err(|e| { - ArrowError::ComputeError(format!( - "Regular expression did not compile: {e:?}" - )) - })?; - patterns.entry(pattern).or_insert(re) - } - }; - match re.captures(value) { - Some(caps) => { - let mut iter = caps.iter(); - if caps.len() > 1 { - iter.next(); - } - for m in iter.flatten() { - list_builder.values().append_value(m.as_str()); + $array + .iter() + .zip(complete_pattern) + .map(|(value, pattern)| { + match (value, pattern) { + // Required for Postgres compatibility: + // SELECT regexp_match('foobarbequebaz', ''); = {""} + (Some(_), Some(pattern)) if pattern == *"" => { + $list_builder.values().append_value(""); + $list_builder.append(true); + } + (Some(value), Some(pattern)) => { + let existing_pattern = patterns.get(&pattern); + let re = match existing_pattern { + Some(re) => re, + None => { + let re = Regex::new(pattern.as_str()).map_err(|e| { + ArrowError::ComputeError(format!( + "Regular expression did not compile: {e:?}" + )) + })?; + patterns.entry(pattern).or_insert(re) } + }; + match re.captures(value) { + Some(caps) => { + let mut iter = caps.iter(); + if caps.len() > 1 { + iter.next(); + } + for m in iter.flatten() { + $list_builder.values().append_value(m.as_str()); + } - list_builder.append(true); + $list_builder.append(true); + } + None => $list_builder.append(false), } - None => list_builder.append(false), } + _ => $list_builder.append(false), } - _ => list_builder.append(false), - } - Ok(()) - }) - .collect::, ArrowError>>()?; + Ok(()) + }) + .collect::, ArrowError>>()?; + }; +} + +fn regexp_array_match( + array: &GenericStringArray, + regex_array: &GenericStringArray, + flags_array: Option<&GenericStringArray>, +) -> Result { + let builder: GenericStringBuilder = GenericStringBuilder::with_capacity(0, 0); + let mut list_builder = ListBuilder::new(builder); + + process_regexp_array_match!(array, regex_array, flags_array, list_builder); + + Ok(Arc::new(list_builder.finish())) +} + +fn regexp_array_match_utf8view( + array: &StringViewArray, + regex_array: &StringViewArray, + flags_array: Option<&StringViewArray>, +) -> Result { + let builder = StringViewBuilder::with_capacity(0); + let mut list_builder = ListBuilder::new(builder); + + process_regexp_array_match!(array, regex_array, flags_array, list_builder); + Ok(Arc::new(list_builder.finish())) } @@ -333,6 +353,54 @@ fn get_scalar_pattern_flag<'a, OffsetSize: OffsetSizeTrait>( } } +fn get_scalar_pattern_flag_utf8view<'a>( + regex_array: &'a dyn Array, + flag_array: Option<&'a dyn Array>, +) -> (Option<&'a str>, Option<&'a str>) { + let regex = regex_array.as_string_view(); + let regex = regex.is_valid(0).then(|| regex.value(0)); + + if let Some(flag_array) = flag_array { + let flag = flag_array.as_string_view(); + (regex, flag.is_valid(0).then(|| flag.value(0))) + } else { + (regex, None) + } +} + +macro_rules! process_regexp_match { + ($array:expr, $regex:expr, $list_builder:expr) => { + $array + .iter() + .map(|value| { + match value { + // Required for Postgres compatibility: + // SELECT regexp_match('foobarbequebaz', ''); = {""} + Some(_) if $regex.as_str().is_empty() => { + $list_builder.values().append_value(""); + $list_builder.append(true); + } + Some(value) => match $regex.captures(value) { + Some(caps) => { + let mut iter = caps.iter(); + if caps.len() > 1 { + iter.next(); + } + for m in iter.flatten() { + $list_builder.values().append_value(m.as_str()); + } + $list_builder.append(true); + } + None => $list_builder.append(false), + }, + None => $list_builder.append(false), + } + Ok(()) + }) + .collect::, ArrowError>>()? + }; +} + fn regexp_scalar_match( array: &GenericStringArray, regex: &Regex, @@ -340,35 +408,19 @@ fn regexp_scalar_match( let builder: GenericStringBuilder = GenericStringBuilder::with_capacity(0, 0); let mut list_builder = ListBuilder::new(builder); - array - .iter() - .map(|value| { - match value { - // Required for Postgres compatibility: - // SELECT regexp_match('foobarbequebaz', ''); = {""} - Some(_) if regex.as_str() == "" => { - list_builder.values().append_value(""); - list_builder.append(true); - } - Some(value) => match regex.captures(value) { - Some(caps) => { - let mut iter = caps.iter(); - if caps.len() > 1 { - iter.next(); - } - for m in iter.flatten() { - list_builder.values().append_value(m.as_str()); - } + process_regexp_match!(array, regex, list_builder); - list_builder.append(true); - } - None => list_builder.append(false), - }, - _ => list_builder.append(false), - } - Ok(()) - }) - .collect::, ArrowError>>()?; + Ok(Arc::new(list_builder.finish())) +} + +fn regexp_scalar_match_utf8view( + array: &StringViewArray, + regex: &Regex, +) -> Result { + let builder = StringViewBuilder::with_capacity(0); + let mut list_builder = ListBuilder::new(builder); + + process_regexp_match!(array, regex, list_builder); Ok(Arc::new(list_builder.finish())) } @@ -406,7 +458,7 @@ pub fn regexp_match( if array.data_type() != rhs.data_type() { return Err(ArrowError::ComputeError( - "regexp_match() requires both array and pattern to be either Utf8 or LargeUtf8" + "regexp_match() requires both array and pattern to be either Utf8, Utf8View or LargeUtf8" .to_string(), )); } @@ -428,7 +480,7 @@ pub fn regexp_match( if flags_array.is_some() && rhs.data_type() != flags.unwrap().data_type() { return Err(ArrowError::ComputeError( - "regexp_match() requires both pattern and flags to be either string or largestring" + "regexp_match() requires both pattern and flags to be either Utf8, Utf8View or LargeUtf8" .to_string(), )); } @@ -436,11 +488,13 @@ pub fn regexp_match( if is_rhs_scalar { // Regex and flag is scalars let (regex, flag) = match rhs.data_type() { + DataType::Utf8View => get_scalar_pattern_flag_utf8view(rhs, flags), DataType::Utf8 => get_scalar_pattern_flag::(rhs, flags), DataType::LargeUtf8 => get_scalar_pattern_flag::(rhs, flags), _ => { return Err(ArrowError::ComputeError( - "regexp_match() requires pattern to be either Utf8 or LargeUtf8".to_string(), + "regexp_match() requires pattern to be either Utf8, Utf8View or LargeUtf8" + .to_string(), )); } }; @@ -468,14 +522,21 @@ pub fn regexp_match( })?; match array.data_type() { + DataType::Utf8View => regexp_scalar_match_utf8view(array.as_string_view(), &re), DataType::Utf8 => regexp_scalar_match(array.as_string::(), &re), DataType::LargeUtf8 => regexp_scalar_match(array.as_string::(), &re), _ => Err(ArrowError::ComputeError( - "regexp_match() requires array to be either Utf8 or LargeUtf8".to_string(), + "regexp_match() requires array to be either Utf8, Utf8View or LargeUtf8" + .to_string(), )), } } else { match array.data_type() { + DataType::Utf8View => { + let regex_array = rhs.as_string_view(); + let flags_array = flags.map(|flags| flags.as_string_view()); + regexp_array_match_utf8view(array.as_string_view(), regex_array, flags_array) + } DataType::Utf8 => { let regex_array = rhs.as_string(); let flags_array = flags.map(|flags| flags.as_string()); @@ -487,7 +548,8 @@ pub fn regexp_match( regexp_array_match(array.as_string::(), regex_array, flags_array) } _ => Err(ArrowError::ComputeError( - "regexp_match() requires array to be either Utf8 or LargeUtf8".to_string(), + "regexp_match() requires array to be either Utf8, Utf8View or LargeUtf8" + .to_string(), )), } } @@ -497,114 +559,292 @@ pub fn regexp_match( mod tests { use super::*; - #[test] - fn match_single_group() { - let values = vec![ + macro_rules! test_match_single_group { + ($test_name:ident, $values:expr, $patterns:expr, $arr_type:ty, $builder_type:ty, $expected:expr) => { + #[test] + fn $test_name() { + let array: $arr_type = <$arr_type>::from($values); + let pattern: $arr_type = <$arr_type>::from($patterns); + + let actual = regexp_match(&array, &pattern, None).unwrap(); + + let elem_builder: $builder_type = <$builder_type>::new(); + let mut expected_builder = ListBuilder::new(elem_builder); + + for val in $expected { + match val { + Some(v) => { + expected_builder.values().append_value(v); + expected_builder.append(true); + } + None => expected_builder.append(false), + } + } + + let expected = expected_builder.finish(); + let result = actual.as_any().downcast_ref::().unwrap(); + assert_eq!(&expected, result); + } + }; + } + + test_match_single_group!( + match_single_group_string, + vec![ Some("abc-005-def"), Some("X-7-5"), Some("X545"), None, Some("foobarbequebaz"), Some("foobarbequebaz"), - ]; - let array = StringArray::from(values); - let mut pattern_values = vec![r".*-(\d*)-.*"; 4]; - pattern_values.push(r"(bar)(bequ1e)"); - pattern_values.push(""); - let pattern = GenericStringArray::::from(pattern_values); - let actual = regexp_match(&array, &pattern, None).unwrap(); - let elem_builder: GenericStringBuilder = GenericStringBuilder::new(); - let mut expected_builder = ListBuilder::new(elem_builder); - expected_builder.values().append_value("005"); - expected_builder.append(true); - expected_builder.values().append_value("7"); - expected_builder.append(true); - expected_builder.append(false); - expected_builder.append(false); - expected_builder.append(false); - expected_builder.values().append_value(""); - expected_builder.append(true); - let expected = expected_builder.finish(); - let result = actual.as_any().downcast_ref::().unwrap(); - assert_eq!(&expected, result); - } + ], + vec![ + r".*-(\d*)-.*", + r".*-(\d*)-.*", + r".*-(\d*)-.*", + r".*-(\d*)-.*", + r"(bar)(bequ1e)", + "" + ], + StringArray, + GenericStringBuilder, + [Some("005"), Some("7"), None, None, None, Some("")] + ); + test_match_single_group!( + match_single_group_string_view, + vec![ + Some("abc-005-def"), + Some("X-7-5"), + Some("X545"), + None, + Some("foobarbequebaz"), + Some("foobarbequebaz"), + ], + vec![ + r".*-(\d*)-.*", + r".*-(\d*)-.*", + r".*-(\d*)-.*", + r".*-(\d*)-.*", + r"(bar)(bequ1e)", + "" + ], + StringViewArray, + StringViewBuilder, + [Some("005"), Some("7"), None, None, None, Some("")] + ); + + macro_rules! test_match_single_group_with_flags { + ($test_name:ident, $values:expr, $patterns:expr, $flags:expr, $array_type:ty, $builder_type:ty, $expected:expr) => { + #[test] + fn $test_name() { + let array: $array_type = <$array_type>::from($values); + let pattern: $array_type = <$array_type>::from($patterns); + let flags: $array_type = <$array_type>::from($flags); + + let actual = regexp_match(&array, &pattern, Some(&flags)).unwrap(); + + let elem_builder: $builder_type = <$builder_type>::new(); + let mut expected_builder = ListBuilder::new(elem_builder); + + for val in $expected { + match val { + Some(v) => { + expected_builder.values().append_value(v); + expected_builder.append(true); + } + None => { + expected_builder.append(false); + } + } + } - #[test] - fn match_single_group_with_flags() { - let values = vec![Some("abc-005-def"), Some("X-7-5"), Some("X545"), None]; - let array = StringArray::from(values); - let pattern = StringArray::from(vec![r"x.*-(\d*)-.*"; 4]); - let flags = StringArray::from(vec!["i"; 4]); - let actual = regexp_match(&array, &pattern, Some(&flags)).unwrap(); - let elem_builder: GenericStringBuilder = GenericStringBuilder::with_capacity(0, 0); - let mut expected_builder = ListBuilder::new(elem_builder); - expected_builder.append(false); - expected_builder.values().append_value("7"); - expected_builder.append(true); - expected_builder.append(false); - expected_builder.append(false); - let expected = expected_builder.finish(); - let result = actual.as_any().downcast_ref::().unwrap(); - assert_eq!(&expected, result); + let expected = expected_builder.finish(); + let result = actual.as_any().downcast_ref::().unwrap(); + assert_eq!(&expected, result); + } + }; } - #[test] - fn match_scalar_pattern() { - let values = vec![Some("abc-005-def"), Some("X-7-5"), Some("X545"), None]; - let array = StringArray::from(values); - let pattern = Scalar::new(StringArray::from(vec![r"x.*-(\d*)-.*"; 1])); - let flags = Scalar::new(StringArray::from(vec!["i"; 1])); - let actual = regexp_match(&array, &pattern, Some(&flags)).unwrap(); - let elem_builder: GenericStringBuilder = GenericStringBuilder::with_capacity(0, 0); - let mut expected_builder = ListBuilder::new(elem_builder); - expected_builder.append(false); - expected_builder.values().append_value("7"); - expected_builder.append(true); - expected_builder.append(false); - expected_builder.append(false); - let expected = expected_builder.finish(); - let result = actual.as_any().downcast_ref::().unwrap(); - assert_eq!(&expected, result); - - // No flag - let values = vec![Some("abc-005-def"), Some("x-7-5"), Some("X545"), None]; - let array = StringArray::from(values); - let actual = regexp_match(&array, &pattern, None).unwrap(); - let result = actual.as_any().downcast_ref::().unwrap(); - assert_eq!(&expected, result); + test_match_single_group_with_flags!( + match_single_group_with_flags_string, + vec![Some("abc-005-def"), Some("X-7-5"), Some("X545"), None], + vec![r"x.*-(\d*)-.*"; 4], + vec!["i"; 4], + StringArray, + GenericStringBuilder, + [None, Some("7"), None, None] + ); + test_match_single_group_with_flags!( + match_single_group_with_flags_stringview, + vec![Some("abc-005-def"), Some("X-7-5"), Some("X545"), None], + vec![r"x.*-(\d*)-.*"; 4], + vec!["i"; 4], + StringViewArray, + StringViewBuilder, + [None, Some("7"), None, None] + ); + + macro_rules! test_match_scalar_pattern { + ($test_name:ident, $values:expr, $pattern:expr, $flag:expr, $array_type:ty, $builder_type:ty, $expected:expr) => { + #[test] + fn $test_name() { + let array: $array_type = <$array_type>::from($values); + + let pattern_scalar = Scalar::new(<$array_type>::from(vec![$pattern; 1])); + let flag_scalar = Scalar::new(<$array_type>::from(vec![$flag; 1])); + + let actual = regexp_match(&array, &pattern_scalar, Some(&flag_scalar)).unwrap(); + + let elem_builder: $builder_type = <$builder_type>::new(); + let mut expected_builder = ListBuilder::new(elem_builder); + + for val in $expected { + match val { + Some(v) => { + expected_builder.values().append_value(v); + expected_builder.append(true); + } + None => expected_builder.append(false), + } + } + + let expected = expected_builder.finish(); + let result = actual.as_any().downcast_ref::().unwrap(); + assert_eq!(&expected, result); + } + }; } - #[test] - fn match_scalar_no_pattern() { - let values = vec![Some("abc-005-def"), Some("X-7-5"), Some("X545"), None]; - let array = StringArray::from(values); - let pattern = Scalar::new(new_null_array(&DataType::Utf8, 1)); - let actual = regexp_match(&array, &pattern, None).unwrap(); - let elem_builder: GenericStringBuilder = GenericStringBuilder::with_capacity(0, 0); - let mut expected_builder = ListBuilder::new(elem_builder); - expected_builder.append(false); - expected_builder.append(false); - expected_builder.append(false); - expected_builder.append(false); - let expected = expected_builder.finish(); - let result = actual.as_any().downcast_ref::().unwrap(); - assert_eq!(&expected, result); + test_match_scalar_pattern!( + match_scalar_pattern_string_with_flags, + vec![Some("abc-005-def"), Some("X-7-5"), Some("X545"), None], + r"x.*-(\d*)-.*", + Some("i"), + StringArray, + GenericStringBuilder, + [None, Some("7"), None, None] + ); + test_match_scalar_pattern!( + match_scalar_pattern_stringview_with_flags, + vec![Some("abc-005-def"), Some("X-7-5"), Some("X545"), None], + r"x.*-(\d*)-.*", + Some("i"), + StringViewArray, + StringViewBuilder, + [None, Some("7"), None, None] + ); + + test_match_scalar_pattern!( + match_scalar_pattern_string_no_flags, + vec![Some("abc-005-def"), Some("x-7-5"), Some("X545"), None], + r"x.*-(\d*)-.*", + None::<&str>, + StringArray, + GenericStringBuilder, + [None, Some("7"), None, None] + ); + test_match_scalar_pattern!( + match_scalar_pattern_stringview_no_flags, + vec![Some("abc-005-def"), Some("x-7-5"), Some("X545"), None], + r"x.*-(\d*)-.*", + None::<&str>, + StringViewArray, + StringViewBuilder, + [None, Some("7"), None, None] + ); + + macro_rules! test_match_scalar_no_pattern { + ($test_name:ident, $values:expr, $array_type:ty, $pattern_type:expr, $builder_type:ty, $expected:expr) => { + #[test] + fn $test_name() { + let array: $array_type = <$array_type>::from($values); + let pattern = Scalar::new(new_null_array(&$pattern_type, 1)); + + let actual = regexp_match(&array, &pattern, None).unwrap(); + + let elem_builder: $builder_type = <$builder_type>::new(); + let mut expected_builder = ListBuilder::new(elem_builder); + + for val in $expected { + match val { + Some(v) => { + expected_builder.values().append_value(v); + expected_builder.append(true); + } + None => expected_builder.append(false), + } + } + + let expected = expected_builder.finish(); + let result = actual.as_any().downcast_ref::().unwrap(); + assert_eq!(&expected, result); + } + }; } - #[test] - fn test_single_group_not_skip_match() { - let array = StringArray::from(vec![Some("foo"), Some("bar")]); - let pattern = GenericStringArray::::from(vec![r"foo"]); - let actual = regexp_match(&array, &pattern, None).unwrap(); - let result = actual.as_any().downcast_ref::().unwrap(); - let elem_builder: GenericStringBuilder = GenericStringBuilder::new(); - let mut expected_builder = ListBuilder::new(elem_builder); - expected_builder.values().append_value("foo"); - expected_builder.append(true); - let expected = expected_builder.finish(); - assert_eq!(&expected, result); + test_match_scalar_no_pattern!( + match_scalar_no_pattern_string, + vec![Some("abc-005-def"), Some("X-7-5"), Some("X545"), None], + StringArray, + DataType::Utf8, + GenericStringBuilder, + [None::<&str>, None, None, None] + ); + test_match_scalar_no_pattern!( + match_scalar_no_pattern_stringview, + vec![Some("abc-005-def"), Some("X-7-5"), Some("X545"), None], + StringViewArray, + DataType::Utf8View, + StringViewBuilder, + [None::<&str>, None, None, None] + ); + + macro_rules! test_match_single_group_not_skip { + ($test_name:ident, $values:expr, $pattern:expr, $array_type:ty, $builder_type:ty, $expected:expr) => { + #[test] + fn $test_name() { + let array: $array_type = <$array_type>::from($values); + let pattern: $array_type = <$array_type>::from(vec![$pattern]); + + let actual = regexp_match(&array, &pattern, None).unwrap(); + + let elem_builder: $builder_type = <$builder_type>::new(); + let mut expected_builder = ListBuilder::new(elem_builder); + + for val in $expected { + match val { + Some(v) => { + expected_builder.values().append_value(v); + expected_builder.append(true); + } + None => expected_builder.append(false), + } + } + + let expected = expected_builder.finish(); + let result = actual.as_any().downcast_ref::().unwrap(); + assert_eq!(&expected, result); + } + }; } + test_match_single_group_not_skip!( + match_single_group_not_skip_string, + vec![Some("foo"), Some("bar")], + r"foo", + StringArray, + GenericStringBuilder, + [Some("foo")] + ); + test_match_single_group_not_skip!( + match_single_group_not_skip_stringview, + vec![Some("foo"), Some("bar")], + r"foo", + StringViewArray, + StringViewBuilder, + [Some("foo")] + ); + macro_rules! test_flag_utf8 { ($test_name:ident, $left:expr, $right:expr, $op:expr, $expected:expr) => { #[test]