diff --git a/arrow-ord/src/cmp.rs b/arrow-ord/src/cmp.rs index bfb1f64e2eb8..b235548c17d9 100644 --- a/arrow-ord/src/cmp.rs +++ b/arrow-ord/src/cmp.rs @@ -27,7 +27,7 @@ use arrow_array::cast::AsArray; use arrow_array::types::ByteArrayType; use arrow_array::{ downcast_primitive_array, AnyDictionaryArray, Array, ArrowNativeTypeOp, BooleanArray, Datum, - FixedSizeBinaryArray, GenericByteArray, + FixedSizeBinaryArray, GenericByteArray, StructArray, }; use arrow_buffer::bit_util::ceil; use arrow_buffer::{BooleanBuffer, MutableBuffer, NullBuffer}; @@ -169,12 +169,14 @@ pub fn not_distinct(lhs: &dyn Datum, rhs: &dyn Datum) -> Result Result { - use arrow_schema::DataType::*; - let (l, l_s) = lhs.get(); - let (r, r_s) = rhs.get(); + let (l_array, l_s) = lhs.get(); + let (r_array, r_s) = rhs.get(); + + let l_nulls = l_array.logical_nulls(); + let r_nulls = r_array.logical_nulls(); - let l_len = l.len(); - let r_len = r.len(); + let l_len = l_array.len(); + let r_len = r_array.len(); if l_len != r_len && !l_s && !r_s { return Err(ArrowError::InvalidArgumentError(format!( @@ -187,39 +189,6 @@ fn compare_op(op: Op, lhs: &dyn Datum, rhs: &dyn Datum) -> Result l_len, }; - let l_nulls = l.logical_nulls(); - let r_nulls = r.logical_nulls(); - - let l_v = l.as_any_dictionary_opt(); - let l = l_v.map(|x| x.values().as_ref()).unwrap_or(l); - let l_t = l.data_type(); - - let r_v = r.as_any_dictionary_opt(); - let r = r_v.map(|x| x.values().as_ref()).unwrap_or(r); - let r_t = r.data_type(); - - if l_t != r_t || l_t.is_nested() { - return Err(ArrowError::InvalidArgumentError(format!( - "Invalid comparison operation: {l_t} {op} {r_t}" - ))); - } - - // Defer computation as may not be necessary - let values = || -> BooleanBuffer { - let d = downcast_primitive_array! { - (l, r) => apply(op, l.values().as_ref(), l_s, l_v, r.values().as_ref(), r_s, r_v), - (Boolean, Boolean) => apply(op, l.as_boolean(), l_s, l_v, r.as_boolean(), r_s, r_v), - (Utf8, Utf8) => apply(op, l.as_string::(), l_s, l_v, r.as_string::(), r_s, r_v), - (LargeUtf8, LargeUtf8) => apply(op, l.as_string::(), l_s, l_v, r.as_string::(), r_s, r_v), - (Binary, Binary) => apply(op, l.as_binary::(), l_s, l_v, r.as_binary::(), r_s, r_v), - (LargeBinary, LargeBinary) => apply(op, l.as_binary::(), l_s, l_v, r.as_binary::(), r_s, r_v), - (FixedSizeBinary(_), FixedSizeBinary(_)) => apply(op, l.as_fixed_size_binary(), l_s, l_v, r.as_fixed_size_binary(), r_s, r_v), - (Null, Null) => None, - _ => unreachable!(), - }; - d.unwrap_or_else(|| BooleanBuffer::new_unset(len)) - }; - let l_nulls = l_nulls.filter(|n| n.null_count() > 0); let r_nulls = r_nulls.filter(|n| n.null_count() > 0); Ok(match (l_nulls, l_s, r_nulls, r_s) { @@ -227,7 +196,7 @@ fn compare_op(op: Op, lhs: &dyn Datum, rhs: &dyn Datum) -> Result { - let values = values(); + let values = compare_op_values(op, l_array, l_s, r_array, r_s, len)?; let l = l.inner().bit_chunks().iter_padded(); let r = r.inner().bit_chunks().iter_padded(); let ne = values.bit_chunks().iter_padded(); @@ -237,7 +206,7 @@ fn compare_op(op: Op, lhs: &dyn Datum, rhs: &dyn Datum) -> Result { - let values = values(); + let values = compare_op_values(op, l_array, l_s, r_array, r_s, len)?; let l = l.inner().bit_chunks().iter_padded(); let r = r.inner().bit_chunks().iter_padded(); let e = values.bit_chunks().iter_padded(); @@ -246,7 +215,10 @@ fn compare_op(op: Op, lhs: &dyn Datum, rhs: &dyn Datum) -> Result BooleanArray::new(values(), NullBuffer::union(Some(&l), Some(&r))), + _ => BooleanArray::new( + compare_op_values(op, l_array, l_s, r_array, r_s, len)?, + NullBuffer::union(Some(&l), Some(&r)), + ), } } (Some(_), true, Some(a), false) | (Some(a), false, Some(_), true) => { @@ -268,23 +240,122 @@ fn compare_op(op: Op, lhs: &dyn Datum, rhs: &dyn Datum) -> Result match op { Op::Distinct => { - let values = values(); + let values = compare_op_values(op, l_array, l_s, r_array, r_s, len)?; let l = nulls.inner().bit_chunks().iter_padded(); let ne = values.bit_chunks().iter_padded(); let c = |(l, n)| u64::not(l) | n; let buffer = l.zip(ne).map(c).collect(); BooleanBuffer::new(buffer, 0, len).into() } - Op::NotDistinct => (nulls.inner() & &values()).into(), - _ => BooleanArray::new(values(), Some(nulls)), + Op::NotDistinct => (nulls.inner() + & &compare_op_values(op, l_array, l_s, r_array, r_s, len)?) + .into(), + _ => BooleanArray::new( + compare_op_values(op, l_array, l_s, r_array, r_s, len)?, + Some(nulls), + ), }, } } // Neither side is nullable - (None, _, None, _) => BooleanArray::new(values(), None), + (None, _, None, _) => BooleanArray::new( + compare_op_values(op, l_array, l_s, r_array, r_s, len)?, + None, + ), }) } +/// Defer computation as may not be necessary +/// get the BooleanBuffer result of the comparison +fn compare_op_values( + op: Op, + l: &dyn Array, + l_s: bool, + r: &dyn Array, + r_s: bool, + len: usize, +) -> Result { + use arrow_schema::DataType::*; + let l_v = l.as_any_dictionary_opt(); + let l = l_v.map(|x| x.values().as_ref()).unwrap_or(l); + let l_t = l.data_type(); + + let r_v = r.as_any_dictionary_opt(); + let r = r_v.map(|x| x.values().as_ref()).unwrap_or(r); + let r_t = r.data_type(); + + if l_t.is_nested() { + if !l_t.equals_datatype(r_t) { + return Err(ArrowError::InvalidArgumentError(format!( + "Invalid comparison operation: {l_t} {op} {r_t}" + ))); + } + match (l_t, op) { + (Struct(_), Op::Equal | Op::NotEqual) => {} + _ => { + return Err(ArrowError::InvalidArgumentError(format!( + "Invalid comparison operation: {l_t} {op} {r_t}" + ))); + } + } + } else if r_t != l_t { + return Err(ArrowError::InvalidArgumentError(format!( + "Invalid comparison operation: {l_t} {op} {r_t}" + ))); + } + let d = downcast_primitive_array! { + (l, r) => apply(op, l.values().as_ref(), l_s, l_v, r.values().as_ref(), r_s, r_v), + (Boolean, Boolean) => apply(op, l.as_boolean(), l_s, l_v, r.as_boolean(), r_s, r_v), + (Utf8, Utf8) => apply(op, l.as_string::(), l_s, l_v, r.as_string::(), r_s, r_v), + (LargeUtf8, LargeUtf8) => apply(op, l.as_string::(), l_s, l_v, r.as_string::(), r_s, r_v), + (Binary, Binary) => apply(op, l.as_binary::(), l_s, l_v, r.as_binary::(), r_s, r_v), + (LargeBinary, LargeBinary) => apply(op, l.as_binary::(), l_s, l_v, r.as_binary::(), r_s, r_v), + (FixedSizeBinary(_), FixedSizeBinary(_)) => apply(op, l.as_fixed_size_binary(), l_s, l_v, r.as_fixed_size_binary(), r_s, r_v), + (Null, Null) => None, + (Struct(_), Struct(_)) => Some(compare_op_struct_values(op, l, l_s, r, r_s, len)?), + _ => unreachable!(), + }; + Ok(d.unwrap_or_else(|| BooleanBuffer::new_unset(len))) +} + +/// recursively compare fields of struct arrays +fn compare_op_struct_values( + op: Op, + l: &dyn Array, + l_s: bool, + r: &dyn Array, + r_s: bool, + len: usize, +) -> Result { + // when one of field is equal, the result is false for not equal + // so we use neg to reverse the result of equal when handle not equal + let neg = match op { + Op::Equal => false, + Op::NotEqual => true, + _ => unreachable!(), + }; + + let l = l.as_any().downcast_ref::().unwrap(); + let r = r.as_any().downcast_ref::().unwrap(); + + let mut child_res: Vec = Vec::with_capacity(len); + // compare each field of struct + for item in l + .columns() + .to_vec() + .iter() + .zip(r.columns().to_vec().iter()) + .map(|(col_l, col_r)| compare_op_values(Op::Equal, col_l, l_s, col_r, r_s, len)) + { + child_res.push(item?); + } + // combine the result of each field + let equality = child_res + .iter() + .fold(BooleanBuffer::new_set(len), |acc, x| &acc & x); + Ok(if neg { !&equality } else { equality }) +} + /// Perform a potentially vectored `op` on the provided `ArrayOrd` fn apply( op: Op, @@ -544,7 +615,9 @@ impl<'a> ArrayOrd for &'a FixedSizeBinaryArray { mod tests { use std::sync::Arc; - use arrow_array::{DictionaryArray, Int32Array, Scalar, StringArray}; + use arrow_array::{ArrayRef, DictionaryArray, Int32Array, Scalar, StringArray, StructArray}; + use arrow_buffer::Buffer; + use arrow_schema::{DataType, Field}; use super::*; @@ -702,4 +775,122 @@ mod tests { neq(&col.slice(0, col.len() - 1), &col.slice(1, col.len() - 1)).unwrap(); } + + #[test] + fn test_struct_equality() { + // test struct('a', 'b') = struct('a', 'b'), the null buffer is 0b0111 + let left_a = Arc::new(Int32Array::new( + vec![0, 1, 2, 3].into(), + Some(vec![true, false, true, false].into()), + )); + let right_a = Arc::new(Int32Array::new( + vec![0, 1, 2, 3].into(), + Some(vec![true, false, true, false].into()), + )); + let left_b = Arc::new(Int32Array::new( + vec![0, 1, 2, 3].into(), + Some(vec![true, true, true, false].into()), + )); + let right_b = Arc::new(Int32Array::new( + vec![0, 1, 2, 3].into(), + Some(vec![true, true, true, false].into()), + )); + let field_a = Arc::new(Field::new("a", DataType::Int32, true)); + let field_b = Arc::new(Field::new("b", DataType::Int32, true)); + let left_struct = StructArray::from(( + vec![ + (field_a.clone(), left_a.clone() as ArrayRef), + (field_b.clone(), left_b.clone() as ArrayRef), + ], + Buffer::from([0b0111]), + )); + let right_struct = StructArray::from(( + vec![ + (field_a.clone(), right_a.clone() as ArrayRef), + (field_b.clone(), right_b.clone() as ArrayRef), + ], + Buffer::from([0b0111]), + )); + let expected = BooleanArray::new( + vec![true, true, true, true].into(), + Some(vec![true, true, true, false].into()), + ); + assert_eq!(eq(&left_struct, &right_struct).unwrap(), expected); + assert_eq!(eq(&right_struct, &left_struct).unwrap(), expected); + let expected = BooleanArray::new( + vec![false, false, false, false].into(), + Some(vec![true, true, true, false].into()), + ); + assert_eq!(neq(&left_struct, &right_struct).unwrap(), expected); + assert_eq!(neq(&right_struct, &left_struct).unwrap(), expected); + + let sub_struct_fields = left_struct.fields().clone(); + + // test struct('a', 'b') = struct('a', 'b'), right a[1] is different from left a[2],the null buffer is 0b0111 + let right_a2 = Arc::new(Int32Array::new( + vec![0, 2, 2, 3].into(), + Some(vec![true, true, true, false].into()), + )); + let right_struct = StructArray::from(( + vec![ + (field_a.clone(), right_a2.clone() as ArrayRef), + (field_b.clone(), right_b.clone() as ArrayRef), + ], + Buffer::from([0b0111]), + )); + let expected = BooleanArray::new( + vec![true, false, true, true].into(), + Some(vec![true, true, true, false].into()), + ); + assert_eq!(eq(&left_struct, &right_struct).unwrap(), expected); + assert_eq!(eq(&right_struct, &left_struct).unwrap(), expected); + let expected = BooleanArray::new( + vec![false, true, false, false].into(), + Some(vec![true, true, true, false].into()), + ); + assert_eq!(neq(&left_struct, &right_struct).unwrap(), expected); + assert_eq!(neq(&right_struct, &left_struct).unwrap(), expected); + + // test struct('a' , struct('suba', 'subb')) = struct('a', struct('suba', 'subb')), where the right suba1[1] different from left suba[1],the null buffer is 0b0111 + let left_struct = StructArray::from(( + vec![ + (field_a.clone(), left_a.clone() as ArrayRef), + ( + Arc::new(Field::new( + "SubStruct", + DataType::Struct(sub_struct_fields.clone()), + true, + )), + Arc::new(left_struct) as ArrayRef, + ), + ], + Buffer::from([0b0111]), + )); + let right_struct = StructArray::from(( + vec![ + (field_a.clone(), right_a.clone() as ArrayRef), + ( + Arc::new(Field::new( + "SubStruct", + DataType::Struct(sub_struct_fields.clone()), + true, + )), + Arc::new(right_struct) as ArrayRef, + ), + ], + Buffer::from([0b0111]), + )); + let expected = BooleanArray::new( + vec![true, false, true, true].into(), + Some(vec![true, true, true, false].into()), + ); + assert_eq!(eq(&left_struct, &right_struct).unwrap(), expected); + assert_eq!(eq(&right_struct, &left_struct).unwrap(), expected); + let expected = BooleanArray::new( + vec![false, true, false, false].into(), + Some(vec![true, true, true, false].into()), + ); + assert_eq!(neq(&left_struct, &right_struct).unwrap(), expected); + assert_eq!(neq(&right_struct, &left_struct).unwrap(), expected); + } }