Skip to content

Commit

Permalink
replace basic block constructor and add MastNodeError
Browse files Browse the repository at this point in the history
  • Loading branch information
sergerad committed Aug 17, 2024
1 parent e7bcf21 commit fad447f
Show file tree
Hide file tree
Showing 9 changed files with 85 additions and 82 deletions.
8 changes: 2 additions & 6 deletions assembly/src/assembler/mast_forest_builder.rs
Original file line number Diff line number Diff line change
Expand Up @@ -170,12 +170,8 @@ impl MastForestBuilder {
operations: Vec<Operation>,
decorators: Option<DecoratorList>,
) -> Result<MastNodeId, AssemblyError> {
match decorators {
Some(decorators) => {
self.ensure_node(MastNode::new_basic_block_with_decorators(operations, decorators))
},
None => self.ensure_node(MastNode::new_basic_block(operations)),
}
let block = MastNode::new_basic_block(operations, decorators)?;
self.ensure_node(block)
}

/// Adds a join node to the forest, and returns the [`MastNodeId`] associated with it.
Expand Down
4 changes: 3 additions & 1 deletion assembly/src/errors.rs
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
use alloc::{string::String, sync::Arc, vec::Vec};

use vm_core::mast::MastForestError;
use vm_core::mast::{MastForestError, MastNodeError};

use crate::{
ast::QualifiedProcedureName,
Expand Down Expand Up @@ -76,6 +76,8 @@ pub enum AssemblyError {
Other(#[from] RelatedError),
#[error(transparent)]
Forest(#[from] MastForestError),
#[error(transparent)]
Node(#[from] MastNodeError),
}

impl From<Report> for AssemblyError {
Expand Down
20 changes: 14 additions & 6 deletions core/src/mast/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -68,12 +68,8 @@ impl MastForest {
operations: Vec<Operation>,
decorators: Option<DecoratorList>,
) -> Result<MastNodeId, MastForestError> {
match decorators {
Some(decorators) => {
self.add_node(MastNode::new_basic_block_with_decorators(operations, decorators))
},
None => self.add_node(MastNode::new_basic_block(operations)),
}
let block = MastNode::new_basic_block(operations, decorators)?;
self.add_node(block)
}

/// Adds a join node to the forest, and returns the [`MastNodeId`] associated with it.
Expand Down Expand Up @@ -271,4 +267,16 @@ pub enum MastForestError {
TooManyNodes,
#[error("node id: {0} is greater than or equal to forest length: {1}")]
NodeIdOverflow(MastNodeId, usize),
#[error(transparent)]
Node(#[from] MastNodeError),
}

// MAST NODE ERROR
// ================================================================================================

/// Represents the types of errors that can occur when dealing with MAST nodes directly.
#[derive(Debug, thiserror::Error, PartialEq)]
pub enum MastNodeError {
#[error("basic block cannot be created from an empty list of operations")]
EmptyBasicBlock,
}
45 changes: 21 additions & 24 deletions core/src/mast/node/basic_block_node/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,9 @@ use miden_crypto::{hash::rpo::RpoDigest, Felt, ZERO};
use miden_formatting::prettier::PrettyPrint;
use winter_utils::flatten_slice_elements;

use crate::{chiplets::hasher, Decorator, DecoratorIterator, DecoratorList, Operation};
use crate::{
chiplets::hasher, mast::MastNodeError, Decorator, DecoratorIterator, DecoratorList, Operation,
};

mod op_batch;
pub use op_batch::OpBatch;
Expand Down Expand Up @@ -77,15 +79,27 @@ impl BasicBlockNode {
// ------------------------------------------------------------------------------------------------
/// Constructors
impl BasicBlockNode {
/// Returns a new [`BasicBlockNode`] instantiated with the specified operations.
/// Returns a new [`BasicBlockNode`] instantiated with the specified operations and decorators.
///
/// # Errors (TODO)
/// Returns an error if:
/// - `operations` vector is empty.
/// - `operations` vector contains any number of system operations.
pub fn new(operations: Vec<Operation>) -> Self {
assert!(!operations.is_empty()); // TODO: return error
Self::with_decorators(operations, DecoratorList::new())
pub fn new(
operations: Vec<Operation>,
decorators: Option<DecoratorList>,
) -> Result<Self, MastNodeError> {
if operations.is_empty() {
return Err(MastNodeError::EmptyBasicBlock);
}

// None is equivalent to an empty list of decorators moving forward.
let decorators = decorators.unwrap_or_default();

// Validate decorators list (only in debug mode).
#[cfg(debug_assertions)]
validate_decorators(&operations, &decorators);

let (op_batches, digest) = batch_and_hash_ops(operations);
Ok(Self { op_batches, digest, decorators })
}

/// Returns a new [`BasicBlockNode`] from values that are assumed to be correct.
Expand All @@ -98,23 +112,6 @@ impl BasicBlockNode {
let (op_batches, _) = batch_ops(operations);
Self { op_batches, digest, decorators }
}

/// Returns a new [`BasicBlockNode`] instantiated with the specified operations and decorators.
///
/// # Errors (TODO)
/// Returns an error if:
/// - `operations` vector is empty.
/// - `operations` vector contains any number of system operations.
pub fn with_decorators(operations: Vec<Operation>, decorators: DecoratorList) -> Self {
assert!(!operations.is_empty()); // TODO: return error

// validate decorators list (only in debug mode)
#[cfg(debug_assertions)]
validate_decorators(&operations, &decorators);

let (op_batches, digest) = batch_and_hash_ops(operations);
Self { op_batches, digest, decorators }
}
}

// ------------------------------------------------------------------------------------------------
Expand Down
2 changes: 1 addition & 1 deletion core/src/mast/node/basic_block_node/tests.rs
Original file line number Diff line number Diff line change
Expand Up @@ -306,7 +306,7 @@ fn operation_or_decorator_iterator() {
(4, Decorator::Event(4)),
];

let node = BasicBlockNode::with_decorators(operations, decorators);
let node = BasicBlockNode::new(operations, Some(decorators)).unwrap();

let mut iterator = node.iter();

Expand Down
15 changes: 6 additions & 9 deletions core/src/mast/node/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@ pub use split_node::SplitNode;
mod loop_node;
pub use loop_node::LoopNode;

use super::MastForestError;
use super::{MastForestError, MastNodeError};
use crate::{
mast::{MastForest, MastNodeId},
DecoratorList, Operation,
Expand All @@ -50,15 +50,12 @@ pub enum MastNode {
// ------------------------------------------------------------------------------------------------
/// Constructors
impl MastNode {
pub fn new_basic_block(operations: Vec<Operation>) -> Self {
Self::Block(BasicBlockNode::new(operations))
}

pub fn new_basic_block_with_decorators(
pub fn new_basic_block(
operations: Vec<Operation>,
decorators: DecoratorList,
) -> Self {
Self::Block(BasicBlockNode::with_decorators(operations, decorators))
decorators: Option<DecoratorList>,
) -> Result<Self, MastNodeError> {
let block = BasicBlockNode::new(operations, decorators)?;
Ok(Self::Block(block))
}

pub fn new_join(
Expand Down
12 changes: 7 additions & 5 deletions processor/src/chiplets/hasher/tests.rs
Original file line number Diff line number Diff line change
Expand Up @@ -249,10 +249,10 @@ fn hash_memoization_control_blocks() {

let mut mast_forest = MastForest::new();

let t_branch = MastNode::new_basic_block(vec![Operation::Push(ZERO)]);
let t_branch = MastNode::new_basic_block(vec![Operation::Push(ZERO)], None).unwrap();
let t_branch_id = mast_forest.add_node(t_branch.clone()).unwrap();

let f_branch = MastNode::new_basic_block(vec![Operation::Push(ONE)]);
let f_branch = MastNode::new_basic_block(vec![Operation::Push(ONE)], None).unwrap();
let f_branch_id = mast_forest.add_node(f_branch.clone()).unwrap();

let split1 = MastNode::new_split(t_branch_id, f_branch_id, &mast_forest).unwrap();
Expand Down Expand Up @@ -350,12 +350,13 @@ fn hash_memoization_control_blocks() {
fn hash_memoization_basic_blocks() {
// --- basic block with 1 batch ----------------------------------------------------------------
let basic_block =
MastNode::new_basic_block(vec![Operation::Push(Felt::new(10)), Operation::Drop]);
MastNode::new_basic_block(vec![Operation::Push(Felt::new(10)), Operation::Drop], None)
.unwrap();

hash_memoization_basic_blocks_check(basic_block);

// --- basic block with multiple batches -------------------------------------------------------
let basic_block = MastNode::new_basic_block(vec![
let ops = vec![
Operation::Push(ONE),
Operation::Push(Felt::new(2)),
Operation::Push(Felt::new(3)),
Expand Down Expand Up @@ -392,7 +393,8 @@ fn hash_memoization_basic_blocks() {
Operation::Drop,
Operation::Drop,
Operation::Drop,
]);
];
let basic_block = MastNode::new_basic_block(ops, None).unwrap();

hash_memoization_basic_blocks_check(basic_block);
}
Expand Down
47 changes: 24 additions & 23 deletions processor/src/decoder/tests.rs
Original file line number Diff line number Diff line change
Expand Up @@ -47,7 +47,7 @@ type DecoderTrace = [Vec<Felt>; DECODER_TRACE_WIDTH];
#[test]
fn basic_block_one_group() {
let ops = vec![Operation::Pad, Operation::Add, Operation::Mul];
let basic_block = BasicBlockNode::new(ops.clone());
let basic_block = BasicBlockNode::new(ops.clone(), None).unwrap();
let program = {
let mut mast_forest = MastForest::new();

Expand Down Expand Up @@ -93,7 +93,7 @@ fn basic_block_one_group() {
fn basic_block_small() {
let iv = [ONE, TWO];
let ops = vec![Operation::Push(iv[0]), Operation::Push(iv[1]), Operation::Add];
let basic_block = BasicBlockNode::new(ops.clone());
let basic_block = BasicBlockNode::new(ops.clone(), None).unwrap();
let program = {
let mut mast_forest = MastForest::new();

Expand Down Expand Up @@ -156,7 +156,7 @@ fn basic_block() {
Operation::Add,
Operation::Inv,
];
let basic_block = BasicBlockNode::new(ops.clone());
let basic_block = BasicBlockNode::new(ops.clone(), None).unwrap();
let program = {
let mut mast_forest = MastForest::new();

Expand Down Expand Up @@ -248,7 +248,7 @@ fn span_block_with_respan() {
Operation::Add,
Operation::Push(iv[8]),
];
let basic_block = BasicBlockNode::new(ops.clone());
let basic_block = BasicBlockNode::new(ops.clone(), None).unwrap();
let program = {
let mut mast_forest = MastForest::new();

Expand Down Expand Up @@ -321,8 +321,8 @@ fn span_block_with_respan() {

#[test]
fn join_node() {
let basic_block1 = MastNode::new_basic_block(vec![Operation::Mul]);
let basic_block2 = MastNode::new_basic_block(vec![Operation::Add]);
let basic_block1 = MastNode::new_basic_block(vec![Operation::Mul], None).unwrap();
let basic_block2 = MastNode::new_basic_block(vec![Operation::Add], None).unwrap();
let program = {
let mut mast_forest = MastForest::new();

Expand Down Expand Up @@ -386,8 +386,8 @@ fn join_node() {

#[test]
fn split_node_true() {
let basic_block1 = MastNode::new_basic_block(vec![Operation::Mul]);
let basic_block2 = MastNode::new_basic_block(vec![Operation::Add]);
let basic_block1 = MastNode::new_basic_block(vec![Operation::Mul], None).unwrap();
let basic_block2 = MastNode::new_basic_block(vec![Operation::Add], None).unwrap();
let program = {
let mut mast_forest = MastForest::new();

Expand Down Expand Up @@ -438,8 +438,8 @@ fn split_node_true() {

#[test]
fn split_node_false() {
let basic_block1 = MastNode::new_basic_block(vec![Operation::Mul]);
let basic_block2 = MastNode::new_basic_block(vec![Operation::Add]);
let basic_block1 = MastNode::new_basic_block(vec![Operation::Mul], None).unwrap();
let basic_block2 = MastNode::new_basic_block(vec![Operation::Add], None).unwrap();
let program = {
let mut mast_forest = MastForest::new();

Expand Down Expand Up @@ -493,7 +493,7 @@ fn split_node_false() {

#[test]
fn loop_node() {
let loop_body = MastNode::new_basic_block(vec![Operation::Pad, Operation::Drop]);
let loop_body = MastNode::new_basic_block(vec![Operation::Pad, Operation::Drop], None).unwrap();
let program = {
let mut mast_forest = MastForest::new();

Expand Down Expand Up @@ -545,7 +545,7 @@ fn loop_node() {

#[test]
fn loop_node_skip() {
let loop_body = MastNode::new_basic_block(vec![Operation::Pad, Operation::Drop]);
let loop_body = MastNode::new_basic_block(vec![Operation::Pad, Operation::Drop], None).unwrap();
let program = {
let mut mast_forest = MastForest::new();

Expand Down Expand Up @@ -587,7 +587,7 @@ fn loop_node_skip() {

#[test]
fn loop_node_repeat() {
let loop_body = MastNode::new_basic_block(vec![Operation::Pad, Operation::Drop]);
let loop_body = MastNode::new_basic_block(vec![Operation::Pad, Operation::Drop], None).unwrap();
let program = {
let mut mast_forest = MastForest::new();

Expand Down Expand Up @@ -678,15 +678,15 @@ fn call_block() {
Operation::Push(TWO),
Operation::FmpUpdate,
Operation::Pad,
]);
], None).unwrap();
let first_basic_block_id = mast_forest.add_node(first_basic_block.clone()).unwrap();

let foo_root_node = MastNode::new_basic_block(vec![
Operation::Push(ONE), Operation::FmpUpdate
]);
], None).unwrap();
let foo_root_node_id = mast_forest.add_node(foo_root_node.clone()).unwrap();

let last_basic_block = MastNode::new_basic_block(vec![Operation::FmpAdd]);
let last_basic_block = MastNode::new_basic_block(vec![Operation::FmpAdd], None).unwrap();
let last_basic_block_id = mast_forest.add_node(last_basic_block.clone()).unwrap();

let foo_call_node = MastNode::new_call(foo_root_node_id, &mast_forest).unwrap();
Expand Down Expand Up @@ -887,13 +887,13 @@ fn syscall_block() {
let mut mast_forest = MastForest::new();

// build foo procedure body
let foo_root = MastNode::new_basic_block(vec![Operation::Push(THREE), Operation::FmpUpdate]);
let foo_root = MastNode::new_basic_block(vec![Operation::Push(THREE), Operation::FmpUpdate], None).unwrap();
let foo_root_id = mast_forest.add_node(foo_root.clone()).unwrap();
mast_forest.make_root(foo_root_id);
let kernel = Kernel::new(&[foo_root.digest()]).unwrap();

// build bar procedure body
let bar_basic_block = MastNode::new_basic_block(vec![Operation::Push(TWO), Operation::FmpUpdate]);
let bar_basic_block = MastNode::new_basic_block(vec![Operation::Push(TWO), Operation::FmpUpdate], None).unwrap();
let bar_basic_block_id = mast_forest.add_node(bar_basic_block.clone()).unwrap();

let foo_call_node = MastNode::new_syscall(foo_root_id, &mast_forest).unwrap();
Expand All @@ -908,10 +908,10 @@ fn syscall_block() {
Operation::Push(ONE),
Operation::FmpUpdate,
Operation::Pad,
]);
], None).unwrap();
let first_basic_block_id = mast_forest.add_node(first_basic_block.clone()).unwrap();

let last_basic_block = MastNode::new_basic_block(vec![Operation::FmpAdd]);
let last_basic_block = MastNode::new_basic_block(vec![Operation::FmpAdd], None).unwrap();
let last_basic_block_id = mast_forest.add_node(last_basic_block.clone()).unwrap();

let bar_call_node = MastNode::new_call(bar_root_node_id, &mast_forest).unwrap();
Expand Down Expand Up @@ -1175,14 +1175,15 @@ fn dyn_block() {

let mut mast_forest = MastForest::new();

let foo_root_node = MastNode::new_basic_block(vec![Operation::Push(ONE), Operation::Add]);
let foo_root_node =
MastNode::new_basic_block(vec![Operation::Push(ONE), Operation::Add], None).unwrap();
let foo_root_node_id = mast_forest.add_node(foo_root_node.clone()).unwrap();
mast_forest.make_root(foo_root_node_id);

let mul_bb_node = MastNode::new_basic_block(vec![Operation::Mul]);
let mul_bb_node = MastNode::new_basic_block(vec![Operation::Mul], None).unwrap();
let mul_bb_node_id = mast_forest.add_node(mul_bb_node.clone()).unwrap();

let save_bb_node = MastNode::new_basic_block(vec![Operation::MovDn4]);
let save_bb_node = MastNode::new_basic_block(vec![Operation::MovDn4], None).unwrap();
let save_bb_node_id = mast_forest.add_node(save_bb_node.clone()).unwrap();

let join_node = MastNode::new_join(mul_bb_node_id, save_bb_node_id, &mast_forest).unwrap();
Expand Down
Loading

0 comments on commit fad447f

Please sign in to comment.