diff --git a/datafusion/functions/src/string/concat.rs b/datafusion/functions/src/string/concat.rs index a93e70e714e8..231f75ce35f2 100644 --- a/datafusion/functions/src/string/concat.rs +++ b/datafusion/functions/src/string/concat.rs @@ -15,8 +15,8 @@ // specific language governing permissions and limitations // under the License. -use arrow::array::{as_largestring_array, Array}; -use arrow::datatypes::DataType; +use arrow::array::{as_largestring_array, Array, ArrayRef, AsArray}; +use arrow::datatypes::{DataType, Field}; use datafusion_expr::sort_properties::ExprProperties; use std::any::Any; use std::sync::Arc; @@ -65,13 +65,61 @@ impl Default for ConcatFunc { impl ConcatFunc { pub fn new() -> Self { - use DataType::*; Self { - signature: Signature::variadic( - vec![Utf8View, Utf8, LargeUtf8], - Volatility::Immutable, - ), + signature: Signature::user_defined(Volatility::Immutable), + } + } + + fn concat_arrays(&self, args: &[ColumnarValue]) -> Result { + let arrays: Result> = args + .iter() + .map(|arg| match arg { + ColumnarValue::Array(array) => Ok(Arc::clone(array)), + ColumnarValue::Scalar(scalar) => scalar.to_array_of_size(1), + }) + .collect(); + let arrays = arrays?; + + // Extract values from each list and concatenate them + let mut all_elements = Vec::new(); + for array in &arrays { + match array.data_type() { + DataType::List(_) => { + let list_array = array.as_list::(); + if !list_array.is_null(0) { + all_elements.push(list_array.value(0)); + } + } + DataType::LargeList(_) => { + let list_array = array.as_list::(); + if !list_array.is_null(0) { + all_elements.push(list_array.value(0)); + } + } + DataType::FixedSizeList(_, _) => { + let list_array = array.as_fixed_size_list(); + if !list_array.is_null(0) { + all_elements.push(list_array.value(0)); + } + } + _ => return internal_err!("Expected array type"), + } } + + if all_elements.is_empty() { + return plan_err!("No elements to concatenate"); + } + + let element_refs: Vec<&dyn Array> = + all_elements.iter().map(|a| a.as_ref()).collect(); + let concatenated = arrow::compute::concat(&element_refs)?; + + let field = Field::new_list_field(concatenated.data_type().clone(), true); + let offsets = arrow::buffer::OffsetBuffer::from_lengths([concatenated.len()]); + let result = + arrow::array::ListArray::new(Arc::new(field), offsets, concatenated, None); + + Ok(ColumnarValue::Array(Arc::new(result))) } } @@ -88,19 +136,43 @@ impl ScalarUDFImpl for ConcatFunc { &self.signature } - fn return_type(&self, arg_types: &[DataType]) -> Result { + fn coerce_types(&self, arg_types: &[DataType]) -> Result> { use DataType::*; - let mut dt = &Utf8; - arg_types.iter().for_each(|data_type| { - if data_type == &Utf8View { - dt = data_type; - } - if data_type == &LargeUtf8 && dt != &Utf8View { - dt = data_type; + + // Arrays don't need coercion + if arg_types + .iter() + .any(|dt| matches!(dt, List(_) | LargeList(_) | FixedSizeList(_, _))) + { + return Ok(arg_types.to_vec()); + } + + // For non-array types, coerce to best string type + let mut best_type = Utf8; + for arg_type in arg_types { + match arg_type { + Utf8View => best_type = Utf8View, + LargeUtf8 if best_type != Utf8View => best_type = LargeUtf8, + _ => {} } - }); + } + + Ok(vec![best_type; arg_types.len()]) + } + + fn return_type(&self, arg_types: &[DataType]) -> Result { + use DataType::*; + + if arg_types + .iter() + .any(|dt| matches!(dt, List(_) | LargeList(_) | FixedSizeList(_, _))) + { + return Ok(arg_types[0].clone()); + } - Ok(dt.to_owned()) + // Use coerced types for return type + let coerced = self.coerce_types(arg_types)?; + Ok(coerced[0].clone()) } /// Concatenates the text representations of all the arguments. NULL arguments are ignored. @@ -108,6 +180,21 @@ impl ScalarUDFImpl for ConcatFunc { fn invoke_with_args(&self, args: ScalarFunctionArgs) -> Result { let ScalarFunctionArgs { args, .. } = args; + if args.is_empty() { + return plan_err!("concat requires at least one argument"); + } + + if args.iter().any(|arg| { + matches!( + arg.data_type(), + DataType::List(_) + | DataType::LargeList(_) + | DataType::FixedSizeList(_, _) + ) + }) { + return self.concat_arrays(&args); + } + let mut return_datatype = DataType::Utf8; args.iter().for_each(|col| { if col.data_type() == DataType::Utf8View { @@ -139,10 +226,14 @@ impl ScalarUDFImpl for ConcatFunc { match scalar.try_as_str() { Some(Some(v)) => result.push_str(v), Some(None) => {} // null literal - None => plan_err!( - "Concat function does not support scalar type {}", - scalar - )?, + None => { + // For non-string types, convert to string representation + if scalar.is_null() { + // Skip null values + } else { + result.push_str(&format!("{scalar}")); + } + } } } @@ -217,7 +308,12 @@ impl ScalarUDFImpl for ConcatFunc { } }; } - _ => unreachable!("concat"), + _ => { + return plan_err!( + "Unsupported argument type for concat: {}", + arg.data_type() + ) + } } } @@ -258,7 +354,9 @@ impl ScalarUDFImpl for ConcatFunc { let string_array = builder.finish(None); Ok(ColumnarValue::Array(Arc::new(string_array))) } - _ => unreachable!(), + _ => { + plan_err!("Unsupported return datatype for concat: {return_datatype}") + } } } @@ -302,45 +400,49 @@ pub fn simplify_concat(args: Vec) -> Result { ConcatFunc::new().return_type(&data_types) }?; - for arg in args.clone() { + for arg in args.iter() { match arg { Expr::Literal(ScalarValue::Utf8(None), _) => {} - Expr::Literal(ScalarValue::LargeUtf8(None), _) => { - } - Expr::Literal(ScalarValue::Utf8View(None), _) => { } + Expr::Literal(ScalarValue::LargeUtf8(None), _) => {} + Expr::Literal(ScalarValue::Utf8View(None), _) => {} - // filter out `null` args - // All literals have been converted to Utf8 or LargeUtf8 in type_coercion. - // Concatenate it with the `contiguous_scalar`. Expr::Literal(ScalarValue::Utf8(Some(v)), _) => { - contiguous_scalar += &v; + contiguous_scalar += v; } Expr::Literal(ScalarValue::LargeUtf8(Some(v)), _) => { - contiguous_scalar += &v; + contiguous_scalar += v; } Expr::Literal(ScalarValue::Utf8View(Some(v)), _) => { - contiguous_scalar += &v; + contiguous_scalar += v; } - Expr::Literal(x, _) => { - return internal_err!( - "The scalar {x} should be casted to string type during the type coercion." - ) + Expr::Literal(_x, _) => { + if !contiguous_scalar.is_empty() { + match return_type { + DataType::Utf8 => new_args.push(lit(contiguous_scalar)), + DataType::LargeUtf8 => new_args + .push(lit(ScalarValue::LargeUtf8(Some(contiguous_scalar)))), + DataType::Utf8View => new_args + .push(lit(ScalarValue::Utf8View(Some(contiguous_scalar)))), + _ => return Ok(ExprSimplifyResult::Original(args)), + } + contiguous_scalar = "".to_string(); + } + new_args.push(arg.clone()); } - // If the arg is not a literal, we should first push the current `contiguous_scalar` - // to the `new_args` (if it is not empty) and reset it to empty string. - // Then pushing this arg to the `new_args`. arg => { if !contiguous_scalar.is_empty() { match return_type { DataType::Utf8 => new_args.push(lit(contiguous_scalar)), - DataType::LargeUtf8 => new_args.push(lit(ScalarValue::LargeUtf8(Some(contiguous_scalar)))), - DataType::Utf8View => new_args.push(lit(ScalarValue::Utf8View(Some(contiguous_scalar)))), - _ => unreachable!(), + DataType::LargeUtf8 => new_args + .push(lit(ScalarValue::LargeUtf8(Some(contiguous_scalar)))), + DataType::Utf8View => new_args + .push(lit(ScalarValue::Utf8View(Some(contiguous_scalar)))), + _ => return Ok(ExprSimplifyResult::Original(args)), } contiguous_scalar = "".to_string(); } - new_args.push(arg); + new_args.push(arg.clone()); } } } @@ -354,7 +456,7 @@ pub fn simplify_concat(args: Vec) -> Result { DataType::Utf8View => { new_args.push(lit(ScalarValue::Utf8View(Some(contiguous_scalar)))) } - _ => unreachable!(), + _ => return Ok(ExprSimplifyResult::Original(args)), } } @@ -501,4 +603,87 @@ mod tests { } Ok(()) } + + #[test] + fn test_array_concat() -> Result<()> { + use arrow::array::{Int64Array, ListArray}; + use arrow::buffer::OffsetBuffer; + use DataType::*; + + // Create list arrays: [1, 2] and [3, 4] + let list1_values = Arc::new(Int64Array::from(vec![1, 2])); + let list1_field = Arc::new(Field::new_list_field(Int64, true)); + let list1_offsets = OffsetBuffer::from_lengths([2]); + let list1 = + ListArray::new(list1_field.clone(), list1_offsets, list1_values, None); + + let list2_values = Arc::new(Int64Array::from(vec![3, 4])); + let list2_offsets = OffsetBuffer::from_lengths([2]); + let list2 = + ListArray::new(list1_field.clone(), list2_offsets, list2_values, None); + + let args = vec![ + ColumnarValue::Array(Arc::new(list1)), + ColumnarValue::Array(Arc::new(list2)), + ]; + + let result = ConcatFunc::new().concat_arrays(&args)?; + + // Expected result: [1, 2, 3, 4] + match result { + ColumnarValue::Array(array) => { + let list_array = array.as_any().downcast_ref::().unwrap(); + assert_eq!(list_array.len(), 1); + let values = list_array.value(0); + let int_values = values.as_any().downcast_ref::().unwrap(); + assert_eq!(int_values.values(), &[1, 2, 3, 4]); + } + _ => panic!("Expected array result"), + } + + Ok(()) + } + + #[test] + fn test_concat_with_integers() -> Result<()> { + use datafusion_common::config::ConfigOptions; + use DataType::*; + + let args = vec![ + ColumnarValue::Scalar(ScalarValue::Utf8(Some("abc".to_string()))), + ColumnarValue::Scalar(ScalarValue::Int64(Some(123))), + ColumnarValue::Scalar(ScalarValue::Utf8(None)), // NULL + ColumnarValue::Scalar(ScalarValue::Int64(Some(456))), + ]; + + let arg_fields = vec![ + Field::new("a", Utf8, true), + Field::new("b", Int64, true), + Field::new("c", Utf8, true), + Field::new("d", Int64, true), + ] + .into_iter() + .map(Arc::new) + .collect::>(); + + let func_args = ScalarFunctionArgs { + args, + arg_fields, + number_rows: 1, + return_field: Field::new("f", Utf8, true).into(), + config_options: Arc::new(ConfigOptions::default()), + }; + + let result = ConcatFunc::new().invoke_with_args(func_args)?; + + // Expected result should be "abc123456" + match result { + ColumnarValue::Scalar(ScalarValue::Utf8(Some(s))) => { + assert_eq!(s, "abc123456"); + } + _ => panic!("Expected scalar UTF8 result, got {:?}", result), + } + + Ok(()) + } } diff --git a/datafusion/sqllogictest/test_files/information_schema.slt b/datafusion/sqllogictest/test_files/information_schema.slt index a69a8d5c0d8f..48959e9ed314 100644 --- a/datafusion/sqllogictest/test_files/information_schema.slt +++ b/datafusion/sqllogictest/test_files/information_schema.slt @@ -766,8 +766,6 @@ datafusion public string_agg 1 OUT NULL String NULL false 1 query TTTBI rowsort select specific_name, data_type, parameter_mode, is_variadic, rid from information_schema.parameters where specific_name = 'concat'; ---- -concat String IN true 0 -concat String OUT false 0 # test ceorcion signature query TTITI rowsort