Skip to content

Commit

Permalink
feat(core): Take ownership of merging forests
Browse files Browse the repository at this point in the history
  • Loading branch information
PhilippGackstatter committed Oct 22, 2024
1 parent 3521b63 commit 4525cd6
Show file tree
Hide file tree
Showing 6 changed files with 110 additions and 70 deletions.
2 changes: 1 addition & 1 deletion assembly/src/assembler/mast_forest_merger_tests.rs
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@ fn merge_programs(
let lib_b = assembler.assemble_library([program_b])?.mast_forest().as_ref().clone();
let lib_a = lib_a.mast_forest().as_ref().clone();

let merged = MastForest::merge([&lib_a, &lib_b]).into_diagnostic()?;
let merged = MastForest::merge([lib_a.clone(), lib_b.clone()]).into_diagnostic()?;

Ok((lib_a, lib_b, merged))
}
Expand Down
83 changes: 46 additions & 37 deletions core/src/mast/merger/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,10 @@ use core::ops::ControlFlow;

use miden_crypto::hash::{blake::Blake3Digest, rpo::RpoDigest};

use crate::mast::{DecoratorId, EqHash, MastForest, MastForestError, MastNode, MastNodeId};
use crate::{
mast::{DecoratorId, EqHash, MastForest, MastForestError, MastNode, MastNodeId},
Decorator,
};

#[cfg(test)]
mod tests;
Expand Down Expand Up @@ -34,34 +37,41 @@ impl MastForestMerger {
/// Merges `other_forest` into the forest contained in self.
pub(crate) fn merge(
&mut self,
other_forest: &MastForest,
mut other_forest: MastForest,
) -> Result<MastForestRootMap, MastForestError> {
let mut decorator_id_remapping = DecoratorIdMap::new(other_forest.decorators.len());
let mut node_id_remapping = MastForestNodeIdMap::new();

self.merge_decorators(other_forest, &mut decorator_id_remapping)?;
// It's fine to take out the decorators here as they aren't accessed after this point.
self.merge_decorators(
core::mem::take(&mut other_forest.decorators),
&mut decorator_id_remapping,
)?;
// `merge_nodes` takes ownership of the forest and needs the roots intact for the DFS
// iteration, so we cannot core::mem::take them out, so we copy the roots as we need
// them after merging of nodes is done.
let roots = other_forest.roots.clone();
self.merge_nodes(other_forest, &decorator_id_remapping, &mut node_id_remapping)?;
self.merge_roots(other_forest, &node_id_remapping)?;
self.merge_roots(roots.as_slice(), &node_id_remapping)?;

let root_map =
MastForestRootMap::from_node_id_map(node_id_remapping, other_forest.roots.as_slice());
let root_map = MastForestRootMap::from_node_id_map(node_id_remapping, roots.as_slice());

Ok(root_map)
}

fn merge_decorators(
&mut self,
other_forest: &MastForest,
decorators: Vec<Decorator>,
decorator_id_remapping: &mut DecoratorIdMap,
) -> Result<(), MastForestError> {
for (merging_id, merging_decorator) in other_forest.decorators.iter().enumerate() {
for (merging_id, merging_decorator) in decorators.into_iter().enumerate() {
let merging_decorator_hash = merging_decorator.eq_hash();
let new_decorator_id = if let Some(existing_decorator) =
self.decorators_by_hash.get(&merging_decorator_hash)
{
*existing_decorator
} else {
let new_decorator_id = self.mast_forest.add_decorator(merging_decorator.clone())?;
let new_decorator_id = self.mast_forest.add_decorator(merging_decorator)?;
self.decorators_by_hash.insert(merging_decorator_hash, new_decorator_id);
new_decorator_id
};
Expand All @@ -75,7 +85,7 @@ impl MastForestMerger {

fn merge_nodes(
&mut self,
other_forest: &MastForest,
other_forest: MastForest,
decorator_id_remapping: &DecoratorIdMap,
node_id_remapping: &mut MastForestNodeIdMap,
) -> Result<(), MastForestError> {
Expand Down Expand Up @@ -127,10 +137,10 @@ impl MastForestMerger {

fn merge_roots(
&mut self,
other_forest: &MastForest,
roots: &[MastNodeId],
node_id_remapping: &MastForestNodeIdMap,
) -> Result<(), MastForestError> {
for root_id in other_forest.roots.iter() {
for root_id in roots {
// Map the previous root to its possibly new id.
let new_root =
node_id_remapping.get(root_id).expect("all node ids should have an entry");
Expand Down Expand Up @@ -226,7 +236,7 @@ impl MastForestMerger {
/// the given maps.
fn remap_node(
&self,
node: &MastNode,
node: MastNode,
decorator_id_remapping: &DecoratorIdMap,
node_id_remapping: &MastForestNodeIdMap,
) -> Result<MastNode, MastForestError> {
Expand All @@ -246,6 +256,15 @@ impl MastForestMerger {
.expect("every node id should have an entry")
};

// Decorators must be handled specially for basic block nodes.
// For other node types we can handle it centrally.
let mut before_enter = Vec::new();
let mut after_exit = Vec::new();
if !node.is_basic_block() {
before_enter = map_decorators(node.before_enter())?;
after_exit = map_decorators(node.after_exit())?;
}

// Due to DFS postorder iteration all children of node's should have been inserted before
// their parents which is why we can `expect` the constructor calls here.
let mut mapped_node = match node {
Expand Down Expand Up @@ -273,34 +292,24 @@ impl MastForestMerger {
MastNode::new_call(callee, &self.mast_forest)
.expect("CallNode children should have been mapped to a lower index")
},
// Other nodes are simply copied.
MastNode::Block(basic_block_node) => {
MastNode::new_basic_block(
basic_block_node.operations().copied().collect(),
// Operation Indices of decorators stay the same while decorator IDs need to be
// mapped.
Some(
basic_block_node
.decorators()
.iter()
.map(|(idx, decorator_id)| match map_decorator_id(decorator_id) {
Ok(mapped_decorator) => Ok((*idx, mapped_decorator)),
Err(err) => Err(err),
})
.collect::<Result<Vec<_>, _>>()?,
),
)
.expect("previously valid BasicBlockNode should still be valid")
MastNode::Block(mut basic_block_node) => {
basic_block_node.map_decorators(|decorator_id| {
match map_decorator_id(decorator_id) {
Ok(mapped_decorator) => Ok(mapped_decorator),
Err(err) => Err(err),
}
})?;

MastNode::Block(basic_block_node)
},
MastNode::Dyn(_) => MastNode::new_dyn(),
MastNode::External(external_node) => MastNode::new_external(external_node.digest()),
// Other nodes are simply copied.
MastNode::Dyn(_) => node,
MastNode::External(_) => node,
};

// Decorators must be handled specially for basic block nodes.
// For other node types we can handle it centrally.
if !mapped_node.is_basic_block() {
mapped_node.set_before_enter(map_decorators(node.before_enter())?);
mapped_node.set_after_exit(map_decorators(node.after_exit())?);
mapped_node.set_before_enter(before_enter);
mapped_node.set_after_exit(after_exit);
}

Ok(mapped_node)
Expand Down
35 changes: 19 additions & 16 deletions core/src/mast/merger/tests.rs
Original file line number Diff line number Diff line change
Expand Up @@ -51,7 +51,7 @@ fn mast_forest_merge_remap() {
let id_call_b = forest_b.add_call(id_bar).unwrap();
forest_b.make_root(id_call_b);

let (merged, root_maps) = MastForest::merge([&forest_a, &forest_b]).unwrap();
let (merged, root_maps) = MastForest::merge([forest_a, forest_b]).unwrap();

assert_eq!(merged.nodes().len(), 4);
assert_eq!(merged.nodes()[0], block_foo());
Expand Down Expand Up @@ -79,7 +79,7 @@ fn mast_forest_merge_duplicate() {
forest_a.make_root(id_call);
forest_a.make_root(id_loop);

let (merged, root_maps) = MastForest::merge([&forest_a, &forest_a]).unwrap();
let (merged, root_maps) = MastForest::merge([forest_a.clone(), forest_a.clone()]).unwrap();

for merged_root in merged.procedure_digests() {
forest_a.procedure_digests().find(|root| root == &merged_root).unwrap();
Expand Down Expand Up @@ -121,8 +121,10 @@ fn mast_forest_merge_replace_external() {
let id_call_b = forest_b.add_call(id_foo_b).unwrap();
forest_b.make_root(id_call_b);

let (merged_ab, root_maps_ab) = MastForest::merge([&forest_a, &forest_b]).unwrap();
let (merged_ba, root_maps_ba) = MastForest::merge([&forest_b, &forest_a]).unwrap();
let (merged_ab, root_maps_ab) =
MastForest::merge([forest_a.clone(), forest_b.clone()]).unwrap();
let (merged_ba, root_maps_ba) =
MastForest::merge([forest_b.clone(), forest_a.clone()]).unwrap();

for (merged, root_map) in [(merged_ab, root_maps_ab), (merged_ba, root_maps_ba)] {
assert_eq!(merged.nodes().len(), 2);
Expand Down Expand Up @@ -162,7 +164,7 @@ fn mast_forest_merge_roots() {
let root_digest_bar_b = forest_b.get_node_by_id(id_bar_b).unwrap().digest();
let root_digest_call_b = forest_b.get_node_by_id(call_b).unwrap().digest();

let (merged, root_maps) = MastForest::merge([&forest_a, &forest_b]).unwrap();
let (merged, root_maps) = MastForest::merge([forest_a.clone(), forest_b.clone()]).unwrap();

// Asserts (together with the other assertions) that the duplicate Call(foo) roots have been
// deduplicated.
Expand Down Expand Up @@ -212,7 +214,8 @@ fn mast_forest_merge_multiple() {
forest_c.make_root(id_qux_c);
forest_c.make_root(call_c);

let (merged, root_maps) = MastForest::merge([&forest_a, &forest_b, &forest_c]).unwrap();
let (merged, root_maps) =
MastForest::merge([forest_a.clone(), forest_b.clone(), forest_c.clone()]).unwrap();

let block_foo_digest = forest_b.get_node_by_id(id_foo_b).unwrap().digest();
let block_bar_digest = forest_b.get_node_by_id(id_bar_b).unwrap().digest();
Expand Down Expand Up @@ -295,7 +298,7 @@ fn mast_forest_merge_decorators() {

forest_b.make_root(id_loop_b);

let (merged, root_maps) = MastForest::merge([&forest_a, &forest_b]).unwrap();
let (merged, root_maps) = MastForest::merge([forest_a.clone(), forest_b.clone()]).unwrap();

// There are 4 unique decorators across both forests.
assert_eq!(merged.decorators.len(), 4);
Expand Down Expand Up @@ -403,8 +406,8 @@ fn mast_forest_merge_external_node_reference_with_decorator() {
forest_b.make_root(id_external_b);

for (idx, (merged, root_maps)) in [
MastForest::merge([&forest_a, &forest_b]).unwrap(),
MastForest::merge([&forest_b, &forest_a]).unwrap(),
MastForest::merge([forest_a.clone(), forest_b.clone()]).unwrap(),
MastForest::merge([forest_b.clone(), forest_a.clone()]).unwrap(),
]
.into_iter()
.enumerate()
Expand Down Expand Up @@ -467,8 +470,8 @@ fn mast_forest_merge_external_node_with_decorator() {
forest_b.make_root(id_foo_b);

for (idx, (merged, root_maps)) in [
MastForest::merge([&forest_a, &forest_b]).unwrap(),
MastForest::merge([&forest_b, &forest_a]).unwrap(),
MastForest::merge([forest_a.clone(), forest_b.clone()]).unwrap(),
MastForest::merge([forest_b.clone(), forest_a.clone()]).unwrap(),
]
.into_iter()
.enumerate()
Expand Down Expand Up @@ -535,8 +538,8 @@ fn mast_forest_merge_external_node_and_referenced_node_have_decorators() {
forest_b.make_root(id_foo_b);

for (idx, (merged, root_maps)) in [
MastForest::merge([&forest_a, &forest_b]).unwrap(),
MastForest::merge([&forest_b, &forest_a]).unwrap(),
MastForest::merge([forest_a.clone(), forest_b.clone()]).unwrap(),
MastForest::merge([forest_b.clone(), forest_a.clone()]).unwrap(),
]
.into_iter()
.enumerate()
Expand Down Expand Up @@ -611,8 +614,8 @@ fn mast_forest_merge_multiple_external_nodes_with_decorator() {
forest_b.make_root(id_foo_b);

for (idx, (merged, root_maps)) in [
MastForest::merge([&forest_a, &forest_b]).unwrap(),
MastForest::merge([&forest_b, &forest_a]).unwrap(),
MastForest::merge([forest_a.clone(), forest_b.clone()]).unwrap(),
MastForest::merge([forest_b.clone(), forest_a.clone()]).unwrap(),
]
.into_iter()
.enumerate()
Expand Down Expand Up @@ -665,6 +668,6 @@ fn mast_forest_merge_invalid_decorator_index() {

forest_b.make_root(id_foo_b);

let err = MastForest::merge([&forest_a, &forest_b]).unwrap_err();
let err = MastForest::merge([forest_a.clone(), forest_b.clone()]).unwrap_err();
assert_matches!(err, MastForestError::DecoratorIdOverflow(_, _));
}
6 changes: 3 additions & 3 deletions core/src/mast/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -251,8 +251,8 @@ impl MastForest {
/// which is effectively deduplication. Decorators are ignored when it comes to merging
/// External nodes. This means that an External node with decorators may be replaced by a node
/// without decorators or vice versa.
pub fn merge<'forest>(
forests: impl IntoIterator<Item = &'forest MastForest>,
pub fn merge(
forests: impl IntoIterator<Item = MastForest>,
) -> Result<(MastForest, Vec<MastForestRootMap>), MastForestError> {
let mut root_maps = Vec::new();
let mut merger = MastForestMerger::new();
Expand Down Expand Up @@ -497,7 +497,7 @@ impl MastForest {
///
/// The iteration on a high-level thus consists of a constant back and forth between discovering
/// trees and returning nodes from the stack.
pub fn iter_nodes(&self) -> impl Iterator<Item = (MastNodeId, &MastNode)> {
pub fn iter_nodes(self) -> impl Iterator<Item = (MastNodeId, MastNode)> {
MastForestNodeIter::new(self)
}
}
Expand Down
12 changes: 12 additions & 0 deletions core/src/mast/node/basic_block_node/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -198,6 +198,18 @@ impl BasicBlockNode {
.expect("basic block contains more than 2^32 operations and decorators")
}

/// A specialized function to map the decorators of self to new values while leaving the
/// remaining parts as-is.
pub(crate) fn map_decorators<E>(
&mut self,
decorator_map: impl Fn(&DecoratorId) -> Result<DecoratorId, E>,
) -> Result<(), E> {
for (_, decorator) in self.decorators.iter_mut() {
*decorator = decorator_map(decorator)?;
}
Ok(())
}

/// Returns an iterator over all operations and decorator, in the order in which they appear in
/// the program.
pub fn iter(&self) -> impl Iterator<Item = OperationOrDecorator> {
Expand Down
Loading

0 comments on commit 4525cd6

Please sign in to comment.