From b2aeb392bffae7f6a1d5321bf0ddef33c18a2ee2 Mon Sep 17 00:00:00 2001 From: berkaysynnada Date: Tue, 28 Nov 2023 14:27:04 +0300 Subject: [PATCH 1/8] Interval cmp handling --- arrow-ord/src/cmp.rs | 114 ++++++++++++++++++++++++++++++++++++++++++- 1 file changed, 112 insertions(+), 2 deletions(-) diff --git a/arrow-ord/src/cmp.rs b/arrow-ord/src/cmp.rs index 96f5aafd8697..1fa116693a14 100644 --- a/arrow-ord/src/cmp.rs +++ b/arrow-ord/src/cmp.rs @@ -24,10 +24,12 @@ //! use arrow_array::cast::AsArray; -use arrow_array::types::ByteArrayType; +use arrow_array::types::{ + ByteArrayType, Int64Type, IntervalDayTimeType, IntervalMonthDayNanoType, +}; use arrow_array::{ downcast_primitive_array, AnyDictionaryArray, Array, ArrowNativeTypeOp, BooleanArray, - Datum, FixedSizeBinaryArray, GenericByteArray, + Datum, FixedSizeBinaryArray, GenericByteArray, PrimitiveArray, }; use arrow_buffer::bit_util::ceil; use arrow_buffer::{BooleanBuffer, MutableBuffer, NullBuffer}; @@ -35,6 +37,10 @@ use arrow_schema::ArrowError; use arrow_select::take::take; use std::ops::Not; +const MILLIS_PER_DAY: i64 = 86_400_000; + +const NANOS_PER_DAY: i64 = 86_400_000_000_000; + #[derive(Debug, Copy, Clone)] enum Op { Equal, @@ -211,6 +217,13 @@ fn compare_op( ))); } + let interval_cmp = safe_cmp_for_intervals(l, r, op); + let (l, r) = if let Some((l, r)) = interval_cmp.as_ref() { + (l as &dyn Array, r as &dyn Array) + } else { + (l, r) + }; + // Defer computation as may not be necessary let values = || -> BooleanBuffer { let d = downcast_primitive_array! { @@ -551,6 +564,103 @@ impl<'a> ArrayOrd for &'a FixedSizeBinaryArray { } } +#[inline] +fn safe_cmp_for_intervals( + l: &dyn Array, + r: &dyn Array, + op: Op, +) -> Option<(PrimitiveArray, PrimitiveArray)> { + match ( + l.as_primitive_opt::(), + r.as_primitive_opt::(), + l.as_primitive_opt::(), + r.as_primitive_opt::(), + ) { + (Some(l_dt), Some(r_dt), _, _) => match op { + Op::Less | Op::LessEqual => { + let l_max = PrimitiveArray::::from( + l_dt.iter() + .map(|dt| dt.map(|dt| dt_in_millis_max(dt))) + .collect::>(), + ); + let r_min = PrimitiveArray::::from( + r_dt.iter() + .map(|dt| dt.map(|dt| dt_in_millis_min(dt))) + .collect::>(), + ); + Some((l_max, r_min)) + } + Op::Greater | Op::GreaterEqual => { + let l_min = PrimitiveArray::::from( + l_dt.iter() + .map(|dt| dt.map(|dt| dt_in_millis_min(dt))) + .collect::>(), + ); + let r_max = PrimitiveArray::::from( + r_dt.iter() + .map(|dt| dt.map(|dt| dt_in_millis_max(dt))) + .collect::>(), + ); + Some((l_min, r_max)) + } + _ => None, + }, + (_, _, Some(l_mdn), Some(r_mdn)) => match op { + Op::Less | Op::LessEqual => { + let l_max = PrimitiveArray::::from_iter( + l_mdn.iter().map(|mdn| mdn.map(|mdn| mdn_in_nanos_max(mdn))), + ); + let r_min = PrimitiveArray::::from_iter( + r_mdn.iter().map(|mdn| mdn.map(|mdn| mdn_in_nanos_min(mdn))), + ); + Some((l_max, r_min)) + } + Op::Greater | Op::GreaterEqual => { + let l_min = PrimitiveArray::::from_iter( + l_mdn.iter().map(|mdn| mdn.map(|mdn| mdn_in_nanos_min(mdn))), + ); + let r_max = PrimitiveArray::::from_iter( + r_mdn.iter().map(|mdn| mdn.map(|mdn| mdn_in_nanos_max(mdn))), + ); + Some((l_min, r_max)) + } + _ => None, + }, + + _ => None, + } +} + +#[inline] +fn dt_in_millis_max(dt: i64) -> i64 { + let d = dt >> 32; + let m = dt as i32 as i64; + d * 31 * (MILLIS_PER_DAY + 1_000) + m +} + +#[inline] +fn dt_in_millis_min(dt: i64) -> i64 { + let d = dt >> 32; + let m = dt as i32 as i64; + d * 28 * (MILLIS_PER_DAY) + m +} + +#[inline] +fn mdn_in_nanos_max(mdn: i128) -> i64 { + let m = (mdn >> 96) as i32; + let d = (mdn >> 64) as i32; + let n = mdn as i64; + ((m * 31) + d) as i64 * (NANOS_PER_DAY + 1_000_000_000) as i64 + n +} + +#[inline] +fn mdn_in_nanos_min(mdn: i128) -> i64 { + let m = (mdn >> 96) as i32; + let d = (mdn >> 64) as i32; + let n = mdn as i64; + ((m * 28) + d) as i64 * (NANOS_PER_DAY) as i64 + n +} + #[cfg(test)] mod tests { use std::sync::Arc; From 76fe5ec398157bb03e09a7f845371eab1b4efd68 Mon Sep 17 00:00:00 2001 From: berkaysynnada Date: Tue, 28 Nov 2023 15:14:18 +0300 Subject: [PATCH 2/8] test added --- arrow-ord/src/cmp.rs | 4 +-- arrow-ord/src/comparison.rs | 51 +++++++++++++++++++++++++++++++++++++ 2 files changed, 53 insertions(+), 2 deletions(-) diff --git a/arrow-ord/src/cmp.rs b/arrow-ord/src/cmp.rs index 1fa116693a14..19a16fe63a96 100644 --- a/arrow-ord/src/cmp.rs +++ b/arrow-ord/src/cmp.rs @@ -635,14 +635,14 @@ fn safe_cmp_for_intervals( fn dt_in_millis_max(dt: i64) -> i64 { let d = dt >> 32; let m = dt as i32 as i64; - d * 31 * (MILLIS_PER_DAY + 1_000) + m + d * (MILLIS_PER_DAY + 1_000) + m } #[inline] fn dt_in_millis_min(dt: i64) -> i64 { let d = dt >> 32; let m = dt as i32 as i64; - d * 28 * (MILLIS_PER_DAY) + m + d * (MILLIS_PER_DAY) + m } #[inline] diff --git a/arrow-ord/src/comparison.rs b/arrow-ord/src/comparison.rs index 4e475d8fd572..35a75444de9b 100644 --- a/arrow-ord/src/comparison.rs +++ b/arrow-ord/src/comparison.rs @@ -2183,6 +2183,57 @@ mod tests { ); } + #[test] + fn test_interval_array_unit_aware() { + let a = + IntervalDayTimeArray::from(vec![Some(IntervalDayTimeType::make_value(0, -5)),Some(IntervalDayTimeType::make_value(3, -1_000_000)),Some(IntervalDayTimeType::make_value(4, -1000)),Some(IntervalDayTimeType::make_value(10, 20)),Some(IntervalDayTimeType::make_value(1, 2))]); + let b = + IntervalDayTimeArray::from(vec![Some(IntervalDayTimeType::make_value(0, -10)),Some(IntervalDayTimeType::make_value(3, -2_000_000)),Some(IntervalDayTimeType::make_value(2, 1000)),Some(IntervalDayTimeType::make_value(5, 6)),Some(IntervalDayTimeType::make_value(1, 1))]); + let res = gt(&a, &b).unwrap(); + let res_eq = gt_eq(&a, &b).unwrap(); + assert_eq!(res, res_eq); + assert_eq!( + &res, + &BooleanArray::from( + vec![ Some(true), Some(true), Some(true), Some(true), Some(false)] + ) + ); + let res = lt(&b, &a).unwrap(); + let res_eq = lt_eq(&b, &a).unwrap(); + assert_eq!(res, res_eq); + assert_eq!( + &res, + &BooleanArray::from( + vec![ Some(true), Some(true), Some(true), Some(true), Some(false)] + ) + ); + + let a = IntervalMonthDayNanoArray::from( + vec![Some(IntervalMonthDayNanoType::make_value(0, 0, 1)),Some(IntervalMonthDayNanoType::make_value(0, 1, -1_000_000_000)),Some(IntervalMonthDayNanoType::make_value(3, 2, -100_000_000_000)),Some(IntervalMonthDayNanoType::make_value(0, 1, 1)),Some(IntervalMonthDayNanoType::make_value(1, 28, 0)), Some(IntervalMonthDayNanoType::make_value(10, 0, -1_000_000_000_000))], + ); + let b = IntervalMonthDayNanoArray::from( + vec![Some(IntervalMonthDayNanoType::make_value(0, 0,0)),Some(IntervalMonthDayNanoType::make_value(0, 1, -8_000_000_000)),Some(IntervalMonthDayNanoType::make_value(1, 25, 100_000_000_000)),Some(IntervalMonthDayNanoType::make_value(0, 1, 0)),Some(IntervalMonthDayNanoType::make_value(2, 0, 0)), Some(IntervalMonthDayNanoType::make_value(5, 150, 1_000_000_000_000))], + ); + let res = gt(&a, &b).unwrap(); + let res_eq = gt_eq(&a, &b).unwrap(); + assert_eq!(res, res_eq); + assert_eq!( + &res, + &BooleanArray::from( + vec![ Some(true), Some(true),Some(true),Some(false),Some(false), Some(false)] + ) + ); + let res = lt(&b, &a).unwrap(); + let res_eq = lt_eq(&b, &a).unwrap(); + assert_eq!(res, res_eq); + assert_eq!( + &res, + &BooleanArray::from( + vec![ Some(true), Some(true),Some(true),Some(false),Some(false), Some(false)] + ) + ); + } + macro_rules! test_binary { ($test_name:ident, $left:expr, $right:expr, $op:expr, $expected:expr) => { #[test] From b161947ad66f187ac78c9903fa0b9cff0b56d984 Mon Sep 17 00:00:00 2001 From: berkaysynnada Date: Tue, 28 Nov 2023 15:34:58 +0300 Subject: [PATCH 3/8] clippy fix --- arrow-ord/src/cmp.rs | 24 ++++++++++++------------ arrow-ord/src/comparison.rs | 8 ++------ 2 files changed, 14 insertions(+), 18 deletions(-) diff --git a/arrow-ord/src/cmp.rs b/arrow-ord/src/cmp.rs index 94a488c5aec2..e53729b04898 100644 --- a/arrow-ord/src/cmp.rs +++ b/arrow-ord/src/cmp.rs @@ -26,8 +26,8 @@ use arrow_array::cast::AsArray; use arrow_array::types::{ByteArrayType, Int64Type, IntervalDayTimeType, IntervalMonthDayNanoType}; use arrow_array::{ - downcast_primitive_array, AnyDictionaryArray, Array, ArrowNativeTypeOp, BooleanArray, - BooleanArray, Datum, FixedSizeBinaryArray, GenericByteArray, PrimitiveArray, + downcast_primitive_array, AnyDictionaryArray, Array, ArrowNativeTypeOp, BooleanArray, Datum, + FixedSizeBinaryArray, GenericByteArray, PrimitiveArray, }; use arrow_buffer::bit_util::ceil; use arrow_buffer::{BooleanBuffer, MutableBuffer, NullBuffer}; @@ -567,12 +567,12 @@ fn safe_cmp_for_intervals( Op::Less | Op::LessEqual => { let l_max = PrimitiveArray::::from( l_dt.iter() - .map(|dt| dt.map(|dt| dt_in_millis_max(dt))) + .map(|dt| dt.map(dt_in_millis_max)) .collect::>(), ); let r_min = PrimitiveArray::::from( r_dt.iter() - .map(|dt| dt.map(|dt| dt_in_millis_min(dt))) + .map(|dt| dt.map(dt_in_millis_min)) .collect::>(), ); Some((l_max, r_min)) @@ -580,12 +580,12 @@ fn safe_cmp_for_intervals( Op::Greater | Op::GreaterEqual => { let l_min = PrimitiveArray::::from( l_dt.iter() - .map(|dt| dt.map(|dt| dt_in_millis_min(dt))) + .map(|dt| dt.map(dt_in_millis_min)) .collect::>(), ); let r_max = PrimitiveArray::::from( r_dt.iter() - .map(|dt| dt.map(|dt| dt_in_millis_max(dt))) + .map(|dt| dt.map(dt_in_millis_max)) .collect::>(), ); Some((l_min, r_max)) @@ -595,19 +595,19 @@ fn safe_cmp_for_intervals( (_, _, Some(l_mdn), Some(r_mdn)) => match op { Op::Less | Op::LessEqual => { let l_max = PrimitiveArray::::from_iter( - l_mdn.iter().map(|mdn| mdn.map(|mdn| mdn_in_nanos_max(mdn))), + l_mdn.iter().map(|mdn| mdn.map(mdn_in_nanos_max)), ); let r_min = PrimitiveArray::::from_iter( - r_mdn.iter().map(|mdn| mdn.map(|mdn| mdn_in_nanos_min(mdn))), + r_mdn.iter().map(|mdn| mdn.map(mdn_in_nanos_min)), ); Some((l_max, r_min)) } Op::Greater | Op::GreaterEqual => { let l_min = PrimitiveArray::::from_iter( - l_mdn.iter().map(|mdn| mdn.map(|mdn| mdn_in_nanos_min(mdn))), + l_mdn.iter().map(|mdn| mdn.map(mdn_in_nanos_min)), ); let r_max = PrimitiveArray::::from_iter( - r_mdn.iter().map(|mdn| mdn.map(|mdn| mdn_in_nanos_max(mdn))), + r_mdn.iter().map(|mdn| mdn.map(mdn_in_nanos_max)), ); Some((l_min, r_max)) } @@ -637,7 +637,7 @@ fn mdn_in_nanos_max(mdn: i128) -> i64 { let m = (mdn >> 96) as i32; let d = (mdn >> 64) as i32; let n = mdn as i64; - ((m * 31) + d) as i64 * (NANOS_PER_DAY + 1_000_000_000) as i64 + n + ((m * 31) + d) as i64 * (NANOS_PER_DAY + 1_000_000_000) + n } #[inline] @@ -645,7 +645,7 @@ fn mdn_in_nanos_min(mdn: i128) -> i64 { let m = (mdn >> 96) as i32; let d = (mdn >> 64) as i32; let n = mdn as i64; - ((m * 28) + d) as i64 * (NANOS_PER_DAY) as i64 + n + ((m * 28) + d) as i64 * (NANOS_PER_DAY) + n } #[cfg(test)] diff --git a/arrow-ord/src/comparison.rs b/arrow-ord/src/comparison.rs index 7268156dea95..3c2b67a18fbd 100644 --- a/arrow-ord/src/comparison.rs +++ b/arrow-ord/src/comparison.rs @@ -2047,18 +2047,14 @@ mod tests { assert_eq!(res, res_eq); assert_eq!( &res, - &BooleanArray::from( - vec![ Some(true), Some(true), Some(true), Some(true), Some(false)] - ) + &BooleanArray::from(vec![ Some(true), Some(true), Some(true), Some(true), Some(false)]) ); let res = lt(&b, &a).unwrap(); let res_eq = lt_eq(&b, &a).unwrap(); assert_eq!(res, res_eq); assert_eq!( &res, - &BooleanArray::from( - vec![ Some(true), Some(true), Some(true), Some(true), Some(false)] - ) + &BooleanArray::from(vec![ Some(true), Some(true), Some(true), Some(true), Some(false)]) ); let a = IntervalMonthDayNanoArray::from( From 248942bd5c592f4069ece19be073a6492ad6fe2f Mon Sep 17 00:00:00 2001 From: berkaysynnada Date: Wed, 6 Dec 2023 17:44:32 +0300 Subject: [PATCH 4/8] dirty commit to check diff --- arrow-array/src/array/primitive_array.rs | 1 + arrow-array/src/cast.rs | 64 ++++++++ arrow-array/src/types.rs | 104 +++++++++---- arrow-data/src/data.rs | 2 + arrow-data/src/equal/mod.rs | 1 + arrow-data/src/transform/mod.rs | 3 + arrow-integration-test/src/datatype.rs | 1 + arrow-ipc/src/convert.rs | 3 +- arrow-ord/src/cmp.rs | 131 +++++++--------- arrow-schema/src/datatype.rs | 4 + arrow-schema/src/field.rs | 1 + parquet/src/arrow/schema/mod.rs | 184 +++++++---------------- 12 files changed, 264 insertions(+), 235 deletions(-) diff --git a/arrow-array/src/array/primitive_array.rs b/arrow-array/src/array/primitive_array.rs index 1112acacfcd9..a4192f2ad2f5 100644 --- a/arrow-array/src/array/primitive_array.rs +++ b/arrow-array/src/array/primitive_array.rs @@ -1180,6 +1180,7 @@ def_from_for_primitive!(Int8Type, i8); def_from_for_primitive!(Int16Type, i16); def_from_for_primitive!(Int32Type, i32); def_from_for_primitive!(Int64Type, i64); +def_from_for_primitive!(Int128Type, i128); def_from_for_primitive!(UInt8Type, u8); def_from_for_primitive!(UInt16Type, u16); def_from_for_primitive!(UInt32Type, u32); diff --git a/arrow-array/src/cast.rs b/arrow-array/src/cast.rs index 2e21f3e7e640..11bf0090c29a 100644 --- a/arrow-array/src/cast.rs +++ b/arrow-array/src/cast.rs @@ -330,6 +330,51 @@ macro_rules! downcast_primitive { }; } +#[macro_export] +macro_rules! downcast_primitive_cmp { + ($($data_type:expr),+ => ($m:path $(, $args:tt)*), $($p:pat => $fallback:expr $(,)*)*) => { + $crate::downcast_integer! { + $($data_type),+ => ($m $(, $args)*), + $crate::repeat_pat!(arrow_schema::DataType::Float16, $($data_type),+) => { + $m!($crate::types::Float16Type $(, $args)*) + } + $crate::repeat_pat!(arrow_schema::DataType::Float32, $($data_type),+) => { + $m!($crate::types::Float32Type $(, $args)*) + } + $crate::repeat_pat!(arrow_schema::DataType::Float64, $($data_type),+) => { + $m!($crate::types::Float64Type $(, $args)*) + } + $crate::repeat_pat!(arrow_schema::DataType::Decimal128(_, _), $($data_type),+) => { + $m!($crate::types::Decimal128Type $(, $args)*) + } + $crate::repeat_pat!(arrow_schema::DataType::Decimal256(_, _), $($data_type),+) => { + $m!($crate::types::Decimal256Type $(, $args)*) + } + $crate::repeat_pat!(arrow_schema::DataType::Interval(arrow_schema::IntervalUnit::YearMonth), $($data_type),+) => { + $m!($crate::types::IntervalYearMonthType $(, $args)*) + } + $crate::repeat_pat!(arrow_schema::DataType::Duration(arrow_schema::TimeUnit::Second), $($data_type),+) => { + $m!($crate::types::DurationSecondType $(, $args)*) + } + $crate::repeat_pat!(arrow_schema::DataType::Duration(arrow_schema::TimeUnit::Millisecond), $($data_type),+) => { + $m!($crate::types::DurationMillisecondType $(, $args)*) + } + $crate::repeat_pat!(arrow_schema::DataType::Duration(arrow_schema::TimeUnit::Microsecond), $($data_type),+) => { + $m!($crate::types::DurationMicrosecondType $(, $args)*) + } + $crate::repeat_pat!(arrow_schema::DataType::Duration(arrow_schema::TimeUnit::Nanosecond), $($data_type),+) => { + $m!($crate::types::DurationNanosecondType $(, $args)*) + } + _ => { + $crate::downcast_temporal! { + $($data_type),+ => ($m $(, $args)*), + $($p => $fallback,)* + } + } + } + }; +} + #[macro_export] #[doc(hidden)] macro_rules! downcast_primitive_array_helper { @@ -383,6 +428,25 @@ macro_rules! downcast_primitive_array { }; } +#[macro_export] +macro_rules! downcast_primitive_array_cmp { + ($values:ident => $e:expr, $($p:pat => $fallback:expr $(,)*)*) => { + $crate::downcast_primitive_array_cmp!($values => {$e} $($p => $fallback)*) + }; + (($($values:ident),+) => $e:expr, $($p:pat => $fallback:expr $(,)*)*) => { + $crate::downcast_primitive_array_cmp!($($values),+ => {$e} $($p => $fallback)*) + }; + ($($values:ident),+ => $e:block $($p:pat => $fallback:expr $(,)*)*) => { + $crate::downcast_primitive_array_cmp!(($($values),+) => $e $($p => $fallback)*) + }; + (($($values:ident),+) => $e:block $($p:pat => $fallback:expr $(,)*)*) => { + $crate::downcast_primitive_cmp!{ + $($values.data_type()),+ => ($crate::downcast_primitive_array_helper, $($values),+, $e), + $($p => $fallback,)* + } + }; +} + /// Force downcast of an [`Array`], such as an [`ArrayRef`], to /// [`PrimitiveArray`], panic'ing on failure. /// diff --git a/arrow-array/src/types.rs b/arrow-array/src/types.rs index 16d0e822d052..db465f53505b 100644 --- a/arrow-array/src/types.rs +++ b/arrow-array/src/types.rs @@ -49,7 +49,9 @@ impl BooleanType { /// static-typed nature of rust types ([`ArrowNativeType`]) for all types that implement [`ArrowNativeType`]. /// /// [`ArrowNativeType`]: arrow_buffer::ArrowNativeType -pub trait ArrowPrimitiveType: primitive::PrimitiveTypeSealed + 'static { +pub trait ArrowPrimitiveType: + primitive::PrimitiveTypeSealed + 'static +{ /// Corresponding Rust native type for the primitive type. type Native: ArrowNativeTypeOp; @@ -74,12 +76,12 @@ mod primitive { } macro_rules! make_type { - ($name:ident, $native_ty:ty, $data_ty:expr, $doc_string: literal) => { + ($name:ident, $native_ty:ty, $data_ty:expr, $doc_string: literal, $custom_cmp:expr) => { #[derive(Debug)] #[doc = $doc_string] pub struct $name {} - impl ArrowPrimitiveType for $name { + impl ArrowPrimitiveType<$custom_cmp> for $name { type Native = $native_ty; const DATA_TYPE: DataType = $data_ty; } @@ -88,168 +90,208 @@ macro_rules! make_type { }; } -make_type!(Int8Type, i8, DataType::Int8, "A signed 8-bit integer type."); +make_type!( + Int8Type, + i8, + DataType::Int8, + "A signed 8-bit integer type.", + false +); make_type!( Int16Type, i16, DataType::Int16, - "A signed 16-bit integer type." + "A signed 16-bit integer type.", + false ); make_type!( Int32Type, i32, DataType::Int32, - "A signed 32-bit integer type." + "A signed 32-bit integer type.", + false ); make_type!( Int64Type, i64, DataType::Int64, - "A signed 64-bit integer type." + "A signed 64-bit integer type.", + false +); +make_type!( + Int128Type, + i128, + DataType::Int128, + "A signed 128-bit integer type.", + false ); make_type!( UInt8Type, u8, DataType::UInt8, - "An unsigned 8-bit integer type." + "An unsigned 8-bit integer type.", + false ); make_type!( UInt16Type, u16, DataType::UInt16, - "An unsigned 16-bit integer type." + "An unsigned 16-bit integer type.", + false ); make_type!( UInt32Type, u32, DataType::UInt32, - "An unsigned 32-bit integer type." + "An unsigned 32-bit integer type.", + false ); make_type!( UInt64Type, u64, DataType::UInt64, - "An unsigned 64-bit integer type." + "An unsigned 64-bit integer type.", + false ); make_type!( Float16Type, f16, DataType::Float16, - "A 16-bit floating point number type." + "A 16-bit floating point number type.", + false ); make_type!( Float32Type, f32, DataType::Float32, - "A 32-bit floating point number type." + "A 32-bit floating point number type.", + false ); make_type!( Float64Type, f64, DataType::Float64, - "A 64-bit floating point number type." + "A 64-bit floating point number type.", + false ); make_type!( TimestampSecondType, i64, DataType::Timestamp(TimeUnit::Second, None), - "A timestamp second type with an optional timezone." + "A timestamp second type with an optional timezone.", + false ); make_type!( TimestampMillisecondType, i64, DataType::Timestamp(TimeUnit::Millisecond, None), - "A timestamp millisecond type with an optional timezone." + "A timestamp millisecond type with an optional timezone.", + false ); make_type!( TimestampMicrosecondType, i64, DataType::Timestamp(TimeUnit::Microsecond, None), - "A timestamp microsecond type with an optional timezone." + "A timestamp microsecond type with an optional timezone.", + false ); make_type!( TimestampNanosecondType, i64, DataType::Timestamp(TimeUnit::Nanosecond, None), - "A timestamp nanosecond type with an optional timezone." + "A timestamp nanosecond type with an optional timezone.", + false ); make_type!( Date32Type, i32, DataType::Date32, - "A 32-bit date type representing the elapsed time since UNIX epoch in days(32 bits)." + "A 32-bit date type representing the elapsed time since UNIX epoch in days(32 bits).", + false ); make_type!( Date64Type, i64, DataType::Date64, - "A 64-bit date type representing the elapsed time since UNIX epoch in milliseconds(64 bits)." + "A 64-bit date type representing the elapsed time since UNIX epoch in milliseconds(64 bits).", + false ); make_type!( Time32SecondType, i32, DataType::Time32(TimeUnit::Second), - "A 32-bit time type representing the elapsed time since midnight in seconds." + "A 32-bit time type representing the elapsed time since midnight in seconds.", + false ); make_type!( Time32MillisecondType, i32, DataType::Time32(TimeUnit::Millisecond), - "A 32-bit time type representing the elapsed time since midnight in milliseconds." + "A 32-bit time type representing the elapsed time since midnight in milliseconds.", + false ); make_type!( Time64MicrosecondType, i64, DataType::Time64(TimeUnit::Microsecond), - "A 64-bit time type representing the elapsed time since midnight in microseconds." + "A 64-bit time type representing the elapsed time since midnight in microseconds.", + false ); make_type!( Time64NanosecondType, i64, DataType::Time64(TimeUnit::Nanosecond), - "A 64-bit time type representing the elapsed time since midnight in nanoseconds." + "A 64-bit time type representing the elapsed time since midnight in nanoseconds.", + false ); make_type!( IntervalYearMonthType, i32, DataType::Interval(IntervalUnit::YearMonth), - "A “calendar” interval type in months." + "A “calendar” interval type in months.", + false ); make_type!( IntervalDayTimeType, i64, DataType::Interval(IntervalUnit::DayTime), - "A “calendar” interval type in days and milliseconds." + "A “calendar” interval type in days and milliseconds.", + false ); make_type!( IntervalMonthDayNanoType, i128, DataType::Interval(IntervalUnit::MonthDayNano), - "A “calendar” interval type in months, days, and nanoseconds." + "A “calendar” interval type in months, days, and nanoseconds.", + false ); make_type!( DurationSecondType, i64, DataType::Duration(TimeUnit::Second), - "An elapsed time type in seconds." + "An elapsed time type in seconds.", + false ); make_type!( DurationMillisecondType, i64, DataType::Duration(TimeUnit::Millisecond), - "An elapsed time type in milliseconds." + "An elapsed time type in milliseconds.", + false ); make_type!( DurationMicrosecondType, i64, DataType::Duration(TimeUnit::Microsecond), - "An elapsed time type in microseconds." + "An elapsed time type in microseconds.", + false ); make_type!( DurationNanosecondType, i64, DataType::Duration(TimeUnit::Nanosecond), - "An elapsed time type in nanoseconds." + "An elapsed time type in nanoseconds.", + false ); /// A subtype of primitive type that represents legal dictionary keys. diff --git a/arrow-data/src/data.rs b/arrow-data/src/data.rs index 10c53c549e2b..d10ac4cd5635 100644 --- a/arrow-data/src/data.rs +++ b/arrow-data/src/data.rs @@ -84,6 +84,7 @@ pub(crate) fn new_buffers(data_type: &DataType, capacity: usize) -> [MutableBuff | DataType::Int16 | DataType::Int32 | DataType::Int64 + | DataType::Int128 | DataType::Float16 | DataType::Float32 | DataType::Float64 @@ -1509,6 +1510,7 @@ pub fn layout(data_type: &DataType) -> DataTypeLayout { DataType::Int16 => DataTypeLayout::new_fixed_width::(), DataType::Int32 => DataTypeLayout::new_fixed_width::(), DataType::Int64 => DataTypeLayout::new_fixed_width::(), + DataType::Int128 => DataTypeLayout::new_fixed_width::(), DataType::UInt8 => DataTypeLayout::new_fixed_width::(), DataType::UInt16 => DataTypeLayout::new_fixed_width::(), DataType::UInt32 => DataTypeLayout::new_fixed_width::(), diff --git a/arrow-data/src/equal/mod.rs b/arrow-data/src/equal/mod.rs index b279546474a0..ba56c42704bd 100644 --- a/arrow-data/src/equal/mod.rs +++ b/arrow-data/src/equal/mod.rs @@ -74,6 +74,7 @@ fn equal_values( DataType::Int16 => primitive_equal::(lhs, rhs, lhs_start, rhs_start, len), DataType::Int32 => primitive_equal::(lhs, rhs, lhs_start, rhs_start, len), DataType::Int64 => primitive_equal::(lhs, rhs, lhs_start, rhs_start, len), + DataType::Int128 => primitive_equal::(lhs, rhs, lhs_start, rhs_start, len), DataType::Float32 => primitive_equal::(lhs, rhs, lhs_start, rhs_start, len), DataType::Float64 => primitive_equal::(lhs, rhs, lhs_start, rhs_start, len), DataType::Decimal128(_, _) => primitive_equal::(lhs, rhs, lhs_start, rhs_start, len), diff --git a/arrow-data/src/transform/mod.rs b/arrow-data/src/transform/mod.rs index 268cf10f2326..2b243fab4ccc 100644 --- a/arrow-data/src/transform/mod.rs +++ b/arrow-data/src/transform/mod.rs @@ -209,6 +209,7 @@ fn build_extend(array: &ArrayData) -> Extend { DataType::Int16 => primitive::build_extend::(array), DataType::Int32 => primitive::build_extend::(array), DataType::Int64 => primitive::build_extend::(array), + DataType::Int128 => primitive::build_extend::(array), DataType::Float32 => primitive::build_extend::(array), DataType::Float64 => primitive::build_extend::(array), DataType::Date32 | DataType::Time32(_) | DataType::Interval(IntervalUnit::YearMonth) => { @@ -251,6 +252,7 @@ fn build_extend_nulls(data_type: &DataType) -> ExtendNulls { DataType::Int16 => primitive::extend_nulls::, DataType::Int32 => primitive::extend_nulls::, DataType::Int64 => primitive::extend_nulls::, + DataType::Int128 => primitive::extend_nulls::, DataType::Float32 => primitive::extend_nulls::, DataType::Float64 => primitive::extend_nulls::, DataType::Date32 | DataType::Time32(_) | DataType::Interval(IntervalUnit::YearMonth) => { @@ -404,6 +406,7 @@ impl<'a> MutableArrayData<'a> { | DataType::Int16 | DataType::Int32 | DataType::Int64 + | DataType::Int128 | DataType::Float16 | DataType::Float32 | DataType::Float64 diff --git a/arrow-integration-test/src/datatype.rs b/arrow-integration-test/src/datatype.rs index 42ac71fbbd7e..3382ecfaa193 100644 --- a/arrow-integration-test/src/datatype.rs +++ b/arrow-integration-test/src/datatype.rs @@ -260,6 +260,7 @@ pub fn data_type_to_json(data_type: &DataType) -> serde_json::Value { DataType::Int16 => json!({"name": "int", "bitWidth": 16, "isSigned": true}), DataType::Int32 => json!({"name": "int", "bitWidth": 32, "isSigned": true}), DataType::Int64 => json!({"name": "int", "bitWidth": 64, "isSigned": true}), + DataType::Int128 => json!({"name": "int", "bitWidth": 128, "isSigned": true}), DataType::UInt8 => json!({"name": "int", "bitWidth": 8, "isSigned": false}), DataType::UInt16 => json!({"name": "int", "bitWidth": 16, "isSigned": false}), DataType::UInt32 => json!({"name": "int", "bitWidth": 32, "isSigned": false}), diff --git a/arrow-ipc/src/convert.rs b/arrow-ipc/src/convert.rs index b290a09acf5d..42067a4ece9b 100644 --- a/arrow-ipc/src/convert.rs +++ b/arrow-ipc/src/convert.rs @@ -499,7 +499,7 @@ pub(crate) fn get_fb_field_type<'a>( children: Some(children), } } - Int8 | Int16 | Int32 | Int64 => { + Int8 | Int16 | Int32 | Int64 | Int128 => { let children = fbb.create_vector(&empty_fields[..]); let mut builder = crate::IntBuilder::new(fbb); builder.add_is_signed(true); @@ -508,6 +508,7 @@ pub(crate) fn get_fb_field_type<'a>( Int16 => builder.add_bitWidth(16), Int32 => builder.add_bitWidth(32), Int64 => builder.add_bitWidth(64), + Int128 => builder.add_bitWidth(128), _ => {} }; FBFieldType { diff --git a/arrow-ord/src/cmp.rs b/arrow-ord/src/cmp.rs index e53729b04898..f22d69574af1 100644 --- a/arrow-ord/src/cmp.rs +++ b/arrow-ord/src/cmp.rs @@ -24,20 +24,23 @@ //! use arrow_array::cast::AsArray; -use arrow_array::types::{ByteArrayType, Int64Type, IntervalDayTimeType, IntervalMonthDayNanoType}; +use arrow_array::types::{ + ArrowPrimitiveType, ByteArrayType, Int128Type, Int64Type, IntervalDayTimeType, + IntervalMonthDayNanoType, +}; use arrow_array::{ - downcast_primitive_array, AnyDictionaryArray, Array, ArrowNativeTypeOp, BooleanArray, Datum, - FixedSizeBinaryArray, GenericByteArray, PrimitiveArray, + downcast_primitive_array, downcast_primitive_array_cmp, AnyDictionaryArray, Array, + ArrowNativeTypeOp, BooleanArray, Datum, FixedSizeBinaryArray, GenericByteArray, PrimitiveArray, }; use arrow_buffer::bit_util::ceil; use arrow_buffer::{BooleanBuffer, MutableBuffer, NullBuffer}; use arrow_schema::ArrowError; +use arrow_schema::IntervalUnit; use arrow_select::take::take; use std::ops::Not; const MILLIS_PER_DAY: i64 = 86_400_000; - -const NANOS_PER_DAY: i64 = 86_400_000_000_000; +const NANOS_PER_DAY: i128 = 86_400_000_000_000; #[derive(Debug, Copy, Clone)] enum Op { @@ -208,16 +211,12 @@ fn compare_op(op: Op, lhs: &dyn Datum, rhs: &dyn Datum) -> Result BooleanBuffer { - let d = downcast_primitive_array! { + let d = downcast_primitive_array_cmp! { (l, r) => apply(op, l.values().as_ref(), l_s, l_v, r.values().as_ref(), r_s, r_v), (Boolean, Boolean) => apply(op, l.as_boolean(), l_s, l_v, r.as_boolean(), r_s, r_v), (Utf8, Utf8) => apply(op, l.as_string::(), l_s, l_v, r.as_string::(), r_s, r_v), @@ -225,6 +224,8 @@ fn compare_op(op: Op, lhs: &dyn Datum, rhs: &dyn Datum) -> Result apply(op, l.as_binary::(), l_s, l_v, r.as_binary::(), r_s, r_v), (LargeBinary, LargeBinary) => apply(op, l.as_binary::(), l_s, l_v, r.as_binary::(), r_s, r_v), (FixedSizeBinary(_), FixedSizeBinary(_)) => apply(op, l.as_fixed_size_binary(), l_s, l_v, r.as_fixed_size_binary(), r_s, r_v), + (Interval(IntervalUnit::DayTime), Interval(IntervalUnit::DayTime)) => apply(op, safer_interval_dt(l, op, true).values().as_ref(), l_s, l_v, safer_interval_dt(r, op, false).values().as_ref(), r_s, r_v), + (Interval(IntervalUnit::MonthDayNano), Interval(IntervalUnit::MonthDayNano)) => apply(op, safer_interval_mdn(l, op, true).values().as_ref(), l_s, l_v, safer_interval_mdn(r, op, false).values().as_ref(), r_s, r_v), (Null, Null) => None, _ => unreachable!(), }; @@ -552,69 +553,53 @@ impl<'a> ArrayOrd for &'a FixedSizeBinaryArray { } #[inline] -fn safe_cmp_for_intervals( - l: &dyn Array, - r: &dyn Array, - op: Op, -) -> Option<(PrimitiveArray, PrimitiveArray)> { - match ( - l.as_primitive_opt::(), - r.as_primitive_opt::(), - l.as_primitive_opt::(), - r.as_primitive_opt::(), - ) { - (Some(l_dt), Some(r_dt), _, _) => match op { - Op::Less | Op::LessEqual => { - let l_max = PrimitiveArray::::from( - l_dt.iter() - .map(|dt| dt.map(dt_in_millis_max)) - .collect::>(), - ); - let r_min = PrimitiveArray::::from( - r_dt.iter() - .map(|dt| dt.map(dt_in_millis_min)) - .collect::>(), - ); - Some((l_max, r_min)) +fn safer_interval_dt(dt: &dyn Array, op: Op, lhs: bool) -> PrimitiveArray { + match dt.as_primitive_opt::() { + Some(dt) => match (op, lhs) { + (Op::Less | Op::LessEqual, true) | (Op::Greater | Op::GreaterEqual, false) => { + PrimitiveArray::::from_iter(dt.iter().map(|dt| dt.map(dt_in_millis_max))) + } + (Op::Greater | Op::GreaterEqual, true) | (Op::Less | Op::LessEqual, false) => { + PrimitiveArray::::from_iter(dt.iter().map(|dt| dt.map(dt_in_millis_min))) } - Op::Greater | Op::GreaterEqual => { - let l_min = PrimitiveArray::::from( - l_dt.iter() - .map(|dt| dt.map(dt_in_millis_min)) - .collect::>(), - ); - let r_max = PrimitiveArray::::from( - r_dt.iter() - .map(|dt| dt.map(dt_in_millis_max)) - .collect::>(), - ); - Some((l_min, r_max)) + (Op::Equal, _) | (Op::NotEqual, _) => PrimitiveArray::::from_iter(dt.iter()), + _ => { + panic!( + "Invalid operator {:?} for Interval(IntervalDayTime) comparison", + op + ) } - _ => None, }, - (_, _, Some(l_mdn), Some(r_mdn)) => match op { - Op::Less | Op::LessEqual => { - let l_max = PrimitiveArray::::from_iter( - l_mdn.iter().map(|mdn| mdn.map(mdn_in_nanos_max)), - ); - let r_min = PrimitiveArray::::from_iter( - r_mdn.iter().map(|mdn| mdn.map(mdn_in_nanos_min)), - ); - Some((l_max, r_min)) + _ => { + panic!("Invalid datatype for Interval(IntervalDayTime) comparison") + } + } +} + +#[inline] +fn safer_interval_mdn(mdn: &dyn Array, op: Op, lhs: bool) -> PrimitiveArray { + match mdn.as_primitive_opt::() { + Some(mdn) => match (op, lhs) { + (Op::Less | Op::LessEqual, true) | (Op::Greater | Op::GreaterEqual, false) => { + PrimitiveArray::::from_iter( + mdn.iter().map(|mdn| mdn.map(mdn_in_nanos_max)), + ) + } + (Op::Greater | Op::GreaterEqual, true) | (Op::Less | Op::LessEqual, false) => { + PrimitiveArray::::from_iter( + mdn.iter().map(|mdn| mdn.map(mdn_in_nanos_min)), + ) } - Op::Greater | Op::GreaterEqual => { - let l_min = PrimitiveArray::::from_iter( - l_mdn.iter().map(|mdn| mdn.map(mdn_in_nanos_min)), - ); - let r_max = PrimitiveArray::::from_iter( - r_mdn.iter().map(|mdn| mdn.map(mdn_in_nanos_max)), - ); - Some((l_min, r_max)) + (Op::Equal, _) | (Op::NotEqual, _) => { + PrimitiveArray::::from_iter(mdn.iter()) + } + _ => { + panic!("Invalid operator for Interval(IntervalMonthDayNano) comparison") } - _ => None, }, - - _ => None, + _ => { + panic!("Invalid datatype for Interval(IntervalMonthDayNano) comparison") + } } } @@ -633,19 +618,19 @@ fn dt_in_millis_min(dt: i64) -> i64 { } #[inline] -fn mdn_in_nanos_max(mdn: i128) -> i64 { +fn mdn_in_nanos_max(mdn: i128) -> i128 { let m = (mdn >> 96) as i32; let d = (mdn >> 64) as i32; let n = mdn as i64; - ((m * 31) + d) as i64 * (NANOS_PER_DAY + 1_000_000_000) + n + ((m as i128 * 31) + d as i128) * (NANOS_PER_DAY + 1_000_000_000) + n as i128 } #[inline] -fn mdn_in_nanos_min(mdn: i128) -> i64 { +fn mdn_in_nanos_min(mdn: i128) -> i128 { let m = (mdn >> 96) as i32; let d = (mdn >> 64) as i32; let n = mdn as i64; - ((m * 28) + d) as i64 * (NANOS_PER_DAY) + n + ((m as i128 * 28) + d as i128) * (NANOS_PER_DAY) + n as i128 } #[cfg(test)] diff --git a/arrow-schema/src/datatype.rs b/arrow-schema/src/datatype.rs index b78c785ae279..f0c4e62fd753 100644 --- a/arrow-schema/src/datatype.rs +++ b/arrow-schema/src/datatype.rs @@ -50,6 +50,8 @@ pub enum DataType { Int32, /// A signed 64-bit integer. Int64, + /// A signed 128-bit integer. + Int128, /// An unsigned 8-bit integer. UInt8, /// An unsigned 16-bit integer. @@ -467,6 +469,7 @@ impl DataType { DataType::Int16 | DataType::UInt16 | DataType::Float16 => Some(2), DataType::Int32 | DataType::UInt32 | DataType::Float32 => Some(4), DataType::Int64 | DataType::UInt64 | DataType::Float64 => Some(8), + DataType::Int128 => Some(16), DataType::Timestamp(_, _) => Some(8), DataType::Date32 | DataType::Time32(_) => Some(4), DataType::Date64 | DataType::Time64(_) => Some(8), @@ -500,6 +503,7 @@ impl DataType { | DataType::Int16 | DataType::Int32 | DataType::Int64 + | DataType::Int128 | DataType::UInt8 | DataType::UInt16 | DataType::UInt32 diff --git a/arrow-schema/src/field.rs b/arrow-schema/src/field.rs index 574c024bb9b9..0bac28e6c7c0 100644 --- a/arrow-schema/src/field.rs +++ b/arrow-schema/src/field.rs @@ -459,6 +459,7 @@ impl Field { | DataType::Int16 | DataType::Int32 | DataType::Int64 + | DataType::Int128 | DataType::UInt8 | DataType::UInt16 | DataType::UInt32 diff --git a/parquet/src/arrow/schema/mod.rs b/parquet/src/arrow/schema/mod.rs index 4c350c4b1d8c..3f619b563ff7 100644 --- a/parquet/src/arrow/schema/mod.rs +++ b/parquet/src/arrow/schema/mod.rs @@ -32,8 +32,7 @@ use arrow_ipc::writer; use arrow_schema::{DataType, Field, Fields, Schema, TimeUnit}; use crate::basic::{ - ConvertedType, LogicalType, Repetition, TimeUnit as ParquetTimeUnit, - Type as PhysicalType, + ConvertedType, LogicalType, Repetition, TimeUnit as ParquetTimeUnit, Type as PhysicalType, }; use crate::errors::{ParquetError, Result}; use crate::file::{metadata::KeyValue, properties::WriterProperties}; @@ -55,11 +54,7 @@ pub fn parquet_to_arrow_schema( parquet_schema: &SchemaDescriptor, key_value_metadata: Option<&Vec>, ) -> Result { - parquet_to_arrow_schema_by_columns( - parquet_schema, - ProjectionMask::all(), - key_value_metadata, - ) + parquet_to_arrow_schema_by_columns(parquet_schema, ProjectionMask::all(), key_value_metadata) } /// Convert parquet schema to arrow schema including optional metadata, @@ -199,10 +194,7 @@ fn encode_arrow_schema(schema: &Schema) -> String { /// Mutates writer metadata by storing the encoded Arrow schema. /// If there is an existing Arrow schema metadata, it is replaced. -pub(crate) fn add_encoded_arrow_schema_to_metadata( - schema: &Schema, - props: &mut WriterProperties, -) { +pub(crate) fn add_encoded_arrow_schema_to_metadata(schema: &Schema, props: &mut WriterProperties) { let encoded = encode_arrow_schema(schema); let schema_kv = KeyValue { @@ -270,16 +262,15 @@ fn parse_key_value_metadata( /// Convert parquet column schema to arrow field. pub fn parquet_to_arrow_field(parquet_column: &ColumnDescriptor) -> Result { let field = complex::convert_type(&parquet_column.self_type_ptr())?; - let mut ret = Field::new( - parquet_column.name(), - field.arrow_type, - field.nullable, - ); + let mut ret = Field::new(parquet_column.name(), field.arrow_type, field.nullable); let basic_info = parquet_column.self_type().get_basic_info(); if basic_info.has_id() { let mut meta = HashMap::with_capacity(1); - meta.insert(PARQUET_FIELD_ID_META_KEY.to_string(), basic_info.id().to_string()); + meta.insert( + PARQUET_FIELD_ID_META_KEY.to_string(), + basic_info.id().to_string(), + ); ret.set_metadata(meta); } @@ -341,6 +332,7 @@ fn arrow_to_parquet_type(field: &Field) -> Result { .with_repetition(repetition) .with_id(id) .build(), + DataType::Int128 => unreachable!(), DataType::UInt8 => Type::primitive_type_builder(name, PhysicalType::INT32) .with_logical_type(Some(LogicalType::Integer { bit_width: 8, @@ -401,15 +393,9 @@ fn arrow_to_parquet_type(field: &Field) -> Result { is_adjusted_to_u_t_c: matches!(tz, Some(z) if !z.as_ref().is_empty()), unit: match time_unit { TimeUnit::Second => unreachable!(), - TimeUnit::Millisecond => { - ParquetTimeUnit::MILLIS(Default::default()) - } - TimeUnit::Microsecond => { - ParquetTimeUnit::MICROS(Default::default()) - } - TimeUnit::Nanosecond => { - ParquetTimeUnit::NANOS(Default::default()) - } + TimeUnit::Millisecond => ParquetTimeUnit::MILLIS(Default::default()), + TimeUnit::Microsecond => ParquetTimeUnit::MICROS(Default::default()), + TimeUnit::Nanosecond => ParquetTimeUnit::NANOS(Default::default()), }, })) .with_repetition(repetition) @@ -457,9 +443,7 @@ fn arrow_to_parquet_type(field: &Field) -> Result { .with_repetition(repetition) .with_id(id) .build(), - DataType::Duration(_) => { - Err(arrow_err!("Converting Duration to parquet not supported",)) - } + DataType::Duration(_) => Err(arrow_err!("Converting Duration to parquet not supported",)), DataType::Interval(_) => { Type::primitive_type_builder(name, PhysicalType::FIXED_LEN_BYTE_ARRAY) .with_converted_type(ConvertedType::INTERVAL) @@ -481,8 +465,7 @@ fn arrow_to_parquet_type(field: &Field) -> Result { .with_length(*length) .build() } - DataType::Decimal128(precision, scale) - | DataType::Decimal256(precision, scale) => { + DataType::Decimal128(precision, scale) | DataType::Decimal256(precision, scale) => { // Decimal precision determines the Parquet physical type to use. // Following the: https://github.com/apache/parquet-format/blob/master/LogicalTypes.md#decimal let (physical_type, length) = if *precision > 1 && *precision <= 9 { @@ -529,9 +512,7 @@ fn arrow_to_parquet_type(field: &Field) -> Result { } DataType::Struct(fields) => { if fields.is_empty() { - return Err( - arrow_err!("Parquet does not support writing empty structs",), - ); + return Err(arrow_err!("Parquet does not support writing empty structs",)); } // recursively convert children to types/nodes let fields = fields @@ -621,8 +602,7 @@ mod tests { let parquet_group_type = parse_message_type(message_type).unwrap(); let parquet_schema = SchemaDescriptor::new(Arc::new(parquet_group_type)); - let converted_arrow_schema = - parquet_to_arrow_schema(&parquet_schema, None).unwrap(); + let converted_arrow_schema = parquet_to_arrow_schema(&parquet_schema, None).unwrap(); let arrow_fields = Fields::from(vec![ Field::new("boolean", DataType::Boolean, false), @@ -660,8 +640,7 @@ mod tests { let parquet_group_type = parse_message_type(message_type).unwrap(); let parquet_schema = SchemaDescriptor::new(Arc::new(parquet_group_type)); - let converted_arrow_schema = - parquet_to_arrow_schema(&parquet_schema, None).unwrap(); + let converted_arrow_schema = parquet_to_arrow_schema(&parquet_schema, None).unwrap(); let arrow_fields = Fields::from(vec![ Field::new("decimal1", DataType::Decimal128(4, 2), false), @@ -687,8 +666,7 @@ mod tests { let parquet_group_type = parse_message_type(message_type).unwrap(); let parquet_schema = SchemaDescriptor::new(Arc::new(parquet_group_type)); - let converted_arrow_schema = - parquet_to_arrow_schema(&parquet_schema, None).unwrap(); + let converted_arrow_schema = parquet_to_arrow_schema(&parquet_schema, None).unwrap(); let arrow_fields = Fields::from(vec![ Field::new("binary", DataType::Binary, false), @@ -709,8 +687,7 @@ mod tests { let parquet_group_type = parse_message_type(message_type).unwrap(); let parquet_schema = SchemaDescriptor::new(Arc::new(parquet_group_type)); - let converted_arrow_schema = - parquet_to_arrow_schema(&parquet_schema, None).unwrap(); + let converted_arrow_schema = parquet_to_arrow_schema(&parquet_schema, None).unwrap(); let arrow_fields = Fields::from(vec![ Field::new("boolean", DataType::Boolean, false), @@ -718,12 +695,9 @@ mod tests { ]); assert_eq!(&arrow_fields, converted_arrow_schema.fields()); - let converted_arrow_schema = parquet_to_arrow_schema_by_columns( - &parquet_schema, - ProjectionMask::all(), - None, - ) - .unwrap(); + let converted_arrow_schema = + parquet_to_arrow_schema_by_columns(&parquet_schema, ProjectionMask::all(), None) + .unwrap(); assert_eq!(&arrow_fields, converted_arrow_schema.fields()); } @@ -921,8 +895,7 @@ mod tests { let parquet_group_type = parse_message_type(message_type).unwrap(); let parquet_schema = SchemaDescriptor::new(Arc::new(parquet_group_type)); - let converted_arrow_schema = - parquet_to_arrow_schema(&parquet_schema, None).unwrap(); + let converted_arrow_schema = parquet_to_arrow_schema(&parquet_schema, None).unwrap(); let converted_fields = converted_arrow_schema.fields(); assert_eq!(arrow_fields.len(), converted_fields.len()); @@ -1000,8 +973,7 @@ mod tests { let parquet_group_type = parse_message_type(message_type).unwrap(); let parquet_schema = SchemaDescriptor::new(Arc::new(parquet_group_type)); - let converted_arrow_schema = - parquet_to_arrow_schema(&parquet_schema, None).unwrap(); + let converted_arrow_schema = parquet_to_arrow_schema(&parquet_schema, None).unwrap(); let converted_fields = converted_arrow_schema.fields(); assert_eq!(arrow_fields.len(), converted_fields.len()); @@ -1095,8 +1067,7 @@ mod tests { let parquet_group_type = parse_message_type(message_type).unwrap(); let parquet_schema = SchemaDescriptor::new(Arc::new(parquet_group_type)); - let converted_arrow_schema = - parquet_to_arrow_schema(&parquet_schema, None).unwrap(); + let converted_arrow_schema = parquet_to_arrow_schema(&parquet_schema, None).unwrap(); let converted_fields = converted_arrow_schema.fields(); assert_eq!(arrow_fields.len(), converted_fields.len()); @@ -1113,8 +1084,7 @@ mod tests { Field::new("leaf1", DataType::Boolean, false), Field::new("leaf2", DataType::Int32, false), ]); - let group1_struct = - Field::new("group1", DataType::Struct(group1_fields), false); + let group1_struct = Field::new("group1", DataType::Struct(group1_fields), false); arrow_fields.push(group1_struct); let leaf3_field = Field::new("leaf3", DataType::Int64, false); @@ -1133,8 +1103,7 @@ mod tests { let parquet_group_type = parse_message_type(message_type).unwrap(); let parquet_schema = SchemaDescriptor::new(Arc::new(parquet_group_type)); - let converted_arrow_schema = - parquet_to_arrow_schema(&parquet_schema, None).unwrap(); + let converted_arrow_schema = parquet_to_arrow_schema(&parquet_schema, None).unwrap(); let converted_fields = converted_arrow_schema.fields(); assert_eq!(arrow_fields.len(), converted_fields.len()); @@ -1287,8 +1256,7 @@ mod tests { let parquet_group_type = parse_message_type(message_type).unwrap(); let parquet_schema = SchemaDescriptor::new(Arc::new(parquet_group_type)); - let converted_arrow_schema = - parquet_to_arrow_schema(&parquet_schema, None).unwrap(); + let converted_arrow_schema = parquet_to_arrow_schema(&parquet_schema, None).unwrap(); let converted_fields = converted_arrow_schema.fields(); assert_eq!(arrow_fields.len(), converted_fields.len()); @@ -1513,20 +1481,11 @@ mod tests { vec![ Field::new("bools", DataType::Boolean, false), Field::new("uint32", DataType::UInt32, false), - Field::new_list( - "int32", - Field::new("element", DataType::Int32, true), - false, - ), + Field::new_list("int32", Field::new("element", DataType::Int32, true), false), ], false, ), - Field::new_dictionary( - "dictionary_strings", - DataType::Int32, - DataType::Utf8, - false, - ), + Field::new_dictionary("dictionary_strings", DataType::Int32, DataType::Utf8, false), Field::new("decimal_int32", DataType::Decimal128(8, 2), false), Field::new("decimal_int64", DataType::Decimal128(16, 2), false), Field::new("decimal_fix_length", DataType::Decimal128(30, 2), false), @@ -1611,10 +1570,8 @@ mod tests { let schema = Schema::new_with_metadata( vec![ - Field::new("c1", DataType::Utf8, false).with_metadata(meta(&[ - ("Key", "Foo"), - (PARQUET_FIELD_ID_META_KEY, "2"), - ])), + Field::new("c1", DataType::Utf8, false) + .with_metadata(meta(&[("Key", "Foo"), (PARQUET_FIELD_ID_META_KEY, "2")])), Field::new("c2", DataType::Binary, false), Field::new("c3", DataType::FixedSizeBinary(3), false), Field::new("c4", DataType::Boolean, false), @@ -1632,10 +1589,7 @@ mod tests { ), Field::new( "c17", - DataType::Timestamp( - TimeUnit::Microsecond, - Some("Africa/Johannesburg".into()), - ), + DataType::Timestamp(TimeUnit::Microsecond, Some("Africa/Johannesburg".into())), false, ), Field::new( @@ -1647,10 +1601,8 @@ mod tests { Field::new("c20", DataType::Interval(IntervalUnit::YearMonth), false), Field::new_list( "c21", - Field::new("item", DataType::Boolean, true).with_metadata(meta(&[ - ("Key", "Bar"), - (PARQUET_FIELD_ID_META_KEY, "5"), - ])), + Field::new("item", DataType::Boolean, true) + .with_metadata(meta(&[("Key", "Bar"), (PARQUET_FIELD_ID_META_KEY, "5")])), false, ) .with_metadata(meta(&[(PARQUET_FIELD_ID_META_KEY, "4")])), @@ -1700,10 +1652,7 @@ mod tests { // Field::new("c30", DataType::Duration(TimeUnit::Nanosecond), false), Field::new_dict( "c31", - DataType::Dictionary( - Box::new(DataType::Int32), - Box::new(DataType::Utf8), - ), + DataType::Dictionary(Box::new(DataType::Int32), Box::new(DataType::Utf8)), true, 123, true, @@ -1738,11 +1687,7 @@ mod tests { "c39", "key_value", Field::new("key", DataType::Utf8, false), - Field::new_list( - "value", - Field::new("element", DataType::Utf8, true), - true, - ), + Field::new_list("value", Field::new("element", DataType::Utf8, true), true), false, // fails to roundtrip keys_sorted true, ), @@ -1781,11 +1726,8 @@ mod tests { // write to an empty parquet file so that schema is serialized let file = tempfile::tempfile().unwrap(); - let writer = ArrowWriter::try_new( - file.try_clone().unwrap(), - Arc::new(schema.clone()), - None, - )?; + let writer = + ArrowWriter::try_new(file.try_clone().unwrap(), Arc::new(schema.clone()), None)?; writer.close()?; // read file back @@ -1844,33 +1786,23 @@ mod tests { }; let schema = Schema::new_with_metadata( vec![ - Field::new("c1", DataType::Utf8, true).with_metadata(meta(&[ - (PARQUET_FIELD_ID_META_KEY, "1"), - ])), - Field::new("c2", DataType::Utf8, true).with_metadata(meta(&[ - (PARQUET_FIELD_ID_META_KEY, "2"), - ])), + Field::new("c1", DataType::Utf8, true) + .with_metadata(meta(&[(PARQUET_FIELD_ID_META_KEY, "1")])), + Field::new("c2", DataType::Utf8, true) + .with_metadata(meta(&[(PARQUET_FIELD_ID_META_KEY, "2")])), ], HashMap::new(), ); - let writer = ArrowWriter::try_new( - vec![], - Arc::new(schema.clone()), - None, - )?; + let writer = ArrowWriter::try_new(vec![], Arc::new(schema.clone()), None)?; let parquet_bytes = writer.into_inner()?; - let reader = crate::file::reader::SerializedFileReader::new( - bytes::Bytes::from(parquet_bytes), - )?; + let reader = + crate::file::reader::SerializedFileReader::new(bytes::Bytes::from(parquet_bytes))?; let schema_descriptor = reader.metadata().file_metadata().schema_descr_ptr(); // don't pass metadata so field ids are read from Parquet and not from serialized Arrow schema - let arrow_schema = crate::arrow::parquet_to_arrow_schema( - &schema_descriptor, - None, - )?; + let arrow_schema = crate::arrow::parquet_to_arrow_schema(&schema_descriptor, None)?; let parq_schema_descr = crate::arrow::arrow_to_parquet_schema(&arrow_schema)?; let parq_fields = parq_schema_descr.root_schema().get_fields(); @@ -1883,19 +1815,14 @@ mod tests { #[test] fn test_arrow_schema_roundtrip_lists() -> Result<()> { - let metadata: HashMap = - [("Key".to_string(), "Value".to_string())] - .iter() - .cloned() - .collect(); + let metadata: HashMap = [("Key".to_string(), "Value".to_string())] + .iter() + .cloned() + .collect(); let schema = Schema::new_with_metadata( vec![ - Field::new_list( - "c21", - Field::new("array", DataType::Boolean, true), - false, - ), + Field::new_list("c21", Field::new("array", DataType::Boolean, true), false), Field::new( "c22", DataType::FixedSizeList( @@ -1926,11 +1853,8 @@ mod tests { // write to an empty parquet file so that schema is serialized let file = tempfile::tempfile().unwrap(); - let writer = ArrowWriter::try_new( - file.try_clone().unwrap(), - Arc::new(schema.clone()), - None, - )?; + let writer = + ArrowWriter::try_new(file.try_clone().unwrap(), Arc::new(schema.clone()), None)?; writer.close()?; // read file back From 73433db794fd7c5bdeaa8765e622c72cd73275e3 Mon Sep 17 00:00:00 2001 From: berkaysynnada Date: Wed, 6 Dec 2023 22:35:35 +0300 Subject: [PATCH 5/8] Clean code --- arrow-array/src/cast.rs | 6 ++ arrow-array/src/types.rs | 100 ++++++++++---------------------- arrow-ord/src/cmp.rs | 66 +++++++++++++++------ parquet/src/arrow/schema/mod.rs | 2 +- 4 files changed, 86 insertions(+), 88 deletions(-) diff --git a/arrow-array/src/cast.rs b/arrow-array/src/cast.rs index 11bf0090c29a..99ac5df5152e 100644 --- a/arrow-array/src/cast.rs +++ b/arrow-array/src/cast.rs @@ -330,6 +330,9 @@ macro_rules! downcast_primitive { }; } +/// This macro functions similarly to [`downcast_primitive`], but it excludes +/// [`arrow_schema::IntervalUnit::DayTime`] and [`arrow_schema::IntervalUnit::MonthDayNano`] +/// because they cannot be simply cast to primitive types during a comparison operation. #[macro_export] macro_rules! downcast_primitive_cmp { ($($data_type:expr),+ => ($m:path $(, $args:tt)*), $($p:pat => $fallback:expr $(,)*)*) => { @@ -428,6 +431,9 @@ macro_rules! downcast_primitive_array { }; } +/// This macro serves a similar function to [`downcast_primitive_array`], but it +/// incorporates [`downcast_primitive_cmp`]. [`downcast_primitive_cmp`] is a specialized +/// version of [`downcast_primitive`] designed specifically for comparison operations. #[macro_export] macro_rules! downcast_primitive_array_cmp { ($values:ident => $e:expr, $($p:pat => $fallback:expr $(,)*)*) => { diff --git a/arrow-array/src/types.rs b/arrow-array/src/types.rs index db465f53505b..2b7c6c527d48 100644 --- a/arrow-array/src/types.rs +++ b/arrow-array/src/types.rs @@ -49,9 +49,7 @@ impl BooleanType { /// static-typed nature of rust types ([`ArrowNativeType`]) for all types that implement [`ArrowNativeType`]. /// /// [`ArrowNativeType`]: arrow_buffer::ArrowNativeType -pub trait ArrowPrimitiveType: - primitive::PrimitiveTypeSealed + 'static -{ +pub trait ArrowPrimitiveType: primitive::PrimitiveTypeSealed + 'static { /// Corresponding Rust native type for the primitive type. type Native: ArrowNativeTypeOp; @@ -76,12 +74,12 @@ mod primitive { } macro_rules! make_type { - ($name:ident, $native_ty:ty, $data_ty:expr, $doc_string: literal, $custom_cmp:expr) => { + ($name:ident, $native_ty:ty, $data_ty:expr, $doc_string: literal) => { #[derive(Debug)] #[doc = $doc_string] pub struct $name {} - impl ArrowPrimitiveType<$custom_cmp> for $name { + impl ArrowPrimitiveType for $name { type Native = $native_ty; const DATA_TYPE: DataType = $data_ty; } @@ -90,208 +88,174 @@ macro_rules! make_type { }; } -make_type!( - Int8Type, - i8, - DataType::Int8, - "A signed 8-bit integer type.", - false -); +make_type!(Int8Type, i8, DataType::Int8, "A signed 8-bit integer type."); make_type!( Int16Type, i16, DataType::Int16, - "A signed 16-bit integer type.", - false + "A signed 16-bit integer type." ); make_type!( Int32Type, i32, DataType::Int32, - "A signed 32-bit integer type.", - false + "A signed 32-bit integer type." ); make_type!( Int64Type, i64, DataType::Int64, - "A signed 64-bit integer type.", - false + "A signed 64-bit integer type." ); make_type!( Int128Type, i128, DataType::Int128, - "A signed 128-bit integer type.", - false + "A signed 128-bit integer type." ); make_type!( UInt8Type, u8, DataType::UInt8, - "An unsigned 8-bit integer type.", - false + "An unsigned 8-bit integer type." ); make_type!( UInt16Type, u16, DataType::UInt16, - "An unsigned 16-bit integer type.", - false + "An unsigned 16-bit integer type." ); make_type!( UInt32Type, u32, DataType::UInt32, - "An unsigned 32-bit integer type.", - false + "An unsigned 32-bit integer type." ); make_type!( UInt64Type, u64, DataType::UInt64, - "An unsigned 64-bit integer type.", - false + "An unsigned 64-bit integer type." ); make_type!( Float16Type, f16, DataType::Float16, - "A 16-bit floating point number type.", - false + "A 16-bit floating point number type." ); make_type!( Float32Type, f32, DataType::Float32, - "A 32-bit floating point number type.", - false + "A 32-bit floating point number type." ); make_type!( Float64Type, f64, DataType::Float64, - "A 64-bit floating point number type.", - false + "A 64-bit floating point number type." ); make_type!( TimestampSecondType, i64, DataType::Timestamp(TimeUnit::Second, None), - "A timestamp second type with an optional timezone.", - false + "A timestamp second type with an optional timezone." ); make_type!( TimestampMillisecondType, i64, DataType::Timestamp(TimeUnit::Millisecond, None), - "A timestamp millisecond type with an optional timezone.", - false + "A timestamp millisecond type with an optional timezone." ); make_type!( TimestampMicrosecondType, i64, DataType::Timestamp(TimeUnit::Microsecond, None), - "A timestamp microsecond type with an optional timezone.", - false + "A timestamp microsecond type with an optional timezone." ); make_type!( TimestampNanosecondType, i64, DataType::Timestamp(TimeUnit::Nanosecond, None), - "A timestamp nanosecond type with an optional timezone.", - false + "A timestamp nanosecond type with an optional timezone." ); make_type!( Date32Type, i32, DataType::Date32, - "A 32-bit date type representing the elapsed time since UNIX epoch in days(32 bits).", - false + "A 32-bit date type representing the elapsed time since UNIX epoch in days(32 bits)." ); make_type!( Date64Type, i64, DataType::Date64, - "A 64-bit date type representing the elapsed time since UNIX epoch in milliseconds(64 bits).", - false + "A 64-bit date type representing the elapsed time since UNIX epoch in milliseconds(64 bits)." ); make_type!( Time32SecondType, i32, DataType::Time32(TimeUnit::Second), - "A 32-bit time type representing the elapsed time since midnight in seconds.", - false + "A 32-bit time type representing the elapsed time since midnight in seconds." ); make_type!( Time32MillisecondType, i32, DataType::Time32(TimeUnit::Millisecond), - "A 32-bit time type representing the elapsed time since midnight in milliseconds.", - false + "A 32-bit time type representing the elapsed time since midnight in milliseconds." ); make_type!( Time64MicrosecondType, i64, DataType::Time64(TimeUnit::Microsecond), - "A 64-bit time type representing the elapsed time since midnight in microseconds.", - false + "A 64-bit time type representing the elapsed time since midnight in microseconds." ); make_type!( Time64NanosecondType, i64, DataType::Time64(TimeUnit::Nanosecond), - "A 64-bit time type representing the elapsed time since midnight in nanoseconds.", - false + "A 64-bit time type representing the elapsed time since midnight in nanoseconds." ); make_type!( IntervalYearMonthType, i32, DataType::Interval(IntervalUnit::YearMonth), - "A “calendar” interval type in months.", - false + "A “calendar” interval type in months." ); make_type!( IntervalDayTimeType, i64, DataType::Interval(IntervalUnit::DayTime), - "A “calendar” interval type in days and milliseconds.", - false + "A “calendar” interval type in days and milliseconds." ); make_type!( IntervalMonthDayNanoType, i128, DataType::Interval(IntervalUnit::MonthDayNano), - "A “calendar” interval type in months, days, and nanoseconds.", - false + "A “calendar” interval type in months, days, and nanoseconds." ); make_type!( DurationSecondType, i64, DataType::Duration(TimeUnit::Second), - "An elapsed time type in seconds.", - false + "An elapsed time type in seconds." ); make_type!( DurationMillisecondType, i64, DataType::Duration(TimeUnit::Millisecond), - "An elapsed time type in milliseconds.", - false + "An elapsed time type in milliseconds." ); make_type!( DurationMicrosecondType, i64, DataType::Duration(TimeUnit::Microsecond), - "An elapsed time type in microseconds.", - false + "An elapsed time type in microseconds." ); make_type!( DurationNanosecondType, i64, DataType::Duration(TimeUnit::Nanosecond), - "An elapsed time type in nanoseconds.", - false + "An elapsed time type in nanoseconds." ); /// A subtype of primitive type that represents legal dictionary keys. diff --git a/arrow-ord/src/cmp.rs b/arrow-ord/src/cmp.rs index f22d69574af1..ad41ba67042a 100644 --- a/arrow-ord/src/cmp.rs +++ b/arrow-ord/src/cmp.rs @@ -23,24 +23,21 @@ //! [here](https://doc.rust-lang.org/stable/core/arch/) for more information. //! +use std::ops::Not; + use arrow_array::cast::AsArray; use arrow_array::types::{ - ArrowPrimitiveType, ByteArrayType, Int128Type, Int64Type, IntervalDayTimeType, - IntervalMonthDayNanoType, + ByteArrayType, Int128Type, Int64Type, IntervalDayTimeType, IntervalMonthDayNanoType, }; use arrow_array::{ - downcast_primitive_array, downcast_primitive_array_cmp, AnyDictionaryArray, Array, - ArrowNativeTypeOp, BooleanArray, Datum, FixedSizeBinaryArray, GenericByteArray, PrimitiveArray, + downcast_primitive_array_cmp, AnyDictionaryArray, Array, ArrowNativeTypeOp, BooleanArray, + Datum, FixedSizeBinaryArray, GenericByteArray, PrimitiveArray, }; use arrow_buffer::bit_util::ceil; use arrow_buffer::{BooleanBuffer, MutableBuffer, NullBuffer}; use arrow_schema::ArrowError; use arrow_schema::IntervalUnit; use arrow_select::take::take; -use std::ops::Not; - -const MILLIS_PER_DAY: i64 = 86_400_000; -const NANOS_PER_DAY: i128 = 86_400_000_000_000; #[derive(Debug, Copy, Clone)] enum Op { @@ -211,9 +208,6 @@ fn compare_op(op: Op, lhs: &dyn Datum, rhs: &dyn Datum) -> Result BooleanBuffer { let d = downcast_primitive_array_cmp! { @@ -552,6 +546,20 @@ impl<'a> ArrayOrd for &'a FixedSizeBinaryArray { } } +/// Computes max or min milliseconds from a `PrimitiveArray` based on +/// the comparison operator (`op`) and operand side (`lhs`). This function is essential for +/// accurate interval comparison operations by considering the leap seconds. +/// +/// # Arguments +/// * `dt` - Reference to an array, expected to be `PrimitiveArray`. +/// * `op` - Comparison operator. +/// * `lhs` - Boolean indicating if the array is on the left-hand side of the operator. +/// +/// # Returns +/// A `PrimitiveArray` with computed milliseconds values. +/// +/// # Panics +/// If `dt` is not a `PrimitiveArray` or if an invalid operator is used. #[inline] fn safer_interval_dt(dt: &dyn Array, op: Op, lhs: bool) -> PrimitiveArray { match dt.as_primitive_opt::() { @@ -562,7 +570,7 @@ fn safer_interval_dt(dt: &dyn Array, op: Op, lhs: bool) -> PrimitiveArray { PrimitiveArray::::from_iter(dt.iter().map(|dt| dt.map(dt_in_millis_min))) } - (Op::Equal, _) | (Op::NotEqual, _) => PrimitiveArray::::from_iter(dt.iter()), + (Op::Equal | Op::NotEqual, _) => PrimitiveArray::::from_iter(dt.iter()), _ => { panic!( "Invalid operator {:?} for Interval(IntervalDayTime) comparison", @@ -576,6 +584,21 @@ fn safer_interval_dt(dt: &dyn Array, op: Op, lhs: bool) -> PrimitiveArray` based on +/// the comparison operator (`op`) and operand side (`lhs`). This function is crucial for +/// precise interval comparison operations involving months and days, which can result in different +/// number of nanoseconds depending on the timestamp. +/// +/// # Arguments +/// * `mdn` - Reference to an array, expected to be `PrimitiveArray`. +/// * `op` - Comparison operator. +/// * `lhs` - Boolean indicating if the array is on the left-hand side of the operator. +/// +/// # Returns +/// A `PrimitiveArray` with computed nanoseconds values. +/// +/// # Panics +/// If `mdn` is not a `PrimitiveArray` or if an invalid operator is used. #[inline] fn safer_interval_mdn(mdn: &dyn Array, op: Op, lhs: bool) -> PrimitiveArray { match mdn.as_primitive_opt::() { @@ -590,9 +613,7 @@ fn safer_interval_mdn(mdn: &dyn Array, op: Op, lhs: bool) -> PrimitiveArray { - PrimitiveArray::::from_iter(mdn.iter()) - } + (Op::Equal | Op::NotEqual, _) => PrimitiveArray::::from_iter(mdn.iter()), _ => { panic!("Invalid operator for Interval(IntervalMonthDayNano) comparison") } @@ -603,34 +624,41 @@ fn safer_interval_mdn(mdn: &dyn Array, op: Op, lhs: bool) -> PrimitiveArray i64 { let d = dt >> 32; let m = dt as i32 as i64; - d * (MILLIS_PER_DAY + 1_000) + m + d * (86_400_000 + 1_000) + m } +/// Calculates the minimum milliseconds for an `IntervalDayTimeType` interval, excluding leap seconds. #[inline] fn dt_in_millis_min(dt: i64) -> i64 { let d = dt >> 32; let m = dt as i32 as i64; - d * (MILLIS_PER_DAY) + m + d * (86_400_000) + m } +/// Calculates the maximum nanoseconds for an `IntervalMonthDayNanoType` interval, assuming +/// 31 days per month and adding extra nanoseconds for longer days. #[inline] fn mdn_in_nanos_max(mdn: i128) -> i128 { let m = (mdn >> 96) as i32; let d = (mdn >> 64) as i32; let n = mdn as i64; - ((m as i128 * 31) + d as i128) * (NANOS_PER_DAY + 1_000_000_000) + n as i128 + ((m as i128 * 31) + d as i128) * (86_400_000_000_000 + 1_000_000_000) + n as i128 } +/// Calculates the minimum nanoseconds for an `IntervalMonthDayNanoType` interval, assuming +/// 28 days per month and excluding additional nanoseconds for longer days. #[inline] fn mdn_in_nanos_min(mdn: i128) -> i128 { let m = (mdn >> 96) as i32; let d = (mdn >> 64) as i32; let n = mdn as i64; - ((m as i128 * 28) + d as i128) * (NANOS_PER_DAY) + n as i128 + ((m as i128 * 28) + d as i128) * (86_400_000_000_000) + n as i128 } #[cfg(test)] diff --git a/parquet/src/arrow/schema/mod.rs b/parquet/src/arrow/schema/mod.rs index 3f619b563ff7..0027f1bf314c 100644 --- a/parquet/src/arrow/schema/mod.rs +++ b/parquet/src/arrow/schema/mod.rs @@ -1868,4 +1868,4 @@ mod tests { fn test_get_arrow_schema_from_metadata() { assert!(get_arrow_schema_from_metadata("").is_err()); } -} +} \ No newline at end of file From 6b901fbfddb1a4dcbf55a6d592b1c5a6f5f11a9a Mon Sep 17 00:00:00 2001 From: berkaysynnada Date: Wed, 6 Dec 2023 22:42:19 +0300 Subject: [PATCH 6/8] Remove fmt diff --- parquet/src/arrow/schema/mod.rs | 185 ++++++++++++++++++++++---------- 1 file changed, 131 insertions(+), 54 deletions(-) diff --git a/parquet/src/arrow/schema/mod.rs b/parquet/src/arrow/schema/mod.rs index 0027f1bf314c..6de81162f1ef 100644 --- a/parquet/src/arrow/schema/mod.rs +++ b/parquet/src/arrow/schema/mod.rs @@ -32,7 +32,8 @@ use arrow_ipc::writer; use arrow_schema::{DataType, Field, Fields, Schema, TimeUnit}; use crate::basic::{ - ConvertedType, LogicalType, Repetition, TimeUnit as ParquetTimeUnit, Type as PhysicalType, + ConvertedType, LogicalType, Repetition, TimeUnit as ParquetTimeUnit, + Type as PhysicalType, }; use crate::errors::{ParquetError, Result}; use crate::file::{metadata::KeyValue, properties::WriterProperties}; @@ -54,7 +55,11 @@ pub fn parquet_to_arrow_schema( parquet_schema: &SchemaDescriptor, key_value_metadata: Option<&Vec>, ) -> Result { - parquet_to_arrow_schema_by_columns(parquet_schema, ProjectionMask::all(), key_value_metadata) + parquet_to_arrow_schema_by_columns( + parquet_schema, + ProjectionMask::all(), + key_value_metadata, + ) } /// Convert parquet schema to arrow schema including optional metadata, @@ -194,7 +199,10 @@ fn encode_arrow_schema(schema: &Schema) -> String { /// Mutates writer metadata by storing the encoded Arrow schema. /// If there is an existing Arrow schema metadata, it is replaced. -pub(crate) fn add_encoded_arrow_schema_to_metadata(schema: &Schema, props: &mut WriterProperties) { +pub(crate) fn add_encoded_arrow_schema_to_metadata( + schema: &Schema, + props: &mut WriterProperties, +) { let encoded = encode_arrow_schema(schema); let schema_kv = KeyValue { @@ -262,15 +270,16 @@ fn parse_key_value_metadata( /// Convert parquet column schema to arrow field. pub fn parquet_to_arrow_field(parquet_column: &ColumnDescriptor) -> Result { let field = complex::convert_type(&parquet_column.self_type_ptr())?; - let mut ret = Field::new(parquet_column.name(), field.arrow_type, field.nullable); + let mut ret = Field::new( + parquet_column.name(), + field.arrow_type, + field.nullable, + ); let basic_info = parquet_column.self_type().get_basic_info(); if basic_info.has_id() { let mut meta = HashMap::with_capacity(1); - meta.insert( - PARQUET_FIELD_ID_META_KEY.to_string(), - basic_info.id().to_string(), - ); + meta.insert(PARQUET_FIELD_ID_META_KEY.to_string(), basic_info.id().to_string()); ret.set_metadata(meta); } @@ -332,7 +341,7 @@ fn arrow_to_parquet_type(field: &Field) -> Result { .with_repetition(repetition) .with_id(id) .build(), - DataType::Int128 => unreachable!(), + DataType::Int128 => unimplemented!(), DataType::UInt8 => Type::primitive_type_builder(name, PhysicalType::INT32) .with_logical_type(Some(LogicalType::Integer { bit_width: 8, @@ -393,9 +402,15 @@ fn arrow_to_parquet_type(field: &Field) -> Result { is_adjusted_to_u_t_c: matches!(tz, Some(z) if !z.as_ref().is_empty()), unit: match time_unit { TimeUnit::Second => unreachable!(), - TimeUnit::Millisecond => ParquetTimeUnit::MILLIS(Default::default()), - TimeUnit::Microsecond => ParquetTimeUnit::MICROS(Default::default()), - TimeUnit::Nanosecond => ParquetTimeUnit::NANOS(Default::default()), + TimeUnit::Millisecond => { + ParquetTimeUnit::MILLIS(Default::default()) + } + TimeUnit::Microsecond => { + ParquetTimeUnit::MICROS(Default::default()) + } + TimeUnit::Nanosecond => { + ParquetTimeUnit::NANOS(Default::default()) + } }, })) .with_repetition(repetition) @@ -443,7 +458,9 @@ fn arrow_to_parquet_type(field: &Field) -> Result { .with_repetition(repetition) .with_id(id) .build(), - DataType::Duration(_) => Err(arrow_err!("Converting Duration to parquet not supported",)), + DataType::Duration(_) => { + Err(arrow_err!("Converting Duration to parquet not supported",)) + } DataType::Interval(_) => { Type::primitive_type_builder(name, PhysicalType::FIXED_LEN_BYTE_ARRAY) .with_converted_type(ConvertedType::INTERVAL) @@ -465,7 +482,8 @@ fn arrow_to_parquet_type(field: &Field) -> Result { .with_length(*length) .build() } - DataType::Decimal128(precision, scale) | DataType::Decimal256(precision, scale) => { + DataType::Decimal128(precision, scale) + | DataType::Decimal256(precision, scale) => { // Decimal precision determines the Parquet physical type to use. // Following the: https://github.com/apache/parquet-format/blob/master/LogicalTypes.md#decimal let (physical_type, length) = if *precision > 1 && *precision <= 9 { @@ -512,7 +530,9 @@ fn arrow_to_parquet_type(field: &Field) -> Result { } DataType::Struct(fields) => { if fields.is_empty() { - return Err(arrow_err!("Parquet does not support writing empty structs",)); + return Err( + arrow_err!("Parquet does not support writing empty structs",), + ); } // recursively convert children to types/nodes let fields = fields @@ -602,7 +622,8 @@ mod tests { let parquet_group_type = parse_message_type(message_type).unwrap(); let parquet_schema = SchemaDescriptor::new(Arc::new(parquet_group_type)); - let converted_arrow_schema = parquet_to_arrow_schema(&parquet_schema, None).unwrap(); + let converted_arrow_schema = + parquet_to_arrow_schema(&parquet_schema, None).unwrap(); let arrow_fields = Fields::from(vec![ Field::new("boolean", DataType::Boolean, false), @@ -640,7 +661,8 @@ mod tests { let parquet_group_type = parse_message_type(message_type).unwrap(); let parquet_schema = SchemaDescriptor::new(Arc::new(parquet_group_type)); - let converted_arrow_schema = parquet_to_arrow_schema(&parquet_schema, None).unwrap(); + let converted_arrow_schema = + parquet_to_arrow_schema(&parquet_schema, None).unwrap(); let arrow_fields = Fields::from(vec![ Field::new("decimal1", DataType::Decimal128(4, 2), false), @@ -666,7 +688,8 @@ mod tests { let parquet_group_type = parse_message_type(message_type).unwrap(); let parquet_schema = SchemaDescriptor::new(Arc::new(parquet_group_type)); - let converted_arrow_schema = parquet_to_arrow_schema(&parquet_schema, None).unwrap(); + let converted_arrow_schema = + parquet_to_arrow_schema(&parquet_schema, None).unwrap(); let arrow_fields = Fields::from(vec![ Field::new("binary", DataType::Binary, false), @@ -687,7 +710,8 @@ mod tests { let parquet_group_type = parse_message_type(message_type).unwrap(); let parquet_schema = SchemaDescriptor::new(Arc::new(parquet_group_type)); - let converted_arrow_schema = parquet_to_arrow_schema(&parquet_schema, None).unwrap(); + let converted_arrow_schema = + parquet_to_arrow_schema(&parquet_schema, None).unwrap(); let arrow_fields = Fields::from(vec![ Field::new("boolean", DataType::Boolean, false), @@ -695,9 +719,12 @@ mod tests { ]); assert_eq!(&arrow_fields, converted_arrow_schema.fields()); - let converted_arrow_schema = - parquet_to_arrow_schema_by_columns(&parquet_schema, ProjectionMask::all(), None) - .unwrap(); + let converted_arrow_schema = parquet_to_arrow_schema_by_columns( + &parquet_schema, + ProjectionMask::all(), + None, + ) + .unwrap(); assert_eq!(&arrow_fields, converted_arrow_schema.fields()); } @@ -895,7 +922,8 @@ mod tests { let parquet_group_type = parse_message_type(message_type).unwrap(); let parquet_schema = SchemaDescriptor::new(Arc::new(parquet_group_type)); - let converted_arrow_schema = parquet_to_arrow_schema(&parquet_schema, None).unwrap(); + let converted_arrow_schema = + parquet_to_arrow_schema(&parquet_schema, None).unwrap(); let converted_fields = converted_arrow_schema.fields(); assert_eq!(arrow_fields.len(), converted_fields.len()); @@ -973,7 +1001,8 @@ mod tests { let parquet_group_type = parse_message_type(message_type).unwrap(); let parquet_schema = SchemaDescriptor::new(Arc::new(parquet_group_type)); - let converted_arrow_schema = parquet_to_arrow_schema(&parquet_schema, None).unwrap(); + let converted_arrow_schema = + parquet_to_arrow_schema(&parquet_schema, None).unwrap(); let converted_fields = converted_arrow_schema.fields(); assert_eq!(arrow_fields.len(), converted_fields.len()); @@ -1067,7 +1096,8 @@ mod tests { let parquet_group_type = parse_message_type(message_type).unwrap(); let parquet_schema = SchemaDescriptor::new(Arc::new(parquet_group_type)); - let converted_arrow_schema = parquet_to_arrow_schema(&parquet_schema, None).unwrap(); + let converted_arrow_schema = + parquet_to_arrow_schema(&parquet_schema, None).unwrap(); let converted_fields = converted_arrow_schema.fields(); assert_eq!(arrow_fields.len(), converted_fields.len()); @@ -1084,7 +1114,8 @@ mod tests { Field::new("leaf1", DataType::Boolean, false), Field::new("leaf2", DataType::Int32, false), ]); - let group1_struct = Field::new("group1", DataType::Struct(group1_fields), false); + let group1_struct = + Field::new("group1", DataType::Struct(group1_fields), false); arrow_fields.push(group1_struct); let leaf3_field = Field::new("leaf3", DataType::Int64, false); @@ -1103,7 +1134,8 @@ mod tests { let parquet_group_type = parse_message_type(message_type).unwrap(); let parquet_schema = SchemaDescriptor::new(Arc::new(parquet_group_type)); - let converted_arrow_schema = parquet_to_arrow_schema(&parquet_schema, None).unwrap(); + let converted_arrow_schema = + parquet_to_arrow_schema(&parquet_schema, None).unwrap(); let converted_fields = converted_arrow_schema.fields(); assert_eq!(arrow_fields.len(), converted_fields.len()); @@ -1256,7 +1288,8 @@ mod tests { let parquet_group_type = parse_message_type(message_type).unwrap(); let parquet_schema = SchemaDescriptor::new(Arc::new(parquet_group_type)); - let converted_arrow_schema = parquet_to_arrow_schema(&parquet_schema, None).unwrap(); + let converted_arrow_schema = + parquet_to_arrow_schema(&parquet_schema, None).unwrap(); let converted_fields = converted_arrow_schema.fields(); assert_eq!(arrow_fields.len(), converted_fields.len()); @@ -1481,11 +1514,20 @@ mod tests { vec![ Field::new("bools", DataType::Boolean, false), Field::new("uint32", DataType::UInt32, false), - Field::new_list("int32", Field::new("element", DataType::Int32, true), false), + Field::new_list( + "int32", + Field::new("element", DataType::Int32, true), + false, + ), ], false, ), - Field::new_dictionary("dictionary_strings", DataType::Int32, DataType::Utf8, false), + Field::new_dictionary( + "dictionary_strings", + DataType::Int32, + DataType::Utf8, + false, + ), Field::new("decimal_int32", DataType::Decimal128(8, 2), false), Field::new("decimal_int64", DataType::Decimal128(16, 2), false), Field::new("decimal_fix_length", DataType::Decimal128(30, 2), false), @@ -1570,8 +1612,10 @@ mod tests { let schema = Schema::new_with_metadata( vec![ - Field::new("c1", DataType::Utf8, false) - .with_metadata(meta(&[("Key", "Foo"), (PARQUET_FIELD_ID_META_KEY, "2")])), + Field::new("c1", DataType::Utf8, false).with_metadata(meta(&[ + ("Key", "Foo"), + (PARQUET_FIELD_ID_META_KEY, "2"), + ])), Field::new("c2", DataType::Binary, false), Field::new("c3", DataType::FixedSizeBinary(3), false), Field::new("c4", DataType::Boolean, false), @@ -1589,7 +1633,10 @@ mod tests { ), Field::new( "c17", - DataType::Timestamp(TimeUnit::Microsecond, Some("Africa/Johannesburg".into())), + DataType::Timestamp( + TimeUnit::Microsecond, + Some("Africa/Johannesburg".into()), + ), false, ), Field::new( @@ -1601,8 +1648,10 @@ mod tests { Field::new("c20", DataType::Interval(IntervalUnit::YearMonth), false), Field::new_list( "c21", - Field::new("item", DataType::Boolean, true) - .with_metadata(meta(&[("Key", "Bar"), (PARQUET_FIELD_ID_META_KEY, "5")])), + Field::new("item", DataType::Boolean, true).with_metadata(meta(&[ + ("Key", "Bar"), + (PARQUET_FIELD_ID_META_KEY, "5"), + ])), false, ) .with_metadata(meta(&[(PARQUET_FIELD_ID_META_KEY, "4")])), @@ -1652,7 +1701,10 @@ mod tests { // Field::new("c30", DataType::Duration(TimeUnit::Nanosecond), false), Field::new_dict( "c31", - DataType::Dictionary(Box::new(DataType::Int32), Box::new(DataType::Utf8)), + DataType::Dictionary( + Box::new(DataType::Int32), + Box::new(DataType::Utf8), + ), true, 123, true, @@ -1687,7 +1739,11 @@ mod tests { "c39", "key_value", Field::new("key", DataType::Utf8, false), - Field::new_list("value", Field::new("element", DataType::Utf8, true), true), + Field::new_list( + "value", + Field::new("element", DataType::Utf8, true), + true, + ), false, // fails to roundtrip keys_sorted true, ), @@ -1726,8 +1782,11 @@ mod tests { // write to an empty parquet file so that schema is serialized let file = tempfile::tempfile().unwrap(); - let writer = - ArrowWriter::try_new(file.try_clone().unwrap(), Arc::new(schema.clone()), None)?; + let writer = ArrowWriter::try_new( + file.try_clone().unwrap(), + Arc::new(schema.clone()), + None, + )?; writer.close()?; // read file back @@ -1786,23 +1845,33 @@ mod tests { }; let schema = Schema::new_with_metadata( vec![ - Field::new("c1", DataType::Utf8, true) - .with_metadata(meta(&[(PARQUET_FIELD_ID_META_KEY, "1")])), - Field::new("c2", DataType::Utf8, true) - .with_metadata(meta(&[(PARQUET_FIELD_ID_META_KEY, "2")])), + Field::new("c1", DataType::Utf8, true).with_metadata(meta(&[ + (PARQUET_FIELD_ID_META_KEY, "1"), + ])), + Field::new("c2", DataType::Utf8, true).with_metadata(meta(&[ + (PARQUET_FIELD_ID_META_KEY, "2"), + ])), ], HashMap::new(), ); - let writer = ArrowWriter::try_new(vec![], Arc::new(schema.clone()), None)?; + let writer = ArrowWriter::try_new( + vec![], + Arc::new(schema.clone()), + None, + )?; let parquet_bytes = writer.into_inner()?; - let reader = - crate::file::reader::SerializedFileReader::new(bytes::Bytes::from(parquet_bytes))?; + let reader = crate::file::reader::SerializedFileReader::new( + bytes::Bytes::from(parquet_bytes), + )?; let schema_descriptor = reader.metadata().file_metadata().schema_descr_ptr(); // don't pass metadata so field ids are read from Parquet and not from serialized Arrow schema - let arrow_schema = crate::arrow::parquet_to_arrow_schema(&schema_descriptor, None)?; + let arrow_schema = crate::arrow::parquet_to_arrow_schema( + &schema_descriptor, + None, + )?; let parq_schema_descr = crate::arrow::arrow_to_parquet_schema(&arrow_schema)?; let parq_fields = parq_schema_descr.root_schema().get_fields(); @@ -1815,14 +1884,19 @@ mod tests { #[test] fn test_arrow_schema_roundtrip_lists() -> Result<()> { - let metadata: HashMap = [("Key".to_string(), "Value".to_string())] - .iter() - .cloned() - .collect(); + let metadata: HashMap = + [("Key".to_string(), "Value".to_string())] + .iter() + .cloned() + .collect(); let schema = Schema::new_with_metadata( vec![ - Field::new_list("c21", Field::new("array", DataType::Boolean, true), false), + Field::new_list( + "c21", + Field::new("array", DataType::Boolean, true), + false, + ), Field::new( "c22", DataType::FixedSizeList( @@ -1853,8 +1927,11 @@ mod tests { // write to an empty parquet file so that schema is serialized let file = tempfile::tempfile().unwrap(); - let writer = - ArrowWriter::try_new(file.try_clone().unwrap(), Arc::new(schema.clone()), None)?; + let writer = ArrowWriter::try_new( + file.try_clone().unwrap(), + Arc::new(schema.clone()), + None, + )?; writer.close()?; // read file back From 48bb07df81bba71c9e1219063711df3fa4a35552 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Berkay=20=C5=9Eahin?= <124376117+berkaysynnada@users.noreply.github.com> Date: Thu, 7 Dec 2023 09:28:01 +0300 Subject: [PATCH 7/8] Update arrow-ord/src/cmp.rs Co-authored-by: Mehmet Ozan Kabak --- arrow-ord/src/cmp.rs | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/arrow-ord/src/cmp.rs b/arrow-ord/src/cmp.rs index ad41ba67042a..dea069e141c7 100644 --- a/arrow-ord/src/cmp.rs +++ b/arrow-ord/src/cmp.rs @@ -35,8 +35,7 @@ use arrow_array::{ }; use arrow_buffer::bit_util::ceil; use arrow_buffer::{BooleanBuffer, MutableBuffer, NullBuffer}; -use arrow_schema::ArrowError; -use arrow_schema::IntervalUnit; +use arrow_schema::{ArrowError, IntervalUnit}; use arrow_select::take::take; #[derive(Debug, Copy, Clone)] From 9f9382e769e2c294e99d3a354ff765b49214a126 Mon Sep 17 00:00:00 2001 From: berkaysynnada Date: Thu, 7 Dec 2023 16:06:20 +0300 Subject: [PATCH 8/8] Intervals account for a reference timestamp --- arrow-ord/src/cmp.rs | 229 ++++++++++++++++++++---------------- arrow-ord/src/comparison.rs | 8 +- 2 files changed, 130 insertions(+), 107 deletions(-) diff --git a/arrow-ord/src/cmp.rs b/arrow-ord/src/cmp.rs index dea069e141c7..4b279feedf49 100644 --- a/arrow-ord/src/cmp.rs +++ b/arrow-ord/src/cmp.rs @@ -217,8 +217,8 @@ fn compare_op(op: Op, lhs: &dyn Datum, rhs: &dyn Datum) -> Result apply(op, l.as_binary::(), l_s, l_v, r.as_binary::(), r_s, r_v), (LargeBinary, LargeBinary) => apply(op, l.as_binary::(), l_s, l_v, r.as_binary::(), r_s, r_v), (FixedSizeBinary(_), FixedSizeBinary(_)) => apply(op, l.as_fixed_size_binary(), l_s, l_v, r.as_fixed_size_binary(), r_s, r_v), - (Interval(IntervalUnit::DayTime), Interval(IntervalUnit::DayTime)) => apply(op, safer_interval_dt(l, op, true).values().as_ref(), l_s, l_v, safer_interval_dt(r, op, false).values().as_ref(), r_s, r_v), - (Interval(IntervalUnit::MonthDayNano), Interval(IntervalUnit::MonthDayNano)) => apply(op, safer_interval_mdn(l, op, true).values().as_ref(), l_s, l_v, safer_interval_mdn(r, op, false).values().as_ref(), r_s, r_v), + (Interval(IntervalUnit::DayTime), Interval(IntervalUnit::DayTime)) => apply_interval_dt(op, l, l_s, l_v, r, r_s, r_v), + (Interval(IntervalUnit::MonthDayNano), Interval(IntervalUnit::MonthDayNano)) => apply_interval_mdn(op, l, l_s, l_v, r, r_s, r_v), (Null, Null) => None, _ => unreachable!(), }; @@ -346,6 +346,82 @@ fn apply( } } +fn apply_interval_dt( + op: Op, + l: &dyn Array, + l_s: bool, + l_v: Option<&dyn AnyDictionaryArray>, + r: &dyn Array, + r_s: bool, + r_v: Option<&dyn AnyDictionaryArray>, +) -> Option { + let evaluate_min = apply( + op, + interval_dt_min(l).values().as_ref(), + l_s, + l_v, + interval_dt_min(r).values().as_ref(), + r_s, + r_v, + ); + let evaluate_max = apply( + op, + interval_dt_max(l).values().as_ref(), + l_s, + l_v, + interval_dt_max(r).values().as_ref(), + r_s, + r_v, + ); + definite_comparison(evaluate_min, evaluate_max) +} + +fn apply_interval_mdn( + op: Op, + l: &dyn Array, + l_s: bool, + l_v: Option<&dyn AnyDictionaryArray>, + r: &dyn Array, + r_s: bool, + r_v: Option<&dyn AnyDictionaryArray>, +) -> Option { + let evaluate_min = apply( + op, + interval_mdn_min(l).values().as_ref(), + l_s, + l_v, + interval_mdn_min(r).values().as_ref(), + r_s, + r_v, + ); + let evaluate_max = apply( + op, + interval_mdn_max(l).values().as_ref(), + l_s, + l_v, + interval_mdn_max(r).values().as_ref(), + r_s, + r_v, + ); + definite_comparison(evaluate_min, evaluate_max) +} + +fn definite_comparison( + min: Option, + max: Option, +) -> Option { + min.and_then(|min_values| { + max.map(|max_values| { + BooleanBuffer::from_iter( + min_values + .into_iter() + .zip(&max_values) + .map(|(min, max)| min & max), + ) + }) + }) +} + /// Perform a take operation on `buffer` with the given dictionary fn take_bits(v: &dyn AnyDictionaryArray, buffer: BooleanBuffer) -> BooleanBuffer { let array = take(&BooleanArray::new(buffer, None), v.keys(), None).unwrap(); @@ -545,119 +621,66 @@ impl<'a> ArrayOrd for &'a FixedSizeBinaryArray { } } -/// Computes max or min milliseconds from a `PrimitiveArray` based on -/// the comparison operator (`op`) and operand side (`lhs`). This function is essential for -/// accurate interval comparison operations by considering the leap seconds. -/// -/// # Arguments -/// * `dt` - Reference to an array, expected to be `PrimitiveArray`. -/// * `op` - Comparison operator. -/// * `lhs` - Boolean indicating if the array is on the left-hand side of the operator. -/// -/// # Returns -/// A `PrimitiveArray` with computed milliseconds values. -/// -/// # Panics -/// If `dt` is not a `PrimitiveArray` or if an invalid operator is used. #[inline] -fn safer_interval_dt(dt: &dyn Array, op: Op, lhs: bool) -> PrimitiveArray { - match dt.as_primitive_opt::() { - Some(dt) => match (op, lhs) { - (Op::Less | Op::LessEqual, true) | (Op::Greater | Op::GreaterEqual, false) => { - PrimitiveArray::::from_iter(dt.iter().map(|dt| dt.map(dt_in_millis_max))) - } - (Op::Greater | Op::GreaterEqual, true) | (Op::Less | Op::LessEqual, false) => { - PrimitiveArray::::from_iter(dt.iter().map(|dt| dt.map(dt_in_millis_min))) - } - (Op::Equal | Op::NotEqual, _) => PrimitiveArray::::from_iter(dt.iter()), - _ => { - panic!( - "Invalid operator {:?} for Interval(IntervalDayTime) comparison", - op - ) - } - }, - _ => { - panic!("Invalid datatype for Interval(IntervalDayTime) comparison") - } +fn interval_dt_min(dt: &dyn Array) -> PrimitiveArray { + if let Some(dt) = dt.as_primitive_opt::() { + PrimitiveArray::::from_iter(dt.iter().map(|dt| { + dt.map(|dt| { + let d = dt >> 32; + let m = dt as i32 as i64; + d * (86_400_000) + m + }) + })) + } else { + panic!("Invalid datatype for Interval(IntervalDayTime) comparison") } } -/// Computes max or min nanoseconds from a `PrimitiveArray` based on -/// the comparison operator (`op`) and operand side (`lhs`). This function is crucial for -/// precise interval comparison operations involving months and days, which can result in different -/// number of nanoseconds depending on the timestamp. -/// -/// # Arguments -/// * `mdn` - Reference to an array, expected to be `PrimitiveArray`. -/// * `op` - Comparison operator. -/// * `lhs` - Boolean indicating if the array is on the left-hand side of the operator. -/// -/// # Returns -/// A `PrimitiveArray` with computed nanoseconds values. -/// -/// # Panics -/// If `mdn` is not a `PrimitiveArray` or if an invalid operator is used. #[inline] -fn safer_interval_mdn(mdn: &dyn Array, op: Op, lhs: bool) -> PrimitiveArray { - match mdn.as_primitive_opt::() { - Some(mdn) => match (op, lhs) { - (Op::Less | Op::LessEqual, true) | (Op::Greater | Op::GreaterEqual, false) => { - PrimitiveArray::::from_iter( - mdn.iter().map(|mdn| mdn.map(mdn_in_nanos_max)), - ) - } - (Op::Greater | Op::GreaterEqual, true) | (Op::Less | Op::LessEqual, false) => { - PrimitiveArray::::from_iter( - mdn.iter().map(|mdn| mdn.map(mdn_in_nanos_min)), - ) - } - (Op::Equal | Op::NotEqual, _) => PrimitiveArray::::from_iter(mdn.iter()), - _ => { - panic!("Invalid operator for Interval(IntervalMonthDayNano) comparison") - } - }, - _ => { - panic!("Invalid datatype for Interval(IntervalMonthDayNano) comparison") - } +fn interval_dt_max(dt: &dyn Array) -> PrimitiveArray { + if let Some(dt) = dt.as_primitive_opt::() { + PrimitiveArray::::from_iter(dt.iter().map(|dt| { + dt.map(|dt| { + let d = dt >> 32; + let m = dt as i32 as i64; + d * (86_400_000 + 1_000) + m + }) + })) + } else { + panic!("Invalid datatype for Interval(IntervalDayTime) comparison") } } -/// Calculates the maximum milliseconds for an `IntervalDayTimeType` interval, accounting -/// for leap seconds by adding an extra 1000 milliseconds for each day. #[inline] -fn dt_in_millis_max(dt: i64) -> i64 { - let d = dt >> 32; - let m = dt as i32 as i64; - d * (86_400_000 + 1_000) + m -} - -/// Calculates the minimum milliseconds for an `IntervalDayTimeType` interval, excluding leap seconds. -#[inline] -fn dt_in_millis_min(dt: i64) -> i64 { - let d = dt >> 32; - let m = dt as i32 as i64; - d * (86_400_000) + m -} - -/// Calculates the maximum nanoseconds for an `IntervalMonthDayNanoType` interval, assuming -/// 31 days per month and adding extra nanoseconds for longer days. -#[inline] -fn mdn_in_nanos_max(mdn: i128) -> i128 { - let m = (mdn >> 96) as i32; - let d = (mdn >> 64) as i32; - let n = mdn as i64; - ((m as i128 * 31) + d as i128) * (86_400_000_000_000 + 1_000_000_000) + n as i128 +fn interval_mdn_min(mdn: &dyn Array) -> PrimitiveArray { + if let Some(mdn) = mdn.as_primitive_opt::() { + PrimitiveArray::::from_iter(mdn.iter().map(|mdn| { + mdn.map(|mdn| { + let m = (mdn >> 96) as i32; + let d = (mdn >> 64) as i32; + let n = mdn as i64; + ((m as i128 * 28) + d as i128) * (86_400_000_000_000) + n as i128 + }) + })) + } else { + panic!("Invalid datatype for Interval(IntervalMonthDayNano) comparison") + } } -/// Calculates the minimum nanoseconds for an `IntervalMonthDayNanoType` interval, assuming -/// 28 days per month and excluding additional nanoseconds for longer days. #[inline] -fn mdn_in_nanos_min(mdn: i128) -> i128 { - let m = (mdn >> 96) as i32; - let d = (mdn >> 64) as i32; - let n = mdn as i64; - ((m as i128 * 28) + d as i128) * (86_400_000_000_000) + n as i128 +fn interval_mdn_max(mdn: &dyn Array) -> PrimitiveArray { + if let Some(mdn) = mdn.as_primitive_opt::() { + PrimitiveArray::::from_iter(mdn.iter().map(|mdn| { + mdn.map(|mdn| { + let m = (mdn >> 96) as i32; + let d = (mdn >> 64) as i32; + let n = mdn as i64; + ((m as i128 * 31) + d as i128) * (86_400_000_000_000 + 1_000_000_000) + n as i128 + }) + })) + } else { + panic!("Invalid datatype for Interval(IntervalMonthDayNano) comparison") + } } #[cfg(test)] diff --git a/arrow-ord/src/comparison.rs b/arrow-ord/src/comparison.rs index 3c2b67a18fbd..de8b769ea9ca 100644 --- a/arrow-ord/src/comparison.rs +++ b/arrow-ord/src/comparison.rs @@ -2047,21 +2047,21 @@ mod tests { assert_eq!(res, res_eq); assert_eq!( &res, - &BooleanArray::from(vec![ Some(true), Some(true), Some(true), Some(true), Some(false)]) + &BooleanArray::from(vec![ Some(true), Some(true), Some(true), Some(true), Some(true)]) ); let res = lt(&b, &a).unwrap(); let res_eq = lt_eq(&b, &a).unwrap(); assert_eq!(res, res_eq); assert_eq!( &res, - &BooleanArray::from(vec![ Some(true), Some(true), Some(true), Some(true), Some(false)]) + &BooleanArray::from(vec![ Some(true), Some(true), Some(true), Some(true), Some(true)]) ); let a = IntervalMonthDayNanoArray::from( - vec![Some(IntervalMonthDayNanoType::make_value(0, 0, 1)),Some(IntervalMonthDayNanoType::make_value(0, 1, -1_000_000_000)),Some(IntervalMonthDayNanoType::make_value(3, 2, -100_000_000_000)),Some(IntervalMonthDayNanoType::make_value(0, 1, 1)),Some(IntervalMonthDayNanoType::make_value(1, 28, 0)), Some(IntervalMonthDayNanoType::make_value(10, 0, -1_000_000_000_000))], + vec![Some(IntervalMonthDayNanoType::make_value(0, 0, 1)),Some(IntervalMonthDayNanoType::make_value(0, 1, -1_000_000_000)),Some(IntervalMonthDayNanoType::make_value(3, 2, -100_000_000_000)),Some(IntervalMonthDayNanoType::make_value(0, 1, 86_400_000_000_999)),Some(IntervalMonthDayNanoType::make_value(1, 28, 0)), Some(IntervalMonthDayNanoType::make_value(10, 0, -1_000_000_000_000))], ); let b = IntervalMonthDayNanoArray::from( - vec![Some(IntervalMonthDayNanoType::make_value(0, 0,0)),Some(IntervalMonthDayNanoType::make_value(0, 1, -8_000_000_000)),Some(IntervalMonthDayNanoType::make_value(1, 25, 100_000_000_000)),Some(IntervalMonthDayNanoType::make_value(0, 1, 0)),Some(IntervalMonthDayNanoType::make_value(2, 0, 0)), Some(IntervalMonthDayNanoType::make_value(5, 150, 1_000_000_000_000))], + vec![Some(IntervalMonthDayNanoType::make_value(0, 0,0)),Some(IntervalMonthDayNanoType::make_value(0, 1, -8_000_000_000)),Some(IntervalMonthDayNanoType::make_value(1, 25, 100_000_000_000)),Some(IntervalMonthDayNanoType::make_value(0, 2, 0)),Some(IntervalMonthDayNanoType::make_value(2, 0, 0)), Some(IntervalMonthDayNanoType::make_value(5, 150, 1_000_000_000_000))], ); let res = gt(&a, &b).unwrap(); let res_eq = gt_eq(&a, &b).unwrap();