From 2fbe528a7b3bbc11e8b5927f0521d3c27dfececb Mon Sep 17 00:00:00 2001 From: Philipp Gackstatter Date: Fri, 25 Oct 2024 13:32:20 +0200 Subject: [PATCH] feat(core): Use a deque instead of a stack in the iterator --- core/src/mast/merger/mod.rs | 2 +- core/src/mast/multi_forest_node_iterator.rs | 170 +++++++++----------- 2 files changed, 76 insertions(+), 96 deletions(-) diff --git a/core/src/mast/merger/mod.rs b/core/src/mast/merger/mod.rs index c497f0cad..30ad7e2d6 100644 --- a/core/src/mast/merger/mod.rs +++ b/core/src/mast/merger/mod.rs @@ -62,7 +62,7 @@ impl MastForestMerger { /// 2. Merge all nodes of forests. /// - Similar to decorators, node indices might move during merging, so the merger keeps a /// node id mapping as it merges nodes. - /// - This is a depth-first traversal over the forests to ensure all children are processed + /// - This is a depth-first traversal over all forests to ensure all children are processed /// before their parents. See the documentation of [`MultiMastForestNodeIter`] for details /// on this traversal. /// - Because all parents are processed after their children, we can use the node id mapping diff --git a/core/src/mast/multi_forest_node_iterator.rs b/core/src/mast/multi_forest_node_iterator.rs index daaf8bf51..48b5fdf9e 100644 --- a/core/src/mast/multi_forest_node_iterator.rs +++ b/core/src/mast/multi_forest_node_iterator.rs @@ -1,4 +1,5 @@ use alloc::{collections::BTreeMap, vec::Vec}; +use std::collections::VecDeque; use miden_crypto::hash::rpo::RpoDigest; @@ -37,29 +38,30 @@ type ForestIndex = usize; /// The only root of A is the `Join` node at index 2. The first three nodes of the forest form a /// tree, since the `Join` node references index 0 and 1. This tree is discovered by /// starting at the root at index 2 and following all children until we reach terminal nodes (like -/// `Block`s) and building up a stack of the discovered nodes. The special case here is the +/// `Block`s) and building up a deque of the discovered nodes. The special case here is the /// `External` node whose digest matches that of a node in forest B. Instead of the External -/// node being added to the stack, the tree of the Call node is added instead. The stack is built -/// such that popping elements off the stack (from the back) yields a postorder. +/// node being added to the deque, the tree of the Call node is added instead. The deque is built +/// such that popping elements off the deque (from the front) yields a postorder. +/// +/// After the first tree is discovered, the deque looks like this: /// -/// After the first tree is discovered, the stack looks like this: /// ```text -/// [Node(forest_idx: 0, node_id: 2), +/// [Node(forest_idx: 0, node_id: 0), +/// Node(forest_idx: 1, node_id: 0), +/// Node(forest_idx: 1, node_id: 1), /// ExternalNodeReplacement( /// replacement_forest_idx: 1, replacement_node_id: 1 /// replaced_forest_idx: 0, replaced_node_id: 1 /// ), -/// Node(forest_idx: 1, node_id: 1), -/// Node(forest_idx: 1, node_id: 0), -/// Node(forest_idx: 0, node_id: 0)] +/// Node(forest_idx: 0, node_id: 2)] /// ``` /// -/// If the stack is exhausted we start another discovery if more unvisited roots exist. In this -/// example, the root of forest B was already visited due to the External node reference, so the -/// iteration is complete. +/// If the deque is exhausted we start another discovery if more undiscovered roots exist. In this +/// example, the root of forest B was already discovered and visited due to the External node +/// reference, so the iteration is complete. /// /// The iteration on a higher level thus consists of a back and forth between discovering trees and -/// returning nodes from the stack. +/// returning nodes from the deque. pub(crate) struct MultiMastForestNodeIter<'forest> { /// The forests that we're iterating. mast_forests: Vec<&'forest MastForest>, @@ -77,13 +79,12 @@ pub(crate) struct MultiMastForestNodeIter<'forest> { /// A map of MAST roots of all non-external nodes in mast_forests to their forest and node /// indices. non_external_nodes: BTreeMap, - /// Describes whether the node at some [forest_index][node_index] has already been visited. - /// Note that this is set to true for all nodes that have been returned from the iterator. - visited_nodes: Vec>, - /// This stack always contains the discovered nodes. - /// The stack might contain a node twice. However we only ever return a node once, which is - /// checked in the `next` function. - discovered_nodes: Vec, + /// Describes whether the node identified by [forest_index][node_index] has already been + /// discovered. Note that this is `true` for all nodes that are in the unvisited node deque. + discovered_nodes: Vec>, + /// This deque always contains the discovered, but unvisited nodes. + /// It holds that discovered_nodes[forest_idx][node_id] = true for all elements in this deque. + unvisited_nodes: VecDeque, } impl<'forest> MultiMastForestNodeIter<'forest> { @@ -91,7 +92,7 @@ impl<'forest> MultiMastForestNodeIter<'forest> { /// the iterator. This enables an efficient check whether for any encountered External node /// referencing digest `foo` a node with digest `foo` already exists in any forest. pub(crate) fn new(mast_forests: Vec<&'forest MastForest>) -> Self { - let visited_nodes = mast_forests + let discovered_nodes = mast_forests .iter() .map(|forest| vec![false; forest.num_nodes() as usize]) .collect(); @@ -114,28 +115,34 @@ impl<'forest> MultiMastForestNodeIter<'forest> { current_forest_idx: 0, current_procedure_root_idx: 0, non_external_nodes, - visited_nodes, - discovered_nodes: Vec::new(), + discovered_nodes, + unvisited_nodes: VecDeque::new(), } } - /// Pushes the given node, uniquely identified by the forest and node index onto the stack + /// Pushes the given node, uniquely identified by the forest and node index onto the deque /// even if the node was already discovered before. /// /// It's the callers responsibility to only pass valid indices. fn push_node(&mut self, forest_idx: usize, node_id: MastNodeId) { - self.discovered_nodes - .push(MultiMastForestIteratorItem::Node { forest_idx, node_id }); + self.unvisited_nodes + .push_back(MultiMastForestIteratorItem::Node { forest_idx, node_id }); + self.discovered_nodes[forest_idx][node_id.as_usize()] = true; } /// Discovers a tree starting at the given forest index and node id. /// - /// It's the callers responsibility to only pass valid indices. + /// SAFETY: We only pass valid forest and node indices so we can index directly in this + /// function. fn discover_tree( &mut self, forest_idx: ForestIndex, node_id: MastNodeId, ) -> Result<(), MastForestError> { + if self.discovered_nodes[forest_idx][node_id.as_usize()] { + return Ok(()); + } + let current_node = &self.mast_forests[forest_idx].nodes.get(node_id.as_usize()).ok_or_else(|| { MastForestError::NodeIdOverflow( @@ -144,30 +151,29 @@ impl<'forest> MultiMastForestNodeIter<'forest> { ) })?; - // Note that the order in which we add or discover nodes is the reverse of postorder, since - // we're pushing them onto a stack, which reverses the order itself. Hence, reversing twice - // gives us the actual postorder we want. + // Note that we can process nodes in postorder, since we push them onto the back of the + // deque but pop them off the front. match current_node { MastNode::Block(_) => { self.push_node(forest_idx, node_id); }, MastNode::Join(join_node) => { - self.push_node(forest_idx, node_id); - self.discover_tree(forest_idx, join_node.second())?; self.discover_tree(forest_idx, join_node.first())?; + self.discover_tree(forest_idx, join_node.second())?; + self.push_node(forest_idx, node_id); }, MastNode::Split(split_node) => { - self.push_node(forest_idx, node_id); - self.discover_tree(forest_idx, split_node.on_false())?; self.discover_tree(forest_idx, split_node.on_true())?; + self.discover_tree(forest_idx, split_node.on_false())?; + self.push_node(forest_idx, node_id); }, MastNode::Loop(loop_node) => { - self.push_node(forest_idx, node_id); self.discover_tree(forest_idx, loop_node.body())?; + self.push_node(forest_idx, node_id); }, MastNode::Call(call_node) => { - self.push_node(forest_idx, node_id); self.discover_tree(forest_idx, call_node.callee())?; + self.push_node(forest_idx, node_id); }, MastNode::Dyn(_) => { self.push_node(forest_idx, node_id); @@ -176,19 +182,20 @@ impl<'forest> MultiMastForestNodeIter<'forest> { // When we encounter an external node referencing digest `foo` there are two cases: // - If there exists a node `replacement` in any forest with digest `foo`, we want // to replace the external node with that node, which we do in two steps. - // - Discover the `replacement`'s tree and add it to the stack. - // - If `replacement` was already visited before, it won't actually be returned. + // - Discover the `replacement`'s tree and add it to the deque. + // - If `replacement` was already discovered before, it won't actually be + // returned. // - In any case this means: The `replacement` node is processed before the // replacement signal we're adding next. - // - Add a replacement signal to the stack, signaling that the `replacement` + // - Add a replacement signal to the deque, signaling that the `replacement` // replaced the external node. - // - Note that the order of these operations in code is reversed, since the stack - // we're pushing the operations onto reverses the order once more. // - If no replacement exists, yield the External Node as a regular `Node`. if let Some((other_forest_idx, other_node_id)) = self.non_external_nodes.get(&external_node.digest()).copied() { - self.discovered_nodes.push( + self.discover_tree(other_forest_idx, other_node_id)?; + + self.unvisited_nodes.push_back( MultiMastForestIteratorItem::ExternalNodeReplacement { replacement_forest_idx: other_forest_idx, replacement_mast_node_id: other_node_id, @@ -197,7 +204,7 @@ impl<'forest> MultiMastForestNodeIter<'forest> { }, ); - self.discover_tree(other_forest_idx, other_node_id)?; + self.discovered_nodes[forest_idx][node_id.as_usize()] = true; } else { self.push_node(forest_idx, node_id); } @@ -207,16 +214,17 @@ impl<'forest> MultiMastForestNodeIter<'forest> { Ok(()) } - /// Finds the next unvisited procedure root and discovers a tree from it. + /// Finds the next undiscovered procedure root and discovers a tree from it. /// - /// If the unvisited node stack is empty after calling this function, the iteration is complete. + /// If the undiscovered node deque is empty after calling this function, the iteration is + /// complete. /// /// This function basically consists of two loops: /// - The outer loop iterates over all forest indices. /// - The inner loop iterates over all procedure root indices for the current forest. fn discover_nodes(&mut self) { 'forest_loop: while self.current_forest_idx < self.mast_forests.len() - && self.discovered_nodes.is_empty() + && self.unvisited_nodes.is_empty() { // If we don't have any forests, there is nothing to do. if self.mast_forests.is_empty() { @@ -230,11 +238,11 @@ impl<'forest> MultiMastForestNodeIter<'forest> { } let procedure_roots = self.mast_forests[self.current_forest_idx].procedure_roots(); - let visited_nodes = &self.visited_nodes[self.current_forest_idx]; + let discovered_nodes = &self.discovered_nodes[self.current_forest_idx]; - // Find the next unvisited procedure root for the current forest by incrementing the - // current procedure root until we find one that was not yet visited. - while visited_nodes + // Find the next undiscovered procedure root for the current forest by incrementing the + // current procedure root until we find one that was not yet discovered. + while discovered_nodes [procedure_roots[self.current_procedure_root_idx as usize].as_usize()] { // If we have reached the end of the procedure roots for the current forest, @@ -250,15 +258,15 @@ impl<'forest> MultiMastForestNodeIter<'forest> { continue 'forest_loop; } - // Since the current procedure root was already visited, check the next one. + // Since the current procedure root was already discovered, check the next one. self.current_procedure_root_idx += 1; } - // We exited the loop, so the current procedure root is unvisited and so we can start - // a discovery from that root. Since that root is unvisited and undiscovered, it is - // guaranteed that after this discovery the stack will be non-empty. - let tree_root_id = procedure_roots[self.current_procedure_root_idx as usize]; - self.discover_tree(self.current_forest_idx, tree_root_id) + // We exited the loop, so the current procedure root is undiscovered and so we can start + // a discovery from that root. Since that root is undiscovered, it is guaranteed that + // after this discovery the deque will be non-empty. + let procedure_root_id = procedure_roots[self.current_procedure_root_idx as usize]; + self.discover_tree(self.current_forest_idx, procedure_root_id) .expect("we should only pass root indices that are valid for the forest"); } } @@ -268,36 +276,16 @@ impl Iterator for MultiMastForestNodeIter<'_> { type Item = MultiMastForestIteratorItem; fn next(&mut self) -> Option { - while let Some(stack_item) = self.discovered_nodes.pop() { - // Get the forest and node index of the node being processed on the stack so we can - // check if it has already been visited. - let (forest_idx, node_id) = match &stack_item { - MultiMastForestIteratorItem::Node { forest_idx, node_id } => (forest_idx, node_id), - MultiMastForestIteratorItem::ExternalNodeReplacement { - replaced_forest_idx, - replaced_mast_node_id, - .. - } => (replaced_forest_idx, replaced_mast_node_id), - }; - - let is_node_visited_mut = self.visited_nodes[*forest_idx] - .get_mut(node_id.as_usize()) - .expect("visited_nodes can be safely indexed by any valid MastNodeId"); - if *is_node_visited_mut { - continue; - } else { - *is_node_visited_mut = true; - } - - return Some(stack_item); + if let Some(deque_item) = self.unvisited_nodes.pop_front() { + return Some(deque_item); } self.discover_nodes(); - if !self.discovered_nodes.is_empty() { + if !self.unvisited_nodes.is_empty() { self.next() } else { - // If the stack is empty after tree discovery, all (reachable) nodes have been + // If the deque is empty after tree discovery, all (reachable) nodes have been // discovered and visited. None } @@ -460,30 +448,22 @@ mod tests { ); } - /// Tests that a node which appears twice in a tree is returned in the required order. + /// Tests that a node which is referenced twice in a Mast Forest is returned in the required + /// order. /// - /// In this test we have a MastForest with a tree like this: + /// In this test we have a MastForest with this graph: /// /// 3 <- Split Node /// / \ /// 1 2 - /// / \ - /// 0 0 - /// - /// In a previous implementation we marked the nodes that we discovered as visited immediately - /// and did not add them again when encountering them again. This was not correct because of - /// this example. In the iterator, we need to discover nodes in reverse so we can later pop - /// them off the stack in postorder. So we first descend down child 2 of the split node. - /// This means we would mark 0 as visited and then add 2 to the stack. Then we descend down - /// child 1 but do not add 0 again because it is already visited and we add 1. So we end up - /// with this stack for the tree: [3, 2, 0, 1]. This means when we pop off nodes from the - /// stack we get to 1 before we get to 0 and that violates the guarantees from this - /// iterator. + /// \ / + /// 0 /// - /// Hence this test to ensure that we do return nodes in the right order. + /// We need to ensure that 0 is processed before 1 and that it is not processed again when + /// processing the children of node 2. /// /// This test and example is essentially a copy from a part of the MastForest of the Miden - /// Stdlib where this error occured. + /// Stdlib where this failed on a previous implementation. #[test] fn multi_mast_forest_child_duplicate() { let block_foo = MastNode::new_basic_block(vec![Operation::Drop], None).unwrap();