Skip to content

Commit

Permalink
feat(core): Use a deque instead of a stack in the iterator
Browse files Browse the repository at this point in the history
  • Loading branch information
PhilippGackstatter committed Oct 25, 2024
1 parent f5a7327 commit 2fbe528
Show file tree
Hide file tree
Showing 2 changed files with 76 additions and 96 deletions.
2 changes: 1 addition & 1 deletion core/src/mast/merger/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -62,7 +62,7 @@ impl MastForestMerger {
/// 2. Merge all nodes of forests.
/// - Similar to decorators, node indices might move during merging, so the merger keeps a
/// node id mapping as it merges nodes.
/// - This is a depth-first traversal over the forests to ensure all children are processed
/// - This is a depth-first traversal over all forests to ensure all children are processed
/// before their parents. See the documentation of [`MultiMastForestNodeIter`] for details
/// on this traversal.
/// - Because all parents are processed after their children, we can use the node id mapping
Expand Down
170 changes: 75 additions & 95 deletions core/src/mast/multi_forest_node_iterator.rs
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
use alloc::{collections::BTreeMap, vec::Vec};
use std::collections::VecDeque;

use miden_crypto::hash::rpo::RpoDigest;

Expand Down Expand Up @@ -37,29 +38,30 @@ type ForestIndex = usize;
/// The only root of A is the `Join` node at index 2. The first three nodes of the forest form a
/// tree, since the `Join` node references index 0 and 1. This tree is discovered by
/// starting at the root at index 2 and following all children until we reach terminal nodes (like
/// `Block`s) and building up a stack of the discovered nodes. The special case here is the
/// `Block`s) and building up a deque of the discovered nodes. The special case here is the
/// `External` node whose digest matches that of a node in forest B. Instead of the External
/// node being added to the stack, the tree of the Call node is added instead. The stack is built
/// such that popping elements off the stack (from the back) yields a postorder.
/// node being added to the deque, the tree of the Call node is added instead. The deque is built
/// such that popping elements off the deque (from the front) yields a postorder.
///
/// After the first tree is discovered, the deque looks like this:
///
/// After the first tree is discovered, the stack looks like this:
/// ```text
/// [Node(forest_idx: 0, node_id: 2),
/// [Node(forest_idx: 0, node_id: 0),
/// Node(forest_idx: 1, node_id: 0),
/// Node(forest_idx: 1, node_id: 1),
/// ExternalNodeReplacement(
/// replacement_forest_idx: 1, replacement_node_id: 1
/// replaced_forest_idx: 0, replaced_node_id: 1
/// ),
/// Node(forest_idx: 1, node_id: 1),
/// Node(forest_idx: 1, node_id: 0),
/// Node(forest_idx: 0, node_id: 0)]
/// Node(forest_idx: 0, node_id: 2)]
/// ```
///
/// If the stack is exhausted we start another discovery if more unvisited roots exist. In this
/// example, the root of forest B was already visited due to the External node reference, so the
/// iteration is complete.
/// If the deque is exhausted we start another discovery if more undiscovered roots exist. In this
/// example, the root of forest B was already discovered and visited due to the External node
/// reference, so the iteration is complete.
///
/// The iteration on a higher level thus consists of a back and forth between discovering trees and
/// returning nodes from the stack.
/// returning nodes from the deque.
pub(crate) struct MultiMastForestNodeIter<'forest> {
/// The forests that we're iterating.
mast_forests: Vec<&'forest MastForest>,
Expand All @@ -77,21 +79,20 @@ pub(crate) struct MultiMastForestNodeIter<'forest> {
/// A map of MAST roots of all non-external nodes in mast_forests to their forest and node
/// indices.
non_external_nodes: BTreeMap<RpoDigest, (ForestIndex, MastNodeId)>,
/// Describes whether the node at some [forest_index][node_index] has already been visited.
/// Note that this is set to true for all nodes that have been returned from the iterator.
visited_nodes: Vec<Vec<bool>>,
/// This stack always contains the discovered nodes.
/// The stack might contain a node twice. However we only ever return a node once, which is
/// checked in the `next` function.
discovered_nodes: Vec<MultiMastForestIteratorItem>,
/// Describes whether the node identified by [forest_index][node_index] has already been
/// discovered. Note that this is `true` for all nodes that are in the unvisited node deque.
discovered_nodes: Vec<Vec<bool>>,
/// This deque always contains the discovered, but unvisited nodes.
/// It holds that discovered_nodes[forest_idx][node_id] = true for all elements in this deque.
unvisited_nodes: VecDeque<MultiMastForestIteratorItem>,
}

impl<'forest> MultiMastForestNodeIter<'forest> {
/// Builds a map of MAST roots to non-external nodes in any of the given forests to initialize
/// the iterator. This enables an efficient check whether for any encountered External node
/// referencing digest `foo` a node with digest `foo` already exists in any forest.
pub(crate) fn new(mast_forests: Vec<&'forest MastForest>) -> Self {
let visited_nodes = mast_forests
let discovered_nodes = mast_forests
.iter()
.map(|forest| vec![false; forest.num_nodes() as usize])
.collect();
Expand All @@ -114,28 +115,34 @@ impl<'forest> MultiMastForestNodeIter<'forest> {
current_forest_idx: 0,
current_procedure_root_idx: 0,
non_external_nodes,
visited_nodes,
discovered_nodes: Vec::new(),
discovered_nodes,
unvisited_nodes: VecDeque::new(),
}
}

/// Pushes the given node, uniquely identified by the forest and node index onto the stack
/// Pushes the given node, uniquely identified by the forest and node index onto the deque
/// even if the node was already discovered before.
///
/// It's the callers responsibility to only pass valid indices.
fn push_node(&mut self, forest_idx: usize, node_id: MastNodeId) {
self.discovered_nodes
.push(MultiMastForestIteratorItem::Node { forest_idx, node_id });
self.unvisited_nodes
.push_back(MultiMastForestIteratorItem::Node { forest_idx, node_id });
self.discovered_nodes[forest_idx][node_id.as_usize()] = true;
}

/// Discovers a tree starting at the given forest index and node id.
///
/// It's the callers responsibility to only pass valid indices.
/// SAFETY: We only pass valid forest and node indices so we can index directly in this
/// function.
fn discover_tree(
&mut self,
forest_idx: ForestIndex,
node_id: MastNodeId,
) -> Result<(), MastForestError> {
if self.discovered_nodes[forest_idx][node_id.as_usize()] {
return Ok(());
}

let current_node =
&self.mast_forests[forest_idx].nodes.get(node_id.as_usize()).ok_or_else(|| {
MastForestError::NodeIdOverflow(
Expand All @@ -144,30 +151,29 @@ impl<'forest> MultiMastForestNodeIter<'forest> {
)
})?;

// Note that the order in which we add or discover nodes is the reverse of postorder, since
// we're pushing them onto a stack, which reverses the order itself. Hence, reversing twice
// gives us the actual postorder we want.
// Note that we can process nodes in postorder, since we push them onto the back of the
// deque but pop them off the front.
match current_node {
MastNode::Block(_) => {
self.push_node(forest_idx, node_id);
},
MastNode::Join(join_node) => {
self.push_node(forest_idx, node_id);
self.discover_tree(forest_idx, join_node.second())?;
self.discover_tree(forest_idx, join_node.first())?;
self.discover_tree(forest_idx, join_node.second())?;
self.push_node(forest_idx, node_id);
},
MastNode::Split(split_node) => {
self.push_node(forest_idx, node_id);
self.discover_tree(forest_idx, split_node.on_false())?;
self.discover_tree(forest_idx, split_node.on_true())?;
self.discover_tree(forest_idx, split_node.on_false())?;
self.push_node(forest_idx, node_id);
},
MastNode::Loop(loop_node) => {
self.push_node(forest_idx, node_id);
self.discover_tree(forest_idx, loop_node.body())?;
self.push_node(forest_idx, node_id);
},
MastNode::Call(call_node) => {
self.push_node(forest_idx, node_id);
self.discover_tree(forest_idx, call_node.callee())?;
self.push_node(forest_idx, node_id);
},
MastNode::Dyn(_) => {
self.push_node(forest_idx, node_id);
Expand All @@ -176,19 +182,20 @@ impl<'forest> MultiMastForestNodeIter<'forest> {
// When we encounter an external node referencing digest `foo` there are two cases:
// - If there exists a node `replacement` in any forest with digest `foo`, we want
// to replace the external node with that node, which we do in two steps.
// - Discover the `replacement`'s tree and add it to the stack.
// - If `replacement` was already visited before, it won't actually be returned.
// - Discover the `replacement`'s tree and add it to the deque.
// - If `replacement` was already discovered before, it won't actually be
// returned.
// - In any case this means: The `replacement` node is processed before the
// replacement signal we're adding next.
// - Add a replacement signal to the stack, signaling that the `replacement`
// - Add a replacement signal to the deque, signaling that the `replacement`
// replaced the external node.
// - Note that the order of these operations in code is reversed, since the stack
// we're pushing the operations onto reverses the order once more.
// - If no replacement exists, yield the External Node as a regular `Node`.
if let Some((other_forest_idx, other_node_id)) =
self.non_external_nodes.get(&external_node.digest()).copied()
{
self.discovered_nodes.push(
self.discover_tree(other_forest_idx, other_node_id)?;

self.unvisited_nodes.push_back(
MultiMastForestIteratorItem::ExternalNodeReplacement {
replacement_forest_idx: other_forest_idx,
replacement_mast_node_id: other_node_id,
Expand All @@ -197,7 +204,7 @@ impl<'forest> MultiMastForestNodeIter<'forest> {
},
);

self.discover_tree(other_forest_idx, other_node_id)?;
self.discovered_nodes[forest_idx][node_id.as_usize()] = true;
} else {
self.push_node(forest_idx, node_id);
}
Expand All @@ -207,16 +214,17 @@ impl<'forest> MultiMastForestNodeIter<'forest> {
Ok(())
}

/// Finds the next unvisited procedure root and discovers a tree from it.
/// Finds the next undiscovered procedure root and discovers a tree from it.
///
/// If the unvisited node stack is empty after calling this function, the iteration is complete.
/// If the undiscovered node deque is empty after calling this function, the iteration is
/// complete.
///
/// This function basically consists of two loops:
/// - The outer loop iterates over all forest indices.
/// - The inner loop iterates over all procedure root indices for the current forest.
fn discover_nodes(&mut self) {
'forest_loop: while self.current_forest_idx < self.mast_forests.len()
&& self.discovered_nodes.is_empty()
&& self.unvisited_nodes.is_empty()
{
// If we don't have any forests, there is nothing to do.
if self.mast_forests.is_empty() {
Expand All @@ -230,11 +238,11 @@ impl<'forest> MultiMastForestNodeIter<'forest> {
}

let procedure_roots = self.mast_forests[self.current_forest_idx].procedure_roots();
let visited_nodes = &self.visited_nodes[self.current_forest_idx];
let discovered_nodes = &self.discovered_nodes[self.current_forest_idx];

// Find the next unvisited procedure root for the current forest by incrementing the
// current procedure root until we find one that was not yet visited.
while visited_nodes
// Find the next undiscovered procedure root for the current forest by incrementing the
// current procedure root until we find one that was not yet discovered.
while discovered_nodes
[procedure_roots[self.current_procedure_root_idx as usize].as_usize()]
{
// If we have reached the end of the procedure roots for the current forest,
Expand All @@ -250,15 +258,15 @@ impl<'forest> MultiMastForestNodeIter<'forest> {
continue 'forest_loop;
}

// Since the current procedure root was already visited, check the next one.
// Since the current procedure root was already discovered, check the next one.
self.current_procedure_root_idx += 1;
}

// We exited the loop, so the current procedure root is unvisited and so we can start
// a discovery from that root. Since that root is unvisited and undiscovered, it is
// guaranteed that after this discovery the stack will be non-empty.
let tree_root_id = procedure_roots[self.current_procedure_root_idx as usize];
self.discover_tree(self.current_forest_idx, tree_root_id)
// We exited the loop, so the current procedure root is undiscovered and so we can start
// a discovery from that root. Since that root is undiscovered, it is guaranteed that
// after this discovery the deque will be non-empty.
let procedure_root_id = procedure_roots[self.current_procedure_root_idx as usize];
self.discover_tree(self.current_forest_idx, procedure_root_id)
.expect("we should only pass root indices that are valid for the forest");
}
}
Expand All @@ -268,36 +276,16 @@ impl Iterator for MultiMastForestNodeIter<'_> {
type Item = MultiMastForestIteratorItem;

fn next(&mut self) -> Option<Self::Item> {
while let Some(stack_item) = self.discovered_nodes.pop() {
// Get the forest and node index of the node being processed on the stack so we can
// check if it has already been visited.
let (forest_idx, node_id) = match &stack_item {
MultiMastForestIteratorItem::Node { forest_idx, node_id } => (forest_idx, node_id),
MultiMastForestIteratorItem::ExternalNodeReplacement {
replaced_forest_idx,
replaced_mast_node_id,
..
} => (replaced_forest_idx, replaced_mast_node_id),
};

let is_node_visited_mut = self.visited_nodes[*forest_idx]
.get_mut(node_id.as_usize())
.expect("visited_nodes can be safely indexed by any valid MastNodeId");
if *is_node_visited_mut {
continue;
} else {
*is_node_visited_mut = true;
}

return Some(stack_item);
if let Some(deque_item) = self.unvisited_nodes.pop_front() {
return Some(deque_item);
}

self.discover_nodes();

if !self.discovered_nodes.is_empty() {
if !self.unvisited_nodes.is_empty() {
self.next()
} else {
// If the stack is empty after tree discovery, all (reachable) nodes have been
// If the deque is empty after tree discovery, all (reachable) nodes have been
// discovered and visited.
None
}
Expand Down Expand Up @@ -460,30 +448,22 @@ mod tests {
);
}

/// Tests that a node which appears twice in a tree is returned in the required order.
/// Tests that a node which is referenced twice in a Mast Forest is returned in the required
/// order.
///
/// In this test we have a MastForest with a tree like this:
/// In this test we have a MastForest with this graph:
///
/// 3 <- Split Node
/// / \
/// 1 2
/// / \
/// 0 0
///
/// In a previous implementation we marked the nodes that we discovered as visited immediately
/// and did not add them again when encountering them again. This was not correct because of
/// this example. In the iterator, we need to discover nodes in reverse so we can later pop
/// them off the stack in postorder. So we first descend down child 2 of the split node.
/// This means we would mark 0 as visited and then add 2 to the stack. Then we descend down
/// child 1 but do not add 0 again because it is already visited and we add 1. So we end up
/// with this stack for the tree: [3, 2, 0, 1]. This means when we pop off nodes from the
/// stack we get to 1 before we get to 0 and that violates the guarantees from this
/// iterator.
/// \ /
/// 0
///
/// Hence this test to ensure that we do return nodes in the right order.
/// We need to ensure that 0 is processed before 1 and that it is not processed again when
/// processing the children of node 2.
///
/// This test and example is essentially a copy from a part of the MastForest of the Miden
/// Stdlib where this error occured.
/// Stdlib where this failed on a previous implementation.
#[test]
fn multi_mast_forest_child_duplicate() {
let block_foo = MastNode::new_basic_block(vec![Operation::Drop], None).unwrap();
Expand Down

0 comments on commit 2fbe528

Please sign in to comment.