Skip to content

Commit

Permalink
Support equality of StructArray
Browse files Browse the repository at this point in the history
  • Loading branch information
my-vegetable-has-exploded committed Dec 17, 2023
1 parent 9a1e8b5 commit d9783dc
Showing 1 changed file with 238 additions and 47 deletions.
285 changes: 238 additions & 47 deletions arrow-ord/src/cmp.rs
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@ use arrow_array::cast::AsArray;
use arrow_array::types::ByteArrayType;
use arrow_array::{
downcast_primitive_array, AnyDictionaryArray, Array, ArrowNativeTypeOp, BooleanArray, Datum,
FixedSizeBinaryArray, GenericByteArray,
FixedSizeBinaryArray, GenericByteArray, StructArray,
};
use arrow_buffer::bit_util::ceil;
use arrow_buffer::{BooleanBuffer, MutableBuffer, NullBuffer};
Expand Down Expand Up @@ -169,12 +169,14 @@ pub fn not_distinct(lhs: &dyn Datum, rhs: &dyn Datum) -> Result<BooleanArray, Ar
/// Perform `op` on the provided `Datum`
#[inline(never)]
fn compare_op(op: Op, lhs: &dyn Datum, rhs: &dyn Datum) -> Result<BooleanArray, ArrowError> {
use arrow_schema::DataType::*;
let (l, l_s) = lhs.get();
let (r, r_s) = rhs.get();
let (l_array, l_s) = lhs.get();
let (r_array, r_s) = rhs.get();

let l_nulls = l_array.logical_nulls();
let r_nulls = r_array.logical_nulls();

let l_len = l.len();
let r_len = r.len();
let l_len = l_array.len();
let r_len = r_array.len();

if l_len != r_len && !l_s && !r_s {
return Err(ArrowError::InvalidArgumentError(format!(
Expand All @@ -187,47 +189,14 @@ fn compare_op(op: Op, lhs: &dyn Datum, rhs: &dyn Datum) -> Result<BooleanArray,
false => l_len,
};

let l_nulls = l.logical_nulls();
let r_nulls = r.logical_nulls();

let l_v = l.as_any_dictionary_opt();
let l = l_v.map(|x| x.values().as_ref()).unwrap_or(l);
let l_t = l.data_type();

let r_v = r.as_any_dictionary_opt();
let r = r_v.map(|x| x.values().as_ref()).unwrap_or(r);
let r_t = r.data_type();

if l_t != r_t || l_t.is_nested() {
return Err(ArrowError::InvalidArgumentError(format!(
"Invalid comparison operation: {l_t} {op} {r_t}"
)));
}

// Defer computation as may not be necessary
let values = || -> BooleanBuffer {
let d = downcast_primitive_array! {
(l, r) => apply(op, l.values().as_ref(), l_s, l_v, r.values().as_ref(), r_s, r_v),
(Boolean, Boolean) => apply(op, l.as_boolean(), l_s, l_v, r.as_boolean(), r_s, r_v),
(Utf8, Utf8) => apply(op, l.as_string::<i32>(), l_s, l_v, r.as_string::<i32>(), r_s, r_v),
(LargeUtf8, LargeUtf8) => apply(op, l.as_string::<i64>(), l_s, l_v, r.as_string::<i64>(), r_s, r_v),
(Binary, Binary) => apply(op, l.as_binary::<i32>(), l_s, l_v, r.as_binary::<i32>(), r_s, r_v),
(LargeBinary, LargeBinary) => apply(op, l.as_binary::<i64>(), l_s, l_v, r.as_binary::<i64>(), r_s, r_v),
(FixedSizeBinary(_), FixedSizeBinary(_)) => apply(op, l.as_fixed_size_binary(), l_s, l_v, r.as_fixed_size_binary(), r_s, r_v),
(Null, Null) => None,
_ => unreachable!(),
};
d.unwrap_or_else(|| BooleanBuffer::new_unset(len))
};

let l_nulls = l_nulls.filter(|n| n.null_count() > 0);
let r_nulls = r_nulls.filter(|n| n.null_count() > 0);
Ok(match (l_nulls, l_s, r_nulls, r_s) {
(Some(l), true, Some(r), true) | (Some(l), false, Some(r), false) => {
// Either both sides are scalar or neither side is scalar
match op {
Op::Distinct => {
let values = values();
let values = compare_op_values(op, l_array, l_s, r_array, r_s, len)?;
let l = l.inner().bit_chunks().iter_padded();
let r = r.inner().bit_chunks().iter_padded();
let ne = values.bit_chunks().iter_padded();
Expand All @@ -237,7 +206,7 @@ fn compare_op(op: Op, lhs: &dyn Datum, rhs: &dyn Datum) -> Result<BooleanArray,
BooleanBuffer::new(buffer, 0, len).into()
}
Op::NotDistinct => {
let values = values();
let values = compare_op_values(op, l_array, l_s, r_array, r_s, len)?;
let l = l.inner().bit_chunks().iter_padded();
let r = r.inner().bit_chunks().iter_padded();
let e = values.bit_chunks().iter_padded();
Expand All @@ -246,7 +215,10 @@ fn compare_op(op: Op, lhs: &dyn Datum, rhs: &dyn Datum) -> Result<BooleanArray,
let buffer = l.zip(r).zip(e).map(c).collect();
BooleanBuffer::new(buffer, 0, len).into()
}
_ => BooleanArray::new(values(), NullBuffer::union(Some(&l), Some(&r))),
_ => BooleanArray::new(
compare_op_values(op, l_array, l_s, r_array, r_s, len)?,
NullBuffer::union(Some(&l), Some(&r)),
),
}
}
(Some(_), true, Some(a), false) | (Some(a), false, Some(_), true) => {
Expand All @@ -268,23 +240,122 @@ fn compare_op(op: Op, lhs: &dyn Datum, rhs: &dyn Datum) -> Result<BooleanArray,
},
false => match op {
Op::Distinct => {
let values = values();
let values = compare_op_values(op, l_array, l_s, r_array, r_s, len)?;
let l = nulls.inner().bit_chunks().iter_padded();
let ne = values.bit_chunks().iter_padded();
let c = |(l, n)| u64::not(l) | n;
let buffer = l.zip(ne).map(c).collect();
BooleanBuffer::new(buffer, 0, len).into()
}
Op::NotDistinct => (nulls.inner() & &values()).into(),
_ => BooleanArray::new(values(), Some(nulls)),
Op::NotDistinct => (nulls.inner()
& &compare_op_values(op, l_array, l_s, r_array, r_s, len)?)
.into(),
_ => BooleanArray::new(
compare_op_values(op, l_array, l_s, r_array, r_s, len)?,
Some(nulls),
),
},
}
}
// Neither side is nullable
(None, _, None, _) => BooleanArray::new(values(), None),
(None, _, None, _) => BooleanArray::new(
compare_op_values(op, l_array, l_s, r_array, r_s, len)?,
None,
),
})
}

/// Defer computation as may not be necessary
/// get the BooleanBuffer result of the comparison
fn compare_op_values(
op: Op,
l: &dyn Array,
l_s: bool,
r: &dyn Array,
r_s: bool,
len: usize,
) -> Result<BooleanBuffer, ArrowError> {
use arrow_schema::DataType::*;
let l_v = l.as_any_dictionary_opt();
let l = l_v.map(|x| x.values().as_ref()).unwrap_or(l);
let l_t = l.data_type();

let r_v = r.as_any_dictionary_opt();
let r = r_v.map(|x| x.values().as_ref()).unwrap_or(r);
let r_t = r.data_type();

if l_t.is_nested() {
if !l_t.equals_datatype(r_t) {
return Err(ArrowError::InvalidArgumentError(format!(
"Invalid comparison operation: {l_t} {op} {r_t}"
)));
}
match (l_t, op) {
(Struct(_), Op::Equal | Op::NotEqual) => {}
_ => {
return Err(ArrowError::InvalidArgumentError(format!(
"Invalid comparison operation: {l_t} {op} {r_t}"
)));
}
}
} else if r_t != l_t {
return Err(ArrowError::InvalidArgumentError(format!(
"Invalid comparison operation: {l_t} {op} {r_t}"
)));
}
let d = downcast_primitive_array! {
(l, r) => apply(op, l.values().as_ref(), l_s, l_v, r.values().as_ref(), r_s, r_v),
(Boolean, Boolean) => apply(op, l.as_boolean(), l_s, l_v, r.as_boolean(), r_s, r_v),
(Utf8, Utf8) => apply(op, l.as_string::<i32>(), l_s, l_v, r.as_string::<i32>(), r_s, r_v),
(LargeUtf8, LargeUtf8) => apply(op, l.as_string::<i64>(), l_s, l_v, r.as_string::<i64>(), r_s, r_v),
(Binary, Binary) => apply(op, l.as_binary::<i32>(), l_s, l_v, r.as_binary::<i32>(), r_s, r_v),
(LargeBinary, LargeBinary) => apply(op, l.as_binary::<i64>(), l_s, l_v, r.as_binary::<i64>(), r_s, r_v),
(FixedSizeBinary(_), FixedSizeBinary(_)) => apply(op, l.as_fixed_size_binary(), l_s, l_v, r.as_fixed_size_binary(), r_s, r_v),
(Null, Null) => None,
(Struct(_), Struct(_)) => Some(compare_op_struct_values(op, l, l_s, r, r_s, len)?),
_ => unreachable!(),
};
Ok(d.unwrap_or_else(|| BooleanBuffer::new_unset(len)))
}

/// recursively compare fields of struct arrays
fn compare_op_struct_values(
op: Op,
l: &dyn Array,
l_s: bool,
r: &dyn Array,
r_s: bool,
len: usize,
) -> Result<BooleanBuffer, ArrowError> {
// when one of field is equal, the result is false for not equal
// so we use neg to reverse the result of equal when handle not equal
let neg = match op {
Op::Equal => false,
Op::NotEqual => true,
_ => unreachable!(),
};

let l = l.as_any().downcast_ref::<StructArray>().unwrap();
let r = r.as_any().downcast_ref::<StructArray>().unwrap();

let mut child_res: Vec<BooleanBuffer> = Vec::with_capacity(len);
// compare each field of struct
for item in l
.columns()
.to_vec()
.iter()
.zip(r.columns().to_vec().iter())
.map(|(col_l, col_r)| compare_op_values(Op::Equal, col_l, l_s, col_r, r_s, len))
{
child_res.push(item?);
}
// combine the result of each field
let equality = child_res
.iter()
.fold(BooleanBuffer::new_set(len), |acc, x| &acc & x);
Ok(if neg { !&equality } else { equality })
}

/// Perform a potentially vectored `op` on the provided `ArrayOrd`
fn apply<T: ArrayOrd>(
op: Op,
Expand Down Expand Up @@ -544,7 +615,9 @@ impl<'a> ArrayOrd for &'a FixedSizeBinaryArray {
mod tests {
use std::sync::Arc;

use arrow_array::{DictionaryArray, Int32Array, Scalar, StringArray};
use arrow_array::{ArrayRef, DictionaryArray, Int32Array, Scalar, StringArray, StructArray};
use arrow_buffer::Buffer;
use arrow_schema::{DataType, Field};

use super::*;

Expand Down Expand Up @@ -702,4 +775,122 @@ mod tests {

neq(&col.slice(0, col.len() - 1), &col.slice(1, col.len() - 1)).unwrap();
}

#[test]
fn test_struct_equality() {
// test struct('a', 'b') = struct('a', 'b'), the null buffer is 0b0111
let left_a = Arc::new(Int32Array::new(
vec![0, 1, 2, 3].into(),
Some(vec![true, false, true, false].into()),
));
let right_a = Arc::new(Int32Array::new(
vec![0, 1, 2, 3].into(),
Some(vec![true, false, true, false].into()),
));
let left_b = Arc::new(Int32Array::new(
vec![0, 1, 2, 3].into(),
Some(vec![true, true, true, false].into()),
));
let right_b = Arc::new(Int32Array::new(
vec![0, 1, 2, 3].into(),
Some(vec![true, true, true, false].into()),
));
let field_a = Arc::new(Field::new("a", DataType::Int32, true));
let field_b = Arc::new(Field::new("b", DataType::Int32, true));
let left_struct = StructArray::from((
vec![
(field_a.clone(), left_a.clone() as ArrayRef),
(field_b.clone(), left_b.clone() as ArrayRef),
],
Buffer::from([0b0111]),
));
let right_struct = StructArray::from((
vec![
(field_a.clone(), right_a.clone() as ArrayRef),
(field_b.clone(), right_b.clone() as ArrayRef),
],
Buffer::from([0b0111]),
));
let expected = BooleanArray::new(
vec![true, true, true, true].into(),
Some(vec![true, true, true, false].into()),
);
assert_eq!(eq(&left_struct, &right_struct).unwrap(), expected);
assert_eq!(eq(&right_struct, &left_struct).unwrap(), expected);
let expected = BooleanArray::new(
vec![false, false, false, false].into(),
Some(vec![true, true, true, false].into()),
);
assert_eq!(neq(&left_struct, &right_struct).unwrap(), expected);
assert_eq!(neq(&right_struct, &left_struct).unwrap(), expected);

let sub_struct_fields = left_struct.fields().clone();

// test struct('a', 'b') = struct('a', 'b'), right a[1] is different from left a[2],the null buffer is 0b0111
let right_a2 = Arc::new(Int32Array::new(
vec![0, 2, 2, 3].into(),
Some(vec![true, true, true, false].into()),
));
let right_struct = StructArray::from((
vec![
(field_a.clone(), right_a2.clone() as ArrayRef),
(field_b.clone(), right_b.clone() as ArrayRef),
],
Buffer::from([0b0111]),
));
let expected = BooleanArray::new(
vec![true, false, true, true].into(),
Some(vec![true, true, true, false].into()),
);
assert_eq!(eq(&left_struct, &right_struct).unwrap(), expected);
assert_eq!(eq(&right_struct, &left_struct).unwrap(), expected);
let expected = BooleanArray::new(
vec![false, true, false, false].into(),
Some(vec![true, true, true, false].into()),
);
assert_eq!(neq(&left_struct, &right_struct).unwrap(), expected);
assert_eq!(neq(&right_struct, &left_struct).unwrap(), expected);

// test struct('a' , struct('suba', 'subb')) = struct('a', struct('suba', 'subb')), where the right suba1[1] different from left suba[1],the null buffer is 0b0111
let left_struct = StructArray::from((
vec![
(field_a.clone(), left_a.clone() as ArrayRef),
(
Arc::new(Field::new(
"SubStruct",
DataType::Struct(sub_struct_fields.clone()),
true,
)),
Arc::new(left_struct) as ArrayRef,
),
],
Buffer::from([0b0111]),
));
let right_struct = StructArray::from((
vec![
(field_a.clone(), right_a.clone() as ArrayRef),
(
Arc::new(Field::new(
"SubStruct",
DataType::Struct(sub_struct_fields.clone()),
true,
)),
Arc::new(right_struct) as ArrayRef,
),
],
Buffer::from([0b0111]),
));
let expected = BooleanArray::new(
vec![true, false, true, true].into(),
Some(vec![true, true, true, false].into()),
);
assert_eq!(eq(&left_struct, &right_struct).unwrap(), expected);
assert_eq!(eq(&right_struct, &left_struct).unwrap(), expected);
let expected = BooleanArray::new(
vec![false, true, false, false].into(),
Some(vec![true, true, true, false].into()),
);
assert_eq!(neq(&left_struct, &right_struct).unwrap(), expected);
assert_eq!(neq(&right_struct, &left_struct).unwrap(), expected);
}
}

0 comments on commit d9783dc

Please sign in to comment.