From 5d4db4faecbfae2cea8dd0831c3ba19e1c59a08c Mon Sep 17 00:00:00 2001 From: Desmond Cheong Date: Fri, 20 Dec 2024 01:20:52 -0800 Subject: [PATCH] feat(optimizer): Implement naive join ordering (#3616) Implements a naive join orderer that simply takes joins relations arbitrarily (as long as a valid join condition exists). This is intended as a building block to ensure that our join graphs can correctly reconstruct into logical plans. The PR that will immediately follow this will create an optimization rule that applies naive join ordering. The optimization rule will be hidden behind a config flag, but will allow us to test logical plan reconstruction on all our integration tests. --- Cargo.lock | 1 + src/common/scan-info/src/test/mod.rs | 5 +- src/daft-logical-plan/Cargo.toml | 1 + .../rules/reorder_joins/join_graph.rs | 449 ++++++++++++++---- .../optimization/rules/reorder_joins/mod.rs | 2 + .../naive_left_deep_join_order.rs | 178 +++++++ src/daft-logical-plan/src/test/mod.rs | 12 +- src/daft-physical-plan/src/test/mod.rs | 1 + 8 files changed, 554 insertions(+), 95 deletions(-) create mode 100644 src/daft-logical-plan/src/optimization/rules/reorder_joins/naive_left_deep_join_order.rs diff --git a/Cargo.lock b/Cargo.lock index 3b93f43998..e648250c9f 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -2351,6 +2351,7 @@ dependencies = [ "log", "pretty_assertions", "pyo3", + "rand 0.8.5", "rstest", "serde", "snafu", diff --git a/src/common/scan-info/src/test/mod.rs b/src/common/scan-info/src/test/mod.rs index 0da27600e9..a9d248c5d6 100644 --- a/src/common/scan-info/src/test/mod.rs +++ b/src/common/scan-info/src/test/mod.rs @@ -17,12 +17,14 @@ use crate::{PartitionField, Pushdowns, ScanOperator, ScanTaskLike, ScanTaskLikeR struct DummyScanTask { pub schema: SchemaRef, pub pushdowns: Pushdowns, + pub in_memory_size: Option, } #[derive(Debug)] pub struct DummyScanOperator { pub schema: SchemaRef, pub num_scan_tasks: u32, + pub in_memory_size_per_task: Option, } #[typetag::serde] @@ -67,7 +69,7 @@ impl ScanTaskLike for DummyScanTask { } fn estimate_in_memory_size_bytes(&self, _: Option<&DaftExecutionConfig>) -> Option { - None + self.in_memory_size } fn file_format_config(&self) -> Arc { @@ -136,6 +138,7 @@ impl ScanOperator for DummyScanOperator { let scan_task = Arc::new(DummyScanTask { schema: self.schema.clone(), pushdowns, + in_memory_size: self.in_memory_size_per_task, }); Ok((0..self.num_scan_tasks) diff --git a/src/daft-logical-plan/Cargo.toml b/src/daft-logical-plan/Cargo.toml index f183a63ef3..cf70c38998 100644 --- a/src/daft-logical-plan/Cargo.toml +++ b/src/daft-logical-plan/Cargo.toml @@ -26,6 +26,7 @@ uuid = {version = "1", features = ["v4"]} [dev-dependencies] daft-dsl = {path = "../daft-dsl", features = ["test-utils"]} pretty_assertions = {workspace = true} +rand = "0.8" rstest = {workspace = true} test-log = {workspace = true} diff --git a/src/daft-logical-plan/src/optimization/rules/reorder_joins/join_graph.rs b/src/daft-logical-plan/src/optimization/rules/reorder_joins/join_graph.rs index f004fe0b3d..e7cd907657 100644 --- a/src/daft-logical-plan/src/optimization/rules/reorder_joins/join_graph.rs +++ b/src/daft-logical-plan/src/optimization/rules/reorder_joins/join_graph.rs @@ -4,25 +4,85 @@ use std::{ sync::Arc, }; +use common_error::DaftResult; use daft_core::join::JoinType; use daft_dsl::{col, optimization::replace_columns_with_expressions, ExprRef}; use crate::{ ops::{Filter, Join, Project}, - LogicalPlan, LogicalPlanRef, + LogicalPlan, LogicalPlanBuilder, LogicalPlanRef, }; -#[derive(Debug)] -struct JoinNode { +/// A JoinOrderTree is a tree that describes a join order between relations, which can range from left deep trees +/// to bushy trees. A relations in a JoinOrderTree contain IDs instead of logical plan references. An ID's +/// corresponding logical plan reference can be found by consulting the JoinAdjList that was used to produce the +/// given JoinOrderTree. +/// +/// TODO(desmond): In the future these trees should keep track of current cost estimates. +#[derive(Clone, Debug)] +pub(super) enum JoinOrderTree { + Relation(usize), // (ID). + Join(Box, Box), // (subtree, subtree). +} + +impl JoinOrderTree { + pub(super) fn join(self: Box, right: Box) -> Box { + Box::new(JoinOrderTree::Join(self, right)) + } + + // Helper function that checks if the join order tree contains a given id. + pub(super) fn contains(&self, target_id: usize) -> bool { + match self { + Self::Relation(id) => *id == target_id, + Self::Join(left, right) => left.contains(target_id) || right.contains(target_id), + } + } + + pub(super) fn iter(&self) -> JoinOrderTreeIterator { + JoinOrderTreeIterator { stack: vec![self] } + } +} + +pub(super) struct JoinOrderTreeIterator<'a> { + stack: Vec<&'a JoinOrderTree>, +} + +impl<'a> Iterator for JoinOrderTreeIterator<'a> { + type Item = usize; + + fn next(&mut self) -> Option { + while let Some(node) = self.stack.pop() { + match node { + JoinOrderTree::Relation(id) => return Some(*id), + JoinOrderTree::Join(left, right) => { + self.stack.push(left); + self.stack.push(right); + } + } + } + None + } +} + +pub(super) trait JoinOrderer { + fn order(&self, graph: &JoinGraph) -> Box; +} + +#[derive(Clone, Debug)] +pub(super) struct JoinNode { relation_name: String, plan: LogicalPlanRef, final_name: String, } +// TODO(desmond): We should also take into account user provided values for: +// - null equals null +// - join strategy + /// JoinNodes represent a relation (i.e. a non-reorderable logical plan node), the column /// that's being accessed from the relation, and the final name of the column in the output. impl JoinNode { - fn new(relation_name: String, plan: LogicalPlanRef, final_name: String) -> Self { + pub(super) fn new(relation_name: String, plan: LogicalPlanRef, final_name: String) -> Self { Self { relation_name, plan, @@ -46,36 +106,122 @@ impl Display for JoinNode { } } -/// JoinEdges currently represent a bidirectional edge between two relations that have -/// an equi-join condition between each other. -#[derive(Debug)] -struct JoinEdge(JoinNode, JoinNode); +#[derive(Clone, Debug)] +pub(super) struct JoinCondition { + pub left_on: String, + pub right_on: String, +} -impl JoinEdge { - /// Helper function that summarizes join edge information. - fn simple_repr(&self) -> String { - format!("{} <-> {}", self.0, self.1) - } +pub(super) struct JoinAdjList { + pub max_id: usize, + plan_to_id: HashMap<*const LogicalPlan, usize>, + id_to_plan: HashMap, + pub edges: HashMap>>, } -impl Display for JoinEdge { +impl std::fmt::Display for JoinAdjList { fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { - write!(f, "{}", self.simple_repr()) + writeln!(f, "Join Graph Adjacency List:")?; + for (node_id, neighbors) in &self.edges { + let node = self.id_to_plan.get(node_id).unwrap(); + writeln!(f, "Node {} (id = {node_id}):", node.name())?; + for (neighbor_id, join_conds) in neighbors { + let neighbor = self.id_to_plan.get(neighbor_id).unwrap(); + writeln!( + f, + " -> {} (id = {neighbor_id}) with conditions:", + neighbor.name() + )?; + for (i, cond) in join_conds.iter().enumerate() { + writeln!(f, " {}: {} = {}", i, cond.left_on, cond.right_on)?; + } + } + } + Ok(()) + } +} + +impl JoinAdjList { + pub(super) fn empty() -> Self { + Self { + max_id: 0, + plan_to_id: HashMap::new(), + id_to_plan: HashMap::new(), + edges: HashMap::new(), + } + } + + pub(super) fn get_or_create_plan_id(&mut self, plan: &LogicalPlanRef) -> usize { + let ptr = Arc::as_ptr(plan); + if let Some(id) = self.plan_to_id.get(&ptr) { + *id + } else { + let id = self.max_id; + self.max_id += 1; + self.plan_to_id.insert(ptr, id); + self.id_to_plan.insert(id, plan.clone()); + id + } + } + + fn add_join_condition( + &mut self, + left_id: usize, + right_id: usize, + join_condition: JoinCondition, + ) { + if let Some(neighbors) = self.edges.get_mut(&left_id) { + if let Some(join_conditions) = neighbors.get_mut(&right_id) { + join_conditions.push(join_condition); + } else { + neighbors.insert(right_id, vec![join_condition]); + } + } else { + let mut neighbors = HashMap::new(); + neighbors.insert(right_id, vec![join_condition]); + self.edges.insert(left_id, neighbors); + } + } + + fn add_unidirectional_edge(&mut self, left: &JoinNode, right: &JoinNode) { + let join_condition = JoinCondition { + left_on: left.final_name.clone(), + right_on: right.final_name.clone(), + }; + let left_id = self.get_or_create_plan_id(&left.plan); + let right_id = self.get_or_create_plan_id(&right.plan); + self.add_join_condition(left_id, right_id, join_condition); + } + + pub(super) fn add_bidirectional_edge(&mut self, node1: JoinNode, node2: JoinNode) { + self.add_unidirectional_edge(&node1, &node2); + self.add_unidirectional_edge(&node2, &node1); + } + + pub(super) fn connected_join_trees(&self, left: &JoinOrderTree, right: &JoinOrderTree) -> bool { + for left_node in left.iter() { + if let Some(neighbors) = self.edges.get(&left_node) { + for right_node in right.iter() { + if let Some(_) = neighbors.get(&right_node) { + return true; + } + } + } + } + return false; } } #[derive(Debug)] -enum ProjectionOrFilter { +pub(crate) enum ProjectionOrFilter { Projection(Vec), Filter(ExprRef), } /// Representation of a logical plan as edges between relations, along with additional information needed to /// reconstruct a logcial plan that's equivalent to the plan that produced this graph. -struct JoinGraph { - // TODO(desmond): Instead of simply storing edges, we might want to maintain adjacency lists between - // relations. We can make this decision later when we implement join order selection. - edges: Vec, +pub(crate) struct JoinGraph { + pub adj_list: JoinAdjList, // List of projections and filters that should be applied after join reordering. This list respects // pre-order traversal of projections and filters in the query tree, so we should apply these operators // starting from the back of the list. @@ -84,47 +230,51 @@ struct JoinGraph { impl JoinGraph { pub(crate) fn new( - edges: Vec, + adj_list: JoinAdjList, final_projections_and_filters: Vec, ) -> Self { Self { - edges, + adj_list, final_projections_and_filters, } } /// Test helper function to get the number of edges that the current graph contains. pub(crate) fn num_edges(&self) -> usize { - self.edges.len() + let mut num_edges = 0; + for (_, edges) in &self.adj_list.edges { + num_edges += edges.len(); + } + // Each edge is bidirectional, so we divide by 2 to get the correct number of edges. + num_edges / 2 } /// Test helper function to check that all relations in this graph are connected. pub(crate) fn fully_connected(&self) -> bool { - // Assuming that we're not testing an empty graph, there should be at least one edge in a connected graph. - if self.edges.is_empty() { - return false; - } - let mut adj_list: HashMap<*const _, Vec<*const _>> = HashMap::new(); - for edge in &self.edges { - let l_ptr = Arc::as_ptr(&edge.0.plan); - let r_ptr = Arc::as_ptr(&edge.1.plan); - - adj_list.entry(l_ptr).or_default().push(r_ptr); - adj_list.entry(r_ptr).or_default().push(l_ptr); - } - let start_ptr = Arc::as_ptr(&self.edges[0].0.plan); + let start = if let Some((node, _)) = self.adj_list.edges.iter().next() { + node + } else { + // There are no nodes. The empty graph is fully connected. + return true; + }; let mut seen = HashSet::new(); - let mut stack = vec![start_ptr]; + let mut stack = vec![start]; while let Some(current) = stack.pop() { if seen.insert(current) { // If this is a new node, add all its neighbors to the stack. - if let Some(neighbors) = adj_list.get(¤t) { - stack.extend(neighbors.iter().filter(|&&n| !seen.contains(&n))); + if let Some(neighbors) = self.adj_list.edges.get(current) { + stack.extend(neighbors.iter().filter_map(|(neighbor, _)| { + if !seen.contains(neighbor) { + Some(neighbor) + } else { + None + } + })); } } } - seen.len() == adj_list.len() + seen.len() == self.adj_list.max_id } /// Test helper function that checks if the graph contains the given projection/filter expressions @@ -150,15 +300,38 @@ impl JoinGraph { false } + fn get_node_by_id(&self, id: usize) -> &LogicalPlanRef { + self.adj_list + .id_to_plan + .get(&id) + .expect("Tried to retrieve a plan from the join graph with an invalid ID") + } + /// Helper function that loosely checks if a given edge (represented by a simple string) /// exists in the current graph. - pub(crate) fn contains_edge(&self, edge_string: &str) -> bool { - for edge in &self.edges { - if edge.simple_repr() == edge_string { - return true; + pub(crate) fn contains_edges(&self, to_check: Vec<&str>) -> bool { + let mut edge_strings = HashSet::new(); + for (left_id, neighbors) in &self.adj_list.edges { + for (right_id, join_conds) in neighbors { + let left = self.get_node_by_id(*left_id); + let right = self.get_node_by_id(*right_id); + for join_cond in join_conds { + edge_strings.insert(format!( + "{}({}) <-> {}({})", + left.name(), + join_cond.left_on, + right.name(), + join_cond.right_on + )); + } } } - false + for cur_check in to_check { + if !edge_strings.contains(cur_check) { + return false; + } + } + true } } @@ -167,14 +340,14 @@ struct JoinGraphBuilder { plan: LogicalPlanRef, join_conds_to_resolve: Vec<(String, LogicalPlanRef, bool)>, final_name_map: HashMap, - edges: Vec, + adj_list: JoinAdjList, final_projections_and_filters: Vec, } impl JoinGraphBuilder { pub(crate) fn build(mut self) -> JoinGraph { self.process_node(&self.plan.clone()); - JoinGraph::new(self.edges, self.final_projections_and_filters) + JoinGraph::new(self.adj_list, self.final_projections_and_filters) } pub(crate) fn from_logical_plan(plan: LogicalPlanRef) -> Self { @@ -192,7 +365,7 @@ impl JoinGraphBuilder { plan, join_conds_to_resolve: vec![], final_name_map: HashMap::new(), - edges: vec![], + adj_list: JoinAdjList::empty(), final_projections_and_filters: vec![ProjectionOrFilter::Projection(output_projection)], } } @@ -328,7 +501,7 @@ impl JoinGraphBuilder { rnode.clone(), self.final_name_map.get(&rname).unwrap().name().to_string(), ); - self.edges.push(JoinEdge(node1, node2)); + self.adj_list.add_bidirectional_edge(node1, node2); } else { panic!("Join conditions were unresolved"); } @@ -337,13 +510,43 @@ impl JoinGraphBuilder { // TODO(desmond): There are potentially more reorderable nodes. For example, we can move repartitions around. _ => { // This is an unreorderable node. All unresolved columns coming out of this node should be marked as resolved. + // TODO(desmond): At this point we should perform a fresh join reorder optimization starting from this + // node as the root node. We can do this once we add the optimizer rule. + let mut projections = vec![]; + let mut needs_projection = false; + let mut seen_names = HashSet::new(); for (name, _, done) in self.join_conds_to_resolve.iter_mut() { - if schema.has_field(name) { + if schema.has_field(name) && !*done && !seen_names.contains(name) { + if let Some(final_name) = self.final_name_map.get(name) { + let final_name = final_name.name().to_string(); + if final_name != *name { + needs_projection = true; + projections.push(col(name.clone()).alias(final_name)); + } else { + projections.push(col(name.clone())); + } + } else { + projections.push(col(name.clone())); + } + seen_names.insert(name); + } + } + // Apply projections and return the new plan as the relation for the appropriate join conditions. + let projected_plan = if needs_projection { + let projected_plan = LogicalPlanBuilder::from(plan.clone()) + .select(projections) + .expect("Computed projections could not be applied to relation") + .build(); + Arc::new(Arc::unwrap_or_clone(projected_plan).with_materialized_stats()) + } else { + plan.clone() + }; + for (name, node, done) in self.join_conds_to_resolve.iter_mut() { + if schema.has_field(name) && !*done { *done = true; + *node = projected_plan.clone(); } } - // TODO(desmond): At this point we should perform a fresh join reorder optimization starting from this - // node as the root node. We can do this once we add the optimizer rule. } } } @@ -354,12 +557,18 @@ mod tests { use std::sync::Arc; use common_scan_info::Pushdowns; + use common_treenode::TransformedResult; use daft_core::prelude::CountMode; use daft_dsl::{col, AggExpr, Expr, LiteralValue}; use daft_schema::{dtype::DataType, field::Field}; use super::JoinGraphBuilder; - use crate::test::{dummy_scan_node_with_pushdowns, dummy_scan_operator}; + use crate::{ + optimization::rules::{EnrichWithStats, MaterializeScans, OptimizerRule}, + test::{ + dummy_scan_node_with_pushdowns, dummy_scan_operator, dummy_scan_operator_with_size, + }, + }; #[test] fn test_create_join_graph_basic_1() { @@ -372,21 +581,21 @@ mod tests { // | // Scan(c_prime) let scan_a = dummy_scan_node_with_pushdowns( - dummy_scan_operator(vec![Field::new("a", DataType::Int64)]), + dummy_scan_operator_with_size(vec![Field::new("a", DataType::Int64)], Some(100)), Pushdowns::default(), ); let scan_b = dummy_scan_node_with_pushdowns( - dummy_scan_operator(vec![Field::new("b", DataType::Int64)]), + dummy_scan_operator_with_size(vec![Field::new("b", DataType::Int64)], Some(10_000)), Pushdowns::default(), ); let scan_c = dummy_scan_node_with_pushdowns( - dummy_scan_operator(vec![Field::new("c_prime", DataType::Int64)]), + dummy_scan_operator_with_size(vec![Field::new("c_prime", DataType::Int64)], Some(100)), Pushdowns::default(), ) .select(vec![col("c_prime").alias("c")]) .unwrap(); let scan_d = dummy_scan_node_with_pushdowns( - dummy_scan_operator(vec![Field::new("d", DataType::Int64)]), + dummy_scan_operator_with_size(vec![Field::new("d", DataType::Int64)], Some(100)), Pushdowns::default(), ); let join_plan_l = scan_a @@ -410,17 +619,26 @@ mod tests { vec![Arc::new(Expr::Column(Arc::from("d")))], ) .unwrap(); - let plan = join_plan.build(); - let join_graph = JoinGraphBuilder::from_logical_plan(plan).build(); + let original_plan = join_plan.build(); + let scan_materializer = MaterializeScans::new(); + let original_plan = scan_materializer + .try_optimize(original_plan) + .data() + .unwrap(); + let stats_enricher = EnrichWithStats::new(); + let original_plan = stats_enricher.try_optimize(original_plan).data().unwrap(); + let join_graph = JoinGraphBuilder::from_logical_plan(original_plan.clone()).build(); assert!(join_graph.fully_connected()); // There should be edges between: // - a <-> b - // - c_prime <-> d + // - c <-> d // - a <-> d assert!(join_graph.num_edges() == 3); - assert!(join_graph.contains_edge("a#Source(a) <-> b#Source(b)")); - assert!(join_graph.contains_edge("c#Source(c_prime) <-> d#Source(d)")); - assert!(join_graph.contains_edge("a#Source(a) <-> d#Source(d)")); + assert!(join_graph.contains_edges(vec![ + "Source(a) <-> Source(b)", + "Project(c) <-> Source(d)", + "Source(a) <-> Source(d)" + ])); } #[test] @@ -472,17 +690,26 @@ mod tests { vec![Arc::new(Expr::Column(Arc::from("d")))], ) .unwrap(); - let plan = join_plan.build(); - let join_graph = JoinGraphBuilder::from_logical_plan(plan).build(); + let original_plan = join_plan.build(); + let scan_materializer = MaterializeScans::new(); + let original_plan = scan_materializer + .try_optimize(original_plan) + .data() + .unwrap(); + let stats_enricher = EnrichWithStats::new(); + let original_plan = stats_enricher.try_optimize(original_plan).data().unwrap(); + let join_graph = JoinGraphBuilder::from_logical_plan(original_plan.clone()).build(); assert!(join_graph.fully_connected()); // There should be edges between: // - a <-> b - // - c_prime <-> d + // - c <-> d // - b <-> d assert!(join_graph.num_edges() == 3); - assert!(join_graph.contains_edge("a#Source(a) <-> b#Source(b)")); - assert!(join_graph.contains_edge("c#Source(c_prime) <-> d#Source(d)")); - assert!(join_graph.contains_edge("b#Source(b) <-> d#Source(d)")); + assert!(join_graph.contains_edges(vec![ + "Source(a) <-> Source(b)", + "Project(c) <-> Source(d)", + "Source(b) <-> Source(d)", + ])); } #[test] @@ -528,15 +755,24 @@ mod tests { vec![Arc::new(Expr::Column(Arc::from("c")))], ) .unwrap(); - let plan = join_plan_2.build(); - let join_graph = JoinGraphBuilder::from_logical_plan(plan).build(); + let original_plan = join_plan_2.build(); + let scan_materializer = MaterializeScans::new(); + let original_plan = scan_materializer + .try_optimize(original_plan) + .data() + .unwrap(); + let stats_enricher = EnrichWithStats::new(); + let original_plan = stats_enricher.try_optimize(original_plan).data().unwrap(); + let join_graph = JoinGraphBuilder::from_logical_plan(original_plan.clone()).build(); assert!(join_graph.fully_connected()); // There should be edges between: - // - a <-> b - // - a <-> c + // - a_beta <-> b + // - a_beta <-> c assert!(join_graph.num_edges() == 2); - assert!(join_graph.contains_edge("a_beta#Source(a) <-> b#Source(b)")); - assert!(join_graph.contains_edge("a_beta#Source(a) <-> c#Source(c)")); + assert!(join_graph.contains_edges(vec![ + "Project(a_beta) <-> Source(b)", + "Project(a_beta) <-> Source(c)", + ])); } #[test] @@ -589,17 +825,26 @@ mod tests { vec![Arc::new(Expr::Column(Arc::from("d")))], ) .unwrap(); - let plan = join_plan.build(); - let join_graph = JoinGraphBuilder::from_logical_plan(plan).build(); + let original_plan = join_plan.build(); + let scan_materializer = MaterializeScans::new(); + let original_plan = scan_materializer + .try_optimize(original_plan) + .data() + .unwrap(); + let stats_enricher = EnrichWithStats::new(); + let original_plan = stats_enricher.try_optimize(original_plan).data().unwrap(); + let join_graph = JoinGraphBuilder::from_logical_plan(original_plan.clone()).build(); assert!(join_graph.fully_connected()); // There should be edges between: // - a <-> b - // - c_prime <-> d + // - c <-> d // - a <-> d assert!(join_graph.num_edges() == 3); - assert!(join_graph.contains_edge("a#Source(a) <-> b#Source(b)")); - assert!(join_graph.contains_edge("c#Source(c_prime) <-> d#Source(d)")); - assert!(join_graph.contains_edge("a#Source(a) <-> d#Source(d)")); + assert!(join_graph.contains_edges(vec![ + "Source(a) <-> Source(b)", + "Project(c) <-> Source(d)", + "Source(a) <-> Source(d)" + ])); // Check for non-join projections at the end. // `c_prime` gets renamed to `c` in the final projection let double_proj = col("c").add(col("c")).alias("double"); @@ -674,17 +919,26 @@ mod tests { vec![Arc::new(Expr::Column(Arc::from("d")))], ) .unwrap(); - let plan = join_plan.build(); - let join_graph = JoinGraphBuilder::from_logical_plan(plan).build(); + let original_plan = join_plan.build(); + let scan_materializer = MaterializeScans::new(); + let original_plan = scan_materializer + .try_optimize(original_plan) + .data() + .unwrap(); + let stats_enricher = EnrichWithStats::new(); + let original_plan = stats_enricher.try_optimize(original_plan).data().unwrap(); + let join_graph = JoinGraphBuilder::from_logical_plan(original_plan.clone()).build(); assert!(join_graph.fully_connected()); // There should be edges between: // - a <-> b - // - c_prime <-> d + // - c <-> d // - a <-> d assert!(join_graph.num_edges() == 3); - assert!(join_graph.contains_edge("a#Source(a) <-> b#Source(b)")); - assert!(join_graph.contains_edge("c#Source(c_prime) <-> d#Source(d)")); - assert!(join_graph.contains_edge("a#Source(a) <-> d#Source(d)")); + assert!(join_graph.contains_edges(vec![ + "Source(a) <-> Source(b)", + "Project(c) <-> Source(d)", + "Source(a) <-> Source(d)", + ])); // Check for non-join projections and filters at the end. // `c_prime` gets renamed to `c` in the final projection let double_proj = col("c").add(col("c")).alias("double"); @@ -760,17 +1014,26 @@ mod tests { vec![Arc::new(Expr::Column(Arc::from("d")))], ) .unwrap(); - let plan = join_plan.build(); - let join_graph = JoinGraphBuilder::from_logical_plan(plan).build(); + let original_plan = join_plan.build(); + let scan_materializer = MaterializeScans::new(); + let original_plan = scan_materializer + .try_optimize(original_plan) + .data() + .unwrap(); + let stats_enricher = EnrichWithStats::new(); + let original_plan = stats_enricher.try_optimize(original_plan).data().unwrap(); + let join_graph = JoinGraphBuilder::from_logical_plan(original_plan.clone()).build(); assert!(join_graph.fully_connected()); // There should be edges between: // - a <-> b - // - c_prime <-> d + // - c <-> d // - a <-> d assert!(join_graph.num_edges() == 3); - assert!(join_graph.contains_edge("a#Aggregate(a) <-> b#Source(b)")); - assert!(join_graph.contains_edge("c#Source(c_prime) <-> d#Source(d)")); - assert!(join_graph.contains_edge("a#Aggregate(a) <-> d#Source(d)")); + assert!(join_graph.contains_edges(vec![ + "Aggregate(a) <-> Source(b)", + "Project(c) <-> Source(d)", + "Aggregate(a) <-> Source(d)" + ])); // Projections below the aggregation should not be part of the final projections. assert!(!join_graph.contains_projections_and_filters(vec![&a_proj])); } diff --git a/src/daft-logical-plan/src/optimization/rules/reorder_joins/mod.rs b/src/daft-logical-plan/src/optimization/rules/reorder_joins/mod.rs index 09ece20040..762d58a4a8 100644 --- a/src/daft-logical-plan/src/optimization/rules/reorder_joins/mod.rs +++ b/src/daft-logical-plan/src/optimization/rules/reorder_joins/mod.rs @@ -1,2 +1,4 @@ #[cfg(test)] mod join_graph; +#[cfg(test)] +mod naive_left_deep_join_order; diff --git a/src/daft-logical-plan/src/optimization/rules/reorder_joins/naive_left_deep_join_order.rs b/src/daft-logical-plan/src/optimization/rules/reorder_joins/naive_left_deep_join_order.rs new file mode 100644 index 0000000000..c0b3ab6634 --- /dev/null +++ b/src/daft-logical-plan/src/optimization/rules/reorder_joins/naive_left_deep_join_order.rs @@ -0,0 +1,178 @@ +use super::join_graph::{JoinGraph, JoinOrderTree, JoinOrderer}; + +pub(crate) struct NaiveLeftDeepJoinOrderer {} + +impl NaiveLeftDeepJoinOrderer { + fn extend_order( + graph: &JoinGraph, + current_order: Box, + mut available: Vec, + ) -> Box { + if available.is_empty() { + return current_order; + } + for (index, candidate_node_id) in available.iter().enumerate() { + let right = Box::new(JoinOrderTree::Relation(*candidate_node_id)); + if graph.adj_list.connected_join_trees(¤t_order, &right) { + let new_order = current_order.join(right); + available.remove(index); + return Self::extend_order(graph, new_order, available); + } + } + panic!("There should be at least one naive join order."); + } +} + +impl JoinOrderer for NaiveLeftDeepJoinOrderer { + fn order(&self, graph: &JoinGraph) -> Box { + let available: Vec = (1..graph.adj_list.max_id).collect(); + // Take a starting order of the node with id 0. + let starting_order = Box::new(JoinOrderTree::Relation(0)); + Self::extend_order(graph, starting_order, available) + } +} + +#[cfg(test)] +mod tests { + use common_scan_info::Pushdowns; + use daft_schema::{dtype::DataType, field::Field}; + use rand::{rngs::StdRng, seq::SliceRandom, Rng, SeedableRng}; + + use super::{JoinGraph, JoinOrderTree, JoinOrderer, NaiveLeftDeepJoinOrderer}; + use crate::{ + optimization::rules::reorder_joins::join_graph::{JoinAdjList, JoinNode}, + test::{dummy_scan_node_with_pushdowns, dummy_scan_operator_with_size}, + LogicalPlanRef, + }; + + fn assert_order_contains_all_nodes(order: &Box, graph: &JoinGraph) { + for id in 0..graph.adj_list.max_id { + assert!( + order.contains(id), + "Graph id {} not found in order {:?}.\n{}", + id, + order, + graph.adj_list + ); + } + } + + fn create_scan_node(name: &str, size: Option) -> LogicalPlanRef { + dummy_scan_node_with_pushdowns( + dummy_scan_operator_with_size(vec![Field::new(name, DataType::Int64)], size), + Pushdowns::default(), + ) + .build() + } + + fn create_join_graph_with_edges(nodes: Vec, edges: Vec<(usize, usize)>) -> JoinGraph { + let mut adj_list = JoinAdjList::empty(); + for (from, to) in edges { + adj_list.add_bidirectional_edge(nodes[from].clone(), nodes[to].clone()); + } + JoinGraph::new(adj_list, vec![]) + } + + macro_rules! create_and_test_join_graph { + ($nodes:expr, $edges:expr, $orderer:expr) => { + let nodes: Vec = $nodes + .iter() + .map(|name| { + let scan_node = create_scan_node(name, Some(100)); + JoinNode::new(name.to_string(), scan_node, name.to_string()) + }) + .collect(); + let graph = create_join_graph_with_edges(nodes.clone(), $edges); + let order = $orderer.order(&graph); + assert_order_contains_all_nodes(&order, &graph); + }; + } + + #[test] + fn test_order_basic_join_graph() { + let nodes = vec!["a", "b", "c", "d"]; + let edges = vec![ + (0, 2), // node_a <-> node_c + (1, 2), // node_b <-> node_c + (2, 3), // node_c <-> node_d + ]; + create_and_test_join_graph!(nodes, edges, NaiveLeftDeepJoinOrderer {}); + } + + pub struct UnionFind { + parent: Vec, + size: Vec, + } + + impl UnionFind { + pub fn create(num_nodes: usize) -> Self { + UnionFind { + parent: (0..num_nodes).collect(), + size: vec![1; num_nodes], + } + } + + pub fn find(&mut self, node: usize) -> usize { + if self.parent[node] != node { + self.parent[node] = self.find(self.parent[node]); + } + self.parent[node] + } + + pub fn union(&mut self, node1: usize, node2: usize) { + let root1 = self.find(node1); + let root2 = self.find(node2); + + if root1 != root2 { + let (small, big) = if self.size[root1] < self.size[root2] { + (root1, root2) + } else { + (root2, root1) + }; + self.parent[small] = big; + self.size[big] += self.size[small]; + } + } + } + + fn create_random_connected_graph(num_nodes: usize) -> Vec<(usize, usize)> { + let mut rng = StdRng::seed_from_u64(0); + let mut edges = Vec::new(); + let mut uf = UnionFind::create(num_nodes); + + // Get a random order of all possible edges. + let mut all_edges: Vec<(usize, usize)> = (0..num_nodes) + .flat_map(|i| (0..i).chain(i + 1..num_nodes).map(move |j| (i, j))) + .collect(); + all_edges.shuffle(&mut rng); + + // Select edges to form a minimum spanning tree + a random number of extra edges. + for (a, b) in all_edges { + if uf.find(a) != uf.find(b) { + uf.union(a, b); + edges.push((a, b)); + } + // Check if we have a minimum spanning tree. + if edges.len() >= num_nodes - 1 { + // Once we have a minimum spanning tree, we let a random number of extra edges be added to the graph. + if rng.gen_bool(0.3) { + break; + } + edges.push((a, b)); + } + } + + edges + } + + const NUM_RANDOM_NODES: usize = 100; + + #[test] + fn test_order_random_join_graph() { + let nodes: Vec = (0..NUM_RANDOM_NODES) + .map(|i| format!("node_{}", i)) + .collect(); + let edges = create_random_connected_graph(NUM_RANDOM_NODES); + create_and_test_join_graph!(nodes, edges, NaiveLeftDeepJoinOrderer {}); + } +} diff --git a/src/daft-logical-plan/src/test/mod.rs b/src/daft-logical-plan/src/test/mod.rs index 75f8ad386b..7ac8da51c1 100644 --- a/src/daft-logical-plan/src/test/mod.rs +++ b/src/daft-logical-plan/src/test/mod.rs @@ -7,10 +7,20 @@ use crate::builder::LogicalPlanBuilder; /// Create a dummy scan node containing the provided fields in its schema and the provided limit. pub fn dummy_scan_operator(fields: Vec) -> ScanOperatorRef { + dummy_scan_operator_with_size(fields, None) +} + +/// Create dummy scan node containing the provided fields in its schema and the provided limit, +/// and with the provided size estimate. +pub fn dummy_scan_operator_with_size( + fields: Vec, + in_memory_size_per_task: Option, +) -> ScanOperatorRef { let schema = Arc::new(Schema::new(fields).unwrap()); ScanOperatorRef(Arc::new(DummyScanOperator { schema, - num_scan_tasks: 0, + num_scan_tasks: 1, + in_memory_size_per_task, })) } diff --git a/src/daft-physical-plan/src/test/mod.rs b/src/daft-physical-plan/src/test/mod.rs index 3e8de6a74c..29f9d81997 100644 --- a/src/daft-physical-plan/src/test/mod.rs +++ b/src/daft-physical-plan/src/test/mod.rs @@ -10,6 +10,7 @@ pub fn dummy_scan_operator(fields: Vec) -> ScanOperatorRef { ScanOperatorRef(Arc::new(DummyScanOperator { schema, num_scan_tasks: 1, + in_memory_size_per_task: None, })) }