From 3d76aa25e4830ef8da42fae17453d8d1b8e66d4e Mon Sep 17 00:00:00 2001 From: Trent Hauck Date: Tue, 6 Aug 2024 03:23:59 -0700 Subject: [PATCH] feat: support `Utf8View` type in `starts_with` function (#11787) * feat: support `Utf8View` for `starts_with` * style: clippy * simplify string view handling * fix: allow utf8 and largeutf8 to be cast into utf8view * fix: fix test * Apply suggestions from code review Co-authored-by: Yongting You <2010youy01@gmail.com> Co-authored-by: Andrew Lamb * style: fix format * feat: add addiontal tests * tests: improve tests * fix: fix null case * tests: one more null test * Test comments and execution tests --------- Co-authored-by: Yongting You <2010youy01@gmail.com> Co-authored-by: Andrew Lamb --- datafusion/expr/src/expr_schema.rs | 1 + .../expr/src/type_coercion/functions.rs | 16 ++++ .../functions/src/string/starts_with.rs | 92 +++++++++++++++---- .../sqllogictest/test_files/string_view.slt | 70 +++++++++++++- 4 files changed, 158 insertions(+), 21 deletions(-) diff --git a/datafusion/expr/src/expr_schema.rs b/datafusion/expr/src/expr_schema.rs index 676903d59a07..9faeb8aed506 100644 --- a/datafusion/expr/src/expr_schema.rs +++ b/datafusion/expr/src/expr_schema.rs @@ -148,6 +148,7 @@ impl ExprSchemable for Expr { .iter() .map(|e| e.get_type(schema)) .collect::>>()?; + // verify that function is invoked with correct number and type of arguments as defined in `TypeSignature` data_types_with_scalar_udf(&arg_data_types, func).map_err(|err| { plan_datafusion_err!( diff --git a/datafusion/expr/src/type_coercion/functions.rs b/datafusion/expr/src/type_coercion/functions.rs index 66807c3f446c..4f2776516d3e 100644 --- a/datafusion/expr/src/type_coercion/functions.rs +++ b/datafusion/expr/src/type_coercion/functions.rs @@ -583,6 +583,10 @@ fn coerced_from<'a>( (Interval(_), _) if matches!(type_from, Utf8 | LargeUtf8) => { Some(type_into.clone()) } + // We can go into a Utf8View from a Utf8 or LargeUtf8 + (Utf8View, _) if matches!(type_from, Utf8 | LargeUtf8 | Null) => { + Some(type_into.clone()) + } // Any type can be coerced into strings (Utf8 | LargeUtf8, _) => Some(type_into.clone()), (Null, _) if can_cast_types(type_from, type_into) => Some(type_into.clone()), @@ -646,6 +650,18 @@ mod tests { use super::*; use arrow::datatypes::Field; + #[test] + fn test_string_conversion() { + let cases = vec![ + (DataType::Utf8View, DataType::Utf8, true), + (DataType::Utf8View, DataType::LargeUtf8, true), + ]; + + for case in cases { + assert_eq!(can_coerce_from(&case.0, &case.1), case.2); + } + } + #[test] fn test_maybe_data_types() { // this vec contains: arg1, arg2, expected result diff --git a/datafusion/functions/src/string/starts_with.rs b/datafusion/functions/src/string/starts_with.rs index 05bd960ff14b..8450697cbf30 100644 --- a/datafusion/functions/src/string/starts_with.rs +++ b/datafusion/functions/src/string/starts_with.rs @@ -18,10 +18,10 @@ use std::any::Any; use std::sync::Arc; -use arrow::array::{ArrayRef, OffsetSizeTrait}; +use arrow::array::ArrayRef; use arrow::datatypes::DataType; -use datafusion_common::{cast::as_generic_string_array, internal_err, Result}; +use datafusion_common::{internal_err, Result}; use datafusion_expr::ColumnarValue; use datafusion_expr::TypeSignature::*; use datafusion_expr::{ScalarUDFImpl, Signature, Volatility}; @@ -30,12 +30,8 @@ use crate::utils::make_scalar_function; /// Returns true if string starts with prefix. /// starts_with('alphabet', 'alph') = 't' -pub fn starts_with(args: &[ArrayRef]) -> Result { - let left = as_generic_string_array::(&args[0])?; - let right = as_generic_string_array::(&args[1])?; - - let result = arrow::compute::kernels::comparison::starts_with(left, right)?; - +pub fn starts_with(args: &[ArrayRef]) -> Result { + let result = arrow::compute::kernels::comparison::starts_with(&args[0], &args[1])?; Ok(Arc::new(result) as ArrayRef) } @@ -52,14 +48,15 @@ impl Default for StartsWithFunc { impl StartsWithFunc { pub fn new() -> Self { - use DataType::*; Self { signature: Signature::one_of( vec![ - Exact(vec![Utf8, Utf8]), - Exact(vec![Utf8, LargeUtf8]), - Exact(vec![LargeUtf8, Utf8]), - Exact(vec![LargeUtf8, LargeUtf8]), + // Planner attempts coercion to the target type starting with the most preferred candidate. + // For example, given input `(Utf8View, Utf8)`, it first tries coercing to `(Utf8View, Utf8View)`. + // If that fails, it proceeds to `(Utf8, Utf8)`. + Exact(vec![DataType::Utf8View, DataType::Utf8View]), + Exact(vec![DataType::Utf8, DataType::Utf8]), + Exact(vec![DataType::LargeUtf8, DataType::LargeUtf8]), ], Volatility::Immutable, ), @@ -81,18 +78,73 @@ impl ScalarUDFImpl for StartsWithFunc { } fn return_type(&self, _arg_types: &[DataType]) -> Result { - use DataType::*; - - Ok(Boolean) + Ok(DataType::Boolean) } fn invoke(&self, args: &[ColumnarValue]) -> Result { match args[0].data_type() { - DataType::Utf8 => make_scalar_function(starts_with::, vec![])(args), - DataType::LargeUtf8 => { - return make_scalar_function(starts_with::, vec![])(args); + DataType::Utf8View | DataType::Utf8 | DataType::LargeUtf8 => { + make_scalar_function(starts_with, vec![])(args) } - _ => internal_err!("Unsupported data type"), + _ => internal_err!("Unsupported data types for starts_with. Expected Utf8, LargeUtf8 or Utf8View")?, } } } + +#[cfg(test)] +mod tests { + use crate::utils::test::test_function; + use arrow::array::{Array, BooleanArray}; + use arrow::datatypes::DataType::Boolean; + use datafusion_common::{Result, ScalarValue}; + use datafusion_expr::{ColumnarValue, ScalarUDFImpl}; + + use super::*; + + #[test] + fn test_functions() -> Result<()> { + // Generate test cases for starts_with + let test_cases = vec![ + (Some("alphabet"), Some("alph"), Some(true)), + (Some("alphabet"), Some("bet"), Some(false)), + ( + Some("somewhat large string"), + Some("somewhat large"), + Some(true), + ), + (Some("somewhat large string"), Some("large"), Some(false)), + ] + .into_iter() + .flat_map(|(a, b, c)| { + let utf_8_args = vec![ + ColumnarValue::Scalar(ScalarValue::Utf8(a.map(|s| s.to_string()))), + ColumnarValue::Scalar(ScalarValue::Utf8(b.map(|s| s.to_string()))), + ]; + + let large_utf_8_args = vec![ + ColumnarValue::Scalar(ScalarValue::LargeUtf8(a.map(|s| s.to_string()))), + ColumnarValue::Scalar(ScalarValue::LargeUtf8(b.map(|s| s.to_string()))), + ]; + + let utf_8_view_args = vec![ + ColumnarValue::Scalar(ScalarValue::Utf8View(a.map(|s| s.to_string()))), + ColumnarValue::Scalar(ScalarValue::Utf8View(b.map(|s| s.to_string()))), + ]; + + vec![(utf_8_args, c), (large_utf_8_args, c), (utf_8_view_args, c)] + }); + + for (args, expected) in test_cases { + test_function!( + StartsWithFunc::new(), + &args, + Ok(expected), + bool, + Boolean, + BooleanArray + ); + } + + Ok(()) + } +} diff --git a/datafusion/sqllogictest/test_files/string_view.slt b/datafusion/sqllogictest/test_files/string_view.slt index 763b4e99c614..584d3b330690 100644 --- a/datafusion/sqllogictest/test_files/string_view.slt +++ b/datafusion/sqllogictest/test_files/string_view.slt @@ -355,6 +355,75 @@ logical_plan 01)Aggregate: groupBy=[[]], aggr=[[count(DISTINCT test.column1_utf8), count(DISTINCT test.column1_utf8view), count(DISTINCT test.column1_dict)]] 02)--TableScan: test projection=[column1_utf8, column1_utf8view, column1_dict] +### `STARTS_WITH` + +# Test STARTS_WITH with utf8view against utf8view, utf8, and largeutf8 +# (should be no casts) +query TT +EXPLAIN SELECT + STARTS_WITH(column1_utf8view, column2_utf8view) as c1, + STARTS_WITH(column1_utf8view, column2_utf8) as c2, + STARTS_WITH(column1_utf8view, column2_large_utf8) as c3 +FROM test; +---- +logical_plan +01)Projection: starts_with(test.column1_utf8view, test.column2_utf8view) AS c1, starts_with(test.column1_utf8view, CAST(test.column2_utf8 AS Utf8View)) AS c2, starts_with(test.column1_utf8view, CAST(test.column2_large_utf8 AS Utf8View)) AS c3 +02)--TableScan: test projection=[column2_utf8, column2_large_utf8, column1_utf8view, column2_utf8view] + +query BBB +SELECT + STARTS_WITH(column1_utf8view, column2_utf8view) as c1, + STARTS_WITH(column1_utf8view, column2_utf8) as c2, + STARTS_WITH(column1_utf8view, column2_large_utf8) as c3 +FROM test; +---- +false false false +true true true +true true true +NULL NULL NULL + +# Test STARTS_WITH with utf8 against utf8view, utf8, and largeutf8 +# Should work, but will have to cast to common types +# should cast utf8 -> utf8view and largeutf8 -> utf8view +query TT +EXPLAIN SELECT + STARTS_WITH(column1_utf8, column2_utf8view) as c1, + STARTS_WITH(column1_utf8, column2_utf8) as c3, + STARTS_WITH(column1_utf8, column2_large_utf8) as c4 +FROM test; +---- +logical_plan +01)Projection: starts_with(__common_expr_1, test.column2_utf8view) AS c1, starts_with(test.column1_utf8, test.column2_utf8) AS c3, starts_with(__common_expr_1, CAST(test.column2_large_utf8 AS Utf8View)) AS c4 +02)--Projection: CAST(test.column1_utf8 AS Utf8View) AS __common_expr_1, test.column1_utf8, test.column2_utf8, test.column2_large_utf8, test.column2_utf8view +03)----TableScan: test projection=[column1_utf8, column2_utf8, column2_large_utf8, column2_utf8view] + +query BBB + SELECT + STARTS_WITH(column1_utf8, column2_utf8view) as c1, + STARTS_WITH(column1_utf8, column2_utf8) as c3, + STARTS_WITH(column1_utf8, column2_large_utf8) as c4 +FROM test; +---- +false false false +true true true +true true true +NULL NULL NULL + + +# Test STARTS_WITH with utf8view against literals +# In this case, the literals should be cast to utf8view. The columns +# should not be cast to utf8. +query TT +EXPLAIN SELECT + STARTS_WITH(column1_utf8view, 'äöüß') as c1, + STARTS_WITH(column1_utf8view, '') as c2, + STARTS_WITH(column1_utf8view, NULL) as c3, + STARTS_WITH(NULL, column1_utf8view) as c4 +FROM test; +---- +logical_plan +01)Projection: starts_with(test.column1_utf8view, Utf8View("äöüß")) AS c1, starts_with(test.column1_utf8view, Utf8View("")) AS c2, starts_with(test.column1_utf8view, Utf8View(NULL)) AS c3, starts_with(Utf8View(NULL), test.column1_utf8view) AS c4 +02)--TableScan: test projection=[column1_utf8view] statement ok drop table test; @@ -376,6 +445,5 @@ select t.dt from dates t where arrow_cast('2024-01-01', 'Utf8View') < t.dt; ---- 2024-01-23 - statement ok drop table dates;