diff --git a/.github/workflows/lint-and-test.yml b/.github/workflows/lint-and-test.yml index 9a9b1828a..fb92317ad 100644 --- a/.github/workflows/lint-and-test.yml +++ b/.github/workflows/lint-and-test.yml @@ -182,7 +182,7 @@ jobs: # files: lcov.info # fail_ci_if_error: true - # Run cargo fmt --all -- --config imports_granularity=Crate,group_imports=One --check + # Run cargo fmt --all -- --check format: name: Format runs-on: ubuntu-latest @@ -194,7 +194,7 @@ jobs: curl https://sh.rustup.rs -sSf | bash -s -- -y --profile minimal && source ~/.cargo/env rustup component add rustfmt - name: Run cargo fmt - run: cargo fmt --all -- --config imports_granularity=Crate,group_imports=One --check + run: cargo fmt --all -- --check udeps: name: Unused Dependencies diff --git a/crates/proof-of-sql/src/base/database/columnar_value.rs b/crates/proof-of-sql/src/base/database/columnar_value.rs index 0505044b6..ce333216f 100644 --- a/crates/proof-of-sql/src/base/database/columnar_value.rs +++ b/crates/proof-of-sql/src/base/database/columnar_value.rs @@ -1,8 +1,9 @@ use crate::base::{ - database::{Column, ColumnType, LiteralValue}, + database::{Column, ColumnOperationError, ColumnOperationResult, ColumnType, LiteralValue}, scalar::Scalar, }; use bumpalo::Bump; +use proof_of_sql_parser::intermediate_ast::{BinaryOperator, UnaryOperator}; use snafu::Snafu; /// The result of evaluating an expression. @@ -37,6 +38,14 @@ impl<'a, S: Scalar> ColumnarValue<'a, S> { } } + /// Get default length of the [`ColumnarValue`] + pub fn default_length(&self) -> usize { + match self { + Self::Column(column) => column.len(), + Self::Literal(_) => 1, + } + } + /// Converts the [`ColumnarValue`] to a [`Column`] pub fn into_column( &self, @@ -59,6 +68,88 @@ impl<'a, S: Scalar> ColumnarValue<'a, S> { } } } + + /// Applies a unary operator to a [`ColumnarValue`]. + pub(crate) fn apply_boolean_unary_operator( + &self, + op: UnaryOperator, + alloc: &'a Bump, + ) -> ColumnOperationResult> { + match (self, op) { + (ColumnarValue::Literal(LiteralValue::Boolean(value)), UnaryOperator::Not) => { + Ok(ColumnarValue::Literal(LiteralValue::Boolean(!value))) + } + (ColumnarValue::Column(Column::Boolean(column)), UnaryOperator::Not) => { + Ok(ColumnarValue::Column(Column::Boolean( + alloc.alloc_slice_fill_with(column.len(), |i| !column[i]), + ))) + } + _ => Err(ColumnOperationError::UnaryOperationInvalidColumnType { + operator: op, + operand_type: self.column_type(), + }), + } + } + + /// Applies a binary operator to two [`ColumnarValue`]s. + pub(crate) fn apply_boolean_binary_operator( + &self, + rhs: &Self, + op: BinaryOperator, + alloc: &'a Bump, + ) -> ColumnOperationResult> { + let op_fn = match op { + BinaryOperator::And => |lhs, rhs| lhs && rhs, + BinaryOperator::Or => |lhs, rhs| lhs || rhs, + _ => { + return Err(ColumnOperationError::BinaryOperationInvalidColumnType { + operator: op, + left_type: self.column_type(), + right_type: rhs.column_type(), + }) + } + }; + match (self, rhs) { + ( + ColumnarValue::Literal(LiteralValue::Boolean(lhs)), + ColumnarValue::Literal(LiteralValue::Boolean(rhs)), + ) => Ok(ColumnarValue::Literal(LiteralValue::Boolean(op_fn( + *lhs, *rhs, + )))), + ( + ColumnarValue::Column(Column::Boolean(lhs)), + ColumnarValue::Literal(LiteralValue::Boolean(rhs)), + ) => Ok(ColumnarValue::Column(Column::Boolean( + alloc.alloc_slice_fill_with(lhs.len(), |i| op_fn(lhs[i], *rhs)), + ))), + ( + ColumnarValue::Literal(LiteralValue::Boolean(lhs)), + ColumnarValue::Column(Column::Boolean(rhs)), + ) => Ok(ColumnarValue::Column(Column::Boolean( + alloc.alloc_slice_fill_with(rhs.len(), |i| op_fn(*lhs, rhs[i])), + ))), + ( + ColumnarValue::Column(Column::Boolean(lhs)), + ColumnarValue::Column(Column::Boolean(rhs)), + ) => { + let len = lhs.len(); + if len != rhs.len() { + return Err(ColumnOperationError::DifferentColumnLength { + len_a: len, + len_b: rhs.len(), + }); + } + Ok(ColumnarValue::Column(Column::Boolean( + alloc.alloc_slice_fill_with(len, |i| op_fn(lhs[i], rhs[i])), + ))) + } + _ => Err(ColumnOperationError::BinaryOperationInvalidColumnType { + operator: op, + left_type: self.column_type(), + right_type: rhs.column_type(), + }), + } + } } #[cfg(test)] diff --git a/crates/proof-of-sql/src/sql/proof/proof_plan.rs b/crates/proof-of-sql/src/sql/proof/proof_plan.rs index 42ceceab1..d7770049c 100644 --- a/crates/proof-of-sql/src/sql/proof/proof_plan.rs +++ b/crates/proof-of-sql/src/sql/proof/proof_plan.rs @@ -55,7 +55,6 @@ pub trait ProverEvaluate { /// Evaluate the query and modify `FirstRoundBuilder` to track the result of the query. fn result_evaluate<'a>( &self, - input_length: usize, alloc: &'a Bump, accessor: &'a dyn DataAccessor, ) -> Vec>; diff --git a/crates/proof-of-sql/src/sql/proof/query_proof.rs b/crates/proof-of-sql/src/sql/proof/query_proof.rs index 62f3ada00..c2c2385cb 100644 --- a/crates/proof-of-sql/src/sql/proof/query_proof.rs +++ b/crates/proof-of-sql/src/sql/proof/query_proof.rs @@ -55,7 +55,7 @@ impl QueryProof { let alloc = Bump::new(); // Evaluate query result - let result_cols = expr.result_evaluate(table_length, &alloc, accessor); + let result_cols = expr.result_evaluate(&alloc, accessor); let output_length = result_cols.first().map_or(0, Column::len); let provable_result = ProvableQueryResult::new(output_length as u64, &result_cols); diff --git a/crates/proof-of-sql/src/sql/proof/query_proof_test.rs b/crates/proof-of-sql/src/sql/proof/query_proof_test.rs index e6e685673..6571fc514 100644 --- a/crates/proof-of-sql/src/sql/proof/query_proof_test.rs +++ b/crates/proof-of-sql/src/sql/proof/query_proof_test.rs @@ -43,7 +43,6 @@ impl Default for TrivialTestProofPlan { impl ProverEvaluate for TrivialTestProofPlan { fn result_evaluate<'a>( &self, - _input_length: usize, alloc: &'a Bump, _accessor: &'a dyn DataAccessor, ) -> Vec> { @@ -203,7 +202,6 @@ impl Default for SquareTestProofPlan { impl ProverEvaluate for SquareTestProofPlan { fn result_evaluate<'a>( &self, - _table_length: usize, alloc: &'a Bump, _accessor: &'a dyn DataAccessor, ) -> Vec> { @@ -388,7 +386,6 @@ impl Default for DoubleSquareTestProofPlan { impl ProverEvaluate for DoubleSquareTestProofPlan { fn result_evaluate<'a>( &self, - _input_length: usize, alloc: &'a Bump, _accessor: &'a dyn DataAccessor, ) -> Vec> { @@ -603,7 +600,6 @@ struct ChallengeTestProofPlan {} impl ProverEvaluate for ChallengeTestProofPlan { fn result_evaluate<'a>( &self, - _input_length: usize, _alloc: &'a Bump, _accessor: &'a dyn DataAccessor, ) -> Vec> { diff --git a/crates/proof-of-sql/src/sql/proof/verifiable_query_result_test.rs b/crates/proof-of-sql/src/sql/proof/verifiable_query_result_test.rs index d2db5df0e..889a8229f 100644 --- a/crates/proof-of-sql/src/sql/proof/verifiable_query_result_test.rs +++ b/crates/proof-of-sql/src/sql/proof/verifiable_query_result_test.rs @@ -27,7 +27,6 @@ pub(super) struct EmptyTestQueryExpr { impl ProverEvaluate for EmptyTestQueryExpr { fn result_evaluate<'a>( &self, - _input_length: usize, alloc: &'a Bump, _accessor: &'a dyn DataAccessor, ) -> Vec> { diff --git a/crates/proof-of-sql/src/sql/proof_exprs/add_subtract_expr.rs b/crates/proof-of-sql/src/sql/proof_exprs/add_subtract_expr.rs index 10a623f85..6f8b32c94 100644 --- a/crates/proof-of-sql/src/sql/proof_exprs/add_subtract_expr.rs +++ b/crates/proof-of-sql/src/sql/proof_exprs/add_subtract_expr.rs @@ -1,10 +1,13 @@ -use super::{add_subtract_columns, scale_and_add_subtract_eval, DynProofExpr, ProofExpr}; +use super::{ + add_subtract_columnar_values, add_subtract_columns, scale_and_add_subtract_eval, DynProofExpr, + ProofExpr, +}; use crate::{ base::{ commitment::Commitment, database::{ - try_add_subtract_column_types, Column, ColumnRef, ColumnType, CommitmentAccessor, - DataAccessor, + try_add_subtract_column_types, Column, ColumnRef, ColumnType, ColumnarValue, + CommitmentAccessor, DataAccessor, }, map::IndexSet, proof::ProofError, @@ -54,22 +57,19 @@ impl ProofExpr for AddSubtractExpr { fn result_evaluate<'a>( &self, - table_length: usize, alloc: &'a Bump, accessor: &'a dyn DataAccessor, - ) -> Column<'a, C::Scalar> { - let lhs_column: Column<'a, C::Scalar> = - self.lhs.result_evaluate(table_length, alloc, accessor); - let rhs_column: Column<'a, C::Scalar> = - self.rhs.result_evaluate(table_length, alloc, accessor); - Column::Scalar(add_subtract_columns( - lhs_column, - rhs_column, + ) -> ColumnarValue<'a, C::Scalar> { + let lhs: ColumnarValue<'a, C::Scalar> = self.lhs.result_evaluate(alloc, accessor); + let rhs: ColumnarValue<'a, C::Scalar> = self.rhs.result_evaluate(alloc, accessor); + add_subtract_columnar_values( + lhs, + rhs, self.lhs.data_type().scale().unwrap_or(0), self.rhs.data_type().scale().unwrap_or(0), alloc, self.is_subtract, - )) + ) } #[tracing::instrument( diff --git a/crates/proof-of-sql/src/sql/proof_exprs/add_subtract_expr_test.rs b/crates/proof-of-sql/src/sql/proof_exprs/add_subtract_expr_test.rs index bb266bb38..04fdb0c3f 100644 --- a/crates/proof-of-sql/src/sql/proof_exprs/add_subtract_expr_test.rs +++ b/crates/proof-of-sql/src/sql/proof_exprs/add_subtract_expr_test.rs @@ -1,7 +1,7 @@ use crate::{ base::{ commitment::InnerProductProof, - database::{owned_table_utility::*, Column, OwnedTableTestAccessor}, + database::{owned_table_utility::*, Column, ColumnarValue, OwnedTableTestAccessor}, scalar::Curve25519Scalar, }, sql::{ @@ -317,11 +317,11 @@ fn we_can_compute_the_correct_output_of_an_add_subtract_expr_using_result_evalua subtract(column(t, "a", &accessor), const_bigint(1)), ); let alloc = Bump::new(); - let res = add_subtract_expr.result_evaluate(4, &alloc, &accessor); + let res = add_subtract_expr.result_evaluate(&alloc, &accessor); let expected_res_scalar = [0, 2, 2, 4] .iter() .map(|v| Curve25519Scalar::from(*v)) .collect::>(); - let expected_res = Column::Scalar(&expected_res_scalar); + let expected_res = ColumnarValue::Column(Column::Scalar(&expected_res_scalar)); assert_eq!(res, expected_res); } diff --git a/crates/proof-of-sql/src/sql/proof_exprs/aggregate_expr.rs b/crates/proof-of-sql/src/sql/proof_exprs/aggregate_expr.rs index d11c157b5..ad2fa1315 100644 --- a/crates/proof-of-sql/src/sql/proof_exprs/aggregate_expr.rs +++ b/crates/proof-of-sql/src/sql/proof_exprs/aggregate_expr.rs @@ -2,7 +2,9 @@ use super::{DynProofExpr, ProofExpr}; use crate::{ base::{ commitment::Commitment, - database::{Column, ColumnRef, ColumnType, CommitmentAccessor, DataAccessor}, + database::{ + Column, ColumnRef, ColumnType, ColumnarValue, CommitmentAccessor, DataAccessor, + }, map::IndexSet, proof::ProofError, }, @@ -45,11 +47,10 @@ impl ProofExpr for AggregateExpr { #[tracing::instrument(name = "AggregateExpr::result_evaluate", level = "debug", skip_all)] fn result_evaluate<'a>( &self, - table_length: usize, alloc: &'a Bump, accessor: &'a dyn DataAccessor, - ) -> Column<'a, C::Scalar> { - self.expr.result_evaluate(table_length, alloc, accessor) + ) -> ColumnarValue<'a, C::Scalar> { + self.expr.result_evaluate(alloc, accessor) } #[tracing::instrument(name = "AggregateExpr::prover_evaluate", level = "debug", skip_all)] diff --git a/crates/proof-of-sql/src/sql/proof_exprs/and_expr.rs b/crates/proof-of-sql/src/sql/proof_exprs/and_expr.rs index 48eff8997..2576aa690 100644 --- a/crates/proof-of-sql/src/sql/proof_exprs/and_expr.rs +++ b/crates/proof-of-sql/src/sql/proof_exprs/and_expr.rs @@ -2,7 +2,9 @@ use super::{DynProofExpr, ProofExpr}; use crate::{ base::{ commitment::Commitment, - database::{Column, ColumnRef, ColumnType, CommitmentAccessor, DataAccessor}, + database::{ + Column, ColumnRef, ColumnType, ColumnarValue, CommitmentAccessor, DataAccessor, + }, map::IndexSet, proof::ProofError, }, @@ -11,6 +13,7 @@ use crate::{ use alloc::{boxed::Box, vec}; use bumpalo::Bump; use num_traits::One; +use proof_of_sql_parser::intermediate_ast::BinaryOperator; use serde::{Deserialize, Serialize}; /// Provable logical AND expression @@ -44,17 +47,16 @@ impl ProofExpr for AndExpr { #[tracing::instrument(name = "AndExpr::result_evaluate", level = "debug", skip_all)] fn result_evaluate<'a>( &self, - table_length: usize, alloc: &'a Bump, accessor: &'a dyn DataAccessor, - ) -> Column<'a, C::Scalar> { - let lhs_column: Column<'a, C::Scalar> = - self.lhs.result_evaluate(table_length, alloc, accessor); - let rhs_column: Column<'a, C::Scalar> = - self.rhs.result_evaluate(table_length, alloc, accessor); - let lhs = lhs_column.as_boolean().expect("lhs is not boolean"); - let rhs = rhs_column.as_boolean().expect("rhs is not boolean"); - Column::Boolean(alloc.alloc_slice_fill_with(table_length, |i| lhs[i] && rhs[i])) + ) -> ColumnarValue<'a, C::Scalar> { + let lhs_columnar_value: ColumnarValue<'a, C::Scalar> = + self.lhs.result_evaluate(alloc, accessor); + let rhs_columnar_value: ColumnarValue<'a, C::Scalar> = + self.rhs.result_evaluate(alloc, accessor); + lhs_columnar_value + .apply_boolean_binary_operator(&rhs_columnar_value, BinaryOperator::And, alloc) + .expect("Failed to apply boolean binary operator") } #[tracing::instrument(name = "AndExpr::prover_evaluate", level = "debug", skip_all)] diff --git a/crates/proof-of-sql/src/sql/proof_exprs/and_expr_test.rs b/crates/proof-of-sql/src/sql/proof_exprs/and_expr_test.rs index c96b9cb8b..1fe53102c 100644 --- a/crates/proof-of-sql/src/sql/proof_exprs/and_expr_test.rs +++ b/crates/proof-of-sql/src/sql/proof_exprs/and_expr_test.rs @@ -1,7 +1,7 @@ use crate::{ base::{ commitment::InnerProductProof, - database::{owned_table_utility::*, Column, OwnedTableTestAccessor}, + database::{owned_table_utility::*, Column, ColumnarValue, OwnedTableTestAccessor}, }, sql::{ proof::{exercise_verification, VerifiableQueryResult}, @@ -160,7 +160,7 @@ fn we_can_compute_the_correct_output_of_an_and_expr_using_result_evaluate() { equal(column(t, "d", &accessor), const_varchar("t")), ); let alloc = Bump::new(); - let res = and_expr.result_evaluate(4, &alloc, &accessor); - let expected_res = Column::Boolean(&[false, true, false, false]); + let res = and_expr.result_evaluate(&alloc, &accessor); + let expected_res = ColumnarValue::Column(Column::Boolean(&[false, true, false, false])); assert_eq!(res, expected_res); } diff --git a/crates/proof-of-sql/src/sql/proof_exprs/column_expr.rs b/crates/proof-of-sql/src/sql/proof_exprs/column_expr.rs index 93b7be813..dac3159f9 100644 --- a/crates/proof-of-sql/src/sql/proof_exprs/column_expr.rs +++ b/crates/proof-of-sql/src/sql/proof_exprs/column_expr.rs @@ -2,7 +2,10 @@ use super::ProofExpr; use crate::{ base::{ commitment::Commitment, - database::{Column, ColumnField, ColumnRef, ColumnType, CommitmentAccessor, DataAccessor}, + database::{ + Column, ColumnField, ColumnRef, ColumnType, ColumnarValue, CommitmentAccessor, + DataAccessor, + }, map::IndexSet, proof::ProofError, }, @@ -62,13 +65,11 @@ impl ProofExpr for ColumnExpr { /// add the result to the [`FirstRoundBuilder`](crate::sql::proof::FirstRoundBuilder) fn result_evaluate<'a>( &self, - table_length: usize, _alloc: &'a Bump, accessor: &'a dyn DataAccessor, - ) -> Column<'a, C::Scalar> { + ) -> ColumnarValue<'a, C::Scalar> { let column = accessor.get_column(self.column_ref); - assert_eq!(column.len(), table_length); - column + ColumnarValue::Column(column) } /// Given the selected rows (as a slice of booleans), evaluate the column expression and diff --git a/crates/proof-of-sql/src/sql/proof_exprs/comparison_util.rs b/crates/proof-of-sql/src/sql/proof_exprs/comparison_util.rs index c0a8cc291..f92f9befb 100644 --- a/crates/proof-of-sql/src/sql/proof_exprs/comparison_util.rs +++ b/crates/proof-of-sql/src/sql/proof_exprs/comparison_util.rs @@ -174,37 +174,30 @@ pub(crate) fn scale_and_subtract_columnar_value<'a, S: Scalar>( lhs_scale: i8, rhs_scale: i8, is_equal: bool, -) -> ConversionResult> { +) -> ConversionResult<&'a [S]> { match (lhs, rhs) { (ColumnarValue::Column(lhs), ColumnarValue::Column(rhs)) => { - Ok(ColumnarValue::Column(Column::Scalar(scale_and_subtract( - alloc, lhs, rhs, lhs_scale, rhs_scale, is_equal, - )?))) - } - (ColumnarValue::Literal(lhs), ColumnarValue::Column(rhs)) => { - Ok(ColumnarValue::Column(Column::Scalar(scale_and_subtract( - alloc, - Column::from_literal_with_length(&lhs, rhs.len(), alloc), - rhs, - lhs_scale, - rhs_scale, - is_equal, - )?))) - } - (ColumnarValue::Column(lhs), ColumnarValue::Literal(rhs)) => { - Ok(ColumnarValue::Column(Column::Scalar(scale_and_subtract( - alloc, - lhs, - Column::from_literal_with_length(&rhs, lhs.len(), alloc), - lhs_scale, - rhs_scale, - is_equal, - )?))) + scale_and_subtract(alloc, lhs, rhs, lhs_scale, rhs_scale, is_equal) } + (ColumnarValue::Literal(lhs), ColumnarValue::Column(rhs)) => scale_and_subtract( + alloc, + Column::from_literal_with_length(&lhs, rhs.len(), alloc), + rhs, + lhs_scale, + rhs_scale, + is_equal, + ), + (ColumnarValue::Column(lhs), ColumnarValue::Literal(rhs)) => scale_and_subtract( + alloc, + lhs, + Column::from_literal_with_length(&rhs, lhs.len(), alloc), + lhs_scale, + rhs_scale, + is_equal, + ), (ColumnarValue::Literal(lhs), ColumnarValue::Literal(rhs)) => { - Ok(ColumnarValue::Literal(LiteralValue::Scalar( - scale_and_subtract_literal(&lhs, &rhs, lhs_scale, rhs_scale, is_equal)?, - ))) + let res = scale_and_subtract_literal(&lhs, &rhs, lhs_scale, rhs_scale, is_equal)?; + Ok(alloc.alloc_slice_fill_with(1, |_| res)) } } } diff --git a/crates/proof-of-sql/src/sql/proof_exprs/dyn_proof_expr.rs b/crates/proof-of-sql/src/sql/proof_exprs/dyn_proof_expr.rs index 63611cb59..731656924 100644 --- a/crates/proof-of-sql/src/sql/proof_exprs/dyn_proof_expr.rs +++ b/crates/proof-of-sql/src/sql/proof_exprs/dyn_proof_expr.rs @@ -5,7 +5,10 @@ use super::{ use crate::{ base::{ commitment::Commitment, - database::{Column, ColumnRef, ColumnType, CommitmentAccessor, DataAccessor, LiteralValue}, + database::{ + Column, ColumnRef, ColumnType, ColumnarValue, CommitmentAccessor, DataAccessor, + LiteralValue, + }, map::IndexSet, proof::ProofError, }, @@ -213,41 +216,24 @@ impl ProofExpr for DynProofExpr { fn result_evaluate<'a>( &self, - table_length: usize, alloc: &'a Bump, accessor: &'a dyn DataAccessor, - ) -> Column<'a, C::Scalar> { + ) -> ColumnarValue<'a, C::Scalar> { match self { - DynProofExpr::Column(expr) => { - ProofExpr::::result_evaluate(expr, table_length, alloc, accessor) - } - DynProofExpr::And(expr) => { - ProofExpr::::result_evaluate(expr, table_length, alloc, accessor) - } - DynProofExpr::Or(expr) => { - ProofExpr::::result_evaluate(expr, table_length, alloc, accessor) - } - DynProofExpr::Not(expr) => { - ProofExpr::::result_evaluate(expr, table_length, alloc, accessor) - } - DynProofExpr::Literal(expr) => { - ProofExpr::::result_evaluate(expr, table_length, alloc, accessor) - } - DynProofExpr::Equals(expr) => { - ProofExpr::::result_evaluate(expr, table_length, alloc, accessor) - } + DynProofExpr::Column(expr) => ProofExpr::::result_evaluate(expr, alloc, accessor), + DynProofExpr::And(expr) => ProofExpr::::result_evaluate(expr, alloc, accessor), + DynProofExpr::Or(expr) => ProofExpr::::result_evaluate(expr, alloc, accessor), + DynProofExpr::Not(expr) => ProofExpr::::result_evaluate(expr, alloc, accessor), + DynProofExpr::Literal(expr) => ProofExpr::::result_evaluate(expr, alloc, accessor), + DynProofExpr::Equals(expr) => ProofExpr::::result_evaluate(expr, alloc, accessor), DynProofExpr::Inequality(expr) => { - ProofExpr::::result_evaluate(expr, table_length, alloc, accessor) + ProofExpr::::result_evaluate(expr, alloc, accessor) } DynProofExpr::AddSubtract(expr) => { - ProofExpr::::result_evaluate(expr, table_length, alloc, accessor) - } - DynProofExpr::Multiply(expr) => { - ProofExpr::::result_evaluate(expr, table_length, alloc, accessor) - } - DynProofExpr::Aggregate(expr) => { - ProofExpr::::result_evaluate(expr, table_length, alloc, accessor) + ProofExpr::::result_evaluate(expr, alloc, accessor) } + DynProofExpr::Multiply(expr) => ProofExpr::::result_evaluate(expr, alloc, accessor), + DynProofExpr::Aggregate(expr) => ProofExpr::::result_evaluate(expr, alloc, accessor), } } diff --git a/crates/proof-of-sql/src/sql/proof_exprs/equals_expr.rs b/crates/proof-of-sql/src/sql/proof_exprs/equals_expr.rs index 5805898b1..4ac090314 100644 --- a/crates/proof-of-sql/src/sql/proof_exprs/equals_expr.rs +++ b/crates/proof-of-sql/src/sql/proof_exprs/equals_expr.rs @@ -1,8 +1,14 @@ -use super::{scale_and_add_subtract_eval, scale_and_subtract, DynProofExpr, ProofExpr}; +use super::{ + scale_and_add_subtract_eval, scale_and_subtract, scale_and_subtract_columnar_value, + DynProofExpr, ProofExpr, +}; use crate::{ base::{ commitment::Commitment, - database::{Column, ColumnRef, ColumnType, CommitmentAccessor, DataAccessor}, + database::{ + Column, ColumnRef, ColumnType, ColumnarValue, CommitmentAccessor, DataAccessor, + LiteralValue, + }, map::IndexSet, proof::ProofError, scalar::Scalar, @@ -43,17 +49,34 @@ impl ProofExpr for EqualsExpr { #[tracing::instrument(name = "EqualsExpr::result_evaluate", level = "debug", skip_all)] fn result_evaluate<'a>( &self, - table_length: usize, alloc: &'a Bump, accessor: &'a dyn DataAccessor, - ) -> Column<'a, C::Scalar> { - let lhs_column = self.lhs.result_evaluate(table_length, alloc, accessor); - let rhs_column = self.rhs.result_evaluate(table_length, alloc, accessor); + ) -> ColumnarValue<'a, C::Scalar> { + let lhs_columnar_value = self.lhs.result_evaluate(alloc, accessor); + let rhs_columnar_value = self.rhs.result_evaluate(alloc, accessor); + // If both sides are literals we should return a literal. + // Otherwise we return a column. + let is_literal = matches!( + (&lhs_columnar_value, &rhs_columnar_value), + (&ColumnarValue::Literal(_), &ColumnarValue::Literal(_)) + ); let lhs_scale = self.lhs.data_type().scale().unwrap_or(0); let rhs_scale = self.rhs.data_type().scale().unwrap_or(0); - let res = scale_and_subtract(alloc, lhs_column, rhs_column, lhs_scale, rhs_scale, true) - .expect("Failed to scale and subtract"); - Column::Boolean(result_evaluate_equals_zero(table_length, alloc, res)) + let res = scale_and_subtract_columnar_value( + alloc, + lhs_columnar_value, + rhs_columnar_value, + lhs_scale, + rhs_scale, + true, + ) + .expect("Failed to scale and subtract"); + let raw_result = result_evaluate_equals_zero(res.len(), alloc, res); + if is_literal { + ColumnarValue::Literal(LiteralValue::Boolean(raw_result[0])) + } else { + ColumnarValue::Column(Column::Boolean(raw_result)) + } } #[tracing::instrument(name = "EqualsExpr::prover_evaluate", level = "debug", skip_all)] diff --git a/crates/proof-of-sql/src/sql/proof_exprs/equals_expr_test.rs b/crates/proof-of-sql/src/sql/proof_exprs/equals_expr_test.rs index ac1426ea3..dcc106aab 100644 --- a/crates/proof-of-sql/src/sql/proof_exprs/equals_expr_test.rs +++ b/crates/proof-of-sql/src/sql/proof_exprs/equals_expr_test.rs @@ -1,7 +1,9 @@ use crate::{ base::{ commitment::InnerProductProof, - database::{owned_table_utility::*, Column, OwnedTable, OwnedTableTestAccessor}, + database::{ + owned_table_utility::*, Column, ColumnarValue, OwnedTable, OwnedTableTestAccessor, + }, scalar::{Curve25519Scalar, Scalar}, }, sql::{ @@ -412,7 +414,7 @@ fn we_can_compute_the_correct_output_of_an_equals_expr_using_result_evaluate() { const_scalar(Curve25519Scalar::ZERO), ); let alloc = Bump::new(); - let res = equals_expr.result_evaluate(4, &alloc, &accessor); - let expected_res = Column::Boolean(&[true, false, true, false]); + let res = equals_expr.result_evaluate(&alloc, &accessor); + let expected_res = ColumnarValue::Column(Column::Boolean(&[true, false, true, false])); assert_eq!(res, expected_res); } diff --git a/crates/proof-of-sql/src/sql/proof_exprs/inequality_expr.rs b/crates/proof-of-sql/src/sql/proof_exprs/inequality_expr.rs index f1f647682..bd919348c 100644 --- a/crates/proof-of-sql/src/sql/proof_exprs/inequality_expr.rs +++ b/crates/proof-of-sql/src/sql/proof_exprs/inequality_expr.rs @@ -1,13 +1,17 @@ use super::{ count_equals_zero, count_or, count_sign, prover_evaluate_equals_zero, prover_evaluate_or, prover_evaluate_sign, result_evaluate_equals_zero, result_evaluate_or, result_evaluate_sign, - scale_and_add_subtract_eval, scale_and_subtract, verifier_evaluate_equals_zero, - verifier_evaluate_or, verifier_evaluate_sign, DynProofExpr, ProofExpr, + scale_and_add_subtract_eval, scale_and_subtract, scale_and_subtract_columnar_value, + verifier_evaluate_equals_zero, verifier_evaluate_or, verifier_evaluate_sign, DynProofExpr, + ProofExpr, }; use crate::{ base::{ commitment::Commitment, - database::{Column, ColumnRef, ColumnType, CommitmentAccessor, DataAccessor}, + database::{ + Column, ColumnRef, ColumnType, ColumnarValue, CommitmentAccessor, DataAccessor, + LiteralValue, + }, map::IndexSet, proof::ProofError, }, @@ -57,30 +61,55 @@ impl ProofExpr for InequalityExpr { #[tracing::instrument(name = "InequalityExpr::result_evaluate", level = "debug", skip_all)] fn result_evaluate<'a>( &self, - table_length: usize, alloc: &'a Bump, accessor: &'a dyn DataAccessor, - ) -> Column<'a, C::Scalar> { - let lhs_column = self.lhs.result_evaluate(table_length, alloc, accessor); - let rhs_column = self.rhs.result_evaluate(table_length, alloc, accessor); + ) -> ColumnarValue<'a, C::Scalar> { + let lhs_columnar_value = self.lhs.result_evaluate(alloc, accessor); + let rhs_columnar_value = self.rhs.result_evaluate(alloc, accessor); + // If both sides are literals we should return a literal. + // Otherwise we return a column. + let is_literal = matches!( + (&lhs_columnar_value, &rhs_columnar_value), + (&ColumnarValue::Literal(_), &ColumnarValue::Literal(_)) + ); let lhs_scale = self.lhs.data_type().scale().unwrap_or(0); let rhs_scale = self.rhs.data_type().scale().unwrap_or(0); let diff = if self.is_lte { - scale_and_subtract(alloc, lhs_column, rhs_column, lhs_scale, rhs_scale, false) - .expect("Failed to scale and subtract") + scale_and_subtract_columnar_value( + alloc, + lhs_columnar_value, + rhs_columnar_value, + lhs_scale, + rhs_scale, + false, + ) + .expect("Failed to scale and subtract") } else { - scale_and_subtract(alloc, rhs_column, lhs_column, rhs_scale, lhs_scale, false) - .expect("Failed to scale and subtract") + scale_and_subtract_columnar_value( + alloc, + rhs_columnar_value, + lhs_columnar_value, + rhs_scale, + lhs_scale, + false, + ) + .expect("Failed to scale and subtract") }; + let diff_len = diff.len(); // diff == 0 - let equals_zero = result_evaluate_equals_zero(table_length, alloc, diff); + let equals_zero = result_evaluate_equals_zero(diff_len, alloc, diff); // sign(diff) == -1 - let sign = result_evaluate_sign(table_length, alloc, diff); + let sign = result_evaluate_sign(diff_len, alloc, diff); // (diff == 0) || (sign(diff) == -1) - Column::Boolean(result_evaluate_or(table_length, alloc, equals_zero, sign)) + let raw_result = result_evaluate_or(diff_len, alloc, equals_zero, sign); + if is_literal { + ColumnarValue::Literal(LiteralValue::Boolean(raw_result[0])) + } else { + ColumnarValue::Column(Column::Boolean(raw_result)) + } } #[tracing::instrument(name = "InequalityExpr::prover_evaluate", level = "debug", skip_all)] diff --git a/crates/proof-of-sql/src/sql/proof_exprs/inequality_expr_test.rs b/crates/proof-of-sql/src/sql/proof_exprs/inequality_expr_test.rs index 34605872c..4322e3ff0 100644 --- a/crates/proof-of-sql/src/sql/proof_exprs/inequality_expr_test.rs +++ b/crates/proof-of-sql/src/sql/proof_exprs/inequality_expr_test.rs @@ -2,8 +2,8 @@ use crate::{ base::{ commitment::InnerProductProof, database::{ - owned_table_utility::*, Column, LiteralValue, OwnedTable, OwnedTableTestAccessor, - TestAccessor, + owned_table_utility::*, Column, ColumnarValue, LiteralValue, OwnedTable, + OwnedTableTestAccessor, TestAccessor, }, scalar::{Curve25519Scalar, Scalar, ScalarExt}, }, @@ -569,8 +569,8 @@ fn we_can_compute_the_correct_output_of_a_lte_inequality_expr_using_result_evalu let rhs_expr = column(t, "b", &accessor); let lte_expr = lte(lhs_expr, rhs_expr); let alloc = Bump::new(); - let res = lte_expr.result_evaluate(3, &alloc, &accessor); - let expected_res = Column::Boolean(&[true, false, true]); + let res = lte_expr.result_evaluate(&alloc, &accessor); + let expected_res = ColumnarValue::Column(Column::Boolean(&[true, false, true])); assert_eq!(res, expected_res); } @@ -584,7 +584,7 @@ fn we_can_compute_the_correct_output_of_a_gte_inequality_expr_using_result_evalu let lit_expr = const_bigint(1); let gte_expr = gte(col_expr, lit_expr); let alloc = Bump::new(); - let res = gte_expr.result_evaluate(3, &alloc, &accessor); - let expected_res = Column::Boolean(&[false, true, true]); + let res = gte_expr.result_evaluate(&alloc, &accessor); + let expected_res = ColumnarValue::Column(Column::Boolean(&[false, true, true])); assert_eq!(res, expected_res); } diff --git a/crates/proof-of-sql/src/sql/proof_exprs/literal_expr.rs b/crates/proof-of-sql/src/sql/proof_exprs/literal_expr.rs index c00af32d1..d232807bd 100644 --- a/crates/proof-of-sql/src/sql/proof_exprs/literal_expr.rs +++ b/crates/proof-of-sql/src/sql/proof_exprs/literal_expr.rs @@ -2,7 +2,10 @@ use super::ProofExpr; use crate::{ base::{ commitment::Commitment, - database::{Column, ColumnRef, ColumnType, CommitmentAccessor, DataAccessor, LiteralValue}, + database::{ + Column, ColumnRef, ColumnType, ColumnarValue, CommitmentAccessor, DataAccessor, + LiteralValue, + }, map::IndexSet, proof::ProofError, scalar::Scalar, @@ -47,11 +50,10 @@ impl ProofExpr for LiteralExpr { #[tracing::instrument(name = "LiteralExpr::result_evaluate", level = "debug", skip_all)] fn result_evaluate<'a>( &self, - table_length: usize, - alloc: &'a Bump, + _alloc: &'a Bump, _accessor: &'a dyn DataAccessor, - ) -> Column<'a, C::Scalar> { - Column::from_literal_with_length(&self.value, table_length, alloc) + ) -> ColumnarValue<'a, C::Scalar> { + ColumnarValue::Literal(self.value.clone()) } #[tracing::instrument(name = "LiteralExpr::prover_evaluate", level = "debug", skip_all)] diff --git a/crates/proof-of-sql/src/sql/proof_exprs/literal_expr_test.rs b/crates/proof-of-sql/src/sql/proof_exprs/literal_expr_test.rs index a3572df59..234735d75 100644 --- a/crates/proof-of-sql/src/sql/proof_exprs/literal_expr_test.rs +++ b/crates/proof-of-sql/src/sql/proof_exprs/literal_expr_test.rs @@ -2,7 +2,7 @@ use super::{DynProofExpr, ProofExpr}; use crate::{ base::{ commitment::InnerProductProof, - database::{owned_table_utility::*, Column, OwnedTableTestAccessor}, + database::{owned_table_utility::*, Column, ColumnarValue, OwnedTableTestAccessor}, }, sql::{ proof::{exercise_verification, VerifiableQueryResult}, @@ -124,7 +124,7 @@ fn we_can_compute_the_correct_output_of_a_literal_expr_using_result_evaluate() { let accessor = OwnedTableTestAccessor::::new_from_table(t, data, 0, ()); let literal_expr: DynProofExpr = const_bool(true); let alloc = Bump::new(); - let res = literal_expr.result_evaluate(4, &alloc, &accessor); - let expected_res = Column::Boolean(&[true, true, true, true]); + let res = literal_expr.result_evaluate(&alloc, &accessor); + let expected_res = ColumnarValue::Column(Column::Boolean(&[true, true, true, true])); assert_eq!(res, expected_res); } diff --git a/crates/proof-of-sql/src/sql/proof_exprs/mod.rs b/crates/proof-of-sql/src/sql/proof_exprs/mod.rs index d2b8c3f27..ca31c57b6 100644 --- a/crates/proof-of-sql/src/sql/proof_exprs/mod.rs +++ b/crates/proof-of-sql/src/sql/proof_exprs/mod.rs @@ -57,11 +57,12 @@ use not_expr::NotExpr; mod not_expr_test; mod comparison_util; -pub(crate) use comparison_util::scale_and_subtract; +pub(crate) use comparison_util::{scale_and_subtract, scale_and_subtract_columnar_value}; mod numerical_util; pub(crate) use numerical_util::{ - add_subtract_columns, multiply_columns, scale_and_add_subtract_eval, + add_subtract_columnar_values, add_subtract_columns, multiply_columnar_values, multiply_columns, + scale_and_add_subtract_eval, }; mod equals_expr; diff --git a/crates/proof-of-sql/src/sql/proof_exprs/multiply_expr.rs b/crates/proof-of-sql/src/sql/proof_exprs/multiply_expr.rs index 7392c1ca6..a8b00c862 100644 --- a/crates/proof-of-sql/src/sql/proof_exprs/multiply_expr.rs +++ b/crates/proof-of-sql/src/sql/proof_exprs/multiply_expr.rs @@ -3,15 +3,15 @@ use crate::{ base::{ commitment::Commitment, database::{ - try_multiply_column_types, Column, ColumnRef, ColumnType, CommitmentAccessor, - DataAccessor, + try_multiply_column_types, Column, ColumnRef, ColumnType, ColumnarValue, + CommitmentAccessor, DataAccessor, }, map::IndexSet, proof::ProofError, }, sql::{ proof::{CountBuilder, FinalRoundBuilder, SumcheckSubpolynomialType, VerificationBuilder}, - proof_exprs::multiply_columns, + proof_exprs::{multiply_columnar_values, multiply_columns}, }, }; use alloc::{boxed::Box, vec}; @@ -50,16 +50,14 @@ impl ProofExpr for MultiplyExpr { fn result_evaluate<'a>( &self, - table_length: usize, alloc: &'a Bump, accessor: &'a dyn DataAccessor, - ) -> Column<'a, C::Scalar> { - let lhs_column: Column<'a, C::Scalar> = - self.lhs.result_evaluate(table_length, alloc, accessor); - let rhs_column: Column<'a, C::Scalar> = - self.rhs.result_evaluate(table_length, alloc, accessor); - let scalars = multiply_columns(&lhs_column, &rhs_column, alloc); - Column::Scalar(scalars) + ) -> ColumnarValue<'a, C::Scalar> { + let lhs_columnar_value: ColumnarValue<'a, C::Scalar> = + self.lhs.result_evaluate(alloc, accessor); + let rhs_columnar_value: ColumnarValue<'a, C::Scalar> = + self.rhs.result_evaluate(alloc, accessor); + multiply_columnar_values(&lhs_columnar_value, &rhs_columnar_value, alloc) } #[tracing::instrument( diff --git a/crates/proof-of-sql/src/sql/proof_exprs/multiply_expr_test.rs b/crates/proof-of-sql/src/sql/proof_exprs/multiply_expr_test.rs index 2ce9290bf..e70b046c1 100644 --- a/crates/proof-of-sql/src/sql/proof_exprs/multiply_expr_test.rs +++ b/crates/proof-of-sql/src/sql/proof_exprs/multiply_expr_test.rs @@ -1,7 +1,7 @@ use crate::{ base::{ commitment::InnerProductProof, - database::{owned_table_utility::*, Column, OwnedTableTestAccessor}, + database::{owned_table_utility::*, Column, ColumnarValue, OwnedTableTestAccessor}, scalar::Curve25519Scalar, }, sql::{ @@ -346,11 +346,11 @@ fn we_can_compute_the_correct_output_of_a_multiply_expr_using_result_evaluate() subtract(column(t, "a", &accessor), const_decimal75(2, 1, 15)), ); let alloc = Bump::new(); - let res = arithmetic_expr.result_evaluate(4, &alloc, &accessor); + let res = arithmetic_expr.result_evaluate(&alloc, &accessor); let expected_res_scalar = [0, 5, 75, 25] .iter() .map(|v| Curve25519Scalar::from(*v)) .collect::>(); - let expected_res = Column::Scalar(&expected_res_scalar); + let expected_res = ColumnarValue::Column(Column::Scalar(&expected_res_scalar)); assert_eq!(res, expected_res); } diff --git a/crates/proof-of-sql/src/sql/proof_exprs/not_expr.rs b/crates/proof-of-sql/src/sql/proof_exprs/not_expr.rs index 194d5b9be..a0dc48882 100644 --- a/crates/proof-of-sql/src/sql/proof_exprs/not_expr.rs +++ b/crates/proof-of-sql/src/sql/proof_exprs/not_expr.rs @@ -2,7 +2,9 @@ use super::{DynProofExpr, ProofExpr}; use crate::{ base::{ commitment::Commitment, - database::{Column, ColumnRef, ColumnType, CommitmentAccessor, DataAccessor}, + database::{ + Column, ColumnRef, ColumnType, ColumnarValue, CommitmentAccessor, DataAccessor, + }, map::IndexSet, proof::ProofError, }, @@ -10,6 +12,7 @@ use crate::{ }; use alloc::boxed::Box; use bumpalo::Bump; +use proof_of_sql_parser::intermediate_ast::UnaryOperator; use serde::{Deserialize, Serialize}; /// Provable logical NOT expression @@ -37,14 +40,14 @@ impl ProofExpr for NotExpr { #[tracing::instrument(name = "NotExpr::result_evaluate", level = "debug", skip_all)] fn result_evaluate<'a>( &self, - table_length: usize, alloc: &'a Bump, accessor: &'a dyn DataAccessor, - ) -> Column<'a, C::Scalar> { - let expr_column: Column<'a, C::Scalar> = - self.expr.result_evaluate(table_length, alloc, accessor); - let expr = expr_column.as_boolean().expect("expr is not boolean"); - Column::Boolean(alloc.alloc_slice_fill_with(expr.len(), |i| !expr[i])) + ) -> ColumnarValue<'a, C::Scalar> { + let expr_columnar_value: ColumnarValue<'a, C::Scalar> = + self.expr.result_evaluate(alloc, accessor); + expr_columnar_value + .apply_boolean_unary_operator(UnaryOperator::Not, alloc) + .expect("Failed to apply boolean unary operator") } #[tracing::instrument(name = "NotExpr::prover_evaluate", level = "debug", skip_all)] diff --git a/crates/proof-of-sql/src/sql/proof_exprs/not_expr_test.rs b/crates/proof-of-sql/src/sql/proof_exprs/not_expr_test.rs index 60b041ea1..61f9956cf 100644 --- a/crates/proof-of-sql/src/sql/proof_exprs/not_expr_test.rs +++ b/crates/proof-of-sql/src/sql/proof_exprs/not_expr_test.rs @@ -1,7 +1,9 @@ use crate::{ base::{ commitment::InnerProductProof, - database::{owned_table_utility::*, Column, OwnedTableTestAccessor, TestAccessor}, + database::{ + owned_table_utility::*, Column, ColumnarValue, OwnedTableTestAccessor, TestAccessor, + }, }, sql::{ proof::{exercise_verification, VerifiableQueryResult}, @@ -120,7 +122,7 @@ fn we_can_compute_the_correct_output_of_a_not_expr_using_result_evaluate() { let not_expr: DynProofExpr = not(equal(column(t, "b", &accessor), const_int128(1))); let alloc = Bump::new(); - let res = not_expr.result_evaluate(2, &alloc, &accessor); - let expected_res = Column::Boolean(&[true, false]); + let res = not_expr.result_evaluate(&alloc, &accessor); + let expected_res = ColumnarValue::Column(Column::Boolean(&[true, false])); assert_eq!(res, expected_res); } diff --git a/crates/proof-of-sql/src/sql/proof_exprs/or_expr.rs b/crates/proof-of-sql/src/sql/proof_exprs/or_expr.rs index 561722232..3abf0d084 100644 --- a/crates/proof-of-sql/src/sql/proof_exprs/or_expr.rs +++ b/crates/proof-of-sql/src/sql/proof_exprs/or_expr.rs @@ -2,7 +2,9 @@ use super::{DynProofExpr, ProofExpr}; use crate::{ base::{ commitment::Commitment, - database::{Column, ColumnRef, ColumnType, CommitmentAccessor, DataAccessor}, + database::{ + Column, ColumnRef, ColumnType, ColumnarValue, CommitmentAccessor, DataAccessor, + }, map::IndexSet, proof::ProofError, scalar::Scalar, @@ -11,6 +13,7 @@ use crate::{ }; use alloc::{boxed::Box, vec}; use bumpalo::Bump; +use proof_of_sql_parser::intermediate_ast::BinaryOperator; use serde::{Deserialize, Serialize}; /// Provable logical OR expression @@ -42,17 +45,16 @@ impl ProofExpr for OrExpr { #[tracing::instrument(name = "OrExpr::result_evaluate", level = "debug", skip_all)] fn result_evaluate<'a>( &self, - table_length: usize, alloc: &'a Bump, accessor: &'a dyn DataAccessor, - ) -> Column<'a, C::Scalar> { - let lhs_column: Column<'a, C::Scalar> = - self.lhs.result_evaluate(table_length, alloc, accessor); - let rhs_column: Column<'a, C::Scalar> = - self.rhs.result_evaluate(table_length, alloc, accessor); - let lhs = lhs_column.as_boolean().expect("lhs is not boolean"); - let rhs = rhs_column.as_boolean().expect("rhs is not boolean"); - Column::Boolean(result_evaluate_or(table_length, alloc, lhs, rhs)) + ) -> ColumnarValue<'a, C::Scalar> { + let lhs_columnar_value: ColumnarValue<'a, C::Scalar> = + self.lhs.result_evaluate(alloc, accessor); + let rhs_columnar_value: ColumnarValue<'a, C::Scalar> = + self.rhs.result_evaluate(alloc, accessor); + lhs_columnar_value + .apply_boolean_binary_operator(&rhs_columnar_value, BinaryOperator::Or, alloc) + .expect("Failed to apply boolean binary operator") } #[tracing::instrument(name = "OrExpr::prover_evaluate", level = "debug", skip_all)] diff --git a/crates/proof-of-sql/src/sql/proof_exprs/or_expr_test.rs b/crates/proof-of-sql/src/sql/proof_exprs/or_expr_test.rs index ac79a280e..1cd5475dd 100644 --- a/crates/proof-of-sql/src/sql/proof_exprs/or_expr_test.rs +++ b/crates/proof-of-sql/src/sql/proof_exprs/or_expr_test.rs @@ -1,7 +1,9 @@ use crate::{ base::{ commitment::InnerProductProof, - database::{owned_table_utility::*, Column, OwnedTableTestAccessor, TestAccessor}, + database::{ + owned_table_utility::*, Column, ColumnarValue, OwnedTableTestAccessor, TestAccessor, + }, }, sql::{ proof::{exercise_verification, VerifiableQueryResult}, @@ -184,7 +186,7 @@ fn we_can_compute_the_correct_output_of_an_or_expr_using_result_evaluate() { equal(column(t, "d", &accessor), const_varchar("g")), ); let alloc = Bump::new(); - let res = and_expr.result_evaluate(4, &alloc, &accessor); - let expected_res = Column::Boolean(&[false, true, true, true]); + let res = and_expr.result_evaluate(&alloc, &accessor); + let expected_res = ColumnarValue::Column(Column::Boolean(&[false, true, true, true])); assert_eq!(res, expected_res); } diff --git a/crates/proof-of-sql/src/sql/proof_exprs/proof_expr.rs b/crates/proof-of-sql/src/sql/proof_exprs/proof_expr.rs index 88f215484..28e4bb1a1 100644 --- a/crates/proof-of-sql/src/sql/proof_exprs/proof_expr.rs +++ b/crates/proof-of-sql/src/sql/proof_exprs/proof_expr.rs @@ -1,7 +1,9 @@ use crate::{ base::{ commitment::Commitment, - database::{Column, ColumnRef, ColumnType, CommitmentAccessor, DataAccessor}, + database::{ + Column, ColumnRef, ColumnType, ColumnarValue, CommitmentAccessor, DataAccessor, + }, map::IndexSet, proof::ProofError, }, @@ -19,14 +21,12 @@ pub trait ProofExpr: Debug + Send + Sync { fn data_type(&self) -> ColumnType; /// This returns the result of evaluating the expression on the given table, and returns - /// a column of values. This result slice is guarenteed to have length `table_length`. - /// Implementations must ensure that the returned slice has length `table_length`. + /// a column of values. fn result_evaluate<'a>( &self, - table_length: usize, alloc: &'a Bump, accessor: &'a dyn DataAccessor, - ) -> Column<'a, C::Scalar>; + ) -> ColumnarValue<'a, C::Scalar>; /// Evaluate the expression, add components needed to prove it, and return thet resulting column /// of values diff --git a/crates/proof-of-sql/src/sql/proof_exprs/proof_expr_test.rs b/crates/proof-of-sql/src/sql/proof_exprs/proof_expr_test.rs index 8d6b07102..f86584093 100644 --- a/crates/proof-of-sql/src/sql/proof_exprs/proof_expr_test.rs +++ b/crates/proof-of-sql/src/sql/proof_exprs/proof_expr_test.rs @@ -1,7 +1,9 @@ use super::{test_utility::*, DynProofExpr, ProofExpr}; use crate::base::{ commitment::InnerProductProof, - database::{owned_table_utility::*, Column, OwnedTableTestAccessor, TestAccessor}, + database::{ + owned_table_utility::*, Column, ColumnarValue, OwnedTableTestAccessor, TestAccessor, + }, }; use bumpalo::Bump; use curve25519_dalek::RistrettoPoint; @@ -36,10 +38,10 @@ fn we_can_compute_the_correct_result_of_a_complex_bool_expr_using_result_evaluat not(equal(column(t, "c", &accessor), const_int128(3))), ); let alloc = Bump::new(); - let res = bool_expr.result_evaluate(17, &alloc, &accessor); - let expected_res = Column::Boolean(&[ + let res = bool_expr.result_evaluate(&alloc, &accessor); + let expected_res = ColumnarValue::Column(Column::Boolean(&[ false, true, false, true, false, true, false, true, false, true, false, true, false, true, false, false, false, - ]); + ])); assert_eq!(res, expected_res); } diff --git a/crates/proof-of-sql/src/sql/proof_plans/dyn_proof_plan.rs b/crates/proof-of-sql/src/sql/proof_plans/dyn_proof_plan.rs index b7edcc70a..44bd61305 100644 --- a/crates/proof-of-sql/src/sql/proof_plans/dyn_proof_plan.rs +++ b/crates/proof-of-sql/src/sql/proof_plans/dyn_proof_plan.rs @@ -115,14 +115,13 @@ impl ProverEvaluate for DynProofPlan { #[tracing::instrument(name = "DynProofPlan::result_evaluate", level = "debug", skip_all)] fn result_evaluate<'a>( &self, - input_length: usize, alloc: &'a Bump, accessor: &'a dyn DataAccessor, ) -> Vec> { match self { - DynProofPlan::Projection(expr) => expr.result_evaluate(input_length, alloc, accessor), - DynProofPlan::GroupBy(expr) => expr.result_evaluate(input_length, alloc, accessor), - DynProofPlan::Filter(expr) => expr.result_evaluate(input_length, alloc, accessor), + DynProofPlan::Projection(expr) => expr.result_evaluate(alloc, accessor), + DynProofPlan::GroupBy(expr) => expr.result_evaluate(alloc, accessor), + DynProofPlan::Filter(expr) => expr.result_evaluate(alloc, accessor), } } diff --git a/crates/proof-of-sql/src/sql/proof_plans/filter_exec.rs b/crates/proof-of-sql/src/sql/proof_plans/filter_exec.rs index 5a1b6106b..c8fdd0a05 100644 --- a/crates/proof-of-sql/src/sql/proof_plans/filter_exec.rs +++ b/crates/proof-of-sql/src/sql/proof_plans/filter_exec.rs @@ -152,14 +152,16 @@ impl ProverEvaluate for FilterExec { #[tracing::instrument(name = "FilterExec::result_evaluate", level = "debug", skip_all)] fn result_evaluate<'a>( &self, - input_length: usize, alloc: &'a Bump, accessor: &'a dyn DataAccessor, ) -> Vec> { + let input_length = accessor.get_length(self.table.table_ref); // 1. selection - let selection_column: Column<'a, C::Scalar> = - self.where_clause - .result_evaluate(input_length, alloc, accessor); + let selection_column: Column<'a, C::Scalar> = self + .where_clause + .result_evaluate(alloc, accessor) + .into_column(input_length, alloc) + .expect("Failed to convert columnar value to column"); let selection = selection_column .as_boolean() .expect("selection is not boolean"); @@ -171,7 +173,9 @@ impl ProverEvaluate for FilterExec { .map(|aliased_expr| { aliased_expr .expr - .result_evaluate(input_length, alloc, accessor) + .result_evaluate(alloc, accessor) + .into_column(input_length, alloc) + .expect("Failed to convert columnar value to column") }) .collect(); diff --git a/crates/proof-of-sql/src/sql/proof_plans/filter_exec_test.rs b/crates/proof-of-sql/src/sql/proof_plans/filter_exec_test.rs index 062781985..fa04b0656 100644 --- a/crates/proof-of-sql/src/sql/proof_plans/filter_exec_test.rs +++ b/crates/proof-of-sql/src/sql/proof_plans/filter_exec_test.rs @@ -196,7 +196,7 @@ fn we_can_get_an_empty_result_from_a_basic_filter_on_an_empty_table_using_result where_clause, ); let alloc = Bump::new(); - let result_cols = expr.result_evaluate(0, &alloc, &accessor); + let result_cols = expr.result_evaluate(&alloc, &accessor); let output_length = result_cols.first().map_or(0, Column::len) as u64; let mut builder = FirstRoundBuilder::new(); expr.first_round_evaluate(&mut builder); @@ -243,7 +243,7 @@ fn we_can_get_an_empty_result_from_a_basic_filter_using_result_evaluate() { where_clause, ); let alloc = Bump::new(); - let result_cols = expr.result_evaluate(5, &alloc, &accessor); + let result_cols = expr.result_evaluate(&alloc, &accessor); let output_length = result_cols.first().map_or(0, Column::len) as u64; let mut builder = FirstRoundBuilder::new(); expr.first_round_evaluate(&mut builder); @@ -286,7 +286,7 @@ fn we_can_get_no_columns_from_a_basic_filter_with_no_selected_columns_using_resu equal(column(t, "a", &accessor), const_int128(5)); let expr = filter(cols_expr_plan(t, &[], &accessor), tab(t), where_clause); let alloc = Bump::new(); - let result_cols = expr.result_evaluate(5, &alloc, &accessor); + let result_cols = expr.result_evaluate(&alloc, &accessor); let output_length = result_cols.first().map_or(0, Column::len) as u64; let mut builder = FirstRoundBuilder::new(); expr.first_round_evaluate(&mut builder); @@ -319,7 +319,7 @@ fn we_can_get_the_correct_result_from_a_basic_filter_using_result_evaluate() { where_clause, ); let alloc = Bump::new(); - let result_cols = expr.result_evaluate(5, &alloc, &accessor); + let result_cols = expr.result_evaluate(&alloc, &accessor); let output_length = result_cols.first().map_or(0, Column::len) as u64; let mut builder = FirstRoundBuilder::new(); expr.first_round_evaluate(&mut builder); diff --git a/crates/proof-of-sql/src/sql/proof_plans/filter_exec_test_dishonest_prover.rs b/crates/proof-of-sql/src/sql/proof_plans/filter_exec_test_dishonest_prover.rs index 4d32bc735..f07c81d3b 100644 --- a/crates/proof-of-sql/src/sql/proof_plans/filter_exec_test_dishonest_prover.rs +++ b/crates/proof-of-sql/src/sql/proof_plans/filter_exec_test_dishonest_prover.rs @@ -37,14 +37,16 @@ impl ProverEvaluate for DishonestFilterExec { )] fn result_evaluate<'a>( &self, - input_length: usize, alloc: &'a Bump, accessor: &'a dyn DataAccessor, ) -> Vec> { + let input_length = accessor.get_length(self.table.table_ref); // 1. selection - let selection_column: Column<'a, Curve25519Scalar> = - self.where_clause - .result_evaluate(input_length, alloc, accessor); + let selection_column: Column<'a, Curve25519Scalar> = self + .where_clause + .result_evaluate(alloc, accessor) + .into_column(input_length, alloc) + .expect("Failed to convert columnar value to column"); let selection = selection_column .as_boolean() .expect("selection is not boolean"); @@ -55,7 +57,9 @@ impl ProverEvaluate for DishonestFilterExec { .map(|aliased_expr| { aliased_expr .expr - .result_evaluate(input_length, alloc, accessor) + .result_evaluate(alloc, accessor) + .into_column(input_length, alloc) + .expect("Failed to convert columnar value to column") }) .collect(); // Compute filtered_columns diff --git a/crates/proof-of-sql/src/sql/proof_plans/group_by_exec.rs b/crates/proof-of-sql/src/sql/proof_plans/group_by_exec.rs index 385b8a2e7..d18d2740d 100644 --- a/crates/proof-of-sql/src/sql/proof_plans/group_by_exec.rs +++ b/crates/proof-of-sql/src/sql/proof_plans/group_by_exec.rs @@ -212,14 +212,16 @@ impl ProverEvaluate for GroupByExec { #[tracing::instrument(name = "GroupByExec::result_evaluate", level = "debug", skip_all)] fn result_evaluate<'a>( &self, - input_length: usize, alloc: &'a Bump, accessor: &'a dyn DataAccessor, ) -> Vec> { + let input_length = accessor.get_length(self.table.table_ref); // 1. selection - let selection_column: Column<'a, C::Scalar> = - self.where_clause - .result_evaluate(input_length, alloc, accessor); + let selection_column: Column<'a, C::Scalar> = self + .where_clause + .result_evaluate(alloc, accessor) + .into_column(input_length, alloc) + .expect("Failed to convert columnar value to column"); let selection = selection_column .as_boolean() @@ -229,7 +231,11 @@ impl ProverEvaluate for GroupByExec { let group_by_columns = self .group_by_exprs .iter() - .map(|expr| expr.result_evaluate(input_length, alloc, accessor)) + .map(|expr| { + expr.result_evaluate(alloc, accessor) + .into_column(input_length, alloc) + .expect("Failed to convert columnar value to column") + }) .collect::>(); let sum_columns = self .sum_expr @@ -237,7 +243,9 @@ impl ProverEvaluate for GroupByExec { .map(|aliased_expr| { aliased_expr .expr - .result_evaluate(input_length, alloc, accessor) + .result_evaluate(alloc, accessor) + .into_column(input_length, alloc) + .expect("Failed to convert columnar value to column") }) .collect::>(); // Compute filtered_columns diff --git a/crates/proof-of-sql/src/sql/proof_plans/projection_exec.rs b/crates/proof-of-sql/src/sql/proof_plans/projection_exec.rs index f3038b310..a4d98ea13 100644 --- a/crates/proof-of-sql/src/sql/proof_plans/projection_exec.rs +++ b/crates/proof-of-sql/src/sql/proof_plans/projection_exec.rs @@ -102,17 +102,19 @@ impl ProverEvaluate for ProjectionExec { #[tracing::instrument(name = "ProjectionExec::result_evaluate", level = "debug", skip_all)] fn result_evaluate<'a>( &self, - input_length: usize, alloc: &'a Bump, accessor: &'a dyn DataAccessor, ) -> Vec> { + let input_length = accessor.get_length(self.table.table_ref); let columns: Vec<_> = self .aliased_results .iter() .map(|aliased_expr| { aliased_expr .expr - .result_evaluate(input_length, alloc, accessor) + .result_evaluate(alloc, accessor) + .into_column(input_length, alloc) + .expect("Failed to convert columnar value to column") }) .collect(); columns diff --git a/crates/proof-of-sql/src/sql/proof_plans/projection_exec_test.rs b/crates/proof-of-sql/src/sql/proof_plans/projection_exec_test.rs index c97ecf471..2aba5bb98 100644 --- a/crates/proof-of-sql/src/sql/proof_plans/projection_exec_test.rs +++ b/crates/proof-of-sql/src/sql/proof_plans/projection_exec_test.rs @@ -169,7 +169,7 @@ fn we_can_get_an_empty_result_from_a_basic_projection_on_an_empty_table_using_re let expr: DynProofPlan = projection(cols_expr_plan(t, &["b", "c", "d", "e"], &accessor), tab(t)); let alloc = Bump::new(); - let result_cols = expr.result_evaluate(0, &alloc, &accessor); + let result_cols = expr.result_evaluate(&alloc, &accessor); let output_length = result_cols.first().map_or(0, Column::len) as u64; let mut builder = FirstRoundBuilder::new(); expr.first_round_evaluate(&mut builder); @@ -210,7 +210,7 @@ fn we_can_get_no_columns_from_a_basic_projection_with_no_selected_columns_using_ accessor.add_table(t, data, 0); let expr: DynProofPlan = projection(cols_expr_plan(t, &[], &accessor), tab(t)); let alloc = Bump::new(); - let result_cols = expr.result_evaluate(5, &alloc, &accessor); + let result_cols = expr.result_evaluate(&alloc, &accessor); let output_length = result_cols.first().map_or(0, Column::len) as u64; let mut builder = FirstRoundBuilder::new(); expr.first_round_evaluate(&mut builder); @@ -248,7 +248,7 @@ fn we_can_get_the_correct_result_from_a_basic_projection_using_result_evaluate() tab(t), ); let alloc = Bump::new(); - let result_cols = expr.result_evaluate(5, &alloc, &accessor); + let result_cols = expr.result_evaluate(&alloc, &accessor); let output_length = result_cols.first().map_or(0, Column::len) as u64; let mut builder = FirstRoundBuilder::new(); expr.first_round_evaluate(&mut builder);