Skip to content

Commit

Permalink
feat: eagerly compute IsConstant stat (#1838)
Browse files Browse the repository at this point in the history
re-attempt of #1492

fixes a bug where compute functions that fall back to arrow can
erroneously return an array of length 1 if both inputs are constant
  • Loading branch information
lwwmanning authored Jan 9, 2025
1 parent bcea32a commit 93f8cb5
Show file tree
Hide file tree
Showing 6 changed files with 155 additions and 116 deletions.
10 changes: 5 additions & 5 deletions vortex-array/src/array/varbin/compute/compare.rs
Original file line number Diff line number Diff line change
Expand Up @@ -4,9 +4,9 @@ use vortex_dtype::DType;
use vortex_error::{vortex_bail, VortexResult};

use crate::array::{VarBinArray, VarBinEncoding};
use crate::arrow::{Datum, FromArrowArray};
use crate::arrow::{from_arrow_array_with_len, Datum};
use crate::compute::{CompareFn, Operator};
use crate::{ArrayDType, ArrayData, IntoArrayData};
use crate::{ArrayDType, ArrayData, ArrayLen, IntoArrayData};

// This implementation exists so we can have custom translation of RHS to arrow that's not the same as IntoCanonical
impl CompareFn<VarBinArray> for VarBinEncoding {
Expand All @@ -18,8 +18,8 @@ impl CompareFn<VarBinArray> for VarBinEncoding {
) -> VortexResult<Option<ArrayData>> {
if let Some(rhs_const) = rhs.as_constant() {
let nullable = lhs.dtype().is_nullable() || rhs_const.dtype().is_nullable();

let lhs = Datum::try_from(lhs.clone().into_array())?;
let len = lhs.len();
let lhs = unsafe { Datum::try_new(lhs.clone().into_array())? };

// TODO(robert): Handle LargeString/Binary arrays
let arrow_rhs: &dyn arrow_array::Datum = match rhs_const.dtype() {
Expand Down Expand Up @@ -48,7 +48,7 @@ impl CompareFn<VarBinArray> for VarBinEncoding {
Operator::Lte => cmp::lt_eq(&lhs, arrow_rhs)?,
};

Ok(Some(ArrayData::from_arrow(&array, nullable)))
Ok(Some(from_arrow_array_with_len(&array, len, nullable)?))
} else {
Ok(None)
}
Expand Down
50 changes: 43 additions & 7 deletions vortex-array/src/arrow/datum.rs
Original file line number Diff line number Diff line change
@@ -1,8 +1,10 @@
use arrow_array::{Array, ArrayRef, Datum as ArrowDatum};
use vortex_error::VortexError;
use vortex_error::{vortex_panic, VortexResult};

use crate::compute::slice;
use crate::{ArrayData, IntoCanonical};
use crate::array::ConstantArray;
use crate::arrow::FromArrowArray;
use crate::compute::{scalar_at, slice};
use crate::{ArrayData, IntoArrayData, IntoCanonical};

/// A wrapper around a generic Arrow array that can be used as a Datum in Arrow compute.
#[derive(Debug)]
Expand All @@ -11,10 +13,18 @@ pub struct Datum {
is_scalar: bool,
}

impl TryFrom<ArrayData> for Datum {
type Error = VortexError;

fn try_from(array: ArrayData) -> Result<Self, Self::Error> {
impl Datum {
/// Create a new [`Datum`] from an [`ArrayData`], which can then be passed to Arrow compute.
/// This is unsafe because it does not preserve the length of the array.
///
/// # Safety
/// The caller must ensure that the length of the array is preserved, and when processing
/// the result of the Arrow compute, must check whether the result is a scalar (Arrow array of length 1),
/// in which case it likely must be expanded to match the length of the original array.
///
/// The utility function [`from_arrow_array_with_len`] can be used to ensure that the length of the
/// result of the Arrow compute matches the length of the original array.
pub unsafe fn try_new(array: ArrayData) -> VortexResult<Self> {
if array.is_constant() {
Ok(Self {
array: slice(array, 0, 1)?.into_arrow()?,
Expand All @@ -34,3 +44,29 @@ impl ArrowDatum for Datum {
(&self.array, self.is_scalar)
}
}

/// Convert an Arrow array to an ArrayData with a specific length.
/// This is useful for compute functions that delegate to Arrow using [Datum],
/// which will return a scalar (length 1 Arrow array) if the input array is constant.
///
/// Panics if the length of the array is not 1 and also not equal to the expected length.
pub fn from_arrow_array_with_len<A>(array: A, len: usize, nullable: bool) -> VortexResult<ArrayData>
where
ArrayData: FromArrowArray<A>,
{
let array = ArrayData::from_arrow(array, nullable);
if array.len() == len {
return Ok(array);
}

if array.len() != 1 {
vortex_panic!(
"Array length mismatch, expected {} got {} for encoding {}",
len,
array.len(),
array.encoding().id()
);
}

Ok(ConstantArray::new(scalar_at(&array, 0)?, len).into_array())
}
80 changes: 36 additions & 44 deletions vortex-array/src/compute/binary_numeric.rs
Original file line number Diff line number Diff line change
@@ -1,12 +1,9 @@
use std::sync::Arc;

use arrow_array::ArrayRef;
use vortex_dtype::{DType, PType};
use vortex_error::{vortex_bail, VortexError, VortexResult};
use vortex_error::{vortex_bail, VortexError, VortexExpect, VortexResult};
use vortex_scalar::{BinaryNumericOperator, Scalar};

use crate::array::ConstantArray;
use crate::arrow::{Datum, FromArrowArray};
use crate::arrow::{from_arrow_array_with_len, Datum};
use crate::encoding::Encoding;
use crate::{ArrayDType, ArrayData, IntoArrayData as _};

Expand Down Expand Up @@ -121,43 +118,15 @@ pub fn binary_numeric(
// Check if LHS supports the operation directly.
if let Some(fun) = lhs.encoding().binary_numeric_fn() {
if let Some(result) = fun.binary_numeric(lhs, rhs, op)? {
debug_assert_eq!(
result.len(),
lhs.len(),
"Numeric operation length mismatch {}",
lhs.encoding().id()
);
debug_assert_eq!(
result.dtype(),
&DType::Primitive(
PType::try_from(lhs.dtype())?,
(lhs.dtype().is_nullable() || rhs.dtype().is_nullable()).into()
),
"Numeric operation dtype mismatch {}",
lhs.encoding().id()
);
check_numeric_result(&result, lhs, rhs);
return Ok(result);
}
}

// Check if RHS supports the operation directly.
if let Some(fun) = rhs.encoding().binary_numeric_fn() {
if let Some(result) = fun.binary_numeric(rhs, lhs, op.swap())? {
debug_assert_eq!(
result.len(),
lhs.len(),
"Numeric operation length mismatch {}",
rhs.encoding().id()
);
debug_assert_eq!(
result.dtype(),
&DType::Primitive(
PType::try_from(lhs.dtype())?,
(lhs.dtype().is_nullable() || rhs.dtype().is_nullable()).into()
),
"Numeric operation dtype mismatch {}",
rhs.encoding().id()
);
check_numeric_result(&result, lhs, rhs);
return Ok(result);
}
}
Expand All @@ -183,20 +152,43 @@ fn arrow_numeric(
operator: BinaryNumericOperator,
) -> VortexResult<ArrayData> {
let nullable = lhs.dtype().is_nullable() || rhs.dtype().is_nullable();
let len = lhs.len();

let lhs = Datum::try_from(lhs)?;
let rhs = Datum::try_from(rhs)?;
let left = unsafe { Datum::try_new(lhs.clone())? };
let right = unsafe { Datum::try_new(rhs.clone())? };

let array = match operator {
BinaryNumericOperator::Add => arrow_arith::numeric::add(&lhs, &rhs)?,
BinaryNumericOperator::Sub => arrow_arith::numeric::sub(&lhs, &rhs)?,
BinaryNumericOperator::RSub => arrow_arith::numeric::sub(&rhs, &lhs)?,
BinaryNumericOperator::Mul => arrow_arith::numeric::mul(&lhs, &rhs)?,
BinaryNumericOperator::Div => arrow_arith::numeric::div(&lhs, &rhs)?,
BinaryNumericOperator::RDiv => arrow_arith::numeric::div(&rhs, &lhs)?,
BinaryNumericOperator::Add => arrow_arith::numeric::add(&left, &right)?,
BinaryNumericOperator::Sub => arrow_arith::numeric::sub(&left, &right)?,
BinaryNumericOperator::RSub => arrow_arith::numeric::sub(&right, &left)?,
BinaryNumericOperator::Mul => arrow_arith::numeric::mul(&left, &right)?,
BinaryNumericOperator::Div => arrow_arith::numeric::div(&left, &right)?,
BinaryNumericOperator::RDiv => arrow_arith::numeric::div(&right, &left)?,
};

Ok(ArrayData::from_arrow(Arc::new(array) as ArrayRef, nullable))
let result = from_arrow_array_with_len(array, len, nullable)?;
check_numeric_result(&result, &lhs, &rhs);
Ok(result)
}

#[inline(always)]
fn check_numeric_result(result: &ArrayData, lhs: &ArrayData, rhs: &ArrayData) {
debug_assert_eq!(
result.len(),
lhs.len(),
"Numeric operation length mismatch {}",
rhs.encoding().id()
);
debug_assert_eq!(
result.dtype(),
&DType::Primitive(
PType::try_from(lhs.dtype())
.vortex_expect("Numeric operation DType failed to convert to PType"),
(lhs.dtype().is_nullable() || rhs.dtype().is_nullable()).into()
),
"Numeric operation dtype mismatch {}",
rhs.encoding().id()
);
}

#[cfg(feature = "test-harness")]
Expand Down
74 changes: 40 additions & 34 deletions vortex-array/src/compute/compare.rs
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@ use vortex_dtype::{DType, Nullability};
use vortex_error::{vortex_bail, VortexError, VortexResult};
use vortex_scalar::Scalar;

use crate::arrow::{Datum, FromArrowArray};
use crate::arrow::{from_arrow_array_with_len, Datum};
use crate::encoding::Encoding;
use crate::{ArrayDType, ArrayData, Canonical, IntoArrayData};

Expand Down Expand Up @@ -130,18 +130,7 @@ pub fn compare(
.and_then(|f| f.compare(left, right, operator).transpose())
.transpose()?
{
debug_assert_eq!(
result.len(),
left.len(),
"Compare length mismatch {}",
left.encoding().id()
);
debug_assert_eq!(
result.dtype(),
&DType::Bool((left.dtype().is_nullable() || right.dtype().is_nullable()).into()),
"Compare dtype mismatch {}",
left.encoding().id()
);
check_compare_result(&result, left, right);
return Ok(result);
}

Expand All @@ -151,18 +140,7 @@ pub fn compare(
.and_then(|f| f.compare(right, left, operator.swap()).transpose())
.transpose()?
{
debug_assert_eq!(
result.len(),
left.len(),
"Compare length mismatch {}",
right.encoding().id()
);
debug_assert_eq!(
result.dtype(),
&result_dtype,
"Compare dtype mismatch {}",
right.encoding().id()
);
check_compare_result(&result, left, right);
return Ok(result);
}

Expand All @@ -178,18 +156,20 @@ pub fn compare(
}

// Fallback to arrow on canonical types
arrow_compare(left, right, operator)
let result = arrow_compare(left, right, operator)?;
check_compare_result(&result, left, right);
Ok(result)
}

/// Implementation of `CompareFn` using the Arrow crate.
pub(crate) fn arrow_compare(
lhs: &ArrayData,
rhs: &ArrayData,
fn arrow_compare(
left: &ArrayData,
right: &ArrayData,
operator: Operator,
) -> VortexResult<ArrayData> {
let nullable = lhs.dtype().is_nullable() || rhs.dtype().is_nullable();
let lhs = Datum::try_from(lhs.clone())?;
let rhs = Datum::try_from(rhs.clone())?;
let nullable = left.dtype().is_nullable() || right.dtype().is_nullable();
let lhs = unsafe { Datum::try_new(left.clone())? };
let rhs = unsafe { Datum::try_new(right.clone())? };

let array = match operator {
Operator::Eq => cmp::eq(&lhs, &rhs)?,
Expand All @@ -199,8 +179,29 @@ pub(crate) fn arrow_compare(
Operator::Lt => cmp::lt(&lhs, &rhs)?,
Operator::Lte => cmp::lt_eq(&lhs, &rhs)?,
};
from_arrow_array_with_len(&array, left.len(), nullable)
}

Ok(ArrayData::from_arrow(&array, nullable))
#[inline(always)]
fn check_compare_result(result: &ArrayData, lhs: &ArrayData, rhs: &ArrayData) {
debug_assert_eq!(
result.len(),
lhs.len(),
"CompareFn result length ({}) mismatch for left encoding {}, left len {}, right encoding {}, right len {}",
result.len(),
lhs.encoding().id(),
lhs.len(),
rhs.encoding().id(),
rhs.len()
);
debug_assert_eq!(
result.dtype(),
&DType::Bool((lhs.dtype().is_nullable() || rhs.dtype().is_nullable()).into()),
"CompareFn result dtype ({}) mismatch for left encoding {}, right encoding {}",
result.dtype(),
lhs.encoding().id(),
rhs.encoding().id(),
);
}

pub fn scalar_cmp(lhs: &Scalar, rhs: &Scalar, operator: Operator) -> Scalar {
Expand Down Expand Up @@ -312,7 +313,12 @@ mod tests {
let left = ConstantArray::new(Scalar::from(2u32), 10);
let right = ConstantArray::new(Scalar::from(10u32), 10);

let compare = compare(left, right, Operator::Gt).unwrap();
let compare = compare(left.clone(), right.clone(), Operator::Gt).unwrap();
let res = compare.as_constant().unwrap();
assert_eq!(res.as_bool().value(), Some(false));
assert_eq!(compare.len(), 10);

let compare = arrow_compare(&left.into_array(), &right.into_array(), Operator::Gt).unwrap();
let res = compare.as_constant().unwrap();
assert_eq!(res.as_bool().value(), Some(false));
assert_eq!(compare.len(), 10);
Expand Down
Loading

0 comments on commit 93f8cb5

Please sign in to comment.