Skip to content

Commit

Permalink
feat: add more functions for ensuring nodes via MastForestBuilder (#1404
Browse files Browse the repository at this point in the history
)
  • Loading branch information
sergerad authored Jul 20, 2024
1 parent 14995b4 commit 64c7401
Show file tree
Hide file tree
Showing 10 changed files with 139 additions and 126 deletions.
9 changes: 5 additions & 4 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -6,8 +6,8 @@

- Added error codes support for the `mtree_verify` instruction (#1328).
- Added support for immediate values for `lt`, `lte`, `gt`, `gte` comparison instructions (#1346).
- Change MAST to a table-based representation (#1349)
- Introduce `MastForestStore` (#1359)
- Changed MAST to a table-based representation (#1349)
- Introduced `MastForestStore` (#1359)
- Adjusted prover's metal acceleration code to work with 0.9 versions of the crates (#1357)
- Added support for immediate values for `u32lt`, `u32lte`, `u32gt`, `u32gte`, `u32min` and `u32max` comparison instructions (#1358).
- Added support for the `nop` instruction, which corresponds to the VM opcode of the same name, and has the same semantics. This is implemented for use by compilers primarily.
Expand All @@ -16,9 +16,10 @@
- Added support for immediate values for `u32and`, `u32or`, `u32xor` and `u32not` bitwise instructions (#1362).
- Optimized `std::sys::truncate_stuck` procedure (#1384).
- Updated CI and Makefile to standardise it accross Miden repositories (#1342).
- Add serialization/deserialization for `MastForest` (#1370)
- Added serialization/deserialization for `MastForest` (#1370)
- Updated CI to support `CHANGELOG.md` modification checking and `no changelog` label (#1406)
- Introduce `MastForestError` to enforce `MastForest` node count invariant (#1394)
- Introduced `MastForestError` to enforce `MastForest` node count invariant (#1394)
- Added functions to `MastForestBuilder` to allow ensuring of nodes with fewer LOC (#1404)

#### Changed

Expand Down
5 changes: 2 additions & 3 deletions assembly/src/assembler/basic_block_builder.rs
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@ use super::{
};
use alloc::{borrow::Borrow, string::ToString, vec::Vec};
use vm_core::{
mast::{MastForestError, MastNode, MastNodeId},
mast::{MastForestError, MastNodeId},
AdviceInjector, AssemblyOp, Operation,
};

Expand Down Expand Up @@ -134,8 +134,7 @@ impl BasicBlockBuilder {
let ops = self.ops.drain(..).collect();
let decorators = self.decorators.drain(..).collect();

let basic_block_node = MastNode::new_basic_block_with_decorators(ops, decorators);
let basic_block_node_id = mast_forest_builder.ensure_node(basic_block_node)?;
let basic_block_node_id = mast_forest_builder.ensure_block(ops, Some(decorators))?;

Ok(Some(basic_block_node_id))
} else if !self.decorators.is_empty() {
Expand Down
26 changes: 9 additions & 17 deletions assembly/src/assembler/instruction/procedures.rs
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@ use crate::{
};

use smallvec::SmallVec;
use vm_core::mast::{MastForest, MastNode, MastNodeId};
use vm_core::mast::{MastForest, MastNodeId};

/// Procedure Invocation
impl Assembler {
Expand Down Expand Up @@ -96,8 +96,7 @@ impl Assembler {
None => {
// If the MAST root called isn't known to us, make it an external
// reference.
let external_node = MastNode::new_external(mast_root);
mast_forest_builder.ensure_node(external_node)?
mast_forest_builder.ensure_external(mast_root)?
}
}
}
Expand All @@ -107,28 +106,23 @@ impl Assembler {
None => {
// If the MAST root called isn't known to us, make it an external
// reference.
let external_node = MastNode::new_external(mast_root);
mast_forest_builder.ensure_node(external_node)?
mast_forest_builder.ensure_external(mast_root)?
}
};

let call_node = MastNode::new_call(callee_id, mast_forest_builder.forest());
mast_forest_builder.ensure_node(call_node)?
mast_forest_builder.ensure_call(callee_id)?
}
InvokeKind::SysCall => {
let callee_id = match mast_forest_builder.find_procedure_root(mast_root) {
Some(callee_id) => callee_id,
None => {
// If the MAST root called isn't known to us, make it an external
// reference.
let external_node = MastNode::new_external(mast_root);
mast_forest_builder.ensure_node(external_node)?
mast_forest_builder.ensure_external(mast_root)?
}
};

let syscall_node =
MastNode::new_syscall(callee_id, mast_forest_builder.forest());
mast_forest_builder.ensure_node(syscall_node)?
mast_forest_builder.ensure_syscall(callee_id)?
}
}
};
Expand All @@ -141,7 +135,7 @@ impl Assembler {
&self,
mast_forest_builder: &mut MastForestBuilder,
) -> Result<Option<MastNodeId>, AssemblyError> {
let dyn_node_id = mast_forest_builder.ensure_node(MastNode::Dyn)?;
let dyn_node_id = mast_forest_builder.ensure_dyn()?;

Ok(Some(dyn_node_id))
}
Expand All @@ -152,10 +146,8 @@ impl Assembler {
mast_forest_builder: &mut MastForestBuilder,
) -> Result<Option<MastNodeId>, AssemblyError> {
let dyn_call_node_id = {
let dyn_node_id = mast_forest_builder.ensure_node(MastNode::Dyn)?;
let dyn_call_node = MastNode::new_call(dyn_node_id, mast_forest_builder.forest());

mast_forest_builder.ensure_node(dyn_call_node)?
let dyn_node_id = mast_forest_builder.ensure_dyn()?;
mast_forest_builder.ensure_call(dyn_node_id)?
};

Ok(Some(dyn_call_node_id))
Expand Down
62 changes: 60 additions & 2 deletions assembly/src/assembler/mast_forest_builder.rs
Original file line number Diff line number Diff line change
@@ -1,9 +1,10 @@
use core::ops::Index;

use alloc::collections::BTreeMap;
use alloc::{collections::BTreeMap, vec::Vec};
use vm_core::{
crypto::hash::RpoDigest,
mast::{MastForest, MastForestError, MastNode, MastNodeId, MerkleTreeNode},
DecoratorList, Operation,
};

/// Builder for a [`MastForest`].
Expand Down Expand Up @@ -44,7 +45,7 @@ impl MastForestBuilder {
/// If a [`MastNode`] which is equal to the current node was previously added, the previously
/// returned [`MastNodeId`] will be returned. This enforces this invariant that equal
/// [`MastNode`]s have equal [`MastNodeId`]s.
pub fn ensure_node(&mut self, node: MastNode) -> Result<MastNodeId, MastForestError> {
fn ensure_node(&mut self, node: MastNode) -> Result<MastNodeId, MastForestError> {
let node_digest = node.digest();

if let Some(node_id) = self.node_id_by_hash.get(&node_digest) {
Expand All @@ -58,6 +59,63 @@ impl MastForestBuilder {
}
}

/// Adds a basic block node to the forest, and returns the [`MastNodeId`] associated with it.
pub fn ensure_block(
&mut self,
operations: Vec<Operation>,
decorators: Option<DecoratorList>,
) -> Result<MastNodeId, MastForestError> {
match decorators {
Some(decorators) => {
self.ensure_node(MastNode::new_basic_block_with_decorators(operations, decorators))
}
None => self.ensure_node(MastNode::new_basic_block(operations)),
}
}

/// Adds a join node to the forest, and returns the [`MastNodeId`] associated with it.
pub fn ensure_join(
&mut self,
left_child: MastNodeId,
right_child: MastNodeId,
) -> Result<MastNodeId, MastForestError> {
self.ensure_node(MastNode::new_join(left_child, right_child, &self.mast_forest))
}

/// Adds a split node to the forest, and returns the [`MastNodeId`] associated with it.
pub fn ensure_split(
&mut self,
if_branch: MastNodeId,
else_branch: MastNodeId,
) -> Result<MastNodeId, MastForestError> {
self.ensure_node(MastNode::new_split(if_branch, else_branch, &self.mast_forest))
}

/// Adds a loop node to the forest, and returns the [`MastNodeId`] associated with it.
pub fn ensure_loop(&mut self, body: MastNodeId) -> Result<MastNodeId, MastForestError> {
self.ensure_node(MastNode::new_loop(body, &self.mast_forest))
}

/// Adds a call node to the forest, and returns the [`MastNodeId`] associated with it.
pub fn ensure_call(&mut self, callee: MastNodeId) -> Result<MastNodeId, MastForestError> {
self.ensure_node(MastNode::new_call(callee, &self.mast_forest))
}

/// Adds a syscall node to the forest, and returns the [`MastNodeId`] associated with it.
pub fn ensure_syscall(&mut self, callee: MastNodeId) -> Result<MastNodeId, MastForestError> {
self.ensure_node(MastNode::new_syscall(callee, &self.mast_forest))
}

/// Adds a dynexec node to the forest, and returns the [`MastNodeId`] associated with it.
pub fn ensure_dyn(&mut self) -> Result<MastNodeId, MastForestError> {
self.ensure_node(MastNode::new_dyn())
}

/// Adds an external node to the forest, and returns the [`MastNodeId`] associated with it.
pub fn ensure_external(&mut self, mast_root: RpoDigest) -> Result<MastNodeId, MastForestError> {
self.ensure_node(MastNode::new_external(mast_root))
}

/// Marks the given [`MastNodeId`] as being the root of a procedure.
pub fn make_root(&mut self, new_root_id: MastNodeId) {
self.mast_forest.make_root(new_root_id)
Expand Down
27 changes: 11 additions & 16 deletions assembly/src/assembler/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@ use crate::{
use alloc::{boxed::Box, sync::Arc, vec::Vec};
use mast_forest_builder::MastForestBuilder;
use vm_core::{
mast::{MastForest, MastNode, MastNodeId, MerkleTreeNode},
mast::{MastForest, MastNodeId, MerkleTreeNode},
Decorator, DecoratorList, Kernel, Operation, Program,
};

Expand Down Expand Up @@ -792,12 +792,9 @@ impl Assembler {
let else_blk =
self.compile_body(else_blk.iter(), context, None, mast_forest_builder)?;

let split_node_id = {
let split_node =
MastNode::new_split(then_blk, else_blk, mast_forest_builder.forest());

mast_forest_builder.ensure_node(split_node).map_err(AssemblyError::from)?
};
let split_node_id = mast_forest_builder
.ensure_split(then_blk, else_blk)
.map_err(AssemblyError::from)?;
mast_node_ids.push(split_node_id);
}

Expand Down Expand Up @@ -828,11 +825,9 @@ impl Assembler {
let loop_body_node_id =
self.compile_body(body.iter(), context, None, mast_forest_builder)?;

let loop_node_id = {
let loop_node =
MastNode::new_loop(loop_body_node_id, mast_forest_builder.forest());
mast_forest_builder.ensure_node(loop_node).map_err(AssemblyError::from)?
};
let loop_node_id = mast_forest_builder
.ensure_loop(loop_body_node_id)
.map_err(AssemblyError::from)?;
mast_node_ids.push(loop_node_id);
}
}
Expand All @@ -846,8 +841,9 @@ impl Assembler {
}

Ok(if mast_node_ids.is_empty() {
let basic_block_node = MastNode::new_basic_block(vec![Operation::Noop]);
mast_forest_builder.ensure_node(basic_block_node).map_err(AssemblyError::from)?
mast_forest_builder
.ensure_block(vec![Operation::Noop], None)
.map_err(AssemblyError::from)?
} else {
combine_mast_node_ids(mast_node_ids, mast_forest_builder)?
})
Expand Down Expand Up @@ -907,8 +903,7 @@ fn combine_mast_node_ids(
while let (Some(left), Some(right)) =
(source_mast_node_iter.next(), source_mast_node_iter.next())
{
let join_mast_node = MastNode::new_join(left, right, mast_forest_builder.forest());
let join_mast_node_id = mast_forest_builder.ensure_node(join_mast_node)?;
let join_mast_node_id = mast_forest_builder.ensure_join(left, right)?;

mast_node_ids.push(join_mast_node_id);
}
Expand Down
Loading

0 comments on commit 64c7401

Please sign in to comment.