diff --git a/crates/proof-of-sql-parser/src/identifier.rs b/crates/proof-of-sql-parser/src/identifier.rs index b11df6862..1b86bc702 100644 --- a/crates/proof-of-sql-parser/src/identifier.rs +++ b/crates/proof-of-sql-parser/src/identifier.rs @@ -37,6 +37,19 @@ impl Identifier { Self::from_str(string.as_ref()) } + /// An alias for [`Identifier::new`], provided for convenience. + pub fn new_valid>(string: S) -> Result { + IdentifierParser::new() + .parse(string.as_ref()) + .map(Identifier::new) // Use the internal new method for valid identifiers + .map_err(|e| ParseError::IdentifierParseError { + error: format!( + "Failed to parse identifier: {}. (reserved keyword or invalid format)", + e + ), + }) + } + /// The name of this [Identifier] /// It already implements [Deref] to [str], so this method is not necessary for most use cases. #[must_use] diff --git a/crates/proof-of-sql-parser/src/intermediate_ast.rs b/crates/proof-of-sql-parser/src/intermediate_ast.rs index d89696654..5856679e1 100644 --- a/crates/proof-of-sql-parser/src/intermediate_ast.rs +++ b/crates/proof-of-sql-parser/src/intermediate_ast.rs @@ -4,9 +4,8 @@ * https://docs.rs/vervolg/latest/vervolg/ast/enum.Statement.html ***/ -use crate::{posql_time::PoSQLTimestamp, Identifier}; +use crate::{intermediate_decimal::IntermediateDecimal, posql_time::PoSQLTimestamp, Identifier}; use alloc::{boxed::Box, string::String, vec::Vec}; -use bigdecimal::BigDecimal; use core::{ fmt, fmt::{Display, Formatter}, @@ -346,7 +345,7 @@ pub enum Literal { /// String Literal VarChar(String), /// Decimal Literal - Decimal(BigDecimal), + Decimal(IntermediateDecimal), /// Timestamp Literal Timestamp(PoSQLTimestamp), } @@ -396,8 +395,8 @@ macro_rules! impl_string_to_literal { impl_string_to_literal!(&str); impl_string_to_literal!(String); -impl From for Literal { - fn from(val: BigDecimal) -> Self { +impl From for Literal { + fn from(val: IntermediateDecimal) -> Self { Literal::Decimal(val) } } diff --git a/crates/proof-of-sql-parser/src/intermediate_ast_tests.rs b/crates/proof-of-sql-parser/src/intermediate_ast_tests.rs index 4574de18b..c3d1e570b 100644 --- a/crates/proof-of-sql-parser/src/intermediate_ast_tests.rs +++ b/crates/proof-of-sql-parser/src/intermediate_ast_tests.rs @@ -1,5 +1,6 @@ use crate::{ intermediate_ast::OrderByDirection::{Asc, Desc}, + intermediate_decimal::IntermediateDecimal, sql::*, utility::*, SelectStatement, @@ -9,7 +10,6 @@ use alloc::{ string::{String, ToString}, vec, }; -use bigdecimal::BigDecimal; // Sting parser tests #[test] @@ -143,7 +143,10 @@ fn we_can_parse_a_query_with_constants() { col_res(lit(3), "bigint"), col_res(lit(true), "boolean"), col_res(lit("proof"), "varchar"), - col_res(lit("-2.34".parse::().unwrap()), "decimal"), + col_res( + lit(IntermediateDecimal::try_from("-2.34").unwrap()), + "decimal", + ), ], tab(None, "sxt_tab"), vec![], @@ -217,7 +220,10 @@ fn we_can_parse_a_query_with_a_column_equals_a_decimal() { query( cols_res(&["a"]), tab(None, "sxt_tab"), - equal(col("a"), lit("-0.32".parse::().unwrap())), + equal( + col("a"), + lit(IntermediateDecimal::try_from("-0.32").unwrap()), + ), vec![], ), vec![], @@ -435,7 +441,10 @@ fn we_can_parse_a_query_with_one_logical_or_filter_expression() { tab(None, "sxt_tab"), or( equal(col("b"), lit(3)), - equal(col("c"), lit("-2.34".parse::().unwrap())), + equal( + col("c"), + lit(IntermediateDecimal::try_from("-2.34").unwrap()), + ), ), vec![], ), diff --git a/crates/proof-of-sql-parser/src/intermediate_decimal.rs b/crates/proof-of-sql-parser/src/intermediate_decimal.rs new file mode 100644 index 000000000..dd28f594d --- /dev/null +++ b/crates/proof-of-sql-parser/src/intermediate_decimal.rs @@ -0,0 +1,273 @@ +//! A parser conforming to standard postgreSQL to parse the precision and scale +//! from a decimal token obtained from the lalrpop lexer. This module +//! exists to resolve a cyclic dependency between proof-of-sql +//! and proof-of-sql-parser. +//! +//! A decimal must have a decimal point. The lexer does not route +//! whole integers to this contructor. +use crate::intermediate_decimal::IntermediateDecimalError::{LossyCast, OutOfRange, ParseError}; +use alloc::string::String; +use bigdecimal::{num_bigint::BigInt, BigDecimal, ParseBigDecimalError, ToPrimitive}; +use core::{fmt, hash::Hash, str::FromStr}; +use serde::{Deserialize, Serialize}; +use snafu::Snafu; + +/// Errors related to the processing of decimal values in proof-of-sql +#[allow(clippy::module_name_repetitions)] +#[derive(Snafu, Debug, PartialEq)] +pub enum IntermediateDecimalError { + /// Represents an error encountered during the parsing of a decimal string. + #[snafu(display("{error}"))] + ParseError { + /// The underlying error + error: ParseBigDecimalError, + }, + /// Error occurs when this decimal cannot fit in a primitive. + #[snafu(display("Value out of range for target type"))] + OutOfRange, + /// Error occurs when this decimal cannot be losslessly cast into a primitive. + #[snafu(display("Fractional part of decimal is non-zero"))] + LossyCast, + /// Cannot cast this decimal to a big integer + #[snafu(display("Conversion to integer failed"))] + ConversionFailure, +} +impl From for IntermediateDecimalError { + fn from(value: ParseBigDecimalError) -> Self { + IntermediateDecimalError::ParseError { error: value } + } +} + +impl Eq for IntermediateDecimalError {} + +/// An intermediate placeholder for a decimal +#[derive(Serialize, Deserialize, Debug, PartialEq, Eq, Clone, Hash)] +pub struct IntermediateDecimal { + value: BigDecimal, +} + +impl IntermediateDecimal { + /// Get the integer part of the fixed-point representation of this intermediate decimal. + #[must_use] + pub fn value(&self) -> BigDecimal { + self.value.clone() + } + + /// Get the precision of the fixed-point representation of this intermediate decimal. + #[must_use] + pub fn precision(&self) -> u8 { + match u8::try_from(self.value.digits()) { + Ok(v) => v, + Err(_) => u8::MAX, // Returning u8::MAX on truncation + } + } + + /// Get the scale of the fixed-point representation of this intermediate decimal. + #[must_use] + pub fn scale(&self) -> i8 { + match i8::try_from(self.value.fractional_digit_count()) { + Ok(v) => v, + Err(_) => i8::MAX, // Returning i8::MAX on truncation + } + } + + /// Attempts to convert the decimal to `BigInt` while adjusting it to the specified precision and scale. + /// Returns an error if the conversion cannot be performed due to precision or scale constraints. + /// + /// # Errors + /// Returns an `IntermediateDecimalError::LossyCast` error if the number of digits in the scaled decimal exceeds the specified precision. + pub fn try_into_bigint_with_precision_and_scale( + &self, + precision: u8, + scale: i8, + ) -> Result { + let scaled_decimal = self.value.with_scale(scale.into()); + if scaled_decimal.digits() > precision.into() { + return Err(LossyCast); + } + let (d, _) = scaled_decimal.into_bigint_and_exponent(); + Ok(d) + } +} + +impl fmt::Display for IntermediateDecimal { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + write!(f, "{}", self.value) + } +} + +impl FromStr for IntermediateDecimal { + type Err = IntermediateDecimalError; + + fn from_str(decimal_string: &str) -> Result { + BigDecimal::from_str(decimal_string) + .map(|value| IntermediateDecimal { + value: value.normalized(), + }) + .map_err(|err| ParseError { error: err }) + } +} + +impl From for IntermediateDecimal { + fn from(value: i128) -> Self { + IntermediateDecimal { + value: BigDecimal::from(value), + } + } +} + +impl From for IntermediateDecimal { + fn from(value: i64) -> Self { + IntermediateDecimal { + value: BigDecimal::from(value), + } + } +} + +impl TryFrom<&str> for IntermediateDecimal { + type Error = IntermediateDecimalError; + + fn try_from(s: &str) -> Result { + IntermediateDecimal::from_str(s) + } +} + +impl TryFrom for IntermediateDecimal { + type Error = IntermediateDecimalError; + + fn try_from(s: String) -> Result { + IntermediateDecimal::from_str(&s) + } +} + +impl TryFrom for i128 { + type Error = IntermediateDecimalError; + + fn try_from(decimal: IntermediateDecimal) -> Result { + if !decimal.value.is_integer() { + return Err(LossyCast); + } + + match decimal.value.to_i128() { + Some(value) if (i128::MIN..=i128::MAX).contains(&value) => Ok(value), + _ => Err(OutOfRange), + } + } +} + +impl TryFrom for i64 { + type Error = IntermediateDecimalError; + + fn try_from(decimal: IntermediateDecimal) -> Result { + if !decimal.value.is_integer() { + return Err(LossyCast); + } + + match decimal.value.to_i64() { + Some(value) if (i64::MIN..=i64::MAX).contains(&value) => Ok(value), + _ => Err(OutOfRange), + } + } +} + +#[cfg(test)] +mod tests { + use super::*; + use alloc::string::ToString; + + #[test] + fn test_valid_decimal_simple() { + let decimal = "123.45".parse(); + assert!(decimal.is_ok()); + let unwrapped_decimal: IntermediateDecimal = decimal.unwrap(); + assert_eq!(unwrapped_decimal.to_string(), "123.45"); + assert_eq!(unwrapped_decimal.precision(), 5); + assert_eq!(unwrapped_decimal.scale(), 2); + } + + #[test] + fn test_valid_decimal_with_leading_and_trailing_zeros() { + let decimal = "000123.45000".parse(); + assert!(decimal.is_ok()); + let unwrapped_decimal: IntermediateDecimal = decimal.unwrap(); + assert_eq!(unwrapped_decimal.to_string(), "123.45"); + assert_eq!(unwrapped_decimal.precision(), 5); + assert_eq!(unwrapped_decimal.scale(), 2); + } + + #[test] + fn test_accessors() { + let decimal: IntermediateDecimal = "123.456".parse().unwrap(); + assert_eq!(decimal.to_string(), "123.456"); + assert_eq!(decimal.precision(), 6); + assert_eq!(decimal.scale(), 3); + } + + #[test] + fn test_conversion_to_i128() { + let valid_decimal = IntermediateDecimal { + value: BigDecimal::from_str("170141183460469231731687303715884105727").unwrap(), + }; + assert_eq!( + i128::try_from(valid_decimal), + Ok(170_141_183_460_469_231_731_687_303_715_884_105_727_i128) + ); + + let valid_decimal = IntermediateDecimal { + value: BigDecimal::from_str("123.000").unwrap(), + }; + assert_eq!(i128::try_from(valid_decimal), Ok(123)); + + let overflow_decimal = IntermediateDecimal { + value: BigDecimal::from_str("170141183460469231731687303715884105728").unwrap(), + }; + assert_eq!(i128::try_from(overflow_decimal), Err(OutOfRange)); + + let valid_decimal_negative = IntermediateDecimal { + value: BigDecimal::from_str("-170141183460469231731687303715884105728").unwrap(), + }; + assert_eq!( + i128::try_from(valid_decimal_negative), + Ok(-170_141_183_460_469_231_731_687_303_715_884_105_728_i128) + ); + + let non_integer = IntermediateDecimal { + value: BigDecimal::from_str("100.5").unwrap(), + }; + assert_eq!(i128::try_from(non_integer), Err(LossyCast)); + } + + #[test] + fn test_conversion_to_i64() { + let valid_decimal = IntermediateDecimal { + value: BigDecimal::from_str("9223372036854775807").unwrap(), + }; + assert_eq!( + i64::try_from(valid_decimal), + Ok(9_223_372_036_854_775_807_i64) + ); + + let valid_decimal = IntermediateDecimal { + value: BigDecimal::from_str("123.000").unwrap(), + }; + assert_eq!(i64::try_from(valid_decimal), Ok(123)); + + let overflow_decimal = IntermediateDecimal { + value: BigDecimal::from_str("9223372036854775808").unwrap(), + }; + assert_eq!(i64::try_from(overflow_decimal), Err(OutOfRange)); + + let valid_decimal_negative = IntermediateDecimal { + value: BigDecimal::from_str("-9223372036854775808").unwrap(), + }; + assert_eq!( + i64::try_from(valid_decimal_negative), + Ok(-9_223_372_036_854_775_808_i64) + ); + + let non_integer = IntermediateDecimal { + value: BigDecimal::from_str("100.5").unwrap(), + }; + assert_eq!(i64::try_from(non_integer), Err(LossyCast)); + } +} diff --git a/crates/proof-of-sql-parser/src/lib.rs b/crates/proof-of-sql-parser/src/lib.rs index a600d6b97..b92937815 100644 --- a/crates/proof-of-sql-parser/src/lib.rs +++ b/crates/proof-of-sql-parser/src/lib.rs @@ -3,6 +3,8 @@ #![cfg_attr(test, allow(clippy::missing_panics_doc))] extern crate alloc; +/// Module for handling an intermediate decimal type received from the lexer. +pub mod intermediate_decimal; /// Module for handling an intermediate timestamp type received from the lexer. pub mod posql_time; #[macro_use] diff --git a/crates/proof-of-sql-parser/src/sql.lalrpop b/crates/proof-of-sql-parser/src/sql.lalrpop index ce6dca007..3074833b9 100644 --- a/crates/proof-of-sql-parser/src/sql.lalrpop +++ b/crates/proof-of-sql-parser/src/sql.lalrpop @@ -2,12 +2,11 @@ use crate::intermediate_ast; use crate::select_statement; use crate::identifier; use lalrpop_util::ParseError::User; -use crate::posql_time::PoSQLTimestamp; +use crate::{intermediate_decimal::IntermediateDecimal, posql_time::PoSQLTimestamp}; use alloc::boxed::Box; use alloc::string::String; use alloc::vec; use alloc::vec::Vec; -use bigdecimal::BigDecimal; grammar; @@ -357,8 +356,8 @@ Int128UnaryNumericLiteral: i128 = { "-" =>? expr.checked_neg().ok_or(User {error: "Integer overflow"}), }; -DecimalNumericLiteral: BigDecimal = { - =>? <>.parse::().map_err(|e| User {error: "decimal out of range"}), +DecimalNumericLiteral: IntermediateDecimal = { + =>? IntermediateDecimal::try_from(lit).map_err(|e| User {error: "decimal out of range"}), }; Int128NumericLiteral: i128 = { diff --git a/crates/proof-of-sql/src/sql/parse/error.rs b/crates/proof-of-sql/src/sql/parse/error.rs index c608093dc..462a6a4ee 100644 --- a/crates/proof-of-sql/src/sql/parse/error.rs +++ b/crates/proof-of-sql/src/sql/parse/error.rs @@ -139,6 +139,26 @@ pub enum ConversionError { /// The underlying error error: String, }, + + #[snafu(display("Unsupported query type encountered"))] + /// Unsupported SQL query type in the AST + UnsupportedQueryType, + + #[snafu(display("Invalid table found in SQL query"))] + /// Invalid table reference in SQL query + InvalidTable, + + #[snafu(display("Unsupported SQL operator encountered"))] + /// Unsupported SQL operator in the query + UnsupportedOperator, + + #[snafu(display("Invalid projection in the SELECT clause"))] + /// Invalid projection found in SQL query + InvalidProjection, + + #[snafu(display("Unsupported SQL expression encountered"))] + /// Unsupported SQL expression in the query + UnsupportedExpression } impl From for ConversionError { diff --git a/crates/proof-of-sql/src/sql/parse/query_expr.rs b/crates/proof-of-sql/src/sql/parse/query_expr.rs index 91fb64135..b0ff84503 100644 --- a/crates/proof-of-sql/src/sql/parse/query_expr.rs +++ b/crates/proof-of-sql/src/sql/parse/query_expr.rs @@ -2,7 +2,7 @@ use super::{EnrichedExpr, FilterExecBuilder, QueryContextBuilder}; use crate::{ base::{commitment::Commitment, database::SchemaAccessor}, sql::{ - parse::ConversionResult, + parse::{ConversionError, ConversionResult}, postprocessing::{ GroupByPostprocessing, OrderByPostprocessing, OwnedTablePostprocessing, SelectPostprocessing, SlicePostprocessing, @@ -11,8 +11,15 @@ use crate::{ }, }; use alloc::{fmt, vec, vec::Vec}; -use proof_of_sql_parser::{intermediate_ast::SetExpression, Identifier, SelectStatement}; +use proof_of_sql_parser::{ + intermediate_ast::{ + AliasedResultExpr, BinaryOperator, Expression, SelectResultExpr, SetExpression, + TableExpression, + }, + Identifier, SelectStatement, +}; use serde::{Deserialize, Serialize}; +use sqlparser::ast::GroupByExpr; #[derive(PartialEq, Serialize, Deserialize)] /// A `QueryExpr` represents a Proof of SQL query that can be executed against a database. @@ -158,6 +165,94 @@ impl QueryExpr { } } + pub fn try_new_from_sqlparser( + ast: sqlparser::ast::Query, + default_schema: Identifier, + schema_accessor: &dyn SchemaAccessor, + ) -> ConversionResult { + // Extract the main components from the SQLParser AST + let query_body = match *ast.body { + sqlparser::ast::SetExpr::Select(select_stmt) => select_stmt, + _ => return Err(ConversionError::UnsupportedQueryType), + }; + + // Convert SQL AST components (SELECT, WHERE, etc.) into Proof of SQL structures + let from_clause = query_body + .from + .iter() + .map(|table| Self::convert_sql_table_to_proof_of_sql_table(table)) + .collect::, _>>()?; + + let result_exprs = query_body + .projection + .iter() + .map(|proj| Self::convert_sql_projection_to_proof_of_sql(proj)) + .collect::, _>>()?; + + let where_expr = query_body + .selection + .map(|expr| Self::convert_sql_expr_to_proof_of_sql(&expr)) + .transpose()?; + + let group_by = match query_body.group_by { + GroupByExpr::Expressions(exprs) => { + let mut identifiers = Vec::new(); + for group_by_expr in exprs { + match group_by_expr { + // If the expression is an identifier, attempt to create a valid Identifier + sqlparser::ast::Expr::Identifier(ident) => { + match Identifier::new_valid(ident.value.clone()) { + Ok(valid_ident) => identifiers.push(valid_ident), + Err(e) => { + return Err(ConversionError::ParseError { + error: format!("ParseError: {:?}", e), + }); + } + } + } + // Handle non-Identifier expressions in GROUP BY clause + _ => { + return Err(ConversionError::InvalidGroupByColumnRef { + column: format!( + "Expected identifier, found expression {:?}", + group_by_expr + ), + }); + } + } + } + identifiers + } + GroupByExpr::All => Vec::new(), + }; + + // Build a QueryContext using the Proof of SQL structures + let context = QueryContextBuilder::new(schema_accessor) + .visit_table_expr(&from_clause, default_schema) + .visit_group_by_exprs(group_by)? + .visit_result_exprs(result_exprs)? + .visit_where_expr(where_expr)? + .visit_order_by_exprs(ast.order_by) + .visit_slice_expr(ast.slice) + .build()?; + + // Create and return the QueryExpr with proof_expr and postprocessing + let enriched_exprs = context.get_aliased_result_exprs()?.to_vec(); + + // Build the FilterExec object + let filter = FilterExecBuilder::new(context.get_column_mapping()) + .add_table_expr(*context.get_table_ref()) + .add_where_expr(context.get_where_expr().clone())? + .add_result_columns(&enriched_exprs) + .build(); + + // Return the QueryExpr + Ok(QueryExpr { + proof_expr: DynProofPlan::Filter(filter), + postprocessing: vec![], + }) + } + /// Immutable access to this query's provable filter expression. pub fn proof_expr(&self) -> &DynProofPlan { &self.proof_expr @@ -167,4 +262,61 @@ impl QueryExpr { pub fn postprocessing(&self) -> &[OwnedTablePostprocessing] { &self.postprocessing } + + fn convert_sql_table_to_proof_of_sql_table( + sql_table: &sqlparser::ast::TableWithJoins, + ) -> Result, ConversionError> { + // Convert SQL table reference to Proof of SQL's TableExpression + match &sql_table.relation { + sqlparser::ast::TableFactor::Table { name, .. } => { + let schema = name + .0 + .get(0) + .map(|ident| Identifier::new_valid(ident.value.clone())); + let table = Identifier::new_valid(name.0.get(1).unwrap().value.clone()); + Ok(Box::new(TableExpression::Named { table, schema })) + } + _ => Err(ConversionError::InvalidTable), + } + } + + fn convert_sql_projection_to_proof_of_sql( + projection: &sqlparser::ast::SelectItem, + ) -> Result { + match projection { + sqlparser::ast::SelectItem::UnnamedExpr(expr) => { + Ok(SelectResultExpr::AliasedResultExpr(AliasedResultExpr { + expr: Box::new(convert_sql_expr_to_proof_of_sql(expr)?), + alias: Identifier::try_new("alias")?, + })) + } + _ => Err(ConversionError::InvalidProjection), + } + } + + fn convert_sql_expr_to_proof_of_sql( + expr: &sqlparser::ast::Expr, + ) -> Result, ConversionError> { + match expr { + sqlparser::ast::Expr::Identifier(ident) => Ok(Box::new(Expression::Column( + Identifier::new_valid(ident.value.clone()), + ))), + sqlparser::ast::Expr::BinaryOp { left, op, right } => { + let left_expr = convert_sql_expr_to_proof_of_sql(left)?; + let right_expr = convert_sql_expr_to_proof_of_sql(right)?; + let op = match op { + sqlparser::ast::BinaryOperator::Eq => BinaryOperator::Equal, + sqlparser::ast::BinaryOperator::Gt => BinaryOperator::GreaterThanOrEqual, + sqlparser::ast::BinaryOperator::Lt => BinaryOperator::LessThanOrEqual, + _ => return Err(ConversionError::UnsupportedOperator), + }; + Ok(Box::new(Expression::Binary { + left: left_expr, + right: right_expr, + op, + })) + } + _ => Err(ConversionError::UnsupportedExpression), + } + } }