diff --git a/datafusion/functions/Cargo.toml b/datafusion/functions/Cargo.toml index c8025fb2d895..db3e6838f6a5 100644 --- a/datafusion/functions/Cargo.toml +++ b/datafusion/functions/Cargo.toml @@ -218,3 +218,8 @@ required-features = ["math_expressions"] harness = false name = "initcap" required-features = ["unicode_expressions"] + +[[bench]] +harness = false +name = "find_in_set" +required-features = ["unicode_expressions"] diff --git a/datafusion/functions/benches/find_in_set.rs b/datafusion/functions/benches/find_in_set.rs new file mode 100644 index 000000000000..9307525482c2 --- /dev/null +++ b/datafusion/functions/benches/find_in_set.rs @@ -0,0 +1,208 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +extern crate criterion; + +use arrow::array::{StringArray, StringViewArray}; +use arrow::datatypes::DataType; +use arrow::util::bench_util::{ + create_string_array_with_len, create_string_view_array_with_len, +}; +use criterion::{black_box, criterion_group, criterion_main, Criterion, SamplingMode}; +use datafusion_common::ScalarValue; +use datafusion_expr::{ColumnarValue, ScalarFunctionArgs}; +use rand::distributions::Alphanumeric; +use rand::prelude::StdRng; +use rand::{Rng, SeedableRng}; +use std::sync::Arc; +use std::time::Duration; + +/// gen_arr(4096, 128, 0.1, 0.1, true) will generate a StringViewArray with +/// 4096 rows, each row containing a string with 128 random characters. +/// around 10% of the rows are null, around 10% of the rows are non-ASCII. +fn gen_args_array( + n_rows: usize, + str_len_chars: usize, + null_density: f32, + utf8_density: f32, + is_string_view: bool, // false -> StringArray, true -> StringViewArray +) -> Vec { + let mut rng = StdRng::seed_from_u64(42); + let rng_ref = &mut rng; + + let num_elements = 5; // 5 elements separated by comma + let utf8 = "DataFusionДатаФусион数据融合📊🔥"; // includes utf8 encoding with 1~4 bytes + let corpus_char_count = utf8.chars().count(); + + let mut output_set_vec: Vec> = Vec::with_capacity(n_rows); + let mut output_element_vec: Vec> = Vec::with_capacity(n_rows); + for _ in 0..n_rows { + let rand_num = rng_ref.gen::(); // [0.0, 1.0) + if rand_num < null_density { + output_element_vec.push(None); + output_set_vec.push(None); + } else if rand_num < null_density + utf8_density { + // Generate random UTF-8 string with comma separators + let mut generated_string = String::with_capacity(str_len_chars); + for i in 0..num_elements { + for _ in 0..str_len_chars { + let idx = rng_ref.gen_range(0..corpus_char_count); + let char = utf8.chars().nth(idx).unwrap(); + generated_string.push(char); + } + if i < num_elements - 1 { + generated_string.push(','); + } + } + output_element_vec.push(Some(random_element_in_set(&generated_string))); + output_set_vec.push(Some(generated_string)); + } else { + // Generate random ASCII-only string with comma separators + let mut generated_string = String::with_capacity(str_len_chars); + for i in 0..num_elements { + for _ in 0..str_len_chars { + let c = rng_ref.sample(Alphanumeric); + generated_string.push(c as char); + } + if i < num_elements - 1 { + generated_string.push(','); + } + } + output_element_vec.push(Some(random_element_in_set(&generated_string))); + output_set_vec.push(Some(generated_string)); + } + } + + if is_string_view { + let set_array: StringViewArray = output_set_vec.into_iter().collect(); + let element_array: StringViewArray = output_element_vec.into_iter().collect(); + vec![ + ColumnarValue::Array(Arc::new(element_array)), + ColumnarValue::Array(Arc::new(set_array)), + ] + } else { + let set_array: StringArray = output_set_vec.clone().into_iter().collect(); + let element_array: StringArray = output_element_vec.into_iter().collect(); + vec![ + ColumnarValue::Array(Arc::new(element_array)), + ColumnarValue::Array(Arc::new(set_array)), + ] + } +} + +fn random_element_in_set(string: &str) -> String { + let elements: Vec<&str> = string.split(',').collect(); + + if elements.is_empty() || (elements.len() == 1 && elements[0].is_empty()) { + return String::new(); + } + + let mut rng = StdRng::seed_from_u64(44); + let random_index = rng.gen_range(0..elements.len()); + + elements[random_index].to_string() +} + +fn gen_args_scalar( + n_rows: usize, + str_len_chars: usize, + null_density: f32, + is_string_view: bool, // false -> StringArray, true -> StringViewArray +) -> Vec { + let str_list = "Apache,DataFusion,SQL,Query,Engine".to_string(); + if is_string_view { + let string = + create_string_view_array_with_len(n_rows, null_density, str_len_chars, false); + vec![ + ColumnarValue::Array(Arc::new(string)), + ColumnarValue::Scalar(ScalarValue::Utf8(Some(str_list))), + ] + } else { + let string = + create_string_array_with_len::(n_rows, null_density, str_len_chars); + vec![ + ColumnarValue::Array(Arc::new(string)), + ColumnarValue::Scalar(ScalarValue::Utf8(Some(str_list))), + ] + } +} + +fn criterion_benchmark(c: &mut Criterion) { + // All benches are single batch run with 8192 rows + let find_in_set = datafusion_functions::unicode::find_in_set(); + + let n_rows = 8192; + for str_len in [8, 32, 1024] { + let mut group = c.benchmark_group("find_in_set"); + group.sampling_mode(SamplingMode::Flat); + group.sample_size(50); + group.measurement_time(Duration::from_secs(10)); + + let args = gen_args_array(n_rows, str_len, 0.1, 0.5, false); + group.bench_function(format!("string_len_{}", str_len), |b| { + b.iter(|| { + black_box(find_in_set.invoke_with_args(ScalarFunctionArgs { + args: args.clone(), + number_rows: n_rows, + return_type: &DataType::Int32, + })) + }) + }); + + let args = gen_args_array(n_rows, str_len, 0.1, 0.5, true); + group.bench_function(format!("string_view_len_{}", str_len), |b| { + b.iter(|| { + black_box(find_in_set.invoke_with_args(ScalarFunctionArgs { + args: args.clone(), + number_rows: n_rows, + return_type: &DataType::Int32, + })) + }) + }); + + group.finish(); + + let mut group = c.benchmark_group("find_in_set_scalar"); + + let args = gen_args_scalar(n_rows, str_len, 0.1, false); + group.bench_function(format!("string_len_{}", str_len), |b| { + b.iter(|| { + black_box(find_in_set.invoke_with_args(ScalarFunctionArgs { + args: args.clone(), + number_rows: n_rows, + return_type: &DataType::Int32, + })) + }) + }); + + let args = gen_args_scalar(n_rows, str_len, 0.1, true); + group.bench_function(format!("string_view_len_{}", str_len), |b| { + b.iter(|| { + black_box(find_in_set.invoke_with_args(ScalarFunctionArgs { + args: args.clone(), + number_rows: n_rows, + return_type: &DataType::Int32, + })) + }) + }); + + group.finish(); + } +} + +criterion_group!(benches, criterion_benchmark); +criterion_main!(benches); diff --git a/datafusion/functions/src/unicode/find_in_set.rs b/datafusion/functions/src/unicode/find_in_set.rs index c4d9b51f6032..12f213a827cf 100644 --- a/datafusion/functions/src/unicode/find_in_set.rs +++ b/datafusion/functions/src/unicode/find_in_set.rs @@ -19,16 +19,17 @@ use std::any::Any; use std::sync::Arc; use arrow::array::{ - ArrayAccessor, ArrayIter, ArrayRef, ArrowPrimitiveType, AsArray, OffsetSizeTrait, - PrimitiveArray, + new_null_array, ArrayAccessor, ArrayIter, ArrayRef, ArrowPrimitiveType, AsArray, + OffsetSizeTrait, PrimitiveArray, }; use arrow::datatypes::{ArrowNativeType, DataType, Int32Type, Int64Type}; -use crate::utils::{make_scalar_function, utf8_to_int_type}; -use datafusion_common::{exec_err, Result}; +use crate::utils::utf8_to_int_type; +use datafusion_common::{exec_err, internal_err, Result, ScalarValue}; use datafusion_expr::TypeSignature::Exact; use datafusion_expr::{ - ColumnarValue, Documentation, ScalarUDFImpl, Signature, Volatility, + ColumnarValue, Documentation, ScalarFunctionArgs, ScalarUDFImpl, Signature, + Volatility, }; use datafusion_macros::user_doc; @@ -42,7 +43,7 @@ use datafusion_macros::user_doc; | find_in_set(Utf8("b"),Utf8("a,b,c,d")) | +----------------------------------------+ | 2 | -+----------------------------------------+ ++----------------------------------------+ ```"#, argument(name = "str", description = "String expression to find in strlist."), argument( @@ -94,12 +95,141 @@ impl ScalarUDFImpl for FindInSetFunc { utf8_to_int_type(&arg_types[0], "find_in_set") } - fn invoke_batch( - &self, - args: &[ColumnarValue], - _number_rows: usize, - ) -> Result { - make_scalar_function(find_in_set, vec![])(args) + fn invoke_with_args(&self, args: ScalarFunctionArgs) -> Result { + let ScalarFunctionArgs { mut args, .. } = args; + + if args.len() != 2 { + return exec_err!( + "find_in_set was called with {} arguments. It requires 2.", + args.len() + ); + } + + let str_list = args.pop().unwrap(); + let string = args.pop().unwrap(); + + match (string, str_list) { + // both inputs are scalars + ( + ColumnarValue::Scalar( + ScalarValue::Utf8View(string) + | ScalarValue::Utf8(string) + | ScalarValue::LargeUtf8(string), + ), + ColumnarValue::Scalar( + ScalarValue::Utf8View(str_list) + | ScalarValue::Utf8(str_list) + | ScalarValue::LargeUtf8(str_list), + ), + ) => { + let res = match (string, str_list) { + (Some(string), Some(str_list)) => { + let position = str_list + .split(',') + .position(|s| s == string) + .map_or(0, |idx| idx + 1); + + Some(position as i32) + } + _ => None, + }; + Ok(ColumnarValue::Scalar(ScalarValue::from(res))) + } + + // `string` is an array, `str_list` is scalar + ( + ColumnarValue::Array(str_array), + ColumnarValue::Scalar( + ScalarValue::Utf8View(str_list_literal) + | ScalarValue::Utf8(str_list_literal) + | ScalarValue::LargeUtf8(str_list_literal), + ), + ) => { + let result_array = match str_list_literal { + // find_in_set(column_a, null) = null + None => new_null_array(str_array.data_type(), str_array.len()), + Some(str_list_literal) => { + let str_list = str_list_literal.split(',').collect::>(); + let result = match str_array.data_type() { + DataType::Utf8 => { + let string_array = str_array.as_string::(); + find_in_set_right_literal::( + string_array, + str_list, + ) + } + DataType::LargeUtf8 => { + let string_array = str_array.as_string::(); + find_in_set_right_literal::( + string_array, + str_list, + ) + } + DataType::Utf8View => { + let string_array = str_array.as_string_view(); + find_in_set_right_literal::( + string_array, + str_list, + ) + } + other => { + exec_err!("Unsupported data type {other:?} for function find_in_set") + } + }; + Arc::new(result?) + } + }; + Ok(ColumnarValue::Array(result_array)) + } + + // `string` is scalar, `str_list` is an array + ( + ColumnarValue::Scalar( + ScalarValue::Utf8View(string_literal) + | ScalarValue::Utf8(string_literal) + | ScalarValue::LargeUtf8(string_literal), + ), + ColumnarValue::Array(str_list_array), + ) => { + let res = match string_literal { + // find_in_set(null, column_b) = null + None => { + new_null_array(str_list_array.data_type(), str_list_array.len()) + } + Some(string) => { + let result = match str_list_array.data_type() { + DataType::Utf8 => { + let str_list = str_list_array.as_string::(); + find_in_set_left_literal::(string, str_list) + } + DataType::LargeUtf8 => { + let str_list = str_list_array.as_string::(); + find_in_set_left_literal::(string, str_list) + } + DataType::Utf8View => { + let str_list = str_list_array.as_string_view(); + find_in_set_left_literal::(string, str_list) + } + other => { + exec_err!("Unsupported data type {other:?} for function find_in_set") + } + }; + Arc::new(result?) + } + }; + Ok(ColumnarValue::Array(res)) + } + + // both inputs are arrays + (ColumnarValue::Array(base_array), ColumnarValue::Array(exp_array)) => { + let res = find_in_set(base_array, exp_array)?; + + Ok(ColumnarValue::Array(res)) + } + _ => { + internal_err!("Invalid argument types for `find_in_set` function") + } + } } fn documentation(&self) -> Option<&Documentation> { @@ -107,29 +237,24 @@ impl ScalarUDFImpl for FindInSetFunc { } } -///Returns a value in the range of 1 to N if the string str is in the string list strlist consisting of N substrings -///A string list is a string composed of substrings separated by , characters. -fn find_in_set(args: &[ArrayRef]) -> Result { - if args.len() != 2 { - return exec_err!( - "find_in_set was called with {} arguments. It requires 2.", - args.len() - ); - } - match args[0].data_type() { +/// Returns a value in the range of 1 to N if the string `str` is in the string list `strlist` +/// consisting of N substrings. A string list is a string composed of substrings separated by `,` +/// characters. +fn find_in_set(str: ArrayRef, str_list: ArrayRef) -> Result { + match str.data_type() { DataType::Utf8 => { - let string_array = args[0].as_string::(); - let str_list_array = args[1].as_string::(); + let string_array = str.as_string::(); + let str_list_array = str_list.as_string::(); find_in_set_general::(string_array, str_list_array) } DataType::LargeUtf8 => { - let string_array = args[0].as_string::(); - let str_list_array = args[1].as_string::(); + let string_array = str.as_string::(); + let str_list_array = str_list.as_string::(); find_in_set_general::(string_array, str_list_array) } DataType::Utf8View => { - let string_array = args[0].as_string_view(); - let str_list_array = args[1].as_string_view(); + let string_array = str.as_string_view(); + let str_list_array = str_list.as_string_view(); find_in_set_general::(string_array, str_list_array) } other => { @@ -138,31 +263,279 @@ fn find_in_set(args: &[ArrayRef]) -> Result { } } -pub fn find_in_set_general<'a, T: ArrowPrimitiveType, V: ArrayAccessor>( +pub fn find_in_set_general<'a, T, V>( string_array: V, str_list_array: V, ) -> Result where + T: ArrowPrimitiveType, T::Native: OffsetSizeTrait, + V: ArrayAccessor, { let string_iter = ArrayIter::new(string_array); let str_list_iter = ArrayIter::new(str_list_array); - let result = string_iter + + let mut builder = PrimitiveArray::::builder(string_iter.len()); + + string_iter .zip(str_list_iter) - .map(|(string, str_list)| match (string, str_list) { - (Some(string), Some(str_list)) => { - let mut res = 0; - let str_set: Vec<&str> = str_list.split(',').collect(); - for (idx, str) in str_set.iter().enumerate() { - if str == &string { - res = idx + 1; - break; - } + .for_each( + |(string_opt, str_list_opt)| match (string_opt, str_list_opt) { + (Some(string), Some(str_list)) => { + let position = str_list + .split(',') + .position(|s| s == string) + .map_or(0, |idx| idx + 1); + builder.append_value(T::Native::from_usize(position).unwrap()); } - T::Native::from_usize(res) + _ => builder.append_null(), + }, + ); + + Ok(Arc::new(builder.finish()) as ArrayRef) +} + +fn find_in_set_left_literal<'a, T, V>( + string: String, + str_list_array: V, +) -> Result +where + T: ArrowPrimitiveType, + T::Native: OffsetSizeTrait, + V: ArrayAccessor, +{ + let mut builder = PrimitiveArray::::builder(str_list_array.len()); + + let str_list_iter = ArrayIter::new(str_list_array); + + str_list_iter.for_each(|str_list_opt| match str_list_opt { + Some(str_list) => { + let position = str_list + .split(',') + .position(|s| s == string) + .map_or(0, |idx| idx + 1); + builder.append_value(T::Native::from_usize(position).unwrap()); + } + None => builder.append_null(), + }); + + Ok(Arc::new(builder.finish()) as ArrayRef) +} + +fn find_in_set_right_literal<'a, T, V>( + string_array: V, + str_list: Vec<&str>, +) -> Result +where + T: ArrowPrimitiveType, + T::Native: OffsetSizeTrait, + V: ArrayAccessor, +{ + let mut builder = PrimitiveArray::::builder(string_array.len()); + + let string_iter = ArrayIter::new(string_array); + + string_iter.for_each(|string_opt| match string_opt { + Some(string) => { + let position = str_list + .iter() + .position(|s| *s == string) + .map_or(0, |idx| idx + 1); + builder.append_value(T::Native::from_usize(position).unwrap()); + } + None => builder.append_null(), + }); + + Ok(Arc::new(builder.finish()) as ArrayRef) +} + +#[cfg(test)] +mod tests { + use crate::unicode::find_in_set::FindInSetFunc; + use crate::utils::test::test_function; + use arrow::array::{Array, Int32Array, StringArray}; + use arrow::datatypes::DataType::Int32; + use datafusion_common::{Result, ScalarValue}; + use datafusion_expr::{ColumnarValue, ScalarFunctionArgs, ScalarUDFImpl}; + use std::sync::Arc; + + #[test] + fn test_functions() -> Result<()> { + test_function!( + FindInSetFunc::new(), + vec![ + ColumnarValue::Scalar(ScalarValue::Utf8(Some(String::from("a")))), + ColumnarValue::Scalar(ScalarValue::Utf8(Some(String::from("a,b,c")))), + ], + Ok(Some(1)), + i32, + Int32, + Int32Array + ); + test_function!( + FindInSetFunc::new(), + vec![ + ColumnarValue::Scalar(ScalarValue::Utf8View(Some(String::from("🔥")))), + ColumnarValue::Scalar(ScalarValue::Utf8View(Some(String::from( + "a,Д,🔥" + )))), + ], + Ok(Some(3)), + i32, + Int32, + Int32Array + ); + test_function!( + FindInSetFunc::new(), + vec![ + ColumnarValue::Scalar(ScalarValue::Utf8View(Some(String::from("d")))), + ColumnarValue::Scalar(ScalarValue::Utf8View(Some(String::from("a,b,c")))), + ], + Ok(Some(0)), + i32, + Int32, + Int32Array + ); + test_function!( + FindInSetFunc::new(), + vec![ + ColumnarValue::Scalar(ScalarValue::Utf8View(Some(String::from( + "Apache Software Foundation" + )))), + ColumnarValue::Scalar(ScalarValue::Utf8View(Some(String::from( + "Github,Apache Software Foundation,DataFusion" + )))), + ], + Ok(Some(2)), + i32, + Int32, + Int32Array + ); + test_function!( + FindInSetFunc::new(), + vec![ + ColumnarValue::Scalar(ScalarValue::Utf8(Some(String::from("")))), + ColumnarValue::Scalar(ScalarValue::Utf8(Some(String::from("a,b,c")))), + ], + Ok(Some(0)), + i32, + Int32, + Int32Array + ); + test_function!( + FindInSetFunc::new(), + vec![ + ColumnarValue::Scalar(ScalarValue::Utf8(Some(String::from("a")))), + ColumnarValue::Scalar(ScalarValue::Utf8(Some(String::from("")))), + ], + Ok(Some(0)), + i32, + Int32, + Int32Array + ); + test_function!( + FindInSetFunc::new(), + vec![ + ColumnarValue::Scalar(ScalarValue::Utf8View(Some(String::from("a")))), + ColumnarValue::Scalar(ScalarValue::Utf8View(None)), + ], + Ok(None), + i32, + Int32, + Int32Array + ); + test_function!( + FindInSetFunc::new(), + vec![ + ColumnarValue::Scalar(ScalarValue::Utf8View(None)), + ColumnarValue::Scalar(ScalarValue::Utf8View(Some(String::from("a,b,c")))), + ], + Ok(None), + i32, + Int32, + Int32Array + ); + + Ok(()) + } + + macro_rules! test_find_in_set { + ($test_name:ident, $args:expr, $expected:expr) => { + #[test] + fn $test_name() -> Result<()> { + let fis = crate::unicode::find_in_set(); + + let args = $args; + let expected = $expected; + + let type_array = args.iter().map(|a| a.data_type()).collect::>(); + let cardinality = args + .iter() + .fold(Option::::None, |acc, arg| match arg { + ColumnarValue::Scalar(_) => acc, + ColumnarValue::Array(a) => Some(a.len()), + }) + .unwrap_or(1); + let return_type = fis.return_type(&type_array)?; + let result = fis.invoke_with_args(ScalarFunctionArgs { + args, + number_rows: cardinality, + return_type: &return_type, + }); + assert!(result.is_ok()); + + let result = result? + .to_array(cardinality) + .expect("Failed to convert to array"); + let result = result + .as_any() + .downcast_ref::() + .expect("Failed to convert to type"); + assert_eq!(*result, expected); + + Ok(()) } - _ => None, - }) - .collect::>(); - Ok(Arc::new(result) as ArrayRef) + }; + } + + test_find_in_set!( + test_find_in_set_with_scalar_args, + vec![ + ColumnarValue::Array(Arc::new(StringArray::from(vec![ + "", "a", "b", "c", "d" + ]))), + ColumnarValue::Scalar(ScalarValue::Utf8(Some("b,c,d".to_string()))), + ], + Int32Array::from(vec![0, 0, 1, 2, 3]) + ); + test_find_in_set!( + test_find_in_set_with_scalar_args_2, + vec![ + ColumnarValue::Scalar(ScalarValue::Utf8View(Some( + "ApacheSoftware".to_string() + ))), + ColumnarValue::Array(Arc::new(StringArray::from(vec![ + "a,b,c", + "ApacheSoftware,Github,DataFusion", + "" + ]))), + ], + Int32Array::from(vec![0, 1, 0]) + ); + test_find_in_set!( + test_find_in_set_with_scalar_args_3, + vec![ + ColumnarValue::Array(Arc::new(StringArray::from(vec![None::<&str>; 3]))), + ColumnarValue::Scalar(ScalarValue::Utf8View(Some("a,b,c".to_string()))), + ], + Int32Array::from(vec![None::; 3]) + ); + test_find_in_set!( + test_find_in_set_with_scalar_args_4, + vec![ + ColumnarValue::Scalar(ScalarValue::Utf8View(Some("a".to_string()))), + ColumnarValue::Array(Arc::new(StringArray::from(vec![None::<&str>; 3]))), + ], + Int32Array::from(vec![None::; 3]) + ); } diff --git a/datafusion/functions/src/unicode/mod.rs b/datafusion/functions/src/unicode/mod.rs index e8e3eb3f4e75..3c5cde3789ea 100644 --- a/datafusion/functions/src/unicode/mod.rs +++ b/datafusion/functions/src/unicode/mod.rs @@ -102,7 +102,7 @@ pub mod expr_fn { string ),( find_in_set, - "Returns a value in the range of 1 to N if the string str is in the string list strlist consisting of N substrings", + "Returns a value in the range of 1 to N if the string `str` is in the string list `strlist` consisting of N substrings", string strlist ));