diff --git a/datafusion/functions/src/string/initcap.rs b/datafusion/functions/src/string/initcap.rs index 864179d130fd..4e1eb213ef57 100644 --- a/datafusion/functions/src/string/initcap.rs +++ b/datafusion/functions/src/string/initcap.rs @@ -18,10 +18,10 @@ use std::any::Any; use std::sync::Arc; -use arrow::array::{ArrayRef, GenericStringArray, OffsetSizeTrait}; +use arrow::array::{ArrayRef, GenericStringArray, OffsetSizeTrait, StringArray}; use arrow::datatypes::DataType; -use datafusion_common::cast::as_generic_string_array; +use datafusion_common::cast::{as_generic_string_array, as_string_view_array}; use datafusion_common::{exec_err, Result}; use datafusion_expr::{ColumnarValue, Volatility}; use datafusion_expr::{ScalarUDFImpl, Signature}; @@ -45,7 +45,7 @@ impl InitcapFunc { Self { signature: Signature::uniform( 1, - vec![Utf8, LargeUtf8], + vec![Utf8, LargeUtf8, Utf8View], Volatility::Immutable, ), } @@ -73,6 +73,7 @@ impl ScalarUDFImpl for InitcapFunc { match args[0].data_type() { DataType::Utf8 => make_scalar_function(initcap::, vec![])(args), DataType::LargeUtf8 => make_scalar_function(initcap::, vec![])(args), + DataType::Utf8View => make_scalar_function(initcap_utf8view, vec![])(args), other => { exec_err!("Unsupported data type {other:?} for function initcap") } @@ -88,28 +89,41 @@ fn initcap(args: &[ArrayRef]) -> Result { // first map is the iterator, second is for the `Option<_>` let result = string_array .iter() - .map(|string| { - string.map(|string: &str| { - let mut char_vector = Vec::::new(); - let mut previous_character_letter_or_number = false; - for c in string.chars() { - if previous_character_letter_or_number { - char_vector.push(c.to_ascii_lowercase()); - } else { - char_vector.push(c.to_ascii_uppercase()); - } - previous_character_letter_or_number = c.is_ascii_uppercase() - || c.is_ascii_lowercase() - || c.is_ascii_digit(); - } - char_vector.iter().collect::() - }) - }) + .map(initcap_string) .collect::>(); Ok(Arc::new(result) as ArrayRef) } +fn initcap_utf8view(args: &[ArrayRef]) -> Result { + let string_view_array = as_string_view_array(&args[0])?; + + let result = string_view_array + .iter() + .map(initcap_string) + .collect::(); + + Ok(Arc::new(result) as ArrayRef) +} + +fn initcap_string(string: Option<&str>) -> Option { + let mut char_vector = Vec::::new(); + string.map(|string: &str| { + char_vector.clear(); + let mut previous_character_letter_or_number = false; + for c in string.chars() { + if previous_character_letter_or_number { + char_vector.push(c.to_ascii_lowercase()); + } else { + char_vector.push(c.to_ascii_uppercase()); + } + previous_character_letter_or_number = + c.is_ascii_uppercase() || c.is_ascii_lowercase() || c.is_ascii_digit(); + } + char_vector.iter().collect::() + }) +} + #[cfg(test)] mod tests { use crate::string::initcap::InitcapFunc; @@ -153,6 +167,44 @@ mod tests { Utf8, StringArray ); + test_function!( + InitcapFunc::new(), + &[ColumnarValue::Scalar(ScalarValue::Utf8View(Some( + "hi THOMAS".to_string() + )))], + Ok(Some("Hi Thomas")), + &str, + Utf8, + StringArray + ); + test_function!( + InitcapFunc::new(), + &[ColumnarValue::Scalar(ScalarValue::Utf8View(Some( + "hi THOMAS wIth M0re ThAN 12 ChaRs".to_string() + )))], + Ok(Some("Hi Thomas With M0re Than 12 Chars")), + &str, + Utf8, + StringArray + ); + test_function!( + InitcapFunc::new(), + &[ColumnarValue::Scalar(ScalarValue::Utf8View(Some( + "".to_string() + )))], + Ok(Some("")), + &str, + Utf8, + StringArray + ); + test_function!( + InitcapFunc::new(), + &[ColumnarValue::Scalar(ScalarValue::Utf8View(None))], + Ok(None), + &str, + Utf8, + StringArray + ); Ok(()) } diff --git a/datafusion/sqllogictest/test_files/string_view.slt b/datafusion/sqllogictest/test_files/string_view.slt index e7166690580f..a61e3830fd08 100644 --- a/datafusion/sqllogictest/test_files/string_view.slt +++ b/datafusion/sqllogictest/test_files/string_view.slt @@ -425,6 +425,50 @@ logical_plan 01)Projection: starts_with(test.column1_utf8view, Utf8View("äöüß")) AS c1, starts_with(test.column1_utf8view, Utf8View("")) AS c2, starts_with(test.column1_utf8view, Utf8View(NULL)) AS c3, starts_with(Utf8View(NULL), test.column1_utf8view) AS c4 02)--TableScan: test projection=[column1_utf8view] +### Initcap + +query TT +EXPLAIN SELECT + INITCAP(column1_utf8view) as c +FROM test; +---- +logical_plan +01)Projection: initcap(test.column1_utf8view) AS c +02)--TableScan: test projection=[column1_utf8view] + +# Create a table with lowercase strings +statement ok +CREATE TABLE test_lowercase AS SELECT + lower(column1_utf8) as column1_utf8_lower, + lower(column1_large_utf8) as column1_large_utf8_lower, + lower(column1_utf8view) as column1_utf8view_lower +FROM test; + +# Test INITCAP with utf8view, utf8, and largeutf8 +# Should not cast anything +query TT +EXPLAIN SELECT + INITCAP(column1_utf8view_lower) as c1, + INITCAP(column1_utf8_lower) as c2, + INITCAP(column1_large_utf8_lower) as c3 +FROM test_lowercase; +---- +logical_plan +01)Projection: initcap(test_lowercase.column1_utf8view_lower) AS c1, initcap(test_lowercase.column1_utf8_lower) AS c2, initcap(test_lowercase.column1_large_utf8_lower) AS c3 +02)--TableScan: test_lowercase projection=[column1_utf8_lower, column1_large_utf8_lower, column1_utf8view_lower] + +query TTT +SELECT + INITCAP(column1_utf8view_lower) as c1, + INITCAP(column1_utf8_lower) as c2, + INITCAP(column1_large_utf8_lower) as c3 +FROM test_lowercase; +---- +Andrew Andrew Andrew +Xiangpeng Xiangpeng Xiangpeng +Raphael Raphael Raphael +NULL NULL NULL + # Ensure string functions use native StringView implementation # and do not fall back to Utf8 or LargeUtf8 # Should see no casts to Utf8 in the plans below @@ -586,18 +630,6 @@ logical_plan 02)--Projection: CAST(test.column2_utf8view AS Utf8) AS __common_expr_1, test.column1_utf8view 03)----TableScan: test projection=[column1_utf8view, column2_utf8view] - -## Ensure no casts for INITCAP -## TODO https://github.com/apache/datafusion/issues/11853 -query TT -EXPLAIN SELECT - INITCAP(column1_utf8view) as c -FROM test; ----- -logical_plan -01)Projection: initcap(CAST(test.column1_utf8view AS Utf8)) AS c -02)--TableScan: test projection=[column1_utf8view] - ## Ensure no casts for LEVENSHTEIN ## TODO https://github.com/apache/datafusion/issues/11854 query TT