diff --git a/src/extract/global_greedy_dag.rs b/src/extract/global_greedy_dag.rs index cfc5ca8..4bb35b9 100644 --- a/src/extract/global_greedy_dag.rs +++ b/src/extract/global_greedy_dag.rs @@ -1,3 +1,5 @@ +use std::iter; + use rpds::{HashTrieMap, HashTrieSet}; use super::*; @@ -15,19 +17,19 @@ type Reachable = HashTrieSet; struct TermInfo { node: NodeId, eclass: ClassId, - node_cost: NotNan, - total_cost: NotNan, + node_cost: Cost, + total_cost: Cost, // store the set of reachable terms from this term reachable: Reachable, size: usize, } -// A TermDag needs to store terms that share common -// subterms using a hashmap. -// However, it also critically needs to be able to answer -// reachability queries in this dag `reachable`. -// This prevents double-counting costs when -// computing the cost of a term. +/// A TermDag needs to store terms that share common +/// subterms using a hashmap. +/// However, it also critically needs to be able to answer +/// reachability queries in this dag `reachable`. +/// This prevents double-counting costs when +/// computing the cost of a term. #[derive(Default)] pub struct TermDag { nodes: Vec, @@ -36,16 +38,16 @@ pub struct TermDag { } impl TermDag { - // Makes a new term using a node and children terms - // Correctly computes total_cost with sharing - // If this term contains itself, returns None - // If this term costs more than target, returns None + /// Makes a new term using a node and children terms + /// Correctly computes total_cost with sharing + /// If this term contains itself, returns None + /// If this term costs more than target, returns None pub fn make( &mut self, node_id: NodeId, node: &Node, children: Vec, - target: NotNan, + target: Cost, ) -> Option { let term = Term { op: node.op.clone(), @@ -66,13 +68,15 @@ impl TermDag { eclass: node.eclass.clone(), node_cost, total_cost: node_cost, - reachable: [node.eclass.clone()].into_iter().collect(), + reachable: iter::once(node.eclass.clone()).collect(), size: 1, }); self.hash_cons.insert(term, next_id); Some(next_id) } else { - // check if children contains this node + // check if children contains this node, preventing cycles + // This is sound because `reachable` is the set of reachable eclasses + // from this term. for child in &children { if self.info[*child].reachable.contains(&node.eclass) { return None; @@ -115,10 +119,16 @@ impl TermDag { } } - // Return a new term, like this one but making use of shared terms. - // Also return the cost of the new nodes. - fn get_cost(&self, shared: &mut Box, id: TermId) -> NotNan { + /// Return a new term, like this one but making use of shared terms. + /// Also return the cost of the new nodes. + fn get_cost(&self, shared: &mut Box, id: TermId) -> Cost { let eclass = self.info[id].eclass.clone(); + + // This is the key to why this algorithm is faster than greedy_dag. + // While doing the set union between reachable sets, we can stop early + // if we find a shared term. + // Since the term with `id` is shared, the reachable set of `id` will already + // be in `shared`. if shared.contains(&eclass) { NotNan::::new(0.0).unwrap() } else { @@ -132,11 +142,11 @@ impl TermDag { } } - pub fn node_cost(&self, id: TermId) -> NotNan { + pub fn node_cost(&self, id: TermId) -> Cost { self.info[id].node_cost } - pub fn total_cost(&self, id: TermId) -> NotNan { + pub fn total_cost(&self, id: TermId) -> Cost { self.info[id].total_cost } }