From c5ab64cdc16b2b37943a865f2e4e8b832d24cdf3 Mon Sep 17 00:00:00 2001 From: Andrew Lamb Date: Sat, 2 Mar 2024 16:14:19 -0500 Subject: [PATCH] fix: lexsort_to_indices unsupported mixed types with list (#5455) * fix: lexsort_to_indices unsupported mixed types with list * chore: pass clippy --------- Co-authored-by: JasonLi --- arrow-ord/src/sort.rs | 313 +++++++++++++++++++++++++++++++++++++-- arrow/benches/lexsort.rs | 59 ++++++++ 2 files changed, 363 insertions(+), 9 deletions(-) diff --git a/arrow-ord/src/sort.rs b/arrow-ord/src/sort.rs index c25df3a480b3..2c06057a84e0 100644 --- a/arrow-ord/src/sort.rs +++ b/arrow-ord/src/sort.rs @@ -766,18 +766,61 @@ impl LexicographicalComparator { pub fn try_new(columns: &[SortColumn]) -> Result { let compare_items = columns .iter() - .map(|column| { - // flatten and convert build comparators - let values = column.values.as_ref(); - Ok(( - values.logical_nulls(), - build_compare(values, values)?, - column.options.unwrap_or_default(), - )) - }) + .map(Self::build_compare_item) .collect::, ArrowError>>()?; Ok(LexicographicalComparator { compare_items }) } + + fn build_compare_item(column: &SortColumn) -> Result { + let values = column.values.as_ref(); + let options = column.options.unwrap_or_default(); + let comparator = match values.data_type() { + DataType::List(_) => Self::build_list_compare(values.as_list::(), options)?, + DataType::LargeList(_) => Self::build_list_compare(values.as_list::(), options)?, + DataType::FixedSizeList(_, _) => { + Self::build_fixed_size_list_compare(values.as_fixed_size_list(), options)? + } + _ => build_compare(values, values)?, + }; + Ok((values.logical_nulls(), comparator, options)) + } + + fn build_list_compare( + array: &GenericListArray, + options: SortOptions, + ) -> Result { + let rank = child_rank(array.values().as_ref(), options)?; + let offsets = array.offsets().clone(); + let cmp = Box::new(move |i: usize, j: usize| { + macro_rules! nth_value { + ($INDEX:expr) => {{ + let end = offsets[$INDEX + 1].as_usize(); + let start = offsets[$INDEX].as_usize(); + &rank[start..end] + }}; + } + Ord::cmp(nth_value!(i), nth_value!(j)) + }); + Ok(cmp) + } + + fn build_fixed_size_list_compare( + array: &FixedSizeListArray, + options: SortOptions, + ) -> Result { + let rank = child_rank(array.values().as_ref(), options)?; + let size = array.value_length() as usize; + let cmp = Box::new(move |i: usize, j: usize| { + macro_rules! nth_value { + ($INDEX:expr) => {{ + let start = $INDEX * size; + &rank[start..start + size] + }}; + } + Ord::cmp(nth_value!(i), nth_value!(j)) + }); + Ok(cmp) + } } #[cfg(test)] @@ -3592,6 +3635,258 @@ mod tests { // Limiting by more rows than present is ok test_lex_sort_arrays(input, slice_arrays(expected, 0, 5), Some(10)); + + // test with FixedSizeListArray, arrays order: [UInt32, FixedSizeList(UInt32, 1)] + + // case1 + let primitive_array_data = vec![ + Some(2), + Some(3), + Some(2), + Some(0), + None, + Some(2), + Some(1), + Some(2), + ]; + let list_array_data = vec![ + None, + Some(vec![Some(4)]), + Some(vec![Some(3)]), + Some(vec![Some(1)]), + Some(vec![Some(5)]), + Some(vec![Some(0)]), + Some(vec![Some(2)]), + Some(vec![None]), + ]; + + let expected_primitive_array_data = vec![ + None, + Some(0), + Some(1), + Some(2), + Some(2), + Some(2), + Some(2), + Some(3), + ]; + let expected_list_array_data = vec![ + Some(vec![Some(5)]), + Some(vec![Some(1)]), + Some(vec![Some(2)]), + None, // <- + Some(vec![None]), + Some(vec![Some(0)]), + Some(vec![Some(3)]), // <- + Some(vec![Some(4)]), + ]; + test_lex_sort_mixed_types_with_fixed_size_list::( + primitive_array_data.clone(), + list_array_data.clone(), + expected_primitive_array_data.clone(), + expected_list_array_data, + None, + None, + ); + + // case2 + let primitive_array_options = SortOptions { + descending: false, + nulls_first: true, + }; + let list_array_options = SortOptions { + descending: false, + nulls_first: false, // has been modified + }; + let expected_list_array_data = vec![ + Some(vec![Some(5)]), + Some(vec![Some(1)]), + Some(vec![Some(2)]), + Some(vec![Some(0)]), // <- + Some(vec![Some(3)]), + Some(vec![None]), + None, // <- + Some(vec![Some(4)]), + ]; + test_lex_sort_mixed_types_with_fixed_size_list::( + primitive_array_data.clone(), + list_array_data.clone(), + expected_primitive_array_data.clone(), + expected_list_array_data, + Some(primitive_array_options), + Some(list_array_options), + ); + + // case3 + let primitive_array_options = SortOptions { + descending: false, + nulls_first: true, + }; + let list_array_options = SortOptions { + descending: true, // has been modified + nulls_first: true, + }; + let expected_list_array_data = vec![ + Some(vec![Some(5)]), + Some(vec![Some(1)]), + Some(vec![Some(2)]), + None, // <- + Some(vec![None]), + Some(vec![Some(3)]), + Some(vec![Some(0)]), // <- + Some(vec![Some(4)]), + ]; + test_lex_sort_mixed_types_with_fixed_size_list::( + primitive_array_data.clone(), + list_array_data.clone(), + expected_primitive_array_data, + expected_list_array_data, + Some(primitive_array_options), + Some(list_array_options), + ); + + // test with ListArray/LargeListArray, arrays order: [List/LargeList, UInt32] + + let list_array_data = vec![ + Some(vec![Some(2), Some(1)]), // 0 + None, // 10 + Some(vec![Some(3)]), // 1 + Some(vec![Some(2), Some(0)]), // 2 + Some(vec![None, Some(2)]), // 3 + Some(vec![Some(0)]), // none + None, // 11 + Some(vec![Some(2), None]), // 4 + Some(vec![None]), // 5 + Some(vec![Some(2), Some(1)]), // 6 + ]; + let primitive_array_data = vec![ + Some(0), + Some(10), + Some(1), + Some(2), + Some(3), + None, + Some(11), + Some(4), + Some(5), + Some(6), + ]; + let expected_list_array_data = vec![ + None, + None, + Some(vec![None]), + Some(vec![None, Some(2)]), + Some(vec![Some(0)]), + Some(vec![Some(2), None]), + Some(vec![Some(2), Some(0)]), + Some(vec![Some(2), Some(1)]), + Some(vec![Some(2), Some(1)]), + Some(vec![Some(3)]), + ]; + let expected_primitive_array_data = vec![ + Some(10), + Some(11), + Some(5), + Some(3), + None, + Some(4), + Some(2), + Some(0), + Some(6), + Some(1), + ]; + test_lex_sort_mixed_types_with_list::( + list_array_data.clone(), + primitive_array_data.clone(), + expected_list_array_data, + expected_primitive_array_data, + None, + None, + ); + } + + fn test_lex_sort_mixed_types_with_fixed_size_list( + primitive_array_data: Vec>, + list_array_data: Vec>>>, + expected_primitive_array_data: Vec>, + expected_list_array_data: Vec>>>, + primitive_array_options: Option, + list_array_options: Option, + ) where + T: ArrowPrimitiveType, + PrimitiveArray: From>>, + { + let input = vec![ + SortColumn { + values: Arc::new(PrimitiveArray::::from(primitive_array_data.clone())) + as ArrayRef, + options: primitive_array_options, + }, + SortColumn { + values: Arc::new(FixedSizeListArray::from_iter_primitive::( + list_array_data.clone(), + 1, + )) as ArrayRef, + options: list_array_options, + }, + ]; + + let expected = vec![ + Arc::new(PrimitiveArray::::from( + expected_primitive_array_data.clone(), + )) as ArrayRef, + Arc::new(FixedSizeListArray::from_iter_primitive::( + expected_list_array_data.clone(), + 1, + )) as ArrayRef, + ]; + + test_lex_sort_arrays(input.clone(), expected.clone(), None); + test_lex_sort_arrays(input.clone(), slice_arrays(expected.clone(), 0, 5), Some(5)); + } + + fn test_lex_sort_mixed_types_with_list( + list_array_data: Vec>>>, + primitive_array_data: Vec>, + expected_list_array_data: Vec>>>, + expected_primitive_array_data: Vec>, + list_array_options: Option, + primitive_array_options: Option, + ) where + T: ArrowPrimitiveType, + PrimitiveArray: From>>, + { + macro_rules! run_test { + ($ARRAY_TYPE:ident) => { + let input = vec![ + SortColumn { + values: Arc::new(<$ARRAY_TYPE>::from_iter_primitive::( + list_array_data.clone(), + )) as ArrayRef, + options: list_array_options.clone(), + }, + SortColumn { + values: Arc::new(PrimitiveArray::::from(primitive_array_data.clone())) + as ArrayRef, + options: primitive_array_options.clone(), + }, + ]; + + let expected = vec![ + Arc::new(<$ARRAY_TYPE>::from_iter_primitive::( + expected_list_array_data.clone(), + )) as ArrayRef, + Arc::new(PrimitiveArray::::from( + expected_primitive_array_data.clone(), + )) as ArrayRef, + ]; + + test_lex_sort_arrays(input.clone(), expected.clone(), None); + test_lex_sort_arrays(input.clone(), slice_arrays(expected.clone(), 0, 5), Some(5)); + }; + } + run_test!(ListArray); + run_test!(LargeListArray); } #[test] diff --git a/arrow/benches/lexsort.rs b/arrow/benches/lexsort.rs index bd2db1e5022d..cd952299df47 100644 --- a/arrow/benches/lexsort.rs +++ b/arrow/benches/lexsort.rs @@ -20,8 +20,10 @@ use arrow::row::{RowConverter, SortField}; use arrow::util::bench_util::{ create_dict_from_values, create_primitive_array, create_string_array_with_len, }; +use arrow::util::data_gen::create_random_array; use arrow_array::types::Int32Type; use arrow_array::{Array, ArrayRef, UInt32Array}; +use arrow_schema::{DataType, Field}; use criterion::{criterion_group, criterion_main, Criterion}; use std::sync::Arc; @@ -33,6 +35,10 @@ enum Column { Optional16CharString, Optional50CharString, Optional100Value50CharStringDict, + RequiredI32List, + OptionalI32List, + Required4CharStringList, + Optional4CharStringList, } impl std::fmt::Debug for Column { @@ -44,6 +50,10 @@ impl std::fmt::Debug for Column { Column::Optional16CharString => "str_opt(16)", Column::Optional50CharString => "str_opt(50)", Column::Optional100Value50CharStringDict => "dict(100,str_opt(50))", + Column::RequiredI32List => "i32_list", + Column::OptionalI32List => "i32_list_opt", + Column::Required4CharStringList => "str_list(4)", + Column::Optional4CharStringList => "str_list_opt(4)", }; f.write_str(s) } @@ -70,6 +80,38 @@ impl Column { &create_string_array_with_len::(100, 0., 50), )) } + Column::RequiredI32List => { + let field = Field::new( + "_1", + DataType::List(Arc::new(Field::new("item", DataType::Int32, false))), + true, + ); + create_random_array(&field, size, 0., 1.).unwrap() + } + Column::OptionalI32List => { + let field = Field::new( + "_1", + DataType::List(Arc::new(Field::new("item", DataType::Int32, true))), + true, + ); + create_random_array(&field, size, 0.2, 1.).unwrap() + } + Column::Required4CharStringList => { + let field = Field::new( + "_1", + DataType::List(Arc::new(Field::new("item", DataType::Utf8, false))), + true, + ); + create_random_array(&field, size, 0., 1.).unwrap() + } + Column::Optional4CharStringList => { + let field = Field::new( + "_1", + DataType::List(Arc::new(Field::new("item", DataType::Utf8, true))), + true, + ); + create_random_array(&field, size, 0.2, 1.).unwrap() + } } } } @@ -150,6 +192,23 @@ fn add_benchmark(c: &mut Criterion) { Column::Optional100Value50CharStringDict, Column::Optional50CharString, ], + &[Column::OptionalI32, Column::RequiredI32List], + &[Column::OptionalI32, Column::OptionalI32List], + &[Column::OptionalI32List, Column::OptionalI32], + &[Column::RequiredI32, Column::Required4CharStringList], + &[Column::Required4CharStringList, Column::RequiredI32], + &[Column::RequiredI32, Column::Optional4CharStringList], + &[Column::Optional4CharStringList, Column::RequiredI32], + &[ + Column::RequiredI32, + Column::RequiredI32List, + Column::Required16CharString, + ], + &[ + Column::OptionalI32, + Column::OptionalI32List, + Column::Optional50CharString, + ], ]; for case in cases {