Skip to content

Commit

Permalink
Issue-9660 - Extract array_to_string and string_to_array from kernels…
Browse files Browse the repository at this point in the history
… and udf containers
  • Loading branch information
erenavsarogullari committed Mar 20, 2024
1 parent ad8d552 commit b69054b
Show file tree
Hide file tree
Showing 5 changed files with 502 additions and 464 deletions.
329 changes: 5 additions & 324 deletions datafusion/functions-array/src/kernels.rs
Original file line number Diff line number Diff line change
Expand Up @@ -18,10 +18,8 @@
//! implementation kernels for array functions
use arrow::array::{
Array, ArrayRef, BooleanArray, Capacities, Date32Array, Float32Array, Float64Array,
GenericListArray, Int16Array, Int32Array, Int64Array, Int8Array, LargeListArray,
LargeStringArray, ListArray, ListBuilder, MutableArrayData, OffsetSizeTrait,
StringArray, StringBuilder, UInt16Array, UInt32Array, UInt64Array, UInt8Array,
Array, ArrayRef, BooleanArray, Capacities, Date32Array, GenericListArray, Int64Array,
LargeListArray, ListArray, MutableArrayData, OffsetSizeTrait, UInt64Array,
};
use arrow::compute;
use arrow::datatypes::{
Expand All @@ -33,335 +31,18 @@ use arrow_schema::FieldRef;
use arrow_schema::SortOptions;

use datafusion_common::cast::{
as_date32_array, as_generic_list_array, as_generic_string_array, as_int64_array,
as_interval_mdn_array, as_large_list_array, as_list_array, as_null_array,
as_string_array,
as_date32_array, as_generic_list_array, as_int64_array, as_interval_mdn_array,
as_large_list_array, as_list_array, as_null_array, as_string_array,
};
use datafusion_common::{
exec_err, internal_datafusion_err, not_impl_datafusion_err, DataFusionError, Result,
ScalarValue,
};

use crate::utils::downcast_arg;
use std::any::type_name;
use std::sync::Arc;

macro_rules! downcast_arg {
($ARG:expr, $ARRAY_TYPE:ident) => {{
$ARG.as_any().downcast_ref::<$ARRAY_TYPE>().ok_or_else(|| {
DataFusionError::Internal(format!(
"could not cast to {}",
type_name::<$ARRAY_TYPE>()
))
})?
}};
}

macro_rules! to_string {
($ARG:expr, $ARRAY:expr, $DELIMITER:expr, $NULL_STRING:expr, $WITH_NULL_STRING:expr, $ARRAY_TYPE:ident) => {{
let arr = downcast_arg!($ARRAY, $ARRAY_TYPE);
for x in arr {
match x {
Some(x) => {
$ARG.push_str(&x.to_string());
$ARG.push_str($DELIMITER);
}
None => {
if $WITH_NULL_STRING {
$ARG.push_str($NULL_STRING);
$ARG.push_str($DELIMITER);
}
}
}
}
Ok($ARG)
}};
}

macro_rules! call_array_function {
($DATATYPE:expr, false) => {
match $DATATYPE {
DataType::Utf8 => array_function!(StringArray),
DataType::LargeUtf8 => array_function!(LargeStringArray),
DataType::Boolean => array_function!(BooleanArray),
DataType::Float32 => array_function!(Float32Array),
DataType::Float64 => array_function!(Float64Array),
DataType::Int8 => array_function!(Int8Array),
DataType::Int16 => array_function!(Int16Array),
DataType::Int32 => array_function!(Int32Array),
DataType::Int64 => array_function!(Int64Array),
DataType::UInt8 => array_function!(UInt8Array),
DataType::UInt16 => array_function!(UInt16Array),
DataType::UInt32 => array_function!(UInt32Array),
DataType::UInt64 => array_function!(UInt64Array),
_ => unreachable!(),
}
};
($DATATYPE:expr, $INCLUDE_LIST:expr) => {{
match $DATATYPE {
DataType::List(_) => array_function!(ListArray),
DataType::Utf8 => array_function!(StringArray),
DataType::LargeUtf8 => array_function!(LargeStringArray),
DataType::Boolean => array_function!(BooleanArray),
DataType::Float32 => array_function!(Float32Array),
DataType::Float64 => array_function!(Float64Array),
DataType::Int8 => array_function!(Int8Array),
DataType::Int16 => array_function!(Int16Array),
DataType::Int32 => array_function!(Int32Array),
DataType::Int64 => array_function!(Int64Array),
DataType::UInt8 => array_function!(UInt8Array),
DataType::UInt16 => array_function!(UInt16Array),
DataType::UInt32 => array_function!(UInt32Array),
DataType::UInt64 => array_function!(UInt64Array),
_ => unreachable!(),
}
}};
}

/// Array_to_string SQL function
pub(super) fn array_to_string(args: &[ArrayRef]) -> Result<ArrayRef> {
if args.len() < 2 || args.len() > 3 {
return exec_err!("array_to_string expects two or three arguments");
}

let arr = &args[0];

let delimiters = as_string_array(&args[1])?;
let delimiters: Vec<Option<&str>> = delimiters.iter().collect();

let mut null_string = String::from("");
let mut with_null_string = false;
if args.len() == 3 {
null_string = as_string_array(&args[2])?.value(0).to_string();
with_null_string = true;
}

fn compute_array_to_string(
arg: &mut String,
arr: ArrayRef,
delimiter: String,
null_string: String,
with_null_string: bool,
) -> datafusion_common::Result<&mut String> {
match arr.data_type() {
DataType::List(..) => {
let list_array = as_list_array(&arr)?;
for i in 0..list_array.len() {
compute_array_to_string(
arg,
list_array.value(i),
delimiter.clone(),
null_string.clone(),
with_null_string,
)?;
}

Ok(arg)
}
DataType::LargeList(..) => {
let list_array = as_large_list_array(&arr)?;
for i in 0..list_array.len() {
compute_array_to_string(
arg,
list_array.value(i),
delimiter.clone(),
null_string.clone(),
with_null_string,
)?;
}

Ok(arg)
}
DataType::Null => Ok(arg),
data_type => {
macro_rules! array_function {
($ARRAY_TYPE:ident) => {
to_string!(
arg,
arr,
&delimiter,
&null_string,
with_null_string,
$ARRAY_TYPE
)
};
}
call_array_function!(data_type, false)
}
}
}

fn generate_string_array<O: OffsetSizeTrait>(
list_arr: &GenericListArray<O>,
delimiters: Vec<Option<&str>>,
null_string: String,
with_null_string: bool,
) -> datafusion_common::Result<StringArray> {
let mut res: Vec<Option<String>> = Vec::new();
for (arr, &delimiter) in list_arr.iter().zip(delimiters.iter()) {
if let (Some(arr), Some(delimiter)) = (arr, delimiter) {
let mut arg = String::from("");
let s = compute_array_to_string(
&mut arg,
arr,
delimiter.to_string(),
null_string.clone(),
with_null_string,
)?
.clone();

if let Some(s) = s.strip_suffix(delimiter) {
res.push(Some(s.to_string()));
} else {
res.push(Some(s));
}
} else {
res.push(None);
}
}

Ok(StringArray::from(res))
}

let arr_type = arr.data_type();
let string_arr = match arr_type {
DataType::List(_) | DataType::FixedSizeList(_, _) => {
let list_array = as_list_array(&arr)?;
generate_string_array::<i32>(
list_array,
delimiters,
null_string,
with_null_string,
)?
}
DataType::LargeList(_) => {
let list_array = as_large_list_array(&arr)?;
generate_string_array::<i64>(
list_array,
delimiters,
null_string,
with_null_string,
)?
}
_ => {
let mut arg = String::from("");
let mut res: Vec<Option<String>> = Vec::new();
// delimiter length is 1
assert_eq!(delimiters.len(), 1);
let delimiter = delimiters[0].unwrap();
let s = compute_array_to_string(
&mut arg,
arr.clone(),
delimiter.to_string(),
null_string,
with_null_string,
)?
.clone();

if !s.is_empty() {
let s = s.strip_suffix(delimiter).unwrap().to_string();
res.push(Some(s));
} else {
res.push(Some(s));
}
StringArray::from(res)
}
};

Ok(Arc::new(string_arr))
}

/// Splits string at occurrences of delimiter and returns an array of parts
/// string_to_array('abc~@~def~@~ghi', '~@~') = '["abc", "def", "ghi"]'
pub fn string_to_array<T: OffsetSizeTrait>(args: &[ArrayRef]) -> Result<ArrayRef> {
if args.len() < 2 || args.len() > 3 {
return exec_err!("string_to_array expects two or three arguments");
}
let string_array = as_generic_string_array::<T>(&args[0])?;
let delimiter_array = as_generic_string_array::<T>(&args[1])?;

let mut list_builder = ListBuilder::new(StringBuilder::with_capacity(
string_array.len(),
string_array.get_buffer_memory_size(),
));

match args.len() {
2 => {
string_array.iter().zip(delimiter_array.iter()).for_each(
|(string, delimiter)| {
match (string, delimiter) {
(Some(string), Some("")) => {
list_builder.values().append_value(string);
list_builder.append(true);
}
(Some(string), Some(delimiter)) => {
string.split(delimiter).for_each(|s| {
list_builder.values().append_value(s);
});
list_builder.append(true);
}
(Some(string), None) => {
string.chars().map(|c| c.to_string()).for_each(|c| {
list_builder.values().append_value(c);
});
list_builder.append(true);
}
_ => list_builder.append(false), // null value
}
},
);
}

3 => {
let null_value_array = as_generic_string_array::<T>(&args[2])?;
string_array
.iter()
.zip(delimiter_array.iter())
.zip(null_value_array.iter())
.for_each(|((string, delimiter), null_value)| {
match (string, delimiter) {
(Some(string), Some("")) => {
if Some(string) == null_value {
list_builder.values().append_null();
} else {
list_builder.values().append_value(string);
}
list_builder.append(true);
}
(Some(string), Some(delimiter)) => {
string.split(delimiter).for_each(|s| {
if Some(s) == null_value {
list_builder.values().append_null();
} else {
list_builder.values().append_value(s);
}
});
list_builder.append(true);
}
(Some(string), None) => {
string.chars().map(|c| c.to_string()).for_each(|c| {
if Some(c.as_str()) == null_value {
list_builder.values().append_null();
} else {
list_builder.values().append_value(c);
}
});
list_builder.append(true);
}
_ => list_builder.append(false), // null value
}
});
}
_ => {
return exec_err!(
"Expect string_to_array function to take two or three parameters"
)
}
}

let list_array = list_builder.finish();
Ok(Arc::new(list_array) as ArrayRef)
}

/// Generates an array of integers from start to stop with a given step.
///
/// This function takes 1 to 3 ArrayRefs as arguments, representing start, stop, and step values.
Expand Down
9 changes: 5 additions & 4 deletions datafusion/functions-array/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,7 @@ mod remove;
mod replace;
mod rewrite;
mod set_ops;
mod string;
mod udf;
mod utils;

Expand Down Expand Up @@ -73,6 +74,8 @@ pub mod expr_fn {
pub use super::set_ops::array_distinct;
pub use super::set_ops::array_intersect;
pub use super::set_ops::array_union;
pub use super::string::array_to_string;
pub use super::string::string_to_array;
pub use super::udf::array_dims;
pub use super::udf::array_empty;
pub use super::udf::array_length;
Expand All @@ -81,19 +84,17 @@ pub mod expr_fn {
pub use super::udf::array_resize;
pub use super::udf::array_reverse;
pub use super::udf::array_sort;
pub use super::udf::array_to_string;
pub use super::udf::cardinality;
pub use super::udf::flatten;
pub use super::udf::gen_series;
pub use super::udf::range;
pub use super::udf::string_to_array;
}

/// Registers all enabled packages with a [`FunctionRegistry`]
pub fn register_all(registry: &mut dyn FunctionRegistry) -> Result<()> {
let functions: Vec<Arc<ScalarUDF>> = vec![
udf::array_to_string_udf(),
udf::string_to_array_udf(),
string::array_to_string_udf(),
string::string_to_array_udf(),
udf::range_udf(),
udf::gen_series_udf(),
udf::array_dims_udf(),
Expand Down
Loading

0 comments on commit b69054b

Please sign in to comment.