diff --git a/arrow-arith/src/numeric.rs b/arrow-arith/src/numeric.rs index b6af40f7d7c..848c2121343 100644 --- a/arrow-arith/src/numeric.rs +++ b/arrow-arith/src/numeric.rs @@ -230,6 +230,24 @@ fn arithmetic_op(op: Op, lhs: &dyn Datum, rhs: &dyn Datum) -> Result interval_op::(op, l, l_scalar, r, r_scalar), (Interval(DayTime), Interval(DayTime)) => interval_op::(op, l, l_scalar, r, r_scalar), (Interval(MonthDayNano), Interval(MonthDayNano)) => interval_op::(op, l, l_scalar, r, r_scalar), + (Interval(unit), rhs) if rhs.is_numeric() && matches!(op, Op::Mul | Op::MulWrapping) => + match unit { + YearMonth => interval_mul_op::(op, l, l_scalar, r, r_scalar), + DayTime => interval_mul_op::(op, l, l_scalar, r, r_scalar), + MonthDayNano => interval_mul_op::(op, l, l_scalar, r, r_scalar), + }, + (lhs, Interval(unit)) if lhs.is_numeric() && matches!(op, Op::Mul | Op::MulWrapping) => + match unit { + YearMonth => interval_mul_op::(op, r, r_scalar, l, l_scalar), + DayTime => interval_mul_op::(op, r, r_scalar, l, l_scalar), + MonthDayNano => interval_mul_op::(op, r, r_scalar, l, l_scalar), + }, + (Interval(unit), rhs) if rhs.is_numeric() && matches!(op, Op::Div) => + match unit { + YearMonth => interval_div_op::(op, l, l_scalar, r, r_scalar), + DayTime => interval_div_op::(op, l, l_scalar, r, r_scalar), + MonthDayNano => interval_div_op::(op, l, l_scalar, r, r_scalar), + }, (Date32, _) => date_op::(op, l, l_scalar, r, r_scalar), (Date64, _) => date_op::(op, l, l_scalar, r, r_scalar), (Decimal128(_, _), Decimal128(_, _)) => decimal_op::(op, l, l_scalar, r, r_scalar), @@ -550,6 +568,21 @@ date!(Date64Type); trait IntervalOp: ArrowPrimitiveType { fn add(left: Self::Native, right: Self::Native) -> Result; fn sub(left: Self::Native, right: Self::Native) -> Result; + fn mul_int(left: Self::Native, right: i32) -> Result; + fn mul_float(left: Self::Native, right: f64) -> Result; + fn div_int(left: Self::Native, right: i32) -> Result; + fn div_float(left: Self::Native, right: f64) -> Result; +} + +/// Helper function to safely convert f64 to i32, checking for overflow and invalid values +fn f64_to_i32(value: f64) -> Result { + if !value.is_finite() || value > i32::MAX as f64 || value < i32::MIN as f64 { + Err(ArrowError::ComputeError( + "Division result out of i32 range".to_string(), + )) + } else { + Ok(value as i32) + } } impl IntervalOp for IntervalYearMonthType { @@ -560,6 +593,33 @@ impl IntervalOp for IntervalYearMonthType { fn sub(left: Self::Native, right: Self::Native) -> Result { left.sub_checked(right) } + + fn mul_int(left: Self::Native, right: i32) -> Result { + left.mul_checked(right) + } + + fn mul_float(left: Self::Native, right: f64) -> Result { + let result = (left as f64 * right) as i32; + Ok(result) + } + + fn div_int(left: Self::Native, right: i32) -> Result { + if right == 0 { + return Err(ArrowError::DivideByZero); + } + + let result = left / right; + Ok(result) + } + + fn div_float(left: Self::Native, right: f64) -> Result { + if right == 0.0 { + return Err(ArrowError::DivideByZero); + } + + let result = left as f64 / right; + f64_to_i32(result) + } } impl IntervalOp for IntervalDayTimeType { @@ -578,6 +638,70 @@ impl IntervalOp for IntervalDayTimeType { let ms = l_ms.sub_checked(r_ms)?; Ok(Self::make_value(days, ms)) } + + fn mul_int(left: Self::Native, right: i32) -> Result { + let (days, ms) = Self::to_parts(left); + Ok(IntervalDayTimeType::make_value( + days.mul_checked(right)?, + ms.mul_checked(right)?, + )) + } + + fn mul_float(left: Self::Native, right: f64) -> Result { + let (days, ms) = Self::to_parts(left); + + // Calculate total days including fractional part + let total_days = days as f64 * right; + // Split into whole and fractional days + let whole_days = total_days.trunc() as i32; + let frac_days = total_days.fract(); + + // Convert fractional days to milliseconds (24 * 60 * 60 * 1000 = 86_400_000 ms per day) + let frac_ms = f64_to_i32(frac_days * 86_400_000.0)?; + + // Calculate total milliseconds including the fractional days + let total_ms = f64_to_i32(ms as f64 * right)? + frac_ms; + + Ok(Self::make_value(whole_days, total_ms)) + } + + fn div_int(left: Self::Native, right: i32) -> Result { + if right == 0 { + return Err(ArrowError::DivideByZero); + } + let (days, ms) = Self::to_parts(left); + + // Convert everything to milliseconds to handle remainders + let total_ms = ms as i64 + (days as i64 * 86_400_000); // 24 * 60 * 60 * 1000 + let result_ms = total_ms / right as i64; + + // Convert back to days and milliseconds + let result_days = result_ms as f64 / 86_400_000.0; + let result_ms = result_ms % 86_400_000; + + let result_days_i32 = f64_to_i32(result_days)?; + let result_ms_i32 = f64_to_i32(result_ms as f64)?; + Ok(Self::make_value(result_days_i32, result_ms_i32)) + } + + fn div_float(left: Self::Native, right: f64) -> Result { + if right == 0.0 { + return Err(ArrowError::DivideByZero); + } + let (days, ms) = Self::to_parts(left); + + // Convert everything to milliseconds to handle remainders + let total_ms = (ms as f64 + (days as f64 * 86_400_000.0)) / right; + + // Convert back to days and milliseconds + let result_days = (total_ms / 86_400_000.0).floor(); + let result_ms = total_ms % 86_400_000.0; + + let result_days_i32 = f64_to_i32(result_days)?; + let result_ms_i32 = f64_to_i32(result_ms)?; + + Ok(Self::make_value(result_days_i32, result_ms_i32)) + } } impl IntervalOp for IntervalMonthDayNanoType { @@ -598,6 +722,33 @@ impl IntervalOp for IntervalMonthDayNanoType { let nanos = l_nanos.sub_checked(r_nanos)?; Ok(Self::make_value(months, days, nanos)) } + + fn mul_int(left: Self::Native, right: i32) -> Result { + let (months, days, nanos) = Self::to_parts(left); + Ok(Self::make_value( + months.mul_checked(right)?, + days.mul_checked(right)?, + nanos.mul_checked(right as i64)?, + )) + } + + fn mul_float(_left: Self::Native, _right: f64) -> Result { + Err(ArrowError::InvalidArgumentError( + "Floating point multiplication not supported for MonthDayNano intervals".to_string(), + )) + } + + fn div_int(_left: Self::Native, _right: i32) -> Result { + Err(ArrowError::InvalidArgumentError( + "Integer division not supported for MonthDayNano intervals".to_string(), + )) + } + + fn div_float(_left: Self::Native, _right: f64) -> Result { + Err(ArrowError::InvalidArgumentError( + "Floating point division not supported for MonthDayNano intervals".to_string(), + )) + } } /// Perform arithmetic operation on an interval array @@ -621,6 +772,98 @@ fn interval_op( } } +/// Perform multiplication between an interval array and a numeric array +fn interval_mul_op( + op: Op, + l: &dyn Array, + l_s: bool, + r: &dyn Array, + r_s: bool, +) -> Result { + // Assume the interval is the left argument + if let Some(l_interval) = l.as_primitive_opt::() { + match r.data_type() { + DataType::Int32 => { + let r_int = r.as_primitive::(); + Ok(try_op_ref!( + T, + l_interval, + l_s, + r_int, + r_s, + T::mul_int(l_interval, r_int) + )) + } + DataType::Float64 => { + let r_float = r.as_primitive::(); + Ok(try_op_ref!( + T, + l_interval, + l_s, + r_float, + r_s, + T::mul_float(l_interval, r_float) + )) + } + _ => Err(ArrowError::InvalidArgumentError(format!( + "Invalid numeric type for interval multiplication: {}", + r.data_type() + ))), + } + } else { + Err(ArrowError::InvalidArgumentError(format!( + "Invalid interval multiplication: {} {op} {}", + l.data_type(), + r.data_type() + ))) + } +} + +fn interval_div_op( + op: Op, + l: &dyn Array, + l_s: bool, + r: &dyn Array, + r_s: bool, +) -> Result { + if let Some(l_interval) = l.as_primitive_opt::() { + match r.data_type() { + DataType::Int32 => { + let r_int = r.as_primitive::(); + Ok(try_op_ref!( + T, + l_interval, + l_s, + r_int, + r_s, + T::div_int(l_interval, r_int) + )) + } + DataType::Float64 => { + let r_float = r.as_primitive::(); + Ok(try_op_ref!( + T, + l_interval, + l_s, + r_float, + r_s, + T::div_float(l_interval, r_float) + )) + } + _ => Err(ArrowError::InvalidArgumentError(format!( + "Invalid numeric type for interval division: {}", + r.data_type() + ))), + } + } else { + Err(ArrowError::InvalidArgumentError(format!( + "Invalid interval division: {} {op} {}", + l.data_type(), + r.data_type() + ))) + } +} + fn duration_op( op: Op, l: &dyn Array, @@ -1356,6 +1599,79 @@ mod tests { err, "Arithmetic overflow: Overflow happened on: 2147483647 + 1" ); + + // Test interval multiplication + let a = IntervalYearMonthArray::from(vec![IntervalYearMonthType::make_value(2, 4)]); + let b = PrimitiveArray::::from(vec![5]); + let result = mul(&a, &b).unwrap(); + assert_eq!( + result.as_ref(), + &IntervalYearMonthArray::from(vec![IntervalYearMonthType::make_value(11, 8),]) + ); + + // swap a and b + let result = mul(&b, &a).unwrap(); + assert_eq!( + result.as_ref(), + &IntervalYearMonthArray::from(vec![IntervalYearMonthType::make_value(11, 8),]) + ); + + let a = IntervalDayTimeArray::from(vec![ + IntervalDayTimeType::make_value(10, 7200000), // 10 days, 2 hours + ]); + let b = PrimitiveArray::::from(vec![3]); + let result = mul(&a, &b).unwrap(); + assert_eq!( + result.as_ref(), + &IntervalDayTimeArray::from(vec![ + IntervalDayTimeType::make_value(30, 21600000), // 30 days, 6 hours + ]) + ); + + let a = IntervalMonthDayNanoArray::from(vec![ + IntervalMonthDayNanoType::make_value(12, 15, 5_000_000_000), // 12 months, 15 days, 5 seconds + ]); + let b = PrimitiveArray::::from(vec![2]); + let result = mul(&a, &b).unwrap(); + assert_eq!( + result.as_ref(), + &IntervalMonthDayNanoArray::from(vec![ + IntervalMonthDayNanoType::make_value(24, 30, 10_000_000_000), // 24 months, 30 days, 10 seconds + ]) + ); + + let a = IntervalYearMonthArray::from(vec![IntervalYearMonthType::make_value(1, 6)]); // 1 year, 6 months + let b = PrimitiveArray::::from(vec![2.5]); + let result = mul(&a, &b).unwrap(); + assert_eq!( + result.as_ref(), + &IntervalYearMonthArray::from(vec![IntervalYearMonthType::make_value(3, 9)]) // 3 years, 9 months = 45 months + ); + + let a = IntervalDayTimeArray::from(vec![ + IntervalDayTimeType::make_value(5, 3600000), // 5 days, 1 hour + ]); + let b = PrimitiveArray::::from(vec![-2]); + let result = mul(&a, &b).unwrap(); + assert_eq!( + result.as_ref(), + &IntervalDayTimeArray::from(vec![ + IntervalDayTimeType::make_value(-10, -7200000), // -10 days, -2 hours + ]) + ); + + // Test interval division + let a = IntervalDayTimeArray::from(vec![ + IntervalDayTimeType::make_value(15, 3600000), // 15 days, 1 hour + ]); + let b = PrimitiveArray::::from(vec![2]); + let result = div(&a, &b).unwrap(); + assert_eq!( + result.as_ref(), + &IntervalDayTimeArray::from(vec![ + IntervalDayTimeType::make_value(7, 45000000), // 7 days, 12.5 hours (half of 15 days, 1 hour) + ]) + ); } fn test_duration_impl>() { diff --git a/arrow-array/src/cast.rs b/arrow-array/src/cast.rs index fc657f94c6a..c464cdf4b81 100644 --- a/arrow-array/src/cast.rs +++ b/arrow-array/src/cast.rs @@ -72,7 +72,7 @@ macro_rules! repeat_pat { /// [`DataType`]: arrow_schema::DataType #[macro_export] macro_rules! downcast_integer { - ($($data_type:expr),+ => ($m:path $(, $args:tt)*), $($p:pat => $fallback:expr $(,)*)*) => { + ($($data_type:expr),+ => ($m:path $(, $args:tt)*), $($p:pat $( if $guard:expr )? => $fallback:expr $(,)*)*) => { match ($($data_type),+) { $crate::repeat_pat!($crate::cast::__private::DataType::Int8, $($data_type),+) => { $m!($crate::types::Int8Type $(, $args)*) @@ -98,7 +98,7 @@ macro_rules! downcast_integer { $crate::repeat_pat!($crate::cast::__private::DataType::UInt64, $($data_type),+) => { $m!($crate::types::UInt64Type $(, $args)*) } - $($p => $fallback,)* + $($p $( if $guard )?=> $fallback,)* } }; }