Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Interval Comparison #5180

Closed
wants to merge 9 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions arrow-array/src/array/primitive_array.rs
Original file line number Diff line number Diff line change
Expand Up @@ -1180,6 +1180,7 @@ def_from_for_primitive!(Int8Type, i8);
def_from_for_primitive!(Int16Type, i16);
def_from_for_primitive!(Int32Type, i32);
def_from_for_primitive!(Int64Type, i64);
def_from_for_primitive!(Int128Type, i128);
def_from_for_primitive!(UInt8Type, u8);
def_from_for_primitive!(UInt16Type, u16);
def_from_for_primitive!(UInt32Type, u32);
Expand Down
70 changes: 70 additions & 0 deletions arrow-array/src/cast.rs
Original file line number Diff line number Diff line change
Expand Up @@ -330,6 +330,54 @@ macro_rules! downcast_primitive {
};
}

/// This macro functions similarly to [`downcast_primitive`], but it excludes
/// [`arrow_schema::IntervalUnit::DayTime`] and [`arrow_schema::IntervalUnit::MonthDayNano`]
/// because they cannot be simply cast to primitive types during a comparison operation.
#[macro_export]
macro_rules! downcast_primitive_cmp {
($($data_type:expr),+ => ($m:path $(, $args:tt)*), $($p:pat => $fallback:expr $(,)*)*) => {
$crate::downcast_integer! {
$($data_type),+ => ($m $(, $args)*),
$crate::repeat_pat!(arrow_schema::DataType::Float16, $($data_type),+) => {
$m!($crate::types::Float16Type $(, $args)*)
}
$crate::repeat_pat!(arrow_schema::DataType::Float32, $($data_type),+) => {
$m!($crate::types::Float32Type $(, $args)*)
}
$crate::repeat_pat!(arrow_schema::DataType::Float64, $($data_type),+) => {
$m!($crate::types::Float64Type $(, $args)*)
}
$crate::repeat_pat!(arrow_schema::DataType::Decimal128(_, _), $($data_type),+) => {
$m!($crate::types::Decimal128Type $(, $args)*)
}
$crate::repeat_pat!(arrow_schema::DataType::Decimal256(_, _), $($data_type),+) => {
$m!($crate::types::Decimal256Type $(, $args)*)
}
$crate::repeat_pat!(arrow_schema::DataType::Interval(arrow_schema::IntervalUnit::YearMonth), $($data_type),+) => {
$m!($crate::types::IntervalYearMonthType $(, $args)*)
}
$crate::repeat_pat!(arrow_schema::DataType::Duration(arrow_schema::TimeUnit::Second), $($data_type),+) => {
$m!($crate::types::DurationSecondType $(, $args)*)
}
$crate::repeat_pat!(arrow_schema::DataType::Duration(arrow_schema::TimeUnit::Millisecond), $($data_type),+) => {
$m!($crate::types::DurationMillisecondType $(, $args)*)
}
$crate::repeat_pat!(arrow_schema::DataType::Duration(arrow_schema::TimeUnit::Microsecond), $($data_type),+) => {
$m!($crate::types::DurationMicrosecondType $(, $args)*)
}
$crate::repeat_pat!(arrow_schema::DataType::Duration(arrow_schema::TimeUnit::Nanosecond), $($data_type),+) => {
$m!($crate::types::DurationNanosecondType $(, $args)*)
}
_ => {
$crate::downcast_temporal! {
$($data_type),+ => ($m $(, $args)*),
$($p => $fallback,)*
}
}
}
};
}

#[macro_export]
#[doc(hidden)]
macro_rules! downcast_primitive_array_helper {
Expand Down Expand Up @@ -383,6 +431,28 @@ macro_rules! downcast_primitive_array {
};
}

/// This macro serves a similar function to [`downcast_primitive_array`], but it
/// incorporates [`downcast_primitive_cmp`]. [`downcast_primitive_cmp`] is a specialized
/// version of [`downcast_primitive`] designed specifically for comparison operations.
#[macro_export]
macro_rules! downcast_primitive_array_cmp {
($values:ident => $e:expr, $($p:pat => $fallback:expr $(,)*)*) => {
$crate::downcast_primitive_array_cmp!($values => {$e} $($p => $fallback)*)
};
(($($values:ident),+) => $e:expr, $($p:pat => $fallback:expr $(,)*)*) => {
$crate::downcast_primitive_array_cmp!($($values),+ => {$e} $($p => $fallback)*)
};
($($values:ident),+ => $e:block $($p:pat => $fallback:expr $(,)*)*) => {
$crate::downcast_primitive_array_cmp!(($($values),+) => $e $($p => $fallback)*)
};
(($($values:ident),+) => $e:block $($p:pat => $fallback:expr $(,)*)*) => {
$crate::downcast_primitive_cmp!{
$($values.data_type()),+ => ($crate::downcast_primitive_array_helper, $($values),+, $e),
$($p => $fallback,)*
}
};
}

/// Force downcast of an [`Array`], such as an [`ArrayRef`], to
/// [`PrimitiveArray<T>`], panic'ing on failure.
///
Expand Down
6 changes: 6 additions & 0 deletions arrow-array/src/types.rs
Original file line number Diff line number Diff line change
Expand Up @@ -107,6 +107,12 @@ make_type!(
DataType::Int64,
"A signed 64-bit integer type."
);
make_type!(
Int128Type,
i128,
DataType::Int128,
"A signed 128-bit integer type."
);
make_type!(
UInt8Type,
u8,
Expand Down
2 changes: 2 additions & 0 deletions arrow-data/src/data.rs
Original file line number Diff line number Diff line change
Expand Up @@ -84,6 +84,7 @@ pub(crate) fn new_buffers(data_type: &DataType, capacity: usize) -> [MutableBuff
| DataType::Int16
| DataType::Int32
| DataType::Int64
| DataType::Int128
| DataType::Float16
| DataType::Float32
| DataType::Float64
Expand Down Expand Up @@ -1509,6 +1510,7 @@ pub fn layout(data_type: &DataType) -> DataTypeLayout {
DataType::Int16 => DataTypeLayout::new_fixed_width::<i16>(),
DataType::Int32 => DataTypeLayout::new_fixed_width::<i32>(),
DataType::Int64 => DataTypeLayout::new_fixed_width::<i64>(),
DataType::Int128 => DataTypeLayout::new_fixed_width::<i128>(),
DataType::UInt8 => DataTypeLayout::new_fixed_width::<u8>(),
DataType::UInt16 => DataTypeLayout::new_fixed_width::<u16>(),
DataType::UInt32 => DataTypeLayout::new_fixed_width::<u32>(),
Expand Down
1 change: 1 addition & 0 deletions arrow-data/src/equal/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -74,6 +74,7 @@ fn equal_values(
DataType::Int16 => primitive_equal::<i16>(lhs, rhs, lhs_start, rhs_start, len),
DataType::Int32 => primitive_equal::<i32>(lhs, rhs, lhs_start, rhs_start, len),
DataType::Int64 => primitive_equal::<i64>(lhs, rhs, lhs_start, rhs_start, len),
DataType::Int128 => primitive_equal::<i128>(lhs, rhs, lhs_start, rhs_start, len),
DataType::Float32 => primitive_equal::<f32>(lhs, rhs, lhs_start, rhs_start, len),
DataType::Float64 => primitive_equal::<f64>(lhs, rhs, lhs_start, rhs_start, len),
DataType::Decimal128(_, _) => primitive_equal::<i128>(lhs, rhs, lhs_start, rhs_start, len),
Expand Down
3 changes: 3 additions & 0 deletions arrow-data/src/transform/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -209,6 +209,7 @@ fn build_extend(array: &ArrayData) -> Extend {
DataType::Int16 => primitive::build_extend::<i16>(array),
DataType::Int32 => primitive::build_extend::<i32>(array),
DataType::Int64 => primitive::build_extend::<i64>(array),
DataType::Int128 => primitive::build_extend::<i128>(array),
DataType::Float32 => primitive::build_extend::<f32>(array),
DataType::Float64 => primitive::build_extend::<f64>(array),
DataType::Date32 | DataType::Time32(_) | DataType::Interval(IntervalUnit::YearMonth) => {
Expand Down Expand Up @@ -251,6 +252,7 @@ fn build_extend_nulls(data_type: &DataType) -> ExtendNulls {
DataType::Int16 => primitive::extend_nulls::<i16>,
DataType::Int32 => primitive::extend_nulls::<i32>,
DataType::Int64 => primitive::extend_nulls::<i64>,
DataType::Int128 => primitive::extend_nulls::<i128>,
DataType::Float32 => primitive::extend_nulls::<f32>,
DataType::Float64 => primitive::extend_nulls::<f64>,
DataType::Date32 | DataType::Time32(_) | DataType::Interval(IntervalUnit::YearMonth) => {
Expand Down Expand Up @@ -404,6 +406,7 @@ impl<'a> MutableArrayData<'a> {
| DataType::Int16
| DataType::Int32
| DataType::Int64
| DataType::Int128
| DataType::Float16
| DataType::Float32
| DataType::Float64
Expand Down
1 change: 1 addition & 0 deletions arrow-integration-test/src/datatype.rs
Original file line number Diff line number Diff line change
Expand Up @@ -260,6 +260,7 @@ pub fn data_type_to_json(data_type: &DataType) -> serde_json::Value {
DataType::Int16 => json!({"name": "int", "bitWidth": 16, "isSigned": true}),
DataType::Int32 => json!({"name": "int", "bitWidth": 32, "isSigned": true}),
DataType::Int64 => json!({"name": "int", "bitWidth": 64, "isSigned": true}),
DataType::Int128 => json!({"name": "int", "bitWidth": 128, "isSigned": true}),
DataType::UInt8 => json!({"name": "int", "bitWidth": 8, "isSigned": false}),
DataType::UInt16 => json!({"name": "int", "bitWidth": 16, "isSigned": false}),
DataType::UInt32 => json!({"name": "int", "bitWidth": 32, "isSigned": false}),
Expand Down
3 changes: 2 additions & 1 deletion arrow-ipc/src/convert.rs
Original file line number Diff line number Diff line change
Expand Up @@ -499,7 +499,7 @@ pub(crate) fn get_fb_field_type<'a>(
children: Some(children),
}
}
Int8 | Int16 | Int32 | Int64 => {
Int8 | Int16 | Int32 | Int64 | Int128 => {
let children = fbb.create_vector(&empty_fields[..]);
let mut builder = crate::IntBuilder::new(fbb);
builder.add_is_signed(true);
Expand All @@ -508,6 +508,7 @@ pub(crate) fn get_fb_field_type<'a>(
Int16 => builder.add_bitWidth(16),
Int32 => builder.add_bitWidth(32),
Int64 => builder.add_bitWidth(64),
Int128 => builder.add_bitWidth(128),
_ => {}
};
FBFieldType {
Expand Down
155 changes: 149 additions & 6 deletions arrow-ord/src/cmp.rs
Original file line number Diff line number Diff line change
Expand Up @@ -23,17 +23,20 @@
//! [here](https://doc.rust-lang.org/stable/core/arch/) for more information.
//!

use std::ops::Not;

use arrow_array::cast::AsArray;
use arrow_array::types::ByteArrayType;
use arrow_array::types::{
ByteArrayType, Int128Type, Int64Type, IntervalDayTimeType, IntervalMonthDayNanoType,
};
use arrow_array::{
downcast_primitive_array, AnyDictionaryArray, Array, ArrowNativeTypeOp, BooleanArray, Datum,
FixedSizeBinaryArray, GenericByteArray,
downcast_primitive_array_cmp, AnyDictionaryArray, Array, ArrowNativeTypeOp, BooleanArray,
Datum, FixedSizeBinaryArray, GenericByteArray, PrimitiveArray,
};
use arrow_buffer::bit_util::ceil;
use arrow_buffer::{BooleanBuffer, MutableBuffer, NullBuffer};
use arrow_schema::ArrowError;
use arrow_schema::{ArrowError, IntervalUnit};
use arrow_select::take::take;
use std::ops::Not;

#[derive(Debug, Copy, Clone)]
enum Op {
Expand Down Expand Up @@ -206,14 +209,16 @@ fn compare_op(op: Op, lhs: &dyn Datum, rhs: &dyn Datum) -> Result<BooleanArray,

// Defer computation as may not be necessary
let values = || -> BooleanBuffer {
let d = downcast_primitive_array! {
let d = downcast_primitive_array_cmp! {
(l, r) => apply(op, l.values().as_ref(), l_s, l_v, r.values().as_ref(), r_s, r_v),
(Boolean, Boolean) => apply(op, l.as_boolean(), l_s, l_v, r.as_boolean(), r_s, r_v),
(Utf8, Utf8) => apply(op, l.as_string::<i32>(), l_s, l_v, r.as_string::<i32>(), r_s, r_v),
(LargeUtf8, LargeUtf8) => apply(op, l.as_string::<i64>(), l_s, l_v, r.as_string::<i64>(), r_s, r_v),
(Binary, Binary) => apply(op, l.as_binary::<i32>(), l_s, l_v, r.as_binary::<i32>(), r_s, r_v),
(LargeBinary, LargeBinary) => apply(op, l.as_binary::<i64>(), l_s, l_v, r.as_binary::<i64>(), r_s, r_v),
(FixedSizeBinary(_), FixedSizeBinary(_)) => apply(op, l.as_fixed_size_binary(), l_s, l_v, r.as_fixed_size_binary(), r_s, r_v),
(Interval(IntervalUnit::DayTime), Interval(IntervalUnit::DayTime)) => apply_interval_dt(op, l, l_s, l_v, r, r_s, r_v),
(Interval(IntervalUnit::MonthDayNano), Interval(IntervalUnit::MonthDayNano)) => apply_interval_mdn(op, l, l_s, l_v, r, r_s, r_v),
(Null, Null) => None,
_ => unreachable!(),
};
Expand Down Expand Up @@ -341,6 +346,82 @@ fn apply<T: ArrayOrd>(
}
}

fn apply_interval_dt(
op: Op,
l: &dyn Array,
l_s: bool,
l_v: Option<&dyn AnyDictionaryArray>,
r: &dyn Array,
r_s: bool,
r_v: Option<&dyn AnyDictionaryArray>,
) -> Option<BooleanBuffer> {
let evaluate_min = apply(
op,
interval_dt_min(l).values().as_ref(),
l_s,
l_v,
interval_dt_min(r).values().as_ref(),
r_s,
r_v,
);
let evaluate_max = apply(
op,
interval_dt_max(l).values().as_ref(),
l_s,
l_v,
interval_dt_max(r).values().as_ref(),
r_s,
r_v,
);
definite_comparison(evaluate_min, evaluate_max)
}

fn apply_interval_mdn(
op: Op,
l: &dyn Array,
l_s: bool,
l_v: Option<&dyn AnyDictionaryArray>,
r: &dyn Array,
r_s: bool,
r_v: Option<&dyn AnyDictionaryArray>,
) -> Option<BooleanBuffer> {
let evaluate_min = apply(
op,
interval_mdn_min(l).values().as_ref(),
l_s,
l_v,
interval_mdn_min(r).values().as_ref(),
r_s,
r_v,
);
let evaluate_max = apply(
op,
interval_mdn_max(l).values().as_ref(),
l_s,
l_v,
interval_mdn_max(r).values().as_ref(),
r_s,
r_v,
);
definite_comparison(evaluate_min, evaluate_max)
}

fn definite_comparison(
min: Option<BooleanBuffer>,
max: Option<BooleanBuffer>,
) -> Option<BooleanBuffer> {
min.and_then(|min_values| {
max.map(|max_values| {
BooleanBuffer::from_iter(
min_values
.into_iter()
.zip(&max_values)
.map(|(min, max)| min & max),
)
})
})
}

/// Perform a take operation on `buffer` with the given dictionary
fn take_bits(v: &dyn AnyDictionaryArray, buffer: BooleanBuffer) -> BooleanBuffer {
let array = take(&BooleanArray::new(buffer, None), v.keys(), None).unwrap();
Expand Down Expand Up @@ -540,6 +621,68 @@ impl<'a> ArrayOrd for &'a FixedSizeBinaryArray {
}
}

#[inline]
fn interval_dt_min(dt: &dyn Array) -> PrimitiveArray<Int64Type> {
if let Some(dt) = dt.as_primitive_opt::<IntervalDayTimeType>() {
PrimitiveArray::<Int64Type>::from_iter(dt.iter().map(|dt| {
dt.map(|dt| {
let d = dt >> 32;
let m = dt as i32 as i64;
d * (86_400_000) + m
})
}))
} else {
panic!("Invalid datatype for Interval(IntervalDayTime) comparison")
}
}

#[inline]
fn interval_dt_max(dt: &dyn Array) -> PrimitiveArray<Int64Type> {
if let Some(dt) = dt.as_primitive_opt::<IntervalDayTimeType>() {
PrimitiveArray::<Int64Type>::from_iter(dt.iter().map(|dt| {
dt.map(|dt| {
let d = dt >> 32;
let m = dt as i32 as i64;
d * (86_400_000 + 1_000) + m
})
}))
} else {
panic!("Invalid datatype for Interval(IntervalDayTime) comparison")
}
}

#[inline]
fn interval_mdn_min(mdn: &dyn Array) -> PrimitiveArray<Int128Type> {
if let Some(mdn) = mdn.as_primitive_opt::<IntervalMonthDayNanoType>() {
PrimitiveArray::<Int128Type>::from_iter(mdn.iter().map(|mdn| {
mdn.map(|mdn| {
let m = (mdn >> 96) as i32;
let d = (mdn >> 64) as i32;
let n = mdn as i64;
((m as i128 * 28) + d as i128) * (86_400_000_000_000) + n as i128
})
}))
} else {
panic!("Invalid datatype for Interval(IntervalMonthDayNano) comparison")
}
}

#[inline]
fn interval_mdn_max(mdn: &dyn Array) -> PrimitiveArray<Int128Type> {
if let Some(mdn) = mdn.as_primitive_opt::<IntervalMonthDayNanoType>() {
PrimitiveArray::<Int128Type>::from_iter(mdn.iter().map(|mdn| {
mdn.map(|mdn| {
let m = (mdn >> 96) as i32;
let d = (mdn >> 64) as i32;
let n = mdn as i64;
((m as i128 * 31) + d as i128) * (86_400_000_000_000 + 1_000_000_000) + n as i128
})
}))
} else {
panic!("Invalid datatype for Interval(IntervalMonthDayNano) comparison")
}
}

#[cfg(test)]
mod tests {
use std::sync::Arc;
Expand Down
Loading
Loading