Skip to content

Commit

Permalink
feat: optimize FoRArray compare (#1656)
Browse files Browse the repository at this point in the history
Fixes #1570 
Fixes #1420

---------

Co-authored-by: Will Manning <[email protected]>
  • Loading branch information
gatesn and lwwmanning authored Jan 7, 2025
1 parent 28a867e commit 116de09
Show file tree
Hide file tree
Showing 2 changed files with 131 additions and 2 deletions.
123 changes: 123 additions & 0 deletions encodings/fastlanes/src/for/compute/compare.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,123 @@
use num_traits::{CheckedShr, WrappingSub};
use vortex_array::array::ConstantArray;
use vortex_array::compute::{compare, CompareFn, Operator};
use vortex_array::{ArrayDType, ArrayData, ArrayLen, IntoArrayData};
use vortex_dtype::{match_each_integer_ptype, NativePType};
use vortex_error::{vortex_err, VortexError, VortexResult};
use vortex_scalar::{PValue, PrimitiveScalar, Scalar};

use crate::{FoRArray, FoREncoding};

impl CompareFn<FoRArray> for FoREncoding {
fn compare(
&self,
lhs: &FoRArray,
rhs: &ArrayData,
operator: Operator,
) -> VortexResult<Option<ArrayData>> {
if let Some(constant) = rhs.as_constant() {
if let Ok(constant) = PrimitiveScalar::try_from(&constant) {
match_each_integer_ptype!(constant.ptype(), |$T| {
return compare_constant(lhs, constant.typed_value::<$T>(), operator);
})
}
}

Ok(None)
}
}

fn compare_constant<T>(
lhs: &FoRArray,
rhs: Option<T>,
operator: Operator,
) -> VortexResult<Option<ArrayData>>
where
T: NativePType + WrappingSub + CheckedShr,
T: TryFrom<PValue, Error = VortexError>,
Scalar: From<Option<T>>,
{
// For now, we only support equals and not equals. Comparisons are a little more fiddly to
// get right regarding how to handle overflow and the wrapping subtraction.
if !matches!(operator, Operator::Eq | Operator::NotEq) {
return Ok(None);
}

let reference = lhs.reference_scalar();
let reference = reference.as_primitive().typed_value::<T>();

// We encode the RHS into the FoR domain.
let rhs = rhs
.map(|mut rhs| {
if let Some(reference) = reference {
rhs = rhs.wrapping_sub(&reference);
}
if lhs.shift() > 0 {
rhs = rhs
.checked_shr(lhs.shift() as u32)
.ok_or_else(|| vortex_err!("Shift overflow"))?;
}
Ok::<_, VortexError>(rhs)
})
.transpose()?;

// Wrap up the RHS into a scalar and cast to the encoded DType (this will be the equivalent
// unsigned integer type).
let rhs = Scalar::from(rhs).cast(lhs.encoded().dtype())?;

compare(
lhs.encoded(),
ConstantArray::new(rhs, lhs.len()).into_array(),
operator,
)
.map(Some)
}

#[cfg(test)]
mod tests {
use arrow_buffer::BooleanBuffer;
use vortex_array::array::PrimitiveArray;
use vortex_array::validity::Validity;
use vortex_array::IntoCanonical;
use vortex_buffer::buffer;

use super::*;

#[test]
fn test_compare_constant() {
let reference = Scalar::from(10);
// 10, 30, 12
let lhs = FoRArray::try_new(
PrimitiveArray::new(buffer!(0u32, 10, 1), Validity::AllValid).into_array(),
reference,
1,
)
.unwrap();

assert_result(
compare_constant(&lhs, Some(30i32), Operator::Eq),
[false, true, false],
);
assert_result(
compare_constant(&lhs, Some(12i32), Operator::NotEq),
[true, true, false],
);
for op in [Operator::Lt, Operator::Lte, Operator::Gt, Operator::Gte] {
assert!(compare_constant(&lhs, Some(30i32), op).unwrap().is_none());
}
}

fn assert_result<T: IntoIterator<Item = bool>>(
result: VortexResult<Option<ArrayData>>,
expected: T,
) {
let result = result
.unwrap()
.unwrap()
.into_canonical()
.unwrap()
.into_bool()
.unwrap();
assert_eq!(result.boolean_buffer(), BooleanBuffer::from_iter(expected));
}
}
Original file line number Diff line number Diff line change
@@ -1,9 +1,11 @@
mod compare;

use std::ops::AddAssign;

use num_traits::{CheckedShl, CheckedShr, WrappingAdd, WrappingSub};
use vortex_array::compute::{
filter, scalar_at, search_sorted, slice, take, ComputeVTable, FilterFn, FilterMask, ScalarAtFn,
SearchResult, SearchSortedFn, SearchSortedSide, SliceFn, TakeFn,
filter, scalar_at, search_sorted, slice, take, CompareFn, ComputeVTable, FilterFn, FilterMask,
ScalarAtFn, SearchResult, SearchSortedFn, SearchSortedSide, SliceFn, TakeFn,
};
use vortex_array::variants::PrimitiveArrayTrait;
use vortex_array::{ArrayDType, ArrayData, IntoArrayData};
Expand All @@ -14,6 +16,10 @@ use vortex_scalar::{PValue, Scalar};
use crate::{FoRArray, FoREncoding};

impl ComputeVTable for FoREncoding {
fn compare_fn(&self) -> Option<&dyn CompareFn<ArrayData>> {
Some(self)
}

fn filter_fn(&self) -> Option<&dyn FilterFn<ArrayData>> {
Some(self)
}
Expand Down

0 comments on commit 116de09

Please sign in to comment.