Skip to content

Commit

Permalink
Add more tests
Browse files Browse the repository at this point in the history
  • Loading branch information
viirya committed Dec 26, 2023
1 parent 583fc77 commit a456e2c
Showing 1 changed file with 63 additions and 11 deletions.
74 changes: 63 additions & 11 deletions arrow-string/src/regexp.rs
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@ use arrow_array::builder::{BooleanBufferBuilder, GenericStringBuilder, ListBuild
use arrow_array::*;
use arrow_buffer::NullBuffer;
use arrow_data::{ArrayData, ArrayDataBuilder};
use arrow_schema::{ArrowError, DataType};
use arrow_schema::{ArrowError, DataType, Field};
use regex::Regex;
use std::collections::HashMap;
use std::sync::Arc;
Expand Down Expand Up @@ -230,12 +230,16 @@ fn regexp_array_match<OffsetSize: OffsetSizeTrait>(
fn get_scalar_pattern_flag<'a, OffsetSize: OffsetSizeTrait>(
regex_array: &'a dyn Array,
flag_array: Option<&'a dyn Array>,
) -> (&'a str, Option<&'a str>) {
) -> (Option<&'a str>, Option<&'a str>) {
let regex = regex_array
.as_any()
.downcast_ref::<GenericStringArray<OffsetSize>>()
.expect("Unable to downcast to StringArray/LargeStringArray");
let regex = regex.value(0);
let regex = if regex.is_valid(0) {
Some(regex.value(0))
} else {
None
};

if flag_array.is_some() {
let flag = flag_array
Expand All @@ -257,10 +261,8 @@ fn get_scalar_pattern_flag<'a, OffsetSize: OffsetSizeTrait>(

fn regexp_scalar_match<OffsetSize: OffsetSizeTrait>(
array: &dyn Array,
regex: Option<&Regex>,
regex: &Regex,
) -> std::result::Result<ArrayRef, ArrowError> {
if regex.is_none() {}

let builder: GenericStringBuilder<OffsetSize> = GenericStringBuilder::with_capacity(0, 0);
let mut list_builder = ListBuilder::new(builder);

Expand All @@ -269,8 +271,6 @@ fn regexp_scalar_match<OffsetSize: OffsetSizeTrait>(
.downcast_ref::<GenericStringArray<OffsetSize>>()
.expect("Unable to downcast to StringArray/LargeStringArray");

let regex = regex.unwrap();

array
.iter()
.map(|value| {
Expand Down Expand Up @@ -371,8 +371,17 @@ pub fn regexp_match(
}
};

if regex.is_none() {
return Ok(new_null_array(
&DataType::List(Arc::new(Field::new("item", lhs.data_type().clone(), true))),
lhs.len(),
));
}

let regex = regex.unwrap();

let pattern = if let Some(flag) = flag {
format!("(?{regex}){flag}")
format!("(?{flag}){regex}")
} else {
regex.to_string()
};
Expand All @@ -382,8 +391,8 @@ pub fn regexp_match(
})?;

match lhs.data_type() {
DataType::Utf8 => regexp_scalar_match::<i32>(lhs, Some(&re)),
DataType::LargeUtf8 => regexp_scalar_match::<i64>(lhs, Some(&re)),
DataType::Utf8 => regexp_scalar_match::<i32>(lhs, &re),
DataType::LargeUtf8 => regexp_scalar_match::<i64>(lhs, &re),
_ => {
return Err(ArrowError::ComputeError(format!(
"regexp_match() requires array to be either Utf8 or LargeUtf8"
Expand Down Expand Up @@ -491,6 +500,49 @@ mod tests {
assert_eq!(&expected, result);
}

#[test]
fn match_scalar_pattern() {
let values = vec![Some("abc-005-def"), Some("X-7-5"), Some("X545"), None];
let array = StringArray::from(values);
let pattern = Scalar::new(StringArray::from(vec![r"x.*-(\d*)-.*"; 1]));
let flags = Scalar::new(StringArray::from(vec!["i"; 1]));
let actual = regexp_match(&array, &pattern, Some(&flags)).unwrap();
let elem_builder: GenericStringBuilder<i32> = GenericStringBuilder::with_capacity(0, 0);
let mut expected_builder = ListBuilder::new(elem_builder);
expected_builder.append(false);
expected_builder.values().append_value("7");
expected_builder.append(true);
expected_builder.append(false);
expected_builder.append(false);
let expected = expected_builder.finish();
let result = actual.as_any().downcast_ref::<ListArray>().unwrap();
assert_eq!(&expected, result);

// No flag
let values = vec![Some("abc-005-def"), Some("x-7-5"), Some("X545"), None];
let array = StringArray::from(values);
let actual = regexp_match(&array, &pattern, None).unwrap();
let result = actual.as_any().downcast_ref::<ListArray>().unwrap();
assert_eq!(&expected, result);
}

#[test]
fn match_scalar_no_pattern() {
let values = vec![Some("abc-005-def"), Some("X-7-5"), Some("X545"), None];
let array = StringArray::from(values);
let pattern = Scalar::new(new_null_array(&DataType::Utf8, 1));
let actual = regexp_match(&array, &pattern, None).unwrap();
let elem_builder: GenericStringBuilder<i32> = GenericStringBuilder::with_capacity(0, 0);
let mut expected_builder = ListBuilder::new(elem_builder);
expected_builder.append(false);
expected_builder.append(false);
expected_builder.append(false);
expected_builder.append(false);
let expected = expected_builder.finish();
let result = actual.as_any().downcast_ref::<ListArray>().unwrap();
assert_eq!(&expected, result);
}

#[test]
fn test_single_group_not_skip_match() {
let array = StringArray::from(vec![Some("foo"), Some("bar")]);
Expand Down

0 comments on commit a456e2c

Please sign in to comment.