Skip to content

Commit

Permalink
Update ASCII scalar function to support Utf8View apache#11834
Browse files Browse the repository at this point in the history
  • Loading branch information
demetribu committed Aug 8, 2024
1 parent 0bbce5d commit ea14dbc
Show file tree
Hide file tree
Showing 2 changed files with 189 additions and 28 deletions.
118 changes: 90 additions & 28 deletions datafusion/functions/src/string/ascii.rs
Original file line number Diff line number Diff line change
Expand Up @@ -16,33 +16,15 @@
// under the License.

use crate::utils::make_scalar_function;
use arrow::array::Int32Array;
use arrow::array::{ArrayRef, OffsetSizeTrait};
use arrow::array::{ArrayAccessor, ArrayIter, ArrayRef, AsArray, Int32Array};
use arrow::datatypes::DataType;
use datafusion_common::{cast::as_generic_string_array, internal_err, Result};
use arrow::error::ArrowError;
use datafusion_common::{internal_err, Result};
use datafusion_expr::ColumnarValue;
use datafusion_expr::{ScalarUDFImpl, Signature, Volatility};
use std::any::Any;
use std::sync::Arc;

/// Returns the numeric code of the first character of the argument.
/// ascii('x') = 120
pub fn ascii<T: OffsetSizeTrait>(args: &[ArrayRef]) -> Result<ArrayRef> {
let string_array = as_generic_string_array::<T>(&args[0])?;

let result = string_array
.iter()
.map(|string| {
string.map(|string: &str| {
let mut chars = string.chars();
chars.next().map_or(0, |v| v as i32)
})
})
.collect::<Int32Array>();

Ok(Arc::new(result) as ArrayRef)
}

#[derive(Debug)]
pub struct AsciiFunc {
signature: Signature,
Expand All @@ -60,7 +42,7 @@ impl AsciiFunc {
Self {
signature: Signature::uniform(
1,
vec![Utf8, LargeUtf8],
vec![Utf8, LargeUtf8, Utf8View],
Volatility::Immutable,
),
}
Expand All @@ -87,12 +69,92 @@ impl ScalarUDFImpl for AsciiFunc {
}

fn invoke(&self, args: &[ColumnarValue]) -> Result<ColumnarValue> {
match args[0].data_type() {
DataType::Utf8 => make_scalar_function(ascii::<i32>, vec![])(args),
DataType::LargeUtf8 => {
return make_scalar_function(ascii::<i64>, vec![])(args);
}
_ => internal_err!("Unsupported data type"),
make_scalar_function(ascii, vec![])(args)
}
}

fn calculate_ascii<'a, V>(array: V) -> Result<ArrayRef, ArrowError>
where
V: ArrayAccessor<Item = &'a str>,
{
let iter = ArrayIter::new(array);
let result = iter
.map(|string| {
string.map(|s| {
let mut chars = s.chars();
chars.next().map_or(0, |v| v as i32)
})
})
.collect::<Int32Array>();

Ok(Arc::new(result) as ArrayRef)
}

/// Returns the numeric code of the first character of the argument.
pub fn ascii(args: &[ArrayRef]) -> Result<ArrayRef> {
match args[0].data_type() {
DataType::Utf8 => {
let string_array = args[0].as_string::<i32>();
Ok(calculate_ascii(string_array)?)
}
DataType::LargeUtf8 => {
let string_array = args[0].as_string::<i64>();
Ok(calculate_ascii(string_array)?)
}
DataType::Utf8View => {
let string_array = args[0].as_string_view();
Ok(calculate_ascii(string_array)?)
}
_ => internal_err!("Unsupported data type"),
}
}

#[cfg(test)]
mod tests {
use crate::string::ascii::AsciiFunc;
use crate::utils::test::test_function;
use arrow::array::{Array, Int32Array};
use arrow::datatypes::DataType::Int32;
use datafusion_common::{Result, ScalarValue};
use datafusion_expr::{ColumnarValue, ScalarUDFImpl};

macro_rules! test_ascii {
($INPUT:expr, $EXPECTED:expr) => {
test_function!(
AsciiFunc::new(),
&[ColumnarValue::Scalar(ScalarValue::Utf8($INPUT))],
$EXPECTED,
i32,
Int32,
Int32Array
);

test_function!(
AsciiFunc::new(),
&[ColumnarValue::Scalar(ScalarValue::LargeUtf8($INPUT))],
$EXPECTED,
i32,
Int32,
Int32Array
);

test_function!(
AsciiFunc::new(),
&[ColumnarValue::Scalar(ScalarValue::Utf8View($INPUT))],
$EXPECTED,
i32,
Int32,
Int32Array
);
};
}

#[test]
fn test_functions() -> Result<()> {
test_ascii!(Some(String::from("x")), Ok(Some(120)));
test_ascii!(Some(String::from("a")), Ok(Some(97)));
test_ascii!(Some(String::from("")), Ok(Some(0)));
test_ascii!(None, Ok(None));
Ok(())
}
}
99 changes: 99 additions & 0 deletions datafusion/sqllogictest/test_files/string_view.slt
Original file line number Diff line number Diff line change
Expand Up @@ -500,3 +500,102 @@ select column2|| ' ' ||column3 from temp;
----
rust fast
datafusion cool

### ASCII
# Setup the initial test data
statement ok
create table test_source as values
('Andrew', 'X'),
('Xiangpeng', 'Xiangpeng'),
('Raphael', 'R'),
(NULL, 'R');

# Table with the different combination of column types
statement ok
create table test as
SELECT
arrow_cast(column1, 'Utf8') as column1_utf8,
arrow_cast(column2, 'Utf8') as column2_utf8,
arrow_cast(column1, 'LargeUtf8') as column1_large_utf8,
arrow_cast(column2, 'LargeUtf8') as column2_large_utf8,
arrow_cast(column1, 'Utf8View') as column1_utf8view,
arrow_cast(column2, 'Utf8View') as column2_utf8view
FROM test_source;

# Test ASCII with utf8view against utf8view, utf8, and largeutf8
# (should be no casts)
query TT
EXPLAIN SELECT
ASCII(column1_utf8view) as c1,
ASCII(column2_utf8) as c2,
ASCII(column2_large_utf8) as c3
FROM test;
----
logical_plan
01)Projection: ascii(test.column1_utf8view) AS c1, ascii(test.column2_utf8) AS c2, ascii(test.column2_large_utf8) AS c3
02)--TableScan: test projection=[column2_utf8, column2_large_utf8, column1_utf8view]

query III
SELECT
ASCII(column1_utf8view) as c1,
ASCII(column2_utf8) as c2,
ASCII(column2_large_utf8) as c3
FROM test;
----
65 88 88
88 88 88
82 82 82
NULL 82 82

query TT
EXPLAIN SELECT
ASCII(column1_utf8) as c1,
ASCII(column1_large_utf8) as c2,
ASCII(column2_utf8view) as c3,
ASCII('hello') as c4,
ASCII(arrow_cast('world', 'Utf8View')) as c5
FROM test;
----
logical_plan
01)Projection: ascii(test.column1_utf8) AS c1, ascii(test.column1_large_utf8) AS c2, ascii(test.column2_utf8view) AS c3, Int32(104) AS c4, Int32(119) AS c5
02)--TableScan: test projection=[column1_utf8, column1_large_utf8, column2_utf8view]

query IIIII
SELECT
ASCII(column1_utf8) as c1,
ASCII(column1_large_utf8) as c2,
ASCII(column2_utf8view) as c3,
ASCII('hello') as c4,
ASCII(arrow_cast('world', 'Utf8View')) as c5
FROM test;
----
65 65 88 104 119
88 88 88 104 119
82 82 82 104 119
NULL NULL 82 104 119

# Test ASCII with literals cast to Utf8View
query TT
EXPLAIN SELECT
ASCII(arrow_cast('äöüß', 'Utf8View')) as c1,
ASCII(arrow_cast('', 'Utf8View')) as c2,
ASCII(arrow_cast(NULL, 'Utf8View')) as c3
FROM test;
----
logical_plan
01)Projection: Int32(228) AS c1, Int32(0) AS c2, Int32(NULL) AS c3
02)--TableScan: test projection=[]

query III
SELECT
ASCII(arrow_cast('äöüß', 'Utf8View')) as c1,
ASCII(arrow_cast('', 'Utf8View')) as c2,
ASCII(arrow_cast(NULL, 'Utf8View')) as c3
----
228 0 NULL

statement ok
drop table test;

statement ok
drop table test_source;

0 comments on commit ea14dbc

Please sign in to comment.