From 64c74014d00580baeeca2837bd2e8ecf06bc4a2e Mon Sep 17 00:00:00 2001 From: Serge Radinovich <47865535+sergerad@users.noreply.github.com> Date: Sun, 21 Jul 2024 06:41:50 +1200 Subject: [PATCH 1/7] feat: add more functions for ensuring nodes via MastForestBuilder (#1404) --- CHANGELOG.md | 9 +- assembly/src/assembler/basic_block_builder.rs | 5 +- .../src/assembler/instruction/procedures.rs | 26 ++-- assembly/src/assembler/mast_forest_builder.rs | 62 ++++++++- assembly/src/assembler/mod.rs | 27 ++-- assembly/src/assembler/tests.rs | 128 +++++++----------- core/src/mast/node/mod.rs | 2 +- core/src/mast/serialization/info.rs | 2 +- core/src/mast/serialization/tests.rs | 2 +- processor/src/decoder/tests.rs | 2 +- 10 files changed, 139 insertions(+), 126 deletions(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index 72392ea618..b8751ad390 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -6,8 +6,8 @@ - Added error codes support for the `mtree_verify` instruction (#1328). - Added support for immediate values for `lt`, `lte`, `gt`, `gte` comparison instructions (#1346). -- Change MAST to a table-based representation (#1349) -- Introduce `MastForestStore` (#1359) +- Changed MAST to a table-based representation (#1349) +- Introduced `MastForestStore` (#1359) - Adjusted prover's metal acceleration code to work with 0.9 versions of the crates (#1357) - Added support for immediate values for `u32lt`, `u32lte`, `u32gt`, `u32gte`, `u32min` and `u32max` comparison instructions (#1358). - Added support for the `nop` instruction, which corresponds to the VM opcode of the same name, and has the same semantics. This is implemented for use by compilers primarily. @@ -16,9 +16,10 @@ - Added support for immediate values for `u32and`, `u32or`, `u32xor` and `u32not` bitwise instructions (#1362). - Optimized `std::sys::truncate_stuck` procedure (#1384). - Updated CI and Makefile to standardise it accross Miden repositories (#1342). -- Add serialization/deserialization for `MastForest` (#1370) +- Added serialization/deserialization for `MastForest` (#1370) - Updated CI to support `CHANGELOG.md` modification checking and `no changelog` label (#1406) -- Introduce `MastForestError` to enforce `MastForest` node count invariant (#1394) +- Introduced `MastForestError` to enforce `MastForest` node count invariant (#1394) +- Added functions to `MastForestBuilder` to allow ensuring of nodes with fewer LOC (#1404) #### Changed diff --git a/assembly/src/assembler/basic_block_builder.rs b/assembly/src/assembler/basic_block_builder.rs index 2545d4eea8..5263faeff2 100644 --- a/assembly/src/assembler/basic_block_builder.rs +++ b/assembly/src/assembler/basic_block_builder.rs @@ -4,7 +4,7 @@ use super::{ }; use alloc::{borrow::Borrow, string::ToString, vec::Vec}; use vm_core::{ - mast::{MastForestError, MastNode, MastNodeId}, + mast::{MastForestError, MastNodeId}, AdviceInjector, AssemblyOp, Operation, }; @@ -134,8 +134,7 @@ impl BasicBlockBuilder { let ops = self.ops.drain(..).collect(); let decorators = self.decorators.drain(..).collect(); - let basic_block_node = MastNode::new_basic_block_with_decorators(ops, decorators); - let basic_block_node_id = mast_forest_builder.ensure_node(basic_block_node)?; + let basic_block_node_id = mast_forest_builder.ensure_block(ops, Some(decorators))?; Ok(Some(basic_block_node_id)) } else if !self.decorators.is_empty() { diff --git a/assembly/src/assembler/instruction/procedures.rs b/assembly/src/assembler/instruction/procedures.rs index d6a96f0f1b..a86474fce7 100644 --- a/assembly/src/assembler/instruction/procedures.rs +++ b/assembly/src/assembler/instruction/procedures.rs @@ -6,7 +6,7 @@ use crate::{ }; use smallvec::SmallVec; -use vm_core::mast::{MastForest, MastNode, MastNodeId}; +use vm_core::mast::{MastForest, MastNodeId}; /// Procedure Invocation impl Assembler { @@ -96,8 +96,7 @@ impl Assembler { None => { // If the MAST root called isn't known to us, make it an external // reference. - let external_node = MastNode::new_external(mast_root); - mast_forest_builder.ensure_node(external_node)? + mast_forest_builder.ensure_external(mast_root)? } } } @@ -107,13 +106,11 @@ impl Assembler { None => { // If the MAST root called isn't known to us, make it an external // reference. - let external_node = MastNode::new_external(mast_root); - mast_forest_builder.ensure_node(external_node)? + mast_forest_builder.ensure_external(mast_root)? } }; - let call_node = MastNode::new_call(callee_id, mast_forest_builder.forest()); - mast_forest_builder.ensure_node(call_node)? + mast_forest_builder.ensure_call(callee_id)? } InvokeKind::SysCall => { let callee_id = match mast_forest_builder.find_procedure_root(mast_root) { @@ -121,14 +118,11 @@ impl Assembler { None => { // If the MAST root called isn't known to us, make it an external // reference. - let external_node = MastNode::new_external(mast_root); - mast_forest_builder.ensure_node(external_node)? + mast_forest_builder.ensure_external(mast_root)? } }; - let syscall_node = - MastNode::new_syscall(callee_id, mast_forest_builder.forest()); - mast_forest_builder.ensure_node(syscall_node)? + mast_forest_builder.ensure_syscall(callee_id)? } } }; @@ -141,7 +135,7 @@ impl Assembler { &self, mast_forest_builder: &mut MastForestBuilder, ) -> Result, AssemblyError> { - let dyn_node_id = mast_forest_builder.ensure_node(MastNode::Dyn)?; + let dyn_node_id = mast_forest_builder.ensure_dyn()?; Ok(Some(dyn_node_id)) } @@ -152,10 +146,8 @@ impl Assembler { mast_forest_builder: &mut MastForestBuilder, ) -> Result, AssemblyError> { let dyn_call_node_id = { - let dyn_node_id = mast_forest_builder.ensure_node(MastNode::Dyn)?; - let dyn_call_node = MastNode::new_call(dyn_node_id, mast_forest_builder.forest()); - - mast_forest_builder.ensure_node(dyn_call_node)? + let dyn_node_id = mast_forest_builder.ensure_dyn()?; + mast_forest_builder.ensure_call(dyn_node_id)? }; Ok(Some(dyn_call_node_id)) diff --git a/assembly/src/assembler/mast_forest_builder.rs b/assembly/src/assembler/mast_forest_builder.rs index a5854c8b01..0a1a44d037 100644 --- a/assembly/src/assembler/mast_forest_builder.rs +++ b/assembly/src/assembler/mast_forest_builder.rs @@ -1,9 +1,10 @@ use core::ops::Index; -use alloc::collections::BTreeMap; +use alloc::{collections::BTreeMap, vec::Vec}; use vm_core::{ crypto::hash::RpoDigest, mast::{MastForest, MastForestError, MastNode, MastNodeId, MerkleTreeNode}, + DecoratorList, Operation, }; /// Builder for a [`MastForest`]. @@ -44,7 +45,7 @@ impl MastForestBuilder { /// If a [`MastNode`] which is equal to the current node was previously added, the previously /// returned [`MastNodeId`] will be returned. This enforces this invariant that equal /// [`MastNode`]s have equal [`MastNodeId`]s. - pub fn ensure_node(&mut self, node: MastNode) -> Result { + fn ensure_node(&mut self, node: MastNode) -> Result { let node_digest = node.digest(); if let Some(node_id) = self.node_id_by_hash.get(&node_digest) { @@ -58,6 +59,63 @@ impl MastForestBuilder { } } + /// Adds a basic block node to the forest, and returns the [`MastNodeId`] associated with it. + pub fn ensure_block( + &mut self, + operations: Vec, + decorators: Option, + ) -> Result { + match decorators { + Some(decorators) => { + self.ensure_node(MastNode::new_basic_block_with_decorators(operations, decorators)) + } + None => self.ensure_node(MastNode::new_basic_block(operations)), + } + } + + /// Adds a join node to the forest, and returns the [`MastNodeId`] associated with it. + pub fn ensure_join( + &mut self, + left_child: MastNodeId, + right_child: MastNodeId, + ) -> Result { + self.ensure_node(MastNode::new_join(left_child, right_child, &self.mast_forest)) + } + + /// Adds a split node to the forest, and returns the [`MastNodeId`] associated with it. + pub fn ensure_split( + &mut self, + if_branch: MastNodeId, + else_branch: MastNodeId, + ) -> Result { + self.ensure_node(MastNode::new_split(if_branch, else_branch, &self.mast_forest)) + } + + /// Adds a loop node to the forest, and returns the [`MastNodeId`] associated with it. + pub fn ensure_loop(&mut self, body: MastNodeId) -> Result { + self.ensure_node(MastNode::new_loop(body, &self.mast_forest)) + } + + /// Adds a call node to the forest, and returns the [`MastNodeId`] associated with it. + pub fn ensure_call(&mut self, callee: MastNodeId) -> Result { + self.ensure_node(MastNode::new_call(callee, &self.mast_forest)) + } + + /// Adds a syscall node to the forest, and returns the [`MastNodeId`] associated with it. + pub fn ensure_syscall(&mut self, callee: MastNodeId) -> Result { + self.ensure_node(MastNode::new_syscall(callee, &self.mast_forest)) + } + + /// Adds a dynexec node to the forest, and returns the [`MastNodeId`] associated with it. + pub fn ensure_dyn(&mut self) -> Result { + self.ensure_node(MastNode::new_dyn()) + } + + /// Adds an external node to the forest, and returns the [`MastNodeId`] associated with it. + pub fn ensure_external(&mut self, mast_root: RpoDigest) -> Result { + self.ensure_node(MastNode::new_external(mast_root)) + } + /// Marks the given [`MastNodeId`] as being the root of a procedure. pub fn make_root(&mut self, new_root_id: MastNodeId) { self.mast_forest.make_root(new_root_id) diff --git a/assembly/src/assembler/mod.rs b/assembly/src/assembler/mod.rs index b6bb051d30..cfdfbfe592 100644 --- a/assembly/src/assembler/mod.rs +++ b/assembly/src/assembler/mod.rs @@ -11,7 +11,7 @@ use crate::{ use alloc::{boxed::Box, sync::Arc, vec::Vec}; use mast_forest_builder::MastForestBuilder; use vm_core::{ - mast::{MastForest, MastNode, MastNodeId, MerkleTreeNode}, + mast::{MastForest, MastNodeId, MerkleTreeNode}, Decorator, DecoratorList, Kernel, Operation, Program, }; @@ -792,12 +792,9 @@ impl Assembler { let else_blk = self.compile_body(else_blk.iter(), context, None, mast_forest_builder)?; - let split_node_id = { - let split_node = - MastNode::new_split(then_blk, else_blk, mast_forest_builder.forest()); - - mast_forest_builder.ensure_node(split_node).map_err(AssemblyError::from)? - }; + let split_node_id = mast_forest_builder + .ensure_split(then_blk, else_blk) + .map_err(AssemblyError::from)?; mast_node_ids.push(split_node_id); } @@ -828,11 +825,9 @@ impl Assembler { let loop_body_node_id = self.compile_body(body.iter(), context, None, mast_forest_builder)?; - let loop_node_id = { - let loop_node = - MastNode::new_loop(loop_body_node_id, mast_forest_builder.forest()); - mast_forest_builder.ensure_node(loop_node).map_err(AssemblyError::from)? - }; + let loop_node_id = mast_forest_builder + .ensure_loop(loop_body_node_id) + .map_err(AssemblyError::from)?; mast_node_ids.push(loop_node_id); } } @@ -846,8 +841,9 @@ impl Assembler { } Ok(if mast_node_ids.is_empty() { - let basic_block_node = MastNode::new_basic_block(vec![Operation::Noop]); - mast_forest_builder.ensure_node(basic_block_node).map_err(AssemblyError::from)? + mast_forest_builder + .ensure_block(vec![Operation::Noop], None) + .map_err(AssemblyError::from)? } else { combine_mast_node_ids(mast_node_ids, mast_forest_builder)? }) @@ -907,8 +903,7 @@ fn combine_mast_node_ids( while let (Some(left), Some(right)) = (source_mast_node_iter.next(), source_mast_node_iter.next()) { - let join_mast_node = MastNode::new_join(left, right, mast_forest_builder.forest()); - let join_mast_node_id = mast_forest_builder.ensure_node(join_mast_node)?; + let join_mast_node_id = mast_forest_builder.ensure_join(left, right)?; mast_node_ids.push(join_mast_node_id); } diff --git a/assembly/src/assembler/tests.rs b/assembly/src/assembler/tests.rs index e627d985ab..c557276b2c 100644 --- a/assembly/src/assembler/tests.rs +++ b/assembly/src/assembler/tests.rs @@ -81,12 +81,10 @@ fn nested_blocks() { // contains the MAST nodes for the kernel after a call to // `Assembler::with_kernel_from_module()`. let syscall_foo_node_id = { - let kernel_foo_node = MastNode::new_basic_block(vec![Operation::Add]); - let kernel_foo_node_id = expected_mast_forest_builder.ensure_node(kernel_foo_node).unwrap(); + let kernel_foo_node_id = + expected_mast_forest_builder.ensure_block(vec![Operation::Add], None).unwrap(); - let syscall_node = - MastNode::new_syscall(kernel_foo_node_id, expected_mast_forest_builder.forest()); - expected_mast_forest_builder.ensure_node(syscall_node).unwrap() + expected_mast_forest_builder.ensure_syscall(kernel_foo_node_id).unwrap() }; let program = r#" @@ -129,95 +127,65 @@ fn nested_blocks() { let exec_bar_node_id = { // bar procedure - let basic_block_1 = MastNode::new_basic_block(vec![Operation::Push(17_u32.into())]); - let basic_block_1_id = expected_mast_forest_builder.ensure_node(basic_block_1).unwrap(); + let basic_block_1_id = expected_mast_forest_builder + .ensure_block(vec![Operation::Push(17_u32.into())], None) + .unwrap(); // Basic block representing the `foo` procedure - let basic_block_2 = MastNode::new_basic_block(vec![Operation::Push(19_u32.into())]); - let basic_block_2_id = expected_mast_forest_builder.ensure_node(basic_block_2).unwrap(); - - let join_node = MastNode::new_join( - basic_block_1_id, - basic_block_2_id, - expected_mast_forest_builder.forest(), - ); - expected_mast_forest_builder.ensure_node(join_node).unwrap() - }; + let basic_block_2_id = expected_mast_forest_builder + .ensure_block(vec![Operation::Push(19_u32.into())], None) + .unwrap(); - let exec_foo_bar_baz_node_id = { - // basic block representing foo::bar.baz procedure - let basic_block = MastNode::new_basic_block(vec![Operation::Push(29_u32.into())]); - expected_mast_forest_builder.ensure_node(basic_block).unwrap() + expected_mast_forest_builder + .ensure_join(basic_block_1_id, basic_block_2_id) + .unwrap() }; - let before = { - let before_node = MastNode::new_basic_block(vec![Operation::Push(2u32.into())]); - expected_mast_forest_builder.ensure_node(before_node).unwrap() - }; + // basic block representing foo::bar.baz procedure + let exec_foo_bar_baz_node_id = expected_mast_forest_builder + .ensure_block(vec![Operation::Push(29_u32.into())], None) + .unwrap(); - let r#true1 = { - let r#true_node = MastNode::new_basic_block(vec![Operation::Push(3u32.into())]); - expected_mast_forest_builder.ensure_node(r#true_node).unwrap() - }; - let r#false1 = { - let r#false_node = MastNode::new_basic_block(vec![Operation::Push(5u32.into())]); - expected_mast_forest_builder.ensure_node(r#false_node).unwrap() - }; - let r#if1 = { - let r#if_node = - MastNode::new_split(r#true1, r#false1, expected_mast_forest_builder.forest()); - expected_mast_forest_builder.ensure_node(r#if_node).unwrap() - }; + let before = expected_mast_forest_builder + .ensure_block(vec![Operation::Push(2u32.into())], None) + .unwrap(); - let r#true3 = { - let r#true_node = MastNode::new_basic_block(vec![Operation::Push(7u32.into())]); - expected_mast_forest_builder.ensure_node(r#true_node).unwrap() - }; - let r#false3 = { - let r#false_node = MastNode::new_basic_block(vec![Operation::Push(11u32.into())]); - expected_mast_forest_builder.ensure_node(r#false_node).unwrap() - }; - let r#true2 = { - let r#if_node = - MastNode::new_split(r#true3, r#false3, expected_mast_forest_builder.forest()); - expected_mast_forest_builder.ensure_node(r#if_node).unwrap() - }; + let r#true1 = expected_mast_forest_builder + .ensure_block(vec![Operation::Push(3u32.into())], None) + .unwrap(); + let r#false1 = expected_mast_forest_builder + .ensure_block(vec![Operation::Push(5u32.into())], None) + .unwrap(); + let r#if1 = expected_mast_forest_builder.ensure_split(r#true1, r#false1).unwrap(); + + let r#true3 = expected_mast_forest_builder + .ensure_block(vec![Operation::Push(7u32.into())], None) + .unwrap(); + let r#false3 = expected_mast_forest_builder + .ensure_block(vec![Operation::Push(11u32.into())], None) + .unwrap(); + let r#true2 = expected_mast_forest_builder.ensure_split(r#true3, r#false3).unwrap(); let r#while = { let push_basic_block_id = { - let push_basic_block = MastNode::new_basic_block(vec![Operation::Push(23u32.into())]); - expected_mast_forest_builder.ensure_node(push_basic_block).unwrap() - }; - let body_node_id = { - let body_node = MastNode::new_join( - exec_bar_node_id, - push_basic_block_id, - expected_mast_forest_builder.forest(), - ); - - expected_mast_forest_builder.ensure_node(body_node).unwrap() + expected_mast_forest_builder + .ensure_block(vec![Operation::Push(23u32.into())], None) + .unwrap() }; + let body_node_id = expected_mast_forest_builder + .ensure_join(exec_bar_node_id, push_basic_block_id) + .unwrap(); - let loop_node = MastNode::new_loop(body_node_id, expected_mast_forest_builder.forest()); - expected_mast_forest_builder.ensure_node(loop_node).unwrap() - }; - let push_13_basic_block_id = { - let node = MastNode::new_basic_block(vec![Operation::Push(13u32.into())]); - expected_mast_forest_builder.ensure_node(node).unwrap() + expected_mast_forest_builder.ensure_loop(body_node_id).unwrap() }; + let push_13_basic_block_id = expected_mast_forest_builder + .ensure_block(vec![Operation::Push(13u32.into())], None) + .unwrap(); - let r#false2 = { - let node = MastNode::new_join( - push_13_basic_block_id, - r#while, - expected_mast_forest_builder.forest(), - ); - expected_mast_forest_builder.ensure_node(node).unwrap() - }; - let nested = { - let node = MastNode::new_split(r#true2, r#false2, expected_mast_forest_builder.forest()); - expected_mast_forest_builder.ensure_node(node).unwrap() - }; + let r#false2 = expected_mast_forest_builder + .ensure_join(push_13_basic_block_id, r#while) + .unwrap(); + let nested = expected_mast_forest_builder.ensure_split(r#true2, r#false2).unwrap(); let combined_node_id = combine_mast_node_ids( vec![before, r#if1, nested, exec_foo_bar_baz_node_id, syscall_foo_node_id], diff --git a/core/src/mast/node/mod.rs b/core/src/mast/node/mod.rs index 31cb297309..aab440d858 100644 --- a/core/src/mast/node/mod.rs +++ b/core/src/mast/node/mod.rs @@ -84,7 +84,7 @@ impl MastNode { Self::Call(CallNode::new_syscall(callee, mast_forest)) } - pub fn new_dynexec() -> Self { + pub fn new_dyn() -> Self { Self::Dyn } diff --git a/core/src/mast/serialization/info.rs b/core/src/mast/serialization/info.rs index b72e1701e0..5b5a6aabdb 100644 --- a/core/src/mast/serialization/info.rs +++ b/core/src/mast/serialization/info.rs @@ -77,7 +77,7 @@ impl MastNodeInfo { Ok(MastNode::new_syscall(callee_id, mast_forest)) } - MastNodeType::Dyn => Ok(MastNode::new_dynexec()), + MastNodeType::Dyn => Ok(MastNode::new_dyn()), MastNodeType::External => Ok(MastNode::new_external(self.digest)), }?; diff --git a/core/src/mast/serialization/tests.rs b/core/src/mast/serialization/tests.rs index 69d746ec38..e07e92296d 100644 --- a/core/src/mast/serialization/tests.rs +++ b/core/src/mast/serialization/tests.rs @@ -322,7 +322,7 @@ fn serialize_deserialize_all_nodes() { mast_forest.add_node(node).unwrap() }; let dyn_node_id = { - let node = MastNode::new_dynexec(); + let node = MastNode::new_dyn(); mast_forest.add_node(node).unwrap() }; diff --git a/processor/src/decoder/tests.rs b/processor/src/decoder/tests.rs index 0cbdbd8f82..390a162e9e 100644 --- a/processor/src/decoder/tests.rs +++ b/processor/src/decoder/tests.rs @@ -1194,7 +1194,7 @@ fn dyn_block() { let join_node_id = mast_forest.add_node(join_node.clone()).unwrap(); // This dyn will point to foo. - let dyn_node = MastNode::new_dynexec(); + let dyn_node = MastNode::new_dyn(); let dyn_node_id = mast_forest.add_node(dyn_node.clone()).unwrap(); let program_root_node = MastNode::new_join(join_node_id, dyn_node_id, &mast_forest); From 2c22ac19d0ca5d52147bc0c144f683b44337e35c Mon Sep 17 00:00:00 2001 From: Bobbin Threadbare Date: Sat, 20 Jul 2024 11:54:02 -0700 Subject: [PATCH 2/7] chore: add section separators --- core/src/mast/mod.rs | 140 ++++++++++++++++++++------------------ core/src/mast/node/mod.rs | 11 +++ 2 files changed, 84 insertions(+), 67 deletions(-) diff --git a/core/src/mast/mod.rs b/core/src/mast/mod.rs index b997c9f051..31ad5f1f0d 100644 --- a/core/src/mast/mod.rs +++ b/core/src/mast/mod.rs @@ -21,59 +21,6 @@ pub trait MerkleTreeNode { fn to_display<'a>(&'a self, mast_forest: &'a MastForest) -> impl fmt::Display + 'a; } -// MAST NODE ID -// ================================================================================================ - -/// An opaque handle to a [`MastNode`] in some [`MastForest`]. It is the responsibility of the user -/// to use a given [`MastNodeId`] with the corresponding [`MastForest`]. -/// -/// Note that the [`MastForest`] does *not* ensure that equal [`MastNode`]s have equal -/// [`MastNodeId`] handles. Hence, [`MastNodeId`] equality must not be used to test for equality of -/// the underlying [`MastNode`]. -#[derive(Debug, Clone, Copy, PartialEq, Eq, PartialOrd, Ord, Hash)] -pub struct MastNodeId(u32); - -impl MastNodeId { - /// Returns a new `MastNodeId` with the provided inner value, or an error if the provided - /// `value` is greater than the number of nodes in the forest. - /// - /// For use in deserialization. - pub fn from_u32_safe( - value: u32, - mast_forest: &MastForest, - ) -> Result { - if (value as usize) < mast_forest.nodes.len() { - Ok(Self(value)) - } else { - Err(DeserializationError::InvalidValue(format!( - "Invalid deserialized MAST node ID '{}', but only {} nodes in the forest", - value, - mast_forest.nodes.len(), - ))) - } - } -} - -impl fmt::Display for MastNodeId { - fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { - write!(f, "MastNodeId({})", self.0) - } -} - -impl Serializable for MastNodeId { - fn write_into(&self, target: &mut W) { - self.0.write_into(target) - } -} - -impl Deserializable for MastNodeId { - fn read_from(source: &mut R) -> Result { - let inner = source.read_u32()?; - - Ok(Self(inner)) - } -} - // MAST FOREST // ================================================================================================ @@ -87,19 +34,7 @@ pub enum MastForestError { TooManyNodes, } -/// Represents one or more procedures, represented as a collection of [`MastNode`]s. -/// -/// A [`MastForest`] does not have an entrypoint, and hence is not executable. A [`crate::Program`] -/// can be built from a [`MastForest`] to specify an entrypoint. -#[derive(Clone, Debug, Default, PartialEq, Eq)] -pub struct MastForest { - /// All of the nodes local to the trees comprising the MAST forest. - nodes: Vec, - - /// Roots of procedures defined within this MAST forest. - roots: Vec, -} - +// ------------------------------------------------------------------------------------------------ /// Constructors impl MastForest { /// Creates a new empty [`MastForest`]. @@ -108,7 +43,8 @@ impl MastForest { } } -/// Mutators +// ------------------------------------------------------------------------------------------------ +/// State mutators impl MastForest { /// The maximum number of nodes that can be stored in a single MAST forest. const MAX_NODES: usize = (1 << 30) - 1; @@ -141,6 +77,7 @@ impl MastForest { } } +// ------------------------------------------------------------------------------------------------ /// Public accessors impl MastForest { /// Returns the [`MastNode`] associated with the provided [`MastNodeId`] if valid, or else @@ -184,3 +121,72 @@ impl Index for MastForest { &self.nodes[idx] } } + +// MAST NODE ID +// ================================================================================================ + +/// An opaque handle to a [`MastNode`] in some [`MastForest`]. It is the responsibility of the user +/// to use a given [`MastNodeId`] with the corresponding [`MastForest`]. +/// +/// Note that the [`MastForest`] does *not* ensure that equal [`MastNode`]s have equal +/// [`MastNodeId`] handles. Hence, [`MastNodeId`] equality must not be used to test for equality of +/// the underlying [`MastNode`]. +#[derive(Debug, Clone, Copy, PartialEq, Eq, PartialOrd, Ord, Hash)] +pub struct MastNodeId(u32); + +impl MastNodeId { + /// Returns a new `MastNodeId` with the provided inner value, or an error if the provided + /// `value` is greater than the number of nodes in the forest. + /// + /// For use in deserialization. + pub fn from_u32_safe( + value: u32, + mast_forest: &MastForest, + ) -> Result { + if (value as usize) < mast_forest.nodes.len() { + Ok(Self(value)) + } else { + Err(DeserializationError::InvalidValue(format!( + "Invalid deserialized MAST node ID '{}', but only {} nodes in the forest", + value, + mast_forest.nodes.len(), + ))) + } + } +} + +impl fmt::Display for MastNodeId { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + write!(f, "MastNodeId({})", self.0) + } +} + +impl Serializable for MastNodeId { + fn write_into(&self, target: &mut W) { + self.0.write_into(target) + } +} + +impl Deserializable for MastNodeId { + fn read_from(source: &mut R) -> Result { + let inner = source.read_u32()?; + + Ok(Self(inner)) + } +} + +// MAST FOREST ERROR +// ================================================================================================ + +/// Represents one or more procedures, represented as a collection of [`MastNode`]s. +/// +/// A [`MastForest`] does not have an entrypoint, and hence is not executable. A [`crate::Program`] +/// can be built from a [`MastForest`] to specify an entrypoint. +#[derive(Clone, Debug, Default, PartialEq, Eq)] +pub struct MastForest { + /// All of the nodes local to the trees comprising the MAST forest. + nodes: Vec, + + /// Roots of procedures defined within this MAST forest. + roots: Vec, +} diff --git a/core/src/mast/node/mod.rs b/core/src/mast/node/mod.rs index aab440d858..a1c52bb211 100644 --- a/core/src/mast/node/mod.rs +++ b/core/src/mast/node/mod.rs @@ -32,6 +32,9 @@ use crate::{ DecoratorList, Operation, }; +// MAST NODE +// ================================================================================================ + #[derive(Debug, Clone, PartialEq, Eq)] pub enum MastNode { Block(BasicBlockNode), @@ -43,6 +46,7 @@ pub enum MastNode { External(ExternalNode), } +// ------------------------------------------------------------------------------------------------ /// Constructors impl MastNode { pub fn new_basic_block(operations: Vec) -> Self { @@ -93,6 +97,7 @@ impl MastNode { } } +// ------------------------------------------------------------------------------------------------ /// Public accessors impl MastNode { pub fn is_basic_block(&self) -> bool { @@ -137,6 +142,9 @@ impl MastNode { } } +// ------------------------------------------------------------------------------------------------ +// MerkleTreeNode impl + impl MerkleTreeNode for MastNode { fn digest(&self) -> RpoDigest { match self { @@ -163,6 +171,9 @@ impl MerkleTreeNode for MastNode { } } +// PRETTY PRINTING +// ================================================================================================ + struct MastNodePrettyPrint<'a> { node_pretty_print: Box, } From b10567bbc2f9dc74acd6d29b6978e1e46049f0f3 Mon Sep 17 00:00:00 2001 From: Bobbin Threadbare <43513081+bobbinth@users.noreply.github.com> Date: Mon, 22 Jul 2024 14:39:23 -0700 Subject: [PATCH 3/7] chore: minor MAST module cleanup (#1407) * refactor: remove MerkleTreeNode trait * chore: added comments and section separators * fix: remove implicit serialization of MastNodeId --- assembly/src/assembler/mast_forest_builder.rs | 2 +- assembly/src/assembler/mod.rs | 2 +- assembly/src/assembler/procedure.rs | 2 +- core/src/mast/mod.rs | 70 ++--- core/src/mast/node/basic_block_node/mod.rs | 286 ++++-------------- .../mast/node/basic_block_node/op_batch.rs | 175 +++++++++++ core/src/mast/node/call_node.rs | 54 +++- core/src/mast/node/dyn_node.rs | 37 ++- core/src/mast/node/external.rs | 18 +- core/src/mast/node/join_node.rs | 46 ++- core/src/mast/node/loop_node.rs | 53 +++- core/src/mast/node/mod.rs | 19 +- core/src/mast/node/split_node.rs | 47 ++- core/src/mast/serialization/info.rs | 2 +- core/src/mast/serialization/mod.rs | 10 +- core/src/mast/tests.rs | 6 +- core/src/program.rs | 2 +- .../integration/operations/io_ops/env_ops.rs | 2 +- processor/src/chiplets/hasher/tests.rs | 2 +- processor/src/decoder/mod.rs | 5 +- processor/src/decoder/tests.rs | 2 +- processor/src/lib.rs | 2 +- processor/src/trace/tests/decoder.rs | 2 +- 23 files changed, 481 insertions(+), 365 deletions(-) create mode 100644 core/src/mast/node/basic_block_node/op_batch.rs diff --git a/assembly/src/assembler/mast_forest_builder.rs b/assembly/src/assembler/mast_forest_builder.rs index 0a1a44d037..2d0bacbb3e 100644 --- a/assembly/src/assembler/mast_forest_builder.rs +++ b/assembly/src/assembler/mast_forest_builder.rs @@ -3,7 +3,7 @@ use core::ops::Index; use alloc::{collections::BTreeMap, vec::Vec}; use vm_core::{ crypto::hash::RpoDigest, - mast::{MastForest, MastForestError, MastNode, MastNodeId, MerkleTreeNode}, + mast::{MastForest, MastForestError, MastNode, MastNodeId}, DecoratorList, Operation, }; diff --git a/assembly/src/assembler/mod.rs b/assembly/src/assembler/mod.rs index cfdfbfe592..2ef43cdb03 100644 --- a/assembly/src/assembler/mod.rs +++ b/assembly/src/assembler/mod.rs @@ -11,7 +11,7 @@ use crate::{ use alloc::{boxed::Box, sync::Arc, vec::Vec}; use mast_forest_builder::MastForestBuilder; use vm_core::{ - mast::{MastForest, MastNodeId, MerkleTreeNode}, + mast::{MastForest, MastNodeId}, Decorator, DecoratorList, Kernel, Operation, Program, }; diff --git a/assembly/src/assembler/procedure.rs b/assembly/src/assembler/procedure.rs index 88224396da..15675f5b35 100644 --- a/assembly/src/assembler/procedure.rs +++ b/assembly/src/assembler/procedure.rs @@ -5,7 +5,7 @@ use crate::{ diagnostics::SourceFile, LibraryPath, RpoDigest, SourceSpan, Spanned, }; -use vm_core::mast::{MastForest, MastNodeId, MerkleTreeNode}; +use vm_core::mast::{MastForest, MastNodeId}; pub type CallSet = BTreeSet; diff --git a/core/src/mast/mod.rs b/core/src/mast/mod.rs index 31ad5f1f0d..20d1ed9309 100644 --- a/core/src/mast/mod.rs +++ b/core/src/mast/mod.rs @@ -5,33 +5,30 @@ use miden_crypto::hash::rpo::RpoDigest; mod node; pub use node::{ - get_span_op_group_count, BasicBlockNode, CallNode, DynNode, ExternalNode, JoinNode, LoopNode, - MastNode, OpBatch, OperationOrDecorator, SplitNode, OP_BATCH_SIZE, OP_GROUP_SIZE, + BasicBlockNode, CallNode, DynNode, ExternalNode, JoinNode, LoopNode, MastNode, OpBatch, + OperationOrDecorator, SplitNode, OP_BATCH_SIZE, OP_GROUP_SIZE, }; -use winter_utils::{ByteReader, ByteWriter, Deserializable, DeserializationError, Serializable}; +use winter_utils::DeserializationError; mod serialization; #[cfg(test)] mod tests; -/// Encapsulates the behavior that a [`MastNode`] (and all its variants) is expected to have. -pub trait MerkleTreeNode { - fn digest(&self) -> RpoDigest; - fn to_display<'a>(&'a self, mast_forest: &'a MastForest) -> impl fmt::Display + 'a; -} - // MAST FOREST // ================================================================================================ -/// Represents the types of errors that can occur when dealing with MAST forest. -#[derive(Debug, thiserror::Error)] -pub enum MastForestError { - #[error( - "invalid node count: MAST forest exceeds the maximum of {} nodes", - MastForest::MAX_NODES - )] - TooManyNodes, +/// Represents one or more procedures, represented as a collection of [`MastNode`]s. +/// +/// A [`MastForest`] does not have an entrypoint, and hence is not executable. A [`crate::Program`] +/// can be built from a [`MastForest`] to specify an entrypoint. +#[derive(Clone, Debug, Default, PartialEq, Eq)] +pub struct MastForest { + /// All of the nodes local to the trees comprising the MAST forest. + nodes: Vec, + + /// Roots of procedures defined within this MAST forest. + roots: Vec, } // ------------------------------------------------------------------------------------------------ @@ -155,38 +152,33 @@ impl MastNodeId { } } -impl fmt::Display for MastNodeId { - fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { - write!(f, "MastNodeId({})", self.0) +impl From for u32 { + fn from(value: MastNodeId) -> Self { + value.0 } } -impl Serializable for MastNodeId { - fn write_into(&self, target: &mut W) { - self.0.write_into(target) +impl From<&MastNodeId> for u32 { + fn from(value: &MastNodeId) -> Self { + value.0 } } -impl Deserializable for MastNodeId { - fn read_from(source: &mut R) -> Result { - let inner = source.read_u32()?; - - Ok(Self(inner)) +impl fmt::Display for MastNodeId { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + write!(f, "MastNodeId({})", self.0) } } // MAST FOREST ERROR // ================================================================================================ -/// Represents one or more procedures, represented as a collection of [`MastNode`]s. -/// -/// A [`MastForest`] does not have an entrypoint, and hence is not executable. A [`crate::Program`] -/// can be built from a [`MastForest`] to specify an entrypoint. -#[derive(Clone, Debug, Default, PartialEq, Eq)] -pub struct MastForest { - /// All of the nodes local to the trees comprising the MAST forest. - nodes: Vec, - - /// Roots of procedures defined within this MAST forest. - roots: Vec, +/// Represents the types of errors that can occur when dealing with MAST forest. +#[derive(Debug, thiserror::Error)] +pub enum MastForestError { + #[error( + "invalid node count: MAST forest exceeds the maximum of {} nodes", + MastForest::MAX_NODES + )] + TooManyNodes, } diff --git a/core/src/mast/node/basic_block_node/mod.rs b/core/src/mast/node/basic_block_node/mod.rs index a8d07ab2de..a98c5f7785 100644 --- a/core/src/mast/node/basic_block_node/mod.rs +++ b/core/src/mast/node/basic_block_node/mod.rs @@ -5,11 +5,11 @@ use miden_crypto::{hash::rpo::RpoDigest, Felt, ZERO}; use miden_formatting::prettier::PrettyPrint; use winter_utils::flatten_slice_elements; -use crate::{ - chiplets::hasher, - mast::{MastForest, MerkleTreeNode}, - Decorator, DecoratorIterator, DecoratorList, Operation, -}; +use crate::{chiplets::hasher, Decorator, DecoratorIterator, DecoratorList, Operation}; + +mod op_batch; +pub use op_batch::OpBatch; +use op_batch::OpBatchAccumulator; #[cfg(test)] mod tests; @@ -67,12 +67,14 @@ pub struct BasicBlockNode { decorators: DecoratorList, } +// ------------------------------------------------------------------------------------------------ /// Constants impl BasicBlockNode { /// The domain of the basic block node (used for control block hashing). pub const DOMAIN: Felt = ZERO; } +// ------------------------------------------------------------------------------------------------ /// Constructors impl BasicBlockNode { /// Returns a new [`BasicBlockNode`] instantiated with the specified operations. @@ -108,25 +110,39 @@ impl BasicBlockNode { } } +// ------------------------------------------------------------------------------------------------ /// Public accessors impl BasicBlockNode { - pub fn num_operations_and_decorators(&self) -> u32 { - let num_ops: usize = self.op_batches.iter().map(|batch| batch.ops().len()).sum(); - let num_decorators = self.decorators.len(); - - (num_ops + num_decorators) - .try_into() - .expect("basic block contains more than 2^32 operations and decorators") + /// Returns a commitment to this basic block. + pub fn digest(&self) -> RpoDigest { + self.digest } + /// Returns a reference to the operation batches in this basic block. pub fn op_batches(&self) -> &[OpBatch] { &self.op_batches } - /// Returns an iterator over all operations and decorator, in the order in which they appear in - /// the program. - pub fn iter(&self) -> impl Iterator { - OperationOrDecoratorIterator::new(self) + /// Returns the total number of operation groups in this basic block. + /// + /// Then number of operation groups is computed as follows: + /// - For all batches but the last one we set the number of groups to 8, regardless of the + /// actual number of groups in the batch. The reason for this is that when operation batches + /// are concatenated together each batch contributes 8 elements to the hash. + /// - For the last batch, we take the number of actual groups and round it up to the next power + /// of two. The reason for rounding is that the VM always executes a number of operation + /// groups which is a power of two. + pub fn num_op_groups(&self) -> usize { + let last_batch_num_groups = self.op_batches.last().expect("no last group").num_groups(); + (self.op_batches.len() - 1) * BATCH_SIZE + last_batch_num_groups.next_power_of_two() + } + + /// Returns a list of decorators in this basic block node. + /// + /// Each decorator is accompanied by the operation index specifying the operation prior to + /// which the decorator should be executed. + pub fn decorators(&self) -> &DecoratorList { + &self.decorators } /// Returns a [`DecoratorIterator`] which allows us to iterate through the decorator list of @@ -135,39 +151,25 @@ impl BasicBlockNode { DecoratorIterator::new(&self.decorators) } - /// Returns a list of decorators in this basic block node. - pub fn decorators(&self) -> &DecoratorList { - &self.decorators - } -} + /// Returns the total number of operations and decorators in this basic block. + pub fn num_operations_and_decorators(&self) -> u32 { + let num_ops: usize = self.op_batches.iter().map(|batch| batch.ops().len()).sum(); + let num_decorators = self.decorators.len(); -impl MerkleTreeNode for BasicBlockNode { - fn digest(&self) -> RpoDigest { - self.digest + (num_ops + num_decorators) + .try_into() + .expect("basic block contains more than 2^32 operations and decorators") } - fn to_display<'a>(&'a self, _mast_forest: &'a MastForest) -> impl fmt::Display + 'a { - self + /// Returns an iterator over all operations and decorator, in the order in which they appear in + /// the program. + pub fn iter(&self) -> impl Iterator { + OperationOrDecoratorIterator::new(self) } } -/// Checks if a given decorators list is valid (only checked in debug mode) -/// - Assert the decorator list is in ascending order. -/// - Assert the last op index in decorator list is less than or equal to the number of operations. -#[cfg(debug_assertions)] -fn validate_decorators(operations: &[Operation], decorators: &DecoratorList) { - if !decorators.is_empty() { - // check if decorator list is sorted - for i in 0..(decorators.len() - 1) { - debug_assert!(decorators[i + 1].0 >= decorators[i].0, "unsorted decorators list"); - } - // assert the last index in decorator list is less than operations vector length - debug_assert!( - operations.len() >= decorators.last().expect("empty decorators list").0, - "last op index in decorator list should be less than or equal to the number of ops" - ); - } -} +// PRETTY PRINTING +// ================================================================================================ impl PrettyPrint for BasicBlockNode { #[rustfmt::skip] @@ -291,175 +293,6 @@ impl<'a> Iterator for OperationOrDecoratorIterator<'a> { } } -// OPERATION BATCH -// ================================================================================================ - -/// A batch of operations in a span block. -/// -/// An operation batch consists of up to 8 operation groups, with each group containing up to 9 -/// operations or a single immediate value. -#[derive(Clone, Debug, PartialEq, Eq)] -pub struct OpBatch { - ops: Vec, - groups: [Felt; BATCH_SIZE], - op_counts: [usize; BATCH_SIZE], - num_groups: usize, -} - -impl OpBatch { - /// Returns a list of operations contained in this batch. - pub fn ops(&self) -> &[Operation] { - &self.ops - } - - /// Returns a list of operation groups contained in this batch. - /// - /// Each group is represented by a single field element. - pub fn groups(&self) -> &[Felt; BATCH_SIZE] { - &self.groups - } - - /// Returns the number of non-decorator operations for each operation group. - /// - /// Number of operations for groups containing immediate values is set to 0. - pub fn op_counts(&self) -> &[usize; BATCH_SIZE] { - &self.op_counts - } - - /// Returns the number of groups in this batch. - pub fn num_groups(&self) -> usize { - self.num_groups - } -} - -/// An accumulator used in construction of operation batches. -struct OpBatchAccumulator { - /// A list of operations in this batch, including decorators. - ops: Vec, - /// Values of operation groups, including immediate values. - groups: [Felt; BATCH_SIZE], - /// Number of non-decorator operations in each operation group. Operation count for groups - /// with immediate values is set to 0. - op_counts: [usize; BATCH_SIZE], - /// Value of the currently active op group. - group: u64, - /// Index of the next opcode in the current group. - op_idx: usize, - /// index of the current group in the batch. - group_idx: usize, - // Index of the next free group in the batch. - next_group_idx: usize, -} - -impl OpBatchAccumulator { - /// Returns a blank [OpBatchAccumulator]. - pub fn new() -> Self { - Self { - ops: Vec::new(), - groups: [ZERO; BATCH_SIZE], - op_counts: [0; BATCH_SIZE], - group: 0, - op_idx: 0, - group_idx: 0, - next_group_idx: 1, - } - } - - /// Returns true if this accumulator does not contain any operations. - pub fn is_empty(&self) -> bool { - self.ops.is_empty() - } - - /// Returns true if this accumulator can accept the specified operation. - /// - /// An accumulator may not be able accept an operation for the following reasons: - /// - There is no more space in the underlying batch (e.g., the 8th group of the batch already - /// contains 9 operations). - /// - There is no space for the immediate value carried by the operation (e.g., the 8th group is - /// only partially full, but we are trying to add a PUSH operation). - /// - The alignment rules require that the operation overflows into the next group, and if this - /// happens, there will be no space for the operation or its immediate value. - pub fn can_accept_op(&self, op: Operation) -> bool { - if op.imm_value().is_some() { - // an operation carrying an immediate value cannot be the last one in a group; so, we - // check if we need to move the operation to the next group. in either case, we need - // to make sure there is enough space for the immediate value as well. - if self.op_idx < GROUP_SIZE - 1 { - self.next_group_idx < BATCH_SIZE - } else { - self.next_group_idx + 1 < BATCH_SIZE - } - } else { - // check if there is space for the operation in the current group, or if there isn't, - // whether we can add another group - self.op_idx < GROUP_SIZE || self.next_group_idx < BATCH_SIZE - } - } - - /// Adds the specified operation to this accumulator. It is expected that the specified - /// operation is not a decorator and that (can_accept_op())[OpBatchAccumulator::can_accept_op] - /// is called before this function to make sure that the specified operation can be added to - /// the accumulator. - pub fn add_op(&mut self, op: Operation) { - // if the group is full, finalize it and start a new group - if self.op_idx == GROUP_SIZE { - self.finalize_op_group(); - } - - // for operations with immediate values, we need to do a few more things - if let Some(imm) = op.imm_value() { - // since an operation with an immediate value cannot be the last one in a group, if - // the operation would be the last one in the group, we need to start a new group - if self.op_idx == GROUP_SIZE - 1 { - self.finalize_op_group(); - } - - // save the immediate value at the next group index and advance the next group pointer - self.groups[self.next_group_idx] = imm; - self.next_group_idx += 1; - } - - // add the opcode to the group and increment the op index pointer - let opcode = op.op_code() as u64; - self.group |= opcode << (Operation::OP_BITS * self.op_idx); - self.ops.push(op); - self.op_idx += 1; - } - - /// Convert the accumulator into an [OpBatch]. - pub fn into_batch(mut self) -> OpBatch { - // make sure the last group gets added to the group array; we also check the op_idx to - // handle the case when a group contains a single NOOP operation. - if self.group != 0 || self.op_idx != 0 { - self.groups[self.group_idx] = Felt::new(self.group); - self.op_counts[self.group_idx] = self.op_idx; - } - - OpBatch { - ops: self.ops, - groups: self.groups, - op_counts: self.op_counts, - num_groups: self.next_group_idx, - } - } - - // HELPER METHODS - // -------------------------------------------------------------------------------------------- - - /// Saves the current group into the group array, advances current and next group pointers, - /// and resets group content. - fn finalize_op_group(&mut self) { - self.groups[self.group_idx] = Felt::new(self.group); - self.op_counts[self.group_idx] = self.op_idx; - - self.group_idx = self.next_group_idx; - self.next_group_idx = self.group_idx + 1; - - self.op_idx = 0; - self.group = 0; - } -} - // HELPER FUNCTIONS // ================================================================================================ @@ -501,17 +334,20 @@ fn batch_ops(ops: Vec) -> (Vec, RpoDigest) { (batches, hash) } -/// Returns the total number of operation groups in a span defined by the provides list of -/// operation batches. -/// -/// Then number of operation groups is computed as follows: -/// - For all batches but the last one we set the number of groups to 8, regardless of the actual -/// number of groups in the batch. The reason for this is that when operation batches are -/// concatenated together each batch contributes 8 elements to the hash. -/// - For the last batch, we take the number of actual batches and round it up to the next power of -/// two. The reason for rounding is that the VM always executes a number of operation groups which -/// is a power of two. -pub fn get_span_op_group_count(op_batches: &[OpBatch]) -> usize { - let last_batch_num_groups = op_batches.last().expect("no last group").num_groups(); - (op_batches.len() - 1) * BATCH_SIZE + last_batch_num_groups.next_power_of_two() +/// Checks if a given decorators list is valid (only checked in debug mode) +/// - Assert the decorator list is in ascending order. +/// - Assert the last op index in decorator list is less than or equal to the number of operations. +#[cfg(debug_assertions)] +fn validate_decorators(operations: &[Operation], decorators: &DecoratorList) { + if !decorators.is_empty() { + // check if decorator list is sorted + for i in 0..(decorators.len() - 1) { + debug_assert!(decorators[i + 1].0 >= decorators[i].0, "unsorted decorators list"); + } + // assert the last index in decorator list is less than operations vector length + debug_assert!( + operations.len() >= decorators.last().expect("empty decorators list").0, + "last op index in decorator list should be less than or equal to the number of ops" + ); + } } diff --git a/core/src/mast/node/basic_block_node/op_batch.rs b/core/src/mast/node/basic_block_node/op_batch.rs new file mode 100644 index 0000000000..f9e064f15a --- /dev/null +++ b/core/src/mast/node/basic_block_node/op_batch.rs @@ -0,0 +1,175 @@ +use super::{Felt, Operation, BATCH_SIZE, GROUP_SIZE, ZERO}; + +use alloc::vec::Vec; + +// OPERATION BATCH +// ================================================================================================ + +/// A batch of operations in a span block. +/// +/// An operation batch consists of up to 8 operation groups, with each group containing up to 9 +/// operations or a single immediate value. +#[derive(Clone, Debug, PartialEq, Eq)] +pub struct OpBatch { + pub(super) ops: Vec, + pub(super) groups: [Felt; BATCH_SIZE], + pub(super) op_counts: [usize; BATCH_SIZE], + pub(super) num_groups: usize, +} + +impl OpBatch { + /// Returns a list of operations contained in this batch. + pub fn ops(&self) -> &[Operation] { + &self.ops + } + + /// Returns a list of operation groups contained in this batch. + /// + /// Each group is represented by a single field element. + pub fn groups(&self) -> &[Felt; BATCH_SIZE] { + &self.groups + } + + /// Returns the number of non-decorator operations for each operation group. + /// + /// Number of operations for groups containing immediate values is set to 0. + pub fn op_counts(&self) -> &[usize; BATCH_SIZE] { + &self.op_counts + } + + /// Returns the number of groups in this batch. + pub fn num_groups(&self) -> usize { + self.num_groups + } +} + +// OPERATION BATCH ACCUMULATOR +// ================================================================================================ + +/// An accumulator used in construction of operation batches. +pub(super) struct OpBatchAccumulator { + /// A list of operations in this batch, including decorators. + ops: Vec, + /// Values of operation groups, including immediate values. + groups: [Felt; BATCH_SIZE], + /// Number of non-decorator operations in each operation group. Operation count for groups + /// with immediate values is set to 0. + op_counts: [usize; BATCH_SIZE], + /// Value of the currently active op group. + group: u64, + /// Index of the next opcode in the current group. + op_idx: usize, + /// index of the current group in the batch. + group_idx: usize, + // Index of the next free group in the batch. + next_group_idx: usize, +} + +impl OpBatchAccumulator { + /// Returns a blank [OpBatchAccumulator]. + pub fn new() -> Self { + Self { + ops: Vec::new(), + groups: [ZERO; BATCH_SIZE], + op_counts: [0; BATCH_SIZE], + group: 0, + op_idx: 0, + group_idx: 0, + next_group_idx: 1, + } + } + + /// Returns true if this accumulator does not contain any operations. + pub fn is_empty(&self) -> bool { + self.ops.is_empty() + } + + /// Returns true if this accumulator can accept the specified operation. + /// + /// An accumulator may not be able accept an operation for the following reasons: + /// - There is no more space in the underlying batch (e.g., the 8th group of the batch already + /// contains 9 operations). + /// - There is no space for the immediate value carried by the operation (e.g., the 8th group is + /// only partially full, but we are trying to add a PUSH operation). + /// - The alignment rules require that the operation overflows into the next group, and if this + /// happens, there will be no space for the operation or its immediate value. + pub fn can_accept_op(&self, op: Operation) -> bool { + if op.imm_value().is_some() { + // an operation carrying an immediate value cannot be the last one in a group; so, we + // check if we need to move the operation to the next group. in either case, we need + // to make sure there is enough space for the immediate value as well. + if self.op_idx < GROUP_SIZE - 1 { + self.next_group_idx < BATCH_SIZE + } else { + self.next_group_idx + 1 < BATCH_SIZE + } + } else { + // check if there is space for the operation in the current group, or if there isn't, + // whether we can add another group + self.op_idx < GROUP_SIZE || self.next_group_idx < BATCH_SIZE + } + } + + /// Adds the specified operation to this accumulator. It is expected that the specified + /// operation is not a decorator and that (can_accept_op())[OpBatchAccumulator::can_accept_op] + /// is called before this function to make sure that the specified operation can be added to + /// the accumulator. + pub fn add_op(&mut self, op: Operation) { + // if the group is full, finalize it and start a new group + if self.op_idx == GROUP_SIZE { + self.finalize_op_group(); + } + + // for operations with immediate values, we need to do a few more things + if let Some(imm) = op.imm_value() { + // since an operation with an immediate value cannot be the last one in a group, if + // the operation would be the last one in the group, we need to start a new group + if self.op_idx == GROUP_SIZE - 1 { + self.finalize_op_group(); + } + + // save the immediate value at the next group index and advance the next group pointer + self.groups[self.next_group_idx] = imm; + self.next_group_idx += 1; + } + + // add the opcode to the group and increment the op index pointer + let opcode = op.op_code() as u64; + self.group |= opcode << (Operation::OP_BITS * self.op_idx); + self.ops.push(op); + self.op_idx += 1; + } + + /// Convert the accumulator into an [OpBatch]. + pub fn into_batch(mut self) -> OpBatch { + // make sure the last group gets added to the group array; we also check the op_idx to + // handle the case when a group contains a single NOOP operation. + if self.group != 0 || self.op_idx != 0 { + self.groups[self.group_idx] = Felt::new(self.group); + self.op_counts[self.group_idx] = self.op_idx; + } + + OpBatch { + ops: self.ops, + groups: self.groups, + op_counts: self.op_counts, + num_groups: self.next_group_idx, + } + } + + // HELPER METHODS + // -------------------------------------------------------------------------------------------- + + /// Saves the current group into the group array, advances current and next group pointers, + /// and resets group content. + pub(super) fn finalize_op_group(&mut self) { + self.groups[self.group_idx] = Felt::new(self.group); + self.op_counts[self.group_idx] = self.op_idx; + + self.group_idx = self.next_group_idx; + self.next_group_idx = self.group_idx + 1; + + self.op_idx = 0; + self.group = 0; + } +} diff --git a/core/src/mast/node/call_node.rs b/core/src/mast/node/call_node.rs index 1f00d74c00..ca9c720195 100644 --- a/core/src/mast/node/call_node.rs +++ b/core/src/mast/node/call_node.rs @@ -5,10 +5,19 @@ use miden_formatting::prettier::PrettyPrint; use crate::{ chiplets::hasher, - mast::{MastForest, MastNodeId, MerkleTreeNode}, + mast::{MastForest, MastNodeId}, OPCODE_CALL, OPCODE_SYSCALL, }; +// CALL NODE +// ================================================================================================ + +/// A Call node describes a function call such that the callee is executed in a different execution +/// context from the currently executing code. +/// +/// A call node can be of two types: +/// - A simple call: the callee is executed in the new user context. +/// - A syscall: the callee is executed in the root context. #[derive(Debug, Clone, PartialEq, Eq)] pub struct CallNode { callee: MastNodeId, @@ -16,6 +25,7 @@ pub struct CallNode { digest: RpoDigest, } +//------------------------------------------------------------------------------------------------- /// Constants impl CallNode { /// The domain of the call block (used for control block hashing). @@ -24,6 +34,7 @@ impl CallNode { pub const SYSCALL_DOMAIN: Felt = Felt::new(OPCODE_SYSCALL as u64); } +//------------------------------------------------------------------------------------------------- /// Constructors impl CallNode { /// Returns a new [`CallNode`] instantiated with the specified callee. @@ -58,16 +69,42 @@ impl CallNode { } } +//------------------------------------------------------------------------------------------------- +/// Public accessors impl CallNode { + /// Returns a commitment to this Call node. + /// + /// The commitment is computed as a hash of the callee and an empty word ([ZERO; 4]) in the + /// domain defined by either [Self::CALL_DOMAIN] or [Self::SYSCALL_DOMAIN], depending on + /// whether the node represents a simple call or a syscall - i.e.,: + /// ``` + /// # use miden_core::mast::CallNode; + /// # use miden_crypto::{hash::rpo::{RpoDigest as Digest, Rpo256 as Hasher}}; + /// # let callee_digest = Digest::default(); + /// Hasher::merge_in_domain(&[callee_digest, Digest::default()], CallNode::CALL_DOMAIN); + /// ``` + /// or + /// ``` + /// # use miden_core::mast::CallNode; + /// # use miden_crypto::{hash::rpo::{RpoDigest as Digest, Rpo256 as Hasher}}; + /// # let callee_digest = Digest::default(); + /// Hasher::merge_in_domain(&[callee_digest, Digest::default()], CallNode::SYSCALL_DOMAIN); + /// ``` + pub fn digest(&self) -> RpoDigest { + self.digest + } + + /// Returns the ID of the node to be invoked by this call node. pub fn callee(&self) -> MastNodeId { self.callee } + /// Returns true if this call node represents a syscall. pub fn is_syscall(&self) -> bool { self.is_syscall } - /// Returns the domain of the call node. + /// Returns the domain of this call node. pub fn domain(&self) -> Felt { if self.is_syscall() { Self::SYSCALL_DOMAIN @@ -75,7 +112,12 @@ impl CallNode { Self::CALL_DOMAIN } } +} + +// PRETTY PRINTING +// ================================================================================================ +impl CallNode { pub(super) fn to_pretty_print<'a>( &'a self, mast_forest: &'a MastForest, @@ -85,14 +127,8 @@ impl CallNode { mast_forest, } } -} - -impl MerkleTreeNode for CallNode { - fn digest(&self) -> RpoDigest { - self.digest - } - fn to_display<'a>(&'a self, mast_forest: &'a MastForest) -> impl fmt::Display + 'a { + pub(super) fn to_display<'a>(&'a self, mast_forest: &'a MastForest) -> impl fmt::Display + 'a { CallNodePrettyPrint { call_node: self, mast_forest, diff --git a/core/src/mast/node/dyn_node.rs b/core/src/mast/node/dyn_node.rs index 83c46f68fe..934a8fec2d 100644 --- a/core/src/mast/node/dyn_node.rs +++ b/core/src/mast/node/dyn_node.rs @@ -2,11 +2,12 @@ use core::fmt; use miden_crypto::{hash::rpo::RpoDigest, Felt}; -use crate::{ - mast::{MastForest, MerkleTreeNode}, - OPCODE_DYN, -}; +use crate::OPCODE_DYN; +// DYN NODE +// ================================================================================================ + +/// A Dyn node specifies that the node to be executed next is defined dynamically via the stack. #[derive(Debug, Clone, Default, PartialEq, Eq)] pub struct DynNode; @@ -16,11 +17,19 @@ impl DynNode { pub const DOMAIN: Felt = Felt::new(OPCODE_DYN as u64); } -impl MerkleTreeNode for DynNode { - fn digest(&self) -> RpoDigest { - // The Dyn node is represented by a constant, which is set to be the hash of two empty - // words ([ZERO, ZERO, ZERO, ZERO]) with a domain value of `DYN_DOMAIN`, i.e. - // hasher::merge_in_domain(&[Digest::default(), Digest::default()], DynNode::DOMAIN) +/// Public accessors +impl DynNode { + /// Returns a commitment to a Dyn node. + /// + /// The commitment is computed by hashing two empty words ([ZERO; 4]) in the domain defined + /// by [Self::DOMAIN], i.e.: + /// + /// ``` + /// # use miden_core::mast::DynNode; + /// # use miden_crypto::{hash::rpo::{RpoDigest as Digest, Rpo256 as Hasher}}; + /// Hasher::merge_in_domain(&[Digest::default(), Digest::default()], DynNode::DOMAIN); + /// ``` + pub fn digest(&self) -> RpoDigest { RpoDigest::new([ Felt::new(8115106948140260551), Felt::new(13491227816952616836), @@ -28,12 +37,11 @@ impl MerkleTreeNode for DynNode { Felt::new(16575543461540527115), ]) } - - fn to_display<'a>(&'a self, _mast_forest: &'a MastForest) -> impl fmt::Display + 'a { - self - } } +// PRETTY PRINTING +// ================================================================================================ + impl crate::prettier::PrettyPrint for DynNode { fn render(&self) -> crate::prettier::Document { use crate::prettier::*; @@ -48,6 +56,9 @@ impl fmt::Display for DynNode { } } +// TESTS +// ================================================================================================ + #[cfg(test)] mod tests { use miden_crypto::hash::rpo::Rpo256; diff --git a/core/src/mast/node/external.rs b/core/src/mast/node/external.rs index c0b8ff10a3..e0d15922ff 100644 --- a/core/src/mast/node/external.rs +++ b/core/src/mast/node/external.rs @@ -1,7 +1,10 @@ -use crate::mast::{MastForest, MerkleTreeNode}; +use crate::mast::MastForest; use core::fmt; use miden_crypto::hash::rpo::RpoDigest; +// EXTERNAL NODE +// ================================================================================================ + /// Node for referencing procedures not present in a given [`MastForest`] (hence "external"). /// /// External nodes can be used to verify the integrity of a program's hash while keeping parts of @@ -24,11 +27,18 @@ impl ExternalNode { } } -impl MerkleTreeNode for ExternalNode { - fn digest(&self) -> RpoDigest { +impl ExternalNode { + /// Returns the commitment to the MAST node referenced by this external node. + pub fn digest(&self) -> RpoDigest { self.digest } - fn to_display<'a>(&'a self, _mast_forest: &'a MastForest) -> impl fmt::Display + 'a { +} + +// PRETTY PRINTING +// ================================================================================================ + +impl ExternalNode { + pub(super) fn to_display<'a>(&'a self, _mast_forest: &'a MastForest) -> impl fmt::Display + 'a { self } } diff --git a/core/src/mast/node/join_node.rs b/core/src/mast/node/join_node.rs index 1cd5322c0f..5f802873dd 100644 --- a/core/src/mast/node/join_node.rs +++ b/core/src/mast/node/join_node.rs @@ -4,11 +4,16 @@ use miden_crypto::{hash::rpo::RpoDigest, Felt}; use crate::{ chiplets::hasher, - mast::{MastForest, MastNodeId, MerkleTreeNode}, + mast::{MastForest, MastNodeId}, prettier::PrettyPrint, OPCODE_JOIN, }; +// JOIN NODE +// ================================================================================================ + +/// A Join node describe sequential execution. When the VM encounters a Join node, it executes the +/// first child first and the second child second. #[derive(Debug, Clone, PartialEq, Eq)] pub struct JoinNode { children: [MastNodeId; 2], @@ -41,35 +46,50 @@ impl JoinNode { } } -/// Accessors +/// Public accessors impl JoinNode { + /// Returns a commitment to this Join node. + /// + /// The commitment is computed as a hash of the `first` and `second` child node in the domain + /// defined by [Self::DOMAIN] - i.e.,: + /// ``` + /// # use miden_core::mast::JoinNode; + /// # use miden_crypto::{hash::rpo::{RpoDigest as Digest, Rpo256 as Hasher}}; + /// # let first_child_digest = Digest::default(); + /// # let second_child_digest = Digest::default(); + /// Hasher::merge_in_domain(&[first_child_digest, second_child_digest], JoinNode::DOMAIN); + /// ``` + pub fn digest(&self) -> RpoDigest { + self.digest + } + + /// Returns the ID of the node that is to be executed first. pub fn first(&self) -> MastNodeId { self.children[0] } + /// Returns the ID of the node that is to be executed after the execution of the program + /// defined by the first node completes. pub fn second(&self) -> MastNodeId { self.children[1] } } +// PRETTY PRINTING +// ================================================================================================ + impl JoinNode { - pub(super) fn to_pretty_print<'a>( - &'a self, - mast_forest: &'a MastForest, - ) -> impl PrettyPrint + 'a { + pub(super) fn to_display<'a>(&'a self, mast_forest: &'a MastForest) -> impl fmt::Display + 'a { JoinNodePrettyPrint { join_node: self, mast_forest, } } -} - -impl MerkleTreeNode for JoinNode { - fn digest(&self) -> RpoDigest { - self.digest - } - fn to_display<'a>(&'a self, mast_forest: &'a MastForest) -> impl fmt::Display + 'a { + pub(super) fn to_pretty_print<'a>( + &'a self, + mast_forest: &'a MastForest, + ) -> impl PrettyPrint + 'a { JoinNodePrettyPrint { join_node: self, mast_forest, diff --git a/core/src/mast/node/loop_node.rs b/core/src/mast/node/loop_node.rs index 1554181b46..aec1b0b451 100644 --- a/core/src/mast/node/loop_node.rs +++ b/core/src/mast/node/loop_node.rs @@ -5,10 +5,19 @@ use miden_formatting::prettier::PrettyPrint; use crate::{ chiplets::hasher, - mast::{MastForest, MastNodeId, MerkleTreeNode}, + mast::{MastForest, MastNodeId}, OPCODE_LOOP, }; +// LOOP NODE +// ================================================================================================ + +/// A Loop node defines condition-controlled iterative execution. When the VM encounters a Loop +/// node, it will keep executing the body of the loop as long as the top of the stack is `1``. +/// +/// The loop is exited when at the end of executing the loop body the top of the stack is `0``. +/// If the top of the stack is neither `0` nor `1` when the condition is checked, the execution +/// fails. #[derive(Debug, Clone, PartialEq, Eq)] pub struct LoopNode { body: MastNodeId, @@ -32,30 +41,44 @@ impl LoopNode { Self { body, digest } } - - pub(super) fn to_pretty_print<'a>( - &'a self, - mast_forest: &'a MastForest, - ) -> impl PrettyPrint + 'a { - LoopNodePrettyPrint { - loop_node: self, - mast_forest, - } - } } impl LoopNode { + /// Returns a commitment to this Loop node. + /// + /// The commitment is computed as a hash of the loop body and an empty word ([ZERO; 4]) in + /// the domain defined by [Self::DOMAIN] - i..e,: + /// ``` + /// # use miden_core::mast::LoopNode; + /// # use miden_crypto::{hash::rpo::{RpoDigest as Digest, Rpo256 as Hasher}}; + /// # let body_digest = Digest::default(); + /// Hasher::merge_in_domain(&[body_digest, Digest::default()], LoopNode::DOMAIN); + /// ``` + pub fn digest(&self) -> RpoDigest { + self.digest + } + + /// Returns the ID of the node presenting the body of the loop. pub fn body(&self) -> MastNodeId { self.body } } -impl MerkleTreeNode for LoopNode { - fn digest(&self) -> RpoDigest { - self.digest +// PRETTY PRINTING +// ================================================================================================ + +impl LoopNode { + pub(super) fn to_display<'a>(&'a self, mast_forest: &'a MastForest) -> impl fmt::Display + 'a { + LoopNodePrettyPrint { + loop_node: self, + mast_forest, + } } - fn to_display<'a>(&'a self, mast_forest: &'a MastForest) -> impl fmt::Display + 'a { + pub(super) fn to_pretty_print<'a>( + &'a self, + mast_forest: &'a MastForest, + ) -> impl PrettyPrint + 'a { LoopNodePrettyPrint { loop_node: self, mast_forest, diff --git a/core/src/mast/node/mod.rs b/core/src/mast/node/mod.rs index a1c52bb211..d84dcf2916 100644 --- a/core/src/mast/node/mod.rs +++ b/core/src/mast/node/mod.rs @@ -3,8 +3,8 @@ use core::fmt; use alloc::{boxed::Box, vec::Vec}; pub use basic_block_node::{ - get_span_op_group_count, BasicBlockNode, OpBatch, OperationOrDecorator, - BATCH_SIZE as OP_BATCH_SIZE, GROUP_SIZE as OP_GROUP_SIZE, + BasicBlockNode, OpBatch, OperationOrDecorator, BATCH_SIZE as OP_BATCH_SIZE, + GROUP_SIZE as OP_GROUP_SIZE, }; mod call_node; @@ -28,7 +28,7 @@ mod loop_node; pub use loop_node::LoopNode; use crate::{ - mast::{MastForest, MastNodeId, MerkleTreeNode}, + mast::{MastForest, MastNodeId}, DecoratorList, Operation, }; @@ -140,13 +140,8 @@ impl MastNode { MastNode::External(_) => panic!("Can't fetch domain for an `External` node."), } } -} - -// ------------------------------------------------------------------------------------------------ -// MerkleTreeNode impl -impl MerkleTreeNode for MastNode { - fn digest(&self) -> RpoDigest { + pub fn digest(&self) -> RpoDigest { match self { MastNode::Block(node) => node.digest(), MastNode::Join(node) => node.digest(), @@ -158,14 +153,14 @@ impl MerkleTreeNode for MastNode { } } - fn to_display<'a>(&'a self, mast_forest: &'a MastForest) -> impl fmt::Display + 'a { + pub fn to_display<'a>(&'a self, mast_forest: &'a MastForest) -> impl fmt::Display + 'a { match self { - MastNode::Block(node) => MastNodeDisplay::new(node.to_display(mast_forest)), + MastNode::Block(node) => MastNodeDisplay::new(node), MastNode::Join(node) => MastNodeDisplay::new(node.to_display(mast_forest)), MastNode::Split(node) => MastNodeDisplay::new(node.to_display(mast_forest)), MastNode::Loop(node) => MastNodeDisplay::new(node.to_display(mast_forest)), MastNode::Call(node) => MastNodeDisplay::new(node.to_display(mast_forest)), - MastNode::Dyn => MastNodeDisplay::new(DynNode.to_display(mast_forest)), + MastNode::Dyn => MastNodeDisplay::new(DynNode), MastNode::External(node) => MastNodeDisplay::new(node.to_display(mast_forest)), } } diff --git a/core/src/mast/node/split_node.rs b/core/src/mast/node/split_node.rs index 600186a9e7..f754735e35 100644 --- a/core/src/mast/node/split_node.rs +++ b/core/src/mast/node/split_node.rs @@ -5,10 +5,19 @@ use miden_formatting::prettier::PrettyPrint; use crate::{ chiplets::hasher, - mast::{MastForest, MastNodeId, MerkleTreeNode}, + mast::{MastForest, MastNodeId}, OPCODE_SPLIT, }; +// SPLIT NODE +// ================================================================================================ + +/// A Split node defines conditional execution. When the VM encounters a Split node it executes +/// either the `on_true` child or `on_false` child. +/// +/// Which child is executed is determined based on the top of the stack. If the value is `1`, then +/// the `on_true` child is executed. If the value is `0`, then the `on_false` child is executed. If +/// the value is neither `0` nor `1`, the execution fails. #[derive(Debug, Clone, PartialEq, Eq)] pub struct SplitNode { branches: [MastNodeId; 2], @@ -42,33 +51,47 @@ impl SplitNode { /// Public accessors impl SplitNode { + /// Returns a commitment to this Split node. + /// + /// The commitment is computed as a hash of the `on_true` and `on_false` child nodes in the + /// domain defined by [Self::DOMAIN] - i..e,: + /// ``` + /// # use miden_core::mast::SplitNode; + /// # use miden_crypto::{hash::rpo::{RpoDigest as Digest, Rpo256 as Hasher}}; + /// # let on_true_digest = Digest::default(); + /// # let on_false_digest = Digest::default(); + /// Hasher::merge_in_domain(&[on_true_digest, on_false_digest], SplitNode::DOMAIN); + /// ``` + pub fn digest(&self) -> RpoDigest { + self.digest + } + + /// Returns the ID of the node which is to be executed if the top of the stack is `1`. pub fn on_true(&self) -> MastNodeId { self.branches[0] } + /// Returns the ID of the node which is to be executed if the top of the stack is `0`. pub fn on_false(&self) -> MastNodeId { self.branches[1] } } +// PRETTY PRINTING +// ================================================================================================ + impl SplitNode { - pub(super) fn to_pretty_print<'a>( - &'a self, - mast_forest: &'a MastForest, - ) -> impl PrettyPrint + 'a { + pub(super) fn to_display<'a>(&'a self, mast_forest: &'a MastForest) -> impl fmt::Display + 'a { SplitNodePrettyPrint { split_node: self, mast_forest, } } -} - -impl MerkleTreeNode for SplitNode { - fn digest(&self) -> RpoDigest { - self.digest - } - fn to_display<'a>(&'a self, mast_forest: &'a MastForest) -> impl core::fmt::Display + 'a { + pub(super) fn to_pretty_print<'a>( + &'a self, + mast_forest: &'a MastForest, + ) -> impl PrettyPrint + 'a { SplitNodePrettyPrint { split_node: self, mast_forest, diff --git a/core/src/mast/serialization/info.rs b/core/src/mast/serialization/info.rs index 5b5a6aabdb..4646e9a20b 100644 --- a/core/src/mast/serialization/info.rs +++ b/core/src/mast/serialization/info.rs @@ -1,7 +1,7 @@ use miden_crypto::hash::rpo::RpoDigest; use winter_utils::{ByteReader, ByteWriter, Deserializable, DeserializationError, Serializable}; -use crate::mast::{MastForest, MastNode, MastNodeId, MerkleTreeNode}; +use crate::mast::{MastForest, MastNode, MastNodeId}; use super::{basic_block_data_decoder::BasicBlockDataDecoder, DataOffset}; diff --git a/core/src/mast/serialization/mod.rs b/core/src/mast/serialization/mod.rs index e23cdbd2fb..40b32c0923 100644 --- a/core/src/mast/serialization/mod.rs +++ b/core/src/mast/serialization/mod.rs @@ -54,7 +54,8 @@ impl Serializable for MastForest { target.write_usize(self.nodes.len()); // roots - self.roots.write_into(target); + let roots: Vec = self.roots.iter().map(u32::from).collect(); + roots.write_into(target); // Prepare MAST node infos, but don't store them yet. We store them at the end to make // deserialization more efficient. @@ -102,11 +103,8 @@ impl Deserializable for MastForest { } let node_count = source.read_usize()?; - - let roots: Vec = Deserializable::read_from(source)?; - + let roots: Vec = Deserializable::read_from(source)?; let strings: Vec = Deserializable::read_from(source)?; - let data: Vec = Deserializable::read_from(source)?; let basic_block_data_decoder = BasicBlockDataDecoder::new(&data, &strings); @@ -128,6 +126,8 @@ impl Deserializable for MastForest { } for root in roots { + // make sure the root is valid in the context of the MAST forest + let root = MastNodeId::from_u32_safe(root, &mast_forest)?; mast_forest.make_root(root); } diff --git a/core/src/mast/tests.rs b/core/src/mast/tests.rs index da43d1b5b4..1ca87f0d86 100644 --- a/core/src/mast/tests.rs +++ b/core/src/mast/tests.rs @@ -1,8 +1,4 @@ -use crate::{ - chiplets::hasher, - mast::{DynNode, MerkleTreeNode}, - Kernel, ProgramInfo, Word, -}; +use crate::{chiplets::hasher, mast::DynNode, Kernel, ProgramInfo, Word}; use alloc::vec::Vec; use miden_crypto::{hash::rpo::RpoDigest, Felt}; use proptest::prelude::*; diff --git a/core/src/program.rs b/core/src/program.rs index b055bf3313..66419229b1 100644 --- a/core/src/program.rs +++ b/core/src/program.rs @@ -5,7 +5,7 @@ use miden_crypto::{hash::rpo::RpoDigest, Felt, WORD_SIZE}; use winter_utils::{ByteReader, ByteWriter, Deserializable, DeserializationError, Serializable}; use crate::{ - mast::{MastForest, MastNode, MastNodeId, MerkleTreeNode}, + mast::{MastForest, MastNode, MastNodeId}, utils::ToElements, }; diff --git a/miden/tests/integration/operations/io_ops/env_ops.rs b/miden/tests/integration/operations/io_ops/env_ops.rs index 379a974652..5582aa13ae 100644 --- a/miden/tests/integration/operations/io_ops/env_ops.rs +++ b/miden/tests/integration/operations/io_ops/env_ops.rs @@ -1,7 +1,7 @@ use processor::FMP_MIN; use test_utils::{build_op_test, build_test, StackInputs, Test, Word, STACK_TOP_SIZE}; use vm_core::{ - mast::{MastForest, MastNode, MerkleTreeNode}, + mast::{MastForest, MastNode}, Operation, }; diff --git a/processor/src/chiplets/hasher/tests.rs b/processor/src/chiplets/hasher/tests.rs index 44f814ac58..69e70d581f 100644 --- a/processor/src/chiplets/hasher/tests.rs +++ b/processor/src/chiplets/hasher/tests.rs @@ -12,7 +12,7 @@ use test_utils::rand::rand_array; use vm_core::{ chiplets::hasher, crypto::merkle::{MerkleTree, NodeIndex}, - mast::{MastForest, MastNode, MerkleTreeNode}, + mast::{MastForest, MastNode}, Operation, ONE, ZERO, }; diff --git a/processor/src/decoder/mod.rs b/processor/src/decoder/mod.rs index 5d66138df4..36e4a6b19b 100644 --- a/processor/src/decoder/mod.rs +++ b/processor/src/decoder/mod.rs @@ -12,8 +12,7 @@ use miden_air::trace::{ }; use vm_core::{ mast::{ - get_span_op_group_count, BasicBlockNode, CallNode, DynNode, JoinNode, LoopNode, MastForest, - MerkleTreeNode, SplitNode, OP_BATCH_SIZE, + BasicBlockNode, CallNode, DynNode, JoinNode, LoopNode, MastForest, SplitNode, OP_BATCH_SIZE, }, stack::STACK_TOP_SIZE, AssemblyOp, @@ -341,7 +340,7 @@ where // start decoding the first operation batch; this also appends a row with SPAN operation // to the decoder trace. we also need the total number of operation groups so that we can // set the value of the group_count register at the beginning of the SPAN. - let num_op_groups = get_span_op_group_count(op_batches); + let num_op_groups = basic_block.num_op_groups(); self.decoder .start_basic_block(&op_batches[0], Felt::new(num_op_groups as u64), addr); self.execute_op(Operation::Noop) diff --git a/processor/src/decoder/tests.rs b/processor/src/decoder/tests.rs index 390a162e9e..8e84ecc5be 100644 --- a/processor/src/decoder/tests.rs +++ b/processor/src/decoder/tests.rs @@ -18,7 +18,7 @@ use miden_air::trace::{ }; use test_utils::rand::rand_value; use vm_core::{ - mast::{BasicBlockNode, MastForest, MastNode, MerkleTreeNode, OP_BATCH_SIZE}, + mast::{BasicBlockNode, MastForest, MastNode, OP_BATCH_SIZE}, Program, EMPTY_WORD, ONE, ZERO, }; diff --git a/processor/src/lib.rs b/processor/src/lib.rs index a254a1b1a6..95ab09e7ed 100644 --- a/processor/src/lib.rs +++ b/processor/src/lib.rs @@ -18,7 +18,7 @@ pub use vm_core::{ chiplets::hasher::Digest, crypto::merkle::SMT_DEPTH, errors::InputError, - mast::{MastForest, MastNode, MastNodeId, MerkleTreeNode}, + mast::{MastForest, MastNode, MastNodeId}, utils::DeserializationError, AdviceInjector, AssemblyOp, Felt, Kernel, Operation, Program, ProgramInfo, QuadExtension, StackInputs, StackOutputs, Word, EMPTY_WORD, ONE, ZERO, diff --git a/processor/src/trace/tests/decoder.rs b/processor/src/trace/tests/decoder.rs index 83246bdf5d..c5acde3241 100644 --- a/processor/src/trace/tests/decoder.rs +++ b/processor/src/trace/tests/decoder.rs @@ -16,7 +16,7 @@ use miden_air::trace::{ }; use test_utils::rand::rand_array; use vm_core::{ - mast::{MastForest, MastNode, MerkleTreeNode}, + mast::{MastForest, MastNode}, FieldElement, Operation, Program, Word, ONE, ZERO, }; From f1875523bb7219e0f368067dc72cfd444709e992 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Philippe=20Laferri=C3=A8re?= Date: Tue, 23 Jul 2024 16:49:00 -0400 Subject: [PATCH 4/7] feat: make `Assembler` single-use (#1409) --- CHANGELOG.md | 1 + assembly/src/assembler/basic_block_builder.rs | 8 +- assembly/src/assembler/context.rs | 259 +++--------- assembly/src/assembler/instruction/env_ops.rs | 17 +- .../src/assembler/instruction/field_ops.rs | 7 +- assembly/src/assembler/instruction/mem_ops.rs | 13 +- assembly/src/assembler/instruction/mod.rs | 90 ++-- .../src/assembler/instruction/procedures.rs | 40 +- assembly/src/assembler/instruction/u32_ops.rs | 19 +- assembly/src/assembler/mast_forest_builder.rs | 4 - assembly/src/assembler/mod.rs | 388 ++---------------- .../src/assembler/module_graph/callgraph.rs | 15 +- assembly/src/assembler/module_graph/mod.rs | 162 +------- .../src/assembler/module_graph/phantom.rs | 45 -- .../assembler/module_graph/procedure_cache.rs | 17 - .../assembler/module_graph/rewrites/module.rs | 18 +- assembly/src/assembler/procedure.rs | 2 + assembly/src/assembler/tests.rs | 10 +- assembly/src/errors.rs | 9 - assembly/src/lib.rs | 2 +- assembly/src/testing.rs | 17 +- assembly/src/tests.rs | 28 +- miden/tests/integration/flow_control/mod.rs | 14 +- .../integration/operations/io_ops/env_ops.rs | 2 + test-utils/src/lib.rs | 4 +- 25 files changed, 249 insertions(+), 942 deletions(-) delete mode 100644 assembly/src/assembler/module_graph/phantom.rs diff --git a/CHANGELOG.md b/CHANGELOG.md index b8751ad390..1354e4a75e 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -20,6 +20,7 @@ - Updated CI to support `CHANGELOG.md` modification checking and `no changelog` label (#1406) - Introduced `MastForestError` to enforce `MastForest` node count invariant (#1394) - Added functions to `MastForestBuilder` to allow ensuring of nodes with fewer LOC (#1404) +- Make `Assembler` single-use (#1409) #### Changed diff --git a/assembly/src/assembler/basic_block_builder.rs b/assembly/src/assembler/basic_block_builder.rs index 5263faeff2..362b57e138 100644 --- a/assembly/src/assembler/basic_block_builder.rs +++ b/assembly/src/assembler/basic_block_builder.rs @@ -1,6 +1,6 @@ use super::{ - mast_forest_builder::MastForestBuilder, AssemblyContext, BodyWrapper, Decorator, DecoratorList, - Instruction, + context::ProcedureContext, mast_forest_builder::MastForestBuilder, BodyWrapper, Decorator, + DecoratorList, Instruction, }; use alloc::{borrow::Borrow, string::ToString, vec::Vec}; use vm_core::{ @@ -85,8 +85,8 @@ impl BasicBlockBuilder { /// /// This indicates that the provided instruction should be tracked and the cycle count for /// this instruction will be computed when the call to set_instruction_cycle_count() is made. - pub fn track_instruction(&mut self, instruction: &Instruction, ctx: &AssemblyContext) { - let context_name = ctx.unwrap_current_procedure().name().to_string(); + pub fn track_instruction(&mut self, instruction: &Instruction, proc_ctx: &ProcedureContext) { + let context_name = proc_ctx.name().to_string(); let num_cycles = 0; let op = instruction.to_string(); let should_break = instruction.should_break(); diff --git a/assembly/src/assembler/context.rs b/assembly/src/assembler/context.rs index 1cc60d225d..9f60511653 100644 --- a/assembly/src/assembler/context.rs +++ b/assembly/src/assembler/context.rs @@ -1,193 +1,18 @@ use alloc::{boxed::Box, sync::Arc}; -use super::{procedure::CallSet, ArtifactKind, GlobalProcedureIndex, Procedure}; +use super::{procedure::CallSet, GlobalProcedureIndex, Procedure}; use crate::{ ast::{FullyQualifiedProcedureName, Visibility}, diagnostics::SourceFile, - AssemblyError, LibraryPath, RpoDigest, SourceSpan, Span, Spanned, + AssemblyError, LibraryPath, RpoDigest, SourceSpan, Spanned, }; use vm_core::mast::{MastForest, MastNodeId}; -// ASSEMBLY CONTEXT -// ================================================================================================ - -/// An [AssemblyContext] is used to store configuration and state pertaining to the current -/// compilation of a module/procedure by an [crate::Assembler]. -/// -/// The context specifies context-specific configuration, the type of artifact being compiled, -/// the current module being compiled, and the current procedure being compiled. -/// -/// To provide a custom context, you must compile by invoking the -/// [crate::Assembler::assemble_in_context] API, which will use the provided context in place of -/// the default one generated internally by the other `compile`-like APIs. -#[derive(Default)] -pub struct AssemblyContext { - /// What kind of artifact are we assembling - kind: ArtifactKind, - /// When true, promote warning diagnostics to errors - warnings_as_errors: bool, - /// When true, this permits calls to refer to procedures which are not locally available, - /// as long as they are referenced by MAST root, and not by name. As long as the MAST for those - /// roots is present when the code is executed, this works fine. However, if the VM tries to - /// execute a program with such calls, and the MAST is not available, the program will trap. - allow_phantom_calls: bool, - /// The current procedure being compiled - current_procedure: Option, - /// The fully-qualified module path which should be compiled. - /// - /// If unset, it defaults to the module which represents the specified `kind`, i.e. if the kind - /// is executable, we compile the executable module, and so on. - /// - /// When set, the module graph is traversed from the given module only, so any code unreachable - /// from this module is not considered for compilation. - root: Option, -} - -/// Builders -impl AssemblyContext { - pub fn new(kind: ArtifactKind) -> Self { - Self { - kind, - ..Default::default() - } - } - - /// Returns a new [AssemblyContext] for a non-executable kernel modules. - pub fn for_kernel(path: &LibraryPath) -> Self { - Self::new(ArtifactKind::Kernel).with_root(path.clone()) - } - - /// Returns a new [AssemblyContext] for library modules. - pub fn for_library(path: &LibraryPath) -> Self { - Self::new(ArtifactKind::Library).with_root(path.clone()) - } - - /// Returns a new [AssemblyContext] for an executable module. - pub fn for_program(path: &LibraryPath) -> Self { - Self::new(ArtifactKind::Executable).with_root(path.clone()) - } - - fn with_root(mut self, path: LibraryPath) -> Self { - self.root = Some(path); - self - } - - /// When true, all warning diagnostics are promoted to errors - #[inline(always)] - pub fn set_warnings_as_errors(&mut self, yes: bool) { - self.warnings_as_errors = yes; - } - - #[inline] - pub(super) fn set_current_procedure(&mut self, context: ProcedureContext) { - self.current_procedure = Some(context); - } - - #[inline] - pub(super) fn take_current_procedure(&mut self) -> Option { - self.current_procedure.take() - } - - #[inline] - pub(super) fn unwrap_current_procedure(&self) -> &ProcedureContext { - self.current_procedure.as_ref().expect("missing current procedure context") - } - - #[inline] - pub(super) fn unwrap_current_procedure_mut(&mut self) -> &mut ProcedureContext { - self.current_procedure.as_mut().expect("missing current procedure context") - } - - /// Enables phantom calls when compiling with this context. - /// - /// # Panics - /// - /// This function will panic if you attempt to enable phantom calls for a kernel-mode context, - /// as non-local procedure calls are not allowed in kernel modules. - pub fn with_phantom_calls(mut self, allow_phantom_calls: bool) -> Self { - assert!( - !self.is_kernel() || !allow_phantom_calls, - "kernel modules cannot have phantom calls enabled" - ); - self.allow_phantom_calls = allow_phantom_calls; - self - } - - /// Returns true if this context is used for compiling a kernel. - pub fn is_kernel(&self) -> bool { - matches!(self.kind, ArtifactKind::Kernel) - } - - /// Returns true if this context is used for compiling an executable. - pub fn is_executable(&self) -> bool { - matches!(self.kind, ArtifactKind::Executable) - } - - /// Returns the type of artifact to produce with this context - pub fn kind(&self) -> ArtifactKind { - self.kind - } - - /// Returns true if this context treats warning diagnostics as errors - #[inline(always)] - pub fn warnings_as_errors(&self) -> bool { - self.warnings_as_errors - } - - /// Registers a "phantom" call to the procedure with the specified MAST root. - /// - /// A phantom call indicates that code for the procedure is not available. Executing a phantom - /// call will result in a runtime error. However, the VM may be able to execute a program with - /// phantom calls as long as the branches containing them are not taken. - /// - /// # Errors - /// Returns an error if phantom calls are not allowed in this assembly context. - pub fn register_phantom_call( - &mut self, - mast_root: Span, - ) -> Result<(), AssemblyError> { - if !self.allow_phantom_calls { - let source_file = self.unwrap_current_procedure().source_file().clone(); - let (span, digest) = mast_root.into_parts(); - Err(AssemblyError::PhantomCallsNotAllowed { - span, - source_file, - digest, - }) - } else { - Ok(()) - } - } - - /// Registers a call to an externally-defined procedure which we have previously compiled. - /// - /// The call set of the callee is added to the call set of the procedure we are currently - /// compiling, to reflect that all of the code reachable from the callee is by extension - /// reachable by the caller. - pub fn register_external_call( - &mut self, - callee: &Procedure, - inlined: bool, - mast_forest: &MastForest, - ) -> Result<(), AssemblyError> { - let context = self.unwrap_current_procedure_mut(); - - // If we call the callee, it's callset is by extension part of our callset - context.extend_callset(callee.callset().iter().cloned()); - - // If the callee is not being inlined, add it to our callset - if !inlined { - context.insert_callee(callee.mast_root(mast_forest)); - } - - Ok(()) - } -} - // PROCEDURE CONTEXT // ================================================================================================ -pub(super) struct ProcedureContext { +/// Information about a procedure currently being compiled. +pub struct ProcedureContext { span: SourceSpan, source_file: Option>, gid: GlobalProcedureIndex, @@ -197,8 +22,10 @@ pub(super) struct ProcedureContext { callset: CallSet, } +// ------------------------------------------------------------------------------------------------ +/// Constructors impl ProcedureContext { - pub(super) fn new( + pub fn new( gid: GlobalProcedureIndex, name: FullyQualifiedProcedureName, visibility: Visibility, @@ -214,36 +41,25 @@ impl ProcedureContext { } } - pub(super) fn with_span(mut self, span: SourceSpan) -> Self { - self.span = span; + pub fn with_num_locals(mut self, num_locals: u16) -> Self { + self.num_locals = num_locals; self } - pub(super) fn with_source_file(mut self, source_file: Option>) -> Self { - self.source_file = source_file; + pub fn with_span(mut self, span: SourceSpan) -> Self { + self.span = span; self } - pub(super) fn with_num_locals(mut self, num_locals: u16) -> Self { - self.num_locals = num_locals; + pub fn with_source_file(mut self, source_file: Option>) -> Self { + self.source_file = source_file; self } +} - pub(crate) fn insert_callee(&mut self, callee: RpoDigest) { - self.callset.insert(callee); - } - - pub(crate) fn extend_callset(&mut self, callees: I) - where - I: IntoIterator, - { - self.callset.extend(callees); - } - - pub fn num_locals(&self) -> u16 { - self.num_locals - } - +// ------------------------------------------------------------------------------------------------ +/// Public accessors +impl ProcedureContext { pub fn id(&self) -> GlobalProcedureIndex { self.gid } @@ -252,6 +68,10 @@ impl ProcedureContext { &self.name } + pub fn num_locals(&self) -> u16 { + self.num_locals + } + #[allow(unused)] pub fn module(&self) -> &LibraryPath { &self.name.module @@ -264,6 +84,43 @@ impl ProcedureContext { pub fn is_kernel(&self) -> bool { self.visibility.is_syscall() } +} + +// ------------------------------------------------------------------------------------------------ +/// State mutators +impl ProcedureContext { + pub fn insert_callee(&mut self, callee: RpoDigest) { + self.callset.insert(callee); + } + + pub fn extend_callset(&mut self, callees: I) + where + I: IntoIterator, + { + self.callset.extend(callees); + } + + /// Registers a call to an externally-defined procedure which we have previously compiled. + /// + /// The call set of the callee is added to the call set of the procedure we are currently + /// compiling, to reflect that all of the code reachable from the callee is by extension + /// reachable by the caller. + pub fn register_external_call( + &mut self, + callee: &Procedure, + inlined: bool, + mast_forest: &MastForest, + ) -> Result<(), AssemblyError> { + // If we call the callee, it's callset is by extension part of our callset + self.extend_callset(callee.callset().iter().cloned()); + + // If the callee is not being inlined, add it to our callset + if !inlined { + self.insert_callee(callee.mast_root(mast_forest)); + } + + Ok(()) + } pub fn into_procedure(self, body_node_id: MastNodeId) -> Box { let procedure = diff --git a/assembly/src/assembler/instruction/env_ops.rs b/assembly/src/assembler/instruction/env_ops.rs index 900b5a39af..c9d455957e 100644 --- a/assembly/src/assembler/instruction/env_ops.rs +++ b/assembly/src/assembler/instruction/env_ops.rs @@ -1,5 +1,5 @@ -use super::{mem_ops::local_to_absolute_addr, push_felt, AssemblyContext, BasicBlockBuilder}; -use crate::{AssemblyError, Felt, Spanned}; +use super::{mem_ops::local_to_absolute_addr, push_felt, BasicBlockBuilder}; +use crate::{assembler::context::ProcedureContext, AssemblyError, Felt, Spanned}; use vm_core::Operation::*; // CONSTANT INPUTS @@ -41,9 +41,9 @@ where pub fn locaddr( span: &mut BasicBlockBuilder, index: u16, - context: &AssemblyContext, + proc_ctx: &ProcedureContext, ) -> Result<(), AssemblyError> { - local_to_absolute_addr(span, index, context.unwrap_current_procedure().num_locals()) + local_to_absolute_addr(span, index, proc_ctx.num_locals()) } /// Appends CALLER operation to the span which puts the hash of the function which initiated the @@ -53,13 +53,12 @@ pub fn locaddr( /// Returns an error if the instruction is being executed outside of kernel context. pub fn caller( span: &mut BasicBlockBuilder, - context: &AssemblyContext, + proc_ctx: &ProcedureContext, ) -> Result<(), AssemblyError> { - let current_procedure = context.unwrap_current_procedure(); - if !current_procedure.is_kernel() { + if !proc_ctx.is_kernel() { return Err(AssemblyError::CallerOutsideOfKernel { - span: current_procedure.span(), - source_file: current_procedure.source_file(), + span: proc_ctx.span(), + source_file: proc_ctx.source_file(), }); } span.push_op(Caller); diff --git a/assembly/src/assembler/instruction/field_ops.rs b/assembly/src/assembler/instruction/field_ops.rs index 5833f735d8..963280e283 100644 --- a/assembly/src/assembler/instruction/field_ops.rs +++ b/assembly/src/assembler/instruction/field_ops.rs @@ -1,5 +1,6 @@ -use super::{validate_param, AssemblyContext, BasicBlockBuilder}; +use super::{validate_param, BasicBlockBuilder}; use crate::{ + assembler::context::ProcedureContext, diagnostics::{RelatedError, Report}, AssemblyError, Felt, Span, MAX_EXP_BITS, ONE, ZERO, }; @@ -88,11 +89,11 @@ pub fn mul_imm(span_builder: &mut BasicBlockBuilder, imm: Felt) { /// Returns an error if the immediate value is ZERO. pub fn div_imm( span_builder: &mut BasicBlockBuilder, - ctx: &mut AssemblyContext, + proc_ctx: &mut ProcedureContext, imm: Span, ) -> Result<(), AssemblyError> { if imm == ZERO { - let source_file = ctx.unwrap_current_procedure().source_file(); + let source_file = proc_ctx.source_file(); let error = Report::new(crate::parser::ParsingError::DivisionByZero { span: imm.span() }); return Err(if let Some(source_file) = source_file { AssemblyError::Other(RelatedError::new(error.with_source_code(source_file))) diff --git a/assembly/src/assembler/instruction/mem_ops.rs b/assembly/src/assembler/instruction/mem_ops.rs index 635ac01943..8b54188014 100644 --- a/assembly/src/assembler/instruction/mem_ops.rs +++ b/assembly/src/assembler/instruction/mem_ops.rs @@ -1,5 +1,5 @@ -use super::{push_felt, push_u32_value, validate_param, AssemblyContext, BasicBlockBuilder}; -use crate::{diagnostics::Report, AssemblyError}; +use super::{push_felt, push_u32_value, validate_param, BasicBlockBuilder}; +use crate::{assembler::context::ProcedureContext, diagnostics::Report, AssemblyError}; use alloc::string::ToString; use vm_core::{Felt, Operation::*}; @@ -22,7 +22,7 @@ use vm_core::{Felt, Operation::*}; /// the number of procedure locals. pub fn mem_read( span: &mut BasicBlockBuilder, - context: &AssemblyContext, + proc_ctx: &ProcedureContext, addr: Option, is_local: bool, is_single: bool, @@ -30,7 +30,7 @@ pub fn mem_read( // if the address was provided as an immediate value, put it onto the stack if let Some(addr) = addr { if is_local { - let num_locals = context.unwrap_current_procedure().num_locals(); + let num_locals = proc_ctx.num_locals(); local_to_absolute_addr(span, addr as u16, num_locals)?; } else { push_u32_value(span, addr); @@ -73,14 +73,13 @@ pub fn mem_read( /// the number of procedure locals. pub fn mem_write_imm( span: &mut BasicBlockBuilder, - context: &AssemblyContext, + proc_ctx: &ProcedureContext, addr: u32, is_local: bool, is_single: bool, ) -> Result<(), AssemblyError> { if is_local { - let num_locals = context.unwrap_current_procedure().num_locals(); - local_to_absolute_addr(span, addr as u16, num_locals)?; + local_to_absolute_addr(span, addr as u16, proc_ctx.num_locals())?; } else { push_u32_value(span, addr); } diff --git a/assembly/src/assembler/instruction/mod.rs b/assembly/src/assembler/instruction/mod.rs index 973be8241e..071509cc82 100644 --- a/assembly/src/assembler/instruction/mod.rs +++ b/assembly/src/assembler/instruction/mod.rs @@ -1,5 +1,5 @@ use super::{ - ast::InvokeKind, mast_forest_builder::MastForestBuilder, Assembler, AssemblyContext, + ast::InvokeKind, context::ProcedureContext, mast_forest_builder::MastForestBuilder, Assembler, BasicBlockBuilder, Felt, Instruction, Operation, ONE, ZERO, }; use crate::{diagnostics::Report, utils::bound_into_included_u64, AssemblyError}; @@ -23,18 +23,22 @@ impl Assembler { &self, instruction: &Instruction, span_builder: &mut BasicBlockBuilder, - ctx: &mut AssemblyContext, + proc_ctx: &mut ProcedureContext, mast_forest_builder: &mut MastForestBuilder, ) -> Result, AssemblyError> { // if the assembler is in debug mode, start tracking the instruction about to be executed; // this will allow us to map the instruction to the sequence of operations which were // executed as a part of this instruction. if self.in_debug_mode() { - span_builder.track_instruction(instruction, ctx); + span_builder.track_instruction(instruction, proc_ctx); } - let result = - self.compile_instruction_impl(instruction, span_builder, ctx, mast_forest_builder)?; + let result = self.compile_instruction_impl( + instruction, + span_builder, + proc_ctx, + mast_forest_builder, + )?; // compute and update the cycle count of the instruction which just finished executing if self.in_debug_mode() { @@ -48,7 +52,7 @@ impl Assembler { &self, instruction: &Instruction, span_builder: &mut BasicBlockBuilder, - ctx: &mut AssemblyContext, + proc_ctx: &mut ProcedureContext, mast_forest_builder: &mut MastForestBuilder, ) -> Result, AssemblyError> { use Operation::*; @@ -80,7 +84,7 @@ impl Assembler { Instruction::MulImm(imm) => field_ops::mul_imm(span_builder, imm.expect_value()), Instruction::Div => span_builder.push_ops([Inv, Mul]), Instruction::DivImm(imm) => { - field_ops::div_imm(span_builder, ctx, imm.expect_spanned_value())?; + field_ops::div_imm(span_builder, proc_ctx, imm.expect_spanned_value())?; } Instruction::Neg => span_builder.push_op(Neg), Instruction::Inv => span_builder.push_op(Inv), @@ -166,17 +170,17 @@ impl Assembler { Instruction::U32OverflowingMadd => span_builder.push_op(U32madd), Instruction::U32WrappingMadd => span_builder.push_ops([U32madd, Drop]), - Instruction::U32Div => u32_ops::u32div(span_builder, ctx, None)?, + Instruction::U32Div => u32_ops::u32div(span_builder, proc_ctx, None)?, Instruction::U32DivImm(v) => { - u32_ops::u32div(span_builder, ctx, Some(v.expect_spanned_value()))? + u32_ops::u32div(span_builder, proc_ctx, Some(v.expect_spanned_value()))? } - Instruction::U32Mod => u32_ops::u32mod(span_builder, ctx, None)?, + Instruction::U32Mod => u32_ops::u32mod(span_builder, proc_ctx, None)?, Instruction::U32ModImm(v) => { - u32_ops::u32mod(span_builder, ctx, Some(v.expect_spanned_value()))? + u32_ops::u32mod(span_builder, proc_ctx, Some(v.expect_spanned_value()))? } - Instruction::U32DivMod => u32_ops::u32divmod(span_builder, ctx, None)?, + Instruction::U32DivMod => u32_ops::u32divmod(span_builder, proc_ctx, None)?, Instruction::U32DivModImm(v) => { - u32_ops::u32divmod(span_builder, ctx, Some(v.expect_spanned_value()))? + u32_ops::u32divmod(span_builder, proc_ctx, Some(v.expect_spanned_value()))? } Instruction::U32And => span_builder.push_op(U32and), Instruction::U32Or => span_builder.push_ops([Dup1, Dup1, U32and, Neg, Add, Add]), @@ -307,42 +311,54 @@ impl Assembler { Instruction::PushU32List(imms) => env_ops::push_many(imms, span_builder), Instruction::PushFeltList(imms) => env_ops::push_many(imms, span_builder), Instruction::Sdepth => span_builder.push_op(SDepth), - Instruction::Caller => env_ops::caller(span_builder, ctx)?, + Instruction::Caller => env_ops::caller(span_builder, proc_ctx)?, Instruction::Clk => span_builder.push_op(Clk), Instruction::AdvPipe => span_builder.push_op(Pipe), Instruction::AdvPush(n) => adv_ops::adv_push(span_builder, n.expect_value())?, Instruction::AdvLoadW => span_builder.push_op(AdvPopW), Instruction::MemStream => span_builder.push_op(MStream), - Instruction::Locaddr(v) => env_ops::locaddr(span_builder, v.expect_value(), ctx)?, - Instruction::MemLoad => mem_ops::mem_read(span_builder, ctx, None, false, true)?, + Instruction::Locaddr(v) => env_ops::locaddr(span_builder, v.expect_value(), proc_ctx)?, + Instruction::MemLoad => mem_ops::mem_read(span_builder, proc_ctx, None, false, true)?, Instruction::MemLoadImm(v) => { - mem_ops::mem_read(span_builder, ctx, Some(v.expect_value()), false, true)? + mem_ops::mem_read(span_builder, proc_ctx, Some(v.expect_value()), false, true)? } - Instruction::MemLoadW => mem_ops::mem_read(span_builder, ctx, None, false, false)?, + Instruction::MemLoadW => mem_ops::mem_read(span_builder, proc_ctx, None, false, false)?, Instruction::MemLoadWImm(v) => { - mem_ops::mem_read(span_builder, ctx, Some(v.expect_value()), false, false)? - } - Instruction::LocLoad(v) => { - mem_ops::mem_read(span_builder, ctx, Some(v.expect_value() as u32), true, true)? - } - Instruction::LocLoadW(v) => { - mem_ops::mem_read(span_builder, ctx, Some(v.expect_value() as u32), true, false)? - } + mem_ops::mem_read(span_builder, proc_ctx, Some(v.expect_value()), false, false)? + } + Instruction::LocLoad(v) => mem_ops::mem_read( + span_builder, + proc_ctx, + Some(v.expect_value() as u32), + true, + true, + )?, + Instruction::LocLoadW(v) => mem_ops::mem_read( + span_builder, + proc_ctx, + Some(v.expect_value() as u32), + true, + false, + )?, Instruction::MemStore => span_builder.push_ops([MStore, Drop]), Instruction::MemStoreW => span_builder.push_ops([MStoreW]), Instruction::MemStoreImm(v) => { - mem_ops::mem_write_imm(span_builder, ctx, v.expect_value(), false, true)? + mem_ops::mem_write_imm(span_builder, proc_ctx, v.expect_value(), false, true)? } Instruction::MemStoreWImm(v) => { - mem_ops::mem_write_imm(span_builder, ctx, v.expect_value(), false, false)? + mem_ops::mem_write_imm(span_builder, proc_ctx, v.expect_value(), false, false)? } Instruction::LocStore(v) => { - mem_ops::mem_write_imm(span_builder, ctx, v.expect_value() as u32, true, true)? - } - Instruction::LocStoreW(v) => { - mem_ops::mem_write_imm(span_builder, ctx, v.expect_value() as u32, true, false)? + mem_ops::mem_write_imm(span_builder, proc_ctx, v.expect_value() as u32, true, true)? } + Instruction::LocStoreW(v) => mem_ops::mem_write_imm( + span_builder, + proc_ctx, + v.expect_value() as u32, + true, + false, + )?, Instruction::AdvInject(injector) => adv_ops::adv_inject(span_builder, injector), @@ -364,25 +380,25 @@ impl Assembler { // ----- exec/call instructions ------------------------------------------------------- Instruction::Exec(ref callee) => { - return self.invoke(InvokeKind::Exec, callee, ctx, mast_forest_builder) + return self.invoke(InvokeKind::Exec, callee, proc_ctx, mast_forest_builder) } Instruction::Call(ref callee) => { - return self.invoke(InvokeKind::Call, callee, ctx, mast_forest_builder) + return self.invoke(InvokeKind::Call, callee, proc_ctx, mast_forest_builder) } Instruction::SysCall(ref callee) => { - return self.invoke(InvokeKind::SysCall, callee, ctx, mast_forest_builder) + return self.invoke(InvokeKind::SysCall, callee, proc_ctx, mast_forest_builder) } Instruction::DynExec => return self.dynexec(mast_forest_builder), Instruction::DynCall => return self.dyncall(mast_forest_builder), Instruction::ProcRef(ref callee) => { - self.procref(callee, ctx, span_builder, mast_forest_builder.forest())? + self.procref(callee, proc_ctx, span_builder, mast_forest_builder.forest())? } // ----- debug decorators ------------------------------------------------------------- Instruction::Breakpoint => { if self.in_debug_mode() { span_builder.push_op(Noop); - span_builder.track_instruction(instruction, ctx); + span_builder.track_instruction(instruction, proc_ctx); } } diff --git a/assembly/src/assembler/instruction/procedures.rs b/assembly/src/assembler/instruction/procedures.rs index a86474fce7..0fc8394bf3 100644 --- a/assembly/src/assembler/instruction/procedures.rs +++ b/assembly/src/assembler/instruction/procedures.rs @@ -1,8 +1,8 @@ -use super::{Assembler, AssemblyContext, BasicBlockBuilder, Operation}; +use super::{Assembler, BasicBlockBuilder, Operation}; use crate::{ - assembler::mast_forest_builder::MastForestBuilder, + assembler::{context::ProcedureContext, mast_forest_builder::MastForestBuilder}, ast::{InvocationTarget, InvokeKind}, - AssemblyError, RpoDigest, SourceSpan, Span, Spanned, + AssemblyError, RpoDigest, SourceSpan, Spanned, }; use smallvec::SmallVec; @@ -14,12 +14,12 @@ impl Assembler { &self, kind: InvokeKind, callee: &InvocationTarget, - context: &mut AssemblyContext, + proc_ctx: &mut ProcedureContext, mast_forest_builder: &mut MastForestBuilder, ) -> Result, AssemblyError> { let span = callee.span(); - let digest = self.resolve_target(kind, callee, context, mast_forest_builder.forest())?; - self.invoke_mast_root(kind, span, digest, context, mast_forest_builder) + let digest = self.resolve_target(kind, callee, proc_ctx, mast_forest_builder.forest())?; + self.invoke_mast_root(kind, span, digest, proc_ctx, mast_forest_builder) } fn invoke_mast_root( @@ -27,15 +27,15 @@ impl Assembler { kind: InvokeKind, span: SourceSpan, mast_root: RpoDigest, - context: &mut AssemblyContext, + proc_ctx: &mut ProcedureContext, mast_forest_builder: &mut MastForestBuilder, ) -> Result, AssemblyError> { // Get the procedure from the assembler let cache = &self.procedure_cache; - let current_source_file = context.unwrap_current_procedure().source_file(); + let current_source_file = proc_ctx.source_file(); // If the procedure is cached, register the call to ensure the callset - // is updated correctly. Otherwise, register a phantom call. + // is updated correctly. match cache.get_by_mast_root(&mast_root) { Some(proc) if matches!(kind, InvokeKind::SysCall) => { // Verify if this is a syscall, that the callee is a kernel procedure @@ -69,10 +69,10 @@ impl Assembler { }) } })?; - context.register_external_call(&proc, false, mast_forest_builder.forest())?; + proc_ctx.register_external_call(&proc, false, mast_forest_builder.forest())?; } Some(proc) => { - context.register_external_call(&proc, false, mast_forest_builder.forest())? + proc_ctx.register_external_call(&proc, false, mast_forest_builder.forest())? } None if matches!(kind, InvokeKind::SysCall) => { return Err(AssemblyError::UnknownSysCallTarget { @@ -81,7 +81,7 @@ impl Assembler { callee: mast_root, }); } - None => context.register_phantom_call(Span::new(span, mast_root))?, + None => (), } let mast_root_node_id = { @@ -156,29 +156,25 @@ impl Assembler { pub(super) fn procref( &self, callee: &InvocationTarget, - context: &mut AssemblyContext, + proc_ctx: &mut ProcedureContext, span_builder: &mut BasicBlockBuilder, mast_forest: &MastForest, ) -> Result<(), AssemblyError> { - let span = callee.span(); - let digest = self.resolve_target(InvokeKind::Exec, callee, context, mast_forest)?; - self.procref_mast_root(span, digest, context, span_builder, mast_forest) + let digest = self.resolve_target(InvokeKind::Exec, callee, proc_ctx, mast_forest)?; + self.procref_mast_root(digest, proc_ctx, span_builder, mast_forest) } fn procref_mast_root( &self, - span: SourceSpan, mast_root: RpoDigest, - context: &mut AssemblyContext, + proc_ctx: &mut ProcedureContext, span_builder: &mut BasicBlockBuilder, mast_forest: &MastForest, ) -> Result<(), AssemblyError> { // Add the root to the callset to be able to use dynamic instructions // with the referenced procedure later - let cache = &self.procedure_cache; - match cache.get_by_mast_root(&mast_root) { - Some(proc) => context.register_external_call(&proc, false, mast_forest)?, - None => context.register_phantom_call(Span::new(span, mast_root))?, + if let Some(proc) = self.procedure_cache.get_by_mast_root(&mast_root) { + proc_ctx.register_external_call(&proc, false, mast_forest)?; } // Create an array with `Push` operations containing root elements diff --git a/assembly/src/assembler/instruction/u32_ops.rs b/assembly/src/assembler/instruction/u32_ops.rs index 4a3223cfa3..fc938c6370 100644 --- a/assembly/src/assembler/instruction/u32_ops.rs +++ b/assembly/src/assembler/instruction/u32_ops.rs @@ -1,7 +1,8 @@ use super::{field_ops::append_pow2_op, push_u32_value, validate_param, BasicBlockBuilder}; use crate::{ + assembler::context::ProcedureContext, diagnostics::{RelatedError, Report}, - AssemblyContext, AssemblyError, Span, MAX_U32_ROTATE_VALUE, MAX_U32_SHIFT_VALUE, + AssemblyError, Span, MAX_U32_ROTATE_VALUE, MAX_U32_SHIFT_VALUE, }; use vm_core::{ AdviceInjector, Felt, @@ -116,10 +117,10 @@ pub fn u32mul(span_builder: &mut BasicBlockBuilder, op_mode: U32OpMode, imm: Opt /// - 3 cycles if b is not 1 pub fn u32div( span_builder: &mut BasicBlockBuilder, - ctx: &AssemblyContext, + proc_ctx: &ProcedureContext, imm: Option>, ) -> Result<(), AssemblyError> { - handle_division(span_builder, ctx, imm)?; + handle_division(span_builder, proc_ctx, imm)?; span_builder.push_op(Drop); Ok(()) } @@ -133,10 +134,10 @@ pub fn u32div( /// - 4 cycles if b is not 1 pub fn u32mod( span_builder: &mut BasicBlockBuilder, - ctx: &AssemblyContext, + proc_ctx: &ProcedureContext, imm: Option>, ) -> Result<(), AssemblyError> { - handle_division(span_builder, ctx, imm)?; + handle_division(span_builder, proc_ctx, imm)?; span_builder.push_ops([Swap, Drop]); Ok(()) } @@ -150,10 +151,10 @@ pub fn u32mod( /// - 2 cycles if b is not 1 pub fn u32divmod( span_builder: &mut BasicBlockBuilder, - ctx: &AssemblyContext, + proc_ctx: &ProcedureContext, imm: Option>, ) -> Result<(), AssemblyError> { - handle_division(span_builder, ctx, imm) + handle_division(span_builder, proc_ctx, imm) } // BITWISE OPERATIONS @@ -366,12 +367,12 @@ fn handle_arithmetic_operation( /// immediate parameters. fn handle_division( span_builder: &mut BasicBlockBuilder, - ctx: &AssemblyContext, + proc_ctx: &ProcedureContext, imm: Option>, ) -> Result<(), AssemblyError> { if let Some(imm) = imm { if imm == 0 { - let source_file = ctx.unwrap_current_procedure().source_file(); + let source_file = proc_ctx.source_file(); let error = Report::new(crate::parser::ParsingError::DivisionByZero { span: imm.span() }); return if let Some(source_file) = source_file { diff --git a/assembly/src/assembler/mast_forest_builder.rs b/assembly/src/assembler/mast_forest_builder.rs index 2d0bacbb3e..359867255c 100644 --- a/assembly/src/assembler/mast_forest_builder.rs +++ b/assembly/src/assembler/mast_forest_builder.rs @@ -15,10 +15,6 @@ pub struct MastForestBuilder { } impl MastForestBuilder { - pub fn new() -> Self { - Self::default() - } - pub fn build(self) -> MastForest { self.mast_forest } diff --git a/assembly/src/assembler/mod.rs b/assembly/src/assembler/mod.rs index 2ef43cdb03..b23ae6d36c 100644 --- a/assembly/src/assembler/mod.rs +++ b/assembly/src/assembler/mod.rs @@ -1,9 +1,8 @@ use crate::{ ast::{ - self, AliasTarget, Export, FullyQualifiedProcedureName, Instruction, InvocationTarget, - InvokeKind, ModuleKind, ProcedureIndex, + self, FullyQualifiedProcedureName, Instruction, InvocationTarget, InvokeKind, ModuleKind, }, - diagnostics::{tracing::instrument, Report}, + diagnostics::Report, sema::SemanticAnalysisError, AssemblyError, Compile, CompileOptions, Felt, Library, LibraryNamespace, LibraryPath, RpoDigest, Spanned, ONE, ZERO, @@ -25,7 +24,6 @@ mod procedure; #[cfg(test)] mod tests; -pub use self::context::AssemblyContext; pub use self::id::{GlobalProcedureIndex, ModuleIndex}; pub(crate) use self::module_graph::ProcedureCache; pub use self::procedure::Procedure; @@ -34,33 +32,6 @@ use self::basic_block_builder::BasicBlockBuilder; use self::context::ProcedureContext; use self::module_graph::{CallerInfo, ModuleGraph, ResolvedTarget}; -// ARTIFACT KIND -// ================================================================================================ - -/// Represents the type of artifact produced by an [Assembler]. -#[derive(Default, Copy, Clone, PartialEq, Eq, Hash)] -#[non_exhaustive] -pub enum ArtifactKind { - /// Produce an executable program. - /// - /// This is the default artifact produced by the assembler, and is the only artifact that is - /// useful on its own. - #[default] - Executable, - /// Produce a MAST library - /// - /// The assembler will produce MAST in binary form which can be packaged and distributed. - /// These artifacts can then be loaded by the VM with an executable program that references - /// the contents of the library, without having to compile them together. - Library, - /// Produce a MAST kernel module - /// - /// The assembler will produce MAST for a kernel module, which is essentially the same as - /// [crate::Library], however additional constraints are imposed on compilation to ensure that - /// the produced kernel is valid. - Kernel, -} - // ASSEMBLER // ================================================================================================ @@ -72,9 +43,6 @@ pub enum ArtifactKind { /// Depending on your needs, there are multiple ways of using the assembler, and whether or not you /// want to provide a custom kernel. /// -/// By default, an empty kernel is provided. However, you may provide your own using -/// [Assembler::with_kernel] or [Assembler::with_kernel_from_module]. -/// ///
/// Programs compiled with an empty kernel cannot use the `syscall` instruction. ///
@@ -84,36 +52,20 @@ pub enum ArtifactKind { /// procedures, build the assembler with them first, using the various builder methods on /// [Assembler], e.g. [Assembler::with_module], [Assembler::with_library], etc. Then, call /// [Assembler::assemble] to get your compiled program. -#[derive(Clone)] +#[derive(Clone, Default)] pub struct Assembler { - mast_forest_builder: MastForestBuilder, - /// The global [ModuleGraph] for this assembler. All new [AssemblyContext]s inherit this graph - /// as a baseline. - module_graph: Box, + /// The global [ModuleGraph] for this assembler. + module_graph: ModuleGraph, /// The global procedure cache for this assembler. procedure_cache: ProcedureCache, /// Whether to treat warning diagnostics as errors warnings_as_errors: bool, /// Whether the assembler enables extra debugging information. in_debug_mode: bool, - /// Whether the assembler allows unknown invocation targets in compiled code. - allow_phantom_calls: bool, } -impl Default for Assembler { - fn default() -> Self { - Self { - mast_forest_builder: Default::default(), - module_graph: Default::default(), - procedure_cache: Default::default(), - warnings_as_errors: false, - in_debug_mode: false, - allow_phantom_calls: true, - } - } -} - -/// Builder +// ------------------------------------------------------------------------------------------------ +/// Constructors impl Assembler { /// Start building an [Assembler] pub fn new() -> Self { @@ -129,26 +81,6 @@ impl Assembler { assembler } - /// Start building an [Assembler], with a kernel given by compiling the given source module. - /// - /// # Errors - /// Returns an error if compiling kernel source results in an error. - pub fn with_kernel_from_module(module: impl Compile) -> Result { - let mut assembler = Self::new(); - let opts = CompileOptions::for_kernel(); - let module = module.compile_with_options(opts)?; - - let mut mast_forest_builder = MastForestBuilder::new(); - - let (kernel_index, kernel) = - assembler.assemble_kernel_module(module, &mut mast_forest_builder)?; - assembler.module_graph.set_kernel(Some(kernel_index), kernel); - - assembler.mast_forest_builder = mast_forest_builder; - - Ok(assembler) - } - /// Sets the default behavior of this assembler with regard to warning diagnostics. /// /// When true, any warning diagnostics that are emitted will be promoted to errors. @@ -163,12 +95,6 @@ impl Assembler { self } - /// Sets whether to allow phantom calls. - pub fn with_phantom_calls(mut self, yes: bool) -> Self { - self.allow_phantom_calls = yes; - self - } - /// Adds `module` to the module graph of the assembler. /// /// The given module must be a library module, or an error will be returned. @@ -277,7 +203,8 @@ impl Assembler { } } -/// Queries +// ------------------------------------------------------------------------------------------------ +/// Public Accessors impl Assembler { /// Returns true if this assembler promotes warning diagnostics as errors by default. pub fn warnings_as_errors(&self) -> bool { @@ -296,11 +223,6 @@ impl Assembler { self.module_graph.kernel() } - /// Returns true if this assembler was instantiated with phantom calls enabled. - pub fn allow_phantom_calls(&self) -> bool { - self.allow_phantom_calls - } - #[cfg(any(test, feature = "testing"))] #[doc(hidden)] pub fn module_graph(&self) -> &ModuleGraph { @@ -308,6 +230,7 @@ impl Assembler { } } +// ------------------------------------------------------------------------------------------------ /// Compilation/Assembly impl Assembler { /// Compiles the provided module into a [`Program`]. The resulting program can be executed on @@ -318,23 +241,12 @@ impl Assembler { /// Returns an error if parsing or compilation of the specified program fails, or if the source /// doesn't have an entrypoint. pub fn assemble(self, source: impl Compile) -> Result { - let mut context = AssemblyContext::default(); - context.set_warnings_as_errors(self.warnings_as_errors); - - self.assemble_in_context(source, &mut context) - } - - /// Like [Assembler::assemble], but also takes an [AssemblyContext] to configure the assembler. - pub fn assemble_in_context( - self, - source: impl Compile, - context: &mut AssemblyContext, - ) -> Result { let opts = CompileOptions { - warnings_as_errors: context.warnings_as_errors(), + warnings_as_errors: self.warnings_as_errors, ..CompileOptions::default() }; - self.assemble_with_options_in_context(source, opts, context) + + self.assemble_with_options(source, opts) } /// Compiles the provided module into a [Program] using the provided options. @@ -345,38 +257,10 @@ impl Assembler { /// /// Returns an error if parsing or compilation of the specified program fails, or the options /// are invalid. - pub fn assemble_with_options( - self, - source: impl Compile, - options: CompileOptions, - ) -> Result { - let mut context = AssemblyContext::default(); - context.set_warnings_as_errors(options.warnings_as_errors); - - self.assemble_with_options_in_context(source, options, &mut context) - } - - /// Like [Assembler::assemble_with_options], but additionally uses the provided - /// [AssemblyContext] to configure the assembler. - #[instrument("assemble_with_opts_in_context", skip_all)] - pub fn assemble_with_options_in_context( - self, - source: impl Compile, - options: CompileOptions, - context: &mut AssemblyContext, - ) -> Result { - self.assemble_with_options_in_context_impl(source, options, context) - } - - /// Implementation of [`Self::assemble_with_options_in_context`] which doesn't consume `self`. - /// - /// The main purpose of this separation is to enable some tests to access the assembler state - /// after assembly. - fn assemble_with_options_in_context_impl( + fn assemble_with_options( mut self, source: impl Compile, options: CompileOptions, - context: &mut AssemblyContext, ) -> Result { if options.kind != ModuleKind::Executable { return Err(Report::msg( @@ -384,23 +268,16 @@ impl Assembler { )); } - let mast_forest_builder = core::mem::take(&mut self.mast_forest_builder); + let mast_forest_builder = MastForestBuilder::default(); let program = source.compile_with_options(CompileOptions { // Override the module name so that we always compile the executable - // module as #exec + // module as #exe path: Some(LibraryPath::from(LibraryNamespace::Exec)), ..options })?; assert!(program.is_executable()); - // Remove any previously compiled executable module and clean up graph - let prev_program = self.module_graph.find_module_index(program.path()); - if let Some(module_index) = prev_program { - self.module_graph.remove_module(module_index); - self.procedure_cache.remove_module(module_index); - } - // Recompute graph with executable module, and start compiling let module_index = self.module_graph.add_module(program)?; self.module_graph.recompute()?; @@ -414,166 +291,7 @@ impl Assembler { }) .ok_or(SemanticAnalysisError::MissingEntrypoint)?; - self.compile_program(entrypoint, context, mast_forest_builder) - } - - /// Compile and assembles all procedures in the specified module, adding them to the procedure - /// cache. - /// - /// Returns a vector of procedure digests for all exported procedures in the module. - /// - /// The provided context is used to determine what type of module to assemble, i.e. either - /// a kernel or library module. - pub fn assemble_module( - &mut self, - module: impl Compile, - options: CompileOptions, - context: &mut AssemblyContext, - ) -> Result, Report> { - match context.kind() { - _ if options.kind.is_executable() => { - return Err(Report::msg( - "invalid compile options: expected configuration for library or kernel module ", - )) - } - ArtifactKind::Executable => { - return Err(Report::msg( - "invalid context: expected context configured for library or kernel modules", - )) - } - ArtifactKind::Kernel if !options.kind.is_kernel() => { - return Err(Report::msg( - "invalid context: cannot assemble a kernel from a module compiled as a library", - )) - } - ArtifactKind::Library if !options.kind.is_library() => { - return Err(Report::msg( - "invalid context: cannot assemble a library from a module compiled as a kernel", - )) - } - ArtifactKind::Kernel | ArtifactKind::Library => (), - } - - // Compile module - let module = module.compile_with_options(options)?; - - // Recompute graph with the provided module, and start assembly - let module_id = self.module_graph.add_module(module)?; - self.module_graph.recompute()?; - - let mut mast_forest_builder = core::mem::take(&mut self.mast_forest_builder); - - self.assemble_graph(context, &mut mast_forest_builder)?; - let exported_procedure_digests = - self.get_module_exports(module_id, mast_forest_builder.forest()); - - // Reassign the mast_forest to the assembler for use is a future program assembly - self.mast_forest_builder = mast_forest_builder; - - exported_procedure_digests - } - - /// Compiles the given kernel module, returning both the compiled kernel and its index in the - /// graph. - fn assemble_kernel_module( - &mut self, - module: Box, - mast_forest_builder: &mut MastForestBuilder, - ) -> Result<(ModuleIndex, Kernel), Report> { - if !module.is_kernel() { - return Err(Report::msg(format!("expected kernel module, got {}", module.kind()))); - } - - let mut context = AssemblyContext::for_kernel(module.path()); - context.set_warnings_as_errors(self.warnings_as_errors); - - let kernel_index = self.module_graph.add_module(module)?; - self.module_graph.recompute()?; - let kernel_module = self.module_graph[kernel_index].clone(); - let mut kernel = Vec::new(); - for (index, _syscall) in kernel_module - .procedures() - .enumerate() - .filter(|(_, p)| p.visibility().is_syscall()) - { - let gid = GlobalProcedureIndex { - module: kernel_index, - index: ProcedureIndex::new(index), - }; - let compiled = self.compile_subgraph(gid, false, &mut context, mast_forest_builder)?; - kernel.push(compiled.mast_root(mast_forest_builder.forest())); - } - - Kernel::new(&kernel) - .map(|kernel| (kernel_index, kernel)) - .map_err(|err| Report::new(AssemblyError::Kernel(err))) - } - - /// Get the set of procedure roots for all exports of the given module - /// - /// Returns an error if the provided Miden Assembly is invalid. - fn get_module_exports( - &mut self, - module: ModuleIndex, - mast_forest: &MastForest, - ) -> Result, Report> { - assert!(self.module_graph.contains_module(module), "invalid module index"); - - let mut exports = Vec::new(); - for (index, procedure) in self.module_graph[module].procedures().enumerate() { - // Only add exports to the code block table, locals will - // be added if they are in the call graph rooted at those - // procedures - if !procedure.visibility().is_exported() { - continue; - } - let gid = match procedure { - Export::Procedure(_) => GlobalProcedureIndex { - module, - index: ProcedureIndex::new(index), - }, - Export::Alias(ref alias) => { - match alias.target() { - AliasTarget::MastRoot(digest) => { - self.procedure_cache.contains_mast_root(digest) - .unwrap_or_else(|| { - panic!( - "compilation apparently succeeded, but did not find a \ - entry in the procedure cache for alias '{}', i.e. '{}'", - alias.name(), - digest - ); - }) - } - AliasTarget::Path(ref name)=> { - self.module_graph.find(alias.source_file(), name)? - } - } - } - }; - let proc = self.procedure_cache.get(gid).unwrap_or_else(|| match procedure { - Export::Procedure(ref proc) => { - panic!( - "compilation apparently succeeded, but did not find a \ - entry in the procedure cache for '{}'", - proc.name() - ) - } - Export::Alias(ref alias) => { - panic!( - "compilation apparently succeeded, but did not find a \ - entry in the procedure cache for alias '{}', i.e. '{}'", - alias.name(), - alias.target() - ); - } - }); - - let proc_code_node = &mast_forest[proc.body_node_id()]; - exports.push(proc_code_node.digest()); - } - - Ok(exports) + self.compile_program(entrypoint, mast_forest_builder) } /// Compile the provided [Module] into a [Program]. @@ -582,17 +300,15 @@ impl Assembler { /// /// Returns an error if the provided Miden Assembly is invalid. fn compile_program( - &mut self, + mut self, entrypoint: GlobalProcedureIndex, - context: &mut AssemblyContext, mut mast_forest_builder: MastForestBuilder, ) -> Result { // Raise an error if we are called with an invalid entrypoint assert!(self.module_graph[entrypoint].name().is_main()); // Compile the module graph rooted at the entrypoint - let entry_procedure = - self.compile_subgraph(entrypoint, true, context, &mut mast_forest_builder)?; + let entry_procedure = self.compile_subgraph(entrypoint, true, &mut mast_forest_builder)?; Ok(Program::with_kernel( mast_forest_builder.build(), @@ -601,21 +317,6 @@ impl Assembler { )) } - /// Compile all of the uncompiled procedures in the module graph, placing them - /// in the procedure cache once compiled. - /// - /// Returns an error if any of the provided Miden Assembly is invalid. - fn assemble_graph( - &mut self, - context: &mut AssemblyContext, - mast_forest_builder: &mut MastForestBuilder, - ) -> Result<(), Report> { - let mut worklist = self.module_graph.topological_sort().to_vec(); - assert!(!worklist.is_empty()); - self.process_graph_worklist(&mut worklist, context, None, mast_forest_builder) - .map(|_| ()) - } - /// Compile the uncompiled procedure in the module graph which are members of the subgraph /// rooted at `root`, placing them in the procedure cache once compiled. /// @@ -624,7 +325,6 @@ impl Assembler { &mut self, root: GlobalProcedureIndex, is_entrypoint: bool, - context: &mut AssemblyContext, mast_forest_builder: &mut MastForestBuilder, ) -> Result, Report> { let mut worklist = self.module_graph.topological_sort_from_root(root).map_err(|cycle| { @@ -641,10 +341,9 @@ impl Assembler { assert!(!worklist.is_empty()); let compiled = if is_entrypoint { - self.process_graph_worklist(&mut worklist, context, Some(root), mast_forest_builder)? + self.process_graph_worklist(&mut worklist, Some(root), mast_forest_builder)? } else { - let _ = - self.process_graph_worklist(&mut worklist, context, None, mast_forest_builder)?; + let _ = self.process_graph_worklist(&mut worklist, None, mast_forest_builder)?; self.procedure_cache.get(root) }; @@ -654,7 +353,6 @@ impl Assembler { fn process_graph_worklist( &mut self, worklist: &mut Vec, - context: &mut AssemblyContext, entrypoint: Option, mast_forest_builder: &mut MastForestBuilder, ) -> Result>, Report> { @@ -687,7 +385,10 @@ impl Assembler { .with_source_file(ast.source_file()); // Compile this procedure - let procedure = self.compile_procedure(pctx, context, mast_forest_builder)?; + let procedure = self.compile_procedure(pctx, mast_forest_builder)?; + + // register the procedure in the MAST forest + mast_forest_builder.make_root(procedure.body_node_id()); // Cache the compiled procedure, unless it's the program entrypoint if is_entry { @@ -711,14 +412,12 @@ impl Assembler { /// Compiles a single Miden Assembly procedure to its MAST representation. fn compile_procedure( &self, - procedure: ProcedureContext, - context: &mut AssemblyContext, + mut proc_ctx: ProcedureContext, mast_forest_builder: &mut MastForestBuilder, ) -> Result, Report> { // Make sure the current procedure context is available during codegen - let gid = procedure.id(); - let num_locals = procedure.num_locals(); - context.set_current_procedure(procedure); + let gid = proc_ctx.id(); + let num_locals = proc_ctx.num_locals(); let proc = self.module_graph[gid].unwrap_procedure(); let proc_body_root = if num_locals > 0 { @@ -731,21 +430,18 @@ impl Assembler { prologue: vec![Operation::Push(num_locals), Operation::FmpUpdate], epilogue: vec![Operation::Push(-num_locals), Operation::FmpUpdate], }; - self.compile_body(proc.iter(), context, Some(wrapper), mast_forest_builder)? + self.compile_body(proc.iter(), &mut proc_ctx, Some(wrapper), mast_forest_builder)? } else { - self.compile_body(proc.iter(), context, None, mast_forest_builder)? + self.compile_body(proc.iter(), &mut proc_ctx, None, mast_forest_builder)? }; - mast_forest_builder.make_root(proc_body_root); - - let pctx = context.take_current_procedure().unwrap(); - Ok(pctx.into_procedure(proc_body_root)) + Ok(proc_ctx.into_procedure(proc_body_root)) } fn compile_body<'a, I>( &self, body: I, - context: &mut AssemblyContext, + proc_ctx: &mut ProcedureContext, wrapper: Option, mast_forest_builder: &mut MastForestBuilder, ) -> Result @@ -763,7 +459,7 @@ impl Assembler { if let Some(mast_node_id) = self.compile_instruction( inst, &mut basic_block_builder, - context, + proc_ctx, mast_forest_builder, )? { if let Some(basic_block_id) = basic_block_builder @@ -788,9 +484,9 @@ impl Assembler { } let then_blk = - self.compile_body(then_blk.iter(), context, None, mast_forest_builder)?; + self.compile_body(then_blk.iter(), proc_ctx, None, mast_forest_builder)?; let else_blk = - self.compile_body(else_blk.iter(), context, None, mast_forest_builder)?; + self.compile_body(else_blk.iter(), proc_ctx, None, mast_forest_builder)?; let split_node_id = mast_forest_builder .ensure_split(then_blk, else_blk) @@ -807,7 +503,7 @@ impl Assembler { } let repeat_node_id = - self.compile_body(body.iter(), context, None, mast_forest_builder)?; + self.compile_body(body.iter(), proc_ctx, None, mast_forest_builder)?; for _ in 0..*count { mast_node_ids.push(repeat_node_id); @@ -823,7 +519,7 @@ impl Assembler { } let loop_body_node_id = - self.compile_body(body.iter(), context, None, mast_forest_builder)?; + self.compile_body(body.iter(), proc_ctx, None, mast_forest_builder)?; let loop_node_id = mast_forest_builder .ensure_loop(loop_body_node_id) @@ -853,14 +549,13 @@ impl Assembler { &self, kind: InvokeKind, target: &InvocationTarget, - context: &AssemblyContext, + proc_ctx: &ProcedureContext, mast_forest: &MastForest, ) -> Result { - let current_proc = context.unwrap_current_procedure(); let caller = CallerInfo { span: target.span(), - source_file: current_proc.source_file(), - module: current_proc.id().module, + source_file: proc_ctx.source_file(), + module: proc_ctx.id().module, kind, }; let resolved = self.module_graph.resolve_target(&caller, target)?; @@ -875,6 +570,9 @@ impl Assembler { } } +// HELPERS +// ================================================================================================ + /// Contains a set of operations which need to be executed before and after a sequence of AST /// nodes (i.e., code body). struct BodyWrapper { diff --git a/assembly/src/assembler/module_graph/callgraph.rs b/assembly/src/assembler/module_graph/callgraph.rs index 6209890c2f..0b688b2f77 100644 --- a/assembly/src/assembler/module_graph/callgraph.rs +++ b/assembly/src/assembler/module_graph/callgraph.rs @@ -3,7 +3,7 @@ use alloc::{ vec::Vec, }; -use crate::assembler::{GlobalProcedureIndex, ModuleIndex}; +use crate::assembler::GlobalProcedureIndex; /// Represents the inability to construct a topological ordering of the nodes in a [CallGraph] /// due to a cycle in the graph, which can happen due to recursion. @@ -80,19 +80,6 @@ impl CallGraph { callees.push(callee); } - /// Removes all edges to/from a procedure in `module` - /// - /// NOTE: If a procedure that is removed has predecessors (callers) in the graph, this will - /// remove those edges, and the graph will be incomplete and not reflect the "true" call graph. - /// In practice, we are recomputing the graph after making such modifications, so this a - /// temporary state of affairs - still, it is important to be aware of this behavior. - pub fn remove_edges_for_module(&mut self, module: ModuleIndex) { - for (_, out_edges) in self.nodes.iter_mut() { - out_edges.retain(|gid| gid.module != module); - } - self.nodes.retain(|gid, _| gid.module != module); - } - /// Removes the edge between `caller` and `callee` from the graph pub fn remove_edge(&mut self, caller: GlobalProcedureIndex, callee: GlobalProcedureIndex) { if let Some(out_edges) = self.nodes.get_mut(&caller) { diff --git a/assembly/src/assembler/module_graph/mod.rs b/assembly/src/assembler/module_graph/mod.rs index 9ebb4c2157..12dafdd4df 100644 --- a/assembly/src/assembler/module_graph/mod.rs +++ b/assembly/src/assembler/module_graph/mod.rs @@ -2,7 +2,6 @@ mod analysis; mod callgraph; mod debug; mod name_resolver; -mod phantom; mod procedure_cache; mod rewrites; @@ -10,29 +9,19 @@ pub use self::callgraph::{CallGraph, CycleError}; pub use self::name_resolver::{CallerInfo, ResolvedTarget}; pub use self::procedure_cache::ProcedureCache; -use alloc::{ - borrow::Cow, - boxed::Box, - collections::{BTreeMap, BTreeSet}, - sync::Arc, - vec::Vec, -}; +use alloc::{boxed::Box, collections::BTreeMap, sync::Arc, vec::Vec}; use core::ops::Index; use vm_core::Kernel; use smallvec::{smallvec, SmallVec}; -use self::{ - analysis::MaybeRewriteCheck, name_resolver::NameResolver, phantom::PhantomCall, - rewrites::ModuleRewriter, -}; +use self::{analysis::MaybeRewriteCheck, name_resolver::NameResolver, rewrites::ModuleRewriter}; use super::{GlobalProcedureIndex, ModuleIndex}; use crate::{ ast::{ Export, FullyQualifiedProcedureName, InvocationTarget, Module, Procedure, ProcedureIndex, - ProcedureName, ResolvedProcedure, + ProcedureName, }, - diagnostics::{RelatedLabel, SourceFile}, AssemblyError, LibraryPath, RpoDigest, Spanned, }; @@ -62,18 +51,12 @@ pub struct ModuleGraph { roots: BTreeMap>, /// The set of procedures in this graph which have known MAST roots digests: BTreeMap, - /// The set of procedures which have no known definition in the graph, aka "phantom calls". - /// Since we know the hash of these functions, we can proceed with compilation, but in some - /// contexts we wish to disallow them and raise an error if any such calls are present. - /// - /// When we merge graphs, we attempt to resolve phantoms by attempting to find definitions in - /// the opposite graph. - phantoms: BTreeSet, kernel_index: Option, kernel: Kernel, } -/// Construction +// ------------------------------------------------------------------------------------------------ +/// Constructors impl ModuleGraph { /// Add `module` to the graph. /// @@ -106,57 +89,6 @@ impl ModuleGraph { Ok(module_id) } - /// Remove a module from the graph by discarding any edges involving that module. We do not - /// remove the module from the node set by default, so as to preserve the stability of indices - /// in the graph. However, we do remove the module from the set if it is the most recently - /// added module, as that matches the most common case of compiling multiple programs in a row, - /// where we discard the executable module each time. - pub fn remove_module(&mut self, index: ModuleIndex) { - use alloc::collections::btree_map::Entry; - - // If the given index is a pending module, we just remove it from the pending set and call - // it a day - let pending_offset = self.modules.len(); - if index.as_usize() >= pending_offset { - self.pending.remove(index.as_usize() - pending_offset); - return; - } - - self.callgraph.remove_edges_for_module(index); - - // We remove all nodes from the topological sort that belong to the given module. The - // resulting sort is still valid, but may change the next time it is computed - self.topo.retain(|gid| gid.module != index); - - // Remove any cached procedure roots for the given module - for (gid, digest) in self.digests.iter() { - if gid.module != index { - continue; - } - if let Entry::Occupied(mut entry) = self.roots.entry(*digest) { - if entry.get().iter().all(|gid| gid.module == index) { - entry.remove(); - } else { - entry.get_mut().retain(|gid| gid.module != index); - } - } - } - self.digests.retain(|gid, _| gid.module != index); - self.roots.retain(|_, gids| !gids.is_empty()); - - // Handle removing the kernel module - if self.kernel_index == Some(index) { - self.kernel_index = None; - self.kernel = Default::default(); - } - - // If the module being removed comes last in the node set, remove it from the set to avoid - // growing the set unnecessarily over time. - if index.as_usize() == self.modules.len().saturating_sub(1) { - self.modules.pop(); - } - } - fn is_pending(&self, path: &LibraryPath) -> bool { self.pending.iter().any(|m| m.path() == path) } @@ -167,6 +99,7 @@ impl ModuleGraph { } } +// ------------------------------------------------------------------------------------------------ /// Kernels impl ModuleGraph { pub(super) fn set_kernel(&mut self, kernel_index: Option, kernel: Kernel) { @@ -208,6 +141,7 @@ impl ModuleGraph { } } +// ------------------------------------------------------------------------------------------------ /// Analysis impl ModuleGraph { /// Recompute the module graph. @@ -293,7 +227,6 @@ impl ModuleGraph { for module in pending.iter() { resolver.push_pending(module); } - let mut phantoms = BTreeSet::default(); let mut edges = Vec::new(); let mut finished = Vec::>::new(); @@ -304,9 +237,6 @@ impl ModuleGraph { let mut rewriter = ModuleRewriter::new(&resolver); rewriter.apply(module_id, &mut module)?; - // Gather the phantom calls found while rewriting the module - phantoms.extend(rewriter.phantoms()); - for (index, procedure) in module.procedures().enumerate() { let procedure_id = ProcedureIndex::new(index); let gid = GlobalProcedureIndex { @@ -336,7 +266,6 @@ impl ModuleGraph { drop(resolver); // Extend the graph with all of the new additions - self.phantoms.extend(phantoms); self.modules.append(&mut finished); edges .into_iter() @@ -387,8 +316,6 @@ impl ModuleGraph { let mut rewriter = ModuleRewriter::new(&resolver); rewriter.apply(module_id, &mut module)?; - self.phantoms.extend(rewriter.phantoms()); - Ok(Some(Arc::from(module))) } else { Ok(None) @@ -396,18 +323,9 @@ impl ModuleGraph { } } +// ------------------------------------------------------------------------------------------------ /// Accessors/Queries impl ModuleGraph { - /// Get a slice representing the topological ordering of this graph. - /// - /// The slice is ordered such that when a node is encountered, all of its dependencies come - /// after it in the slice. Thus, by walking the slice in reverse, we visit the leaves of the - /// graph before any of the dependents of those leaves. We use this property to resolve MAST - /// roots for the entire program, bottom-up. - pub fn topological_sort(&self) -> &[GlobalProcedureIndex] { - self.topo.as_slice() - } - /// Compute the topological sort of the callgraph rooted at `caller` pub fn topological_sort_from_root( &self, @@ -422,11 +340,6 @@ impl ModuleGraph { self.modules.get(id.as_usize()).cloned() } - /// Fetch a [Module] by [ModuleIndex] - pub fn contains_module(&self, id: ModuleIndex) -> bool { - self.modules.get(id.as_usize()).is_some() - } - /// Fetch a [Export] by [GlobalProcedureIndex] #[allow(unused)] pub fn get_procedure(&self, id: GlobalProcedureIndex) -> Option<&Export> { @@ -540,65 +453,6 @@ impl ModuleGraph { Ok(()) } - /// Resolves a [FullyQualifiedProcedureName] to its defining [Procedure]. - pub fn find( - &self, - source_file: Option>, - name: &FullyQualifiedProcedureName, - ) -> Result { - let mut next = Cow::Borrowed(name); - let mut caller = source_file.clone(); - loop { - let module_index = self.find_module_index(&next.module).ok_or_else(|| { - AssemblyError::UndefinedModule { - span: next.span(), - source_file: caller.clone(), - path: name.module.clone(), - } - })?; - let module = &self.modules[module_index.as_usize()]; - match module.resolve(&next.name) { - Some(ResolvedProcedure::Local(index)) => { - let id = GlobalProcedureIndex { - module: module_index, - index: index.into_inner(), - }; - break Ok(id); - } - Some(ResolvedProcedure::External(fqn)) => { - // If we see that we're about to enter an infinite resolver loop because of a - // recursive alias, return an error - if name == &fqn { - break Err(AssemblyError::RecursiveAlias { - source_file: caller.clone(), - name: name.clone(), - }); - } - next = Cow::Owned(fqn); - caller = module.source_file(); - } - Some(ResolvedProcedure::MastRoot(ref digest)) => { - if let Some(id) = self.get_procedure_index_by_digest(digest) { - break Ok(id); - } - break Err(AssemblyError::Failed { - labels: vec![RelatedLabel::error("undefined procedure") - .with_source_file(source_file) - .with_labeled_span(next.span(), "unable to resolve this reference")], - }); - } - None => { - // No such procedure known to `module` - break Err(AssemblyError::Failed { - labels: vec![RelatedLabel::error("undefined procedure") - .with_source_file(source_file) - .with_labeled_span(next.span(), "unable to resolve this reference")], - }); - } - } - } - } - /// Resolve a [LibraryPath] to a [ModuleIndex] in this graph pub fn find_module_index(&self, name: &LibraryPath) -> Option { self.modules.iter().position(|m| m.path() == name).map(ModuleIndex::new) diff --git a/assembly/src/assembler/module_graph/phantom.rs b/assembly/src/assembler/module_graph/phantom.rs deleted file mode 100644 index bf4956426f..0000000000 --- a/assembly/src/assembler/module_graph/phantom.rs +++ /dev/null @@ -1,45 +0,0 @@ -use alloc::sync::Arc; - -use crate::{diagnostics::SourceFile, RpoDigest, SourceSpan, Spanned}; - -/// Represents a call to a procedure for which we do not have an implementation. -/// -/// Such calls are still valid, as at runtime they may be supplied to the VM, but we are limited -/// in how much we can reason about such procedures, so we represent them and handle them -/// explicitly. -#[derive(Clone)] -pub struct PhantomCall { - /// The source span associated with the call - pub span: SourceSpan, - /// The source file corresponding to `span`, if available - #[allow(dead_code)] - pub source_file: Option>, - /// The MAST root of the callee - pub callee: RpoDigest, -} - -impl Spanned for PhantomCall { - fn span(&self) -> SourceSpan { - self.span - } -} - -impl Eq for PhantomCall {} - -impl PartialEq for PhantomCall { - fn eq(&self, other: &Self) -> bool { - self.callee.eq(&other.callee) - } -} - -impl Ord for PhantomCall { - fn cmp(&self, other: &Self) -> core::cmp::Ordering { - self.callee.cmp(&other.callee) - } -} - -impl PartialOrd for PhantomCall { - fn partial_cmp(&self, other: &Self) -> Option { - Some(self.cmp(other)) - } -} diff --git a/assembly/src/assembler/module_graph/procedure_cache.rs b/assembly/src/assembler/module_graph/procedure_cache.rs index 02c7a6d351..71171ec3e0 100644 --- a/assembly/src/assembler/module_graph/procedure_cache.rs +++ b/assembly/src/assembler/module_graph/procedure_cache.rs @@ -137,11 +137,6 @@ impl ProcedureCache { .unwrap_or(false) } - /// Returns the [GlobalProcedureIndex] of the procedure with the given MAST root, if cached. - pub fn contains_mast_root(&self, hash: &RpoDigest) -> Option { - self.by_mast_root.get(hash).copied() - } - /// Inserts the given [Procedure] into this cache, using the [GlobalProcedureIndex] as the /// cache key. /// @@ -214,18 +209,6 @@ impl ProcedureCache { Ok(()) } - /// This removes any entries in the cache for procedures in `module` - pub fn remove_module(&mut self, module: ModuleIndex) { - let index = module.as_usize(); - if let Some(slots) = self.cache.get_mut(index) { - slots.clear(); - } - if let Some(path) = self.modules.get_mut(index) { - *path = None; - } - self.by_mast_root.retain(|_digest, gid| gid.module != module); - } - fn ensure_cache_slot_exists(&mut self, id: GlobalProcedureIndex, module: &LibraryPath) { let min_cache_len = id.module.as_usize() + 1; let min_module_len = id.index.as_usize() + 1; diff --git a/assembly/src/assembler/module_graph/rewrites/module.rs b/assembly/src/assembler/module_graph/rewrites/module.rs index 51145e593b..73218ddf8b 100644 --- a/assembly/src/assembler/module_graph/rewrites/module.rs +++ b/assembly/src/assembler/module_graph/rewrites/module.rs @@ -3,7 +3,7 @@ use core::ops::ControlFlow; use crate::{ assembler::{ - module_graph::{CallerInfo, NameResolver, PhantomCall}, + module_graph::{CallerInfo, NameResolver}, ModuleIndex, ResolvedTarget, }, ast::{ @@ -24,7 +24,6 @@ pub struct ModuleRewriter<'a, 'b: 'a> { resolver: &'a NameResolver<'b>, module_id: ModuleIndex, invoked: BTreeSet, - phantoms: BTreeSet, source_file: Option>, } @@ -35,7 +34,6 @@ impl<'a, 'b: 'a> ModuleRewriter<'a, 'b> { resolver, module_id: ModuleIndex::new(u16::MAX as usize), invoked: Default::default(), - phantoms: Default::default(), source_file: None, } } @@ -56,11 +54,6 @@ impl<'a, 'b: 'a> ModuleRewriter<'a, 'b> { Ok(()) } - /// Take the set of accumulated phantom calls out of this rewriter - pub fn phantoms(&mut self) -> BTreeSet { - core::mem::take(&mut self.phantoms) - } - fn rewrite_target( &mut self, kind: InvokeKind, @@ -81,14 +74,7 @@ impl<'a, 'b: 'a> ModuleRewriter<'a, 'b> { target: target.clone(), }); } - Ok(ResolvedTarget::Phantom(callee)) => { - let call = PhantomCall { - span: target.span(), - source_file: self.source_file.clone(), - callee, - }; - self.phantoms.insert(call); - } + Ok(ResolvedTarget::Phantom(_)) => (), Ok(ResolvedTarget::Exact { .. }) => { self.invoked.insert(Invoke { kind, diff --git a/assembly/src/assembler/procedure.rs b/assembly/src/assembler/procedure.rs index 15675f5b35..d29da8d706 100644 --- a/assembly/src/assembler/procedure.rs +++ b/assembly/src/assembler/procedure.rs @@ -72,6 +72,7 @@ impl Procedure { /// Metadata impl Procedure { /// Returns a reference to the name of this procedure + #[allow(unused)] pub fn name(&self) -> &ProcedureName { &self.path.name } @@ -93,6 +94,7 @@ impl Procedure { /// Returns a reference to the Miden Assembly source file from which this /// procedure was compiled, if available. + #[allow(unused)] pub fn source_file(&self) -> Option> { self.source_file.clone() } diff --git a/assembly/src/assembler/tests.rs b/assembly/src/assembler/tests.rs index c557276b2c..523d9072c5 100644 --- a/assembly/src/assembler/tests.rs +++ b/assembly/src/assembler/tests.rs @@ -16,6 +16,9 @@ use crate::{ // TESTS // ================================================================================================ +// TODO: Fix test after we implement the new `Assembler::add_library()` +#[ignore] +#[allow(unused)] #[test] fn nested_blocks() { const MODULE: &str = "foo::bar"; @@ -67,13 +70,10 @@ fn nested_blocks() { } } - let assembler = Assembler::with_kernel_from_module(KERNEL) - .unwrap() - .with_library(&DummyLibrary::default()) - .unwrap(); + let assembler = Assembler::new().with_library(&DummyLibrary::default()).unwrap(); // The expected `MastForest` for the program (that we will build by hand) - let mut expected_mast_forest_builder = MastForestBuilder::new(); + let mut expected_mast_forest_builder = MastForestBuilder::default(); // fetch the kernel digest and store into a syscall block // diff --git a/assembly/src/errors.rs b/assembly/src/errors.rs index 648fd851ef..df31c9e1c1 100644 --- a/assembly/src/errors.rs +++ b/assembly/src/errors.rs @@ -78,15 +78,6 @@ pub enum AssemblyError { #[source_code] source_file: Option>, }, - #[error("cannot call phantom procedure: phantom calls are disabled")] - #[diagnostic(help("mast root is {digest}"))] - PhantomCallsNotAllowed { - #[label("the procedure referenced here is not available")] - span: SourceSpan, - #[source_code] - source_file: Option>, - digest: RpoDigest, - }, #[error("invalid syscall: '{callee}' is not an exported kernel procedure")] #[diagnostic()] InvalidSysCallTarget { diff --git a/assembly/src/lib.rs b/assembly/src/lib.rs index 1131f60695..51078c962e 100644 --- a/assembly/src/lib.rs +++ b/assembly/src/lib.rs @@ -31,7 +31,7 @@ pub mod testing; #[cfg(test)] mod tests; -pub use self::assembler::{ArtifactKind, Assembler, AssemblyContext}; +pub use self::assembler::Assembler; pub use self::compile::{Compile, Options as CompileOptions}; pub use self::errors::AssemblyError; pub use self::library::{ diff --git a/assembly/src/testing.rs b/assembly/src/testing.rs index a9b99051a8..9518f3c617 100644 --- a/assembly/src/testing.rs +++ b/assembly/src/testing.rs @@ -1,5 +1,5 @@ use crate::{ - assembler::{Assembler, AssemblyContext}, + assembler::Assembler, ast::{Form, Module, ModuleKind}, diagnostics::{ reporting::{set_hook, ReportHandlerOpts}, @@ -316,17 +316,10 @@ impl TestContext { #[track_caller] pub fn assemble_module( &mut self, - path: LibraryPath, - module: impl Compile, + _path: LibraryPath, + _module: impl Compile, ) -> Result, Report> { - let mut context = AssemblyContext::for_library(&path); - context.set_warnings_as_errors(self.assembler.warnings_as_errors()); - - let options = CompileOptions { - path: Some(path), - warnings_as_errors: self.assembler.warnings_as_errors(), - ..CompileOptions::for_library() - }; - self.assembler.assemble_module(module, options, &mut context) + // This API will change after we implement `Assembler::add_library()` + unimplemented!() } } diff --git a/assembly/src/tests.rs b/assembly/src/tests.rs index 07cdc00917..9db5847ddc 100644 --- a/assembly/src/tests.rs +++ b/assembly/src/tests.rs @@ -6,7 +6,7 @@ use crate::{ diagnostics::Report, regex, source_file, testing::{Pattern, TestContext}, - Assembler, AssemblyContext, Library, LibraryNamespace, LibraryPath, MaslLibrary, Version, + Assembler, Library, LibraryNamespace, LibraryPath, MaslLibrary, Version, }; type TestResult = Result<(), Report>; @@ -278,6 +278,8 @@ fn simple_main_call() -> TestResult { Ok(()) } +// TODO: Fix test after we implement the new `Assembler::add_library()` +#[ignore] #[test] fn call_without_path() -> TestResult { let mut context = TestContext::default(); @@ -1484,8 +1486,6 @@ fn program_with_invalid_rpo_digest_call() { ); } -/// Phantom calls are currently not implemented. Re-enable this test once they are implemented. -#[ignore] #[test] fn program_with_phantom_mast_call() -> TestResult { let mut context = TestContext::default(); @@ -1494,28 +1494,8 @@ fn program_with_phantom_mast_call() -> TestResult { ); let ast = context.parse_program(source)?; - // phantom calls not allowed - let assembler = Assembler::default().with_debug_mode(true); - - let mut context = AssemblyContext::for_program(ast.path()).with_phantom_calls(false); - let err = assembler - .assemble_in_context(ast.clone(), &mut context) - .expect_err("expected compilation to fail with phantom calls"); - assert_diagnostic_lines!( - err, - "cannot call phantom procedure: phantom calls are disabled", - regex!(r#",-\[test[\d]+:1:12\]"#), - "1 | begin call.0xc2545da99d3a1f3f38d957c7893c44d78998d8ea8b11aba7e22c8c2b2a213dae end", - " : ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^|^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^", - " : `-- the procedure referenced here is not available", - " `----", - " help: mast root is 0xc2545da99d3a1f3f38d957c7893c44d78998d8ea8b11aba7e22c8c2b2a213dae" - ); - - // phantom calls allowed let assembler = Assembler::default().with_debug_mode(true); - let mut context = AssemblyContext::for_program(ast.path()).with_phantom_calls(true); - assembler.assemble_in_context(ast, &mut context)?; + assembler.assemble(ast)?; Ok(()) } diff --git a/miden/tests/integration/flow_control/mod.rs b/miden/tests/integration/flow_control/mod.rs index a9bf914c68..9d2954f960 100644 --- a/miden/tests/integration/flow_control/mod.rs +++ b/miden/tests/integration/flow_control/mod.rs @@ -1,5 +1,6 @@ -use assembly::{ast::ModuleKind, Assembler, AssemblyContext, LibraryPath}; +use assembly::{ast::ModuleKind, Assembler, LibraryPath}; use processor::ExecutionError; +use prover::Digest; use stdlib::StdLibrary; use test_utils::{build_test, expect_exec_error, StackInputs, Test}; @@ -186,6 +187,8 @@ fn local_fn_call_with_mem_access() { test.prove_and_verify(vec![3, 7], false); } +// TODO: Fix test after we implement the new `Assembler::add_library()` +#[ignore] #[test] fn simple_syscall() { let kernel_source = " @@ -386,6 +389,9 @@ fn simple_dyncall() { // PROCREF INSTRUCTION // ================================================================================================ +// TODO: Fix test after we implement the new `Assembler::add_library()` +#[ignore] +#[allow(unused)] #[test] fn procref() { let mut assembler = Assembler::default().with_library(&StdLibrary::default()).unwrap(); @@ -401,9 +407,11 @@ fn procref() { // obtain procedures' MAST roots by compiling them as module let module_path = "test::foo".parse::().unwrap(); - let mut context = AssemblyContext::for_library(&module_path); let opts = assembly::CompileOptions::new(ModuleKind::Library, module_path).unwrap(); - let mast_roots = assembler.assemble_module(module_source, opts, &mut context).unwrap(); + + // TODO: Fix + // let mast_roots = assembler.assemble_module(module_source, opts).unwrap(); + let mast_roots: Vec = Vec::new(); let source = " use.std::math::u64 diff --git a/miden/tests/integration/operations/io_ops/env_ops.rs b/miden/tests/integration/operations/io_ops/env_ops.rs index 5582aa13ae..367a121df9 100644 --- a/miden/tests/integration/operations/io_ops/env_ops.rs +++ b/miden/tests/integration/operations/io_ops/env_ops.rs @@ -126,6 +126,8 @@ fn locaddr() { // CALLER INSTRUCTION // ================================================================================================ +// TODO: Fix test after we implement the new `Assembler::add_library()` +#[ignore] #[test] fn caller() { let kernel_source = " diff --git a/test-utils/src/lib.rs b/test-utils/src/lib.rs index ce8b28f172..48d29d0f39 100644 --- a/test-utils/src/lib.rs +++ b/test-utils/src/lib.rs @@ -277,8 +277,10 @@ impl Test { /// Compiles a test's source and returns the resulting Program or Assembly error. pub fn compile(&self) -> Result { use assembly::{ast::ModuleKind, CompileOptions}; + #[allow(unused)] let assembler = if let Some(kernel) = self.kernel.as_ref() { - assembly::Assembler::with_kernel_from_module(kernel).expect("invalid kernel") + // TODO: Load in kernel after we add the new `Assembler::add_library()` + assembly::Assembler::default() } else { assembly::Assembler::default() }; From 51ab7bb202eed1ed0a4e6bb25d8601e053df4289 Mon Sep 17 00:00:00 2001 From: Bobbin Threadbare <43513081+bobbinth@users.noreply.github.com> Date: Tue, 23 Jul 2024 14:21:54 -0700 Subject: [PATCH 5/7] refactor: remove procedure cache from the assembler (#1411) --- CHANGELOG.md | 1 + assembly/src/assembler/basic_block_builder.rs | 15 +- assembly/src/assembler/context.rs | 139 ------- assembly/src/assembler/id.rs | 8 +- assembly/src/assembler/instruction/env_ops.rs | 2 +- .../src/assembler/instruction/field_ops.rs | 2 +- assembly/src/assembler/instruction/mem_ops.rs | 2 +- assembly/src/assembler/instruction/mod.rs | 6 +- .../src/assembler/instruction/procedures.rs | 35 +- assembly/src/assembler/instruction/u32_ops.rs | 2 +- assembly/src/assembler/mast_forest_builder.rs | 136 +++++-- assembly/src/assembler/mod.rs | 100 ++--- .../module_graph/analysis/rewrite_check.rs | 7 - assembly/src/assembler/module_graph/mod.rs | 37 -- .../assembler/module_graph/name_resolver.rs | 103 ++--- .../assembler/module_graph/procedure_cache.rs | 383 ------------------ .../assembler/module_graph/rewrites/module.rs | 9 +- assembly/src/assembler/procedure.rs | 173 +++++++- 18 files changed, 375 insertions(+), 785 deletions(-) delete mode 100644 assembly/src/assembler/context.rs delete mode 100644 assembly/src/assembler/module_graph/procedure_cache.rs diff --git a/CHANGELOG.md b/CHANGELOG.md index 1354e4a75e..53d198691c 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -21,6 +21,7 @@ - Introduced `MastForestError` to enforce `MastForest` node count invariant (#1394) - Added functions to `MastForestBuilder` to allow ensuring of nodes with fewer LOC (#1404) - Make `Assembler` single-use (#1409) +- Remove `ProcedureCache` from the assembler (#1411). #### Changed diff --git a/assembly/src/assembler/basic_block_builder.rs b/assembly/src/assembler/basic_block_builder.rs index 362b57e138..21b42ab0fa 100644 --- a/assembly/src/assembler/basic_block_builder.rs +++ b/assembly/src/assembler/basic_block_builder.rs @@ -1,12 +1,11 @@ +use crate::AssemblyError; + use super::{ - context::ProcedureContext, mast_forest_builder::MastForestBuilder, BodyWrapper, Decorator, - DecoratorList, Instruction, + mast_forest_builder::MastForestBuilder, BodyWrapper, Decorator, DecoratorList, Instruction, + ProcedureContext, }; use alloc::{borrow::Borrow, string::ToString, vec::Vec}; -use vm_core::{ - mast::{MastForestError, MastNodeId}, - AdviceInjector, AssemblyOp, Operation, -}; +use vm_core::{mast::MastNodeId, AdviceInjector, AssemblyOp, Operation}; // BASIC BLOCK BUILDER // ================================================================================================ @@ -129,7 +128,7 @@ impl BasicBlockBuilder { pub fn make_basic_block( &mut self, mast_forest_builder: &mut MastForestBuilder, - ) -> Result, MastForestError> { + ) -> Result, AssemblyError> { if !self.ops.is_empty() { let ops = self.ops.drain(..).collect(); let decorators = self.decorators.drain(..).collect(); @@ -157,7 +156,7 @@ impl BasicBlockBuilder { pub fn try_into_basic_block( mut self, mast_forest_builder: &mut MastForestBuilder, - ) -> Result, MastForestError> { + ) -> Result, AssemblyError> { self.ops.append(&mut self.epilogue); self.make_basic_block(mast_forest_builder) } diff --git a/assembly/src/assembler/context.rs b/assembly/src/assembler/context.rs deleted file mode 100644 index 9f60511653..0000000000 --- a/assembly/src/assembler/context.rs +++ /dev/null @@ -1,139 +0,0 @@ -use alloc::{boxed::Box, sync::Arc}; - -use super::{procedure::CallSet, GlobalProcedureIndex, Procedure}; -use crate::{ - ast::{FullyQualifiedProcedureName, Visibility}, - diagnostics::SourceFile, - AssemblyError, LibraryPath, RpoDigest, SourceSpan, Spanned, -}; -use vm_core::mast::{MastForest, MastNodeId}; - -// PROCEDURE CONTEXT -// ================================================================================================ - -/// Information about a procedure currently being compiled. -pub struct ProcedureContext { - span: SourceSpan, - source_file: Option>, - gid: GlobalProcedureIndex, - name: FullyQualifiedProcedureName, - visibility: Visibility, - num_locals: u16, - callset: CallSet, -} - -// ------------------------------------------------------------------------------------------------ -/// Constructors -impl ProcedureContext { - pub fn new( - gid: GlobalProcedureIndex, - name: FullyQualifiedProcedureName, - visibility: Visibility, - ) -> Self { - Self { - span: name.span(), - source_file: None, - gid, - name, - visibility, - num_locals: 0, - callset: Default::default(), - } - } - - pub fn with_num_locals(mut self, num_locals: u16) -> Self { - self.num_locals = num_locals; - self - } - - pub fn with_span(mut self, span: SourceSpan) -> Self { - self.span = span; - self - } - - pub fn with_source_file(mut self, source_file: Option>) -> Self { - self.source_file = source_file; - self - } -} - -// ------------------------------------------------------------------------------------------------ -/// Public accessors -impl ProcedureContext { - pub fn id(&self) -> GlobalProcedureIndex { - self.gid - } - - pub fn name(&self) -> &FullyQualifiedProcedureName { - &self.name - } - - pub fn num_locals(&self) -> u16 { - self.num_locals - } - - #[allow(unused)] - pub fn module(&self) -> &LibraryPath { - &self.name.module - } - - pub fn source_file(&self) -> Option> { - self.source_file.clone() - } - - pub fn is_kernel(&self) -> bool { - self.visibility.is_syscall() - } -} - -// ------------------------------------------------------------------------------------------------ -/// State mutators -impl ProcedureContext { - pub fn insert_callee(&mut self, callee: RpoDigest) { - self.callset.insert(callee); - } - - pub fn extend_callset(&mut self, callees: I) - where - I: IntoIterator, - { - self.callset.extend(callees); - } - - /// Registers a call to an externally-defined procedure which we have previously compiled. - /// - /// The call set of the callee is added to the call set of the procedure we are currently - /// compiling, to reflect that all of the code reachable from the callee is by extension - /// reachable by the caller. - pub fn register_external_call( - &mut self, - callee: &Procedure, - inlined: bool, - mast_forest: &MastForest, - ) -> Result<(), AssemblyError> { - // If we call the callee, it's callset is by extension part of our callset - self.extend_callset(callee.callset().iter().cloned()); - - // If the callee is not being inlined, add it to our callset - if !inlined { - self.insert_callee(callee.mast_root(mast_forest)); - } - - Ok(()) - } - - pub fn into_procedure(self, body_node_id: MastNodeId) -> Box { - let procedure = - Procedure::new(self.name, self.visibility, self.num_locals as u32, body_node_id) - .with_span(self.span) - .with_source_file(self.source_file) - .with_callset(self.callset); - Box::new(procedure) - } -} - -impl Spanned for ProcedureContext { - fn span(&self) -> SourceSpan { - self.span - } -} diff --git a/assembly/src/assembler/id.rs b/assembly/src/assembler/id.rs index ce01ca8c9d..b7cfd1348a 100644 --- a/assembly/src/assembler/id.rs +++ b/assembly/src/assembler/id.rs @@ -16,10 +16,10 @@ use crate::ast::ProcedureIndex; /// /// /// In addition to the [super::ModuleGraph], these indices are also used with an instance of a -/// [super::ProcedureCache]. This is because the [super::ModuleGraph] and [super::ProcedureCache] -/// instances are paired, i.e. the [super::ModuleGraph] stores the syntax trees and call graph -/// analysis for a program, while the [super::ProcedureCache] caches the compiled -/// [super::Procedure]s for the same program, as derived from the corresponding graph. +/// [super::MastForestBuilder]. This is because the [super::ModuleGraph] and +/// [super::MastForestBuilder] instances are paired, i.e. the [super::ModuleGraph] stores the syntax +/// trees and call graph analysis for a program, while the [super::MastForestBuilder] caches the +/// compiled [super::Procedure]s for the same program, as derived from the corresponding graph. /// /// This is intended for use when we are doing global inter-procedural analysis on a (possibly /// growable) set of modules. It is expected that the index of a module in the set, as well as the diff --git a/assembly/src/assembler/instruction/env_ops.rs b/assembly/src/assembler/instruction/env_ops.rs index c9d455957e..f6260ca671 100644 --- a/assembly/src/assembler/instruction/env_ops.rs +++ b/assembly/src/assembler/instruction/env_ops.rs @@ -1,5 +1,5 @@ use super::{mem_ops::local_to_absolute_addr, push_felt, BasicBlockBuilder}; -use crate::{assembler::context::ProcedureContext, AssemblyError, Felt, Spanned}; +use crate::{assembler::ProcedureContext, AssemblyError, Felt, Spanned}; use vm_core::Operation::*; // CONSTANT INPUTS diff --git a/assembly/src/assembler/instruction/field_ops.rs b/assembly/src/assembler/instruction/field_ops.rs index 963280e283..2e6cd2af0e 100644 --- a/assembly/src/assembler/instruction/field_ops.rs +++ b/assembly/src/assembler/instruction/field_ops.rs @@ -1,6 +1,6 @@ use super::{validate_param, BasicBlockBuilder}; use crate::{ - assembler::context::ProcedureContext, + assembler::ProcedureContext, diagnostics::{RelatedError, Report}, AssemblyError, Felt, Span, MAX_EXP_BITS, ONE, ZERO, }; diff --git a/assembly/src/assembler/instruction/mem_ops.rs b/assembly/src/assembler/instruction/mem_ops.rs index 8b54188014..b103e51fa3 100644 --- a/assembly/src/assembler/instruction/mem_ops.rs +++ b/assembly/src/assembler/instruction/mem_ops.rs @@ -1,5 +1,5 @@ use super::{push_felt, push_u32_value, validate_param, BasicBlockBuilder}; -use crate::{assembler::context::ProcedureContext, diagnostics::Report, AssemblyError}; +use crate::{assembler::ProcedureContext, diagnostics::Report, AssemblyError}; use alloc::string::ToString; use vm_core::{Felt, Operation::*}; diff --git a/assembly/src/assembler/instruction/mod.rs b/assembly/src/assembler/instruction/mod.rs index 071509cc82..76fb77a393 100644 --- a/assembly/src/assembler/instruction/mod.rs +++ b/assembly/src/assembler/instruction/mod.rs @@ -1,6 +1,6 @@ use super::{ - ast::InvokeKind, context::ProcedureContext, mast_forest_builder::MastForestBuilder, Assembler, - BasicBlockBuilder, Felt, Instruction, Operation, ONE, ZERO, + ast::InvokeKind, mast_forest_builder::MastForestBuilder, Assembler, BasicBlockBuilder, Felt, + Instruction, Operation, ProcedureContext, ONE, ZERO, }; use crate::{diagnostics::Report, utils::bound_into_included_u64, AssemblyError}; use core::ops::RangeBounds; @@ -391,7 +391,7 @@ impl Assembler { Instruction::DynExec => return self.dynexec(mast_forest_builder), Instruction::DynCall => return self.dyncall(mast_forest_builder), Instruction::ProcRef(ref callee) => { - self.procref(callee, proc_ctx, span_builder, mast_forest_builder.forest())? + self.procref(callee, proc_ctx, span_builder, mast_forest_builder)? } // ----- debug decorators ------------------------------------------------------------- diff --git a/assembly/src/assembler/instruction/procedures.rs b/assembly/src/assembler/instruction/procedures.rs index 0fc8394bf3..bb75937b2b 100644 --- a/assembly/src/assembler/instruction/procedures.rs +++ b/assembly/src/assembler/instruction/procedures.rs @@ -1,12 +1,12 @@ use super::{Assembler, BasicBlockBuilder, Operation}; use crate::{ - assembler::{context::ProcedureContext, mast_forest_builder::MastForestBuilder}, + assembler::{mast_forest_builder::MastForestBuilder, ProcedureContext}, ast::{InvocationTarget, InvokeKind}, AssemblyError, RpoDigest, SourceSpan, Spanned, }; use smallvec::SmallVec; -use vm_core::mast::{MastForest, MastNodeId}; +use vm_core::mast::MastNodeId; /// Procedure Invocation impl Assembler { @@ -18,7 +18,7 @@ impl Assembler { mast_forest_builder: &mut MastForestBuilder, ) -> Result, AssemblyError> { let span = callee.span(); - let digest = self.resolve_target(kind, callee, proc_ctx, mast_forest_builder.forest())?; + let digest = self.resolve_target(kind, callee, proc_ctx, mast_forest_builder)?; self.invoke_mast_root(kind, span, digest, proc_ctx, mast_forest_builder) } @@ -31,12 +31,11 @@ impl Assembler { mast_forest_builder: &mut MastForestBuilder, ) -> Result, AssemblyError> { // Get the procedure from the assembler - let cache = &self.procedure_cache; let current_source_file = proc_ctx.source_file(); // If the procedure is cached, register the call to ensure the callset // is updated correctly. - match cache.get_by_mast_root(&mast_root) { + match mast_forest_builder.find_procedure(&mast_root) { Some(proc) if matches!(kind, InvokeKind::SysCall) => { // Verify if this is a syscall, that the callee is a kernel procedure // @@ -69,11 +68,9 @@ impl Assembler { }) } })?; - proc_ctx.register_external_call(&proc, false, mast_forest_builder.forest())?; - } - Some(proc) => { - proc_ctx.register_external_call(&proc, false, mast_forest_builder.forest())? + proc_ctx.register_external_call(&proc, false)?; } + Some(proc) => proc_ctx.register_external_call(&proc, false)?, None if matches!(kind, InvokeKind::SysCall) => { return Err(AssemblyError::UnknownSysCallTarget { span, @@ -91,7 +88,7 @@ impl Assembler { // procedures, such that when we assemble a procedure, all // procedures that it calls will have been assembled, and // hence be present in the `MastForest`. - match mast_forest_builder.find_procedure_root(mast_root) { + match mast_forest_builder.find_procedure_node_id(mast_root) { Some(root) => root, None => { // If the MAST root called isn't known to us, make it an external @@ -101,7 +98,7 @@ impl Assembler { } } InvokeKind::Call => { - let callee_id = match mast_forest_builder.find_procedure_root(mast_root) { + let callee_id = match mast_forest_builder.find_procedure_node_id(mast_root) { Some(callee_id) => callee_id, None => { // If the MAST root called isn't known to us, make it an external @@ -113,7 +110,7 @@ impl Assembler { mast_forest_builder.ensure_call(callee_id)? } InvokeKind::SysCall => { - let callee_id = match mast_forest_builder.find_procedure_root(mast_root) { + let callee_id = match mast_forest_builder.find_procedure_node_id(mast_root) { Some(callee_id) => callee_id, None => { // If the MAST root called isn't known to us, make it an external @@ -158,10 +155,11 @@ impl Assembler { callee: &InvocationTarget, proc_ctx: &mut ProcedureContext, span_builder: &mut BasicBlockBuilder, - mast_forest: &MastForest, + mast_forest_builder: &MastForestBuilder, ) -> Result<(), AssemblyError> { - let digest = self.resolve_target(InvokeKind::Exec, callee, proc_ctx, mast_forest)?; - self.procref_mast_root(digest, proc_ctx, span_builder, mast_forest) + let digest = + self.resolve_target(InvokeKind::Exec, callee, proc_ctx, mast_forest_builder)?; + self.procref_mast_root(digest, proc_ctx, span_builder, mast_forest_builder) } fn procref_mast_root( @@ -169,12 +167,13 @@ impl Assembler { mast_root: RpoDigest, proc_ctx: &mut ProcedureContext, span_builder: &mut BasicBlockBuilder, - mast_forest: &MastForest, + mast_forest_builder: &MastForestBuilder, ) -> Result<(), AssemblyError> { // Add the root to the callset to be able to use dynamic instructions // with the referenced procedure later - if let Some(proc) = self.procedure_cache.get_by_mast_root(&mast_root) { - proc_ctx.register_external_call(&proc, false, mast_forest)?; + + if let Some(proc) = mast_forest_builder.find_procedure(&mast_root) { + proc_ctx.register_external_call(&proc, false)?; } // Create an array with `Push` operations containing root elements diff --git a/assembly/src/assembler/instruction/u32_ops.rs b/assembly/src/assembler/instruction/u32_ops.rs index fc938c6370..4633d91799 100644 --- a/assembly/src/assembler/instruction/u32_ops.rs +++ b/assembly/src/assembler/instruction/u32_ops.rs @@ -1,6 +1,6 @@ use super::{field_ops::append_pow2_op, push_u32_value, validate_param, BasicBlockBuilder}; use crate::{ - assembler::context::ProcedureContext, + assembler::ProcedureContext, diagnostics::{RelatedError, Report}, AssemblyError, Span, MAX_U32_ROTATE_VALUE, MAX_U32_SHIFT_VALUE, }; diff --git a/assembly/src/assembler/mast_forest_builder.rs b/assembly/src/assembler/mast_forest_builder.rs index 359867255c..035eccdf7b 100644 --- a/assembly/src/assembler/mast_forest_builder.rs +++ b/assembly/src/assembler/mast_forest_builder.rs @@ -1,17 +1,26 @@ use core::ops::Index; -use alloc::{collections::BTreeMap, vec::Vec}; +use alloc::{collections::BTreeMap, sync::Arc, vec::Vec}; use vm_core::{ crypto::hash::RpoDigest, - mast::{MastForest, MastForestError, MastNode, MastNodeId}, + mast::{MastForest, MastNode, MastNodeId}, DecoratorList, Operation, }; +use crate::AssemblyError; + +use super::{GlobalProcedureIndex, Procedure}; + +// MAST FOREST BUILDER +// ================================================================================================ + /// Builder for a [`MastForest`]. #[derive(Clone, Debug, Default)] pub struct MastForestBuilder { mast_forest: MastForest, node_id_by_hash: BTreeMap, + procedures: BTreeMap>, + proc_gid_by_hash: BTreeMap, } impl MastForestBuilder { @@ -20,28 +29,112 @@ impl MastForestBuilder { } } -/// Accessors +// ------------------------------------------------------------------------------------------------ +/// Public accessors impl MastForestBuilder { - /// Returns the underlying [`MastForest`] being built - pub fn forest(&self) -> &MastForest { - &self.mast_forest + /// Returns a reference to the procedure with the specified [`GlobalProcedureIndex`], or None + /// if such a procedure is not present in this MAST forest builder. + #[inline(always)] + pub fn get_procedure(&self, gid: GlobalProcedureIndex) -> Option> { + self.procedures.get(&gid).cloned() } - /// Returns the [`MastNodeId`] of the procedure associated with a given digest, if any. + /// Returns a reference to the procedure with the specified MAST root, or None + /// if such a procedure is not present in this MAST forest builder. #[inline(always)] - pub fn find_procedure_root(&self, digest: RpoDigest) -> Option { + pub fn find_procedure(&self, mast_root: &RpoDigest) -> Option> { + self.proc_gid_by_hash.get(mast_root).and_then(|gid| self.get_procedure(*gid)) + } + + /// Returns the [`MastNodeId`] of the procedure associated with a given MAST root, or None + /// if such a procedure is not present in this MAST forest builder. + #[inline(always)] + pub fn find_procedure_node_id(&self, digest: RpoDigest) -> Option { self.mast_forest.find_procedure_root(digest) } + + /// Returns the [`MastNode`] for the provided MAST node ID, or None if a node with this ID is + /// not present in this MAST forest builder. + pub fn get_mast_node(&self, id: MastNodeId) -> Option<&MastNode> { + self.mast_forest.get_node_by_id(id) + } +} + +impl MastForestBuilder { + /// Inserts a procedure into this MAST forest builder. + /// + /// If the procedure with the same ID already exists in this forest builder, this will have + /// no effect. + pub fn insert_procedure( + &mut self, + gid: GlobalProcedureIndex, + procedure: Procedure, + ) -> Result<(), AssemblyError> { + let proc_root = self.mast_forest[procedure.body_node_id()].digest(); + + // Check if an entry is already in this cache slot. + // + // If there is already a cache entry, but it conflicts with what we're trying to cache, + // then raise an error. + if let Some(cached) = self.procedures.get(&gid) { + let cached_root = self.mast_forest[cached.body_node_id()].digest(); + if cached_root != proc_root || cached.num_locals() != procedure.num_locals() { + return Err(AssemblyError::ConflictingDefinitions { + first: cached.fully_qualified_name().clone(), + second: procedure.fully_qualified_name().clone(), + }); + } + + // The global procedure index and the MAST root resolve to an already cached version of + // this procedure, nothing to do. + // + // TODO: We should emit a warning for this, because while it is not an error per se, it + // does reflect that we're doing work we don't need to be doing. However, emitting a + // warning only makes sense if this is controllable by the user, and it isn't yet + // clear whether this edge case will ever happen in practice anyway. + return Ok(()); + } + + // We don't have a cache entry yet, but we do want to make sure we don't have a conflicting + // cache entry with the same MAST root: + if let Some(cached) = self.find_procedure(&proc_root) { + if cached.num_locals() != procedure.num_locals() { + return Err(AssemblyError::ConflictingDefinitions { + first: cached.fully_qualified_name().clone(), + second: procedure.fully_qualified_name().clone(), + }); + } + + // We have a previously cached version of an equivalent procedure, just under a + // different [GlobalProcedureIndex], so insert the cached procedure into the slot for + // `id`, but skip inserting a record in the MAST root lookup table + self.make_root(procedure.body_node_id()); + self.procedures.insert(gid, Arc::new(procedure)); + return Ok(()); + } + + self.make_root(procedure.body_node_id()); + self.proc_gid_by_hash.insert(proc_root, gid); + self.procedures.insert(gid, Arc::new(procedure)); + + Ok(()) + } + + /// Marks the given [`MastNodeId`] as being the root of a procedure. + pub fn make_root(&mut self, new_root_id: MastNodeId) { + self.mast_forest.make_root(new_root_id) + } } -/// Mutators +// ------------------------------------------------------------------------------------------------ +/// Node inserters impl MastForestBuilder { /// Adds a node to the forest, and returns the [`MastNodeId`] associated with it. /// /// If a [`MastNode`] which is equal to the current node was previously added, the previously /// returned [`MastNodeId`] will be returned. This enforces this invariant that equal /// [`MastNode`]s have equal [`MastNodeId`]s. - fn ensure_node(&mut self, node: MastNode) -> Result { + fn ensure_node(&mut self, node: MastNode) -> Result { let node_digest = node.digest(); if let Some(node_id) = self.node_id_by_hash.get(&node_digest) { @@ -60,7 +153,7 @@ impl MastForestBuilder { &mut self, operations: Vec, decorators: Option, - ) -> Result { + ) -> Result { match decorators { Some(decorators) => { self.ensure_node(MastNode::new_basic_block_with_decorators(operations, decorators)) @@ -74,7 +167,7 @@ impl MastForestBuilder { &mut self, left_child: MastNodeId, right_child: MastNodeId, - ) -> Result { + ) -> Result { self.ensure_node(MastNode::new_join(left_child, right_child, &self.mast_forest)) } @@ -83,39 +176,34 @@ impl MastForestBuilder { &mut self, if_branch: MastNodeId, else_branch: MastNodeId, - ) -> Result { + ) -> Result { self.ensure_node(MastNode::new_split(if_branch, else_branch, &self.mast_forest)) } /// Adds a loop node to the forest, and returns the [`MastNodeId`] associated with it. - pub fn ensure_loop(&mut self, body: MastNodeId) -> Result { + pub fn ensure_loop(&mut self, body: MastNodeId) -> Result { self.ensure_node(MastNode::new_loop(body, &self.mast_forest)) } /// Adds a call node to the forest, and returns the [`MastNodeId`] associated with it. - pub fn ensure_call(&mut self, callee: MastNodeId) -> Result { + pub fn ensure_call(&mut self, callee: MastNodeId) -> Result { self.ensure_node(MastNode::new_call(callee, &self.mast_forest)) } /// Adds a syscall node to the forest, and returns the [`MastNodeId`] associated with it. - pub fn ensure_syscall(&mut self, callee: MastNodeId) -> Result { + pub fn ensure_syscall(&mut self, callee: MastNodeId) -> Result { self.ensure_node(MastNode::new_syscall(callee, &self.mast_forest)) } - /// Adds a dynexec node to the forest, and returns the [`MastNodeId`] associated with it. - pub fn ensure_dyn(&mut self) -> Result { + /// Adds a dyn node to the forest, and returns the [`MastNodeId`] associated with it. + pub fn ensure_dyn(&mut self) -> Result { self.ensure_node(MastNode::new_dyn()) } /// Adds an external node to the forest, and returns the [`MastNodeId`] associated with it. - pub fn ensure_external(&mut self, mast_root: RpoDigest) -> Result { + pub fn ensure_external(&mut self, mast_root: RpoDigest) -> Result { self.ensure_node(MastNode::new_external(mast_root)) } - - /// Marks the given [`MastNodeId`] as being the root of a procedure. - pub fn make_root(&mut self, new_root_id: MastNodeId) { - self.mast_forest.make_root(new_root_id) - } } impl Index for MastForestBuilder { diff --git a/assembly/src/assembler/mod.rs b/assembly/src/assembler/mod.rs index b23ae6d36c..c49e121e68 100644 --- a/assembly/src/assembler/mod.rs +++ b/assembly/src/assembler/mod.rs @@ -7,15 +7,11 @@ use crate::{ AssemblyError, Compile, CompileOptions, Felt, Library, LibraryNamespace, LibraryPath, RpoDigest, Spanned, ONE, ZERO, }; -use alloc::{boxed::Box, sync::Arc, vec::Vec}; +use alloc::{sync::Arc, vec::Vec}; use mast_forest_builder::MastForestBuilder; -use vm_core::{ - mast::{MastForest, MastNodeId}, - Decorator, DecoratorList, Kernel, Operation, Program, -}; +use vm_core::{mast::MastNodeId, Decorator, DecoratorList, Kernel, Operation, Program}; mod basic_block_builder; -mod context; mod id; mod instruction; mod mast_forest_builder; @@ -25,11 +21,9 @@ mod procedure; mod tests; pub use self::id::{GlobalProcedureIndex, ModuleIndex}; -pub(crate) use self::module_graph::ProcedureCache; -pub use self::procedure::Procedure; +pub use self::procedure::{Procedure, ProcedureContext}; use self::basic_block_builder::BasicBlockBuilder; -use self::context::ProcedureContext; use self::module_graph::{CallerInfo, ModuleGraph, ResolvedTarget}; // ASSEMBLER @@ -56,8 +50,6 @@ use self::module_graph::{CallerInfo, ModuleGraph, ResolvedTarget}; pub struct Assembler { /// The global [ModuleGraph] for this assembler. module_graph: ModuleGraph, - /// The global procedure cache for this assembler. - procedure_cache: ProcedureCache, /// Whether to treat warning diagnostics as errors warnings_as_errors: bool, /// Whether the assembler enables extra debugging information. @@ -318,7 +310,7 @@ impl Assembler { } /// Compile the uncompiled procedure in the module graph which are members of the subgraph - /// rooted at `root`, placing them in the procedure cache once compiled. + /// rooted at `root`, placing them in the MAST forest builder once compiled. /// /// Returns an error if any of the provided Miden Assembly is invalid. fn compile_subgraph( @@ -344,7 +336,7 @@ impl Assembler { self.process_graph_worklist(&mut worklist, Some(root), mast_forest_builder)? } else { let _ = self.process_graph_worklist(&mut worklist, None, mast_forest_builder)?; - self.procedure_cache.get(root) + mast_forest_builder.get_procedure(root) }; Ok(compiled.expect("compilation succeeded but root not found in cache")) @@ -361,11 +353,8 @@ impl Assembler { let mut compiled_entrypoint = None; while let Some(procedure_gid) = worklist.pop() { // If we have already compiled this procedure, do not recompile - if let Some(proc) = self.procedure_cache.get(procedure_gid) { - self.module_graph.register_mast_root( - procedure_gid, - proc.mast_root(mast_forest_builder.forest()), - )?; + if let Some(proc) = mast_forest_builder.get_procedure(procedure_gid) { + self.module_graph.register_mast_root(procedure_gid, proc.mast_root())?; continue; } let is_entry = entrypoint == Some(procedure_gid); @@ -387,22 +376,14 @@ impl Assembler { // Compile this procedure let procedure = self.compile_procedure(pctx, mast_forest_builder)?; - // register the procedure in the MAST forest - mast_forest_builder.make_root(procedure.body_node_id()); - // Cache the compiled procedure, unless it's the program entrypoint if is_entry { + mast_forest_builder.make_root(procedure.body_node_id()); compiled_entrypoint = Some(Arc::from(procedure)); } else { // Make the MAST root available to all dependents - let digest = procedure.mast_root(mast_forest_builder.forest()); - self.module_graph.register_mast_root(procedure_gid, digest)?; - - self.procedure_cache.insert( - procedure_gid, - Arc::from(procedure), - mast_forest_builder.forest(), - )?; + self.module_graph.register_mast_root(procedure_gid, procedure.mast_root())?; + mast_forest_builder.insert_procedure(procedure_gid, procedure)?; } } @@ -414,13 +395,13 @@ impl Assembler { &self, mut proc_ctx: ProcedureContext, mast_forest_builder: &mut MastForestBuilder, - ) -> Result, Report> { + ) -> Result { // Make sure the current procedure context is available during codegen let gid = proc_ctx.id(); let num_locals = proc_ctx.num_locals(); let proc = self.module_graph[gid].unwrap_procedure(); - let proc_body_root = if num_locals > 0 { + let proc_body_id = if num_locals > 0 { // for procedures with locals, we need to update fmp register before and after the // procedure body is executed. specifically: // - to allocate procedure locals we need to increment fmp by the number of locals @@ -435,7 +416,10 @@ impl Assembler { self.compile_body(proc.iter(), &mut proc_ctx, None, mast_forest_builder)? }; - Ok(proc_ctx.into_procedure(proc_body_root)) + let proc_body_node = mast_forest_builder + .get_mast_node(proc_body_id) + .expect("no MAST node for compiled procedure"); + Ok(proc_ctx.into_procedure(proc_body_node.digest(), proc_body_id)) } fn compile_body<'a, I>( @@ -462,9 +446,8 @@ impl Assembler { proc_ctx, mast_forest_builder, )? { - if let Some(basic_block_id) = basic_block_builder - .make_basic_block(mast_forest_builder) - .map_err(AssemblyError::from)? + if let Some(basic_block_id) = + basic_block_builder.make_basic_block(mast_forest_builder)? { mast_node_ids.push(basic_block_id); } @@ -476,9 +459,8 @@ impl Assembler { Op::If { then_blk, else_blk, .. } => { - if let Some(basic_block_id) = basic_block_builder - .make_basic_block(mast_forest_builder) - .map_err(AssemblyError::from)? + if let Some(basic_block_id) = + basic_block_builder.make_basic_block(mast_forest_builder)? { mast_node_ids.push(basic_block_id); } @@ -488,16 +470,13 @@ impl Assembler { let else_blk = self.compile_body(else_blk.iter(), proc_ctx, None, mast_forest_builder)?; - let split_node_id = mast_forest_builder - .ensure_split(then_blk, else_blk) - .map_err(AssemblyError::from)?; + let split_node_id = mast_forest_builder.ensure_split(then_blk, else_blk)?; mast_node_ids.push(split_node_id); } Op::Repeat { count, body, .. } => { - if let Some(basic_block_id) = basic_block_builder - .make_basic_block(mast_forest_builder) - .map_err(AssemblyError::from)? + if let Some(basic_block_id) = + basic_block_builder.make_basic_block(mast_forest_builder)? { mast_node_ids.push(basic_block_id); } @@ -511,9 +490,8 @@ impl Assembler { } Op::While { body, .. } => { - if let Some(basic_block_id) = basic_block_builder - .make_basic_block(mast_forest_builder) - .map_err(AssemblyError::from)? + if let Some(basic_block_id) = + basic_block_builder.make_basic_block(mast_forest_builder)? { mast_node_ids.push(basic_block_id); } @@ -521,25 +499,20 @@ impl Assembler { let loop_body_node_id = self.compile_body(body.iter(), proc_ctx, None, mast_forest_builder)?; - let loop_node_id = mast_forest_builder - .ensure_loop(loop_body_node_id) - .map_err(AssemblyError::from)?; + let loop_node_id = mast_forest_builder.ensure_loop(loop_body_node_id)?; mast_node_ids.push(loop_node_id); } } } - if let Some(basic_block_id) = basic_block_builder - .try_into_basic_block(mast_forest_builder) - .map_err(AssemblyError::from)? + if let Some(basic_block_id) = + basic_block_builder.try_into_basic_block(mast_forest_builder)? { mast_node_ids.push(basic_block_id); } Ok(if mast_node_ids.is_empty() { - mast_forest_builder - .ensure_block(vec![Operation::Noop], None) - .map_err(AssemblyError::from)? + mast_forest_builder.ensure_block(vec![Operation::Noop], None)? } else { combine_mast_node_ids(mast_node_ids, mast_forest_builder)? }) @@ -550,7 +523,7 @@ impl Assembler { kind: InvokeKind, target: &InvocationTarget, proc_ctx: &ProcedureContext, - mast_forest: &MastForest, + mast_forest_builder: &MastForestBuilder, ) -> Result { let caller = CallerInfo { span: target.span(), @@ -560,12 +533,13 @@ impl Assembler { }; let resolved = self.module_graph.resolve_target(&caller, target)?; match resolved { - ResolvedTarget::Phantom(digest) | ResolvedTarget::Cached { digest, .. } => Ok(digest), - ResolvedTarget::Exact { gid } | ResolvedTarget::Resolved { gid, .. } => Ok(self - .procedure_cache - .get(gid) - .map(|p| p.mast_root(mast_forest)) - .expect("expected callee to have been compiled already")), + ResolvedTarget::Phantom(digest) => Ok(digest), + ResolvedTarget::Exact { gid } | ResolvedTarget::Resolved { gid, .. } => { + Ok(mast_forest_builder + .get_procedure(gid) + .map(|p| p.mast_root()) + .expect("expected callee to have been compiled already")) + } } } } diff --git a/assembly/src/assembler/module_graph/analysis/rewrite_check.rs b/assembly/src/assembler/module_graph/analysis/rewrite_check.rs index 0baddf44b2..4c097133e5 100644 --- a/assembly/src/assembler/module_graph/analysis/rewrite_check.rs +++ b/assembly/src/assembler/module_graph/analysis/rewrite_check.rs @@ -68,13 +68,6 @@ impl<'a, 'b: 'a> RewriteCheckVisitor<'a, 'b> { Ok(ResolvedTarget::Exact { .. } | ResolvedTarget::Phantom(_)) => { ControlFlow::Continue(()) } - Ok(ResolvedTarget::Cached { .. }) => { - if let InvocationTarget::MastRoot(_) = target { - ControlFlow::Continue(()) - } else { - ControlFlow::Break(Ok(true)) - } - } } } } diff --git a/assembly/src/assembler/module_graph/mod.rs b/assembly/src/assembler/module_graph/mod.rs index 12dafdd4df..0ce18226dc 100644 --- a/assembly/src/assembler/module_graph/mod.rs +++ b/assembly/src/assembler/module_graph/mod.rs @@ -2,12 +2,10 @@ mod analysis; mod callgraph; mod debug; mod name_resolver; -mod procedure_cache; mod rewrites; pub use self::callgraph::{CallGraph, CycleError}; pub use self::name_resolver::{CallerInfo, ResolvedTarget}; -pub use self::procedure_cache::ProcedureCache; use alloc::{boxed::Box, collections::BTreeMap, sync::Arc, vec::Vec}; use core::ops::Index; @@ -49,8 +47,6 @@ pub struct ModuleGraph { /// The set of MAST roots which have procedure definitions in this graph. There can be /// multiple procedures bound to the same root due to having identical code. roots: BTreeMap>, - /// The set of procedures in this graph which have known MAST roots - digests: BTreeMap, kernel_index: Option, kernel: Kernel, } @@ -346,20 +342,6 @@ impl ModuleGraph { self.modules.get(id.module.as_usize()).and_then(|m| m.get(id.index)) } - /// Fetches a [Procedure] by [RpoDigest]. - /// - /// NOTE: This implicitly chooses the first definition for a procedure if the same digest is - /// shared for multiple definitions. - #[allow(unused)] - pub fn get_procedure_by_digest(&self, digest: &RpoDigest) -> Option<&Procedure> { - self.roots - .get(digest) - .and_then(|indices| match self.get_procedure(indices[0])? { - Export::Procedure(ref proc) => Some(proc), - Export::Alias(_) => None, - }) - } - pub fn get_procedure_index_by_digest( &self, digest: &RpoDigest, @@ -367,12 +349,6 @@ impl ModuleGraph { self.roots.get(digest).map(|indices| indices[0]) } - /// Look up the [RpoDigest] associated with the given [GlobalProcedureIndex], if one is known - /// at this point in time. - pub fn get_mast_root(&self, id: GlobalProcedureIndex) -> Option<&RpoDigest> { - self.digests.get(&id) - } - #[allow(unused)] pub fn callees(&self, gid: GlobalProcedureIndex) -> &[GlobalProcedureIndex] { self.callgraph.out_edges(gid) @@ -437,19 +413,6 @@ impl ModuleGraph { } } - match self.digests.entry(id) { - Entry::Occupied(ref entry) => { - assert_eq!( - entry.get(), - &digest, - "attempted to register the same procedure with different digests!" - ); - } - Entry::Vacant(entry) => { - entry.insert(digest); - } - } - Ok(()) } diff --git a/assembly/src/assembler/module_graph/name_resolver.rs b/assembly/src/assembler/module_graph/name_resolver.rs index 6070a52e18..07fb614dc6 100644 --- a/assembly/src/assembler/module_graph/name_resolver.rs +++ b/assembly/src/assembler/module_graph/name_resolver.rs @@ -48,12 +48,6 @@ pub struct CallerInfo { /// Represents the output of the [NameResolver] when it resolves a procedure name. #[derive(Debug)] pub enum ResolvedTarget { - /// The callee is available in the procedure cache, so we know its exact hash. - Cached { - digest: RpoDigest, - /// If the procedure was compiled from source, this is its identifier in the [ModuleGraph] - gid: Option, - }, /// The callee was resolved to a known procedure in the [ModuleGraph] Exact { gid: GlobalProcedureIndex }, /// The callee was resolved to a concrete procedure definition, and can be referenced as @@ -73,7 +67,6 @@ impl ResolvedTarget { pub fn into_global_id(self) -> Option { match self { ResolvedTarget::Exact { gid } | ResolvedTarget::Resolved { gid, .. } => Some(gid), - ResolvedTarget::Cached { gid, .. } => gid, ResolvedTarget::Phantom(_) => None, } } @@ -140,51 +133,31 @@ impl<'a> NameResolver<'a> { module: self.graph.kernel_index.unwrap(), index: index.into_inner(), }; - match self.graph.get_mast_root(gid) { - Some(digest) => Ok(ResolvedTarget::Cached { - digest: *digest, - gid: Some(gid), - }), - None => Ok(ResolvedTarget::Exact { gid }), - } + Ok(ResolvedTarget::Exact { gid }) } Some(ResolvedProcedure::Local(index)) => { let gid = GlobalProcedureIndex { module: caller.module, index: index.into_inner(), }; - match self.graph.get_mast_root(gid) { - Some(digest) => Ok(ResolvedTarget::Cached { - digest: *digest, - gid: Some(gid), - }), - None => Ok(ResolvedTarget::Exact { gid }), - } + Ok(ResolvedTarget::Exact { gid }) } Some(ResolvedProcedure::External(ref fqn)) => { let gid = self.find(caller, fqn)?; - match self.graph.get_mast_root(gid) { - Some(digest) => Ok(ResolvedTarget::Cached { - digest: *digest, - gid: Some(gid), - }), - None => { - let path = self.module_path(gid.module); - let pending_offset = self.graph.modules.len(); - let name = if gid.module.as_usize() >= pending_offset { - self.pending[gid.module.as_usize() - pending_offset] - .resolver - .get_name(gid.index) - .clone() - } else { - self.graph[gid].name().clone() - }; - Ok(ResolvedTarget::Resolved { - gid, - target: InvocationTarget::AbsoluteProcedurePath { name, path }, - }) - } - } + let path = self.module_path(gid.module); + let pending_offset = self.graph.modules.len(); + let name = if gid.module.as_usize() >= pending_offset { + self.pending[gid.module.as_usize() - pending_offset] + .resolver + .get_name(gid.index) + .clone() + } else { + self.graph[gid].name().clone() + }; + Ok(ResolvedTarget::Resolved { + gid, + target: InvocationTarget::AbsoluteProcedurePath { name, path }, + }) } Some(ResolvedProcedure::MastRoot(ref digest)) => { match self.graph.get_procedure_index_by_digest(digest) { @@ -241,28 +214,20 @@ impl<'a> NameResolver<'a> { name: name.clone(), }; let gid = self.find(caller, &fqn)?; - match self.graph.get_mast_root(gid) { - Some(digest) => Ok(ResolvedTarget::Cached { - digest: *digest, - gid: Some(gid), - }), - None => { - let path = self.module_path(gid.module); - let pending_offset = self.graph.modules.len(); - let name = if gid.module.as_usize() >= pending_offset { - self.pending[gid.module.as_usize() - pending_offset] - .resolver - .get_name(gid.index) - .clone() - } else { - self.graph[gid].name().clone() - }; - Ok(ResolvedTarget::Resolved { - gid, - target: InvocationTarget::AbsoluteProcedurePath { name, path }, - }) - } - } + let path = self.module_path(gid.module); + let pending_offset = self.graph.modules.len(); + let name = if gid.module.as_usize() >= pending_offset { + self.pending[gid.module.as_usize() - pending_offset] + .resolver + .get_name(gid.index) + .clone() + } else { + self.graph[gid].name().clone() + }; + Ok(ResolvedTarget::Resolved { + gid, + target: InvocationTarget::AbsoluteProcedurePath { name, path }, + }) } None => Err(AssemblyError::UndefinedModule { span: target.span(), @@ -280,13 +245,7 @@ impl<'a> NameResolver<'a> { name: name.clone(), }; let gid = self.find(caller, &fqn)?; - match self.graph.get_mast_root(gid) { - Some(digest) => Ok(ResolvedTarget::Cached { - digest: *digest, - gid: Some(gid), - }), - None => Ok(ResolvedTarget::Exact { gid }), - } + Ok(ResolvedTarget::Exact { gid }) } } } diff --git a/assembly/src/assembler/module_graph/procedure_cache.rs b/assembly/src/assembler/module_graph/procedure_cache.rs deleted file mode 100644 index 71171ec3e0..0000000000 --- a/assembly/src/assembler/module_graph/procedure_cache.rs +++ /dev/null @@ -1,383 +0,0 @@ -use alloc::{ - collections::{BTreeMap, VecDeque}, - sync::Arc, - vec::Vec, -}; -use core::{fmt, ops::Index}; -use vm_core::mast::MastForest; - -use crate::{ - assembler::{GlobalProcedureIndex, ModuleIndex, Procedure}, - ast::{FullyQualifiedProcedureName, ProcedureIndex}, - AssemblyError, LibraryPath, RpoDigest, -}; - -// PROCEDURE CACHE -// ================================================================================================ - -/// The [ProcedureCache] is responsible for caching the MAST of compiled procedures. -/// -/// Once cached, subsequent compilations will use the cached MAST artifacts, rather than -/// recompiling the same procedures again and again. -/// -/// # Usage -/// -/// The procedure cache is intimately tied to a [ModuleGraph], which effectively acts as a cache -/// for the MASM syntax tree, and associates each procedure with a unique [GlobalProcedureIndex] -/// which acts as the cache key for the corresponding [ProcedureCache]. -/// -/// This also is how we avoid serving cached artifacts when the syntax tree of a module is modified -/// and recompiled - the old module will be removed from the [ModuleGraph] and the new version will -/// be added as a new module, getting new [GlobalProcedureIndex]s for each of its procedures as a -/// result. -/// -/// As a result of this design choice, a unique [ProcedureCache] is associated with each context in -/// play during compilation: the global assembler context has its own cache, and each -/// [AssemblyContext] has its own cache. -#[derive(Default, Clone)] -pub struct ProcedureCache { - cache: Vec>>>, - /// This is always the same length as `cache` - modules: Vec>, - by_mast_root: BTreeMap, -} - -/// When indexing by [ModuleIndex], we return the [LibraryPath] of the [Module] -/// to which that cache slot belongs. -impl Index for ProcedureCache { - type Output = LibraryPath; - fn index(&self, id: ModuleIndex) -> &Self::Output { - self.modules[id.as_usize()].as_ref().expect("attempted to index an empty cache") - } -} - -/// When indexing by [GlobalProcedureIndex], we return the cached [Procedure] -impl Index for ProcedureCache { - type Output = Arc; - fn index(&self, id: GlobalProcedureIndex) -> &Self::Output { - self.cache[id.module.as_usize()][id.index.as_usize()] - .as_ref() - .expect("attempted to index an empty cache slot") - } -} - -impl ProcedureCache { - /// Returns true if the cache is empty. - pub fn is_empty(&self) -> bool { - self.len() == 0 - } - - /// Returns the number of procedures in the cache. - pub fn len(&self) -> usize { - self.cache.iter().map(|m| m.iter().filter_map(|p| p.as_deref()).count()).sum() - } - - /// Searches for a procedure in the cache using `predicate`. - #[allow(unused)] - pub fn find(&self, mut predicate: F) -> Option> - where - F: FnMut(&Procedure) -> bool, - { - self.cache.iter().find_map(|m| { - m.iter().filter_map(|p| p.as_ref()).find_map(|p| { - if predicate(p) { - Some(p.clone()) - } else { - None - } - }) - }) - } - - /// Searches for a procedure in the cache using `predicate`, starting from procedures with the - /// highest [ModuleIndex] to lowest. - #[allow(unused)] - pub fn rfind(&self, mut predicate: F) -> Option> - where - F: FnMut(&Procedure) -> bool, - { - self.cache.iter().rev().find_map(|m| { - m.iter().filter_map(|p| p.as_ref()).find_map(|p| { - if predicate(p) { - Some(p.clone()) - } else { - None - } - }) - }) - } - - /// Looks up a procedure by its MAST root. - pub fn get_by_mast_root(&self, digest: &RpoDigest) -> Option> { - self.by_mast_root.get(digest).copied().map(|index| self[index].clone()) - } - - /// Looks up a procedure by its fully-qualified name. - /// - /// NOTE: If a procedure with the same name is cached twice, this will return the version with - /// the highest [ModuleIndex]. - #[allow(unused)] - pub fn get_by_name(&self, name: &FullyQualifiedProcedureName) -> Option> { - self.rfind(|p| p.fully_qualified_name() == name) - } - - /// Returns the procedure with the given [GlobalProcedureIndex], if it is cached. - pub fn get(&self, id: GlobalProcedureIndex) -> Option> { - self.cache - .get(id.module.as_usize()) - .and_then(|m| m.get(id.index.as_usize()).and_then(|p| p.clone())) - } - - /// Returns true if the procedure with the given [GlobalProcedureIndex] is cached. - #[allow(unused)] - pub fn contains_key(&self, id: GlobalProcedureIndex) -> bool { - self.cache - .get(id.module.as_usize()) - .map(|m| m.get(id.index.as_usize()).is_some()) - .unwrap_or(false) - } - - /// Inserts the given [Procedure] into this cache, using the [GlobalProcedureIndex] as the - /// cache key. - /// - /// # Errors - /// - /// This operation will fail under the following conditions: - /// - The cache slot for the given [GlobalProcedureIndex] is occupied with a conflicting - /// definition. - /// - A procedure with the same MAST root is already in the cache, but the two procedures have - /// differing metadata (such as the number of locals, etc). - pub fn insert( - &mut self, - id: GlobalProcedureIndex, - procedure: Arc, - mast_forest: &MastForest, - ) -> Result<(), AssemblyError> { - let mast_root = procedure.mast_root(mast_forest); - - // Make sure we can index to the cache slot for this procedure - self.ensure_cache_slot_exists(id, procedure.path()); - - // Check if an entry is already in this cache slot. - // - // If there is already a cache entry, but it conflicts with what we're trying to cache, - // then raise an error. - if let Some(cached) = self.get(id) { - if cached.mast_root(mast_forest) != mast_root - || cached.num_locals() != procedure.num_locals() - { - return Err(AssemblyError::ConflictingDefinitions { - first: cached.fully_qualified_name().clone(), - second: procedure.fully_qualified_name().clone(), - }); - } - - // The global procedure index and the MAST root resolve to an already cached version of - // this procedure, nothing to do. - // - // TODO: We should emit a warning for this, because while it is not an error per se, it - // does reflect that we're doing work we don't need to be doing. However, emitting a - // warning only makes sense if this is controllable by the user, and it isn't yet - // clear whether this edge case will ever happen in practice anyway. - return Ok(()); - } - - // We don't have a cache entry yet, but we do want to make sure we don't have a conflicting - // cache entry with the same MAST root: - if let Some(cached) = self.get_by_mast_root(&mast_root) { - // Sanity check - assert_eq!(cached.mast_root(mast_forest), mast_root); - - if cached.num_locals() != procedure.num_locals() { - return Err(AssemblyError::ConflictingDefinitions { - first: cached.fully_qualified_name().clone(), - second: procedure.fully_qualified_name().clone(), - }); - } - - // We have a previously cached version of an equivalent procedure, just under a - // different [GlobalProcedureIndex], so insert the cached procedure into the slot for - // `id`, but skip inserting a record in the MAST root lookup table - self.cache[id.module.as_usize()][id.index.as_usize()] = Some(procedure); - return Ok(()); - } - - // This is a new entry, so record both the cache entry and the MAST root mapping - self.cache[id.module.as_usize()][id.index.as_usize()] = Some(procedure); - self.by_mast_root.insert(mast_root, id); - - Ok(()) - } - - fn ensure_cache_slot_exists(&mut self, id: GlobalProcedureIndex, module: &LibraryPath) { - let min_cache_len = id.module.as_usize() + 1; - let min_module_len = id.index.as_usize() + 1; - - if self.cache.len() < min_cache_len { - self.cache.resize(min_cache_len, Vec::default()); - self.modules.resize(min_cache_len, None); - } - - // If this is the first entry for this module index, record the path to the module for - // future queries - let module_name = &mut self.modules[id.module.as_usize()]; - if module_name.is_none() { - *module_name = Some(module.clone()); - } - - let module_cache = &mut self.cache[id.module.as_usize()]; - if module_cache.len() < min_module_len { - module_cache.resize(min_module_len, None); - } - } -} - -impl IntoIterator for ProcedureCache { - type Item = (GlobalProcedureIndex, Arc); - type IntoIter = IntoIter; - - fn into_iter(self) -> Self::IntoIter { - let empty = self.is_empty(); - let pos = (0, 0); - IntoIter { - empty, - pos, - cache: VecDeque::from_iter(self.cache.into_iter().map(VecDeque::from)), - } - } -} - -pub struct IntoIter { - cache: VecDeque>>>, - pos: (usize, usize), - empty: bool, -} - -impl Iterator for IntoIter { - type Item = (GlobalProcedureIndex, Arc); - - fn next(&mut self) -> Option { - if self.empty { - return None; - } - - loop { - let (module, index) = self.pos; - if let Some(slot) = self.cache[module].pop_front() { - self.pos.1 += 1; - if let Some(procedure) = slot { - let gid = GlobalProcedureIndex { - module: ModuleIndex::new(module), - index: ProcedureIndex::new(index), - }; - break Some((gid, procedure)); - } - continue; - } - - // We've reached the end of this module cache - self.cache.pop_front(); - self.pos.0 += 1; - - // Check if we've reached the end of the overall cache - if self.cache.is_empty() { - self.empty = true; - break None; - } - } - } -} - -impl fmt::Debug for ProcedureCache { - fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { - f.debug_struct("ProcedureCache") - .field("modules", &DisplayCachedModules(self)) - .finish() - } -} - -#[doc(hidden)] -struct DisplayCachedModules<'a>(&'a ProcedureCache); - -impl<'a> fmt::Debug for DisplayCachedModules<'a> { - fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { - let roots = &self.0.by_mast_root; - f.debug_map() - .entries(self.0.modules.iter().enumerate().zip(self.0.cache.iter()).filter_map( - |((index, path), slots)| { - path.as_ref().map(|path| { - ( - ModuleSlot { - index, - module: path, - }, - DisplayCachedProcedures { - roots, - module: index, - slots: slots.as_slice(), - }, - ) - }) - }, - )) - .finish() - } -} - -#[doc(hidden)] -struct DisplayCachedProcedures<'a> { - roots: &'a BTreeMap, - slots: &'a [Option>], - module: usize, -} - -impl<'a> fmt::Debug for DisplayCachedProcedures<'a> { - fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { - f.debug_set() - .entries(self.slots.iter().enumerate().filter_map(|(index, p)| { - p.as_deref().map(|p| ProcedureSlot { - roots: self.roots, - module: self.module, - index, - procedure: p, - }) - })) - .finish() - } -} - -// NOTE: Clippy thinks these fields are dead because it doesn't recognize that they are used by the -// `debug_map` implementation. -#[derive(Debug)] -#[allow(dead_code)] -struct ModuleSlot<'a> { - index: usize, - module: &'a LibraryPath, -} - -#[doc(hidden)] -struct ProcedureSlot<'a> { - roots: &'a BTreeMap, - module: usize, - index: usize, - procedure: &'a Procedure, -} - -impl<'a> fmt::Debug for ProcedureSlot<'a> { - fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { - let id = GlobalProcedureIndex { - module: ModuleIndex::new(self.module), - index: ProcedureIndex::new(self.index), - }; - let digest = self - .roots - .iter() - .find_map(|(hash, gid)| if gid == &id { Some(hash) } else { None }) - .expect("missing root for cache entry"); - f.debug_struct("CacheEntry") - .field("index", &self.index) - .field("key", digest) - .field("procedure", self.procedure) - .finish() - } -} diff --git a/assembly/src/assembler/module_graph/rewrites/module.rs b/assembly/src/assembler/module_graph/rewrites/module.rs index 73218ddf8b..9d39307b89 100644 --- a/assembly/src/assembler/module_graph/rewrites/module.rs +++ b/assembly/src/assembler/module_graph/rewrites/module.rs @@ -11,7 +11,7 @@ use crate::{ InvocationTarget, Invoke, InvokeKind, Module, Procedure, }, diagnostics::SourceFile, - AssemblyError, Span, Spanned, + AssemblyError, Spanned, }; /// A [ModuleRewriter] handles applying all of the module-wide rewrites to a [Module] that is being @@ -67,13 +67,6 @@ impl<'a, 'b: 'a> ModuleRewriter<'a, 'b> { }; match self.resolver.resolve_target(&caller, target) { Err(err) => return ControlFlow::Break(err), - Ok(ResolvedTarget::Cached { digest, .. }) => { - *target = InvocationTarget::MastRoot(Span::new(target.span(), digest)); - self.invoked.insert(Invoke { - kind, - target: target.clone(), - }); - } Ok(ResolvedTarget::Phantom(_)) => (), Ok(ResolvedTarget::Exact { .. }) => { self.invoked.insert(Invoke { diff --git a/assembly/src/assembler/procedure.rs b/assembly/src/assembler/procedure.rs index d29da8d706..18ece32c0e 100644 --- a/assembly/src/assembler/procedure.rs +++ b/assembly/src/assembler/procedure.rs @@ -1,26 +1,164 @@ use alloc::{collections::BTreeSet, sync::Arc}; +use super::GlobalProcedureIndex; use crate::{ ast::{FullyQualifiedProcedureName, ProcedureName, Visibility}, diagnostics::SourceFile, - LibraryPath, RpoDigest, SourceSpan, Spanned, + AssemblyError, LibraryPath, RpoDigest, SourceSpan, Spanned, }; -use vm_core::mast::{MastForest, MastNodeId}; +use vm_core::mast::MastNodeId; pub type CallSet = BTreeSet; +// PROCEDURE CONTEXT +// ================================================================================================ + +/// Information about a procedure currently being compiled. +pub struct ProcedureContext { + gid: GlobalProcedureIndex, + span: SourceSpan, + source_file: Option>, + name: FullyQualifiedProcedureName, + visibility: Visibility, + num_locals: u16, + callset: CallSet, +} + +// ------------------------------------------------------------------------------------------------ +/// Constructors +impl ProcedureContext { + pub fn new( + gid: GlobalProcedureIndex, + name: FullyQualifiedProcedureName, + visibility: Visibility, + ) -> Self { + Self { + gid, + span: name.span(), + source_file: None, + name, + visibility, + num_locals: 0, + callset: Default::default(), + } + } + + pub fn with_num_locals(mut self, num_locals: u16) -> Self { + self.num_locals = num_locals; + self + } + + pub fn with_span(mut self, span: SourceSpan) -> Self { + self.span = span; + self + } + + pub fn with_source_file(mut self, source_file: Option>) -> Self { + self.source_file = source_file; + self + } +} + +// ------------------------------------------------------------------------------------------------ +/// Public accessors +impl ProcedureContext { + pub fn id(&self) -> GlobalProcedureIndex { + self.gid + } + + pub fn name(&self) -> &FullyQualifiedProcedureName { + &self.name + } + + pub fn num_locals(&self) -> u16 { + self.num_locals + } + + #[allow(unused)] + pub fn module(&self) -> &LibraryPath { + &self.name.module + } + + pub fn source_file(&self) -> Option> { + self.source_file.clone() + } + + pub fn is_kernel(&self) -> bool { + self.visibility.is_syscall() + } +} + +// ------------------------------------------------------------------------------------------------ +/// State mutators +impl ProcedureContext { + pub fn insert_callee(&mut self, callee: RpoDigest) { + self.callset.insert(callee); + } + + pub fn extend_callset(&mut self, callees: I) + where + I: IntoIterator, + { + self.callset.extend(callees); + } + + /// Registers a call to an externally-defined procedure which we have previously compiled. + /// + /// The call set of the callee is added to the call set of the procedure we are currently + /// compiling, to reflect that all of the code reachable from the callee is by extension + /// reachable by the caller. + pub fn register_external_call( + &mut self, + callee: &Procedure, + inlined: bool, + ) -> Result<(), AssemblyError> { + // If we call the callee, it's callset is by extension part of our callset + self.extend_callset(callee.callset().iter().cloned()); + + // If the callee is not being inlined, add it to our callset + if !inlined { + self.insert_callee(callee.mast_root()); + } + + Ok(()) + } + + /// Transforms this procedure context into a [Procedure]. + /// + /// The passed-in `mast_root` defines the MAST root of the procedure's body while + /// `mast_node_id` specifies the ID of the procedure's body node in the MAST forest in + /// which the procedure is defined. + /// + ///
+ /// `mast_root` and `mast_node_id` must be consistent. That is, the node located in the MAST + /// forest under `mast_node_id` must have the digest equal to the `mast_root`. + ///
+ pub fn into_procedure(self, mast_root: RpoDigest, mast_node_id: MastNodeId) -> Procedure { + Procedure::new(self.name, self.visibility, self.num_locals as u32, mast_root, mast_node_id) + .with_span(self.span) + .with_source_file(self.source_file) + .with_callset(self.callset) + } +} + +impl Spanned for ProcedureContext { + fn span(&self) -> SourceSpan { + self.span + } +} + // PROCEDURE // ================================================================================================ -/// A compiled Miden Assembly procedure, consisting of MAST and basic metadata. +/// A compiled Miden Assembly procedure, consisting of MAST info and basic metadata. /// /// Procedure metadata includes: /// -/// * Fully-qualified path of the procedure in Miden Assembly (if known). -/// * Number of procedure locals to allocate. -/// * The visibility of the procedure (e.g. public/private/syscall) -/// * The set of MAST roots invoked by this procedure. -/// * The original source span and file of the procedure (if available). +/// - Fully-qualified path of the procedure in Miden Assembly (if known). +/// - Number of procedure locals to allocate. +/// - The visibility of the procedure (e.g. public/private/syscall) +/// - The set of MAST roots invoked by this procedure. +/// - The original source span and file of the procedure (if available). #[derive(Clone, Debug)] pub struct Procedure { span: SourceSpan, @@ -28,18 +166,22 @@ pub struct Procedure { path: FullyQualifiedProcedureName, visibility: Visibility, num_locals: u32, - /// The MAST node id for the root of this procedure + /// The MAST root of the procedure. + mast_root: RpoDigest, + /// The MAST node id which resolves to the above MAST root. body_node_id: MastNodeId, /// The set of MAST roots called by this procedure callset: CallSet, } -/// Builder +// ------------------------------------------------------------------------------------------------ +/// Constructors impl Procedure { - pub(crate) fn new( + fn new( path: FullyQualifiedProcedureName, visibility: Visibility, num_locals: u32, + mast_root: RpoDigest, body_node_id: MastNodeId, ) -> Self { Self { @@ -48,6 +190,7 @@ impl Procedure { path, visibility, num_locals, + mast_root, body_node_id, callset: Default::default(), } @@ -69,7 +212,8 @@ impl Procedure { } } -/// Metadata +// ------------------------------------------------------------------------------------------------ +/// Public accessors impl Procedure { /// Returns a reference to the name of this procedure #[allow(unused)] @@ -105,9 +249,8 @@ impl Procedure { } /// Returns the root of this procedure's MAST. - pub fn mast_root(&self, mast_forest: &MastForest) -> RpoDigest { - let body_node = &mast_forest[self.body_node_id]; - body_node.digest() + pub fn mast_root(&self) -> RpoDigest { + self.mast_root } /// Returns a reference to the MAST node ID of this procedure. From 4f0dbf2ddf94be2ce24239940efc306e2d86e2d5 Mon Sep 17 00:00:00 2001 From: Bobbin Threadbare Date: Tue, 23 Jul 2024 14:50:18 -0700 Subject: [PATCH 6/7] chore: fix typos and add section separators --- .../module_graph/analysis/rewrite_check.rs | 18 ++- .../assembler/module_graph/name_resolver.rs | 134 +++++++++--------- .../assembler/module_graph/rewrites/module.rs | 3 + 3 files changed, 82 insertions(+), 73 deletions(-) diff --git a/assembly/src/assembler/module_graph/analysis/rewrite_check.rs b/assembly/src/assembler/module_graph/analysis/rewrite_check.rs index 4c097133e5..c2941a181e 100644 --- a/assembly/src/assembler/module_graph/analysis/rewrite_check.rs +++ b/assembly/src/assembler/module_graph/analysis/rewrite_check.rs @@ -11,15 +11,18 @@ use crate::{ AssemblyError, Spanned, }; +// MAYBE REWRITE CHECK +// ================================================================================================ + /// [MaybeRewriteCheck] is a simple analysis pass over a [Module], that looks for evidence that new /// information has been found that would result in at least one rewrite to the module body. /// -/// This pass is intended for modules that were already added to a [ModuleGraph], and so have been -/// rewritten at least once before. When new modules are added to the graph, the introduction of -/// those modules may allow us to resolve invocation targets that were previously unresolvable, or -/// that resolved as phantoms due to missing definitions. When that occurs, we want to go back and -/// rewrite all of the modules that can be further refined as a result of that additional -/// information. +/// This pass is intended for modules that were already added to a [super::super::ModuleGraph], and +/// so have been rewritten at least once before. When new modules are added to the graph, the +/// introduction of those modules may allow us to resolve invocation targets that were previously +/// unresolvable, or that resolved as phantoms due to missing definitions. When that occurs, we +/// want to go back and rewrite all of the modules that can be further refined as a result of that +/// additional information. pub struct MaybeRewriteCheck<'a, 'b: 'a> { resolver: &'a NameResolver<'b>, } @@ -44,6 +47,9 @@ impl<'a, 'b: 'a> MaybeRewriteCheck<'a, 'b> { } } +// REWRITE CHECK VISITOR +// ================================================================================================ + struct RewriteCheckVisitor<'a, 'b: 'a> { resolver: &'a NameResolver<'b>, module_id: ModuleIndex, diff --git a/assembly/src/assembler/module_graph/name_resolver.rs b/assembly/src/assembler/module_graph/name_resolver.rs index 07fb614dc6..4703ebf94b 100644 --- a/assembly/src/assembler/module_graph/name_resolver.rs +++ b/assembly/src/assembler/module_graph/name_resolver.rs @@ -120,9 +120,71 @@ impl<'a> NameResolver<'a> { }); } + /// Resolve `target`, a possibly-resolved callee identifier, to a [ResolvedTarget], using + /// `caller` as the context. + pub fn resolve_target( + &self, + caller: &CallerInfo, + target: &InvocationTarget, + ) -> Result { + match target { + InvocationTarget::MastRoot(mast_root) => { + match self.graph.get_procedure_index_by_digest(mast_root) { + None => Ok(ResolvedTarget::Phantom(mast_root.into_inner())), + Some(gid) => Ok(ResolvedTarget::Exact { gid }), + } + } + InvocationTarget::ProcedureName(ref callee) => self.resolve(caller, callee), + InvocationTarget::ProcedurePath { + ref name, + module: ref imported_module, + } => match self.resolve_import(caller, imported_module) { + Some(imported_module) => { + let fqn = FullyQualifiedProcedureName { + span: target.span(), + module: imported_module.into_inner().clone(), + name: name.clone(), + }; + let gid = self.find(caller, &fqn)?; + let path = self.module_path(gid.module); + let pending_offset = self.graph.modules.len(); + let name = if gid.module.as_usize() >= pending_offset { + self.pending[gid.module.as_usize() - pending_offset] + .resolver + .get_name(gid.index) + .clone() + } else { + self.graph[gid].name().clone() + }; + Ok(ResolvedTarget::Resolved { + gid, + target: InvocationTarget::AbsoluteProcedurePath { name, path }, + }) + } + None => Err(AssemblyError::UndefinedModule { + span: target.span(), + source_file: caller.source_file.clone(), + path: LibraryPath::new_from_components( + LibraryNamespace::User(imported_module.clone().into_inner()), + [], + ), + }), + }, + InvocationTarget::AbsoluteProcedurePath { ref name, ref path } => { + let fqn = FullyQualifiedProcedureName { + span: target.span(), + module: path.clone(), + name: name.clone(), + }; + let gid = self.find(caller, &fqn)?; + Ok(ResolvedTarget::Exact { gid }) + } + } + } + /// Resolver `callee` to a [ResolvedTarget], using `caller` as the context in which `callee` /// should be resolved. - pub fn resolve( + fn resolve( &self, caller: &CallerInfo, callee: &ProcedureName, @@ -175,7 +237,7 @@ impl<'a> NameResolver<'a> { /// Resolve `name`, the name of an imported module, to a [LibraryPath], using `caller` as the /// context. - pub fn resolve_import(&self, caller: &CallerInfo, name: &Ident) -> Option> { + fn resolve_import(&self, caller: &CallerInfo, name: &Ident) -> Option> { let pending_offset = self.graph.modules.len(); if caller.module.as_usize() >= pending_offset { self.pending[caller.module.as_usize() - pending_offset] @@ -188,68 +250,6 @@ impl<'a> NameResolver<'a> { } } - /// Resolve `target`, a possibly-resolved callee identifier, to a [ResolvedTarget], using - /// `caller` as the context. - pub fn resolve_target( - &self, - caller: &CallerInfo, - target: &InvocationTarget, - ) -> Result { - match target { - InvocationTarget::MastRoot(mast_root) => { - match self.graph.get_procedure_index_by_digest(mast_root) { - None => Ok(ResolvedTarget::Phantom(mast_root.into_inner())), - Some(gid) => Ok(ResolvedTarget::Exact { gid }), - } - } - InvocationTarget::ProcedureName(ref callee) => self.resolve(caller, callee), - InvocationTarget::ProcedurePath { - ref name, - module: ref imported_module, - } => match self.resolve_import(caller, imported_module) { - Some(imported_module) => { - let fqn = FullyQualifiedProcedureName { - span: target.span(), - module: imported_module.into_inner().clone(), - name: name.clone(), - }; - let gid = self.find(caller, &fqn)?; - let path = self.module_path(gid.module); - let pending_offset = self.graph.modules.len(); - let name = if gid.module.as_usize() >= pending_offset { - self.pending[gid.module.as_usize() - pending_offset] - .resolver - .get_name(gid.index) - .clone() - } else { - self.graph[gid].name().clone() - }; - Ok(ResolvedTarget::Resolved { - gid, - target: InvocationTarget::AbsoluteProcedurePath { name, path }, - }) - } - None => Err(AssemblyError::UndefinedModule { - span: target.span(), - source_file: caller.source_file.clone(), - path: LibraryPath::new_from_components( - LibraryNamespace::User(imported_module.clone().into_inner()), - [], - ), - }), - }, - InvocationTarget::AbsoluteProcedurePath { ref name, ref path } => { - let fqn = FullyQualifiedProcedureName { - span: target.span(), - module: path.clone(), - name: name.clone(), - }; - let gid = self.find(caller, &fqn)?; - Ok(ResolvedTarget::Exact { gid }) - } - } - } - fn resolve_local( &self, caller: &CallerInfo, @@ -278,11 +278,11 @@ impl<'a> NameResolver<'a> { } } - /// Resolve `name` to its concrete definition, returning the corresponding + /// Resolve `callee` to its concrete definition, returning the corresponding /// [GlobalProcedureIndex]. /// /// If an error occurs during resolution, or the name cannot be resolved, `Err` is returned. - pub fn find( + fn find( &self, caller: &CallerInfo, callee: &FullyQualifiedProcedureName, @@ -430,7 +430,7 @@ impl<'a> NameResolver<'a> { } /// Resolve a [LibraryPath] to a [ModuleIndex] in this graph - pub fn find_module_index(&self, name: &LibraryPath) -> Option { + fn find_module_index(&self, name: &LibraryPath) -> Option { self.graph .modules .iter() diff --git a/assembly/src/assembler/module_graph/rewrites/module.rs b/assembly/src/assembler/module_graph/rewrites/module.rs index 9d39307b89..4637ff2f4e 100644 --- a/assembly/src/assembler/module_graph/rewrites/module.rs +++ b/assembly/src/assembler/module_graph/rewrites/module.rs @@ -14,6 +14,9 @@ use crate::{ AssemblyError, Spanned, }; +// MODULE REWRITE CHECK +// ================================================================================================ + /// A [ModuleRewriter] handles applying all of the module-wide rewrites to a [Module] that is being /// added to a [ModuleGraph]. These rewrites include: /// From 454d34415223ad868d29a4e6e2720cd36a48875b Mon Sep 17 00:00:00 2001 From: Andrey Khmuro Date: Wed, 24 Jul 2024 12:37:17 +0300 Subject: [PATCH 7/7] Rename `internals` feature to `testing` (#1399) * refactor: rename internals feature to testing * chore: update CHANGELOG --- CHANGELOG.md | 1 + Makefile | 2 +- air/Cargo.toml | 2 +- air/src/trace/main_trace.rs | 4 ++-- processor/Cargo.toml | 2 +- processor/src/host/advice/inputs.rs | 6 +++--- processor/src/host/advice/providers.rs | 4 ++-- processor/src/host/mod.rs | 4 ++-- processor/src/lib.rs | 4 ++-- processor/src/stack/mod.rs | 2 +- processor/src/stack/trace.rs | 2 +- stdlib/Cargo.toml | 2 +- test-utils/Cargo.toml | 2 +- 13 files changed, 19 insertions(+), 18 deletions(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index 53d198691c..f71356af6d 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -26,6 +26,7 @@ #### Changed - When using `if.(true|false) .. end`, the parser used to emit an empty block for the branch that was elided. The parser now emits a block containing a single `nop` instruction instead, which is equivalent to the code emitted by the assembler when lowering to MAST. +- `internals` configuration feature was renamed to `testing` (#1399). ## 0.9.2 (2024-05-22) - `stdlib` crate only diff --git a/Makefile b/Makefile index c5df854680..8ad16f81e1 100644 --- a/Makefile +++ b/Makefile @@ -51,7 +51,7 @@ mdbook: ## Generates mdbook documentation .PHONY: test test: ## Runs all tests - $(DEBUG_ASSERTIONS) cargo nextest run --cargo-profile test-release --features internals + $(DEBUG_ASSERTIONS) cargo nextest run --cargo-profile test-release --features testing # --- checking ------------------------------------------------------------------------------------ diff --git a/air/Cargo.toml b/air/Cargo.toml index 267be62ba6..47249e8f70 100644 --- a/air/Cargo.toml +++ b/air/Cargo.toml @@ -28,7 +28,7 @@ harness = false [features] default = ["std"] std = ["vm-core/std", "winter-air/std"] -internals = [] +testing = [] [dependencies] vm-core = { package = "miden-core", path = "../core", version = "0.9", default-features = false } diff --git a/air/src/trace/main_trace.rs b/air/src/trace/main_trace.rs index c25d8716c8..d51c3fcec9 100644 --- a/air/src/trace/main_trace.rs +++ b/air/src/trace/main_trace.rs @@ -19,7 +19,7 @@ use super::{ use core::ops::{Deref, Range}; use vm_core::{utils::range, Felt, Word, ONE, ZERO}; -#[cfg(any(test, feature = "internals"))] +#[cfg(any(test, feature = "testing"))] use alloc::vec::Vec; // CONSTANTS @@ -54,7 +54,7 @@ impl MainTrace { self.columns.num_rows() } - #[cfg(any(test, feature = "internals"))] + #[cfg(any(test, feature = "testing"))] pub fn get_column_range(&self, range: Range) -> Vec> { range.fold(vec![], |mut acc, col_idx| { acc.push(self.get_column(col_idx).to_vec()); diff --git a/processor/Cargo.toml b/processor/Cargo.toml index f7e31acebd..4bd83943c2 100644 --- a/processor/Cargo.toml +++ b/processor/Cargo.toml @@ -20,7 +20,7 @@ doctest = false [features] concurrent = ["std", "winter-prover/concurrent"] default = ["std"] -internals = ["miden-air/internals"] +testing = ["miden-air/testing"] std = ["vm-core/std", "winter-prover/std"] [dependencies] diff --git a/processor/src/host/advice/inputs.rs b/processor/src/host/advice/inputs.rs index 7e217ecea5..ffce6345f3 100644 --- a/processor/src/host/advice/inputs.rs +++ b/processor/src/host/advice/inputs.rs @@ -18,7 +18,7 @@ use vm_core::crypto::hash::RpoDigest; /// 2. Key-mapped element lists which can be pushed onto the advice stack. /// 3. Merkle store, which is used to provide nondeterministic inputs for instructions that operates /// with Merkle trees. -#[cfg(not(feature = "internals"))] +#[cfg(not(feature = "testing"))] #[derive(Clone, Debug, Default)] pub struct AdviceInputs { stack: Vec, @@ -132,10 +132,10 @@ impl AdviceInputs { } } -// INTERNALS +// TESTING // ================================================================================================ -#[cfg(feature = "internals")] +#[cfg(feature = "testing")] #[derive(Clone, Debug, Default)] pub struct AdviceInputs { pub stack: Vec, diff --git a/processor/src/host/advice/providers.rs b/processor/src/host/advice/providers.rs index e6846ef4f5..3e7cd11c52 100644 --- a/processor/src/host/advice/providers.rs +++ b/processor/src/host/advice/providers.rs @@ -243,7 +243,7 @@ impl From for MemAdviceProvider { } /// Accessors to internal data structures of the provider used for testing purposes. -#[cfg(any(test, feature = "internals"))] +#[cfg(any(test, feature = "testing"))] impl MemAdviceProvider { /// Returns the current state of the advice stack. pub fn stack(&self) -> &[Felt] { @@ -364,7 +364,7 @@ impl From for RecAdviceProvider { } /// Accessors to internal data structures of the provider used for testing purposes. -#[cfg(any(test, feature = "internals"))] +#[cfg(any(test, feature = "testing"))] impl RecAdviceProvider { /// Returns the current state of the advice stack. pub fn stack(&self) -> &[Felt] { diff --git a/processor/src/host/mod.rs b/processor/src/host/mod.rs index d6bfe9a79d..86a2128092 100644 --- a/processor/src/host/mod.rs +++ b/processor/src/host/mod.rs @@ -309,12 +309,12 @@ where self.store.insert(mast_forest) } - #[cfg(any(test, feature = "internals"))] + #[cfg(any(test, feature = "testing"))] pub fn advice_provider(&self) -> &A { &self.adv_provider } - #[cfg(any(test, feature = "internals"))] + #[cfg(any(test, feature = "testing"))] pub fn advice_provider_mut(&mut self) -> &mut A { &mut self.adv_provider } diff --git a/processor/src/lib.rs b/processor/src/lib.rs index 95ab09e7ed..3f5648c4ea 100644 --- a/processor/src/lib.rs +++ b/processor/src/lib.rs @@ -171,7 +171,7 @@ where /// However, for situations in which you want finer-grained control over those steps, you will need /// to construct an instance of [Process] using [Process::new], invoke [Process::execute], and then /// get the execution trace using [ExecutionTrace::new] using the outputs produced by execution. -#[cfg(not(any(test, feature = "internals")))] +#[cfg(not(any(test, feature = "testing")))] pub struct Process where H: Host, @@ -186,7 +186,7 @@ where enable_tracing: bool, } -#[cfg(any(test, feature = "internals"))] +#[cfg(any(test, feature = "testing"))] pub struct Process where H: Host, diff --git a/processor/src/stack/mod.rs b/processor/src/stack/mod.rs index 86419d23fe..cd65cc74a0 100644 --- a/processor/src/stack/mod.rs +++ b/processor/src/stack/mod.rs @@ -330,7 +330,7 @@ impl Stack { /// Returns state of stack item columns at the current clock cycle. This does not include stack /// values in the overflow table. - #[cfg(any(test, feature = "internals"))] + #[cfg(any(test, feature = "testing"))] pub fn trace_state(&self) -> [Felt; STACK_TOP_SIZE] { self.trace.get_stack_state_at(self.clk) } diff --git a/processor/src/stack/trace.rs b/processor/src/stack/trace.rs index 960fd79827..c8534df24f 100644 --- a/processor/src/stack/trace.rs +++ b/processor/src/stack/trace.rs @@ -200,7 +200,7 @@ impl StackTrace { // -------------------------------------------------------------------------------------------- /// Returns the stack trace state at the specified clock cycle. - #[cfg(any(test, feature = "internals"))] + #[cfg(any(test, feature = "testing"))] pub fn get_stack_state_at(&self, clk: u32) -> [Felt; STACK_TOP_SIZE] { let mut result = [ZERO; STACK_TOP_SIZE]; for (result, column) in result.iter_mut().zip(self.stack.iter()) { diff --git a/stdlib/Cargo.toml b/stdlib/Cargo.toml index 55a3849b80..b7fe8bcbdf 100644 --- a/stdlib/Cargo.toml +++ b/stdlib/Cargo.toml @@ -35,7 +35,7 @@ num = "0.4.1" num-bigint = "0.4" pretty_assertions = "1.4" processor = { package = "miden-processor", path = "../processor", version = "0.9", default-features = false, features = [ - "internals", + "testing", ] } rand = { version = "0.8.5", default-features = false } serde_json = "1.0" diff --git a/test-utils/Cargo.toml b/test-utils/Cargo.toml index 0873c88315..680a79986a 100644 --- a/test-utils/Cargo.toml +++ b/test-utils/Cargo.toml @@ -28,7 +28,7 @@ assembly = { package = "miden-assembly", path = "../assembly", version = "0.9", "testing", ] } processor = { package = "miden-processor", path = "../processor", version = "0.9", default-features = false, features = [ - "internals", + "testing", ] } prover = { package = "miden-prover", path = "../prover", version = "0.9", default-features = false } test-case = "3.2"