From 2cdf18bfebf4965533fff37fc74006403ba61db9 Mon Sep 17 00:00:00 2001 From: Tai Le Manh Date: Tue, 13 Aug 2024 17:51:26 +0700 Subject: [PATCH 1/2] Implement native support StringView for REPEAT Signed-off-by: Tai Le Manh --- datafusion/functions/src/string/repeat.rs | 82 ++++++++++++++++--- .../sqllogictest/test_files/string_view.slt | 3 +- 2 files changed, 71 insertions(+), 14 deletions(-) diff --git a/datafusion/functions/src/string/repeat.rs b/datafusion/functions/src/string/repeat.rs index 9d122f6101a7..713b371e3fda 100644 --- a/datafusion/functions/src/string/repeat.rs +++ b/datafusion/functions/src/string/repeat.rs @@ -18,10 +18,10 @@ use std::any::Any; use std::sync::Arc; -use arrow::array::{ArrayRef, GenericStringArray, OffsetSizeTrait}; +use arrow::array::{ArrayRef, GenericStringArray, OffsetSizeTrait, StringArray}; use arrow::datatypes::DataType; -use datafusion_common::cast::{as_generic_string_array, as_int64_array}; +use datafusion_common::cast::{as_generic_string_array, as_int64_array, as_string_view_array}; use datafusion_common::{exec_err, Result}; use datafusion_expr::TypeSignature::*; use datafusion_expr::{ColumnarValue, Volatility}; @@ -45,7 +45,14 @@ impl RepeatFunc { use DataType::*; Self { signature: Signature::one_of( - vec![Exact(vec![Utf8, Int64]), Exact(vec![LargeUtf8, Int64])], + vec![ + // Planner attempts coercion to the target type starting with the most preferred candidate. + // For example, given input `(Utf8View, Int64)`, it first tries coercing to `(Utf8View, Int64)`. + // If that fails, it proceeds to `(Utf8, Int64)`. + Exact(vec![Utf8View, Int64]), + Exact(vec![Utf8, Int64]), + Exact(vec![LargeUtf8, Int64]), + ], Volatility::Immutable, ), } @@ -71,9 +78,10 @@ impl ScalarUDFImpl for RepeatFunc { fn invoke(&self, args: &[ColumnarValue]) -> Result { match args[0].data_type() { + DataType::Utf8View => make_scalar_function(repeat_utf8view, vec![])(args), DataType::Utf8 => make_scalar_function(repeat::, vec![])(args), DataType::LargeUtf8 => make_scalar_function(repeat::, vec![])(args), - other => exec_err!("Unsupported data type {other:?} for function repeat"), + other => exec_err!("Unsupported data type {other:?} for function repeat. Expected Utf8, Utf8View or LargeUtf8"), } } } @@ -87,18 +95,35 @@ fn repeat(args: &[ArrayRef]) -> Result { let result = string_array .iter() .zip(number_array.iter()) - .map(|(string, number)| match (string, number) { - (Some(string), Some(number)) if number >= 0 => { - Some(string.repeat(number as usize)) - } - (Some(_), Some(_)) => Some("".to_string()), - _ => None, - }) + .map(|(string, number)| repeat_common(string, number)) .collect::>(); Ok(Arc::new(result) as ArrayRef) } +fn repeat_utf8view(args: &[ArrayRef]) -> Result { + let string_view_array = as_string_view_array(&args[0])?; + let number_array = as_int64_array(&args[1])?; + + let result = string_view_array + .iter() + .zip(number_array.iter()) + .map(|(string, number)| repeat_common(string, number)) + .collect::(); + + Ok(Arc::new(result) as ArrayRef) +} + +fn repeat_common(string: Option<&str>, number: Option) -> Option { + match (string, number) { + (Some(string), Some(number)) if number >= 0 => { + Some(string.repeat(number as usize)) + } + (Some(_), Some(_)) => Some("".to_string()), + _ => None, + } +} + #[cfg(test)] mod tests { use arrow::array::{Array, StringArray}; @@ -124,7 +149,6 @@ mod tests { Utf8, StringArray ); - test_function!( RepeatFunc::new(), &[ @@ -148,6 +172,40 @@ mod tests { StringArray ); + test_function!( + RepeatFunc::new(), + &[ + ColumnarValue::Scalar(ScalarValue::Utf8View(Some(String::from("Pg")))), + ColumnarValue::Scalar(ScalarValue::Int64(Some(4))), + ], + Ok(Some("PgPgPgPg")), + &str, + Utf8, + StringArray + ); + test_function!( + RepeatFunc::new(), + &[ + ColumnarValue::Scalar(ScalarValue::Utf8View(None)), + ColumnarValue::Scalar(ScalarValue::Int64(Some(4))), + ], + Ok(None), + &str, + Utf8, + StringArray + ); + test_function!( + RepeatFunc::new(), + &[ + ColumnarValue::Scalar(ScalarValue::Utf8View(Some(String::from("Pg")))), + ColumnarValue::Scalar(ScalarValue::Int64(None)), + ], + Ok(None), + &str, + Utf8, + StringArray + ); + Ok(()) } } diff --git a/datafusion/sqllogictest/test_files/string_view.slt b/datafusion/sqllogictest/test_files/string_view.slt index fcd71b7f7e94..26a5fbd3fd5a 100644 --- a/datafusion/sqllogictest/test_files/string_view.slt +++ b/datafusion/sqllogictest/test_files/string_view.slt @@ -761,14 +761,13 @@ logical_plan ## Ensure no casts for REPEAT -## TODO file ticket query TT EXPLAIN SELECT REPEAT(column1_utf8view, 2) as c1 FROM test; ---- logical_plan -01)Projection: repeat(CAST(test.column1_utf8view AS Utf8), Int64(2)) AS c1 +01)Projection: repeat(test.column1_utf8view, Int64(2)) AS c1 02)--TableScan: test projection=[column1_utf8view] ## Ensure no casts for REPLACE From 6f23f5fde2509965d054956068af1bedac852d78 Mon Sep 17 00:00:00 2001 From: Andrew Lamb Date: Wed, 14 Aug 2024 07:22:23 -0400 Subject: [PATCH 2/2] cargo fmt --- datafusion/functions/src/string/repeat.rs | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/datafusion/functions/src/string/repeat.rs b/datafusion/functions/src/string/repeat.rs index 713b371e3fda..a377dee06f41 100644 --- a/datafusion/functions/src/string/repeat.rs +++ b/datafusion/functions/src/string/repeat.rs @@ -21,7 +21,9 @@ use std::sync::Arc; use arrow::array::{ArrayRef, GenericStringArray, OffsetSizeTrait, StringArray}; use arrow::datatypes::DataType; -use datafusion_common::cast::{as_generic_string_array, as_int64_array, as_string_view_array}; +use datafusion_common::cast::{ + as_generic_string_array, as_int64_array, as_string_view_array, +}; use datafusion_common::{exec_err, Result}; use datafusion_expr::TypeSignature::*; use datafusion_expr::{ColumnarValue, Volatility};