diff --git a/.github/workflows/lint-and-test.yml b/.github/workflows/lint-and-test.yml index dfebc0ea6..9a9b1828a 100644 --- a/.github/workflows/lint-and-test.yml +++ b/.github/workflows/lint-and-test.yml @@ -118,17 +118,10 @@ jobs: run: cargo run --example dog_breeds - name: Run wood types example run: cargo run --example wood_types - - name: Run dinosaurs example - run: cargo run --example dinosaurs - - name: Run books example - run: cargo run --example books - - name: Run brands example - run: cargo run --example brands - name: Run posql_db example (With Blitzar) run: bash crates/proof-of-sql/examples/posql_db/run_example.sh - name: Run posql_db example (Without Blitzar) run: bash crates/proof-of-sql/examples/posql_db/run_example.sh --no-default-features --features="rayon" - clippy: name: Clippy runs-on: large-8-core-32gb-22-04 @@ -238,4 +231,4 @@ jobs: - name: Install solhint run: npm install -g solhint - name: Run tests - run: solhint -c 'crates/proof-of-sql/.solhint.json' 'crates/proof-of-sql/**/*.sol' -w 0 + run: solhint -c 'crates/proof-of-sql/.solhint.json' 'crates/proof-of-sql/**/*.sol' -w 0 \ No newline at end of file diff --git a/Cargo.toml b/Cargo.toml index a0d8f7216..035636f51 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -20,8 +20,8 @@ ark-poly = { version = "0.4.0" } ark-serialize = { version = "0.4.0" } ark-std = { version = "0.4.0", default-features = false } arrayvec = { version = "0.7", default-features = false } -arrow = { version = "51.0" } -arrow-csv = { version = "51.0" } +arrow = { version = "51.0.0" } +arrow-csv = { version = "51.0.0" } bit-iter = { version = "1.1.1" } bigdecimal = { version = "0.4.5", default-features = false, features = ["serde"] } blake3 = { version = "1.3.3", default-features = false } diff --git a/crates/proof-of-sql-parser/Cargo.toml b/crates/proof-of-sql-parser/Cargo.toml index 228833b7e..29c23bc56 100644 --- a/crates/proof-of-sql-parser/Cargo.toml +++ b/crates/proof-of-sql-parser/Cargo.toml @@ -21,6 +21,7 @@ chrono = { workspace = true, features = ["serde"] } lalrpop-util = { workspace = true, features = ["lexer", "unicode"] } serde = { workspace = true, features = ["serde_derive", "alloc"] } snafu = { workspace = true } +sqlparser = { workspace = true } [build-dependencies] lalrpop = { workspace = true } diff --git a/crates/proof-of-sql-parser/src/identifier.rs b/crates/proof-of-sql-parser/src/identifier.rs index b11df6862..87ccc9de4 100644 --- a/crates/proof-of-sql-parser/src/identifier.rs +++ b/crates/proof-of-sql-parser/src/identifier.rs @@ -2,6 +2,7 @@ use crate::{sql::IdentifierParser, ParseError, ParseResult}; use alloc::{format, string::ToString}; use arrayvec::ArrayString; use core::{cmp::Ordering, fmt, ops::Deref, str::FromStr}; +use sqlparser::ast::Ident; /// Top-level unique identifier. #[derive(Debug, PartialEq, Eq, Clone, Hash, Ord, PartialOrd, Copy)] @@ -71,6 +72,16 @@ impl fmt::Display for Identifier { } } +// TryFrom for Identifier +impl TryFrom for Identifier { + type Error = ParseError; + + fn try_from(ident: Ident) -> ParseResult { + // Convert Ident's value to Identifier + Identifier::try_new(ident.value) + } +} + impl PartialEq for Identifier { fn eq(&self, other: &str) -> bool { other.eq_ignore_ascii_case(&self.name) @@ -278,4 +289,14 @@ mod tests { Identifier::new("t".repeat(64)); Identifier::new("茶".repeat(21)); } + + #[test] + fn try_from_ident() { + let ident = Ident::new("ValidIdentifier"); + let identifier = Identifier::try_from(ident).unwrap(); + assert_eq!(identifier.name(), "valididentifier"); + + let invalid_ident = Ident::new("INVALID$IDENTIFIER"); + assert!(Identifier::try_from(invalid_ident).is_err()); + } } diff --git a/crates/proof-of-sql-parser/src/resource_id.rs b/crates/proof-of-sql-parser/src/resource_id.rs index 2fe0d0848..10045d03b 100644 --- a/crates/proof-of-sql-parser/src/resource_id.rs +++ b/crates/proof-of-sql-parser/src/resource_id.rs @@ -3,11 +3,13 @@ use crate::{impl_serde_from_str, sql::ResourceIdParser, Identifier, ParseError, use alloc::{ format, string::{String, ToString}, + vec::Vec, }; use core::{ fmt::{self, Display}, str::FromStr, }; +use sqlparser::ast::Ident; /// Unique resource identifier, like `schema.object_name`. #[derive(Debug, PartialEq, Eq, Clone, Hash, Copy)] @@ -110,6 +112,22 @@ impl FromStr for ResourceId { } impl_serde_from_str!(ResourceId); +impl TryFrom> for ResourceId { + type Error = ParseError; + + fn try_from(identifiers: Vec) -> ParseResult { + if identifiers.len() != 2 { + return Err(ParseError::ResourceIdParseError { + error: "Expected exactly two identifiers for ResourceId".to_string(), + }); + } + + let schema = Identifier::try_from(identifiers[0].clone())?; + let object_name = Identifier::try_from(identifiers[1].clone())?; + Ok(ResourceId::new(schema, object_name)) + } +} + #[cfg(test)] mod tests { use super::*; @@ -233,4 +251,15 @@ mod tests { serde_json::from_str(r#""good_identifier.bad!identifier"#); assert!(deserialized.is_err()); } + + #[test] + fn test_try_from_vec_ident() { + let identifiers = alloc::vec![Ident::new("schema_name"), Ident::new("object_name")]; + let resource_id = ResourceId::try_from(identifiers).unwrap(); + assert_eq!(resource_id.schema().name(), "schema_name"); + assert_eq!(resource_id.object_name().name(), "object_name"); + + let invalid_identifiers = alloc::vec![Ident::new("only_one_ident")]; + assert!(ResourceId::try_from(invalid_identifiers).is_err()); + } } diff --git a/crates/proof-of-sql/Cargo.toml b/crates/proof-of-sql/Cargo.toml index d4fe9f69d..423f704f0 100644 --- a/crates/proof-of-sql/Cargo.toml +++ b/crates/proof-of-sql/Cargo.toml @@ -89,43 +89,88 @@ required-features = ["test"] [[example]] name = "posql_db" -required-features = [ "arrow" ] +required-features = ["arrow"] [[example]] name = "space" -required-features = [ "arrow" ] +required-features = ["arrow"] [[example]] name = "dog_breeds" -required-features = [ "arrow" ] +required-features = ["arrow"] [[example]] name = "wood_types" -required-features = [ "arrow" ] +required-features = ["arrow"] [[example]] name = "dinosaurs" -required-features = [ "arrow" ] +required-features = ["arrow"] [[example]] name = "books" -required-features = [ "arrow" ] +required-features = ["arrow"] + +[[example]] +name = "programming_books" +required-features = ["arrow"] [[example]] name = "brands" +required-features = ["arrow"] + +[[example]] +name = "census" required-features = [ "arrow" ] +[[example]] +name = "plastics" +required-features = ["arrow"] + +[[example]] +name = "avocado-prices" +required-features = ["arrow"] + +[[example]] +name = "sushi" +required-features = ["arrow"] + +[[example]] +name = "stocks" +required-features = ["arrow"] + +[[example]] +name = "tech_gadget_prices" +required-features = [ "arrow" ] + +[[example]] +name = "albums" +required-features = [ "arrow" ] + +[[example]] +name = "vehicles" +required-features = [ "arrow" ] + +[[example]] +name = "countries" +required-features = [ "arrow" ] + +[[example]] +name = "rockets" +required-features = [ "arrow" ] + + [[bench]] name = "posql_benches" harness = false -required-features = [ "blitzar" ] +required-features = ["blitzar"] [[bench]] name = "bench_append_rows" harness = false -required-features = [ "test" ] +required-features = ["test"] [[bench]] name = "jaeger_benches" harness = false -required-features = [ "blitzar" ] +required-features = ["blitzar"] diff --git a/crates/proof-of-sql/examples/albums/albums.csv b/crates/proof-of-sql/examples/albums/albums.csv new file mode 100644 index 000000000..f62c9868d --- /dev/null +++ b/crates/proof-of-sql/examples/albums/albums.csv @@ -0,0 +1,100 @@ +artist,year,genre,album +Michael Jackson,1982,Pop,Thriller +The Beatles,1969,Rock,Abbey Road +Pink Floyd,1973,Progressive Rock,Dark Side of the Moon +Eagles,1976,Rock,Hotel California +Fleetwood Mac,1977,Rock,Rumours +AC/DC,1980,Hard Rock,Back in Black +Whitney Houston,1992,Pop/R&B,The Bodyguard +Bee Gees,1977,Disco,Saturday Night Fever +Queen,1975,Rock,A Night at the Opera +Bruce Springsteen,1984,Rock,Born in the U.S.A. +Nirvana,1991,Grunge,Nevermind +Adele,2011,Pop/Soul,21 +Bob Marley & The Wailers,1977,Reggae,Exodus +Metallica,1991,Metal,Metallica +Prince,1984,Pop/Funk,Purple Rain +Led Zeppelin,1971,Rock,Led Zeppelin IV +The Rolling Stones,1972,Rock,Exile on Main St. +David Bowie,1972,Rock,The Rise and Fall of Ziggy Stardust and the Spiders from Mars +Stevie Wonder,1976,Soul/R&B,Songs in the Key of Life +Madonna,1984,Pop,Like a Virgin +Amy Winehouse,2006,Soul/R&B,Back to Black +Carole King,1971,Folk/Rock,Tapestry +Dr. Dre,1992,Hip-Hop,The Chronic +Bruce Springsteen,1975,Rock,Born to Run +The Beach Boys,1966,Rock,Pet Sounds +Joni Mitchell,1971,Folk,Blue +Miles Davis,1959,Jazz,Kind of Blue +The Clash,1979,Punk Rock,London Calling +Simon & Garfunkel,1970,Folk Rock,Bridge Over Troubled Water +Paul Simon,1986,World/Pop,Graceland +U2,1987,Rock,The Joshua Tree +Marvin Gaye,1971,Soul/R&B,What's Going On +Radiohead,1997,Alternative Rock,OK Computer +The Who,1971,Rock,Who's Next +Bob Dylan,1965,Folk Rock,Highway 61 Revisited +Guns N' Roses,1987,Hard Rock,Appetite for Destruction +The Doors,1967,Rock,The Doors +Elton John,1973,Rock/Pop,Goodbye Yellow Brick Road +R.E.M.,1992,Alternative Rock,Automatic for the People +Kendrick Lamar,2015,Hip-Hop,To Pimp a Butterfly +The Strokes,2001,Indie Rock,Is This It +Kanye West,2010,Hip-Hop,My Beautiful Dark Twisted Fantasy +Beyoncé,2016,R&B/Pop,Lemonade +Arcade Fire,2004,Indie Rock,Funeral +Oasis,1995,Britpop,(What's the Story) Morning Glory? +Daft Punk,2001,Electronic,Discovery +Nas,1994,Hip-Hop,Illmatic +Green Day,1994,Punk Rock,Dookie +Jay-Z,2001,Hip-Hop,The Blueprint +Taylor Swift,2014,Pop,1989 +Arctic Monkeys,2013,Alternative Rock,AM +The Weeknd,2020,Pop/R&B,After Hours +Lana Del Rey,2012,Alternative/Pop,Born to Die +Tame Impala,2015,Psychedelic Rock,Currents +Frank Ocean,2012,R&B/Soul,Channel Orange +Coldplay,2002,Alternative Rock,A Rush of Blood to the Head +Lady Gaga,2011,Pop,Born This Way +Black Keys,2010,Blues Rock,Brothers +Ed Sheeran,2014,Pop,x +Tyler The Creator,2019,Hip-Hop,IGOR +Billie Eilish,2019,Pop/Alternative,When We All Fall Asleep +Tool,2019,Progressive Metal,Fear Inoculum +SZA,2022,R&B/Soul,SOS +Rosalía,2022,Flamenco Pop,Motomami +Harry Styles,2022,Pop,Harry's House +Bad Bunny,2022,Reggaeton/Latin Trap,Un Verano Sin Ti +Pearl Jam,1991,Grunge,Ten +Red Hot Chili Peppers,1991,Alternative Rock,Blood Sugar Sex Magik +Björk,1997,Art Pop,Homogenic +The Weeknd,2022,Pop/R&B,Dawn FM +Kendrick Lamar,2022,Hip-Hop,Mr. Morale & the Big Steppers +Taylor Swift,2022,Pop,Midnights +Arctic Monkeys,2022,Alternative Rock,The Car +Beyoncé,2022,Dance/Pop,Renaissance +Drake,2022,Hip-Hop/R&B,Honestly +Post Malone,2022,Pop/Hip-Hop,Twelve Carat Toothache +Florence + The Machine,2022,Art Rock,Dance Fever +Jack Harlow,2022,Hip-Hop,Come Home the Kids Miss You +Lizzo,2022,Pop/R&B,Special +Olivia Rodrigo,2023,Pop/Rock,GUTS +Lorde,2013,Art Pop,Pure Heroine +Talking Heads,1980,New Wave,Remain in Light +The Velvet Underground,1967,Art Rock,The Velvet Underground & Nico +Kate Bush,1985,Art Pop,Hounds of Love +Stevie Nicks,1981,Rock,Bella Donna +Travis Scott,2018,Hip-Hop,Astroworld +Portishead,1994,Trip Hop,Dummy +The Smiths,1986,Alternative Rock,The Queen Is Dead +Calvin Harris,2012,Electronic Dance,18 Months +Rihanna,2016,Pop/R&B,Anti +Dua Lipa,2020,Pop,Future Nostalgia +The Cure,1989,Gothic Rock,Disintegration +Foo Fighters,1997,Alternative Rock,The Colour and the Shape +A Tribe Called Quest,1991,Hip-Hop,The Low End Theory +Massive Attack,1998,Trip Hop,Mezzanine +Gorillaz,2001,Alternative/Hip-Hop,Gorillaz +Depeche Mode,1990,Electronic,Violator +Rage Against The Machine,1992,Rap Metal,Rage Against The Machine +Joy Division,1979,Post-Punk,Unknown Pleasures \ No newline at end of file diff --git a/crates/proof-of-sql/examples/albums/main.rs b/crates/proof-of-sql/examples/albums/main.rs new file mode 100644 index 000000000..486ca1f12 --- /dev/null +++ b/crates/proof-of-sql/examples/albums/main.rs @@ -0,0 +1,135 @@ +//! This is a non-interactive example of using Proof of SQL with an albums dataset. +//! To run this, use `cargo run --release --example albums`. +//! +//! NOTE: If this doesn't work because you do not have the appropriate GPU drivers installed, +//! you can run `cargo run --release --example albums --no-default-features --features="arrow cpu-perf"` instead. It will be slower for proof generation. + +use arrow::datatypes::SchemaRef; +use arrow_csv::{infer_schema_from_files, ReaderBuilder}; +use proof_of_sql::{ + base::database::{ + arrow_schema_utility::get_posql_compatible_schema, OwnedTable, OwnedTableTestAccessor, + TestAccessor, + }, + proof_primitive::dory::{ + DynamicDoryCommitment, DynamicDoryEvaluationProof, ProverSetup, PublicParameters, + VerifierSetup, + }, + sql::{parse::QueryExpr, postprocessing::apply_postprocessing_steps, proof::QueryProof}, +}; +use rand::{rngs::StdRng, SeedableRng}; +use std::{fs::File, time::Instant}; + +// We generate the public parameters and the setups used by the prover and verifier for the Dory PCS. +// The `max_nu` should be set such that the maximum table size is less than `2^(2*max_nu-1)`. +const DORY_SETUP_MAX_NU: usize = 8; +// This should be a "nothing-up-my-sleeve" phrase or number. +const DORY_SEED: [u8; 32] = *b"32f7f321c4ab1234d5e6f7a8b9c0d1e2"; + +/// # Panics +/// Will panic if the query does not parse or the proof fails to verify. +fn prove_and_verify_query( + sql: &str, + accessor: &OwnedTableTestAccessor, + prover_setup: &ProverSetup, + verifier_setup: &VerifierSetup, +) { + // Parse the query: + println!("Parsing the query: {sql}..."); + let now = Instant::now(); + let query_plan = QueryExpr::::try_new( + sql.parse().unwrap(), + "albums".parse().unwrap(), + accessor, + ) + .unwrap(); + println!("Done in {} ms.", now.elapsed().as_secs_f64() * 1000.); + + // Generate the proof and result: + print!("Generating proof..."); + let now = Instant::now(); + let (proof, provable_result) = QueryProof::::new( + query_plan.proof_expr(), + accessor, + &prover_setup, + ); + println!("Done in {} ms.", now.elapsed().as_secs_f64() * 1000.); + + // Verify the result with the proof: + print!("Verifying proof..."); + let now = Instant::now(); + let result = proof + .verify( + query_plan.proof_expr(), + accessor, + &provable_result, + &verifier_setup, + ) + .unwrap(); + let result = apply_postprocessing_steps(result.table, query_plan.postprocessing()); + println!("Verified in {} ms.", now.elapsed().as_secs_f64() * 1000.); + + // Display the result + println!("Query Result:"); + println!("{result:?}"); +} + +fn main() { + let mut rng = StdRng::from_seed(DORY_SEED); + let public_parameters = PublicParameters::rand(DORY_SETUP_MAX_NU, &mut rng); + let prover_setup = ProverSetup::from(&public_parameters); + let verifier_setup = VerifierSetup::from(&public_parameters); + + let filename = "crates/proof-of-sql/examples/albums/albums.csv"; + let schema = get_posql_compatible_schema(&SchemaRef::new( + infer_schema_from_files(&[filename.to_string()], b',', None, true).unwrap(), + )); + let albums_batch = ReaderBuilder::new(schema) + .with_header(true) + .build(File::open(filename).unwrap()) + .unwrap() + .next() + .unwrap() + .unwrap(); + + // Load the table into an "Accessor" so that the prover and verifier can access the data/commitments. + let mut accessor = + OwnedTableTestAccessor::::new_empty_with_setup(&prover_setup); + accessor.add_table( + "albums.collection".parse().unwrap(), + OwnedTable::try_from(albums_batch).unwrap(), + 0, + ); + + // Query 1: Count number of albums by genre + prove_and_verify_query( + "SELECT genre, COUNT(*) AS album_count FROM albums.collection GROUP BY genre ORDER BY genre", + &accessor, + &prover_setup, + &verifier_setup, + ); + + // Query 2: Find all albums from the 1970s + prove_and_verify_query( + "SELECT artist, album, year FROM albums.collection WHERE year >= 1970 AND year < 1980 ORDER BY year", + &accessor, + &prover_setup, + &verifier_setup, + ); + + // Query 3: Count total number of albums + prove_and_verify_query( + "SELECT COUNT(*) AS total_albums FROM albums.collection", + &accessor, + &prover_setup, + &verifier_setup, + ); + + // Query 4: List all rock albums after 1975 (using exact matches for Rock genres) + prove_and_verify_query( + "SELECT artist, album, year FROM albums.collection WHERE (genre = 'Rock' OR genre = 'Hard Rock' OR genre = 'Progressive Rock') AND year > 1975 ORDER BY year DESC", + &accessor, + &prover_setup, + &verifier_setup, + ); +} diff --git a/crates/proof-of-sql/examples/avocado-prices/avocado-prices.csv b/crates/proof-of-sql/examples/avocado-prices/avocado-prices.csv new file mode 100644 index 000000000..7750f7a46 --- /dev/null +++ b/crates/proof-of-sql/examples/avocado-prices/avocado-prices.csv @@ -0,0 +1,37 @@ +Year,Price +1990,96 +1991,100 +1992,269 +1993,149 +1994,127 +1995,153 +1996,232 +1997,127 +1998,249 +1999,240 +2000,241 +2001,90 +2002,91 +2003,169 +2004,167 +2005,56 +2006,230 +2007,174 +2008,124 +2009,92 +2010,201 +2011,167 +2012,125 +2013,147 +2014,285 +2015,154 +2016,106 +2017,223 +2018,85 +2019,145 +2020,147 +2021,68 +2022,142 +2023,281 +2024,164 + diff --git a/crates/proof-of-sql/examples/avocado-prices/main.rs b/crates/proof-of-sql/examples/avocado-prices/main.rs new file mode 100644 index 000000000..85d5e50b8 --- /dev/null +++ b/crates/proof-of-sql/examples/avocado-prices/main.rs @@ -0,0 +1,124 @@ +//! Example to use Proof of SQL with datasets +//! To run, use `cargo run --example avocado-prices`. +//! +//! NOTE: If this doesn't work because you do not have the appropriate GPU drivers installed, +//! you can run `cargo run --release --example avocado-prices --no-default-features --features="arrow cpu-perf"` instead. It will be slower for proof generation. +use arrow::datatypes::SchemaRef; +use arrow_csv::{infer_schema_from_files, ReaderBuilder}; +use proof_of_sql::{ + base::database::{OwnedTable, OwnedTableTestAccessor}, + proof_primitive::dory::{ + DynamicDoryCommitment, DynamicDoryEvaluationProof, ProverSetup, PublicParameters, + VerifierSetup, + }, + sql::{parse::QueryExpr, postprocessing::apply_postprocessing_steps, proof::QueryProof}, +}; +use rand::{rngs::StdRng, SeedableRng}; +use std::{fs::File, time::Instant}; + +// We generate the public parameters and the setups used by the prover and verifier for the Dory PCS. +// The `max_nu` should be set such that the maximum table size is less than `2^(2*max_nu-1)`. +// For a sampling: +// max_nu = 3 => max table size is 32 rows +// max_nu = 4 => max table size is 128 rows +// max_nu = 8 => max table size is 32768 rows +// max_nu = 10 => max table size is 0.5 million rows +// max_nu = 15 => max table size is 0.5 billion rows +// max_nu = 20 => max table size is 0.5 trillion rows +// Note: we will eventually load these from a file. +const DORY_SETUP_MAX_NU: usize = 8; +// This should be a "nothing-up-my-sleeve" phrase or number. +const DORY_SEED: [u8; 32] = *b"len 32 rng seed - Space and Time"; + +/// # Panics +/// Will panic if the query does not parse or the proof fails to verify. +fn prove_and_verify_query( + sql: &str, + accessor: &OwnedTableTestAccessor, + prover_setup: &ProverSetup, + verifier_setup: &VerifierSetup, +) { + // Parse the query: + println!("Parsing the query: {sql}..."); + let now = Instant::now(); + let query_plan = QueryExpr::::try_new( + sql.parse().unwrap(), + "avocado".parse().unwrap(), + accessor, + ) + .unwrap(); + println!("Done in {} ms.", now.elapsed().as_secs_f64() * 1000.); + + // Generate the proof and result: + print!("Generating proof..."); + let now = Instant::now(); + let (proof, provable_result) = QueryProof::::new( + query_plan.proof_expr(), + accessor, + &prover_setup, + ); + println!("Done in {} ms.", now.elapsed().as_secs_f64() * 1000.); + + // Verify the result with the proof: + print!("Verifying proof..."); + let now = Instant::now(); + let result = proof + .verify( + query_plan.proof_expr(), + accessor, + &provable_result, + &verifier_setup, + ) + .unwrap(); + let result = apply_postprocessing_steps(result.table, query_plan.postprocessing()); + println!("Verified in {} ms.", now.elapsed().as_secs_f64() * 1000.); + + // Display the result + println!("Query Result:"); + println!("{result:?}"); +} + +fn main() { + let mut rng = StdRng::from_seed(DORY_SEED); + let public_parameters = PublicParameters::rand(DORY_SETUP_MAX_NU, &mut rng); + let prover_setup = ProverSetup::from(&public_parameters); + let verifier_setup = VerifierSetup::from(&public_parameters); + + let filename = "./crates/proof-of-sql/examples/avocado-prices/avocado-prices.csv"; + let data_batch = ReaderBuilder::new(SchemaRef::new( + infer_schema_from_files(&[filename.to_string()], b',', None, true).unwrap(), + )) + .with_header(true) + .build(File::open(filename).unwrap()) + .unwrap() + .next() + .unwrap() + .unwrap(); + + // Load the table into an "Accessor" so that the prover and verifier can access the data/commitments. + let accessor = OwnedTableTestAccessor::::new_from_table( + "avocado.prices".parse().unwrap(), + OwnedTable::try_from(data_batch).unwrap(), + 0, + &prover_setup, + ); + + prove_and_verify_query( + "SELECT COUNT(*) AS total FROM prices", + &accessor, + &prover_setup, + &verifier_setup, + ); + prove_and_verify_query( + "SELECT Price, COUNT(*) AS total FROM prices GROUP BY Price ORDER BY total", + &accessor, + &prover_setup, + &verifier_setup, + ); + prove_and_verify_query( + "SELECT Year, COUNT(*) AS total FROM prices WHERE Price > 100 GROUP BY Year ORDER BY total DESC LIMIT 5", + &accessor, + &prover_setup, + &verifier_setup, + ); +} diff --git a/crates/proof-of-sql/examples/census/census-income.csv b/crates/proof-of-sql/examples/census/census-income.csv new file mode 100644 index 000000000..0accbfc01 --- /dev/null +++ b/crates/proof-of-sql/examples/census/census-income.csv @@ -0,0 +1,53 @@ +Id,Geography,Id2,Households_Estimate_Total +0400000US01,Alabama,1,1837292 +0400000US02,Alaska,2,250875 +0400000US04,Arizona,4,2381501 +0400000US05,Arkansas,5,1130417 +0400000US06,California,6,12581722 +0400000US08,Colorado,8,1989371 +0400000US09,Connecticut,9,1348275 +0400000US10,Delaware,10,337245 +0400000US11,District of Columbia,11,268015 +0400000US12,Florida,12,7168502 +0400000US13,Georgia,13,3522934 +0400000US15,Hawaii,15,449296 +0400000US16,Idaho,16,583452 +0400000US17,Illinois,17,4763457 +0400000US18,Indiana,18,2482558 +0400000US19,Iowa,19,1227201 +0400000US20,Kansas,20,1109747 +0400000US21,Kentucky,21,1693399 +0400000US22,Louisiana,22,1715997 +0400000US23,Maine,23,552589 +0400000US24,Maryland,24,2149424 +0400000US25,Massachusetts,25,2528592 +0400000US26,Michigan,26,3815532 +0400000US27,Minnesota,27,2109924 +0400000US28,Mississippi,28,1086898 +0400000US29,Missouri,29,2353778 +0400000US30,Montana,30,405504 +0400000US31,Nebraska,31,729572 +0400000US32,Nevada,32,995980 +0400000US33,New Hampshire,33,518088 +0400000US34,New Jersey,34,3181152 +0400000US35,New Mexico,35,760251 +0400000US36,New York,36,7214163 +0400000US37,North Carolina,37,3721358 +0400000US38,North Dakota,38,291468 +0400000US39,Ohio,39,4551497 +0400000US40,Oklahoma,40,1445059 +0400000US41,Oregon,41,1516591 +0400000US42,Pennsylvania,42,4945140 +0400000US44,Rhode Island,44,410347 +0400000US45,South Carolina,45,1781957 +0400000US46,South Dakota,46,326086 +0400000US47,Tennessee,47,2480467 +0400000US48,Texas,48,8965352 +0400000US49,Utah,49,891240 +0400000US50,Vermont,50,256563 +0400000US51,Virginia,51,3026761 +0400000US53,Washington,53,2634496 +0400000US54,West Virginia,54,739759 +0400000US55,Wisconsin,55,2281781 +0400000US56,Wyoming,56,222679 +0400000US72,Puerto Rico,72,1254274 diff --git a/crates/proof-of-sql/examples/census/main.rs b/crates/proof-of-sql/examples/census/main.rs new file mode 100644 index 000000000..3f3f31e25 --- /dev/null +++ b/crates/proof-of-sql/examples/census/main.rs @@ -0,0 +1,127 @@ +//! Example to use Proof of SQL with census datasets +//! To run, use `cargo run --release --example census`. +//! +//! NOTE: If this doesn't work because you do not have the appropriate GPU drivers installed, +//! you can run `cargo run --release --example census --no-default-features --features="arrow cpu-perf"` instead. It will be slower for proof generation. + +// Note: the census-income.csv was obtained from +// https://github.com/domoritz/vis-examples/blob/master/data/census-income.csv +use arrow::datatypes::SchemaRef; +use arrow_csv::{infer_schema_from_files, ReaderBuilder}; +use proof_of_sql::{ + base::database::{OwnedTable, OwnedTableTestAccessor}, + proof_primitive::dory::{ + DynamicDoryCommitment, DynamicDoryEvaluationProof, ProverSetup, PublicParameters, + VerifierSetup, + }, + sql::{parse::QueryExpr, postprocessing::apply_postprocessing_steps, proof::QueryProof}, +}; +use rand::{rngs::StdRng, SeedableRng}; +use std::{fs::File, time::Instant}; + +// We generate the public parameters and the setups used by the prover and verifier for the Dory PCS. +// The `max_nu` should be set such that the maximum table size is less than `2^(2*max_nu-1)`. +// For a sampling: +// max_nu = 3 => max table size is 32 rows +// max_nu = 4 => max table size is 128 rows +// max_nu = 8 => max table size is 32768 rows +// max_nu = 10 => max table size is 0.5 million rows +// max_nu = 15 => max table size is 0.5 billion rows +// max_nu = 20 => max table size is 0.5 trillion rows +// Note: we will eventually load these from a file. +const DORY_SETUP_MAX_NU: usize = 8; +// This should be a "nothing-up-my-sleeve" phrase or number. +const DORY_SEED: [u8; 32] = *b"len 32 rng seed - Space and Time"; + +/// # Panics +/// Will panic if the query does not parse or the proof fails to verify. +fn prove_and_verify_query( + sql: &str, + accessor: &OwnedTableTestAccessor, + prover_setup: &ProverSetup, + verifier_setup: &VerifierSetup, +) { + // Parse the query: + println!("Parsing the query: {sql}..."); + let now = Instant::now(); + let query_plan = QueryExpr::::try_new( + sql.parse().unwrap(), + "census".parse().unwrap(), + accessor, + ) + .unwrap(); + println!("Done in {} ms.", now.elapsed().as_secs_f64() * 1000.); + + // Generate the proof and result: + print!("Generating proof..."); + let now = Instant::now(); + let (proof, provable_result) = QueryProof::::new( + query_plan.proof_expr(), + accessor, + &prover_setup, + ); + println!("Done in {} ms.", now.elapsed().as_secs_f64() * 1000.); + + // Verify the result with the proof: + print!("Verifying proof..."); + let now = Instant::now(); + let result = proof + .verify( + query_plan.proof_expr(), + accessor, + &provable_result, + &verifier_setup, + ) + .unwrap(); + let result = apply_postprocessing_steps(result.table, query_plan.postprocessing()); + println!("Verified in {} ms.", now.elapsed().as_secs_f64() * 1000.); + + // Display the result + println!("Query Result:"); + println!("{result:?}"); +} + +fn main() { + let mut rng = StdRng::from_seed(DORY_SEED); + let public_parameters = PublicParameters::rand(DORY_SETUP_MAX_NU, &mut rng); + let prover_setup = ProverSetup::from(&public_parameters); + let verifier_setup = VerifierSetup::from(&public_parameters); + + let filename = "./crates/proof-of-sql/examples/census/census-income.csv"; + let census_income_batch = ReaderBuilder::new(SchemaRef::new( + infer_schema_from_files(&[filename.to_string()], b',', None, true).unwrap(), + )) + .with_header(true) + .build(File::open(filename).unwrap()) + .unwrap() + .next() + .unwrap() + .unwrap(); + + // Load the table into an "Accessor" so that the prover and verifier can access the data/commitments. + let accessor = OwnedTableTestAccessor::::new_from_table( + "census.income".parse().unwrap(), + OwnedTable::try_from(census_income_batch).unwrap(), + 0, + &prover_setup, + ); + + prove_and_verify_query( + "SELECT COUNT(*) AS total_geographies FROM income", + &accessor, + &prover_setup, + &verifier_setup, + ); + prove_and_verify_query( + "SELECT Geography, COUNT(*) AS num_geographies FROM income GROUP BY Geography ORDER BY num_geographies", + &accessor, + &prover_setup, + &verifier_setup, + ); + prove_and_verify_query( + "SELECT Geography, COUNT(*) AS num_geographies FROM income WHERE Households_Estimate_total > 2000000 GROUP BY Geography ORDER BY num_geographies DESC LIMIT 5", + &accessor, + &prover_setup, + &verifier_setup, + ); +} diff --git a/crates/proof-of-sql/examples/countries/countries_gdp.csv b/crates/proof-of-sql/examples/countries/countries_gdp.csv new file mode 100644 index 000000000..397102f8f --- /dev/null +++ b/crates/proof-of-sql/examples/countries/countries_gdp.csv @@ -0,0 +1,35 @@ +Country,Continent,GDP,GDPP +UnitedStates,NorthAmerica,21137,63543 +China,Asia,14342,10261 +Japan,Asia,5081,40293 +Germany,Europe,3846,46329 +India,Asia,2875,2099 +UnitedKingdom,Europe,2825,42330 +France,Europe,2716,41463 +Italy,Europe,2001,33279 +Brazil,SouthAmerica,1839,8718 +Canada,NorthAmerica,1643,43119 +Russia,EuropeAsia,1637,11229 +SouthKorea,Asia,1622,31489 +Australia,Oceania,1382,53799 +Spain,Europe,1316,28152 +Mexico,NorthAmerica,1265,9958 +Indonesia,Asia,1119,4152 +Netherlands,Europe,902,52477 +SaudiArabia,Asia,793,23206 +Turkey,EuropeAsia,761,9005 +Switzerland,Europe,703,81392 +Argentina,SouthAmerica,449,9921 +Sweden,Europe,528,52073 +Nigeria,Africa,448,2190 +Poland,Europe,594,15673 +Thailand,Asia,509,7306 +SouthAfrica,Africa,350,5883 +Philippines,Asia,402,3685 +Colombia,SouthAmerica,323,6458 +Egypt,Africa,302,3012 +Pakistan,Asia,278,1450 +Bangladesh,Asia,302,1855 +Vietnam,Asia,283,2900 +Chile,SouthAmerica,252,13120 +Finland,Europe,268,48888 \ No newline at end of file diff --git a/crates/proof-of-sql/examples/countries/main.rs b/crates/proof-of-sql/examples/countries/main.rs new file mode 100644 index 000000000..10bfb8705 --- /dev/null +++ b/crates/proof-of-sql/examples/countries/main.rs @@ -0,0 +1,132 @@ +//! This is a non-interactive example of using Proof of SQL with a countries dataset. +//! To run this, use `cargo run --release --example countries`. +//! +//! NOTE: If this doesn't work because you do not have the appropriate GPU drivers installed, +//! you can run `cargo run --release --example countries --no-default-features --features="arrow cpu-perf"` instead. It will be slower for proof generation. + +use arrow::datatypes::SchemaRef; +use arrow_csv::{infer_schema_from_files, ReaderBuilder}; +use proof_of_sql::{ + base::database::{ + arrow_schema_utility::get_posql_compatible_schema, OwnedTable, OwnedTableTestAccessor, + TestAccessor, + }, + proof_primitive::dory::{ + DynamicDoryCommitment, DynamicDoryEvaluationProof, ProverSetup, PublicParameters, + VerifierSetup, + }, + sql::{parse::QueryExpr, postprocessing::apply_postprocessing_steps, proof::QueryProof}, +}; +use rand::{rngs::StdRng, SeedableRng}; +use std::{fs::File, time::Instant}; + +// We generate the public parameters and the setups used by the prover and verifier for the Dory PCS. +// The `max_nu` should be set such that the maximum table size is less than `2^(2*max_nu-1)`. +const DORY_SETUP_MAX_NU: usize = 8; +// This should be a "nothing-up-my-sleeve" phrase or number. +const DORY_SEED: [u8; 32] = *b"7a1b3c8d2e4f9g6h5i0j7k2l8m3n9o1p"; + +/// # Panics +/// Will panic if the query does not parse or the proof fails to verify. +fn prove_and_verify_query( + sql: &str, + accessor: &OwnedTableTestAccessor, + prover_setup: &ProverSetup, + verifier_setup: &VerifierSetup, +) { + // Parse the query: + println!("Parsing the query: {sql}..."); + let now = Instant::now(); + let query_plan = QueryExpr::::try_new( + sql.parse().unwrap(), + "countries".parse().unwrap(), + accessor, + ) + .unwrap(); + println!("Done in {} ms.", now.elapsed().as_secs_f64() * 1000.); + + // Generate the proof and result: + print!("Generating proof..."); + let now = Instant::now(); + let (proof, provable_result) = QueryProof::::new( + query_plan.proof_expr(), + accessor, + &prover_setup, + ); + println!("Done in {} ms.", now.elapsed().as_secs_f64() * 1000.); + + // Verify the result with the proof: + print!("Verifying proof..."); + let now = Instant::now(); + let result = proof + .verify( + query_plan.proof_expr(), + accessor, + &provable_result, + &verifier_setup, + ) + .unwrap(); + let result = apply_postprocessing_steps(result.table, query_plan.postprocessing()); + println!("Verified in {} ms.", now.elapsed().as_secs_f64() * 1000.); + + // Display the result + println!("Query Result:"); + println!("{result:?}"); +} + +fn main() { + let mut rng = StdRng::from_seed(DORY_SEED); + let public_parameters = PublicParameters::rand(DORY_SETUP_MAX_NU, &mut rng); + let prover_setup = ProverSetup::from(&public_parameters); + let verifier_setup = VerifierSetup::from(&public_parameters); + + let filename = "./crates/proof-of-sql/examples/countries/countries_gdp.csv"; + let inferred_schema = + SchemaRef::new(infer_schema_from_files(&[filename.to_string()], b',', None, true).unwrap()); + let posql_compatible_schema = get_posql_compatible_schema(&inferred_schema); + + let countries_batch = ReaderBuilder::new(posql_compatible_schema) + .with_header(true) + .build(File::open(filename).unwrap()) + .unwrap() + .next() + .unwrap() + .unwrap(); + + // Load the table into an "Accessor" so that the prover and verifier can access the data/commitments. + let mut accessor = + OwnedTableTestAccessor::::new_empty_with_setup(&prover_setup); + accessor.add_table( + "countries.countries".parse().unwrap(), + OwnedTable::try_from(countries_batch).unwrap(), + 0, + ); + + prove_and_verify_query( + "SELECT COUNT(*) AS total_countries FROM countries", + &accessor, + &prover_setup, + &verifier_setup, + ); + + prove_and_verify_query( + "SELECT country FROM countries WHERE continent = 'Asia'", + &accessor, + &prover_setup, + &verifier_setup, + ); + + prove_and_verify_query( + "SELECT country FROM countries WHERE gdp > 500 AND gdp < 1500", + &accessor, + &prover_setup, + &verifier_setup, + ); + + prove_and_verify_query( + "SELECT SUM(gdp) AS total_market_cap FROM countries WHERE country = 'China' OR country = 'India'", + &accessor, + &prover_setup, + &verifier_setup, + ); +} diff --git a/crates/proof-of-sql/examples/plastics/main.rs b/crates/proof-of-sql/examples/plastics/main.rs new file mode 100644 index 000000000..7263e7538 --- /dev/null +++ b/crates/proof-of-sql/examples/plastics/main.rs @@ -0,0 +1,135 @@ +//! This is a non-interactive example of using Proof of SQL with a plastics dataset. +//! To run this, use `cargo run --release --example plastics`. +//! +//! NOTE: If this doesn't work because you do not have the appropriate GPU drivers installed, +//! you can run `cargo run --release --example plastics --no-default-features --features="arrow cpu-perf"` instead. It will be slower for proof generation. + +use arrow::datatypes::SchemaRef; +use arrow_csv::{infer_schema_from_files, ReaderBuilder}; +use proof_of_sql::{ + base::database::{ + arrow_schema_utility::get_posql_compatible_schema, OwnedTable, OwnedTableTestAccessor, + TestAccessor, + }, + proof_primitive::dory::{ + DynamicDoryCommitment, DynamicDoryEvaluationProof, ProverSetup, PublicParameters, + VerifierSetup, + }, + sql::{parse::QueryExpr, postprocessing::apply_postprocessing_steps, proof::QueryProof}, +}; +use rand::{rngs::StdRng, SeedableRng}; +use std::{fs::File, time::Instant}; + +// We generate the public parameters and the setups used by the prover and verifier for the Dory PCS. +// The `max_nu` should be set such that the maximum table size is less than `2^(2*max_nu-1)`. +const DORY_SETUP_MAX_NU: usize = 8; +// This should be a "nothing-up-my-sleeve" phrase or number. +const DORY_SEED: [u8; 32] = *b"32f7f321c4ab1234d5e6f7a8b9c0d1e2"; + +/// # Panics +/// Will panic if the query does not parse or the proof fails to verify. +fn prove_and_verify_query( + sql: &str, + accessor: &OwnedTableTestAccessor, + prover_setup: &ProverSetup, + verifier_setup: &VerifierSetup, +) { + // Parse the query: + println!("Parsing the query: {sql}..."); + let now = Instant::now(); + let query_plan = QueryExpr::::try_new( + sql.parse().unwrap(), + "plastics".parse().unwrap(), + accessor, + ) + .unwrap(); + println!("Done in {} ms.", now.elapsed().as_secs_f64() * 1000.); + + // Generate the proof and result: + print!("Generating proof..."); + let now = Instant::now(); + let (proof, provable_result) = QueryProof::::new( + query_plan.proof_expr(), + accessor, + &prover_setup, + ); + println!("Done in {} ms.", now.elapsed().as_secs_f64() * 1000.); + + // Verify the result with the proof: + print!("Verifying proof..."); + let now = Instant::now(); + let result = proof + .verify( + query_plan.proof_expr(), + accessor, + &provable_result, + &verifier_setup, + ) + .unwrap(); + let result = apply_postprocessing_steps(result.table, query_plan.postprocessing()); + println!("Verified in {} ms.", now.elapsed().as_secs_f64() * 1000.); + + // Display the result + println!("Query Result:"); + println!("{result:?}"); +} + +fn main() { + let mut rng = StdRng::from_seed(DORY_SEED); + let public_parameters = PublicParameters::rand(DORY_SETUP_MAX_NU, &mut rng); + let prover_setup = ProverSetup::from(&public_parameters); + let verifier_setup = VerifierSetup::from(&public_parameters); + + let filename = "./crates/proof-of-sql/examples/plastics/plastics.csv"; + let schema = get_posql_compatible_schema(&SchemaRef::new( + infer_schema_from_files(&[filename.to_string()], b',', None, true).unwrap(), + )); + let plastics_batch = ReaderBuilder::new(schema) + .with_header(true) + .build(File::open(filename).unwrap()) + .unwrap() + .next() + .unwrap() + .unwrap(); + + // Load the table into an "Accessor" so that the prover and verifier can access the data/commitments. + let mut accessor = + OwnedTableTestAccessor::::new_empty_with_setup(&prover_setup); + accessor.add_table( + "plastics.types".parse().unwrap(), + OwnedTable::try_from(plastics_batch).unwrap(), + 0, + ); + + // Query 1: Count total number of plastic types + prove_and_verify_query( + "SELECT COUNT(*) AS total_types FROM types", + &accessor, + &prover_setup, + &verifier_setup, + ); + + // Query 2: List names of biodegradable plastics + prove_and_verify_query( + "SELECT Name FROM types WHERE Biodegradable = TRUE ORDER BY Name", + &accessor, + &prover_setup, + &verifier_setup, + ); + + // Query 3: Show average density of plastics by recycling code + prove_and_verify_query( + "SELECT Code, SUM(Density)/COUNT(*) as avg_density FROM types GROUP BY Code ORDER BY Code", + &accessor, + &prover_setup, + &verifier_setup, + ); + + // Query 4: List plastics with density greater than 1.0 g/cm³ + prove_and_verify_query( + "SELECT Name, Density FROM types WHERE Density > 1.0 ORDER BY Density DESC", + &accessor, + &prover_setup, + &verifier_setup, + ); +} diff --git a/crates/proof-of-sql/examples/plastics/plastics.csv b/crates/proof-of-sql/examples/plastics/plastics.csv new file mode 100644 index 000000000..9b793da0a --- /dev/null +++ b/crates/proof-of-sql/examples/plastics/plastics.csv @@ -0,0 +1,19 @@ +Name,Code,Density,Biodegradable +Polyethylene Terephthalate (PET),1,1.38,FALSE +High-Density Polyethylene (HDPE),2,0.97,FALSE +Polyvinyl Chloride (PVC),3,1.40,FALSE +Low-Density Polyethylene (LDPE),4,0.92,FALSE +Polypropylene (PP),5,0.90,FALSE +Polystyrene (PS),6,1.05,FALSE +Polylactic Acid (PLA),7,1.25,TRUE +Polybutylene Adipate Terephthalate (PBAT),7,1.26,TRUE +Polyhydroxyalkanoates (PHA),7,1.24,TRUE +Polybutylene Succinate (PBS),7,1.26,TRUE +Acrylic (PMMA),7,1.18,FALSE +Polycarbonate (PC),7,1.20,FALSE +Polyurethane (PU),7,1.05,FALSE +Acrylonitrile Butadiene Styrene (ABS),7,1.04,FALSE +Polyamide (Nylon),7,1.15,FALSE +Polyethylene Furanoate (PEF),7,1.43,TRUE +Thermoplastic Starch (TPS),7,1.35,TRUE +Cellulose Acetate,7,1.30,TRUE \ No newline at end of file diff --git a/crates/proof-of-sql/examples/posql_db/main.rs b/crates/proof-of-sql/examples/posql_db/main.rs index a796ed25e..f2facf2c8 100644 --- a/crates/proof-of-sql/examples/posql_db/main.rs +++ b/crates/proof-of-sql/examples/posql_db/main.rs @@ -5,6 +5,7 @@ mod commit_accessor; mod csv_accessor; /// TODO: add docs mod record_batch_accessor; + use arrow::{ datatypes::{DataType, Field, Schema}, record_batch::RecordBatch, @@ -273,7 +274,7 @@ fn main() { end_timer(timer); println!( "Verified Result: {:?}", - RecordBatch::try_from(query_result).unwrap() + RecordBatch::try_from(query_result.table).unwrap() ); } } diff --git a/crates/proof-of-sql/examples/posql_db/record_batch_accessor.rs b/crates/proof-of-sql/examples/posql_db/record_batch_accessor.rs index 8af046972..08e25f4fe 100644 --- a/crates/proof-of-sql/examples/posql_db/record_batch_accessor.rs +++ b/crates/proof-of-sql/examples/posql_db/record_batch_accessor.rs @@ -2,9 +2,9 @@ use arrow::record_batch::RecordBatch; use bumpalo::Bump; use indexmap::IndexMap; use proof_of_sql::base::{ + arrow::arrow_array_to_column_conversion::ArrayRefExt, database::{ - ArrayRefExt, Column, ColumnRef, ColumnType, DataAccessor, MetadataAccessor, SchemaAccessor, - TableRef, + Column, ColumnRef, ColumnType, DataAccessor, MetadataAccessor, SchemaAccessor, TableRef, }, scalar::Scalar, }; diff --git a/crates/proof-of-sql/examples/programming_books/main.rs b/crates/proof-of-sql/examples/programming_books/main.rs new file mode 100644 index 000000000..09af38488 --- /dev/null +++ b/crates/proof-of-sql/examples/programming_books/main.rs @@ -0,0 +1,133 @@ +//! This is a non-interactive example of using Proof of SQL with an extended books dataset. +//! To run this, use `cargo run --example programming_books`. +//! +//! NOTE: If this doesn't work because you do not have the appropriate GPU drivers installed, +//! you can run `cargo run --example programming_books --no-default-features --features="arrow cpu-perf"` instead. It will be slower for proof generation. + +use arrow::datatypes::SchemaRef; +use arrow_csv::{infer_schema_from_files, ReaderBuilder}; +use proof_of_sql::{ + base::database::{ + arrow_schema_utility::get_posql_compatible_schema, OwnedTable, OwnedTableTestAccessor, + TestAccessor, + }, + proof_primitive::dory::{ + DynamicDoryCommitment, DynamicDoryEvaluationProof, ProverSetup, PublicParameters, + VerifierSetup, + }, + sql::{parse::QueryExpr, postprocessing::apply_postprocessing_steps, proof::QueryProof}, +}; +use rand::{rngs::StdRng, SeedableRng}; +use std::{fs::File, time::Instant}; + +const DORY_SETUP_MAX_NU: usize = 8; +const DORY_SEED: [u8; 32] = *b"ebab60d58dee4cc69658939b7c2a582d"; + +/// # Panics +/// Will panic if the query does not parse or the proof fails to verify. +fn prove_and_verify_query( + sql: &str, + accessor: &OwnedTableTestAccessor, + prover_setup: &ProverSetup, + verifier_setup: &VerifierSetup, +) { + // Parse the query: + println!("Parsing the query: {sql}..."); + let now = Instant::now(); + let query_plan = QueryExpr::::try_new( + sql.parse().unwrap(), + "programming_books".parse().unwrap(), + accessor, + ) + .unwrap(); + println!("Done in {} ms.", now.elapsed().as_secs_f64() * 1000.); + + // Generate the proof and result: + print!("Generating proof..."); + let now = Instant::now(); + let (proof, provable_result) = QueryProof::::new( + query_plan.proof_expr(), + accessor, + &prover_setup, + ); + println!("Done in {} ms.", now.elapsed().as_secs_f64() * 1000.); + + // Verify the result with the proof: + print!("Verifying proof..."); + let now = Instant::now(); + let result = proof + .verify( + query_plan.proof_expr(), + accessor, + &provable_result, + &verifier_setup, + ) + .unwrap(); + let result = apply_postprocessing_steps(result.table, query_plan.postprocessing()); + println!("Verified in {} ms.", now.elapsed().as_secs_f64() * 1000.); + + // Display the result + println!("Query Result:"); + println!("{result:?}"); +} + +fn main() { + let mut rng = StdRng::from_seed(DORY_SEED); + let public_parameters = PublicParameters::rand(DORY_SETUP_MAX_NU, &mut rng); + let prover_setup = ProverSetup::from(&public_parameters); + let verifier_setup = VerifierSetup::from(&public_parameters); + + let filename = "./crates/proof-of-sql/examples/programming_books/programming_books.csv"; + let inferred_schema = + SchemaRef::new(infer_schema_from_files(&[filename.to_string()], b',', None, true).unwrap()); + let posql_compatible_schema = get_posql_compatible_schema(&inferred_schema); + + let books_extra_batch = ReaderBuilder::new(posql_compatible_schema) + .with_header(true) + .build(File::open(filename).unwrap()) + .unwrap() + .next() + .unwrap() + .unwrap(); + + // Load the table into an "Accessor" so that the prover and verifier can access the data/commitments. + let mut accessor = + OwnedTableTestAccessor::::new_empty_with_setup(&prover_setup); + accessor.add_table( + "programming_books.books".parse().unwrap(), + OwnedTable::try_from(books_extra_batch).unwrap(), + 0, + ); + + // Query 1: Count the total number of books + prove_and_verify_query( + "SELECT COUNT(*) AS total_books FROM books", + &accessor, + &prover_setup, + &verifier_setup, + ); + + // Query 2: Find books with a rating higher than 4.5 + prove_and_verify_query( + "SELECT title, author FROM books WHERE rating > 4.5", + &accessor, + &prover_setup, + &verifier_setup, + ); + + // Query 3: List all programming books published after 2000 + prove_and_verify_query( + "SELECT title, publication_year FROM books WHERE genre = 'Programming' AND publication_year > 2000", + &accessor, + &prover_setup, + &verifier_setup, + ); + + // Query 4: Find the top 5 authors with the most books + prove_and_verify_query( + "SELECT author, COUNT(*) AS book_count FROM books GROUP BY author ORDER BY book_count DESC LIMIT 5", + &accessor, + &prover_setup, + &verifier_setup, + ); +} diff --git a/crates/proof-of-sql/examples/programming_books/programming_books.csv b/crates/proof-of-sql/examples/programming_books/programming_books.csv new file mode 100644 index 000000000..dbad4ba3b --- /dev/null +++ b/crates/proof-of-sql/examples/programming_books/programming_books.csv @@ -0,0 +1,11 @@ +title,author,publication_year,genre,rating +The Pragmatic Programmer,Andrew Hunt,1999,Programming,4.5 +Clean Code,Robert C. Martin,2008,Programming,4.7 +The Clean Coder,Robert C. Martin,2011,Programming,4.6 +Design Patterns,Erich Gamma,1994,Software Engineering,4.8 +Refactoring,Martin Fowler,1999,Programming,4.5 +Effective Java,Joshua Bloch,2008,Programming,4.7 +Introduction to Algorithms,Thomas H. Cormen,2009,Computer Science,4.8 +Code Complete,Steve McConnell,2004,Programming,4.6 +The Mythical Man-Month,Fred Brooks,1975,Software Engineering,4.3 +Algorithms,Robert Sedgewick,1983,Computer Science,4.5 diff --git a/crates/proof-of-sql/examples/rockets/launch_vehicles.csv b/crates/proof-of-sql/examples/rockets/launch_vehicles.csv new file mode 100644 index 000000000..cba1aeb2f --- /dev/null +++ b/crates/proof-of-sql/examples/rockets/launch_vehicles.csv @@ -0,0 +1,28 @@ +name,country,year,mtow +Saturn V,USA,1967,2976000 +Falcon Heavy,USA,2018,1420788 +Space Shuttle,USA,1981,2041167 +Energia,USSR,1987,2400000 +Ariane 5,Europe,1996,780000 +Delta IV Heavy,USA,2004,733400 +Long March 5,China,2016,869000 +Proton,USSR/Russia,1965,705000 +Atlas V,USA,2002,546700 +H-IIA,Japan,2001,445000 +Soyuz,USSR/Russia,1966,308000 +Falcon 9,USA,2010,549054 +Vega,Europe,2012,137000 +PSLV,India,1993,320000 +GSLV Mk III,India,2017,640000 +Titan II,USA,1962,153800 +Angara A5,Russia,2014,1335000 +Delta II,USA,1989,231870 +Electron,New Zealand,2017,12500 +Antares,USA,2013,240000 +Zenit,USSR/Ukraine,1985,462000 +N1,USSR,1969,2735000 +New Glenn,USA,2024,1300000 +Redstone,USA,1953,29500 +Black Arrow,UK,1971,18800 +Diamant,France,1965,18000 +Pegasus,USA,1990,23300 diff --git a/crates/proof-of-sql/examples/rockets/main.rs b/crates/proof-of-sql/examples/rockets/main.rs new file mode 100644 index 000000000..79ad4c4a4 --- /dev/null +++ b/crates/proof-of-sql/examples/rockets/main.rs @@ -0,0 +1,132 @@ +//! This is a non-interactive example of using Proof of SQL with a rockets dataset. +//! To run this, use `cargo run --release --example rockets`. +//! +//! NOTE: If this doesn't work because you do not have the appropriate GPU drivers installed, +//! you can run `cargo run --release --example rockets --no-default-features --features="arrow cpu-perf"` instead. It will be slower for proof generation. + +use arrow::datatypes::SchemaRef; +use arrow_csv::{infer_schema_from_files, ReaderBuilder}; +use proof_of_sql::{ + base::database::{ + arrow_schema_utility::get_posql_compatible_schema, OwnedTable, OwnedTableTestAccessor, + TestAccessor, + }, + proof_primitive::dory::{ + DynamicDoryCommitment, DynamicDoryEvaluationProof, ProverSetup, PublicParameters, + VerifierSetup, + }, + sql::{parse::QueryExpr, postprocessing::apply_postprocessing_steps, proof::QueryProof}, +}; +use rand::{rngs::StdRng, SeedableRng}; +use std::{fs::File, time::Instant}; + +// We generate the public parameters and the setups used by the prover and verifier for the Dory PCS. +// The `max_nu` should be set such that the maximum table size is less than `2^(2*max_nu-1)`. +const DORY_SETUP_MAX_NU: usize = 8; +// This should be a "nothing-up-my-sleeve" phrase or number. +const DORY_SEED: [u8; 32] = *b"7a1b3c8d2e4f9g6h5i0j7k2l8m3n9o1p"; + +/// # Panics +/// Will panic if the query does not parse or the proof fails to verify. +fn prove_and_verify_query( + sql: &str, + accessor: &OwnedTableTestAccessor, + prover_setup: &ProverSetup, + verifier_setup: &VerifierSetup, +) { + // Parse the query: + println!("Parsing the query: {sql}..."); + let now = Instant::now(); + let query_plan = QueryExpr::::try_new( + sql.parse().unwrap(), + "rockets".parse().unwrap(), + accessor, + ) + .unwrap(); + println!("Done in {} ms.", now.elapsed().as_secs_f64() * 1000.); + + // Generate the proof and result: + print!("Generating proof..."); + let now = Instant::now(); + let (proof, provable_result) = QueryProof::::new( + query_plan.proof_expr(), + accessor, + &prover_setup, + ); + println!("Done in {} ms.", now.elapsed().as_secs_f64() * 1000.); + + // Verify the result with the proof: + print!("Verifying proof..."); + let now = Instant::now(); + let result = proof + .verify( + query_plan.proof_expr(), + accessor, + &provable_result, + &verifier_setup, + ) + .unwrap(); + let result = apply_postprocessing_steps(result.table, query_plan.postprocessing()); + println!("Verified in {} ms.", now.elapsed().as_secs_f64() * 1000.); + + // Display the result + println!("Query Result:"); + println!("{result:?}"); +} + +fn main() { + let mut rng = StdRng::from_seed(DORY_SEED); + let public_parameters = PublicParameters::rand(DORY_SETUP_MAX_NU, &mut rng); + let prover_setup = ProverSetup::from(&public_parameters); + let verifier_setup = VerifierSetup::from(&public_parameters); + + let filename = "./crates/proof-of-sql/examples/rockets/launch_vehicles.csv"; + let inferred_schema = + SchemaRef::new(infer_schema_from_files(&[filename.to_string()], b',', None, true).unwrap()); + let posql_compatible_schema = get_posql_compatible_schema(&inferred_schema); + + let rockets_batch = ReaderBuilder::new(posql_compatible_schema) + .with_header(true) + .build(File::open(filename).unwrap()) + .unwrap() + .next() + .unwrap() + .unwrap(); + + // Load the table into an "Accessor" so that the prover and verifier can access the data/commitments. + let mut accessor = + OwnedTableTestAccessor::::new_empty_with_setup(&prover_setup); + accessor.add_table( + "rockets.launch_vehicles".parse().unwrap(), + OwnedTable::try_from(rockets_batch).unwrap(), + 0, + ); + + prove_and_verify_query( + "SELECT COUNT(*) AS total_rockets FROM launch_vehicles", + &accessor, + &prover_setup, + &verifier_setup, + ); + + prove_and_verify_query( + "SELECT country, MAX(mtow) as max_mtow, COUNT(*) as rocket_count FROM launch_vehicles GROUP BY country ORDER BY max_mtow DESC", + &accessor, + &prover_setup, + &verifier_setup, + ); + + prove_and_verify_query( + "SELECT name FROM launch_vehicles WHERE country = 'USA'", + &accessor, + &prover_setup, + &verifier_setup, + ); + + prove_and_verify_query( + "SELECT name FROM launch_vehicles WHERE mtow > 100000 and mtow < 150000", + &accessor, + &prover_setup, + &verifier_setup, + ); +} diff --git a/crates/proof-of-sql/examples/stocks/main.rs b/crates/proof-of-sql/examples/stocks/main.rs new file mode 100644 index 000000000..6a3da79aa --- /dev/null +++ b/crates/proof-of-sql/examples/stocks/main.rs @@ -0,0 +1,144 @@ +//! This is a non-interactive example of using Proof of SQL with a stocks dataset. +//! To run this, use cargo run --release --example stocks. +//! +//! NOTE: If this doesn't work because you do not have the appropriate GPU drivers installed, +//! you can run cargo run --release --example stocks --no-default-features --features="arrow cpu-perf" instead. It will be slower for proof generation. + +use arrow::datatypes::SchemaRef; +use arrow_csv::{infer_schema_from_files, ReaderBuilder}; +use proof_of_sql::{ + base::database::{ + arrow_schema_utility::get_posql_compatible_schema, OwnedTable, OwnedTableTestAccessor, + TestAccessor, + }, + proof_primitive::dory::{ + DynamicDoryCommitment, DynamicDoryEvaluationProof, ProverSetup, PublicParameters, + VerifierSetup, + }, + sql::{parse::QueryExpr, postprocessing::apply_postprocessing_steps, proof::QueryProof}, +}; +use rand::{rngs::StdRng, SeedableRng}; +use std::{fs::File, time::Instant}; + +// We generate the public parameters and the setups used by the prover and verifier for the Dory PCS. +// The max_nu should be set such that the maximum table size is less than 2^(2*max_nu-1). +const DORY_SETUP_MAX_NU: usize = 8; +// This should be a "nothing-up-my-sleeve" phrase or number. +const DORY_SEED: [u8; 32] = *b"f9d2e8c1b7a654309cfe81d2b7a3c940"; + +/// # Panics +/// Will panic if the query does not parse or the proof fails to verify. +fn prove_and_verify_query( + sql: &str, + accessor: &OwnedTableTestAccessor, + prover_setup: &ProverSetup, + verifier_setup: &VerifierSetup, +) { + // Parse the query: + println!("Parsing the query: {sql}..."); + let now = Instant::now(); + let query_plan = QueryExpr::::try_new( + sql.parse().unwrap(), + "stocks".parse().unwrap(), + accessor, + ) + .unwrap(); + println!("Done in {} ms.", now.elapsed().as_secs_f64() * 1000.); + + // Generate the proof and result: + print!("Generating proof..."); + let now = Instant::now(); + let (proof, provable_result) = QueryProof::::new( + query_plan.proof_expr(), + accessor, + &prover_setup, + ); + println!("Done in {} ms.", now.elapsed().as_secs_f64() * 1000.); + + // Verify the result with the proof: + print!("Verifying proof..."); + let now = Instant::now(); + let result = proof + .verify( + query_plan.proof_expr(), + accessor, + &provable_result, + &verifier_setup, + ) + .unwrap(); + let result = apply_postprocessing_steps(result.table, query_plan.postprocessing()); + println!("Verified in {} ms.", now.elapsed().as_secs_f64() * 1000.); + + // Display the result + println!("Query Result:"); + println!("{result:?}"); +} + +fn main() { + let mut rng = StdRng::from_seed(DORY_SEED); + let public_parameters = PublicParameters::rand(DORY_SETUP_MAX_NU, &mut rng); + let prover_setup = ProverSetup::from(&public_parameters); + let verifier_setup = VerifierSetup::from(&public_parameters); + + let filename = "./crates/proof-of-sql/examples/stocks/stocks.csv"; + let schema = get_posql_compatible_schema(&SchemaRef::new( + infer_schema_from_files(&[filename.to_string()], b',', None, true).unwrap(), + )); + let stocks_batch = ReaderBuilder::new(schema) + .with_header(true) + .build(File::open(filename).unwrap()) + .unwrap() + .next() + .unwrap() + .unwrap(); + + // Load the table into an "Accessor" so that the prover and verifier can access the data/commitments. + let mut accessor = + OwnedTableTestAccessor::::new_empty_with_setup(&prover_setup); + accessor.add_table( + "stocks.stocks".parse().unwrap(), + OwnedTable::try_from(stocks_batch).unwrap(), + 0, + ); + + // Query 1: Calculate total market cap and count of stocks + prove_and_verify_query( + "SELECT SUM(MarketCap) as total_market_cap, COUNT(*) as c FROM stocks", + &accessor, + &prover_setup, + &verifier_setup, + ); + + // Query 2: Find technology stocks with PE ratio under 30 and dividend yield > 0 + prove_and_verify_query( + "SELECT Symbol, Company, PE_Ratio, DividendYield + FROM stocks + WHERE Sector = 'Technology' AND PE_Ratio < 30 AND DividendYield > 0 + ORDER BY PE_Ratio DESC", + &accessor, + &prover_setup, + &verifier_setup, + ); + + // Query 3: Average market cap by sector (using SUM/COUNT instead of AVG) + prove_and_verify_query( + "SELECT Sector, SUM(MarketCap)/COUNT(*) as avg_market_cap, COUNT(*) as c + FROM stocks + GROUP BY Sector + ORDER BY avg_market_cap DESC", + &accessor, + &prover_setup, + &verifier_setup, + ); + + // Query 4: High value stocks with significant volume and dividend yield + prove_and_verify_query( + "SELECT Symbol, Company, Price, Volume, DividendYield + FROM stocks + WHERE Volume > 20000000 AND DividendYield > 0 AND Price > 100 + ORDER BY Volume DESC", + &accessor, + &prover_setup, + &verifier_setup, + ); +} diff --git a/crates/proof-of-sql/examples/stocks/stocks.csv b/crates/proof-of-sql/examples/stocks/stocks.csv new file mode 100644 index 000000000..1e00ae80d --- /dev/null +++ b/crates/proof-of-sql/examples/stocks/stocks.csv @@ -0,0 +1,19 @@ +Symbol,Company,Sector,Price,Volume,MarketCap,PE_Ratio,DividendYield +AAPL,Apple Inc.,Technology,175.50,52000000,2850.25,28.5,0.5 +MSFT,Microsoft Corporation,Technology,325.75,25000000,2425.80,32.8,0.8 +GOOGL,Alphabet Inc.,Technology,135.20,18000000,1720.40,25.2,0.0 +AMZN,Amazon.com Inc.,Consumer Cyclical,128.90,35000000,1325.60,42.1,0.0 +META,Meta Platforms Inc.,Technology,308.45,22000000,785.30,31.5,0.0 +TSLA,Tesla Inc.,Automotive,238.45,100000000,755.90,75.2,0.0 +JPM,JPMorgan Chase & Co.,Financial Services,148.75,12000000,428.90,11.2,2.8 +V,Visa Inc.,Financial Services,245.60,8000000,510.30,28.9,0.7 +JNJ,Johnson & Johnson,Healthcare,152.30,6000000,395.80,15.5,3.1 +PG,Procter & Gamble Co.,Consumer Defensive,150.20,7000000,355.40,25.3,2.4 +XOM,Exxon Mobil Corporation,Energy,105.80,20000000,425.60,8.9,3.5 +WMT,Walmart Inc.,Consumer Defensive,158.90,8500000,428.70,26.4,1.5 +KO,Coca-Cola Company,Consumer Defensive,58.75,12000000,254.30,24.2,3.0 +DIS,Walt Disney Company,Communication Services,85.50,15000000,156.80,42.8,0.0 +NFLX,Netflix Inc.,Communication Services,385.20,7500000,171.20,38.5,0.0 +NVDA,NVIDIA Corporation,Technology,455.80,45000000,1125.40,110.5,0.1 +INTC,Intel Corporation,Technology,35.80,42000000,150.20,15.8,1.5 +AMD,Advanced Micro Devices,Technology,105.25,65000000,170.30,220.5,0.0 \ No newline at end of file diff --git a/crates/proof-of-sql/examples/sushi/fish.csv b/crates/proof-of-sql/examples/sushi/fish.csv new file mode 100644 index 000000000..e0a14ebc0 --- /dev/null +++ b/crates/proof-of-sql/examples/sushi/fish.csv @@ -0,0 +1,13 @@ +nameEn,nameJa,kindEn,kindJa,pricePerPound +Tuna,Maguro,Lean Red Meat,Akami,25 +Tuna,Maguro,Medium Fat Red Meat,Toro,65 +Tuna,Maguro,Fatty Red Meat,Otoro,115 +Bonito,Katsuo,Red Meat,Akami,20 +Yellowtail,Hamachi,Red Meat,Akami,27 +Salmon,Salmon,White Fish,Shiromi,17 +Sea Bream,Tai,White Fish,Shiromi,32 +Sea Bass,Suzuki,White Fish,Shiromi,28 +Mackerel,Aji,Silver Skinned,Hikarimono,14 +Sardine,Iwashi,Silver Skinned,Hikarimono,11 +Scallops,Hotate,Shellfish,Kai,26 +Ark-shell clams,Akagai,Shellfish,Kai,29 diff --git a/crates/proof-of-sql/examples/sushi/main.rs b/crates/proof-of-sql/examples/sushi/main.rs new file mode 100644 index 000000000..0c7f89545 --- /dev/null +++ b/crates/proof-of-sql/examples/sushi/main.rs @@ -0,0 +1,141 @@ +//! This is an non-interactive example of using Proof of SQL with some sushi related datasets. +//! To run this, use `cargo run --example sushi`. + +//! NOTE: If this doesn't work because you do not have the appropriate GPU drivers installed, +//! you can run `cargo run --release --example sushi --no-default-features --features="arrow cpu-perf"` instead. It will be slower for proof generation. +use arrow::datatypes::SchemaRef; +use arrow_csv::{infer_schema_from_files, ReaderBuilder}; +use proof_of_sql::{ + base::database::{OwnedTable, OwnedTableTestAccessor, TestAccessor}, + proof_primitive::dory::{ + DynamicDoryCommitment, DynamicDoryEvaluationProof, ProverSetup, PublicParameters, + VerifierSetup, + }, + sql::{parse::QueryExpr, proof::QueryProof}, +}; +use rand::{rngs::StdRng, SeedableRng}; +use std::{fs::File, time::Instant}; + +const DORY_SETUP_MAX_NU: usize = 8; +const DORY_SEED: [u8; 32] = *b"sushi-is-the-best-food-available"; + +/// # Panics +/// Will panic if the query does not parse or the proof fails to verify. +fn prove_and_verify_query( + sql: &str, + accessor: &OwnedTableTestAccessor, + prover_setup: &ProverSetup, + verifier_setup: &VerifierSetup, +) { + // Parse the query: + println!("Parsing the query: {sql}..."); + let now = Instant::now(); + let query_plan = QueryExpr::::try_new( + sql.parse().unwrap(), + "sushi".parse().unwrap(), + accessor, + ) + .unwrap(); + println!("Done in {} ms.", now.elapsed().as_secs_f64() * 1000.); + // Generate the proof and result: + print!("Generating proof..."); + let now = Instant::now(); + let (proof, provable_result) = QueryProof::::new( + query_plan.proof_expr(), + accessor, + &prover_setup, + ); + println!("Done in {} ms.", now.elapsed().as_secs_f64() * 1000.); + // Verify the result with the proof: + print!("Verifying proof..."); + let now = Instant::now(); + let result = proof + .verify( + query_plan.proof_expr(), + accessor, + &provable_result, + &verifier_setup, + ) + .unwrap(); + println!("Verified in {} ms.", now.elapsed().as_secs_f64() * 1000.); + // Display the result + println!("Query Result:"); + println!("{:?}", result.table); +} + +fn main() { + let mut rng = StdRng::from_seed(DORY_SEED); + let public_parameters = PublicParameters::rand(DORY_SETUP_MAX_NU, &mut rng); + let prover_setup = ProverSetup::from(&public_parameters); + let verifier_setup = VerifierSetup::from(&public_parameters); + + let filename = "./crates/proof-of-sql/examples/sushi/fish.csv"; + let fish_batch = ReaderBuilder::new(SchemaRef::new( + infer_schema_from_files(&[filename.to_string()], b',', None, true).unwrap(), + )) + .with_header(true) + .build(File::open(filename).unwrap()) + .unwrap() + .next() + .unwrap() + .unwrap(); + println!("{fish_batch:?}"); + + // Load the table into an "Accessor" so that the prover and verifier can access the data/commitments. + let mut accessor = + OwnedTableTestAccessor::::new_empty_with_setup(&prover_setup); + accessor.add_table( + "sushi.fish".parse().unwrap(), + OwnedTable::try_from(fish_batch).unwrap(), + 0, + ); + + prove_and_verify_query( + "SELECT * FROM fish", + &accessor, + &prover_setup, + &verifier_setup, + ); + + prove_and_verify_query( + "SELECT COUNT(*) FROM fish WHERE nameEn = 'Tuna'", + &accessor, + &prover_setup, + &verifier_setup, + ); + + prove_and_verify_query( + "SELECT kindEn FROM fish WHERE kindJa = 'Otoro'", + &accessor, + &prover_setup, + &verifier_setup, + ); + + prove_and_verify_query( + "SELECT kindEn FROM fish WHERE kindJa = 'Otoro'", + &accessor, + &prover_setup, + &verifier_setup, + ); + + prove_and_verify_query( + "SELECT * FROM fish WHERE pricePerPound > 25 AND pricePerPound < 75", + &accessor, + &prover_setup, + &verifier_setup, + ); + + prove_and_verify_query( + "SELECT kindJa, COUNT(*) FROM fish GROUP BY kindJa", + &accessor, + &prover_setup, + &verifier_setup, + ); + + prove_and_verify_query( + "SELECT kindJa, pricePerPound FROM fish WHERE nameEn = 'Tuna' ORDER BY pricePerPound ASC", + &accessor, + &prover_setup, + &verifier_setup, + ); +} diff --git a/crates/proof-of-sql/examples/tech_gadget_prices/main.rs b/crates/proof-of-sql/examples/tech_gadget_prices/main.rs new file mode 100644 index 000000000..2ae44ee29 --- /dev/null +++ b/crates/proof-of-sql/examples/tech_gadget_prices/main.rs @@ -0,0 +1,108 @@ +//! This is a non-interactive example of using Proof of SQL with a `tech_gadget_prices` dataset. +//! To run this, use cargo run --release --example `tech_gadget_prices`. +//! +//! NOTE: If this doesn't work because you do not have the appropriate GPU drivers installed, +//! you can run cargo run --release --example `tech_gadget_prices` --no-default-features --features="arrow cpu-perf" instead. It will be slower for proof generation. + +use arrow::datatypes::SchemaRef; +use arrow_csv::{infer_schema_from_files, ReaderBuilder}; +use proof_of_sql::{ + base::database::{OwnedTable, OwnedTableTestAccessor}, + proof_primitive::dory::{ + DynamicDoryCommitment, DynamicDoryEvaluationProof, ProverSetup, PublicParameters, + VerifierSetup, + }, + sql::{parse::QueryExpr, proof::QueryProof}, +}; +use rand::{rngs::StdRng, SeedableRng}; +use std::{error::Error, fs::File, time::Instant}; + +const DORY_SETUP_MAX_NU: usize = 8; +const DORY_SEED: [u8; 32] = *b"tech-gadget-prices-dataset-seed!"; + +fn prove_and_verify_query( + sql: &str, + accessor: &OwnedTableTestAccessor, + prover_setup: &ProverSetup, + verifier_setup: &VerifierSetup, +) -> Result<(), Box> { + println!("Parsing the query: {sql}..."); + let now = Instant::now(); + let query_plan = QueryExpr::::try_new( + sql.parse()?, + "tech_gadget_prices".parse()?, + accessor, + )?; + println!("Done in {} ms.", now.elapsed().as_secs_f64() * 1000.); + + print!("Generating proof..."); + let now = Instant::now(); + let (proof, provable_result) = QueryProof::::new( + query_plan.proof_expr(), + accessor, + &prover_setup, + ); + println!("Done in {} ms.", now.elapsed().as_secs_f64() * 1000.); + + print!("Verifying proof..."); + let now = Instant::now(); + let result = proof.verify( + query_plan.proof_expr(), + accessor, + &provable_result, + &verifier_setup, + )?; + println!("Verified in {} ms.", now.elapsed().as_secs_f64() * 1000.); + + println!("Query Result:"); + println!("{:?}", result.table); + Ok(()) +} + +fn main() -> Result<(), Box> { + let mut rng = StdRng::from_seed(DORY_SEED); + let public_parameters = PublicParameters::rand(DORY_SETUP_MAX_NU, &mut rng); + let prover_setup = ProverSetup::from(&public_parameters); + let verifier_setup = VerifierSetup::from(&public_parameters); + + let filename = "./tech_gadget_prices/tech_gadget_prices.csv"; + let schema = infer_schema_from_files(&[filename.to_string()], b',', None, true)?; + let data_batch = ReaderBuilder::new(SchemaRef::new(schema)) + .with_header(true) + .build(File::open(filename)?)? + .next() + .ok_or("No data found in CSV file")??; + + let accessor = OwnedTableTestAccessor::::new_from_table( + "tech_gadget_prices.prices".parse()?, + OwnedTable::try_from(data_batch)?, + 0, + &prover_setup, + ); + + prove_and_verify_query( + "SELECT COUNT(*) AS total FROM prices", + &accessor, + &prover_setup, + &verifier_setup, + )?; + prove_and_verify_query( + "SELECT Brand, COUNT(*) AS total FROM prices GROUP BY Brand ORDER BY total", + &accessor, + &prover_setup, + &verifier_setup, + )?; + prove_and_verify_query( + "SELECT Name, Price FROM prices WHERE Category = 'Smartphone' ORDER BY Price DESC LIMIT 3", + &accessor, + &prover_setup, + &verifier_setup, + )?; + prove_and_verify_query( + "SELECT Name, ReleaseYear FROM prices WHERE Price > 500 ORDER BY ReleaseYear DESC", + &accessor, + &prover_setup, + &verifier_setup, + )?; + Ok(()) +} diff --git a/crates/proof-of-sql/examples/tech_gadget_prices/tech_gadget_prices.csv b/crates/proof-of-sql/examples/tech_gadget_prices/tech_gadget_prices.csv new file mode 100644 index 000000000..e03e8e90d --- /dev/null +++ b/crates/proof-of-sql/examples/tech_gadget_prices/tech_gadget_prices.csv @@ -0,0 +1,9 @@ +Name,Brand,Category,ReleaseYear,Price +iPhone 13,Apple,Smartphone,2021,799 +Galaxy S21,Samsung,Smartphone,2021,799 +PlayStation 5,Sony,Game Console,2020,499 +Xbox Series X,Microsoft,Game Console,2020,499 +iPad Pro,Apple,Tablet,2021,799 +Surface Pro 7,Microsoft,Tablet,2019,749 +MacBook Air,Apple,Laptop,2020,999 +Pixel 5,Google,Smartphone,2020,699 \ No newline at end of file diff --git a/crates/proof-of-sql/examples/vehicles/main.rs b/crates/proof-of-sql/examples/vehicles/main.rs new file mode 100644 index 000000000..dede0042e --- /dev/null +++ b/crates/proof-of-sql/examples/vehicles/main.rs @@ -0,0 +1,132 @@ +//! This is a non-interactive example of using Proof of SQL with a vehicles dataset. +//! To run this, use `cargo run --release --example vehicles`. +//! +//! NOTE: If this doesn't work because you do not have the appropriate GPU drivers installed, +//! you can run `cargo run --release --example vehicles --no-default-features --features="arrow cpu-perf"` instead. It will be slower for proof generation. + +use arrow::datatypes::SchemaRef; +use arrow_csv::{infer_schema_from_files, ReaderBuilder}; +use proof_of_sql::{ + base::database::{ + arrow_schema_utility::get_posql_compatible_schema, OwnedTable, OwnedTableTestAccessor, + TestAccessor, + }, + proof_primitive::dory::{ + DynamicDoryCommitment, DynamicDoryEvaluationProof, ProverSetup, PublicParameters, + VerifierSetup, + }, + sql::{parse::QueryExpr, postprocessing::apply_postprocessing_steps, proof::QueryProof}, +}; +use rand::{rngs::StdRng, SeedableRng}; +use std::{fs::File, time::Instant}; + +// We generate the public parameters and the setups used by the prover and verifier for the Dory PCS. +// The `max_nu` should be set such that the maximum table size is less than `2^(2*max_nu-1)`. +const DORY_SETUP_MAX_NU: usize = 8; +// This should be a "nothing-up-my-sleeve" phrase or number. +const DORY_SEED: [u8; 32] = *b"97b13c8che4f9g4050jjkk2l5m3nbo1p"; + +/// # Panics +/// Will panic if the query does not parse or the proof fails to verify. +fn prove_and_verify_query( + sql: &str, + accessor: &OwnedTableTestAccessor, + prover_setup: &ProverSetup, + verifier_setup: &VerifierSetup, +) { + // Parse the query: + println!("Parsing the query: {sql}..."); + let now = Instant::now(); + let query_plan = QueryExpr::::try_new( + sql.parse().unwrap(), + "vehicles".parse().unwrap(), + accessor, + ) + .unwrap(); + println!("Done in {} ms.", now.elapsed().as_secs_f64() * 1000.); + + // Generate the proof and result: + print!("Generating proof..."); + let now = Instant::now(); + let (proof, provable_result) = QueryProof::::new( + query_plan.proof_expr(), + accessor, + &prover_setup, + ); + println!("Done in {} ms.", now.elapsed().as_secs_f64() * 1000.); + + // Verify the result with the proof: + print!("Verifying proof..."); + let now = Instant::now(); + let result = proof + .verify( + query_plan.proof_expr(), + accessor, + &provable_result, + &verifier_setup, + ) + .unwrap(); + let result = apply_postprocessing_steps(result.table, query_plan.postprocessing()); + println!("Verified in {} ms.", now.elapsed().as_secs_f64() * 1000.); + + // Display the result + println!("Query Result:"); + println!("{result:?}"); +} + +fn main() { + let mut rng = StdRng::from_seed(DORY_SEED); + let public_parameters = PublicParameters::rand(DORY_SETUP_MAX_NU, &mut rng); + let prover_setup = ProverSetup::from(&public_parameters); + let verifier_setup = VerifierSetup::from(&public_parameters); + + let filename = "./crates/proof-of-sql/examples/vehicles/vehicles.csv"; + let inferred_schema = + SchemaRef::new(infer_schema_from_files(&[filename.to_string()], b',', None, true).unwrap()); + let posql_compatible_schema = get_posql_compatible_schema(&inferred_schema); + + let vehicles_batch = ReaderBuilder::new(posql_compatible_schema) + .with_header(true) + .build(File::open(filename).unwrap()) + .unwrap() + .next() + .unwrap() + .unwrap(); + + // Load the table into an "Accessor" so that the prover and verifier can access the data/commitments. + let mut accessor = + OwnedTableTestAccessor::::new_empty_with_setup(&prover_setup); + accessor.add_table( + "vehicles.vehicles".parse().unwrap(), + OwnedTable::try_from(vehicles_batch).unwrap(), + 0, + ); + + prove_and_verify_query( + "SELECT COUNT(*) AS total_vehicles FROM vehicles", + &accessor, + &prover_setup, + &verifier_setup, + ); + + prove_and_verify_query( + "SELECT model FROM vehicles WHERE make = 'Ford'", + &accessor, + &prover_setup, + &verifier_setup, + ); + + prove_and_verify_query( + "SELECT make,model FROM vehicles WHERE price > 30000 AND price < 50000", + &accessor, + &prover_setup, + &verifier_setup, + ); + + prove_and_verify_query( + "SELECT MAX(price) FROM vehicles", + &accessor, + &prover_setup, + &verifier_setup, + ); +} diff --git a/crates/proof-of-sql/examples/vehicles/vehicles.csv b/crates/proof-of-sql/examples/vehicles/vehicles.csv new file mode 100644 index 000000000..16b8a8bc9 --- /dev/null +++ b/crates/proof-of-sql/examples/vehicles/vehicles.csv @@ -0,0 +1,11 @@ +id,make,model,year,price +1,Tesla,Model S,2020,79999 +2,Ford,Mustang,2019,55999 +3,Chevrolet,Camaro,2018,42999 +4,BMW,3 Series,2021,41300 +5,Audi,A4,2021,39900 +6,Ford,Maverick,2024,27990 +7,Hyundai,Santa Cruz,2024,29895 +8,Toyota,Tacoma,2024,32995 +9,Ram,1500 TRX,2024,98335 +10,Ford,F-150,2025,39345 diff --git a/crates/proof-of-sql/src/base/database/arrow_array_to_column_conversion.rs b/crates/proof-of-sql/src/base/arrow/arrow_array_to_column_conversion.rs similarity index 100% rename from crates/proof-of-sql/src/base/database/arrow_array_to_column_conversion.rs rename to crates/proof-of-sql/src/base/arrow/arrow_array_to_column_conversion.rs diff --git a/crates/proof-of-sql/src/base/arrow/column_arrow_conversions.rs b/crates/proof-of-sql/src/base/arrow/column_arrow_conversions.rs new file mode 100644 index 000000000..5eade6cf3 --- /dev/null +++ b/crates/proof-of-sql/src/base/arrow/column_arrow_conversions.rs @@ -0,0 +1,79 @@ +use crate::base::{ + database::{ColumnField, ColumnType}, + math::decimal::Precision, +}; +use alloc::sync::Arc; +use arrow::datatypes::{DataType, Field, TimeUnit as ArrowTimeUnit}; +use proof_of_sql_parser::posql_time::{PoSQLTimeUnit, PoSQLTimeZone}; + +/// Convert [`ColumnType`] values to some arrow [`DataType`] +impl From<&ColumnType> for DataType { + fn from(column_type: &ColumnType) -> Self { + match column_type { + ColumnType::Boolean => DataType::Boolean, + ColumnType::TinyInt => DataType::Int8, + ColumnType::SmallInt => DataType::Int16, + ColumnType::Int => DataType::Int32, + ColumnType::BigInt => DataType::Int64, + ColumnType::Int128 => DataType::Decimal128(38, 0), + ColumnType::Decimal75(precision, scale) => { + DataType::Decimal256(precision.value(), *scale) + } + ColumnType::VarChar => DataType::Utf8, + ColumnType::Scalar => unimplemented!("Cannot convert Scalar type to arrow type"), + ColumnType::TimestampTZ(timeunit, timezone) => { + let arrow_timezone = Some(Arc::from(timezone.to_string())); + let arrow_timeunit = match timeunit { + PoSQLTimeUnit::Second => ArrowTimeUnit::Second, + PoSQLTimeUnit::Millisecond => ArrowTimeUnit::Millisecond, + PoSQLTimeUnit::Microsecond => ArrowTimeUnit::Microsecond, + PoSQLTimeUnit::Nanosecond => ArrowTimeUnit::Nanosecond, + }; + DataType::Timestamp(arrow_timeunit, arrow_timezone) + } + } + } +} + +/// Convert arrow [`DataType`] values to some [`ColumnType`] +impl TryFrom for ColumnType { + type Error = String; + + fn try_from(data_type: DataType) -> Result { + match data_type { + DataType::Boolean => Ok(ColumnType::Boolean), + DataType::Int8 => Ok(ColumnType::TinyInt), + DataType::Int16 => Ok(ColumnType::SmallInt), + DataType::Int32 => Ok(ColumnType::Int), + DataType::Int64 => Ok(ColumnType::BigInt), + DataType::Decimal128(38, 0) => Ok(ColumnType::Int128), + DataType::Decimal256(precision, scale) if precision <= 75 => { + Ok(ColumnType::Decimal75(Precision::new(precision)?, scale)) + } + DataType::Timestamp(time_unit, timezone_option) => { + let posql_time_unit = match time_unit { + ArrowTimeUnit::Second => PoSQLTimeUnit::Second, + ArrowTimeUnit::Millisecond => PoSQLTimeUnit::Millisecond, + ArrowTimeUnit::Microsecond => PoSQLTimeUnit::Microsecond, + ArrowTimeUnit::Nanosecond => PoSQLTimeUnit::Nanosecond, + }; + Ok(ColumnType::TimestampTZ( + posql_time_unit, + PoSQLTimeZone::try_from(&timezone_option)?, + )) + } + DataType::Utf8 => Ok(ColumnType::VarChar), + _ => Err(format!("Unsupported arrow data type {data_type:?}")), + } + } +} +/// Convert [`ColumnField`] values to arrow Field +impl From<&ColumnField> for Field { + fn from(column_field: &ColumnField) -> Self { + Field::new( + column_field.name().name(), + (&column_field.data_type()).into(), + false, + ) + } +} diff --git a/crates/proof-of-sql/src/base/arrow/mod.rs b/crates/proof-of-sql/src/base/arrow/mod.rs new file mode 100644 index 000000000..0bcac183d --- /dev/null +++ b/crates/proof-of-sql/src/base/arrow/mod.rs @@ -0,0 +1,26 @@ +//! This module provides conversions and utilities for working with Arrow data structures. + +/// Module for handling conversion from Arrow arrays to columns. +pub mod arrow_array_to_column_conversion; + +/// Module for converting between owned and Arrow data structures. +pub mod owned_and_arrow_conversions; + +#[cfg(test)] +/// Tests for owned and Arrow conversions. +mod owned_and_arrow_conversions_test; + +/// Module for converting record batches. +pub mod record_batch_conversion; + +/// Module for record batch error definitions. +pub mod record_batch_errors; + +/// Utility functions for record batches. +pub mod record_batch_utility; + +/// Module for scalar and i256 conversions. +pub mod scalar_and_i256_conversions; + +/// Module for handling conversions between columns and Arrow arrays. +pub mod column_arrow_conversions; diff --git a/crates/proof-of-sql/src/base/database/owned_and_arrow_conversions.rs b/crates/proof-of-sql/src/base/arrow/owned_and_arrow_conversions.rs similarity index 98% rename from crates/proof-of-sql/src/base/database/owned_and_arrow_conversions.rs rename to crates/proof-of-sql/src/base/arrow/owned_and_arrow_conversions.rs index adf4f94af..74ad96839 100644 --- a/crates/proof-of-sql/src/base/database/owned_and_arrow_conversions.rs +++ b/crates/proof-of-sql/src/base/arrow/owned_and_arrow_conversions.rs @@ -12,12 +12,9 @@ //! This is because there is no `Int128` type in Arrow. //! This does not check that the values are less than 39 digits. //! However, the actual arrow backing `i128` is the correct value. -use super::scalar_and_i256_conversions::convert_scalar_to_i256; +use super::scalar_and_i256_conversions::{convert_i256_to_scalar, convert_scalar_to_i256}; use crate::base::{ - database::{ - scalar_and_i256_conversions::convert_i256_to_scalar, OwnedColumn, OwnedTable, - OwnedTableError, - }, + database::{OwnedColumn, OwnedTable, OwnedTableError}, map::IndexMap, math::decimal::Precision, scalar::Scalar, diff --git a/crates/proof-of-sql/src/base/database/owned_and_arrow_conversions_test.rs b/crates/proof-of-sql/src/base/arrow/owned_and_arrow_conversions_test.rs similarity index 97% rename from crates/proof-of-sql/src/base/database/owned_and_arrow_conversions_test.rs rename to crates/proof-of-sql/src/base/arrow/owned_and_arrow_conversions_test.rs index 970df4bad..539d94eaa 100644 --- a/crates/proof-of-sql/src/base/database/owned_and_arrow_conversions_test.rs +++ b/crates/proof-of-sql/src/base/arrow/owned_and_arrow_conversions_test.rs @@ -1,7 +1,7 @@ -use super::{OwnedColumn, OwnedTable}; +use super::owned_and_arrow_conversions::OwnedArrowConversionError; use crate::{ base::{ - database::{owned_table_utility::*, OwnedArrowConversionError}, + database::{owned_table_utility::*, OwnedColumn, OwnedTable}, map::IndexMap, scalar::Curve25519Scalar, }, diff --git a/crates/proof-of-sql/src/base/arrow/record_batch_conversion.rs b/crates/proof-of-sql/src/base/arrow/record_batch_conversion.rs new file mode 100644 index 000000000..6f24457cc --- /dev/null +++ b/crates/proof-of-sql/src/base/arrow/record_batch_conversion.rs @@ -0,0 +1,160 @@ +use super::{ + arrow_array_to_column_conversion::ArrayRefExt, + record_batch_errors::{AppendRecordBatchTableCommitmentError, RecordBatchToColumnsError}, +}; +use crate::base::{ + commitment::{ + AppendColumnCommitmentsError, AppendTableCommitmentError, Commitment, TableCommitment, + TableCommitmentFromColumnsError, + }, + database::Column, + scalar::Scalar, +}; +use arrow::record_batch::RecordBatch; +use bumpalo::Bump; +use proof_of_sql_parser::Identifier; + +/// This function will return an error if: +/// - The field name cannot be parsed into an [`Identifier`]. +/// - The conversion of an Arrow array to a [`Column`] fails. +pub fn batch_to_columns<'a, S: Scalar + 'a>( + batch: &'a RecordBatch, + alloc: &'a Bump, +) -> Result)>, RecordBatchToColumnsError> { + batch + .schema() + .fields() + .into_iter() + .zip(batch.columns()) + .map(|(field, array)| { + let identifier: Identifier = field.name().parse()?; + let column: Column = array.to_column(alloc, &(0..array.len()), None)?; + Ok((identifier, column)) + }) + .collect() +} + +impl TableCommitment { + /// Append an arrow [`RecordBatch`] to the existing [`TableCommitment`]. + /// + /// The row offset is assumed to be the end of the [`TableCommitment`]'s current range. + /// + /// Will error on a variety of mismatches, or if the provided columns have mixed length. + #[allow(clippy::missing_panics_doc)] + pub fn try_append_record_batch( + &mut self, + batch: &RecordBatch, + setup: &C::PublicSetup<'_>, + ) -> Result<(), AppendRecordBatchTableCommitmentError> { + match self.try_append_rows( + batch_to_columns::(batch, &Bump::new())? + .iter() + .map(|(a, b)| (a, b)), + setup, + ) { + Ok(()) => Ok(()), + Err(AppendTableCommitmentError::MixedLengthColumns { .. }) => { + panic!("RecordBatches cannot have columns of mixed length") + } + Err(AppendTableCommitmentError::AppendColumnCommitments { + source: AppendColumnCommitmentsError::DuplicateIdentifiers { .. }, + }) => { + panic!("RecordBatches cannot have duplicate identifiers") + } + Err(AppendTableCommitmentError::AppendColumnCommitments { + source: AppendColumnCommitmentsError::Mismatch { source: e }, + }) => Err(e)?, + } + } + /// Returns a [`TableCommitment`] to the provided arrow [`RecordBatch`]. + pub fn try_from_record_batch( + batch: &RecordBatch, + setup: &C::PublicSetup<'_>, + ) -> Result, RecordBatchToColumnsError> { + Self::try_from_record_batch_with_offset(batch, 0, setup) + } + + /// Returns a [`TableCommitment`] to the provided arrow [`RecordBatch`] with the given row offset. + #[allow(clippy::missing_panics_doc)] + pub fn try_from_record_batch_with_offset( + batch: &RecordBatch, + offset: usize, + setup: &C::PublicSetup<'_>, + ) -> Result, RecordBatchToColumnsError> { + match Self::try_from_columns_with_offset( + batch_to_columns::(batch, &Bump::new())? + .iter() + .map(|(a, b)| (a, b)), + offset, + setup, + ) { + Ok(commitment) => Ok(commitment), + Err(TableCommitmentFromColumnsError::MixedLengthColumns { .. }) => { + panic!("RecordBatches cannot have columns of mixed length") + } + Err(TableCommitmentFromColumnsError::DuplicateIdentifiers { .. }) => { + panic!("RecordBatches cannot have duplicate identifiers") + } + } + } +} + +#[cfg(all(test, feature = "blitzar"))] +mod tests { + use super::*; + use crate::{base::scalar::Curve25519Scalar, record_batch}; + use curve25519_dalek::RistrettoPoint; + + #[test] + fn we_can_create_and_append_table_commitments_with_record_batchs() { + let batch = record_batch!( + "a" => [1i64, 2, 3], + "b" => ["1", "2", "3"], + ); + + let b_scals = ["1".into(), "2".into(), "3".into()]; + + let columns = [ + ( + &"a".parse().unwrap(), + &Column::::BigInt(&[1, 2, 3]), + ), + ( + &"b".parse().unwrap(), + &Column::::VarChar((&["1", "2", "3"], &b_scals)), + ), + ]; + + let mut expected_commitment = + TableCommitment::::try_from_columns_with_offset(columns, 0, &()) + .unwrap(); + + let mut commitment = + TableCommitment::::try_from_record_batch(&batch, &()).unwrap(); + + assert_eq!(commitment, expected_commitment); + + let batch2 = record_batch!( + "a" => [4i64, 5, 6], + "b" => ["4", "5", "6"], + ); + + let b_scals2 = ["4".into(), "5".into(), "6".into()]; + + let columns2 = [ + ( + &"a".parse().unwrap(), + &Column::::BigInt(&[4, 5, 6]), + ), + ( + &"b".parse().unwrap(), + &Column::::VarChar((&["4", "5", "6"], &b_scals2)), + ), + ]; + + expected_commitment.try_append_rows(columns2, &()).unwrap(); + commitment.try_append_record_batch(&batch2, &()).unwrap(); + + assert_eq!(commitment, expected_commitment); + } +} diff --git a/crates/proof-of-sql/src/base/arrow/record_batch_errors.rs b/crates/proof-of-sql/src/base/arrow/record_batch_errors.rs new file mode 100644 index 000000000..b3986d1a6 --- /dev/null +++ b/crates/proof-of-sql/src/base/arrow/record_batch_errors.rs @@ -0,0 +1,38 @@ +use super::arrow_array_to_column_conversion::ArrowArrayToColumnConversionError; +use crate::base::commitment::ColumnCommitmentsMismatch; +use proof_of_sql_parser::ParseError; +use snafu::Snafu; + +/// Errors that can occur when trying to create or extend a [`TableCommitment`] from a record batch. +#[derive(Debug, Snafu)] +pub enum RecordBatchToColumnsError { + /// Error converting from arrow array + #[snafu(transparent)] + ArrowArrayToColumnConversionError { + /// The underlying source error + source: ArrowArrayToColumnConversionError, + }, + #[snafu(transparent)] + /// This error occurs when convering from a record batch name to an identifier fails. (Which may be impossible.) + FieldParseFail { + /// The underlying source error + source: ParseError, + }, +} + +/// Errors that can occur when attempting to append a record batch to a [`TableCommitment`]. +#[derive(Debug, Snafu)] +pub enum AppendRecordBatchTableCommitmentError { + /// During commitment operation, metadata indicates that operand tables cannot be the same. + #[snafu(transparent)] + ColumnCommitmentsMismatch { + /// The underlying source error + source: ColumnCommitmentsMismatch, + }, + /// Error converting from arrow array + #[snafu(transparent)] + ArrowBatchToColumnError { + /// The underlying source error + source: RecordBatchToColumnsError, + }, +} diff --git a/crates/proof-of-sql/src/base/database/record_batch_utility.rs b/crates/proof-of-sql/src/base/arrow/record_batch_utility.rs similarity index 99% rename from crates/proof-of-sql/src/base/database/record_batch_utility.rs rename to crates/proof-of-sql/src/base/arrow/record_batch_utility.rs index d1180005b..3ede592bd 100644 --- a/crates/proof-of-sql/src/base/database/record_batch_utility.rs +++ b/crates/proof-of-sql/src/base/arrow/record_batch_utility.rs @@ -169,7 +169,7 @@ macro_rules! record_batch { use arrow::datatypes::Field; use arrow::datatypes::Schema; use arrow::record_batch::RecordBatch; - use $crate::base::database::ToArrow; + use $crate::base::arrow::record_batch_utility::ToArrow; let schema = Arc::new(Schema::new( vec![$( diff --git a/crates/proof-of-sql/src/base/database/scalar_and_i256_conversions.rs b/crates/proof-of-sql/src/base/arrow/scalar_and_i256_conversions.rs similarity index 96% rename from crates/proof-of-sql/src/base/database/scalar_and_i256_conversions.rs rename to crates/proof-of-sql/src/base/arrow/scalar_and_i256_conversions.rs index 9a44c3766..f606c03cb 100644 --- a/crates/proof-of-sql/src/base/database/scalar_and_i256_conversions.rs +++ b/crates/proof-of-sql/src/base/arrow/scalar_and_i256_conversions.rs @@ -54,12 +54,10 @@ pub fn convert_i256_to_scalar(value: &i256) -> Option { #[cfg(test)] mod tests { - - use super::{convert_i256_to_scalar, convert_scalar_to_i256}; - use crate::base::{ - database::scalar_and_i256_conversions::{MAX_SUPPORTED_I256, MIN_SUPPORTED_I256}, - scalar::{Curve25519Scalar, Scalar}, + use super::{ + convert_i256_to_scalar, convert_scalar_to_i256, MAX_SUPPORTED_I256, MIN_SUPPORTED_I256, }; + use crate::base::scalar::{Curve25519Scalar, Scalar}; use arrow::datatypes::i256; use num_traits::Zero; use rand::RngCore; diff --git a/crates/proof-of-sql/src/base/commitment/table_commitment.rs b/crates/proof-of-sql/src/base/commitment/table_commitment.rs index 0f9e21783..1a52b7cea 100644 --- a/crates/proof-of-sql/src/base/commitment/table_commitment.rs +++ b/crates/proof-of-sql/src/base/commitment/table_commitment.rs @@ -2,18 +2,13 @@ use super::{ committable_column::CommittableColumn, AppendColumnCommitmentsError, ColumnCommitments, ColumnCommitmentsMismatch, Commitment, DuplicateIdentifiers, }; -#[cfg(feature = "arrow")] -use crate::base::database::{ArrayRefExt, ArrowArrayToColumnConversionError}; use crate::base::{ - database::{Column, ColumnField, CommitmentAccessor, OwnedTable, TableRef}, + database::{ColumnField, CommitmentAccessor, OwnedTable, TableRef}, scalar::Scalar, }; use alloc::vec::Vec; -#[cfg(feature = "arrow")] -use arrow::record_batch::RecordBatch; -use bumpalo::Bump; use core::ops::Range; -use proof_of_sql_parser::{Identifier, ParseError}; +use proof_of_sql_parser::Identifier; use serde::{Deserialize, Serialize}; use snafu::Snafu; @@ -83,42 +78,6 @@ pub enum TableCommitmentArithmeticError { NonContiguous, } -/// Errors that can occur when trying to create or extend a [`TableCommitment`] from a record batch. -#[cfg(feature = "arrow")] -#[derive(Debug, Snafu)] -pub enum RecordBatchToColumnsError { - /// Error converting from arrow array - #[snafu(transparent)] - ArrowArrayToColumnConversionError { - /// The underlying source error - source: ArrowArrayToColumnConversionError, - }, - #[snafu(transparent)] - /// This error occurs when convering from a record batch name to an identifier fails. (Which may be impossible.) - FieldParseFail { - /// The underlying source error - source: ParseError, - }, -} - -/// Errors that can occur when attempting to append a record batch to a [`TableCommitment`]. -#[cfg(feature = "arrow")] -#[derive(Debug, Snafu)] -pub enum AppendRecordBatchTableCommitmentError { - /// During commitment operation, metadata indicates that operand tables cannot be the same. - #[snafu(transparent)] - ColumnCommitmentsMismatch { - /// The underlying source error - source: ColumnCommitmentsMismatch, - }, - /// Error converting from arrow array - #[snafu(transparent)] - ArrowBatchToColumnError { - /// The underlying source error - source: RecordBatchToColumnsError, - }, -} - /// Commitment for an entire table, with column and table metadata. /// /// Unlike [`ColumnCommitments`], all columns in this commitment must have the same length. @@ -398,90 +357,6 @@ impl TableCommitment { range, }) } - - /// Append an arrow [`RecordBatch`] to the existing [`TableCommitment`]. - /// - /// The row offset is assumed to be the end of the [`TableCommitment`]'s current range. - /// - /// Will error on a variety of mismatches, or if the provided columns have mixed length. - #[cfg(feature = "arrow")] - #[allow(clippy::missing_panics_doc)] - pub fn try_append_record_batch( - &mut self, - batch: &RecordBatch, - setup: &C::PublicSetup<'_>, - ) -> Result<(), AppendRecordBatchTableCommitmentError> { - match self.try_append_rows( - batch_to_columns::(batch, &Bump::new())? - .iter() - .map(|(a, b)| (a, b)), - setup, - ) { - Ok(()) => Ok(()), - Err(AppendTableCommitmentError::MixedLengthColumns { .. }) => { - panic!("RecordBatches cannot have columns of mixed length") - } - Err(AppendTableCommitmentError::AppendColumnCommitments { - source: AppendColumnCommitmentsError::DuplicateIdentifiers { .. }, - }) => { - panic!("RecordBatches cannot have duplicate identifiers") - } - Err(AppendTableCommitmentError::AppendColumnCommitments { - source: AppendColumnCommitmentsError::Mismatch { source: e }, - }) => Err(e)?, - } - } - /// Returns a [`TableCommitment`] to the provided arrow [`RecordBatch`]. - #[cfg(feature = "arrow")] - pub fn try_from_record_batch( - batch: &RecordBatch, - setup: &C::PublicSetup<'_>, - ) -> Result, RecordBatchToColumnsError> { - Self::try_from_record_batch_with_offset(batch, 0, setup) - } - - /// Returns a [`TableCommitment`] to the provided arrow [`RecordBatch`] with the given row offset. - #[allow(clippy::missing_panics_doc)] - #[cfg(feature = "arrow")] - pub fn try_from_record_batch_with_offset( - batch: &RecordBatch, - offset: usize, - setup: &C::PublicSetup<'_>, - ) -> Result, RecordBatchToColumnsError> { - match Self::try_from_columns_with_offset( - batch_to_columns::(batch, &Bump::new())? - .iter() - .map(|(a, b)| (a, b)), - offset, - setup, - ) { - Ok(commitment) => Ok(commitment), - Err(TableCommitmentFromColumnsError::MixedLengthColumns { .. }) => { - panic!("RecordBatches cannot have columns of mixed length") - } - Err(TableCommitmentFromColumnsError::DuplicateIdentifiers { .. }) => { - panic!("RecordBatches cannot have duplicate identifiers") - } - } - } -} - -#[cfg(feature = "arrow")] -fn batch_to_columns<'a, S: Scalar + 'a>( - batch: &'a RecordBatch, - alloc: &'a Bump, -) -> Result)>, RecordBatchToColumnsError> { - batch - .schema() - .fields() - .into_iter() - .zip(batch.columns()) - .map(|(field, array)| { - let identifier: Identifier = field.name().parse()?; - let column: Column = array.to_column(alloc, &(0..array.len()), None)?; - Ok((identifier, column)) - }) - .collect() } /// Return the number of rows for the provided columns, erroring if they have mixed length. @@ -505,13 +380,10 @@ fn num_rows_of_columns<'a>( #[cfg(all(test, feature = "arrow", feature = "blitzar"))] mod tests { use super::*; - use crate::{ - base::{ - database::{owned_table_utility::*, OwnedColumn}, - map::IndexMap, - scalar::Curve25519Scalar, - }, - record_batch, + use crate::base::{ + database::{owned_table_utility::*, OwnedColumn}, + map::IndexMap, + scalar::Curve25519Scalar, }; use curve25519_dalek::RistrettoPoint; @@ -1263,57 +1135,4 @@ mod tests { Err(TableCommitmentArithmeticError::NegativeRange { .. }) )); } - - #[test] - fn we_can_create_and_append_table_commitments_with_record_batchs() { - let batch = record_batch!( - "a" => [1i64, 2, 3], - "b" => ["1", "2", "3"], - ); - - let b_scals = ["1".into(), "2".into(), "3".into()]; - - let columns = [ - ( - &"a".parse().unwrap(), - &Column::::BigInt(&[1, 2, 3]), - ), - ( - &"b".parse().unwrap(), - &Column::::VarChar((&["1", "2", "3"], &b_scals)), - ), - ]; - - let mut expected_commitment = - TableCommitment::::try_from_columns_with_offset(columns, 0, &()) - .unwrap(); - - let mut commitment = - TableCommitment::::try_from_record_batch(&batch, &()).unwrap(); - - assert_eq!(commitment, expected_commitment); - - let batch2 = record_batch!( - "a" => [4i64, 5, 6], - "b" => ["4", "5", "6"], - ); - - let b_scals2 = ["4".into(), "5".into(), "6".into()]; - - let columns2 = [ - ( - &"a".parse().unwrap(), - &Column::::BigInt(&[4, 5, 6]), - ), - ( - &"b".parse().unwrap(), - &Column::::VarChar((&["4", "5", "6"], &b_scals2)), - ), - ]; - - expected_commitment.try_append_rows(columns2, &()).unwrap(); - commitment.try_append_record_batch(&batch2, &()).unwrap(); - - assert_eq!(commitment, expected_commitment); - } } diff --git a/crates/proof-of-sql/src/base/database/column.rs b/crates/proof-of-sql/src/base/database/column.rs index 3d3b11372..be536b1d5 100644 --- a/crates/proof-of-sql/src/base/database/column.rs +++ b/crates/proof-of-sql/src/base/database/column.rs @@ -4,9 +4,7 @@ use crate::base::{ scalar::{Scalar, ScalarExt}, slice_ops::slice_cast_with, }; -use alloc::{sync::Arc, vec::Vec}; -#[cfg(feature = "arrow")] -use arrow::datatypes::{DataType, Field, TimeUnit as ArrowTimeUnit}; +use alloc::vec::Vec; use bumpalo::Bump; use core::{ fmt, @@ -412,70 +410,6 @@ impl ColumnType { } } -/// Convert [`ColumnType`] values to some arrow [`DataType`] -#[cfg(feature = "arrow")] -impl From<&ColumnType> for DataType { - fn from(column_type: &ColumnType) -> Self { - match column_type { - ColumnType::Boolean => DataType::Boolean, - ColumnType::TinyInt => DataType::Int8, - ColumnType::SmallInt => DataType::Int16, - ColumnType::Int => DataType::Int32, - ColumnType::BigInt => DataType::Int64, - ColumnType::Int128 => DataType::Decimal128(38, 0), - ColumnType::Decimal75(precision, scale) => { - DataType::Decimal256(precision.value(), *scale) - } - ColumnType::VarChar => DataType::Utf8, - ColumnType::Scalar => unimplemented!("Cannot convert Scalar type to arrow type"), - ColumnType::TimestampTZ(timeunit, timezone) => { - let arrow_timezone = Some(Arc::from(timezone.to_string())); - let arrow_timeunit = match timeunit { - PoSQLTimeUnit::Second => ArrowTimeUnit::Second, - PoSQLTimeUnit::Millisecond => ArrowTimeUnit::Millisecond, - PoSQLTimeUnit::Microsecond => ArrowTimeUnit::Microsecond, - PoSQLTimeUnit::Nanosecond => ArrowTimeUnit::Nanosecond, - }; - DataType::Timestamp(arrow_timeunit, arrow_timezone) - } - } - } -} - -/// Convert arrow [`DataType`] values to some [`ColumnType`] -#[cfg(feature = "arrow")] -impl TryFrom for ColumnType { - type Error = String; - - fn try_from(data_type: DataType) -> Result { - match data_type { - DataType::Boolean => Ok(ColumnType::Boolean), - DataType::Int8 => Ok(ColumnType::TinyInt), - DataType::Int16 => Ok(ColumnType::SmallInt), - DataType::Int32 => Ok(ColumnType::Int), - DataType::Int64 => Ok(ColumnType::BigInt), - DataType::Decimal128(38, 0) => Ok(ColumnType::Int128), - DataType::Decimal256(precision, scale) if precision <= 75 => { - Ok(ColumnType::Decimal75(Precision::new(precision)?, scale)) - } - DataType::Timestamp(time_unit, timezone_option) => { - let posql_time_unit = match time_unit { - ArrowTimeUnit::Second => PoSQLTimeUnit::Second, - ArrowTimeUnit::Millisecond => PoSQLTimeUnit::Millisecond, - ArrowTimeUnit::Microsecond => PoSQLTimeUnit::Microsecond, - ArrowTimeUnit::Nanosecond => PoSQLTimeUnit::Nanosecond, - }; - Ok(ColumnType::TimestampTZ( - posql_time_unit, - PoSQLTimeZone::try_from(&timezone_option)?, - )) - } - DataType::Utf8 => Ok(ColumnType::VarChar), - _ => Err(format!("Unsupported arrow data type {data_type:?}")), - } - } -} - /// Display the column type as a str name (in all caps) impl Display for ColumnType { fn fmt(&self, f: &mut Formatter<'_>) -> fmt::Result { @@ -570,18 +504,6 @@ impl ColumnField { } } -/// Convert [`ColumnField`] values to arrow Field -#[cfg(feature = "arrow")] -impl From<&ColumnField> for Field { - fn from(column_field: &ColumnField) -> Self { - Field::new( - column_field.name().name(), - (&column_field.data_type()).into(), - false, - ) - } -} - #[cfg(test)] mod tests { use super::*; diff --git a/crates/proof-of-sql/src/base/database/columnar_value.rs b/crates/proof-of-sql/src/base/database/columnar_value.rs new file mode 100644 index 000000000..0505044b6 --- /dev/null +++ b/crates/proof-of-sql/src/base/database/columnar_value.rs @@ -0,0 +1,138 @@ +use crate::base::{ + database::{Column, ColumnType, LiteralValue}, + scalar::Scalar, +}; +use bumpalo::Bump; +use snafu::Snafu; + +/// The result of evaluating an expression. +/// +/// Inspired by [`datafusion_expr_common::ColumnarValue`] +#[derive(Debug, Eq, PartialEq, Clone)] +pub enum ColumnarValue<'a, S: Scalar> { + /// A [ `ColumnarValue::Column` ] is a list of values. + Column(Column<'a, S>), + /// A [ `ColumnarValue::Literal` ] is a single value with indeterminate size. + Literal(LiteralValue), +} + +/// Errors from operations on [`ColumnarValue`]s. +#[derive(Snafu, Debug, PartialEq, Eq)] +pub enum ColumnarValueError { + /// Attempt to convert a `[ColumnarValue::Column]` to a column of a different length + ColumnLengthMismatch { + /// The length of the `[ColumnarValue::Column]` + columnar_value_length: usize, + /// The length we attempted to convert the `[ColumnarValue::Column]` to + attempt_to_convert_length: usize, + }, +} + +impl<'a, S: Scalar> ColumnarValue<'a, S> { + /// Provides the column type associated with the column + pub fn column_type(&self) -> ColumnType { + match self { + Self::Column(column) => column.column_type(), + Self::Literal(literal) => literal.column_type(), + } + } + + /// Converts the [`ColumnarValue`] to a [`Column`] + pub fn into_column( + &self, + num_rows: usize, + alloc: &'a Bump, + ) -> Result, ColumnarValueError> { + match self { + Self::Column(column) => { + if column.len() == num_rows { + Ok(*column) + } else { + Err(ColumnarValueError::ColumnLengthMismatch { + columnar_value_length: column.len(), + attempt_to_convert_length: num_rows, + }) + } + } + Self::Literal(literal) => { + Ok(Column::from_literal_with_length(literal, num_rows, alloc)) + } + } + } +} + +#[cfg(test)] +mod tests { + use super::*; + use crate::base::scalar::test_scalar::TestScalar; + use core::convert::Into; + + #[test] + fn we_can_get_column_type_of_columnar_values() { + let column = ColumnarValue::Column(Column::::Int(&[1, 2, 3])); + assert_eq!(column.column_type(), ColumnType::Int); + + let column = ColumnarValue::Literal(LiteralValue::::Boolean(true)); + assert_eq!(column.column_type(), ColumnType::Boolean); + } + + #[test] + fn we_can_transform_columnar_values_into_columns() { + let bump = Bump::new(); + + let columnar_value = ColumnarValue::Column(Column::::Int(&[1, 2, 3])); + let column = columnar_value.into_column(3, &bump).unwrap(); + assert_eq!(column, Column::Int(&[1, 2, 3])); + + let columnar_value = ColumnarValue::Literal(LiteralValue::::Boolean(false)); + let column = columnar_value.into_column(5, &bump).unwrap(); + assert_eq!(column, Column::Boolean(&[false; 5])); + + // Check whether it works if `num_rows` is 0 + let columnar_value = ColumnarValue::Literal(LiteralValue::::TinyInt(2)); + let column = columnar_value.into_column(0, &bump).unwrap(); + assert_eq!(column, Column::TinyInt(&[])); + + let columnar_value = ColumnarValue::Column(Column::::SmallInt(&[])); + let column = columnar_value.into_column(0, &bump).unwrap(); + assert_eq!(column, Column::SmallInt(&[])); + } + + #[test] + fn we_cannot_transform_columnar_values_into_columns_of_different_length() { + let bump = Bump::new(); + + let columnar_value = ColumnarValue::Column(Column::::Int(&[1, 2, 3])); + let res = columnar_value.into_column(2, &bump); + assert_eq!( + res, + Err(ColumnarValueError::ColumnLengthMismatch { + columnar_value_length: 3, + attempt_to_convert_length: 2, + }) + ); + + let strings = ["a", "b", "c"]; + let scalars: Vec = strings.iter().map(Into::into).collect(); + let columnar_value = + ColumnarValue::Column(Column::::VarChar((&strings, &scalars))); + let res = columnar_value.into_column(0, &bump); + assert_eq!( + res, + Err(ColumnarValueError::ColumnLengthMismatch { + columnar_value_length: 3, + attempt_to_convert_length: 0, + }) + ); + + let columnar_value = ColumnarValue::Column(Column::::Boolean(&[])); + let res = columnar_value.into_column(1, &bump); + assert_eq!( + res, + Err(ColumnarValueError::ColumnLengthMismatch { + columnar_value_length: 0, + attempt_to_convert_length: 1, + }) + ); + } +} diff --git a/crates/proof-of-sql/src/base/database/mod.rs b/crates/proof-of-sql/src/base/database/mod.rs index e65b7efb5..822b798ee 100644 --- a/crates/proof-of-sql/src/base/database/mod.rs +++ b/crates/proof-of-sql/src/base/database/mod.rs @@ -15,30 +15,25 @@ pub use column_operation::{ mod column_operation_error; pub use column_operation_error::{ColumnOperationError, ColumnOperationResult}; +mod columnar_value; +pub use columnar_value::ColumnarValue; + mod literal_value; pub use literal_value::LiteralValue; mod table_ref; -pub use table_ref::TableRef; - -#[cfg(feature = "arrow")] -mod arrow_array_to_column_conversion; -#[cfg(feature = "arrow")] -pub use arrow_array_to_column_conversion::{ArrayRefExt, ArrowArrayToColumnConversionError}; - -#[cfg(feature = "arrow")] -mod record_batch_utility; #[cfg(feature = "arrow")] -pub use record_batch_utility::ToArrow; +pub use crate::base::arrow::{ + arrow_array_to_column_conversion::{ArrayRefExt, ArrowArrayToColumnConversionError}, + owned_and_arrow_conversions::OwnedArrowConversionError, + record_batch_utility::ToArrow, + scalar_and_i256_conversions, +}; +pub use table_ref::TableRef; #[cfg(feature = "arrow")] pub mod arrow_schema_utility; -#[cfg(all(test, feature = "arrow", feature = "test"))] -mod test_accessor_utility; -#[cfg(all(test, feature = "arrow", feature = "test"))] -pub use test_accessor_utility::{make_random_test_accessor_data, RandomTestAccessorDescriptor}; - mod owned_column; pub(crate) use owned_column::compare_indexes_by_owned_columns_with_direction; pub use owned_column::OwnedColumn; @@ -63,13 +58,6 @@ mod expression_evaluation_error; mod expression_evaluation_test; pub use expression_evaluation_error::{ExpressionEvaluationError, ExpressionEvaluationResult}; -#[cfg(feature = "arrow")] -mod owned_and_arrow_conversions; -#[cfg(feature = "arrow")] -pub use owned_and_arrow_conversions::OwnedArrowConversionError; -#[cfg(all(test, feature = "arrow"))] -mod owned_and_arrow_conversions_test; - mod test_accessor; pub use test_accessor::TestAccessor; #[cfg(test)] @@ -84,9 +72,6 @@ mod owned_table_test_accessor; pub use owned_table_test_accessor::OwnedTableTestAccessor; #[cfg(all(test, feature = "blitzar"))] mod owned_table_test_accessor_test; -/// Contains traits for scalar <-> i256 conversions -#[cfg(feature = "arrow")] -pub mod scalar_and_i256_conversions; /// TODO: add docs pub(crate) mod filter_util; diff --git a/crates/proof-of-sql/src/base/database/test_accessor_utility.rs b/crates/proof-of-sql/src/base/database/test_accessor_utility.rs deleted file mode 100644 index 2b06081dd..000000000 --- a/crates/proof-of-sql/src/base/database/test_accessor_utility.rs +++ /dev/null @@ -1,218 +0,0 @@ -use crate::base::database::ColumnType; -use arrow::{ - array::{ - Array, BooleanArray, Decimal128Array, Decimal256Array, Int16Array, Int32Array, Int64Array, - Int8Array, StringArray, TimestampMicrosecondArray, TimestampMillisecondArray, - TimestampNanosecondArray, TimestampSecondArray, - }, - datatypes::{i256, DataType, Field, Schema, TimeUnit}, - record_batch::RecordBatch, -}; -use proof_of_sql_parser::posql_time::PoSQLTimeUnit; -use rand::{ - distributions::{Distribution, Uniform}, - rngs::StdRng, -}; -use std::sync::Arc; - -/// Specify what form a randomly generated `TestAccessor` can take -pub struct RandomTestAccessorDescriptor { - /// The minimum number of rows in the generated `RecordBatch` - pub min_rows: usize, - /// The maximum number of rows in the generated `RecordBatch` - pub max_rows: usize, - /// The minimum value of the generated data - pub min_value: i64, - /// The maximum value of the generated data - pub max_value: i64, -} - -impl Default for RandomTestAccessorDescriptor { - fn default() -> Self { - Self { - min_rows: 0, - max_rows: 100, - min_value: -5, - max_value: 5, - } - } -} - -/// Generate a `DataFrame` with random data -/// -/// # Panics -/// -/// This function may panic in the following cases: -/// - If `Precision::new(7)` fails when creating a `Decimal75` column type, which would occur -/// if the precision is invalid. -/// - When calling `.unwrap()` on the result of `RecordBatch::try_new(schema, columns)`, which -/// will panic if the schema and columns do not align correctly or if there are any other -/// underlying errors. -#[allow(dead_code, clippy::too_many_lines)] -pub fn make_random_test_accessor_data( - rng: &mut StdRng, - cols: &[(&str, ColumnType)], - descriptor: &RandomTestAccessorDescriptor, -) -> RecordBatch { - let n = Uniform::new(descriptor.min_rows, descriptor.max_rows + 1).sample(rng); - let dist = Uniform::new(descriptor.min_value, descriptor.max_value + 1); - - let mut columns: Vec> = Vec::with_capacity(n); - let mut column_fields: Vec<_> = Vec::with_capacity(n); - - for (col_name, col_type) in cols { - let values: Vec = dist.sample_iter(&mut *rng).take(n).collect(); - - match col_type { - ColumnType::Boolean => { - column_fields.push(Field::new(*col_name, DataType::Boolean, false)); - let boolean_values: Vec = values.iter().map(|x| x % 2 != 0).collect(); - columns.push(Arc::new(BooleanArray::from(boolean_values))); - } - ColumnType::TinyInt => { - column_fields.push(Field::new(*col_name, DataType::Int8, false)); - let values: Vec = values - .iter() - .map(|x| ((*x >> 56) as i8)) // Shift right to align the lower 8 bits - .collect(); - columns.push(Arc::new(Int8Array::from(values))); - } - ColumnType::SmallInt => { - column_fields.push(Field::new(*col_name, DataType::Int16, false)); - let values: Vec = values - .iter() - .map(|x| ((*x >> 48) as i16)) // Shift right to align the lower 16 bits - .collect(); - columns.push(Arc::new(Int16Array::from(values))); - } - ColumnType::Int => { - column_fields.push(Field::new(*col_name, DataType::Int32, false)); - let values: Vec = values - .iter() - .map(|x| ((*x >> 32) as i32)) // Shift right to align the lower 32 bits - .collect(); - columns.push(Arc::new(Int32Array::from(values))); - } - ColumnType::BigInt => { - column_fields.push(Field::new(*col_name, DataType::Int64, false)); - let values: Vec = values.clone(); - columns.push(Arc::new(Int64Array::from(values))); - } - ColumnType::Int128 => { - column_fields.push(Field::new(*col_name, DataType::Decimal128(38, 0), false)); - - let values: Vec = values.iter().map(|x| i128::from(*x)).collect(); - columns.push(Arc::new( - Decimal128Array::from(values.clone()) - .with_precision_and_scale(38, 0) - .unwrap(), - )); - } - ColumnType::Decimal75(precision, scale) => { - column_fields.push(Field::new( - *col_name, - DataType::Decimal256(precision.value(), *scale), - false, - )); - - let values: Vec = values.iter().map(|x| i256::from(*x)).collect(); - columns.push(Arc::new( - Decimal256Array::from(values.clone()) - .with_precision_and_scale(precision.value(), *scale) - .unwrap(), - )); - } - ColumnType::VarChar => { - let col = &values - .iter() - .map(|v| "s".to_owned() + &v.to_string()[..]) - .collect::>()[..]; - let col: Vec<_> = col.iter().map(String::as_str).collect(); - - column_fields.push(Field::new(*col_name, DataType::Utf8, false)); - - columns.push(Arc::new(StringArray::from(col))); - } - ColumnType::Scalar => unimplemented!("Scalar columns are not supported by arrow"), - ColumnType::TimestampTZ(tu, tz) => { - column_fields.push(Field::new( - *col_name, - DataType::Timestamp( - match tu { - PoSQLTimeUnit::Second => TimeUnit::Second, - PoSQLTimeUnit::Millisecond => TimeUnit::Millisecond, - PoSQLTimeUnit::Microsecond => TimeUnit::Microsecond, - PoSQLTimeUnit::Nanosecond => TimeUnit::Nanosecond, - }, - Some(Arc::from(tz.to_string())), - ), - false, - )); - // Create the correct timestamp array based on the time unit - let timestamp_array: Arc = match tu { - PoSQLTimeUnit::Second => Arc::new(TimestampSecondArray::from(values.clone())), - PoSQLTimeUnit::Millisecond => { - Arc::new(TimestampMillisecondArray::from(values.clone())) - } - PoSQLTimeUnit::Microsecond => { - Arc::new(TimestampMicrosecondArray::from(values.clone())) - } - PoSQLTimeUnit::Nanosecond => { - Arc::new(TimestampNanosecondArray::from(values.clone())) - } - }; - columns.push(timestamp_array); - } - } - } - - let schema = Arc::new(Schema::new(column_fields)); - RecordBatch::try_new(schema, columns).unwrap() -} - -#[cfg(test)] -mod tests { - use super::*; - use crate::record_batch; - use rand_core::SeedableRng; - - #[test] - fn we_can_construct_a_random_test_data() { - let descriptor = RandomTestAccessorDescriptor::default(); - let mut rng = StdRng::from_seed([0u8; 32]); - let cols = [ - ("a", ColumnType::BigInt), - ("b", ColumnType::VarChar), - ("c", ColumnType::Int128), - ("d", ColumnType::SmallInt), - ("e", ColumnType::Int), - ("f", ColumnType::TinyInt), - ]; - - let data1 = make_random_test_accessor_data(&mut rng, &cols, &descriptor); - let data2 = make_random_test_accessor_data(&mut rng, &cols, &descriptor); - assert_ne!(data1.num_rows(), data2.num_rows()); - } - - #[test] - fn we_can_construct_a_random_test_data_with_the_correct_data() { - let descriptor = RandomTestAccessorDescriptor { - min_rows: 1, - max_rows: 1, - min_value: -2, - max_value: -2, - }; - let mut rng = StdRng::from_seed([0u8; 32]); - let cols = [ - ("b", ColumnType::BigInt), - ("a", ColumnType::VarChar), - ("c", ColumnType::Int128), - ]; - let data = make_random_test_accessor_data(&mut rng, &cols, &descriptor); - - assert_eq!( - data, - record_batch!("b" => [-2_i64], "a" => ["s-2"], "c" => [-2_i128]) - ); - } -} diff --git a/crates/proof-of-sql/src/base/mod.rs b/crates/proof-of-sql/src/base/mod.rs index ad5573639..657b855d1 100644 --- a/crates/proof-of-sql/src/base/mod.rs +++ b/crates/proof-of-sql/src/base/mod.rs @@ -1,5 +1,8 @@ //! This module contains basic shared functionalities of the library. /// TODO: add docs +#[cfg(feature = "arrow")] +pub mod arrow; + pub(crate) mod bit; pub mod commitment; pub mod database; diff --git a/crates/proof-of-sql/src/proof_primitive/dory/dynamic_dory_commitment_helper_cpu.rs b/crates/proof-of-sql/src/proof_primitive/dory/dynamic_dory_commitment_helper_cpu.rs index 8ed2ddbb5..b36e1177d 100644 --- a/crates/proof-of-sql/src/proof_primitive/dory/dynamic_dory_commitment_helper_cpu.rs +++ b/crates/proof-of-sql/src/proof_primitive/dory/dynamic_dory_commitment_helper_cpu.rs @@ -1,10 +1,14 @@ use super::{ - dynamic_dory_structure::row_and_column_from_index, pairings, DoryScalar, DynamicDoryCommitment, - G1Affine, G1Projective, ProverSetup, GT, + dynamic_dory_structure::{full_width_of_row, row_and_column_from_index, row_start_index}, + pairings, DoryScalar, DynamicDoryCommitment, G1Projective, ProverSetup, GT, }; -use crate::base::commitment::CommittableColumn; -use alloc::{vec, vec::Vec}; +use crate::base::{commitment::CommittableColumn, if_rayon, slice_ops::slice_cast}; +use alloc::vec::Vec; +use ark_ec::VariableBaseMSM; +use bytemuck::TransparentWrapper; use num_traits::Zero; +#[cfg(feature = "rayon")] +use rayon::iter::{IntoParallelIterator, IntoParallelRefIterator, ParallelIterator}; #[tracing::instrument(name = "compute_dory_commitment_impl (cpu)", level = "debug", skip_all)] /// # Panics @@ -13,6 +17,7 @@ use num_traits::Zero; /// - `setup.Gamma_1.last()` returns `None`, indicating that `Gamma_1` is empty. /// - `setup.Gamma_2.last()` returns `None`, indicating that `Gamma_2` is empty. /// - The indexing for `Gamma_2` with `first_row..=last_row` goes out of bounds. +#[allow(clippy::range_plus_one)] fn compute_dory_commitment_impl<'a, T>( column: &'a [T], offset: usize, @@ -22,18 +27,39 @@ where &'a T: Into, T: Sync, { + if column.is_empty() { + return DynamicDoryCommitment::default(); + } let Gamma_1 = setup.Gamma_1.last().unwrap(); let Gamma_2 = setup.Gamma_2.last().unwrap(); - let (first_row, _) = row_and_column_from_index(offset); - let (last_row, _) = row_and_column_from_index(offset + column.len() - 1); - let row_commits = column.iter().enumerate().fold( - vec![G1Projective::from(G1Affine::identity()); last_row - first_row + 1], - |mut row_commits, (i, v)| { - let (row, col) = row_and_column_from_index(i + offset); - row_commits[row - first_row] += Gamma_1[col] * v.into().0; - row_commits - }, - ); + let (first_row, first_col) = row_and_column_from_index(offset); + let (last_row, last_col) = row_and_column_from_index(offset + column.len() - 1); + + let row_commits: Vec<_> = if_rayon!( + (first_row..=last_row).into_par_iter(), + (first_row..=last_row) + ) + .map(|row| { + let width = full_width_of_row(row); + let row_start = row_start_index(row); + let (gamma_range, column_range) = if first_row == last_row { + (first_col..last_col + 1, 0..column.len()) + } else if row == 1 { + (1..2, (1 - offset)..(2 - offset)) + } else if row == first_row { + (first_col..width, 0..width - first_col) + } else if row == last_row { + (0..last_col + 1, column.len() - last_col - 1..column.len()) + } else { + (0..width, row_start - offset..width + row_start - offset) + }; + G1Projective::msm_unchecked( + &Gamma_1[gamma_range], + TransparentWrapper::peel_slice(&slice_cast::<_, DoryScalar>(&column[column_range])), + ) + }) + .collect(); + DynamicDoryCommitment(pairings::multi_pairing( row_commits, &Gamma_2[first_row..=last_row], @@ -70,8 +96,7 @@ pub(super) fn compute_dynamic_dory_commitments( offset: usize, setup: &ProverSetup, ) -> Vec { - committable_columns - .iter() + if_rayon!(committable_columns.par_iter(), committable_columns.iter()) .map(|column| { column .is_empty() diff --git a/crates/proof-of-sql/src/proof_primitive/dory/dynamic_dory_structure.rs b/crates/proof-of-sql/src/proof_primitive/dory/dynamic_dory_structure.rs index 2598db988..03f6ffaa3 100644 --- a/crates/proof-of-sql/src/proof_primitive/dory/dynamic_dory_structure.rs +++ b/crates/proof-of-sql/src/proof_primitive/dory/dynamic_dory_structure.rs @@ -40,7 +40,7 @@ pub(crate) const fn full_width_of_row(row: usize) -> usize { /// Returns the index that belongs in the first column in a particular row. /// /// Note: when row = 1, this correctly returns 0, even though no data belongs there. -#[cfg(test)] +#[cfg(any(test, not(feature = "blitzar")))] pub(crate) const fn row_start_index(row: usize) -> usize { let width_of_row = full_width_of_row(row); width_of_row * (row - width_of_row / 2) diff --git a/crates/proof-of-sql/src/sql/parse/query_expr_tests.rs b/crates/proof-of-sql/src/sql/parse/query_expr_tests.rs index 4f68869d9..0c1cfd965 100644 --- a/crates/proof-of-sql/src/sql/parse/query_expr_tests.rs +++ b/crates/proof-of-sql/src/sql/parse/query_expr_tests.rs @@ -54,25 +54,6 @@ pub fn schema_accessor_from_table_ref_with_schema( TestSchemaAccessor::new(indexmap! {table => schema}) } -fn get_test_accessor() -> (TableRef, TestSchemaAccessor) { - let table = "sxt.t".parse().unwrap(); - let accessor = schema_accessor_from_table_ref_with_schema( - table, - indexmap! { - "s".parse().unwrap() => ColumnType::VarChar, - "i".parse().unwrap() => ColumnType::BigInt, - "d".parse().unwrap() => ColumnType::Int128, - "s0".parse().unwrap() => ColumnType::VarChar, - "i0".parse().unwrap() => ColumnType::BigInt, - "d0".parse().unwrap() => ColumnType::Int128, - "s1".parse().unwrap() => ColumnType::VarChar, - "i1".parse().unwrap() => ColumnType::BigInt, - "d1".parse().unwrap() => ColumnType::Int128, - }, - ); - (table, accessor) -} - #[test] fn we_can_convert_an_ast_with_one_column() { let t = "sxt.sxt_tab".parse().unwrap(); @@ -1128,8 +1109,17 @@ fn we_can_group_by_without_using_aggregate_functions() { #[test] fn group_by_expressions_are_parsed_before_an_order_by_referencing_an_aggregate_alias_result() { let query_text = - "select max(i) max_sal, i0 d, count(i0) from sxt.t group by i0, i1 order by max_sal"; - let (t, accessor) = get_test_accessor(); + "select max(salary) max_sal, department_budget d, count(department_budget) from sxt.employees group by department_budget, tax order by max_sal"; + + let t = "sxt.employees".parse().unwrap(); + let accessor = schema_accessor_from_table_ref_with_schema( + t, + indexmap! { + "department_budget".parse().unwrap() => ColumnType::BigInt, + "salary".parse().unwrap() => ColumnType::BigInt, + "tax".parse().unwrap() => ColumnType::BigInt, + }, + ); let intermediate_ast = SelectStatementParser::new().parse(query_text).unwrap(); let query = @@ -1138,20 +1128,20 @@ fn group_by_expressions_are_parsed_before_an_order_by_referencing_an_aggregate_a let expected_query = QueryExpr::new( filter( vec![ - col_expr_plan(t, "i", &accessor), - col_expr_plan(t, "i0", &accessor), - col_expr_plan(t, "i1", &accessor), + col_expr_plan(t, "department_budget", &accessor), + col_expr_plan(t, "salary", &accessor), + col_expr_plan(t, "tax", &accessor), ], tab(t), const_bool(true), ), vec![ group_by_postprocessing( - &["i0", "i1"], + &["department_budget", "tax"], &[ - aliased_expr(max(col("i")), "max_sal"), - aliased_expr(col("i0"), "d"), - aliased_expr(count(col("i0")), "__count__"), + aliased_expr(max(col("salary")), "max_sal"), + aliased_expr(col("department_budget"), "d"), + aliased_expr(count(col("department_budget")), "__count__"), ], ), orders(&["max_sal"], &[Asc]), @@ -1240,8 +1230,14 @@ fn group_by_column_cannot_be_a_column_result_alias() { #[test] fn we_can_have_aggregate_functions_without_a_group_by_clause() { - let query_text = "select count(s) from sxt.t"; - let (t, accessor) = get_test_accessor(); + let query_text = "select count(name) from sxt.employees"; + let t = "sxt.employees".parse().unwrap(); + let accessor = schema_accessor_from_table_ref_with_schema( + t, + indexmap! { + "name".parse().unwrap() => ColumnType::VarChar, + }, + ); let intermediate_ast = SelectStatementParser::new().parse(query_text).unwrap(); let ast = @@ -1401,8 +1397,17 @@ fn we_can_use_the_same_result_columns_with_different_aliases_and_associate_it_wi #[test] fn we_can_use_multiple_group_by_clauses_with_multiple_agg_and_non_agg_exprs() { - let (t, accessor) = get_test_accessor(); - let query_text = "select i d1, max(i1), i d2, sum(i0) sum_bonus, count(s) count_s from sxt.t group by i, i0, i"; + let t = "sxt.employees".parse().unwrap(); + let accessor = schema_accessor_from_table_ref_with_schema( + t, + indexmap! { + "bonus".parse().unwrap() => ColumnType::BigInt, + "name".parse().unwrap() => ColumnType::VarChar, + "salary".parse().unwrap() => ColumnType::BigInt, + "tax".parse().unwrap() => ColumnType::BigInt, + }, + ); + let query_text = "select salary d1, max(tax), salary d2, sum(bonus) sum_bonus, count(name) count_s from sxt.employees group by salary, bonus, salary"; let intermediate_ast = SelectStatementParser::new().parse(query_text).unwrap(); let ast = @@ -1410,18 +1415,18 @@ fn we_can_use_multiple_group_by_clauses_with_multiple_agg_and_non_agg_exprs() { let expected_ast = QueryExpr::new( filter( - cols_expr_plan(t, &["i", "i0", "i1", "s"], &accessor), + cols_expr_plan(t, &["bonus", "name", "salary", "tax"], &accessor), tab(t), const_bool(true), ), vec![group_by_postprocessing( - &["i", "i0", "i"], + &["salary", "bonus", "salary"], &[ - aliased_expr(col("i"), "d1"), - aliased_expr(max(col("i1")), "__max__"), - aliased_expr(col("i"), "d2"), - aliased_expr(sum(col("i0")), "sum_bonus"), - aliased_expr(count(col("s")), "count_s"), + aliased_expr(col("salary"), "d1"), + aliased_expr(max(col("tax")), "__max__"), + aliased_expr(col("salary"), "d2"), + aliased_expr(sum(col("bonus")), "sum_bonus"), + aliased_expr(count(col("name")), "count_s"), ], )], ); @@ -1567,12 +1572,19 @@ fn we_can_parse_arithmetic_expression_within_aggregations_in_the_result_expr() { #[test] fn we_cannot_use_non_grouped_columns_outside_agg() { - let (t, accessor) = get_test_accessor(); + let t = "sxt.employees".parse().unwrap(); + let accessor = schema_accessor_from_table_ref_with_schema( + t, + indexmap! { + "salary".parse().unwrap() => ColumnType::BigInt, + "name".parse().unwrap() => ColumnType::VarChar, + }, + ); let identifier_not_in_agg_queries = vec![ - "select i from sxt.t group by s", - "select sum(i), i from sxt.t group by s", - "select min(i) + i from sxt.t group by s", - "select 2 * i, min(i) from sxt.t group by s", + "select salary from sxt.employees group by name", + "select sum(salary), salary from sxt.employees group by name", + "select min(salary) + salary from sxt.employees group by name", + "select 2 * salary, min(salary) from sxt.employees group by name", ]; for query_text in &identifier_not_in_agg_queries { @@ -1589,9 +1601,9 @@ fn we_cannot_use_non_grouped_columns_outside_agg() { } let invalid_group_by_queries = vec![ - "select 2 * i, min(i) from sxt.t", - "select sum(i), i from sxt.t", - "select max(i) + 2 * i from sxt.t", + "select 2 * salary, min(salary) from sxt.employees", + "select sum(salary), salary from sxt.employees", + "select max(salary) + 2 * salary from sxt.employees", ]; for query_text in &invalid_group_by_queries { @@ -1608,11 +1620,23 @@ fn we_cannot_use_non_grouped_columns_outside_agg() { #[test] fn varchar_column_is_not_compatible_with_integer_column() { - let bigint_to_varchar_queries = vec!["select -123 * s from sxt.t", "select i - s from sxt.t"]; - let (t, accessor) = get_test_accessor(); + let t = "sxt.employees".parse().unwrap(); + let accessor = schema_accessor_from_table_ref_with_schema( + t, + indexmap! { + "salary".parse().unwrap() => ColumnType::BigInt, + "name".parse().unwrap() => ColumnType::VarChar, + }, + ); + + let bigint_to_varchar_queries = vec![ + "select -123 * name from sxt.employees", + "select salary - name from sxt.employees", + ]; + let varchar_to_bigint_queries = vec![ - "select s from sxt.t where 'abc' = i", - "select s from sxt.t where 'abc' != i", + "select name from sxt.employees where 'abc' = salary", + "select name from sxt.employees where 'abc' != salary", ]; for query_text in &bigint_to_varchar_queries { @@ -1646,8 +1670,16 @@ fn varchar_column_is_not_compatible_with_integer_column() { #[test] fn arithmetic_operations_are_not_allowed_with_varchar_column() { - let (t, accessor) = get_test_accessor(); - let query_text = "select s - s1 from sxt.t"; + let t = "sxt.employees".parse().unwrap(); + let accessor = schema_accessor_from_table_ref_with_schema( + t, + indexmap! { + "name".parse().unwrap() => ColumnType::VarChar, + "position".parse().unwrap() => ColumnType::VarChar, + }, + ); + + let query_text = "select name - position from sxt.employees"; let intermediate_ast = SelectStatementParser::new().parse(query_text).unwrap(); let result = QueryExpr::::try_new(intermediate_ast, t.schema_id(), &accessor); @@ -1662,8 +1694,14 @@ fn arithmetic_operations_are_not_allowed_with_varchar_column() { #[test] fn varchar_column_is_not_allowed_within_numeric_aggregations() { - let (t, accessor) = get_test_accessor(); - let sum_query = "select sum(s) from sxt.t"; + let t = "sxt.employees".parse().unwrap(); + let accessor = schema_accessor_from_table_ref_with_schema( + t, + indexmap! { + "name".parse().unwrap() => ColumnType::VarChar, + }, + ); + let sum_query = "select sum(name) from sxt.employees"; let intermediate_ast = SelectStatementParser::new().parse(sum_query).unwrap(); let result = QueryExpr::::try_new(intermediate_ast, t.schema_id(), &accessor); @@ -1673,7 +1711,7 @@ fn varchar_column_is_not_allowed_within_numeric_aggregations() { if expression == "cannot use expression of type 'varchar' with numeric aggregation function 'sum'" )); - let max_query = "select max(s) from sxt.t"; + let max_query = "select max(name) from sxt.employees"; let intermediate_ast = SelectStatementParser::new().parse(max_query).unwrap(); let result = QueryExpr::::try_new(intermediate_ast, t.schema_id(), &accessor); @@ -1683,7 +1721,7 @@ fn varchar_column_is_not_allowed_within_numeric_aggregations() { if expression == "cannot use expression of type 'varchar' with numeric aggregation function 'max'" )); - let min_query = "select min(s) from sxt.t"; + let min_query = "select min(name) from sxt.employees"; let intermediate_ast = SelectStatementParser::new().parse(min_query).unwrap(); let result = QueryExpr::::try_new(intermediate_ast, t.schema_id(), &accessor); @@ -1696,8 +1734,14 @@ fn varchar_column_is_not_allowed_within_numeric_aggregations() { #[test] fn group_by_with_bigint_column_is_valid() { - let (t, accessor) = get_test_accessor(); - let query_text = "select i from sxt.t group by i"; + let t = "sxt.employees".parse().unwrap(); + let accessor = schema_accessor_from_table_ref_with_schema( + t, + indexmap! { + "salary".parse().unwrap() => ColumnType::BigInt, + }, + ); + let query_text = "select salary from sxt.employees group by salary"; let intermediate_ast = SelectStatementParser::new().parse(query_text).unwrap(); let query = @@ -1705,13 +1749,13 @@ fn group_by_with_bigint_column_is_valid() { let expected_query = QueryExpr::new( filter( - cols_expr_plan(t, &["i"], &accessor), + cols_expr_plan(t, &["salary"], &accessor), tab(t), const_bool(true), ), vec![group_by_postprocessing( - &["i"], - &[aliased_expr(col("i"), "i")], + &["salary"], + &[aliased_expr(col("salary"), "salary")], )], ); assert_eq!(query, expected_query); @@ -1719,8 +1763,14 @@ fn group_by_with_bigint_column_is_valid() { #[test] fn group_by_with_decimal_column_is_valid() { - let (t, accessor) = get_test_accessor(); - let query_text = "select d from sxt.t group by d"; + let t = "sxt.employees".parse().unwrap(); + let accessor = schema_accessor_from_table_ref_with_schema( + t, + indexmap! { + "salary".parse().unwrap() => ColumnType::Int128, + }, + ); + let query_text = "select salary from sxt.employees group by salary"; let intermediate_ast = SelectStatementParser::new().parse(query_text).unwrap(); let query = @@ -1728,13 +1778,13 @@ fn group_by_with_decimal_column_is_valid() { let expected_query = QueryExpr::new( filter( - cols_expr_plan(t, &["d"], &accessor), + cols_expr_plan(t, &["salary"], &accessor), tab(t), const_bool(true), ), vec![group_by_postprocessing( - &["d"], - &[aliased_expr(col("d"), "d")], + &["salary"], + &[aliased_expr(col("salary"), "salary")], )], ); assert_eq!(query, expected_query); @@ -1742,8 +1792,14 @@ fn group_by_with_decimal_column_is_valid() { #[test] fn group_by_with_varchar_column_is_valid() { - let (t, accessor) = get_test_accessor(); - let query_text = "select s from sxt.t group by s"; + let t = "sxt.employees".parse().unwrap(); + let accessor = schema_accessor_from_table_ref_with_schema( + t, + indexmap! { + "name".parse().unwrap() => ColumnType::VarChar, + }, + ); + let query_text = "select name from sxt.employees group by name"; let intermediate_ast = SelectStatementParser::new().parse(query_text).unwrap(); let query = @@ -1751,13 +1807,13 @@ fn group_by_with_varchar_column_is_valid() { let expected_query = QueryExpr::new( filter( - cols_expr_plan(t, &["s"], &accessor), + cols_expr_plan(t, &["name"], &accessor), tab(t), const_bool(true), ), vec![group_by_postprocessing( - &["s"], - &[aliased_expr(col("s"), "s")], + &["name"], + &[aliased_expr(col("name"), "name")], )], ); assert_eq!(query, expected_query); @@ -1765,8 +1821,16 @@ fn group_by_with_varchar_column_is_valid() { #[test] fn we_can_use_arithmetic_outside_agg_expressions_while_using_group_by() { - let (t, accessor) = get_test_accessor(); - let query_text = "select 2 * i + sum(i) - i1 from sxt.t group by i, i1"; + let t = "sxt.employees".parse().unwrap(); + let accessor = schema_accessor_from_table_ref_with_schema( + t, + indexmap! { + "salary".parse().unwrap() => ColumnType::BigInt, + "tax".parse().unwrap() => ColumnType::BigInt, + }, + ); + let query_text = + "select 2 * salary + sum(salary) - tax from sxt.employees group by salary, tax"; let intermediate_ast = SelectStatementParser::new().parse(query_text).unwrap(); let query = @@ -1774,20 +1838,26 @@ fn we_can_use_arithmetic_outside_agg_expressions_while_using_group_by() { let expected_query = QueryExpr::new( filter( - cols_expr_plan(t, &["i", "i1"], &accessor), + cols_expr_plan(t, &["salary", "tax"], &accessor), tab(t), const_bool(true), ), vec![ group_by_postprocessing( - &["i", "i1"], + &["salary", "tax"], &[aliased_expr( - psub(padd(pmul(lit(2), col("i")), sum(col("i"))), col("i1")), + psub( + padd(pmul(lit(2), col("salary")), sum(col("salary"))), + col("tax"), + ), "__expr__", )], ), select_expr(&[aliased_expr( - psub(padd(pmul(lit(2), col("i")), col("__col_agg_0")), col("i1")), + psub( + padd(pmul(lit(2), col("salary")), col("__col_agg_0")), + col("tax"), + ), "__expr__", )]), ], @@ -1797,8 +1867,15 @@ fn we_can_use_arithmetic_outside_agg_expressions_while_using_group_by() { #[test] fn we_can_use_arithmetic_outside_agg_expressions_without_using_group_by() { - let (t, accessor) = get_test_accessor(); - let query_text = "select 7 + max(i) as max_i, min(i + 777 * d) * -5 as min_d from t"; + let t = "sxt.employees".parse().unwrap(); + let accessor = schema_accessor_from_table_ref_with_schema( + t, + indexmap! { + "salary".parse().unwrap() => ColumnType::BigInt, + "bonus".parse().unwrap() => ColumnType::Int128, + }, + ); + let query_text = "select 7 + max(salary) as max_i, min(salary + 777 * bonus) * -5 as min_d from sxt.employees"; let intermediate_ast = SelectStatementParser::new().parse(query_text).unwrap(); let ast = @@ -1806,7 +1883,7 @@ fn we_can_use_arithmetic_outside_agg_expressions_without_using_group_by() { let expected_ast = QueryExpr::new( filter( - cols_expr_plan(t, &["d", "i"], &accessor), + cols_expr_plan(t, &["bonus", "salary"], &accessor), tab(t), const_bool(true), ), @@ -1814,9 +1891,12 @@ fn we_can_use_arithmetic_outside_agg_expressions_without_using_group_by() { group_by_postprocessing( &[], &[ - aliased_expr(padd(lit(7), max(col("i"))), "max_i"), + aliased_expr(padd(lit(7), max(col("salary"))), "max_i"), aliased_expr( - pmul(min(padd(col("i"), pmul(lit(777), col("d")))), lit(-5)), + pmul( + min(padd(col("salary"), pmul(lit(777), col("bonus")))), + lit(-5), + ), "min_d", ), ], @@ -1832,8 +1912,17 @@ fn we_can_use_arithmetic_outside_agg_expressions_without_using_group_by() { #[test] fn count_aggregation_always_have_integer_type() { - let (t, accessor) = get_test_accessor(); - let query_text = "select 7 + count(s) as cs, count(i) * -5 as ci, count(d) from t"; + let t = "sxt.employees".parse().unwrap(); + let accessor = schema_accessor_from_table_ref_with_schema( + t, + indexmap! { + "name".parse().unwrap() => ColumnType::VarChar, + "salary".parse().unwrap() => ColumnType::BigInt, + "tax".parse().unwrap() => ColumnType::Int128, + }, + ); + let query_text = + "select 7 + count(name) as cs, count(salary) * -5 as ci, count(tax) from sxt.employees"; let intermediate_ast = SelectStatementParser::new().parse(query_text).unwrap(); let ast = @@ -1841,7 +1930,7 @@ fn count_aggregation_always_have_integer_type() { let expected_ast = QueryExpr::new( filter( - cols_expr_plan(t, &["d", "i", "s"], &accessor), + cols_expr_plan(t, &["name", "salary", "tax"], &accessor), tab(t), const_bool(true), ), @@ -1849,9 +1938,9 @@ fn count_aggregation_always_have_integer_type() { group_by_postprocessing( &[], &[ - aliased_expr(padd(lit(7), count(col("s"))), "cs"), - aliased_expr(pmul(count(col("i")), lit(-5)), "ci"), - aliased_expr(count(col("d")), "__count__"), + aliased_expr(padd(lit(7), count(col("name"))), "cs"), + aliased_expr(pmul(count(col("salary")), lit(-5)), "ci"), + aliased_expr(count(col("tax")), "__count__"), ], ), select_expr(&[ @@ -1866,17 +1955,41 @@ fn count_aggregation_always_have_integer_type() { #[test] fn select_wildcard_is_valid_with_group_by_exprs() { - let columns = ["s", "i", "d", "s0", "i0", "d0", "s1", "i1", "d1"]; + let columns = [ + "employee_name", + "base_salary", + "annual_bonus", + "manager_name", + "manager_salary", + "manager_bonus", + "department_name", + "department_budget", + "department_headcount", + ]; let sorted_columns = columns.iter().sorted().collect::>(); let aliased_exprs = columns .iter() .map(|c| aliased_expr(col(c), c)) .collect::>(); - let (t, accessor) = get_test_accessor(); - let table_name = "sxt.t"; + let t = "sxt.employees".parse().unwrap(); + let accessor = schema_accessor_from_table_ref_with_schema( + t, + indexmap! { + "employee_name".parse().unwrap() => ColumnType::VarChar, + "base_salary".parse().unwrap() => ColumnType::BigInt, + "annual_bonus".parse().unwrap() => ColumnType::Int128, + "manager_name".parse().unwrap() => ColumnType::VarChar, + "manager_salary".parse().unwrap() => ColumnType::BigInt, + "manager_bonus".parse().unwrap() => ColumnType::Int128, + "department_name".parse().unwrap() => ColumnType::VarChar, + "department_budget".parse().unwrap() => ColumnType::BigInt, + "department_headcount".parse().unwrap() => ColumnType::Int128, + }, + ); + let query_text = format!( "SELECT * FROM {} GROUP BY {}", - table_name, + "sxt.employees", columns.join(", ") ); @@ -1901,10 +2014,19 @@ fn select_wildcard_is_valid_with_group_by_exprs() { #[test] fn nested_aggregations_are_not_supported() { let supported_agg = ["max", "min", "sum", "count"]; - let (t, accessor) = get_test_accessor(); + let t = "sxt.employees".parse().unwrap(); + let accessor = schema_accessor_from_table_ref_with_schema( + t, + indexmap! { + "salary".parse().unwrap() => ColumnType::BigInt, + }, + ); for perm_aggs in supported_agg.iter().permutations(2) { - let query_text = format!("SELECT {}({}(i)) FROM t", perm_aggs[0], perm_aggs[1]); + let query_text = format!( + "SELECT {}({}(salary)) FROM sxt.employees", + perm_aggs[0], perm_aggs[1] + ); let intermediate_ast = SelectStatementParser::new().parse(&query_text).unwrap(); let result = @@ -1922,8 +2044,17 @@ fn nested_aggregations_are_not_supported() { #[test] fn select_group_and_order_by_preserve_the_column_order_reference() { const N: usize = 4; - let (t, accessor) = get_test_accessor(); - let base_cols: [&str; N] = ["i", "i0", "i1", "s"]; // sorted because of `select: [cols = ... ]` + let t = "sxt.employees".parse().unwrap(); + let accessor = schema_accessor_from_table_ref_with_schema( + t, + indexmap! { + "salary".parse().unwrap() => ColumnType::BigInt, + "department".parse().unwrap() => ColumnType::BigInt, + "tax".parse().unwrap() => ColumnType::BigInt, + "name".parse().unwrap() => ColumnType::VarChar, + }, + ); + let base_cols: [&str; N] = ["salary", "department", "tax", "name"]; // sorted because of `select: [cols = ... ]` let base_ordering = [Asc, Desc, Asc, Desc]; for (idx, perm_cols) in base_cols .into_iter() diff --git a/crates/proof-of-sql/src/sql/proof/mod.rs b/crates/proof-of-sql/src/sql/proof/mod.rs index 48139dc22..b33be315c 100644 --- a/crates/proof-of-sql/src/sql/proof/mod.rs +++ b/crates/proof-of-sql/src/sql/proof/mod.rs @@ -25,8 +25,6 @@ pub(crate) use provable_result_column::ProvableResultColumn; mod provable_query_result; pub use provable_query_result::ProvableQueryResult; -#[cfg(all(test, feature = "arrow"))] -mod provable_query_result_test; mod sumcheck_mle_evaluations; pub(crate) use sumcheck_mle_evaluations::SumcheckMleEvaluations; @@ -70,3 +68,6 @@ pub(crate) use result_element_serialization::{ mod first_round_builder; pub(crate) use first_round_builder::FirstRoundBuilder; + +#[cfg(all(test, feature = "arrow"))] +mod provable_query_result_test; diff --git a/crates/proof-of-sql/src/sql/proof/proof_plan.rs b/crates/proof-of-sql/src/sql/proof/proof_plan.rs index 430485308..42ceceab1 100644 --- a/crates/proof-of-sql/src/sql/proof/proof_plan.rs +++ b/crates/proof-of-sql/src/sql/proof/proof_plan.rs @@ -3,7 +3,7 @@ use crate::base::{ commitment::Commitment, database::{ Column, ColumnField, ColumnRef, CommitmentAccessor, DataAccessor, MetadataAccessor, - OwnedTable, + OwnedTable, TableRef, }, map::IndexSet, proof::ProofError, @@ -46,6 +46,9 @@ pub trait ProofPlan: Debug + Send + Sync + ProverEvaluate IndexSet; + + /// Return all the tables referenced in the Query + fn get_table_references(&self) -> IndexSet; } pub trait ProverEvaluate { diff --git a/crates/proof-of-sql/src/sql/proof/query_proof_test.rs b/crates/proof-of-sql/src/sql/proof/query_proof_test.rs index a4fa8a65a..e6e685673 100644 --- a/crates/proof-of-sql/src/sql/proof/query_proof_test.rs +++ b/crates/proof-of-sql/src/sql/proof/query_proof_test.rs @@ -7,7 +7,7 @@ use crate::{ database::{ owned_table_utility::{bigint, owned_table}, Column, ColumnField, ColumnRef, ColumnType, CommitmentAccessor, DataAccessor, - MetadataAccessor, OwnedTable, OwnedTableTestAccessor, TestAccessor, + MetadataAccessor, OwnedTable, OwnedTableTestAccessor, TableRef, TestAccessor, UnimplementedTestAccessor, }, map::IndexSet, @@ -109,6 +109,9 @@ impl ProofPlan for TrivialTestProofPlan { fn get_column_references(&self) -> IndexSet { unimplemented!("no real usage for this function yet") } + fn get_table_references(&self) -> IndexSet { + unimplemented!("no real usage for this function yet") + } } fn verify_a_trivial_query_proof_with_given_offset(n: usize, offset_generators: usize) { @@ -278,6 +281,9 @@ impl ProofPlan for SquareTestProofPlan { fn get_column_references(&self) -> IndexSet { unimplemented!("no real usage for this function yet") } + fn get_table_references(&self) -> IndexSet { + unimplemented!("no real usage for this function yet") + } } fn verify_a_proof_with_an_anchored_commitment_and_given_offset(offset_generators: usize) { @@ -481,6 +487,9 @@ impl ProofPlan for DoubleSquareTestProofPlan { fn get_column_references(&self) -> IndexSet { unimplemented!("no real usage for this function yet") } + fn get_table_references(&self) -> IndexSet { + unimplemented!("no real usage for this function yet") + } } fn verify_a_proof_with_an_intermediate_commitment_and_given_offset(offset_generators: usize) { @@ -677,6 +686,9 @@ impl ProofPlan for ChallengeTestProofPlan { fn get_column_references(&self) -> IndexSet { unimplemented!("no real usage for this function yet") } + fn get_table_references(&self) -> IndexSet { + unimplemented!("no real usage for this function yet") + } } fn verify_a_proof_with_a_post_result_challenge_and_given_offset(offset_generators: usize) { diff --git a/crates/proof-of-sql/src/sql/proof/query_result.rs b/crates/proof-of-sql/src/sql/proof/query_result.rs index 31b9ad994..647e4ad0b 100644 --- a/crates/proof-of-sql/src/sql/proof/query_result.rs +++ b/crates/proof-of-sql/src/sql/proof/query_result.rs @@ -3,8 +3,6 @@ use crate::base::{ proof::ProofError, scalar::Scalar, }; -#[cfg(feature = "arrow")] -use arrow::{error::ArrowError, record_batch::RecordBatch}; use snafu::Snafu; /// Verifiable query errors @@ -54,22 +52,5 @@ pub struct QueryData { pub verification_hash: [u8; 32], } -impl QueryData { - #[cfg(all(test, feature = "arrow"))] - #[must_use] - pub fn into_record_batch(self) -> RecordBatch { - self.try_into().unwrap() - } -} - -#[cfg(feature = "arrow")] -impl TryFrom> for RecordBatch { - type Error = ArrowError; - - fn try_from(value: QueryData) -> Result { - Self::try_from(value.table) - } -} - /// The result of a query -- either an error or a table. pub type QueryResult = Result, QueryError>; diff --git a/crates/proof-of-sql/src/sql/proof/verifiable_query_result_test.rs b/crates/proof-of-sql/src/sql/proof/verifiable_query_result_test.rs index 5d299e408..d2db5df0e 100644 --- a/crates/proof-of-sql/src/sql/proof/verifiable_query_result_test.rs +++ b/crates/proof-of-sql/src/sql/proof/verifiable_query_result_test.rs @@ -8,7 +8,7 @@ use crate::{ database::{ owned_table_utility::{bigint, owned_table}, Column, ColumnField, ColumnRef, ColumnType, CommitmentAccessor, DataAccessor, - MetadataAccessor, OwnedTable, TestAccessor, UnimplementedTestAccessor, + MetadataAccessor, OwnedTable, TableRef, TestAccessor, UnimplementedTestAccessor, }, map::IndexSet, proof::ProofError, @@ -88,6 +88,10 @@ impl ProofPlan for EmptyTestQueryExpr { fn get_column_references(&self) -> IndexSet { unimplemented!("no real usage for this function yet") } + + fn get_table_references(&self) -> IndexSet { + unimplemented!("no real usage for this function yet") + } } #[test] diff --git a/crates/proof-of-sql/src/sql/proof_exprs/comparison_util.rs b/crates/proof-of-sql/src/sql/proof_exprs/comparison_util.rs index c25445340..c0a8cc291 100644 --- a/crates/proof-of-sql/src/sql/proof_exprs/comparison_util.rs +++ b/crates/proof-of-sql/src/sql/proof_exprs/comparison_util.rs @@ -1,15 +1,15 @@ use crate::{ base::{ - database::Column, + database::{Column, ColumnarValue, LiteralValue}, if_rayon, math::decimal::{DecimalError, Precision}, - scalar::Scalar, + scalar::{Scalar, ScalarExt}, }, sql::parse::{type_check_binary_operation, ConversionError, ConversionResult}, }; use alloc::string::ToString; use bumpalo::Bump; -use core::cmp; +use core::cmp::{max, Ordering}; use proof_of_sql_parser::intermediate_ast::BinaryOperator; #[cfg(feature = "rayon")] use rayon::iter::{IndexedParallelIterator, IntoParallelRefMutIterator, ParallelIterator}; @@ -31,6 +31,70 @@ fn unchecked_subtract_impl<'a, S: Scalar>( Ok(result) } +/// Scale LHS and RHS to the same scale if at least one of them is decimal +/// and take the difference. This function is used for comparisons. +/// +/// # Panics +/// This function will panic if `lhs` and `rhs` have [`ColumnType`]s that are not comparable +/// or if we have precision overflow issues. +#[allow(clippy::cast_sign_loss)] +pub fn scale_and_subtract_literal( + lhs: &LiteralValue, + rhs: &LiteralValue, + lhs_scale: i8, + rhs_scale: i8, + is_equal: bool, +) -> ConversionResult { + let lhs_type = lhs.column_type(); + let rhs_type = rhs.column_type(); + let operator = if is_equal { + BinaryOperator::Equal + } else { + BinaryOperator::LessThanOrEqual + }; + if !type_check_binary_operation(&lhs_type, &rhs_type, operator) { + return Err(ConversionError::DataTypeMismatch { + left_type: lhs_type.to_string(), + right_type: rhs_type.to_string(), + }); + } + let max_scale = max(lhs_scale, rhs_scale); + let lhs_upscale = max_scale - lhs_scale; + let rhs_upscale = max_scale - rhs_scale; + // Only check precision overflow issues if at least one side is decimal + if max_scale != 0 { + let lhs_precision_value = lhs_type + .precision_value() + .expect("If scale is set, precision must be set"); + let rhs_precision_value = rhs_type + .precision_value() + .expect("If scale is set, precision must be set"); + let max_precision_value = max( + lhs_precision_value + (max_scale - lhs_scale) as u8, + rhs_precision_value + (max_scale - rhs_scale) as u8, + ); + // Check if the precision is valid + let _max_precision = Precision::new(max_precision_value).map_err(|_| { + ConversionError::DecimalConversionError { + source: DecimalError::InvalidPrecision { + error: max_precision_value.to_string(), + }, + } + })?; + } + match lhs_scale.cmp(&rhs_scale) { + Ordering::Less => { + let upscale_factor = S::pow10(rhs_upscale as u8); + Ok(lhs.to_scalar() * upscale_factor - rhs.to_scalar()) + } + Ordering::Equal => Ok(lhs.to_scalar() - rhs.to_scalar()), + Ordering::Greater => { + let upscale_factor = S::pow10(lhs_upscale as u8); + Ok(lhs.to_scalar() - rhs.to_scalar() * upscale_factor) + } + } +} + #[allow( clippy::missing_panics_doc, reason = "precision and scale are validated prior to calling this function, ensuring no panic occurs" @@ -67,7 +131,7 @@ pub(crate) fn scale_and_subtract<'a, S: Scalar>( right_type: rhs_type.to_string(), }); } - let max_scale = cmp::max(lhs_scale, rhs_scale); + let max_scale = max(lhs_scale, rhs_scale); let lhs_upscale = max_scale - lhs_scale; let rhs_upscale = max_scale - rhs_scale; // Only check precision overflow issues if at least one side is decimal @@ -78,7 +142,7 @@ pub(crate) fn scale_and_subtract<'a, S: Scalar>( let rhs_precision_value = rhs_type .precision_value() .expect("If scale is set, precision must be set"); - let max_precision_value = cmp::max( + let max_precision_value = max( lhs_precision_value + (max_scale - lhs_scale) as u8, rhs_precision_value + (max_scale - rhs_scale) as u8, ); @@ -98,3 +162,49 @@ pub(crate) fn scale_and_subtract<'a, S: Scalar>( lhs_len, ) } + +#[allow(clippy::cast_sign_loss)] +#[allow(dead_code)] +/// Scale LHS and RHS to the same scale if at least one of them is decimal +/// and take the difference. This function is used for comparisons. +pub(crate) fn scale_and_subtract_columnar_value<'a, S: Scalar>( + alloc: &'a Bump, + lhs: ColumnarValue<'a, S>, + rhs: ColumnarValue<'a, S>, + lhs_scale: i8, + rhs_scale: i8, + is_equal: bool, +) -> ConversionResult> { + match (lhs, rhs) { + (ColumnarValue::Column(lhs), ColumnarValue::Column(rhs)) => { + Ok(ColumnarValue::Column(Column::Scalar(scale_and_subtract( + alloc, lhs, rhs, lhs_scale, rhs_scale, is_equal, + )?))) + } + (ColumnarValue::Literal(lhs), ColumnarValue::Column(rhs)) => { + Ok(ColumnarValue::Column(Column::Scalar(scale_and_subtract( + alloc, + Column::from_literal_with_length(&lhs, rhs.len(), alloc), + rhs, + lhs_scale, + rhs_scale, + is_equal, + )?))) + } + (ColumnarValue::Column(lhs), ColumnarValue::Literal(rhs)) => { + Ok(ColumnarValue::Column(Column::Scalar(scale_and_subtract( + alloc, + lhs, + Column::from_literal_with_length(&rhs, lhs.len(), alloc), + lhs_scale, + rhs_scale, + is_equal, + )?))) + } + (ColumnarValue::Literal(lhs), ColumnarValue::Literal(rhs)) => { + Ok(ColumnarValue::Literal(LiteralValue::Scalar( + scale_and_subtract_literal(&lhs, &rhs, lhs_scale, rhs_scale, is_equal)?, + ))) + } + } +} diff --git a/crates/proof-of-sql/src/sql/proof_exprs/numerical_util.rs b/crates/proof-of-sql/src/sql/proof_exprs/numerical_util.rs index eacc03142..8d584d0a1 100644 --- a/crates/proof-of-sql/src/sql/proof_exprs/numerical_util.rs +++ b/crates/proof-of-sql/src/sql/proof_exprs/numerical_util.rs @@ -1,8 +1,36 @@ use crate::base::{ - database::Column, + database::{Column, ColumnarValue, LiteralValue}, scalar::{Scalar, ScalarExt}, }; use bumpalo::Bump; +use core::cmp::Ordering; + +#[allow(clippy::cast_sign_loss)] +/// Add or subtract two literals together. +pub(crate) fn add_subtract_literals( + lhs: &LiteralValue, + rhs: &LiteralValue, + lhs_scale: i8, + rhs_scale: i8, + is_subtract: bool, +) -> S { + let (lhs_scaled, rhs_scaled) = match lhs_scale.cmp(&rhs_scale) { + Ordering::Less => { + let scaling_factor = S::pow10((rhs_scale - lhs_scale) as u8); + (lhs.to_scalar() * scaling_factor, rhs.to_scalar()) + } + Ordering::Equal => (lhs.to_scalar(), rhs.to_scalar()), + Ordering::Greater => { + let scaling_factor = S::pow10((lhs_scale - rhs_scale) as u8); + (lhs.to_scalar(), rhs.to_scalar() * scaling_factor) + } + }; + if is_subtract { + lhs_scaled - rhs_scaled + } else { + lhs_scaled + rhs_scaled + } +} #[allow( clippy::missing_panics_doc, @@ -36,9 +64,62 @@ pub(crate) fn add_subtract_columns<'a, S: Scalar>( result } +/// Add or subtract two [`ColumnarValues`] together. +#[allow(dead_code)] +pub(crate) fn add_subtract_columnar_values<'a, S: Scalar>( + lhs: ColumnarValue<'a, S>, + rhs: ColumnarValue<'a, S>, + lhs_scale: i8, + rhs_scale: i8, + alloc: &'a Bump, + is_subtract: bool, +) -> ColumnarValue<'a, S> { + match (lhs, rhs) { + (ColumnarValue::Column(lhs), ColumnarValue::Column(rhs)) => { + ColumnarValue::Column(Column::Scalar(add_subtract_columns( + lhs, + rhs, + lhs_scale, + rhs_scale, + alloc, + is_subtract, + ))) + } + (ColumnarValue::Literal(lhs), ColumnarValue::Column(rhs)) => { + ColumnarValue::Column(Column::Scalar(add_subtract_columns( + Column::from_literal_with_length(&lhs, rhs.len(), alloc), + rhs, + lhs_scale, + rhs_scale, + alloc, + is_subtract, + ))) + } + (ColumnarValue::Column(lhs), ColumnarValue::Literal(rhs)) => { + ColumnarValue::Column(Column::Scalar(add_subtract_columns( + lhs, + Column::from_literal_with_length(&rhs, lhs.len(), alloc), + lhs_scale, + rhs_scale, + alloc, + is_subtract, + ))) + } + (ColumnarValue::Literal(lhs), ColumnarValue::Literal(rhs)) => { + ColumnarValue::Literal(LiteralValue::Scalar(add_subtract_literals( + &lhs, + &rhs, + lhs_scale, + rhs_scale, + is_subtract, + ))) + } + } +} + /// Multiply two columns together. /// # Panics -/// Panics if: The lengths of `lhs` and `rhs` are not equal.`lhs.scalar_at(i)` or `rhs.scalar_at(i)` returns `None`, which occurs if the column does not have, a scalar at the given index `i`. +/// Panics if: `lhs` and `rhs` are not of the same length. pub(crate) fn multiply_columns<'a, S: Scalar>( lhs: &Column<'a, S>, rhs: &Column<'a, S>, @@ -55,6 +136,38 @@ pub(crate) fn multiply_columns<'a, S: Scalar>( }) } +#[allow(dead_code)] +/// Multiply two [`ColumnarValues`] together. +/// # Panics +/// Panics if: `lhs` and `rhs` are not of the same length. +pub(crate) fn multiply_columnar_values<'a, S: Scalar>( + lhs: &ColumnarValue<'a, S>, + rhs: &ColumnarValue<'a, S>, + alloc: &'a Bump, +) -> ColumnarValue<'a, S> { + match (lhs, rhs) { + (ColumnarValue::Column(lhs), ColumnarValue::Column(rhs)) => { + ColumnarValue::Column(Column::Scalar(multiply_columns(lhs, rhs, alloc))) + } + (ColumnarValue::Literal(lhs), ColumnarValue::Column(rhs)) => { + let lhs_scalar = lhs.to_scalar(); + let result = + alloc.alloc_slice_fill_with(rhs.len(), |i| lhs_scalar * rhs.scalar_at(i).unwrap()); + ColumnarValue::Column(Column::Scalar(result)) + } + (ColumnarValue::Column(lhs), ColumnarValue::Literal(rhs)) => { + let rhs_scalar = rhs.to_scalar(); + let result = + alloc.alloc_slice_fill_with(lhs.len(), |i| lhs.scalar_at(i).unwrap() * rhs_scalar); + ColumnarValue::Column(Column::Scalar(result)) + } + (ColumnarValue::Literal(lhs), ColumnarValue::Literal(rhs)) => { + let result = lhs.to_scalar() * rhs.to_scalar(); + ColumnarValue::Literal(LiteralValue::Scalar(result)) + } + } +} + #[allow( clippy::missing_panics_doc, reason = "scaling factor is guaranteed to not be negative based on input validation prior to calling this function" diff --git a/crates/proof-of-sql/src/sql/proof_plans/dyn_proof_plan.rs b/crates/proof-of-sql/src/sql/proof_plans/dyn_proof_plan.rs index c524a2c76..b7edcc70a 100644 --- a/crates/proof-of-sql/src/sql/proof_plans/dyn_proof_plan.rs +++ b/crates/proof-of-sql/src/sql/proof_plans/dyn_proof_plan.rs @@ -1,9 +1,21 @@ use super::{FilterExec, GroupByExec, ProjectionExec}; use crate::{ - base::{commitment::Commitment, database::Column, map::IndexSet}, - sql::proof::{ProofPlan, ProverEvaluate}, + base::{ + commitment::Commitment, + database::{ + Column, ColumnField, ColumnRef, CommitmentAccessor, DataAccessor, MetadataAccessor, + OwnedTable, TableRef, + }, + map::IndexSet, + proof::ProofError, + }, + sql::proof::{ + CountBuilder, FinalRoundBuilder, FirstRoundBuilder, ProofPlan, ProverEvaluate, + VerificationBuilder, + }, }; use alloc::vec::Vec; +use bumpalo::Bump; use serde::{Deserialize, Serialize}; /// The query plan for proving a query @@ -34,9 +46,9 @@ pub enum DynProofPlan { impl ProofPlan for DynProofPlan { fn count( &self, - builder: &mut crate::sql::proof::CountBuilder, - accessor: &dyn crate::base::database::MetadataAccessor, - ) -> Result<(), crate::base::proof::ProofError> { + builder: &mut CountBuilder, + accessor: &dyn MetadataAccessor, + ) -> Result<(), ProofError> { match self { DynProofPlan::Projection(expr) => expr.count(builder, accessor), DynProofPlan::GroupBy(expr) => expr.count(builder, accessor), @@ -44,7 +56,7 @@ impl ProofPlan for DynProofPlan { } } - fn get_length(&self, accessor: &dyn crate::base::database::MetadataAccessor) -> usize { + fn get_length(&self, accessor: &dyn MetadataAccessor) -> usize { match self { DynProofPlan::Projection(expr) => expr.get_length(accessor), DynProofPlan::GroupBy(expr) => expr.get_length(accessor), @@ -52,7 +64,7 @@ impl ProofPlan for DynProofPlan { } } - fn get_offset(&self, accessor: &dyn crate::base::database::MetadataAccessor) -> usize { + fn get_offset(&self, accessor: &dyn MetadataAccessor) -> usize { match self { DynProofPlan::Projection(expr) => expr.get_offset(accessor), DynProofPlan::GroupBy(expr) => expr.get_offset(accessor), @@ -63,10 +75,10 @@ impl ProofPlan for DynProofPlan { #[tracing::instrument(name = "DynProofPlan::verifier_evaluate", level = "debug", skip_all)] fn verifier_evaluate( &self, - builder: &mut crate::sql::proof::VerificationBuilder, - accessor: &dyn crate::base::database::CommitmentAccessor, - result: Option<&crate::base::database::OwnedTable>, - ) -> Result, crate::base::proof::ProofError> { + builder: &mut VerificationBuilder, + accessor: &dyn CommitmentAccessor, + result: Option<&OwnedTable>, + ) -> Result, ProofError> { match self { DynProofPlan::Projection(expr) => expr.verifier_evaluate(builder, accessor, result), DynProofPlan::GroupBy(expr) => expr.verifier_evaluate(builder, accessor, result), @@ -74,7 +86,7 @@ impl ProofPlan for DynProofPlan { } } - fn get_column_result_fields(&self) -> Vec { + fn get_column_result_fields(&self) -> Vec { match self { DynProofPlan::Projection(expr) => expr.get_column_result_fields(), DynProofPlan::GroupBy(expr) => expr.get_column_result_fields(), @@ -82,13 +94,21 @@ impl ProofPlan for DynProofPlan { } } - fn get_column_references(&self) -> IndexSet { + fn get_column_references(&self) -> IndexSet { match self { DynProofPlan::Projection(expr) => expr.get_column_references(), DynProofPlan::GroupBy(expr) => expr.get_column_references(), DynProofPlan::Filter(expr) => expr.get_column_references(), } } + + fn get_table_references(&self) -> IndexSet { + match self { + DynProofPlan::Projection(expr) => expr.get_table_references(), + DynProofPlan::GroupBy(expr) => expr.get_table_references(), + DynProofPlan::Filter(expr) => expr.get_table_references(), + } + } } impl ProverEvaluate for DynProofPlan { @@ -96,8 +116,8 @@ impl ProverEvaluate for DynProofPlan { fn result_evaluate<'a>( &self, input_length: usize, - alloc: &'a bumpalo::Bump, - accessor: &'a dyn crate::base::database::DataAccessor, + alloc: &'a Bump, + accessor: &'a dyn DataAccessor, ) -> Vec> { match self { DynProofPlan::Projection(expr) => expr.result_evaluate(input_length, alloc, accessor), @@ -106,7 +126,7 @@ impl ProverEvaluate for DynProofPlan { } } - fn first_round_evaluate(&self, builder: &mut crate::sql::proof::FirstRoundBuilder) { + fn first_round_evaluate(&self, builder: &mut FirstRoundBuilder) { match self { DynProofPlan::Projection(expr) => expr.first_round_evaluate(builder), DynProofPlan::GroupBy(expr) => expr.first_round_evaluate(builder), @@ -117,9 +137,9 @@ impl ProverEvaluate for DynProofPlan { #[tracing::instrument(name = "DynProofPlan::final_round_evaluate", level = "debug", skip_all)] fn final_round_evaluate<'a>( &self, - builder: &mut crate::sql::proof::FinalRoundBuilder<'a, C::Scalar>, - alloc: &'a bumpalo::Bump, - accessor: &'a dyn crate::base::database::DataAccessor, + builder: &mut FinalRoundBuilder<'a, C::Scalar>, + alloc: &'a Bump, + accessor: &'a dyn DataAccessor, ) -> Vec> { match self { DynProofPlan::Projection(expr) => expr.final_round_evaluate(builder, alloc, accessor), diff --git a/crates/proof-of-sql/src/sql/proof_plans/filter_exec.rs b/crates/proof-of-sql/src/sql/proof_plans/filter_exec.rs index 4259d3d88..5a1b6106b 100644 --- a/crates/proof-of-sql/src/sql/proof_plans/filter_exec.rs +++ b/crates/proof-of-sql/src/sql/proof_plans/filter_exec.rs @@ -4,7 +4,7 @@ use crate::{ commitment::Commitment, database::{ filter_util::filter_columns, Column, ColumnField, ColumnRef, CommitmentAccessor, - DataAccessor, MetadataAccessor, OwnedTable, + DataAccessor, MetadataAccessor, OwnedTable, TableRef, }, map::IndexSet, proof::ProofError, @@ -139,6 +139,10 @@ where columns } + + fn get_table_references(&self) -> IndexSet { + IndexSet::from_iter([self.table.table_ref]) + } } /// Alias for a filter expression with a honest prover. diff --git a/crates/proof-of-sql/src/sql/proof_plans/filter_exec_test.rs b/crates/proof-of-sql/src/sql/proof_plans/filter_exec_test.rs index c6252d133..062781985 100644 --- a/crates/proof-of-sql/src/sql/proof_plans/filter_exec_test.rs +++ b/crates/proof-of-sql/src/sql/proof_plans/filter_exec_test.rs @@ -153,6 +153,10 @@ fn we_can_correctly_fetch_all_the_referenced_columns() { ) ]) ); + + let ref_tables = provable_ast.get_table_references(); + + assert_eq!(ref_tables, IndexSet::from_iter([table_ref])); } #[test] diff --git a/crates/proof-of-sql/src/sql/proof_plans/group_by_exec.rs b/crates/proof-of-sql/src/sql/proof_plans/group_by_exec.rs index 0a43da82f..385b8a2e7 100644 --- a/crates/proof-of-sql/src/sql/proof_plans/group_by_exec.rs +++ b/crates/proof-of-sql/src/sql/proof_plans/group_by_exec.rs @@ -7,7 +7,7 @@ use crate::{ aggregate_columns, compare_indexes_by_owned_columns, AggregatedColumns, }, Column, ColumnField, ColumnRef, ColumnType, CommitmentAccessor, DataAccessor, - MetadataAccessor, OwnedTable, + MetadataAccessor, OwnedTable, TableRef, }, map::IndexSet, proof::ProofError, @@ -202,6 +202,10 @@ impl ProofPlan for GroupByExec { columns } + + fn get_table_references(&self) -> IndexSet { + IndexSet::from_iter([self.table.table_ref]) + } } impl ProverEvaluate for GroupByExec { diff --git a/crates/proof-of-sql/src/sql/proof_plans/projection_exec.rs b/crates/proof-of-sql/src/sql/proof_plans/projection_exec.rs index fb66bff00..f3038b310 100644 --- a/crates/proof-of-sql/src/sql/proof_plans/projection_exec.rs +++ b/crates/proof-of-sql/src/sql/proof_plans/projection_exec.rs @@ -3,7 +3,7 @@ use crate::{ commitment::Commitment, database::{ Column, ColumnField, ColumnRef, CommitmentAccessor, DataAccessor, MetadataAccessor, - OwnedTable, + OwnedTable, TableRef, }, map::IndexSet, proof::ProofError, @@ -92,6 +92,10 @@ impl ProofPlan for ProjectionExec { }); columns } + + fn get_table_references(&self) -> IndexSet { + IndexSet::from_iter([self.table.table_ref]) + } } impl ProverEvaluate for ProjectionExec { diff --git a/crates/proof-of-sql/src/sql/proof_plans/projection_exec_test.rs b/crates/proof-of-sql/src/sql/proof_plans/projection_exec_test.rs index 3addcfb17..c97ecf471 100644 --- a/crates/proof-of-sql/src/sql/proof_plans/projection_exec_test.rs +++ b/crates/proof-of-sql/src/sql/proof_plans/projection_exec_test.rs @@ -102,6 +102,10 @@ fn we_can_correctly_fetch_all_the_referenced_columns() { ), ]) ); + + let ref_tables = provable_ast.get_table_references(); + + assert_eq!(ref_tables, IndexSet::from_iter([table_ref])); } #[test]