Skip to content

Commit

Permalink
feat: implement MastForest merging (#1534)
Browse files Browse the repository at this point in the history
  • Loading branch information
PhilippGackstatter authored Oct 28, 2024
1 parent dc41735 commit 9f9cc63
Show file tree
Hide file tree
Showing 10 changed files with 2,036 additions and 137 deletions.
1 change: 1 addition & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@
- [BREAKING] The `run` and the `prove` commands in the cli will accept `--trace` flag instead of `--tracing` (#1502)
- Migrated to new padding rule for RPO (#1343).
- Migrated to `miden-crypto` v0.11.0 (#1343).
- Implemented `MastForest` merging (#1534)

#### Fixes

Expand Down
139 changes: 3 additions & 136 deletions assembly/src/assembler/mast_forest_builder.rs
Original file line number Diff line number Diff line change
Expand Up @@ -5,8 +5,8 @@ use alloc::{
use core::ops::{Index, IndexMut};

use vm_core::{
crypto::hash::{Blake3Digest, Blake3_256, Digest, RpoDigest},
mast::{DecoratorId, MastForest, MastNode, MastNodeId},
crypto::hash::{Blake3Digest, RpoDigest},
mast::{DecoratorId, EqHash, MastForest, MastNode, MastNodeId},
Decorator, DecoratorList, Operation,
};

Expand Down Expand Up @@ -445,115 +445,9 @@ impl MastForestBuilder {
}
}

/// Helpers
impl MastForestBuilder {
fn eq_hash_for_node(&self, node: &MastNode) -> EqHash {
match node {
MastNode::Block(node) => {
let mut bytes_to_hash = Vec::new();

for &(idx, decorator_id) in node.decorators() {
bytes_to_hash.extend(idx.to_le_bytes());
bytes_to_hash.extend(self[decorator_id].eq_hash().as_bytes());
}

// Add any `Assert` or `U32assert2` opcodes present, since these are not included in
// the MAST root.
for (op_idx, op) in node.operations().enumerate() {
if let Operation::U32assert2(inner_value)
| Operation::Assert(inner_value)
| Operation::MpVerify(inner_value) = op
{
let op_idx: u32 = op_idx
.try_into()
.expect("there are more than 2^{32}-1 operations in basic block");

// we include the opcode to differentiate between `Assert` and `U32assert2`
bytes_to_hash.push(op.op_code());
// we include the operation index to distinguish between basic blocks that
// would have the same assert instructions, but in a different order
bytes_to_hash.extend(op_idx.to_le_bytes());
bytes_to_hash.extend(inner_value.to_le_bytes());
}
}

if bytes_to_hash.is_empty() {
EqHash::new(node.digest())
} else {
let decorator_root = Blake3_256::hash(&bytes_to_hash);
EqHash::with_decorator_root(node.digest(), decorator_root)
}
},
MastNode::Join(node) => self.eq_hash_from_parts(
node.before_enter(),
node.after_exit(),
&[node.first(), node.second()],
node.digest(),
),
MastNode::Split(node) => self.eq_hash_from_parts(
node.before_enter(),
node.after_exit(),
&[node.on_true(), node.on_false()],
node.digest(),
),
MastNode::Loop(node) => self.eq_hash_from_parts(
node.before_enter(),
node.after_exit(),
&[node.body()],
node.digest(),
),
MastNode::Call(node) => self.eq_hash_from_parts(
node.before_enter(),
node.after_exit(),
&[node.callee()],
node.digest(),
),
MastNode::Dyn(node) => {
self.eq_hash_from_parts(node.before_enter(), node.after_exit(), &[], node.digest())
},
MastNode::External(node) => {
self.eq_hash_from_parts(node.before_enter(), node.after_exit(), &[], node.digest())
},
}
}

fn eq_hash_from_parts(
&self,
before_enter_ids: &[DecoratorId],
after_exit_ids: &[DecoratorId],
children_ids: &[MastNodeId],
node_digest: RpoDigest,
) -> EqHash {
let pre_decorator_hash_bytes =
before_enter_ids.iter().flat_map(|&id| self[id].eq_hash().as_bytes());
let post_decorator_hash_bytes =
after_exit_ids.iter().flat_map(|&id| self[id].eq_hash().as_bytes());

// Reminder: the `EqHash`'s decorator root will be `None` if and only if there are no
// decorators attached to the node, and all children have no decorator roots (meaning that
// there are no decorators in all the descendants).
if pre_decorator_hash_bytes.clone().next().is_none()
&& post_decorator_hash_bytes.clone().next().is_none()
&& children_ids
.iter()
.filter_map(|child_id| self.hash_by_node_id[child_id].decorator_root)
.next()
.is_none()
{
EqHash::new(node_digest)
} else {
let children_decorator_roots = children_ids
.iter()
.filter_map(|child_id| self.hash_by_node_id[child_id].decorator_root)
.flat_map(|decorator_root| decorator_root.as_bytes());
let decorator_bytes_to_hash: Vec<u8> = pre_decorator_hash_bytes
.chain(post_decorator_hash_bytes)
.chain(children_decorator_roots)
.collect();

let decorator_root = Blake3_256::hash(&decorator_bytes_to_hash);
EqHash::with_decorator_root(node_digest, decorator_root)
}
EqHash::from_mast_node(&self.mast_forest, &self.hash_by_node_id, node)
}
}

Expand Down Expand Up @@ -582,33 +476,6 @@ impl IndexMut<DecoratorId> for MastForestBuilder {
}
}

// EQ HASH
// ================================================================================================

/// Represents the hash used to test for equality between [`MastNode`]s.
///
/// The decorator root will be `None` if and only if there are no decorators attached to the node,
/// and all children have no decorator roots (meaning that there are no decorators in all the
/// descendants).
#[derive(Debug, Clone, Copy, PartialEq, Eq, PartialOrd, Ord)]
struct EqHash {
mast_root: RpoDigest,
decorator_root: Option<Blake3Digest<32>>,
}

impl EqHash {
fn new(mast_root: RpoDigest) -> Self {
Self { mast_root, decorator_root: None }
}

fn with_decorator_root(mast_root: RpoDigest, decorator_root: Blake3Digest<32>) -> Self {
Self {
mast_root,
decorator_root: Some(decorator_root),
}
}
}

// HELPER FUNCTIONS
// ================================================================================================

Expand Down
73 changes: 73 additions & 0 deletions assembly/src/assembler/mast_forest_merger_tests.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,73 @@
use miette::{IntoDiagnostic, Report};
use vm_core::mast::{MastForest, MastForestRootMap};

use crate::{testing::TestContext, Assembler};

#[allow(clippy::type_complexity)]
fn merge_programs(
program_a: &str,
program_b: &str,
) -> Result<(MastForest, MastForest, MastForest, MastForestRootMap), Report> {
let context = TestContext::new();
let module = context.parse_module_with_path("lib::mod".parse().unwrap(), program_a)?;

let lib_a = Assembler::new(context.source_manager()).assemble_library([module])?;

let mut assembler = Assembler::new(context.source_manager());
assembler.add_library(lib_a.clone())?;
let lib_b = assembler.assemble_library([program_b])?.mast_forest().as_ref().clone();
let lib_a = lib_a.mast_forest().as_ref().clone();

let (merged, root_maps) = MastForest::merge([&lib_a, &lib_b]).into_diagnostic()?;

Ok((lib_a, lib_b, merged, root_maps))
}

/// Tests that an assembler-produced library's forests can be merged and that external nodes are
/// replaced by their referenced procedures.
#[test]
fn mast_forest_merge_assembler() {
let lib_a = r#"
export.foo
push.19
end
export.qux
swap drop
end
"#;

let lib_b = r#"
use.lib::mod
export.qux_duplicate
swap drop
end
export.bar
push.2
if.true
push.3
else
while.true
add
push.23
end
end
exec.mod::foo
end"#;

let (forest_a, forest_b, merged, root_maps) = merge_programs(lib_a, lib_b).unwrap();

for (forest_idx, forest) in [forest_a, forest_b].into_iter().enumerate() {
for root in forest.procedure_roots() {
let original_digest = forest.nodes()[root.as_usize()].digest();
let new_root = root_maps.map_root(forest_idx, root).unwrap();
let new_digest = merged.nodes()[new_root.as_usize()].digest();
assert_eq!(original_digest, new_digest);
}
}

// Assert that the external node for the import was removed during merging.
merged.nodes().iter().for_each(|node| assert!(!node.is_external()));
}
3 changes: 3 additions & 0 deletions assembly/src/assembler/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,9 @@ mod procedure;
#[cfg(test)]
mod tests;

#[cfg(test)]
mod mast_forest_merger_tests;

use self::{
basic_block_builder::BasicBlockBuilder,
module_graph::{CallerInfo, ModuleGraph, ResolvedTarget},
Expand Down
Loading

0 comments on commit 9f9cc63

Please sign in to comment.