From 9f9cc63b709d8b7c83841d8e421701dcd33792bb Mon Sep 17 00:00:00 2001 From: Philipp Gackstatter Date: Mon, 28 Oct 2024 15:49:47 +0100 Subject: [PATCH] feat: implement `MastForest` merging (#1534) --- CHANGELOG.md | 1 + assembly/src/assembler/mast_forest_builder.rs | 139 +-- .../src/assembler/mast_forest_merger_tests.rs | 73 ++ assembly/src/assembler/mod.rs | 3 + core/src/mast/merger/mod.rs | 411 +++++++++ core/src/mast/merger/tests.rs | 796 ++++++++++++++++++ core/src/mast/mod.rs | 240 +++++- core/src/mast/multi_forest_node_iterator.rs | 490 +++++++++++ stdlib/tests/main.rs | 1 + stdlib/tests/mast_forest_merge.rs | 19 + 10 files changed, 2036 insertions(+), 137 deletions(-) create mode 100644 assembly/src/assembler/mast_forest_merger_tests.rs create mode 100644 core/src/mast/merger/mod.rs create mode 100644 core/src/mast/merger/tests.rs create mode 100644 core/src/mast/multi_forest_node_iterator.rs create mode 100644 stdlib/tests/mast_forest_merge.rs diff --git a/CHANGELOG.md b/CHANGELOG.md index 273cbb9dd2..a1204ac13c 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -19,6 +19,7 @@ - [BREAKING] The `run` and the `prove` commands in the cli will accept `--trace` flag instead of `--tracing` (#1502) - Migrated to new padding rule for RPO (#1343). - Migrated to `miden-crypto` v0.11.0 (#1343). +- Implemented `MastForest` merging (#1534) #### Fixes diff --git a/assembly/src/assembler/mast_forest_builder.rs b/assembly/src/assembler/mast_forest_builder.rs index dc320c9ca0..72139b1a05 100644 --- a/assembly/src/assembler/mast_forest_builder.rs +++ b/assembly/src/assembler/mast_forest_builder.rs @@ -5,8 +5,8 @@ use alloc::{ use core::ops::{Index, IndexMut}; use vm_core::{ - crypto::hash::{Blake3Digest, Blake3_256, Digest, RpoDigest}, - mast::{DecoratorId, MastForest, MastNode, MastNodeId}, + crypto::hash::{Blake3Digest, RpoDigest}, + mast::{DecoratorId, EqHash, MastForest, MastNode, MastNodeId}, Decorator, DecoratorList, Operation, }; @@ -445,115 +445,9 @@ impl MastForestBuilder { } } -/// Helpers impl MastForestBuilder { fn eq_hash_for_node(&self, node: &MastNode) -> EqHash { - match node { - MastNode::Block(node) => { - let mut bytes_to_hash = Vec::new(); - - for &(idx, decorator_id) in node.decorators() { - bytes_to_hash.extend(idx.to_le_bytes()); - bytes_to_hash.extend(self[decorator_id].eq_hash().as_bytes()); - } - - // Add any `Assert` or `U32assert2` opcodes present, since these are not included in - // the MAST root. - for (op_idx, op) in node.operations().enumerate() { - if let Operation::U32assert2(inner_value) - | Operation::Assert(inner_value) - | Operation::MpVerify(inner_value) = op - { - let op_idx: u32 = op_idx - .try_into() - .expect("there are more than 2^{32}-1 operations in basic block"); - - // we include the opcode to differentiate between `Assert` and `U32assert2` - bytes_to_hash.push(op.op_code()); - // we include the operation index to distinguish between basic blocks that - // would have the same assert instructions, but in a different order - bytes_to_hash.extend(op_idx.to_le_bytes()); - bytes_to_hash.extend(inner_value.to_le_bytes()); - } - } - - if bytes_to_hash.is_empty() { - EqHash::new(node.digest()) - } else { - let decorator_root = Blake3_256::hash(&bytes_to_hash); - EqHash::with_decorator_root(node.digest(), decorator_root) - } - }, - MastNode::Join(node) => self.eq_hash_from_parts( - node.before_enter(), - node.after_exit(), - &[node.first(), node.second()], - node.digest(), - ), - MastNode::Split(node) => self.eq_hash_from_parts( - node.before_enter(), - node.after_exit(), - &[node.on_true(), node.on_false()], - node.digest(), - ), - MastNode::Loop(node) => self.eq_hash_from_parts( - node.before_enter(), - node.after_exit(), - &[node.body()], - node.digest(), - ), - MastNode::Call(node) => self.eq_hash_from_parts( - node.before_enter(), - node.after_exit(), - &[node.callee()], - node.digest(), - ), - MastNode::Dyn(node) => { - self.eq_hash_from_parts(node.before_enter(), node.after_exit(), &[], node.digest()) - }, - MastNode::External(node) => { - self.eq_hash_from_parts(node.before_enter(), node.after_exit(), &[], node.digest()) - }, - } - } - - fn eq_hash_from_parts( - &self, - before_enter_ids: &[DecoratorId], - after_exit_ids: &[DecoratorId], - children_ids: &[MastNodeId], - node_digest: RpoDigest, - ) -> EqHash { - let pre_decorator_hash_bytes = - before_enter_ids.iter().flat_map(|&id| self[id].eq_hash().as_bytes()); - let post_decorator_hash_bytes = - after_exit_ids.iter().flat_map(|&id| self[id].eq_hash().as_bytes()); - - // Reminder: the `EqHash`'s decorator root will be `None` if and only if there are no - // decorators attached to the node, and all children have no decorator roots (meaning that - // there are no decorators in all the descendants). - if pre_decorator_hash_bytes.clone().next().is_none() - && post_decorator_hash_bytes.clone().next().is_none() - && children_ids - .iter() - .filter_map(|child_id| self.hash_by_node_id[child_id].decorator_root) - .next() - .is_none() - { - EqHash::new(node_digest) - } else { - let children_decorator_roots = children_ids - .iter() - .filter_map(|child_id| self.hash_by_node_id[child_id].decorator_root) - .flat_map(|decorator_root| decorator_root.as_bytes()); - let decorator_bytes_to_hash: Vec = pre_decorator_hash_bytes - .chain(post_decorator_hash_bytes) - .chain(children_decorator_roots) - .collect(); - - let decorator_root = Blake3_256::hash(&decorator_bytes_to_hash); - EqHash::with_decorator_root(node_digest, decorator_root) - } + EqHash::from_mast_node(&self.mast_forest, &self.hash_by_node_id, node) } } @@ -582,33 +476,6 @@ impl IndexMut for MastForestBuilder { } } -// EQ HASH -// ================================================================================================ - -/// Represents the hash used to test for equality between [`MastNode`]s. -/// -/// The decorator root will be `None` if and only if there are no decorators attached to the node, -/// and all children have no decorator roots (meaning that there are no decorators in all the -/// descendants). -#[derive(Debug, Clone, Copy, PartialEq, Eq, PartialOrd, Ord)] -struct EqHash { - mast_root: RpoDigest, - decorator_root: Option>, -} - -impl EqHash { - fn new(mast_root: RpoDigest) -> Self { - Self { mast_root, decorator_root: None } - } - - fn with_decorator_root(mast_root: RpoDigest, decorator_root: Blake3Digest<32>) -> Self { - Self { - mast_root, - decorator_root: Some(decorator_root), - } - } -} - // HELPER FUNCTIONS // ================================================================================================ diff --git a/assembly/src/assembler/mast_forest_merger_tests.rs b/assembly/src/assembler/mast_forest_merger_tests.rs new file mode 100644 index 0000000000..96e533992c --- /dev/null +++ b/assembly/src/assembler/mast_forest_merger_tests.rs @@ -0,0 +1,73 @@ +use miette::{IntoDiagnostic, Report}; +use vm_core::mast::{MastForest, MastForestRootMap}; + +use crate::{testing::TestContext, Assembler}; + +#[allow(clippy::type_complexity)] +fn merge_programs( + program_a: &str, + program_b: &str, +) -> Result<(MastForest, MastForest, MastForest, MastForestRootMap), Report> { + let context = TestContext::new(); + let module = context.parse_module_with_path("lib::mod".parse().unwrap(), program_a)?; + + let lib_a = Assembler::new(context.source_manager()).assemble_library([module])?; + + let mut assembler = Assembler::new(context.source_manager()); + assembler.add_library(lib_a.clone())?; + let lib_b = assembler.assemble_library([program_b])?.mast_forest().as_ref().clone(); + let lib_a = lib_a.mast_forest().as_ref().clone(); + + let (merged, root_maps) = MastForest::merge([&lib_a, &lib_b]).into_diagnostic()?; + + Ok((lib_a, lib_b, merged, root_maps)) +} + +/// Tests that an assembler-produced library's forests can be merged and that external nodes are +/// replaced by their referenced procedures. +#[test] +fn mast_forest_merge_assembler() { + let lib_a = r#" + export.foo + push.19 + end + + export.qux + swap drop + end +"#; + + let lib_b = r#" + use.lib::mod + + export.qux_duplicate + swap drop + end + + export.bar + push.2 + if.true + push.3 + else + while.true + add + push.23 + end + end + exec.mod::foo + end"#; + + let (forest_a, forest_b, merged, root_maps) = merge_programs(lib_a, lib_b).unwrap(); + + for (forest_idx, forest) in [forest_a, forest_b].into_iter().enumerate() { + for root in forest.procedure_roots() { + let original_digest = forest.nodes()[root.as_usize()].digest(); + let new_root = root_maps.map_root(forest_idx, root).unwrap(); + let new_digest = merged.nodes()[new_root.as_usize()].digest(); + assert_eq!(original_digest, new_digest); + } + } + + // Assert that the external node for the import was removed during merging. + merged.nodes().iter().for_each(|node| assert!(!node.is_external())); +} diff --git a/assembly/src/assembler/mod.rs b/assembly/src/assembler/mod.rs index 6575056e9b..2ddc94e31a 100644 --- a/assembly/src/assembler/mod.rs +++ b/assembly/src/assembler/mod.rs @@ -28,6 +28,9 @@ mod procedure; #[cfg(test)] mod tests; +#[cfg(test)] +mod mast_forest_merger_tests; + use self::{ basic_block_builder::BasicBlockBuilder, module_graph::{CallerInfo, ModuleGraph, ResolvedTarget}, diff --git a/core/src/mast/merger/mod.rs b/core/src/mast/merger/mod.rs new file mode 100644 index 0000000000..d8e522dc0d --- /dev/null +++ b/core/src/mast/merger/mod.rs @@ -0,0 +1,411 @@ +use alloc::{collections::BTreeMap, vec::Vec}; + +use miden_crypto::hash::blake::Blake3Digest; + +use crate::mast::{ + DecoratorId, EqHash, MastForest, MastForestError, MastNode, MastNodeId, + MultiMastForestIteratorItem, MultiMastForestNodeIter, +}; + +#[cfg(test)] +mod tests; + +/// A type that allows merging [`MastForest`]s. +/// +/// This functionality is exposed via [`MastForest::merge`]. See its documentation for more details. +pub(crate) struct MastForestMerger { + mast_forest: MastForest, + // Internal indices needed for efficient duplicate checking and EqHash computation. + // + // These are always in-sync with the nodes in `mast_forest`, i.e. all nodes added to the + // `mast_forest` are also added to the indices. + node_id_by_hash: BTreeMap, + hash_by_node_id: BTreeMap, + decorators_by_hash: BTreeMap, DecoratorId>, + /// Mappings from old decorator and node ids to their new ids. + /// + /// Any decorator in `mast_forest` is present as the target of some mapping in this map. + decorator_id_mappings: Vec, + /// Mappings from previous `MastNodeId`s to their new ids. + /// + /// Any `MastNodeId` in `mast_forest` is present as the target of some mapping in this map. + node_id_mappings: Vec, +} + +impl MastForestMerger { + /// Creates a new merger with an initially empty forest and merges all provided [`MastForest`]s + /// into it. + pub(crate) fn merge<'forest>( + forests: impl IntoIterator, + ) -> Result<(MastForest, MastForestRootMap), MastForestError> { + let forests = forests.into_iter().collect::>(); + let decorator_id_mappings = Vec::with_capacity(forests.len()); + let node_id_mappings = vec![MastForestNodeIdMap::new(); forests.len()]; + + let mut merger = Self { + node_id_by_hash: BTreeMap::new(), + hash_by_node_id: BTreeMap::new(), + decorators_by_hash: BTreeMap::new(), + mast_forest: MastForest::new(), + decorator_id_mappings, + node_id_mappings, + }; + + merger.merge_inner(forests.clone())?; + + let Self { mast_forest, node_id_mappings, .. } = merger; + + let root_maps = MastForestRootMap::from_node_id_map(node_id_mappings, forests); + + Ok((mast_forest, root_maps)) + } + + /// Merges all `forests` into self. + /// + /// It does this in three steps: + /// + /// 1. Merge all decorators, which is a case of deduplication and creating a decorator id + /// mapping which contains how existing [`DecoratorId`]s map to [`DecoratorId`]s in the + /// merged forest. + /// 2. Merge all nodes of forests. + /// - Similar to decorators, node indices might move during merging, so the merger keeps a + /// node id mapping as it merges nodes. + /// - This is a depth-first traversal over all forests to ensure all children are processed + /// before their parents. See the documentation of [`MultiMastForestNodeIter`] for details + /// on this traversal. + /// - Because all parents are processed after their children, we can use the node id mapping + /// to remap all [`MastNodeId`]s of the children to their potentially new id in the merged + /// forest. + /// - If any external node is encountered during this traversal with a digest `foo` for which + /// a `replacement` node exists in another forest with digest `foo`, then the external node + /// will be replaced by that node. In particular, it means we do not want to add the + /// external node to the merged forest, so it is never yielded from the iterator. + /// - Assuming the simple case, where the `replacement` was not visited yet and is just a + /// single node (not a tree), the iterator would first yield the `replacement` node which + /// means it is going to be merged into the forest. + /// - Next the iterator yields [`MultiMastForestIteratorItem::ExternalNodeReplacement`] + /// which signals that an external node was replaced by another node. In this example, + /// the `replacement_*` indices contained in that variant would point to the + /// `replacement` node. Now we can simply add a mapping from the external node to the + /// `replacement` node in our node id mapping which means all nodes that referenced the + /// external node will point to the `replacement` instead. + /// 3. Finally, we merge all roots of all forests. Here we map the existing root indices to + /// their potentially new indices in the merged forest and add them to the forest, + /// deduplicating in the process, too. + fn merge_inner(&mut self, forests: Vec<&MastForest>) -> Result<(), MastForestError> { + for other_forest in forests.iter() { + self.merge_decorators(other_forest)?; + } + + let iterator = MultiMastForestNodeIter::new(forests.clone()); + for item in iterator { + match item { + MultiMastForestIteratorItem::Node { forest_idx, node_id } => { + let node = &forests[forest_idx][node_id]; + self.merge_node(forest_idx, node_id, node)?; + }, + MultiMastForestIteratorItem::ExternalNodeReplacement { + // forest index of the node which replaces the external node + replacement_forest_idx, + // ID of the node that replaces the external node + replacement_mast_node_id, + // forest index of the external node + replaced_forest_idx, + // ID of the external node + replaced_mast_node_id, + } => { + // The iterator is not aware of the merged forest, so the node indices it yields + // are for the existing forests. That means we have to map the ID of the + // replacement to its new location, since it was previously merged and its IDs + // have very likely changed. + let mapped_replacement = self.node_id_mappings[replacement_forest_idx] + .get(&replacement_mast_node_id) + .copied() + .expect("every merged node id should be mapped"); + + // SAFETY: The iterator only yields valid forest indices, so it is safe to index + // directly. + self.node_id_mappings[replaced_forest_idx] + .insert(replaced_mast_node_id, mapped_replacement); + }, + } + } + + for (forest_idx, forest) in forests.iter().enumerate() { + self.merge_roots(forest_idx, forest)?; + } + + Ok(()) + } + + fn merge_decorators(&mut self, other_forest: &MastForest) -> Result<(), MastForestError> { + let mut decorator_id_remapping = DecoratorIdMap::new(other_forest.decorators.len()); + + for (merging_id, merging_decorator) in other_forest.decorators.iter().enumerate() { + let merging_decorator_hash = merging_decorator.eq_hash(); + let new_decorator_id = if let Some(existing_decorator) = + self.decorators_by_hash.get(&merging_decorator_hash) + { + *existing_decorator + } else { + let new_decorator_id = self.mast_forest.add_decorator(merging_decorator.clone())?; + self.decorators_by_hash.insert(merging_decorator_hash, new_decorator_id); + new_decorator_id + }; + + decorator_id_remapping + .insert(DecoratorId::new_unchecked(merging_id as u32), new_decorator_id); + } + + self.decorator_id_mappings.push(decorator_id_remapping); + + Ok(()) + } + + fn merge_node( + &mut self, + forest_idx: usize, + merging_id: MastNodeId, + node: &MastNode, + ) -> Result<(), MastForestError> { + // We need to remap the node prior to computing the EqHash. + // + // This is because the EqHash computation looks up its descendants and decorators in + // the internal index, and if we were to pass the original node to that + // computation, it would look up the incorrect descendants and decorators (since the + // descendant's indices may have changed). + // + // Remapping at this point is guaranteed to be "complete", meaning all ids of children + // will be present in the node id mapping since the DFS iteration guarantees + // that all children of this `node` have been processed before this node and + // their indices have been added to the mappings. + let remapped_node = self.remap_node(forest_idx, node)?; + + let node_fingerprint = + EqHash::from_mast_node(&self.mast_forest, &self.hash_by_node_id, &remapped_node); + + match self.lookup_node_by_fingerprint(&node_fingerprint) { + Some(matching_node_id) => { + // If a node with a matching fingerprint exists, then the merging node is a + // duplicate and we remap it to the existing node. + self.node_id_mappings[forest_idx].insert(merging_id, matching_node_id); + }, + None => { + // If no node with a matching fingerprint exists, then the merging node is + // unique and we can add it to the merged forest. + let new_node_id = self.mast_forest.add_node(remapped_node)?; + self.node_id_mappings[forest_idx].insert(merging_id, new_node_id); + + // We need to update the indices with the newly inserted nodes + // since the EqHash computation requires all descendants of a node + // to be in this index. Hence when we encounter a node in the merging forest + // which has descendants (Call, Loop, Split, ...), then their descendants need to be + // in the indices. + self.node_id_by_hash.insert(node_fingerprint, new_node_id); + self.hash_by_node_id.insert(new_node_id, node_fingerprint); + }, + } + + Ok(()) + } + + fn merge_roots( + &mut self, + forest_idx: usize, + other_forest: &MastForest, + ) -> Result<(), MastForestError> { + for root_id in other_forest.roots.iter() { + // Map the previous root to its possibly new id. + let new_root = self.node_id_mappings[forest_idx] + .get(root_id) + .expect("all node ids should have an entry"); + // This takes O(n) where n is the number of roots in the merged forest every time to + // check if the root already exists. As the number of roots is relatively low generally, + // this should be okay. + self.mast_forest.make_root(*new_root); + } + + Ok(()) + } + + /// Remaps a nodes' potentially contained children and decorators to their new IDs according to + /// the given maps. + fn remap_node(&self, forest_idx: usize, node: &MastNode) -> Result { + let map_decorator_id = |decorator_id: &DecoratorId| { + self.decorator_id_mappings[forest_idx].get(decorator_id).ok_or_else(|| { + MastForestError::DecoratorIdOverflow( + *decorator_id, + self.decorator_id_mappings[forest_idx].len(), + ) + }) + }; + let map_decorators = |decorators: &[DecoratorId]| -> Result, MastForestError> { + decorators.iter().map(map_decorator_id).collect() + }; + + let map_node_id = |node_id: MastNodeId| { + self.node_id_mappings[forest_idx] + .get(&node_id) + .copied() + .expect("every node id should have an entry") + }; + + // Due to DFS postorder iteration all children of node's should have been inserted before + // their parents which is why we can `expect` the constructor calls here. + let mut mapped_node = match node { + MastNode::Join(join_node) => { + let first = map_node_id(join_node.first()); + let second = map_node_id(join_node.second()); + + MastNode::new_join(first, second, &self.mast_forest) + .expect("JoinNode children should have been mapped to a lower index") + }, + MastNode::Split(split_node) => { + let if_branch = map_node_id(split_node.on_true()); + let else_branch = map_node_id(split_node.on_false()); + + MastNode::new_split(if_branch, else_branch, &self.mast_forest) + .expect("SplitNode children should have been mapped to a lower index") + }, + MastNode::Loop(loop_node) => { + let body = map_node_id(loop_node.body()); + MastNode::new_loop(body, &self.mast_forest) + .expect("LoopNode children should have been mapped to a lower index") + }, + MastNode::Call(call_node) => { + let callee = map_node_id(call_node.callee()); + MastNode::new_call(callee, &self.mast_forest) + .expect("CallNode children should have been mapped to a lower index") + }, + // Other nodes are simply copied. + MastNode::Block(basic_block_node) => { + MastNode::new_basic_block( + basic_block_node.operations().copied().collect(), + // Operation Indices of decorators stay the same while decorator IDs need to be + // mapped. + Some( + basic_block_node + .decorators() + .iter() + .map(|(idx, decorator_id)| match map_decorator_id(decorator_id) { + Ok(mapped_decorator) => Ok((*idx, mapped_decorator)), + Err(err) => Err(err), + }) + .collect::, _>>()?, + ), + ) + .expect("previously valid BasicBlockNode should still be valid") + }, + MastNode::Dyn(_) => MastNode::new_dyn(), + MastNode::External(external_node) => MastNode::new_external(external_node.digest()), + }; + + // Decorators must be handled specially for basic block nodes. + // For other node types we can handle it centrally. + if !mapped_node.is_basic_block() { + mapped_node.set_before_enter(map_decorators(node.before_enter())?); + mapped_node.set_after_exit(map_decorators(node.after_exit())?); + } + + Ok(mapped_node) + } + + // HELPERS + // ================================================================================================ + + /// Returns a slice of nodes in the merged forest which have the given `mast_root`. + fn lookup_node_by_fingerprint(&self, fingerprint: &EqHash) -> Option { + self.node_id_by_hash.get(fingerprint).copied() + } +} + +// MAST FOREST ROOT MAP +// ================================================================================================ + +/// A mapping for the new location of the roots of a [`MastForest`] after a merge. +/// +/// It maps the roots ([`MastNodeId`]s) of a forest to their new [`MastNodeId`] in the merged +/// forest. See [`MastForest::merge`] for more details. +#[derive(Debug, Clone, PartialEq, Eq)] +pub struct MastForestRootMap { + root_maps: Vec>, +} + +impl MastForestRootMap { + fn from_node_id_map(id_map: Vec, forests: Vec<&MastForest>) -> Self { + let mut root_maps = vec![BTreeMap::new(); forests.len()]; + + for (forest_idx, forest) in forests.into_iter().enumerate() { + for root in forest.procedure_roots() { + let new_id = id_map[forest_idx] + .get(root) + .copied() + .expect("every node id should be mapped to its new id"); + root_maps[forest_idx].insert(*root, new_id); + } + } + + Self { root_maps } + } + + /// Maps the given root to its new location in the merged forest, if such a mapping exists. + /// + /// It is guaranteed that every root of the map's corresponding forest is contained in the map. + pub fn map_root(&self, forest_index: usize, root: &MastNodeId) -> Option { + self.root_maps.get(forest_index).and_then(|map| map.get(root)).copied() + } +} + +// DECORATOR ID MAP +// ================================================================================================ + +/// A specialized map from [`DecoratorId`] -> [`DecoratorId`]. +/// +/// When mapping Decorator IDs during merging, we always map all IDs of the merging +/// forest to new ids. Hence it is more efficient to use a `Vec` instead of, say, a `BTreeMap`. +/// +/// In other words, this type is similar to `BTreeMap` but takes advantage of the fact that +/// the keys are contiguous. +/// +/// This type is meant to encapsulates some guarantees: +/// +/// - Indexing into the vector for any ID is safe if that ID is valid for the corresponding forest. +/// Despite that, we still cannot index unconditionally in case a node with invalid +/// [`DecoratorId`]s is passed to `merge`. +/// - The entry itself can be either None or Some. However: +/// - For `DecoratorId`s we iterate and insert all decorators into this map before retrieving any +/// entry, so all entries contain `Some`. Because of this, we can use `expect` in `get` for the +/// `Option` value. +/// - Similarly, inserting any ID from the corresponding forest is safe as the map contains a +/// pre-allocated `Vec` of the appropriate size. +struct DecoratorIdMap { + inner: Vec>, +} + +impl DecoratorIdMap { + fn new(num_ids: usize) -> Self { + Self { inner: vec![None; num_ids] } + } + + /// Maps the given key to the given value. + /// + /// It is the caller's responsibility to only pass keys that belong to the forest for which this + /// map was originally created. + fn insert(&mut self, key: DecoratorId, value: DecoratorId) { + self.inner[key.as_usize()] = Some(value); + } + + /// Retrieves the value for the given key. + fn get(&self, key: &DecoratorId) -> Option { + self.inner + .get(key.as_usize()) + .map(|id| id.expect("every id should have a Some entry in the map when calling get")) + } + + fn len(&self) -> usize { + self.inner.len() + } +} + +/// A type definition for increased readability in function signatures. +type MastForestNodeIdMap = BTreeMap; diff --git a/core/src/mast/merger/tests.rs b/core/src/mast/merger/tests.rs new file mode 100644 index 0000000000..881e22c2ca --- /dev/null +++ b/core/src/mast/merger/tests.rs @@ -0,0 +1,796 @@ +use miden_crypto::{hash::rpo::RpoDigest, ONE}; + +use super::*; +use crate::{Decorator, Operation}; + +fn block_foo() -> MastNode { + MastNode::new_basic_block(vec![Operation::Mul, Operation::Add], None).unwrap() +} + +fn block_bar() -> MastNode { + MastNode::new_basic_block(vec![Operation::And, Operation::Eq], None).unwrap() +} + +fn block_qux() -> MastNode { + MastNode::new_basic_block(vec![Operation::Swap, Operation::Push(ONE), Operation::Eq], None) + .unwrap() +} + +/// Asserts that the given forest contains exactly one node with the given digest. +/// +/// Returns a Result which can be unwrapped in the calling test function to assert. This way, if +/// this assertion fails it'll be clear which exact call failed. +fn assert_contains_node_once(forest: &MastForest, digest: RpoDigest) -> Result<(), &str> { + if forest.nodes.iter().filter(|node| node.digest() == digest).count() != 1 { + return Err("node digest contained more than once in the forest"); + } + + Ok(()) +} + +/// Asserts that every root of an original forest has an id to which it is mapped and that this +/// mapped root is in the set of roots in the merged forest. +/// +/// Returns a Result which can be unwrapped in the calling test function to assert. This way, if +/// this assertion fails it'll be clear which exact call failed. +fn assert_root_mapping( + root_map: &MastForestRootMap, + original_roots: Vec<&[MastNodeId]>, + merged_roots: &[MastNodeId], +) -> Result<(), &'static str> { + for (forest_idx, original_root) in original_roots.into_iter().enumerate() { + for root in original_root { + let mapped_root = root_map.map_root(forest_idx, root).unwrap(); + if !merged_roots.contains(&mapped_root) { + return Err("merged root does not contain mapped root"); + } + } + } + + Ok(()) +} + +/// Asserts that all children of nodes in the given forest have an id that is less than the parent's +/// ID. +/// +/// Returns a Result which can be unwrapped in the calling test function to assert. This way, if +/// this assertion fails it'll be clear which exact call failed. +fn assert_child_id_lt_parent_id(forest: &MastForest) -> Result<(), &str> { + for (mast_node_id, node) in forest.nodes().iter().enumerate() { + match node { + MastNode::Join(join_node) => { + if !join_node.first().as_usize() < mast_node_id { + return Err("join node first child id is not < parent id"); + }; + if !join_node.second().as_usize() < mast_node_id { + return Err("join node second child id is not < parent id"); + } + }, + MastNode::Split(split_node) => { + if !split_node.on_true().as_usize() < mast_node_id { + return Err("split node on true id is not < parent id"); + } + if !split_node.on_false().as_usize() < mast_node_id { + return Err("split node on false id is not < parent id"); + } + }, + MastNode::Loop(loop_node) => { + if !loop_node.body().as_usize() < mast_node_id { + return Err("loop node body id is not < parent id"); + } + }, + MastNode::Call(call_node) => { + if !call_node.callee().as_usize() < mast_node_id { + return Err("call node callee id is not < parent id"); + } + }, + MastNode::Block(_) => (), + MastNode::Dyn(_) => (), + MastNode::External(_) => (), + } + } + + Ok(()) +} + +/// Tests that Call(bar) still correctly calls the remapped bar block. +/// +/// [Block(foo), Call(foo)] +/// + +/// [Block(bar), Call(bar)] +/// = +/// [Block(foo), Call(foo), Block(bar), Call(bar)] +#[test] +fn mast_forest_merge_remap() { + let mut forest_a = MastForest::new(); + let id_foo = forest_a.add_node(block_foo()).unwrap(); + let id_call_a = forest_a.add_call(id_foo).unwrap(); + forest_a.make_root(id_call_a); + + let mut forest_b = MastForest::new(); + let id_bar = forest_b.add_node(block_bar()).unwrap(); + let id_call_b = forest_b.add_call(id_bar).unwrap(); + forest_b.make_root(id_call_b); + + let (merged, root_maps) = MastForest::merge([&forest_a, &forest_b]).unwrap(); + + assert_eq!(merged.nodes().len(), 4); + assert_eq!(merged.nodes()[0], block_foo()); + assert_matches!(&merged.nodes()[1], MastNode::Call(call_node) if call_node.callee().as_u32() == 0); + assert_eq!(merged.nodes()[2], block_bar()); + assert_matches!(&merged.nodes()[3], MastNode::Call(call_node) if call_node.callee().as_u32() == 2); + + assert_eq!(root_maps.map_root(0, &id_call_a).unwrap().as_u32(), 1); + assert_eq!(root_maps.map_root(1, &id_call_b).unwrap().as_u32(), 3); + + assert_child_id_lt_parent_id(&merged).unwrap(); +} + +/// Tests that Forest_A + Forest_A = Forest_A (i.e. duplicates are removed). +#[test] +fn mast_forest_merge_duplicate() { + let mut forest_a = MastForest::new(); + forest_a.add_decorator(Decorator::Debug(crate::DebugOptions::MemAll)).unwrap(); + forest_a.add_decorator(Decorator::Trace(25)).unwrap(); + + let id_external = forest_a.add_external(block_bar().digest()).unwrap(); + let id_foo = forest_a.add_node(block_foo()).unwrap(); + let id_call = forest_a.add_call(id_foo).unwrap(); + let id_loop = forest_a.add_loop(id_external).unwrap(); + forest_a.make_root(id_call); + forest_a.make_root(id_loop); + + let (merged, root_maps) = MastForest::merge([&forest_a, &forest_a]).unwrap(); + + for merged_root in merged.procedure_digests() { + forest_a.procedure_digests().find(|root| root == &merged_root).unwrap(); + } + + // Both maps should map the roots to the same target id. + for original_root in forest_a.procedure_roots() { + assert_eq!(&root_maps.map_root(0, original_root), &root_maps.map_root(1, original_root)); + } + + for merged_node in merged.nodes().iter().map(MastNode::digest) { + forest_a.nodes.iter().find(|node| node.digest() == merged_node).unwrap(); + } + + for merged_decorator in merged.decorators.iter() { + assert!(forest_a.decorators.contains(merged_decorator)); + } + + assert_child_id_lt_parent_id(&merged).unwrap(); +} + +/// Tests that External(foo) is replaced by Block(foo) whether it is in forest A or B, and the +/// duplicate Call is removed. +/// +/// [External(foo), Call(foo)] +/// + +/// [Block(foo), Call(foo)] +/// = +/// [Block(foo), Call(foo)] +/// + +/// [External(foo), Call(foo)] +/// = +/// [Block(foo), Call(foo)] +#[test] +fn mast_forest_merge_replace_external() { + let mut forest_a = MastForest::new(); + let id_foo_a = forest_a.add_external(block_foo().digest()).unwrap(); + let id_call_a = forest_a.add_call(id_foo_a).unwrap(); + forest_a.make_root(id_call_a); + + let mut forest_b = MastForest::new(); + let id_foo_b = forest_b.add_node(block_foo()).unwrap(); + let id_call_b = forest_b.add_call(id_foo_b).unwrap(); + forest_b.make_root(id_call_b); + + let (merged_ab, root_maps_ab) = MastForest::merge([&forest_a, &forest_b]).unwrap(); + let (merged_ba, root_maps_ba) = MastForest::merge([&forest_b, &forest_a]).unwrap(); + + for (merged, root_map) in [(merged_ab, root_maps_ab), (merged_ba, root_maps_ba)] { + assert_eq!(merged.nodes().len(), 2); + assert_eq!(merged.nodes()[0], block_foo()); + assert_matches!(&merged.nodes()[1], MastNode::Call(call_node) if call_node.callee().as_u32() == 0); + // The only root node should be the call node. + assert_eq!(merged.roots.len(), 1); + assert_eq!(root_map.map_root(0, &id_call_a).unwrap().as_usize(), 1); + assert_eq!(root_map.map_root(1, &id_call_b).unwrap().as_usize(), 1); + assert_child_id_lt_parent_id(&merged).unwrap(); + } +} + +/// Test that roots are preserved and deduplicated if appropriate. +/// +/// Nodes: [Block(foo), Call(foo)] +/// Roots: [Call(foo)] +/// + +/// Nodes: [Block(foo), Block(bar), Call(foo)] +/// Roots: [Block(bar), Call(foo)] +/// = +/// Nodes: [Block(foo), Block(bar), Call(foo)] +/// Roots: [Block(bar), Call(foo)] +#[test] +fn mast_forest_merge_roots() { + let mut forest_a = MastForest::new(); + let id_foo_a = forest_a.add_node(block_foo()).unwrap(); + let call_a = forest_a.add_call(id_foo_a).unwrap(); + forest_a.make_root(call_a); + + let mut forest_b = MastForest::new(); + let id_foo_b = forest_b.add_node(block_foo()).unwrap(); + let id_bar_b = forest_b.add_node(block_bar()).unwrap(); + let call_b = forest_b.add_call(id_foo_b).unwrap(); + forest_b.make_root(id_bar_b); + forest_b.make_root(call_b); + + let root_digest_call_a = forest_a.get_node_by_id(call_a).unwrap().digest(); + let root_digest_bar_b = forest_b.get_node_by_id(id_bar_b).unwrap().digest(); + let root_digest_call_b = forest_b.get_node_by_id(call_b).unwrap().digest(); + + let (merged, root_maps) = MastForest::merge([&forest_a, &forest_b]).unwrap(); + + // Asserts (together with the other assertions) that the duplicate Call(foo) roots have been + // deduplicated. + assert_eq!(merged.procedure_roots().len(), 2); + + // Assert that all root digests from A an B are still roots in the merged forest. + let root_digests = merged.procedure_digests().collect::>(); + assert!(root_digests.contains(&root_digest_call_a)); + assert!(root_digests.contains(&root_digest_bar_b)); + assert!(root_digests.contains(&root_digest_call_b)); + + assert_root_mapping(&root_maps, vec![&forest_a.roots, &forest_b.roots], &merged.roots).unwrap(); + + assert_child_id_lt_parent_id(&merged).unwrap(); +} + +/// Test that multiple trees can be merged when the same merger is reused. +/// +/// Nodes: [Block(foo), Call(foo)] +/// Roots: [Call(foo)] +/// + +/// Nodes: [Block(foo), Block(bar), Call(foo)] +/// Roots: [Block(bar), Call(foo)] +/// + +/// Nodes: [Block(foo), Block(qux), Call(foo)] +/// Roots: [Block(qux), Call(foo)] +/// = +/// Nodes: [Block(foo), Block(bar), Block(qux), Call(foo)] +/// Roots: [Block(bar), Block(qux), Call(foo)] +#[test] +fn mast_forest_merge_multiple() { + let mut forest_a = MastForest::new(); + let id_foo_a = forest_a.add_node(block_foo()).unwrap(); + let call_a = forest_a.add_call(id_foo_a).unwrap(); + forest_a.make_root(call_a); + + let mut forest_b = MastForest::new(); + let id_foo_b = forest_b.add_node(block_foo()).unwrap(); + let id_bar_b = forest_b.add_node(block_bar()).unwrap(); + let call_b = forest_b.add_call(id_foo_b).unwrap(); + forest_b.make_root(id_bar_b); + forest_b.make_root(call_b); + + let mut forest_c = MastForest::new(); + let id_foo_c = forest_c.add_node(block_foo()).unwrap(); + let id_qux_c = forest_c.add_node(block_qux()).unwrap(); + let call_c = forest_c.add_call(id_foo_c).unwrap(); + forest_c.make_root(id_qux_c); + forest_c.make_root(call_c); + + let (merged, root_maps) = MastForest::merge([&forest_a, &forest_b, &forest_c]).unwrap(); + + let block_foo_digest = forest_b.get_node_by_id(id_foo_b).unwrap().digest(); + let block_bar_digest = forest_b.get_node_by_id(id_bar_b).unwrap().digest(); + let call_foo_digest = forest_b.get_node_by_id(call_b).unwrap().digest(); + let block_qux_digest = forest_c.get_node_by_id(id_qux_c).unwrap().digest(); + + assert_eq!(merged.procedure_roots().len(), 3); + + let root_digests = merged.procedure_digests().collect::>(); + assert!(root_digests.contains(&call_foo_digest)); + assert!(root_digests.contains(&block_bar_digest)); + assert!(root_digests.contains(&block_qux_digest)); + + assert_contains_node_once(&merged, block_foo_digest).unwrap(); + assert_contains_node_once(&merged, block_bar_digest).unwrap(); + assert_contains_node_once(&merged, block_qux_digest).unwrap(); + assert_contains_node_once(&merged, call_foo_digest).unwrap(); + + assert_root_mapping( + &root_maps, + vec![&forest_a.roots, &forest_b.roots, &forest_c.roots], + &merged.roots, + ) + .unwrap(); + + assert_child_id_lt_parent_id(&merged).unwrap(); +} + +/// Tests that decorators are merged and that nodes who are identical except for their +/// decorators are not deduplicated. +/// +/// Note in particular that the `Loop` nodes only differ in their decorator which ensures that +/// the merging takes decorators into account. +/// +/// Nodes: [Block(foo, [Trace(1), Trace(2)]), Loop(foo, [Trace(0), Trace(2)])] +/// Decorators: [Trace(0), Trace(1), Trace(2)] +/// + +/// Nodes: [Block(foo, [Trace(1), Trace(2)]), Loop(foo, [Trace(1), Trace(3)])] +/// Decorators: [Trace(1), Trace(2), Trace(3)] +/// = +/// Nodes: [ +/// Block(foo, [Trace(1), Trace(2)]), +/// Loop(foo, [Trace(0), Trace(2)]), +/// Loop(foo, [Trace(1), Trace(3)]), +/// ] +/// Decorators: [Trace(0), Trace(1), Trace(2), Trace(3)] +#[test] +fn mast_forest_merge_decorators() { + let mut forest_a = MastForest::new(); + let trace0 = Decorator::Trace(0); + let trace1 = Decorator::Trace(1); + let trace2 = Decorator::Trace(2); + let trace3 = Decorator::Trace(3); + + // Build Forest A + let deco0_a = forest_a.add_decorator(trace0.clone()).unwrap(); + let deco1_a = forest_a.add_decorator(trace1.clone()).unwrap(); + let deco2_a = forest_a.add_decorator(trace2.clone()).unwrap(); + + let mut foo_node_a = block_foo(); + foo_node_a.set_before_enter(vec![deco1_a, deco2_a]); + let id_foo_a = forest_a.add_node(foo_node_a).unwrap(); + + let mut loop_node_a = MastNode::new_loop(id_foo_a, &forest_a).unwrap(); + loop_node_a.set_after_exit(vec![deco0_a, deco2_a]); + let id_loop_a = forest_a.add_node(loop_node_a).unwrap(); + + forest_a.make_root(id_loop_a); + + // Build Forest B + let mut forest_b = MastForest::new(); + let deco1_b = forest_b.add_decorator(trace1.clone()).unwrap(); + let deco2_b = forest_b.add_decorator(trace2.clone()).unwrap(); + let deco3_b = forest_b.add_decorator(trace3.clone()).unwrap(); + + // This foo node is identical to the one in A, including its decorators. + let mut foo_node_b = block_foo(); + foo_node_b.set_before_enter(vec![deco1_b, deco2_b]); + let id_foo_b = forest_b.add_node(foo_node_b).unwrap(); + + // This loop node's decorators are different from the loop node in a. + let mut loop_node_b = MastNode::new_loop(id_foo_b, &forest_b).unwrap(); + loop_node_b.set_after_exit(vec![deco1_b, deco3_b]); + let id_loop_b = forest_b.add_node(loop_node_b).unwrap(); + + forest_b.make_root(id_loop_b); + + let (merged, root_maps) = MastForest::merge([&forest_a, &forest_b]).unwrap(); + + // There are 4 unique decorators across both forests. + assert_eq!(merged.decorators.len(), 4); + assert!(merged.decorators.contains(&trace0)); + assert!(merged.decorators.contains(&trace1)); + assert!(merged.decorators.contains(&trace2)); + assert!(merged.decorators.contains(&trace3)); + + let find_decorator_id = |deco: &Decorator| { + let idx = merged + .decorators + .iter() + .enumerate() + .find_map( + |(deco_id, forest_deco)| if forest_deco == deco { Some(deco_id) } else { None }, + ) + .unwrap(); + DecoratorId::from_u32_safe(idx as u32, &merged).unwrap() + }; + + let merged_deco0 = find_decorator_id(&trace0); + let merged_deco1 = find_decorator_id(&trace1); + let merged_deco2 = find_decorator_id(&trace2); + let merged_deco3 = find_decorator_id(&trace3); + + assert_eq!(merged.nodes.len(), 3); + + let merged_foo_block = merged.nodes.iter().find(|node| node.is_basic_block()).unwrap(); + let MastNode::Block(merged_foo_block) = merged_foo_block else { + panic!("expected basic block node"); + }; + + assert_eq!( + merged_foo_block.decorators().as_slice(), + &[(0, merged_deco1), (0, merged_deco2)] + ); + + // Asserts that there exists exactly one Loop Node with the given decorators. + assert_eq!( + merged + .nodes + .iter() + .filter(|node| { + if let MastNode::Loop(loop_node) = node { + loop_node.after_exit() == [merged_deco0, merged_deco2] + } else { + false + } + }) + .count(), + 1 + ); + + // Asserts that there exists exactly one Loop Node with the given decorators. + assert_eq!( + merged + .nodes + .iter() + .filter(|node| { + if let MastNode::Loop(loop_node) = node { + loop_node.after_exit() == [merged_deco1, merged_deco3] + } else { + false + } + }) + .count(), + 1 + ); + + assert_root_mapping(&root_maps, vec![&forest_a.roots, &forest_b.roots], &merged.roots).unwrap(); + + assert_child_id_lt_parent_id(&merged).unwrap(); +} + +/// Tests that an external node without decorators is replaced by its referenced node which has +/// decorators. +/// +/// [External(foo)] +/// + +/// [Block(foo, Trace(1))] +/// = +/// [Block(foo, Trace(1))] +/// + +/// [External(foo)] +/// = +/// [Block(foo, Trace(1))] +#[test] +fn mast_forest_merge_external_node_reference_with_decorator() { + let mut forest_a = MastForest::new(); + let trace = Decorator::Trace(1); + + // Build Forest A + let deco = forest_a.add_decorator(trace.clone()).unwrap(); + + let mut foo_node_a = block_foo(); + foo_node_a.set_before_enter(vec![deco]); + let foo_node_digest = foo_node_a.digest(); + let id_foo_a = forest_a.add_node(foo_node_a).unwrap(); + + forest_a.make_root(id_foo_a); + + // Build Forest B + let mut forest_b = MastForest::new(); + let id_external_b = forest_b.add_external(foo_node_digest).unwrap(); + + forest_b.make_root(id_external_b); + + for (idx, (merged, root_maps)) in [ + MastForest::merge([&forest_a, &forest_b]).unwrap(), + MastForest::merge([&forest_b, &forest_a]).unwrap(), + ] + .into_iter() + .enumerate() + { + let id_foo_a_fingerprint = + EqHash::from_mast_node(&forest_a, &BTreeMap::new(), &forest_a[id_foo_a]); + + let fingerprints: Vec<_> = merged + .nodes() + .iter() + .map(|node| EqHash::from_mast_node(&merged, &BTreeMap::new(), node)) + .collect(); + + assert_eq!(merged.nodes.len(), 1); + assert!(fingerprints.contains(&id_foo_a_fingerprint)); + + if idx == 0 { + assert_root_mapping(&root_maps, vec![&forest_a.roots, &forest_b.roots], &merged.roots) + .unwrap(); + } else { + assert_root_mapping(&root_maps, vec![&forest_b.roots, &forest_a.roots], &merged.roots) + .unwrap(); + } + + assert_child_id_lt_parent_id(&merged).unwrap(); + } +} + +/// Tests that an external node with decorators is replaced by its referenced node which does not +/// have decorators. +/// +/// [External(foo, Trace(1), Trace(2))] +/// + +/// [Block(foo)] +/// = +/// [Block(foo)] +/// + +/// [External(foo, Trace(1), Trace(2))] +/// = +/// [Block(foo)] +#[test] +fn mast_forest_merge_external_node_with_decorator() { + let mut forest_a = MastForest::new(); + let trace1 = Decorator::Trace(1); + let trace2 = Decorator::Trace(2); + + // Build Forest A + let deco1 = forest_a.add_decorator(trace1.clone()).unwrap(); + let deco2 = forest_a.add_decorator(trace2.clone()).unwrap(); + + let mut external_node_a = MastNode::new_external(block_foo().digest()); + external_node_a.set_before_enter(vec![deco1]); + external_node_a.set_after_exit(vec![deco2]); + let id_external_a = forest_a.add_node(external_node_a).unwrap(); + + forest_a.make_root(id_external_a); + + // Build Forest B + let mut forest_b = MastForest::new(); + let id_foo_b = forest_b.add_node(block_foo()).unwrap(); + + forest_b.make_root(id_foo_b); + + for (idx, (merged, root_maps)) in [ + MastForest::merge([&forest_a, &forest_b]).unwrap(), + MastForest::merge([&forest_b, &forest_a]).unwrap(), + ] + .into_iter() + .enumerate() + { + assert_eq!(merged.nodes.len(), 1); + + let id_foo_b_fingerprint = + EqHash::from_mast_node(&forest_a, &BTreeMap::new(), &forest_b[id_foo_b]); + + let fingerprints: Vec<_> = merged + .nodes() + .iter() + .map(|node| EqHash::from_mast_node(&merged, &BTreeMap::new(), node)) + .collect(); + + // Block foo should be unmodified. + assert!(fingerprints.contains(&id_foo_b_fingerprint)); + + if idx == 0 { + assert_root_mapping(&root_maps, vec![&forest_a.roots, &forest_b.roots], &merged.roots) + .unwrap(); + } else { + assert_root_mapping(&root_maps, vec![&forest_b.roots, &forest_a.roots], &merged.roots) + .unwrap(); + } + + assert_child_id_lt_parent_id(&merged).unwrap(); + } +} + +/// Tests that an external node with decorators is replaced by its referenced node which also has +/// decorators. +/// +/// [External(foo, Trace(1))] +/// + +/// [Block(foo, Trace(2))] +/// = +/// [Block(foo, Trace(2))] +/// + +/// [External(foo, Trace(1))] +/// = +/// [Block(foo, Trace(2))] +#[test] +fn mast_forest_merge_external_node_and_referenced_node_have_decorators() { + let mut forest_a = MastForest::new(); + let trace1 = Decorator::Trace(1); + let trace2 = Decorator::Trace(2); + + // Build Forest A + let deco1_a = forest_a.add_decorator(trace1.clone()).unwrap(); + + let mut external_node_a = MastNode::new_external(block_foo().digest()); + external_node_a.set_before_enter(vec![deco1_a]); + let id_external_a = forest_a.add_node(external_node_a).unwrap(); + + forest_a.make_root(id_external_a); + + // Build Forest B + let mut forest_b = MastForest::new(); + let deco2_b = forest_b.add_decorator(trace2.clone()).unwrap(); + + let mut foo_node_b = block_foo(); + foo_node_b.set_before_enter(vec![deco2_b]); + let id_foo_b = forest_b.add_node(foo_node_b).unwrap(); + + forest_b.make_root(id_foo_b); + + for (idx, (merged, root_maps)) in [ + MastForest::merge([&forest_a, &forest_b]).unwrap(), + MastForest::merge([&forest_b, &forest_a]).unwrap(), + ] + .into_iter() + .enumerate() + { + assert_eq!(merged.nodes.len(), 1); + + let id_foo_b_fingerprint = + EqHash::from_mast_node(&forest_b, &BTreeMap::new(), &forest_b[id_foo_b]); + + let fingerprints: Vec<_> = merged + .nodes() + .iter() + .map(|node| EqHash::from_mast_node(&merged, &BTreeMap::new(), node)) + .collect(); + + // Block foo should be unmodified. + assert!(fingerprints.contains(&id_foo_b_fingerprint)); + + if idx == 0 { + assert_root_mapping(&root_maps, vec![&forest_a.roots, &forest_b.roots], &merged.roots) + .unwrap(); + } else { + assert_root_mapping(&root_maps, vec![&forest_b.roots, &forest_a.roots], &merged.roots) + .unwrap(); + } + + assert_child_id_lt_parent_id(&merged).unwrap(); + } +} + +/// Tests that two external nodes with the same MAST root are deduplicated during merging and then +/// replaced by a block with the matching digest. +/// +/// [External(foo, Trace(1), Trace(2)), +/// External(foo, Trace(1))] +/// + +/// [Block(foo, Trace(1))] +/// = +/// [Block(foo, Trace(1))] +/// + +/// [External(foo, Trace(1), Trace(2)), +/// External(foo, Trace(1))] +/// = +/// [Block(foo, Trace(1))] +#[test] +fn mast_forest_merge_multiple_external_nodes_with_decorator() { + let mut forest_a = MastForest::new(); + let trace1 = Decorator::Trace(1); + let trace2 = Decorator::Trace(2); + + // Build Forest A + let deco1_a = forest_a.add_decorator(trace1.clone()).unwrap(); + let deco2_a = forest_a.add_decorator(trace2.clone()).unwrap(); + + let mut external_node_a = MastNode::new_external(block_foo().digest()); + external_node_a.set_before_enter(vec![deco1_a]); + external_node_a.set_after_exit(vec![deco2_a]); + let id_external_a = forest_a.add_node(external_node_a).unwrap(); + + let mut external_node_b = MastNode::new_external(block_foo().digest()); + external_node_b.set_before_enter(vec![deco1_a]); + let id_external_b = forest_a.add_node(external_node_b).unwrap(); + + forest_a.make_root(id_external_a); + forest_a.make_root(id_external_b); + + // Build Forest B + let mut forest_b = MastForest::new(); + let deco1_b = forest_b.add_decorator(trace1).unwrap(); + let mut block_foo_b = block_foo(); + block_foo_b.set_before_enter(vec![deco1_b]); + let id_foo_b = forest_b.add_node(block_foo_b).unwrap(); + + forest_b.make_root(id_foo_b); + + for (idx, (merged, root_maps)) in [ + MastForest::merge([&forest_a, &forest_b]).unwrap(), + MastForest::merge([&forest_b, &forest_a]).unwrap(), + ] + .into_iter() + .enumerate() + { + assert_eq!(merged.nodes.len(), 1); + + let id_foo_b_fingerprint = + EqHash::from_mast_node(&forest_a, &BTreeMap::new(), &forest_b[id_foo_b]); + + let fingerprints: Vec<_> = merged + .nodes() + .iter() + .map(|node| EqHash::from_mast_node(&merged, &BTreeMap::new(), node)) + .collect(); + + // Block foo should be unmodified. + assert!(fingerprints.contains(&id_foo_b_fingerprint)); + + if idx == 0 { + assert_root_mapping(&root_maps, vec![&forest_a.roots, &forest_b.roots], &merged.roots) + .unwrap(); + } else { + assert_root_mapping(&root_maps, vec![&forest_b.roots, &forest_a.roots], &merged.roots) + .unwrap(); + } + + assert_child_id_lt_parent_id(&merged).unwrap(); + } +} + +/// Tests that dependencies between External nodes are correctly resolved. +/// +/// [External(foo), Call(0) = qux] +/// + +/// [External(qux), Call(0), Block(foo)] +/// = +/// [External(qux), Call(0), Block(foo)] +/// + +/// [External(foo), Call(0) = qux] +/// = +/// [Block(foo), Call(0), Call(1)] +#[test] +fn mast_forest_merge_external_dependencies() { + let mut forest_a = MastForest::new(); + let id_foo_a = forest_a.add_external(block_qux().digest()).unwrap(); + let id_call_a = forest_a.add_call(id_foo_a).unwrap(); + forest_a.make_root(id_call_a); + + let mut forest_b = MastForest::new(); + let id_ext_b = forest_b.add_external(forest_a[id_call_a].digest()).unwrap(); + let id_call_b = forest_b.add_call(id_ext_b).unwrap(); + let id_qux_b = forest_b.add_node(block_qux()).unwrap(); + forest_b.make_root(id_call_b); + forest_b.make_root(id_qux_b); + + for (merged, _) in [ + MastForest::merge([&forest_a, &forest_b]).unwrap(), + MastForest::merge([&forest_b, &forest_a]).unwrap(), + ] + .into_iter() + { + let digests = merged.nodes().iter().map(|node| node.digest()).collect::>(); + assert_eq!(merged.nodes().len(), 3); + assert!(digests.contains(&forest_b[id_ext_b].digest())); + assert!(digests.contains(&forest_b[id_call_b].digest())); + assert!(digests.contains(&forest_a[id_foo_a].digest())); + assert!(digests.contains(&forest_a[id_call_a].digest())); + assert!(digests.contains(&forest_b[id_qux_b].digest())); + assert_eq!(merged.nodes().iter().filter(|node| node.is_external()).count(), 0); + + assert_child_id_lt_parent_id(&merged).unwrap(); + } +} + +/// Tests that a forest with nodes who reference non-existent decorators return an error during +/// merging and does not panic. +#[test] +fn mast_forest_merge_invalid_decorator_index() { + let trace1 = Decorator::Trace(1); + let trace2 = Decorator::Trace(2); + + // Build Forest A + let mut forest_a = MastForest::new(); + let deco1_a = forest_a.add_decorator(trace1.clone()).unwrap(); + let deco2_a = forest_a.add_decorator(trace2.clone()).unwrap(); + let id_bar_a = forest_a.add_node(block_bar()).unwrap(); + + forest_a.make_root(id_bar_a); + + // Build Forest B + let mut forest_b = MastForest::new(); + let mut block_b = block_foo(); + // We're using a DecoratorId from forest A which is invalid. + block_b.set_before_enter(vec![deco1_a, deco2_a]); + let id_foo_b = forest_b.add_node(block_b).unwrap(); + + forest_b.make_root(id_foo_b); + + let err = MastForest::merge([&forest_a, &forest_b]).unwrap_err(); + assert_matches!(err, MastForestError::DecoratorIdOverflow(_, _)); +} diff --git a/core/src/mast/mod.rs b/core/src/mast/mod.rs index bc9cc8e3b5..a105ae996f 100644 --- a/core/src/mast/mod.rs +++ b/core/src/mast/mod.rs @@ -7,7 +7,11 @@ use core::{ ops::{Index, IndexMut}, }; -use miden_crypto::hash::rpo::RpoDigest; +use miden_crypto::hash::{ + blake::{Blake3Digest, Blake3_256}, + rpo::RpoDigest, + Digest, +}; mod node; pub use node::{ @@ -20,6 +24,13 @@ use crate::{Decorator, DecoratorList, Operation}; mod serialization; +mod merger; +pub(crate) use merger::MastForestMerger; +pub use merger::MastForestRootMap; + +mod multi_forest_node_iterator; +pub(crate) use multi_forest_node_iterator::*; + #[cfg(test)] mod tests; @@ -191,6 +202,61 @@ impl MastForest { self[node_id].set_after_exit(decorator_ids) } + /// Merges all `forests` into a new [`MastForest`]. + /// + /// Merging two forests means combining all their constituent parts, i.e. [`MastNode`]s, + /// [`Decorator`]s and roots. During this process, any duplicate or + /// unreachable nodes are removed. Additionally, [`MastNodeId`]s of nodes as well as + /// [`DecoratorId`]s of decorators may change and references to them are remapped to their new + /// location. + /// + /// For example, consider this representation of a forest's nodes with all of these nodes being + /// roots: + /// + /// ```text + /// [Block(foo), Block(bar)] + /// ``` + /// + /// If we merge another forest into it: + /// + /// ```text + /// [Block(bar), Call(0)] + /// ``` + /// + /// then we would expect this forest: + /// + /// ```text + /// [Block(foo), Block(bar), Call(1)] + /// ``` + /// + /// - The `Call` to the `bar` block was remapped to its new index (now 1, previously 0). + /// - The `Block(bar)` was deduplicated any only exists once in the merged forest. + /// + /// The function also returns a vector of [`MastForestRootMap`]s, whose length equals the number + /// of passed `forests`. The indices in the vector correspond to the ones in `forests`. The map + /// of a given forest contains the new locations of its roots in the merged forest. To + /// illustrate, the above example would return a vector of two maps: + /// + /// ```text + /// vec![{0 -> 0, 1 -> 1} + /// {0 -> 1, 1 -> 2}] + /// ``` + /// + /// - The root locations of the original forest are unchanged. + /// - For the second forest, the `bar` block has moved from index 0 to index 1 in the merged + /// forest, and the `Call` has moved from index 1 to 2. + /// + /// If any forest being merged contains an `External(qux)` node and another forest contains a + /// node whose digest is `qux`, then the external node will be replaced with the `qux` node, + /// which is effectively deduplication. Decorators are ignored when it comes to merging + /// External nodes. This means that an External node with decorators may be replaced by a node + /// without decorators or vice versa. + pub fn merge<'forest>( + forests: impl IntoIterator, + ) -> Result<(MastForest, MastForestRootMap), MastForestError> { + MastForestMerger::merge(forests) + } + /// Adds a basic block node to the forest, and returns the [`MastNodeId`] associated with it. /// /// It is assumed that the decorators have not already been added to the MAST forest. If they @@ -466,6 +532,11 @@ impl MastNodeId { } } + /// Returns a new [`MastNodeId`] from the given `value` without checking its validity. + pub(crate) fn new_unchecked(value: u32) -> Self { + Self(value) + } + pub fn as_usize(&self) -> usize { self.0 as usize } @@ -527,6 +598,11 @@ impl DecoratorId { } } + /// Creates a new [`DecoratorId`] without checking its validity. + pub(crate) fn new_unchecked(value: u32) -> Self { + Self(value) + } + pub fn as_usize(&self) -> usize { self.0 as usize } @@ -566,6 +642,166 @@ impl Serializable for DecoratorId { } } +// MAST NODE EQUALITY +// ================================================================================================ + +/// Represents the hash used to test for equality between [`MastNode`]s. +/// +/// The decorator root will be `None` if and only if there are no decorators attached to the node, +/// and all children have no decorator roots (meaning that there are no decorators in all the +/// descendants). +#[derive(Debug, Clone, Copy, PartialEq, Eq, PartialOrd, Ord)] +pub struct EqHash { + mast_root: RpoDigest, + decorator_root: Option>, +} + +// TODO: Document public functions and assumptions about forest and the index map. +impl EqHash { + pub fn new(mast_root: RpoDigest) -> Self { + Self { mast_root, decorator_root: None } + } + + pub fn with_decorator_root(mast_root: RpoDigest, decorator_root: Blake3Digest<32>) -> Self { + Self { + mast_root, + decorator_root: Some(decorator_root), + } + } + + pub fn from_mast_node( + forest: &MastForest, + hash_by_node_id: &BTreeMap, + node: &MastNode, + ) -> EqHash { + match node { + MastNode::Block(node) => { + let mut bytes_to_hash = Vec::new(); + + for &(idx, decorator_id) in node.decorators() { + bytes_to_hash.extend(idx.to_le_bytes()); + bytes_to_hash.extend(forest[decorator_id].eq_hash().as_bytes()); + } + + // Add any `Assert` or `U32assert2` opcodes present, since these are not included in + // the MAST root. + for (op_idx, op) in node.operations().enumerate() { + if let Operation::U32assert2(inner_value) + | Operation::Assert(inner_value) + | Operation::MpVerify(inner_value) = op + { + let op_idx: u32 = op_idx + .try_into() + .expect("there are more than 2^{32}-1 operations in basic block"); + + // we include the opcode to differentiate between `Assert` and `U32assert2` + bytes_to_hash.push(op.op_code()); + // we include the operation index to distinguish between basic blocks that + // would have the same assert instructions, but in a different order + bytes_to_hash.extend(op_idx.to_le_bytes()); + bytes_to_hash.extend(inner_value.to_le_bytes()); + } + } + + if bytes_to_hash.is_empty() { + EqHash::new(node.digest()) + } else { + let decorator_root = Blake3_256::hash(&bytes_to_hash); + EqHash::with_decorator_root(node.digest(), decorator_root) + } + }, + MastNode::Join(node) => eq_hash_from_parts( + forest, + hash_by_node_id, + node.before_enter(), + node.after_exit(), + &[node.first(), node.second()], + node.digest(), + ), + MastNode::Split(node) => eq_hash_from_parts( + forest, + hash_by_node_id, + node.before_enter(), + node.after_exit(), + &[node.on_true(), node.on_false()], + node.digest(), + ), + MastNode::Loop(node) => eq_hash_from_parts( + forest, + hash_by_node_id, + node.before_enter(), + node.after_exit(), + &[node.body()], + node.digest(), + ), + MastNode::Call(node) => eq_hash_from_parts( + forest, + hash_by_node_id, + node.before_enter(), + node.after_exit(), + &[node.callee()], + node.digest(), + ), + MastNode::Dyn(node) => eq_hash_from_parts( + forest, + hash_by_node_id, + node.before_enter(), + node.after_exit(), + &[], + node.digest(), + ), + MastNode::External(node) => eq_hash_from_parts( + forest, + hash_by_node_id, + node.before_enter(), + node.after_exit(), + &[], + node.digest(), + ), + } + } +} + +fn eq_hash_from_parts( + forest: &MastForest, + hash_by_node_id: &BTreeMap, + before_enter_ids: &[DecoratorId], + after_exit_ids: &[DecoratorId], + children_ids: &[MastNodeId], + node_digest: RpoDigest, +) -> EqHash { + let pre_decorator_hash_bytes = + before_enter_ids.iter().flat_map(|&id| forest[id].eq_hash().as_bytes()); + let post_decorator_hash_bytes = + after_exit_ids.iter().flat_map(|&id| forest[id].eq_hash().as_bytes()); + + // Reminder: the `EqHash`'s decorator root will be `None` if and only if there are no + // decorators attached to the node, and all children have no decorator roots (meaning that + // there are no decorators in all the descendants). + if pre_decorator_hash_bytes.clone().next().is_none() + && post_decorator_hash_bytes.clone().next().is_none() + && children_ids + .iter() + .filter_map(|child_id| hash_by_node_id[child_id].decorator_root) + .next() + .is_none() + { + EqHash::new(node_digest) + } else { + let children_decorator_roots = children_ids + .iter() + .filter_map(|child_id| hash_by_node_id[child_id].decorator_root) + .flat_map(|decorator_root| decorator_root.as_bytes()); + let decorator_bytes_to_hash: Vec = pre_decorator_hash_bytes + .chain(post_decorator_hash_bytes) + .chain(children_decorator_roots) + .collect(); + + let decorator_root = Blake3_256::hash(&decorator_bytes_to_hash); + EqHash::with_decorator_root(node_digest, decorator_root) + } +} + // MAST FOREST ERROR // ================================================================================================ @@ -584,6 +820,8 @@ pub enum MastForestError { TooManyNodes, #[error("node id: {0} is greater than or equal to forest length: {1}")] NodeIdOverflow(MastNodeId, usize), + #[error("decorator id: {0} is greater than or equal to decorator count: {1}")] + DecoratorIdOverflow(DecoratorId, usize), #[error("basic block cannot be created from an empty list of operations")] EmptyBasicBlock, } diff --git a/core/src/mast/multi_forest_node_iterator.rs b/core/src/mast/multi_forest_node_iterator.rs new file mode 100644 index 0000000000..be5ce1b5d2 --- /dev/null +++ b/core/src/mast/multi_forest_node_iterator.rs @@ -0,0 +1,490 @@ +use alloc::{ + collections::{BTreeMap, VecDeque}, + vec::Vec, +}; + +use miden_crypto::hash::rpo::RpoDigest; + +use crate::mast::{MastForest, MastForestError, MastNode, MastNodeId}; + +type ForestIndex = usize; + +/// Depth First Search Iterator in Post Order for [`MastForest`]s. +/// +/// This iterator iterates through all **reachable** nodes of all given forests exactly once. +/// +/// Since a `MastForest` has multiple possible entrypoints in the form of its roots, a depth-first +/// search must visit all of those roots and the trees they form. This iterator's `Item` is +/// [`MultiMastForestIteratorItem`]. It contains either a [`MultiMastForestIteratorItem::Node`] of a +/// forest, or the replacement of an external node. This is returned if one forest contains an +/// External node with digest `foo` and another forest contains a non-external node with digest +/// `foo`. In such a case the `foo` node is yielded first (unless it was already visited) and +/// subsequently a "replacement signal" ([`MultiMastForestIteratorItem::ExternalNodeReplacement`]) +/// for the external node is yielded to make the caller aware that this replacement has happened. +/// +/// All of this is useful to ensure that children are always processed before their parents, even if +/// a child is an External node which is replaced by a node in another forest. This guarantees that +/// **all [`MastNodeId`]s of child nodes are strictly less than the [`MastNodeId`] of their +/// parents**. +/// +/// For instance, consider these `MastForest`s being passed to this iterator with the `Call(0)`'s +/// digest being `qux`: +/// +/// ```text +/// Forest A Nodes: [Block(foo), External(qux), Join(0, 1)] +/// Forest A Roots: [2] +/// Forest B Nodes: [Block(bar), Call(0)] +/// Forest B Roots: [0] +/// ``` +/// +/// The only root of A is the `Join` node at index 2. The first three nodes of the forest form a +/// tree, since the `Join` node references index 0 and 1. This tree is discovered by +/// starting at the root at index 2 and following all children until we reach terminal nodes (like +/// `Block`s) and building up a deque of the discovered nodes. The special case here is the +/// `External` node whose digest matches that of a node in forest B. Instead of the External +/// node being added to the deque, the tree of the Call node is added instead. The deque is built +/// such that popping elements off the deque (from the front) yields a postorder. +/// +/// After the first tree is discovered, the deque looks like this: +/// +/// ```text +/// [Node(forest_idx: 0, node_id: 0), +/// Node(forest_idx: 1, node_id: 0), +/// Node(forest_idx: 1, node_id: 1), +/// ExternalNodeReplacement( +/// replacement_forest_idx: 1, replacement_node_id: 1 +/// replaced_forest_idx: 0, replaced_node_id: 1 +/// ), +/// Node(forest_idx: 0, node_id: 2)] +/// ``` +/// +/// If the deque is exhausted we start another discovery if more undiscovered roots exist. In this +/// example, the root of forest B was already discovered and visited due to the External node +/// reference, so the iteration is complete. +/// +/// The iteration on a higher level thus consists of a back and forth between discovering trees and +/// returning nodes from the deque. +pub(crate) struct MultiMastForestNodeIter<'forest> { + /// The forests that we're iterating. + mast_forests: Vec<&'forest MastForest>, + /// The index of the forest we're currently processing and discovering trees in. + /// + /// This value iterates through 0..mast_forests.len() which guarantees that we visit all + /// forests once. + current_forest_idx: ForestIndex, + /// The procedure root index at which we last started a tree discovery in the + /// current_forest_idx. + /// + /// This value iterates through 0..mast_forests[current_forest_idx].num_procedures() which + /// guarantees that we visit all nodes reachable from all roots. + current_procedure_root_idx: u32, + /// A map of MAST roots of all non-external nodes in mast_forests to their forest and node + /// indices. + non_external_nodes: BTreeMap, + /// Describes whether the node identified by [forest_index][node_index] has already been + /// discovered. Note that this is `true` for all nodes that are in the unvisited node deque. + discovered_nodes: Vec>, + /// This deque always contains the discovered, but unvisited nodes. + /// It holds that discovered_nodes[forest_idx][node_id] = true for all elements in this deque. + unvisited_nodes: VecDeque, +} + +impl<'forest> MultiMastForestNodeIter<'forest> { + /// Builds a map of MAST roots to non-external nodes in any of the given forests to initialize + /// the iterator. This enables an efficient check whether for any encountered External node + /// referencing digest `foo` a node with digest `foo` already exists in any forest. + pub(crate) fn new(mast_forests: Vec<&'forest MastForest>) -> Self { + let discovered_nodes = mast_forests + .iter() + .map(|forest| vec![false; forest.num_nodes() as usize]) + .collect(); + + let mut non_external_nodes = BTreeMap::new(); + + for (forest_idx, forest) in mast_forests.iter().enumerate() { + for (node_idx, node) in forest.nodes().iter().enumerate() { + // SAFETY: The passed id comes from the iterator over the nodes, so we never exceed + // the forest's number of nodes. + let node_id = MastNodeId::new_unchecked(node_idx as u32); + if !node.is_external() { + non_external_nodes.insert(node.digest(), (forest_idx, node_id)); + } + } + } + + Self { + mast_forests, + current_forest_idx: 0, + current_procedure_root_idx: 0, + non_external_nodes, + discovered_nodes, + unvisited_nodes: VecDeque::new(), + } + } + + /// Pushes the given node, uniquely identified by the forest and node index onto the deque + /// even if the node was already discovered before. + /// + /// It's the callers responsibility to only pass valid indices. + fn push_node(&mut self, forest_idx: usize, node_id: MastNodeId) { + self.unvisited_nodes + .push_back(MultiMastForestIteratorItem::Node { forest_idx, node_id }); + self.discovered_nodes[forest_idx][node_id.as_usize()] = true; + } + + /// Discovers a tree starting at the given forest index and node id. + /// + /// SAFETY: We only pass valid forest and node indices so we can index directly in this + /// function. + fn discover_tree( + &mut self, + forest_idx: ForestIndex, + node_id: MastNodeId, + ) -> Result<(), MastForestError> { + if self.discovered_nodes[forest_idx][node_id.as_usize()] { + return Ok(()); + } + + let current_node = + &self.mast_forests[forest_idx].nodes.get(node_id.as_usize()).ok_or_else(|| { + MastForestError::NodeIdOverflow( + node_id, + self.mast_forests[forest_idx].num_nodes() as usize, + ) + })?; + + // Note that we can process nodes in postorder, since we push them onto the back of the + // deque but pop them off the front. + match current_node { + MastNode::Block(_) => { + self.push_node(forest_idx, node_id); + }, + MastNode::Join(join_node) => { + self.discover_tree(forest_idx, join_node.first())?; + self.discover_tree(forest_idx, join_node.second())?; + self.push_node(forest_idx, node_id); + }, + MastNode::Split(split_node) => { + self.discover_tree(forest_idx, split_node.on_true())?; + self.discover_tree(forest_idx, split_node.on_false())?; + self.push_node(forest_idx, node_id); + }, + MastNode::Loop(loop_node) => { + self.discover_tree(forest_idx, loop_node.body())?; + self.push_node(forest_idx, node_id); + }, + MastNode::Call(call_node) => { + self.discover_tree(forest_idx, call_node.callee())?; + self.push_node(forest_idx, node_id); + }, + MastNode::Dyn(_) => { + self.push_node(forest_idx, node_id); + }, + MastNode::External(external_node) => { + // When we encounter an external node referencing digest `foo` there are two cases: + // - If there exists a node `replacement` in any forest with digest `foo`, we want + // to replace the external node with that node, which we do in two steps. + // - Discover the `replacement`'s tree and add it to the deque. + // - If `replacement` was already discovered before, it won't actually be + // returned. + // - In any case this means: The `replacement` node is processed before the + // replacement signal we're adding next. + // - Add a replacement signal to the deque, signaling that the `replacement` + // replaced the external node. + // - If no replacement exists, yield the External Node as a regular `Node`. + if let Some((other_forest_idx, other_node_id)) = + self.non_external_nodes.get(&external_node.digest()).copied() + { + self.discover_tree(other_forest_idx, other_node_id)?; + + self.unvisited_nodes.push_back( + MultiMastForestIteratorItem::ExternalNodeReplacement { + replacement_forest_idx: other_forest_idx, + replacement_mast_node_id: other_node_id, + replaced_forest_idx: forest_idx, + replaced_mast_node_id: node_id, + }, + ); + + self.discovered_nodes[forest_idx][node_id.as_usize()] = true; + } else { + self.push_node(forest_idx, node_id); + } + }, + } + + Ok(()) + } + + /// Finds the next undiscovered procedure root and discovers a tree from it. + /// + /// If the undiscovered node deque is empty after calling this function, the iteration is + /// complete. + /// + /// This function basically consists of two loops: + /// - The outer loop iterates over all forest indices. + /// - The inner loop iterates over all procedure root indices for the current forest. + fn discover_nodes(&mut self) { + 'forest_loop: while self.current_forest_idx < self.mast_forests.len() + && self.unvisited_nodes.is_empty() + { + // If we don't have any forests, there is nothing to do. + if self.mast_forests.is_empty() { + return; + } + + // If the current forest doesn't have roots, advance to the next one. + if self.mast_forests[self.current_forest_idx].num_procedures() == 0 { + self.current_forest_idx += 1; + continue; + } + + let procedure_roots = self.mast_forests[self.current_forest_idx].procedure_roots(); + let discovered_nodes = &self.discovered_nodes[self.current_forest_idx]; + + // Find the next undiscovered procedure root for the current forest by incrementing the + // current procedure root until we find one that was not yet discovered. + while discovered_nodes + [procedure_roots[self.current_procedure_root_idx as usize].as_usize()] + { + // If we have reached the end of the procedure roots for the current forest, + // continue searching in the next forest. + if self.current_procedure_root_idx + 1 + >= self.mast_forests[self.current_forest_idx].num_procedures() + { + // Reset current procedure root. + self.current_procedure_root_idx = 0; + // Increment forest index. + self.current_forest_idx += 1; + + continue 'forest_loop; + } + + // Since the current procedure root was already discovered, check the next one. + self.current_procedure_root_idx += 1; + } + + // We exited the loop, so the current procedure root is undiscovered and so we can start + // a discovery from that root. Since that root is undiscovered, it is guaranteed that + // after this discovery the deque will be non-empty. + let procedure_root_id = procedure_roots[self.current_procedure_root_idx as usize]; + self.discover_tree(self.current_forest_idx, procedure_root_id) + .expect("we should only pass root indices that are valid for the forest"); + } + } +} + +impl Iterator for MultiMastForestNodeIter<'_> { + type Item = MultiMastForestIteratorItem; + + fn next(&mut self) -> Option { + if let Some(deque_item) = self.unvisited_nodes.pop_front() { + return Some(deque_item); + } + + self.discover_nodes(); + + if !self.unvisited_nodes.is_empty() { + self.next() + } else { + // If the deque is empty after tree discovery, all (reachable) nodes have been + // discovered and visited. + None + } + } +} + +/// The iterator item for [`MultiMastForestNodeIter`]. See its documentation for details. +#[derive(Debug, Clone, PartialEq, Eq)] +pub(crate) enum MultiMastForestIteratorItem { + /// A regular node discovered by the iterator. + Node { + forest_idx: ForestIndex, + node_id: MastNodeId, + }, + /// Signals a replacement of an external node by some other node. + ExternalNodeReplacement { + replacement_forest_idx: usize, + replacement_mast_node_id: MastNodeId, + replaced_forest_idx: usize, + replaced_mast_node_id: MastNodeId, + }, +} + +// TESTS +// ================================================================================================ + +#[cfg(test)] +mod tests { + use miden_crypto::hash::rpo::RpoDigest; + + use super::*; + use crate::Operation; + + fn random_digest() -> RpoDigest { + RpoDigest::new([rand_utils::rand_value(); 4]) + } + + #[test] + fn multi_mast_forest_dfs_empty() { + let forest = MastForest::new(); + let mut iterator = MultiMastForestNodeIter::new(vec![&forest]); + assert!(iterator.next().is_none()); + } + + #[test] + fn multi_mast_forest_multiple_forests_dfs() { + let nodea0_digest = random_digest(); + let nodea1_digest = random_digest(); + let nodea2_digest = random_digest(); + let nodea3_digest = random_digest(); + + let nodeb0_digest = random_digest(); + + let mut forest_a = MastForest::new(); + forest_a.add_external(nodea0_digest).unwrap(); + let id1 = forest_a.add_external(nodea1_digest).unwrap(); + let id2 = forest_a.add_external(nodea2_digest).unwrap(); + let id3 = forest_a.add_external(nodea3_digest).unwrap(); + let id_split = forest_a.add_split(id2, id3).unwrap(); + let id_join = forest_a.add_join(id2, id_split).unwrap(); + + forest_a.make_root(id_join); + forest_a.make_root(id1); + + let mut forest_b = MastForest::new(); + let id_ext_b = forest_b.add_external(nodeb0_digest).unwrap(); + let id_block_b = forest_b.add_block(vec![Operation::Eqz], None).unwrap(); + let id_split_b = forest_b.add_split(id_ext_b, id_block_b).unwrap(); + + forest_b.make_root(id_split_b); + + // Note that the node at index 0 is not visited because it is not reachable from any root + // and is not a root itself. + let nodes = MultiMastForestNodeIter::new(vec![&forest_a, &forest_b]).collect::>(); + + assert_eq!(nodes.len(), 8); + assert_eq!(nodes[0], MultiMastForestIteratorItem::Node { forest_idx: 0, node_id: id2 }); + assert_eq!(nodes[1], MultiMastForestIteratorItem::Node { forest_idx: 0, node_id: id3 }); + assert_eq!( + nodes[2], + MultiMastForestIteratorItem::Node { forest_idx: 0, node_id: id_split } + ); + assert_eq!(nodes[3], MultiMastForestIteratorItem::Node { forest_idx: 0, node_id: id_join }); + assert_eq!(nodes[4], MultiMastForestIteratorItem::Node { forest_idx: 0, node_id: id1 }); + assert_eq!( + nodes[5], + MultiMastForestIteratorItem::Node { forest_idx: 1, node_id: id_ext_b } + ); + assert_eq!( + nodes[6], + MultiMastForestIteratorItem::Node { forest_idx: 1, node_id: id_block_b } + ); + assert_eq!( + nodes[7], + MultiMastForestIteratorItem::Node { forest_idx: 1, node_id: id_split_b } + ); + } + + #[test] + fn multi_mast_forest_external_dependencies() { + let block_foo = MastNode::new_basic_block(vec![Operation::Drop], None).unwrap(); + let mut forest_a = MastForest::new(); + let id_foo_a = forest_a.add_external(block_foo.digest()).unwrap(); + let id_call_a = forest_a.add_call(id_foo_a).unwrap(); + forest_a.make_root(id_call_a); + + let mut forest_b = MastForest::new(); + let id_ext_b = forest_b.add_external(forest_a[id_call_a].digest()).unwrap(); + let id_call_b = forest_b.add_call(id_ext_b).unwrap(); + forest_b.add_node(block_foo).unwrap(); + forest_b.make_root(id_call_b); + + let nodes = MultiMastForestNodeIter::new(vec![&forest_a, &forest_b]).collect::>(); + + assert_eq!(nodes.len(), 5); + + // The replacement for the external node from forest A. + assert_eq!( + nodes[0], + MultiMastForestIteratorItem::Node { + forest_idx: 1, + node_id: MastNodeId::new_unchecked(2) + } + ); + // The external node replaced by the block foo from forest B. + assert_eq!( + nodes[1], + MultiMastForestIteratorItem::ExternalNodeReplacement { + replacement_forest_idx: 1, + replacement_mast_node_id: MastNodeId::new_unchecked(2), + replaced_forest_idx: 0, + replaced_mast_node_id: MastNodeId::new_unchecked(0) + } + ); + // The call from forest A. + assert_eq!( + nodes[2], + MultiMastForestIteratorItem::Node { + forest_idx: 0, + node_id: MastNodeId::new_unchecked(1) + } + ); + // The replacement for the external node that is replaced by the Call in forest A. + assert_eq!( + nodes[3], + MultiMastForestIteratorItem::ExternalNodeReplacement { + replacement_forest_idx: 0, + replacement_mast_node_id: MastNodeId::new_unchecked(1), + replaced_forest_idx: 1, + replaced_mast_node_id: MastNodeId::new_unchecked(0) + } + ); + // The call from forest B. + assert_eq!( + nodes[4], + MultiMastForestIteratorItem::Node { + forest_idx: 1, + node_id: MastNodeId::new_unchecked(1) + } + ); + } + + /// Tests that a node which is referenced twice in a Mast Forest is returned in the required + /// order. + /// + /// In this test we have a MastForest with this graph: + /// + /// 3 <- Split Node + /// / \ + /// 1 2 + /// \ / + /// 0 + /// + /// We need to ensure that 0 is processed before 1 and that it is not processed again when + /// processing the children of node 2. + /// + /// This test and example is essentially a copy from a part of the MastForest of the Miden + /// Stdlib where this failed on a previous implementation. + #[test] + fn multi_mast_forest_child_duplicate() { + let block_foo = MastNode::new_basic_block(vec![Operation::Drop], None).unwrap(); + let mut forest = MastForest::new(); + let id_foo = forest.add_external(block_foo.digest()).unwrap(); + let id_call1 = forest.add_call(id_foo).unwrap(); + let id_call2 = forest.add_call(id_foo).unwrap(); + let id_split = forest.add_split(id_call1, id_call2).unwrap(); + forest.make_root(id_split); + + let nodes = MultiMastForestNodeIter::new(vec![&forest]).collect::>(); + + // The foo node should be yielded first and it should not be yielded twice. + for (i, expected_node_id) in [id_foo, id_call1, id_call2, id_split].into_iter().enumerate() + { + assert_eq!( + nodes[i], + MultiMastForestIteratorItem::Node { forest_idx: 0, node_id: expected_node_id } + ); + } + } +} diff --git a/stdlib/tests/main.rs b/stdlib/tests/main.rs index b50db64413..f159a08612 100644 --- a/stdlib/tests/main.rs +++ b/stdlib/tests/main.rs @@ -12,6 +12,7 @@ macro_rules! build_test { mod collections; mod crypto; +mod mast_forest_merge; mod math; mod mem; mod sys; diff --git a/stdlib/tests/mast_forest_merge.rs b/stdlib/tests/mast_forest_merge.rs new file mode 100644 index 0000000000..3040551b94 --- /dev/null +++ b/stdlib/tests/mast_forest_merge.rs @@ -0,0 +1,19 @@ +use processor::MastForest; + +/// Tests that the stdlib merged with itself produces a forest that has the same procedure +/// roots. +/// +/// This test is added here since we do not have the StdLib in miden-core where merging is +/// implemented and the StdLib serves as a convenient example of a large MastForest. +#[test] +fn mast_forest_merge_stdlib() { + let std_lib = miden_stdlib::StdLibrary::default(); + let std_forest = std_lib.mast_forest().as_ref(); + + let (merged, _) = MastForest::merge([std_forest, std_forest]).unwrap(); + + let merged_digests = merged.procedure_digests().collect::>(); + for digest in std_forest.procedure_digests() { + assert!(merged_digests.contains(&digest)); + } +}