diff --git a/CHANGELOG.md b/CHANGELOG.md
index 72392ea618..b8751ad390 100644
--- a/CHANGELOG.md
+++ b/CHANGELOG.md
@@ -6,8 +6,8 @@
- Added error codes support for the `mtree_verify` instruction (#1328).
- Added support for immediate values for `lt`, `lte`, `gt`, `gte` comparison instructions (#1346).
-- Change MAST to a table-based representation (#1349)
-- Introduce `MastForestStore` (#1359)
+- Changed MAST to a table-based representation (#1349)
+- Introduced `MastForestStore` (#1359)
- Adjusted prover's metal acceleration code to work with 0.9 versions of the crates (#1357)
- Added support for immediate values for `u32lt`, `u32lte`, `u32gt`, `u32gte`, `u32min` and `u32max` comparison instructions (#1358).
- Added support for the `nop` instruction, which corresponds to the VM opcode of the same name, and has the same semantics. This is implemented for use by compilers primarily.
@@ -16,9 +16,10 @@
- Added support for immediate values for `u32and`, `u32or`, `u32xor` and `u32not` bitwise instructions (#1362).
- Optimized `std::sys::truncate_stuck` procedure (#1384).
- Updated CI and Makefile to standardise it accross Miden repositories (#1342).
-- Add serialization/deserialization for `MastForest` (#1370)
+- Added serialization/deserialization for `MastForest` (#1370)
- Updated CI to support `CHANGELOG.md` modification checking and `no changelog` label (#1406)
-- Introduce `MastForestError` to enforce `MastForest` node count invariant (#1394)
+- Introduced `MastForestError` to enforce `MastForest` node count invariant (#1394)
+- Added functions to `MastForestBuilder` to allow ensuring of nodes with fewer LOC (#1404)
#### Changed
diff --git a/assembly/src/assembler/basic_block_builder.rs b/assembly/src/assembler/basic_block_builder.rs
index 2545d4eea8..5263faeff2 100644
--- a/assembly/src/assembler/basic_block_builder.rs
+++ b/assembly/src/assembler/basic_block_builder.rs
@@ -4,7 +4,7 @@ use super::{
};
use alloc::{borrow::Borrow, string::ToString, vec::Vec};
use vm_core::{
- mast::{MastForestError, MastNode, MastNodeId},
+ mast::{MastForestError, MastNodeId},
AdviceInjector, AssemblyOp, Operation,
};
@@ -134,8 +134,7 @@ impl BasicBlockBuilder {
let ops = self.ops.drain(..).collect();
let decorators = self.decorators.drain(..).collect();
- let basic_block_node = MastNode::new_basic_block_with_decorators(ops, decorators);
- let basic_block_node_id = mast_forest_builder.ensure_node(basic_block_node)?;
+ let basic_block_node_id = mast_forest_builder.ensure_block(ops, Some(decorators))?;
Ok(Some(basic_block_node_id))
} else if !self.decorators.is_empty() {
diff --git a/assembly/src/assembler/instruction/procedures.rs b/assembly/src/assembler/instruction/procedures.rs
index d6a96f0f1b..a86474fce7 100644
--- a/assembly/src/assembler/instruction/procedures.rs
+++ b/assembly/src/assembler/instruction/procedures.rs
@@ -6,7 +6,7 @@ use crate::{
};
use smallvec::SmallVec;
-use vm_core::mast::{MastForest, MastNode, MastNodeId};
+use vm_core::mast::{MastForest, MastNodeId};
/// Procedure Invocation
impl Assembler {
@@ -96,8 +96,7 @@ impl Assembler {
None => {
// If the MAST root called isn't known to us, make it an external
// reference.
- let external_node = MastNode::new_external(mast_root);
- mast_forest_builder.ensure_node(external_node)?
+ mast_forest_builder.ensure_external(mast_root)?
}
}
}
@@ -107,13 +106,11 @@ impl Assembler {
None => {
// If the MAST root called isn't known to us, make it an external
// reference.
- let external_node = MastNode::new_external(mast_root);
- mast_forest_builder.ensure_node(external_node)?
+ mast_forest_builder.ensure_external(mast_root)?
}
};
- let call_node = MastNode::new_call(callee_id, mast_forest_builder.forest());
- mast_forest_builder.ensure_node(call_node)?
+ mast_forest_builder.ensure_call(callee_id)?
}
InvokeKind::SysCall => {
let callee_id = match mast_forest_builder.find_procedure_root(mast_root) {
@@ -121,14 +118,11 @@ impl Assembler {
None => {
// If the MAST root called isn't known to us, make it an external
// reference.
- let external_node = MastNode::new_external(mast_root);
- mast_forest_builder.ensure_node(external_node)?
+ mast_forest_builder.ensure_external(mast_root)?
}
};
- let syscall_node =
- MastNode::new_syscall(callee_id, mast_forest_builder.forest());
- mast_forest_builder.ensure_node(syscall_node)?
+ mast_forest_builder.ensure_syscall(callee_id)?
}
}
};
@@ -141,7 +135,7 @@ impl Assembler {
&self,
mast_forest_builder: &mut MastForestBuilder,
) -> Result, AssemblyError> {
- let dyn_node_id = mast_forest_builder.ensure_node(MastNode::Dyn)?;
+ let dyn_node_id = mast_forest_builder.ensure_dyn()?;
Ok(Some(dyn_node_id))
}
@@ -152,10 +146,8 @@ impl Assembler {
mast_forest_builder: &mut MastForestBuilder,
) -> Result , AssemblyError> {
let dyn_call_node_id = {
- let dyn_node_id = mast_forest_builder.ensure_node(MastNode::Dyn)?;
- let dyn_call_node = MastNode::new_call(dyn_node_id, mast_forest_builder.forest());
-
- mast_forest_builder.ensure_node(dyn_call_node)?
+ let dyn_node_id = mast_forest_builder.ensure_dyn()?;
+ mast_forest_builder.ensure_call(dyn_node_id)?
};
Ok(Some(dyn_call_node_id))
diff --git a/assembly/src/assembler/mast_forest_builder.rs b/assembly/src/assembler/mast_forest_builder.rs
index a5854c8b01..0a1a44d037 100644
--- a/assembly/src/assembler/mast_forest_builder.rs
+++ b/assembly/src/assembler/mast_forest_builder.rs
@@ -1,9 +1,10 @@
use core::ops::Index;
-use alloc::collections::BTreeMap;
+use alloc::{collections::BTreeMap, vec::Vec};
use vm_core::{
crypto::hash::RpoDigest,
mast::{MastForest, MastForestError, MastNode, MastNodeId, MerkleTreeNode},
+ DecoratorList, Operation,
};
/// Builder for a [`MastForest`].
@@ -44,7 +45,7 @@ impl MastForestBuilder {
/// If a [`MastNode`] which is equal to the current node was previously added, the previously
/// returned [`MastNodeId`] will be returned. This enforces this invariant that equal
/// [`MastNode`]s have equal [`MastNodeId`]s.
- pub fn ensure_node(&mut self, node: MastNode) -> Result {
+ fn ensure_node(&mut self, node: MastNode) -> Result {
let node_digest = node.digest();
if let Some(node_id) = self.node_id_by_hash.get(&node_digest) {
@@ -58,6 +59,63 @@ impl MastForestBuilder {
}
}
+ /// Adds a basic block node to the forest, and returns the [`MastNodeId`] associated with it.
+ pub fn ensure_block(
+ &mut self,
+ operations: Vec,
+ decorators: Option,
+ ) -> Result {
+ match decorators {
+ Some(decorators) => {
+ self.ensure_node(MastNode::new_basic_block_with_decorators(operations, decorators))
+ }
+ None => self.ensure_node(MastNode::new_basic_block(operations)),
+ }
+ }
+
+ /// Adds a join node to the forest, and returns the [`MastNodeId`] associated with it.
+ pub fn ensure_join(
+ &mut self,
+ left_child: MastNodeId,
+ right_child: MastNodeId,
+ ) -> Result {
+ self.ensure_node(MastNode::new_join(left_child, right_child, &self.mast_forest))
+ }
+
+ /// Adds a split node to the forest, and returns the [`MastNodeId`] associated with it.
+ pub fn ensure_split(
+ &mut self,
+ if_branch: MastNodeId,
+ else_branch: MastNodeId,
+ ) -> Result {
+ self.ensure_node(MastNode::new_split(if_branch, else_branch, &self.mast_forest))
+ }
+
+ /// Adds a loop node to the forest, and returns the [`MastNodeId`] associated with it.
+ pub fn ensure_loop(&mut self, body: MastNodeId) -> Result {
+ self.ensure_node(MastNode::new_loop(body, &self.mast_forest))
+ }
+
+ /// Adds a call node to the forest, and returns the [`MastNodeId`] associated with it.
+ pub fn ensure_call(&mut self, callee: MastNodeId) -> Result {
+ self.ensure_node(MastNode::new_call(callee, &self.mast_forest))
+ }
+
+ /// Adds a syscall node to the forest, and returns the [`MastNodeId`] associated with it.
+ pub fn ensure_syscall(&mut self, callee: MastNodeId) -> Result {
+ self.ensure_node(MastNode::new_syscall(callee, &self.mast_forest))
+ }
+
+ /// Adds a dynexec node to the forest, and returns the [`MastNodeId`] associated with it.
+ pub fn ensure_dyn(&mut self) -> Result {
+ self.ensure_node(MastNode::new_dyn())
+ }
+
+ /// Adds an external node to the forest, and returns the [`MastNodeId`] associated with it.
+ pub fn ensure_external(&mut self, mast_root: RpoDigest) -> Result {
+ self.ensure_node(MastNode::new_external(mast_root))
+ }
+
/// Marks the given [`MastNodeId`] as being the root of a procedure.
pub fn make_root(&mut self, new_root_id: MastNodeId) {
self.mast_forest.make_root(new_root_id)
diff --git a/assembly/src/assembler/mod.rs b/assembly/src/assembler/mod.rs
index b6bb051d30..cfdfbfe592 100644
--- a/assembly/src/assembler/mod.rs
+++ b/assembly/src/assembler/mod.rs
@@ -11,7 +11,7 @@ use crate::{
use alloc::{boxed::Box, sync::Arc, vec::Vec};
use mast_forest_builder::MastForestBuilder;
use vm_core::{
- mast::{MastForest, MastNode, MastNodeId, MerkleTreeNode},
+ mast::{MastForest, MastNodeId, MerkleTreeNode},
Decorator, DecoratorList, Kernel, Operation, Program,
};
@@ -792,12 +792,9 @@ impl Assembler {
let else_blk =
self.compile_body(else_blk.iter(), context, None, mast_forest_builder)?;
- let split_node_id = {
- let split_node =
- MastNode::new_split(then_blk, else_blk, mast_forest_builder.forest());
-
- mast_forest_builder.ensure_node(split_node).map_err(AssemblyError::from)?
- };
+ let split_node_id = mast_forest_builder
+ .ensure_split(then_blk, else_blk)
+ .map_err(AssemblyError::from)?;
mast_node_ids.push(split_node_id);
}
@@ -828,11 +825,9 @@ impl Assembler {
let loop_body_node_id =
self.compile_body(body.iter(), context, None, mast_forest_builder)?;
- let loop_node_id = {
- let loop_node =
- MastNode::new_loop(loop_body_node_id, mast_forest_builder.forest());
- mast_forest_builder.ensure_node(loop_node).map_err(AssemblyError::from)?
- };
+ let loop_node_id = mast_forest_builder
+ .ensure_loop(loop_body_node_id)
+ .map_err(AssemblyError::from)?;
mast_node_ids.push(loop_node_id);
}
}
@@ -846,8 +841,9 @@ impl Assembler {
}
Ok(if mast_node_ids.is_empty() {
- let basic_block_node = MastNode::new_basic_block(vec![Operation::Noop]);
- mast_forest_builder.ensure_node(basic_block_node).map_err(AssemblyError::from)?
+ mast_forest_builder
+ .ensure_block(vec![Operation::Noop], None)
+ .map_err(AssemblyError::from)?
} else {
combine_mast_node_ids(mast_node_ids, mast_forest_builder)?
})
@@ -907,8 +903,7 @@ fn combine_mast_node_ids(
while let (Some(left), Some(right)) =
(source_mast_node_iter.next(), source_mast_node_iter.next())
{
- let join_mast_node = MastNode::new_join(left, right, mast_forest_builder.forest());
- let join_mast_node_id = mast_forest_builder.ensure_node(join_mast_node)?;
+ let join_mast_node_id = mast_forest_builder.ensure_join(left, right)?;
mast_node_ids.push(join_mast_node_id);
}
diff --git a/assembly/src/assembler/tests.rs b/assembly/src/assembler/tests.rs
index e627d985ab..c557276b2c 100644
--- a/assembly/src/assembler/tests.rs
+++ b/assembly/src/assembler/tests.rs
@@ -81,12 +81,10 @@ fn nested_blocks() {
// contains the MAST nodes for the kernel after a call to
// `Assembler::with_kernel_from_module()`.
let syscall_foo_node_id = {
- let kernel_foo_node = MastNode::new_basic_block(vec![Operation::Add]);
- let kernel_foo_node_id = expected_mast_forest_builder.ensure_node(kernel_foo_node).unwrap();
+ let kernel_foo_node_id =
+ expected_mast_forest_builder.ensure_block(vec![Operation::Add], None).unwrap();
- let syscall_node =
- MastNode::new_syscall(kernel_foo_node_id, expected_mast_forest_builder.forest());
- expected_mast_forest_builder.ensure_node(syscall_node).unwrap()
+ expected_mast_forest_builder.ensure_syscall(kernel_foo_node_id).unwrap()
};
let program = r#"
@@ -129,95 +127,65 @@ fn nested_blocks() {
let exec_bar_node_id = {
// bar procedure
- let basic_block_1 = MastNode::new_basic_block(vec![Operation::Push(17_u32.into())]);
- let basic_block_1_id = expected_mast_forest_builder.ensure_node(basic_block_1).unwrap();
+ let basic_block_1_id = expected_mast_forest_builder
+ .ensure_block(vec![Operation::Push(17_u32.into())], None)
+ .unwrap();
// Basic block representing the `foo` procedure
- let basic_block_2 = MastNode::new_basic_block(vec![Operation::Push(19_u32.into())]);
- let basic_block_2_id = expected_mast_forest_builder.ensure_node(basic_block_2).unwrap();
-
- let join_node = MastNode::new_join(
- basic_block_1_id,
- basic_block_2_id,
- expected_mast_forest_builder.forest(),
- );
- expected_mast_forest_builder.ensure_node(join_node).unwrap()
- };
+ let basic_block_2_id = expected_mast_forest_builder
+ .ensure_block(vec![Operation::Push(19_u32.into())], None)
+ .unwrap();
- let exec_foo_bar_baz_node_id = {
- // basic block representing foo::bar.baz procedure
- let basic_block = MastNode::new_basic_block(vec![Operation::Push(29_u32.into())]);
- expected_mast_forest_builder.ensure_node(basic_block).unwrap()
+ expected_mast_forest_builder
+ .ensure_join(basic_block_1_id, basic_block_2_id)
+ .unwrap()
};
- let before = {
- let before_node = MastNode::new_basic_block(vec![Operation::Push(2u32.into())]);
- expected_mast_forest_builder.ensure_node(before_node).unwrap()
- };
+ // basic block representing foo::bar.baz procedure
+ let exec_foo_bar_baz_node_id = expected_mast_forest_builder
+ .ensure_block(vec![Operation::Push(29_u32.into())], None)
+ .unwrap();
- let r#true1 = {
- let r#true_node = MastNode::new_basic_block(vec![Operation::Push(3u32.into())]);
- expected_mast_forest_builder.ensure_node(r#true_node).unwrap()
- };
- let r#false1 = {
- let r#false_node = MastNode::new_basic_block(vec![Operation::Push(5u32.into())]);
- expected_mast_forest_builder.ensure_node(r#false_node).unwrap()
- };
- let r#if1 = {
- let r#if_node =
- MastNode::new_split(r#true1, r#false1, expected_mast_forest_builder.forest());
- expected_mast_forest_builder.ensure_node(r#if_node).unwrap()
- };
+ let before = expected_mast_forest_builder
+ .ensure_block(vec![Operation::Push(2u32.into())], None)
+ .unwrap();
- let r#true3 = {
- let r#true_node = MastNode::new_basic_block(vec![Operation::Push(7u32.into())]);
- expected_mast_forest_builder.ensure_node(r#true_node).unwrap()
- };
- let r#false3 = {
- let r#false_node = MastNode::new_basic_block(vec![Operation::Push(11u32.into())]);
- expected_mast_forest_builder.ensure_node(r#false_node).unwrap()
- };
- let r#true2 = {
- let r#if_node =
- MastNode::new_split(r#true3, r#false3, expected_mast_forest_builder.forest());
- expected_mast_forest_builder.ensure_node(r#if_node).unwrap()
- };
+ let r#true1 = expected_mast_forest_builder
+ .ensure_block(vec![Operation::Push(3u32.into())], None)
+ .unwrap();
+ let r#false1 = expected_mast_forest_builder
+ .ensure_block(vec![Operation::Push(5u32.into())], None)
+ .unwrap();
+ let r#if1 = expected_mast_forest_builder.ensure_split(r#true1, r#false1).unwrap();
+
+ let r#true3 = expected_mast_forest_builder
+ .ensure_block(vec![Operation::Push(7u32.into())], None)
+ .unwrap();
+ let r#false3 = expected_mast_forest_builder
+ .ensure_block(vec![Operation::Push(11u32.into())], None)
+ .unwrap();
+ let r#true2 = expected_mast_forest_builder.ensure_split(r#true3, r#false3).unwrap();
let r#while = {
let push_basic_block_id = {
- let push_basic_block = MastNode::new_basic_block(vec![Operation::Push(23u32.into())]);
- expected_mast_forest_builder.ensure_node(push_basic_block).unwrap()
- };
- let body_node_id = {
- let body_node = MastNode::new_join(
- exec_bar_node_id,
- push_basic_block_id,
- expected_mast_forest_builder.forest(),
- );
-
- expected_mast_forest_builder.ensure_node(body_node).unwrap()
+ expected_mast_forest_builder
+ .ensure_block(vec![Operation::Push(23u32.into())], None)
+ .unwrap()
};
+ let body_node_id = expected_mast_forest_builder
+ .ensure_join(exec_bar_node_id, push_basic_block_id)
+ .unwrap();
- let loop_node = MastNode::new_loop(body_node_id, expected_mast_forest_builder.forest());
- expected_mast_forest_builder.ensure_node(loop_node).unwrap()
- };
- let push_13_basic_block_id = {
- let node = MastNode::new_basic_block(vec![Operation::Push(13u32.into())]);
- expected_mast_forest_builder.ensure_node(node).unwrap()
+ expected_mast_forest_builder.ensure_loop(body_node_id).unwrap()
};
+ let push_13_basic_block_id = expected_mast_forest_builder
+ .ensure_block(vec![Operation::Push(13u32.into())], None)
+ .unwrap();
- let r#false2 = {
- let node = MastNode::new_join(
- push_13_basic_block_id,
- r#while,
- expected_mast_forest_builder.forest(),
- );
- expected_mast_forest_builder.ensure_node(node).unwrap()
- };
- let nested = {
- let node = MastNode::new_split(r#true2, r#false2, expected_mast_forest_builder.forest());
- expected_mast_forest_builder.ensure_node(node).unwrap()
- };
+ let r#false2 = expected_mast_forest_builder
+ .ensure_join(push_13_basic_block_id, r#while)
+ .unwrap();
+ let nested = expected_mast_forest_builder.ensure_split(r#true2, r#false2).unwrap();
let combined_node_id = combine_mast_node_ids(
vec![before, r#if1, nested, exec_foo_bar_baz_node_id, syscall_foo_node_id],
diff --git a/core/src/mast/node/mod.rs b/core/src/mast/node/mod.rs
index 31cb297309..aab440d858 100644
--- a/core/src/mast/node/mod.rs
+++ b/core/src/mast/node/mod.rs
@@ -84,7 +84,7 @@ impl MastNode {
Self::Call(CallNode::new_syscall(callee, mast_forest))
}
- pub fn new_dynexec() -> Self {
+ pub fn new_dyn() -> Self {
Self::Dyn
}
diff --git a/core/src/mast/serialization/info.rs b/core/src/mast/serialization/info.rs
index b72e1701e0..5b5a6aabdb 100644
--- a/core/src/mast/serialization/info.rs
+++ b/core/src/mast/serialization/info.rs
@@ -77,7 +77,7 @@ impl MastNodeInfo {
Ok(MastNode::new_syscall(callee_id, mast_forest))
}
- MastNodeType::Dyn => Ok(MastNode::new_dynexec()),
+ MastNodeType::Dyn => Ok(MastNode::new_dyn()),
MastNodeType::External => Ok(MastNode::new_external(self.digest)),
}?;
diff --git a/core/src/mast/serialization/tests.rs b/core/src/mast/serialization/tests.rs
index 69d746ec38..e07e92296d 100644
--- a/core/src/mast/serialization/tests.rs
+++ b/core/src/mast/serialization/tests.rs
@@ -322,7 +322,7 @@ fn serialize_deserialize_all_nodes() {
mast_forest.add_node(node).unwrap()
};
let dyn_node_id = {
- let node = MastNode::new_dynexec();
+ let node = MastNode::new_dyn();
mast_forest.add_node(node).unwrap()
};
diff --git a/processor/src/decoder/tests.rs b/processor/src/decoder/tests.rs
index 0cbdbd8f82..390a162e9e 100644
--- a/processor/src/decoder/tests.rs
+++ b/processor/src/decoder/tests.rs
@@ -1194,7 +1194,7 @@ fn dyn_block() {
let join_node_id = mast_forest.add_node(join_node.clone()).unwrap();
// This dyn will point to foo.
- let dyn_node = MastNode::new_dynexec();
+ let dyn_node = MastNode::new_dyn();
let dyn_node_id = mast_forest.add_node(dyn_node.clone()).unwrap();
let program_root_node = MastNode::new_join(join_node_id, dyn_node_id, &mast_forest);