From 259f61b0292b0f6f508cd632db5ca3e80b8664bb Mon Sep 17 00:00:00 2001 From: Philipp Gackstatter Date: Thu, 24 Oct 2024 14:34:13 +0200 Subject: [PATCH] feat(core): Move Vec inside the `MastForestRootMap` --- .../src/assembler/mast_forest_merger_tests.rs | 14 +- core/src/mast/merger/mod.rs | 34 ++-- core/src/mast/merger/tests.rs | 156 +++++++++++------- core/src/mast/mod.rs | 2 +- 4 files changed, 121 insertions(+), 85 deletions(-) diff --git a/assembly/src/assembler/mast_forest_merger_tests.rs b/assembly/src/assembler/mast_forest_merger_tests.rs index 9d1b9faba..63197cfbd 100644 --- a/assembly/src/assembler/mast_forest_merger_tests.rs +++ b/assembly/src/assembler/mast_forest_merger_tests.rs @@ -1,5 +1,3 @@ -use std::vec::Vec; - use miette::{IntoDiagnostic, Report}; use vm_core::mast::{MastForest, MastForestRootMap}; @@ -9,7 +7,7 @@ use crate::{testing::TestContext, Assembler}; fn merge_programs( program_a: &str, program_b: &str, -) -> Result<(MastForest, MastForest, (MastForest, Vec)), Report> { +) -> Result<(MastForest, MastForest, MastForest, MastForestRootMap), Report> { let context = TestContext::new(); let module = context.parse_module_with_path("lib::mod".parse().unwrap(), program_a)?; @@ -20,9 +18,9 @@ fn merge_programs( let lib_b = assembler.assemble_library([program_b])?.mast_forest().as_ref().clone(); let lib_a = lib_a.mast_forest().as_ref().clone(); - let merged = MastForest::merge([&lib_a, &lib_b]).into_diagnostic()?; + let (merged, root_maps) = MastForest::merge([&lib_a, &lib_b]).into_diagnostic()?; - Ok((lib_a, lib_b, merged)) + Ok((lib_a, lib_b, merged, root_maps)) } /// Tests that an assembler-produced library's forests can be merged and that external nodes are @@ -51,12 +49,12 @@ fn mast_forest_merge_assembler() { exec.mod::foo end"#; - let (forest_a, forest_b, (merged, root_maps)) = merge_programs(lib_a, lib_b).unwrap(); + let (forest_a, forest_b, merged, root_maps) = merge_programs(lib_a, lib_b).unwrap(); - for (forest, root_map) in [(forest_a, &root_maps[0]), (forest_b, &root_maps[1])] { + for (forest_idx, forest) in [forest_a, forest_b].into_iter().enumerate() { for root in forest.procedure_roots() { let original_digest = forest.nodes()[root.as_usize()].digest(); - let new_root = root_map.map_root(root).unwrap(); + let new_root = root_maps.map_root(forest_idx, root).unwrap(); let new_digest = forest.nodes()[new_root.as_usize()].digest(); assert_eq!(original_digest, new_digest); } diff --git a/core/src/mast/merger/mod.rs b/core/src/mast/merger/mod.rs index 8d99e91c8..73daf2d20 100644 --- a/core/src/mast/merger/mod.rs +++ b/core/src/mast/merger/mod.rs @@ -29,7 +29,7 @@ impl MastForestMerger { /// [`MastForest`]s are merged. pub(crate) fn merge<'forest>( forests: impl IntoIterator, - ) -> Result<(MastForest, Vec), MastForestError> { + ) -> Result<(MastForest, MastForestRootMap), MastForestError> { let forests = forests.into_iter().collect::>(); let decorator_id_mappings = Vec::with_capacity(forests.len()); let node_id_mappings = vec![MastForestNodeIdMap::new(); forests.len()]; @@ -47,11 +47,7 @@ impl MastForestMerger { let Self { mast_forest, node_id_mappings, .. } = merger; - let mut root_maps = Vec::new(); - for (forest_idx, mapping) in node_id_mappings.into_iter().enumerate() { - let forest = forests[forest_idx]; - root_maps.push(MastForestRootMap::from_node_id_map(mapping, &forest.roots)); - } + let root_maps = MastForestRootMap::from_node_id_map(node_id_mappings, forests); Ok((mast_forest, root_maps)) } @@ -320,27 +316,31 @@ impl MastForestMerger { /// forest. See [`MastForest::merge`] for more details. #[derive(Debug, Clone, PartialEq, Eq)] pub struct MastForestRootMap { - root_map: BTreeMap, + root_maps: Vec>, } impl MastForestRootMap { - fn from_node_id_map(id_map: MastForestNodeIdMap, roots: &[MastNodeId]) -> Self { - let mut root_map = BTreeMap::new(); - - for root in roots { - let new_id = - id_map.get(root).copied().expect("every node id should be mapped to its new id"); - root_map.insert(*root, new_id); + fn from_node_id_map(id_map: Vec, forests: Vec<&MastForest>) -> Self { + let mut root_maps = vec![BTreeMap::new(); forests.len()]; + + for (forest_idx, forest) in forests.into_iter().enumerate() { + for root in forest.procedure_roots() { + let new_id = id_map[forest_idx] + .get(root) + .copied() + .expect("every node id should be mapped to its new id"); + root_maps[forest_idx].insert(*root, new_id); + } } - Self { root_map } + Self { root_maps } } /// Maps the given root to its new location in the merged forest, if such a mapping exists. /// /// It is guaranteed that every root of the map's corresponding forest is contained in the map. - pub fn map_root(&self, root: &MastNodeId) -> Option { - self.root_map.get(root).copied() + pub fn map_root(&self, forest_index: usize, root: &MastNodeId) -> Option { + self.root_maps.get(forest_index).and_then(|map| map.get(root)).copied() } } diff --git a/core/src/mast/merger/tests.rs b/core/src/mast/merger/tests.rs index 52d397761..881e22c2c 100644 --- a/core/src/mast/merger/tests.rs +++ b/core/src/mast/merger/tests.rs @@ -16,45 +16,81 @@ fn block_qux() -> MastNode { .unwrap() } -fn assert_contains_node_once(forest: &MastForest, digest: RpoDigest) { - assert_eq!(forest.nodes.iter().filter(|node| node.digest() == digest).count(), 1); +/// Asserts that the given forest contains exactly one node with the given digest. +/// +/// Returns a Result which can be unwrapped in the calling test function to assert. This way, if +/// this assertion fails it'll be clear which exact call failed. +fn assert_contains_node_once(forest: &MastForest, digest: RpoDigest) -> Result<(), &str> { + if forest.nodes.iter().filter(|node| node.digest() == digest).count() != 1 { + return Err("node digest contained more than once in the forest"); + } + + Ok(()) } /// Asserts that every root of an original forest has an id to which it is mapped and that this /// mapped root is in the set of roots in the merged forest. +/// +/// Returns a Result which can be unwrapped in the calling test function to assert. This way, if +/// this assertion fails it'll be clear which exact call failed. fn assert_root_mapping( root_map: &MastForestRootMap, - original_roots: &[MastNodeId], + original_roots: Vec<&[MastNodeId]>, merged_roots: &[MastNodeId], -) { - for original_root in original_roots { - let mapped_root = root_map.map_root(original_root).unwrap(); - assert!(merged_roots.contains(&mapped_root)); +) -> Result<(), &'static str> { + for (forest_idx, original_root) in original_roots.into_iter().enumerate() { + for root in original_root { + let mapped_root = root_map.map_root(forest_idx, root).unwrap(); + if !merged_roots.contains(&mapped_root) { + return Err("merged root does not contain mapped root"); + } + } } + + Ok(()) } -fn assert_child_id_lt_parent_id(forest: &MastForest) { - for (idx, node) in forest.nodes().iter().enumerate() { +/// Asserts that all children of nodes in the given forest have an id that is less than the parent's +/// ID. +/// +/// Returns a Result which can be unwrapped in the calling test function to assert. This way, if +/// this assertion fails it'll be clear which exact call failed. +fn assert_child_id_lt_parent_id(forest: &MastForest) -> Result<(), &str> { + for (mast_node_id, node) in forest.nodes().iter().enumerate() { match node { MastNode::Join(join_node) => { - assert!(join_node.first().as_usize() < idx); - assert!(join_node.second().as_usize() < idx); + if !join_node.first().as_usize() < mast_node_id { + return Err("join node first child id is not < parent id"); + }; + if !join_node.second().as_usize() < mast_node_id { + return Err("join node second child id is not < parent id"); + } }, MastNode::Split(split_node) => { - assert!(split_node.on_true().as_usize() < idx); - assert!(split_node.on_false().as_usize() < idx); + if !split_node.on_true().as_usize() < mast_node_id { + return Err("split node on true id is not < parent id"); + } + if !split_node.on_false().as_usize() < mast_node_id { + return Err("split node on false id is not < parent id"); + } }, MastNode::Loop(loop_node) => { - assert!(loop_node.body().as_usize() < idx); + if !loop_node.body().as_usize() < mast_node_id { + return Err("loop node body id is not < parent id"); + } }, MastNode::Call(call_node) => { - assert!(call_node.callee().as_usize() < idx); + if !call_node.callee().as_usize() < mast_node_id { + return Err("call node callee id is not < parent id"); + } }, MastNode::Block(_) => (), MastNode::Dyn(_) => (), MastNode::External(_) => (), } } + + Ok(()) } /// Tests that Call(bar) still correctly calls the remapped bar block. @@ -84,12 +120,10 @@ fn mast_forest_merge_remap() { assert_eq!(merged.nodes()[2], block_bar()); assert_matches!(&merged.nodes()[3], MastNode::Call(call_node) if call_node.callee().as_u32() == 2); - let root_map_a = &root_maps[0]; - let root_map_b = &root_maps[1]; - assert_eq!(root_map_a.map_root(&id_call_a).unwrap().as_u32(), 1); - assert_eq!(root_map_b.map_root(&id_call_b).unwrap().as_u32(), 3); + assert_eq!(root_maps.map_root(0, &id_call_a).unwrap().as_u32(), 1); + assert_eq!(root_maps.map_root(1, &id_call_b).unwrap().as_u32(), 3); - assert_child_id_lt_parent_id(&merged); + assert_child_id_lt_parent_id(&merged).unwrap(); } /// Tests that Forest_A + Forest_A = Forest_A (i.e. duplicates are removed). @@ -113,7 +147,9 @@ fn mast_forest_merge_duplicate() { } // Both maps should map the roots to the same target id. - assert_eq!(&root_maps[0], &root_maps[1]); + for original_root in forest_a.procedure_roots() { + assert_eq!(&root_maps.map_root(0, original_root), &root_maps.map_root(1, original_root)); + } for merged_node in merged.nodes().iter().map(MastNode::digest) { forest_a.nodes.iter().find(|node| node.digest() == merged_node).unwrap(); @@ -123,7 +159,7 @@ fn mast_forest_merge_duplicate() { assert!(forest_a.decorators.contains(merged_decorator)); } - assert_child_id_lt_parent_id(&merged); + assert_child_id_lt_parent_id(&merged).unwrap(); } /// Tests that External(foo) is replaced by Block(foo) whether it is in forest A or B, and the @@ -159,8 +195,9 @@ fn mast_forest_merge_replace_external() { assert_matches!(&merged.nodes()[1], MastNode::Call(call_node) if call_node.callee().as_u32() == 0); // The only root node should be the call node. assert_eq!(merged.roots.len(), 1); - assert_eq!(root_map[0].map_root(&merged.roots[0]).unwrap().as_usize(), 1); - assert_child_id_lt_parent_id(&merged); + assert_eq!(root_map.map_root(0, &id_call_a).unwrap().as_usize(), 1); + assert_eq!(root_map.map_root(1, &id_call_b).unwrap().as_usize(), 1); + assert_child_id_lt_parent_id(&merged).unwrap(); } } @@ -204,10 +241,9 @@ fn mast_forest_merge_roots() { assert!(root_digests.contains(&root_digest_bar_b)); assert!(root_digests.contains(&root_digest_call_b)); - assert_root_mapping(&root_maps[0], &forest_a.roots, &merged.roots); - assert_root_mapping(&root_maps[1], &forest_b.roots, &merged.roots); + assert_root_mapping(&root_maps, vec![&forest_a.roots, &forest_b.roots], &merged.roots).unwrap(); - assert_child_id_lt_parent_id(&merged); + assert_child_id_lt_parent_id(&merged).unwrap(); } /// Test that multiple trees can be merged when the same merger is reused. @@ -258,16 +294,19 @@ fn mast_forest_merge_multiple() { assert!(root_digests.contains(&block_bar_digest)); assert!(root_digests.contains(&block_qux_digest)); - assert_contains_node_once(&merged, block_foo_digest); - assert_contains_node_once(&merged, block_bar_digest); - assert_contains_node_once(&merged, block_qux_digest); - assert_contains_node_once(&merged, call_foo_digest); + assert_contains_node_once(&merged, block_foo_digest).unwrap(); + assert_contains_node_once(&merged, block_bar_digest).unwrap(); + assert_contains_node_once(&merged, block_qux_digest).unwrap(); + assert_contains_node_once(&merged, call_foo_digest).unwrap(); - assert_root_mapping(&root_maps[0], &forest_a.roots, &merged.roots); - assert_root_mapping(&root_maps[1], &forest_b.roots, &merged.roots); - assert_root_mapping(&root_maps[2], &forest_c.roots, &merged.roots); + assert_root_mapping( + &root_maps, + vec![&forest_a.roots, &forest_b.roots, &forest_c.roots], + &merged.roots, + ) + .unwrap(); - assert_child_id_lt_parent_id(&merged); + assert_child_id_lt_parent_id(&merged).unwrap(); } /// Tests that decorators are merged and that nodes who are identical except for their @@ -399,10 +438,9 @@ fn mast_forest_merge_decorators() { 1 ); - assert_root_mapping(&root_maps[0], &forest_a.roots, &merged.roots); - assert_root_mapping(&root_maps[1], &forest_b.roots, &merged.roots); + assert_root_mapping(&root_maps, vec![&forest_a.roots, &forest_b.roots], &merged.roots).unwrap(); - assert_child_id_lt_parent_id(&merged); + assert_child_id_lt_parent_id(&merged).unwrap(); } /// Tests that an external node without decorators is replaced by its referenced node which has @@ -458,14 +496,14 @@ fn mast_forest_merge_external_node_reference_with_decorator() { assert!(fingerprints.contains(&id_foo_a_fingerprint)); if idx == 0 { - assert_root_mapping(&root_maps[0], &forest_a.roots, &merged.roots); - assert_root_mapping(&root_maps[1], &forest_b.roots, &merged.roots); + assert_root_mapping(&root_maps, vec![&forest_a.roots, &forest_b.roots], &merged.roots) + .unwrap(); } else { - assert_root_mapping(&root_maps[0], &forest_b.roots, &merged.roots); - assert_root_mapping(&root_maps[1], &forest_a.roots, &merged.roots); + assert_root_mapping(&root_maps, vec![&forest_b.roots, &forest_a.roots], &merged.roots) + .unwrap(); } - assert_child_id_lt_parent_id(&merged); + assert_child_id_lt_parent_id(&merged).unwrap(); } } @@ -526,14 +564,14 @@ fn mast_forest_merge_external_node_with_decorator() { assert!(fingerprints.contains(&id_foo_b_fingerprint)); if idx == 0 { - assert_root_mapping(&root_maps[0], &forest_a.roots, &merged.roots); - assert_root_mapping(&root_maps[1], &forest_b.roots, &merged.roots); + assert_root_mapping(&root_maps, vec![&forest_a.roots, &forest_b.roots], &merged.roots) + .unwrap(); } else { - assert_root_mapping(&root_maps[0], &forest_b.roots, &merged.roots); - assert_root_mapping(&root_maps[1], &forest_a.roots, &merged.roots); + assert_root_mapping(&root_maps, vec![&forest_b.roots, &forest_a.roots], &merged.roots) + .unwrap(); } - assert_child_id_lt_parent_id(&merged); + assert_child_id_lt_parent_id(&merged).unwrap(); } } @@ -596,14 +634,14 @@ fn mast_forest_merge_external_node_and_referenced_node_have_decorators() { assert!(fingerprints.contains(&id_foo_b_fingerprint)); if idx == 0 { - assert_root_mapping(&root_maps[0], &forest_a.roots, &merged.roots); - assert_root_mapping(&root_maps[1], &forest_b.roots, &merged.roots); + assert_root_mapping(&root_maps, vec![&forest_a.roots, &forest_b.roots], &merged.roots) + .unwrap(); } else { - assert_root_mapping(&root_maps[0], &forest_b.roots, &merged.roots); - assert_root_mapping(&root_maps[1], &forest_a.roots, &merged.roots); + assert_root_mapping(&root_maps, vec![&forest_b.roots, &forest_a.roots], &merged.roots) + .unwrap(); } - assert_child_id_lt_parent_id(&merged); + assert_child_id_lt_parent_id(&merged).unwrap(); } } @@ -674,14 +712,14 @@ fn mast_forest_merge_multiple_external_nodes_with_decorator() { assert!(fingerprints.contains(&id_foo_b_fingerprint)); if idx == 0 { - assert_root_mapping(&root_maps[0], &forest_a.roots, &merged.roots); - assert_root_mapping(&root_maps[1], &forest_b.roots, &merged.roots); + assert_root_mapping(&root_maps, vec![&forest_a.roots, &forest_b.roots], &merged.roots) + .unwrap(); } else { - assert_root_mapping(&root_maps[0], &forest_b.roots, &merged.roots); - assert_root_mapping(&root_maps[1], &forest_a.roots, &merged.roots); + assert_root_mapping(&root_maps, vec![&forest_b.roots, &forest_a.roots], &merged.roots) + .unwrap(); } - assert_child_id_lt_parent_id(&merged); + assert_child_id_lt_parent_id(&merged).unwrap(); } } @@ -725,7 +763,7 @@ fn mast_forest_merge_external_dependencies() { assert!(digests.contains(&forest_b[id_qux_b].digest())); assert_eq!(merged.nodes().iter().filter(|node| node.is_external()).count(), 0); - assert_child_id_lt_parent_id(&merged); + assert_child_id_lt_parent_id(&merged).unwrap(); } } diff --git a/core/src/mast/mod.rs b/core/src/mast/mod.rs index 5898cb9e1..172817af2 100644 --- a/core/src/mast/mod.rs +++ b/core/src/mast/mod.rs @@ -253,7 +253,7 @@ impl MastForest { /// without decorators or vice versa. pub fn merge<'forest>( forests: impl IntoIterator, - ) -> Result<(MastForest, Vec), MastForestError> { + ) -> Result<(MastForest, MastForestRootMap), MastForestError> { MastForestMerger::merge(forests) }