Skip to content

Commit

Permalink
Finish prototype
Browse files Browse the repository at this point in the history
  • Loading branch information
xinlifoobar committed Jul 16, 2024
1 parent a73167a commit f99fd05
Show file tree
Hide file tree
Showing 2 changed files with 27 additions and 36 deletions.
14 changes: 10 additions & 4 deletions datafusion/functions/src/udf.rs
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@
use std::sync::Arc;

use arrow::{
array::{Array, RecordBatch},
array::{Array, ArrayRef, RecordBatch},
datatypes::{Field, Schema, SchemaRef},
};
use arrow_udf::{function, sig::REGISTRY};
Expand Down Expand Up @@ -50,12 +50,18 @@ fn eq<T: Eq>(lhs: T, rhs: T) -> bool {
// Predicate::like(rhs).unwrap().matches(lhs);
// }

#[function("concat(string, string) -> string")]
#[function("concat(largestring, largestring) -> largestring")]
fn concat(lhs: &str, rhs: &str) -> String {
format!("{}{}", lhs, rhs)
}

pub fn apply_udf(
lhs: &ColumnarValue,
rhs: &ColumnarValue,
return_field: &Field,
udf_name: &str,
) -> Result<ColumnarValue> {
) -> Result<ArrayRef> {
let (record_batch, schema) = match (lhs, rhs) {
(ColumnarValue::Array(left), ColumnarValue::Array(right)) => {
let schema = Arc::new(Schema::new(vec![
Expand Down Expand Up @@ -109,7 +115,7 @@ fn apply_udf_inner(
record_batch: &RecordBatch,
return_field: &Field,
udf_name: &str,
) -> Result<ColumnarValue> {
) -> Result<ArrayRef> {
println!("schema: {:?}", schema);

let Some(eval) = REGISTRY
Expand All @@ -132,5 +138,5 @@ fn apply_udf_inner(

let result_array = result.column_by_name(udf_name).unwrap();

Ok(ColumnarValue::Array(Arc::clone(result_array)))
Ok(result_array.to_owned())
}
49 changes: 17 additions & 32 deletions datafusion/physical-expr/src/expressions/binary.rs
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,6 @@ use arrow::compute::kernels::cmp::*;
use arrow::compute::kernels::comparison::{
regexp_is_match_utf8, regexp_is_match_utf8_scalar,
};
use arrow::compute::kernels::concat_elements::concat_elements_utf8;
use arrow::compute::{cast, ilike, like, nilike, nlike};
use arrow::datatypes::*;
use datafusion_common::cast::as_boolean_array;
Expand Down Expand Up @@ -132,34 +131,6 @@ impl std::fmt::Display for BinaryExpr {
}
}

/// Invoke a compute kernel on a pair of binary data arrays
macro_rules! compute_utf8_op {
($LEFT:expr, $RIGHT:expr, $OP:ident, $DT:ident) => {{
let ll = $LEFT
.as_any()
.downcast_ref::<$DT>()
.expect("compute_op failed to downcast left side array");
let rr = $RIGHT
.as_any()
.downcast_ref::<$DT>()
.expect("compute_op failed to downcast right side array");
Ok(Arc::new(paste::expr! {[<$OP _utf8>]}(&ll, &rr)?))
}};
}

macro_rules! binary_string_array_op {
($LEFT:expr, $RIGHT:expr, $OP:ident) => {{
match $LEFT.data_type() {
DataType::Utf8 => compute_utf8_op!($LEFT, $RIGHT, $OP, StringArray),
DataType::LargeUtf8 => compute_utf8_op!($LEFT, $RIGHT, $OP, LargeStringArray),
other => internal_err!(
"Data type {:?} not supported for binary operation '{}' on string arrays",
other, stringify!($OP)
),
}
}};
}

/// Invoke a boolean kernel on a pair of arrays
macro_rules! boolean_op {
($LEFT:expr, $RIGHT:expr, $OP:ident) => {{
Expand Down Expand Up @@ -303,12 +274,12 @@ impl PhysicalExpr for BinaryExpr {
#[cfg(not(feature = "arrow_udf"))]
return apply_cmp(&lhs, &rhs, eq);
#[cfg(feature = "arrow_udf")]
return apply_udf(
return Ok(ColumnarValue::Array(apply_udf(
&lhs,
&rhs,
&Field::new("", DataType::Boolean, true),
"eq",
);
)?));
}
Operator::NotEq => return apply_cmp(&lhs, &rhs, neq),
Operator::Lt => return apply_cmp(&lhs, &rhs, lt),
Expand Down Expand Up @@ -669,7 +640,21 @@ impl BinaryExpr {
BitwiseXor => bitwise_xor_dyn(left, right),
BitwiseShiftRight => bitwise_shift_right_dyn(left, right),
BitwiseShiftLeft => bitwise_shift_left_dyn(left, right),
StringConcat => binary_string_array_op!(left, right, concat_elements),
StringConcat => {
#[cfg(not(feature = "arrow_udf"))]
{
binary_string_array_op!(left, right, concat_elements)
}
#[cfg(feature = "arrow_udf")]
{
apply_udf(
&ColumnarValue::Array(left),
&ColumnarValue::Array(right),
&Field::new("", DataType::Utf8, true),
"concat",
)
}
}
AtArrow | ArrowAt => {
unreachable!("ArrowAt and AtArrow should be rewritten to function")
}
Expand Down

0 comments on commit f99fd05

Please sign in to comment.