From fad447f56ff68cbcc5603849071ed99e85e01ad0 Mon Sep 17 00:00:00 2001 From: sergerad Date: Sat, 17 Aug 2024 20:58:35 +1200 Subject: [PATCH] replace basic block constructor and add MastNodeError --- assembly/src/assembler/mast_forest_builder.rs | 8 +--- assembly/src/errors.rs | 4 +- core/src/mast/mod.rs | 20 +++++--- core/src/mast/node/basic_block_node/mod.rs | 45 +++++++++--------- core/src/mast/node/basic_block_node/tests.rs | 2 +- core/src/mast/node/mod.rs | 15 +++--- processor/src/chiplets/hasher/tests.rs | 12 +++-- processor/src/decoder/tests.rs | 47 ++++++++++--------- processor/src/trace/tests/decoder.rs | 14 +++--- 9 files changed, 85 insertions(+), 82 deletions(-) diff --git a/assembly/src/assembler/mast_forest_builder.rs b/assembly/src/assembler/mast_forest_builder.rs index 317420684..bca1934ce 100644 --- a/assembly/src/assembler/mast_forest_builder.rs +++ b/assembly/src/assembler/mast_forest_builder.rs @@ -170,12 +170,8 @@ impl MastForestBuilder { operations: Vec, decorators: Option, ) -> Result { - match decorators { - Some(decorators) => { - self.ensure_node(MastNode::new_basic_block_with_decorators(operations, decorators)) - }, - None => self.ensure_node(MastNode::new_basic_block(operations)), - } + let block = MastNode::new_basic_block(operations, decorators)?; + self.ensure_node(block) } /// Adds a join node to the forest, and returns the [`MastNodeId`] associated with it. diff --git a/assembly/src/errors.rs b/assembly/src/errors.rs index 8c4a3d22c..cc79a2d4c 100644 --- a/assembly/src/errors.rs +++ b/assembly/src/errors.rs @@ -1,6 +1,6 @@ use alloc::{string::String, sync::Arc, vec::Vec}; -use vm_core::mast::MastForestError; +use vm_core::mast::{MastForestError, MastNodeError}; use crate::{ ast::QualifiedProcedureName, @@ -76,6 +76,8 @@ pub enum AssemblyError { Other(#[from] RelatedError), #[error(transparent)] Forest(#[from] MastForestError), + #[error(transparent)] + Node(#[from] MastNodeError), } impl From for AssemblyError { diff --git a/core/src/mast/mod.rs b/core/src/mast/mod.rs index b594c0875..f59d90ad7 100644 --- a/core/src/mast/mod.rs +++ b/core/src/mast/mod.rs @@ -68,12 +68,8 @@ impl MastForest { operations: Vec, decorators: Option, ) -> Result { - match decorators { - Some(decorators) => { - self.add_node(MastNode::new_basic_block_with_decorators(operations, decorators)) - }, - None => self.add_node(MastNode::new_basic_block(operations)), - } + let block = MastNode::new_basic_block(operations, decorators)?; + self.add_node(block) } /// Adds a join node to the forest, and returns the [`MastNodeId`] associated with it. @@ -271,4 +267,16 @@ pub enum MastForestError { TooManyNodes, #[error("node id: {0} is greater than or equal to forest length: {1}")] NodeIdOverflow(MastNodeId, usize), + #[error(transparent)] + Node(#[from] MastNodeError), +} + +// MAST NODE ERROR +// ================================================================================================ + +/// Represents the types of errors that can occur when dealing with MAST nodes directly. +#[derive(Debug, thiserror::Error, PartialEq)] +pub enum MastNodeError { + #[error("basic block cannot be created from an empty list of operations")] + EmptyBasicBlock, } diff --git a/core/src/mast/node/basic_block_node/mod.rs b/core/src/mast/node/basic_block_node/mod.rs index 7e487458d..35bf7cd44 100644 --- a/core/src/mast/node/basic_block_node/mod.rs +++ b/core/src/mast/node/basic_block_node/mod.rs @@ -5,7 +5,9 @@ use miden_crypto::{hash::rpo::RpoDigest, Felt, ZERO}; use miden_formatting::prettier::PrettyPrint; use winter_utils::flatten_slice_elements; -use crate::{chiplets::hasher, Decorator, DecoratorIterator, DecoratorList, Operation}; +use crate::{ + chiplets::hasher, mast::MastNodeError, Decorator, DecoratorIterator, DecoratorList, Operation, +}; mod op_batch; pub use op_batch::OpBatch; @@ -77,15 +79,27 @@ impl BasicBlockNode { // ------------------------------------------------------------------------------------------------ /// Constructors impl BasicBlockNode { - /// Returns a new [`BasicBlockNode`] instantiated with the specified operations. + /// Returns a new [`BasicBlockNode`] instantiated with the specified operations and decorators. /// - /// # Errors (TODO) /// Returns an error if: /// - `operations` vector is empty. - /// - `operations` vector contains any number of system operations. - pub fn new(operations: Vec) -> Self { - assert!(!operations.is_empty()); // TODO: return error - Self::with_decorators(operations, DecoratorList::new()) + pub fn new( + operations: Vec, + decorators: Option, + ) -> Result { + if operations.is_empty() { + return Err(MastNodeError::EmptyBasicBlock); + } + + // None is equivalent to an empty list of decorators moving forward. + let decorators = decorators.unwrap_or_default(); + + // Validate decorators list (only in debug mode). + #[cfg(debug_assertions)] + validate_decorators(&operations, &decorators); + + let (op_batches, digest) = batch_and_hash_ops(operations); + Ok(Self { op_batches, digest, decorators }) } /// Returns a new [`BasicBlockNode`] from values that are assumed to be correct. @@ -98,23 +112,6 @@ impl BasicBlockNode { let (op_batches, _) = batch_ops(operations); Self { op_batches, digest, decorators } } - - /// Returns a new [`BasicBlockNode`] instantiated with the specified operations and decorators. - /// - /// # Errors (TODO) - /// Returns an error if: - /// - `operations` vector is empty. - /// - `operations` vector contains any number of system operations. - pub fn with_decorators(operations: Vec, decorators: DecoratorList) -> Self { - assert!(!operations.is_empty()); // TODO: return error - - // validate decorators list (only in debug mode) - #[cfg(debug_assertions)] - validate_decorators(&operations, &decorators); - - let (op_batches, digest) = batch_and_hash_ops(operations); - Self { op_batches, digest, decorators } - } } // ------------------------------------------------------------------------------------------------ diff --git a/core/src/mast/node/basic_block_node/tests.rs b/core/src/mast/node/basic_block_node/tests.rs index 1826be788..b44b3f1fd 100644 --- a/core/src/mast/node/basic_block_node/tests.rs +++ b/core/src/mast/node/basic_block_node/tests.rs @@ -306,7 +306,7 @@ fn operation_or_decorator_iterator() { (4, Decorator::Event(4)), ]; - let node = BasicBlockNode::with_decorators(operations, decorators); + let node = BasicBlockNode::new(operations, Some(decorators)).unwrap(); let mut iterator = node.iter(); diff --git a/core/src/mast/node/mod.rs b/core/src/mast/node/mod.rs index 57f4327e1..b7c80da09 100644 --- a/core/src/mast/node/mod.rs +++ b/core/src/mast/node/mod.rs @@ -27,7 +27,7 @@ pub use split_node::SplitNode; mod loop_node; pub use loop_node::LoopNode; -use super::MastForestError; +use super::{MastForestError, MastNodeError}; use crate::{ mast::{MastForest, MastNodeId}, DecoratorList, Operation, @@ -50,15 +50,12 @@ pub enum MastNode { // ------------------------------------------------------------------------------------------------ /// Constructors impl MastNode { - pub fn new_basic_block(operations: Vec) -> Self { - Self::Block(BasicBlockNode::new(operations)) - } - - pub fn new_basic_block_with_decorators( + pub fn new_basic_block( operations: Vec, - decorators: DecoratorList, - ) -> Self { - Self::Block(BasicBlockNode::with_decorators(operations, decorators)) + decorators: Option, + ) -> Result { + let block = BasicBlockNode::new(operations, decorators)?; + Ok(Self::Block(block)) } pub fn new_join( diff --git a/processor/src/chiplets/hasher/tests.rs b/processor/src/chiplets/hasher/tests.rs index 190d5ca85..1d90d3951 100644 --- a/processor/src/chiplets/hasher/tests.rs +++ b/processor/src/chiplets/hasher/tests.rs @@ -249,10 +249,10 @@ fn hash_memoization_control_blocks() { let mut mast_forest = MastForest::new(); - let t_branch = MastNode::new_basic_block(vec![Operation::Push(ZERO)]); + let t_branch = MastNode::new_basic_block(vec![Operation::Push(ZERO)], None).unwrap(); let t_branch_id = mast_forest.add_node(t_branch.clone()).unwrap(); - let f_branch = MastNode::new_basic_block(vec![Operation::Push(ONE)]); + let f_branch = MastNode::new_basic_block(vec![Operation::Push(ONE)], None).unwrap(); let f_branch_id = mast_forest.add_node(f_branch.clone()).unwrap(); let split1 = MastNode::new_split(t_branch_id, f_branch_id, &mast_forest).unwrap(); @@ -350,12 +350,13 @@ fn hash_memoization_control_blocks() { fn hash_memoization_basic_blocks() { // --- basic block with 1 batch ---------------------------------------------------------------- let basic_block = - MastNode::new_basic_block(vec![Operation::Push(Felt::new(10)), Operation::Drop]); + MastNode::new_basic_block(vec![Operation::Push(Felt::new(10)), Operation::Drop], None) + .unwrap(); hash_memoization_basic_blocks_check(basic_block); // --- basic block with multiple batches ------------------------------------------------------- - let basic_block = MastNode::new_basic_block(vec![ + let ops = vec![ Operation::Push(ONE), Operation::Push(Felt::new(2)), Operation::Push(Felt::new(3)), @@ -392,7 +393,8 @@ fn hash_memoization_basic_blocks() { Operation::Drop, Operation::Drop, Operation::Drop, - ]); + ]; + let basic_block = MastNode::new_basic_block(ops, None).unwrap(); hash_memoization_basic_blocks_check(basic_block); } diff --git a/processor/src/decoder/tests.rs b/processor/src/decoder/tests.rs index 6e7d74228..f17f94d1b 100644 --- a/processor/src/decoder/tests.rs +++ b/processor/src/decoder/tests.rs @@ -47,7 +47,7 @@ type DecoderTrace = [Vec; DECODER_TRACE_WIDTH]; #[test] fn basic_block_one_group() { let ops = vec![Operation::Pad, Operation::Add, Operation::Mul]; - let basic_block = BasicBlockNode::new(ops.clone()); + let basic_block = BasicBlockNode::new(ops.clone(), None).unwrap(); let program = { let mut mast_forest = MastForest::new(); @@ -93,7 +93,7 @@ fn basic_block_one_group() { fn basic_block_small() { let iv = [ONE, TWO]; let ops = vec![Operation::Push(iv[0]), Operation::Push(iv[1]), Operation::Add]; - let basic_block = BasicBlockNode::new(ops.clone()); + let basic_block = BasicBlockNode::new(ops.clone(), None).unwrap(); let program = { let mut mast_forest = MastForest::new(); @@ -156,7 +156,7 @@ fn basic_block() { Operation::Add, Operation::Inv, ]; - let basic_block = BasicBlockNode::new(ops.clone()); + let basic_block = BasicBlockNode::new(ops.clone(), None).unwrap(); let program = { let mut mast_forest = MastForest::new(); @@ -248,7 +248,7 @@ fn span_block_with_respan() { Operation::Add, Operation::Push(iv[8]), ]; - let basic_block = BasicBlockNode::new(ops.clone()); + let basic_block = BasicBlockNode::new(ops.clone(), None).unwrap(); let program = { let mut mast_forest = MastForest::new(); @@ -321,8 +321,8 @@ fn span_block_with_respan() { #[test] fn join_node() { - let basic_block1 = MastNode::new_basic_block(vec![Operation::Mul]); - let basic_block2 = MastNode::new_basic_block(vec![Operation::Add]); + let basic_block1 = MastNode::new_basic_block(vec![Operation::Mul], None).unwrap(); + let basic_block2 = MastNode::new_basic_block(vec![Operation::Add], None).unwrap(); let program = { let mut mast_forest = MastForest::new(); @@ -386,8 +386,8 @@ fn join_node() { #[test] fn split_node_true() { - let basic_block1 = MastNode::new_basic_block(vec![Operation::Mul]); - let basic_block2 = MastNode::new_basic_block(vec![Operation::Add]); + let basic_block1 = MastNode::new_basic_block(vec![Operation::Mul], None).unwrap(); + let basic_block2 = MastNode::new_basic_block(vec![Operation::Add], None).unwrap(); let program = { let mut mast_forest = MastForest::new(); @@ -438,8 +438,8 @@ fn split_node_true() { #[test] fn split_node_false() { - let basic_block1 = MastNode::new_basic_block(vec![Operation::Mul]); - let basic_block2 = MastNode::new_basic_block(vec![Operation::Add]); + let basic_block1 = MastNode::new_basic_block(vec![Operation::Mul], None).unwrap(); + let basic_block2 = MastNode::new_basic_block(vec![Operation::Add], None).unwrap(); let program = { let mut mast_forest = MastForest::new(); @@ -493,7 +493,7 @@ fn split_node_false() { #[test] fn loop_node() { - let loop_body = MastNode::new_basic_block(vec![Operation::Pad, Operation::Drop]); + let loop_body = MastNode::new_basic_block(vec![Operation::Pad, Operation::Drop], None).unwrap(); let program = { let mut mast_forest = MastForest::new(); @@ -545,7 +545,7 @@ fn loop_node() { #[test] fn loop_node_skip() { - let loop_body = MastNode::new_basic_block(vec![Operation::Pad, Operation::Drop]); + let loop_body = MastNode::new_basic_block(vec![Operation::Pad, Operation::Drop], None).unwrap(); let program = { let mut mast_forest = MastForest::new(); @@ -587,7 +587,7 @@ fn loop_node_skip() { #[test] fn loop_node_repeat() { - let loop_body = MastNode::new_basic_block(vec![Operation::Pad, Operation::Drop]); + let loop_body = MastNode::new_basic_block(vec![Operation::Pad, Operation::Drop], None).unwrap(); let program = { let mut mast_forest = MastForest::new(); @@ -678,15 +678,15 @@ fn call_block() { Operation::Push(TWO), Operation::FmpUpdate, Operation::Pad, - ]); + ], None).unwrap(); let first_basic_block_id = mast_forest.add_node(first_basic_block.clone()).unwrap(); let foo_root_node = MastNode::new_basic_block(vec![ Operation::Push(ONE), Operation::FmpUpdate - ]); + ], None).unwrap(); let foo_root_node_id = mast_forest.add_node(foo_root_node.clone()).unwrap(); - let last_basic_block = MastNode::new_basic_block(vec![Operation::FmpAdd]); + let last_basic_block = MastNode::new_basic_block(vec![Operation::FmpAdd], None).unwrap(); let last_basic_block_id = mast_forest.add_node(last_basic_block.clone()).unwrap(); let foo_call_node = MastNode::new_call(foo_root_node_id, &mast_forest).unwrap(); @@ -887,13 +887,13 @@ fn syscall_block() { let mut mast_forest = MastForest::new(); // build foo procedure body - let foo_root = MastNode::new_basic_block(vec![Operation::Push(THREE), Operation::FmpUpdate]); + let foo_root = MastNode::new_basic_block(vec![Operation::Push(THREE), Operation::FmpUpdate], None).unwrap(); let foo_root_id = mast_forest.add_node(foo_root.clone()).unwrap(); mast_forest.make_root(foo_root_id); let kernel = Kernel::new(&[foo_root.digest()]).unwrap(); // build bar procedure body - let bar_basic_block = MastNode::new_basic_block(vec![Operation::Push(TWO), Operation::FmpUpdate]); + let bar_basic_block = MastNode::new_basic_block(vec![Operation::Push(TWO), Operation::FmpUpdate], None).unwrap(); let bar_basic_block_id = mast_forest.add_node(bar_basic_block.clone()).unwrap(); let foo_call_node = MastNode::new_syscall(foo_root_id, &mast_forest).unwrap(); @@ -908,10 +908,10 @@ fn syscall_block() { Operation::Push(ONE), Operation::FmpUpdate, Operation::Pad, - ]); + ], None).unwrap(); let first_basic_block_id = mast_forest.add_node(first_basic_block.clone()).unwrap(); - let last_basic_block = MastNode::new_basic_block(vec![Operation::FmpAdd]); + let last_basic_block = MastNode::new_basic_block(vec![Operation::FmpAdd], None).unwrap(); let last_basic_block_id = mast_forest.add_node(last_basic_block.clone()).unwrap(); let bar_call_node = MastNode::new_call(bar_root_node_id, &mast_forest).unwrap(); @@ -1175,14 +1175,15 @@ fn dyn_block() { let mut mast_forest = MastForest::new(); - let foo_root_node = MastNode::new_basic_block(vec![Operation::Push(ONE), Operation::Add]); + let foo_root_node = + MastNode::new_basic_block(vec![Operation::Push(ONE), Operation::Add], None).unwrap(); let foo_root_node_id = mast_forest.add_node(foo_root_node.clone()).unwrap(); mast_forest.make_root(foo_root_node_id); - let mul_bb_node = MastNode::new_basic_block(vec![Operation::Mul]); + let mul_bb_node = MastNode::new_basic_block(vec![Operation::Mul], None).unwrap(); let mul_bb_node_id = mast_forest.add_node(mul_bb_node.clone()).unwrap(); - let save_bb_node = MastNode::new_basic_block(vec![Operation::MovDn4]); + let save_bb_node = MastNode::new_basic_block(vec![Operation::MovDn4], None).unwrap(); let save_bb_node_id = mast_forest.add_node(save_bb_node.clone()).unwrap(); let join_node = MastNode::new_join(mul_bb_node_id, save_bb_node_id, &mast_forest).unwrap(); diff --git a/processor/src/trace/tests/decoder.rs b/processor/src/trace/tests/decoder.rs index 56ff5cb96..1213b05dd 100644 --- a/processor/src/trace/tests/decoder.rs +++ b/processor/src/trace/tests/decoder.rs @@ -358,10 +358,10 @@ fn decoder_p2_span_with_respan() { fn decoder_p2_join() { let mut mast_forest = MastForest::new(); - let basic_block_1 = MastNode::new_basic_block(vec![Operation::Mul]); + let basic_block_1 = MastNode::new_basic_block(vec![Operation::Mul], None).unwrap(); let basic_block_1_id = mast_forest.add_node(basic_block_1.clone()).unwrap(); - let basic_block_2 = MastNode::new_basic_block(vec![Operation::Add]); + let basic_block_2 = MastNode::new_basic_block(vec![Operation::Add], None).unwrap(); let basic_block_2_id = mast_forest.add_node(basic_block_2.clone()).unwrap(); let join = MastNode::new_join(basic_block_1_id, basic_block_2_id, &mast_forest).unwrap(); @@ -423,7 +423,7 @@ fn decoder_p2_split_true() { // build program let mut mast_forest = MastForest::new(); - let basic_block_1 = MastNode::new_basic_block(vec![Operation::Mul]); + let basic_block_1 = MastNode::new_basic_block(vec![Operation::Mul], None).unwrap(); let basic_block_1_id = mast_forest.add_node(basic_block_1.clone()).unwrap(); let basic_block_2_id = mast_forest.add_block(vec![Operation::Add], None).unwrap(); @@ -477,10 +477,10 @@ fn decoder_p2_split_false() { // build program let mut mast_forest = MastForest::new(); - let basic_block_1 = MastNode::new_basic_block(vec![Operation::Mul]); + let basic_block_1 = MastNode::new_basic_block(vec![Operation::Mul], None).unwrap(); let basic_block_1_id = mast_forest.add_node(basic_block_1.clone()).unwrap(); - let basic_block_2 = MastNode::new_basic_block(vec![Operation::Add]); + let basic_block_2 = MastNode::new_basic_block(vec![Operation::Add], None).unwrap(); let basic_block_2_id = mast_forest.add_node(basic_block_2.clone()).unwrap(); let split_id = mast_forest.add_split(basic_block_1_id, basic_block_2_id).unwrap(); @@ -532,10 +532,10 @@ fn decoder_p2_loop_with_repeat() { // build program let mut mast_forest = MastForest::new(); - let basic_block_1 = MastNode::new_basic_block(vec![Operation::Pad]); + let basic_block_1 = MastNode::new_basic_block(vec![Operation::Pad], None).unwrap(); let basic_block_1_id = mast_forest.add_node(basic_block_1.clone()).unwrap(); - let basic_block_2 = MastNode::new_basic_block(vec![Operation::Drop]); + let basic_block_2 = MastNode::new_basic_block(vec![Operation::Drop], None).unwrap(); let basic_block_2_id = mast_forest.add_node(basic_block_2.clone()).unwrap(); let join = MastNode::new_join(basic_block_1_id, basic_block_2_id, &mast_forest).unwrap();