Skip to content

Commit

Permalink
feat: add AddSubExpr
Browse files Browse the repository at this point in the history
- add `AddSubExpr` and enable + and - elsewhere
- generalize `scale_and_subtract_eval` to `scale_and_add_subtract_eval`
  • Loading branch information
iajoiner committed Jun 20, 2024
1 parent 91d3a20 commit 025cd19
Show file tree
Hide file tree
Showing 14 changed files with 708 additions and 88 deletions.
44 changes: 41 additions & 3 deletions crates/proof-of-sql/src/base/database/column.rs
Original file line number Diff line number Diff line change
Expand Up @@ -207,15 +207,15 @@ pub enum ColumnType {
/// Mapped to String
#[serde(alias = "VARCHAR", alias = "varchar")]
VarChar,
/// Mapped to Curve25519Scalar
#[serde(alias = "SCALAR", alias = "scalar")]
Scalar,
/// Mapped to i256
#[serde(rename = "Decimal75", alias = "DECIMAL75", alias = "decimal75")]
Decimal75(Precision, i8),
/// Mapped to i64
#[serde(alias = "TIMESTAMP", alias = "timestamp")]
TimestampTZ(PoSQLTimeUnit, PoSQLTimeZone),
/// Mapped to Curve25519Scalar
#[serde(alias = "SCALAR", alias = "scalar")]
Scalar,
}

impl ColumnType {
Expand All @@ -240,6 +240,44 @@ impl ColumnType {
)
}

/// Returns the number of bits in the integer type if it is an integer type. Otherwise, return None.
fn to_integer_bits(&self) -> Option<usize> {
match self {
ColumnType::SmallInt => Some(16),
ColumnType::Int => Some(32),
ColumnType::BigInt => Some(64),
ColumnType::Int128 => Some(128),
_ => None,
}
}

/// Returns the ColumnType of the integer type with the given number of bits if it is a valid integer type.
///
/// Otherwise, return None.
fn from_integer_bits(bits: usize) -> Option<Self> {
match bits {
16 => Some(ColumnType::SmallInt),
32 => Some(ColumnType::Int),
64 => Some(ColumnType::BigInt),
128 => Some(ColumnType::Int128),
_ => None,
}
}

/// Returns the larger integer type of two ColumnTypes if they are both integers.
///
/// If either of the columns is not an integer, return None.
pub fn max_integer_type(&self, other: &Self) -> Option<Self> {
// If either of the columns is not an integer, return None
if !self.is_integer() || !other.is_integer() {
return None;
}
Self::from_integer_bits(std::cmp::max(
self.to_integer_bits().unwrap(),
other.to_integer_bits().unwrap(),
))
}

/// Returns the precision of a ColumnType if it is converted to a decimal wrapped in Some(). If it can not be converted to a decimal, return None.
pub fn precision_value(&self) -> Option<u8> {
match self {
Expand Down
101 changes: 101 additions & 0 deletions crates/proof-of-sql/src/sql/ast/add_subtract_expr.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,101 @@
use super::{
scale_and_add_subtract_eval, try_add_subtract_column_types, try_add_subtract_columns,
ProvableExpr, ProvableExprPlan,
};
use crate::{
base::{
commitment::Commitment,
database::{Column, ColumnRef, ColumnType, CommitmentAccessor, DataAccessor},
proof::ProofError,
},
sql::proof::{CountBuilder, ProofBuilder, VerificationBuilder},
};
use bumpalo::Bump;
use serde::{Deserialize, Serialize};
use std::collections::HashSet;

/// Provable numerical + / - expression
#[derive(Clone, Debug, PartialEq, Serialize, Deserialize)]
pub struct AddSubtractExpr<C: Commitment> {
lhs: Box<ProvableExprPlan<C>>,
rhs: Box<ProvableExprPlan<C>>,
is_subtract: bool,
}

impl<C: Commitment> AddSubtractExpr<C> {
/// Create numerical + / - expression
pub fn new(
lhs: Box<ProvableExprPlan<C>>,
rhs: Box<ProvableExprPlan<C>>,
is_subtract: bool,
) -> Self {
Self {
lhs,
rhs,
is_subtract,
}
}
}

impl<C: Commitment> ProvableExpr<C> for AddSubtractExpr<C> {
fn count(&self, builder: &mut CountBuilder) -> Result<(), ProofError> {
self.lhs.count(builder)?;
self.rhs.count(builder)?;
Ok(())
}

fn data_type(&self) -> ColumnType {
try_add_subtract_column_types(self.lhs.data_type(), self.rhs.data_type())
.expect("Failed to add/subtract column types")
}

fn result_evaluate<'a>(
&self,
table_length: usize,
alloc: &'a Bump,
accessor: &'a dyn DataAccessor<C::Scalar>,
) -> Column<'a, C::Scalar> {
let lhs_column: Column<'a, C::Scalar> =
self.lhs.result_evaluate(table_length, alloc, accessor);
let rhs_column: Column<'a, C::Scalar> =
self.rhs.result_evaluate(table_length, alloc, accessor);
try_add_subtract_columns(lhs_column, rhs_column, alloc, self.is_subtract)
.expect("Failed to add/subtract columns")
}

#[tracing::instrument(
name = "proofs.sql.ast.not_expr.prover_evaluate",
level = "info",
skip_all
)]
fn prover_evaluate<'a>(
&self,
builder: &mut ProofBuilder<'a, C::Scalar>,
alloc: &'a Bump,
accessor: &'a dyn DataAccessor<C::Scalar>,
) -> Column<'a, C::Scalar> {
let lhs_column: Column<'a, C::Scalar> = self.lhs.prover_evaluate(builder, alloc, accessor);
let rhs_column: Column<'a, C::Scalar> = self.rhs.prover_evaluate(builder, alloc, accessor);
try_add_subtract_columns(lhs_column, rhs_column, alloc, self.is_subtract)
.expect("Failed to add/subtract columns")
}

fn verifier_evaluate(
&self,
builder: &mut VerificationBuilder<C>,
accessor: &dyn CommitmentAccessor<C>,
) -> Result<C::Scalar, ProofError> {
let lhs_eval = self.lhs.verifier_evaluate(builder, accessor)?;
let rhs_eval = self.rhs.verifier_evaluate(builder, accessor)?;
let lhs_scale = self.lhs.data_type().scale().unwrap_or(0);
let rhs_scale = self.rhs.data_type().scale().unwrap_or(0);
let res =
scale_and_add_subtract_eval(lhs_eval, rhs_eval, lhs_scale, rhs_scale, self.is_subtract);
Ok(res)
}

fn get_column_references(&self, columns: &mut HashSet<ColumnRef>) {
self.lhs.get_column_references(columns);
self.rhs.get_column_references(columns);
}
}
181 changes: 181 additions & 0 deletions crates/proof-of-sql/src/sql/ast/add_subtract_expr_test.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,181 @@
use crate::{
base::{
commitment::InnerProductProof,
database::{
make_random_test_accessor_data, owned_table_utility::*, Column, ColumnType,
OwnedTableTestAccessor, RandomTestAccessorDescriptor, RecordBatchTestAccessor,
TestAccessor,
},
},
record_batch,
sql::ast::{
test_expr::TestExprNode,
test_utility::{add, column, equal, subtract},
ProvableExpr, ProvableExprPlan,
},
};
use arrow::record_batch::RecordBatch;
use bumpalo::Bump;
use curve25519_dalek::ristretto::RistrettoPoint;
use polars::prelude::*;
use rand::{rngs::StdRng, Rng};
use rand_core::SeedableRng;

// select results from table_ref where filter_col_l = filter_col_r0 + / - filter_col_r1
#[allow(clippy::too_many_arguments)]
fn create_test_add_subtract_expr(
table_ref: &str,
results: &[&str],
filter_col_l: &str,
filter_col_r0: &str,
filter_col_r1: &str,
data: RecordBatch,
offset: usize,
is_subtract: bool,
) -> TestExprNode {
let mut accessor = RecordBatchTestAccessor::new_empty();
let t = table_ref.parse().unwrap();
accessor.add_table(t, data, offset);
let df_filter = if is_subtract {
polars::prelude::col(filter_col_l).eq(col(filter_col_r0) - col(filter_col_r1))
} else {
polars::prelude::col(filter_col_l).eq(col(filter_col_r0) + col(filter_col_r1))
};
let filter_expr = equal(
column(t, filter_col_l, &accessor),
if is_subtract {
subtract(
column(t, filter_col_r0, &accessor),
column(t, filter_col_r1, &accessor),
)
} else {
add(
column(t, filter_col_r0, &accessor),
column(t, filter_col_r1, &accessor),
)
},
);
TestExprNode::new(t, results, filter_expr, df_filter, accessor)
}

#[test]
fn we_can_prove_a_equals_add_query_with_a_single_selected_row() {
let data = record_batch!(
"a" => [123_i64, 456],
"b" => [4_i64, 1],
"c" => [123_i64, 457],
"d" => ["alfa", "gama"]
);
let test_expr =
create_test_add_subtract_expr("sxt.t", &["a", "d"], "c", "a", "b", data, 0, false);
let res = test_expr.verify_expr();
let expected_res = record_batch!(
"a" => [456_i64],
"d" => ["gama"]
);
assert_eq!(res, expected_res);
}

#[test]
fn we_can_prove_a_equals_subtract_query_with_a_single_selected_row() {
let data = record_batch!(
"a" => [127_i64, 458],
"b" => [4_i64, 1],
"c" => [123_i64, 457],
"d" => ["alfa", "gama"]
);
let test_expr =
create_test_add_subtract_expr("sxt.t", &["a", "d"], "c", "a", "b", data, 0, true);
let res = test_expr.verify_expr();
let expected_res = record_batch!(
"a" => [127_i64, 458],
"d" => ["alfa", "gama"]
);
assert_eq!(res, expected_res);
}

fn test_random_tables_with_given_offset(offset: usize) {
let descr = RandomTestAccessorDescriptor {
min_rows: 1,
max_rows: 20,
min_value: -3,
max_value: 3,
};
let mut rng = StdRng::from_seed([0u8; 32]);
let cols = [
("l", ColumnType::BigInt),
("r0", ColumnType::BigInt),
("r1", ColumnType::BigInt),
("varchar", ColumnType::VarChar),
("integer", ColumnType::BigInt),
];
for _ in 0..20 {
let data = make_random_test_accessor_data(&mut rng, &cols, &descr);
let is_subtract = rng.gen::<bool>();
let test_expr = create_test_add_subtract_expr(
"sxt.t",
&["l", "varchar", "integer"],
"l",
"r0",
"r1",
data,
offset,
is_subtract,
);
let res = test_expr.verify_expr();
let expected_res = test_expr.query_table();
assert_eq!(res, expected_res);
}
}

#[test]
fn we_can_query_random_tables_with_a_zero_offset() {
test_random_tables_with_given_offset(0);
}

#[test]
fn we_can_query_random_tables_with_a_non_zero_offset() {
test_random_tables_with_given_offset(75);
}

#[test]
fn we_can_compute_the_correct_output_of_an_add_expr_using_result_evaluate() {
let data = owned_table([
bigint("a", [123, 456]),
bigint("b", [3, 1]),
bigint("c", [126, 453]),
varchar("d", ["alfa", "gama"]),
]);
let mut accessor = OwnedTableTestAccessor::<InnerProductProof>::new_empty_with_setup(());
let t = "sxt.t".parse().unwrap();
accessor.add_table(t, data, 0);
let eq_expr: ProvableExprPlan<RistrettoPoint> = equal(
column(t, "c", &accessor),
add(column(t, "a", &accessor), column(t, "b", &accessor)),
);
let alloc = Bump::new();
let res = eq_expr.result_evaluate(2, &alloc, &accessor);
let expected_res = Column::Boolean(&[true, false]);
assert_eq!(res, expected_res);
}

#[test]
fn we_can_compute_the_correct_output_of_a_subtract_expr_using_result_evaluate() {
let data = owned_table([
bigint("a", [123, 456]),
bigint("b", [3, 1]),
bigint("c", [126, 455]),
varchar("d", ["alfa", "gama"]),
]);
let mut accessor = OwnedTableTestAccessor::<InnerProductProof>::new_empty_with_setup(());
let t = "sxt.t".parse().unwrap();
accessor.add_table(t, data, 0);
let eq_expr: ProvableExprPlan<RistrettoPoint> = equal(
column(t, "c", &accessor),
subtract(column(t, "a", &accessor), column(t, "b", &accessor)),
);
let alloc = Bump::new();
let res = eq_expr.result_evaluate(2, &alloc, &accessor);
let expected_res = Column::Boolean(&[false, true]);
assert_eq!(res, expected_res);
}
21 changes: 2 additions & 19 deletions crates/proof-of-sql/src/sql/ast/comparison_util.rs
Original file line number Diff line number Diff line change
@@ -1,9 +1,5 @@
use crate::{
base::{
database::Column,
math::decimal::{scale_scalar, Precision},
scalar::Scalar,
},
base::{database::Column, math::decimal::Precision, scalar::Scalar},
sql::parse::{type_check_binary_operation, ConversionError, ConversionResult},
};
use bumpalo::Bump;
Expand Down Expand Up @@ -72,7 +68,7 @@ pub(crate) fn scale_and_subtract<'a, S: Scalar>(
);
// Check if the precision is valid
let _max_precision = Precision::new(max_precision_value)
.map_err(|_| ConversionError::InvalidPrecision(max_precision_value))?;
.map_err(|_| ConversionError::InvalidPrecision(max_precision_value as i16))?;
}
unchecked_subtract_impl(
alloc,
Expand All @@ -81,16 +77,3 @@ pub(crate) fn scale_and_subtract<'a, S: Scalar>(
lhs_len,
)
}

/// The counterpart of `scale_and_subtract` for evaluating decimal expressions.
pub(crate) fn scale_and_subtract_eval<S: Scalar>(
lhs_eval: S,
rhs_eval: S,
lhs_scale: i8,
rhs_scale: i8,
) -> ConversionResult<S> {
let max_scale = lhs_scale.max(rhs_scale);
let scaled_lhs_eval = scale_scalar(lhs_eval, max_scale - lhs_scale)?;
let scaled_rhs_eval = scale_scalar(rhs_eval, max_scale - rhs_scale)?;
Ok(scaled_lhs_eval - scaled_rhs_eval)
}
Loading

0 comments on commit 025cd19

Please sign in to comment.