Skip to content

Commit

Permalink
feat: upper and lower for utf8view
Browse files Browse the repository at this point in the history
  • Loading branch information
tshauck committed Aug 22, 2024
1 parent cb1e3f0 commit 9cc02d6
Show file tree
Hide file tree
Showing 4 changed files with 57 additions and 77 deletions.
102 changes: 33 additions & 69 deletions datafusion/functions/src/string/common.rs
Original file line number Diff line number Diff line change
Expand Up @@ -20,10 +20,9 @@ use std::sync::Arc;

use arrow::array::{
new_null_array, Array, ArrayAccessor, ArrayDataBuilder, ArrayIter, ArrayRef,
GenericStringArray, GenericStringBuilder, OffsetSizeTrait, StringArray,
StringViewArray,
GenericStringArray, OffsetSizeTrait, StringArray, StringViewArray,
};
use arrow::buffer::{Buffer, MutableBuffer, NullBuffer};
use arrow::buffer::{MutableBuffer, NullBuffer};
use arrow::datatypes::DataType;

use datafusion_common::cast::{as_generic_string_array, as_string_view_array};
Expand Down Expand Up @@ -187,12 +186,12 @@ fn string_trim<'a, T: OffsetSizeTrait>(
}
}

pub(crate) fn to_lower(args: &[ColumnarValue], name: &str) -> Result<ColumnarValue> {
case_conversion(args, |string| string.to_lowercase(), name)
pub(crate) fn to_lower(args: &[ColumnarValue]) -> Result<ColumnarValue> {
case_conversion(args, |string| string.to_lowercase(), "lower")
}

pub(crate) fn to_upper(args: &[ColumnarValue], name: &str) -> Result<ColumnarValue> {
case_conversion(args, |string| string.to_uppercase(), name)
pub(crate) fn to_upper(args: &[ColumnarValue]) -> Result<ColumnarValue> {
case_conversion(args, |string| string.to_uppercase(), "upper")
}

fn case_conversion<'a, F>(
Expand All @@ -205,20 +204,32 @@ where
{
match &args[0] {
ColumnarValue::Array(array) => match array.data_type() {
DataType::Utf8 => Ok(ColumnarValue::Array(case_conversion_array::<i32, _>(
array, op,
)?)),
DataType::LargeUtf8 => Ok(ColumnarValue::Array(case_conversion_array::<
i64,
_,
>(array, op)?)),
DataType::Utf8 => {
let array = as_generic_string_array::<i32>(array)?;
let array = case_conversion_array::<_, i32, _>(array, op)?;
Ok(ColumnarValue::Array(array))
}
DataType::LargeUtf8 => {
let array = as_generic_string_array::<i64>(array)?;
let array = case_conversion_array::<_, i64, _>(array, op)?;
Ok(ColumnarValue::Array(array))
}
DataType::Utf8View => {
let array = as_string_view_array(array)?;
let array = case_conversion_array::<_, i32, _>(array, op)?;
Ok(ColumnarValue::Array(array))
}
other => exec_err!("Unsupported data type {other:?} for function {name}"),
},
ColumnarValue::Scalar(scalar) => match scalar {
ScalarValue::Utf8(a) => {
let result = a.as_ref().map(|x| op(x));
Ok(ColumnarValue::Scalar(ScalarValue::Utf8(result)))
}
ScalarValue::Utf8View(a) => {
let result = a.as_ref().map(|x| op(x));
Ok(ColumnarValue::Scalar(ScalarValue::Utf8View(result)))
}
ScalarValue::LargeUtf8(a) => {
let result = a.as_ref().map(|x| op(x));
Ok(ColumnarValue::Scalar(ScalarValue::LargeUtf8(result)))
Expand Down Expand Up @@ -333,64 +344,17 @@ impl StringArrayBuilder {
}
}

fn case_conversion_array<'a, O, F>(array: &'a ArrayRef, op: F) -> Result<ArrayRef>
fn case_conversion_array<'a, V, T, F>(array: V, op: F) -> Result<ArrayRef>
where
O: OffsetSizeTrait,
V: ArrayAccessor<Item = &'a str>,
T: OffsetSizeTrait,
F: Fn(&'a str) -> String,
{
const PRE_ALLOC_BYTES: usize = 8;

let string_array = as_generic_string_array::<O>(array)?;
let value_data = string_array.value_data();

// All values are ASCII.
if value_data.is_ascii() {
return case_conversion_ascii_array::<O, _>(string_array, op);
}
let iter = ArrayIter::new(array);

// Values contain non-ASCII.
let item_len = string_array.len();
let capacity = string_array.value_data().len() + PRE_ALLOC_BYTES;
let mut builder = GenericStringBuilder::<O>::with_capacity(item_len, capacity);
let result = iter
.map(|string| string.map(|string| op(string)))
.collect::<GenericStringArray<T>>();

if string_array.null_count() == 0 {
let iter =
(0..item_len).map(|i| Some(op(unsafe { string_array.value_unchecked(i) })));
builder.extend(iter);
} else {
let iter = string_array.iter().map(|string| string.map(&op));
builder.extend(iter);
}
Ok(Arc::new(builder.finish()))
}

/// All values of string_array are ASCII, and when converting case, there is no changes in the byte
/// array length. Therefore, the StringArray can be treated as a complete ASCII string for
/// case conversion, and we can reuse the offsets buffer and the nulls buffer.
fn case_conversion_ascii_array<'a, O, F>(
string_array: &'a GenericStringArray<O>,
op: F,
) -> Result<ArrayRef>
where
O: OffsetSizeTrait,
F: Fn(&'a str) -> String,
{
let value_data = string_array.value_data();
// SAFETY: all items stored in value_data satisfy UTF8.
// ref: impl ByteArrayNativeType for str {...}
let str_values = unsafe { std::str::from_utf8_unchecked(value_data) };

// conversion
let converted_values = op(str_values);
assert_eq!(converted_values.len(), str_values.len());
let bytes = converted_values.into_bytes();

// build result
let values = Buffer::from_vec(bytes);
let offsets = string_array.offsets().clone();
let nulls = string_array.nulls().cloned();
// SAFETY: offsets and nulls are consistent with the input array.
Ok(Arc::new(unsafe {
GenericStringArray::<O>::new_unchecked(offsets, values, nulls)
}))
Ok(Arc::new(result) as ArrayRef)
}
4 changes: 2 additions & 2 deletions datafusion/functions/src/string/lower.rs
Original file line number Diff line number Diff line change
Expand Up @@ -43,7 +43,7 @@ impl LowerFunc {
Self {
signature: Signature::uniform(
1,
vec![Utf8, LargeUtf8],
vec![Utf8, LargeUtf8, Utf8View],
Volatility::Immutable,
),
}
Expand All @@ -68,7 +68,7 @@ impl ScalarUDFImpl for LowerFunc {
}

fn invoke(&self, args: &[ColumnarValue]) -> Result<ColumnarValue> {
to_lower(args, "lower")
to_lower(args)
}
}

Expand Down
4 changes: 2 additions & 2 deletions datafusion/functions/src/string/upper.rs
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,7 @@ impl UpperFunc {
Self {
signature: Signature::uniform(
1,
vec![Utf8, LargeUtf8],
vec![Utf8, LargeUtf8, Utf8View],
Volatility::Immutable,
),
}
Expand All @@ -65,7 +65,7 @@ impl ScalarUDFImpl for UpperFunc {
}

fn invoke(&self, args: &[ColumnarValue]) -> Result<ColumnarValue> {
to_upper(args, "upper")
to_upper(args)
}
}

Expand Down
24 changes: 20 additions & 4 deletions datafusion/sqllogictest/test_files/string_view.slt
Original file line number Diff line number Diff line change
Expand Up @@ -501,7 +501,7 @@ SELECT
INITCAP(column1_large_utf8_lower) as c3
FROM test_lowercase;
----
Andrew Andrew Andrew
Andrew Andrew Andrew
Xiangpeng Xiangpeng Xiangpeng
Raphael Raphael Raphael
NULL NULL NULL
Expand Down Expand Up @@ -827,17 +827,33 @@ logical_plan
01)Projection: levenshtein(test.column1_utf8view, Utf8View("foo")) AS c1, levenshtein(test.column1_utf8view, test.column2_utf8view) AS c2
02)--TableScan: test projection=[column1_utf8view, column2_utf8view]

## Ensure no casts for LOWER
## TODO https://github.com/apache/datafusion/issues/11855
query TT
EXPLAIN SELECT
LOWER(column1_utf8view) as c1
FROM test;
----
logical_plan
01)Projection: lower(CAST(test.column1_utf8view AS Utf8)) AS c1
01)Projection: lower(test.column1_utf8view) AS c1
02)--TableScan: test projection=[column1_utf8view]

query T
SELECT LOWER(column1_utf8view) as c1
FROM test;
----
andrew
xiangpeng
raphael
NULL

query T
SELECT UPPER(column1_utf8view) as c1
FROM test;
----
ANDREW
XIANGPENG
RAPHAEL
NULL


## Ensure no casts for LPAD
query TT
Expand Down

0 comments on commit 9cc02d6

Please sign in to comment.