Skip to content

Commit

Permalink
added rounding logic for non e-notation,
Browse files Browse the repository at this point in the history
error message changed.
  • Loading branch information
himadripal committed Dec 19, 2024
1 parent 4b19083 commit 45ec17e
Show file tree
Hide file tree
Showing 3 changed files with 40 additions and 24 deletions.
23 changes: 10 additions & 13 deletions arrow-cast/src/cast/decimal.rs
Original file line number Diff line number Diff line change
Expand Up @@ -231,6 +231,7 @@ where
)?))
}

#[allow(dead_code)]
/// Parses given string to specified decimal native (i128/i256) based on given
/// scale. Returns an `Err` if it cannot parse given string.
pub(crate) fn parse_string_to_decimal_native<T: DecimalType>(
Expand Down Expand Up @@ -343,10 +344,9 @@ where
&'a S: StringArrayType<'a>,
{
if cast_options.safe {
let iter = from.iter().map(|v| {
v.and_then(|v| parse_decimal::<T>(v, precision, scale).ok())
.and_then(|v| T::is_valid_decimal_precision(v, precision).then_some(v))
});
let iter = from
.iter()
.map(|v| v.and_then(|v| parse_decimal::<T>(v, precision, scale).ok()));
// Benefit:
// 20% performance improvement
// Soundness:
Expand All @@ -360,15 +360,12 @@ where
.iter()
.map(|v| {
v.map(|v| {
parse_decimal::<T>(v, precision, scale)
.map_err(|_| {
ArrowError::CastError(format!(
"Cannot cast string '{}' to value of {:?} type",
v,
T::DATA_TYPE,
))
})
.and_then(|v| T::validate_decimal_precision(v, precision).map(|_| v))
parse_decimal::<T>(v, precision, scale).map_err(|_| {
ArrowError::CastError(format!(
"Cannot cast string '{}' to decimal type of precision {} and scale {}",
v, precision, scale
))
})
})
.transpose()
})
Expand Down
22 changes: 14 additions & 8 deletions arrow-cast/src/cast/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -8849,16 +8849,16 @@ mod tests {
format_options: FormatOptions::default(),
};
let casted_err = cast_with_options(&array, &output_type, &option).unwrap_err();
assert!(casted_err
.to_string()
.contains("Cannot cast string '4.4.5' to value of Decimal128(38, 10) type"));
assert!(casted_err.to_string().contains(
"Cast error: Cannot cast string '4.4.5' to decimal type of precision 38 and scale 2"
));

let str_array = StringArray::from(vec![". 0.123"]);
let array = Arc::new(str_array) as ArrayRef;
let casted_err = cast_with_options(&array, &output_type, &option).unwrap_err();
assert!(casted_err
.to_string()
.contains("Cannot cast string '. 0.123' to value of Decimal128(38, 10) type"));
assert!(casted_err.to_string().contains(
"Cast error: Cannot cast string '. 0.123' to decimal type of precision 38 and scale 2"
));
}

fn test_cast_string_to_decimal128_overflow(overflow_array: ArrayRef) {
Expand Down Expand Up @@ -8902,7 +8902,10 @@ mod tests {
format_options: FormatOptions::default(),
},
);
assert_eq!("Invalid argument error: 100000000000 is too large to store in a Decimal128 of precision 10. Max is 9999999999", err.unwrap_err().to_string());
assert_eq!(
"Cast error: Cannot cast string '1000' to decimal type of precision 10 and scale 8",
err.unwrap_err().to_string()
);
}

#[test]
Expand Down Expand Up @@ -8985,7 +8988,10 @@ mod tests {
format_options: FormatOptions::default(),
},
);
assert_eq!("Invalid argument error: 100000000000 is too large to store in a Decimal256 of precision 10. Max is 9999999999", err.unwrap_err().to_string());
assert_eq!(
"Cast error: Cannot cast string '1000' to decimal type of precision 10 and scale 8",
err.unwrap_err().to_string()
);
}

#[test]
Expand Down
19 changes: 16 additions & 3 deletions arrow-cast/src/parse.rs
Original file line number Diff line number Diff line change
Expand Up @@ -842,6 +842,7 @@ pub fn parse_decimal<T: DecimalType>(
let mut result = T::Native::usize_as(0);
let mut fractionals: i8 = 0;
let mut digits: u8 = 0;
let mut rounding_digit = -1; // to store digit after the scale for rounding
let base = T::Native::usize_as(10);

let bs = s.as_bytes();
Expand Down Expand Up @@ -871,6 +872,13 @@ pub fn parse_decimal<T: DecimalType>(
// Ignore leading zeros.
continue;
}
if fractionals == scale && scale != 0 && rounding_digit < 0 {
// Capture the rounding digit once
if rounding_digit < 0 {
rounding_digit = (b - b'0') as i8;
}
continue;
}
digits += 1;
result = result.mul_wrapping(base);
result = result.add_wrapping(T::Native::usize_as((b - b'0') as usize));
Expand Down Expand Up @@ -903,9 +911,10 @@ pub fn parse_decimal<T: DecimalType>(
)));
}
if fractionals == scale && scale != 0 {
// We have processed all the digits that we need. All that
// is left is to validate that the rest of the string contains
// valid digits.
// Capture the rounding digit once
if rounding_digit < 0 {
rounding_digit = (b - b'0') as i8;
}
continue;
}
fractionals += 1;
Expand Down Expand Up @@ -966,6 +975,10 @@ pub fn parse_decimal<T: DecimalType>(
"parse decimal overflow ({s})"
)));
}
//add one if >=5
if rounding_digit >= 5 {
result = result.add_wrapping(T::Native::usize_as(1));
}
}

Ok(if negative {
Expand Down

0 comments on commit 45ec17e

Please sign in to comment.