From 3d5736a9443e8365472a61a8fcfd22af9bd74113 Mon Sep 17 00:00:00 2001 From: desmondcheongzx Date: Wed, 11 Dec 2024 03:41:39 +0800 Subject: [PATCH 01/11] Implement GOO --- src/common/scan-info/src/test/mod.rs | 5 +- .../rules/reorder_joins/greedy_join_order.rs | 126 ++++++++++++ .../rules/reorder_joins/join_graph.rs | 184 ++++++++++++------ .../optimization/rules/reorder_joins/mod.rs | 2 + src/daft-logical-plan/src/test/mod.rs | 9 +- src/daft-physical-plan/src/test/mod.rs | 1 + 6 files changed, 266 insertions(+), 61 deletions(-) create mode 100644 src/daft-logical-plan/src/optimization/rules/reorder_joins/greedy_join_order.rs diff --git a/src/common/scan-info/src/test/mod.rs b/src/common/scan-info/src/test/mod.rs index 0da27600e9..a9d248c5d6 100644 --- a/src/common/scan-info/src/test/mod.rs +++ b/src/common/scan-info/src/test/mod.rs @@ -17,12 +17,14 @@ use crate::{PartitionField, Pushdowns, ScanOperator, ScanTaskLike, ScanTaskLikeR struct DummyScanTask { pub schema: SchemaRef, pub pushdowns: Pushdowns, + pub in_memory_size: Option, } #[derive(Debug)] pub struct DummyScanOperator { pub schema: SchemaRef, pub num_scan_tasks: u32, + pub in_memory_size_per_task: Option, } #[typetag::serde] @@ -67,7 +69,7 @@ impl ScanTaskLike for DummyScanTask { } fn estimate_in_memory_size_bytes(&self, _: Option<&DaftExecutionConfig>) -> Option { - None + self.in_memory_size } fn file_format_config(&self) -> Arc { @@ -136,6 +138,7 @@ impl ScanOperator for DummyScanOperator { let scan_task = Arc::new(DummyScanTask { schema: self.schema.clone(), pushdowns, + in_memory_size: self.in_memory_size_per_task, }); Ok((0..self.num_scan_tasks) diff --git a/src/daft-logical-plan/src/optimization/rules/reorder_joins/greedy_join_order.rs b/src/daft-logical-plan/src/optimization/rules/reorder_joins/greedy_join_order.rs new file mode 100644 index 0000000000..ab1c3e4776 --- /dev/null +++ b/src/daft-logical-plan/src/optimization/rules/reorder_joins/greedy_join_order.rs @@ -0,0 +1,126 @@ +use std::{collections::{HashMap, HashSet}, sync::Arc}; + +use common_error::DaftResult; +use daft_dsl::{col, ExprRef}; + +use crate::{LogicalPlanBuilder, LogicalPlanRef}; + +use super::join_graph::{JoinCondition, JoinGraph}; + +// This is an implementation of the Greedy Operator Ordering algorithm (GOO) [1] for join selection. This algorithm +// selects join edges greedily by picking the edge with the smallest cost at each step. This is similar to Kruskal's +// minimum spanning tree algorithm, with the caveat that edge costs update at each step, due to changing cardinalities +// and selectivities between join nodes. +// +// Compared to DP-based algorithms, GOO is not always optimal. However, GOO has a complexity of O(n^3) and is more viable +// than DP-based algorithms when performing join ordering on many relations. DP Connected subgraph Complement Pairs (DPccp) [2] +// is the DP-based algorithm widely used in database systems today and has a O(3^n) complexity, although the latest +// literature does offer a super-polynomially faster DP-algorithm but that still has a O(2^n) to O(2^n * n^3) complexity [3]. +// +// For this reason, we maintain a greedy-based join ordering algorithm to use when the number of relations is large, and default +// to DP-based algorithms otherwise. +// +// [1]: Fegaras, L. (1998). A New Heuristic for Optimizing Large Queries. International Conference on Database and Expert Systems Applications. +// [2]: Moerkotte, G., & Neumann, T. (2006). Analysis of two existing and one new dynamic programming algorithm for the generation of optimal bushy join trees without cross products. Very Large Data Bases Conference. +// [3]: Stoian, M., & Kipf, A. (2024). DPconv: Super-Polynomially Faster Join Ordering. ArXiv, abs/2409.08013. + +pub(crate) fn compute_join_order(join_graph: &mut JoinGraph) -> DaftResult { + // TODO(desmond): we need to handle projections. + println!("adjlist: {}", join_graph.adj_list); + while join_graph.adj_list.0.len() > 1 { + let (min_cost, selected_pair) = find_minimum_cost_join(&join_graph.adj_list.0); + println!("min cost: {min_cost:?}"); + if let Some((left, right, join_conds)) = selected_pair { + println!("selected pair: {}({}) <-> {}({}) on {:?}", left.name(), left.schema(), right.name(), right.schema(), join_conds); + let (left_on, right_on) = join_conds + .iter() + .map(|join_cond| (col(join_cond.left_on.clone()), col(join_cond.right_on.clone()))) + .collect::<(Vec, Vec)>(); + let left_builder = LogicalPlanBuilder::from(left.clone()); + let join = left_builder.inner_join(right.clone(), left_on, right_on)?.build(); + let join = Arc::new(Arc::unwrap_or_clone(join).with_materialized_stats()); + let old_left_edges = join_graph.adj_list.0.remove(&left).unwrap(); + let old_right_edges = join_graph.adj_list.0.remove(&right).unwrap(); + let mut new_join_edges = HashMap::new(); + + // Process all neighbors from both left and right nodes + let mut process_edges = |edges: HashMap>| { + for (neighbor, _) in edges { + if neighbor == right || neighbor == left { + continue; // Skip the nodes we just joined + } + let mut join_conditions = Vec::new(); + + // If neighbor was connected to left node, collect those conditions + if let Some(left_conds) = join_graph.adj_list.0.get_mut(&neighbor).unwrap().remove(&left) { + join_conditions.extend(left_conds); + } + + // If neighbor was connected to right node, collect those conditions + if let Some(right_conds) = join_graph.adj_list.0.get_mut(&neighbor).unwrap().remove(&right) { + join_conditions.extend(right_conds); + } + + // If this neighbor had any connections to left or right, create new edge to join node + if !join_conditions.is_empty() { + join_graph.adj_list.0.get_mut(&neighbor).unwrap().insert(join.clone(), join_conditions.clone()); + new_join_edges.insert(neighbor.clone(), join_conditions.iter().map(|cond| cond.flip()).collect()); + } + } + }; + + // Process edges from both left and right nodes + process_edges(old_left_edges); + process_edges(old_right_edges); + + // Add the new join node and its edges to the graph + join_graph.adj_list.0.insert(join, new_join_edges); + } else { + panic!("No valid join edge selected despite join graph containing more than one relation"); + } + println!("adjlist: {}", join_graph.adj_list); + } + // TODO(desmond): Apply projections. + todo!() +} + +fn find_minimum_cost_join( + adj_list: &HashMap>> +) -> (Option, Option<(LogicalPlanRef, LogicalPlanRef, Vec)>) { + let mut min_cost = None; + let mut selected_pair = None; + + for (candidate_left, neighbors) in adj_list { + for (candidate_right, join_conds) in neighbors { + let left_stats = candidate_left.materialized_stats(); + let right_stats = candidate_right.materialized_stats(); + + // Assume primary key foreign key join which would have a size bounded by the foreign key relation, + // which is typically larger. + let cur_cost = left_stats.approx_stats.upper_bound_bytes + .max(right_stats.approx_stats.upper_bound_bytes); + + if let Some(existing_min) = min_cost { + if let Some(current) = cur_cost { + if current < existing_min { + min_cost = Some(current); + selected_pair = Some(( + candidate_left.clone(), + candidate_right.clone(), + join_conds.clone() + )); + } + } + } else { + min_cost = cur_cost; + selected_pair = Some(( + candidate_left.clone(), + candidate_right.clone(), + join_conds.clone() + )); + } + } + } + + (min_cost, selected_pair) +} \ No newline at end of file diff --git a/src/daft-logical-plan/src/optimization/rules/reorder_joins/join_graph.rs b/src/daft-logical-plan/src/optimization/rules/reorder_joins/join_graph.rs index f004fe0b3d..f64cbb1bdf 100644 --- a/src/daft-logical-plan/src/optimization/rules/reorder_joins/join_graph.rs +++ b/src/daft-logical-plan/src/optimization/rules/reorder_joins/join_graph.rs @@ -19,6 +19,10 @@ struct JoinNode { final_name: String, } +// TODO(desmond): We should also take into account user provided values for: +// - null equals null +// - join strategy + /// JoinNodes represent a relation (i.e. a non-reorderable logical plan node), the column /// that's being accessed from the relation, and the final name of the column in the output. impl JoinNode { @@ -64,6 +68,58 @@ impl Display for JoinEdge { } } +#[derive(Clone, Debug)] +pub(crate) struct JoinCondition { + pub left_on: String, + pub right_on: String, +} + +impl JoinCondition { + pub(crate) fn flip(&self) -> Self { + JoinCondition { left_on: self.right_on.clone(), right_on: self.left_on.clone() } + } +} + +pub(crate) struct JoinAdjList(pub HashMap>>); + +impl std::fmt::Display for JoinAdjList { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + writeln!(f, "Join Graph Adjacency List:")?; + for (node, neighbors) in &self.0 { + writeln!(f, "Node '{}':", node.name())?; + for (neighbor, conditions) in neighbors { + writeln!(f, " →> '{}' with conditions:", neighbor.name())?; + for (i, cond) in conditions.iter().enumerate() { + writeln!(f, " {}: {} = {}", i + 1, cond.left_on, cond.right_on)?; + } + } + } + Ok(()) + } +} + +impl JoinAdjList { + fn add_unidirectional_edge(&mut self, left: &JoinNode, right: &JoinNode) { + // TODO(desmond): We should also keep track of projections that we need to do. + let join_condition = JoinCondition{left_on: left.final_name.clone(), right_on: right.final_name.clone()}; + if let Some(neighbors) = self.0.get_mut(&left.plan) { + if let Some(join_conditions) = neighbors.get_mut(&right.plan) { + join_conditions.push(join_condition); + } else { + neighbors.insert(right.plan.clone(), vec![join_condition]); + } + } else { + let mut neighbors = HashMap::new(); + neighbors.insert(right.plan.clone(), vec![join_condition]); + self.0.insert(left.plan.clone(), neighbors); + } + } + fn add_bidirectional_edge(&mut self, node1: JoinNode, node2: JoinNode) { + self.add_unidirectional_edge(&node1, &node2); + self.add_unidirectional_edge(&node2, &node1); + } +} + #[derive(Debug)] enum ProjectionOrFilter { Projection(Vec), @@ -72,10 +128,8 @@ enum ProjectionOrFilter { /// Representation of a logical plan as edges between relations, along with additional information needed to /// reconstruct a logcial plan that's equivalent to the plan that produced this graph. -struct JoinGraph { - // TODO(desmond): Instead of simply storing edges, we might want to maintain adjacency lists between - // relations. We can make this decision later when we implement join order selection. - edges: Vec, +pub(crate) struct JoinGraph { + pub adj_list: JoinAdjList, // List of projections and filters that should be applied after join reordering. This list respects // pre-order traversal of projections and filters in the query tree, so we should apply these operators // starting from the back of the list. @@ -84,47 +138,48 @@ struct JoinGraph { impl JoinGraph { pub(crate) fn new( - edges: Vec, + adj_list: JoinAdjList, final_projections_and_filters: Vec, ) -> Self { Self { - edges, + adj_list, final_projections_and_filters, } } - /// Test helper function to get the number of edges that the current graph contains. - pub(crate) fn num_edges(&self) -> usize { - self.edges.len() - } + // /// Test helper function to get the number of edges that the current graph contains. + // pub(crate) fn num_edges(&self) -> usize { + // self.edges.len() + // } /// Test helper function to check that all relations in this graph are connected. pub(crate) fn fully_connected(&self) -> bool { - // Assuming that we're not testing an empty graph, there should be at least one edge in a connected graph. - if self.edges.is_empty() { - return false; - } - let mut adj_list: HashMap<*const _, Vec<*const _>> = HashMap::new(); - for edge in &self.edges { - let l_ptr = Arc::as_ptr(&edge.0.plan); - let r_ptr = Arc::as_ptr(&edge.1.plan); + // // Assuming that we're not testing an empty graph, there should be at least one edge in a connected graph. + // if self.edges.is_empty() { + // return false; + // } + // let mut adj_list: HashMap<*const _, Vec<*const _>> = HashMap::new(); + // for edge in &self.edges { + // let l_ptr = Arc::as_ptr(&edge.0.plan); + // let r_ptr = Arc::as_ptr(&edge.1.plan); - adj_list.entry(l_ptr).or_default().push(r_ptr); - adj_list.entry(r_ptr).or_default().push(l_ptr); - } - let start_ptr = Arc::as_ptr(&self.edges[0].0.plan); - let mut seen = HashSet::new(); - let mut stack = vec![start_ptr]; - - while let Some(current) = stack.pop() { - if seen.insert(current) { - // If this is a new node, add all its neighbors to the stack. - if let Some(neighbors) = adj_list.get(¤t) { - stack.extend(neighbors.iter().filter(|&&n| !seen.contains(&n))); - } - } - } - seen.len() == adj_list.len() + // adj_list.entry(l_ptr).or_default().push(r_ptr); + // adj_list.entry(r_ptr).or_default().push(l_ptr); + // } + // let start_ptr = Arc::as_ptr(&self.edges[0].0.plan); + // let mut seen = HashSet::new(); + // let mut stack = vec![start_ptr]; + + // while let Some(current) = stack.pop() { + // if seen.insert(current) { + // // If this is a new node, add all its neighbors to the stack. + // if let Some(neighbors) = adj_list.get(¤t) { + // stack.extend(neighbors.iter().filter(|&&n| !seen.contains(&n))); + // } + // } + // } + // seen.len() == adj_list.len() + true } /// Test helper function that checks if the graph contains the given projection/filter expressions @@ -153,12 +208,13 @@ impl JoinGraph { /// Helper function that loosely checks if a given edge (represented by a simple string) /// exists in the current graph. pub(crate) fn contains_edge(&self, edge_string: &str) -> bool { - for edge in &self.edges { - if edge.simple_repr() == edge_string { - return true; - } - } - false + // for edge in &self.edges { + // if edge.simple_repr() == edge_string { + // return true; + // } + // } + // false + true } } @@ -167,14 +223,14 @@ struct JoinGraphBuilder { plan: LogicalPlanRef, join_conds_to_resolve: Vec<(String, LogicalPlanRef, bool)>, final_name_map: HashMap, - edges: Vec, + adj_list: JoinAdjList, final_projections_and_filters: Vec, } impl JoinGraphBuilder { pub(crate) fn build(mut self) -> JoinGraph { self.process_node(&self.plan.clone()); - JoinGraph::new(self.edges, self.final_projections_and_filters) + JoinGraph::new(self.adj_list, self.final_projections_and_filters) } pub(crate) fn from_logical_plan(plan: LogicalPlanRef) -> Self { @@ -192,7 +248,7 @@ impl JoinGraphBuilder { plan, join_conds_to_resolve: vec![], final_name_map: HashMap::new(), - edges: vec![], + adj_list: JoinAdjList(HashMap::new()), final_projections_and_filters: vec![ProjectionOrFilter::Projection(output_projection)], } } @@ -328,7 +384,7 @@ impl JoinGraphBuilder { rnode.clone(), self.final_name_map.get(&rname).unwrap().name().to_string(), ); - self.edges.push(JoinEdge(node1, node2)); + self.adj_list.add_bidirectional_edge(node1, node2); } else { panic!("Join conditions were unresolved"); } @@ -354,12 +410,13 @@ mod tests { use std::sync::Arc; use common_scan_info::Pushdowns; + use common_treenode::TransformedResult; use daft_core::prelude::CountMode; use daft_dsl::{col, AggExpr, Expr, LiteralValue}; use daft_schema::{dtype::DataType, field::Field}; use super::JoinGraphBuilder; - use crate::test::{dummy_scan_node_with_pushdowns, dummy_scan_operator}; + use crate::{optimization::rules::{reorder_joins::greedy_join_order::compute_join_order, EnrichWithStats, MaterializeScans, OptimizerRule}, test::{dummy_scan_node_with_pushdowns, dummy_scan_operator, dummy_scan_operator_with_size}}; #[test] fn test_create_join_graph_basic_1() { @@ -372,21 +429,25 @@ mod tests { // | // Scan(c_prime) let scan_a = dummy_scan_node_with_pushdowns( - dummy_scan_operator(vec![Field::new("a", DataType::Int64)]), + dummy_scan_operator_with_size(vec![Field::new("a", DataType::Int64)], Some(100)), Pushdowns::default(), ); let scan_b = dummy_scan_node_with_pushdowns( - dummy_scan_operator(vec![Field::new("b", DataType::Int64)]), + dummy_scan_operator_with_size(vec![Field::new("b", DataType::Int64)], Some(10_000)), Pushdowns::default(), ); let scan_c = dummy_scan_node_with_pushdowns( - dummy_scan_operator(vec![Field::new("c_prime", DataType::Int64)]), + dummy_scan_operator_with_size(vec![Field::new("c", DataType::Int64)], Some(100)), Pushdowns::default(), - ) - .select(vec![col("c_prime").alias("c")]) - .unwrap(); + ); + // let scan_c = dummy_scan_node_with_pushdowns( + // dummy_scan_operator_with_size(vec![Field::new("c_prime", DataType::Int64)], Some(100)), + // Pushdowns::default(), + // ) + // .select(vec![col("c_prime").alias("c")]) + // .unwrap(); let scan_d = dummy_scan_node_with_pushdowns( - dummy_scan_operator(vec![Field::new("d", DataType::Int64)]), + dummy_scan_operator_with_size(vec![Field::new("d", DataType::Int64)],Some(100)), Pushdowns::default(), ); let join_plan_l = scan_a @@ -411,16 +472,21 @@ mod tests { ) .unwrap(); let plan = join_plan.build(); - let join_graph = JoinGraphBuilder::from_logical_plan(plan).build(); + let scan_materializer = MaterializeScans::new(); + let plan = scan_materializer.try_optimize(plan).data().unwrap(); + let stats_enricher = EnrichWithStats::new(); + let plan = stats_enricher.try_optimize(plan).data().unwrap(); + let mut join_graph = JoinGraphBuilder::from_logical_plan(plan).build(); assert!(join_graph.fully_connected()); // There should be edges between: // - a <-> b // - c_prime <-> d // - a <-> d - assert!(join_graph.num_edges() == 3); + // assert!(join_graph.num_edges() == 3); assert!(join_graph.contains_edge("a#Source(a) <-> b#Source(b)")); assert!(join_graph.contains_edge("c#Source(c_prime) <-> d#Source(d)")); assert!(join_graph.contains_edge("a#Source(a) <-> d#Source(d)")); + println!("result: {:?}", compute_join_order(&mut join_graph)); } #[test] @@ -479,7 +545,7 @@ mod tests { // - a <-> b // - c_prime <-> d // - b <-> d - assert!(join_graph.num_edges() == 3); + // assert!(join_graph.num_edges() == 3); assert!(join_graph.contains_edge("a#Source(a) <-> b#Source(b)")); assert!(join_graph.contains_edge("c#Source(c_prime) <-> d#Source(d)")); assert!(join_graph.contains_edge("b#Source(b) <-> d#Source(d)")); @@ -534,7 +600,7 @@ mod tests { // There should be edges between: // - a <-> b // - a <-> c - assert!(join_graph.num_edges() == 2); + // assert!(join_graph.num_edges() == 2); assert!(join_graph.contains_edge("a_beta#Source(a) <-> b#Source(b)")); assert!(join_graph.contains_edge("a_beta#Source(a) <-> c#Source(c)")); } @@ -596,7 +662,7 @@ mod tests { // - a <-> b // - c_prime <-> d // - a <-> d - assert!(join_graph.num_edges() == 3); + // assert!(join_graph.num_edges() == 3); assert!(join_graph.contains_edge("a#Source(a) <-> b#Source(b)")); assert!(join_graph.contains_edge("c#Source(c_prime) <-> d#Source(d)")); assert!(join_graph.contains_edge("a#Source(a) <-> d#Source(d)")); @@ -681,7 +747,7 @@ mod tests { // - a <-> b // - c_prime <-> d // - a <-> d - assert!(join_graph.num_edges() == 3); + // assert!(join_graph.num_edges() == 3); assert!(join_graph.contains_edge("a#Source(a) <-> b#Source(b)")); assert!(join_graph.contains_edge("c#Source(c_prime) <-> d#Source(d)")); assert!(join_graph.contains_edge("a#Source(a) <-> d#Source(d)")); @@ -767,7 +833,7 @@ mod tests { // - a <-> b // - c_prime <-> d // - a <-> d - assert!(join_graph.num_edges() == 3); + // assert!(join_graph.num_edges() == 3); assert!(join_graph.contains_edge("a#Aggregate(a) <-> b#Source(b)")); assert!(join_graph.contains_edge("c#Source(c_prime) <-> d#Source(d)")); assert!(join_graph.contains_edge("a#Aggregate(a) <-> d#Source(d)")); diff --git a/src/daft-logical-plan/src/optimization/rules/reorder_joins/mod.rs b/src/daft-logical-plan/src/optimization/rules/reorder_joins/mod.rs index 09ece20040..e4e1eda2a3 100644 --- a/src/daft-logical-plan/src/optimization/rules/reorder_joins/mod.rs +++ b/src/daft-logical-plan/src/optimization/rules/reorder_joins/mod.rs @@ -1,2 +1,4 @@ #[cfg(test)] mod join_graph; +#[cfg(test)] +mod greedy_join_order; \ No newline at end of file diff --git a/src/daft-logical-plan/src/test/mod.rs b/src/daft-logical-plan/src/test/mod.rs index 75f8ad386b..b75115a860 100644 --- a/src/daft-logical-plan/src/test/mod.rs +++ b/src/daft-logical-plan/src/test/mod.rs @@ -7,10 +7,17 @@ use crate::builder::LogicalPlanBuilder; /// Create a dummy scan node containing the provided fields in its schema and the provided limit. pub fn dummy_scan_operator(fields: Vec) -> ScanOperatorRef { + dummy_scan_operator_with_size(fields, None) +} + +/// Create dummy scan node containing the provided fields in its schema and the provided limit, +/// and with the provided size estimate. +pub fn dummy_scan_operator_with_size(fields: Vec, in_memory_size_per_task: Option) -> ScanOperatorRef { let schema = Arc::new(Schema::new(fields).unwrap()); ScanOperatorRef(Arc::new(DummyScanOperator { schema, - num_scan_tasks: 0, + num_scan_tasks: 1, + in_memory_size_per_task, })) } diff --git a/src/daft-physical-plan/src/test/mod.rs b/src/daft-physical-plan/src/test/mod.rs index 3e8de6a74c..29f9d81997 100644 --- a/src/daft-physical-plan/src/test/mod.rs +++ b/src/daft-physical-plan/src/test/mod.rs @@ -10,6 +10,7 @@ pub fn dummy_scan_operator(fields: Vec) -> ScanOperatorRef { ScanOperatorRef(Arc::new(DummyScanOperator { schema, num_scan_tasks: 1, + in_memory_size_per_task: None, })) } From 8207e1ec62bd25e519a5eb5545e840c0ffaf4bca Mon Sep 17 00:00:00 2001 From: desmondcheongzx Date: Wed, 11 Dec 2024 16:00:14 +0800 Subject: [PATCH 02/11] Apply projections and filters --- .../rules/reorder_joins/greedy_join_order.rs | 113 ++++++++++++------ .../rules/reorder_joins/join_graph.rs | 45 +++++-- .../optimization/rules/reorder_joins/mod.rs | 4 +- src/daft-logical-plan/src/test/mod.rs | 5 +- 4 files changed, 119 insertions(+), 48 deletions(-) diff --git a/src/daft-logical-plan/src/optimization/rules/reorder_joins/greedy_join_order.rs b/src/daft-logical-plan/src/optimization/rules/reorder_joins/greedy_join_order.rs index ab1c3e4776..43c5575ff8 100644 --- a/src/daft-logical-plan/src/optimization/rules/reorder_joins/greedy_join_order.rs +++ b/src/daft-logical-plan/src/optimization/rules/reorder_joins/greedy_join_order.rs @@ -1,11 +1,13 @@ -use std::{collections::{HashMap, HashSet}, sync::Arc}; +use std::{ + collections::{HashMap, HashSet}, + sync::Arc, +}; use common_error::DaftResult; use daft_dsl::{col, ExprRef}; -use crate::{LogicalPlanBuilder, LogicalPlanRef}; - use super::join_graph::{JoinCondition, JoinGraph}; +use crate::{LogicalPlanBuilder, LogicalPlanRef}; // This is an implementation of the Greedy Operator Ordering algorithm (GOO) [1] for join selection. This algorithm // selects join edges greedily by picking the edge with the smallest cost at each step. This is similar to Kruskal's @@ -14,7 +16,7 @@ use super::join_graph::{JoinCondition, JoinGraph}; // // Compared to DP-based algorithms, GOO is not always optimal. However, GOO has a complexity of O(n^3) and is more viable // than DP-based algorithms when performing join ordering on many relations. DP Connected subgraph Complement Pairs (DPccp) [2] -// is the DP-based algorithm widely used in database systems today and has a O(3^n) complexity, although the latest +// is the DP-based algorithm widely used in database systems today and has a O(3^n) complexity, although the latest // literature does offer a super-polynomially faster DP-algorithm but that still has a O(2^n) to O(2^n * n^3) complexity [3]. // // For this reason, we maintain a greedy-based join ordering algorithm to use when the number of relations is large, and default @@ -26,67 +28,100 @@ use super::join_graph::{JoinCondition, JoinGraph}; pub(crate) fn compute_join_order(join_graph: &mut JoinGraph) -> DaftResult { // TODO(desmond): we need to handle projections. - println!("adjlist: {}", join_graph.adj_list); + // println!("adjlist: {}", join_graph.adj_list); while join_graph.adj_list.0.len() > 1 { let (min_cost, selected_pair) = find_minimum_cost_join(&join_graph.adj_list.0); - println!("min cost: {min_cost:?}"); + // println!("min cost: {min_cost:?}"); if let Some((left, right, join_conds)) = selected_pair { - println!("selected pair: {}({}) <-> {}({}) on {:?}", left.name(), left.schema(), right.name(), right.schema(), join_conds); + // println!("selected pair: {}({}) <-> {}({}) on {:?}", left.name(), left.schema(), right.name(), right.schema(), join_conds); let (left_on, right_on) = join_conds .iter() - .map(|join_cond| (col(join_cond.left_on.clone()), col(join_cond.right_on.clone()))) + .map(|join_cond| { + ( + col(join_cond.left_on.clone()), + col(join_cond.right_on.clone()), + ) + }) .collect::<(Vec, Vec)>(); let left_builder = LogicalPlanBuilder::from(left.clone()); - let join = left_builder.inner_join(right.clone(), left_on, right_on)?.build(); + let join = left_builder + .inner_join(right.clone(), left_on, right_on)? + .build(); let join = Arc::new(Arc::unwrap_or_clone(join).with_materialized_stats()); - let old_left_edges = join_graph.adj_list.0.remove(&left).unwrap(); - let old_right_edges = join_graph.adj_list.0.remove(&right).unwrap(); + let left_neighbors = join_graph.adj_list.0.remove(&left).unwrap(); + let right_neighbors = join_graph.adj_list.0.remove(&right).unwrap(); let mut new_join_edges = HashMap::new(); - // Process all neighbors from both left and right nodes - let mut process_edges = |edges: HashMap>| { - for (neighbor, _) in edges { + // Helper function to collapse the left and right node + let mut update_neighbors = |neighbors: HashMap>| { + for (neighbor, _) in neighbors { if neighbor == right || neighbor == left { - continue; // Skip the nodes we just joined + // Skip the nodes that we just joined. + continue; } let mut join_conditions = Vec::new(); - - // If neighbor was connected to left node, collect those conditions - if let Some(left_conds) = join_graph.adj_list.0.get_mut(&neighbor).unwrap().remove(&left) { + // If this neighbor was connected to left or right nodes, collect the join conditions. + if let Some(left_conds) = join_graph + .adj_list + .0 + .get_mut(&neighbor) + .unwrap() + .remove(&left) + { join_conditions.extend(left_conds); } - - // If neighbor was connected to right node, collect those conditions - if let Some(right_conds) = join_graph.adj_list.0.get_mut(&neighbor).unwrap().remove(&right) { + if let Some(right_conds) = join_graph + .adj_list + .0 + .get_mut(&neighbor) + .unwrap() + .remove(&right) + { join_conditions.extend(right_conds); } - - // If this neighbor had any connections to left or right, create new edge to join node + // If this neighbor had any connections to left or right, create a new edge to the new join node. if !join_conditions.is_empty() { - join_graph.adj_list.0.get_mut(&neighbor).unwrap().insert(join.clone(), join_conditions.clone()); - new_join_edges.insert(neighbor.clone(), join_conditions.iter().map(|cond| cond.flip()).collect()); + join_graph + .adj_list + .0 + .get_mut(&neighbor) + .unwrap() + .insert(join.clone(), join_conditions.clone()); + new_join_edges.insert( + neighbor.clone(), + join_conditions.iter().map(|cond| cond.flip()).collect(), + ); } } }; - // Process edges from both left and right nodes - process_edges(old_left_edges); - process_edges(old_right_edges); + // Process all neighbors from both the left and right sides. + update_neighbors(left_neighbors); + update_neighbors(right_neighbors); // Add the new join node and its edges to the graph join_graph.adj_list.0.insert(join, new_join_edges); } else { - panic!("No valid join edge selected despite join graph containing more than one relation"); + panic!( + "No valid join edge selected despite join graph containing more than one relation" + ); } - println!("adjlist: {}", join_graph.adj_list); + // println!("adjlist: {}", join_graph.adj_list); } // TODO(desmond): Apply projections. - todo!() + if let Some(joined_plan) = join_graph.adj_list.0.drain().map(|(plan, _)| plan).last() { + join_graph.apply_projections_and_filters_to_plan(joined_plan) + } else { + panic!("No valid logical plan after join reordering") + } } fn find_minimum_cost_join( - adj_list: &HashMap>> -) -> (Option, Option<(LogicalPlanRef, LogicalPlanRef, Vec)>) { + adj_list: &HashMap>>, +) -> ( + Option, + Option<(LogicalPlanRef, LogicalPlanRef, Vec)>, +) { let mut min_cost = None; let mut selected_pair = None; @@ -94,10 +129,12 @@ fn find_minimum_cost_join( for (candidate_right, join_conds) in neighbors { let left_stats = candidate_left.materialized_stats(); let right_stats = candidate_right.materialized_stats(); - + // Assume primary key foreign key join which would have a size bounded by the foreign key relation, // which is typically larger. - let cur_cost = left_stats.approx_stats.upper_bound_bytes + let cur_cost = left_stats + .approx_stats + .upper_bound_bytes .max(right_stats.approx_stats.upper_bound_bytes); if let Some(existing_min) = min_cost { @@ -107,7 +144,7 @@ fn find_minimum_cost_join( selected_pair = Some(( candidate_left.clone(), candidate_right.clone(), - join_conds.clone() + join_conds.clone(), )); } } @@ -116,11 +153,11 @@ fn find_minimum_cost_join( selected_pair = Some(( candidate_left.clone(), candidate_right.clone(), - join_conds.clone() + join_conds.clone(), )); } } } (min_cost, selected_pair) -} \ No newline at end of file +} diff --git a/src/daft-logical-plan/src/optimization/rules/reorder_joins/join_graph.rs b/src/daft-logical-plan/src/optimization/rules/reorder_joins/join_graph.rs index f64cbb1bdf..dd2bf62674 100644 --- a/src/daft-logical-plan/src/optimization/rules/reorder_joins/join_graph.rs +++ b/src/daft-logical-plan/src/optimization/rules/reorder_joins/join_graph.rs @@ -4,12 +4,13 @@ use std::{ sync::Arc, }; +use common_error::DaftResult; use daft_core::join::JoinType; use daft_dsl::{col, optimization::replace_columns_with_expressions, ExprRef}; use crate::{ ops::{Filter, Join, Project}, - LogicalPlan, LogicalPlanRef, + LogicalPlan, LogicalPlanBuilder, LogicalPlanRef, }; #[derive(Debug)] @@ -76,11 +77,16 @@ pub(crate) struct JoinCondition { impl JoinCondition { pub(crate) fn flip(&self) -> Self { - JoinCondition { left_on: self.right_on.clone(), right_on: self.left_on.clone() } + JoinCondition { + left_on: self.right_on.clone(), + right_on: self.left_on.clone(), + } } } -pub(crate) struct JoinAdjList(pub HashMap>>); +pub(crate) struct JoinAdjList( + pub HashMap>>, +); impl std::fmt::Display for JoinAdjList { fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { @@ -101,7 +107,10 @@ impl std::fmt::Display for JoinAdjList { impl JoinAdjList { fn add_unidirectional_edge(&mut self, left: &JoinNode, right: &JoinNode) { // TODO(desmond): We should also keep track of projections that we need to do. - let join_condition = JoinCondition{left_on: left.final_name.clone(), right_on: right.final_name.clone()}; + let join_condition = JoinCondition { + left_on: left.final_name.clone(), + right_on: right.final_name.clone(), + }; if let Some(neighbors) = self.0.get_mut(&left.plan) { if let Some(join_conditions) = neighbors.get_mut(&right.plan) { join_conditions.push(join_condition); @@ -121,7 +130,7 @@ impl JoinAdjList { } #[derive(Debug)] -enum ProjectionOrFilter { +pub(crate) enum ProjectionOrFilter { Projection(Vec), Filter(ExprRef), } @@ -147,6 +156,20 @@ impl JoinGraph { } } + pub(crate) fn apply_projections_and_filters_to_plan( + &mut self, + plan: LogicalPlanRef, + ) -> DaftResult { + let mut plan = LogicalPlanBuilder::from(plan); + for projection_or_filter in self.final_projections_and_filters.drain(..).rev() { + match projection_or_filter { + ProjectionOrFilter::Projection(projections) => plan = plan.select(projections)?, + ProjectionOrFilter::Filter(predicate) => plan = plan.filter(predicate)?, + } + } + Ok(plan.build()) + } + // /// Test helper function to get the number of edges that the current graph contains. // pub(crate) fn num_edges(&self) -> usize { // self.edges.len() @@ -416,7 +439,15 @@ mod tests { use daft_schema::{dtype::DataType, field::Field}; use super::JoinGraphBuilder; - use crate::{optimization::rules::{reorder_joins::greedy_join_order::compute_join_order, EnrichWithStats, MaterializeScans, OptimizerRule}, test::{dummy_scan_node_with_pushdowns, dummy_scan_operator, dummy_scan_operator_with_size}}; + use crate::{ + optimization::rules::{ + reorder_joins::greedy_join_order::compute_join_order, EnrichWithStats, + MaterializeScans, OptimizerRule, + }, + test::{ + dummy_scan_node_with_pushdowns, dummy_scan_operator, dummy_scan_operator_with_size, + }, + }; #[test] fn test_create_join_graph_basic_1() { @@ -447,7 +478,7 @@ mod tests { // .select(vec![col("c_prime").alias("c")]) // .unwrap(); let scan_d = dummy_scan_node_with_pushdowns( - dummy_scan_operator_with_size(vec![Field::new("d", DataType::Int64)],Some(100)), + dummy_scan_operator_with_size(vec![Field::new("d", DataType::Int64)], Some(100)), Pushdowns::default(), ); let join_plan_l = scan_a diff --git a/src/daft-logical-plan/src/optimization/rules/reorder_joins/mod.rs b/src/daft-logical-plan/src/optimization/rules/reorder_joins/mod.rs index e4e1eda2a3..58987555ab 100644 --- a/src/daft-logical-plan/src/optimization/rules/reorder_joins/mod.rs +++ b/src/daft-logical-plan/src/optimization/rules/reorder_joins/mod.rs @@ -1,4 +1,4 @@ #[cfg(test)] -mod join_graph; +mod greedy_join_order; #[cfg(test)] -mod greedy_join_order; \ No newline at end of file +mod join_graph; diff --git a/src/daft-logical-plan/src/test/mod.rs b/src/daft-logical-plan/src/test/mod.rs index b75115a860..7ac8da51c1 100644 --- a/src/daft-logical-plan/src/test/mod.rs +++ b/src/daft-logical-plan/src/test/mod.rs @@ -12,7 +12,10 @@ pub fn dummy_scan_operator(fields: Vec) -> ScanOperatorRef { /// Create dummy scan node containing the provided fields in its schema and the provided limit, /// and with the provided size estimate. -pub fn dummy_scan_operator_with_size(fields: Vec, in_memory_size_per_task: Option) -> ScanOperatorRef { +pub fn dummy_scan_operator_with_size( + fields: Vec, + in_memory_size_per_task: Option, +) -> ScanOperatorRef { let schema = Arc::new(Schema::new(fields).unwrap()); ScanOperatorRef(Arc::new(DummyScanOperator { schema, From 8873626d919f0ab4e84a977ade803b283bc7ddf3 Mon Sep 17 00:00:00 2001 From: desmondcheongzx Date: Wed, 11 Dec 2024 16:16:58 +0800 Subject: [PATCH 03/11] Reimplement num_edges() --- .../rules/reorder_joins/join_graph.rs | 25 +++++++++++-------- 1 file changed, 15 insertions(+), 10 deletions(-) diff --git a/src/daft-logical-plan/src/optimization/rules/reorder_joins/join_graph.rs b/src/daft-logical-plan/src/optimization/rules/reorder_joins/join_graph.rs index dd2bf62674..b31aa8d876 100644 --- a/src/daft-logical-plan/src/optimization/rules/reorder_joins/join_graph.rs +++ b/src/daft-logical-plan/src/optimization/rules/reorder_joins/join_graph.rs @@ -170,10 +170,15 @@ impl JoinGraph { Ok(plan.build()) } - // /// Test helper function to get the number of edges that the current graph contains. - // pub(crate) fn num_edges(&self) -> usize { - // self.edges.len() - // } + /// Test helper function to get the number of edges that the current graph contains. + pub(crate) fn num_edges(&self) -> usize { + let mut num_edges = 0; + for (_, edges) in &self.adj_list.0 { + num_edges += edges.len(); + } + // Each edge is bidirectional, so we divide by 2 to get the correct number of edges. + num_edges / 2 + } /// Test helper function to check that all relations in this graph are connected. pub(crate) fn fully_connected(&self) -> bool { @@ -513,7 +518,7 @@ mod tests { // - a <-> b // - c_prime <-> d // - a <-> d - // assert!(join_graph.num_edges() == 3); + assert!(join_graph.num_edges() == 3); assert!(join_graph.contains_edge("a#Source(a) <-> b#Source(b)")); assert!(join_graph.contains_edge("c#Source(c_prime) <-> d#Source(d)")); assert!(join_graph.contains_edge("a#Source(a) <-> d#Source(d)")); @@ -576,7 +581,7 @@ mod tests { // - a <-> b // - c_prime <-> d // - b <-> d - // assert!(join_graph.num_edges() == 3); + assert!(join_graph.num_edges() == 3); assert!(join_graph.contains_edge("a#Source(a) <-> b#Source(b)")); assert!(join_graph.contains_edge("c#Source(c_prime) <-> d#Source(d)")); assert!(join_graph.contains_edge("b#Source(b) <-> d#Source(d)")); @@ -631,7 +636,7 @@ mod tests { // There should be edges between: // - a <-> b // - a <-> c - // assert!(join_graph.num_edges() == 2); + assert!(join_graph.num_edges() == 2); assert!(join_graph.contains_edge("a_beta#Source(a) <-> b#Source(b)")); assert!(join_graph.contains_edge("a_beta#Source(a) <-> c#Source(c)")); } @@ -693,7 +698,7 @@ mod tests { // - a <-> b // - c_prime <-> d // - a <-> d - // assert!(join_graph.num_edges() == 3); + assert!(join_graph.num_edges() == 3); assert!(join_graph.contains_edge("a#Source(a) <-> b#Source(b)")); assert!(join_graph.contains_edge("c#Source(c_prime) <-> d#Source(d)")); assert!(join_graph.contains_edge("a#Source(a) <-> d#Source(d)")); @@ -778,7 +783,7 @@ mod tests { // - a <-> b // - c_prime <-> d // - a <-> d - // assert!(join_graph.num_edges() == 3); + assert!(join_graph.num_edges() == 3); assert!(join_graph.contains_edge("a#Source(a) <-> b#Source(b)")); assert!(join_graph.contains_edge("c#Source(c_prime) <-> d#Source(d)")); assert!(join_graph.contains_edge("a#Source(a) <-> d#Source(d)")); @@ -864,7 +869,7 @@ mod tests { // - a <-> b // - c_prime <-> d // - a <-> d - // assert!(join_graph.num_edges() == 3); + assert!(join_graph.num_edges() == 3); assert!(join_graph.contains_edge("a#Aggregate(a) <-> b#Source(b)")); assert!(join_graph.contains_edge("c#Source(c_prime) <-> d#Source(d)")); assert!(join_graph.contains_edge("a#Aggregate(a) <-> d#Source(d)")); From d450ebf365057bde3990dae162dad91448980e21 Mon Sep 17 00:00:00 2001 From: desmondcheongzx Date: Wed, 11 Dec 2024 16:28:11 +0800 Subject: [PATCH 04/11] Reimplement fully_connected() --- .../rules/reorder_joins/join_graph.rs | 49 +++++++++---------- 1 file changed, 24 insertions(+), 25 deletions(-) diff --git a/src/daft-logical-plan/src/optimization/rules/reorder_joins/join_graph.rs b/src/daft-logical-plan/src/optimization/rules/reorder_joins/join_graph.rs index b31aa8d876..7399b9478e 100644 --- a/src/daft-logical-plan/src/optimization/rules/reorder_joins/join_graph.rs +++ b/src/daft-logical-plan/src/optimization/rules/reorder_joins/join_graph.rs @@ -182,32 +182,31 @@ impl JoinGraph { /// Test helper function to check that all relations in this graph are connected. pub(crate) fn fully_connected(&self) -> bool { - // // Assuming that we're not testing an empty graph, there should be at least one edge in a connected graph. - // if self.edges.is_empty() { - // return false; - // } - // let mut adj_list: HashMap<*const _, Vec<*const _>> = HashMap::new(); - // for edge in &self.edges { - // let l_ptr = Arc::as_ptr(&edge.0.plan); - // let r_ptr = Arc::as_ptr(&edge.1.plan); - - // adj_list.entry(l_ptr).or_default().push(r_ptr); - // adj_list.entry(r_ptr).or_default().push(l_ptr); - // } + let start = if let Some((key, value)) = self.adj_list.0.iter().next() { + key + } else { + // There are no nodes. The empty graph is fully connected. + return true; + }; // let start_ptr = Arc::as_ptr(&self.edges[0].0.plan); - // let mut seen = HashSet::new(); - // let mut stack = vec![start_ptr]; - - // while let Some(current) = stack.pop() { - // if seen.insert(current) { - // // If this is a new node, add all its neighbors to the stack. - // if let Some(neighbors) = adj_list.get(¤t) { - // stack.extend(neighbors.iter().filter(|&&n| !seen.contains(&n))); - // } - // } - // } - // seen.len() == adj_list.len() - true + let mut seen = HashSet::new(); + let mut stack = vec![start]; + + while let Some(current) = stack.pop() { + if seen.insert(current) { + // If this is a new node, add all its neighbors to the stack. + if let Some(neighbors) = self.adj_list.0.get(current) { + stack.extend(neighbors.iter().filter_map(|(neighbor, _)| { + if !seen.contains(neighbor) { + Some(neighbor) + } else { + None + } + })); + } + } + } + seen.len() == self.adj_list.0.len() } /// Test helper function that checks if the graph contains the given projection/filter expressions From 9f33df32cbfa2bf173f3a34664ffc88511ebf87a Mon Sep 17 00:00:00 2001 From: desmondcheongzx Date: Wed, 11 Dec 2024 17:27:38 +0800 Subject: [PATCH 05/11] Sweatin' --- .../rules/reorder_joins/greedy_join_order.rs | 228 ++++++++---------- .../rules/reorder_joins/join_graph.rs | 56 +++-- 2 files changed, 147 insertions(+), 137 deletions(-) diff --git a/src/daft-logical-plan/src/optimization/rules/reorder_joins/greedy_join_order.rs b/src/daft-logical-plan/src/optimization/rules/reorder_joins/greedy_join_order.rs index 43c5575ff8..b64fdb1504 100644 --- a/src/daft-logical-plan/src/optimization/rules/reorder_joins/greedy_join_order.rs +++ b/src/daft-logical-plan/src/optimization/rules/reorder_joins/greedy_join_order.rs @@ -1,7 +1,4 @@ -use std::{ - collections::{HashMap, HashSet}, - sync::Arc, -}; +use std::{collections::HashMap, sync::Arc}; use common_error::DaftResult; use daft_dsl::{col, ExprRef}; @@ -25,139 +22,124 @@ use crate::{LogicalPlanBuilder, LogicalPlanRef}; // [1]: Fegaras, L. (1998). A New Heuristic for Optimizing Large Queries. International Conference on Database and Expert Systems Applications. // [2]: Moerkotte, G., & Neumann, T. (2006). Analysis of two existing and one new dynamic programming algorithm for the generation of optimal bushy join trees without cross products. Very Large Data Bases Conference. // [3]: Stoian, M., & Kipf, A. (2024). DPconv: Super-Polynomially Faster Join Ordering. ArXiv, abs/2409.08013. +pub(crate) struct GreedyJoinOrderer {} -pub(crate) fn compute_join_order(join_graph: &mut JoinGraph) -> DaftResult { - // TODO(desmond): we need to handle projections. - // println!("adjlist: {}", join_graph.adj_list); - while join_graph.adj_list.0.len() > 1 { - let (min_cost, selected_pair) = find_minimum_cost_join(&join_graph.adj_list.0); - // println!("min cost: {min_cost:?}"); - if let Some((left, right, join_conds)) = selected_pair { - // println!("selected pair: {}({}) <-> {}({}) on {:?}", left.name(), left.schema(), right.name(), right.schema(), join_conds); - let (left_on, right_on) = join_conds - .iter() - .map(|join_cond| { - ( - col(join_cond.left_on.clone()), - col(join_cond.right_on.clone()), - ) - }) - .collect::<(Vec, Vec)>(); - let left_builder = LogicalPlanBuilder::from(left.clone()); - let join = left_builder - .inner_join(right.clone(), left_on, right_on)? - .build(); - let join = Arc::new(Arc::unwrap_or_clone(join).with_materialized_stats()); - let left_neighbors = join_graph.adj_list.0.remove(&left).unwrap(); - let right_neighbors = join_graph.adj_list.0.remove(&right).unwrap(); - let mut new_join_edges = HashMap::new(); +impl GreedyJoinOrderer { + pub(crate) fn compute_join_order(join_graph: &mut JoinGraph) -> DaftResult { + // TODO(desmond): we need to handle projections. + while join_graph.adj_list.0.len() > 1 { + let selected_pair = GreedyJoinOrderer::find_minimum_cost_join(&join_graph.adj_list.0); + if let Some((left, right, join_conds)) = selected_pair { + let (left_on, right_on) = join_conds + .iter() + .map(|join_cond| { + ( + col(join_cond.left_on.clone()), + col(join_cond.right_on.clone()), + ) + }) + .collect::<(Vec, Vec)>(); + let left_builder = LogicalPlanBuilder::from(left.clone()); + let join = left_builder + .inner_join(right.clone(), left_on, right_on)? + .build(); + let join = Arc::new(Arc::unwrap_or_clone(join).with_materialized_stats()); + let left_neighbors = join_graph.adj_list.0.remove(&left).unwrap(); + let right_neighbors = join_graph.adj_list.0.remove(&right).unwrap(); + let mut new_join_edges = HashMap::new(); - // Helper function to collapse the left and right node - let mut update_neighbors = |neighbors: HashMap>| { - for (neighbor, _) in neighbors { - if neighbor == right || neighbor == left { - // Skip the nodes that we just joined. - continue; - } - let mut join_conditions = Vec::new(); - // If this neighbor was connected to left or right nodes, collect the join conditions. - if let Some(left_conds) = join_graph - .adj_list - .0 - .get_mut(&neighbor) - .unwrap() - .remove(&left) - { - join_conditions.extend(left_conds); - } - if let Some(right_conds) = join_graph - .adj_list - .0 - .get_mut(&neighbor) - .unwrap() - .remove(&right) - { - join_conditions.extend(right_conds); - } - // If this neighbor had any connections to left or right, create a new edge to the new join node. - if !join_conditions.is_empty() { - join_graph - .adj_list - .0 - .get_mut(&neighbor) - .unwrap() - .insert(join.clone(), join_conditions.clone()); - new_join_edges.insert( - neighbor.clone(), - join_conditions.iter().map(|cond| cond.flip()).collect(), - ); - } - } - }; + // Helper function to collapse the left and right node + let mut update_neighbors = + |neighbors: HashMap>| { + for (neighbor, _) in neighbors { + if neighbor == right || neighbor == left { + // Skip the nodes that we just joined. + continue; + } + let mut join_conditions = Vec::new(); + // If this neighbor was connected to left or right nodes, collect the join conditions. + let neighbor_edges = join_graph + .adj_list + .0 + .get_mut(&neighbor) + .expect("The neighbor should still be in the join graph"); + if let Some(left_conds) = neighbor_edges.remove(&left) { + join_conditions.extend(left_conds); + } + if let Some(right_conds) = neighbor_edges.remove(&right) { + join_conditions.extend(right_conds); + } + // If this neighbor had any connections to left or right, create a new edge to the new join node. + if !join_conditions.is_empty() { + neighbor_edges.insert(join.clone(), join_conditions.clone()); + new_join_edges.insert( + neighbor.clone(), + join_conditions.iter().map(|cond| cond.flip()).collect(), + ); + } + } + }; - // Process all neighbors from both the left and right sides. - update_neighbors(left_neighbors); - update_neighbors(right_neighbors); + // Process all neighbors from both the left and right sides. + update_neighbors(left_neighbors); + update_neighbors(right_neighbors); - // Add the new join node and its edges to the graph - join_graph.adj_list.0.insert(join, new_join_edges); + // Add the new join node and its edges to the graph + join_graph.adj_list.0.insert(join, new_join_edges); + } else { + panic!( + "No valid join edge selected despite join graph containing more than one relation" + ); + } + } + // Apply projections and filters on top of the fully joined plan. + if let Some(joined_plan) = join_graph.adj_list.0.drain().map(|(plan, _)| plan).last() { + join_graph.apply_projections_and_filters_to_plan(joined_plan) } else { - panic!( - "No valid join edge selected despite join graph containing more than one relation" - ); + panic!("No valid logical plan after join reordering") } - // println!("adjlist: {}", join_graph.adj_list); - } - // TODO(desmond): Apply projections. - if let Some(joined_plan) = join_graph.adj_list.0.drain().map(|(plan, _)| plan).last() { - join_graph.apply_projections_and_filters_to_plan(joined_plan) - } else { - panic!("No valid logical plan after join reordering") } -} -fn find_minimum_cost_join( - adj_list: &HashMap>>, -) -> ( - Option, - Option<(LogicalPlanRef, LogicalPlanRef, Vec)>, -) { - let mut min_cost = None; - let mut selected_pair = None; + fn find_minimum_cost_join( + adj_list: &HashMap>>, + ) -> Option<(LogicalPlanRef, LogicalPlanRef, Vec)> { + let mut min_cost = None; + let mut selected_pair = None; - for (candidate_left, neighbors) in adj_list { - for (candidate_right, join_conds) in neighbors { - let left_stats = candidate_left.materialized_stats(); - let right_stats = candidate_right.materialized_stats(); + for (candidate_left, neighbors) in adj_list { + for (candidate_right, join_conds) in neighbors { + let left_stats = candidate_left.materialized_stats(); + let right_stats = candidate_right.materialized_stats(); - // Assume primary key foreign key join which would have a size bounded by the foreign key relation, - // which is typically larger. - let cur_cost = left_stats - .approx_stats - .upper_bound_bytes - .max(right_stats.approx_stats.upper_bound_bytes); + // Assume primary key foreign key join which would have a size bounded by the foreign key relation, + // which is typically larger. + let cur_cost = left_stats + .approx_stats + .upper_bound_bytes + .max(right_stats.approx_stats.upper_bound_bytes); - if let Some(existing_min) = min_cost { - if let Some(current) = cur_cost { - if current < existing_min { - min_cost = Some(current); - selected_pair = Some(( - candidate_left.clone(), - candidate_right.clone(), - join_conds.clone(), - )); + if let Some(existing_min) = min_cost { + if let Some(current) = cur_cost { + if current < existing_min { + min_cost = Some(current); + selected_pair = Some(( + candidate_left.clone(), + candidate_right.clone(), + join_conds.clone(), + )); + } } + } else { + min_cost = cur_cost; + selected_pair = Some(( + candidate_left.clone(), + candidate_right.clone(), + join_conds.clone(), + )); } - } else { - min_cost = cur_cost; - selected_pair = Some(( - candidate_left.clone(), - candidate_right.clone(), - join_conds.clone(), - )); } } - } - (min_cost, selected_pair) + selected_pair + } } diff --git a/src/daft-logical-plan/src/optimization/rules/reorder_joins/join_graph.rs b/src/daft-logical-plan/src/optimization/rules/reorder_joins/join_graph.rs index 7399b9478e..a2034b1404 100644 --- a/src/daft-logical-plan/src/optimization/rules/reorder_joins/join_graph.rs +++ b/src/daft-logical-plan/src/optimization/rules/reorder_joins/join_graph.rs @@ -420,13 +420,44 @@ impl JoinGraphBuilder { // TODO(desmond): There are potentially more reorderable nodes. For example, we can move repartitions around. _ => { // This is an unreorderable node. All unresolved columns coming out of this node should be marked as resolved. + // TODO(desmond): At this point we should perform a fresh join reorder optimization starting from this + // node as the root node. We can do this once we add the optimizer rule. + let mut projections = vec![]; + let mut needs_projection = false; + let mut seen_names = HashSet::new(); for (name, _, done) in self.join_conds_to_resolve.iter_mut() { - if schema.has_field(name) { + if schema.has_field(name) && !*done && !seen_names.contains(name) { + if let Some(final_name) = self.final_name_map.get(name) { + let final_name = final_name.name().to_string(); + if final_name != *name { + needs_projection = true; + projections.push(col(name.clone()).alias(final_name)); + } else { + projections.push(col(name.clone())); + } + } else { + projections.push(col(name.clone())); + } + seen_names.insert(name); + } + } + // Apply projections and return the new plan as the relation for the appropriate join conditions. + println!("projections: {projections:?}"); + let projected_plan = if needs_projection { + let projected_plan = LogicalPlanBuilder::from(plan.clone()) + .select(projections) + .expect("Computed projections could not be applied to relation") + .build(); + Arc::new(Arc::unwrap_or_clone(projected_plan).with_materialized_stats()) + } else { + plan.clone() + }; + for (name, node, done) in self.join_conds_to_resolve.iter_mut() { + if schema.has_field(name) && !*done { *done = true; + *node = projected_plan.clone(); } } - // TODO(desmond): At this point we should perform a fresh join reorder optimization starting from this - // node as the root node. We can do this once we add the optimizer rule. } } } @@ -445,8 +476,8 @@ mod tests { use super::JoinGraphBuilder; use crate::{ optimization::rules::{ - reorder_joins::greedy_join_order::compute_join_order, EnrichWithStats, - MaterializeScans, OptimizerRule, + reorder_joins::greedy_join_order::GreedyJoinOrderer, EnrichWithStats, MaterializeScans, + OptimizerRule, }, test::{ dummy_scan_node_with_pushdowns, dummy_scan_operator, dummy_scan_operator_with_size, @@ -472,15 +503,11 @@ mod tests { Pushdowns::default(), ); let scan_c = dummy_scan_node_with_pushdowns( - dummy_scan_operator_with_size(vec![Field::new("c", DataType::Int64)], Some(100)), + dummy_scan_operator_with_size(vec![Field::new("c_prime", DataType::Int64)], Some(100)), Pushdowns::default(), - ); - // let scan_c = dummy_scan_node_with_pushdowns( - // dummy_scan_operator_with_size(vec![Field::new("c_prime", DataType::Int64)], Some(100)), - // Pushdowns::default(), - // ) - // .select(vec![col("c_prime").alias("c")]) - // .unwrap(); + ) + .select(vec![col("c_prime").alias("c")]) + .unwrap(); let scan_d = dummy_scan_node_with_pushdowns( dummy_scan_operator_with_size(vec![Field::new("d", DataType::Int64)], Some(100)), Pushdowns::default(), @@ -521,7 +548,8 @@ mod tests { assert!(join_graph.contains_edge("a#Source(a) <-> b#Source(b)")); assert!(join_graph.contains_edge("c#Source(c_prime) <-> d#Source(d)")); assert!(join_graph.contains_edge("a#Source(a) <-> d#Source(d)")); - println!("result: {:?}", compute_join_order(&mut join_graph)); + let plan = GreedyJoinOrderer::compute_join_order(&mut join_graph).unwrap(); + println!("{}", plan.repr_ascii(false)) } #[test] From 9b7886a34e7c7b0266680588d67c0e0eeebb5317 Mon Sep 17 00:00:00 2001 From: desmondcheongzx Date: Thu, 12 Dec 2024 04:28:19 +0800 Subject: [PATCH 06/11] Update tests --- .../rules/reorder_joins/greedy_join_order.rs | 14 +- .../rules/reorder_joins/join_graph.rs | 220 ++++++++++++------ 2 files changed, 160 insertions(+), 74 deletions(-) diff --git a/src/daft-logical-plan/src/optimization/rules/reorder_joins/greedy_join_order.rs b/src/daft-logical-plan/src/optimization/rules/reorder_joins/greedy_join_order.rs index b64fdb1504..3d003ccd91 100644 --- a/src/daft-logical-plan/src/optimization/rules/reorder_joins/greedy_join_order.rs +++ b/src/daft-logical-plan/src/optimization/rules/reorder_joins/greedy_join_order.rs @@ -25,11 +25,14 @@ use crate::{LogicalPlanBuilder, LogicalPlanRef}; pub(crate) struct GreedyJoinOrderer {} impl GreedyJoinOrderer { + /// Consumes the join graph and transforms it into a logical plan with joins reordered. pub(crate) fn compute_join_order(join_graph: &mut JoinGraph) -> DaftResult { - // TODO(desmond): we need to handle projections. + // While the join graph consists of more than one join node, select the edge that has the smallest cost, + // then join the left and right nodes connected by this edge. while join_graph.adj_list.0.len() > 1 { let selected_pair = GreedyJoinOrderer::find_minimum_cost_join(&join_graph.adj_list.0); if let Some((left, right, join_conds)) = selected_pair { + // Join the left and right relations using the given join conditions. let (left_on, right_on) = join_conds .iter() .map(|join_cond| { @@ -44,11 +47,14 @@ impl GreedyJoinOrderer { .inner_join(right.clone(), left_on, right_on)? .build(); let join = Arc::new(Arc::unwrap_or_clone(join).with_materialized_stats()); + + // Add the new node into the adjacency list. let left_neighbors = join_graph.adj_list.0.remove(&left).unwrap(); let right_neighbors = join_graph.adj_list.0.remove(&right).unwrap(); let mut new_join_edges = HashMap::new(); - // Helper function to collapse the left and right node + // Helper function that takes in neighbors to the left and right nodes, the combines edges that point + // back to the left and/or right nodes into edges that point to the new join node. let mut update_neighbors = |neighbors: HashMap>| { for (neighbor, _) in neighbors { @@ -84,7 +90,7 @@ impl GreedyJoinOrderer { update_neighbors(left_neighbors); update_neighbors(right_neighbors); - // Add the new join node and its edges to the graph + // Add the new join node and its edges to the graph. join_graph.adj_list.0.insert(join, new_join_edges); } else { panic!( @@ -100,6 +106,8 @@ impl GreedyJoinOrderer { } } + /// Helper functions that finds the next join edge in the adjacency list that has the smallest cost. + /// Currently cost is determined based on the max size in bytes of the candidate left and right relations. fn find_minimum_cost_join( adj_list: &HashMap>>, ) -> Option<(LogicalPlanRef, LogicalPlanRef, Vec)> { diff --git a/src/daft-logical-plan/src/optimization/rules/reorder_joins/join_graph.rs b/src/daft-logical-plan/src/optimization/rules/reorder_joins/join_graph.rs index a2034b1404..6c153fba99 100644 --- a/src/daft-logical-plan/src/optimization/rules/reorder_joins/join_graph.rs +++ b/src/daft-logical-plan/src/optimization/rules/reorder_joins/join_graph.rs @@ -51,24 +51,6 @@ impl Display for JoinNode { } } -/// JoinEdges currently represent a bidirectional edge between two relations that have -/// an equi-join condition between each other. -#[derive(Debug)] -struct JoinEdge(JoinNode, JoinNode); - -impl JoinEdge { - /// Helper function that summarizes join edge information. - fn simple_repr(&self) -> String { - format!("{} <-> {}", self.0, self.1) - } -} - -impl Display for JoinEdge { - fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { - write!(f, "{}", self.simple_repr()) - } -} - #[derive(Clone, Debug)] pub(crate) struct JoinCondition { pub left_on: String, @@ -161,10 +143,28 @@ impl JoinGraph { plan: LogicalPlanRef, ) -> DaftResult { let mut plan = LogicalPlanBuilder::from(plan); - for projection_or_filter in self.final_projections_and_filters.drain(..).rev() { + // Apply projections and filters in post-traversal order. + let mut reversed_items = self + .final_projections_and_filters + .drain(..) + .rev() + .peekable(); + while let Some(projection_or_filter) = reversed_items.next() { + let is_last = reversed_items.peek().is_none(); + match projection_or_filter { - ProjectionOrFilter::Projection(projections) => plan = plan.select(projections)?, - ProjectionOrFilter::Filter(predicate) => plan = plan.filter(predicate)?, + ProjectionOrFilter::Projection(projections) => { + if is_last { + // The final projection is the output projection, so here we select the final projection. + plan = plan.select(projections)?; + } else { + // Intermediate projections might only transform a subset of columns, so we use `with_columns()` instead of `select()`. + plan = plan.with_columns(projections)?; + } + } + ProjectionOrFilter::Filter(predicate) => { + plan = plan.filter(predicate)?; + } } } Ok(plan.build()) @@ -182,8 +182,8 @@ impl JoinGraph { /// Test helper function to check that all relations in this graph are connected. pub(crate) fn fully_connected(&self) -> bool { - let start = if let Some((key, value)) = self.adj_list.0.iter().next() { - key + let start = if let Some((node, _)) = self.adj_list.0.iter().next() { + node } else { // There are no nodes. The empty graph is fully connected. return true; @@ -234,13 +234,26 @@ impl JoinGraph { /// Helper function that loosely checks if a given edge (represented by a simple string) /// exists in the current graph. - pub(crate) fn contains_edge(&self, edge_string: &str) -> bool { - // for edge in &self.edges { - // if edge.simple_repr() == edge_string { - // return true; - // } - // } - // false + pub(crate) fn contains_edges(&self, to_check: Vec<&str>) -> bool { + let mut edge_strings = HashSet::new(); + for (left, neighbors) in &self.adj_list.0 { + for (right, join_conds) in neighbors { + for join_cond in join_conds { + edge_strings.insert(format!( + "{}({}) <-> {}({})", + left.name(), + join_cond.left_on, + right.name(), + join_cond.right_on + )); + } + } + } + for cur_check in to_check { + if !edge_strings.contains(cur_check) { + return false; + } + } true } } @@ -442,7 +455,6 @@ impl JoinGraphBuilder { } } // Apply projections and return the new plan as the relation for the appropriate join conditions. - println!("projections: {projections:?}"); let projected_plan = if needs_projection { let projected_plan = LogicalPlanBuilder::from(plan.clone()) .select(projections) @@ -533,23 +545,29 @@ mod tests { vec![Arc::new(Expr::Column(Arc::from("d")))], ) .unwrap(); - let plan = join_plan.build(); + let original_plan = join_plan.build(); let scan_materializer = MaterializeScans::new(); - let plan = scan_materializer.try_optimize(plan).data().unwrap(); + let original_plan = scan_materializer + .try_optimize(original_plan) + .data() + .unwrap(); let stats_enricher = EnrichWithStats::new(); - let plan = stats_enricher.try_optimize(plan).data().unwrap(); - let mut join_graph = JoinGraphBuilder::from_logical_plan(plan).build(); + let original_plan = stats_enricher.try_optimize(original_plan).data().unwrap(); + let mut join_graph = JoinGraphBuilder::from_logical_plan(original_plan.clone()).build(); assert!(join_graph.fully_connected()); // There should be edges between: // - a <-> b - // - c_prime <-> d + // - c <-> d // - a <-> d assert!(join_graph.num_edges() == 3); - assert!(join_graph.contains_edge("a#Source(a) <-> b#Source(b)")); - assert!(join_graph.contains_edge("c#Source(c_prime) <-> d#Source(d)")); - assert!(join_graph.contains_edge("a#Source(a) <-> d#Source(d)")); - let plan = GreedyJoinOrderer::compute_join_order(&mut join_graph).unwrap(); - println!("{}", plan.repr_ascii(false)) + assert!(join_graph.contains_edges(vec![ + "Source(a) <-> Source(b)", + "Project(c) <-> Source(d)", + "Source(a) <-> Source(d)" + ])); + // Test greedy join reordering. + let reordered_plan = GreedyJoinOrderer::compute_join_order(&mut join_graph).unwrap(); + assert!(reordered_plan.schema() == original_plan.schema()); } #[test] @@ -601,17 +619,29 @@ mod tests { vec![Arc::new(Expr::Column(Arc::from("d")))], ) .unwrap(); - let plan = join_plan.build(); - let join_graph = JoinGraphBuilder::from_logical_plan(plan).build(); + let original_plan = join_plan.build(); + let scan_materializer = MaterializeScans::new(); + let original_plan = scan_materializer + .try_optimize(original_plan) + .data() + .unwrap(); + let stats_enricher = EnrichWithStats::new(); + let original_plan = stats_enricher.try_optimize(original_plan).data().unwrap(); + let mut join_graph = JoinGraphBuilder::from_logical_plan(original_plan.clone()).build(); assert!(join_graph.fully_connected()); // There should be edges between: // - a <-> b - // - c_prime <-> d + // - c <-> d // - b <-> d assert!(join_graph.num_edges() == 3); - assert!(join_graph.contains_edge("a#Source(a) <-> b#Source(b)")); - assert!(join_graph.contains_edge("c#Source(c_prime) <-> d#Source(d)")); - assert!(join_graph.contains_edge("b#Source(b) <-> d#Source(d)")); + assert!(join_graph.contains_edges(vec![ + "Source(a) <-> Source(b)", + "Project(c) <-> Source(d)", + "Source(b) <-> Source(d)", + ])); + // Test greedy join reordering. + let reordered_plan = GreedyJoinOrderer::compute_join_order(&mut join_graph).unwrap(); + assert!(reordered_plan.schema() == original_plan.schema()); } #[test] @@ -657,15 +687,27 @@ mod tests { vec![Arc::new(Expr::Column(Arc::from("c")))], ) .unwrap(); - let plan = join_plan_2.build(); - let join_graph = JoinGraphBuilder::from_logical_plan(plan).build(); + let original_plan = join_plan_2.build(); + let scan_materializer = MaterializeScans::new(); + let original_plan = scan_materializer + .try_optimize(original_plan) + .data() + .unwrap(); + let stats_enricher = EnrichWithStats::new(); + let original_plan = stats_enricher.try_optimize(original_plan).data().unwrap(); + let mut join_graph = JoinGraphBuilder::from_logical_plan(original_plan.clone()).build(); assert!(join_graph.fully_connected()); // There should be edges between: - // - a <-> b - // - a <-> c + // - a_beta <-> b + // - a_beta <-> c assert!(join_graph.num_edges() == 2); - assert!(join_graph.contains_edge("a_beta#Source(a) <-> b#Source(b)")); - assert!(join_graph.contains_edge("a_beta#Source(a) <-> c#Source(c)")); + assert!(join_graph.contains_edges(vec![ + "Project(a_beta) <-> Source(b)", + "Project(a_beta) <-> Source(c)", + ])); + // Test greedy join reordering. + let reordered_plan = GreedyJoinOrderer::compute_join_order(&mut join_graph).unwrap(); + assert!(reordered_plan.schema() == original_plan.schema()); } #[test] @@ -718,21 +760,33 @@ mod tests { vec![Arc::new(Expr::Column(Arc::from("d")))], ) .unwrap(); - let plan = join_plan.build(); - let join_graph = JoinGraphBuilder::from_logical_plan(plan).build(); + let original_plan = join_plan.build(); + let scan_materializer = MaterializeScans::new(); + let original_plan = scan_materializer + .try_optimize(original_plan) + .data() + .unwrap(); + let stats_enricher = EnrichWithStats::new(); + let original_plan = stats_enricher.try_optimize(original_plan).data().unwrap(); + let mut join_graph = JoinGraphBuilder::from_logical_plan(original_plan.clone()).build(); assert!(join_graph.fully_connected()); // There should be edges between: // - a <-> b - // - c_prime <-> d + // - c <-> d // - a <-> d assert!(join_graph.num_edges() == 3); - assert!(join_graph.contains_edge("a#Source(a) <-> b#Source(b)")); - assert!(join_graph.contains_edge("c#Source(c_prime) <-> d#Source(d)")); - assert!(join_graph.contains_edge("a#Source(a) <-> d#Source(d)")); + assert!(join_graph.contains_edges(vec![ + "Source(a) <-> Source(b)", + "Project(c) <-> Source(d)", + "Source(a) <-> Source(d)" + ])); // Check for non-join projections at the end. // `c_prime` gets renamed to `c` in the final projection let double_proj = col("c").add(col("c")).alias("double"); assert!(join_graph.contains_projections_and_filters(vec![&double_proj])); + // Test greedy join reordering. + let reordered_plan = GreedyJoinOrderer::compute_join_order(&mut join_graph).unwrap(); + assert!(reordered_plan.schema() == original_plan.schema()); } #[test] @@ -803,17 +857,26 @@ mod tests { vec![Arc::new(Expr::Column(Arc::from("d")))], ) .unwrap(); - let plan = join_plan.build(); - let join_graph = JoinGraphBuilder::from_logical_plan(plan).build(); + let original_plan = join_plan.build(); + let scan_materializer = MaterializeScans::new(); + let original_plan = scan_materializer + .try_optimize(original_plan) + .data() + .unwrap(); + let stats_enricher = EnrichWithStats::new(); + let original_plan = stats_enricher.try_optimize(original_plan).data().unwrap(); + let mut join_graph = JoinGraphBuilder::from_logical_plan(original_plan.clone()).build(); assert!(join_graph.fully_connected()); // There should be edges between: // - a <-> b - // - c_prime <-> d + // - c <-> d // - a <-> d assert!(join_graph.num_edges() == 3); - assert!(join_graph.contains_edge("a#Source(a) <-> b#Source(b)")); - assert!(join_graph.contains_edge("c#Source(c_prime) <-> d#Source(d)")); - assert!(join_graph.contains_edge("a#Source(a) <-> d#Source(d)")); + assert!(join_graph.contains_edges(vec![ + "Source(a) <-> Source(b)", + "Project(c) <-> Source(d)", + "Source(a) <-> Source(d)", + ])); // Check for non-join projections and filters at the end. // `c_prime` gets renamed to `c` in the final projection let double_proj = col("c").add(col("c")).alias("double"); @@ -824,6 +887,9 @@ mod tests { &double_proj, &filter_c_prime, ])); + // Test greedy join reordering. + let reordered_plan = GreedyJoinOrderer::compute_join_order(&mut join_graph).unwrap(); + assert!(reordered_plan.schema() == original_plan.schema()); } #[test] @@ -889,18 +955,30 @@ mod tests { vec![Arc::new(Expr::Column(Arc::from("d")))], ) .unwrap(); - let plan = join_plan.build(); - let join_graph = JoinGraphBuilder::from_logical_plan(plan).build(); + let original_plan = join_plan.build(); + let scan_materializer = MaterializeScans::new(); + let original_plan = scan_materializer + .try_optimize(original_plan) + .data() + .unwrap(); + let stats_enricher = EnrichWithStats::new(); + let original_plan = stats_enricher.try_optimize(original_plan).data().unwrap(); + let mut join_graph = JoinGraphBuilder::from_logical_plan(original_plan.clone()).build(); assert!(join_graph.fully_connected()); // There should be edges between: // - a <-> b - // - c_prime <-> d + // - c <-> d // - a <-> d assert!(join_graph.num_edges() == 3); - assert!(join_graph.contains_edge("a#Aggregate(a) <-> b#Source(b)")); - assert!(join_graph.contains_edge("c#Source(c_prime) <-> d#Source(d)")); - assert!(join_graph.contains_edge("a#Aggregate(a) <-> d#Source(d)")); + assert!(join_graph.contains_edges(vec![ + "Aggregate(a) <-> Source(b)", + "Project(c) <-> Source(d)", + "Aggregate(a) <-> Source(d)" + ])); // Projections below the aggregation should not be part of the final projections. assert!(!join_graph.contains_projections_and_filters(vec![&a_proj])); + // Test greedy join reordering. + let reordered_plan = GreedyJoinOrderer::compute_join_order(&mut join_graph).unwrap(); + assert!(reordered_plan.schema() == original_plan.schema()); } } From ba6eee24aa06fe7894b3274103625090976f9145 Mon Sep 17 00:00:00 2001 From: desmondcheongzx Date: Thu, 12 Dec 2024 04:36:50 +0800 Subject: [PATCH 07/11] Cleanup --- .../src/optimization/rules/reorder_joins/join_graph.rs | 10 +++++----- 1 file changed, 5 insertions(+), 5 deletions(-) diff --git a/src/daft-logical-plan/src/optimization/rules/reorder_joins/join_graph.rs b/src/daft-logical-plan/src/optimization/rules/reorder_joins/join_graph.rs index 6c153fba99..94ea453fb3 100644 --- a/src/daft-logical-plan/src/optimization/rules/reorder_joins/join_graph.rs +++ b/src/daft-logical-plan/src/optimization/rules/reorder_joins/join_graph.rs @@ -74,11 +74,11 @@ impl std::fmt::Display for JoinAdjList { fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { writeln!(f, "Join Graph Adjacency List:")?; for (node, neighbors) in &self.0 { - writeln!(f, "Node '{}':", node.name())?; - for (neighbor, conditions) in neighbors { - writeln!(f, " →> '{}' with conditions:", neighbor.name())?; - for (i, cond) in conditions.iter().enumerate() { - writeln!(f, " {}: {} = {}", i + 1, cond.left_on, cond.right_on)?; + writeln!(f, "Node {}:", node.name())?; + for (neighbor, join_conds) in neighbors { + writeln!(f, " -> {} with conditions:", neighbor.name())?; + for (i, cond) in join_conds.iter().enumerate() { + writeln!(f, " {}: {} = {}", i, cond.left_on, cond.right_on)?; } } } From c77c3d1bb48b6794cc182822029e904a29473c69 Mon Sep 17 00:00:00 2001 From: desmondcheongzx Date: Thu, 12 Dec 2024 04:39:19 +0800 Subject: [PATCH 08/11] Cleanup --- .../src/optimization/rules/reorder_joins/greedy_join_order.rs | 2 +- .../src/optimization/rules/reorder_joins/join_graph.rs | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/src/daft-logical-plan/src/optimization/rules/reorder_joins/greedy_join_order.rs b/src/daft-logical-plan/src/optimization/rules/reorder_joins/greedy_join_order.rs index 3d003ccd91..0f5a12592d 100644 --- a/src/daft-logical-plan/src/optimization/rules/reorder_joins/greedy_join_order.rs +++ b/src/daft-logical-plan/src/optimization/rules/reorder_joins/greedy_join_order.rs @@ -53,7 +53,7 @@ impl GreedyJoinOrderer { let right_neighbors = join_graph.adj_list.0.remove(&right).unwrap(); let mut new_join_edges = HashMap::new(); - // Helper function that takes in neighbors to the left and right nodes, the combines edges that point + // Helper function that takes in neighbors to the left and right nodes, then combines edges that point // back to the left and/or right nodes into edges that point to the new join node. let mut update_neighbors = |neighbors: HashMap>| { diff --git a/src/daft-logical-plan/src/optimization/rules/reorder_joins/join_graph.rs b/src/daft-logical-plan/src/optimization/rules/reorder_joins/join_graph.rs index 94ea453fb3..e7c09f3174 100644 --- a/src/daft-logical-plan/src/optimization/rules/reorder_joins/join_graph.rs +++ b/src/daft-logical-plan/src/optimization/rules/reorder_joins/join_graph.rs @@ -76,7 +76,7 @@ impl std::fmt::Display for JoinAdjList { for (node, neighbors) in &self.0 { writeln!(f, "Node {}:", node.name())?; for (neighbor, join_conds) in neighbors { - writeln!(f, " -> {} with conditions:", neighbor.name())?; + writeln!(f, " -> {} with conditions:", neighbor.name())?; for (i, cond) in join_conds.iter().enumerate() { writeln!(f, " {}: {} = {}", i, cond.left_on, cond.right_on)?; } From 731cabb06f8c1da75a46c3ad22c992d3333941e8 Mon Sep 17 00:00:00 2001 From: desmondcheongzx Date: Thu, 19 Dec 2024 17:13:42 +0800 Subject: [PATCH 09/11] Implement a naive join order --- Cargo.lock | 1 + src/daft-logical-plan/Cargo.toml | 1 + .../rules/reorder_joins/greedy_join_order.rs | 153 ------------- .../rules/reorder_joins/join_graph.rs | 205 ++++++++++-------- .../optimization/rules/reorder_joins/mod.rs | 4 +- .../rules/reorder_joins/naive_join_order.rs | 181 ++++++++++++++++ 6 files changed, 298 insertions(+), 247 deletions(-) delete mode 100644 src/daft-logical-plan/src/optimization/rules/reorder_joins/greedy_join_order.rs create mode 100644 src/daft-logical-plan/src/optimization/rules/reorder_joins/naive_join_order.rs diff --git a/Cargo.lock b/Cargo.lock index 34b37bc81f..fa690309ce 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -2337,6 +2337,7 @@ dependencies = [ "log", "pretty_assertions", "pyo3", + "rand 0.8.5", "rstest", "serde", "snafu", diff --git a/src/daft-logical-plan/Cargo.toml b/src/daft-logical-plan/Cargo.toml index 1b4dab023f..dadde8b7f8 100644 --- a/src/daft-logical-plan/Cargo.toml +++ b/src/daft-logical-plan/Cargo.toml @@ -24,6 +24,7 @@ uuid = {version = "1", features = ["v4"]} [dev-dependencies] daft-dsl = {path = "../daft-dsl", features = ["test-utils"]} pretty_assertions = {workspace = true} +rand = "0.8" rstest = {workspace = true} test-log = {workspace = true} diff --git a/src/daft-logical-plan/src/optimization/rules/reorder_joins/greedy_join_order.rs b/src/daft-logical-plan/src/optimization/rules/reorder_joins/greedy_join_order.rs deleted file mode 100644 index 0f5a12592d..0000000000 --- a/src/daft-logical-plan/src/optimization/rules/reorder_joins/greedy_join_order.rs +++ /dev/null @@ -1,153 +0,0 @@ -use std::{collections::HashMap, sync::Arc}; - -use common_error::DaftResult; -use daft_dsl::{col, ExprRef}; - -use super::join_graph::{JoinCondition, JoinGraph}; -use crate::{LogicalPlanBuilder, LogicalPlanRef}; - -// This is an implementation of the Greedy Operator Ordering algorithm (GOO) [1] for join selection. This algorithm -// selects join edges greedily by picking the edge with the smallest cost at each step. This is similar to Kruskal's -// minimum spanning tree algorithm, with the caveat that edge costs update at each step, due to changing cardinalities -// and selectivities between join nodes. -// -// Compared to DP-based algorithms, GOO is not always optimal. However, GOO has a complexity of O(n^3) and is more viable -// than DP-based algorithms when performing join ordering on many relations. DP Connected subgraph Complement Pairs (DPccp) [2] -// is the DP-based algorithm widely used in database systems today and has a O(3^n) complexity, although the latest -// literature does offer a super-polynomially faster DP-algorithm but that still has a O(2^n) to O(2^n * n^3) complexity [3]. -// -// For this reason, we maintain a greedy-based join ordering algorithm to use when the number of relations is large, and default -// to DP-based algorithms otherwise. -// -// [1]: Fegaras, L. (1998). A New Heuristic for Optimizing Large Queries. International Conference on Database and Expert Systems Applications. -// [2]: Moerkotte, G., & Neumann, T. (2006). Analysis of two existing and one new dynamic programming algorithm for the generation of optimal bushy join trees without cross products. Very Large Data Bases Conference. -// [3]: Stoian, M., & Kipf, A. (2024). DPconv: Super-Polynomially Faster Join Ordering. ArXiv, abs/2409.08013. -pub(crate) struct GreedyJoinOrderer {} - -impl GreedyJoinOrderer { - /// Consumes the join graph and transforms it into a logical plan with joins reordered. - pub(crate) fn compute_join_order(join_graph: &mut JoinGraph) -> DaftResult { - // While the join graph consists of more than one join node, select the edge that has the smallest cost, - // then join the left and right nodes connected by this edge. - while join_graph.adj_list.0.len() > 1 { - let selected_pair = GreedyJoinOrderer::find_minimum_cost_join(&join_graph.adj_list.0); - if let Some((left, right, join_conds)) = selected_pair { - // Join the left and right relations using the given join conditions. - let (left_on, right_on) = join_conds - .iter() - .map(|join_cond| { - ( - col(join_cond.left_on.clone()), - col(join_cond.right_on.clone()), - ) - }) - .collect::<(Vec, Vec)>(); - let left_builder = LogicalPlanBuilder::from(left.clone()); - let join = left_builder - .inner_join(right.clone(), left_on, right_on)? - .build(); - let join = Arc::new(Arc::unwrap_or_clone(join).with_materialized_stats()); - - // Add the new node into the adjacency list. - let left_neighbors = join_graph.adj_list.0.remove(&left).unwrap(); - let right_neighbors = join_graph.adj_list.0.remove(&right).unwrap(); - let mut new_join_edges = HashMap::new(); - - // Helper function that takes in neighbors to the left and right nodes, then combines edges that point - // back to the left and/or right nodes into edges that point to the new join node. - let mut update_neighbors = - |neighbors: HashMap>| { - for (neighbor, _) in neighbors { - if neighbor == right || neighbor == left { - // Skip the nodes that we just joined. - continue; - } - let mut join_conditions = Vec::new(); - // If this neighbor was connected to left or right nodes, collect the join conditions. - let neighbor_edges = join_graph - .adj_list - .0 - .get_mut(&neighbor) - .expect("The neighbor should still be in the join graph"); - if let Some(left_conds) = neighbor_edges.remove(&left) { - join_conditions.extend(left_conds); - } - if let Some(right_conds) = neighbor_edges.remove(&right) { - join_conditions.extend(right_conds); - } - // If this neighbor had any connections to left or right, create a new edge to the new join node. - if !join_conditions.is_empty() { - neighbor_edges.insert(join.clone(), join_conditions.clone()); - new_join_edges.insert( - neighbor.clone(), - join_conditions.iter().map(|cond| cond.flip()).collect(), - ); - } - } - }; - - // Process all neighbors from both the left and right sides. - update_neighbors(left_neighbors); - update_neighbors(right_neighbors); - - // Add the new join node and its edges to the graph. - join_graph.adj_list.0.insert(join, new_join_edges); - } else { - panic!( - "No valid join edge selected despite join graph containing more than one relation" - ); - } - } - // Apply projections and filters on top of the fully joined plan. - if let Some(joined_plan) = join_graph.adj_list.0.drain().map(|(plan, _)| plan).last() { - join_graph.apply_projections_and_filters_to_plan(joined_plan) - } else { - panic!("No valid logical plan after join reordering") - } - } - - /// Helper functions that finds the next join edge in the adjacency list that has the smallest cost. - /// Currently cost is determined based on the max size in bytes of the candidate left and right relations. - fn find_minimum_cost_join( - adj_list: &HashMap>>, - ) -> Option<(LogicalPlanRef, LogicalPlanRef, Vec)> { - let mut min_cost = None; - let mut selected_pair = None; - - for (candidate_left, neighbors) in adj_list { - for (candidate_right, join_conds) in neighbors { - let left_stats = candidate_left.materialized_stats(); - let right_stats = candidate_right.materialized_stats(); - - // Assume primary key foreign key join which would have a size bounded by the foreign key relation, - // which is typically larger. - let cur_cost = left_stats - .approx_stats - .upper_bound_bytes - .max(right_stats.approx_stats.upper_bound_bytes); - - if let Some(existing_min) = min_cost { - if let Some(current) = cur_cost { - if current < existing_min { - min_cost = Some(current); - selected_pair = Some(( - candidate_left.clone(), - candidate_right.clone(), - join_conds.clone(), - )); - } - } - } else { - min_cost = cur_cost; - selected_pair = Some(( - candidate_left.clone(), - candidate_right.clone(), - join_conds.clone(), - )); - } - } - } - - selected_pair - } -} diff --git a/src/daft-logical-plan/src/optimization/rules/reorder_joins/join_graph.rs b/src/daft-logical-plan/src/optimization/rules/reorder_joins/join_graph.rs index e7c09f3174..2aa5234ff0 100644 --- a/src/daft-logical-plan/src/optimization/rules/reorder_joins/join_graph.rs +++ b/src/daft-logical-plan/src/optimization/rules/reorder_joins/join_graph.rs @@ -13,8 +13,42 @@ use crate::{ LogicalPlan, LogicalPlanBuilder, LogicalPlanRef, }; -#[derive(Debug)] -struct JoinNode { +// TODO(desmond): In the future these trees should keep track of current cost estimates. +#[derive(Clone, Debug)] +pub(super) enum JoinOrderTree { + Relation(usize), // (id). + Join(Box, Box, Vec), // (subtree, subtree, nodes involved). +} + +impl JoinOrderTree { + pub(super) fn join(self: Box, right: Box) -> Box { + let mut nodes = self.nodes(); + nodes.append(&mut right.nodes()); + Box::new(JoinOrderTree::Join(self, right, nodes)) + } + + pub(super) fn nodes(&self) -> Vec { + match self { + Self::Relation(id) => vec![*id], + Self::Join(_, _, nodes) => nodes.clone(), + } + } + + // Helper function that checks if the join order tree contains a given id. + pub(super) fn contains(&self, target_id: usize) -> bool { + match self { + Self::Relation(id) => *id == target_id, + Self::Join(left, right, _) => left.contains(target_id) || right.contains(target_id), + } + } +} + +pub(super) trait JoinOrderer { + fn order(&self, graph: &JoinGraph) -> Box; +} + +#[derive(Clone, Debug)] +pub(super) struct JoinNode { relation_name: String, plan: LogicalPlanRef, final_name: String, @@ -27,7 +61,7 @@ struct JoinNode { /// JoinNodes represent a relation (i.e. a non-reorderable logical plan node), the column /// that's being accessed from the relation, and the final name of the column in the output. impl JoinNode { - fn new(relation_name: String, plan: LogicalPlanRef, final_name: String) -> Self { + pub(super) fn new(relation_name: String, plan: LogicalPlanRef, final_name: String) -> Self { Self { relation_name, plan, @@ -52,31 +86,31 @@ impl Display for JoinNode { } #[derive(Clone, Debug)] -pub(crate) struct JoinCondition { +pub(super) struct JoinCondition { pub left_on: String, pub right_on: String, } -impl JoinCondition { - pub(crate) fn flip(&self) -> Self { - JoinCondition { - left_on: self.right_on.clone(), - right_on: self.left_on.clone(), - } - } +pub(super) struct JoinAdjList { + pub max_id: usize, + plan_to_id: HashMap<*const LogicalPlan, usize>, + id_to_plan: HashMap, + pub edges: HashMap>>, } -pub(crate) struct JoinAdjList( - pub HashMap>>, -); - impl std::fmt::Display for JoinAdjList { fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { writeln!(f, "Join Graph Adjacency List:")?; - for (node, neighbors) in &self.0 { - writeln!(f, "Node {}:", node.name())?; - for (neighbor, join_conds) in neighbors { - writeln!(f, " -> {} with conditions:", neighbor.name())?; + for (node_id, neighbors) in &self.edges { + let node = self.id_to_plan.get(node_id).unwrap(); + writeln!(f, "Node {} (id = {node_id}):", node.name())?; + for (neighbor_id, join_conds) in neighbors { + let neighbor = self.id_to_plan.get(neighbor_id).unwrap(); + writeln!( + f, + " -> {} (id = {neighbor_id}) with conditions:", + neighbor.name() + )?; for (i, cond) in join_conds.iter().enumerate() { writeln!(f, " {}: {} = {}", i, cond.left_on, cond.right_on)?; } @@ -87,28 +121,66 @@ impl std::fmt::Display for JoinAdjList { } impl JoinAdjList { + pub(super) fn empty() -> Self { + Self { + max_id: 0, + plan_to_id: HashMap::new(), + id_to_plan: HashMap::new(), + edges: HashMap::new(), + } + } + + pub(super) fn get_plan_id(&mut self, plan: &LogicalPlanRef) -> usize { + let ptr = Arc::as_ptr(plan); + if let Some(id) = self.plan_to_id.get(&ptr) { + *id + } else { + let id = self.max_id; + self.max_id += 1; + self.plan_to_id.insert(ptr, id); + self.id_to_plan.insert(id, plan.clone()); + id + } + } + fn add_unidirectional_edge(&mut self, left: &JoinNode, right: &JoinNode) { // TODO(desmond): We should also keep track of projections that we need to do. let join_condition = JoinCondition { left_on: left.final_name.clone(), right_on: right.final_name.clone(), }; - if let Some(neighbors) = self.0.get_mut(&left.plan) { - if let Some(join_conditions) = neighbors.get_mut(&right.plan) { + let left_id = self.get_plan_id(&left.plan); + let right_id = self.get_plan_id(&right.plan); + if let Some(neighbors) = self.edges.get_mut(&left_id) { + if let Some(join_conditions) = neighbors.get_mut(&right_id) { join_conditions.push(join_condition); } else { - neighbors.insert(right.plan.clone(), vec![join_condition]); + neighbors.insert(right_id, vec![join_condition]); } } else { let mut neighbors = HashMap::new(); - neighbors.insert(right.plan.clone(), vec![join_condition]); - self.0.insert(left.plan.clone(), neighbors); + neighbors.insert(right_id, vec![join_condition]); + self.edges.insert(left_id, neighbors); } } - fn add_bidirectional_edge(&mut self, node1: JoinNode, node2: JoinNode) { + + pub(super) fn add_bidirectional_edge(&mut self, node1: JoinNode, node2: JoinNode) { self.add_unidirectional_edge(&node1, &node2); self.add_unidirectional_edge(&node2, &node1); } + + pub(super) fn connected(&self, left_nodes: &Vec, right_nodes: &Vec) -> bool { + for left_node in left_nodes { + if let Some(neighbors) = self.edges.get(left_node) { + for right_node in right_nodes { + if let Some(_) = neighbors.get(right_node) { + return true; + } + } + } + } + return false; + } } #[derive(Debug)] @@ -138,42 +210,10 @@ impl JoinGraph { } } - pub(crate) fn apply_projections_and_filters_to_plan( - &mut self, - plan: LogicalPlanRef, - ) -> DaftResult { - let mut plan = LogicalPlanBuilder::from(plan); - // Apply projections and filters in post-traversal order. - let mut reversed_items = self - .final_projections_and_filters - .drain(..) - .rev() - .peekable(); - while let Some(projection_or_filter) = reversed_items.next() { - let is_last = reversed_items.peek().is_none(); - - match projection_or_filter { - ProjectionOrFilter::Projection(projections) => { - if is_last { - // The final projection is the output projection, so here we select the final projection. - plan = plan.select(projections)?; - } else { - // Intermediate projections might only transform a subset of columns, so we use `with_columns()` instead of `select()`. - plan = plan.with_columns(projections)?; - } - } - ProjectionOrFilter::Filter(predicate) => { - plan = plan.filter(predicate)?; - } - } - } - Ok(plan.build()) - } - /// Test helper function to get the number of edges that the current graph contains. pub(crate) fn num_edges(&self) -> usize { let mut num_edges = 0; - for (_, edges) in &self.adj_list.0 { + for (_, edges) in &self.adj_list.edges { num_edges += edges.len(); } // Each edge is bidirectional, so we divide by 2 to get the correct number of edges. @@ -182,7 +222,7 @@ impl JoinGraph { /// Test helper function to check that all relations in this graph are connected. pub(crate) fn fully_connected(&self) -> bool { - let start = if let Some((node, _)) = self.adj_list.0.iter().next() { + let start = if let Some((node, _)) = self.adj_list.edges.iter().next() { node } else { // There are no nodes. The empty graph is fully connected. @@ -195,7 +235,7 @@ impl JoinGraph { while let Some(current) = stack.pop() { if seen.insert(current) { // If this is a new node, add all its neighbors to the stack. - if let Some(neighbors) = self.adj_list.0.get(current) { + if let Some(neighbors) = self.adj_list.edges.get(current) { stack.extend(neighbors.iter().filter_map(|(neighbor, _)| { if !seen.contains(neighbor) { Some(neighbor) @@ -206,7 +246,7 @@ impl JoinGraph { } } } - seen.len() == self.adj_list.0.len() + seen.len() == self.adj_list.max_id } /// Test helper function that checks if the graph contains the given projection/filter expressions @@ -236,8 +276,10 @@ impl JoinGraph { /// exists in the current graph. pub(crate) fn contains_edges(&self, to_check: Vec<&str>) -> bool { let mut edge_strings = HashSet::new(); - for (left, neighbors) in &self.adj_list.0 { - for (right, join_conds) in neighbors { + for (left_id, neighbors) in &self.adj_list.edges { + for (right_id, join_conds) in neighbors { + let left = self.adj_list.id_to_plan.get(left_id).unwrap(); + let right = self.adj_list.id_to_plan.get(right_id).unwrap(); for join_cond in join_conds { edge_strings.insert(format!( "{}({}) <-> {}({})", @@ -288,7 +330,7 @@ impl JoinGraphBuilder { plan, join_conds_to_resolve: vec![], final_name_map: HashMap::new(), - adj_list: JoinAdjList(HashMap::new()), + adj_list: JoinAdjList::empty(), final_projections_and_filters: vec![ProjectionOrFilter::Projection(output_projection)], } } @@ -487,10 +529,7 @@ mod tests { use super::JoinGraphBuilder; use crate::{ - optimization::rules::{ - reorder_joins::greedy_join_order::GreedyJoinOrderer, EnrichWithStats, MaterializeScans, - OptimizerRule, - }, + optimization::rules::{EnrichWithStats, MaterializeScans, OptimizerRule}, test::{ dummy_scan_node_with_pushdowns, dummy_scan_operator, dummy_scan_operator_with_size, }, @@ -553,7 +592,7 @@ mod tests { .unwrap(); let stats_enricher = EnrichWithStats::new(); let original_plan = stats_enricher.try_optimize(original_plan).data().unwrap(); - let mut join_graph = JoinGraphBuilder::from_logical_plan(original_plan.clone()).build(); + let join_graph = JoinGraphBuilder::from_logical_plan(original_plan.clone()).build(); assert!(join_graph.fully_connected()); // There should be edges between: // - a <-> b @@ -565,9 +604,6 @@ mod tests { "Project(c) <-> Source(d)", "Source(a) <-> Source(d)" ])); - // Test greedy join reordering. - let reordered_plan = GreedyJoinOrderer::compute_join_order(&mut join_graph).unwrap(); - assert!(reordered_plan.schema() == original_plan.schema()); } #[test] @@ -627,7 +663,7 @@ mod tests { .unwrap(); let stats_enricher = EnrichWithStats::new(); let original_plan = stats_enricher.try_optimize(original_plan).data().unwrap(); - let mut join_graph = JoinGraphBuilder::from_logical_plan(original_plan.clone()).build(); + let join_graph = JoinGraphBuilder::from_logical_plan(original_plan.clone()).build(); assert!(join_graph.fully_connected()); // There should be edges between: // - a <-> b @@ -639,9 +675,6 @@ mod tests { "Project(c) <-> Source(d)", "Source(b) <-> Source(d)", ])); - // Test greedy join reordering. - let reordered_plan = GreedyJoinOrderer::compute_join_order(&mut join_graph).unwrap(); - assert!(reordered_plan.schema() == original_plan.schema()); } #[test] @@ -695,7 +728,7 @@ mod tests { .unwrap(); let stats_enricher = EnrichWithStats::new(); let original_plan = stats_enricher.try_optimize(original_plan).data().unwrap(); - let mut join_graph = JoinGraphBuilder::from_logical_plan(original_plan.clone()).build(); + let join_graph = JoinGraphBuilder::from_logical_plan(original_plan.clone()).build(); assert!(join_graph.fully_connected()); // There should be edges between: // - a_beta <-> b @@ -705,9 +738,6 @@ mod tests { "Project(a_beta) <-> Source(b)", "Project(a_beta) <-> Source(c)", ])); - // Test greedy join reordering. - let reordered_plan = GreedyJoinOrderer::compute_join_order(&mut join_graph).unwrap(); - assert!(reordered_plan.schema() == original_plan.schema()); } #[test] @@ -768,7 +798,7 @@ mod tests { .unwrap(); let stats_enricher = EnrichWithStats::new(); let original_plan = stats_enricher.try_optimize(original_plan).data().unwrap(); - let mut join_graph = JoinGraphBuilder::from_logical_plan(original_plan.clone()).build(); + let join_graph = JoinGraphBuilder::from_logical_plan(original_plan.clone()).build(); assert!(join_graph.fully_connected()); // There should be edges between: // - a <-> b @@ -784,9 +814,6 @@ mod tests { // `c_prime` gets renamed to `c` in the final projection let double_proj = col("c").add(col("c")).alias("double"); assert!(join_graph.contains_projections_and_filters(vec![&double_proj])); - // Test greedy join reordering. - let reordered_plan = GreedyJoinOrderer::compute_join_order(&mut join_graph).unwrap(); - assert!(reordered_plan.schema() == original_plan.schema()); } #[test] @@ -865,7 +892,7 @@ mod tests { .unwrap(); let stats_enricher = EnrichWithStats::new(); let original_plan = stats_enricher.try_optimize(original_plan).data().unwrap(); - let mut join_graph = JoinGraphBuilder::from_logical_plan(original_plan.clone()).build(); + let join_graph = JoinGraphBuilder::from_logical_plan(original_plan.clone()).build(); assert!(join_graph.fully_connected()); // There should be edges between: // - a <-> b @@ -887,9 +914,6 @@ mod tests { &double_proj, &filter_c_prime, ])); - // Test greedy join reordering. - let reordered_plan = GreedyJoinOrderer::compute_join_order(&mut join_graph).unwrap(); - assert!(reordered_plan.schema() == original_plan.schema()); } #[test] @@ -963,7 +987,7 @@ mod tests { .unwrap(); let stats_enricher = EnrichWithStats::new(); let original_plan = stats_enricher.try_optimize(original_plan).data().unwrap(); - let mut join_graph = JoinGraphBuilder::from_logical_plan(original_plan.clone()).build(); + let join_graph = JoinGraphBuilder::from_logical_plan(original_plan.clone()).build(); assert!(join_graph.fully_connected()); // There should be edges between: // - a <-> b @@ -977,8 +1001,5 @@ mod tests { ])); // Projections below the aggregation should not be part of the final projections. assert!(!join_graph.contains_projections_and_filters(vec![&a_proj])); - // Test greedy join reordering. - let reordered_plan = GreedyJoinOrderer::compute_join_order(&mut join_graph).unwrap(); - assert!(reordered_plan.schema() == original_plan.schema()); } } diff --git a/src/daft-logical-plan/src/optimization/rules/reorder_joins/mod.rs b/src/daft-logical-plan/src/optimization/rules/reorder_joins/mod.rs index 58987555ab..c8644b620e 100644 --- a/src/daft-logical-plan/src/optimization/rules/reorder_joins/mod.rs +++ b/src/daft-logical-plan/src/optimization/rules/reorder_joins/mod.rs @@ -1,4 +1,4 @@ #[cfg(test)] -mod greedy_join_order; -#[cfg(test)] mod join_graph; +#[cfg(test)] +mod naive_join_order; diff --git a/src/daft-logical-plan/src/optimization/rules/reorder_joins/naive_join_order.rs b/src/daft-logical-plan/src/optimization/rules/reorder_joins/naive_join_order.rs new file mode 100644 index 0000000000..0fba6e4fa3 --- /dev/null +++ b/src/daft-logical-plan/src/optimization/rules/reorder_joins/naive_join_order.rs @@ -0,0 +1,181 @@ +use super::join_graph::{JoinGraph, JoinOrderTree, JoinOrderer}; + +pub(crate) struct NaiveJoinOrderer {} + +impl NaiveJoinOrderer { + fn extend_order( + graph: &JoinGraph, + current_order: Box, + mut available: Vec, + ) -> Box { + if available.is_empty() { + return current_order; + } + for (index, candidate_node_id) in available.iter().enumerate() { + let right = Box::new(JoinOrderTree::Relation(*candidate_node_id)); + if graph + .adj_list + .connected(¤t_order.nodes(), &right.nodes()) + { + let new_order = current_order.join(right); + available.remove(index); + return Self::extend_order(graph, new_order, available); + } + } + panic!("There should be at least one naive join order."); + } +} + +impl JoinOrderer for NaiveJoinOrderer { + fn order(&self, graph: &JoinGraph) -> Box { + let available: Vec = (1..graph.adj_list.max_id).collect(); + // Take a starting order of the node with id 0. + let starting_order = Box::new(JoinOrderTree::Relation(0)); + Self::extend_order(graph, starting_order, available) + } +} + +#[cfg(test)] +mod tests { + use common_scan_info::Pushdowns; + use daft_schema::{dtype::DataType, field::Field}; + use rand::{seq::SliceRandom, Rng}; + + use super::{JoinGraph, JoinOrderTree, JoinOrderer, NaiveJoinOrderer}; + use crate::{ + optimization::rules::reorder_joins::join_graph::{JoinAdjList, JoinNode}, + test::{dummy_scan_node_with_pushdowns, dummy_scan_operator_with_size}, + LogicalPlanRef, + }; + + fn assert_order_contains_all_nodes(order: &Box, graph: &JoinGraph) { + for id in 0..graph.adj_list.max_id { + assert!( + order.contains(id), + "Graph id {} not found in order {:?}.\n{}", + id, + order, + graph.adj_list + ); + } + } + + fn create_scan_node(name: &str, size: Option) -> LogicalPlanRef { + dummy_scan_node_with_pushdowns( + dummy_scan_operator_with_size(vec![Field::new(name, DataType::Int64)], size), + Pushdowns::default(), + ) + .build() + } + + fn create_join_graph_with_edges(nodes: Vec, edges: Vec<(usize, usize)>) -> JoinGraph { + let mut adj_list = JoinAdjList::empty(); + for (from, to) in edges { + adj_list.add_bidirectional_edge(nodes[from].clone(), nodes[to].clone()); + } + JoinGraph::new(adj_list, vec![]) + } + + macro_rules! create_and_test_join_graph { + ($nodes:expr, $edges:expr, $orderer:expr) => { + let nodes: Vec = $nodes + .iter() + .map(|name| { + let scan_node = create_scan_node(name, Some(100)); + JoinNode::new(name.to_string(), scan_node, name.to_string()) + }) + .collect(); + let graph = create_join_graph_with_edges(nodes.clone(), $edges); + let order = $orderer.order(&graph); + assert_order_contains_all_nodes(&order, &graph); + }; + } + + #[test] + fn test_order_basic_join_graph() { + let nodes = vec!["a", "b", "c", "d"]; + let edges = vec![ + (0, 2), // node_a <-> node_c + (1, 2), // node_b <-> node_c + (2, 3), // node_c <-> node_d + ]; + create_and_test_join_graph!(nodes, edges, NaiveJoinOrderer {}); + } + + pub struct UnionFind { + parent: Vec, + size: Vec, + } + + impl UnionFind { + pub fn create(num_nodes: usize) -> Self { + UnionFind { + parent: (0..num_nodes).collect(), + size: vec![1; num_nodes], + } + } + + pub fn find(&mut self, node: usize) -> usize { + if self.parent[node] != node { + self.parent[node] = self.find(self.parent[node]); + } + self.parent[node] + } + + pub fn union(&mut self, node1: usize, node2: usize) { + let root1 = self.find(node1); + let root2 = self.find(node2); + + if root1 != root2 { + let (small, big) = if self.size[root1] < self.size[root2] { + (root1, root2) + } else { + (root2, root1) + }; + self.parent[small] = big; + self.size[big] += self.size[small]; + } + } + } + + fn create_random_connected_graph(num_nodes: usize) -> Vec<(usize, usize)> { + let mut rng = rand::thread_rng(); + let mut edges = Vec::new(); + let mut uf = UnionFind::create(num_nodes); + + // Get a random order of all possible edges. + let mut all_edges: Vec<(usize, usize)> = (0..num_nodes) + .flat_map(|i| (0..i).chain(i + 1..num_nodes).map(move |j| (i, j))) + .collect(); + all_edges.shuffle(&mut rng); + + // Select edges to form a minimum spanning tree + a random number of extra edges. + for (a, b) in all_edges { + if uf.find(a) != uf.find(b) { + uf.union(a, b); + edges.push((a, b)); + } + // Check if we have a minimum spanning tree. + if edges.len() >= num_nodes - 1 { + // Once we have a minimum spanning tree, we let a random number of extra edges be added to the graph. + if rng.gen_bool(0.3) { + break; + } + edges.push((a, b)); + } + } + + edges + } + + const NUM_RANDOM_NODES: usize = 100; + + #[test] + fn test_order_random_join_graph() { + let nodes: Vec = (0..NUM_RANDOM_NODES) + .map(|i| format!("node_{}", i)) + .collect(); + let edges = create_random_connected_graph(NUM_RANDOM_NODES); + create_and_test_join_graph!(nodes, edges, NaiveJoinOrderer {}); + } +} From b2d9c2d768f1d1a05fe684edf81be9b90bdd7216 Mon Sep 17 00:00:00 2001 From: desmondcheongzx Date: Thu, 19 Dec 2024 17:27:43 +0800 Subject: [PATCH 10/11] Smoll cleanup --- .../src/optimization/rules/reorder_joins/join_graph.rs | 1 - 1 file changed, 1 deletion(-) diff --git a/src/daft-logical-plan/src/optimization/rules/reorder_joins/join_graph.rs b/src/daft-logical-plan/src/optimization/rules/reorder_joins/join_graph.rs index 2aa5234ff0..d9c65e6e3b 100644 --- a/src/daft-logical-plan/src/optimization/rules/reorder_joins/join_graph.rs +++ b/src/daft-logical-plan/src/optimization/rules/reorder_joins/join_graph.rs @@ -228,7 +228,6 @@ impl JoinGraph { // There are no nodes. The empty graph is fully connected. return true; }; - // let start_ptr = Arc::as_ptr(&self.edges[0].0.plan); let mut seen = HashSet::new(); let mut stack = vec![start]; From 4e6d79ff1e756ea595a91a6fc27186c179cac99d Mon Sep 17 00:00:00 2001 From: desmondcheongzx Date: Fri, 20 Dec 2024 16:18:49 +0800 Subject: [PATCH 11/11] Address comments --- .../rules/reorder_joins/join_graph.rs | 96 +++++++++++++------ .../optimization/rules/reorder_joins/mod.rs | 2 +- ...order.rs => naive_left_deep_join_order.rs} | 21 ++-- 3 files changed, 76 insertions(+), 43 deletions(-) rename src/daft-logical-plan/src/optimization/rules/reorder_joins/{naive_join_order.rs => naive_left_deep_join_order.rs} (90%) diff --git a/src/daft-logical-plan/src/optimization/rules/reorder_joins/join_graph.rs b/src/daft-logical-plan/src/optimization/rules/reorder_joins/join_graph.rs index d9c65e6e3b..e7cd907657 100644 --- a/src/daft-logical-plan/src/optimization/rules/reorder_joins/join_graph.rs +++ b/src/daft-logical-plan/src/optimization/rules/reorder_joins/join_graph.rs @@ -13,33 +13,54 @@ use crate::{ LogicalPlan, LogicalPlanBuilder, LogicalPlanRef, }; -// TODO(desmond): In the future these trees should keep track of current cost estimates. +/// A JoinOrderTree is a tree that describes a join order between relations, which can range from left deep trees +/// to bushy trees. A relations in a JoinOrderTree contain IDs instead of logical plan references. An ID's +/// corresponding logical plan reference can be found by consulting the JoinAdjList that was used to produce the +/// given JoinOrderTree. +/// +/// TODO(desmond): In the future these trees should keep track of current cost estimates. #[derive(Clone, Debug)] pub(super) enum JoinOrderTree { - Relation(usize), // (id). - Join(Box, Box, Vec), // (subtree, subtree, nodes involved). + Relation(usize), // (ID). + Join(Box, Box), // (subtree, subtree). } impl JoinOrderTree { pub(super) fn join(self: Box, right: Box) -> Box { - let mut nodes = self.nodes(); - nodes.append(&mut right.nodes()); - Box::new(JoinOrderTree::Join(self, right, nodes)) - } - - pub(super) fn nodes(&self) -> Vec { - match self { - Self::Relation(id) => vec![*id], - Self::Join(_, _, nodes) => nodes.clone(), - } + Box::new(JoinOrderTree::Join(self, right)) } // Helper function that checks if the join order tree contains a given id. pub(super) fn contains(&self, target_id: usize) -> bool { match self { Self::Relation(id) => *id == target_id, - Self::Join(left, right, _) => left.contains(target_id) || right.contains(target_id), + Self::Join(left, right) => left.contains(target_id) || right.contains(target_id), + } + } + + pub(super) fn iter(&self) -> JoinOrderTreeIterator { + JoinOrderTreeIterator { stack: vec![self] } + } +} + +pub(super) struct JoinOrderTreeIterator<'a> { + stack: Vec<&'a JoinOrderTree>, +} + +impl<'a> Iterator for JoinOrderTreeIterator<'a> { + type Item = usize; + + fn next(&mut self) -> Option { + while let Some(node) = self.stack.pop() { + match node { + JoinOrderTree::Relation(id) => return Some(*id), + JoinOrderTree::Join(left, right) => { + self.stack.push(left); + self.stack.push(right); + } + } } + None } } @@ -130,7 +151,7 @@ impl JoinAdjList { } } - pub(super) fn get_plan_id(&mut self, plan: &LogicalPlanRef) -> usize { + pub(super) fn get_or_create_plan_id(&mut self, plan: &LogicalPlanRef) -> usize { let ptr = Arc::as_ptr(plan); if let Some(id) = self.plan_to_id.get(&ptr) { *id @@ -143,14 +164,12 @@ impl JoinAdjList { } } - fn add_unidirectional_edge(&mut self, left: &JoinNode, right: &JoinNode) { - // TODO(desmond): We should also keep track of projections that we need to do. - let join_condition = JoinCondition { - left_on: left.final_name.clone(), - right_on: right.final_name.clone(), - }; - let left_id = self.get_plan_id(&left.plan); - let right_id = self.get_plan_id(&right.plan); + fn add_join_condition( + &mut self, + left_id: usize, + right_id: usize, + join_condition: JoinCondition, + ) { if let Some(neighbors) = self.edges.get_mut(&left_id) { if let Some(join_conditions) = neighbors.get_mut(&right_id) { join_conditions.push(join_condition); @@ -164,16 +183,26 @@ impl JoinAdjList { } } + fn add_unidirectional_edge(&mut self, left: &JoinNode, right: &JoinNode) { + let join_condition = JoinCondition { + left_on: left.final_name.clone(), + right_on: right.final_name.clone(), + }; + let left_id = self.get_or_create_plan_id(&left.plan); + let right_id = self.get_or_create_plan_id(&right.plan); + self.add_join_condition(left_id, right_id, join_condition); + } + pub(super) fn add_bidirectional_edge(&mut self, node1: JoinNode, node2: JoinNode) { self.add_unidirectional_edge(&node1, &node2); self.add_unidirectional_edge(&node2, &node1); } - pub(super) fn connected(&self, left_nodes: &Vec, right_nodes: &Vec) -> bool { - for left_node in left_nodes { - if let Some(neighbors) = self.edges.get(left_node) { - for right_node in right_nodes { - if let Some(_) = neighbors.get(right_node) { + pub(super) fn connected_join_trees(&self, left: &JoinOrderTree, right: &JoinOrderTree) -> bool { + for left_node in left.iter() { + if let Some(neighbors) = self.edges.get(&left_node) { + for right_node in right.iter() { + if let Some(_) = neighbors.get(&right_node) { return true; } } @@ -271,14 +300,21 @@ impl JoinGraph { false } + fn get_node_by_id(&self, id: usize) -> &LogicalPlanRef { + self.adj_list + .id_to_plan + .get(&id) + .expect("Tried to retrieve a plan from the join graph with an invalid ID") + } + /// Helper function that loosely checks if a given edge (represented by a simple string) /// exists in the current graph. pub(crate) fn contains_edges(&self, to_check: Vec<&str>) -> bool { let mut edge_strings = HashSet::new(); for (left_id, neighbors) in &self.adj_list.edges { for (right_id, join_conds) in neighbors { - let left = self.adj_list.id_to_plan.get(left_id).unwrap(); - let right = self.adj_list.id_to_plan.get(right_id).unwrap(); + let left = self.get_node_by_id(*left_id); + let right = self.get_node_by_id(*right_id); for join_cond in join_conds { edge_strings.insert(format!( "{}({}) <-> {}({})", diff --git a/src/daft-logical-plan/src/optimization/rules/reorder_joins/mod.rs b/src/daft-logical-plan/src/optimization/rules/reorder_joins/mod.rs index c8644b620e..762d58a4a8 100644 --- a/src/daft-logical-plan/src/optimization/rules/reorder_joins/mod.rs +++ b/src/daft-logical-plan/src/optimization/rules/reorder_joins/mod.rs @@ -1,4 +1,4 @@ #[cfg(test)] mod join_graph; #[cfg(test)] -mod naive_join_order; +mod naive_left_deep_join_order; diff --git a/src/daft-logical-plan/src/optimization/rules/reorder_joins/naive_join_order.rs b/src/daft-logical-plan/src/optimization/rules/reorder_joins/naive_left_deep_join_order.rs similarity index 90% rename from src/daft-logical-plan/src/optimization/rules/reorder_joins/naive_join_order.rs rename to src/daft-logical-plan/src/optimization/rules/reorder_joins/naive_left_deep_join_order.rs index 0fba6e4fa3..c0b3ab6634 100644 --- a/src/daft-logical-plan/src/optimization/rules/reorder_joins/naive_join_order.rs +++ b/src/daft-logical-plan/src/optimization/rules/reorder_joins/naive_left_deep_join_order.rs @@ -1,8 +1,8 @@ use super::join_graph::{JoinGraph, JoinOrderTree, JoinOrderer}; -pub(crate) struct NaiveJoinOrderer {} +pub(crate) struct NaiveLeftDeepJoinOrderer {} -impl NaiveJoinOrderer { +impl NaiveLeftDeepJoinOrderer { fn extend_order( graph: &JoinGraph, current_order: Box, @@ -13,10 +13,7 @@ impl NaiveJoinOrderer { } for (index, candidate_node_id) in available.iter().enumerate() { let right = Box::new(JoinOrderTree::Relation(*candidate_node_id)); - if graph - .adj_list - .connected(¤t_order.nodes(), &right.nodes()) - { + if graph.adj_list.connected_join_trees(¤t_order, &right) { let new_order = current_order.join(right); available.remove(index); return Self::extend_order(graph, new_order, available); @@ -26,7 +23,7 @@ impl NaiveJoinOrderer { } } -impl JoinOrderer for NaiveJoinOrderer { +impl JoinOrderer for NaiveLeftDeepJoinOrderer { fn order(&self, graph: &JoinGraph) -> Box { let available: Vec = (1..graph.adj_list.max_id).collect(); // Take a starting order of the node with id 0. @@ -39,9 +36,9 @@ impl JoinOrderer for NaiveJoinOrderer { mod tests { use common_scan_info::Pushdowns; use daft_schema::{dtype::DataType, field::Field}; - use rand::{seq::SliceRandom, Rng}; + use rand::{rngs::StdRng, seq::SliceRandom, Rng, SeedableRng}; - use super::{JoinGraph, JoinOrderTree, JoinOrderer, NaiveJoinOrderer}; + use super::{JoinGraph, JoinOrderTree, JoinOrderer, NaiveLeftDeepJoinOrderer}; use crate::{ optimization::rules::reorder_joins::join_graph::{JoinAdjList, JoinNode}, test::{dummy_scan_node_with_pushdowns, dummy_scan_operator_with_size}, @@ -99,7 +96,7 @@ mod tests { (1, 2), // node_b <-> node_c (2, 3), // node_c <-> node_d ]; - create_and_test_join_graph!(nodes, edges, NaiveJoinOrderer {}); + create_and_test_join_graph!(nodes, edges, NaiveLeftDeepJoinOrderer {}); } pub struct UnionFind { @@ -139,7 +136,7 @@ mod tests { } fn create_random_connected_graph(num_nodes: usize) -> Vec<(usize, usize)> { - let mut rng = rand::thread_rng(); + let mut rng = StdRng::seed_from_u64(0); let mut edges = Vec::new(); let mut uf = UnionFind::create(num_nodes); @@ -176,6 +173,6 @@ mod tests { .map(|i| format!("node_{}", i)) .collect(); let edges = create_random_connected_graph(NUM_RANDOM_NODES); - create_and_test_join_graph!(nodes, edges, NaiveJoinOrderer {}); + create_and_test_join_graph!(nodes, edges, NaiveLeftDeepJoinOrderer {}); } }