Skip to content

Commit

Permalink
feat: make polars postprocessing optional behind feature flag
Browse files Browse the repository at this point in the history
  • Loading branch information
iajoiner committed Jul 8, 2024
1 parent 962088a commit 30ae609
Show file tree
Hide file tree
Showing 14 changed files with 172 additions and 65 deletions.
8 changes: 8 additions & 0 deletions .github/workflows/lint-and-test.yml
Original file line number Diff line number Diff line change
Expand Up @@ -48,6 +48,14 @@ jobs:
run: cargo check -p proof-of-sql --no-default-features --features="test"
- name: Run cargo check (proof-of-sql) (just "blitzar" feature)
run: cargo check -p proof-of-sql --no-default-features --features="blitzar"
- name: Run cargo check (proof-of-sql) (just "polars" feature)
run: cargo check -p proof-of-sql --no-default-features --features="polars"
- name: Run cargo check (proof-of-sql) (no "test" feature)
run: cargo check -p proof-of-sql --no-default-features --features="blitzar polars"
- name: Run cargo check (proof-of-sql) (no "blitzar" feature)
run: cargo check -p proof-of-sql --no-default-features --features="test polars"
- name: Run cargo check (proof-of-sql) (no "polars" feature)
run: cargo check -p proof-of-sql --no-default-features --features="blitzar test"

test:
name: Test Suite
Expand Down
4 changes: 2 additions & 2 deletions Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@ bigdecimal = { version = "0.4.5", features = ["serde"] }
blake3 = { version = "1.3.3" }
blitzar = { version = "3.0.2" }
bumpalo = { version = "3.11.0" }
bytemuck = {version = "1.14.2" }
bytemuck = {version = "1.14.2", features = ["derive"] }
byte-slice-cast = { version = "1.2.1" }
clap = { version = "4.5.4" }
criterion = { version = "0.5.1" }
Expand All @@ -35,7 +35,7 @@ derive_more = { version = "0.99" }
dyn_partial_eq = { version = "0.1.2" }
flexbuffers = { version = "2.0.0" }
hashbrown = { version = "0.14.0" }
indexmap = { version = "2.1" }
indexmap = { version = "2.1", features = ["serde"] }
itertools = { version = "0.13.0" }
lalrpop-util = { version = "0.20.0" }
lazy_static = { version = "1.4.0" }
Expand Down
6 changes: 3 additions & 3 deletions crates/proof-of-sql/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,7 @@ lazy_static = { workspace = true }
merlin = { workspace = true }
num-traits = { workspace = true }
num-bigint = { workspace = true, default-features = false }
polars = { workspace = true, features = ["lazy", "bigidx", "dtype-decimal", "serde-lazy"] }
polars = { workspace = true, features = ["lazy", "bigidx", "dtype-decimal", "serde-lazy"], optional = true }
postcard = { workspace = true, features = ["alloc"] }
proof-of-sql-parser = { workspace = true }
rand = { workspace = true, optional = true }
Expand All @@ -59,7 +59,7 @@ clap = { workspace = true, features = ["derive"] }
criterion = { workspace = true, features = ["html_reports"] }
opentelemetry = { workspace = true }
opentelemetry-jaeger = { workspace = true }
polars = { workspace = true, features = ["lazy"] }
polars = { workspace = true, features = ["lazy", "dtype-decimal"] }
rand = { workspace = true }
rand_core = { workspace = true }
serde_json = { workspace = true }
Expand All @@ -69,7 +69,7 @@ tracing-subscriber = { workspace = true }
flexbuffers = { workspace = true }

[features]
default = ["blitzar"]
default = ["blitzar", "polars"]
test = ["dep:rand"]

[lints]
Expand Down
2 changes: 2 additions & 0 deletions crates/proof-of-sql/src/base/database/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,9 @@ pub use table_ref::TableRef;
mod arrow_array_to_column_conversion;
pub use arrow_array_to_column_conversion::{ArrayRefExt, ArrowArrayToColumnConversionError};

#[cfg(any(test, feature = "polars"))]
mod record_batch_dataframe_conversion;
#[cfg(any(test, feature = "polars"))]
pub(crate) use record_batch_dataframe_conversion::{
dataframe_to_record_batch, record_batch_to_dataframe,
};
Expand Down
Original file line number Diff line number Diff line change
@@ -1,13 +1,15 @@
#[cfg(any(test, feature = "polars"))]
use super::{dataframe_to_record_batch, record_batch_to_dataframe};
use super::{
dataframe_to_record_batch, record_batch_to_dataframe, ArrayRefExt, Column, ColumnRef,
ColumnType, CommitmentAccessor, DataAccessor, MetadataAccessor, SchemaAccessor, TableRef,
TestAccessor,
ArrayRefExt, Column, ColumnRef, ColumnType, CommitmentAccessor, DataAccessor, MetadataAccessor,
SchemaAccessor, TableRef, TestAccessor,
};
use crate::base::scalar::{compute_commitment_for_testing, Curve25519Scalar};
use arrow::{array::ArrayRef, datatypes::DataType, record_batch::RecordBatch};
use bumpalo::Bump;
use curve25519_dalek::ristretto::RistrettoPoint;
use indexmap::IndexMap;
#[cfg(any(test, feature = "polars"))]
use polars::prelude::DataFrame;
use proof_of_sql_parser::Identifier;
use std::collections::HashMap;
Expand Down Expand Up @@ -114,6 +116,7 @@ impl TestAccessor<RistrettoPoint> for RecordBatchTestAccessor {

impl RecordBatchTestAccessor {
/// Apply a query function to table and then convert the result to a RecordBatch
#[cfg(any(test, feature = "polars"))]
pub fn query_table(
&self,
table_ref: TableRef,
Expand Down
7 changes: 2 additions & 5 deletions crates/proof-of-sql/src/sql/parse/result_expr_builder.rs
Original file line number Diff line number Diff line change
Expand Up @@ -36,13 +36,10 @@ impl ResultExprBuilder {
.iter()
.map(|aliased_expr| Expression::Column(aliased_expr.alias))
.collect();
self.composition
.add(Box::new(SelectExpr::new_from_expressions(&exprs)));
self.composition.add(Box::new(SelectExpr::new(&exprs)));
} else {
self.composition
.add(Box::new(SelectExpr::new_from_aliased_result_exprs(
aliased_exprs,
)));
.add(Box::new(SelectExpr::new(aliased_exprs)));
}
self
}
Expand Down
48 changes: 32 additions & 16 deletions crates/proof-of-sql/src/sql/transform/group_by_expr.rs
Original file line number Diff line number Diff line change
@@ -1,7 +1,10 @@
#[allow(deprecated)]
#[cfg(feature = "polars")]
use super::DataFrameExpr;
#[cfg(feature = "polars")]
use super::{ToPolarsExpr, INT128_PRECISION, INT128_SCALE};
use dyn_partial_eq::DynPartialEq;
#[cfg(feature = "polars")]
use polars::prelude::{col, DataType, Expr, GetOutput, LazyFrame, NamedFrom, Series};
use proof_of_sql_parser::{intermediate_ast::AliasedResultExpr, Identifier};
use serde::{Deserialize, Serialize};
Expand All @@ -10,56 +13,67 @@ use serde::{Deserialize, Serialize};
#[derive(Debug, DynPartialEq, PartialEq, Serialize, Deserialize)]
pub struct GroupByExpr {
/// A list of aggregation column expressions
agg_exprs: Vec<Expr>,
aliased_exprs: Vec<AliasedResultExpr>,

/// A list of group by column expressions
by_exprs: Vec<Expr>,
by_ids: Vec<Identifier>,
}

impl GroupByExpr {
/// Create a new group by expression containing the group by and aggregation expressions
pub fn new(by_ids: &[Identifier], aliased_exprs: &[AliasedResultExpr]) -> Self {
let by_exprs = Vec::from_iter(by_ids.iter().map(|id| col(id.as_str())));
let agg_exprs = Vec::from_iter(aliased_exprs.iter().map(ToPolarsExpr::to_polars_expr));
assert!(!agg_exprs.is_empty(), "Agg expressions must not be empty");
assert!(
!by_exprs.is_empty(),
"Group by expressions must not be empty"
!aliased_exprs.is_empty(),
"Agg expressions must not be empty"
);
assert!(!by_ids.is_empty(), "Group by expressions must not be empty");

Self {
by_exprs,
agg_exprs,
by_ids: by_ids.to_vec(),
aliased_exprs: aliased_exprs.to_vec(),
}
}

#[cfg(feature = "polars")]
fn agg_exprs(&self) -> Vec<Expr> {
self.aliased_exprs
.iter()
.map(ToPolarsExpr::to_polars_expr)
.collect()
}
}

#[cfg(not(feature = "polars"))]
#[typetag::serde]
impl super::RecordBatchExpr for GroupByExpr {}
#[cfg(feature = "polars")]
super::impl_record_batch_expr_for_data_frame_expr!(GroupByExpr);
#[allow(deprecated)]
#[cfg(feature = "polars")]
impl DataFrameExpr for GroupByExpr {
fn lazy_transformation(&self, lazy_frame: LazyFrame, num_input_rows: usize) -> LazyFrame {
// TODO: polars currently lacks support for min/max aggregation in data frames
// with either zero or one element when a group by operation is applied.
// We remove the group by clause to temporarily work around this limitation.
// Issue created to track progress: https://github.com/pola-rs/polars/issues/11232
if num_input_rows == 0 {
return lazy_frame.select(&self.agg_exprs).limit(0);
return lazy_frame.select(&self.agg_exprs()).limit(0);
}

if num_input_rows == 1 {
return lazy_frame.select(&self.agg_exprs);
return lazy_frame.select(&self.agg_exprs());
}

// Add invalid column aliases to group by expressions so that we can
// exclude them from the final result.
let by_expr_aliases = (0..self.by_exprs.len())
let by_expr_aliases = (0..self.by_ids.len())
.map(|pos| "#$".to_owned() + pos.to_string().as_str())
.collect::<Vec<_>>();

let by_exprs: Vec<_> = self
.by_exprs
.clone()
.into_iter()
.by_ids
.iter()
.map(|id| col(id.as_str()))
.zip(by_expr_aliases.iter())
.map(|(expr, alias)| expr.alias(alias.as_str()))
// TODO: remove this mapping once Polars supports decimal columns inside group by
Expand All @@ -71,11 +85,12 @@ impl DataFrameExpr for GroupByExpr {
// to avoid non-deterministic results with our tests.
lazy_frame
.group_by_stable(&by_exprs)
.agg(&self.agg_exprs)
.agg(&self.agg_exprs())
.select(&[col("*").exclude(by_expr_aliases)])
}
}

#[cfg(any(test, feature = "polars"))]
pub(crate) fn group_by_map_i128_to_utf8(v: i128) -> String {
// use big end to allow
// skipping leading zeros
Expand All @@ -99,6 +114,7 @@ pub(crate) fn group_by_map_i128_to_utf8(v: i128) -> String {

// Polars doesn't support Decimal columns inside group by.
// So we need to remap them to the supported UTF8 type.
#[cfg(feature = "polars")]
fn group_by_map_to_utf8_if_decimal(expr: Expr) -> Expr {
expr.map(
|series| match series.dtype().clone() {
Expand Down
11 changes: 11 additions & 0 deletions crates/proof-of-sql/src/sql/transform/mod.rs
Original file line number Diff line number Diff line change
@@ -1,8 +1,10 @@
//! This module contains postprocessing for non-provable components.
/// The precision for [ColumnType::INT128] values
#[cfg(feature = "polars")]
pub const INT128_PRECISION: usize = 38;

/// The scale for [ColumnType::INT128] values
#[cfg(feature = "polars")]
pub const INT128_SCALE: usize = 0;

mod result_expr;
Expand All @@ -17,10 +19,13 @@ pub use composition_expr::CompositionExpr;
#[cfg(test)]
pub mod composition_expr_test;

#[cfg(feature = "polars")]
mod data_frame_expr;
#[allow(deprecated)]
#[cfg(feature = "polars")]
pub(crate) use data_frame_expr::DataFrameExpr;
mod record_batch_expr;
#[cfg(feature = "polars")]
pub(crate) use record_batch_expr::impl_record_batch_expr_for_data_frame_expr;
pub use record_batch_expr::RecordBatchExpr;

Expand Down Expand Up @@ -53,10 +58,16 @@ pub use group_by_expr::GroupByExpr;
#[cfg(test)]
mod group_by_expr_test;

#[cfg(feature = "polars")]
mod polars_conversions;
#[cfg(feature = "polars")]
pub use polars_conversions::LiteralConversion;

#[cfg(feature = "polars")]
mod polars_arithmetic;
#[cfg(feature = "polars")]
pub use polars_arithmetic::SafeDivision;
#[cfg(feature = "polars")]
mod to_polars_expr;
#[cfg(feature = "polars")]
pub(crate) use to_polars_expr::ToPolarsExpr;
15 changes: 14 additions & 1 deletion crates/proof-of-sql/src/sql/transform/order_by_exprs.rs
Original file line number Diff line number Diff line change
@@ -1,10 +1,16 @@
#[allow(deprecated)]
#[cfg(feature = "polars")]
use super::DataFrameExpr;
#[cfg(feature = "polars")]
use super::{INT128_PRECISION, INT128_SCALE};
#[cfg(any(test, feature = "polars"))]
use arrow::datatypes::ArrowNativeTypeOp;
use dyn_partial_eq::DynPartialEq;
#[cfg(feature = "polars")]
use polars::prelude::{col, DataType, Expr, GetOutput, LazyFrame, NamedFrom, Series};
use proof_of_sql_parser::intermediate_ast::{OrderBy, OrderByDirection};
use proof_of_sql_parser::intermediate_ast::OrderBy;
#[cfg(feature = "polars")]
use proof_of_sql_parser::intermediate_ast::OrderByDirection;
use serde::{Deserialize, Serialize};

/// A node representing a list of `OrderBy` expressions.
Expand All @@ -20,8 +26,13 @@ impl OrderByExprs {
}
}

#[cfg(not(feature = "polars"))]
#[typetag::serde]
impl super::RecordBatchExpr for OrderByExprs {}
#[cfg(feature = "polars")]
super::impl_record_batch_expr_for_data_frame_expr!(OrderByExprs);
#[allow(deprecated)]
#[cfg(feature = "polars")]
impl DataFrameExpr for OrderByExprs {
/// Sort the `LazyFrame` by the `OrderBy` expressions.
fn lazy_transformation(&self, lazy_frame: LazyFrame, _: usize) -> LazyFrame {
Expand Down Expand Up @@ -51,6 +62,7 @@ impl DataFrameExpr for OrderByExprs {
/// * `a < b` if and only if `map_i128_to_utf8(a) < map_i128_to_utf8(b)`.
/// * `a == b` if and only if `map_i128_to_utf8(a) == map_i128_to_utf8(b)`.
/// * `a > b` if and only if `map_i128_to_utf8(a) > map_i128_to_utf8(b)`.
#[cfg(any(test, feature = "polars"))]
pub(crate) fn order_by_map_i128_to_utf8(v: i128) -> String {
let is_neg = v.is_negative() as u8;
v.abs()
Expand Down Expand Up @@ -78,6 +90,7 @@ pub(crate) fn order_by_map_i128_to_utf8(v: i128) -> String {

// Polars doesn't support Decimal columns inside order by.
// So we need to remap them to the supported UTF8 type.
#[cfg(feature = "polars")]
fn order_by_map_to_utf8_if_decimal(expr: Expr) -> Expr {
expr.map(
|series| match series.dtype().clone() {
Expand Down
7 changes: 6 additions & 1 deletion crates/proof-of-sql/src/sql/transform/record_batch_expr.rs
Original file line number Diff line number Diff line change
Expand Up @@ -7,9 +7,13 @@ use std::fmt::Debug;
#[dyn_partial_eq]
pub trait RecordBatchExpr: Debug + Send + Sync {
/// Apply the transformation to the `RecordBatch` and return the result.
fn apply_transformation(&self, record_batch: RecordBatch) -> Option<RecordBatch>;
#[allow(unused_variables)]
fn apply_transformation(&self, record_batch: RecordBatch) -> Option<RecordBatch> {
None
}
}

#[cfg(feature = "polars")]
macro_rules! impl_record_batch_expr_for_data_frame_expr {
($t:ty) => {
#[typetag::serde]
Expand All @@ -29,4 +33,5 @@ macro_rules! impl_record_batch_expr_for_data_frame_expr {
};
}

#[cfg(feature = "polars")]
pub(crate) use impl_record_batch_expr_for_data_frame_expr;
10 changes: 6 additions & 4 deletions crates/proof-of-sql/src/sql/transform/result_expr.rs
Original file line number Diff line number Diff line change
@@ -1,9 +1,9 @@
use crate::{
base::database::{dataframe_to_record_batch, record_batch_to_dataframe},
sql::transform::RecordBatchExpr,
};
#[cfg(feature = "polars")]
use crate::base::database::{dataframe_to_record_batch, record_batch_to_dataframe};
use crate::sql::transform::RecordBatchExpr;
use arrow::record_batch::RecordBatch;
use dyn_partial_eq::DynPartialEq;
#[cfg(feature = "polars")]
use polars::prelude::{IntoLazy, LazyFrame};
use serde::{Deserialize, Serialize};

Expand All @@ -23,11 +23,13 @@ impl ResultExpr {
}
}

#[cfg(feature = "polars")]
pub(super) fn record_batch_to_lazy_frame(result_batch: RecordBatch) -> Option<(LazyFrame, usize)> {
let num_input_rows = result_batch.num_rows();
let df = record_batch_to_dataframe(result_batch)?;
Some((df.lazy(), num_input_rows))
}
#[cfg(feature = "polars")]
pub(super) fn lazy_frame_to_record_batch(lazy_frame: LazyFrame) -> Option<RecordBatch> {
// We're currently excluding NULLs in post-processing due to a lack of
// prover support, aiming to avoid future complexities.
Expand Down
Loading

0 comments on commit 30ae609

Please sign in to comment.