From 40868860580d09c4672e922a368597b4ff425b02 Mon Sep 17 00:00:00 2001 From: jayzhan211 Date: Sun, 21 Apr 2024 08:49:33 +0800 Subject: [PATCH 1/5] support primitive Signed-off-by: jayzhan211 --- arrow-ord/src/ord.rs | 242 +++++++++++++++++++++++++++--------------- arrow-ord/src/sort.rs | 29 ++--- arrow-row/src/lib.rs | 3 +- parquet-testing | 2 +- testing | 2 +- 5 files changed, 174 insertions(+), 104 deletions(-) diff --git a/arrow-ord/src/ord.rs b/arrow-ord/src/ord.rs index e793038de929..304786c5524d 100644 --- a/arrow-ord/src/ord.rs +++ b/arrow-ord/src/ord.rs @@ -24,8 +24,68 @@ use arrow_buffer::ArrowNativeType; use arrow_schema::ArrowError; use std::cmp::Ordering; +#[derive(Debug, PartialEq, Eq)] +pub enum Compare { + Less, + Greater, + Equal, + LeftNull, + RightNull, + BothNull, +} + +impl Compare { + pub fn ordering(&self, null_first: bool) -> Ordering { + match self { + Self::Less => Ordering::Less, + Self::Greater => Ordering::Greater, + Self::Equal => Ordering::Equal, + Self::LeftNull => { + if null_first { + Ordering::Less + } else { + Ordering::Greater + } + } + Self::RightNull => { + if null_first { + Ordering::Greater + } else { + Ordering::Less + } + } + Self::BothNull => Ordering::Equal, + } + } + + #[inline] + pub fn is_null(&self) -> bool { + matches!(self, Self::LeftNull | Self::RightNull | Self::BothNull) + } + + #[inline] + pub const fn reverse(self) -> Self { + match self { + Self::Less => Self::Greater, + Self::Greater => Self::Less, + _ => self, + } + } +} + +impl From for Compare { + fn from(ordering: Ordering) -> Self { + match ordering { + Ordering::Less => Self::Less, + Ordering::Greater => Self::Greater, + Ordering::Equal => Self::Equal, + } + } +} + /// Compare the values at two arbitrary indices in two arrays. -pub type DynComparator = Box Ordering + Send + Sync>; +pub type DynComparator = Box Compare + Send + Sync>; + fn compare_primitive(left: &dyn Array, right: &dyn Array) -> DynComparator where @@ -33,14 +93,23 @@ where { let left = left.as_primitive::().clone(); let right = right.as_primitive::().clone(); - Box::new(move |i, j| left.value(i).compare(right.value(j))) + Box::new(move |i, j| { + match (left.is_null(i), right.is_null(j)) { + (true, true) => Compare::BothNull, + (true, false) => Compare::LeftNull, + (false, true) => Compare::RightNull, + (false, false) => { + left.value(i).compare(right.value(j)).into() + } + } + }) } fn compare_boolean(left: &dyn Array, right: &dyn Array) -> DynComparator { let left: BooleanArray = left.as_boolean().clone(); let right: BooleanArray = right.as_boolean().clone(); - Box::new(move |i, j| left.value(i).cmp(&right.value(j))) + Box::new(move |i, j| left.value(i).cmp(&right.value(j)).into()) } fn compare_bytes(left: &dyn Array, right: &dyn Array) -> DynComparator { @@ -50,7 +119,7 @@ fn compare_bytes(left: &dyn Array, right: &dyn Array) -> DynCo Box::new(move |i, j| { let l: &[u8] = left.value(i).as_ref(); let r: &[u8] = right.value(j).as_ref(); - l.cmp(r) + l.cmp(r).into() }) } @@ -69,7 +138,7 @@ fn compare_dict( Ok(Box::new(move |i, j| { let l = left_keys.value(i).as_usize(); let r = right_keys.value(j).as_usize(); - cmp(l, r) + cmp(l, r).into() })) } @@ -108,7 +177,7 @@ pub fn build_compare(left: &dyn Array, right: &dyn Array) -> Result { let left = left.as_fixed_size_binary().clone(); let right = right.as_fixed_size_binary().clone(); - Ok(Box::new(move |i, j| left.value(i).cmp(right.value(j)))) + Ok(Box::new(move |i, j| left.value(i).cmp(right.value(j)).into())) }, (Dictionary(l_key, _), Dictionary(r_key, _)) => { macro_rules! dict_helper { @@ -135,6 +204,8 @@ pub mod tests { use half::f16; use std::sync::Arc; + + #[test] fn test_fixed_size_binary() { let items = vec![vec![1u8], vec![2u8]]; @@ -142,7 +213,7 @@ pub mod tests { let cmp = build_compare(&array, &array).unwrap(); - assert_eq!(Ordering::Less, cmp(0, 1)); + assert_eq!(Compare::Less, cmp(0, 1)); } #[test] @@ -154,26 +225,23 @@ pub mod tests { let cmp = build_compare(&array1, &array2).unwrap(); - assert_eq!(Ordering::Less, cmp(0, 0)); + assert_eq!(Compare::Less, cmp(0, 0)); } #[test] fn test_i32() { - let array = Int32Array::from(vec![1, 2]); - - let cmp = build_compare(&array, &array).unwrap(); + let a1 = Int32Array::from(vec![Some(1), None, Some(5)]); - assert_eq!(Ordering::Less, (cmp)(0, 1)); - } - - #[test] - fn test_i32_i32() { - let array1 = Int32Array::from(vec![1]); - let array2 = Int32Array::from(vec![2]); + let cmp = build_compare(&a1, &a1).unwrap(); + assert_eq!(Compare::Less, cmp(0, 2)); + assert_eq!(Compare::BothNull, cmp(1, 1)); - let cmp = build_compare(&array1, &array2).unwrap(); - - assert_eq!(Ordering::Less, cmp(0, 0)); + let a2 = Int32Array::from(vec![Some(3), Some(4), None]); + let cmp = build_compare(&a1, &a2).unwrap(); + assert_eq!(Compare::Less, cmp(0, 0)); + assert_eq!(Compare::LeftNull, cmp(1, 1)); + assert_eq!(Compare::RightNull, cmp(2, 2)); + assert_eq!(Compare::Greater, cmp(2, 0)); } #[test] @@ -182,7 +250,7 @@ pub mod tests { let cmp = build_compare(&array, &array).unwrap(); - assert_eq!(Ordering::Less, cmp(0, 1)); + assert_eq!(Compare::Less, cmp(0, 1)); } #[test] @@ -191,7 +259,7 @@ pub mod tests { let cmp = build_compare(&array, &array).unwrap(); - assert_eq!(Ordering::Less, cmp(0, 1)); + assert_eq!(Compare::Less, cmp(0, 1)); } #[test] @@ -200,8 +268,8 @@ pub mod tests { let cmp = build_compare(&array, &array).unwrap(); - assert_eq!(Ordering::Less, cmp(0, 1)); - assert_eq!(Ordering::Equal, cmp(1, 1)); + assert_eq!(Compare::Less, cmp(0, 1)); + assert_eq!(Compare::Equal, cmp(1, 1)); } #[test] @@ -210,8 +278,8 @@ pub mod tests { let cmp = build_compare(&array, &array).unwrap(); - assert_eq!(Ordering::Less, cmp(0, 1)); - assert_eq!(Ordering::Greater, cmp(1, 0)); + assert_eq!(Compare::Less, cmp(0, 1)); + assert_eq!(Compare::Greater, cmp(1, 0)); } #[test] @@ -227,14 +295,14 @@ pub mod tests { let cmp = build_compare(&array, &array).unwrap(); - assert_eq!(Ordering::Less, cmp(0, 1)); - assert_eq!(Ordering::Greater, cmp(1, 0)); + assert_eq!(Compare::Less, cmp(0, 1)); + assert_eq!(Compare::Greater, cmp(1, 0)); // somewhat confusingly, while 90M milliseconds is more than 1 day, // it will compare less as the comparison is done on the underlying // values not field by field - assert_eq!(Ordering::Greater, cmp(1, 2)); - assert_eq!(Ordering::Less, cmp(2, 1)); + assert_eq!(Compare::Greater, cmp(1, 2)); + assert_eq!(Compare::Less, cmp(2, 1)); } #[test] @@ -250,12 +318,12 @@ pub mod tests { let cmp = build_compare(&array, &array).unwrap(); - assert_eq!(Ordering::Less, cmp(0, 1)); - assert_eq!(Ordering::Greater, cmp(1, 0)); + assert_eq!(Compare::Less, cmp(0, 1)); + assert_eq!(Compare::Greater, cmp(1, 0)); // the underlying representation is months, so both quantities are the same - assert_eq!(Ordering::Equal, cmp(1, 2)); - assert_eq!(Ordering::Equal, cmp(2, 1)); + assert_eq!(Compare::Equal, cmp(1, 2)); + assert_eq!(Compare::Equal, cmp(2, 1)); } #[test] @@ -271,14 +339,14 @@ pub mod tests { let cmp = build_compare(&array, &array).unwrap(); - assert_eq!(Ordering::Less, cmp(0, 1)); - assert_eq!(Ordering::Greater, cmp(1, 0)); + assert_eq!(Compare::Less, cmp(0, 1)); + assert_eq!(Compare::Greater, cmp(1, 0)); // somewhat confusingly, while 100 days is more than 1 month in all cases // it will compare less as the comparison is done on the underlying // values not field by field - assert_eq!(Ordering::Greater, cmp(1, 2)); - assert_eq!(Ordering::Less, cmp(2, 1)); + assert_eq!(Compare::Greater, cmp(1, 2)); + assert_eq!(Compare::Less, cmp(2, 1)); } #[test] @@ -290,8 +358,8 @@ pub mod tests { .unwrap(); let cmp = build_compare(&array, &array).unwrap(); - assert_eq!(Ordering::Less, cmp(1, 0)); - assert_eq!(Ordering::Greater, cmp(0, 2)); + assert_eq!(Compare::Less, cmp(1, 0)); + assert_eq!(Compare::Greater, cmp(0, 2)); } #[test] @@ -307,8 +375,8 @@ pub mod tests { .unwrap(); let cmp = build_compare(&array, &array).unwrap(); - assert_eq!(Ordering::Less, cmp(1, 0)); - assert_eq!(Ordering::Greater, cmp(0, 2)); + assert_eq!(Compare::Less, cmp(1, 0)); + assert_eq!(Compare::Greater, cmp(0, 2)); } #[test] @@ -318,9 +386,9 @@ pub mod tests { let cmp = build_compare(&array, &array).unwrap(); - assert_eq!(Ordering::Less, cmp(0, 1)); - assert_eq!(Ordering::Equal, cmp(3, 4)); - assert_eq!(Ordering::Greater, cmp(2, 3)); + assert_eq!(Compare::Less, cmp(0, 1)); + assert_eq!(Compare::Equal, cmp(3, 4)); + assert_eq!(Compare::Greater, cmp(2, 3)); } #[test] @@ -332,9 +400,9 @@ pub mod tests { let cmp = build_compare(&a1, &a2).unwrap(); - assert_eq!(Ordering::Less, cmp(0, 0)); - assert_eq!(Ordering::Equal, cmp(0, 3)); - assert_eq!(Ordering::Greater, cmp(1, 3)); + assert_eq!(Compare::Less, cmp(0, 0)); + assert_eq!(Compare::Equal, cmp(0, 3)); + assert_eq!(Compare::Greater, cmp(1, 3)); } #[test] @@ -349,11 +417,11 @@ pub mod tests { let cmp = build_compare(&array1, &array2).unwrap(); - assert_eq!(Ordering::Less, cmp(0, 0)); - assert_eq!(Ordering::Less, cmp(0, 3)); - assert_eq!(Ordering::Equal, cmp(3, 3)); - assert_eq!(Ordering::Greater, cmp(3, 1)); - assert_eq!(Ordering::Greater, cmp(3, 2)); + assert_eq!(Compare::Less, cmp(0, 0)); + assert_eq!(Compare::Less, cmp(0, 3)); + assert_eq!(Compare::Equal, cmp(3, 3)); + assert_eq!(Compare::Greater, cmp(3, 1)); + assert_eq!(Compare::Greater, cmp(3, 2)); } #[test] @@ -368,11 +436,11 @@ pub mod tests { let cmp = build_compare(&array1, &array2).unwrap(); - assert_eq!(Ordering::Less, cmp(0, 0)); - assert_eq!(Ordering::Less, cmp(0, 3)); - assert_eq!(Ordering::Equal, cmp(3, 3)); - assert_eq!(Ordering::Greater, cmp(3, 1)); - assert_eq!(Ordering::Greater, cmp(3, 2)); + assert_eq!(Compare::Less, cmp(0, 0)); + assert_eq!(Compare::Less, cmp(0, 3)); + assert_eq!(Compare::Equal, cmp(3, 3)); + assert_eq!(Compare::Greater, cmp(3, 1)); + assert_eq!(Compare::Greater, cmp(3, 2)); } #[test] @@ -387,11 +455,11 @@ pub mod tests { let cmp = build_compare(&array1, &array2).unwrap(); - assert_eq!(Ordering::Less, cmp(0, 0)); - assert_eq!(Ordering::Less, cmp(0, 3)); - assert_eq!(Ordering::Equal, cmp(3, 3)); - assert_eq!(Ordering::Greater, cmp(3, 1)); - assert_eq!(Ordering::Greater, cmp(3, 2)); + assert_eq!(Compare::Less, cmp(0, 0)); + assert_eq!(Compare::Less, cmp(0, 3)); + assert_eq!(Compare::Equal, cmp(3, 3)); + assert_eq!(Compare::Greater, cmp(3, 1)); + assert_eq!(Compare::Greater, cmp(3, 2)); } #[test] @@ -406,11 +474,11 @@ pub mod tests { let cmp = build_compare(&array1, &array2).unwrap(); - assert_eq!(Ordering::Less, cmp(0, 0)); - assert_eq!(Ordering::Less, cmp(0, 3)); - assert_eq!(Ordering::Equal, cmp(3, 3)); - assert_eq!(Ordering::Greater, cmp(3, 1)); - assert_eq!(Ordering::Greater, cmp(3, 2)); + assert_eq!(Compare::Less, cmp(0, 0)); + assert_eq!(Compare::Less, cmp(0, 3)); + assert_eq!(Compare::Equal, cmp(3, 3)); + assert_eq!(Compare::Greater, cmp(3, 1)); + assert_eq!(Compare::Greater, cmp(3, 2)); } #[test] @@ -425,11 +493,11 @@ pub mod tests { let cmp = build_compare(&array1, &array2).unwrap(); - assert_eq!(Ordering::Less, cmp(0, 0)); - assert_eq!(Ordering::Less, cmp(0, 3)); - assert_eq!(Ordering::Equal, cmp(3, 3)); - assert_eq!(Ordering::Greater, cmp(3, 1)); - assert_eq!(Ordering::Greater, cmp(3, 2)); + assert_eq!(Compare::Less, cmp(0, 0)); + assert_eq!(Compare::Less, cmp(0, 3)); + assert_eq!(Compare::Equal, cmp(3, 3)); + assert_eq!(Compare::Greater, cmp(3, 1)); + assert_eq!(Compare::Greater, cmp(3, 2)); } #[test] @@ -444,11 +512,11 @@ pub mod tests { let cmp = build_compare(&array1, &array2).unwrap(); - assert_eq!(Ordering::Less, cmp(0, 0)); - assert_eq!(Ordering::Less, cmp(0, 3)); - assert_eq!(Ordering::Equal, cmp(3, 3)); - assert_eq!(Ordering::Greater, cmp(3, 1)); - assert_eq!(Ordering::Greater, cmp(3, 2)); + assert_eq!(Compare::Less, cmp(0, 0)); + assert_eq!(Compare::Less, cmp(0, 3)); + assert_eq!(Compare::Equal, cmp(3, 3)); + assert_eq!(Compare::Greater, cmp(3, 1)); + assert_eq!(Compare::Greater, cmp(3, 2)); } #[test] @@ -473,11 +541,11 @@ pub mod tests { let cmp = build_compare(&array1, &array2).unwrap(); - assert_eq!(Ordering::Less, cmp(0, 0)); - assert_eq!(Ordering::Less, cmp(0, 3)); - assert_eq!(Ordering::Equal, cmp(3, 3)); - assert_eq!(Ordering::Greater, cmp(3, 1)); - assert_eq!(Ordering::Greater, cmp(3, 2)); + assert_eq!(Compare::Less, cmp(0, 0)); + assert_eq!(Compare::Less, cmp(0, 3)); + assert_eq!(Compare::Equal, cmp(3, 3)); + assert_eq!(Compare::Greater, cmp(3, 1)); + assert_eq!(Compare::Greater, cmp(3, 2)); } fn test_bytes_impl() { @@ -485,9 +553,9 @@ pub mod tests { let a = GenericByteArray::::new(offsets, b"abcdefa".into(), None); let cmp = build_compare(&a, &a).unwrap(); - assert_eq!(Ordering::Less, cmp(0, 1)); - assert_eq!(Ordering::Greater, cmp(0, 2)); - assert_eq!(Ordering::Equal, cmp(1, 1)); + assert_eq!(Compare::Less, cmp(0, 1)); + assert_eq!(Compare::Greater, cmp(0, 2)); + assert_eq!(Compare::Equal, cmp(1, 1)); } #[test] diff --git a/arrow-ord/src/sort.rs b/arrow-ord/src/sort.rs index fe3a1f86ac00..681ae8e819c5 100644 --- a/arrow-ord/src/sort.rs +++ b/arrow-ord/src/sort.rs @@ -17,7 +17,8 @@ //! Defines sort kernel for `ArrayRef` -use crate::ord::{build_compare, DynComparator}; +use crate::ord::{Compare, DynComparator}; +use crate::ord::build_compare; use arrow_array::builder::BufferBuilder; use arrow_array::cast::*; use arrow_array::types::*; @@ -685,7 +686,7 @@ pub fn lexsort_to_indices( let lexicographical_comparator = LexicographicalComparator::try_new(columns)?; // uint32 can be sorted unstably sort_unstable_by(&mut value_indices, len, |a, b| { - lexicographical_comparator.compare(*a, *b) + lexicographical_comparator.compare(*a, *b).ordering(true) }); Ok(UInt32Array::from_iter_values( @@ -718,7 +719,7 @@ pub struct LexicographicalComparator { impl LexicographicalComparator { /// lexicographically compare values at the wrapped columns with given indices. - pub fn compare(&self, a_idx: usize, b_idx: usize) -> Ordering { + pub fn compare(&self, a_idx: usize, b_idx: usize) -> Compare { for (nulls, comparator, sort_option) in &self.compare_items { let (lhs_valid, rhs_valid) = match nulls { Some(n) => (n.is_valid(a_idx), n.is_valid(b_idx)), @@ -729,7 +730,7 @@ impl LexicographicalComparator { (true, true) => { match (comparator)(a_idx, b_idx) { // equal, move on to next column - Ordering::Equal => continue, + Compare::Equal => continue, order => { if sort_option.descending { return order.reverse(); @@ -741,16 +742,16 @@ impl LexicographicalComparator { } (false, true) => { return if sort_option.nulls_first { - Ordering::Less + Compare::Less } else { - Ordering::Greater + Compare::Greater }; } (true, false) => { return if sort_option.nulls_first { - Ordering::Greater + Compare::Greater } else { - Ordering::Less + Compare::Less }; } // equal, move on to next column @@ -758,7 +759,7 @@ impl LexicographicalComparator { } } - Ordering::Equal + Compare::Equal } /// Create a new lex comparator that will wrap the given sort columns and give comparison @@ -799,7 +800,7 @@ impl LexicographicalComparator { &rank[start..end] }}; } - Ord::cmp(nth_value!(i), nth_value!(j)) + Ord::cmp(nth_value!(i), nth_value!(j)).into() }); Ok(cmp) } @@ -817,7 +818,7 @@ impl LexicographicalComparator { &rank[start..start + size] }}; } - Ord::cmp(nth_value!(i), nth_value!(j)) + Ord::cmp(nth_value!(i), nth_value!(j)).into() }); Ok(cmp) } @@ -4202,11 +4203,11 @@ mod tests { }]) .unwrap(); // 1.cmp(NULL) - assert_eq!(comparator.compare(0, 1), Ordering::Greater); + assert_eq!(comparator.compare(0, 1), Compare::Greater); // NULL.cmp(NULL) - assert_eq!(comparator.compare(2, 1), Ordering::Equal); + assert_eq!(comparator.compare(2, 1), Compare::Equal); // NULL.cmp(4) - assert_eq!(comparator.compare(2, 3), Ordering::Less); + assert_eq!(comparator.compare(2, 3), Compare::Less); } #[test] diff --git a/arrow-row/src/lib.rs b/arrow-row/src/lib.rs index 037ed404adca..76efbfb8f354 100644 --- a/arrow-row/src/lib.rs +++ b/arrow-row/src/lib.rs @@ -1303,6 +1303,7 @@ unsafe fn decode_column( #[cfg(test)] mod tests { + use arrow_ord::ord::Compare; use rand::distributions::uniform::SampleUniform; use rand::distributions::{Distribution, Standard}; use rand::{thread_rng, Rng}; @@ -2193,7 +2194,7 @@ mod tests { for j in 0..len { let row_i = rows.row(i); let row_j = rows.row(j); - let row_cmp = row_i.cmp(&row_j); + let row_cmp: Compare = row_i.cmp(&row_j).into(); let lex_cmp = comparator.compare(i, j); assert_eq!( row_cmp, diff --git a/parquet-testing b/parquet-testing index 4cb3cff24c96..89b685a64c31 160000 --- a/parquet-testing +++ b/parquet-testing @@ -1 +1 @@ -Subproject commit 4cb3cff24c965fb329cdae763eabce47395a68a0 +Subproject commit 89b685a64c3117b3023d8684af1f41400841db71 diff --git a/testing b/testing index e270341fb5f3..d315f7985207 160000 --- a/testing +++ b/testing @@ -1 +1 @@ -Subproject commit e270341fb5f3ff785410e6286cc42898e9d6a99c +Subproject commit d315f7985207d2d67fc2c8e41053e9d97d573f4b From 4b8807b1281e996e44a329e054d55026c41b6a25 Mon Sep 17 00:00:00 2001 From: jayzhan211 Date: Sun, 21 Apr 2024 08:56:04 +0800 Subject: [PATCH 2/5] bool Signed-off-by: jayzhan211 --- arrow-ord/src/ord.rs | 36 +++++++++++++++++++++++------------- arrow-ord/src/sort.rs | 2 +- 2 files changed, 24 insertions(+), 14 deletions(-) diff --git a/arrow-ord/src/ord.rs b/arrow-ord/src/ord.rs index 304786c5524d..274ddf0b6a37 100644 --- a/arrow-ord/src/ord.rs +++ b/arrow-ord/src/ord.rs @@ -86,22 +86,17 @@ impl From for Compare { /// Compare the values at two arbitrary indices in two arrays. pub type DynComparator = Box Compare + Send + Sync>; - fn compare_primitive(left: &dyn Array, right: &dyn Array) -> DynComparator where T::Native: ArrowNativeTypeOp, { let left = left.as_primitive::().clone(); let right = right.as_primitive::().clone(); - Box::new(move |i, j| { - match (left.is_null(i), right.is_null(j)) { - (true, true) => Compare::BothNull, - (true, false) => Compare::LeftNull, - (false, true) => Compare::RightNull, - (false, false) => { - left.value(i).compare(right.value(j)).into() - } - } + Box::new(move |i, j| match (left.is_null(i), right.is_null(j)) { + (true, true) => Compare::BothNull, + (true, false) => Compare::LeftNull, + (false, true) => Compare::RightNull, + (false, false) => left.value(i).compare(right.value(j)).into(), }) } @@ -109,7 +104,12 @@ fn compare_boolean(left: &dyn Array, right: &dyn Array) -> DynComparator { let left: BooleanArray = left.as_boolean().clone(); let right: BooleanArray = right.as_boolean().clone(); - Box::new(move |i, j| left.value(i).cmp(&right.value(j)).into()) + Box::new(move |i, j| match (left.is_null(i), right.is_null(j)) { + (true, true) => Compare::BothNull, + (true, false) => Compare::LeftNull, + (false, true) => Compare::RightNull, + (false, false) => left.value(i).cmp(&right.value(j)).into(), + }) } fn compare_bytes(left: &dyn Array, right: &dyn Array) -> DynComparator { @@ -204,8 +204,6 @@ pub mod tests { use half::f16; use std::sync::Arc; - - #[test] fn test_fixed_size_binary() { let items = vec![vec![1u8], vec![2u8]]; @@ -244,6 +242,18 @@ pub mod tests { assert_eq!(Compare::Greater, cmp(2, 0)); } + #[test] + fn test_bool() { + let a1 = BooleanArray::from(vec![Some(true), None, Some(false)]); + let a2 = BooleanArray::from(vec![Some(false), Some(true), None]); + let cmp = build_compare(&a1, &a2).unwrap(); + assert_eq!(Compare::Greater, cmp(0, 0)); + assert_eq!(Compare::Equal, cmp(0, 1)); + assert_eq!(Compare::LeftNull, cmp(1, 0)); + assert_eq!(Compare::BothNull, cmp(1, 2)); + assert_eq!(Compare::RightNull, cmp(2, 2)); + } + #[test] fn test_f16() { let array = Float16Array::from(vec![f16::from_f32(1.0), f16::from_f32(2.0)]); diff --git a/arrow-ord/src/sort.rs b/arrow-ord/src/sort.rs index 681ae8e819c5..865f1364612c 100644 --- a/arrow-ord/src/sort.rs +++ b/arrow-ord/src/sort.rs @@ -17,8 +17,8 @@ //! Defines sort kernel for `ArrayRef` -use crate::ord::{Compare, DynComparator}; use crate::ord::build_compare; +use crate::ord::{Compare, DynComparator}; use arrow_array::builder::BufferBuilder; use arrow_array::cast::*; use arrow_array::types::*; From cb16bce0e664cef3c35df72011f7c92bf74eda81 Mon Sep 17 00:00:00 2001 From: jayzhan211 Date: Sun, 21 Apr 2024 09:22:52 +0800 Subject: [PATCH 3/5] dict and fixed size binary Signed-off-by: jayzhan211 --- arrow-ord/src/ord.rs | 116 ++++++++++++++++++++++++++++++++++++------- 1 file changed, 97 insertions(+), 19 deletions(-) diff --git a/arrow-ord/src/ord.rs b/arrow-ord/src/ord.rs index 274ddf0b6a37..f97267d1fb14 100644 --- a/arrow-ord/src/ord.rs +++ b/arrow-ord/src/ord.rs @@ -116,10 +116,15 @@ fn compare_bytes(left: &dyn Array, right: &dyn Array) -> DynCo let left = left.as_bytes::().clone(); let right = right.as_bytes::().clone(); - Box::new(move |i, j| { - let l: &[u8] = left.value(i).as_ref(); - let r: &[u8] = right.value(j).as_ref(); - l.cmp(r).into() + Box::new(move |i, j| match (left.is_null(i), right.is_null(j)) { + (true, true) => Compare::BothNull, + (true, false) => Compare::LeftNull, + (false, true) => Compare::RightNull, + (false, false) => { + let l: &[u8] = left.value(i).as_ref(); + let r: &[u8] = right.value(j).as_ref(); + l.cmp(r).into() + } }) } @@ -134,11 +139,17 @@ fn compare_dict( let left_keys = left.keys().clone(); let right_keys = right.keys().clone(); - // TODO: Handle value nulls (#2687) Ok(Box::new(move |i, j| { - let l = left_keys.value(i).as_usize(); - let r = right_keys.value(j).as_usize(); - cmp(l, r).into() + match (left_keys.is_null(i), right_keys.is_null(j)) { + (true, true) => Compare::BothNull, + (true, false) => Compare::LeftNull, + (false, true) => Compare::RightNull, + (false, false) => { + let l = left_keys.value(i).as_usize(); + let r = right_keys.value(j).as_usize(); + cmp(l, r).into() + } + } })) } @@ -177,7 +188,18 @@ pub fn build_compare(left: &dyn Array, right: &dyn Array) -> Result { let left = left.as_fixed_size_binary().clone(); let right = right.as_fixed_size_binary().clone(); - Ok(Box::new(move |i, j| left.value(i).cmp(right.value(j)).into())) + Ok(Box::new(move |i, j| { + match (left.is_null(i), right.is_null(j)) { + (true, true) => Compare::BothNull, + (true, false) => Compare::LeftNull, + (false, true) => Compare::RightNull, + (false, false) => { + let l = left.value(i).as_ref(); + let r = right.value(j).as_ref(); + l.cmp(r).into() + } + } + })) }, (Dictionary(l_key, _), Dictionary(r_key, _)) => { macro_rules! dict_helper { @@ -200,18 +222,41 @@ pub fn build_compare(left: &dyn Array, right: &dyn Array) -> Result>(); + let mut builder = StringDictionaryBuilder::::new(); + builder.append_value("a"); + builder.append_value("b"); + builder.append_null(); + builder.append_value("b"); + builder.append_null(); + builder.append_null(); + builder.append_value("c"); + let a1 = builder.finish(); + + let mut builder = StringDictionaryBuilder::::new(); + builder.append_null(); + builder.append_value("a"); + builder.append_value("b"); + builder.append_value("b"); + builder.append_value("c"); + builder.append_null(); + builder.append_null(); + let a2 = builder.finish(); - let cmp = build_compare(&array, &array).unwrap(); + let cmp = build_compare(&a1, &a2).unwrap(); - assert_eq!(Compare::Less, cmp(0, 1)); - assert_eq!(Compare::Equal, cmp(3, 4)); - assert_eq!(Compare::Greater, cmp(2, 3)); + assert_eq!(Compare::RightNull, cmp(0, 0)); + assert_eq!(Compare::LeftNull, cmp(2, 1)); + assert_eq!(Compare::Equal, cmp(1, 2)); + assert_eq!(Compare::LeftNull, cmp(2, 3)); + assert_eq!(Compare::BothNull, cmp(2, 0)); + assert_eq!(Compare::Greater, cmp(6, 1)); } #[test] @@ -568,6 +633,19 @@ pub mod tests { assert_eq!(Compare::Equal, cmp(1, 1)); } + #[test] + fn test_string() { + let a1 = StringArray::from(vec![Some("a"), None, Some("abcd")]); + let a2 = StringArray::from(vec![Some("ab"), Some("abcd"), None]); + let cmp = build_compare(&a1, &a2).unwrap(); + assert_eq!(Compare::Less, cmp(0, 0)); + assert_eq!(Compare::Equal, cmp(2, 1)); + assert_eq!(Compare::Greater, cmp(2, 0)); + assert_eq!(Compare::RightNull, cmp(0, 2)); + assert_eq!(Compare::BothNull, cmp(1, 2)); + assert_eq!(Compare::LeftNull, cmp(1, 1)); + } + #[test] fn test_bytes() { test_bytes_impl::(); From 040df2cfb6db8dd47fe0d6d3ab7296a1f88face4 Mon Sep 17 00:00:00 2001 From: jayzhan211 Date: Sun, 21 Apr 2024 10:03:58 +0800 Subject: [PATCH 4/5] fix ordering Signed-off-by: jayzhan211 --- arrow-ord/src/sort.rs | 69 +++++++++++++++++++++++++++++++++++-------- arrow-row/src/lib.rs | 2 +- 2 files changed, 58 insertions(+), 13 deletions(-) diff --git a/arrow-ord/src/sort.rs b/arrow-ord/src/sort.rs index 865f1364612c..5d3800663f26 100644 --- a/arrow-ord/src/sort.rs +++ b/arrow-ord/src/sort.rs @@ -686,7 +686,7 @@ pub fn lexsort_to_indices( let lexicographical_comparator = LexicographicalComparator::try_new(columns)?; // uint32 can be sorted unstably sort_unstable_by(&mut value_indices, len, |a, b| { - lexicographical_comparator.compare(*a, *b).ordering(true) + lexicographical_comparator.compare(*a, *b) }); Ok(UInt32Array::from_iter_values( @@ -719,13 +719,16 @@ pub struct LexicographicalComparator { impl LexicographicalComparator { /// lexicographically compare values at the wrapped columns with given indices. - pub fn compare(&self, a_idx: usize, b_idx: usize) -> Compare { + pub fn compare(&self, a_idx: usize, b_idx: usize) -> Ordering { for (nulls, comparator, sort_option) in &self.compare_items { let (lhs_valid, rhs_valid) = match nulls { Some(n) => (n.is_valid(a_idx), n.is_valid(b_idx)), None => (true, true), }; + // TODO: after dict values with nulls are supported, we can compare with + // (comparator)(a_idx, b_idx) directly + match (lhs_valid, rhs_valid) { (true, true) => { match (comparator)(a_idx, b_idx) { @@ -733,25 +736,25 @@ impl LexicographicalComparator { Compare::Equal => continue, order => { if sort_option.descending { - return order.reverse(); + return order.reverse().ordering(sort_option.nulls_first); } else { - return order; + return order.ordering(sort_option.nulls_first); } } } } (false, true) => { return if sort_option.nulls_first { - Compare::Less + Ordering::Less } else { - Compare::Greater + Ordering::Greater }; } (true, false) => { return if sort_option.nulls_first { - Compare::Greater + Ordering::Greater } else { - Compare::Less + Ordering::Less }; } // equal, move on to next column @@ -759,7 +762,7 @@ impl LexicographicalComparator { } } - Compare::Equal + Ordering::Equal } /// Create a new lex comparator that will wrap the given sort columns and give comparison @@ -800,6 +803,7 @@ impl LexicographicalComparator { &rank[start..end] }}; } + Ord::cmp(nth_value!(i), nth_value!(j)).into() }); Ok(cmp) @@ -4203,11 +4207,11 @@ mod tests { }]) .unwrap(); // 1.cmp(NULL) - assert_eq!(comparator.compare(0, 1), Compare::Greater); + assert_eq!(comparator.compare(0, 1), Ordering::Greater); // NULL.cmp(NULL) - assert_eq!(comparator.compare(2, 1), Compare::Equal); + assert_eq!(comparator.compare(2, 1), Ordering::Equal); // NULL.cmp(4) - assert_eq!(comparator.compare(2, 3), Compare::Less); + assert_eq!(comparator.compare(2, 3), Ordering::Less); } #[test] @@ -4236,4 +4240,45 @@ mod tests { let sort_indices = sort_to_indices(&a, None, None).unwrap(); assert_eq!(sort_indices.values(), &[1, 2, 0]); } + + #[test] + fn test_build_list_compare() { + let a1 = Arc::new(Int64Array::from(vec![ + Some(1), + None, + Some(2), + Some(3), + None, + ])) as ArrayRef; + let comparator = LexicographicalComparator::try_new(&[SortColumn { + values: a1.clone(), + options: None, + }]) + .unwrap(); + + // default is null first + // 1 vs null + assert_eq!(Ordering::Greater, comparator.compare(0, 1)); + // null vs 2 + assert_eq!(Ordering::Less, comparator.compare(1, 2)); + // null vs null + assert_eq!(Ordering::Equal, comparator.compare(1, 4)); + + let comparator = LexicographicalComparator::try_new(&[SortColumn { + values: a1, + options: Some(SortOptions { + descending: false, + nulls_first: false, + }), + }]) + .unwrap(); + + // nulls last + // 1 vs null + assert_eq!(Ordering::Less, comparator.compare(0, 1)); + // null vs 2 + assert_eq!(Ordering::Greater, comparator.compare(1, 2)); + // null vs null + assert_eq!(Ordering::Equal, comparator.compare(1, 4)); + } } diff --git a/arrow-row/src/lib.rs b/arrow-row/src/lib.rs index 76efbfb8f354..0f97f778e9c9 100644 --- a/arrow-row/src/lib.rs +++ b/arrow-row/src/lib.rs @@ -2194,7 +2194,7 @@ mod tests { for j in 0..len { let row_i = rows.row(i); let row_j = rows.row(j); - let row_cmp: Compare = row_i.cmp(&row_j).into(); + let row_cmp = row_i.cmp(&row_j); let lex_cmp = comparator.compare(i, j); assert_eq!( row_cmp, From 85bfbce4c3cf1af52affe7960b88e1abedd09198 Mon Sep 17 00:00:00 2001 From: jayzhan211 Date: Sun, 21 Apr 2024 10:09:34 +0800 Subject: [PATCH 5/5] revert submodule Signed-off-by: jayzhan211 --- parquet-testing | 2 +- testing | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/parquet-testing b/parquet-testing index 89b685a64c31..4cb3cff24c96 160000 --- a/parquet-testing +++ b/parquet-testing @@ -1 +1 @@ -Subproject commit 89b685a64c3117b3023d8684af1f41400841db71 +Subproject commit 4cb3cff24c965fb329cdae763eabce47395a68a0 diff --git a/testing b/testing index d315f7985207..e270341fb5f3 160000 --- a/testing +++ b/testing @@ -1 +1 @@ -Subproject commit d315f7985207d2d67fc2c8e41053e9d97d573f4b +Subproject commit e270341fb5f3ff785410e6286cc42898e9d6a99c