Skip to content

Commit

Permalink
Validate ScalarUDF output rows and fix nulls for array_has and `get…
Browse files Browse the repository at this point in the history
…_field` for `Map` (apache#10148)

* validate input/output of udf

* clip

* fmt

* clean garbage

* don't check if output is scalar

* lint

* fix array_has

* rm debug

* chore: temp code for demonstration

* getfield retains number of rows

* rust fmt

* minor comments

* fmt

* refactor

* compile err

* fmt again

* fmt

* add validate_number_of_rows for UDF

* only check for columnarvalue::array
  • Loading branch information
duongcongtoai authored Apr 29, 2024
1 parent acd9865 commit 0f2a68e
Show file tree
Hide file tree
Showing 6 changed files with 141 additions and 60 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@ use datafusion_common::{
assert_batches_eq, assert_batches_sorted_eq, cast::as_float64_array,
cast::as_int32_array, not_impl_err, plan_err, ExprSchema, Result, ScalarValue,
};
use datafusion_common::{exec_err, internal_err, DataFusionError};
use datafusion_common::{assert_contains, exec_err, internal_err, DataFusionError};
use datafusion_execution::runtime_env::{RuntimeConfig, RuntimeEnv};
use datafusion_expr::simplify::{ExprSimplifyResult, SimplifyInfo};
use datafusion_expr::{
Expand Down Expand Up @@ -205,6 +205,44 @@ impl ScalarUDFImpl for Simple0ArgsScalarUDF {
}
}

#[tokio::test]
async fn test_row_mismatch_error_in_scalar_udf() -> Result<()> {
let schema = Schema::new(vec![Field::new("a", DataType::Int32, false)]);

let batch = RecordBatch::try_new(
Arc::new(schema.clone()),
vec![Arc::new(Int32Array::from(vec![1, 2]))],
)?;

let ctx = SessionContext::new();

ctx.register_batch("t", batch)?;

// udf that always return 1 row
let buggy_udf = Arc::new(|_: &[ColumnarValue]| {
Ok(ColumnarValue::Array(Arc::new(Int32Array::from(vec![0]))))
});

ctx.register_udf(create_udf(
"buggy_func",
vec![DataType::Int32],
Arc::new(DataType::Int32),
Volatility::Immutable,
buggy_udf,
));
assert_contains!(
ctx.sql("select buggy_func(a) from t")
.await?
.show()
.await
.err()
.unwrap()
.to_string(),
"UDF returned a different number of rows than expected"
);
Ok(())
}

#[tokio::test]
async fn scalar_udf_zero_params() -> Result<()> {
let schema = Schema::new(vec![Field::new("a", DataType::Int32, false)]);
Expand Down
60 changes: 32 additions & 28 deletions datafusion/functions-array/src/array_has.rs
Original file line number Diff line number Diff line change
Expand Up @@ -288,36 +288,40 @@ fn general_array_has_dispatch<O: OffsetSizeTrait>(
} else {
array
};

for (row_idx, (arr, sub_arr)) in array.iter().zip(sub_array.iter()).enumerate() {
if let (Some(arr), Some(sub_arr)) = (arr, sub_arr) {
let arr_values = converter.convert_columns(&[arr])?;
let sub_arr_values = if comparison_type != ComparisonType::Single {
converter.convert_columns(&[sub_arr])?
} else {
converter.convert_columns(&[element.clone()])?
};

let mut res = match comparison_type {
ComparisonType::All => sub_arr_values
.iter()
.dedup()
.all(|elem| arr_values.iter().dedup().any(|x| x == elem)),
ComparisonType::Any => sub_arr_values
.iter()
.dedup()
.any(|elem| arr_values.iter().dedup().any(|x| x == elem)),
ComparisonType::Single => arr_values
.iter()
.dedup()
.any(|x| x == sub_arr_values.row(row_idx)),
};

if comparison_type == ComparisonType::Any {
res |= res;
match (arr, sub_arr) {
(Some(arr), Some(sub_arr)) => {
let arr_values = converter.convert_columns(&[arr])?;
let sub_arr_values = if comparison_type != ComparisonType::Single {
converter.convert_columns(&[sub_arr])?
} else {
converter.convert_columns(&[element.clone()])?
};

let mut res = match comparison_type {
ComparisonType::All => sub_arr_values
.iter()
.dedup()
.all(|elem| arr_values.iter().dedup().any(|x| x == elem)),
ComparisonType::Any => sub_arr_values
.iter()
.dedup()
.any(|elem| arr_values.iter().dedup().any(|x| x == elem)),
ComparisonType::Single => arr_values
.iter()
.dedup()
.any(|x| x == sub_arr_values.row(row_idx)),
};

if comparison_type == ComparisonType::Any {
res |= res;
}
boolean_builder.append_value(res);
}
// respect null input
(_, _) => {
boolean_builder.append_null();
}

boolean_builder.append_value(res);
}
}
Ok(Arc::new(boolean_builder.finish()))
Expand Down
70 changes: 49 additions & 21 deletions datafusion/functions/src/core/getfield.rs
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,9 @@
// specific language governing permissions and limitations
// under the License.

use arrow::array::{Scalar, StringArray};
use arrow::array::{
make_array, Array, Capacities, MutableArrayData, Scalar, StringArray,
};
use arrow::datatypes::DataType;
use datafusion_common::cast::{as_map_array, as_struct_array};
use datafusion_common::{exec_err, ExprSchema, Result, ScalarValue};
Expand Down Expand Up @@ -107,29 +109,55 @@ impl ScalarUDFImpl for GetFieldFunc {
);
}
};

match (array.data_type(), name) {
(DataType::Map(_, _), ScalarValue::Utf8(Some(k))) => {
let map_array = as_map_array(array.as_ref())?;
let key_scalar = Scalar::new(StringArray::from(vec![k.clone()]));
let keys = arrow::compute::kernels::cmp::eq(&key_scalar, map_array.keys())?;
let entries = arrow::compute::filter(map_array.entries(), &keys)?;
let entries_struct_array = as_struct_array(entries.as_ref())?;
Ok(ColumnarValue::Array(entries_struct_array.column(1).clone()))
}
(DataType::Struct(_), ScalarValue::Utf8(Some(k))) => {
let as_struct_array = as_struct_array(&array)?;
match as_struct_array.column_by_name(k) {
None => exec_err!(
"get indexed field {k} not found in struct"),
Some(col) => Ok(ColumnarValue::Array(col.clone()))
(DataType::Map(_, _), ScalarValue::Utf8(Some(k))) => {
let map_array = as_map_array(array.as_ref())?;
let key_scalar: Scalar<arrow::array::GenericByteArray<arrow::datatypes::GenericStringType<i32>>> = Scalar::new(StringArray::from(vec![k.clone()]));
let keys = arrow::compute::kernels::cmp::eq(&key_scalar, map_array.keys())?;

// note that this array has more entries than the expected output/input size
// because maparray is flatten
let original_data = map_array.entries().column(1).to_data();
let capacity = Capacities::Array(original_data.len());
let mut mutable =
MutableArrayData::with_capacities(vec![&original_data], true,
capacity);

for entry in 0..map_array.len(){
let start = map_array.value_offsets()[entry] as usize;
let end = map_array.value_offsets()[entry + 1] as usize;

let maybe_matched =
keys.slice(start, end-start).
iter().enumerate().
find(|(_, t)| t.unwrap());
if maybe_matched.is_none(){
mutable.extend_nulls(1);
continue
}
let (match_offset,_) = maybe_matched.unwrap();
mutable.extend(0, start + match_offset, start + match_offset + 1);
}
let data = mutable.freeze();
let data = make_array(data);
Ok(ColumnarValue::Array(data))
}
(DataType::Struct(_), ScalarValue::Utf8(Some(k))) => {
let as_struct_array = as_struct_array(&array)?;
match as_struct_array.column_by_name(k) {
None => exec_err!("get indexed field {k} not found in struct"),
Some(col) => Ok(ColumnarValue::Array(col.clone())),
}
(DataType::Struct(_), name) => exec_err!(
"get indexed field is only possible on struct with utf8 indexes. \
Tried with {name:?} index"),
(dt, name) => exec_err!(
"get indexed field is only possible on lists with int64 indexes or struct \
with utf8 indexes. Tried {dt:?} with {name:?} index"),
}
(DataType::Struct(_), name) => exec_err!(
"get indexed field is only possible on struct with utf8 indexes. \
Tried with {name:?} index"
),
(dt, name) => exec_err!(
"get indexed field is only possible on lists with int64 indexes or struct \
with utf8 indexes. Tried {dt:?} with {name:?} index"
),
}
}
}
15 changes: 11 additions & 4 deletions datafusion/physical-expr/src/scalar_function.rs
Original file line number Diff line number Diff line change
Expand Up @@ -146,11 +146,18 @@ impl PhysicalExpr for ScalarFunctionExpr {
// evaluate the function
match self.fun {
ScalarFunctionDefinition::UDF(ref fun) => {
if self.args.is_empty() {
fun.invoke_no_args(batch.num_rows())
} else {
fun.invoke(&inputs)
let output = match self.args.is_empty() {
true => fun.invoke_no_args(batch.num_rows()),
false => fun.invoke(&inputs),
}?;

if let ColumnarValue::Array(array) = &output {
if array.len() != batch.num_rows() {
return internal_err!("UDF returned a different number of rows than expected. Expected: {}, Got: {}",
batch.num_rows(), array.len());
}
}
Ok(output)
}
ScalarFunctionDefinition::Name(_) => {
internal_err!(
Expand Down
15 changes: 9 additions & 6 deletions datafusion/sqllogictest/test_files/array.slt
Original file line number Diff line number Diff line change
Expand Up @@ -5169,8 +5169,9 @@ false false false true
true false true false
true false false true
false true false false
false false false false
false false false false
NULL NULL false false
false false NULL false
false false false NULL

query BBBB
select array_has(arrow_cast(column1, 'LargeList(List(Int64))'), make_array(5, 6)),
Expand All @@ -5183,8 +5184,9 @@ false false false true
true false true false
true false false true
false true false false
false false false false
false false false false
NULL NULL false false
false false NULL false
false false false NULL

query BBBB
select array_has(column1, make_array(5, 6)),
Expand All @@ -5197,8 +5199,9 @@ false false false true
true false true false
true false false true
false true false false
false false false false
false false false false
NULL NULL false false
false false NULL false
false false false NULL

query BBBBBBBBBBBBB
select array_has_all(make_array(1,2,3), make_array(1,3)),
Expand Down
1 change: 1 addition & 0 deletions datafusion/sqllogictest/test_files/map.slt
Original file line number Diff line number Diff line change
Expand Up @@ -44,6 +44,7 @@ DELETE 24
query T
SELECT strings['not_found'] FROM data LIMIT 1;
----
NULL

statement ok
drop table data;
Expand Down

0 comments on commit 0f2a68e

Please sign in to comment.