Skip to content

Commit

Permalink
refactor: use a generic consistent total ordering, also for floats (#…
Browse files Browse the repository at this point in the history
…11468)

Co-authored-by: ritchie <[email protected]>
  • Loading branch information
orlp and ritchie46 authored Oct 6, 2023
1 parent f090cab commit 5dfc006
Show file tree
Hide file tree
Showing 29 changed files with 686 additions and 328 deletions.
79 changes: 8 additions & 71 deletions crates/nano-arrow/src/array/ord.rs
Original file line number Diff line number Diff line change
@@ -1,55 +1,20 @@
//! Contains functions and function factories to order values within arrays.
use std::cmp::Ordering;

use crate::array::*;
use crate::datatypes::*;
use crate::error::{Error, Result};
use crate::offset::Offset;
use crate::types::NativeType;
use crate::util::total_ord::TotalOrd;

/// Compare the values at two arbitrary indices in two arrays.
pub type DynComparator = Box<dyn Fn(usize, usize) -> Ordering + Send + Sync>;

/// implements comparison using IEEE 754 total ordering for f32
// Original implementation from https://doc.rust-lang.org/std/primitive.f32.html#method.total_cmp
// TODO to change to use std when it becomes stable
#[inline]
pub fn total_cmp_f32(l: &f32, r: &f32) -> std::cmp::Ordering {
let mut left = l.to_bits() as i32;
let mut right = r.to_bits() as i32;

left ^= (((left >> 31) as u32) >> 1) as i32;
right ^= (((right >> 31) as u32) >> 1) as i32;

left.cmp(&right)
}

/// implements comparison using IEEE 754 total ordering for f64
// Original implementation from https://doc.rust-lang.org/std/primitive.f64.html#method.total_cmp
// TODO to change to use std when it becomes stable
#[inline]
pub fn total_cmp_f64(l: &f64, r: &f64) -> std::cmp::Ordering {
let mut left = l.to_bits() as i64;
let mut right = r.to_bits() as i64;

left ^= (((left >> 63) as u64) >> 1) as i64;
right ^= (((right >> 63) as u64) >> 1) as i64;

left.cmp(&right)
}

/// Total order of all native types whose Rust implementation
/// that support total order.
#[inline]
pub fn total_cmp<T>(l: &T, r: &T) -> std::cmp::Ordering
where
T: NativeType + Ord,
{
l.cmp(r)
}

fn compare_primitives<T: NativeType + Ord>(left: &dyn Array, right: &dyn Array) -> DynComparator {
fn compare_primitives<T: NativeType + TotalOrd>(
left: &dyn Array,
right: &dyn Array,
) -> DynComparator {
let left = left
.as_any()
.downcast_ref::<PrimitiveArray<T>>()
Expand All @@ -60,7 +25,7 @@ fn compare_primitives<T: NativeType + Ord>(left: &dyn Array, right: &dyn Array)
.downcast_ref::<PrimitiveArray<T>>()
.unwrap()
.clone();
Box::new(move |i, j| total_cmp(&left.value(i), &right.value(j)))
Box::new(move |i, j| left.value(i).tot_cmp(&right.value(j)))
}

fn compare_boolean(left: &dyn Array, right: &dyn Array) -> DynComparator {
Expand All @@ -77,34 +42,6 @@ fn compare_boolean(left: &dyn Array, right: &dyn Array) -> DynComparator {
Box::new(move |i, j| left.value(i).cmp(&right.value(j)))
}

fn compare_f32(left: &dyn Array, right: &dyn Array) -> DynComparator {
let left = left
.as_any()
.downcast_ref::<PrimitiveArray<f32>>()
.unwrap()
.clone();
let right = right
.as_any()
.downcast_ref::<PrimitiveArray<f32>>()
.unwrap()
.clone();
Box::new(move |i, j| total_cmp_f32(&left.value(i), &right.value(j)))
}

fn compare_f64(left: &dyn Array, right: &dyn Array) -> DynComparator {
let left = left
.as_any()
.downcast_ref::<PrimitiveArray<f64>>()
.unwrap()
.clone();
let right = right
.as_any()
.downcast_ref::<PrimitiveArray<f64>>()
.unwrap()
.clone();
Box::new(move |i, j| total_cmp_f64(&left.value(i), &right.value(j)))
}

fn compare_string<O: Offset>(left: &dyn Array, right: &dyn Array) -> DynComparator {
let left = left
.as_any()
Expand Down Expand Up @@ -212,8 +149,8 @@ pub fn build_compare(left: &dyn Array, right: &dyn Array) -> Result<DynComparato
| (Duration(Millisecond), Duration(Millisecond))
| (Duration(Microsecond), Duration(Microsecond))
| (Duration(Nanosecond), Duration(Nanosecond)) => compare_primitives::<i64>(left, right),
(Float32, Float32) => compare_f32(left, right),
(Float64, Float64) => compare_f64(left, right),
(Float32, Float32) => compare_primitives::<f32>(left, right),
(Float64, Float64) => compare_primitives::<f64>(left, right),
(Decimal(_, _), Decimal(_, _)) => compare_primitives::<i128>(left, right),
(Utf8, Utf8) => compare_string::<i32>(left, right),
(LargeUtf8, LargeUtf8) => compare_string::<i64>(left, right),
Expand Down
40 changes: 40 additions & 0 deletions crates/nano-arrow/src/bitmap/bitmask.rs
Original file line number Diff line number Diff line change
Expand Up @@ -217,6 +217,46 @@ impl<'a> BitMask<'a> {
None
}

/// Computes the index of the nth set bit before end, counting backwards.
///
/// Both are zero-indexed, so nth_set_bit_idx_rev(0, len) finds the index of
/// the last bit set (which can be 0 as well). The returned index is
/// absolute (and starts at the beginning), not relative to end.
pub fn nth_set_bit_idx_rev(&self, mut n: usize, mut end: usize) -> Option<usize> {
while end > 0 {
// We want to find bits *before* end, so if end < 32 we must mask
// out the bits after the endth.
let (u32_mask_start, u32_mask_mask) = if end >= 32 {
(end - 32, u32::MAX)
} else {
(0, (1 << end) - 1)
};
let next_u32_mask = self.get_u32(u32_mask_start) & u32_mask_mask;
if next_u32_mask == u32::MAX {
// Happy fast path for dense non-null section.
if n < 32 {
return Some(end - 1 - n);
}
n -= 32;
} else {
let ones = next_u32_mask.count_ones() as usize;
if n < ones {
let rev_n = ones - 1 - n;
let idx = unsafe {
// SAFETY: we know the rev_nth bit is in the mask.
nth_set_bit_u32(next_u32_mask, rev_n as u32).unwrap_unchecked() as usize
};
return Some(u32_mask_start + idx);
}
n -= ones;
}

end = u32_mask_start;
}

None
}

#[inline]
pub fn get(&self, idx: usize) -> bool {
let byte_idx = (self.offset + idx) / 8;
Expand Down
22 changes: 14 additions & 8 deletions crates/nano-arrow/src/compute/arithmetics/basic/div.rs
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
//! Definition of basic div operations with primitive arrays
use std::ops::Div;
use std::ops::{Add, Div};

use num_traits::{CheckedDiv, NumCast};
use strength_reduce::{
Expand Down Expand Up @@ -29,14 +29,18 @@ use crate::datatypes::PrimitiveType;
/// ```
pub fn div<T>(lhs: &PrimitiveArray<T>, rhs: &PrimitiveArray<T>) -> PrimitiveArray<T>
where
T: NativeArithmetics + Div<Output = T>,
T: NativeArithmetics + Add<Output = T> + Div<Output = T>,
{
// Adding zero to divisor ensures x/0 becomes +infinity, ignoring
// the sign of the zero.
if rhs.null_count() == 0 {
binary(lhs, rhs, lhs.data_type().clone(), |a, b| a / b)
binary(lhs, rhs, lhs.data_type().clone(), |a, b| {
a / (b + T::zeroed())
})
} else {
check_same_len(lhs, rhs).unwrap();
let values = lhs.iter().zip(rhs.iter()).map(|(l, r)| match (l, r) {
(Some(l), Some(r)) => Some(*l / *r),
(Some(l), Some(r)) => Some(*l / (*r + T::zeroed())),
_ => None,
});

Expand All @@ -61,17 +65,19 @@ where
/// ```
pub fn checked_div<T>(lhs: &PrimitiveArray<T>, rhs: &PrimitiveArray<T>) -> PrimitiveArray<T>
where
T: NativeArithmetics + CheckedDiv<Output = T>,
T: NativeArithmetics + Add<Output = T> + CheckedDiv<Output = T>,
{
let op = move |a: T, b: T| a.checked_div(&b);
// Adding zero to divisor ensures x/0 becomes +infinity, ignoring
// the sign of the zero.
let op = move |a: T, b: T| a.checked_div(&(b + T::zeroed()));

binary_checked(lhs, rhs, lhs.data_type().clone(), op)
}

// Implementation of ArrayDiv trait for PrimitiveArrays
impl<T> ArrayDiv<PrimitiveArray<T>> for PrimitiveArray<T>
where
T: NativeArithmetics + Div<Output = T>,
T: NativeArithmetics + Add<Output = T> + Div<Output = T>,
{
fn div(&self, rhs: &PrimitiveArray<T>) -> Self {
div(self, rhs)
Expand All @@ -81,7 +87,7 @@ where
// Implementation of ArrayCheckedDiv trait for PrimitiveArrays
impl<T> ArrayCheckedDiv<PrimitiveArray<T>> for PrimitiveArray<T>
where
T: NativeArithmetics + CheckedDiv<Output = T>,
T: NativeArithmetics + Add<Output = T> + CheckedDiv<Output = T>,
{
fn checked_div(&self, rhs: &PrimitiveArray<T>) -> Self {
checked_div(self, rhs)
Expand Down
2 changes: 2 additions & 0 deletions crates/nano-arrow/src/util/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -22,3 +22,5 @@ pub use lexical::*;
#[cfg(feature = "benchmarks")]
#[cfg_attr(docsrs, doc(cfg(feature = "benchmarks")))]
pub mod bench_util;

pub mod total_ord;
Loading

0 comments on commit 5dfc006

Please sign in to comment.