-
Notifications
You must be signed in to change notification settings - Fork 1.2k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Update the CONCAT scalar function to support Utf8View #12224
Changes from 11 commits
723ceb7
f7abdd5
503d5b9
b30330d
76d6b5f
9798ed3
91d04ff
6d28927
769d99d
ac30a83
7ea6e0a
dd3ad39
504459c
e081934
0069c1a
0929a4b
f16de44
d0bf3ba
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -15,14 +15,13 @@ | |
// specific language governing permissions and limitations | ||
// under the License. | ||
|
||
use arrow::array::{as_largestring_array, Array, StringViewArray}; | ||
use arrow::datatypes::DataType; | ||
use std::any::Any; | ||
use std::sync::Arc; | ||
|
||
use arrow::datatypes::DataType; | ||
use arrow::datatypes::DataType::Utf8; | ||
|
||
use datafusion_common::cast::as_string_array; | ||
use datafusion_common::{internal_err, Result, ScalarValue}; | ||
use datafusion_common::cast::{as_string_array, as_string_view_array}; | ||
use datafusion_common::{internal_err, plan_err, Result, ScalarValue}; | ||
use datafusion_expr::expr::ScalarFunction; | ||
use datafusion_expr::simplify::{ExprSimplifyResult, SimplifyInfo}; | ||
use datafusion_expr::{lit, ColumnarValue, Expr, Volatility}; | ||
|
@@ -46,7 +45,10 @@ impl ConcatFunc { | |
pub fn new() -> Self { | ||
use DataType::*; | ||
Self { | ||
signature: Signature::variadic(vec![Utf8], Volatility::Immutable), | ||
signature: Signature::variadic( | ||
vec![Utf8, Utf8View, LargeUtf8], | ||
Volatility::Immutable, | ||
), | ||
} | ||
} | ||
} | ||
|
@@ -64,13 +66,19 @@ impl ScalarUDFImpl for ConcatFunc { | |
&self.signature | ||
} | ||
|
||
fn return_type(&self, _arg_types: &[DataType]) -> Result<DataType> { | ||
Ok(Utf8) | ||
fn return_type(&self, arg_types: &[DataType]) -> Result<DataType> { | ||
use DataType::*; | ||
Ok(match &arg_types[0] { | ||
Utf8View => Utf8View, | ||
LargeUtf8 => LargeUtf8, | ||
_ => Utf8, | ||
}) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. the logic seems to assume all arguments are of the same type? also, why not always return There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Ah yeah @findepi I think the logic is "Whatever the first argument type is the output should be of that type" so if the received values were: Utf8, Utf8View the output would be Utf8. I'm taking the logic from other UDFs and applying it here. It may not be the best way of doing this though. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
i understand this is what's implemented. but not sure why it is so. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Ah okay. So what you're saying is that here: https://github.com/apache/datafusion/pull/12224/files#diff-71970189679c6dd5b3b677bb21603234b488e68d1601be9c4d400d40e430a909R204 I'm building a Utf8 string anyways? So I suspect I should change that bit of code to use a StringViewArrayBuilder? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. yes, that's my intuition from the issue #11836
this is about inputs to the function, not the return type Side note: There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. @findepi what about for LargeUtf8? I suspect that if a LargeUtf8 is the input then the output should also be that since its an i64 datatype vs the i32 datatype for Utf8? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. i don't know the exact rules for how we handle LargeUtf8. in fact, what does the binary concat operator do? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. @findepi when you say binary concat operator are you talking about There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
yes
the question still holds (why exactly we bias towards the first param type), but i am no longer convinced about my suggestion to use Utf8 always. i think we should "just" make sure
cc @alamb There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Sounds good @findepi i like that logic. I can adjust to make it so. |
||
} | ||
|
||
/// Concatenates the text representations of all the arguments. NULL arguments are ignored. | ||
/// concat('abcde', 2, NULL, 22) = 'abcde222' | ||
fn invoke(&self, args: &[ColumnarValue]) -> Result<ColumnarValue> { | ||
let args_datatype = args[0].data_type(); | ||
devanbenz marked this conversation as resolved.
Show resolved
Hide resolved
|
||
let array_len = args | ||
.iter() | ||
.filter_map(|x| match x { | ||
|
@@ -87,7 +95,21 @@ impl ScalarUDFImpl for ConcatFunc { | |
result.push_str(v); | ||
} | ||
} | ||
return Ok(ColumnarValue::Scalar(ScalarValue::Utf8(Some(result)))); | ||
|
||
return match args_datatype { | ||
DataType::Utf8View => { | ||
Ok(ColumnarValue::Scalar(ScalarValue::Utf8View(Some(result)))) | ||
} | ||
DataType::Utf8 => { | ||
Ok(ColumnarValue::Scalar(ScalarValue::Utf8(Some(result)))) | ||
} | ||
DataType::LargeUtf8 => { | ||
Ok(ColumnarValue::Scalar(ScalarValue::LargeUtf8(Some(result)))) | ||
} | ||
other => { | ||
plan_err!("Concat function does not support datatype of {other}") | ||
} | ||
}; | ||
} | ||
|
||
// Array | ||
|
@@ -103,28 +125,98 @@ impl ScalarUDFImpl for ConcatFunc { | |
columns.push(ColumnarValueRef::Scalar(s.as_bytes())); | ||
} | ||
} | ||
ColumnarValue::Scalar(ScalarValue::Utf8View(maybe_value)) => { | ||
if let Some(s) = maybe_value { | ||
data_size += s.len() * len; | ||
columns.push(ColumnarValueRef::Scalar(s.as_bytes())); | ||
} | ||
} | ||
ColumnarValue::Array(array) => { | ||
let string_array = as_string_array(array)?; | ||
data_size += string_array.values().len(); | ||
let column = if array.is_nullable() { | ||
ColumnarValueRef::NullableArray(string_array) | ||
} else { | ||
ColumnarValueRef::NonNullableArray(string_array) | ||
match array.data_type() { | ||
DataType::Utf8 => { | ||
let string_array = as_string_array(array)?; | ||
|
||
data_size += string_array.values().len(); | ||
let column = if array.is_nullable() { | ||
ColumnarValueRef::NullableArray(string_array) | ||
} else { | ||
ColumnarValueRef::NonNullableArray(string_array) | ||
}; | ||
columns.push(column); | ||
}, | ||
DataType::LargeUtf8 => { | ||
let string_array = as_largestring_array(array); | ||
|
||
data_size += string_array.values().len(); | ||
let column = if array.is_nullable() { | ||
ColumnarValueRef::NullableLargeStringArray(string_array) | ||
} else { | ||
ColumnarValueRef::NonNullableLargeStringArray(string_array) | ||
}; | ||
columns.push(column); | ||
}, | ||
DataType::Utf8View => { | ||
let string_array = as_string_view_array(array)?; | ||
|
||
data_size += string_array.len(); | ||
let column = if array.is_nullable() { | ||
ColumnarValueRef::NullableStringViewArray(string_array) | ||
} else { | ||
ColumnarValueRef::NonNullableStringViewArray(string_array) | ||
}; | ||
columns.push(column); | ||
}, | ||
other => { | ||
return plan_err!("Input was {other} which is not a supported datatype for concat function") | ||
} | ||
}; | ||
columns.push(column); | ||
} | ||
_ => unreachable!(), | ||
} | ||
} | ||
|
||
let mut builder = StringArrayBuilder::with_capacity(len, data_size); | ||
for i in 0..len { | ||
columns | ||
.iter() | ||
.for_each(|column| builder.write::<true>(column, i)); | ||
builder.append_offset(); | ||
match args_datatype { | ||
DataType::Utf8 => { | ||
let mut builder = StringArrayBuilder::with_capacity(len, data_size); | ||
for i in 0..len { | ||
columns | ||
.iter() | ||
.for_each(|column| builder.write::<true>(column, i)); | ||
builder.append_offset(); | ||
} | ||
|
||
let string_array = builder.finish(None); | ||
Ok(ColumnarValue::Array(Arc::new(string_array))) | ||
} | ||
DataType::LargeUtf8 => { | ||
let mut builder = LargeStringArrayBuilder::with_capacity(len, data_size); | ||
for i in 0..len { | ||
columns | ||
.iter() | ||
.for_each(|column| builder.write::<true>(column, i)); | ||
builder.append_offset(); | ||
} | ||
|
||
let string_array = builder.finish(None); | ||
Ok(ColumnarValue::Array(Arc::new(string_array))) | ||
} | ||
DataType::Utf8View => { | ||
let mut builder = StringArrayBuilder::with_capacity(len, data_size); | ||
for i in 0..len { | ||
columns | ||
.iter() | ||
.for_each(|column| builder.write::<true>(column, i)); | ||
builder.append_offset(); | ||
} | ||
|
||
let string_array = builder.finish(None); | ||
let string_array_iter = string_array.into_iter(); | ||
Ok(ColumnarValue::Array(Arc::new(StringViewArray::from_iter( | ||
string_array_iter, | ||
)))) | ||
} | ||
_ => unreachable!(), | ||
} | ||
Ok(ColumnarValue::Array(Arc::new(builder.finish(None)))) | ||
} | ||
|
||
/// Simplify the `concat` function by | ||
|
@@ -151,11 +243,11 @@ pub fn simplify_concat(args: Vec<Expr>) -> Result<ExprSimplifyResult> { | |
for arg in args.clone() { | ||
match arg { | ||
// filter out `null` args | ||
Expr::Literal(ScalarValue::Utf8(None) | ScalarValue::LargeUtf8(None)) => {} | ||
Expr::Literal(ScalarValue::Utf8(None) | ScalarValue::LargeUtf8(None) | ScalarValue::Utf8View(None)) => {} | ||
// All literals have been converted to Utf8 or LargeUtf8 in type_coercion. | ||
// Concatenate it with the `contiguous_scalar`. | ||
Expr::Literal( | ||
ScalarValue::Utf8(Some(v)) | ScalarValue::LargeUtf8(Some(v)), | ||
ScalarValue::Utf8(Some(v)) | ScalarValue::LargeUtf8(Some(v)) | ScalarValue::Utf8View(Some(v)), | ||
) => contiguous_scalar += &v, | ||
Expr::Literal(x) => { | ||
return internal_err!( | ||
|
@@ -197,6 +289,7 @@ mod tests { | |
use crate::utils::test::test_function; | ||
use arrow::array::Array; | ||
use arrow::array::{ArrayRef, StringArray}; | ||
use DataType::*; | ||
|
||
#[test] | ||
fn test_functions() -> Result<()> { | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I wanted to make the already existing
StringArrayBuilder
generic but was having issues 😢