diff --git a/datafusion/functions/src/unicode/substr.rs b/datafusion/functions/src/unicode/substr.rs index 9d15920bb655..9fd8c75eab23 100644 --- a/datafusion/functions/src/unicode/substr.rs +++ b/datafusion/functions/src/unicode/substr.rs @@ -19,10 +19,12 @@ use std::any::Any; use std::cmp::max; use std::sync::Arc; -use arrow::array::{ArrayRef, GenericStringArray, OffsetSizeTrait}; +use arrow::array::{ + ArrayAccessor, ArrayIter, ArrayRef, AsArray, GenericStringArray, OffsetSizeTrait, +}; use arrow::datatypes::DataType; -use datafusion_common::cast::{as_generic_string_array, as_int64_array}; +use datafusion_common::cast::as_int64_array; use datafusion_common::{exec_err, Result}; use datafusion_expr::TypeSignature::Exact; use datafusion_expr::{ColumnarValue, ScalarUDFImpl, Signature, Volatility}; @@ -51,6 +53,8 @@ impl SubstrFunc { Exact(vec![LargeUtf8, Int64]), Exact(vec![Utf8, Int64, Int64]), Exact(vec![LargeUtf8, Int64, Int64]), + Exact(vec![Utf8View, Int64]), + Exact(vec![Utf8View, Int64, Int64]), ], Volatility::Immutable, ), @@ -77,11 +81,7 @@ impl ScalarUDFImpl for SubstrFunc { } fn invoke(&self, args: &[ColumnarValue]) -> Result { - match args[0].data_type() { - DataType::Utf8 => make_scalar_function(substr::, vec![])(args), - DataType::LargeUtf8 => make_scalar_function(substr::, vec![])(args), - other => exec_err!("Unsupported data type {other:?} for function substr"), - } + make_scalar_function(substr, vec![])(args) } fn aliases(&self) -> &[String] { @@ -89,18 +89,39 @@ impl ScalarUDFImpl for SubstrFunc { } } +pub fn substr(args: &[ArrayRef]) -> Result { + match args[0].data_type() { + DataType::Utf8 => { + let string_array = args[0].as_string::(); + calculate_substr::<_, i32>(string_array, &args[1..]) + } + DataType::LargeUtf8 => { + let string_array = args[0].as_string::(); + calculate_substr::<_, i64>(string_array, &args[1..]) + } + DataType::Utf8View => { + let string_array = args[0].as_string_view(); + calculate_substr::<_, i32>(string_array, &args[1..]) + } + other => exec_err!("Unsupported data type {other:?} for function substr"), + } +} + /// Extracts the substring of string starting at the start'th character, and extending for count characters if that is specified. (Same as substring(string from start for count).) /// substr('alphabet', 3) = 'phabet' /// substr('alphabet', 3, 2) = 'ph' /// The implementation uses UTF-8 code points as characters -pub fn substr(args: &[ArrayRef]) -> Result { +fn calculate_substr<'a, V, T>(string_array: V, args: &[ArrayRef]) -> Result +where + V: ArrayAccessor, + T: OffsetSizeTrait, +{ match args.len() { - 2 => { - let string_array = as_generic_string_array::(&args[0])?; - let start_array = as_int64_array(&args[1])?; + 1 => { + let iter = ArrayIter::new(string_array); + let start_array = as_int64_array(&args[0])?; - let result = string_array - .iter() + let result = iter .zip(start_array.iter()) .map(|(string, start)| match (string, start) { (Some(string), Some(start)) => { @@ -113,16 +134,14 @@ pub fn substr(args: &[ArrayRef]) -> Result { _ => None, }) .collect::>(); - Ok(Arc::new(result) as ArrayRef) } - 3 => { - let string_array = as_generic_string_array::(&args[0])?; - let start_array = as_int64_array(&args[1])?; - let count_array = as_int64_array(&args[2])?; + 2 => { + let iter = ArrayIter::new(string_array); + let start_array = as_int64_array(&args[0])?; + let count_array = as_int64_array(&args[1])?; - let result = string_array - .iter() + let result = iter .zip(start_array.iter()) .zip(count_array.iter()) .map(|((string, start), count)| match (string, start, count) { @@ -162,6 +181,71 @@ mod tests { #[test] fn test_functions() -> Result<()> { + test_function!( + SubstrFunc::new(), + &[ + ColumnarValue::Scalar(ScalarValue::Utf8View(None)), + ColumnarValue::Scalar(ScalarValue::from(1i64)), + ], + Ok(None), + &str, + Utf8, + StringArray + ); + test_function!( + SubstrFunc::new(), + &[ + ColumnarValue::Scalar(ScalarValue::Utf8View(Some(String::from( + "alphabet" + )))), + ColumnarValue::Scalar(ScalarValue::from(0i64)), + ], + Ok(Some("alphabet")), + &str, + Utf8, + StringArray + ); + test_function!( + SubstrFunc::new(), + &[ + ColumnarValue::Scalar(ScalarValue::Utf8View(Some(String::from( + "joséésoj" + )))), + ColumnarValue::Scalar(ScalarValue::from(5i64)), + ], + Ok(Some("ésoj")), + &str, + Utf8, + StringArray + ); + test_function!( + SubstrFunc::new(), + &[ + ColumnarValue::Scalar(ScalarValue::Utf8View(Some(String::from( + "alphabet" + )))), + ColumnarValue::Scalar(ScalarValue::from(3i64)), + ColumnarValue::Scalar(ScalarValue::from(2i64)), + ], + Ok(Some("ph")), + &str, + Utf8, + StringArray + ); + test_function!( + SubstrFunc::new(), + &[ + ColumnarValue::Scalar(ScalarValue::Utf8View(Some(String::from( + "alphabet" + )))), + ColumnarValue::Scalar(ScalarValue::from(3i64)), + ColumnarValue::Scalar(ScalarValue::from(20i64)), + ], + Ok(Some("phabet")), + &str, + Utf8, + StringArray + ); test_function!( SubstrFunc::new(), &[ diff --git a/datafusion/sqllogictest/test_files/string_view.slt b/datafusion/sqllogictest/test_files/string_view.slt index e094bcaf1b5d..82a714a432ba 100644 --- a/datafusion/sqllogictest/test_files/string_view.slt +++ b/datafusion/sqllogictest/test_files/string_view.slt @@ -521,7 +521,30 @@ logical_plan 01)Projection: test.column1_utf8view LIKE Utf8View("foo") AS like, test.column1_utf8view ILIKE Utf8View("foo") AS ilike 02)--TableScan: test projection=[column1_utf8view] +## Ensure no casts for SUBSTR +query TT +EXPLAIN SELECT + SUBSTR(column1_utf8view, 1, 3) as c1, + SUBSTR(column2_utf8, 1, 3) as c2, + SUBSTR(column2_large_utf8, 1, 3) as c3 +FROM test; +---- +logical_plan +01)Projection: substr(test.column1_utf8view, Int64(1), Int64(3)) AS c1, substr(test.column2_utf8, Int64(1), Int64(3)) AS c2, substr(test.column2_large_utf8, Int64(1), Int64(3)) AS c3 +02)--TableScan: test projection=[column2_utf8, column2_large_utf8, column1_utf8view] + +query TTT +SELECT + SUBSTR(column1_utf8view, 1, 3) as c1, + SUBSTR(column2_utf8, 1, 3) as c2, + SUBSTR(column2_large_utf8, 1, 3) as c3 +FROM test; +---- +And X X +Xia Xia Xia +Rap R R +NULL R R ## Ensure no casts for ASCII @@ -1047,9 +1070,8 @@ EXPLAIN SELECT FROM test; ---- logical_plan -01)Projection: substr(__common_expr_1, Int64(1)) AS c, substr(__common_expr_1, Int64(1), Int64(2)) AS c2 -02)--Projection: CAST(test.column1_utf8view AS Utf8) AS __common_expr_1 -03)----TableScan: test projection=[column1_utf8view] +01)Projection: substr(test.column1_utf8view, Int64(1)) AS c, substr(test.column1_utf8view, Int64(1), Int64(2)) AS c2 +02)--TableScan: test projection=[column1_utf8view] ## Ensure no casts for SUBSTRINDEX query TT