From 64c74014d00580baeeca2837bd2e8ecf06bc4a2e Mon Sep 17 00:00:00 2001 From: Serge Radinovich <47865535+sergerad@users.noreply.github.com> Date: Sun, 21 Jul 2024 06:41:50 +1200 Subject: [PATCH] feat: add more functions for ensuring nodes via MastForestBuilder (#1404) --- CHANGELOG.md | 9 +- assembly/src/assembler/basic_block_builder.rs | 5 +- .../src/assembler/instruction/procedures.rs | 26 ++-- assembly/src/assembler/mast_forest_builder.rs | 62 ++++++++- assembly/src/assembler/mod.rs | 27 ++-- assembly/src/assembler/tests.rs | 128 +++++++----------- core/src/mast/node/mod.rs | 2 +- core/src/mast/serialization/info.rs | 2 +- core/src/mast/serialization/tests.rs | 2 +- processor/src/decoder/tests.rs | 2 +- 10 files changed, 139 insertions(+), 126 deletions(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index 72392ea618..b8751ad390 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -6,8 +6,8 @@ - Added error codes support for the `mtree_verify` instruction (#1328). - Added support for immediate values for `lt`, `lte`, `gt`, `gte` comparison instructions (#1346). -- Change MAST to a table-based representation (#1349) -- Introduce `MastForestStore` (#1359) +- Changed MAST to a table-based representation (#1349) +- Introduced `MastForestStore` (#1359) - Adjusted prover's metal acceleration code to work with 0.9 versions of the crates (#1357) - Added support for immediate values for `u32lt`, `u32lte`, `u32gt`, `u32gte`, `u32min` and `u32max` comparison instructions (#1358). - Added support for the `nop` instruction, which corresponds to the VM opcode of the same name, and has the same semantics. This is implemented for use by compilers primarily. @@ -16,9 +16,10 @@ - Added support for immediate values for `u32and`, `u32or`, `u32xor` and `u32not` bitwise instructions (#1362). - Optimized `std::sys::truncate_stuck` procedure (#1384). - Updated CI and Makefile to standardise it accross Miden repositories (#1342). -- Add serialization/deserialization for `MastForest` (#1370) +- Added serialization/deserialization for `MastForest` (#1370) - Updated CI to support `CHANGELOG.md` modification checking and `no changelog` label (#1406) -- Introduce `MastForestError` to enforce `MastForest` node count invariant (#1394) +- Introduced `MastForestError` to enforce `MastForest` node count invariant (#1394) +- Added functions to `MastForestBuilder` to allow ensuring of nodes with fewer LOC (#1404) #### Changed diff --git a/assembly/src/assembler/basic_block_builder.rs b/assembly/src/assembler/basic_block_builder.rs index 2545d4eea8..5263faeff2 100644 --- a/assembly/src/assembler/basic_block_builder.rs +++ b/assembly/src/assembler/basic_block_builder.rs @@ -4,7 +4,7 @@ use super::{ }; use alloc::{borrow::Borrow, string::ToString, vec::Vec}; use vm_core::{ - mast::{MastForestError, MastNode, MastNodeId}, + mast::{MastForestError, MastNodeId}, AdviceInjector, AssemblyOp, Operation, }; @@ -134,8 +134,7 @@ impl BasicBlockBuilder { let ops = self.ops.drain(..).collect(); let decorators = self.decorators.drain(..).collect(); - let basic_block_node = MastNode::new_basic_block_with_decorators(ops, decorators); - let basic_block_node_id = mast_forest_builder.ensure_node(basic_block_node)?; + let basic_block_node_id = mast_forest_builder.ensure_block(ops, Some(decorators))?; Ok(Some(basic_block_node_id)) } else if !self.decorators.is_empty() { diff --git a/assembly/src/assembler/instruction/procedures.rs b/assembly/src/assembler/instruction/procedures.rs index d6a96f0f1b..a86474fce7 100644 --- a/assembly/src/assembler/instruction/procedures.rs +++ b/assembly/src/assembler/instruction/procedures.rs @@ -6,7 +6,7 @@ use crate::{ }; use smallvec::SmallVec; -use vm_core::mast::{MastForest, MastNode, MastNodeId}; +use vm_core::mast::{MastForest, MastNodeId}; /// Procedure Invocation impl Assembler { @@ -96,8 +96,7 @@ impl Assembler { None => { // If the MAST root called isn't known to us, make it an external // reference. - let external_node = MastNode::new_external(mast_root); - mast_forest_builder.ensure_node(external_node)? + mast_forest_builder.ensure_external(mast_root)? } } } @@ -107,13 +106,11 @@ impl Assembler { None => { // If the MAST root called isn't known to us, make it an external // reference. - let external_node = MastNode::new_external(mast_root); - mast_forest_builder.ensure_node(external_node)? + mast_forest_builder.ensure_external(mast_root)? } }; - let call_node = MastNode::new_call(callee_id, mast_forest_builder.forest()); - mast_forest_builder.ensure_node(call_node)? + mast_forest_builder.ensure_call(callee_id)? } InvokeKind::SysCall => { let callee_id = match mast_forest_builder.find_procedure_root(mast_root) { @@ -121,14 +118,11 @@ impl Assembler { None => { // If the MAST root called isn't known to us, make it an external // reference. - let external_node = MastNode::new_external(mast_root); - mast_forest_builder.ensure_node(external_node)? + mast_forest_builder.ensure_external(mast_root)? } }; - let syscall_node = - MastNode::new_syscall(callee_id, mast_forest_builder.forest()); - mast_forest_builder.ensure_node(syscall_node)? + mast_forest_builder.ensure_syscall(callee_id)? } } }; @@ -141,7 +135,7 @@ impl Assembler { &self, mast_forest_builder: &mut MastForestBuilder, ) -> Result, AssemblyError> { - let dyn_node_id = mast_forest_builder.ensure_node(MastNode::Dyn)?; + let dyn_node_id = mast_forest_builder.ensure_dyn()?; Ok(Some(dyn_node_id)) } @@ -152,10 +146,8 @@ impl Assembler { mast_forest_builder: &mut MastForestBuilder, ) -> Result, AssemblyError> { let dyn_call_node_id = { - let dyn_node_id = mast_forest_builder.ensure_node(MastNode::Dyn)?; - let dyn_call_node = MastNode::new_call(dyn_node_id, mast_forest_builder.forest()); - - mast_forest_builder.ensure_node(dyn_call_node)? + let dyn_node_id = mast_forest_builder.ensure_dyn()?; + mast_forest_builder.ensure_call(dyn_node_id)? }; Ok(Some(dyn_call_node_id)) diff --git a/assembly/src/assembler/mast_forest_builder.rs b/assembly/src/assembler/mast_forest_builder.rs index a5854c8b01..0a1a44d037 100644 --- a/assembly/src/assembler/mast_forest_builder.rs +++ b/assembly/src/assembler/mast_forest_builder.rs @@ -1,9 +1,10 @@ use core::ops::Index; -use alloc::collections::BTreeMap; +use alloc::{collections::BTreeMap, vec::Vec}; use vm_core::{ crypto::hash::RpoDigest, mast::{MastForest, MastForestError, MastNode, MastNodeId, MerkleTreeNode}, + DecoratorList, Operation, }; /// Builder for a [`MastForest`]. @@ -44,7 +45,7 @@ impl MastForestBuilder { /// If a [`MastNode`] which is equal to the current node was previously added, the previously /// returned [`MastNodeId`] will be returned. This enforces this invariant that equal /// [`MastNode`]s have equal [`MastNodeId`]s. - pub fn ensure_node(&mut self, node: MastNode) -> Result { + fn ensure_node(&mut self, node: MastNode) -> Result { let node_digest = node.digest(); if let Some(node_id) = self.node_id_by_hash.get(&node_digest) { @@ -58,6 +59,63 @@ impl MastForestBuilder { } } + /// Adds a basic block node to the forest, and returns the [`MastNodeId`] associated with it. + pub fn ensure_block( + &mut self, + operations: Vec, + decorators: Option, + ) -> Result { + match decorators { + Some(decorators) => { + self.ensure_node(MastNode::new_basic_block_with_decorators(operations, decorators)) + } + None => self.ensure_node(MastNode::new_basic_block(operations)), + } + } + + /// Adds a join node to the forest, and returns the [`MastNodeId`] associated with it. + pub fn ensure_join( + &mut self, + left_child: MastNodeId, + right_child: MastNodeId, + ) -> Result { + self.ensure_node(MastNode::new_join(left_child, right_child, &self.mast_forest)) + } + + /// Adds a split node to the forest, and returns the [`MastNodeId`] associated with it. + pub fn ensure_split( + &mut self, + if_branch: MastNodeId, + else_branch: MastNodeId, + ) -> Result { + self.ensure_node(MastNode::new_split(if_branch, else_branch, &self.mast_forest)) + } + + /// Adds a loop node to the forest, and returns the [`MastNodeId`] associated with it. + pub fn ensure_loop(&mut self, body: MastNodeId) -> Result { + self.ensure_node(MastNode::new_loop(body, &self.mast_forest)) + } + + /// Adds a call node to the forest, and returns the [`MastNodeId`] associated with it. + pub fn ensure_call(&mut self, callee: MastNodeId) -> Result { + self.ensure_node(MastNode::new_call(callee, &self.mast_forest)) + } + + /// Adds a syscall node to the forest, and returns the [`MastNodeId`] associated with it. + pub fn ensure_syscall(&mut self, callee: MastNodeId) -> Result { + self.ensure_node(MastNode::new_syscall(callee, &self.mast_forest)) + } + + /// Adds a dynexec node to the forest, and returns the [`MastNodeId`] associated with it. + pub fn ensure_dyn(&mut self) -> Result { + self.ensure_node(MastNode::new_dyn()) + } + + /// Adds an external node to the forest, and returns the [`MastNodeId`] associated with it. + pub fn ensure_external(&mut self, mast_root: RpoDigest) -> Result { + self.ensure_node(MastNode::new_external(mast_root)) + } + /// Marks the given [`MastNodeId`] as being the root of a procedure. pub fn make_root(&mut self, new_root_id: MastNodeId) { self.mast_forest.make_root(new_root_id) diff --git a/assembly/src/assembler/mod.rs b/assembly/src/assembler/mod.rs index b6bb051d30..cfdfbfe592 100644 --- a/assembly/src/assembler/mod.rs +++ b/assembly/src/assembler/mod.rs @@ -11,7 +11,7 @@ use crate::{ use alloc::{boxed::Box, sync::Arc, vec::Vec}; use mast_forest_builder::MastForestBuilder; use vm_core::{ - mast::{MastForest, MastNode, MastNodeId, MerkleTreeNode}, + mast::{MastForest, MastNodeId, MerkleTreeNode}, Decorator, DecoratorList, Kernel, Operation, Program, }; @@ -792,12 +792,9 @@ impl Assembler { let else_blk = self.compile_body(else_blk.iter(), context, None, mast_forest_builder)?; - let split_node_id = { - let split_node = - MastNode::new_split(then_blk, else_blk, mast_forest_builder.forest()); - - mast_forest_builder.ensure_node(split_node).map_err(AssemblyError::from)? - }; + let split_node_id = mast_forest_builder + .ensure_split(then_blk, else_blk) + .map_err(AssemblyError::from)?; mast_node_ids.push(split_node_id); } @@ -828,11 +825,9 @@ impl Assembler { let loop_body_node_id = self.compile_body(body.iter(), context, None, mast_forest_builder)?; - let loop_node_id = { - let loop_node = - MastNode::new_loop(loop_body_node_id, mast_forest_builder.forest()); - mast_forest_builder.ensure_node(loop_node).map_err(AssemblyError::from)? - }; + let loop_node_id = mast_forest_builder + .ensure_loop(loop_body_node_id) + .map_err(AssemblyError::from)?; mast_node_ids.push(loop_node_id); } } @@ -846,8 +841,9 @@ impl Assembler { } Ok(if mast_node_ids.is_empty() { - let basic_block_node = MastNode::new_basic_block(vec![Operation::Noop]); - mast_forest_builder.ensure_node(basic_block_node).map_err(AssemblyError::from)? + mast_forest_builder + .ensure_block(vec![Operation::Noop], None) + .map_err(AssemblyError::from)? } else { combine_mast_node_ids(mast_node_ids, mast_forest_builder)? }) @@ -907,8 +903,7 @@ fn combine_mast_node_ids( while let (Some(left), Some(right)) = (source_mast_node_iter.next(), source_mast_node_iter.next()) { - let join_mast_node = MastNode::new_join(left, right, mast_forest_builder.forest()); - let join_mast_node_id = mast_forest_builder.ensure_node(join_mast_node)?; + let join_mast_node_id = mast_forest_builder.ensure_join(left, right)?; mast_node_ids.push(join_mast_node_id); } diff --git a/assembly/src/assembler/tests.rs b/assembly/src/assembler/tests.rs index e627d985ab..c557276b2c 100644 --- a/assembly/src/assembler/tests.rs +++ b/assembly/src/assembler/tests.rs @@ -81,12 +81,10 @@ fn nested_blocks() { // contains the MAST nodes for the kernel after a call to // `Assembler::with_kernel_from_module()`. let syscall_foo_node_id = { - let kernel_foo_node = MastNode::new_basic_block(vec![Operation::Add]); - let kernel_foo_node_id = expected_mast_forest_builder.ensure_node(kernel_foo_node).unwrap(); + let kernel_foo_node_id = + expected_mast_forest_builder.ensure_block(vec![Operation::Add], None).unwrap(); - let syscall_node = - MastNode::new_syscall(kernel_foo_node_id, expected_mast_forest_builder.forest()); - expected_mast_forest_builder.ensure_node(syscall_node).unwrap() + expected_mast_forest_builder.ensure_syscall(kernel_foo_node_id).unwrap() }; let program = r#" @@ -129,95 +127,65 @@ fn nested_blocks() { let exec_bar_node_id = { // bar procedure - let basic_block_1 = MastNode::new_basic_block(vec![Operation::Push(17_u32.into())]); - let basic_block_1_id = expected_mast_forest_builder.ensure_node(basic_block_1).unwrap(); + let basic_block_1_id = expected_mast_forest_builder + .ensure_block(vec![Operation::Push(17_u32.into())], None) + .unwrap(); // Basic block representing the `foo` procedure - let basic_block_2 = MastNode::new_basic_block(vec![Operation::Push(19_u32.into())]); - let basic_block_2_id = expected_mast_forest_builder.ensure_node(basic_block_2).unwrap(); - - let join_node = MastNode::new_join( - basic_block_1_id, - basic_block_2_id, - expected_mast_forest_builder.forest(), - ); - expected_mast_forest_builder.ensure_node(join_node).unwrap() - }; + let basic_block_2_id = expected_mast_forest_builder + .ensure_block(vec![Operation::Push(19_u32.into())], None) + .unwrap(); - let exec_foo_bar_baz_node_id = { - // basic block representing foo::bar.baz procedure - let basic_block = MastNode::new_basic_block(vec![Operation::Push(29_u32.into())]); - expected_mast_forest_builder.ensure_node(basic_block).unwrap() + expected_mast_forest_builder + .ensure_join(basic_block_1_id, basic_block_2_id) + .unwrap() }; - let before = { - let before_node = MastNode::new_basic_block(vec![Operation::Push(2u32.into())]); - expected_mast_forest_builder.ensure_node(before_node).unwrap() - }; + // basic block representing foo::bar.baz procedure + let exec_foo_bar_baz_node_id = expected_mast_forest_builder + .ensure_block(vec![Operation::Push(29_u32.into())], None) + .unwrap(); - let r#true1 = { - let r#true_node = MastNode::new_basic_block(vec![Operation::Push(3u32.into())]); - expected_mast_forest_builder.ensure_node(r#true_node).unwrap() - }; - let r#false1 = { - let r#false_node = MastNode::new_basic_block(vec![Operation::Push(5u32.into())]); - expected_mast_forest_builder.ensure_node(r#false_node).unwrap() - }; - let r#if1 = { - let r#if_node = - MastNode::new_split(r#true1, r#false1, expected_mast_forest_builder.forest()); - expected_mast_forest_builder.ensure_node(r#if_node).unwrap() - }; + let before = expected_mast_forest_builder + .ensure_block(vec![Operation::Push(2u32.into())], None) + .unwrap(); - let r#true3 = { - let r#true_node = MastNode::new_basic_block(vec![Operation::Push(7u32.into())]); - expected_mast_forest_builder.ensure_node(r#true_node).unwrap() - }; - let r#false3 = { - let r#false_node = MastNode::new_basic_block(vec![Operation::Push(11u32.into())]); - expected_mast_forest_builder.ensure_node(r#false_node).unwrap() - }; - let r#true2 = { - let r#if_node = - MastNode::new_split(r#true3, r#false3, expected_mast_forest_builder.forest()); - expected_mast_forest_builder.ensure_node(r#if_node).unwrap() - }; + let r#true1 = expected_mast_forest_builder + .ensure_block(vec![Operation::Push(3u32.into())], None) + .unwrap(); + let r#false1 = expected_mast_forest_builder + .ensure_block(vec![Operation::Push(5u32.into())], None) + .unwrap(); + let r#if1 = expected_mast_forest_builder.ensure_split(r#true1, r#false1).unwrap(); + + let r#true3 = expected_mast_forest_builder + .ensure_block(vec![Operation::Push(7u32.into())], None) + .unwrap(); + let r#false3 = expected_mast_forest_builder + .ensure_block(vec![Operation::Push(11u32.into())], None) + .unwrap(); + let r#true2 = expected_mast_forest_builder.ensure_split(r#true3, r#false3).unwrap(); let r#while = { let push_basic_block_id = { - let push_basic_block = MastNode::new_basic_block(vec![Operation::Push(23u32.into())]); - expected_mast_forest_builder.ensure_node(push_basic_block).unwrap() - }; - let body_node_id = { - let body_node = MastNode::new_join( - exec_bar_node_id, - push_basic_block_id, - expected_mast_forest_builder.forest(), - ); - - expected_mast_forest_builder.ensure_node(body_node).unwrap() + expected_mast_forest_builder + .ensure_block(vec![Operation::Push(23u32.into())], None) + .unwrap() }; + let body_node_id = expected_mast_forest_builder + .ensure_join(exec_bar_node_id, push_basic_block_id) + .unwrap(); - let loop_node = MastNode::new_loop(body_node_id, expected_mast_forest_builder.forest()); - expected_mast_forest_builder.ensure_node(loop_node).unwrap() - }; - let push_13_basic_block_id = { - let node = MastNode::new_basic_block(vec![Operation::Push(13u32.into())]); - expected_mast_forest_builder.ensure_node(node).unwrap() + expected_mast_forest_builder.ensure_loop(body_node_id).unwrap() }; + let push_13_basic_block_id = expected_mast_forest_builder + .ensure_block(vec![Operation::Push(13u32.into())], None) + .unwrap(); - let r#false2 = { - let node = MastNode::new_join( - push_13_basic_block_id, - r#while, - expected_mast_forest_builder.forest(), - ); - expected_mast_forest_builder.ensure_node(node).unwrap() - }; - let nested = { - let node = MastNode::new_split(r#true2, r#false2, expected_mast_forest_builder.forest()); - expected_mast_forest_builder.ensure_node(node).unwrap() - }; + let r#false2 = expected_mast_forest_builder + .ensure_join(push_13_basic_block_id, r#while) + .unwrap(); + let nested = expected_mast_forest_builder.ensure_split(r#true2, r#false2).unwrap(); let combined_node_id = combine_mast_node_ids( vec![before, r#if1, nested, exec_foo_bar_baz_node_id, syscall_foo_node_id], diff --git a/core/src/mast/node/mod.rs b/core/src/mast/node/mod.rs index 31cb297309..aab440d858 100644 --- a/core/src/mast/node/mod.rs +++ b/core/src/mast/node/mod.rs @@ -84,7 +84,7 @@ impl MastNode { Self::Call(CallNode::new_syscall(callee, mast_forest)) } - pub fn new_dynexec() -> Self { + pub fn new_dyn() -> Self { Self::Dyn } diff --git a/core/src/mast/serialization/info.rs b/core/src/mast/serialization/info.rs index b72e1701e0..5b5a6aabdb 100644 --- a/core/src/mast/serialization/info.rs +++ b/core/src/mast/serialization/info.rs @@ -77,7 +77,7 @@ impl MastNodeInfo { Ok(MastNode::new_syscall(callee_id, mast_forest)) } - MastNodeType::Dyn => Ok(MastNode::new_dynexec()), + MastNodeType::Dyn => Ok(MastNode::new_dyn()), MastNodeType::External => Ok(MastNode::new_external(self.digest)), }?; diff --git a/core/src/mast/serialization/tests.rs b/core/src/mast/serialization/tests.rs index 69d746ec38..e07e92296d 100644 --- a/core/src/mast/serialization/tests.rs +++ b/core/src/mast/serialization/tests.rs @@ -322,7 +322,7 @@ fn serialize_deserialize_all_nodes() { mast_forest.add_node(node).unwrap() }; let dyn_node_id = { - let node = MastNode::new_dynexec(); + let node = MastNode::new_dyn(); mast_forest.add_node(node).unwrap() }; diff --git a/processor/src/decoder/tests.rs b/processor/src/decoder/tests.rs index 0cbdbd8f82..390a162e9e 100644 --- a/processor/src/decoder/tests.rs +++ b/processor/src/decoder/tests.rs @@ -1194,7 +1194,7 @@ fn dyn_block() { let join_node_id = mast_forest.add_node(join_node.clone()).unwrap(); // This dyn will point to foo. - let dyn_node = MastNode::new_dynexec(); + let dyn_node = MastNode::new_dyn(); let dyn_node_id = mast_forest.add_node(dyn_node.clone()).unwrap(); let program_root_node = MastNode::new_join(join_node_id, dyn_node_id, &mast_forest);