diff --git a/core/src/mast/merger/mod.rs b/core/src/mast/merger/mod.rs index 3c858f75c..dde1324e2 100644 --- a/core/src/mast/merger/mod.rs +++ b/core/src/mast/merger/mod.rs @@ -1,9 +1,10 @@ use alloc::{collections::BTreeMap, vec::Vec}; -use core::ops::ControlFlow; use miden_crypto::hash::{blake::Blake3Digest, rpo::RpoDigest}; -use crate::mast::{DecoratorId, EqHash, MastForest, MastForestError, MastNode, MastNodeId}; +use crate::mast::{ + DecoratorId, EqHash, MastForest, MastForestError, MastNode, MastNodeId, MultiMastForestNodeIter, +}; #[cfg(test)] mod tests; @@ -17,43 +18,84 @@ pub(crate) struct MastForestMerger { node_id_by_hash: BTreeMap>, hash_by_node_id: BTreeMap, decorators_by_hash: BTreeMap, DecoratorId>, + decorator_id_mappings: Vec, + node_id_mappings: Vec, } impl MastForestMerger { /// Creates a new merger which creates a new internal, empty forest into which other /// [`MastForest`]s are merged. - pub(crate) fn new() -> Self { - Self { + pub(crate) fn merge<'forest>( + forests: impl IntoIterator, + ) -> Result<(MastForest, Vec), MastForestError> { + let forests = forests.into_iter().collect::>(); + let decorator_id_mappings = Vec::with_capacity(forests.len()); + let node_id_mappings = vec![MastForestNodeIdMap::new(); forests.len()]; + + let mut merger = Self { node_id_by_hash: BTreeMap::new(), hash_by_node_id: BTreeMap::new(), decorators_by_hash: BTreeMap::new(), mast_forest: MastForest::new(), + decorator_id_mappings, + node_id_mappings, + }; + + merger.merge_inner(forests.clone())?; + + let Self { mast_forest, node_id_mappings, .. } = merger; + + let mut root_maps = Vec::new(); + for (forest_idx, mapping) in node_id_mappings.into_iter().enumerate() { + let forest = forests[forest_idx]; + root_maps.push(MastForestRootMap::from_node_id_map(mapping, &forest.roots)); } + + Ok((mast_forest, root_maps)) } - /// Merges `other_forest` into the forest contained in self. - pub(crate) fn merge( + fn merge_inner<'forest>( &mut self, - other_forest: &MastForest, - ) -> Result { - let mut decorator_id_remapping = DecoratorIdMap::new(other_forest.decorators.len()); - let mut node_id_remapping = MastForestNodeIdMap::new(); + forests: Vec<&'forest MastForest>, + ) -> Result<(), MastForestError> { + for other_forest in forests.iter() { + self.merge_decorators(other_forest)?; + } - self.merge_decorators(other_forest, &mut decorator_id_remapping)?; - self.merge_nodes(other_forest, &decorator_id_remapping, &mut node_id_remapping)?; - self.merge_roots(other_forest, &node_id_remapping)?; + let iterator = MultiMastForestNodeIter::new(forests.clone()); + for item in iterator { + match item { + super::MultiMastForestIteratorItem::Regular { forest_idx, node_id } => { + let node = &forests[forest_idx][node_id]; + self.merge_node(forest_idx, node_id, node)?; + }, + super::MultiMastForestIteratorItem::ExternalNodeReplacement { + replacement_forest_idx, + replacement_mast_node_id, + replaced_forest_idx, + replaced_mast_node_id, + } => { + let mapped_replacement = self.node_id_mappings[replacement_forest_idx] + .get(&replacement_mast_node_id) + .copied() + .expect("every node should be mapped"); + + self.node_id_mappings[replaced_forest_idx] + .insert(replaced_mast_node_id, mapped_replacement); + }, + } + } - let root_map = - MastForestRootMap::from_node_id_map(node_id_remapping, other_forest.roots.as_slice()); + for (forest_idx, forest) in forests.iter().enumerate() { + self.merge_roots(forest_idx, &forest)?; + } - Ok(root_map) + Ok(()) } - fn merge_decorators( - &mut self, - other_forest: &MastForest, - decorator_id_remapping: &mut DecoratorIdMap, - ) -> Result<(), MastForestError> { + fn merge_decorators(&mut self, other_forest: &MastForest) -> Result<(), MastForestError> { + let mut decorator_id_remapping = DecoratorIdMap::new(other_forest.decorators.len()); + for (merging_id, merging_decorator) in other_forest.decorators.iter().enumerate() { let merging_decorator_hash = merging_decorator.eq_hash(); let new_decorator_id = if let Some(existing_decorator) = @@ -70,54 +112,76 @@ impl MastForestMerger { .insert(DecoratorId::new_unsafe(merging_id as u32), new_decorator_id); } + self.decorator_id_mappings.push(decorator_id_remapping); + Ok(()) } - fn merge_nodes( + fn merge_node( &mut self, - other_forest: &MastForest, - decorator_id_remapping: &DecoratorIdMap, - node_id_remapping: &mut MastForestNodeIdMap, + forest_idx: usize, + merging_id: MastNodeId, + node: &MastNode, ) -> Result<(), MastForestError> { - for (merging_id, node) in other_forest.iter_nodes() { - // We need to remap the node prior to computing the EqHash. - // - // This is because the EqHash computation looks up its descendants and decorators in - // the internal index, and if we were to pass the original node to that - // computation, it would look up the incorrect descendants and decorators (since the - // descendant's indices may have changed). - // - // Remapping at this point is guaranteed to be "complete", meaning all ids of children - // will be present in `node_id_remapping` since the DFS iteration guarantees - // that all children of this `node` have been processed before this node and - // their indices have been added to the mappings. - let remapped_node = self.remap_node(node, decorator_id_remapping, node_id_remapping)?; - - let node_eq = - EqHash::from_mast_node(&self.mast_forest, &self.hash_by_node_id, &remapped_node); - - match self.merge_external_nodes( - merging_id, - &node_eq, - &remapped_node, - node_id_remapping, - )? { - // Continue is interpreted as doing nothing. - ControlFlow::Continue(_) => (), - // Break is interpreted as continue in the loop sense. - ControlFlow::Break(_) => continue, - } + // We need to remap the node prior to computing the EqHash. + // + // This is because the EqHash computation looks up its descendants and decorators in + // the internal index, and if we were to pass the original node to that + // computation, it would look up the incorrect descendants and decorators (since the + // descendant's indices may have changed). + // + // Remapping at this point is guaranteed to be "complete", meaning all ids of children + // will be present in `node_id_remapping` since the DFS iteration guarantees + // that all children of this `node` have been processed before this node and + // their indices have been added to the mappings. + let remapped_node = self.remap_node(forest_idx, node)?; + + let node_fingerprint = + EqHash::from_mast_node(&self.mast_forest, &self.hash_by_node_id, &remapped_node); + + // If no node with a matching root exists, then the merging node is unique and we can add it + // to the merged forest. + let Some(matching_nodes) = self.lookup_all_nodes_by_root(&node_fingerprint.mast_root) + else { + return self.add_merged_node(forest_idx, merging_id, remapped_node, node_fingerprint); + }; - // If an external node was previously replaced by the remapped node, this will detect - // them as duplicates here if their fingerprints match exactly and add the appropriate - // mapping from the merging id to the existing id. - match self.lookup_node_by_fingerprint(&node_eq) { - Some((_, existing_node_id)) => { - // We have to map any occurence of `merging_id` to `existing_node_id`. - node_id_remapping.insert(merging_id, *existing_node_id); + if remapped_node.is_external() { + // If there already is _any_ node with the same MAST root, map the merging + // external node to that existing one. + let (_, existing_external_node_id) = matching_nodes + .first() + .copied() + .expect("we should never insert empty entries in the internal index"); + self.node_id_mappings[forest_idx].insert(merging_id, existing_external_node_id); + } else { + // It should never be the case that the MAST root of the merging node matches + // the referenced MAST root of an External node in the merged forest due to the + // preprocessing of external nodes. + debug_assert!(matching_nodes.into_iter().all(|(_, matching_node_id)| { + !self.mast_forest[*matching_node_id].is_external() + })); + + match matching_nodes + .into_iter() + .find_map(|(matching_node_fingerprint, node_id)| { + if matching_node_fingerprint == &node_fingerprint { + Some(node_id) + } else { + None + } + }) + .copied() + { + Some(matching_node_id) => { + // If a node with a matching fingerprint exists, then the merging node is a + // duplicate and we remap it to the existing node. + self.node_id_mappings[forest_idx].insert(merging_id, matching_node_id); }, None => { - self.add_merged_node(merging_id, remapped_node, node_id_remapping, node_eq)?; + // If no node with a matching fingerprint exists, then the merging node is + // unique and we can add it to the merged forest. + self.add_merged_node(forest_idx, merging_id, remapped_node, node_fingerprint)?; }, } } @@ -127,13 +191,14 @@ impl MastForestMerger { fn merge_roots( &mut self, + forest_idx: usize, other_forest: &MastForest, - node_id_remapping: &MastForestNodeIdMap, ) -> Result<(), MastForestError> { for root_id in other_forest.roots.iter() { // Map the previous root to its possibly new id. - let new_root = - node_id_remapping.get(root_id).expect("all node ids should have an entry"); + let new_root = self.node_id_mappings[forest_idx] + .get(root_id) + .expect("all node ids should have an entry"); // This will take O(n) every time to check if the root already exists. // We could improve this by keeping a BTreeSet of existing roots during // merging for a faster check. @@ -145,13 +210,13 @@ impl MastForestMerger { fn add_merged_node( &mut self, + forest_idx: usize, previous_id: MastNodeId, node: MastNode, - node_id_remapping: &mut MastForestNodeIdMap, node_eq: EqHash, ) -> Result<(), MastForestError> { let new_node_id = self.mast_forest.add_node(node)?; - node_id_remapping.insert(previous_id, new_node_id); + self.node_id_mappings[forest_idx].insert(previous_id, new_node_id); // We need to update the indices with the newly inserted nodes // since the EqHash computation requires all descendants of a node @@ -168,71 +233,15 @@ impl MastForestMerger { Ok(()) } - /// This will handle two cases: - /// - /// - The existing forest contains a node (external or non-external) with MAST root `foo` and - /// the merging External node refers to `foo`. In this case, the merging node will be mapped - /// to the existing node and dropped. - /// - The existing forest contains an External nodes with a MAST root `foo` and the non-external - /// merging node's digest is `foo`. In this case, the existing external node will be replaced - /// by the merging node. - /// - /// Returns whether the caller should continue in their code path for this node or skip it. - fn merge_external_nodes( - &mut self, - previous_id: MastNodeId, - node_eq: &EqHash, - remapped_node: &MastNode, - node_id_remapping: &mut MastForestNodeIdMap, - ) -> Result, MastForestError> { - if remapped_node.is_external() { - match self.lookup_node_by_root(&node_eq.mast_root) { - // If there already is any node with the same MAST root, map the merging external - // node to that existing one. - // This code path is also entered if the fingerprints match, so we can skip the - // general merging case by returning `Break`. - Some((_, existing_external_node_id)) => { - node_id_remapping.insert(previous_id, *existing_external_node_id); - Ok(ControlFlow::Break(())) - }, - // If no duplicate for the external node exists do nothing as `merge_nodes` - // will simply add the node to the forest. - None => Ok(ControlFlow::Continue(())), - } - } else { - // Replace an external node in self with the given MAST root with the non-external - // node from the merging forest. - // Any node in the existing forest that pointed to the external node will - // have the same MAST root due to the semantics of external nodes. - match self.lookup_external_node_by_root(&node_eq.mast_root) { - Some((_, external_node_id)) => { - self.mast_forest[external_node_id] = remapped_node.clone(); - node_id_remapping.insert(previous_id, external_node_id); - // The other branch of this function guarantees that no external and - // non-external node with the same MAST root exist in the - // merged forest, so if we found an external node with a - // given MAST root, it must be the only one in the merged - // forest, so we can skip the remainder of the `merge_nodes` code path. - Ok(ControlFlow::Break(())) - }, - // If we did not find a matching node, we can continue in the `merge_nodes` code - // path. - None => Ok(ControlFlow::Continue(())), - } - } - } - /// Remaps a nodes' potentially contained children and decorators to their new IDs according to /// the given maps. - fn remap_node( - &self, - node: &MastNode, - decorator_id_remapping: &DecoratorIdMap, - node_id_remapping: &MastForestNodeIdMap, - ) -> Result { + fn remap_node(&self, forest_idx: usize, node: &MastNode) -> Result { let map_decorator_id = |decorator_id: &DecoratorId| { - decorator_id_remapping.get(decorator_id).ok_or_else(|| { - MastForestError::DecoratorIdOverflow(*decorator_id, decorator_id_remapping.len()) + self.decorator_id_mappings[forest_idx].get(decorator_id).ok_or_else(|| { + MastForestError::DecoratorIdOverflow( + *decorator_id, + self.decorator_id_mappings[forest_idx].len(), + ) }) }; let map_decorators = |decorators: &[DecoratorId]| -> Result, MastForestError> { @@ -240,7 +249,7 @@ impl MastForestMerger { }; let map_node_id = |node_id: MastNodeId| { - node_id_remapping + self.node_id_mappings[forest_idx] .get(&node_id) .copied() .expect("every node id should have an entry") @@ -309,34 +318,8 @@ impl MastForestMerger { // HELPERS // ================================================================================================ - fn lookup_node_by_fingerprint(&self, eq_hash: &EqHash) -> Option<&(EqHash, MastNodeId)> { - self.node_id_by_hash.get(&eq_hash.mast_root).and_then(|node_ids| { - node_ids.iter().find(|(node_fingerprint, _)| node_fingerprint == eq_hash) - }) - } - - fn lookup_node_by_root(&self, mast_root: &RpoDigest) -> Option<&(EqHash, MastNodeId)> { - self.node_id_by_hash.get(mast_root).and_then(|node_ids| node_ids.first()) - } - - fn lookup_external_node_by_root(&self, mast_root: &RpoDigest) -> Option<(EqHash, MastNodeId)> { - self.node_id_by_hash.get(mast_root).and_then(|ids| { - let mut iterator = ids - .iter() - .filter(|(_, node_id)| self.mast_forest[*node_id].is_external()) - .copied(); - let external_node = iterator.next(); - // The merging implementation should guarantee that no two external nodes with the same - // MAST root exist. - debug_assert!(iterator.next().is_none()); - external_node - }) - } -} - -impl From for MastForest { - fn from(merger: MastForestMerger) -> Self { - merger.mast_forest + fn lookup_all_nodes_by_root(&self, mast_root: &RpoDigest) -> Option<&[(EqHash, MastNodeId)]> { + self.node_id_by_hash.get(mast_root).map(|node_ids| node_ids.as_slice()) } } diff --git a/core/src/mast/merger/tests.rs b/core/src/mast/merger/tests.rs index a9e6f72d9..f311dac61 100644 --- a/core/src/mast/merger/tests.rs +++ b/core/src/mast/merger/tests.rs @@ -12,7 +12,8 @@ fn block_bar() -> MastNode { } fn block_qux() -> MastNode { - MastNode::new_basic_block(vec![Operation::Swap, Operation::Push(ONE)], None).unwrap() + MastNode::new_basic_block(vec![Operation::Swap, Operation::Push(ONE), Operation::Eq], None) + .unwrap() } fn assert_contains_node_once(forest: &MastForest, digest: RpoDigest) { @@ -32,6 +33,30 @@ fn assert_root_mapping( } } +fn assert_child_id_lt_parent_id(forest: &MastForest) { + for (idx, node) in forest.nodes().iter().enumerate() { + match node { + MastNode::Join(join_node) => { + assert!(join_node.first().as_usize() < idx); + assert!(join_node.second().as_usize() < idx); + }, + MastNode::Split(split_node) => { + assert!(split_node.on_true().as_usize() < idx); + assert!(split_node.on_false().as_usize() < idx); + }, + MastNode::Loop(loop_node) => { + assert!(loop_node.body().as_usize() < idx); + }, + MastNode::Call(call_node) => { + assert!(call_node.callee().as_usize() < idx); + }, + MastNode::Block(_) => (), + MastNode::Dyn(_) => (), + MastNode::External(_) => (), + } + } +} + /// Tests that Call(bar) still correctly calls the remapped bar block. /// /// [Block(foo), Call(foo)] @@ -63,6 +88,8 @@ fn mast_forest_merge_remap() { let root_map_b = &root_maps[1]; assert_eq!(root_map_a.map_root(&id_call_a).unwrap().as_u32(), 1); assert_eq!(root_map_b.map_root(&id_call_b).unwrap().as_u32(), 3); + + assert_child_id_lt_parent_id(&merged); } /// Tests that Forest_A + Forest_A = Forest_A (i.e. duplicates are removed). @@ -95,6 +122,8 @@ fn mast_forest_merge_duplicate() { for merged_decorator in merged.decorators.iter() { assert!(forest_a.decorators.contains(merged_decorator)); } + + assert_child_id_lt_parent_id(&merged); } /// Tests that External(foo) is replaced by Block(foo) whether it is in forest A or B, and the @@ -131,6 +160,7 @@ fn mast_forest_merge_replace_external() { // The only root node should be the call node. assert_eq!(merged.roots.len(), 1); assert_eq!(root_map[0].map_root(&merged.roots[0]).unwrap().as_usize(), 1); + assert_child_id_lt_parent_id(&merged); } } @@ -176,6 +206,8 @@ fn mast_forest_merge_roots() { assert_root_mapping(&root_maps[0], &forest_a.roots, &merged.roots); assert_root_mapping(&root_maps[1], &forest_b.roots, &merged.roots); + + assert_child_id_lt_parent_id(&merged); } /// Test that multiple trees can be merged when the same merger is reused. @@ -234,6 +266,8 @@ fn mast_forest_merge_multiple() { assert_root_mapping(&root_maps[0], &forest_a.roots, &merged.roots); assert_root_mapping(&root_maps[1], &forest_b.roots, &merged.roots); assert_root_mapping(&root_maps[2], &forest_c.roots, &merged.roots); + + assert_child_id_lt_parent_id(&merged); } /// Tests that decorators are merged and that nodes who are identical except for their @@ -367,6 +401,8 @@ fn mast_forest_merge_decorators() { assert_root_mapping(&root_maps[0], &forest_a.roots, &merged.roots); assert_root_mapping(&root_maps[1], &forest_b.roots, &merged.roots); + + assert_child_id_lt_parent_id(&merged); } /// Tests that an external node without decorators is replaced by its referenced node which has @@ -428,6 +464,8 @@ fn mast_forest_merge_external_node_reference_with_decorator() { assert_root_mapping(&root_maps[0], &forest_b.roots, &merged.roots); assert_root_mapping(&root_maps[1], &forest_a.roots, &merged.roots); } + + assert_child_id_lt_parent_id(&merged); } } @@ -494,6 +532,8 @@ fn mast_forest_merge_external_node_with_decorator() { assert_root_mapping(&root_maps[0], &forest_b.roots, &merged.roots); assert_root_mapping(&root_maps[1], &forest_a.roots, &merged.roots); } + + assert_child_id_lt_parent_id(&merged); } } @@ -562,6 +602,8 @@ fn mast_forest_merge_external_node_and_referenced_node_have_decorators() { assert_root_mapping(&root_maps[0], &forest_b.roots, &merged.roots); assert_root_mapping(&root_maps[1], &forest_a.roots, &merged.roots); } + + assert_child_id_lt_parent_id(&merged); } } @@ -638,6 +680,47 @@ fn mast_forest_merge_multiple_external_nodes_with_decorator() { assert_root_mapping(&root_maps[0], &forest_b.roots, &merged.roots); assert_root_mapping(&root_maps[1], &forest_a.roots, &merged.roots); } + + assert_child_id_lt_parent_id(&merged); + } +} + +/// [External(foo), Call(0) = qux] +/// + +/// [External(qux), Call(0), Block(foo)] +/// = +/// [Block(foo), Call(0), Call(1)] +#[test] +fn mast_forest_merge_external_dependencies() { + let mut forest_a = MastForest::new(); + let id_foo_a = forest_a.add_external(block_qux().digest()).unwrap(); + let id_call_a = forest_a.add_call(id_foo_a).unwrap(); + forest_a.make_root(id_call_a); + + let mut forest_b = MastForest::new(); + let id_ext_b = forest_b.add_external(forest_a[id_call_a].digest()).unwrap(); + let id_call_b = forest_b.add_call(id_ext_b).unwrap(); + let id_qux_b = forest_b.add_node(block_qux()).unwrap(); + forest_b.make_root(id_call_b); + forest_b.make_root(id_qux_b); + + for (_, (merged, _)) in [ + MastForest::merge([&forest_a, &forest_b]).unwrap(), + MastForest::merge([&forest_b, &forest_a]).unwrap(), + ] + .into_iter() + .enumerate() + { + let digests = merged.nodes().iter().map(|node| node.digest()).collect::>(); + assert_eq!(merged.nodes().len(), 3); + assert!(digests.contains(&forest_b[id_ext_b].digest())); + assert!(digests.contains(&forest_b[id_call_b].digest())); + assert!(digests.contains(&forest_a[id_foo_a].digest())); + assert!(digests.contains(&forest_a[id_call_a].digest())); + assert!(digests.contains(&forest_b[id_qux_b].digest())); + assert_eq!(merged.nodes().iter().filter(|node| node.is_external()).count(), 0); + + assert_child_id_lt_parent_id(&merged); } } diff --git a/core/src/mast/mod.rs b/core/src/mast/mod.rs index 71cb01a83..7fb482f85 100644 --- a/core/src/mast/mod.rs +++ b/core/src/mast/mod.rs @@ -28,8 +28,8 @@ mod merger; pub(crate) use merger::MastForestMerger; pub use merger::MastForestRootMap; -mod node_iterator; -pub(crate) use node_iterator::*; +mod multi_forest_node_iterator; +pub(crate) use multi_forest_node_iterator::*; #[cfg(test)] mod tests; @@ -254,14 +254,7 @@ impl MastForest { pub fn merge<'forest>( forests: impl IntoIterator, ) -> Result<(MastForest, Vec), MastForestError> { - let mut root_maps = Vec::new(); - let mut merger = MastForestMerger::new(); - - for forest in forests { - root_maps.push(merger.merge(forest)?); - } - - Ok((merger.into(), root_maps)) + MastForestMerger::merge(forests) } /// Adds a basic block node to the forest, and returns the [`MastNodeId`] associated with it. @@ -467,39 +460,40 @@ impl MastForest { &self.nodes } - /// Returns an iterator which traverses over the nodes in a depth-first search and returns nodes - /// in postorder. - /// - /// This iterator iterates through all **reachable** nodes of a forest exactly once. - /// - /// Since a `MastForest` has multiple possible entrypoints in the form of its roots, a - /// depth-first search must visit all of those roots and the trees they form. - /// - /// For instance, consider this `MastForest`: - /// - /// ```text - /// Nodes: [Block(foo), Block(bar), Join(0, 1), External(qux)] - /// Roots: [2] - /// ``` - /// - /// The only root is the `Join` node at index 2. The first three nodes of the forest form a - /// tree, since the `Join` node references index 0 and 1. This tree is discovered by - /// starting at the root at index 2 and following all children until we reach terminal nodes - /// (like `Block`s) and build up a stack of the discovered, but unvisited nodes. The stack - /// is built such that popping elements off the stack (from the back) yields a postorder. - /// - /// After the first tree is discovered, the stack looks like this: `[2, 1, 0]`. On each - /// call to `next` one element is popped off this stack and returned. - /// - /// If the stack is exhausted we start another discovery if more unvisited roots exist. Since - /// the `External` node is not a root and not referenced by any other tree in the forest, it - /// will not be visited. - /// - /// The iteration on a high-level thus consists of a constant back and forth between discovering - /// trees and returning nodes from the stack. - pub fn iter_nodes(&self) -> impl Iterator { - MastForestNodeIter::new(self) - } + // TODO: Replace with Multi iterator? + // /// Returns an iterator which traverses over the nodes in a depth-first search and returns nodes + // /// in postorder. + // /// + // /// This iterator iterates through all **reachable** nodes of a forest exactly once. + // /// + // /// Since a `MastForest` has multiple possible entrypoints in the form of its roots, a + // /// depth-first search must visit all of those roots and the trees they form. + // /// + // /// For instance, consider this `MastForest`: + // /// + // /// ```text + // /// Nodes: [Block(foo), Block(bar), Join(0, 1), External(qux)] + // /// Roots: [2] + // /// ``` + // /// + // /// The only root is the `Join` node at index 2. The first three nodes of the forest form a + // /// tree, since the `Join` node references index 0 and 1. This tree is discovered by + // /// starting at the root at index 2 and following all children until we reach terminal nodes + // /// (like `Block`s) and build up a stack of the discovered, but unvisited nodes. The stack + // /// is built such that popping elements off the stack (from the back) yields a postorder. + // /// + // /// After the first tree is discovered, the stack looks like this: `[2, 1, 0]`. On each + // /// call to `next` one element is popped off this stack and returned. + // /// + // /// If the stack is exhausted we start another discovery if more unvisited roots exist. Since + // /// the `External` node is not a root and not referenced by any other tree in the forest, it + // /// will not be visited. + // /// + // /// The iteration on a high-level thus consists of a constant back and forth between discovering + // /// trees and returning nodes from the stack. + // pub fn iter_nodes(&self) -> impl Iterator { + // MultiMastForestNodeIter::new(self) + // } } impl Index for MastForest { @@ -573,6 +567,11 @@ impl MastNodeId { } } + #[cfg(test)] + pub fn new_unsafe(value: u32) -> Self { + Self(value) + } + pub fn as_usize(&self) -> usize { self.0 as usize } diff --git a/core/src/mast/multi_forest_node_iterator.rs b/core/src/mast/multi_forest_node_iterator.rs new file mode 100644 index 000000000..0a5c8e638 --- /dev/null +++ b/core/src/mast/multi_forest_node_iterator.rs @@ -0,0 +1,396 @@ +use alloc::vec::Vec; +use std::collections::BTreeMap; + +use miden_crypto::hash::rpo::RpoDigest; + +use crate::mast::{MastForest, MastForestError, MastNode, MastNodeId}; + +type ForestIndex = usize; + +/// Depth First Search Iterator in Post Order for [`MastForest`]s. +/// +/// This iterator iterates through all **reachable** nodes of a forest exactly once. +/// +/// Since a `MastForest` has multiple possible entrypoints in the form of its roots, a depth-first +/// search must visit all of those roots and the trees they form. +/// +/// For instance, consider this `MastForest`: +/// +/// ```text +/// Nodes: [Block(foo), Block(bar), Join(0, 1), External(qux)] +/// Roots: [2] +/// ``` +/// +/// The only root is the `Join` node at index 2. The first three nodes of the forest form a +/// tree, since the `Join` node references index 0 and 1. This tree is discovered by +/// starting at the root at index 2 and following all children until we reach terminal nodes (like +/// `Block`s) and build up a stack of the discovered, but unvisited nodes. The stack is +/// built such that popping elements off the stack (from the back) yields a postorder. +/// +/// After the first tree is discovered, the stack looks like this: `[2, 1, 0]`. On each +/// call to `next` one element is popped off this stack and returned. +/// +/// If the stack is exhausted we start another discovery if more unvisited roots exist. Since the +/// `External` node is not a root and not referenced by any other tree in the forest, it will not be +/// visited. +/// +/// The iteration on a high-level thus consists of a constant back and forth between discovering +/// trees and returning nodes from the stack. +/// +/// Note: This type could be made more general to implement pre-order or in-order iteration too. +pub(crate) struct MultiMastForestNodeIter<'forest> { + /// The forest that we're iterating. + mast_forests: Vec<&'forest MastForest>, + /// The procedure root index at which we last started a tree discovery. + /// + /// This value iterates through 0..mast_forest.num_procedures() which guarantees that we visit + /// all nodes reachable from all roots. + last_forest_idx: usize, + last_procedure_root_idx: u32, + non_external_nodes: BTreeMap, + /// Describes whether the node at some index has already been visited. Note that this is set to + /// true for all nodes on the stack, even if the caller of the iterator has not yet seen the + /// node. See [`Self::visit_later`] for more details. + node_visited: Vec>, + /// This stack always contains the discovered but unvisited nodes. + /// For any id store on the stack it holds that `node_visited[id] = true`. + unvisited_node_stack: Vec, +} + +impl<'forest> MultiMastForestNodeIter<'forest> { + pub(crate) fn new(mast_forests: Vec<&'forest MastForest>) -> Self { + let visited = mast_forests + .iter() + .map(|forest| vec![false; forest.num_nodes() as usize]) + .collect(); + + let mut non_external_nodes = BTreeMap::new(); + + for (forest_idx, forest) in mast_forests.iter().enumerate() { + for (node_idx, node) in forest.nodes().iter().enumerate() { + let node_id = MastNodeId::from_u32_safe(node_idx as u32, &mast_forests[forest_idx]) + .expect("the passed id should be a valid node in the forest"); + if !node.is_external() { + non_external_nodes.insert(node.digest(), (forest_idx, node_id)); + } + } + } + + Self { + mast_forests, + last_forest_idx: 0, + last_procedure_root_idx: 0, + non_external_nodes, + node_visited: visited, + unvisited_node_stack: Vec::new(), + } + } + + /// Pushes the given index onto the stack unless the index was already visited. + fn mark_for_visit(&mut self, forest_idx: usize, node_id: MastNodeId) { + // SAFETY: The node_visited Vec's len is equal to the number of forest nodes + // so any `MastNodeId` from that forest is safe to use. + let node_visited_mut = self.node_visited[forest_idx] + .get_mut(node_id.as_usize()) + .expect("node_visited can be safely indexed by any valid MastNodeId"); + + if !*node_visited_mut { + self.unvisited_node_stack + .push(MultiMastForestIteratorItem::Regular { forest_idx, node_id }); + // Set nodes added to the stack as visited even though we have not technically visited + // them. This is however important to avoid visiting nodes twice that appear + // in the same tree. If we were to add all nodes to the stack that we + // discovered, then we would have duplicate ids on the stack. Marking them + // as visited immediately when adding them avoid this issue. + *node_visited_mut = true; + } + } + + /// Discovers a tree starting at the given root index. + fn discover_tree( + &mut self, + forest_idx: usize, + root_idx: MastNodeId, + ) -> Result<(), MastForestError> { + let current_node = + &self.mast_forests[forest_idx].nodes.get(root_idx.as_usize()).ok_or_else(|| { + MastForestError::NodeIdOverflow( + root_idx, + self.mast_forests[forest_idx].num_nodes() as usize, + ) + })?; + + // Note that the order in which we add or discover nodes is the reverse of postorder, since + // we're pushing them onto a stack, which reverses the order itself. Hence, reversing twice + // gives us the actual postorder we want. + match current_node { + MastNode::Block(_) => { + self.mark_for_visit(forest_idx, root_idx); + }, + MastNode::Join(join_node) => { + self.mark_for_visit(forest_idx, root_idx); + self.discover_tree(forest_idx, join_node.second())?; + self.discover_tree(forest_idx, join_node.first())?; + }, + MastNode::Split(split_node) => { + self.mark_for_visit(forest_idx, root_idx); + self.discover_tree(forest_idx, split_node.on_false())?; + self.discover_tree(forest_idx, split_node.on_true())?; + }, + MastNode::Loop(loop_node) => { + self.mark_for_visit(forest_idx, root_idx); + self.discover_tree(forest_idx, loop_node.body())?; + }, + MastNode::Call(call_node) => { + self.mark_for_visit(forest_idx, root_idx); + self.discover_tree(forest_idx, call_node.callee())?; + }, + MastNode::Dyn(_) => { + self.mark_for_visit(forest_idx, root_idx); + }, + MastNode::External(external_node) => { + if let Some((other_forest_idx, other_node_id)) = + self.non_external_nodes.get(&external_node.digest()).copied() + { + let visited = self.node_visited[forest_idx] + .get(root_idx.as_usize()) + .expect("node_visited can be safely indexed by any valid MastNodeId"); + if !visited { + self.unvisited_node_stack.push( + MultiMastForestIteratorItem::ExternalNodeReplacement { + replacement_forest_idx: other_forest_idx, + replacement_mast_node_id: other_node_id, + replaced_forest_idx: forest_idx, + replaced_mast_node_id: root_idx, + }, + ); + } + + self.discover_tree(other_forest_idx, other_node_id)?; + + // Skip external node. + *self.node_visited[forest_idx] + .get_mut(root_idx.as_usize()) + .expect("node_visited can be safely indexed by any valid MastNodeId") = + true; + } else { + self.mark_for_visit(forest_idx, root_idx); + } + }, + } + + Ok(()) + } + + /// Finds the next unvisited procedure root and discovers a tree from it. + /// + /// If the unvisited node stack is empty after calling this function, the iteration is complete. + fn discover_nodes(&mut self) { + 'forest_loop: while self.last_forest_idx < self.mast_forests.len() + && self.unvisited_node_stack.is_empty() + { + if self.mast_forests.is_empty() { + return; + } + if self.mast_forests[self.last_forest_idx].num_procedures() == 0 { + self.last_forest_idx += 1; + continue; + } + + let procedure_roots = self.mast_forests[self.last_forest_idx].procedure_roots(); + let node_visited = &self.node_visited[self.last_forest_idx]; + // Find the next unvisited procedure root. + while node_visited[procedure_roots[self.last_procedure_root_idx as usize].as_usize()] { + if self.last_procedure_root_idx + 1 + >= self.mast_forests[self.last_forest_idx].num_procedures() + { + self.last_procedure_root_idx = 0; + self.last_forest_idx += 1; + continue 'forest_loop; + } + self.last_procedure_root_idx += 1; + } + + let tree_root_id = procedure_roots[self.last_procedure_root_idx as usize]; + self.discover_tree(self.last_forest_idx, tree_root_id) + .expect("we should only pass root indices that are valid for the forest"); + } + } +} + +impl<'forest> Iterator for MultiMastForestNodeIter<'forest> { + type Item = MultiMastForestIteratorItem; + + fn next(&mut self) -> Option { + if let Some(stack_item) = self.unvisited_node_stack.pop() { + // SAFETY: We only add valid ids to the stack so it's fine to index the forest nodes + // directly. + // let node = &self.mast_forests[stack_item.forest_idx].nodes[next_node_id.as_usize()]; + + return Some(stack_item); + } + + self.discover_nodes(); + + if !self.unvisited_node_stack.is_empty() { + self.next() + } else { + // If the stack is empty after tree discovery, all (reachable) nodes have been + None + } + } +} + +#[derive(Debug, Clone, PartialEq, Eq)] +pub(crate) enum MultiMastForestIteratorItem { + Regular { + forest_idx: ForestIndex, + node_id: MastNodeId, + }, + ExternalNodeReplacement { + replacement_forest_idx: usize, + replacement_mast_node_id: MastNodeId, + replaced_forest_idx: usize, + replaced_mast_node_id: MastNodeId, + }, +} + +#[cfg(test)] +mod tests { + use miden_crypto::hash::rpo::RpoDigest; + + use super::*; + use crate::Operation; + + fn random_digest() -> RpoDigest { + RpoDigest::new([rand_utils::rand_value(); 4]) + } + + #[test] + fn multi_mast_forest_dfs_empty() { + let forest = MastForest::new(); + let mut iterator = MultiMastForestNodeIter::new(vec![&forest]); + assert!(iterator.next().is_none()); + } + + #[test] + fn multi_mast_forest_multiple_forests_dfs() { + let nodea0_digest = random_digest(); + let nodea1_digest = random_digest(); + let nodea2_digest = random_digest(); + let nodea3_digest = random_digest(); + + let nodeb0_digest = random_digest(); + + let mut forest_a = MastForest::new(); + forest_a.add_external(nodea0_digest).unwrap(); + let id1 = forest_a.add_external(nodea1_digest).unwrap(); + let id2 = forest_a.add_external(nodea2_digest).unwrap(); + let id3 = forest_a.add_external(nodea3_digest).unwrap(); + let id_split = forest_a.add_split(id2, id3).unwrap(); + let id_join = forest_a.add_join(id2, id_split).unwrap(); + + forest_a.make_root(id_join); + forest_a.make_root(id1); + + let mut forest_b = MastForest::new(); + let id_ext_b = forest_b.add_external(nodeb0_digest).unwrap(); + let id_block_b = forest_b.add_block(vec![Operation::Eqz], None).unwrap(); + let id_split_b = forest_b.add_split(id_ext_b, id_block_b).unwrap(); + + forest_b.make_root(id_split_b); + + // Note that the node at index 0 is not visited because it is not reachable from any root + // and is not a root itself. + let nodes = MultiMastForestNodeIter::new(vec![&forest_a, &forest_b]).collect::>(); + + assert_eq!(nodes.len(), 8); + assert_eq!(nodes[0], MultiMastForestIteratorItem::Regular { forest_idx: 0, node_id: id2 }); + assert_eq!(nodes[1], MultiMastForestIteratorItem::Regular { forest_idx: 0, node_id: id3 }); + assert_eq!( + nodes[2], + MultiMastForestIteratorItem::Regular { forest_idx: 0, node_id: id_split } + ); + assert_eq!( + nodes[3], + MultiMastForestIteratorItem::Regular { forest_idx: 0, node_id: id_join } + ); + assert_eq!(nodes[4], MultiMastForestIteratorItem::Regular { forest_idx: 0, node_id: id1 }); + assert_eq!( + nodes[5], + MultiMastForestIteratorItem::Regular { forest_idx: 1, node_id: id_ext_b } + ); + assert_eq!( + nodes[6], + MultiMastForestIteratorItem::Regular { forest_idx: 1, node_id: id_block_b } + ); + assert_eq!( + nodes[7], + MultiMastForestIteratorItem::Regular { forest_idx: 1, node_id: id_split_b } + ); + } + + #[test] + fn multi_mast_forest_external_dependencies() { + let block_foo = MastNode::new_basic_block(vec![Operation::Drop], None).unwrap(); + let mut forest_a = MastForest::new(); + let id_foo_a = forest_a.add_external(block_foo.digest()).unwrap(); + let id_call_a = forest_a.add_call(id_foo_a).unwrap(); + forest_a.make_root(id_call_a); + + let mut forest_b = MastForest::new(); + let id_ext_b = forest_b.add_external(forest_a[id_call_a].digest()).unwrap(); + let id_call_b = forest_b.add_call(id_ext_b).unwrap(); + forest_b.add_node(block_foo).unwrap(); + forest_b.make_root(id_call_b); + + let nodes = MultiMastForestNodeIter::new(vec![&forest_a, &forest_b]).collect::>(); + + assert_eq!(nodes.len(), 5); + + // The replacement for the external node from forest A. + assert_eq!( + nodes[0], + MultiMastForestIteratorItem::Regular { + forest_idx: 1, + node_id: MastNodeId::new_unsafe(2) + } + ); + // The external node replaced by the block foo from forest B. + assert_eq!( + nodes[1], + MultiMastForestIteratorItem::ExternalNodeReplacement { + replacement_forest_idx: 1, + replacement_mast_node_id: MastNodeId::new_unsafe(2), + replaced_forest_idx: 0, + replaced_mast_node_id: MastNodeId::new_unsafe(0) + } + ); + // The call from forest A. + assert_eq!( + nodes[2], + MultiMastForestIteratorItem::Regular { + forest_idx: 0, + node_id: MastNodeId::new_unsafe(1) + } + ); + // The replacement for the external node that is replaced by the Call in forest A. + assert_eq!( + nodes[3], + MultiMastForestIteratorItem::ExternalNodeReplacement { + replacement_forest_idx: 0, + replacement_mast_node_id: MastNodeId::new_unsafe(1), + replaced_forest_idx: 1, + replaced_mast_node_id: MastNodeId::new_unsafe(0) + } + ); + // The call from forest B. + assert_eq!( + nodes[4], + MultiMastForestIteratorItem::Regular { + forest_idx: 1, + node_id: MastNodeId::new_unsafe(1) + } + ); + } +} diff --git a/core/src/mast/node_iterator.rs b/core/src/mast/node_iterator.rs deleted file mode 100644 index 59bb675d1..000000000 --- a/core/src/mast/node_iterator.rs +++ /dev/null @@ -1,213 +0,0 @@ -use alloc::vec::Vec; - -use crate::mast::{MastForest, MastNode, MastNodeId}; - -/// Depth First Search Iterator in Post Order for [`MastForest`]s. -/// -/// This iterator iterates through all **reachable** nodes of a forest exactly once. -/// -/// Since a `MastForest` has multiple possible entrypoints in the form of its roots, a depth-first -/// search must visit all of those roots and the trees they form. -/// -/// For instance, consider this `MastForest`: -/// -/// ```text -/// Nodes: [Block(foo), Block(bar), Join(0, 1), External(qux)] -/// Roots: [2] -/// ``` -/// -/// The only root is the `Join` node at index 2. The first three nodes of the forest form a -/// tree, since the `Join` node references index 0 and 1. This tree is discovered by -/// starting at the root at index 2 and following all children until we reach terminal nodes (like -/// `Block`s) and build up a stack of the discovered, but unvisited nodes. The stack is -/// built such that popping elements off the stack (from the back) yields a postorder. -/// -/// After the first tree is discovered, the stack looks like this: `[2, 1, 0]`. On each -/// call to `next` one element is popped off this stack and returned. -/// -/// If the stack is exhausted we start another discovery if more unvisited roots exist. Since the -/// `External` node is not a root and not referenced by any other tree in the forest, it will not be -/// visited. -/// -/// The iteration on a high-level thus consists of a constant back and forth between discovering -/// trees and returning nodes from the stack. -/// -/// Note: This type could be made more general to implement pre-order or in-order iteration too. -pub(crate) struct MastForestNodeIter<'forest> { - /// The forest that we're iterating. - pub mast_forest: &'forest MastForest, - /// The procedure root index at which we last started a tree discovery. - /// - /// This value iterates through 0..mast_forest.num_procedures() which guarantees that we visit - /// all nodes reachable from all roots. - pub last_procedure_root_idx: u32, - /// Describes whether the node at some index has already been visited. Note that this is set to - /// true for all nodes on the stack, even if the caller of the iterator has not yet seen the - /// node. See [`Self::visit_later`] for more details. - pub node_visited: Vec, - /// This stack always contains the discovered but unvisited nodes. - /// For any id store on the stack it holds that `node_visited[id] = true`. - pub unvisited_node_stack: Vec, -} - -impl<'forest> MastForestNodeIter<'forest> { - pub(crate) fn new(mast_forest: &'forest MastForest) -> Self { - let visited = vec![false; mast_forest.num_nodes() as usize]; - - Self { - mast_forest, - last_procedure_root_idx: 0, - node_visited: visited, - unvisited_node_stack: Vec::new(), - } - } - - /// Pushes the given index onto the stack unless the index was already visited. - fn mark_for_visit(&mut self, node_id: MastNodeId) { - // SAFETY: The node_visited Vec's len is equal to the number of forest nodes - // so any `MastNodeId` from that forest is safe to use. - let node_visited_mut = self - .node_visited - .get_mut(node_id.as_usize()) - .expect("node_visited can be safely indexed by any valid MastNodeId"); - - if !*node_visited_mut { - self.unvisited_node_stack.push(node_id); - // Set nodes added to the stack as visited even though we have not technically visited - // them. This is however important to avoid visiting nodes twice that appear - // in the same tree. If we were to add all nodes to the stack that we - // discovered, then we would have duplicate ids on the stack. Marking them - // as visited immediately when adding them avoid this issue. - *node_visited_mut = true; - } - } - - /// Discovers a tree starting at the given root index. - fn discover_tree(&mut self, root_idx: MastNodeId) { - let current_node = &self.mast_forest.nodes[root_idx.as_usize()]; - // Note that the order in which we add or discover nodes is the reverse of postorder, since - // we're pushing them onto a stack, which reverses the order itself. Hence, reversing twice - // gives us the actual postorder we want. - match current_node { - MastNode::Block(_) => { - self.mark_for_visit(root_idx); - }, - MastNode::Join(join_node) => { - self.mark_for_visit(root_idx); - self.discover_tree(join_node.second()); - self.discover_tree(join_node.first()); - }, - MastNode::Split(split_node) => { - self.mark_for_visit(root_idx); - self.discover_tree(split_node.on_false()); - self.discover_tree(split_node.on_true()); - }, - MastNode::Loop(loop_node) => { - self.mark_for_visit(root_idx); - self.discover_tree(loop_node.body()); - }, - MastNode::Call(call_node) => { - self.mark_for_visit(root_idx); - self.discover_tree(call_node.callee()); - }, - MastNode::Dyn(_) => { - self.mark_for_visit(root_idx); - }, - MastNode::External(_) => { - self.mark_for_visit(root_idx); - }, - } - } - - /// Finds the next unvisited procedure root and discovers a tree from it. - /// - /// If the unvisited node stack is empty after calling this function, the iteration is complete. - fn discover_nodes(&mut self) { - if self.mast_forest.num_procedures() == 0 { - return; - } - - let procedure_roots = self.mast_forest.procedure_roots(); - // Find the next unvisited procedure root. - while self.node_visited[procedure_roots[self.last_procedure_root_idx as usize].as_usize()] { - if self.last_procedure_root_idx + 1 >= self.mast_forest.num_procedures() { - return; - } - - self.last_procedure_root_idx += 1; - } - - let tree_root_id = procedure_roots[self.last_procedure_root_idx as usize]; - self.discover_tree(tree_root_id); - } -} - -impl<'forest> Iterator for MastForestNodeIter<'forest> { - type Item = (MastNodeId, &'forest MastNode); - - fn next(&mut self) -> Option { - if let Some(next_node_id) = self.unvisited_node_stack.pop() { - // SAFETY: We only add valid ids to the stack so it's fine to index the forest nodes - // directly. - let node = &self.mast_forest.nodes[next_node_id.as_usize()]; - return Some((next_node_id, node)); - } - - self.discover_nodes(); - - if !self.unvisited_node_stack.is_empty() { - self.next() - } else { - // If the stack is empty after tree discovery, all (reachable) nodes have been visited. - None - } - } -} - -#[cfg(test)] -mod tests { - use miden_crypto::hash::rpo::RpoDigest; - - use super::*; - - fn random_digest() -> RpoDigest { - RpoDigest::new([rand_utils::rand_value(); 4]) - } - - #[test] - fn mast_forest_dfs_empty() { - let forest = MastForest::new(); - let mut iterator = forest.iter_nodes(); - assert!(iterator.next().is_none()); - } - - #[test] - fn mast_forest_dfs() { - let node0_digest = random_digest(); - let node1_digest = random_digest(); - let node2_digest = random_digest(); - let node3_digest = random_digest(); - - let mut forest = MastForest::new(); - forest.add_external(node0_digest).unwrap(); - let id1 = forest.add_external(node1_digest).unwrap(); - let id2 = forest.add_external(node2_digest).unwrap(); - let id3 = forest.add_external(node3_digest).unwrap(); - let id_split = forest.add_split(id2, id3).unwrap(); - let id_join = forest.add_join(id2, id_split).unwrap(); - - forest.make_root(id_join); - forest.make_root(id1); - - // Note that the node at index 0 is not visited because it is not reachable from any root - // and is not a root itself. - let mut iterator = forest.iter_nodes(); - // Node at id2 should only be visited once. - assert_matches!(iterator.next().unwrap(), (id, MastNode::External(digest)) if digest.digest() == node2_digest && id == id2); - assert_matches!(iterator.next().unwrap(), (id, MastNode::External(digest)) if digest.digest() == node3_digest && id == id3); - assert_matches!(iterator.next().unwrap(), (id, MastNode::Split(_)) if id == id_split); - assert_matches!(iterator.next().unwrap(), (id, MastNode::Join(_)) if id == id_join); - assert_matches!(iterator.next().unwrap(), (id, MastNode::External(digest)) if digest.digest() == node1_digest&& id == id1); - assert!(iterator.next().is_none()); - } -}