From a2695ff64e520c887f506ece5a215833bf820744 Mon Sep 17 00:00:00 2001 From: Joe Isaacs Date: Wed, 20 Nov 2024 11:39:11 +0000 Subject: [PATCH] fx --- datafusion/functions/src/string/concat.rs | 87 ++++++++++++++++++++--- 1 file changed, 76 insertions(+), 11 deletions(-) diff --git a/datafusion/functions/src/string/concat.rs b/datafusion/functions/src/string/concat.rs index 8395eab52e78..0b77dd5b5157 100644 --- a/datafusion/functions/src/string/concat.rs +++ b/datafusion/functions/src/string/concat.rs @@ -48,7 +48,7 @@ impl ConcatFunc { use DataType::*; Self { signature: Signature::variadic( - vec![Utf8, Utf8View, LargeUtf8], + vec![Utf8View, Utf8, LargeUtf8], Volatility::Immutable, ), } @@ -114,8 +114,19 @@ impl ScalarUDFImpl for ConcatFunc { if array_len.is_none() { let mut result = String::new(); for arg in args { - if let ColumnarValue::Scalar(ScalarValue::Utf8(Some(v))) = arg { - result.push_str(v); + match arg { + ColumnarValue::Scalar(ScalarValue::Utf8(Some(v))) + | ColumnarValue::Scalar(ScalarValue::Utf8View(Some(v))) + | ColumnarValue::Scalar(ScalarValue::LargeUtf8(Some(v))) => { + result.push_str(v); + } + ColumnarValue::Scalar(ScalarValue::Utf8(None)) + | ColumnarValue::Scalar(ScalarValue::Utf8View(None)) + | ColumnarValue::Scalar(ScalarValue::LargeUtf8(None)) => {} + other => plan_err!( + "Concat function does not support scalar type {:?}", + other + )?, } } @@ -286,15 +297,37 @@ pub fn simplify_concat(args: Vec) -> Result { let mut new_args = Vec::with_capacity(args.len()); let mut contiguous_scalar = "".to_string(); + let return_type = { + let data_types: Vec<_> = args + .iter() + .filter_map(|expr| match expr { + Expr::Literal(l) => Some(l.data_type()), + _ => None, + }) + .collect(); + ConcatFunc::new().return_type(&data_types) + }?; + for arg in args.clone() { match arg { + Expr::Literal(ScalarValue::Utf8(None)) => {} + Expr::Literal(ScalarValue::LargeUtf8(None)) => { + } + Expr::Literal(ScalarValue::Utf8View(None)) => { } + // filter out `null` args - Expr::Literal(ScalarValue::Utf8(None) | ScalarValue::LargeUtf8(None) | ScalarValue::Utf8View(None)) => {} // All literals have been converted to Utf8 or LargeUtf8 in type_coercion. // Concatenate it with the `contiguous_scalar`. - Expr::Literal( - ScalarValue::Utf8(Some(v)) | ScalarValue::LargeUtf8(Some(v)) | ScalarValue::Utf8View(Some(v)), - ) => contiguous_scalar += &v, + Expr::Literal(ScalarValue::Utf8(Some(v))) => { + contiguous_scalar += &v; + } + Expr::Literal(ScalarValue::LargeUtf8(Some(v))) => { + contiguous_scalar += &v; + } + Expr::Literal(ScalarValue::Utf8View(Some(v))) => { + contiguous_scalar += &v; + } + Expr::Literal(x) => { return internal_err!( "The scalar {x} should be casted to string type during the type coercion." @@ -305,7 +338,12 @@ pub fn simplify_concat(args: Vec) -> Result { // Then pushing this arg to the `new_args`. arg => { if !contiguous_scalar.is_empty() { - new_args.push(lit(contiguous_scalar)); + match return_type { + DataType::Utf8 => new_args.push(lit(contiguous_scalar)), + DataType::LargeUtf8 => new_args.push(lit(ScalarValue::LargeUtf8(Some(contiguous_scalar)))), + DataType::Utf8View => new_args.push(lit(ScalarValue::Utf8View(Some(contiguous_scalar)))), + _ => unreachable!(), + } contiguous_scalar = "".to_string(); } new_args.push(arg); @@ -314,7 +352,16 @@ pub fn simplify_concat(args: Vec) -> Result { } if !contiguous_scalar.is_empty() { - new_args.push(lit(contiguous_scalar)); + match return_type { + DataType::Utf8 => new_args.push(lit(contiguous_scalar)), + DataType::LargeUtf8 => { + new_args.push(lit(ScalarValue::LargeUtf8(Some(contiguous_scalar)))) + } + DataType::Utf8View => { + new_args.push(lit(ScalarValue::Utf8View(Some(contiguous_scalar)))) + } + _ => unreachable!(), + } } if !args.eq(&new_args) { @@ -396,6 +443,17 @@ mod tests { LargeUtf8, LargeStringArray ); + test_function!( + ConcatFunc::new(), + &[ + ColumnarValue::Scalar(ScalarValue::Utf8View(Some("aa".to_string()))), + ColumnarValue::Scalar(ScalarValue::Utf8(Some("cc".to_string()))), + ], + Ok(Some("aacc")), + &str, + Utf8View, + StringViewArray + ); Ok(()) } @@ -410,12 +468,19 @@ mod tests { None, Some("z"), ]))); - let args = &[c0, c1, c2]; + let c3 = ColumnarValue::Scalar(ScalarValue::Utf8View(Some(",".to_string()))); + let c4 = ColumnarValue::Array(Arc::new(StringViewArray::from(vec![ + Some("a"), + None, + Some("b"), + ]))); + let args = &[c0, c1, c2, c3, c4]; #[allow(deprecated)] // TODO migrate UDF invoke to invoke_batch let result = ConcatFunc::new().invoke_batch(args, 3)?; let expected = - Arc::new(StringArray::from(vec!["foo,x", "bar,", "baz,z"])) as ArrayRef; + Arc::new(StringViewArray::from(vec!["foo,x,a", "bar,,", "baz,z,b"])) + as ArrayRef; match &result { ColumnarValue::Array(array) => { assert_eq!(&expected, array);