Skip to content

Commit

Permalink
fix ordering
Browse files Browse the repository at this point in the history
Signed-off-by: jayzhan211 <[email protected]>
  • Loading branch information
jayzhan211 committed Apr 21, 2024
1 parent cb16bce commit 040df2c
Show file tree
Hide file tree
Showing 2 changed files with 58 additions and 13 deletions.
69 changes: 57 additions & 12 deletions arrow-ord/src/sort.rs
Original file line number Diff line number Diff line change
Expand Up @@ -686,7 +686,7 @@ pub fn lexsort_to_indices(
let lexicographical_comparator = LexicographicalComparator::try_new(columns)?;
// uint32 can be sorted unstably
sort_unstable_by(&mut value_indices, len, |a, b| {
lexicographical_comparator.compare(*a, *b).ordering(true)
lexicographical_comparator.compare(*a, *b)
});

Ok(UInt32Array::from_iter_values(
Expand Down Expand Up @@ -719,47 +719,50 @@ pub struct LexicographicalComparator {

impl LexicographicalComparator {
/// lexicographically compare values at the wrapped columns with given indices.
pub fn compare(&self, a_idx: usize, b_idx: usize) -> Compare {
pub fn compare(&self, a_idx: usize, b_idx: usize) -> Ordering {
for (nulls, comparator, sort_option) in &self.compare_items {
let (lhs_valid, rhs_valid) = match nulls {
Some(n) => (n.is_valid(a_idx), n.is_valid(b_idx)),
None => (true, true),
};

// TODO: after dict values with nulls are supported, we can compare with
// (comparator)(a_idx, b_idx) directly

match (lhs_valid, rhs_valid) {
(true, true) => {
match (comparator)(a_idx, b_idx) {
// equal, move on to next column
Compare::Equal => continue,
order => {
if sort_option.descending {
return order.reverse();
return order.reverse().ordering(sort_option.nulls_first);
} else {
return order;
return order.ordering(sort_option.nulls_first);
}
}
}
}
(false, true) => {
return if sort_option.nulls_first {
Compare::Less
Ordering::Less
} else {
Compare::Greater
Ordering::Greater
};
}
(true, false) => {
return if sort_option.nulls_first {
Compare::Greater
Ordering::Greater
} else {
Compare::Less
Ordering::Less
};
}
// equal, move on to next column
(false, false) => continue,
}
}

Compare::Equal
Ordering::Equal
}

/// Create a new lex comparator that will wrap the given sort columns and give comparison
Expand Down Expand Up @@ -800,6 +803,7 @@ impl LexicographicalComparator {
&rank[start..end]
}};
}

Ord::cmp(nth_value!(i), nth_value!(j)).into()
});
Ok(cmp)
Expand Down Expand Up @@ -4203,11 +4207,11 @@ mod tests {
}])
.unwrap();
// 1.cmp(NULL)
assert_eq!(comparator.compare(0, 1), Compare::Greater);
assert_eq!(comparator.compare(0, 1), Ordering::Greater);
// NULL.cmp(NULL)
assert_eq!(comparator.compare(2, 1), Compare::Equal);
assert_eq!(comparator.compare(2, 1), Ordering::Equal);
// NULL.cmp(4)
assert_eq!(comparator.compare(2, 3), Compare::Less);
assert_eq!(comparator.compare(2, 3), Ordering::Less);
}

#[test]
Expand Down Expand Up @@ -4236,4 +4240,45 @@ mod tests {
let sort_indices = sort_to_indices(&a, None, None).unwrap();
assert_eq!(sort_indices.values(), &[1, 2, 0]);
}

#[test]
fn test_build_list_compare() {
let a1 = Arc::new(Int64Array::from(vec![
Some(1),
None,
Some(2),
Some(3),
None,
])) as ArrayRef;
let comparator = LexicographicalComparator::try_new(&[SortColumn {
values: a1.clone(),
options: None,
}])
.unwrap();

// default is null first
// 1 vs null
assert_eq!(Ordering::Greater, comparator.compare(0, 1));
// null vs 2
assert_eq!(Ordering::Less, comparator.compare(1, 2));
// null vs null
assert_eq!(Ordering::Equal, comparator.compare(1, 4));

let comparator = LexicographicalComparator::try_new(&[SortColumn {
values: a1,
options: Some(SortOptions {
descending: false,
nulls_first: false,
}),
}])
.unwrap();

// nulls last
// 1 vs null
assert_eq!(Ordering::Less, comparator.compare(0, 1));
// null vs 2
assert_eq!(Ordering::Greater, comparator.compare(1, 2));
// null vs null
assert_eq!(Ordering::Equal, comparator.compare(1, 4));
}
}
2 changes: 1 addition & 1 deletion arrow-row/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2194,7 +2194,7 @@ mod tests {
for j in 0..len {
let row_i = rows.row(i);
let row_j = rows.row(j);
let row_cmp: Compare = row_i.cmp(&row_j).into();
let row_cmp = row_i.cmp(&row_j);
let lex_cmp = comparator.compare(i, j);
assert_eq!(
row_cmp,
Expand Down

0 comments on commit 040df2c

Please sign in to comment.