diff --git a/assembly/src/assembler/mast_forest_merger_tests.rs b/assembly/src/assembler/mast_forest_merger_tests.rs index 9d1b9faba..a2382bcb1 100644 --- a/assembly/src/assembler/mast_forest_merger_tests.rs +++ b/assembly/src/assembler/mast_forest_merger_tests.rs @@ -20,7 +20,7 @@ fn merge_programs( let lib_b = assembler.assemble_library([program_b])?.mast_forest().as_ref().clone(); let lib_a = lib_a.mast_forest().as_ref().clone(); - let merged = MastForest::merge([&lib_a, &lib_b]).into_diagnostic()?; + let merged = MastForest::merge([lib_a.clone(), lib_b.clone()]).into_diagnostic()?; Ok((lib_a, lib_b, merged)) } diff --git a/core/src/mast/merger/mod.rs b/core/src/mast/merger/mod.rs index 3c858f75c..289bcac08 100644 --- a/core/src/mast/merger/mod.rs +++ b/core/src/mast/merger/mod.rs @@ -3,7 +3,10 @@ use core::ops::ControlFlow; use miden_crypto::hash::{blake::Blake3Digest, rpo::RpoDigest}; -use crate::mast::{DecoratorId, EqHash, MastForest, MastForestError, MastNode, MastNodeId}; +use crate::{ + mast::{DecoratorId, EqHash, MastForest, MastForestError, MastNode, MastNodeId}, + Decorator, +}; #[cfg(test)] mod tests; @@ -34,34 +37,41 @@ impl MastForestMerger { /// Merges `other_forest` into the forest contained in self. pub(crate) fn merge( &mut self, - other_forest: &MastForest, + mut other_forest: MastForest, ) -> Result { let mut decorator_id_remapping = DecoratorIdMap::new(other_forest.decorators.len()); let mut node_id_remapping = MastForestNodeIdMap::new(); - self.merge_decorators(other_forest, &mut decorator_id_remapping)?; + // It's fine to take out the decorators here as they aren't accessed after this point. + self.merge_decorators( + core::mem::take(&mut other_forest.decorators), + &mut decorator_id_remapping, + )?; + // `merge_nodes` takes ownership of the forest and needs the roots intact for the DFS + // iteration, so we cannot core::mem::take them out, so we copy the roots as we need + // them after merging of nodes is done. + let roots = other_forest.roots.clone(); self.merge_nodes(other_forest, &decorator_id_remapping, &mut node_id_remapping)?; - self.merge_roots(other_forest, &node_id_remapping)?; + self.merge_roots(roots.as_slice(), &node_id_remapping)?; - let root_map = - MastForestRootMap::from_node_id_map(node_id_remapping, other_forest.roots.as_slice()); + let root_map = MastForestRootMap::from_node_id_map(node_id_remapping, roots.as_slice()); Ok(root_map) } fn merge_decorators( &mut self, - other_forest: &MastForest, + decorators: Vec, decorator_id_remapping: &mut DecoratorIdMap, ) -> Result<(), MastForestError> { - for (merging_id, merging_decorator) in other_forest.decorators.iter().enumerate() { + for (merging_id, merging_decorator) in decorators.into_iter().enumerate() { let merging_decorator_hash = merging_decorator.eq_hash(); let new_decorator_id = if let Some(existing_decorator) = self.decorators_by_hash.get(&merging_decorator_hash) { *existing_decorator } else { - let new_decorator_id = self.mast_forest.add_decorator(merging_decorator.clone())?; + let new_decorator_id = self.mast_forest.add_decorator(merging_decorator)?; self.decorators_by_hash.insert(merging_decorator_hash, new_decorator_id); new_decorator_id }; @@ -75,7 +85,7 @@ impl MastForestMerger { fn merge_nodes( &mut self, - other_forest: &MastForest, + other_forest: MastForest, decorator_id_remapping: &DecoratorIdMap, node_id_remapping: &mut MastForestNodeIdMap, ) -> Result<(), MastForestError> { @@ -127,10 +137,10 @@ impl MastForestMerger { fn merge_roots( &mut self, - other_forest: &MastForest, + roots: &[MastNodeId], node_id_remapping: &MastForestNodeIdMap, ) -> Result<(), MastForestError> { - for root_id in other_forest.roots.iter() { + for root_id in roots { // Map the previous root to its possibly new id. let new_root = node_id_remapping.get(root_id).expect("all node ids should have an entry"); @@ -226,7 +236,7 @@ impl MastForestMerger { /// the given maps. fn remap_node( &self, - node: &MastNode, + node: MastNode, decorator_id_remapping: &DecoratorIdMap, node_id_remapping: &MastForestNodeIdMap, ) -> Result { @@ -246,6 +256,15 @@ impl MastForestMerger { .expect("every node id should have an entry") }; + // Decorators must be handled specially for basic block nodes. + // For other node types we can handle it centrally. + let mut before_enter = Vec::new(); + let mut after_exit = Vec::new(); + if !node.is_basic_block() { + before_enter = map_decorators(node.before_enter())?; + after_exit = map_decorators(node.after_exit())?; + } + // Due to DFS postorder iteration all children of node's should have been inserted before // their parents which is why we can `expect` the constructor calls here. let mut mapped_node = match node { @@ -273,34 +292,24 @@ impl MastForestMerger { MastNode::new_call(callee, &self.mast_forest) .expect("CallNode children should have been mapped to a lower index") }, - // Other nodes are simply copied. - MastNode::Block(basic_block_node) => { - MastNode::new_basic_block( - basic_block_node.operations().copied().collect(), - // Operation Indices of decorators stay the same while decorator IDs need to be - // mapped. - Some( - basic_block_node - .decorators() - .iter() - .map(|(idx, decorator_id)| match map_decorator_id(decorator_id) { - Ok(mapped_decorator) => Ok((*idx, mapped_decorator)), - Err(err) => Err(err), - }) - .collect::, _>>()?, - ), - ) - .expect("previously valid BasicBlockNode should still be valid") + MastNode::Block(mut basic_block_node) => { + basic_block_node.map_decorators(|decorator_id| { + match map_decorator_id(decorator_id) { + Ok(mapped_decorator) => Ok(mapped_decorator), + Err(err) => Err(err), + } + })?; + + MastNode::Block(basic_block_node) }, - MastNode::Dyn(_) => MastNode::new_dyn(), - MastNode::External(external_node) => MastNode::new_external(external_node.digest()), + // Other nodes are simply copied. + MastNode::Dyn(_) => node, + MastNode::External(_) => node, }; - // Decorators must be handled specially for basic block nodes. - // For other node types we can handle it centrally. if !mapped_node.is_basic_block() { - mapped_node.set_before_enter(map_decorators(node.before_enter())?); - mapped_node.set_after_exit(map_decorators(node.after_exit())?); + mapped_node.set_before_enter(before_enter); + mapped_node.set_after_exit(after_exit); } Ok(mapped_node) diff --git a/core/src/mast/merger/tests.rs b/core/src/mast/merger/tests.rs index a9e6f72d9..2ea34f314 100644 --- a/core/src/mast/merger/tests.rs +++ b/core/src/mast/merger/tests.rs @@ -51,7 +51,7 @@ fn mast_forest_merge_remap() { let id_call_b = forest_b.add_call(id_bar).unwrap(); forest_b.make_root(id_call_b); - let (merged, root_maps) = MastForest::merge([&forest_a, &forest_b]).unwrap(); + let (merged, root_maps) = MastForest::merge([forest_a, forest_b]).unwrap(); assert_eq!(merged.nodes().len(), 4); assert_eq!(merged.nodes()[0], block_foo()); @@ -79,7 +79,7 @@ fn mast_forest_merge_duplicate() { forest_a.make_root(id_call); forest_a.make_root(id_loop); - let (merged, root_maps) = MastForest::merge([&forest_a, &forest_a]).unwrap(); + let (merged, root_maps) = MastForest::merge([forest_a.clone(), forest_a.clone()]).unwrap(); for merged_root in merged.procedure_digests() { forest_a.procedure_digests().find(|root| root == &merged_root).unwrap(); @@ -121,8 +121,10 @@ fn mast_forest_merge_replace_external() { let id_call_b = forest_b.add_call(id_foo_b).unwrap(); forest_b.make_root(id_call_b); - let (merged_ab, root_maps_ab) = MastForest::merge([&forest_a, &forest_b]).unwrap(); - let (merged_ba, root_maps_ba) = MastForest::merge([&forest_b, &forest_a]).unwrap(); + let (merged_ab, root_maps_ab) = + MastForest::merge([forest_a.clone(), forest_b.clone()]).unwrap(); + let (merged_ba, root_maps_ba) = + MastForest::merge([forest_b.clone(), forest_a.clone()]).unwrap(); for (merged, root_map) in [(merged_ab, root_maps_ab), (merged_ba, root_maps_ba)] { assert_eq!(merged.nodes().len(), 2); @@ -162,7 +164,7 @@ fn mast_forest_merge_roots() { let root_digest_bar_b = forest_b.get_node_by_id(id_bar_b).unwrap().digest(); let root_digest_call_b = forest_b.get_node_by_id(call_b).unwrap().digest(); - let (merged, root_maps) = MastForest::merge([&forest_a, &forest_b]).unwrap(); + let (merged, root_maps) = MastForest::merge([forest_a.clone(), forest_b.clone()]).unwrap(); // Asserts (together with the other assertions) that the duplicate Call(foo) roots have been // deduplicated. @@ -212,7 +214,8 @@ fn mast_forest_merge_multiple() { forest_c.make_root(id_qux_c); forest_c.make_root(call_c); - let (merged, root_maps) = MastForest::merge([&forest_a, &forest_b, &forest_c]).unwrap(); + let (merged, root_maps) = + MastForest::merge([forest_a.clone(), forest_b.clone(), forest_c.clone()]).unwrap(); let block_foo_digest = forest_b.get_node_by_id(id_foo_b).unwrap().digest(); let block_bar_digest = forest_b.get_node_by_id(id_bar_b).unwrap().digest(); @@ -295,7 +298,7 @@ fn mast_forest_merge_decorators() { forest_b.make_root(id_loop_b); - let (merged, root_maps) = MastForest::merge([&forest_a, &forest_b]).unwrap(); + let (merged, root_maps) = MastForest::merge([forest_a.clone(), forest_b.clone()]).unwrap(); // There are 4 unique decorators across both forests. assert_eq!(merged.decorators.len(), 4); @@ -403,8 +406,8 @@ fn mast_forest_merge_external_node_reference_with_decorator() { forest_b.make_root(id_external_b); for (idx, (merged, root_maps)) in [ - MastForest::merge([&forest_a, &forest_b]).unwrap(), - MastForest::merge([&forest_b, &forest_a]).unwrap(), + MastForest::merge([forest_a.clone(), forest_b.clone()]).unwrap(), + MastForest::merge([forest_b.clone(), forest_a.clone()]).unwrap(), ] .into_iter() .enumerate() @@ -467,8 +470,8 @@ fn mast_forest_merge_external_node_with_decorator() { forest_b.make_root(id_foo_b); for (idx, (merged, root_maps)) in [ - MastForest::merge([&forest_a, &forest_b]).unwrap(), - MastForest::merge([&forest_b, &forest_a]).unwrap(), + MastForest::merge([forest_a.clone(), forest_b.clone()]).unwrap(), + MastForest::merge([forest_b.clone(), forest_a.clone()]).unwrap(), ] .into_iter() .enumerate() @@ -535,8 +538,8 @@ fn mast_forest_merge_external_node_and_referenced_node_have_decorators() { forest_b.make_root(id_foo_b); for (idx, (merged, root_maps)) in [ - MastForest::merge([&forest_a, &forest_b]).unwrap(), - MastForest::merge([&forest_b, &forest_a]).unwrap(), + MastForest::merge([forest_a.clone(), forest_b.clone()]).unwrap(), + MastForest::merge([forest_b.clone(), forest_a.clone()]).unwrap(), ] .into_iter() .enumerate() @@ -611,8 +614,8 @@ fn mast_forest_merge_multiple_external_nodes_with_decorator() { forest_b.make_root(id_foo_b); for (idx, (merged, root_maps)) in [ - MastForest::merge([&forest_a, &forest_b]).unwrap(), - MastForest::merge([&forest_b, &forest_a]).unwrap(), + MastForest::merge([forest_a.clone(), forest_b.clone()]).unwrap(), + MastForest::merge([forest_b.clone(), forest_a.clone()]).unwrap(), ] .into_iter() .enumerate() @@ -665,6 +668,6 @@ fn mast_forest_merge_invalid_decorator_index() { forest_b.make_root(id_foo_b); - let err = MastForest::merge([&forest_a, &forest_b]).unwrap_err(); + let err = MastForest::merge([forest_a.clone(), forest_b.clone()]).unwrap_err(); assert_matches!(err, MastForestError::DecoratorIdOverflow(_, _)); } diff --git a/core/src/mast/mod.rs b/core/src/mast/mod.rs index 71cb01a83..a62cc9445 100644 --- a/core/src/mast/mod.rs +++ b/core/src/mast/mod.rs @@ -251,8 +251,8 @@ impl MastForest { /// which is effectively deduplication. Decorators are ignored when it comes to merging /// External nodes. This means that an External node with decorators may be replaced by a node /// without decorators or vice versa. - pub fn merge<'forest>( - forests: impl IntoIterator, + pub fn merge( + forests: impl IntoIterator, ) -> Result<(MastForest, Vec), MastForestError> { let mut root_maps = Vec::new(); let mut merger = MastForestMerger::new(); @@ -497,7 +497,7 @@ impl MastForest { /// /// The iteration on a high-level thus consists of a constant back and forth between discovering /// trees and returning nodes from the stack. - pub fn iter_nodes(&self) -> impl Iterator { + pub fn iter_nodes(self) -> impl Iterator { MastForestNodeIter::new(self) } } diff --git a/core/src/mast/node/basic_block_node/mod.rs b/core/src/mast/node/basic_block_node/mod.rs index 726decbc7..c725419c6 100644 --- a/core/src/mast/node/basic_block_node/mod.rs +++ b/core/src/mast/node/basic_block_node/mod.rs @@ -198,6 +198,18 @@ impl BasicBlockNode { .expect("basic block contains more than 2^32 operations and decorators") } + /// A specialized function to map the decorators of self to new values while leaving the + /// remaining parts as-is. + pub(crate) fn map_decorators( + &mut self, + decorator_map: impl Fn(&DecoratorId) -> Result, + ) -> Result<(), E> { + for (_, decorator) in self.decorators.iter_mut() { + *decorator = decorator_map(decorator)?; + } + Ok(()) + } + /// Returns an iterator over all operations and decorator, in the order in which they appear in /// the program. pub fn iter(&self) -> impl Iterator { diff --git a/core/src/mast/node_iterator.rs b/core/src/mast/node_iterator.rs index 59bb675d1..724d674a1 100644 --- a/core/src/mast/node_iterator.rs +++ b/core/src/mast/node_iterator.rs @@ -33,9 +33,9 @@ use crate::mast::{MastForest, MastNode, MastNodeId}; /// trees and returning nodes from the stack. /// /// Note: This type could be made more general to implement pre-order or in-order iteration too. -pub(crate) struct MastForestNodeIter<'forest> { +pub(crate) struct MastForestNodeIter { /// The forest that we're iterating. - pub mast_forest: &'forest MastForest, + pub mast_forest: MastForest, /// The procedure root index at which we last started a tree discovery. /// /// This value iterates through 0..mast_forest.num_procedures() which guarantees that we visit @@ -50,8 +50,8 @@ pub(crate) struct MastForestNodeIter<'forest> { pub unvisited_node_stack: Vec, } -impl<'forest> MastForestNodeIter<'forest> { - pub(crate) fn new(mast_forest: &'forest MastForest) -> Self { +impl MastForestNodeIter { + pub(crate) fn new(mast_forest: MastForest) -> Self { let visited = vec![false; mast_forest.num_nodes() as usize]; Self { @@ -93,22 +93,32 @@ impl<'forest> MastForestNodeIter<'forest> { self.mark_for_visit(root_idx); }, MastNode::Join(join_node) => { + let second = join_node.second(); + let first = join_node.first(); + self.mark_for_visit(root_idx); - self.discover_tree(join_node.second()); - self.discover_tree(join_node.first()); + self.discover_tree(second); + self.discover_tree(first); }, MastNode::Split(split_node) => { + let on_false = split_node.on_false(); + let on_true = split_node.on_true(); + self.mark_for_visit(root_idx); - self.discover_tree(split_node.on_false()); - self.discover_tree(split_node.on_true()); + self.discover_tree(on_false); + self.discover_tree(on_true); }, MastNode::Loop(loop_node) => { + let body = loop_node.body(); + self.mark_for_visit(root_idx); - self.discover_tree(loop_node.body()); + self.discover_tree(body); }, MastNode::Call(call_node) => { + let callee = call_node.callee(); + self.mark_for_visit(root_idx); - self.discover_tree(call_node.callee()); + self.discover_tree(callee); }, MastNode::Dyn(_) => { self.mark_for_visit(root_idx); @@ -142,14 +152,20 @@ impl<'forest> MastForestNodeIter<'forest> { } } -impl<'forest> Iterator for MastForestNodeIter<'forest> { - type Item = (MastNodeId, &'forest MastNode); +impl Iterator for MastForestNodeIter { + type Item = (MastNodeId, MastNode); fn next(&mut self) -> Option { if let Some(next_node_id) = self.unvisited_node_stack.pop() { // SAFETY: We only add valid ids to the stack so it's fine to index the forest nodes // directly. - let node = &self.mast_forest.nodes[next_node_id.as_usize()]; + let node = std::mem::replace( + &mut self.mast_forest.nodes[next_node_id.as_usize()], + // A Dyn node is the `MastNode` with the smallest memory footprint so we use it to + // replace the node we take out. + // Since we visit each node exactly once, this is fine. + MastNode::new_dyn(), + ); return Some((next_node_id, node)); }