From 08bdb294e975388d89b940f6320a4132b63133c2 Mon Sep 17 00:00:00 2001 From: Tai Le Manh Date: Mon, 26 Aug 2024 15:13:05 +0700 Subject: [PATCH 1/7] Implement native support StringView for contains function Signed-off-by: Tai Le Manh --- datafusion/functions/src/string/common.rs | 99 ++++++++- datafusion/functions/src/string/contains.rs | 203 ++++++++++++++++-- .../sqllogictest/test_files/string_view.slt | 42 +++- .../source/user-guide/sql/scalar_functions.md | 2 +- 4 files changed, 309 insertions(+), 37 deletions(-) diff --git a/datafusion/functions/src/string/common.rs b/datafusion/functions/src/string/common.rs index 6f23a5ddd236..7df1e485cae9 100644 --- a/datafusion/functions/src/string/common.rs +++ b/datafusion/functions/src/string/common.rs @@ -17,21 +17,24 @@ //! Common utilities for implementing string functions +use std::collections::HashMap; use std::fmt::{Display, Formatter}; use std::sync::Arc; use arrow::array::{ new_null_array, Array, ArrayAccessor, ArrayDataBuilder, ArrayIter, ArrayRef, - GenericStringArray, GenericStringBuilder, OffsetSizeTrait, StringArray, + BooleanArray, GenericStringArray, GenericStringBuilder, OffsetSizeTrait, StringArray, StringViewArray, }; use arrow::buffer::{Buffer, MutableBuffer, NullBuffer}; use arrow::datatypes::DataType; - +use arrow::error::ArrowError; +use arrow_buffer::BooleanBufferBuilder; use datafusion_common::cast::{as_generic_string_array, as_string_view_array}; -use datafusion_common::Result; use datafusion_common::{exec_err, ScalarValue}; +use datafusion_common::{DataFusionError, Result}; use datafusion_expr::ColumnarValue; +use regex::Regex; pub(crate) enum TrimType { Left, @@ -458,3 +461,93 @@ where GenericStringArray::::new_unchecked(offsets, values, nulls) })) } + +/// Perform SQL `array ~ regex_array` operation on +/// [`StringArray`] / [`LargeStringArray`] / [`StringViewArray`]. +/// If `regex_array` element has an empty value, the corresponding result value is always true. +/// +/// `flags_array` are optional [`StringArray`] / [`LargeStringArray`] / [`StringViewArray`] flag, +/// which allow special search modes, such as case-insensitive and multi-line mode. +/// See the documentation [here](https://docs.rs/regex/1.5.4/regex/#grouping-and-flags) +/// for more information. +/// +/// It is inspired / copied from `regexp_is_match_utf8` [arrow-rs]. +/// +/// [arrow-rs]: https://github.com/apache/arrow-rs/blob/8c956a9f9ab26c14072740cce64c2b99cb039b13/arrow-string/src/regexp.rs#L31-L37 +pub fn regexp_is_match<'a, ArrayType1, ArrayType2, ArrayType3>( + array: &'a ArrayType1, + regex_array: &'a ArrayType2, + flags_array: Option<&'a ArrayType3>, +) -> Result +where + &'a ArrayType1: StringArrayType<'a>, + &'a ArrayType2: StringArrayType<'a>, + &'a ArrayType3: StringArrayType<'a>, +{ + if array.len() != regex_array.len() { + return Err(DataFusionError::Execution( + "Cannot perform comparison operation on arrays of different length" + .to_string(), + )); + } + + let nulls = NullBuffer::union(array.nulls(), regex_array.nulls()); + + let mut patterns: HashMap = HashMap::new(); + let mut result = BooleanBufferBuilder::new(array.len()); + + let complete_pattern = match flags_array { + Some(flags) => Box::new(regex_array.iter().zip(flags.iter()).map( + |(pattern, flags)| { + pattern.map(|pattern| match flags { + Some(flag) => format!("(?{flag}){pattern}"), + None => pattern.to_string(), + }) + }, + )) as Box>>, + None => Box::new( + regex_array + .iter() + .map(|pattern| pattern.map(|pattern| pattern.to_string())), + ), + }; + + array + .iter() + .zip(complete_pattern) + .map(|(value, pattern)| { + match (value, pattern) { + (Some(_), Some(pattern)) if pattern == *"" => { + result.append(true); + } + (Some(value), Some(pattern)) => { + let existing_pattern = patterns.get(&pattern); + let re = match existing_pattern { + Some(re) => re, + None => { + let re = Regex::new(pattern.as_str()).map_err(|e| { + DataFusionError::Execution(format!( + "Regular expression did not compile: {e:?}" + )) + })?; + patterns.entry(pattern).or_insert(re) + } + }; + result.append(re.is_match(value)); + } + _ => result.append(false), + } + Ok(()) + }) + .collect::, ArrowError>>()?; + + let data = unsafe { + ArrayDataBuilder::new(DataType::Boolean) + .len(array.len()) + .buffers(vec![result.into()]) + .nulls(nulls) + .build_unchecked() + }; + + Ok(BooleanArray::from(data)) +} diff --git a/datafusion/functions/src/string/contains.rs b/datafusion/functions/src/string/contains.rs index faf979f80614..73c68cbf2115 100644 --- a/datafusion/functions/src/string/contains.rs +++ b/datafusion/functions/src/string/contains.rs @@ -15,19 +15,22 @@ // specific language governing permissions and limitations // under the License. +use crate::string::common::regexp_is_match; use crate::utils::make_scalar_function; -use arrow::array::{ArrayRef, OffsetSizeTrait}; + +use arrow::array::{Array, ArrayRef, AsArray, GenericStringArray, StringViewArray}; use arrow::datatypes::DataType; -use arrow::datatypes::DataType::Boolean; -use datafusion_common::cast::as_generic_string_array; +use arrow::datatypes::DataType::{Boolean, LargeUtf8, Utf8, Utf8View}; +use datafusion_common::exec_err; use datafusion_common::DataFusionError; use datafusion_common::Result; -use datafusion_common::{arrow_datafusion_err, exec_err}; use datafusion_expr::ScalarUDFImpl; use datafusion_expr::TypeSignature::Exact; use datafusion_expr::{ColumnarValue, Signature, Volatility}; + use std::any::Any; use std::sync::Arc; + #[derive(Debug)] pub struct ContainsFunc { signature: Signature, @@ -44,7 +47,17 @@ impl ContainsFunc { use DataType::*; Self { signature: Signature::one_of( - vec![Exact(vec![Utf8, Utf8]), Exact(vec![LargeUtf8, LargeUtf8])], + vec![ + Exact(vec![Utf8View, Utf8View]), + Exact(vec![Utf8View, Utf8]), + Exact(vec![Utf8View, LargeUtf8]), + Exact(vec![Utf8, Utf8View]), + Exact(vec![Utf8, Utf8]), + Exact(vec![Utf8, LargeUtf8]), + Exact(vec![LargeUtf8, Utf8View]), + Exact(vec![LargeUtf8, Utf8]), + Exact(vec![LargeUtf8, LargeUtf8]), + ], Volatility::Immutable, ), } @@ -69,28 +82,125 @@ impl ScalarUDFImpl for ContainsFunc { } fn invoke(&self, args: &[ColumnarValue]) -> Result { - match args[0].data_type() { - DataType::Utf8 => make_scalar_function(contains::, vec![])(args), - DataType::LargeUtf8 => make_scalar_function(contains::, vec![])(args), - other => { - exec_err!("unsupported data type {other:?} for function contains") - } - } + make_scalar_function(contains, vec![])(args) } } /// use regexp_is_match_utf8_scalar to do the calculation for contains -pub fn contains( - args: &[ArrayRef], -) -> Result { - let mod_str = as_generic_string_array::(&args[0])?; - let match_str = as_generic_string_array::(&args[1])?; - let res = arrow::compute::kernels::comparison::regexp_is_match_utf8( - mod_str, match_str, None, - ) - .map_err(|e| arrow_datafusion_err!(e))?; - - Ok(Arc::new(res) as ArrayRef) +pub fn contains(args: &[ArrayRef]) -> Result { + match (args[0].data_type(), args[1].data_type()) { + (Utf8View, Utf8View) => { + let mod_str = args[0].as_string_view(); + let match_str = args[1].as_string_view(); + let res = regexp_is_match::< + StringViewArray, + StringViewArray, + StringViewArray + >(mod_str, match_str, None) + .map_err(|error| error)?; + + Ok(Arc::new(res) as ArrayRef) + } + (Utf8View, Utf8) => { + let mod_str = args[0].as_string_view(); + let match_str = args[1].as_string::(); + let res = regexp_is_match::< + StringViewArray, + GenericStringArray, + StringViewArray, + >(mod_str, match_str, None) + .map_err(|error| error)?; + + Ok(Arc::new(res) as ArrayRef) + } + (Utf8View, LargeUtf8) => { + let mod_str = args[0].as_string_view(); + let match_str = args[1].as_string::(); + let res = regexp_is_match::< + StringViewArray, + GenericStringArray, + StringViewArray, + >(mod_str, match_str, None) + .map_err(|error| error)?; + + Ok(Arc::new(res) as ArrayRef) + } + (Utf8, Utf8View) => { + let mod_str = args[0].as_string::(); + let match_str = args[1].as_string_view(); + let res = regexp_is_match::< + GenericStringArray, + StringViewArray, + StringViewArray, + >(mod_str, match_str, None) + .map_err(|error| error)?; + + Ok(Arc::new(res) as ArrayRef) + } + (Utf8, Utf8) => { + let mod_str = args[0].as_string::(); + let match_str = args[1].as_string::(); + let res = regexp_is_match::< + GenericStringArray, + GenericStringArray, + StringViewArray, + >(mod_str, match_str, None) + .map_err(|error| error)?; + + Ok(Arc::new(res) as ArrayRef) + } + (Utf8, LargeUtf8) => { + let mod_str = args[0].as_string::(); + let match_str = args[1].as_string::(); + let res = regexp_is_match::< + GenericStringArray, + GenericStringArray, + StringViewArray, + >(mod_str, match_str, None) + .map_err(|error| error)?; + + Ok(Arc::new(res) as ArrayRef) + } + (LargeUtf8, Utf8View) => { + let mod_str = args[0].as_string::(); + let match_str = args[1].as_string_view(); + let res = regexp_is_match::< + GenericStringArray, + StringViewArray, + StringViewArray, + >(mod_str, match_str, None) + .map_err(|error| error)?; + + Ok(Arc::new(res) as ArrayRef) + } + (LargeUtf8, Utf8) => { + let mod_str = args[0].as_string::(); + let match_str = args[1].as_string::(); + let res = regexp_is_match::< + GenericStringArray, + GenericStringArray, + StringViewArray, + >(mod_str, match_str, None) + .map_err(|error| error)?; + + Ok(Arc::new(res) as ArrayRef) + } + (LargeUtf8, LargeUtf8) => { + let mod_str = args[0].as_string::(); + let match_str = args[1].as_string::(); + let res = regexp_is_match::< + GenericStringArray, + GenericStringArray, + StringViewArray, + >(mod_str, match_str, None) + .map_err(|error| error)?; + + Ok(Arc::new(res) as ArrayRef) + } + other => { + exec_err!("Unsupported data type {other:?} for function `contains`.") + } + } } #[cfg(test)] @@ -138,6 +248,53 @@ mod tests { Boolean, BooleanArray ); + + test_function!( + ContainsFunc::new(), + &[ + ColumnarValue::Scalar( + ScalarValue::Utf8View(Some(String::from("Apache"))) + ), + ColumnarValue::Scalar( + ScalarValue::Utf8View(Some(String::from("pac"))) + ), + ], + Ok(Some(true)), + bool, + Boolean, + BooleanArray + ); + test_function!( + ContainsFunc::new(), + &[ + ColumnarValue::Scalar( + ScalarValue::Utf8View(Some(String::from("Apache"))) + ), + ColumnarValue::Scalar( + ScalarValue::Utf8(Some(String::from("ap"))) + ), + ], + Ok(Some(false)), + bool, + Boolean, + BooleanArray + ); + test_function!( + ContainsFunc::new(), + &[ + ColumnarValue::Scalar( + ScalarValue::Utf8View(Some(String::from("Apache"))) + ), + ColumnarValue::Scalar( + ScalarValue::LargeUtf8(Some(String::from("DataFusion"))) + ), + ], + Ok(Some(false)), + bool, + Boolean, + BooleanArray + ); + Ok(()) } } diff --git a/datafusion/sqllogictest/test_files/string_view.slt b/datafusion/sqllogictest/test_files/string_view.slt index 3b3d7b88a4a1..738e50405cf4 100644 --- a/datafusion/sqllogictest/test_files/string_view.slt +++ b/datafusion/sqllogictest/test_files/string_view.slt @@ -21,11 +21,10 @@ statement ok create table test_source as values - ('Andrew', 'X'), - ('Xiangpeng', 'Xiangpeng'), - ('Raphael', 'R'), - (NULL, 'R') -; + ('Andrew', 'X'), + ('Xiangpeng', 'Xiangpeng'), + ('Raphael', 'R'), + (NULL, 'R'); # Table with the different combination of column types statement ok @@ -793,17 +792,40 @@ logical_plan 02)--TableScan: test projection=[column1_utf8view, column2_utf8view] ## Ensure no casts for CONTAINS -## TODO https://github.com/apache/datafusion/issues/11838 query TT EXPLAIN SELECT CONTAINS(column1_utf8view, 'foo') as c1, - CONTAINS(column2_utf8view, column2_utf8view) as c2 + CONTAINS(column1_utf8view, column2_utf8view) as c2, + CONTAINS(column1_utf8view, column2_large_utf8) as c3, + CONTAINS(column1_utf8, column2_utf8view) as c4, + CONTAINS(column1_utf8, column2_utf8) as c5, + CONTAINS(column1_utf8, column2_large_utf8) as c6, + CONTAINS(column1_large_utf8, column1_utf8view) as c7, + CONTAINS(column1_large_utf8, column2_utf8) as c8, + CONTAINS(column1_large_utf8, column2_large_utf8) as c9 FROM test; ---- logical_plan -01)Projection: contains(CAST(test.column1_utf8view AS Utf8), Utf8("foo")) AS c1, contains(__common_expr_1, __common_expr_1) AS c2 -02)--Projection: CAST(test.column2_utf8view AS Utf8) AS __common_expr_1, test.column1_utf8view -03)----TableScan: test projection=[column1_utf8view, column2_utf8view] +01)Projection: contains(test.column1_utf8view, Utf8("foo")) AS c1, contains(test.column1_utf8view, test.column2_utf8view) AS c2, contains(test.column1_utf8view, test.column2_large_utf8) AS c3, contains(test.column1_utf8, test.column2_utf8view) AS c4, contains(test.column1_utf8, test.column2_utf8) AS c5, contains(test.column1_utf8, test.column2_large_utf8) AS c6, contains(test.column1_large_utf8, test.column1_utf8view) AS c7, contains(test.column1_large_utf8, test.column2_utf8) AS c8, contains(test.column1_large_utf8, test.column2_large_utf8) AS c9 +02)--TableScan: test projection=[column1_utf8, column2_utf8, column1_large_utf8, column2_large_utf8, column1_utf8view, column2_utf8view] + +query BBBBBBBBB +SELECT + CONTAINS(column1_utf8view, 'foo') as c1, + CONTAINS(column1_utf8view, column2_utf8view) as c2, + CONTAINS(column1_utf8view, column2_large_utf8) as c3, + CONTAINS(column1_utf8, column2_utf8view) as c4, + CONTAINS(column1_utf8, column2_utf8) as c5, + CONTAINS(column1_utf8, column2_large_utf8) as c6, + CONTAINS(column1_large_utf8, column1_utf8view) as c7, + CONTAINS(column1_large_utf8, column2_utf8) as c8, + CONTAINS(column1_large_utf8, column2_large_utf8) as c9 +FROM test; +---- +false false false false false false true false false +false true true true true true true true true +false true true true true true true true true +NULL NULL NULL NULL NULL NULL NULL NULL NULL ## Ensure no casts for ENDS_WITH query TT diff --git a/docs/source/user-guide/sql/scalar_functions.md b/docs/source/user-guide/sql/scalar_functions.md index c7b3409ba7cd..fb0ef1087545 100644 --- a/docs/source/user-guide/sql/scalar_functions.md +++ b/docs/source/user-guide/sql/scalar_functions.md @@ -1454,7 +1454,7 @@ position(substr in origstr) ### `contains` -Return true if search_string is found within string. +Return true if search_string is found within string (case-sensitive). ``` contains(string, search_string) From 31305cc598668dc4e985f8cb5d455e032f45fe26 Mon Sep 17 00:00:00 2001 From: Tai Le Manh Date: Mon, 26 Aug 2024 15:32:33 +0700 Subject: [PATCH 2/7] Fix cargo fmt --- datafusion/functions/src/string/contains.rs | 43 +++++++++------------ 1 file changed, 19 insertions(+), 24 deletions(-) diff --git a/datafusion/functions/src/string/contains.rs b/datafusion/functions/src/string/contains.rs index 73c68cbf2115..47df963d2147 100644 --- a/datafusion/functions/src/string/contains.rs +++ b/datafusion/functions/src/string/contains.rs @@ -92,12 +92,11 @@ pub fn contains(args: &[ArrayRef]) -> Result { (Utf8View, Utf8View) => { let mod_str = args[0].as_string_view(); let match_str = args[1].as_string_view(); - let res = regexp_is_match::< - StringViewArray, - StringViewArray, - StringViewArray - >(mod_str, match_str, None) - .map_err(|error| error)?; + let res = + regexp_is_match::( + mod_str, match_str, None, + ) + .map_err(|error| error)?; Ok(Arc::new(res) as ArrayRef) } @@ -252,12 +251,10 @@ mod tests { test_function!( ContainsFunc::new(), &[ - ColumnarValue::Scalar( - ScalarValue::Utf8View(Some(String::from("Apache"))) - ), - ColumnarValue::Scalar( - ScalarValue::Utf8View(Some(String::from("pac"))) - ), + ColumnarValue::Scalar(ScalarValue::Utf8View(Some(String::from( + "Apache" + )))), + ColumnarValue::Scalar(ScalarValue::Utf8View(Some(String::from("pac")))), ], Ok(Some(true)), bool, @@ -267,12 +264,10 @@ mod tests { test_function!( ContainsFunc::new(), &[ - ColumnarValue::Scalar( - ScalarValue::Utf8View(Some(String::from("Apache"))) - ), - ColumnarValue::Scalar( - ScalarValue::Utf8(Some(String::from("ap"))) - ), + ColumnarValue::Scalar(ScalarValue::Utf8View(Some(String::from( + "Apache" + )))), + ColumnarValue::Scalar(ScalarValue::Utf8(Some(String::from("ap")))), ], Ok(Some(false)), bool, @@ -282,12 +277,12 @@ mod tests { test_function!( ContainsFunc::new(), &[ - ColumnarValue::Scalar( - ScalarValue::Utf8View(Some(String::from("Apache"))) - ), - ColumnarValue::Scalar( - ScalarValue::LargeUtf8(Some(String::from("DataFusion"))) - ), + ColumnarValue::Scalar(ScalarValue::Utf8View(Some(String::from( + "Apache" + )))), + ColumnarValue::Scalar(ScalarValue::LargeUtf8(Some(String::from( + "DataFusion" + )))), ], Ok(Some(false)), bool, From a78d95967cda7087a6aa0ade1058a457bca1c676 Mon Sep 17 00:00:00 2001 From: Tai Le Manh Date: Mon, 26 Aug 2024 16:37:07 +0700 Subject: [PATCH 3/7] Implement native support StringView for contains function Signed-off-by: Tai Le Manh --- datafusion/functions/src/regex/common.rs | 119 ++++++++++++++++++++ datafusion/functions/src/regex/mod.rs | 4 +- datafusion/functions/src/string/common.rs | 98 +--------------- datafusion/functions/src/string/contains.rs | 2 +- 4 files changed, 125 insertions(+), 98 deletions(-) create mode 100644 datafusion/functions/src/regex/common.rs diff --git a/datafusion/functions/src/regex/common.rs b/datafusion/functions/src/regex/common.rs new file mode 100644 index 000000000000..17753eea0cd6 --- /dev/null +++ b/datafusion/functions/src/regex/common.rs @@ -0,0 +1,119 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +//! Common utilities for implementing regex functions + +use std::collections::HashMap; + +use arrow::array::{Array, ArrayDataBuilder, BooleanArray}; +use arrow::datatypes::DataType; +use arrow::error::ArrowError; +use arrow_buffer::{BooleanBufferBuilder, NullBuffer}; +use datafusion_common::DataFusionError; +use regex::Regex; + +use crate::string::common::StringArrayType; + +/// Perform SQL `array ~ regex_array` operation on +/// [`StringArray`] / [`LargeStringArray`] / [`StringViewArray`]. +/// If `regex_array` element has an empty value, the corresponding result value is always true. +/// +/// `flags_array` are optional [`StringArray`] / [`LargeStringArray`] / [`StringViewArray`] flag, +/// which allow special search modes, such as case-insensitive and multi-line mode. +/// See the documentation [here](https://docs.rs/regex/1.5.4/regex/#grouping-and-flags) +/// for more information. +/// +/// It is inspired / copied from `regexp_is_match_utf8` [arrow-rs]. +/// +/// [arrow-rs]: https://github.com/apache/arrow-rs/blob/8c956a9f9ab26c14072740cce64c2b99cb039b13/arrow-string/src/regexp.rs#L31-L37 +pub fn regexp_is_match<'a, ArrayType1, ArrayType2, ArrayType3>( + array: &'a ArrayType1, + regex_array: &'a ArrayType2, + flags_array: Option<&'a ArrayType3>, +) -> datafusion_common::Result +where + &'a ArrayType1: StringArrayType<'a>, + &'a ArrayType2: StringArrayType<'a>, + &'a ArrayType3: StringArrayType<'a>, +{ + if array.len() != regex_array.len() { + return Err(DataFusionError::Execution( + "Cannot perform comparison operation on arrays of different length" + .to_string(), + )); + } + + let nulls = NullBuffer::union(array.nulls(), regex_array.nulls()); + + let mut patterns: HashMap = HashMap::new(); + let mut result = BooleanBufferBuilder::new(array.len()); + + let complete_pattern = match flags_array { + Some(flags) => Box::new(regex_array.iter().zip(flags.iter()).map( + |(pattern, flags)| { + pattern.map(|pattern| match flags { + Some(flag) => format!("(?{flag}){pattern}"), + None => pattern.to_string(), + }) + }, + )) as Box>>, + None => Box::new( + regex_array + .iter() + .map(|pattern| pattern.map(|pattern| pattern.to_string())), + ), + }; + + array + .iter() + .zip(complete_pattern) + .map(|(value, pattern)| { + match (value, pattern) { + (Some(_), Some(pattern)) if pattern == *"" => { + result.append(true); + } + (Some(value), Some(pattern)) => { + let existing_pattern = patterns.get(&pattern); + let re = match existing_pattern { + Some(re) => re, + None => { + let re = Regex::new(pattern.as_str()).map_err(|e| { + DataFusionError::Execution(format!( + "Regular expression did not compile: {e:?}" + )) + })?; + patterns.entry(pattern).or_insert(re) + } + }; + result.append(re.is_match(value)); + } + _ => result.append(false), + } + Ok(()) + }) + .collect::, ArrowError>>()?; + + let data = unsafe { + ArrayDataBuilder::new(DataType::Boolean) + .len(array.len()) + .buffers(vec![result.into()]) + .nulls(nulls) + .build_unchecked() + }; + + Ok(BooleanArray::from(data)) +} diff --git a/datafusion/functions/src/regex/mod.rs b/datafusion/functions/src/regex/mod.rs index 4ac162290ddb..bfb7e30fab6f 100644 --- a/datafusion/functions/src/regex/mod.rs +++ b/datafusion/functions/src/regex/mod.rs @@ -15,11 +15,13 @@ // specific language governing permissions and limitations // under the License. -//! "regx" DataFusion functions +//! "regex" DataFusion functions +pub mod common; pub mod regexplike; pub mod regexpmatch; pub mod regexpreplace; + // create UDFs make_udf_function!(regexpmatch::RegexpMatchFunc, REGEXP_MATCH, regexp_match); make_udf_function!(regexplike::RegexpLikeFunc, REGEXP_LIKE, regexp_like); diff --git a/datafusion/functions/src/string/common.rs b/datafusion/functions/src/string/common.rs index 9cd61ec4dacd..c8afd73163c3 100644 --- a/datafusion/functions/src/string/common.rs +++ b/datafusion/functions/src/string/common.rs @@ -17,24 +17,20 @@ //! Common utilities for implementing string functions -use std::collections::HashMap; use std::fmt::{Display, Formatter}; use std::sync::Arc; use arrow::array::{ new_null_array, Array, ArrayAccessor, ArrayDataBuilder, ArrayIter, ArrayRef, - BooleanArray, GenericStringArray, GenericStringBuilder, OffsetSizeTrait, StringArray, + GenericStringArray, GenericStringBuilder, OffsetSizeTrait, StringArray, StringBuilder, StringViewArray, }; use arrow::buffer::{Buffer, MutableBuffer, NullBuffer}; use arrow::datatypes::DataType; -use arrow::error::ArrowError; -use arrow_buffer::BooleanBufferBuilder; use datafusion_common::cast::{as_generic_string_array, as_string_view_array}; +use datafusion_common::Result; use datafusion_common::{exec_err, ScalarValue}; -use datafusion_common::{DataFusionError, Result}; use datafusion_expr::ColumnarValue; -use regex::Regex; pub(crate) enum TrimType { Left, @@ -482,93 +478,3 @@ where GenericStringArray::::new_unchecked(offsets, values, nulls) })) } - -/// Perform SQL `array ~ regex_array` operation on -/// [`StringArray`] / [`LargeStringArray`] / [`StringViewArray`]. -/// If `regex_array` element has an empty value, the corresponding result value is always true. -/// -/// `flags_array` are optional [`StringArray`] / [`LargeStringArray`] / [`StringViewArray`] flag, -/// which allow special search modes, such as case-insensitive and multi-line mode. -/// See the documentation [here](https://docs.rs/regex/1.5.4/regex/#grouping-and-flags) -/// for more information. -/// -/// It is inspired / copied from `regexp_is_match_utf8` [arrow-rs]. -/// -/// [arrow-rs]: https://github.com/apache/arrow-rs/blob/8c956a9f9ab26c14072740cce64c2b99cb039b13/arrow-string/src/regexp.rs#L31-L37 -pub fn regexp_is_match<'a, ArrayType1, ArrayType2, ArrayType3>( - array: &'a ArrayType1, - regex_array: &'a ArrayType2, - flags_array: Option<&'a ArrayType3>, -) -> Result -where - &'a ArrayType1: StringArrayType<'a>, - &'a ArrayType2: StringArrayType<'a>, - &'a ArrayType3: StringArrayType<'a>, -{ - if array.len() != regex_array.len() { - return Err(DataFusionError::Execution( - "Cannot perform comparison operation on arrays of different length" - .to_string(), - )); - } - - let nulls = NullBuffer::union(array.nulls(), regex_array.nulls()); - - let mut patterns: HashMap = HashMap::new(); - let mut result = BooleanBufferBuilder::new(array.len()); - - let complete_pattern = match flags_array { - Some(flags) => Box::new(regex_array.iter().zip(flags.iter()).map( - |(pattern, flags)| { - pattern.map(|pattern| match flags { - Some(flag) => format!("(?{flag}){pattern}"), - None => pattern.to_string(), - }) - }, - )) as Box>>, - None => Box::new( - regex_array - .iter() - .map(|pattern| pattern.map(|pattern| pattern.to_string())), - ), - }; - - array - .iter() - .zip(complete_pattern) - .map(|(value, pattern)| { - match (value, pattern) { - (Some(_), Some(pattern)) if pattern == *"" => { - result.append(true); - } - (Some(value), Some(pattern)) => { - let existing_pattern = patterns.get(&pattern); - let re = match existing_pattern { - Some(re) => re, - None => { - let re = Regex::new(pattern.as_str()).map_err(|e| { - DataFusionError::Execution(format!( - "Regular expression did not compile: {e:?}" - )) - })?; - patterns.entry(pattern).or_insert(re) - } - }; - result.append(re.is_match(value)); - } - _ => result.append(false), - } - Ok(()) - }) - .collect::, ArrowError>>()?; - - let data = unsafe { - ArrayDataBuilder::new(DataType::Boolean) - .len(array.len()) - .buffers(vec![result.into()]) - .nulls(nulls) - .build_unchecked() - }; - - Ok(BooleanArray::from(data)) -} diff --git a/datafusion/functions/src/string/contains.rs b/datafusion/functions/src/string/contains.rs index 47df963d2147..1a3688308017 100644 --- a/datafusion/functions/src/string/contains.rs +++ b/datafusion/functions/src/string/contains.rs @@ -15,7 +15,7 @@ // specific language governing permissions and limitations // under the License. -use crate::string::common::regexp_is_match; +use crate::regex::common::regexp_is_match; use crate::utils::make_scalar_function; use arrow::array::{Array, ArrayRef, AsArray, GenericStringArray, StringViewArray}; From 88aab468d283d8b0411f3d5ea2f80b0fb7b81361 Mon Sep 17 00:00:00 2001 From: Tai Le Manh Date: Mon, 26 Aug 2024 17:13:02 +0700 Subject: [PATCH 4/7] Fix cargo check --- datafusion/functions/Cargo.toml | 2 +- datafusion/functions/src/regex/common.rs | 119 -------------------- datafusion/functions/src/regex/mod.rs | 1 - datafusion/functions/src/string/common.rs | 97 +++++++++++++++- datafusion/functions/src/string/contains.rs | 2 +- 5 files changed, 97 insertions(+), 124 deletions(-) delete mode 100644 datafusion/functions/src/regex/common.rs diff --git a/datafusion/functions/Cargo.toml b/datafusion/functions/Cargo.toml index 337379a74670..c201cff9d67e 100644 --- a/datafusion/functions/Cargo.toml +++ b/datafusion/functions/Cargo.toml @@ -54,7 +54,7 @@ math_expressions = [] # enable regular expressions regex_expressions = ["regex"] # enable string functions -string_expressions = ["uuid"] +string_expressions = ["regex", "uuid"] # enable unicode functions unicode_expressions = ["hashbrown", "unicode-segmentation"] diff --git a/datafusion/functions/src/regex/common.rs b/datafusion/functions/src/regex/common.rs deleted file mode 100644 index 17753eea0cd6..000000000000 --- a/datafusion/functions/src/regex/common.rs +++ /dev/null @@ -1,119 +0,0 @@ -// Licensed to the Apache Software Foundation (ASF) under one -// or more contributor license agreements. See the NOTICE file -// distributed with this work for additional information -// regarding copyright ownership. The ASF licenses this file -// to you under the Apache License, Version 2.0 (the -// "License"); you may not use this file except in compliance -// with the License. You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, -// software distributed under the License is distributed on an -// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -// KIND, either express or implied. See the License for the -// specific language governing permissions and limitations -// under the License. - -//! Common utilities for implementing regex functions - -use std::collections::HashMap; - -use arrow::array::{Array, ArrayDataBuilder, BooleanArray}; -use arrow::datatypes::DataType; -use arrow::error::ArrowError; -use arrow_buffer::{BooleanBufferBuilder, NullBuffer}; -use datafusion_common::DataFusionError; -use regex::Regex; - -use crate::string::common::StringArrayType; - -/// Perform SQL `array ~ regex_array` operation on -/// [`StringArray`] / [`LargeStringArray`] / [`StringViewArray`]. -/// If `regex_array` element has an empty value, the corresponding result value is always true. -/// -/// `flags_array` are optional [`StringArray`] / [`LargeStringArray`] / [`StringViewArray`] flag, -/// which allow special search modes, such as case-insensitive and multi-line mode. -/// See the documentation [here](https://docs.rs/regex/1.5.4/regex/#grouping-and-flags) -/// for more information. -/// -/// It is inspired / copied from `regexp_is_match_utf8` [arrow-rs]. -/// -/// [arrow-rs]: https://github.com/apache/arrow-rs/blob/8c956a9f9ab26c14072740cce64c2b99cb039b13/arrow-string/src/regexp.rs#L31-L37 -pub fn regexp_is_match<'a, ArrayType1, ArrayType2, ArrayType3>( - array: &'a ArrayType1, - regex_array: &'a ArrayType2, - flags_array: Option<&'a ArrayType3>, -) -> datafusion_common::Result -where - &'a ArrayType1: StringArrayType<'a>, - &'a ArrayType2: StringArrayType<'a>, - &'a ArrayType3: StringArrayType<'a>, -{ - if array.len() != regex_array.len() { - return Err(DataFusionError::Execution( - "Cannot perform comparison operation on arrays of different length" - .to_string(), - )); - } - - let nulls = NullBuffer::union(array.nulls(), regex_array.nulls()); - - let mut patterns: HashMap = HashMap::new(); - let mut result = BooleanBufferBuilder::new(array.len()); - - let complete_pattern = match flags_array { - Some(flags) => Box::new(regex_array.iter().zip(flags.iter()).map( - |(pattern, flags)| { - pattern.map(|pattern| match flags { - Some(flag) => format!("(?{flag}){pattern}"), - None => pattern.to_string(), - }) - }, - )) as Box>>, - None => Box::new( - regex_array - .iter() - .map(|pattern| pattern.map(|pattern| pattern.to_string())), - ), - }; - - array - .iter() - .zip(complete_pattern) - .map(|(value, pattern)| { - match (value, pattern) { - (Some(_), Some(pattern)) if pattern == *"" => { - result.append(true); - } - (Some(value), Some(pattern)) => { - let existing_pattern = patterns.get(&pattern); - let re = match existing_pattern { - Some(re) => re, - None => { - let re = Regex::new(pattern.as_str()).map_err(|e| { - DataFusionError::Execution(format!( - "Regular expression did not compile: {e:?}" - )) - })?; - patterns.entry(pattern).or_insert(re) - } - }; - result.append(re.is_match(value)); - } - _ => result.append(false), - } - Ok(()) - }) - .collect::, ArrowError>>()?; - - let data = unsafe { - ArrayDataBuilder::new(DataType::Boolean) - .len(array.len()) - .buffers(vec![result.into()]) - .nulls(nulls) - .build_unchecked() - }; - - Ok(BooleanArray::from(data)) -} diff --git a/datafusion/functions/src/regex/mod.rs b/datafusion/functions/src/regex/mod.rs index bfb7e30fab6f..4afbe6cbbb89 100644 --- a/datafusion/functions/src/regex/mod.rs +++ b/datafusion/functions/src/regex/mod.rs @@ -17,7 +17,6 @@ //! "regex" DataFusion functions -pub mod common; pub mod regexplike; pub mod regexpmatch; pub mod regexpreplace; diff --git a/datafusion/functions/src/string/common.rs b/datafusion/functions/src/string/common.rs index c8afd73163c3..64ce523c249a 100644 --- a/datafusion/functions/src/string/common.rs +++ b/datafusion/functions/src/string/common.rs @@ -17,20 +17,23 @@ //! Common utilities for implementing string functions +use std::collections::HashMap; use std::fmt::{Display, Formatter}; use std::sync::Arc; use arrow::array::{ new_null_array, Array, ArrayAccessor, ArrayDataBuilder, ArrayIter, ArrayRef, - GenericStringArray, GenericStringBuilder, OffsetSizeTrait, StringArray, + BooleanArray, GenericStringArray, GenericStringBuilder, OffsetSizeTrait, StringArray, StringBuilder, StringViewArray, }; use arrow::buffer::{Buffer, MutableBuffer, NullBuffer}; use arrow::datatypes::DataType; +use arrow_buffer::BooleanBufferBuilder; use datafusion_common::cast::{as_generic_string_array, as_string_view_array}; -use datafusion_common::Result; use datafusion_common::{exec_err, ScalarValue}; +use datafusion_common::{DataFusionError, Result}; use datafusion_expr::ColumnarValue; +use regex::Regex; pub(crate) enum TrimType { Left, @@ -478,3 +481,93 @@ where GenericStringArray::::new_unchecked(offsets, values, nulls) })) } + +/// Perform SQL `array ~ regex_array` operation on +/// [`StringArray`] / [`LargeStringArray`] / [`StringViewArray`]. +/// If `regex_array` element has an empty value, the corresponding result value is always true. +/// +/// `flags_array` are optional [`StringArray`] / [`LargeStringArray`] / [`StringViewArray`] flag, +/// which allow special search modes, such as case-insensitive and multi-line mode. +/// See the documentation [here](https://docs.rs/regex/1.5.4/regex/#grouping-and-flags) +/// for more information. +/// +/// It is inspired / copied from `regexp_is_match_utf8` [arrow-rs]. +/// +/// [arrow-rs]: https://github.com/apache/arrow-rs/blob/8c956a9f9ab26c14072740cce64c2b99cb039b13/arrow-string/src/regexp.rs#L31-L37 +pub fn regexp_is_match<'a, ArrayType1, ArrayType2, ArrayType3>( + array: &'a ArrayType1, + regex_array: &'a ArrayType2, + flags_array: Option<&'a ArrayType3>, +) -> datafusion_common::Result +where + &'a ArrayType1: StringArrayType<'a>, + &'a ArrayType2: StringArrayType<'a>, + &'a ArrayType3: StringArrayType<'a>, +{ + if array.len() != regex_array.len() { + return Err(DataFusionError::Execution( + "Cannot perform comparison operation on arrays of different length" + .to_string(), + )); + } + + let nulls = NullBuffer::union(array.nulls(), regex_array.nulls()); + + let mut patterns: HashMap = HashMap::new(); + let mut result = BooleanBufferBuilder::new(array.len()); + + let complete_pattern = match flags_array { + Some(flags) => Box::new(regex_array.iter().zip(flags.iter()).map( + |(pattern, flags)| { + pattern.map(|pattern| match flags { + Some(flag) => format!("(?{flag}){pattern}"), + None => pattern.to_string(), + }) + }, + )) as Box>>, + None => Box::new( + regex_array + .iter() + .map(|pattern| pattern.map(|pattern| pattern.to_string())), + ), + }; + + array + .iter() + .zip(complete_pattern) + .map(|(value, pattern)| { + match (value, pattern) { + (Some(_), Some(pattern)) if pattern == *"" => { + result.append(true); + } + (Some(value), Some(pattern)) => { + let existing_pattern = patterns.get(&pattern); + let re = match existing_pattern { + Some(re) => re, + None => { + let re = Regex::new(pattern.as_str()).map_err(|e| { + DataFusionError::Execution(format!( + "Regular expression did not compile: {e:?}" + )) + })?; + patterns.entry(pattern).or_insert(re) + } + }; + result.append(re.is_match(value)); + } + _ => result.append(false), + } + Ok(()) + }) + .collect::, DataFusionError>>()?; + + let data = unsafe { + ArrayDataBuilder::new(DataType::Boolean) + .len(array.len()) + .buffers(vec![result.into()]) + .nulls(nulls) + .build_unchecked() + }; + + Ok(BooleanArray::from(data)) +} diff --git a/datafusion/functions/src/string/contains.rs b/datafusion/functions/src/string/contains.rs index 1a3688308017..47df963d2147 100644 --- a/datafusion/functions/src/string/contains.rs +++ b/datafusion/functions/src/string/contains.rs @@ -15,7 +15,7 @@ // specific language governing permissions and limitations // under the License. -use crate::regex::common::regexp_is_match; +use crate::string::common::regexp_is_match; use crate::utils::make_scalar_function; use arrow::array::{Array, ArrayRef, AsArray, GenericStringArray, StringViewArray}; From 45dd1416b214e9581c9adbdd7b53afdccf127e76 Mon Sep 17 00:00:00 2001 From: Tai Le Manh Date: Tue, 27 Aug 2024 00:34:32 +0700 Subject: [PATCH 5/7] Fix unresolved doc link --- datafusion/functions/src/string/common.rs | 21 +++++---- datafusion/functions/src/string/contains.rs | 50 +++++++++------------ 2 files changed, 33 insertions(+), 38 deletions(-) diff --git a/datafusion/functions/src/string/common.rs b/datafusion/functions/src/string/common.rs index 64ce523c249a..805e5a7f30b9 100644 --- a/datafusion/functions/src/string/common.rs +++ b/datafusion/functions/src/string/common.rs @@ -482,8 +482,11 @@ where })) } +#[cfg(doc)] +use arrow::array::LargeStringArray; /// Perform SQL `array ~ regex_array` operation on /// [`StringArray`] / [`LargeStringArray`] / [`StringViewArray`]. +/// /// If `regex_array` element has an empty value, the corresponding result value is always true. /// /// `flags_array` are optional [`StringArray`] / [`LargeStringArray`] / [`StringViewArray`] flag, @@ -494,15 +497,15 @@ where /// It is inspired / copied from `regexp_is_match_utf8` [arrow-rs]. /// /// [arrow-rs]: https://github.com/apache/arrow-rs/blob/8c956a9f9ab26c14072740cce64c2b99cb039b13/arrow-string/src/regexp.rs#L31-L37 -pub fn regexp_is_match<'a, ArrayType1, ArrayType2, ArrayType3>( - array: &'a ArrayType1, - regex_array: &'a ArrayType2, - flags_array: Option<&'a ArrayType3>, -) -> datafusion_common::Result +pub fn regexp_is_match<'a, S1, S2, S3>( + array: &'a S1, + regex_array: &'a S2, + flags_array: Option<&'a S3>, +) -> Result where - &'a ArrayType1: StringArrayType<'a>, - &'a ArrayType2: StringArrayType<'a>, - &'a ArrayType3: StringArrayType<'a>, + &'a S1: StringArrayType<'a>, + &'a S2: StringArrayType<'a>, + &'a S3: StringArrayType<'a>, { if array.len() != regex_array.len() { return Err(DataFusionError::Execution( @@ -559,7 +562,7 @@ where } Ok(()) }) - .collect::, DataFusionError>>()?; + .collect::, DataFusionError>>()?; let data = unsafe { ArrayDataBuilder::new(DataType::Boolean) diff --git a/datafusion/functions/src/string/contains.rs b/datafusion/functions/src/string/contains.rs index 47df963d2147..8b80317696e9 100644 --- a/datafusion/functions/src/string/contains.rs +++ b/datafusion/functions/src/string/contains.rs @@ -92,11 +92,11 @@ pub fn contains(args: &[ArrayRef]) -> Result { (Utf8View, Utf8View) => { let mod_str = args[0].as_string_view(); let match_str = args[1].as_string_view(); - let res = - regexp_is_match::( - mod_str, match_str, None, - ) - .map_err(|error| error)?; + let res = regexp_is_match::< + StringViewArray, + StringViewArray, + GenericStringArray, + >(mod_str, match_str, None)?; Ok(Arc::new(res) as ArrayRef) } @@ -106,9 +106,8 @@ pub fn contains(args: &[ArrayRef]) -> Result { let res = regexp_is_match::< StringViewArray, GenericStringArray, - StringViewArray, - >(mod_str, match_str, None) - .map_err(|error| error)?; + GenericStringArray, + >(mod_str, match_str, None)?; Ok(Arc::new(res) as ArrayRef) } @@ -118,9 +117,8 @@ pub fn contains(args: &[ArrayRef]) -> Result { let res = regexp_is_match::< StringViewArray, GenericStringArray, - StringViewArray, - >(mod_str, match_str, None) - .map_err(|error| error)?; + GenericStringArray, + >(mod_str, match_str, None)?; Ok(Arc::new(res) as ArrayRef) } @@ -130,9 +128,8 @@ pub fn contains(args: &[ArrayRef]) -> Result { let res = regexp_is_match::< GenericStringArray, StringViewArray, - StringViewArray, - >(mod_str, match_str, None) - .map_err(|error| error)?; + GenericStringArray, + >(mod_str, match_str, None)?; Ok(Arc::new(res) as ArrayRef) } @@ -142,9 +139,8 @@ pub fn contains(args: &[ArrayRef]) -> Result { let res = regexp_is_match::< GenericStringArray, GenericStringArray, - StringViewArray, - >(mod_str, match_str, None) - .map_err(|error| error)?; + GenericStringArray, + >(mod_str, match_str, None)?; Ok(Arc::new(res) as ArrayRef) } @@ -154,9 +150,8 @@ pub fn contains(args: &[ArrayRef]) -> Result { let res = regexp_is_match::< GenericStringArray, GenericStringArray, - StringViewArray, - >(mod_str, match_str, None) - .map_err(|error| error)?; + GenericStringArray, + >(mod_str, match_str, None)?; Ok(Arc::new(res) as ArrayRef) } @@ -166,9 +161,8 @@ pub fn contains(args: &[ArrayRef]) -> Result { let res = regexp_is_match::< GenericStringArray, StringViewArray, - StringViewArray, - >(mod_str, match_str, None) - .map_err(|error| error)?; + GenericStringArray, + >(mod_str, match_str, None)?; Ok(Arc::new(res) as ArrayRef) } @@ -178,9 +172,8 @@ pub fn contains(args: &[ArrayRef]) -> Result { let res = regexp_is_match::< GenericStringArray, GenericStringArray, - StringViewArray, - >(mod_str, match_str, None) - .map_err(|error| error)?; + GenericStringArray, + >(mod_str, match_str, None)?; Ok(Arc::new(res) as ArrayRef) } @@ -190,9 +183,8 @@ pub fn contains(args: &[ArrayRef]) -> Result { let res = regexp_is_match::< GenericStringArray, GenericStringArray, - StringViewArray, - >(mod_str, match_str, None) - .map_err(|error| error)?; + GenericStringArray, + >(mod_str, match_str, None)?; Ok(Arc::new(res) as ArrayRef) } From 4344bc85d9bf4e9754b7b57c52cb65bd2d9544b8 Mon Sep 17 00:00:00 2001 From: Tai Le Manh Date: Fri, 6 Sep 2024 09:39:36 +0700 Subject: [PATCH 6/7] Implement native support StringView for contains function Signed-off-by: Tai Le Manh --- datafusion/functions/Cargo.toml | 2 +- datafusion/functions/src/lib.rs | 3 + datafusion/functions/src/regexp_common.rs | 121 ++++++++++++++++++++ datafusion/functions/src/string/common.rs | 100 +--------------- datafusion/functions/src/string/contains.rs | 20 ++-- 5 files changed, 137 insertions(+), 109 deletions(-) create mode 100644 datafusion/functions/src/regexp_common.rs diff --git a/datafusion/functions/Cargo.toml b/datafusion/functions/Cargo.toml index c201cff9d67e..7888d72e7d67 100644 --- a/datafusion/functions/Cargo.toml +++ b/datafusion/functions/Cargo.toml @@ -54,7 +54,7 @@ math_expressions = [] # enable regular expressions regex_expressions = ["regex"] # enable string functions -string_expressions = ["regex", "uuid"] +string_expressions = ["regex_expressions", "uuid"] # enable unicode functions unicode_expressions = ["hashbrown", "unicode-segmentation"] diff --git a/datafusion/functions/src/lib.rs b/datafusion/functions/src/lib.rs index 81be5552666d..bb680f3c67de 100644 --- a/datafusion/functions/src/lib.rs +++ b/datafusion/functions/src/lib.rs @@ -92,6 +92,9 @@ pub mod macros; pub mod string; make_stub_package!(string, "string_expressions"); +#[cfg(feature = "string_expressions")] +mod regexp_common; + /// Core datafusion expressions /// Enabled via feature flag `core_expressions` #[cfg(feature = "core_expressions")] diff --git a/datafusion/functions/src/regexp_common.rs b/datafusion/functions/src/regexp_common.rs new file mode 100644 index 000000000000..582ef639173e --- /dev/null +++ b/datafusion/functions/src/regexp_common.rs @@ -0,0 +1,121 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +//! Common utilities for implementing regex functions + +use crate::string::common::StringArrayType; + +use arrow::array::{Array, ArrayDataBuilder, BooleanArray}; +use arrow::datatypes::DataType; +use arrow_buffer::{BooleanBufferBuilder, NullBuffer}; +use datafusion_common::DataFusionError; +use regex::Regex; + +use std::collections::HashMap; + +#[cfg(doc)] +use arrow::array::{LargeStringArray, StringArray, StringViewArray}; +/// Perform SQL `array ~ regex_array` operation on +/// [`StringArray`] / [`LargeStringArray`] / [`StringViewArray`]. +/// +/// If `regex_array` element has an empty value, the corresponding result value is always true. +/// +/// `flags_array` are optional [`StringArray`] / [`LargeStringArray`] / [`StringViewArray`] flag, +/// which allow special search modes, such as case-insensitive and multi-line mode. +/// See the documentation [here](https://docs.rs/regex/1.5.4/regex/#grouping-and-flags) +/// for more information. +/// +/// It is inspired / copied from `regexp_is_match_utf8` [arrow-rs]. +/// +/// [arrow-rs]: https://github.com/apache/arrow-rs/blob/8c956a9f9ab26c14072740cce64c2b99cb039b13/arrow-string/src/regexp.rs#L31-L37 +pub fn regexp_is_match_utf8<'a, S1, S2, S3>( + array: &'a S1, + regex_array: &'a S2, + flags_array: Option<&'a S3>, +) -> datafusion_common::Result +where + &'a S1: StringArrayType<'a>, + &'a S2: StringArrayType<'a>, + &'a S3: StringArrayType<'a>, +{ + if array.len() != regex_array.len() { + return Err(DataFusionError::Execution( + "Cannot perform comparison operation on arrays of different length" + .to_string(), + )); + } + + let nulls = NullBuffer::union(array.nulls(), regex_array.nulls()); + + let mut patterns: HashMap = HashMap::new(); + let mut result = BooleanBufferBuilder::new(array.len()); + + let complete_pattern = match flags_array { + Some(flags) => Box::new(regex_array.iter().zip(flags.iter()).map( + |(pattern, flags)| { + pattern.map(|pattern| match flags { + Some(flag) => format!("(?{flag}){pattern}"), + None => pattern.to_string(), + }) + }, + )) as Box>>, + None => Box::new( + regex_array + .iter() + .map(|pattern| pattern.map(|pattern| pattern.to_string())), + ), + }; + + array + .iter() + .zip(complete_pattern) + .map(|(value, pattern)| { + match (value, pattern) { + (Some(_), Some(pattern)) if pattern == *"" => { + result.append(true); + } + (Some(value), Some(pattern)) => { + let existing_pattern = patterns.get(&pattern); + let re = match existing_pattern { + Some(re) => re, + None => { + let re = Regex::new(pattern.as_str()).map_err(|e| { + DataFusionError::Execution(format!( + "Regular expression did not compile: {e:?}" + )) + })?; + patterns.entry(pattern).or_insert(re) + } + }; + result.append(re.is_match(value)); + } + _ => result.append(false), + } + Ok(()) + }) + .collect::, DataFusionError>>()?; + + let data = unsafe { + ArrayDataBuilder::new(DataType::Boolean) + .len(array.len()) + .buffers(vec![result.into()]) + .nulls(nulls) + .build_unchecked() + }; + + Ok(BooleanArray::from(data)) +} diff --git a/datafusion/functions/src/string/common.rs b/datafusion/functions/src/string/common.rs index 805e5a7f30b9..c8afd73163c3 100644 --- a/datafusion/functions/src/string/common.rs +++ b/datafusion/functions/src/string/common.rs @@ -17,23 +17,20 @@ //! Common utilities for implementing string functions -use std::collections::HashMap; use std::fmt::{Display, Formatter}; use std::sync::Arc; use arrow::array::{ new_null_array, Array, ArrayAccessor, ArrayDataBuilder, ArrayIter, ArrayRef, - BooleanArray, GenericStringArray, GenericStringBuilder, OffsetSizeTrait, StringArray, + GenericStringArray, GenericStringBuilder, OffsetSizeTrait, StringArray, StringBuilder, StringViewArray, }; use arrow::buffer::{Buffer, MutableBuffer, NullBuffer}; use arrow::datatypes::DataType; -use arrow_buffer::BooleanBufferBuilder; use datafusion_common::cast::{as_generic_string_array, as_string_view_array}; +use datafusion_common::Result; use datafusion_common::{exec_err, ScalarValue}; -use datafusion_common::{DataFusionError, Result}; use datafusion_expr::ColumnarValue; -use regex::Regex; pub(crate) enum TrimType { Left, @@ -481,96 +478,3 @@ where GenericStringArray::::new_unchecked(offsets, values, nulls) })) } - -#[cfg(doc)] -use arrow::array::LargeStringArray; -/// Perform SQL `array ~ regex_array` operation on -/// [`StringArray`] / [`LargeStringArray`] / [`StringViewArray`]. -/// -/// If `regex_array` element has an empty value, the corresponding result value is always true. -/// -/// `flags_array` are optional [`StringArray`] / [`LargeStringArray`] / [`StringViewArray`] flag, -/// which allow special search modes, such as case-insensitive and multi-line mode. -/// See the documentation [here](https://docs.rs/regex/1.5.4/regex/#grouping-and-flags) -/// for more information. -/// -/// It is inspired / copied from `regexp_is_match_utf8` [arrow-rs]. -/// -/// [arrow-rs]: https://github.com/apache/arrow-rs/blob/8c956a9f9ab26c14072740cce64c2b99cb039b13/arrow-string/src/regexp.rs#L31-L37 -pub fn regexp_is_match<'a, S1, S2, S3>( - array: &'a S1, - regex_array: &'a S2, - flags_array: Option<&'a S3>, -) -> Result -where - &'a S1: StringArrayType<'a>, - &'a S2: StringArrayType<'a>, - &'a S3: StringArrayType<'a>, -{ - if array.len() != regex_array.len() { - return Err(DataFusionError::Execution( - "Cannot perform comparison operation on arrays of different length" - .to_string(), - )); - } - - let nulls = NullBuffer::union(array.nulls(), regex_array.nulls()); - - let mut patterns: HashMap = HashMap::new(); - let mut result = BooleanBufferBuilder::new(array.len()); - - let complete_pattern = match flags_array { - Some(flags) => Box::new(regex_array.iter().zip(flags.iter()).map( - |(pattern, flags)| { - pattern.map(|pattern| match flags { - Some(flag) => format!("(?{flag}){pattern}"), - None => pattern.to_string(), - }) - }, - )) as Box>>, - None => Box::new( - regex_array - .iter() - .map(|pattern| pattern.map(|pattern| pattern.to_string())), - ), - }; - - array - .iter() - .zip(complete_pattern) - .map(|(value, pattern)| { - match (value, pattern) { - (Some(_), Some(pattern)) if pattern == *"" => { - result.append(true); - } - (Some(value), Some(pattern)) => { - let existing_pattern = patterns.get(&pattern); - let re = match existing_pattern { - Some(re) => re, - None => { - let re = Regex::new(pattern.as_str()).map_err(|e| { - DataFusionError::Execution(format!( - "Regular expression did not compile: {e:?}" - )) - })?; - patterns.entry(pattern).or_insert(re) - } - }; - result.append(re.is_match(value)); - } - _ => result.append(false), - } - Ok(()) - }) - .collect::, DataFusionError>>()?; - - let data = unsafe { - ArrayDataBuilder::new(DataType::Boolean) - .len(array.len()) - .buffers(vec![result.into()]) - .nulls(nulls) - .build_unchecked() - }; - - Ok(BooleanArray::from(data)) -} diff --git a/datafusion/functions/src/string/contains.rs b/datafusion/functions/src/string/contains.rs index 8b80317696e9..c319f80661c3 100644 --- a/datafusion/functions/src/string/contains.rs +++ b/datafusion/functions/src/string/contains.rs @@ -15,7 +15,7 @@ // specific language governing permissions and limitations // under the License. -use crate::string::common::regexp_is_match; +use crate::regexp_common::regexp_is_match_utf8; use crate::utils::make_scalar_function; use arrow::array::{Array, ArrayRef, AsArray, GenericStringArray, StringViewArray}; @@ -92,7 +92,7 @@ pub fn contains(args: &[ArrayRef]) -> Result { (Utf8View, Utf8View) => { let mod_str = args[0].as_string_view(); let match_str = args[1].as_string_view(); - let res = regexp_is_match::< + let res = regexp_is_match_utf8::< StringViewArray, StringViewArray, GenericStringArray, @@ -103,7 +103,7 @@ pub fn contains(args: &[ArrayRef]) -> Result { (Utf8View, Utf8) => { let mod_str = args[0].as_string_view(); let match_str = args[1].as_string::(); - let res = regexp_is_match::< + let res = regexp_is_match_utf8::< StringViewArray, GenericStringArray, GenericStringArray, @@ -114,7 +114,7 @@ pub fn contains(args: &[ArrayRef]) -> Result { (Utf8View, LargeUtf8) => { let mod_str = args[0].as_string_view(); let match_str = args[1].as_string::(); - let res = regexp_is_match::< + let res = regexp_is_match_utf8::< StringViewArray, GenericStringArray, GenericStringArray, @@ -125,7 +125,7 @@ pub fn contains(args: &[ArrayRef]) -> Result { (Utf8, Utf8View) => { let mod_str = args[0].as_string::(); let match_str = args[1].as_string_view(); - let res = regexp_is_match::< + let res = regexp_is_match_utf8::< GenericStringArray, StringViewArray, GenericStringArray, @@ -136,7 +136,7 @@ pub fn contains(args: &[ArrayRef]) -> Result { (Utf8, Utf8) => { let mod_str = args[0].as_string::(); let match_str = args[1].as_string::(); - let res = regexp_is_match::< + let res = regexp_is_match_utf8::< GenericStringArray, GenericStringArray, GenericStringArray, @@ -147,7 +147,7 @@ pub fn contains(args: &[ArrayRef]) -> Result { (Utf8, LargeUtf8) => { let mod_str = args[0].as_string::(); let match_str = args[1].as_string::(); - let res = regexp_is_match::< + let res = regexp_is_match_utf8::< GenericStringArray, GenericStringArray, GenericStringArray, @@ -158,7 +158,7 @@ pub fn contains(args: &[ArrayRef]) -> Result { (LargeUtf8, Utf8View) => { let mod_str = args[0].as_string::(); let match_str = args[1].as_string_view(); - let res = regexp_is_match::< + let res = regexp_is_match_utf8::< GenericStringArray, StringViewArray, GenericStringArray, @@ -169,7 +169,7 @@ pub fn contains(args: &[ArrayRef]) -> Result { (LargeUtf8, Utf8) => { let mod_str = args[0].as_string::(); let match_str = args[1].as_string::(); - let res = regexp_is_match::< + let res = regexp_is_match_utf8::< GenericStringArray, GenericStringArray, GenericStringArray, @@ -180,7 +180,7 @@ pub fn contains(args: &[ArrayRef]) -> Result { (LargeUtf8, LargeUtf8) => { let mod_str = args[0].as_string::(); let match_str = args[1].as_string::(); - let res = regexp_is_match::< + let res = regexp_is_match_utf8::< GenericStringArray, GenericStringArray, GenericStringArray, From 29eb13268a4f599821ad5ab781787df6d26ca32e Mon Sep 17 00:00:00 2001 From: Andrew Lamb Date: Mon, 9 Sep 2024 08:47:43 -0400 Subject: [PATCH 7/7] Update datafusion/functions/src/regexp_common.rs --- datafusion/functions/src/regexp_common.rs | 2 ++ 1 file changed, 2 insertions(+) diff --git a/datafusion/functions/src/regexp_common.rs b/datafusion/functions/src/regexp_common.rs index 582ef639173e..748c1a294f97 100644 --- a/datafusion/functions/src/regexp_common.rs +++ b/datafusion/functions/src/regexp_common.rs @@ -41,6 +41,8 @@ use arrow::array::{LargeStringArray, StringArray, StringViewArray}; /// /// It is inspired / copied from `regexp_is_match_utf8` [arrow-rs]. /// +/// Can remove when is implemented upstream +/// /// [arrow-rs]: https://github.com/apache/arrow-rs/blob/8c956a9f9ab26c14072740cce64c2b99cb039b13/arrow-string/src/regexp.rs#L31-L37 pub fn regexp_is_match_utf8<'a, S1, S2, S3>( array: &'a S1,