Skip to content

Commit

Permalink
feat: add more functions for adding nodes to MastForest (#1412)
Browse files Browse the repository at this point in the history
  • Loading branch information
sergerad authored Jul 27, 2024
1 parent b8a767f commit 74f1c2f
Show file tree
Hide file tree
Showing 18 changed files with 272 additions and 189 deletions.
8 changes: 4 additions & 4 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -19,10 +19,10 @@
- Added serialization/deserialization for `MastForest` (#1370)
- Updated CI to support `CHANGELOG.md` modification checking and `no changelog` label (#1406)
- Introduced `MastForestError` to enforce `MastForest` node count invariant (#1394)
- Added functions to `MastForestBuilder` to allow ensuring of nodes with fewer LOC (#1404)
- Make `Assembler` single-use (#1409)
- Remove `ProcedureCache` from the assembler (#1411).
- Add `Assembler::assemble_library()` (#1413)
- Added functions to `MastForest` and `MastForestBuilder` to add and ensure nodes with fewer LOC (#1404, #1412)
- Made `Assembler` single-use (#1409)
- Removed `ProcedureCache` from the assembler (#1411).
- Added `Assembler::assemble_library()` (#1413)

#### Changed

Expand Down
15 changes: 10 additions & 5 deletions assembly/src/assembler/mast_forest_builder.rs
Original file line number Diff line number Diff line change
Expand Up @@ -168,7 +168,8 @@ impl MastForestBuilder {
left_child: MastNodeId,
right_child: MastNodeId,
) -> Result<MastNodeId, AssemblyError> {
self.ensure_node(MastNode::new_join(left_child, right_child, &self.mast_forest))
let join = MastNode::new_join(left_child, right_child, &self.mast_forest)?;
self.ensure_node(join)
}

/// Adds a split node to the forest, and returns the [`MastNodeId`] associated with it.
Expand All @@ -177,22 +178,26 @@ impl MastForestBuilder {
if_branch: MastNodeId,
else_branch: MastNodeId,
) -> Result<MastNodeId, AssemblyError> {
self.ensure_node(MastNode::new_split(if_branch, else_branch, &self.mast_forest))
let split = MastNode::new_split(if_branch, else_branch, &self.mast_forest)?;
self.ensure_node(split)
}

/// Adds a loop node to the forest, and returns the [`MastNodeId`] associated with it.
pub fn ensure_loop(&mut self, body: MastNodeId) -> Result<MastNodeId, AssemblyError> {
self.ensure_node(MastNode::new_loop(body, &self.mast_forest))
let loop_node = MastNode::new_loop(body, &self.mast_forest)?;
self.ensure_node(loop_node)
}

/// Adds a call node to the forest, and returns the [`MastNodeId`] associated with it.
pub fn ensure_call(&mut self, callee: MastNodeId) -> Result<MastNodeId, AssemblyError> {
self.ensure_node(MastNode::new_call(callee, &self.mast_forest))
let call = MastNode::new_call(callee, &self.mast_forest)?;
self.ensure_node(call)
}

/// Adds a syscall node to the forest, and returns the [`MastNodeId`] associated with it.
pub fn ensure_syscall(&mut self, callee: MastNodeId) -> Result<MastNodeId, AssemblyError> {
self.ensure_node(MastNode::new_syscall(callee, &self.mast_forest))
let syscall = MastNode::new_syscall(callee, &self.mast_forest)?;
self.ensure_node(syscall)
}

/// Adds a dyn node to the forest, and returns the [`MastNodeId`] associated with it.
Expand Down
29 changes: 7 additions & 22 deletions assembly/src/assembler/tests.rs
Original file line number Diff line number Diff line change
@@ -1,10 +1,6 @@
use alloc::{boxed::Box, vec::Vec};
use pretty_assertions::assert_eq;
use vm_core::{
assert_matches,
mast::{MastForest, MastNode},
Program,
};
use vm_core::{assert_matches, mast::MastForest, Program};

use super::{Assembler, Library, Operation};
use crate::{
Expand Down Expand Up @@ -248,29 +244,18 @@ fn duplicate_nodes() {
let mut expected_mast_forest = MastForest::new();

// basic block: mul
let mul_basic_block_id = {
let node = MastNode::new_basic_block(vec![Operation::Mul]);
expected_mast_forest.add_node(node).unwrap()
};
let mul_basic_block_id = expected_mast_forest.add_block(vec![Operation::Mul], None).unwrap();

// basic block: add
let add_basic_block_id = {
let node = MastNode::new_basic_block(vec![Operation::Add]);
expected_mast_forest.add_node(node).unwrap()
};
let add_basic_block_id = expected_mast_forest.add_block(vec![Operation::Add], None).unwrap();

// inner split: `if.true add else mul end`
let inner_split_id = {
let node =
MastNode::new_split(add_basic_block_id, mul_basic_block_id, &expected_mast_forest);
expected_mast_forest.add_node(node).unwrap()
};
let inner_split_id =
expected_mast_forest.add_split(add_basic_block_id, mul_basic_block_id).unwrap();

// root: outer split
let root_id = {
let node = MastNode::new_split(mul_basic_block_id, inner_split_id, &expected_mast_forest);
expected_mast_forest.add_node(node).unwrap()
};
let root_id = expected_mast_forest.add_split(mul_basic_block_id, inner_split_id).unwrap();

expected_mast_forest.make_root(root_id);

let expected_program = Program::new(expected_mast_forest, root_id);
Expand Down
74 changes: 73 additions & 1 deletion core/src/mast/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,8 @@ pub use node::{
};
use winter_utils::DeserializationError;

use crate::{DecoratorList, Operation};

mod serialization;

#[cfg(test)]
Expand Down Expand Up @@ -60,6 +62,68 @@ impl MastForest {
Ok(new_node_id)
}

/// Adds a basic block node to the forest, and returns the [`MastNodeId`] associated with it.
pub fn add_block(
&mut self,
operations: Vec<Operation>,
decorators: Option<DecoratorList>,
) -> Result<MastNodeId, MastForestError> {
match decorators {
Some(decorators) => {
self.add_node(MastNode::new_basic_block_with_decorators(operations, decorators))
}
None => self.add_node(MastNode::new_basic_block(operations)),
}
}

/// Adds a join node to the forest, and returns the [`MastNodeId`] associated with it.
pub fn add_join(
&mut self,
left_child: MastNodeId,
right_child: MastNodeId,
) -> Result<MastNodeId, MastForestError> {
let join = MastNode::new_join(left_child, right_child, self)?;
self.add_node(join)
}

/// Adds a split node to the forest, and returns the [`MastNodeId`] associated with it.
pub fn add_split(
&mut self,
if_branch: MastNodeId,
else_branch: MastNodeId,
) -> Result<MastNodeId, MastForestError> {
let split = MastNode::new_split(if_branch, else_branch, self)?;
self.add_node(split)
}

/// Adds a loop node to the forest, and returns the [`MastNodeId`] associated with it.
pub fn add_loop(&mut self, body: MastNodeId) -> Result<MastNodeId, MastForestError> {
let loop_node = MastNode::new_loop(body, self)?;
self.add_node(loop_node)
}

/// Adds a call node to the forest, and returns the [`MastNodeId`] associated with it.
pub fn add_call(&mut self, callee: MastNodeId) -> Result<MastNodeId, MastForestError> {
let call = MastNode::new_call(callee, self)?;
self.add_node(call)
}

/// Adds a syscall node to the forest, and returns the [`MastNodeId`] associated with it.
pub fn add_syscall(&mut self, callee: MastNodeId) -> Result<MastNodeId, MastForestError> {
let syscall = MastNode::new_syscall(callee, self)?;
self.add_node(syscall)
}

/// Adds a dyn node to the forest, and returns the [`MastNodeId`] associated with it.
pub fn add_dyn(&mut self) -> Result<MastNodeId, MastForestError> {
self.add_node(MastNode::new_dyn())
}

/// Adds an external node to the forest, and returns the [`MastNodeId`] associated with it.
pub fn add_external(&mut self, mast_root: RpoDigest) -> Result<MastNodeId, MastForestError> {
self.add_node(MastNode::new_external(mast_root))
}

/// Marks the given [`MastNodeId`] as being the root of a procedure.
///
/// # Panics
Expand Down Expand Up @@ -157,6 +221,12 @@ impl MastNodeId {
}
}

impl From<MastNodeId> for usize {
fn from(value: MastNodeId) -> Self {
value.0 as usize
}
}

impl From<MastNodeId> for u32 {
fn from(value: MastNodeId) -> Self {
value.0
Expand All @@ -179,11 +249,13 @@ impl fmt::Display for MastNodeId {
// ================================================================================================

/// Represents the types of errors that can occur when dealing with MAST forest.
#[derive(Debug, thiserror::Error)]
#[derive(Debug, thiserror::Error, PartialEq)]
pub enum MastForestError {
#[error(
"invalid node count: MAST forest exceeds the maximum of {} nodes",
MastForest::MAX_NODES
)]
TooManyNodes,
#[error("node id: {0} is greater than or equal to forest length: {1}")]
NodeIdOverflow(MastNodeId, usize),
}
23 changes: 16 additions & 7 deletions core/src/mast/node/call_node.rs
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@ use miden_formatting::prettier::PrettyPrint;

use crate::{
chiplets::hasher,
mast::{MastForest, MastNodeId},
mast::{MastForest, MastForestError, MastNodeId},
OPCODE_CALL, OPCODE_SYSCALL,
};

Expand Down Expand Up @@ -38,34 +38,43 @@ impl CallNode {
/// Constructors
impl CallNode {
/// Returns a new [`CallNode`] instantiated with the specified callee.
pub fn new(callee: MastNodeId, mast_forest: &MastForest) -> Self {
pub fn new(callee: MastNodeId, mast_forest: &MastForest) -> Result<Self, MastForestError> {
if usize::from(callee) >= mast_forest.nodes.len() {
return Err(MastForestError::NodeIdOverflow(callee, mast_forest.nodes.len()));
}
let digest = {
let callee_digest = mast_forest[callee].digest();

hasher::merge_in_domain(&[callee_digest, RpoDigest::default()], Self::CALL_DOMAIN)
};

Self {
Ok(Self {
callee,
is_syscall: false,
digest,
}
})
}

/// Returns a new [`CallNode`] instantiated with the specified callee and marked as a kernel
/// call.
pub fn new_syscall(callee: MastNodeId, mast_forest: &MastForest) -> Self {
pub fn new_syscall(
callee: MastNodeId,
mast_forest: &MastForest,
) -> Result<Self, MastForestError> {
if usize::from(callee) >= mast_forest.nodes.len() {
return Err(MastForestError::NodeIdOverflow(callee, mast_forest.nodes.len()));
}
let digest = {
let callee_digest = mast_forest[callee].digest();

hasher::merge_in_domain(&[callee_digest, RpoDigest::default()], Self::SYSCALL_DOMAIN)
};

Self {
Ok(Self {
callee,
is_syscall: true,
digest,
}
})
}
}

Expand Down
15 changes: 12 additions & 3 deletions core/src/mast/node/join_node.rs
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@ use miden_crypto::{hash::rpo::RpoDigest, Felt};

use crate::{
chiplets::hasher,
mast::{MastForest, MastNodeId},
mast::{MastForest, MastForestError, MastNodeId},
prettier::PrettyPrint,
OPCODE_JOIN,
};
Expand All @@ -29,15 +29,24 @@ impl JoinNode {
/// Constructors
impl JoinNode {
/// Returns a new [`JoinNode`] instantiated with the specified children nodes.
pub fn new(children: [MastNodeId; 2], mast_forest: &MastForest) -> Self {
pub fn new(
children: [MastNodeId; 2],
mast_forest: &MastForest,
) -> Result<Self, MastForestError> {
let forest_len = mast_forest.nodes.len();
if usize::from(children[0]) >= forest_len {
return Err(MastForestError::NodeIdOverflow(children[0], forest_len));
} else if usize::from(children[1]) >= forest_len {
return Err(MastForestError::NodeIdOverflow(children[1], forest_len));
}
let digest = {
let left_child_hash = mast_forest[children[0]].digest();
let right_child_hash = mast_forest[children[1]].digest();

hasher::merge_in_domain(&[left_child_hash, right_child_hash], Self::DOMAIN)
};

Self { children, digest }
Ok(Self { children, digest })
}

#[cfg(test)]
Expand Down
9 changes: 6 additions & 3 deletions core/src/mast/node/loop_node.rs
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@ use miden_formatting::prettier::PrettyPrint;

use crate::{
chiplets::hasher,
mast::{MastForest, MastNodeId},
mast::{MastForest, MastForestError, MastNodeId},
OPCODE_LOOP,
};

Expand All @@ -32,14 +32,17 @@ impl LoopNode {

/// Constructors
impl LoopNode {
pub fn new(body: MastNodeId, mast_forest: &MastForest) -> Self {
pub fn new(body: MastNodeId, mast_forest: &MastForest) -> Result<Self, MastForestError> {
if usize::from(body) >= mast_forest.nodes.len() {
return Err(MastForestError::NodeIdOverflow(body, mast_forest.nodes.len()));
}
let digest = {
let body_hash = mast_forest[body].digest();

hasher::merge_in_domain(&[body_hash, RpoDigest::default()], Self::DOMAIN)
};

Self { body, digest }
Ok(Self { body, digest })
}
}

Expand Down
30 changes: 20 additions & 10 deletions core/src/mast/node/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,8 @@ use crate::{
DecoratorList, Operation,
};

use super::MastForestError;

// MAST NODE
// ================================================================================================

Expand Down Expand Up @@ -64,28 +66,36 @@ impl MastNode {
left_child: MastNodeId,
right_child: MastNodeId,
mast_forest: &MastForest,
) -> Self {
Self::Join(JoinNode::new([left_child, right_child], mast_forest))
) -> Result<Self, MastForestError> {
let join = JoinNode::new([left_child, right_child], mast_forest)?;
Ok(Self::Join(join))
}

pub fn new_split(
if_branch: MastNodeId,
else_branch: MastNodeId,
mast_forest: &MastForest,
) -> Self {
Self::Split(SplitNode::new([if_branch, else_branch], mast_forest))
) -> Result<Self, MastForestError> {
let split = SplitNode::new([if_branch, else_branch], mast_forest)?;
Ok(Self::Split(split))
}

pub fn new_loop(body: MastNodeId, mast_forest: &MastForest) -> Self {
Self::Loop(LoopNode::new(body, mast_forest))
pub fn new_loop(body: MastNodeId, mast_forest: &MastForest) -> Result<Self, MastForestError> {
let loop_node = LoopNode::new(body, mast_forest)?;
Ok(Self::Loop(loop_node))
}

pub fn new_call(callee: MastNodeId, mast_forest: &MastForest) -> Self {
Self::Call(CallNode::new(callee, mast_forest))
pub fn new_call(callee: MastNodeId, mast_forest: &MastForest) -> Result<Self, MastForestError> {
let call = CallNode::new(callee, mast_forest)?;
Ok(Self::Call(call))
}

pub fn new_syscall(callee: MastNodeId, mast_forest: &MastForest) -> Self {
Self::Call(CallNode::new_syscall(callee, mast_forest))
pub fn new_syscall(
callee: MastNodeId,
mast_forest: &MastForest,
) -> Result<Self, MastForestError> {
let syscall = CallNode::new_syscall(callee, mast_forest)?;
Ok(Self::Call(syscall))
}

pub fn new_dyn() -> Self {
Expand Down
Loading

0 comments on commit 74f1c2f

Please sign in to comment.