Skip to content

Commit

Permalink
fx
Browse files Browse the repository at this point in the history
  • Loading branch information
joseph-isaacs committed Nov 20, 2024
1 parent 9877079 commit a2695ff
Showing 1 changed file with 76 additions and 11 deletions.
87 changes: 76 additions & 11 deletions datafusion/functions/src/string/concat.rs
Original file line number Diff line number Diff line change
Expand Up @@ -48,7 +48,7 @@ impl ConcatFunc {
use DataType::*;
Self {
signature: Signature::variadic(
vec![Utf8, Utf8View, LargeUtf8],
vec![Utf8View, Utf8, LargeUtf8],
Volatility::Immutable,
),
}
Expand Down Expand Up @@ -114,8 +114,19 @@ impl ScalarUDFImpl for ConcatFunc {
if array_len.is_none() {
let mut result = String::new();
for arg in args {
if let ColumnarValue::Scalar(ScalarValue::Utf8(Some(v))) = arg {
result.push_str(v);
match arg {
ColumnarValue::Scalar(ScalarValue::Utf8(Some(v)))
| ColumnarValue::Scalar(ScalarValue::Utf8View(Some(v)))
| ColumnarValue::Scalar(ScalarValue::LargeUtf8(Some(v))) => {
result.push_str(v);
}
ColumnarValue::Scalar(ScalarValue::Utf8(None))
| ColumnarValue::Scalar(ScalarValue::Utf8View(None))
| ColumnarValue::Scalar(ScalarValue::LargeUtf8(None)) => {}
other => plan_err!(
"Concat function does not support scalar type {:?}",
other
)?,
}
}

Expand Down Expand Up @@ -286,15 +297,37 @@ pub fn simplify_concat(args: Vec<Expr>) -> Result<ExprSimplifyResult> {
let mut new_args = Vec::with_capacity(args.len());
let mut contiguous_scalar = "".to_string();

let return_type = {
let data_types: Vec<_> = args
.iter()
.filter_map(|expr| match expr {
Expr::Literal(l) => Some(l.data_type()),
_ => None,
})
.collect();
ConcatFunc::new().return_type(&data_types)
}?;

for arg in args.clone() {
match arg {
Expr::Literal(ScalarValue::Utf8(None)) => {}
Expr::Literal(ScalarValue::LargeUtf8(None)) => {
}
Expr::Literal(ScalarValue::Utf8View(None)) => { }

// filter out `null` args
Expr::Literal(ScalarValue::Utf8(None) | ScalarValue::LargeUtf8(None) | ScalarValue::Utf8View(None)) => {}
// All literals have been converted to Utf8 or LargeUtf8 in type_coercion.
// Concatenate it with the `contiguous_scalar`.
Expr::Literal(
ScalarValue::Utf8(Some(v)) | ScalarValue::LargeUtf8(Some(v)) | ScalarValue::Utf8View(Some(v)),
) => contiguous_scalar += &v,
Expr::Literal(ScalarValue::Utf8(Some(v))) => {
contiguous_scalar += &v;
}
Expr::Literal(ScalarValue::LargeUtf8(Some(v))) => {
contiguous_scalar += &v;
}
Expr::Literal(ScalarValue::Utf8View(Some(v))) => {
contiguous_scalar += &v;
}

Expr::Literal(x) => {
return internal_err!(
"The scalar {x} should be casted to string type during the type coercion."
Expand All @@ -305,7 +338,12 @@ pub fn simplify_concat(args: Vec<Expr>) -> Result<ExprSimplifyResult> {
// Then pushing this arg to the `new_args`.
arg => {
if !contiguous_scalar.is_empty() {
new_args.push(lit(contiguous_scalar));
match return_type {
DataType::Utf8 => new_args.push(lit(contiguous_scalar)),
DataType::LargeUtf8 => new_args.push(lit(ScalarValue::LargeUtf8(Some(contiguous_scalar)))),
DataType::Utf8View => new_args.push(lit(ScalarValue::Utf8View(Some(contiguous_scalar)))),
_ => unreachable!(),
}
contiguous_scalar = "".to_string();
}
new_args.push(arg);
Expand All @@ -314,7 +352,16 @@ pub fn simplify_concat(args: Vec<Expr>) -> Result<ExprSimplifyResult> {
}

if !contiguous_scalar.is_empty() {
new_args.push(lit(contiguous_scalar));
match return_type {
DataType::Utf8 => new_args.push(lit(contiguous_scalar)),
DataType::LargeUtf8 => {
new_args.push(lit(ScalarValue::LargeUtf8(Some(contiguous_scalar))))
}
DataType::Utf8View => {
new_args.push(lit(ScalarValue::Utf8View(Some(contiguous_scalar))))
}
_ => unreachable!(),
}
}

if !args.eq(&new_args) {
Expand Down Expand Up @@ -396,6 +443,17 @@ mod tests {
LargeUtf8,
LargeStringArray
);
test_function!(
ConcatFunc::new(),
&[
ColumnarValue::Scalar(ScalarValue::Utf8View(Some("aa".to_string()))),
ColumnarValue::Scalar(ScalarValue::Utf8(Some("cc".to_string()))),
],
Ok(Some("aacc")),
&str,
Utf8View,
StringViewArray
);

Ok(())
}
Expand All @@ -410,12 +468,19 @@ mod tests {
None,
Some("z"),
])));
let args = &[c0, c1, c2];
let c3 = ColumnarValue::Scalar(ScalarValue::Utf8View(Some(",".to_string())));
let c4 = ColumnarValue::Array(Arc::new(StringViewArray::from(vec![
Some("a"),
None,
Some("b"),
])));
let args = &[c0, c1, c2, c3, c4];

#[allow(deprecated)] // TODO migrate UDF invoke to invoke_batch
let result = ConcatFunc::new().invoke_batch(args, 3)?;
let expected =
Arc::new(StringArray::from(vec!["foo,x", "bar,", "baz,z"])) as ArrayRef;
Arc::new(StringViewArray::from(vec!["foo,x,a", "bar,,", "baz,z,b"]))
as ArrayRef;
match &result {
ColumnarValue::Array(array) => {
assert_eq!(&expected, array);
Expand Down

0 comments on commit a2695ff

Please sign in to comment.