diff --git a/Cargo.lock b/Cargo.lock index 3011a56b24..048324ca60 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -1577,6 +1577,7 @@ dependencies = [ "common-display", "common-error", "common-file-formats", + "daft-algebra", "daft-dsl", "daft-schema", "pyo3", @@ -1905,6 +1906,7 @@ dependencies = [ "common-system-info", "common-tracing", "common-version", + "daft-algebra", "daft-catalog", "daft-catalog-python-catalog", "daft-compression", @@ -1941,6 +1943,17 @@ dependencies = [ "tikv-jemallocator", ] +[[package]] +name = "daft-algebra" +version = "0.3.0-dev0" +dependencies = [ + "common-error", + "common-treenode", + "daft-dsl", + "daft-schema", + "rstest", +] + [[package]] name = "daft-catalog" version = "0.3.0-dev0" @@ -2329,6 +2342,7 @@ dependencies = [ "common-resource-request", "common-scan-info", "common-treenode", + "daft-algebra", "daft-core", "daft-dsl", "daft-functions", @@ -2536,6 +2550,7 @@ dependencies = [ "common-error", "common-io-config", "common-runtime", + "daft-algebra", "daft-core", "daft-dsl", "daft-functions", diff --git a/Cargo.toml b/Cargo.toml index d5a5cf218d..ba00e45cd4 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -9,6 +9,7 @@ common-scan-info = {path = "src/common/scan-info", default-features = false} common-system-info = {path = "src/common/system-info", default-features = false} common-tracing = {path = "src/common/tracing", default-features = false} common-version = {path = "src/common/version", default-features = false} +daft-algebra = {path = "src/daft-algebra", default-features = false} daft-catalog = {path = "src/daft-catalog", default-features = false} daft-catalog-python-catalog = {path = "src/daft-catalog/python-catalog", optional = true} daft-compression = {path = "src/daft-compression", default-features = false} @@ -149,6 +150,7 @@ members = [ "src/common/scan-info", "src/common/system-info", "src/common/treenode", + "src/daft-algebra", "src/daft-catalog", "src/daft-core", "src/daft-csv", diff --git a/src/common/scan-info/Cargo.toml b/src/common/scan-info/Cargo.toml index 04f9550997..0aecf55f6e 100644 --- a/src/common/scan-info/Cargo.toml +++ b/src/common/scan-info/Cargo.toml @@ -3,6 +3,7 @@ common-daft-config = {path = "../daft-config", default-features = false} common-display = {path = "../display", default-features = false} common-error = {path = "../error", default-features = false} common-file-formats = {path = "../file-formats", default-features = false} +daft-algebra = {path = "../../daft-algebra", default-features = false} daft-dsl = {path = "../../daft-dsl", default-features = false} daft-schema = {path = "../../daft-schema", default-features = false} pyo3 = {workspace = true, optional = true} diff --git a/src/common/scan-info/src/expr_rewriter.rs b/src/common/scan-info/src/expr_rewriter.rs index f678ad07c1..fedf212a18 100644 --- a/src/common/scan-info/src/expr_rewriter.rs +++ b/src/common/scan-info/src/expr_rewriter.rs @@ -1,13 +1,12 @@ use std::collections::HashMap; use common_error::DaftResult; +use daft_algebra::boolean::split_conjunction; use daft_dsl::{ col, common_treenode::{Transformed, TreeNode, TreeNodeRecursion}, functions::{partitioning, FunctionExpr}, - null_lit, - optimization::split_conjuction, - Expr, ExprRef, Operator, + null_lit, Expr, ExprRef, Operator, }; use crate::{PartitionField, PartitionTransform}; @@ -93,7 +92,7 @@ pub fn rewrite_predicate_for_partitioning( // Before rewriting predicate for partition filter pushdown, partition predicate clauses into groups that will need // to be applied at the data level (i.e. any clauses that aren't pure partition predicates with identity // transformations). - let data_split = split_conjuction(predicate); + let data_split = split_conjunction(predicate); // Predicates that reference both partition columns and data columns. let mut needs_filter_op_preds: Vec = vec![]; // Predicates that only reference data columns (no partition column references) or only reference partition columns @@ -332,7 +331,7 @@ pub fn rewrite_predicate_for_partitioning( let with_part_cols = with_part_cols.data; // Filter to predicate clauses that only involve partition columns. - let split = split_conjuction(&with_part_cols); + let split = split_conjunction(&with_part_cols); let mut part_preds: Vec = vec![]; for e in split { let mut all_part_keys = true; diff --git a/src/daft-algebra/Cargo.toml b/src/daft-algebra/Cargo.toml new file mode 100644 index 0000000000..89e942c700 --- /dev/null +++ b/src/daft-algebra/Cargo.toml @@ -0,0 +1,16 @@ +[dependencies] +common-error = {path = "../common/error", default-features = false} +common-treenode = {path = "../common/treenode", default-features = false} +daft-dsl = {path = "../daft-dsl", default-features = false} +daft-schema = {path = "../daft-schema", default-features = false} + +[dev-dependencies] +rstest = {workspace = true} + +[lints] +workspace = true + +[package] +edition = {workspace = true} +name = "daft-algebra" +version = {workspace = true} diff --git a/src/daft-algebra/src/boolean.rs b/src/daft-algebra/src/boolean.rs new file mode 100644 index 0000000000..38f659e00c --- /dev/null +++ b/src/daft-algebra/src/boolean.rs @@ -0,0 +1,24 @@ +use common_treenode::{TreeNode, TreeNodeRecursion}; +use daft_dsl::{Expr, ExprRef, Operator}; + +pub fn split_conjunction(expr: &ExprRef) -> Vec { + let mut splits = vec![]; + + expr.apply(|e| match e.as_ref() { + Expr::BinaryOp { + op: Operator::And, .. + } + | Expr::Alias(..) => Ok(TreeNodeRecursion::Continue), + _ => { + splits.push(e.clone()); + Ok(TreeNodeRecursion::Jump) + } + }) + .unwrap(); + + splits +} + +pub fn combine_conjunction>(exprs: T) -> Option { + exprs.into_iter().reduce(|acc, e| acc.and(e)) +} diff --git a/src/daft-algebra/src/lib.rs b/src/daft-algebra/src/lib.rs new file mode 100644 index 0000000000..317ef5eea1 --- /dev/null +++ b/src/daft-algebra/src/lib.rs @@ -0,0 +1,4 @@ +pub mod boolean; +mod simplify; + +pub use simplify::simplify_expr; diff --git a/src/daft-algebra/src/simplify.rs b/src/daft-algebra/src/simplify.rs new file mode 100644 index 0000000000..698a48fa1b --- /dev/null +++ b/src/daft-algebra/src/simplify.rs @@ -0,0 +1,465 @@ +use std::sync::Arc; + +use common_error::DaftResult; +use common_treenode::Transformed; +use daft_dsl::{lit, null_lit, Expr, ExprRef, LiteralValue, Operator}; +use daft_schema::{dtype::DataType, schema::SchemaRef}; + +pub fn simplify_expr(expr: Expr, schema: &SchemaRef) -> DaftResult> { + Ok(match expr { + // ---------------- + // Eq + // ---------------- + // true = A --> A + // false = A --> !A + Expr::BinaryOp { + op: Operator::Eq, + left, + right, + } + // A = true --> A + // A = false --> !A + | Expr::BinaryOp { + op: Operator::Eq, + left: right, + right: left, + } if is_bool_lit(&left) && is_bool_type(&right, schema) => { + Transformed::yes(match as_bool_lit(&left) { + Some(true) => right, + Some(false) => right.not(), + None => unreachable!(), + }) + } + + // null = A --> null + // A = null --> null + Expr::BinaryOp { + op: Operator::Eq, + left, + right, + } + | Expr::BinaryOp { + op: Operator::Eq, + left: right, + right: left, + } if is_null(&left) && is_bool_type(&right, schema) => Transformed::yes(null_lit()), + + // ---------------- + // Neq + // ---------------- + // true != A --> !A + // false != A --> A + Expr::BinaryOp { + op: Operator::NotEq, + left, + right, + } + // A != true --> !A + // A != false --> A + | Expr::BinaryOp { + op: Operator::NotEq, + left: right, + right: left, + } if is_bool_lit(&left) && is_bool_type(&right, schema) => { + Transformed::yes(match as_bool_lit(&left) { + Some(true) => right.not(), + Some(false) => right, + None => unreachable!(), + }) + } + + // null != A --> null + // A != null --> null + Expr::BinaryOp { + op: Operator::NotEq, + left, + right, + } + | Expr::BinaryOp { + op: Operator::NotEq, + left: right, + right: left, + } if is_null(&left) && is_bool_type(&right, schema) => Transformed::yes(null_lit()), + + // ---------------- + // OR + // ---------------- + + // true OR A --> true + Expr::BinaryOp { + op: Operator::Or, + left, + right: _, + } if is_true(&left) => Transformed::yes(left), + // false OR A --> A + Expr::BinaryOp { + op: Operator::Or, + left, + right, + } if is_false(&left) => Transformed::yes(right), + // A OR true --> true + Expr::BinaryOp { + op: Operator::Or, + left: _, + right, + } if is_true(&right) => Transformed::yes(right), + // A OR false --> A + Expr::BinaryOp { + left, + op: Operator::Or, + right, + } if is_false(&right) => Transformed::yes(left), + + // ---------------- + // AND (TODO) + // ---------------- + + // ---------------- + // Multiplication + // ---------------- + + // A * 1 --> A + // 1 * A --> A + Expr::BinaryOp { + op: Operator::Multiply, + left, + right, + }| Expr::BinaryOp { + op: Operator::Multiply, + left: right, + right: left, + } if is_one(&right) => Transformed::yes(left), + + // A * null --> null + Expr::BinaryOp { + op: Operator::Multiply, + left: _, + right, + } if is_null(&right) => Transformed::yes(right), + // null * A --> null + Expr::BinaryOp { + op: Operator::Multiply, + left, + right: _, + } if is_null(&left) => Transformed::yes(left), + + // TODO: Can't do this one because we don't have a way to determine if an expr potentially contains nulls (nullable) + // A * 0 --> 0 (if A is not null and not floating/decimal) + // 0 * A --> 0 (if A is not null and not floating/decimal) + + // ---------------- + // Division + // ---------------- + // A / 1 --> A + Expr::BinaryOp { + op: Operator::TrueDivide, + left, + right, + } if is_one(&right) => Transformed::yes(left), + // null / A --> null + Expr::BinaryOp { + op: Operator::TrueDivide, + left, + right: _, + } if is_null(&left) => Transformed::yes(left), + // A / null --> null + Expr::BinaryOp { + op: Operator::TrueDivide, + left: _, + right, + } if is_null(&right) => Transformed::yes(right), + + // ---------------- + // Addition + // ---------------- + // A + 0 --> A + Expr::BinaryOp { + op: Operator::Plus, + left, + right, + } if is_zero(&right) => Transformed::yes(left), + + // 0 + A --> A + Expr::BinaryOp { + op: Operator::Plus, + left, + right, + } if is_zero(&left) => Transformed::yes(right), + + // ---------------- + // Subtraction + // ---------------- + + // A - 0 --> A + Expr::BinaryOp { + op: Operator::Minus, + left, + right, + } if is_zero(&right) => Transformed::yes(left), + + // A - null --> null + Expr::BinaryOp { + op: Operator::Minus, + left: _, + right, + } if is_null(&right) => Transformed::yes(right), + // null - A --> null + Expr::BinaryOp { + op: Operator::Minus, + left, + right: _, + } if is_null(&left) => Transformed::yes(left), + + // ---------------- + // Modulus + // ---------------- + + // A % null --> null + Expr::BinaryOp { + op: Operator::Modulus, + left: _, + right, + } if is_null(&right) => Transformed::yes(right), + + // null % A --> null + Expr::BinaryOp { + op: Operator::Modulus, + left, + right: _, + } if is_null(&left) => Transformed::yes(left), + + // A BETWEEN low AND high --> A >= low AND A <= high + Expr::Between(expr, low, high) => { + Transformed::yes(expr.clone().lt_eq(high).and(expr.gt_eq(low))) + } + Expr::Not(expr) => match Arc::unwrap_or_clone(expr) { + // NOT (BETWEEN A AND B) --> A < low OR A > high + Expr::Between(expr, low, high) => { + Transformed::yes(expr.clone().lt(low).or(expr.gt(high))) + } + // expr NOT IN () --> true + Expr::IsIn(_, list) if list.is_empty() => Transformed::yes(lit(true)), + + expr => { + let expr = simplify_expr(expr, schema)?; + if expr.transformed { + Transformed::yes(expr.data.not()) + } else { + Transformed::no(expr.data.not()) + } + } + }, + // expr IN () --> false + Expr::IsIn(_, list) if list.is_empty() => Transformed::yes(lit(false)), + + other => Transformed::no(Arc::new(other)), + }) +} + +fn is_zero(s: &Expr) -> bool { + match s { + Expr::Literal(LiteralValue::Int32(0)) + | Expr::Literal(LiteralValue::Int64(0)) + | Expr::Literal(LiteralValue::UInt32(0)) + | Expr::Literal(LiteralValue::UInt64(0)) + | Expr::Literal(LiteralValue::Float64(0.)) => true, + Expr::Literal(LiteralValue::Decimal(v, _p, _s)) if *v == 0 => true, + _ => false, + } +} + +fn is_one(s: &Expr) -> bool { + match s { + Expr::Literal(LiteralValue::Int32(1)) + | Expr::Literal(LiteralValue::Int64(1)) + | Expr::Literal(LiteralValue::UInt32(1)) + | Expr::Literal(LiteralValue::UInt64(1)) + | Expr::Literal(LiteralValue::Float64(1.)) => true, + + Expr::Literal(LiteralValue::Decimal(v, _p, s)) => { + *s >= 0 && POWS_OF_TEN.get(*s as usize).is_some_and(|pow| v == pow) + } + _ => false, + } +} + +fn is_true(expr: &Expr) -> bool { + match expr { + Expr::Literal(LiteralValue::Boolean(v)) => *v, + _ => false, + } +} +fn is_false(expr: &Expr) -> bool { + match expr { + Expr::Literal(LiteralValue::Boolean(v)) => !*v, + _ => false, + } +} + +/// returns true if expr is a +/// `Expr::Literal(LiteralValue::Boolean(v))` , false otherwise +fn is_bool_lit(expr: &Expr) -> bool { + matches!(expr, Expr::Literal(LiteralValue::Boolean(_))) +} + +fn is_bool_type(expr: &Expr, schema: &SchemaRef) -> bool { + matches!(expr.get_type(schema), Ok(DataType::Boolean)) +} + +fn as_bool_lit(expr: &Expr) -> Option { + expr.as_literal().and_then(|l| l.as_bool()) +} + +fn is_null(expr: &Expr) -> bool { + matches!(expr, Expr::Literal(LiteralValue::Null)) +} + +static POWS_OF_TEN: [i128; 38] = [ + 1, + 10, + 100, + 1000, + 10000, + 100000, + 1000000, + 10000000, + 100000000, + 1000000000, + 10000000000, + 100000000000, + 1000000000000, + 10000000000000, + 100000000000000, + 1000000000000000, + 10000000000000000, + 100000000000000000, + 1000000000000000000, + 10000000000000000000, + 100000000000000000000, + 1000000000000000000000, + 10000000000000000000000, + 100000000000000000000000, + 1000000000000000000000000, + 10000000000000000000000000, + 100000000000000000000000000, + 1000000000000000000000000000, + 10000000000000000000000000000, + 100000000000000000000000000000, + 1000000000000000000000000000000, + 10000000000000000000000000000000, + 100000000000000000000000000000000, + 1000000000000000000000000000000000, + 10000000000000000000000000000000000, + 100000000000000000000000000000000000, + 1000000000000000000000000000000000000, + 10000000000000000000000000000000000000, +]; + +#[cfg(test)] +mod test { + use std::sync::Arc; + + use common_error::DaftResult; + use daft_dsl::{col, lit, null_lit, ExprRef}; + use daft_schema::{ + dtype::DataType, + field::Field, + schema::{Schema, SchemaRef}, + }; + use rstest::{fixture, rstest}; + + use crate::simplify_expr; + + #[fixture] + fn schema() -> SchemaRef { + Arc::new( + Schema::new(vec![ + Field::new("bool", DataType::Boolean), + Field::new("int", DataType::Int32), + ]) + .unwrap(), + ) + } + + #[rstest] + // true = A --> A + #[case(col("bool").eq(lit(true)), col("bool"))] + // false = A --> !A + #[case(col("bool").eq(lit(false)), col("bool").not())] + // A == true ---> A + #[case(col("bool").eq(lit(true)), col("bool"))] + // null = A --> null + #[case(null_lit().eq(col("bool")), null_lit())] + // A == false ---> !A + #[case(col("bool").eq(lit(false)), col("bool").not())] + // true != A --> !A + #[case(lit(true).not_eq(col("bool")), col("bool").not())] + // false != A --> A + #[case(lit(false).not_eq(col("bool")), col("bool"))] + // true OR A --> true + #[case(lit(true).or(col("bool")), lit(true))] + // false OR A --> A + #[case(lit(false).or(col("bool")), col("bool"))] + // A OR true --> true + #[case(col("bool").or(lit(true)), lit(true))] + // A OR false --> A + #[case(col("bool").or(lit(false)), col("bool"))] + fn test_simplify_bool_exprs( + #[case] input: ExprRef, + #[case] expected: ExprRef, + schema: SchemaRef, + ) -> DaftResult<()> { + let optimized = simplify_expr(Arc::unwrap_or_clone(input), &schema)?; + + assert!(optimized.transformed); + assert_eq!(optimized.data, expected); + Ok(()) + } + + #[rstest] + // A * 1 --> A + #[case(col("int").mul(lit(1)), col("int"))] + // 1 * A --> A + #[case(lit(1).mul(col("int")), col("int"))] + // A / 1 --> A + #[case(col("int").div(lit(1)), col("int"))] + // A + 0 --> A + #[case(col("int").add(lit(0)), col("int"))] + // A - 0 --> A + #[case(col("int").sub(lit(0)), col("int"))] + fn test_math_exprs( + #[case] input: ExprRef, + #[case] expected: ExprRef, + schema: SchemaRef, + ) -> DaftResult<()> { + let optimized = simplify_expr(Arc::unwrap_or_clone(input), &schema)?; + + assert!(optimized.transformed); + assert_eq!(optimized.data, expected); + Ok(()) + } + + #[rstest] + fn test_not_between(schema: SchemaRef) -> DaftResult<()> { + let input = col("int").between(lit(1), lit(10)).not(); + let expected = col("int").lt(lit(1)).or(col("int").gt(lit(10))); + + let optimized = simplify_expr(Arc::unwrap_or_clone(input), &schema)?; + + assert!(optimized.transformed); + assert_eq!(optimized.data, expected); + Ok(()) + } + + #[rstest] + fn test_between(schema: SchemaRef) -> DaftResult<()> { + let input = col("int").between(lit(1), lit(10)); + let expected = col("int").lt_eq(lit(10)).and(col("int").gt_eq(lit(1))); + + let optimized = simplify_expr(Arc::unwrap_or_clone(input), &schema)?; + + assert!(optimized.transformed); + assert_eq!(optimized.data, expected); + Ok(()) + } +} diff --git a/src/daft-dsl/src/optimization.rs b/src/daft-dsl/src/optimization.rs index 06cff96959..38a9c8588e 100644 --- a/src/daft-dsl/src/optimization.rs +++ b/src/daft-dsl/src/optimization.rs @@ -2,8 +2,7 @@ use std::collections::HashMap; use common_treenode::{Transformed, TreeNode, TreeNodeRecursion}; -use super::expr::Expr; -use crate::{ExprRef, Operator}; +use crate::{Expr, ExprRef}; pub fn get_required_columns(e: &ExprRef) -> Vec { let mut cols = vec![]; @@ -57,30 +56,3 @@ pub fn replace_columns_with_expressions( .expect("Error occurred when rewriting column expressions"); transformed.data } - -pub fn split_conjuction(expr: &ExprRef) -> Vec<&ExprRef> { - let mut splits = vec![]; - _split_conjuction(expr, &mut splits); - splits -} - -fn _split_conjuction<'a>(expr: &'a ExprRef, out_exprs: &mut Vec<&'a ExprRef>) { - match expr.as_ref() { - Expr::BinaryOp { - op: Operator::And, - left, - right, - } => { - _split_conjuction(left, out_exprs); - _split_conjuction(right, out_exprs); - } - Expr::Alias(inner_expr, ..) => _split_conjuction(inner_expr, out_exprs), - _ => { - out_exprs.push(expr); - } - } -} - -pub fn conjuct>(exprs: T) -> Option { - exprs.into_iter().reduce(|acc, expr| acc.and(expr)) -} diff --git a/src/daft-logical-plan/Cargo.toml b/src/daft-logical-plan/Cargo.toml index 1b4dab023f..707d881977 100644 --- a/src/daft-logical-plan/Cargo.toml +++ b/src/daft-logical-plan/Cargo.toml @@ -8,6 +8,7 @@ common-py-serde = {path = "../common/py-serde", default-features = false} common-resource-request = {path = "../common/resource-request", default-features = false} common-scan-info = {path = "../common/scan-info", default-features = false} common-treenode = {path = "../common/treenode", default-features = false} +daft-algebra = {path = "../daft-algebra", default-features = false} daft-core = {path = "../daft-core", default-features = false} daft-dsl = {path = "../daft-dsl", default-features = false} daft-functions = {path = "../daft-functions", default-features = false} diff --git a/src/daft-logical-plan/src/optimization/rules/push_down_filter.rs b/src/daft-logical-plan/src/optimization/rules/push_down_filter.rs index 2b77bd8e9a..6e5be33c40 100644 --- a/src/daft-logical-plan/src/optimization/rules/push_down_filter.rs +++ b/src/daft-logical-plan/src/optimization/rules/push_down_filter.rs @@ -6,12 +6,11 @@ use std::{ use common_error::DaftResult; use common_scan_info::{rewrite_predicate_for_partitioning, PredicateGroups}; use common_treenode::{DynTreeNode, Transformed, TreeNode}; +use daft_algebra::boolean::{combine_conjunction, split_conjunction}; use daft_core::join::JoinType; use daft_dsl::{ col, - optimization::{ - conjuct, get_required_columns, replace_columns_with_expressions, split_conjuction, - }, + optimization::{get_required_columns, replace_columns_with_expressions}, ExprRef, }; @@ -56,20 +55,20 @@ impl PushDownFilter { // Filter-Filter --> Filter // Split predicate expression on conjunctions (ANDs). - let parent_predicates = split_conjuction(&filter.predicate); - let predicate_set: HashSet<&ExprRef> = parent_predicates.iter().copied().collect(); + let parent_predicates = split_conjunction(&filter.predicate); + let predicate_set: HashSet<&ExprRef> = parent_predicates.iter().collect(); // Add child predicate expressions to parent predicate expressions, eliminating duplicates. let new_predicates: Vec = parent_predicates .iter() .chain( - split_conjuction(&child_filter.predicate) + split_conjunction(&child_filter.predicate) .iter() - .filter(|e| !predicate_set.contains(**e)), + .filter(|e| !predicate_set.contains(*e)), ) .map(|e| (*e).clone()) .collect::>(); // Reconjunct predicate expressions. - let new_predicate = conjuct(new_predicates).unwrap(); + let new_predicate = combine_conjunction(new_predicates).unwrap(); let new_filter: Arc = LogicalPlan::from(Filter::try_new(child_filter.input.clone(), new_predicate)?) .into(); @@ -133,8 +132,8 @@ impl PushDownFilter { return Ok(Transformed::no(plan)); } - let data_filter = conjuct(data_only_filter); - let partition_filter = conjuct(partition_only_filter); + let data_filter = combine_conjunction(data_only_filter); + let partition_filter = combine_conjunction(partition_only_filter); assert!(data_filter.is_some() || partition_filter.is_some()); let new_pushdowns = if let Some(data_filter) = data_filter { @@ -158,7 +157,7 @@ impl PushDownFilter { // TODO(Clark): Support pushing predicates referencing both partition and data columns into the scan. let filter_op: LogicalPlan = Filter::try_new( new_source.into(), - conjuct(needing_filter_op).unwrap(), + combine_conjunction(needing_filter_op).unwrap(), )? .into(); return Ok(Transformed::yes(filter_op.into())); @@ -176,7 +175,7 @@ impl PushDownFilter { // don't involve compute. // // Filter-Projection --> {Filter-}Projection-Filter - let predicates = split_conjuction(&filter.predicate); + let predicates = split_conjunction(&filter.predicate); let projection_input_mapping = child_project .projection .iter() @@ -191,7 +190,7 @@ impl PushDownFilter { let mut can_push: Vec = vec![]; let mut can_not_push: Vec = vec![]; for predicate in predicates { - let predicate_cols = get_required_columns(predicate); + let predicate_cols = get_required_columns(&predicate); if predicate_cols .iter() .all(|col| projection_input_mapping.contains_key(col)) @@ -212,7 +211,7 @@ impl PushDownFilter { return Ok(Transformed::no(plan)); } // Create new Filter with predicates that can be pushed past Projection. - let predicates_to_push = conjuct(can_push).unwrap(); + let predicates_to_push = combine_conjunction(can_push).unwrap(); let push_down_filter: LogicalPlan = Filter::try_new(child_project.input.clone(), predicates_to_push)?.into(); // Create new Projection. @@ -226,7 +225,7 @@ impl PushDownFilter { } else { // Otherwise, add a Filter after Projection that filters with predicate expressions // that couldn't be pushed past the Projection, returning a Filter-Projection-Filter subplan. - let post_projection_predicate = conjuct(can_not_push).unwrap(); + let post_projection_predicate = combine_conjunction(can_not_push).unwrap(); let post_projection_filter: LogicalPlan = Filter::try_new(new_projection.into(), post_projection_predicate)?.into(); post_projection_filter.into() @@ -274,7 +273,7 @@ impl PushDownFilter { let left_cols = HashSet::<_>::from_iter(child_join.left.schema().names()); let right_cols = HashSet::<_>::from_iter(child_join.right.schema().names()); - for predicate in split_conjuction(&filter.predicate).into_iter().cloned() { + for predicate in split_conjunction(&filter.predicate) { let pred_cols = HashSet::<_>::from_iter(get_required_columns(&predicate)); match ( @@ -307,11 +306,11 @@ impl PushDownFilter { } } - let left_pushdowns = conjuct(left_pushdowns); - let right_pushdowns = conjuct(right_pushdowns); + let left_pushdowns = combine_conjunction(left_pushdowns); + let right_pushdowns = combine_conjunction(right_pushdowns); if left_pushdowns.is_some() || right_pushdowns.is_some() { - let kept_predicates = conjuct(kept_predicates); + let kept_predicates = combine_conjunction(kept_predicates); let new_left = left_pushdowns.map_or_else( || child_join.left.clone(), diff --git a/src/daft-logical-plan/src/optimization/rules/simplify_expressions.rs b/src/daft-logical-plan/src/optimization/rules/simplify_expressions.rs index bb890e2a17..bb395ae428 100644 --- a/src/daft-logical-plan/src/optimization/rules/simplify_expressions.rs +++ b/src/daft-logical-plan/src/optimization/rules/simplify_expressions.rs @@ -3,9 +3,7 @@ use std::sync::Arc; use common_error::DaftResult; use common_scan_info::{PhysicalScanInfo, ScanState}; use common_treenode::{Transformed, TreeNode}; -use daft_core::prelude::SchemaRef; -use daft_dsl::{lit, null_lit, Expr, ExprRef, LiteralValue, Operator}; -use daft_schema::dtype::DataType; +use daft_algebra::simplify_expr; use super::OptimizerRule; use crate::LogicalPlan; @@ -46,364 +44,13 @@ impl OptimizerRule for SimplifyExpressionsRule { } } -fn simplify_expr(expr: Expr, schema: &SchemaRef) -> DaftResult> { - Ok(match expr { - // ---------------- - // Eq - // ---------------- - // true = A --> A - // false = A --> !A - Expr::BinaryOp { - op: Operator::Eq, - left, - right, - } - // A = true --> A - // A = false --> !A - | Expr::BinaryOp { - op: Operator::Eq, - left: right, - right: left, - } if is_bool_lit(&left) && is_bool_type(&right, schema) => { - Transformed::yes(match as_bool_lit(&left) { - Some(true) => right, - Some(false) => right.not(), - None => unreachable!(), - }) - } - - // null = A --> null - // A = null --> null - Expr::BinaryOp { - op: Operator::Eq, - left, - right, - } - | Expr::BinaryOp { - op: Operator::Eq, - left: right, - right: left, - } if is_null(&left) && is_bool_type(&right, schema) => Transformed::yes(null_lit()), - - // ---------------- - // Neq - // ---------------- - // true != A --> !A - // false != A --> A - Expr::BinaryOp { - op: Operator::NotEq, - left, - right, - } - // A != true --> !A - // A != false --> A - | Expr::BinaryOp { - op: Operator::NotEq, - left: right, - right: left, - } if is_bool_lit(&left) && is_bool_type(&right, schema) => { - Transformed::yes(match as_bool_lit(&left) { - Some(true) => right.not(), - Some(false) => right, - None => unreachable!(), - }) - } - - // null != A --> null - // A != null --> null - Expr::BinaryOp { - op: Operator::NotEq, - left, - right, - } - | Expr::BinaryOp { - op: Operator::NotEq, - left: right, - right: left, - } if is_null(&left) && is_bool_type(&right, schema) => Transformed::yes(null_lit()), - - // ---------------- - // OR - // ---------------- - - // true OR A --> true - Expr::BinaryOp { - op: Operator::Or, - left, - right: _, - } if is_true(&left) => Transformed::yes(left), - // false OR A --> A - Expr::BinaryOp { - op: Operator::Or, - left, - right, - } if is_false(&left) => Transformed::yes(right), - // A OR true --> true - Expr::BinaryOp { - op: Operator::Or, - left: _, - right, - } if is_true(&right) => Transformed::yes(right), - // A OR false --> A - Expr::BinaryOp { - left, - op: Operator::Or, - right, - } if is_false(&right) => Transformed::yes(left), - - // ---------------- - // AND (TODO) - // ---------------- - - // ---------------- - // Multiplication - // ---------------- - - // A * 1 --> A - // 1 * A --> A - Expr::BinaryOp { - op: Operator::Multiply, - left, - right, - }| Expr::BinaryOp { - op: Operator::Multiply, - left: right, - right: left, - } if is_one(&right) => Transformed::yes(left), - - // A * null --> null - Expr::BinaryOp { - op: Operator::Multiply, - left: _, - right, - } if is_null(&right) => Transformed::yes(right), - // null * A --> null - Expr::BinaryOp { - op: Operator::Multiply, - left, - right: _, - } if is_null(&left) => Transformed::yes(left), - - // TODO: Can't do this one because we don't have a way to determine if an expr potentially contains nulls (nullable) - // A * 0 --> 0 (if A is not null and not floating/decimal) - // 0 * A --> 0 (if A is not null and not floating/decimal) - - // ---------------- - // Division - // ---------------- - // A / 1 --> A - Expr::BinaryOp { - op: Operator::TrueDivide, - left, - right, - } if is_one(&right) => Transformed::yes(left), - // null / A --> null - Expr::BinaryOp { - op: Operator::TrueDivide, - left, - right: _, - } if is_null(&left) => Transformed::yes(left), - // A / null --> null - Expr::BinaryOp { - op: Operator::TrueDivide, - left: _, - right, - } if is_null(&right) => Transformed::yes(right), - - // ---------------- - // Addition - // ---------------- - // A + 0 --> A - Expr::BinaryOp { - op: Operator::Plus, - left, - right, - } if is_zero(&right) => Transformed::yes(left), - - // 0 + A --> A - Expr::BinaryOp { - op: Operator::Plus, - left, - right, - } if is_zero(&left) => Transformed::yes(right), - - // ---------------- - // Subtraction - // ---------------- - - // A - 0 --> A - Expr::BinaryOp { - op: Operator::Minus, - left, - right, - } if is_zero(&right) => Transformed::yes(left), - - // A - null --> null - Expr::BinaryOp { - op: Operator::Minus, - left: _, - right, - } if is_null(&right) => Transformed::yes(right), - // null - A --> null - Expr::BinaryOp { - op: Operator::Minus, - left, - right: _, - } if is_null(&left) => Transformed::yes(left), - - // ---------------- - // Modulus - // ---------------- - - // A % null --> null - Expr::BinaryOp { - op: Operator::Modulus, - left: _, - right, - } if is_null(&right) => Transformed::yes(right), - - // null % A --> null - Expr::BinaryOp { - op: Operator::Modulus, - left, - right: _, - } if is_null(&left) => Transformed::yes(left), - - // A BETWEEN low AND high --> A >= low AND A <= high - Expr::Between(expr, low, high) => { - Transformed::yes(expr.clone().lt_eq(high).and(expr.gt_eq(low))) - } - Expr::Not(expr) => match Arc::unwrap_or_clone(expr) { - // NOT (BETWEEN A AND B) --> A < low OR A > high - Expr::Between(expr, low, high) => { - Transformed::yes(expr.clone().lt(low).or(expr.gt(high))) - } - // expr NOT IN () --> true - Expr::IsIn(_, list) if list.is_empty() => Transformed::yes(lit(true)), - - expr => { - let expr = simplify_expr(expr, schema)?; - if expr.transformed { - Transformed::yes(expr.data.not()) - } else { - Transformed::no(expr.data.not()) - } - } - }, - // expr IN () --> false - Expr::IsIn(_, list) if list.is_empty() => Transformed::yes(lit(false)), - - other => Transformed::no(Arc::new(other)), - }) -} - -fn is_zero(s: &Expr) -> bool { - match s { - Expr::Literal(LiteralValue::Int32(0)) - | Expr::Literal(LiteralValue::Int64(0)) - | Expr::Literal(LiteralValue::UInt32(0)) - | Expr::Literal(LiteralValue::UInt64(0)) - | Expr::Literal(LiteralValue::Float64(0.)) => true, - Expr::Literal(LiteralValue::Decimal(v, _p, _s)) if *v == 0 => true, - _ => false, - } -} - -fn is_one(s: &Expr) -> bool { - match s { - Expr::Literal(LiteralValue::Int32(1)) - | Expr::Literal(LiteralValue::Int64(1)) - | Expr::Literal(LiteralValue::UInt32(1)) - | Expr::Literal(LiteralValue::UInt64(1)) - | Expr::Literal(LiteralValue::Float64(1.)) => true, - - Expr::Literal(LiteralValue::Decimal(v, _p, s)) => { - *s >= 0 && POWS_OF_TEN.get(*s as usize).is_some_and(|pow| v == pow) - } - _ => false, - } -} - -fn is_true(expr: &Expr) -> bool { - match expr { - Expr::Literal(LiteralValue::Boolean(v)) => *v, - _ => false, - } -} -fn is_false(expr: &Expr) -> bool { - match expr { - Expr::Literal(LiteralValue::Boolean(v)) => !*v, - _ => false, - } -} - -/// returns true if expr is a -/// `Expr::Literal(LiteralValue::Boolean(v))` , false otherwise -fn is_bool_lit(expr: &Expr) -> bool { - matches!(expr, Expr::Literal(LiteralValue::Boolean(_))) -} - -fn is_bool_type(expr: &Expr, schema: &SchemaRef) -> bool { - matches!(expr.get_type(schema), Ok(DataType::Boolean)) -} - -fn as_bool_lit(expr: &Expr) -> Option { - expr.as_literal().and_then(|l| l.as_bool()) -} - -fn is_null(expr: &Expr) -> bool { - matches!(expr, Expr::Literal(LiteralValue::Null)) -} - -static POWS_OF_TEN: [i128; 38] = [ - 1, - 10, - 100, - 1000, - 10000, - 100000, - 1000000, - 10000000, - 100000000, - 1000000000, - 10000000000, - 100000000000, - 1000000000000, - 10000000000000, - 100000000000000, - 1000000000000000, - 10000000000000000, - 100000000000000000, - 1000000000000000000, - 10000000000000000000, - 100000000000000000000, - 1000000000000000000000, - 10000000000000000000000, - 100000000000000000000000, - 1000000000000000000000000, - 10000000000000000000000000, - 100000000000000000000000000, - 1000000000000000000000000000, - 10000000000000000000000000000, - 100000000000000000000000000000, - 1000000000000000000000000000000, - 10000000000000000000000000000000, - 100000000000000000000000000000000, - 1000000000000000000000000000000000, - 10000000000000000000000000000000000, - 100000000000000000000000000000000000, - 1000000000000000000000000000000000000, - 10000000000000000000000000000000000000, -]; - #[cfg(test)] mod test { use std::sync::Arc; use daft_core::prelude::Schema; - use daft_dsl::{col, lit, null_lit, ExprRef}; + use daft_dsl::{col, lit}; use daft_schema::{dtype::DataType, field::Field}; - use rstest::rstest; use super::SimplifyExpressionsRule; use crate::{ @@ -436,112 +83,6 @@ mod test { ) } - #[rstest] - // true = A --> A - #[case(col("bool").eq(lit(true)), col("bool"))] - // false = A --> !A - #[case(col("bool").eq(lit(false)), col("bool").not())] - // A == true ---> A - #[case(col("bool").eq(lit(true)), col("bool"))] - // null = A --> null - #[case(null_lit().eq(col("bool")), null_lit())] - // A == false ---> !A - #[case(col("bool").eq(lit(false)), col("bool").not())] - // true != A --> !A - #[case(lit(true).not_eq(col("bool")), col("bool").not())] - // false != A --> A - #[case(lit(false).not_eq(col("bool")), col("bool"))] - // true OR A --> true - #[case(lit(true).or(col("bool")), lit(true))] - // false OR A --> A - #[case(lit(false).or(col("bool")), col("bool"))] - // A OR true --> true - #[case(col("bool").or(lit(true)), lit(true))] - // A OR false --> A - #[case(col("bool").or(lit(false)), col("bool"))] - fn test_simplify_bool_exprs(#[case] input: ExprRef, #[case] expected: ExprRef) { - let source = make_source().filter(input).unwrap().build(); - let optimizer = SimplifyExpressionsRule::new(); - let optimized = optimizer.try_optimize(source).unwrap(); - - let LogicalPlan::Filter(Filter { predicate, .. }) = optimized.data.as_ref() else { - panic!("Expected Filter, got {:?}", optimized.data) - }; - - // make sure the expression is simplified - assert!(optimized.transformed); - - assert_eq!(predicate, &expected); - } - - #[rstest] - // A * 1 --> A - #[case(col("int").mul(lit(1)), col("int"))] - // 1 * A --> A - #[case(lit(1).mul(col("int")), col("int"))] - // A / 1 --> A - #[case(col("int").div(lit(1)), col("int"))] - // A + 0 --> A - #[case(col("int").add(lit(0)), col("int"))] - // A - 0 --> A - #[case(col("int").sub(lit(0)), col("int"))] - fn test_math_exprs(#[case] input: ExprRef, #[case] expected: ExprRef) { - let source = make_source().select(vec![input]).unwrap().build(); - let optimizer = SimplifyExpressionsRule::new(); - let optimized = optimizer.try_optimize(source).unwrap(); - - let LogicalPlan::Project(Project { projection, .. }) = optimized.data.as_ref() else { - panic!("Expected Filter, got {:?}", optimized.data) - }; - - let projection = projection.first().unwrap(); - - // make sure the expression is simplified - assert!(optimized.transformed); - - assert_eq!(projection, &expected); - } - - #[test] - fn test_not_between() { - let source = make_source() - .filter(col("int").between(lit(1), lit(10)).not()) - .unwrap() - .build(); - let optimizer = SimplifyExpressionsRule::new(); - let optimized = optimizer.try_optimize(source).unwrap(); - - let LogicalPlan::Filter(Filter { predicate, .. }) = optimized.data.as_ref() else { - panic!("Expected Filter, got {:?}", optimized.data) - }; - - // make sure the expression is simplified - assert!(optimized.transformed); - - assert_eq!(predicate, &col("int").lt(lit(1)).or(col("int").gt(lit(10)))); - } - - #[test] - fn test_between() { - let source = make_source() - .filter(col("int").between(lit(1), lit(10))) - .unwrap() - .build(); - let optimizer = SimplifyExpressionsRule::new(); - let optimized = optimizer.try_optimize(source).unwrap(); - - let LogicalPlan::Filter(Filter { predicate, .. }) = optimized.data.as_ref() else { - panic!("Expected Filter, got {:?}", optimized.data) - }; - - // make sure the expression is simplified - assert!(optimized.transformed); - - assert_eq!( - predicate, - &col("int").lt_eq(lit(10)).and(col("int").gt_eq(lit(1))) - ); - } #[test] fn test_nested_plan() { let source = make_source() diff --git a/src/daft-logical-plan/src/optimization/rules/unnest_subquery.rs b/src/daft-logical-plan/src/optimization/rules/unnest_subquery.rs index 3413e8cc53..5039cc9767 100644 --- a/src/daft-logical-plan/src/optimization/rules/unnest_subquery.rs +++ b/src/daft-logical-plan/src/optimization/rules/unnest_subquery.rs @@ -2,12 +2,9 @@ use std::{collections::HashSet, sync::Arc}; use common_error::{DaftError, DaftResult}; use common_treenode::{DynTreeNode, Transformed, TreeNode}; +use daft_algebra::boolean::{combine_conjunction, split_conjunction}; use daft_core::{join::JoinType, prelude::SchemaRef}; -use daft_dsl::{ - col, - optimization::{conjuct, split_conjuction}, - Expr, ExprRef, Operator, Subquery, -}; +use daft_dsl::{col, Expr, ExprRef, Operator, Subquery}; use itertools::multiunzip; use uuid::Uuid; @@ -73,12 +70,12 @@ impl UnnestScalarSubquery { impl UnnestScalarSubquery { fn unnest_subqueries( input: LogicalPlanRef, - exprs: Vec<&ExprRef>, + exprs: &[ExprRef], ) -> DaftResult)>> { let mut subqueries = HashSet::new(); let new_exprs = exprs - .into_iter() + .iter() .map(|expr| { expr.clone() .transform_down(|e| { @@ -164,7 +161,7 @@ impl OptimizerRule for UnnestScalarSubquery { input, predicate, .. }) => { let unnest_result = - Self::unnest_subqueries(input.clone(), split_conjuction(predicate))?; + Self::unnest_subqueries(input.clone(), &split_conjunction(predicate))?; if !unnest_result.transformed { return Ok(Transformed::no(node)); @@ -172,7 +169,7 @@ impl OptimizerRule for UnnestScalarSubquery { let (new_input, new_predicates) = unnest_result.data; - let new_predicate = conjuct(new_predicates) + let new_predicate = combine_conjunction(new_predicates) .expect("predicates are guaranteed to exist at this point, so 'conjunct' should never return 'None'"); let new_filter = Arc::new(LogicalPlan::Filter(Filter::try_new( @@ -192,7 +189,7 @@ impl OptimizerRule for UnnestScalarSubquery { input, projection, .. }) => { let unnest_result = - Self::unnest_subqueries(input.clone(), projection.iter().collect())?; + Self::unnest_subqueries(input.clone(), projection)?; if !unnest_result.transformed { return Ok(Transformed::no(node)); @@ -275,7 +272,7 @@ impl OptimizerRule for UnnestPredicateSubquery { }) => { let mut subqueries = HashSet::new(); - let new_predicates = split_conjuction(predicate) + let new_predicates = split_conjunction(predicate) .into_iter() .filter(|expr| { match expr.as_ref() { @@ -303,7 +300,6 @@ impl OptimizerRule for UnnestPredicateSubquery { _ => true } }) - .cloned() .collect::>(); if subqueries.is_empty() { @@ -345,7 +341,7 @@ impl OptimizerRule for UnnestPredicateSubquery { )?))) })?; - let new_plan = if let Some(new_predicate) = conjuct(new_predicates) { + let new_plan = if let Some(new_predicate) = combine_conjunction(new_predicates) { // add filter back if there are non-subquery predicates Arc::new(LogicalPlan::Filter(Filter::try_new( new_input, @@ -387,7 +383,7 @@ fn pull_up_correlated_cols( }) => { let mut found_correlated_col = false; - let preds = split_conjuction(predicate) + let preds = split_conjunction(predicate) .into_iter() .filter(|expr| { if let Expr::BinaryOp { @@ -418,7 +414,6 @@ fn pull_up_correlated_cols( true }) - .cloned() .collect::>(); // no new correlated cols found @@ -426,7 +421,7 @@ fn pull_up_correlated_cols( return Ok((plan.clone(), subquery_on, outer_on)); } - if let Some(new_predicate) = conjuct(preds) { + if let Some(new_predicate) = combine_conjunction(preds) { let new_plan = Arc::new(LogicalPlan::Filter(Filter::try_new( input.clone(), new_predicate, diff --git a/src/daft-sql/Cargo.toml b/src/daft-sql/Cargo.toml index 6e45c23741..a402235011 100644 --- a/src/daft-sql/Cargo.toml +++ b/src/daft-sql/Cargo.toml @@ -3,6 +3,7 @@ common-daft-config = {path = "../common/daft-config"} common-error = {path = "../common/error"} common-io-config = {path = "../common/io-config", default-features = false} common-runtime = {workspace = true} +daft-algebra = {path = "../daft-algebra"} daft-core = {path = "../daft-core"} daft-dsl = {path = "../daft-dsl"} daft-functions = {path = "../daft-functions"} diff --git a/src/daft-sql/src/planner.rs b/src/daft-sql/src/planner.rs index 683391b601..ce2ef703a3 100644 --- a/src/daft-sql/src/planner.rs +++ b/src/daft-sql/src/planner.rs @@ -6,13 +6,13 @@ use std::{ }; use common_error::{DaftError, DaftResult}; +use daft_algebra::boolean::combine_conjunction; use daft_core::prelude::*; use daft_dsl::{ col, common_treenode::{Transformed, TreeNode}, - has_agg, lit, literals_to_series, null_lit, - optimization::conjuct, - AggExpr, Expr, ExprRef, LiteralValue, Operator, OuterReferenceColumn, Subquery, + has_agg, lit, literals_to_series, null_lit, AggExpr, Expr, ExprRef, LiteralValue, Operator, + OuterReferenceColumn, Subquery, }; use daft_functions::{ numeric::{ceil::ceil, floor::floor}, @@ -959,12 +959,12 @@ impl<'a> SQLPlanner<'a> { }; let mut left_plan = self.current_relation.as_ref().unwrap().inner.clone(); - if let Some(left_predicate) = conjuct(left_filters) { + if let Some(left_predicate) = combine_conjunction(left_filters) { left_plan = left_plan.filter(left_predicate)?; } let mut right_plan = right_rel.inner.clone(); - if let Some(right_predicate) = conjuct(right_filters) { + if let Some(right_predicate) = combine_conjunction(right_filters) { right_plan = right_plan.filter(right_predicate)?; }