diff --git a/arrow-ord/src/cmp.rs b/arrow-ord/src/cmp.rs index 7e2058a64409..e1697570e600 100644 --- a/arrow-ord/src/cmp.rs +++ b/arrow-ord/src/cmp.rs @@ -169,14 +169,12 @@ pub fn not_distinct(lhs: &dyn Datum, rhs: &dyn Datum) -> Result Result { - let (l_array, l_s) = lhs.get(); - let (r_array, r_s) = rhs.get(); - - let l_nulls = l_array.logical_nulls(); - let r_nulls = r_array.logical_nulls(); + use arrow_schema::DataType::*; + let (l, l_s) = lhs.get(); + let (r, r_s) = rhs.get(); - let l_len = l_array.len(); - let r_len = r_array.len(); + let l_len = l.len(); + let r_len = r.len(); if l_len != r_len && !l_s && !r_s { return Err(ArrowError::InvalidArgumentError(format!( @@ -184,49 +182,166 @@ fn compare_op(op: Op, lhs: &dyn Datum, rhs: &dyn Datum) -> Result {} + _ => { + return Err(ArrowError::InvalidArgumentError(format!( + "Invalid comparison operation: {l_t} {op} {r_t}" + ))); + } + } + } else if r_t != l_t { + return Err(ArrowError::InvalidArgumentError(format!( + "Invalid comparison operation: {l_t} {op} {r_t}" + ))); + } + let len = match l_s { true => r_len, false => l_len, }; + Ok(BooleanArray::new( + compare_op_values(op, l, l_s, r, r_s, len)?, + compare_op_nulls(op, l, l_s, r, r_s, len)?, + )) +} + +/// get the NullBuffer result of the comparison +fn compare_op_nulls( + op: Op, + l: &dyn Array, + l_s: bool, + r: &dyn Array, + r_s: bool, + len: usize, +) -> Result, ArrowError> { + use arrow_schema::DataType::*; + let l_t = l.data_type(); + let r_t = r.data_type(); + let l_nulls = l.logical_nulls().filter(|n| n.null_count() > 0); + let r_nulls = r.logical_nulls().filter(|n| n.null_count() > 0); + // for [not]Distinct, the result is never null + match op { + Op::Distinct | Op::NotDistinct => { + return Ok(None); + } + _ => {} + } + let nulls = match (l_nulls, l_s, r_nulls, r_s) { + // Either both sides are scalar or neither side is scalar + (Some(l_nulls), true, Some(r_nulls), true) + | (Some(l_nulls), false, Some(r_nulls), false) => { + NullBuffer::union(Some(&l_nulls), Some(&r_nulls)) + } + // Scalar is null, other side is non-scalar and nullable + (Some(_), true, Some(_), false) | (Some(_), false, Some(_), true) => { + Some(NullBuffer::new_null(len)) + } + // Only one side is nullable + (Some(nulls), is_scalar, None, _) | (None, _, Some(nulls), is_scalar) => match is_scalar { + true => Some(NullBuffer::new_null(len)), + false => Some(nulls), + }, + // Neither side is nullable + (None, _, None, _) => None, + }; + match (l_t, r_t) { + (Struct(_), Struct(_)) => { + // union all nulls from children, because any child in certain slot is null, the struct in the slot is uncomparable + let child_nulls = l + .as_struct() + .columns() + .iter() + .zip(r.as_struct().columns().iter()) + .map(|(l, r)| compare_op_nulls(op, l, l_s, r, r_s, len)) + .collect::, _>>()?; + Ok(child_nulls.iter().fold(nulls, |nulls, child_null| { + NullBuffer::union(nulls.as_ref(), child_null.as_ref()) + })) + } + _ => Ok(nulls), + } +} + +/// Defer computation as may not be necessary +/// get the BooleanBuffer result of the comparison +fn compare_op_values( + op: Op, + l: &dyn Array, + l_s: bool, + r: &dyn Array, + r_s: bool, + len: usize, +) -> Result { + use arrow_schema::DataType::*; + let l_v = l.as_any_dictionary_opt(); + let l = l_v.map(|x| x.values().as_ref()).unwrap_or(l); - let l_nulls = l_nulls.filter(|n| n.null_count() > 0); - let r_nulls = r_nulls.filter(|n| n.null_count() > 0); + let r_v = r.as_any_dictionary_opt(); + let r = r_v.map(|x| x.values().as_ref()).unwrap_or(r); + + let l_nulls = l.logical_nulls().filter(|n| n.null_count() > 0); + let r_nulls = r.logical_nulls().filter(|n| n.null_count() > 0); + let values = || -> Result { + let values = downcast_primitive_array! { + (l, r) => apply(op, l.values().as_ref(), l_s, l_v, r.values().as_ref(), r_s, r_v), + (Boolean, Boolean) => apply(op, l.as_boolean(), l_s, l_v, r.as_boolean(), r_s, r_v), + (Utf8, Utf8) => apply(op, l.as_string::(), l_s, l_v, r.as_string::(), r_s, r_v), + (LargeUtf8, LargeUtf8) => apply(op, l.as_string::(), l_s, l_v, r.as_string::(), r_s, r_v), + (Binary, Binary) => apply(op, l.as_binary::(), l_s, l_v, r.as_binary::(), r_s, r_v), + (LargeBinary, LargeBinary) => apply(op, l.as_binary::(), l_s, l_v, r.as_binary::(), r_s, r_v), + (FixedSizeBinary(_), FixedSizeBinary(_)) => apply(op, l.as_fixed_size_binary(), l_s, l_v, r.as_fixed_size_binary(), r_s, r_v), + (Null, Null) => None, + (Struct(_), Struct(_)) => Some(compare_op_struct_values(op, l, l_s, r, r_s, len)?), + _ => unreachable!(), + }; + Ok(values.unwrap_or_else(|| BooleanBuffer::new_unset(len))) + }; Ok(match (l_nulls, l_s, r_nulls, r_s) { - (Some(l), true, Some(r), true) | (Some(l), false, Some(r), false) => { + (Some(l_nulls), true, Some(r_nulls), true) + | (Some(l_nulls), false, Some(r_nulls), false) => { // Either both sides are scalar or neither side is scalar match op { Op::Distinct => { - let values = compare_op_values(op, l_array, l_s, r_array, r_s, len)?; - let l = l.inner().bit_chunks().iter_padded(); - let r = r.inner().bit_chunks().iter_padded(); + let values = values()?; + let l_nulls = l_nulls.inner().bit_chunks().iter_padded(); + let r_nulls = r_nulls.inner().bit_chunks().iter_padded(); let ne = values.bit_chunks().iter_padded(); - let c = |((l, r), n)| ((l ^ r) | (l & r & n)); - let buffer = l.zip(r).zip(ne).map(c).collect(); - BooleanBuffer::new(buffer, 0, len).into() + let c = + |((l_nulls, r_nulls), n)| ((l_nulls ^ r_nulls) | (l_nulls & r_nulls & n)); + let buffer = l_nulls.zip(r_nulls).zip(ne).map(c).collect(); + BooleanBuffer::new(buffer, 0, len) } Op::NotDistinct => { - let values = compare_op_values(op, l_array, l_s, r_array, r_s, len)?; - let l = l.inner().bit_chunks().iter_padded(); - let r = r.inner().bit_chunks().iter_padded(); + let values = values()?; + let l_nulls = l_nulls.inner().bit_chunks().iter_padded(); + let r_nulls = r_nulls.inner().bit_chunks().iter_padded(); let e = values.bit_chunks().iter_padded(); - let c = |((l, r), e)| u64::not(l | r) | (l & r & e); - let buffer = l.zip(r).zip(e).map(c).collect(); - BooleanBuffer::new(buffer, 0, len).into() + let c = |((l_nulls, r_nulls), e)| { + u64::not(l_nulls | r_nulls) | (l_nulls & r_nulls & e) + }; + let buffer = l_nulls.zip(r_nulls).zip(e).map(c).collect(); + BooleanBuffer::new(buffer, 0, len) } - _ => BooleanArray::new( - compare_op_values(op, l_array, l_s, r_array, r_s, len)?, - NullBuffer::union(Some(&l), Some(&r)), - ), + _ => values()?, } } (Some(_), true, Some(a), false) | (Some(a), false, Some(_), true) => { // Scalar is null, other side is non-scalar and nullable match op { - Op::Distinct => a.into_inner().into(), - Op::NotDistinct => a.into_inner().not().into(), - _ => BooleanArray::new_null(len), + Op::Distinct => a.into_inner(), + Op::NotDistinct => a.into_inner().not(), + _ => BooleanBuffer::new_unset(len), } } (Some(nulls), is_scalar, None, _) | (None, _, Some(nulls), is_scalar) => { @@ -234,90 +349,29 @@ fn compare_op(op: Op, lhs: &dyn Datum, rhs: &dyn Datum) -> Result match op { // Scalar is null, other side is not nullable - Op::Distinct => BooleanBuffer::new_set(len).into(), - Op::NotDistinct => BooleanBuffer::new_unset(len).into(), - _ => BooleanArray::new_null(len), + Op::Distinct => BooleanBuffer::new_set(len), + Op::NotDistinct => BooleanBuffer::new_unset(len), + _ => BooleanBuffer::new_unset(len), }, false => match op { Op::Distinct => { - let values = compare_op_values(op, l_array, l_s, r_array, r_s, len)?; - let l = nulls.inner().bit_chunks().iter_padded(); + let values = values()?; + let l_nulls = nulls.inner().bit_chunks().iter_padded(); let ne = values.bit_chunks().iter_padded(); - let c = |(l, n)| u64::not(l) | n; - let buffer = l.zip(ne).map(c).collect(); - BooleanBuffer::new(buffer, 0, len).into() + let c = |(l_nulls, n)| u64::not(l_nulls) | n; + let buffer = l_nulls.zip(ne).map(c).collect(); + BooleanBuffer::new(buffer, 0, len) } - Op::NotDistinct => (nulls.inner() - & &compare_op_values(op, l_array, l_s, r_array, r_s, len)?) - .into(), - _ => BooleanArray::new( - compare_op_values(op, l_array, l_s, r_array, r_s, len)?, - Some(nulls), - ), + Op::NotDistinct => nulls.inner() & &values()?, + _ => values()?, }, } } // Neither side is nullable - (None, _, None, _) => BooleanArray::new( - compare_op_values(op, l_array, l_s, r_array, r_s, len)?, - None, - ), + (None, _, None, _) => values()?, }) } -/// Defer computation as may not be necessary -/// get the BooleanBuffer result of the comparison -fn compare_op_values( - op: Op, - l: &dyn Array, - l_s: bool, - r: &dyn Array, - r_s: bool, - len: usize, -) -> Result { - use arrow_schema::DataType::*; - let l_v = l.as_any_dictionary_opt(); - let l = l_v.map(|x| x.values().as_ref()).unwrap_or(l); - let l_t = l.data_type(); - - let r_v = r.as_any_dictionary_opt(); - let r = r_v.map(|x| x.values().as_ref()).unwrap_or(r); - let r_t = r.data_type(); - - if l_t.is_nested() { - if !l_t.equals_datatype(r_t) { - return Err(ArrowError::InvalidArgumentError(format!( - "Invalid comparison operation: {l_t} {op} {r_t}" - ))); - } - match (l_t, op) { - (Struct(_), Op::Equal | Op::NotEqual) => {} - _ => { - return Err(ArrowError::InvalidArgumentError(format!( - "Invalid comparison operation: {l_t} {op} {r_t}" - ))); - } - } - } else if r_t != l_t { - return Err(ArrowError::InvalidArgumentError(format!( - "Invalid comparison operation: {l_t} {op} {r_t}" - ))); - } - let d = downcast_primitive_array! { - (l, r) => apply(op, l.values().as_ref(), l_s, l_v, r.values().as_ref(), r_s, r_v), - (Boolean, Boolean) => apply(op, l.as_boolean(), l_s, l_v, r.as_boolean(), r_s, r_v), - (Utf8, Utf8) => apply(op, l.as_string::(), l_s, l_v, r.as_string::(), r_s, r_v), - (LargeUtf8, LargeUtf8) => apply(op, l.as_string::(), l_s, l_v, r.as_string::(), r_s, r_v), - (Binary, Binary) => apply(op, l.as_binary::(), l_s, l_v, r.as_binary::(), r_s, r_v), - (LargeBinary, LargeBinary) => apply(op, l.as_binary::(), l_s, l_v, r.as_binary::(), r_s, r_v), - (FixedSizeBinary(_), FixedSizeBinary(_)) => apply(op, l.as_fixed_size_binary(), l_s, l_v, r.as_fixed_size_binary(), r_s, r_v), - (Null, Null) => None, - (Struct(_), Struct(_)) => Some(compare_op_struct_values(op, l, l_s, r, r_s, len)?), - _ => unreachable!(), - }; - Ok(d.unwrap_or_else(|| BooleanBuffer::new_unset(len))) -} - /// recursively compare fields of struct arrays fn compare_op_struct_values( op: Op, @@ -330,8 +384,8 @@ fn compare_op_struct_values( // when one of field is equal, the result is false for not equal // so we use neg to reverse the result of equal when handle not equal let neg = match op { - Op::Equal => false, - Op::NotEqual => true, + Op::Equal | Op::NotDistinct => false, + Op::NotEqual | Op::Distinct => true, _ => unreachable!(), }; @@ -339,16 +393,18 @@ fn compare_op_struct_values( let r = r.as_struct(); // compare each field of struct - let child_res = l + let child_values = l .columns() .iter() .zip(r.columns().iter()) .map(|(col_l, col_r)| compare_op_values(Op::Equal, col_l, l_s, col_r, r_s, len)) .collect::, ArrowError>>()?; // combine the result of each field - let equality = child_res + let equality = child_values .iter() - .fold(BooleanBuffer::new_set(len), |acc, x| &acc & x); + .fold(BooleanBuffer::new_set(len), |values, child_value| { + &values & child_value + }); Ok(if neg { !&equality } else { equality }) } @@ -773,8 +829,8 @@ mod tests { } #[test] - fn test_struct_equality() { - // test struct('a', 'b') = struct('a', 'b'), the null buffer is 0b0111 + fn test_struct_uncomparable() { + // test struct('a') == struct('a','b') let left_a = Arc::new(Int32Array::new( vec![0, 1, 2, 3].into(), Some(vec![true, false, true, false].into()), @@ -783,13 +839,42 @@ mod tests { vec![0, 1, 2, 3].into(), Some(vec![true, false, true, false].into()), )); - let left_b = Arc::new(Int32Array::new( + let right_b = Arc::new(Int32Array::new( vec![0, 1, 2, 3].into(), Some(vec![true, true, true, false].into()), )); + let field_a = Arc::new(Field::new("a", DataType::Int32, true)); + let field_b = Arc::new(Field::new("b", DataType::Int32, true)); + let left = StructArray::from(vec![(field_a.clone(), left_a.clone() as ArrayRef)]); + let right = StructArray::from(vec![ + (field_a.clone(), right_a.clone() as ArrayRef), + (field_b.clone(), right_b.clone() as ArrayRef), + ]); + assert_eq!(eq(&left, &right).unwrap_err().to_string(), "Invalid argument error: Invalid comparison operation: Struct([Field { name: \"a\", data_type: Int32, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }]) == Struct([Field { name: \"a\", data_type: Int32, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }, Field { name: \"b\", data_type: Int32, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }])"); + + // test struct('a') <= struct('a') + assert_eq!(lt(&left, &left).unwrap_err().to_string(), "Invalid argument error: Invalid comparison operation: Struct([Field { name: \"a\", data_type: Int32, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }]) < Struct([Field { name: \"a\", data_type: Int32, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }])"); + } + + #[test] + fn test_struct_compare() { + // test struct('a', 'b')、struct('a', 'b'), the null buffer is 0b0111 + // left b[2] is different from right b[2] + let left_a = Arc::new(Int32Array::new( + vec![0, 1, 2, 3].into(), + Some(vec![true, false, true, true].into()), + )); + let right_a = Arc::new(Int32Array::new( + vec![0, 1, 2, 3].into(), + Some(vec![true, false, true, true].into()), + )); + let left_b = Arc::new(Int32Array::new( + vec![0, 1, 20, 3].into(), + Some(vec![true, true, true, true].into()), + )); let right_b = Arc::new(Int32Array::new( vec![0, 1, 2, 3].into(), - Some(vec![true, true, true, false].into()), + Some(vec![true, true, true, true].into()), )); let field_a = Arc::new(Field::new("a", DataType::Int32, true)); let field_b = Arc::new(Field::new("b", DataType::Int32, true)); @@ -808,46 +893,32 @@ mod tests { Buffer::from([0b0111]), )); let expected = BooleanArray::new( - vec![true, true, true, true].into(), - Some(vec![true, true, true, false].into()), + vec![true, true, false, true].into(), + // a[1] is none in child, struct[3] is none in parent + Some(vec![true, false, true, false].into()), ); assert_eq!(eq(&left_struct, &right_struct).unwrap(), expected); assert_eq!(eq(&right_struct, &left_struct).unwrap(), expected); let expected = BooleanArray::new( - vec![false, false, false, false].into(), - Some(vec![true, true, true, false].into()), + vec![false, false, true, false].into(), + Some(vec![true, false, true, false].into()), ); assert_eq!(neq(&left_struct, &right_struct).unwrap(), expected); assert_eq!(neq(&right_struct, &left_struct).unwrap(), expected); - - let sub_struct_fields = left_struct.fields().clone(); - - // test struct('a', 'b') = struct('a', 'b'), right a[1] is different from left a[2],the null buffer is 0b0111 - let right_a2 = Arc::new(Int32Array::new( - vec![0, 2, 2, 3].into(), - Some(vec![true, true, true, false].into()), - )); - let right_struct = StructArray::from(( - vec![ - (field_a.clone(), right_a2.clone() as ArrayRef), - (field_b.clone(), right_b.clone() as ArrayRef), - ], - Buffer::from([0b0111]), - )); - let expected = BooleanArray::new( - vec![true, false, true, true].into(), - Some(vec![true, true, true, false].into()), - ); - assert_eq!(eq(&left_struct, &right_struct).unwrap(), expected); - assert_eq!(eq(&right_struct, &left_struct).unwrap(), expected); let expected = BooleanArray::new( - vec![false, true, false, false].into(), - Some(vec![true, true, true, false].into()), + // left[0] equals to right[0], left b[1] is not distinct from right b[1], left b[2] is distinct from right b[2], struct[3] is none in parent + vec![false, false, true, false].into(), + None, ); - assert_eq!(neq(&left_struct, &right_struct).unwrap(), expected); - assert_eq!(neq(&right_struct, &left_struct).unwrap(), expected); + assert_eq!(distinct(&left_struct, &right_struct).unwrap(), expected); + assert_eq!(distinct(&right_struct, &left_struct).unwrap(), expected); + let expected = BooleanArray::new(vec![true, true, false, true].into(), None); + assert_eq!(not_distinct(&left_struct, &right_struct).unwrap(), expected); + assert_eq!(not_distinct(&right_struct, &left_struct).unwrap(), expected); + + let sub_struct_fields = left_struct.fields().clone(); - // test struct('a' , struct('suba', 'subb')) = struct('a', struct('suba', 'subb')), where the right suba1[1] different from left suba[1],the null buffer is 0b0111 + // test struct('a' , struct('suba', 'subb')) 、 struct('a', struct('suba', 'subb')), where the right subb1[2] different from left subb[2],the null buffer is 0b0111 let left_struct = StructArray::from(( vec![ (field_a.clone(), left_a.clone() as ArrayRef), @@ -877,14 +948,14 @@ mod tests { Buffer::from([0b0111]), )); let expected = BooleanArray::new( - vec![true, false, true, true].into(), - Some(vec![true, true, true, false].into()), + vec![true, false, false, true].into(), + Some(vec![true, false, true, false].into()), ); assert_eq!(eq(&left_struct, &right_struct).unwrap(), expected); assert_eq!(eq(&right_struct, &left_struct).unwrap(), expected); let expected = BooleanArray::new( - vec![false, true, false, false].into(), - Some(vec![true, true, true, false].into()), + vec![false, true, true, false].into(), + Some(vec![true, false, true, false].into()), ); assert_eq!(neq(&left_struct, &right_struct).unwrap(), expected); assert_eq!(neq(&right_struct, &left_struct).unwrap(), expected);