Skip to content

Commit

Permalink
Check precision overflow for casting floating to decimal
Browse files Browse the repository at this point in the history
  • Loading branch information
viirya committed Sep 27, 2023
1 parent 4ef7917 commit fcfc957
Showing 1 changed file with 104 additions and 20 deletions.
124 changes: 104 additions & 20 deletions arrow-cast/src/cast.rs
Original file line number Diff line number Diff line change
Expand Up @@ -364,21 +364,33 @@ where

if cast_options.safe {
array
.unary_opt::<_, Decimal128Type>(|v| (mul * v.as_()).round().to_i128())
.unary_opt::<_, Decimal128Type>(|v| {
(mul * v.as_()).round().to_i128().and_then(|v| {
(Decimal128Type::validate_decimal_precision(v, precision).is_ok())
.then_some(v)
})
})
.with_precision_and_scale(precision, scale)
.map(|a| Arc::new(a) as ArrayRef)
} else {
array
.try_unary::<_, Decimal128Type, _>(|v| {
(mul * v.as_()).round().to_i128().ok_or_else(|| {
ArrowError::CastError(format!(
"Cannot cast to {}({}, {}). Overflowing on {:?}",
Decimal128Type::PREFIX,
precision,
scale,
v
))
})
(mul * v.as_())
.round()
.to_i128()
.ok_or_else(|| {
ArrowError::CastError(format!(
"Cannot cast to {}({}, {}). Overflowing on {:?}",
Decimal128Type::PREFIX,
precision,
scale,
v
))
})
.and_then(|v| {
Decimal128Type::validate_decimal_precision(v, precision)
.map(|_| v)
})
})?
.with_precision_and_scale(precision, scale)
.map(|a| Arc::new(a) as ArrayRef)
Expand All @@ -398,21 +410,31 @@ where

if cast_options.safe {
array
.unary_opt::<_, Decimal256Type>(|v| i256::from_f64((v.as_() * mul).round()))
.unary_opt::<_, Decimal256Type>(|v| {
i256::from_f64((v.as_() * mul).round()).and_then(|v| {
(Decimal256Type::validate_decimal_precision(v, precision).is_ok())
.then_some(v)
})
})
.with_precision_and_scale(precision, scale)
.map(|a| Arc::new(a) as ArrayRef)
} else {
array
.try_unary::<_, Decimal256Type, _>(|v| {
i256::from_f64((v.as_() * mul).round()).ok_or_else(|| {
ArrowError::CastError(format!(
"Cannot cast to {}({}, {}). Overflowing on {:?}",
Decimal256Type::PREFIX,
precision,
scale,
v
))
})
i256::from_f64((v.as_() * mul).round())
.ok_or_else(|| {
ArrowError::CastError(format!(
"Cannot cast to {}({}, {}). Overflowing on {:?}",
Decimal256Type::PREFIX,
precision,
scale,
v
))
})
.and_then(|v| {
Decimal256Type::validate_decimal_precision(v, precision)
.map(|_| v)
})
})?
.with_precision_and_scale(precision, scale)
.map(|a| Arc::new(a) as ArrayRef)
Expand Down Expand Up @@ -7748,6 +7770,68 @@ mod tests {
assert!(casted_array.is_err());
}

#[test]
fn test_cast_floating_point_to_decimal128_precision_overflow() {
let array = Float64Array::from(vec![1.1]);
let array = Arc::new(array) as ArrayRef;
let casted_array = cast_with_options(
&array,
&DataType::Decimal128(2, 2),
&CastOptions {
safe: true,
format_options: FormatOptions::default(),
},
);
assert!(casted_array.is_ok());
assert!(casted_array.unwrap().is_null(0));

let casted_array = cast_with_options(
&array,
&DataType::Decimal128(2, 2),
&CastOptions {
safe: false,
format_options: FormatOptions::default(),
},
);
let err = casted_array.unwrap_err().to_string();
let expected_error = "Invalid argument error: 110 is too large to store in a Decimal128 of precision 2. Max is 99";
assert!(
err.contains(expected_error),
"did not find expected error '{expected_error}' in actual error '{err}'"
);
}

#[test]
fn test_cast_floating_point_to_decimal256_precision_overflow() {
let array = Float64Array::from(vec![1.1]);
let array = Arc::new(array) as ArrayRef;
let casted_array = cast_with_options(
&array,
&DataType::Decimal256(2, 2),
&CastOptions {
safe: true,
format_options: FormatOptions::default(),
},
);
assert!(casted_array.is_ok());
assert!(casted_array.unwrap().is_null(0));

let casted_array = cast_with_options(
&array,
&DataType::Decimal256(2, 2),
&CastOptions {
safe: false,
format_options: FormatOptions::default(),
},
);
let err = casted_array.unwrap_err().to_string();
let expected_error = "Invalid argument error: 110 is too large to store in a Decimal256 of precision 2. Max is 99";
assert!(
err.contains(expected_error),
"did not find expected error '{expected_error}' in actual error '{err}'"
);
}

#[test]
fn test_cast_floating_point_to_decimal128_overflow() {
let array = Float64Array::from(vec![f64::MAX]);
Expand Down

0 comments on commit fcfc957

Please sign in to comment.