Skip to content

Commit

Permalink
Merge branch 'add-sqlparser' of https://github.com/varshith257/sxt-pr…
Browse files Browse the repository at this point in the history
…oof-of-sql into add-sqlparser
  • Loading branch information
varshith257 committed Oct 21, 2024
2 parents 9a07d3a + 92c4cf4 commit da03f0c
Show file tree
Hide file tree
Showing 11 changed files with 257 additions and 204 deletions.
6 changes: 4 additions & 2 deletions crates/proof-of-sql/src/base/database/column_operation.rs
Original file line number Diff line number Diff line change
Expand Up @@ -115,7 +115,7 @@ pub fn try_multiply_column_types(
let scale = left_scale.checked_add(right_scale).ok_or(
ColumnOperationError::DecimalConversionError {
source: DecimalError::InvalidScale {
scale: i16::from(left_scale) + i16::from(right_scale),
scale: (i16::from(left_scale) + i16::from(right_scale)).to_string(),
},
},
)?;
Expand Down Expand Up @@ -160,7 +160,9 @@ pub fn try_divide_column_types(
let precision_value: i16 = left_precision_value - left_scale + right_scale + raw_scale;
let scale =
i8::try_from(raw_scale).map_err(|_| ColumnOperationError::DecimalConversionError {
source: DecimalError::InvalidScale { scale: raw_scale },
source: DecimalError::InvalidScale {
scale: raw_scale.to_string(),
},
})?;
let precision = u8::try_from(precision_value)
.map_err(|_| ColumnOperationError::DecimalConversionError {
Expand Down
16 changes: 12 additions & 4 deletions crates/proof-of-sql/src/base/database/expression_evaluation.rs
Original file line number Diff line number Diff line change
@@ -1,7 +1,10 @@
use super::{ExpressionEvaluationError, ExpressionEvaluationResult};
use crate::base::{
database::{OwnedColumn, OwnedTable},
math::decimal::{try_into_to_scalar, Precision},
math::{
decimal::{try_convert_intermediate_decimal_to_scalar, DecimalError, Precision},
BigDecimalExt,
},
scalar::Scalar,
};
use alloc::{format, string::ToString, vec};
Expand Down Expand Up @@ -44,9 +47,14 @@ impl<S: Scalar> OwnedTable<S> {
Literal::BigInt(i) => Ok(OwnedColumn::BigInt(vec![*i; len])),
Literal::Int128(i) => Ok(OwnedColumn::Int128(vec![*i; len])),
Literal::Decimal(d) => {
let scale = d.scale();
let precision = Precision::new(d.precision())?;
let scalar = try_into_to_scalar(d, precision, scale)?;
let raw_scale = d.scale();
let scale = raw_scale
.try_into()
.map_err(|_| DecimalError::InvalidScale {
scale: raw_scale.to_string(),
})?;
let precision = Precision::try_from(d.precision())?;
let scalar = try_convert_intermediate_decimal_to_scalar(d, precision, scale)?;
Ok(OwnedColumn::Decimal75(precision, scale, vec![scalar; len]))
}
Literal::VarChar(s) => Ok(OwnedColumn::VarChar(vec![s.clone(); len])),
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -6,9 +6,9 @@ use crate::base::{
math::decimal::Precision,
scalar::Curve25519Scalar,
};
use bigdecimal::BigDecimal;
use proof_of_sql_parser::{
intermediate_ast::Literal,
intermediate_decimal::IntermediateDecimal,
posql_time::{PoSQLTimeUnit, PoSQLTimeZone, PoSQLTimestamp},
utility::*,
};
Expand Down Expand Up @@ -46,7 +46,7 @@ fn we_can_evaluate_a_simple_literal() {
assert_eq!(actual_column, expected_column);

// A group of people has about 0.67 cats per person
let expr = lit("0.67".parse::<IntermediateDecimal>().unwrap());
let expr = lit("0.67".parse::<BigDecimal>().unwrap());
let actual_column = table.evaluate(&expr).unwrap();
let expected_column = OwnedColumn::Decimal75(Precision::new(2).unwrap(), 2, vec![67.into(); 5]);
assert_eq!(actual_column, expected_column);
Expand Down Expand Up @@ -165,10 +165,7 @@ fn we_can_evaluate_an_arithmetic_expression() {
// Multiply decimals with 0.75 and add smallints to the product
let expr = add(
col("smallints"),
mul(
col("decimals"),
lit("0.75".parse::<IntermediateDecimal>().unwrap()),
),
mul(col("decimals"), lit("0.75".parse::<BigDecimal>().unwrap())),
);
let actual_column = table.evaluate(&expr).unwrap();
let expected_scalars = [-2000, -925, 150, 1225, 2300]
Expand All @@ -180,10 +177,7 @@ fn we_can_evaluate_an_arithmetic_expression() {

// Decimals over 2.5 plus int128s
let expr = add(
div(
col("decimals"),
lit("2.5".parse::<IntermediateDecimal>().unwrap()),
),
div(col("decimals"), lit("2.5".parse::<BigDecimal>().unwrap())),
col("int128s"),
);
let actual_column = table.evaluate(&expr).unwrap();
Expand Down
81 changes: 81 additions & 0 deletions crates/proof-of-sql/src/base/math/big_decimal_ext.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,81 @@
use super::decimal::{IntermediateDecimalError, IntermediateDecimalError::LossyCast};
use bigdecimal::BigDecimal;
use num_bigint::BigInt;

pub trait BigDecimalExt {
fn precision(&self) -> u64;
fn scale(&self) -> i64;
fn try_into_bigint_with_precision_and_scale(
&self,
precision: u8,
scale: i8,
) -> Result<BigInt, IntermediateDecimalError>;
}
impl BigDecimalExt for BigDecimal {
/// Get the precision of the fixed-point representation of this intermediate decimal.
#[must_use]
fn precision(&self) -> u64 {
self.normalized().digits()
}

/// Get the scale of the fixed-point representation of this intermediate decimal.
#[must_use]
fn scale(&self) -> i64 {
self.normalized().fractional_digit_count()
}

/// Attempts to convert the decimal to `BigInt` while adjusting it to the specified precision and scale.
/// Returns an error if the conversion cannot be performed due to precision or scale constraints.
///
/// # Errors
/// Returns an `IntermediateDecimalError::LossyCast` error if the number of digits in the scaled decimal exceeds the specified precision.
fn try_into_bigint_with_precision_and_scale(
&self,
precision: u8,
scale: i8,
) -> Result<BigInt, IntermediateDecimalError> {
if self.scale() > scale.into() {
Err(IntermediateDecimalError::ConversionFailure)?;
}
let scaled_decimal = self.normalized().with_scale(scale.into());
if scaled_decimal.digits() > precision.into() {
return Err(LossyCast);
}
let (d, _) = scaled_decimal.into_bigint_and_exponent();
Ok(d)
}
}

#[cfg(test)]
mod tests {
use super::*;
use alloc::string::ToString;

#[test]
fn test_valid_decimal_simple() {
let decimal = "123.45".parse::<BigDecimal>();
assert!(decimal.is_ok());
let unwrapped_decimal: BigDecimal = decimal.unwrap().normalized();
assert_eq!(unwrapped_decimal.to_string(), "123.45");
assert_eq!(unwrapped_decimal.precision(), 5);
assert_eq!(unwrapped_decimal.scale(), 2);
}

#[test]
fn test_valid_decimal_with_leading_and_trailing_zeros() {
let decimal = "000123.45000".parse::<BigDecimal>();
assert!(decimal.is_ok());
let unwrapped_decimal: BigDecimal = decimal.unwrap().normalized();
assert_eq!(unwrapped_decimal.to_string(), "123.45");
assert_eq!(unwrapped_decimal.precision(), 5);
assert_eq!(unwrapped_decimal.scale(), 2);
}

#[test]
fn test_accessors() {
let decimal: BigDecimal = "123.456".parse::<BigDecimal>().unwrap().normalized();
assert_eq!(decimal.to_string(), "123.456");
assert_eq!(decimal.precision(), 6);
assert_eq!(decimal.scale(), 3);
}
}
Loading

0 comments on commit da03f0c

Please sign in to comment.