Skip to content

Commit

Permalink
add helper methods for adding nodes to MastForest
Browse files Browse the repository at this point in the history
  • Loading branch information
sergerad committed Jul 24, 2024
1 parent 4f0dbf2 commit 67030ee
Show file tree
Hide file tree
Showing 10 changed files with 111 additions and 116 deletions.
23 changes: 6 additions & 17 deletions assembly/src/assembler/tests.rs
Original file line number Diff line number Diff line change
Expand Up @@ -248,29 +248,18 @@ fn duplicate_nodes() {
let mut expected_mast_forest = MastForest::new();

// basic block: mul
let mul_basic_block_id = {
let node = MastNode::new_basic_block(vec![Operation::Mul]);
expected_mast_forest.add_node(node).unwrap()
};
let mul_basic_block_id = expected_mast_forest.add_block(vec![Operation::Mul], None).unwrap();

// basic block: add
let add_basic_block_id = {
let node = MastNode::new_basic_block(vec![Operation::Add]);
expected_mast_forest.add_node(node).unwrap()
};
let add_basic_block_id = expected_mast_forest.add_block(vec![Operation::Add], None).unwrap();

// inner split: `if.true add else mul end`
let inner_split_id = {
let node =
MastNode::new_split(add_basic_block_id, mul_basic_block_id, &expected_mast_forest);
expected_mast_forest.add_node(node).unwrap()
};
let inner_split_id =
expected_mast_forest.add_split(add_basic_block_id, mul_basic_block_id).unwrap();

// root: outer split
let root_id = {
let node = MastNode::new_split(mul_basic_block_id, inner_split_id, &expected_mast_forest);
expected_mast_forest.add_node(node).unwrap()
};
let root_id = expected_mast_forest.add_split(mul_basic_block_id, inner_split_id).unwrap();

expected_mast_forest.make_root(root_id);

let expected_program = Program::new(expected_mast_forest, root_id);
Expand Down
59 changes: 59 additions & 0 deletions core/src/mast/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,8 @@ pub use node::{
};
use winter_utils::DeserializationError;

use crate::{DecoratorList, Operation};

mod serialization;

#[cfg(test)]
Expand Down Expand Up @@ -60,6 +62,63 @@ impl MastForest {
Ok(new_node_id)
}

/// Adds a basic block node to the forest, and returns the [`MastNodeId`] associated with it.
pub fn add_block(
&mut self,
operations: Vec<Operation>,
decorators: Option<DecoratorList>,
) -> Result<MastNodeId, MastForestError> {
match decorators {
Some(decorators) => {
self.add_node(MastNode::new_basic_block_with_decorators(operations, decorators))
}
None => self.add_node(MastNode::new_basic_block(operations)),
}
}

/// Adds a join node to the forest, and returns the [`MastNodeId`] associated with it.
pub fn add_join(
&mut self,
left_child: MastNodeId,
right_child: MastNodeId,
) -> Result<MastNodeId, MastForestError> {
self.add_node(MastNode::new_join(left_child, right_child, self))
}

/// Adds a split node to the forest, and returns the [`MastNodeId`] associated with it.
pub fn add_split(
&mut self,
if_branch: MastNodeId,
else_branch: MastNodeId,
) -> Result<MastNodeId, MastForestError> {
self.add_node(MastNode::new_split(if_branch, else_branch, self))
}

/// Adds a loop node to the forest, and returns the [`MastNodeId`] associated with it.
pub fn add_loop(&mut self, body: MastNodeId) -> Result<MastNodeId, MastForestError> {
self.add_node(MastNode::new_loop(body, self))
}

/// Adds a call node to the forest, and returns the [`MastNodeId`] associated with it.
pub fn add_call(&mut self, callee: MastNodeId) -> Result<MastNodeId, MastForestError> {
self.add_node(MastNode::new_call(callee, self))
}

/// Adds a syscall node to the forest, and returns the [`MastNodeId`] associated with it.
pub fn add_syscall(&mut self, callee: MastNodeId) -> Result<MastNodeId, MastForestError> {
self.add_node(MastNode::new_syscall(callee, self))
}

/// Adds a dyn node to the forest, and returns the [`MastNodeId`] associated with it.
pub fn add_dyn(&mut self) -> Result<MastNodeId, MastForestError> {
self.add_node(MastNode::new_dyn())
}

/// Adds an external node to the forest, and returns the [`MastNodeId`] associated with it.
pub fn add_external(&mut self, mast_root: RpoDigest) -> Result<MastNodeId, MastForestError> {
self.add_node(MastNode::new_external(mast_root))
}

/// Marks the given [`MastNodeId`] as being the root of a procedure.
///
/// # Panics
Expand Down
38 changes: 8 additions & 30 deletions core/src/mast/serialization/tests.rs
Original file line number Diff line number Diff line change
Expand Up @@ -295,41 +295,19 @@ fn serialize_deserialize_all_nodes() {
(num_operations, Decorator::Trace(55)),
];

let basic_block_node = MastNode::new_basic_block_with_decorators(operations, decorators);
mast_forest.add_node(basic_block_node).unwrap()
mast_forest.add_block(operations, Some(decorators)).unwrap()
};

let call_node_id = {
let node = MastNode::new_call(basic_block_id, &mast_forest);
mast_forest.add_node(node).unwrap()
};
let call_node_id = mast_forest.add_call(basic_block_id).unwrap();

let syscall_node_id = {
let node = MastNode::new_syscall(basic_block_id, &mast_forest);
mast_forest.add_node(node).unwrap()
};
let syscall_node_id = mast_forest.add_syscall(basic_block_id).unwrap();

let loop_node_id = {
let node = MastNode::new_loop(basic_block_id, &mast_forest);
mast_forest.add_node(node).unwrap()
};
let join_node_id = {
let node = MastNode::new_join(basic_block_id, call_node_id, &mast_forest);
mast_forest.add_node(node).unwrap()
};
let split_node_id = {
let node = MastNode::new_split(basic_block_id, call_node_id, &mast_forest);
mast_forest.add_node(node).unwrap()
};
let dyn_node_id = {
let node = MastNode::new_dyn();
mast_forest.add_node(node).unwrap()
};
let loop_node_id = mast_forest.add_loop(basic_block_id).unwrap();
let join_node_id = mast_forest.add_join(basic_block_id, call_node_id).unwrap();
let split_node_id = mast_forest.add_split(basic_block_id, call_node_id).unwrap();
let dyn_node_id = mast_forest.add_dyn().unwrap();

let external_node_id = {
let node = MastNode::new_external(RpoDigest::default());
mast_forest.add_node(node).unwrap()
};
let external_node_id = mast_forest.add_external(RpoDigest::default()).unwrap();

mast_forest.make_root(join_node_id);
mast_forest.make_root(syscall_node_id);
Expand Down
3 changes: 1 addition & 2 deletions miden/tests/integration/operations/io_ops/env_ops.rs
Original file line number Diff line number Diff line change
Expand Up @@ -162,8 +162,7 @@ fn caller() {
fn build_bar_hash() -> [u64; 4] {
let mut mast_forest = MastForest::new();

let foo_root = MastNode::new_basic_block(vec![Operation::Caller]);
let foo_root_id = mast_forest.add_node(foo_root).unwrap();
let foo_root_id = mast_forest.add_block(vec![Operation::Caller], None).unwrap();

let bar_root = MastNode::new_syscall(foo_root_id, &mast_forest);
let bar_hash: Word = bar_root.digest().into();
Expand Down
5 changes: 3 additions & 2 deletions processor/src/chiplets/hasher/tests.rs
Original file line number Diff line number Diff line change
Expand Up @@ -416,8 +416,9 @@ fn hash_memoization_basic_blocks_check(basic_block: MastNode) {
let basic_block_1 = basic_block.clone();
let basic_block_1_id = mast_forest.add_node(basic_block_1.clone()).unwrap();

let loop_body = MastNode::new_basic_block(vec![Operation::Pad, Operation::Eq, Operation::Not]);
let loop_body_id = mast_forest.add_node(loop_body).unwrap();
let loop_body_id = mast_forest
.add_block(vec![Operation::Pad, Operation::Eq, Operation::Not], None)
.unwrap();

let loop_block = MastNode::new_loop(loop_body_id, &mast_forest);
let loop_block_id = mast_forest.add_node(loop_block.clone()).unwrap();
Expand Down
3 changes: 1 addition & 2 deletions processor/src/chiplets/tests.rs
Original file line number Diff line number Diff line change
Expand Up @@ -119,8 +119,7 @@ fn build_trace(
let program = {
let mut mast_forest = MastForest::new();

let basic_block = MastNode::new_basic_block(operations);
let basic_block_id = mast_forest.add_node(basic_block).unwrap();
let basic_block_id = mast_forest.add_block(operations, None).unwrap();

Program::new(mast_forest, basic_block_id)
};
Expand Down
24 changes: 8 additions & 16 deletions processor/src/decoder/tests.rs
Original file line number Diff line number Diff line change
Expand Up @@ -327,8 +327,7 @@ fn join_node() {
let basic_block1_id = mast_forest.add_node(basic_block1.clone()).unwrap();
let basic_block2_id = mast_forest.add_node(basic_block2.clone()).unwrap();

let join_node = MastNode::new_join(basic_block1_id, basic_block2_id, &mast_forest);
let join_node_id = mast_forest.add_node(join_node).unwrap();
let join_node_id = mast_forest.add_join(basic_block1_id, basic_block2_id).unwrap();

Program::new(mast_forest, join_node_id)
};
Expand Down Expand Up @@ -393,8 +392,7 @@ fn split_node_true() {
let basic_block1_id = mast_forest.add_node(basic_block1.clone()).unwrap();
let basic_block2_id = mast_forest.add_node(basic_block2.clone()).unwrap();

let split_node = MastNode::new_split(basic_block1_id, basic_block2_id, &mast_forest);
let split_node_id = mast_forest.add_node(split_node).unwrap();
let split_node_id = mast_forest.add_split(basic_block1_id, basic_block2_id).unwrap();

Program::new(mast_forest, split_node_id)
};
Expand Down Expand Up @@ -446,8 +444,7 @@ fn split_node_false() {
let basic_block1_id = mast_forest.add_node(basic_block1.clone()).unwrap();
let basic_block2_id = mast_forest.add_node(basic_block2.clone()).unwrap();

let split_node = MastNode::new_split(basic_block1_id, basic_block2_id, &mast_forest);
let split_node_id = mast_forest.add_node(split_node).unwrap();
let split_node_id = mast_forest.add_split(basic_block1_id, basic_block2_id).unwrap();

Program::new(mast_forest, split_node_id)
};
Expand Down Expand Up @@ -500,8 +497,7 @@ fn loop_node() {

let loop_body_id = mast_forest.add_node(loop_body.clone()).unwrap();

let loop_node = MastNode::new_loop(loop_body_id, &mast_forest);
let loop_node_id = mast_forest.add_node(loop_node).unwrap();
let loop_node_id = mast_forest.add_loop(loop_body_id).unwrap();

Program::new(mast_forest, loop_node_id)
};
Expand Down Expand Up @@ -553,8 +549,7 @@ fn loop_node_skip() {

let loop_body_id = mast_forest.add_node(loop_body.clone()).unwrap();

let loop_node = MastNode::new_loop(loop_body_id, &mast_forest);
let loop_node_id = mast_forest.add_node(loop_node).unwrap();
let loop_node_id = mast_forest.add_loop(loop_body_id).unwrap();

Program::new(mast_forest, loop_node_id)
};
Expand Down Expand Up @@ -596,8 +591,7 @@ fn loop_node_repeat() {

let loop_body_id = mast_forest.add_node(loop_body.clone()).unwrap();

let loop_node = MastNode::new_loop(loop_body_id, &mast_forest);
let loop_node_id = mast_forest.add_node(loop_node).unwrap();
let loop_node_id = mast_forest.add_loop(loop_body_id).unwrap();

Program::new(mast_forest, loop_node_id)
};
Expand Down Expand Up @@ -699,8 +693,7 @@ fn call_block() {
let join1_node = MastNode::new_join(first_basic_block_id, foo_call_node_id, &mast_forest);
let join1_node_id = mast_forest.add_node(join1_node.clone()).unwrap();

let program_root = MastNode::new_join(join1_node_id, last_basic_block_id, &mast_forest);
let program_root_id = mast_forest.add_node(program_root).unwrap();
let program_root_id = mast_forest.add_join(join1_node_id, last_basic_block_id).unwrap();

let program = Program::new(mast_forest, program_root_id);

Expand Down Expand Up @@ -1305,8 +1298,7 @@ fn set_user_op_helpers_many() {
let program = {
let mut mast_forest = MastForest::new();

let basic_block = MastNode::new_basic_block(vec![Operation::U32div]);
let basic_block_id = mast_forest.add_node(basic_block).unwrap();
let basic_block_id = mast_forest.add_block(vec![Operation::U32div], None).unwrap();

Program::new(mast_forest, basic_block_id)
};
Expand Down
21 changes: 8 additions & 13 deletions processor/src/trace/tests/chiplets/hasher.rs
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@ use miden_air::trace::{
use vm_core::{
chiplets::hasher::apply_permutation,
crypto::merkle::{MerkleStore, MerkleTree, NodeIndex},
mast::{MastForest, MastNode},
mast::MastForest,
utils::range,
Program, Word,
};
Expand Down Expand Up @@ -50,8 +50,8 @@ pub fn b_chip_span() {
let program = {
let mut mast_forest = MastForest::new();

let basic_block = MastNode::new_basic_block(vec![Operation::Add, Operation::Mul]);
let basic_block_id = mast_forest.add_node(basic_block).unwrap();
let basic_block_id =
mast_forest.add_block(vec![Operation::Add, Operation::Mul], None).unwrap();

Program::new(mast_forest, basic_block_id)
};
Expand Down Expand Up @@ -123,8 +123,7 @@ pub fn b_chip_span_with_respan() {
let mut mast_forest = MastForest::new();

let (ops, _) = build_span_with_respan_ops();
let basic_block = MastNode::new_basic_block(ops);
let basic_block_id = mast_forest.add_node(basic_block).unwrap();
let basic_block_id = mast_forest.add_block(ops, None).unwrap();

Program::new(mast_forest, basic_block_id)
};
Expand Down Expand Up @@ -215,14 +214,11 @@ pub fn b_chip_merge() {
let program = {
let mut mast_forest = MastForest::new();

let t_branch = MastNode::new_basic_block(vec![Operation::Add]);
let t_branch_id = mast_forest.add_node(t_branch).unwrap();
let t_branch_id = mast_forest.add_block(vec![Operation::Add], None).unwrap();

let f_branch = MastNode::new_basic_block(vec![Operation::Mul]);
let f_branch_id = mast_forest.add_node(f_branch).unwrap();
let f_branch_id = mast_forest.add_block(vec![Operation::Mul], None).unwrap();

let split = MastNode::new_split(t_branch_id, f_branch_id, &mast_forest);
let split_id = mast_forest.add_node(split).unwrap();
let split_id = mast_forest.add_split(t_branch_id, f_branch_id).unwrap();

Program::new(mast_forest, split_id)
};
Expand Down Expand Up @@ -334,8 +330,7 @@ pub fn b_chip_permutation() {
let program = {
let mut mast_forest = MastForest::new();

let basic_block = MastNode::new_basic_block(vec![Operation::HPerm]);
let basic_block_id = mast_forest.add_node(basic_block).unwrap();
let basic_block_id = mast_forest.add_block(vec![Operation::HPerm], None).unwrap();

Program::new(mast_forest, basic_block_id)
};
Expand Down
Loading

0 comments on commit 67030ee

Please sign in to comment.