Skip to content

Commit

Permalink
feat(core): Move Vec inside the MastForestRootMap
Browse files Browse the repository at this point in the history
  • Loading branch information
PhilippGackstatter committed Oct 24, 2024
1 parent cfcc972 commit 259f61b
Show file tree
Hide file tree
Showing 4 changed files with 121 additions and 85 deletions.
14 changes: 6 additions & 8 deletions assembly/src/assembler/mast_forest_merger_tests.rs
Original file line number Diff line number Diff line change
@@ -1,5 +1,3 @@
use std::vec::Vec;

use miette::{IntoDiagnostic, Report};
use vm_core::mast::{MastForest, MastForestRootMap};

Expand All @@ -9,7 +7,7 @@ use crate::{testing::TestContext, Assembler};
fn merge_programs(
program_a: &str,
program_b: &str,
) -> Result<(MastForest, MastForest, (MastForest, Vec<MastForestRootMap>)), Report> {
) -> Result<(MastForest, MastForest, MastForest, MastForestRootMap), Report> {
let context = TestContext::new();
let module = context.parse_module_with_path("lib::mod".parse().unwrap(), program_a)?;

Expand All @@ -20,9 +18,9 @@ fn merge_programs(
let lib_b = assembler.assemble_library([program_b])?.mast_forest().as_ref().clone();
let lib_a = lib_a.mast_forest().as_ref().clone();

let merged = MastForest::merge([&lib_a, &lib_b]).into_diagnostic()?;
let (merged, root_maps) = MastForest::merge([&lib_a, &lib_b]).into_diagnostic()?;

Ok((lib_a, lib_b, merged))
Ok((lib_a, lib_b, merged, root_maps))
}

/// Tests that an assembler-produced library's forests can be merged and that external nodes are
Expand Down Expand Up @@ -51,12 +49,12 @@ fn mast_forest_merge_assembler() {
exec.mod::foo
end"#;

let (forest_a, forest_b, (merged, root_maps)) = merge_programs(lib_a, lib_b).unwrap();
let (forest_a, forest_b, merged, root_maps) = merge_programs(lib_a, lib_b).unwrap();

for (forest, root_map) in [(forest_a, &root_maps[0]), (forest_b, &root_maps[1])] {
for (forest_idx, forest) in [forest_a, forest_b].into_iter().enumerate() {
for root in forest.procedure_roots() {
let original_digest = forest.nodes()[root.as_usize()].digest();
let new_root = root_map.map_root(root).unwrap();
let new_root = root_maps.map_root(forest_idx, root).unwrap();
let new_digest = forest.nodes()[new_root.as_usize()].digest();
assert_eq!(original_digest, new_digest);
}
Expand Down
34 changes: 17 additions & 17 deletions core/src/mast/merger/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@ impl MastForestMerger {
/// [`MastForest`]s are merged.
pub(crate) fn merge<'forest>(
forests: impl IntoIterator<Item = &'forest MastForest>,
) -> Result<(MastForest, Vec<MastForestRootMap>), MastForestError> {
) -> Result<(MastForest, MastForestRootMap), MastForestError> {
let forests = forests.into_iter().collect::<Vec<_>>();
let decorator_id_mappings = Vec::with_capacity(forests.len());
let node_id_mappings = vec![MastForestNodeIdMap::new(); forests.len()];
Expand All @@ -47,11 +47,7 @@ impl MastForestMerger {

let Self { mast_forest, node_id_mappings, .. } = merger;

let mut root_maps = Vec::new();
for (forest_idx, mapping) in node_id_mappings.into_iter().enumerate() {
let forest = forests[forest_idx];
root_maps.push(MastForestRootMap::from_node_id_map(mapping, &forest.roots));
}
let root_maps = MastForestRootMap::from_node_id_map(node_id_mappings, forests);

Ok((mast_forest, root_maps))
}
Expand Down Expand Up @@ -320,27 +316,31 @@ impl MastForestMerger {
/// forest. See [`MastForest::merge`] for more details.
#[derive(Debug, Clone, PartialEq, Eq)]
pub struct MastForestRootMap {
root_map: BTreeMap<MastNodeId, MastNodeId>,
root_maps: Vec<BTreeMap<MastNodeId, MastNodeId>>,
}

impl MastForestRootMap {
fn from_node_id_map(id_map: MastForestNodeIdMap, roots: &[MastNodeId]) -> Self {
let mut root_map = BTreeMap::new();

for root in roots {
let new_id =
id_map.get(root).copied().expect("every node id should be mapped to its new id");
root_map.insert(*root, new_id);
fn from_node_id_map(id_map: Vec<MastForestNodeIdMap>, forests: Vec<&MastForest>) -> Self {
let mut root_maps = vec![BTreeMap::new(); forests.len()];

for (forest_idx, forest) in forests.into_iter().enumerate() {
for root in forest.procedure_roots() {
let new_id = id_map[forest_idx]
.get(root)
.copied()
.expect("every node id should be mapped to its new id");
root_maps[forest_idx].insert(*root, new_id);
}
}

Self { root_map }
Self { root_maps }
}

/// Maps the given root to its new location in the merged forest, if such a mapping exists.
///
/// It is guaranteed that every root of the map's corresponding forest is contained in the map.
pub fn map_root(&self, root: &MastNodeId) -> Option<MastNodeId> {
self.root_map.get(root).copied()
pub fn map_root(&self, forest_index: usize, root: &MastNodeId) -> Option<MastNodeId> {
self.root_maps.get(forest_index).and_then(|map| map.get(root)).copied()
}
}

Expand Down
156 changes: 97 additions & 59 deletions core/src/mast/merger/tests.rs
Original file line number Diff line number Diff line change
Expand Up @@ -16,45 +16,81 @@ fn block_qux() -> MastNode {
.unwrap()
}

fn assert_contains_node_once(forest: &MastForest, digest: RpoDigest) {
assert_eq!(forest.nodes.iter().filter(|node| node.digest() == digest).count(), 1);
/// Asserts that the given forest contains exactly one node with the given digest.
///
/// Returns a Result which can be unwrapped in the calling test function to assert. This way, if
/// this assertion fails it'll be clear which exact call failed.
fn assert_contains_node_once(forest: &MastForest, digest: RpoDigest) -> Result<(), &str> {
if forest.nodes.iter().filter(|node| node.digest() == digest).count() != 1 {
return Err("node digest contained more than once in the forest");
}

Ok(())
}

/// Asserts that every root of an original forest has an id to which it is mapped and that this
/// mapped root is in the set of roots in the merged forest.
///
/// Returns a Result which can be unwrapped in the calling test function to assert. This way, if
/// this assertion fails it'll be clear which exact call failed.
fn assert_root_mapping(
root_map: &MastForestRootMap,
original_roots: &[MastNodeId],
original_roots: Vec<&[MastNodeId]>,
merged_roots: &[MastNodeId],
) {
for original_root in original_roots {
let mapped_root = root_map.map_root(original_root).unwrap();
assert!(merged_roots.contains(&mapped_root));
) -> Result<(), &'static str> {
for (forest_idx, original_root) in original_roots.into_iter().enumerate() {
for root in original_root {
let mapped_root = root_map.map_root(forest_idx, root).unwrap();
if !merged_roots.contains(&mapped_root) {
return Err("merged root does not contain mapped root");
}
}
}

Ok(())
}

fn assert_child_id_lt_parent_id(forest: &MastForest) {
for (idx, node) in forest.nodes().iter().enumerate() {
/// Asserts that all children of nodes in the given forest have an id that is less than the parent's
/// ID.
///
/// Returns a Result which can be unwrapped in the calling test function to assert. This way, if
/// this assertion fails it'll be clear which exact call failed.
fn assert_child_id_lt_parent_id(forest: &MastForest) -> Result<(), &str> {
for (mast_node_id, node) in forest.nodes().iter().enumerate() {
match node {
MastNode::Join(join_node) => {
assert!(join_node.first().as_usize() < idx);
assert!(join_node.second().as_usize() < idx);
if !join_node.first().as_usize() < mast_node_id {
return Err("join node first child id is not < parent id");
};
if !join_node.second().as_usize() < mast_node_id {
return Err("join node second child id is not < parent id");
}
},
MastNode::Split(split_node) => {
assert!(split_node.on_true().as_usize() < idx);
assert!(split_node.on_false().as_usize() < idx);
if !split_node.on_true().as_usize() < mast_node_id {
return Err("split node on true id is not < parent id");
}
if !split_node.on_false().as_usize() < mast_node_id {
return Err("split node on false id is not < parent id");
}
},
MastNode::Loop(loop_node) => {
assert!(loop_node.body().as_usize() < idx);
if !loop_node.body().as_usize() < mast_node_id {
return Err("loop node body id is not < parent id");
}
},
MastNode::Call(call_node) => {
assert!(call_node.callee().as_usize() < idx);
if !call_node.callee().as_usize() < mast_node_id {
return Err("call node callee id is not < parent id");
}
},
MastNode::Block(_) => (),
MastNode::Dyn(_) => (),
MastNode::External(_) => (),
}
}

Ok(())
}

/// Tests that Call(bar) still correctly calls the remapped bar block.
Expand Down Expand Up @@ -84,12 +120,10 @@ fn mast_forest_merge_remap() {
assert_eq!(merged.nodes()[2], block_bar());
assert_matches!(&merged.nodes()[3], MastNode::Call(call_node) if call_node.callee().as_u32() == 2);

let root_map_a = &root_maps[0];
let root_map_b = &root_maps[1];
assert_eq!(root_map_a.map_root(&id_call_a).unwrap().as_u32(), 1);
assert_eq!(root_map_b.map_root(&id_call_b).unwrap().as_u32(), 3);
assert_eq!(root_maps.map_root(0, &id_call_a).unwrap().as_u32(), 1);
assert_eq!(root_maps.map_root(1, &id_call_b).unwrap().as_u32(), 3);

assert_child_id_lt_parent_id(&merged);
assert_child_id_lt_parent_id(&merged).unwrap();
}

/// Tests that Forest_A + Forest_A = Forest_A (i.e. duplicates are removed).
Expand All @@ -113,7 +147,9 @@ fn mast_forest_merge_duplicate() {
}

// Both maps should map the roots to the same target id.
assert_eq!(&root_maps[0], &root_maps[1]);
for original_root in forest_a.procedure_roots() {
assert_eq!(&root_maps.map_root(0, original_root), &root_maps.map_root(1, original_root));
}

for merged_node in merged.nodes().iter().map(MastNode::digest) {
forest_a.nodes.iter().find(|node| node.digest() == merged_node).unwrap();
Expand All @@ -123,7 +159,7 @@ fn mast_forest_merge_duplicate() {
assert!(forest_a.decorators.contains(merged_decorator));
}

assert_child_id_lt_parent_id(&merged);
assert_child_id_lt_parent_id(&merged).unwrap();
}

/// Tests that External(foo) is replaced by Block(foo) whether it is in forest A or B, and the
Expand Down Expand Up @@ -159,8 +195,9 @@ fn mast_forest_merge_replace_external() {
assert_matches!(&merged.nodes()[1], MastNode::Call(call_node) if call_node.callee().as_u32() == 0);
// The only root node should be the call node.
assert_eq!(merged.roots.len(), 1);
assert_eq!(root_map[0].map_root(&merged.roots[0]).unwrap().as_usize(), 1);
assert_child_id_lt_parent_id(&merged);
assert_eq!(root_map.map_root(0, &id_call_a).unwrap().as_usize(), 1);
assert_eq!(root_map.map_root(1, &id_call_b).unwrap().as_usize(), 1);
assert_child_id_lt_parent_id(&merged).unwrap();
}
}

Expand Down Expand Up @@ -204,10 +241,9 @@ fn mast_forest_merge_roots() {
assert!(root_digests.contains(&root_digest_bar_b));
assert!(root_digests.contains(&root_digest_call_b));

assert_root_mapping(&root_maps[0], &forest_a.roots, &merged.roots);
assert_root_mapping(&root_maps[1], &forest_b.roots, &merged.roots);
assert_root_mapping(&root_maps, vec![&forest_a.roots, &forest_b.roots], &merged.roots).unwrap();

assert_child_id_lt_parent_id(&merged);
assert_child_id_lt_parent_id(&merged).unwrap();
}

/// Test that multiple trees can be merged when the same merger is reused.
Expand Down Expand Up @@ -258,16 +294,19 @@ fn mast_forest_merge_multiple() {
assert!(root_digests.contains(&block_bar_digest));
assert!(root_digests.contains(&block_qux_digest));

assert_contains_node_once(&merged, block_foo_digest);
assert_contains_node_once(&merged, block_bar_digest);
assert_contains_node_once(&merged, block_qux_digest);
assert_contains_node_once(&merged, call_foo_digest);
assert_contains_node_once(&merged, block_foo_digest).unwrap();
assert_contains_node_once(&merged, block_bar_digest).unwrap();
assert_contains_node_once(&merged, block_qux_digest).unwrap();
assert_contains_node_once(&merged, call_foo_digest).unwrap();

assert_root_mapping(&root_maps[0], &forest_a.roots, &merged.roots);
assert_root_mapping(&root_maps[1], &forest_b.roots, &merged.roots);
assert_root_mapping(&root_maps[2], &forest_c.roots, &merged.roots);
assert_root_mapping(
&root_maps,
vec![&forest_a.roots, &forest_b.roots, &forest_c.roots],
&merged.roots,
)
.unwrap();

assert_child_id_lt_parent_id(&merged);
assert_child_id_lt_parent_id(&merged).unwrap();
}

/// Tests that decorators are merged and that nodes who are identical except for their
Expand Down Expand Up @@ -399,10 +438,9 @@ fn mast_forest_merge_decorators() {
1
);

assert_root_mapping(&root_maps[0], &forest_a.roots, &merged.roots);
assert_root_mapping(&root_maps[1], &forest_b.roots, &merged.roots);
assert_root_mapping(&root_maps, vec![&forest_a.roots, &forest_b.roots], &merged.roots).unwrap();

assert_child_id_lt_parent_id(&merged);
assert_child_id_lt_parent_id(&merged).unwrap();
}

/// Tests that an external node without decorators is replaced by its referenced node which has
Expand Down Expand Up @@ -458,14 +496,14 @@ fn mast_forest_merge_external_node_reference_with_decorator() {
assert!(fingerprints.contains(&id_foo_a_fingerprint));

if idx == 0 {
assert_root_mapping(&root_maps[0], &forest_a.roots, &merged.roots);
assert_root_mapping(&root_maps[1], &forest_b.roots, &merged.roots);
assert_root_mapping(&root_maps, vec![&forest_a.roots, &forest_b.roots], &merged.roots)
.unwrap();
} else {
assert_root_mapping(&root_maps[0], &forest_b.roots, &merged.roots);
assert_root_mapping(&root_maps[1], &forest_a.roots, &merged.roots);
assert_root_mapping(&root_maps, vec![&forest_b.roots, &forest_a.roots], &merged.roots)
.unwrap();
}

assert_child_id_lt_parent_id(&merged);
assert_child_id_lt_parent_id(&merged).unwrap();
}
}

Expand Down Expand Up @@ -526,14 +564,14 @@ fn mast_forest_merge_external_node_with_decorator() {
assert!(fingerprints.contains(&id_foo_b_fingerprint));

if idx == 0 {
assert_root_mapping(&root_maps[0], &forest_a.roots, &merged.roots);
assert_root_mapping(&root_maps[1], &forest_b.roots, &merged.roots);
assert_root_mapping(&root_maps, vec![&forest_a.roots, &forest_b.roots], &merged.roots)
.unwrap();
} else {
assert_root_mapping(&root_maps[0], &forest_b.roots, &merged.roots);
assert_root_mapping(&root_maps[1], &forest_a.roots, &merged.roots);
assert_root_mapping(&root_maps, vec![&forest_b.roots, &forest_a.roots], &merged.roots)
.unwrap();
}

assert_child_id_lt_parent_id(&merged);
assert_child_id_lt_parent_id(&merged).unwrap();
}
}

Expand Down Expand Up @@ -596,14 +634,14 @@ fn mast_forest_merge_external_node_and_referenced_node_have_decorators() {
assert!(fingerprints.contains(&id_foo_b_fingerprint));

if idx == 0 {
assert_root_mapping(&root_maps[0], &forest_a.roots, &merged.roots);
assert_root_mapping(&root_maps[1], &forest_b.roots, &merged.roots);
assert_root_mapping(&root_maps, vec![&forest_a.roots, &forest_b.roots], &merged.roots)
.unwrap();
} else {
assert_root_mapping(&root_maps[0], &forest_b.roots, &merged.roots);
assert_root_mapping(&root_maps[1], &forest_a.roots, &merged.roots);
assert_root_mapping(&root_maps, vec![&forest_b.roots, &forest_a.roots], &merged.roots)
.unwrap();
}

assert_child_id_lt_parent_id(&merged);
assert_child_id_lt_parent_id(&merged).unwrap();
}
}

Expand Down Expand Up @@ -674,14 +712,14 @@ fn mast_forest_merge_multiple_external_nodes_with_decorator() {
assert!(fingerprints.contains(&id_foo_b_fingerprint));

if idx == 0 {
assert_root_mapping(&root_maps[0], &forest_a.roots, &merged.roots);
assert_root_mapping(&root_maps[1], &forest_b.roots, &merged.roots);
assert_root_mapping(&root_maps, vec![&forest_a.roots, &forest_b.roots], &merged.roots)
.unwrap();
} else {
assert_root_mapping(&root_maps[0], &forest_b.roots, &merged.roots);
assert_root_mapping(&root_maps[1], &forest_a.roots, &merged.roots);
assert_root_mapping(&root_maps, vec![&forest_b.roots, &forest_a.roots], &merged.roots)
.unwrap();
}

assert_child_id_lt_parent_id(&merged);
assert_child_id_lt_parent_id(&merged).unwrap();
}
}

Expand Down Expand Up @@ -725,7 +763,7 @@ fn mast_forest_merge_external_dependencies() {
assert!(digests.contains(&forest_b[id_qux_b].digest()));
assert_eq!(merged.nodes().iter().filter(|node| node.is_external()).count(), 0);

assert_child_id_lt_parent_id(&merged);
assert_child_id_lt_parent_id(&merged).unwrap();
}
}

Expand Down
Loading

0 comments on commit 259f61b

Please sign in to comment.