diff --git a/arrow-cast/src/cast/decimal.rs b/arrow-cast/src/cast/decimal.rs index 972233e5a0a..3dee987e818 100644 --- a/arrow-cast/src/cast/decimal.rs +++ b/arrow-cast/src/cast/decimal.rs @@ -231,6 +231,7 @@ where )?)) } +#[allow(dead_code)] /// Parses given string to specified decimal native (i128/i256) based on given /// scale. Returns an `Err` if it cannot parse given string. pub(crate) fn parse_string_to_decimal_native( @@ -343,10 +344,9 @@ where &'a S: StringArrayType<'a>, { if cast_options.safe { - let iter = from.iter().map(|v| { - v.and_then(|v| parse_decimal::(v, precision, scale).ok()) - .and_then(|v| T::is_valid_decimal_precision(v, precision).then_some(v)) - }); + let iter = from + .iter() + .map(|v| v.and_then(|v| parse_decimal::(v, precision, scale).ok())); // Benefit: // 20% performance improvement // Soundness: @@ -360,15 +360,12 @@ where .iter() .map(|v| { v.map(|v| { - parse_decimal::(v, precision, scale) - .map_err(|_| { - ArrowError::CastError(format!( - "Cannot cast string '{}' to value of {:?} type", - v, - T::DATA_TYPE, - )) - }) - .and_then(|v| T::validate_decimal_precision(v, precision).map(|_| v)) + parse_decimal::(v, precision, scale).map_err(|_| { + ArrowError::CastError(format!( + "Cannot cast string '{}' to decimal type of precision {} and scale {}", + v, precision, scale + )) + }) }) .transpose() }) diff --git a/arrow-cast/src/cast/mod.rs b/arrow-cast/src/cast/mod.rs index 8d0d097979d..5c0da189914 100644 --- a/arrow-cast/src/cast/mod.rs +++ b/arrow-cast/src/cast/mod.rs @@ -8849,16 +8849,16 @@ mod tests { format_options: FormatOptions::default(), }; let casted_err = cast_with_options(&array, &output_type, &option).unwrap_err(); - assert!(casted_err - .to_string() - .contains("Cannot cast string '4.4.5' to value of Decimal128(38, 10) type")); + assert!(casted_err.to_string().contains( + "Cast error: Cannot cast string '4.4.5' to decimal type of precision 38 and scale 2" + )); let str_array = StringArray::from(vec![". 0.123"]); let array = Arc::new(str_array) as ArrayRef; let casted_err = cast_with_options(&array, &output_type, &option).unwrap_err(); - assert!(casted_err - .to_string() - .contains("Cannot cast string '. 0.123' to value of Decimal128(38, 10) type")); + assert!(casted_err.to_string().contains( + "Cast error: Cannot cast string '. 0.123' to decimal type of precision 38 and scale 2" + )); } fn test_cast_string_to_decimal128_overflow(overflow_array: ArrayRef) { @@ -8902,7 +8902,10 @@ mod tests { format_options: FormatOptions::default(), }, ); - assert_eq!("Invalid argument error: 100000000000 is too large to store in a Decimal128 of precision 10. Max is 9999999999", err.unwrap_err().to_string()); + assert_eq!( + "Cast error: Cannot cast string '1000' to decimal type of precision 10 and scale 8", + err.unwrap_err().to_string() + ); } #[test] @@ -8985,7 +8988,10 @@ mod tests { format_options: FormatOptions::default(), }, ); - assert_eq!("Invalid argument error: 100000000000 is too large to store in a Decimal256 of precision 10. Max is 9999999999", err.unwrap_err().to_string()); + assert_eq!( + "Cast error: Cannot cast string '1000' to decimal type of precision 10 and scale 8", + err.unwrap_err().to_string() + ); } #[test] diff --git a/arrow-cast/src/parse.rs b/arrow-cast/src/parse.rs index fd2e078d897..ea020a27ca3 100644 --- a/arrow-cast/src/parse.rs +++ b/arrow-cast/src/parse.rs @@ -842,6 +842,7 @@ pub fn parse_decimal( let mut result = T::Native::usize_as(0); let mut fractionals: i8 = 0; let mut digits: u8 = 0; + let mut rounding_digit = -1; // to store digit after the scale for rounding let base = T::Native::usize_as(10); let bs = s.as_bytes(); @@ -871,6 +872,13 @@ pub fn parse_decimal( // Ignore leading zeros. continue; } + if fractionals == scale && scale != 0 && rounding_digit < 0 { + // Capture the rounding digit once + if rounding_digit < 0 { + rounding_digit = (b - b'0') as i8; + } + continue; + } digits += 1; result = result.mul_wrapping(base); result = result.add_wrapping(T::Native::usize_as((b - b'0') as usize)); @@ -903,9 +911,10 @@ pub fn parse_decimal( ))); } if fractionals == scale && scale != 0 { - // We have processed all the digits that we need. All that - // is left is to validate that the rest of the string contains - // valid digits. + // Capture the rounding digit once + if rounding_digit < 0 { + rounding_digit = (b - b'0') as i8; + } continue; } fractionals += 1; @@ -966,6 +975,10 @@ pub fn parse_decimal( "parse decimal overflow ({s})" ))); } + //add one if >=5 + if rounding_digit >= 5 { + result = result.add_wrapping(T::Native::usize_as(1)); + } } Ok(if negative {