Skip to content

Commit

Permalink
refactor: conversion error decimal variant (#26)
Browse files Browse the repository at this point in the history
# Rationale for this change

The ```ConversionError``` error type has grown quite large with the
addition of new features and types. This PR categorizes some timestamp
and decimal errors into different modules and reduces the size and
complexity of ```ConversionError```.

# What changes are included in this PR?

IntermediateDecimal, Decimal errors are moved to respective modules.
Native thiserr ```#from``` is used in place of manual ```From```
conversions.

# Are these changes tested?

yes
  • Loading branch information
Dustin-Ray authored Jun 25, 2024
1 parent 55816c3 commit 88f3d40
Show file tree
Hide file tree
Showing 7 changed files with 126 additions and 95 deletions.
43 changes: 20 additions & 23 deletions crates/proof-of-sql-parser/src/intermediate_decimal.rs
Original file line number Diff line number Diff line change
Expand Up @@ -5,14 +5,15 @@
//!
//! A decimal must have a decimal point. The lexer does not route
//! whole integers to this contructor.
use crate::intermediate_decimal::IntermediateDecimalError::{LossyCast, OutOfRange, ParseError};
use bigdecimal::{num_bigint::BigInt, BigDecimal, ParseBigDecimalError, ToPrimitive};
use serde::{Deserialize, Serialize};
use std::{fmt, str::FromStr};
use thiserror::Error;

/// Errors related to the processing of decimal values in proof-of-sql
#[derive(Error, Debug, PartialEq)]
pub enum DecimalError {
pub enum IntermediateDecimalError {
/// Represents an error encountered during the parsing of a decimal string.
#[error(transparent)]
ParseError(#[from] ParseBigDecimalError),
Expand All @@ -27,6 +28,8 @@ pub enum DecimalError {
ConversionFailure,
}

impl Eq for IntermediateDecimalError {}

/// An intermediate placeholder for a decimal
#[derive(Serialize, Deserialize, Debug, PartialEq, Eq, Clone)]
pub struct IntermediateDecimal {
Expand Down Expand Up @@ -55,10 +58,10 @@ impl IntermediateDecimal {
&self,
precision: u8,
scale: i8,
) -> Result<BigInt, DecimalError> {
) -> Result<BigInt, IntermediateDecimalError> {
let scaled_decimal = self.value.with_scale(scale.into());
if scaled_decimal.digits() > precision.into() {
return Err(DecimalError::LossyCast);
return Err(LossyCast);
}
let (d, _) = scaled_decimal.into_bigint_and_exponent();
Ok(d)
Expand All @@ -72,14 +75,14 @@ impl fmt::Display for IntermediateDecimal {
}

impl FromStr for IntermediateDecimal {
type Err = DecimalError;
type Err = IntermediateDecimalError;

fn from_str(decimal_string: &str) -> Result<Self, Self::Err> {
BigDecimal::from_str(decimal_string)
.map(|value| IntermediateDecimal {
value: value.normalized(),
})
.map_err(DecimalError::ParseError)
.map_err(ParseError)
}
}

Expand All @@ -100,47 +103,47 @@ impl From<i64> for IntermediateDecimal {
}

impl TryFrom<&str> for IntermediateDecimal {
type Error = DecimalError;
type Error = IntermediateDecimalError;

fn try_from(s: &str) -> Result<Self, Self::Error> {
IntermediateDecimal::from_str(s)
}
}

impl TryFrom<String> for IntermediateDecimal {
type Error = DecimalError;
type Error = IntermediateDecimalError;

fn try_from(s: String) -> Result<Self, Self::Error> {
IntermediateDecimal::from_str(&s)
}
}

impl TryFrom<IntermediateDecimal> for i128 {
type Error = DecimalError;
type Error = IntermediateDecimalError;

fn try_from(decimal: IntermediateDecimal) -> Result<Self, Self::Error> {
if !decimal.value.is_integer() {
return Err(DecimalError::LossyCast);
return Err(LossyCast);
}

match decimal.value.to_i128() {
Some(value) if (i128::MIN..=i128::MAX).contains(&value) => Ok(value),
_ => Err(DecimalError::OutOfRange),
_ => Err(OutOfRange),
}
}
}

impl TryFrom<IntermediateDecimal> for i64 {
type Error = DecimalError;
type Error = IntermediateDecimalError;

fn try_from(decimal: IntermediateDecimal) -> Result<Self, Self::Error> {
if !decimal.value.is_integer() {
return Err(DecimalError::LossyCast);
return Err(LossyCast);
}

match decimal.value.to_i64() {
Some(value) if (i64::MIN..=i64::MAX).contains(&value) => Ok(value),
_ => Err(DecimalError::OutOfRange),
_ => Err(OutOfRange),
}
}
}
Expand Down Expand Up @@ -195,10 +198,7 @@ mod tests {
let overflow_decimal = IntermediateDecimal {
value: BigDecimal::from_str("170141183460469231731687303715884105728").unwrap(),
};
assert_eq!(
i128::try_from(overflow_decimal),
Err(DecimalError::OutOfRange)
);
assert_eq!(i128::try_from(overflow_decimal), Err(OutOfRange));

let valid_decimal_negative = IntermediateDecimal {
value: BigDecimal::from_str("-170141183460469231731687303715884105728").unwrap(),
Expand All @@ -211,7 +211,7 @@ mod tests {
let non_integer = IntermediateDecimal {
value: BigDecimal::from_str("100.5").unwrap(),
};
assert_eq!(i128::try_from(non_integer), Err(DecimalError::LossyCast));
assert_eq!(i128::try_from(non_integer), Err(LossyCast));
}

#[test]
Expand All @@ -229,10 +229,7 @@ mod tests {
let overflow_decimal = IntermediateDecimal {
value: BigDecimal::from_str("9223372036854775808").unwrap(),
};
assert_eq!(
i64::try_from(overflow_decimal),
Err(DecimalError::OutOfRange)
);
assert_eq!(i64::try_from(overflow_decimal), Err(OutOfRange));

let valid_decimal_negative = IntermediateDecimal {
value: BigDecimal::from_str("-9223372036854775808").unwrap(),
Expand All @@ -245,6 +242,6 @@ mod tests {
let non_integer = IntermediateDecimal {
value: BigDecimal::from_str("100.5").unwrap(),
};
assert_eq!(i64::try_from(non_integer), Err(DecimalError::LossyCast));
assert_eq!(i64::try_from(non_integer), Err(LossyCast));
}
}
80 changes: 61 additions & 19 deletions crates/proof-of-sql/src/base/math/decimal.rs
Original file line number Diff line number Diff line change
@@ -1,10 +1,51 @@
//! Module for parsing an `IntermediateDecimal` into a `Decimal75`.
use crate::{
base::scalar::Scalar,
sql::parse::{ConversionError, ConversionResult},
base::{
math::decimal::DecimalError::{
IntermediateDecimalConversionError, InvalidPrecision, RoundingError,
},
scalar::Scalar,
},
sql::parse::{
ConversionError::{self, DecimalConversionError},
ConversionResult,
},
};
use proof_of_sql_parser::intermediate_decimal::IntermediateDecimal;
use proof_of_sql_parser::intermediate_decimal::{IntermediateDecimal, IntermediateDecimalError};
use serde::{Deserialize, Deserializer, Serialize};
use thiserror::Error;

/// Errors related to decimal operations.
#[derive(Error, Debug, Eq, PartialEq)]
pub enum DecimalError {
#[error("Invalid decimal format or value: {0}")]
/// Error when a decimal format or value is incorrect,
/// the string isn't even a decimal e.g. "notastring",
/// "-21.233.122" etc aka InvalidDecimal
InvalidDecimal(String),

#[error("Decimal precision is not valid: {0}")]
/// Decimal precision exceeds the allowed limit,
/// e.g. precision above 75/76/whatever set by Scalar
/// or non-positive aka InvalidPrecision
InvalidPrecision(String),

#[error("Unsupported operation: cannot round decimal: {0}")]
/// This error occurs when attempting to scale a
/// decimal in such a way that a loss of precision occurs.
RoundingError(String),

/// Errors that may occur when parsing an intermediate decimal
/// into a posql decimal
#[error("Intermediate decimal conversion error: {0}")]
IntermediateDecimalConversionError(IntermediateDecimalError),
}

impl From<IntermediateDecimalError> for ConversionError {
fn from(err: IntermediateDecimalError) -> ConversionError {
DecimalConversionError(IntermediateDecimalConversionError(err))
}
}

#[derive(Eq, PartialEq, Debug, Clone, Hash, Serialize, Copy)]
/// limit-enforced precision
Expand All @@ -15,10 +56,10 @@ impl Precision {
/// Constructor for creating a Precision instance
pub fn new(value: u8) -> Result<Self, ConversionError> {
if value > MAX_SUPPORTED_PRECISION || value == 0 {
Err(ConversionError::PrecisionParseError(format!(
Err(DecimalConversionError(InvalidPrecision(format!(
"Failed to parse precision. Value of {} exceeds max supported precision of {}",
value, MAX_SUPPORTED_PRECISION
)))
))))
} else {
Ok(Precision(value))
}
Expand Down Expand Up @@ -73,9 +114,9 @@ impl<S: Scalar> Decimal<S> {
) -> ConversionResult<Decimal<S>> {
let scale_factor = new_scale - self.scale;
if scale_factor < 0 || new_precision.value() < self.precision.value() + scale_factor as u8 {
return Err(ConversionError::DecimalRoundingError(
return Err(DecimalConversionError(RoundingError(
"Scale factor must be non-negative".to_string(),
));
)));
}
let scaled_value = scale_scalar(self.value, scale_factor)?;
Ok(Decimal::new(scaled_value, new_precision, new_scale))
Expand All @@ -86,14 +127,14 @@ impl<S: Scalar> Decimal<S> {
const MINIMAL_PRECISION: u8 = 19;
let raw_precision = precision.value();
if raw_precision < MINIMAL_PRECISION {
return Err(ConversionError::DecimalRoundingError(
return Err(DecimalConversionError(RoundingError(
"Precision must be at least 19".to_string(),
));
)));
}
if scale < 0 || raw_precision < MINIMAL_PRECISION + scale as u8 {
return Err(ConversionError::DecimalRoundingError(
return Err(DecimalConversionError(RoundingError(
"Can not scale down a decimal".to_string(),
));
)));
}
let scaled_value = scale_scalar(S::from(&value), scale)?;
Ok(Decimal::new(scaled_value, precision, scale))
Expand All @@ -104,14 +145,14 @@ impl<S: Scalar> Decimal<S> {
const MINIMAL_PRECISION: u8 = 39;
let raw_precision = precision.value();
if raw_precision < MINIMAL_PRECISION {
return Err(ConversionError::DecimalRoundingError(
return Err(DecimalConversionError(RoundingError(
"Precision must be at least 19".to_string(),
));
)));
}
if scale < 0 || raw_precision < MINIMAL_PRECISION + scale as u8 {
return Err(ConversionError::DecimalRoundingError(
return Err(DecimalConversionError(RoundingError(
"Can not scale down a decimal".to_string(),
));
)));
}
let scaled_value = scale_scalar(S::from(&value), scale)?;
Ok(Decimal::new(scaled_value, precision, scale))
Expand All @@ -132,8 +173,9 @@ impl<S: Scalar> Decimal<S> {
/// * `target_scale` - The scale (number of decimal places) to use in the scalar.
///
/// ## Errors
/// Returns `ConversionError::PrecisionParseError` if the number of digits in
/// the decimal exceeds the `target_precision` after adjusting for `target_scale`.
/// Returns `InvalidPrecision` error if the number of digits in
/// the decimal exceeds the `target_precision` before or after adjusting for
/// `target_scale`, or if the target precision is zero.
pub(crate) fn try_into_to_scalar<S: Scalar>(
d: &IntermediateDecimal,
target_precision: Precision,
Expand All @@ -147,9 +189,9 @@ pub(crate) fn try_into_to_scalar<S: Scalar>(
/// Note that we do not check for overflow.
pub(crate) fn scale_scalar<S: Scalar>(s: S, scale: i8) -> ConversionResult<S> {
if scale < 0 {
return Err(ConversionError::DecimalRoundingError(
return Err(DecimalConversionError(RoundingError(
"Scale factor must be non-negative".to_string(),
));
)));
}
let ten = S::from(10);
let mut res = s;
Expand Down
12 changes: 9 additions & 3 deletions crates/proof-of-sql/src/base/scalar/mont_scalar.rs
Original file line number Diff line number Diff line change
@@ -1,5 +1,11 @@
use super::{scalar_conversion_to_int, Scalar, ScalarConversionError};
use crate::{base::math::decimal::MAX_SUPPORTED_PRECISION, sql::parse::ConversionError};
use crate::{
base::{
math::decimal::{DecimalError, MAX_SUPPORTED_PRECISION},
scalar::mont_scalar::DecimalError::InvalidDecimal,
},
sql::parse::{ConversionError, ConversionError::DecimalConversionError},
};
use ark_ff::{BigInteger, Field, Fp, Fp256, MontBackend, MontConfig, PrimeField};
use ark_serialize::{CanonicalDeserialize, CanonicalSerialize};
use bytemuck::TransparentWrapper;
Expand Down Expand Up @@ -163,11 +169,11 @@ impl<T: MontConfig<4>> TryFrom<num_bigint::BigInt> for MontScalar<T> {

// Check if the number of digits exceeds the maximum precision allowed
if digits.len() > MAX_SUPPORTED_PRECISION.into() {
return Err(ConversionError::InvalidDecimal(format!(
return Err(DecimalConversionError(InvalidDecimal(format!(
"Attempted to parse a number with {} digits, which exceeds the max supported precision of {}",
digits.len(),
MAX_SUPPORTED_PRECISION
)));
))));
}

// Continue with the previous logic
Expand Down
19 changes: 15 additions & 4 deletions crates/proof-of-sql/src/sql/ast/comparison_util.rs
Original file line number Diff line number Diff line change
@@ -1,6 +1,16 @@
use crate::{
base::{database::Column, math::decimal::Precision, scalar::Scalar},
sql::parse::{type_check_binary_operation, ConversionError, ConversionResult},
base::{
database::Column,
math::decimal::{DecimalError, Precision},
scalar::Scalar,
},
sql::{
ast::comparison_util::DecimalError::InvalidPrecision,
parse::{
type_check_binary_operation, ConversionError, ConversionError::DecimalConversionError,
ConversionResult,
},
},
};
use bumpalo::Bump;
use proof_of_sql_parser::intermediate_ast::BinaryOperator;
Expand Down Expand Up @@ -67,8 +77,9 @@ pub(crate) fn scale_and_subtract<'a, S: Scalar>(
rhs_precision_value + (max_scale - rhs_scale) as u8,
);
// Check if the precision is valid
let _max_precision = Precision::new(max_precision_value)
.map_err(|_| ConversionError::InvalidPrecision(max_precision_value as i16))?;
let _max_precision = Precision::new(max_precision_value).map_err(|_| {
DecimalConversionError(InvalidPrecision(max_precision_value.to_string()))
})?;
}
unchecked_subtract_impl(
alloc,
Expand Down
12 changes: 8 additions & 4 deletions crates/proof-of-sql/src/sql/ast/numerical_util.rs
Original file line number Diff line number Diff line change
@@ -1,10 +1,13 @@
use crate::{
base::{
database::{Column, ColumnType},
math::decimal::{scale_scalar, Precision},
math::decimal::{scale_scalar, DecimalError, Precision},
scalar::Scalar,
},
sql::parse::{ConversionError, ConversionResult},
sql::{
ast::numerical_util::DecimalError::InvalidPrecision,
parse::{ConversionError, ConversionError::DecimalConversionError, ConversionResult},
},
};
use bumpalo::Bump;

Expand Down Expand Up @@ -41,9 +44,10 @@ pub(crate) fn try_add_subtract_column_types(
.max(right_precision_value - right_scale as i16)
+ 1_i16;
let precision = u8::try_from(precision_value)
.map_err(|_| ConversionError::InvalidPrecision(precision_value))
.map_err(|_| DecimalConversionError(InvalidPrecision(precision_value.to_string())))
.and_then(|p| {
Precision::new(p).map_err(|_| ConversionError::InvalidPrecision(p as i16))
Precision::new(p)
.map_err(|_| DecimalConversionError(InvalidPrecision(p.to_string())))
})?;
Ok(ColumnType::Decimal75(precision, scale))
}
Expand Down
Loading

0 comments on commit 88f3d40

Please sign in to comment.