From 5165e5e67d14be17b9ee4713dd6f6b01d29e963a Mon Sep 17 00:00:00 2001 From: Kev Wang Date: Tue, 17 Dec 2024 12:25:50 -0800 Subject: [PATCH] chore: move symbolic and boolean algebra code into new crate (#3570) Moving some of the code we have around symbolic/boolean algebra on expressions into its own crate, because I anticipate that we will be building more of this kind of thing, so it would be nicer to consolidate it as well as make it easier to reuse. It also allows us to better test these things in isolation of the context they are being used in. For example, we'll be building some optimization rules that more intelligently finds predicates for filter pushdown into joins, and that may use both `split_conjunction` as well as some expression simplification logic. Also took this opportunity to fix a typo (conjuct -> conjunct) and rename `conjunct` (which is an adjective) to `combine_conjunction`. Otherwise everything else is pretty much a straightforward move --- Cargo.lock | 15 + Cargo.toml | 2 + src/common/scan-info/Cargo.toml | 1 + src/common/scan-info/src/expr_rewriter.rs | 9 +- src/daft-algebra/Cargo.toml | 16 + src/daft-algebra/src/boolean.rs | 24 + src/daft-algebra/src/lib.rs | 4 + src/daft-algebra/src/simplify.rs | 465 ++++++++++++++++++ src/daft-dsl/src/optimization.rs | 30 +- src/daft-logical-plan/Cargo.toml | 1 + .../optimization/rules/push_down_filter.rs | 37 +- .../rules/simplify_expressions.rs | 463 +---------------- .../src/optimization/rules/unnest_subquery.rs | 27 +- src/daft-sql/Cargo.toml | 1 + src/daft-sql/src/planner.rs | 10 +- 15 files changed, 570 insertions(+), 535 deletions(-) create mode 100644 src/daft-algebra/Cargo.toml create mode 100644 src/daft-algebra/src/boolean.rs create mode 100644 src/daft-algebra/src/lib.rs create mode 100644 src/daft-algebra/src/simplify.rs diff --git a/Cargo.lock b/Cargo.lock index 3011a56b24..048324ca60 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -1577,6 +1577,7 @@ dependencies = [ "common-display", "common-error", "common-file-formats", + "daft-algebra", "daft-dsl", "daft-schema", "pyo3", @@ -1905,6 +1906,7 @@ dependencies = [ "common-system-info", "common-tracing", "common-version", + "daft-algebra", "daft-catalog", "daft-catalog-python-catalog", "daft-compression", @@ -1941,6 +1943,17 @@ dependencies = [ "tikv-jemallocator", ] +[[package]] +name = "daft-algebra" +version = "0.3.0-dev0" +dependencies = [ + "common-error", + "common-treenode", + "daft-dsl", + "daft-schema", + "rstest", +] + [[package]] name = "daft-catalog" version = "0.3.0-dev0" @@ -2329,6 +2342,7 @@ dependencies = [ "common-resource-request", "common-scan-info", "common-treenode", + "daft-algebra", "daft-core", "daft-dsl", "daft-functions", @@ -2536,6 +2550,7 @@ dependencies = [ "common-error", "common-io-config", "common-runtime", + "daft-algebra", "daft-core", "daft-dsl", "daft-functions", diff --git a/Cargo.toml b/Cargo.toml index d5a5cf218d..ba00e45cd4 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -9,6 +9,7 @@ common-scan-info = {path = "src/common/scan-info", default-features = false} common-system-info = {path = "src/common/system-info", default-features = false} common-tracing = {path = "src/common/tracing", default-features = false} common-version = {path = "src/common/version", default-features = false} +daft-algebra = {path = "src/daft-algebra", default-features = false} daft-catalog = {path = "src/daft-catalog", default-features = false} daft-catalog-python-catalog = {path = "src/daft-catalog/python-catalog", optional = true} daft-compression = {path = "src/daft-compression", default-features = false} @@ -149,6 +150,7 @@ members = [ "src/common/scan-info", "src/common/system-info", "src/common/treenode", + "src/daft-algebra", "src/daft-catalog", "src/daft-core", "src/daft-csv", diff --git a/src/common/scan-info/Cargo.toml b/src/common/scan-info/Cargo.toml index 04f9550997..0aecf55f6e 100644 --- a/src/common/scan-info/Cargo.toml +++ b/src/common/scan-info/Cargo.toml @@ -3,6 +3,7 @@ common-daft-config = {path = "../daft-config", default-features = false} common-display = {path = "../display", default-features = false} common-error = {path = "../error", default-features = false} common-file-formats = {path = "../file-formats", default-features = false} +daft-algebra = {path = "../../daft-algebra", default-features = false} daft-dsl = {path = "../../daft-dsl", default-features = false} daft-schema = {path = "../../daft-schema", default-features = false} pyo3 = {workspace = true, optional = true} diff --git a/src/common/scan-info/src/expr_rewriter.rs b/src/common/scan-info/src/expr_rewriter.rs index f678ad07c1..fedf212a18 100644 --- a/src/common/scan-info/src/expr_rewriter.rs +++ b/src/common/scan-info/src/expr_rewriter.rs @@ -1,13 +1,12 @@ use std::collections::HashMap; use common_error::DaftResult; +use daft_algebra::boolean::split_conjunction; use daft_dsl::{ col, common_treenode::{Transformed, TreeNode, TreeNodeRecursion}, functions::{partitioning, FunctionExpr}, - null_lit, - optimization::split_conjuction, - Expr, ExprRef, Operator, + null_lit, Expr, ExprRef, Operator, }; use crate::{PartitionField, PartitionTransform}; @@ -93,7 +92,7 @@ pub fn rewrite_predicate_for_partitioning( // Before rewriting predicate for partition filter pushdown, partition predicate clauses into groups that will need // to be applied at the data level (i.e. any clauses that aren't pure partition predicates with identity // transformations). - let data_split = split_conjuction(predicate); + let data_split = split_conjunction(predicate); // Predicates that reference both partition columns and data columns. let mut needs_filter_op_preds: Vec = vec![]; // Predicates that only reference data columns (no partition column references) or only reference partition columns @@ -332,7 +331,7 @@ pub fn rewrite_predicate_for_partitioning( let with_part_cols = with_part_cols.data; // Filter to predicate clauses that only involve partition columns. - let split = split_conjuction(&with_part_cols); + let split = split_conjunction(&with_part_cols); let mut part_preds: Vec = vec![]; for e in split { let mut all_part_keys = true; diff --git a/src/daft-algebra/Cargo.toml b/src/daft-algebra/Cargo.toml new file mode 100644 index 0000000000..89e942c700 --- /dev/null +++ b/src/daft-algebra/Cargo.toml @@ -0,0 +1,16 @@ +[dependencies] +common-error = {path = "../common/error", default-features = false} +common-treenode = {path = "../common/treenode", default-features = false} +daft-dsl = {path = "../daft-dsl", default-features = false} +daft-schema = {path = "../daft-schema", default-features = false} + +[dev-dependencies] +rstest = {workspace = true} + +[lints] +workspace = true + +[package] +edition = {workspace = true} +name = "daft-algebra" +version = {workspace = true} diff --git a/src/daft-algebra/src/boolean.rs b/src/daft-algebra/src/boolean.rs new file mode 100644 index 0000000000..38f659e00c --- /dev/null +++ b/src/daft-algebra/src/boolean.rs @@ -0,0 +1,24 @@ +use common_treenode::{TreeNode, TreeNodeRecursion}; +use daft_dsl::{Expr, ExprRef, Operator}; + +pub fn split_conjunction(expr: &ExprRef) -> Vec { + let mut splits = vec![]; + + expr.apply(|e| match e.as_ref() { + Expr::BinaryOp { + op: Operator::And, .. + } + | Expr::Alias(..) => Ok(TreeNodeRecursion::Continue), + _ => { + splits.push(e.clone()); + Ok(TreeNodeRecursion::Jump) + } + }) + .unwrap(); + + splits +} + +pub fn combine_conjunction>(exprs: T) -> Option { + exprs.into_iter().reduce(|acc, e| acc.and(e)) +} diff --git a/src/daft-algebra/src/lib.rs b/src/daft-algebra/src/lib.rs new file mode 100644 index 0000000000..317ef5eea1 --- /dev/null +++ b/src/daft-algebra/src/lib.rs @@ -0,0 +1,4 @@ +pub mod boolean; +mod simplify; + +pub use simplify::simplify_expr; diff --git a/src/daft-algebra/src/simplify.rs b/src/daft-algebra/src/simplify.rs new file mode 100644 index 0000000000..698a48fa1b --- /dev/null +++ b/src/daft-algebra/src/simplify.rs @@ -0,0 +1,465 @@ +use std::sync::Arc; + +use common_error::DaftResult; +use common_treenode::Transformed; +use daft_dsl::{lit, null_lit, Expr, ExprRef, LiteralValue, Operator}; +use daft_schema::{dtype::DataType, schema::SchemaRef}; + +pub fn simplify_expr(expr: Expr, schema: &SchemaRef) -> DaftResult> { + Ok(match expr { + // ---------------- + // Eq + // ---------------- + // true = A --> A + // false = A --> !A + Expr::BinaryOp { + op: Operator::Eq, + left, + right, + } + // A = true --> A + // A = false --> !A + | Expr::BinaryOp { + op: Operator::Eq, + left: right, + right: left, + } if is_bool_lit(&left) && is_bool_type(&right, schema) => { + Transformed::yes(match as_bool_lit(&left) { + Some(true) => right, + Some(false) => right.not(), + None => unreachable!(), + }) + } + + // null = A --> null + // A = null --> null + Expr::BinaryOp { + op: Operator::Eq, + left, + right, + } + | Expr::BinaryOp { + op: Operator::Eq, + left: right, + right: left, + } if is_null(&left) && is_bool_type(&right, schema) => Transformed::yes(null_lit()), + + // ---------------- + // Neq + // ---------------- + // true != A --> !A + // false != A --> A + Expr::BinaryOp { + op: Operator::NotEq, + left, + right, + } + // A != true --> !A + // A != false --> A + | Expr::BinaryOp { + op: Operator::NotEq, + left: right, + right: left, + } if is_bool_lit(&left) && is_bool_type(&right, schema) => { + Transformed::yes(match as_bool_lit(&left) { + Some(true) => right.not(), + Some(false) => right, + None => unreachable!(), + }) + } + + // null != A --> null + // A != null --> null + Expr::BinaryOp { + op: Operator::NotEq, + left, + right, + } + | Expr::BinaryOp { + op: Operator::NotEq, + left: right, + right: left, + } if is_null(&left) && is_bool_type(&right, schema) => Transformed::yes(null_lit()), + + // ---------------- + // OR + // ---------------- + + // true OR A --> true + Expr::BinaryOp { + op: Operator::Or, + left, + right: _, + } if is_true(&left) => Transformed::yes(left), + // false OR A --> A + Expr::BinaryOp { + op: Operator::Or, + left, + right, + } if is_false(&left) => Transformed::yes(right), + // A OR true --> true + Expr::BinaryOp { + op: Operator::Or, + left: _, + right, + } if is_true(&right) => Transformed::yes(right), + // A OR false --> A + Expr::BinaryOp { + left, + op: Operator::Or, + right, + } if is_false(&right) => Transformed::yes(left), + + // ---------------- + // AND (TODO) + // ---------------- + + // ---------------- + // Multiplication + // ---------------- + + // A * 1 --> A + // 1 * A --> A + Expr::BinaryOp { + op: Operator::Multiply, + left, + right, + }| Expr::BinaryOp { + op: Operator::Multiply, + left: right, + right: left, + } if is_one(&right) => Transformed::yes(left), + + // A * null --> null + Expr::BinaryOp { + op: Operator::Multiply, + left: _, + right, + } if is_null(&right) => Transformed::yes(right), + // null * A --> null + Expr::BinaryOp { + op: Operator::Multiply, + left, + right: _, + } if is_null(&left) => Transformed::yes(left), + + // TODO: Can't do this one because we don't have a way to determine if an expr potentially contains nulls (nullable) + // A * 0 --> 0 (if A is not null and not floating/decimal) + // 0 * A --> 0 (if A is not null and not floating/decimal) + + // ---------------- + // Division + // ---------------- + // A / 1 --> A + Expr::BinaryOp { + op: Operator::TrueDivide, + left, + right, + } if is_one(&right) => Transformed::yes(left), + // null / A --> null + Expr::BinaryOp { + op: Operator::TrueDivide, + left, + right: _, + } if is_null(&left) => Transformed::yes(left), + // A / null --> null + Expr::BinaryOp { + op: Operator::TrueDivide, + left: _, + right, + } if is_null(&right) => Transformed::yes(right), + + // ---------------- + // Addition + // ---------------- + // A + 0 --> A + Expr::BinaryOp { + op: Operator::Plus, + left, + right, + } if is_zero(&right) => Transformed::yes(left), + + // 0 + A --> A + Expr::BinaryOp { + op: Operator::Plus, + left, + right, + } if is_zero(&left) => Transformed::yes(right), + + // ---------------- + // Subtraction + // ---------------- + + // A - 0 --> A + Expr::BinaryOp { + op: Operator::Minus, + left, + right, + } if is_zero(&right) => Transformed::yes(left), + + // A - null --> null + Expr::BinaryOp { + op: Operator::Minus, + left: _, + right, + } if is_null(&right) => Transformed::yes(right), + // null - A --> null + Expr::BinaryOp { + op: Operator::Minus, + left, + right: _, + } if is_null(&left) => Transformed::yes(left), + + // ---------------- + // Modulus + // ---------------- + + // A % null --> null + Expr::BinaryOp { + op: Operator::Modulus, + left: _, + right, + } if is_null(&right) => Transformed::yes(right), + + // null % A --> null + Expr::BinaryOp { + op: Operator::Modulus, + left, + right: _, + } if is_null(&left) => Transformed::yes(left), + + // A BETWEEN low AND high --> A >= low AND A <= high + Expr::Between(expr, low, high) => { + Transformed::yes(expr.clone().lt_eq(high).and(expr.gt_eq(low))) + } + Expr::Not(expr) => match Arc::unwrap_or_clone(expr) { + // NOT (BETWEEN A AND B) --> A < low OR A > high + Expr::Between(expr, low, high) => { + Transformed::yes(expr.clone().lt(low).or(expr.gt(high))) + } + // expr NOT IN () --> true + Expr::IsIn(_, list) if list.is_empty() => Transformed::yes(lit(true)), + + expr => { + let expr = simplify_expr(expr, schema)?; + if expr.transformed { + Transformed::yes(expr.data.not()) + } else { + Transformed::no(expr.data.not()) + } + } + }, + // expr IN () --> false + Expr::IsIn(_, list) if list.is_empty() => Transformed::yes(lit(false)), + + other => Transformed::no(Arc::new(other)), + }) +} + +fn is_zero(s: &Expr) -> bool { + match s { + Expr::Literal(LiteralValue::Int32(0)) + | Expr::Literal(LiteralValue::Int64(0)) + | Expr::Literal(LiteralValue::UInt32(0)) + | Expr::Literal(LiteralValue::UInt64(0)) + | Expr::Literal(LiteralValue::Float64(0.)) => true, + Expr::Literal(LiteralValue::Decimal(v, _p, _s)) if *v == 0 => true, + _ => false, + } +} + +fn is_one(s: &Expr) -> bool { + match s { + Expr::Literal(LiteralValue::Int32(1)) + | Expr::Literal(LiteralValue::Int64(1)) + | Expr::Literal(LiteralValue::UInt32(1)) + | Expr::Literal(LiteralValue::UInt64(1)) + | Expr::Literal(LiteralValue::Float64(1.)) => true, + + Expr::Literal(LiteralValue::Decimal(v, _p, s)) => { + *s >= 0 && POWS_OF_TEN.get(*s as usize).is_some_and(|pow| v == pow) + } + _ => false, + } +} + +fn is_true(expr: &Expr) -> bool { + match expr { + Expr::Literal(LiteralValue::Boolean(v)) => *v, + _ => false, + } +} +fn is_false(expr: &Expr) -> bool { + match expr { + Expr::Literal(LiteralValue::Boolean(v)) => !*v, + _ => false, + } +} + +/// returns true if expr is a +/// `Expr::Literal(LiteralValue::Boolean(v))` , false otherwise +fn is_bool_lit(expr: &Expr) -> bool { + matches!(expr, Expr::Literal(LiteralValue::Boolean(_))) +} + +fn is_bool_type(expr: &Expr, schema: &SchemaRef) -> bool { + matches!(expr.get_type(schema), Ok(DataType::Boolean)) +} + +fn as_bool_lit(expr: &Expr) -> Option { + expr.as_literal().and_then(|l| l.as_bool()) +} + +fn is_null(expr: &Expr) -> bool { + matches!(expr, Expr::Literal(LiteralValue::Null)) +} + +static POWS_OF_TEN: [i128; 38] = [ + 1, + 10, + 100, + 1000, + 10000, + 100000, + 1000000, + 10000000, + 100000000, + 1000000000, + 10000000000, + 100000000000, + 1000000000000, + 10000000000000, + 100000000000000, + 1000000000000000, + 10000000000000000, + 100000000000000000, + 1000000000000000000, + 10000000000000000000, + 100000000000000000000, + 1000000000000000000000, + 10000000000000000000000, + 100000000000000000000000, + 1000000000000000000000000, + 10000000000000000000000000, + 100000000000000000000000000, + 1000000000000000000000000000, + 10000000000000000000000000000, + 100000000000000000000000000000, + 1000000000000000000000000000000, + 10000000000000000000000000000000, + 100000000000000000000000000000000, + 1000000000000000000000000000000000, + 10000000000000000000000000000000000, + 100000000000000000000000000000000000, + 1000000000000000000000000000000000000, + 10000000000000000000000000000000000000, +]; + +#[cfg(test)] +mod test { + use std::sync::Arc; + + use common_error::DaftResult; + use daft_dsl::{col, lit, null_lit, ExprRef}; + use daft_schema::{ + dtype::DataType, + field::Field, + schema::{Schema, SchemaRef}, + }; + use rstest::{fixture, rstest}; + + use crate::simplify_expr; + + #[fixture] + fn schema() -> SchemaRef { + Arc::new( + Schema::new(vec![ + Field::new("bool", DataType::Boolean), + Field::new("int", DataType::Int32), + ]) + .unwrap(), + ) + } + + #[rstest] + // true = A --> A + #[case(col("bool").eq(lit(true)), col("bool"))] + // false = A --> !A + #[case(col("bool").eq(lit(false)), col("bool").not())] + // A == true ---> A + #[case(col("bool").eq(lit(true)), col("bool"))] + // null = A --> null + #[case(null_lit().eq(col("bool")), null_lit())] + // A == false ---> !A + #[case(col("bool").eq(lit(false)), col("bool").not())] + // true != A --> !A + #[case(lit(true).not_eq(col("bool")), col("bool").not())] + // false != A --> A + #[case(lit(false).not_eq(col("bool")), col("bool"))] + // true OR A --> true + #[case(lit(true).or(col("bool")), lit(true))] + // false OR A --> A + #[case(lit(false).or(col("bool")), col("bool"))] + // A OR true --> true + #[case(col("bool").or(lit(true)), lit(true))] + // A OR false --> A + #[case(col("bool").or(lit(false)), col("bool"))] + fn test_simplify_bool_exprs( + #[case] input: ExprRef, + #[case] expected: ExprRef, + schema: SchemaRef, + ) -> DaftResult<()> { + let optimized = simplify_expr(Arc::unwrap_or_clone(input), &schema)?; + + assert!(optimized.transformed); + assert_eq!(optimized.data, expected); + Ok(()) + } + + #[rstest] + // A * 1 --> A + #[case(col("int").mul(lit(1)), col("int"))] + // 1 * A --> A + #[case(lit(1).mul(col("int")), col("int"))] + // A / 1 --> A + #[case(col("int").div(lit(1)), col("int"))] + // A + 0 --> A + #[case(col("int").add(lit(0)), col("int"))] + // A - 0 --> A + #[case(col("int").sub(lit(0)), col("int"))] + fn test_math_exprs( + #[case] input: ExprRef, + #[case] expected: ExprRef, + schema: SchemaRef, + ) -> DaftResult<()> { + let optimized = simplify_expr(Arc::unwrap_or_clone(input), &schema)?; + + assert!(optimized.transformed); + assert_eq!(optimized.data, expected); + Ok(()) + } + + #[rstest] + fn test_not_between(schema: SchemaRef) -> DaftResult<()> { + let input = col("int").between(lit(1), lit(10)).not(); + let expected = col("int").lt(lit(1)).or(col("int").gt(lit(10))); + + let optimized = simplify_expr(Arc::unwrap_or_clone(input), &schema)?; + + assert!(optimized.transformed); + assert_eq!(optimized.data, expected); + Ok(()) + } + + #[rstest] + fn test_between(schema: SchemaRef) -> DaftResult<()> { + let input = col("int").between(lit(1), lit(10)); + let expected = col("int").lt_eq(lit(10)).and(col("int").gt_eq(lit(1))); + + let optimized = simplify_expr(Arc::unwrap_or_clone(input), &schema)?; + + assert!(optimized.transformed); + assert_eq!(optimized.data, expected); + Ok(()) + } +} diff --git a/src/daft-dsl/src/optimization.rs b/src/daft-dsl/src/optimization.rs index 06cff96959..38a9c8588e 100644 --- a/src/daft-dsl/src/optimization.rs +++ b/src/daft-dsl/src/optimization.rs @@ -2,8 +2,7 @@ use std::collections::HashMap; use common_treenode::{Transformed, TreeNode, TreeNodeRecursion}; -use super::expr::Expr; -use crate::{ExprRef, Operator}; +use crate::{Expr, ExprRef}; pub fn get_required_columns(e: &ExprRef) -> Vec { let mut cols = vec![]; @@ -57,30 +56,3 @@ pub fn replace_columns_with_expressions( .expect("Error occurred when rewriting column expressions"); transformed.data } - -pub fn split_conjuction(expr: &ExprRef) -> Vec<&ExprRef> { - let mut splits = vec![]; - _split_conjuction(expr, &mut splits); - splits -} - -fn _split_conjuction<'a>(expr: &'a ExprRef, out_exprs: &mut Vec<&'a ExprRef>) { - match expr.as_ref() { - Expr::BinaryOp { - op: Operator::And, - left, - right, - } => { - _split_conjuction(left, out_exprs); - _split_conjuction(right, out_exprs); - } - Expr::Alias(inner_expr, ..) => _split_conjuction(inner_expr, out_exprs), - _ => { - out_exprs.push(expr); - } - } -} - -pub fn conjuct>(exprs: T) -> Option { - exprs.into_iter().reduce(|acc, expr| acc.and(expr)) -} diff --git a/src/daft-logical-plan/Cargo.toml b/src/daft-logical-plan/Cargo.toml index 1b4dab023f..707d881977 100644 --- a/src/daft-logical-plan/Cargo.toml +++ b/src/daft-logical-plan/Cargo.toml @@ -8,6 +8,7 @@ common-py-serde = {path = "../common/py-serde", default-features = false} common-resource-request = {path = "../common/resource-request", default-features = false} common-scan-info = {path = "../common/scan-info", default-features = false} common-treenode = {path = "../common/treenode", default-features = false} +daft-algebra = {path = "../daft-algebra", default-features = false} daft-core = {path = "../daft-core", default-features = false} daft-dsl = {path = "../daft-dsl", default-features = false} daft-functions = {path = "../daft-functions", default-features = false} diff --git a/src/daft-logical-plan/src/optimization/rules/push_down_filter.rs b/src/daft-logical-plan/src/optimization/rules/push_down_filter.rs index 2b77bd8e9a..6e5be33c40 100644 --- a/src/daft-logical-plan/src/optimization/rules/push_down_filter.rs +++ b/src/daft-logical-plan/src/optimization/rules/push_down_filter.rs @@ -6,12 +6,11 @@ use std::{ use common_error::DaftResult; use common_scan_info::{rewrite_predicate_for_partitioning, PredicateGroups}; use common_treenode::{DynTreeNode, Transformed, TreeNode}; +use daft_algebra::boolean::{combine_conjunction, split_conjunction}; use daft_core::join::JoinType; use daft_dsl::{ col, - optimization::{ - conjuct, get_required_columns, replace_columns_with_expressions, split_conjuction, - }, + optimization::{get_required_columns, replace_columns_with_expressions}, ExprRef, }; @@ -56,20 +55,20 @@ impl PushDownFilter { // Filter-Filter --> Filter // Split predicate expression on conjunctions (ANDs). - let parent_predicates = split_conjuction(&filter.predicate); - let predicate_set: HashSet<&ExprRef> = parent_predicates.iter().copied().collect(); + let parent_predicates = split_conjunction(&filter.predicate); + let predicate_set: HashSet<&ExprRef> = parent_predicates.iter().collect(); // Add child predicate expressions to parent predicate expressions, eliminating duplicates. let new_predicates: Vec = parent_predicates .iter() .chain( - split_conjuction(&child_filter.predicate) + split_conjunction(&child_filter.predicate) .iter() - .filter(|e| !predicate_set.contains(**e)), + .filter(|e| !predicate_set.contains(*e)), ) .map(|e| (*e).clone()) .collect::>(); // Reconjunct predicate expressions. - let new_predicate = conjuct(new_predicates).unwrap(); + let new_predicate = combine_conjunction(new_predicates).unwrap(); let new_filter: Arc = LogicalPlan::from(Filter::try_new(child_filter.input.clone(), new_predicate)?) .into(); @@ -133,8 +132,8 @@ impl PushDownFilter { return Ok(Transformed::no(plan)); } - let data_filter = conjuct(data_only_filter); - let partition_filter = conjuct(partition_only_filter); + let data_filter = combine_conjunction(data_only_filter); + let partition_filter = combine_conjunction(partition_only_filter); assert!(data_filter.is_some() || partition_filter.is_some()); let new_pushdowns = if let Some(data_filter) = data_filter { @@ -158,7 +157,7 @@ impl PushDownFilter { // TODO(Clark): Support pushing predicates referencing both partition and data columns into the scan. let filter_op: LogicalPlan = Filter::try_new( new_source.into(), - conjuct(needing_filter_op).unwrap(), + combine_conjunction(needing_filter_op).unwrap(), )? .into(); return Ok(Transformed::yes(filter_op.into())); @@ -176,7 +175,7 @@ impl PushDownFilter { // don't involve compute. // // Filter-Projection --> {Filter-}Projection-Filter - let predicates = split_conjuction(&filter.predicate); + let predicates = split_conjunction(&filter.predicate); let projection_input_mapping = child_project .projection .iter() @@ -191,7 +190,7 @@ impl PushDownFilter { let mut can_push: Vec = vec![]; let mut can_not_push: Vec = vec![]; for predicate in predicates { - let predicate_cols = get_required_columns(predicate); + let predicate_cols = get_required_columns(&predicate); if predicate_cols .iter() .all(|col| projection_input_mapping.contains_key(col)) @@ -212,7 +211,7 @@ impl PushDownFilter { return Ok(Transformed::no(plan)); } // Create new Filter with predicates that can be pushed past Projection. - let predicates_to_push = conjuct(can_push).unwrap(); + let predicates_to_push = combine_conjunction(can_push).unwrap(); let push_down_filter: LogicalPlan = Filter::try_new(child_project.input.clone(), predicates_to_push)?.into(); // Create new Projection. @@ -226,7 +225,7 @@ impl PushDownFilter { } else { // Otherwise, add a Filter after Projection that filters with predicate expressions // that couldn't be pushed past the Projection, returning a Filter-Projection-Filter subplan. - let post_projection_predicate = conjuct(can_not_push).unwrap(); + let post_projection_predicate = combine_conjunction(can_not_push).unwrap(); let post_projection_filter: LogicalPlan = Filter::try_new(new_projection.into(), post_projection_predicate)?.into(); post_projection_filter.into() @@ -274,7 +273,7 @@ impl PushDownFilter { let left_cols = HashSet::<_>::from_iter(child_join.left.schema().names()); let right_cols = HashSet::<_>::from_iter(child_join.right.schema().names()); - for predicate in split_conjuction(&filter.predicate).into_iter().cloned() { + for predicate in split_conjunction(&filter.predicate) { let pred_cols = HashSet::<_>::from_iter(get_required_columns(&predicate)); match ( @@ -307,11 +306,11 @@ impl PushDownFilter { } } - let left_pushdowns = conjuct(left_pushdowns); - let right_pushdowns = conjuct(right_pushdowns); + let left_pushdowns = combine_conjunction(left_pushdowns); + let right_pushdowns = combine_conjunction(right_pushdowns); if left_pushdowns.is_some() || right_pushdowns.is_some() { - let kept_predicates = conjuct(kept_predicates); + let kept_predicates = combine_conjunction(kept_predicates); let new_left = left_pushdowns.map_or_else( || child_join.left.clone(), diff --git a/src/daft-logical-plan/src/optimization/rules/simplify_expressions.rs b/src/daft-logical-plan/src/optimization/rules/simplify_expressions.rs index bb890e2a17..bb395ae428 100644 --- a/src/daft-logical-plan/src/optimization/rules/simplify_expressions.rs +++ b/src/daft-logical-plan/src/optimization/rules/simplify_expressions.rs @@ -3,9 +3,7 @@ use std::sync::Arc; use common_error::DaftResult; use common_scan_info::{PhysicalScanInfo, ScanState}; use common_treenode::{Transformed, TreeNode}; -use daft_core::prelude::SchemaRef; -use daft_dsl::{lit, null_lit, Expr, ExprRef, LiteralValue, Operator}; -use daft_schema::dtype::DataType; +use daft_algebra::simplify_expr; use super::OptimizerRule; use crate::LogicalPlan; @@ -46,364 +44,13 @@ impl OptimizerRule for SimplifyExpressionsRule { } } -fn simplify_expr(expr: Expr, schema: &SchemaRef) -> DaftResult> { - Ok(match expr { - // ---------------- - // Eq - // ---------------- - // true = A --> A - // false = A --> !A - Expr::BinaryOp { - op: Operator::Eq, - left, - right, - } - // A = true --> A - // A = false --> !A - | Expr::BinaryOp { - op: Operator::Eq, - left: right, - right: left, - } if is_bool_lit(&left) && is_bool_type(&right, schema) => { - Transformed::yes(match as_bool_lit(&left) { - Some(true) => right, - Some(false) => right.not(), - None => unreachable!(), - }) - } - - // null = A --> null - // A = null --> null - Expr::BinaryOp { - op: Operator::Eq, - left, - right, - } - | Expr::BinaryOp { - op: Operator::Eq, - left: right, - right: left, - } if is_null(&left) && is_bool_type(&right, schema) => Transformed::yes(null_lit()), - - // ---------------- - // Neq - // ---------------- - // true != A --> !A - // false != A --> A - Expr::BinaryOp { - op: Operator::NotEq, - left, - right, - } - // A != true --> !A - // A != false --> A - | Expr::BinaryOp { - op: Operator::NotEq, - left: right, - right: left, - } if is_bool_lit(&left) && is_bool_type(&right, schema) => { - Transformed::yes(match as_bool_lit(&left) { - Some(true) => right.not(), - Some(false) => right, - None => unreachable!(), - }) - } - - // null != A --> null - // A != null --> null - Expr::BinaryOp { - op: Operator::NotEq, - left, - right, - } - | Expr::BinaryOp { - op: Operator::NotEq, - left: right, - right: left, - } if is_null(&left) && is_bool_type(&right, schema) => Transformed::yes(null_lit()), - - // ---------------- - // OR - // ---------------- - - // true OR A --> true - Expr::BinaryOp { - op: Operator::Or, - left, - right: _, - } if is_true(&left) => Transformed::yes(left), - // false OR A --> A - Expr::BinaryOp { - op: Operator::Or, - left, - right, - } if is_false(&left) => Transformed::yes(right), - // A OR true --> true - Expr::BinaryOp { - op: Operator::Or, - left: _, - right, - } if is_true(&right) => Transformed::yes(right), - // A OR false --> A - Expr::BinaryOp { - left, - op: Operator::Or, - right, - } if is_false(&right) => Transformed::yes(left), - - // ---------------- - // AND (TODO) - // ---------------- - - // ---------------- - // Multiplication - // ---------------- - - // A * 1 --> A - // 1 * A --> A - Expr::BinaryOp { - op: Operator::Multiply, - left, - right, - }| Expr::BinaryOp { - op: Operator::Multiply, - left: right, - right: left, - } if is_one(&right) => Transformed::yes(left), - - // A * null --> null - Expr::BinaryOp { - op: Operator::Multiply, - left: _, - right, - } if is_null(&right) => Transformed::yes(right), - // null * A --> null - Expr::BinaryOp { - op: Operator::Multiply, - left, - right: _, - } if is_null(&left) => Transformed::yes(left), - - // TODO: Can't do this one because we don't have a way to determine if an expr potentially contains nulls (nullable) - // A * 0 --> 0 (if A is not null and not floating/decimal) - // 0 * A --> 0 (if A is not null and not floating/decimal) - - // ---------------- - // Division - // ---------------- - // A / 1 --> A - Expr::BinaryOp { - op: Operator::TrueDivide, - left, - right, - } if is_one(&right) => Transformed::yes(left), - // null / A --> null - Expr::BinaryOp { - op: Operator::TrueDivide, - left, - right: _, - } if is_null(&left) => Transformed::yes(left), - // A / null --> null - Expr::BinaryOp { - op: Operator::TrueDivide, - left: _, - right, - } if is_null(&right) => Transformed::yes(right), - - // ---------------- - // Addition - // ---------------- - // A + 0 --> A - Expr::BinaryOp { - op: Operator::Plus, - left, - right, - } if is_zero(&right) => Transformed::yes(left), - - // 0 + A --> A - Expr::BinaryOp { - op: Operator::Plus, - left, - right, - } if is_zero(&left) => Transformed::yes(right), - - // ---------------- - // Subtraction - // ---------------- - - // A - 0 --> A - Expr::BinaryOp { - op: Operator::Minus, - left, - right, - } if is_zero(&right) => Transformed::yes(left), - - // A - null --> null - Expr::BinaryOp { - op: Operator::Minus, - left: _, - right, - } if is_null(&right) => Transformed::yes(right), - // null - A --> null - Expr::BinaryOp { - op: Operator::Minus, - left, - right: _, - } if is_null(&left) => Transformed::yes(left), - - // ---------------- - // Modulus - // ---------------- - - // A % null --> null - Expr::BinaryOp { - op: Operator::Modulus, - left: _, - right, - } if is_null(&right) => Transformed::yes(right), - - // null % A --> null - Expr::BinaryOp { - op: Operator::Modulus, - left, - right: _, - } if is_null(&left) => Transformed::yes(left), - - // A BETWEEN low AND high --> A >= low AND A <= high - Expr::Between(expr, low, high) => { - Transformed::yes(expr.clone().lt_eq(high).and(expr.gt_eq(low))) - } - Expr::Not(expr) => match Arc::unwrap_or_clone(expr) { - // NOT (BETWEEN A AND B) --> A < low OR A > high - Expr::Between(expr, low, high) => { - Transformed::yes(expr.clone().lt(low).or(expr.gt(high))) - } - // expr NOT IN () --> true - Expr::IsIn(_, list) if list.is_empty() => Transformed::yes(lit(true)), - - expr => { - let expr = simplify_expr(expr, schema)?; - if expr.transformed { - Transformed::yes(expr.data.not()) - } else { - Transformed::no(expr.data.not()) - } - } - }, - // expr IN () --> false - Expr::IsIn(_, list) if list.is_empty() => Transformed::yes(lit(false)), - - other => Transformed::no(Arc::new(other)), - }) -} - -fn is_zero(s: &Expr) -> bool { - match s { - Expr::Literal(LiteralValue::Int32(0)) - | Expr::Literal(LiteralValue::Int64(0)) - | Expr::Literal(LiteralValue::UInt32(0)) - | Expr::Literal(LiteralValue::UInt64(0)) - | Expr::Literal(LiteralValue::Float64(0.)) => true, - Expr::Literal(LiteralValue::Decimal(v, _p, _s)) if *v == 0 => true, - _ => false, - } -} - -fn is_one(s: &Expr) -> bool { - match s { - Expr::Literal(LiteralValue::Int32(1)) - | Expr::Literal(LiteralValue::Int64(1)) - | Expr::Literal(LiteralValue::UInt32(1)) - | Expr::Literal(LiteralValue::UInt64(1)) - | Expr::Literal(LiteralValue::Float64(1.)) => true, - - Expr::Literal(LiteralValue::Decimal(v, _p, s)) => { - *s >= 0 && POWS_OF_TEN.get(*s as usize).is_some_and(|pow| v == pow) - } - _ => false, - } -} - -fn is_true(expr: &Expr) -> bool { - match expr { - Expr::Literal(LiteralValue::Boolean(v)) => *v, - _ => false, - } -} -fn is_false(expr: &Expr) -> bool { - match expr { - Expr::Literal(LiteralValue::Boolean(v)) => !*v, - _ => false, - } -} - -/// returns true if expr is a -/// `Expr::Literal(LiteralValue::Boolean(v))` , false otherwise -fn is_bool_lit(expr: &Expr) -> bool { - matches!(expr, Expr::Literal(LiteralValue::Boolean(_))) -} - -fn is_bool_type(expr: &Expr, schema: &SchemaRef) -> bool { - matches!(expr.get_type(schema), Ok(DataType::Boolean)) -} - -fn as_bool_lit(expr: &Expr) -> Option { - expr.as_literal().and_then(|l| l.as_bool()) -} - -fn is_null(expr: &Expr) -> bool { - matches!(expr, Expr::Literal(LiteralValue::Null)) -} - -static POWS_OF_TEN: [i128; 38] = [ - 1, - 10, - 100, - 1000, - 10000, - 100000, - 1000000, - 10000000, - 100000000, - 1000000000, - 10000000000, - 100000000000, - 1000000000000, - 10000000000000, - 100000000000000, - 1000000000000000, - 10000000000000000, - 100000000000000000, - 1000000000000000000, - 10000000000000000000, - 100000000000000000000, - 1000000000000000000000, - 10000000000000000000000, - 100000000000000000000000, - 1000000000000000000000000, - 10000000000000000000000000, - 100000000000000000000000000, - 1000000000000000000000000000, - 10000000000000000000000000000, - 100000000000000000000000000000, - 1000000000000000000000000000000, - 10000000000000000000000000000000, - 100000000000000000000000000000000, - 1000000000000000000000000000000000, - 10000000000000000000000000000000000, - 100000000000000000000000000000000000, - 1000000000000000000000000000000000000, - 10000000000000000000000000000000000000, -]; - #[cfg(test)] mod test { use std::sync::Arc; use daft_core::prelude::Schema; - use daft_dsl::{col, lit, null_lit, ExprRef}; + use daft_dsl::{col, lit}; use daft_schema::{dtype::DataType, field::Field}; - use rstest::rstest; use super::SimplifyExpressionsRule; use crate::{ @@ -436,112 +83,6 @@ mod test { ) } - #[rstest] - // true = A --> A - #[case(col("bool").eq(lit(true)), col("bool"))] - // false = A --> !A - #[case(col("bool").eq(lit(false)), col("bool").not())] - // A == true ---> A - #[case(col("bool").eq(lit(true)), col("bool"))] - // null = A --> null - #[case(null_lit().eq(col("bool")), null_lit())] - // A == false ---> !A - #[case(col("bool").eq(lit(false)), col("bool").not())] - // true != A --> !A - #[case(lit(true).not_eq(col("bool")), col("bool").not())] - // false != A --> A - #[case(lit(false).not_eq(col("bool")), col("bool"))] - // true OR A --> true - #[case(lit(true).or(col("bool")), lit(true))] - // false OR A --> A - #[case(lit(false).or(col("bool")), col("bool"))] - // A OR true --> true - #[case(col("bool").or(lit(true)), lit(true))] - // A OR false --> A - #[case(col("bool").or(lit(false)), col("bool"))] - fn test_simplify_bool_exprs(#[case] input: ExprRef, #[case] expected: ExprRef) { - let source = make_source().filter(input).unwrap().build(); - let optimizer = SimplifyExpressionsRule::new(); - let optimized = optimizer.try_optimize(source).unwrap(); - - let LogicalPlan::Filter(Filter { predicate, .. }) = optimized.data.as_ref() else { - panic!("Expected Filter, got {:?}", optimized.data) - }; - - // make sure the expression is simplified - assert!(optimized.transformed); - - assert_eq!(predicate, &expected); - } - - #[rstest] - // A * 1 --> A - #[case(col("int").mul(lit(1)), col("int"))] - // 1 * A --> A - #[case(lit(1).mul(col("int")), col("int"))] - // A / 1 --> A - #[case(col("int").div(lit(1)), col("int"))] - // A + 0 --> A - #[case(col("int").add(lit(0)), col("int"))] - // A - 0 --> A - #[case(col("int").sub(lit(0)), col("int"))] - fn test_math_exprs(#[case] input: ExprRef, #[case] expected: ExprRef) { - let source = make_source().select(vec![input]).unwrap().build(); - let optimizer = SimplifyExpressionsRule::new(); - let optimized = optimizer.try_optimize(source).unwrap(); - - let LogicalPlan::Project(Project { projection, .. }) = optimized.data.as_ref() else { - panic!("Expected Filter, got {:?}", optimized.data) - }; - - let projection = projection.first().unwrap(); - - // make sure the expression is simplified - assert!(optimized.transformed); - - assert_eq!(projection, &expected); - } - - #[test] - fn test_not_between() { - let source = make_source() - .filter(col("int").between(lit(1), lit(10)).not()) - .unwrap() - .build(); - let optimizer = SimplifyExpressionsRule::new(); - let optimized = optimizer.try_optimize(source).unwrap(); - - let LogicalPlan::Filter(Filter { predicate, .. }) = optimized.data.as_ref() else { - panic!("Expected Filter, got {:?}", optimized.data) - }; - - // make sure the expression is simplified - assert!(optimized.transformed); - - assert_eq!(predicate, &col("int").lt(lit(1)).or(col("int").gt(lit(10)))); - } - - #[test] - fn test_between() { - let source = make_source() - .filter(col("int").between(lit(1), lit(10))) - .unwrap() - .build(); - let optimizer = SimplifyExpressionsRule::new(); - let optimized = optimizer.try_optimize(source).unwrap(); - - let LogicalPlan::Filter(Filter { predicate, .. }) = optimized.data.as_ref() else { - panic!("Expected Filter, got {:?}", optimized.data) - }; - - // make sure the expression is simplified - assert!(optimized.transformed); - - assert_eq!( - predicate, - &col("int").lt_eq(lit(10)).and(col("int").gt_eq(lit(1))) - ); - } #[test] fn test_nested_plan() { let source = make_source() diff --git a/src/daft-logical-plan/src/optimization/rules/unnest_subquery.rs b/src/daft-logical-plan/src/optimization/rules/unnest_subquery.rs index 3413e8cc53..5039cc9767 100644 --- a/src/daft-logical-plan/src/optimization/rules/unnest_subquery.rs +++ b/src/daft-logical-plan/src/optimization/rules/unnest_subquery.rs @@ -2,12 +2,9 @@ use std::{collections::HashSet, sync::Arc}; use common_error::{DaftError, DaftResult}; use common_treenode::{DynTreeNode, Transformed, TreeNode}; +use daft_algebra::boolean::{combine_conjunction, split_conjunction}; use daft_core::{join::JoinType, prelude::SchemaRef}; -use daft_dsl::{ - col, - optimization::{conjuct, split_conjuction}, - Expr, ExprRef, Operator, Subquery, -}; +use daft_dsl::{col, Expr, ExprRef, Operator, Subquery}; use itertools::multiunzip; use uuid::Uuid; @@ -73,12 +70,12 @@ impl UnnestScalarSubquery { impl UnnestScalarSubquery { fn unnest_subqueries( input: LogicalPlanRef, - exprs: Vec<&ExprRef>, + exprs: &[ExprRef], ) -> DaftResult)>> { let mut subqueries = HashSet::new(); let new_exprs = exprs - .into_iter() + .iter() .map(|expr| { expr.clone() .transform_down(|e| { @@ -164,7 +161,7 @@ impl OptimizerRule for UnnestScalarSubquery { input, predicate, .. }) => { let unnest_result = - Self::unnest_subqueries(input.clone(), split_conjuction(predicate))?; + Self::unnest_subqueries(input.clone(), &split_conjunction(predicate))?; if !unnest_result.transformed { return Ok(Transformed::no(node)); @@ -172,7 +169,7 @@ impl OptimizerRule for UnnestScalarSubquery { let (new_input, new_predicates) = unnest_result.data; - let new_predicate = conjuct(new_predicates) + let new_predicate = combine_conjunction(new_predicates) .expect("predicates are guaranteed to exist at this point, so 'conjunct' should never return 'None'"); let new_filter = Arc::new(LogicalPlan::Filter(Filter::try_new( @@ -192,7 +189,7 @@ impl OptimizerRule for UnnestScalarSubquery { input, projection, .. }) => { let unnest_result = - Self::unnest_subqueries(input.clone(), projection.iter().collect())?; + Self::unnest_subqueries(input.clone(), projection)?; if !unnest_result.transformed { return Ok(Transformed::no(node)); @@ -275,7 +272,7 @@ impl OptimizerRule for UnnestPredicateSubquery { }) => { let mut subqueries = HashSet::new(); - let new_predicates = split_conjuction(predicate) + let new_predicates = split_conjunction(predicate) .into_iter() .filter(|expr| { match expr.as_ref() { @@ -303,7 +300,6 @@ impl OptimizerRule for UnnestPredicateSubquery { _ => true } }) - .cloned() .collect::>(); if subqueries.is_empty() { @@ -345,7 +341,7 @@ impl OptimizerRule for UnnestPredicateSubquery { )?))) })?; - let new_plan = if let Some(new_predicate) = conjuct(new_predicates) { + let new_plan = if let Some(new_predicate) = combine_conjunction(new_predicates) { // add filter back if there are non-subquery predicates Arc::new(LogicalPlan::Filter(Filter::try_new( new_input, @@ -387,7 +383,7 @@ fn pull_up_correlated_cols( }) => { let mut found_correlated_col = false; - let preds = split_conjuction(predicate) + let preds = split_conjunction(predicate) .into_iter() .filter(|expr| { if let Expr::BinaryOp { @@ -418,7 +414,6 @@ fn pull_up_correlated_cols( true }) - .cloned() .collect::>(); // no new correlated cols found @@ -426,7 +421,7 @@ fn pull_up_correlated_cols( return Ok((plan.clone(), subquery_on, outer_on)); } - if let Some(new_predicate) = conjuct(preds) { + if let Some(new_predicate) = combine_conjunction(preds) { let new_plan = Arc::new(LogicalPlan::Filter(Filter::try_new( input.clone(), new_predicate, diff --git a/src/daft-sql/Cargo.toml b/src/daft-sql/Cargo.toml index 6e45c23741..a402235011 100644 --- a/src/daft-sql/Cargo.toml +++ b/src/daft-sql/Cargo.toml @@ -3,6 +3,7 @@ common-daft-config = {path = "../common/daft-config"} common-error = {path = "../common/error"} common-io-config = {path = "../common/io-config", default-features = false} common-runtime = {workspace = true} +daft-algebra = {path = "../daft-algebra"} daft-core = {path = "../daft-core"} daft-dsl = {path = "../daft-dsl"} daft-functions = {path = "../daft-functions"} diff --git a/src/daft-sql/src/planner.rs b/src/daft-sql/src/planner.rs index 683391b601..ce2ef703a3 100644 --- a/src/daft-sql/src/planner.rs +++ b/src/daft-sql/src/planner.rs @@ -6,13 +6,13 @@ use std::{ }; use common_error::{DaftError, DaftResult}; +use daft_algebra::boolean::combine_conjunction; use daft_core::prelude::*; use daft_dsl::{ col, common_treenode::{Transformed, TreeNode}, - has_agg, lit, literals_to_series, null_lit, - optimization::conjuct, - AggExpr, Expr, ExprRef, LiteralValue, Operator, OuterReferenceColumn, Subquery, + has_agg, lit, literals_to_series, null_lit, AggExpr, Expr, ExprRef, LiteralValue, Operator, + OuterReferenceColumn, Subquery, }; use daft_functions::{ numeric::{ceil::ceil, floor::floor}, @@ -959,12 +959,12 @@ impl<'a> SQLPlanner<'a> { }; let mut left_plan = self.current_relation.as_ref().unwrap().inner.clone(); - if let Some(left_predicate) = conjuct(left_filters) { + if let Some(left_predicate) = combine_conjunction(left_filters) { left_plan = left_plan.filter(left_predicate)?; } let mut right_plan = right_rel.inner.clone(); - if let Some(right_predicate) = conjuct(right_filters) { + if let Some(right_predicate) = combine_conjunction(right_filters) { right_plan = right_plan.filter(right_predicate)?; }