From b192c61216d4dacfe19cb8609ba25f6b0d74193d Mon Sep 17 00:00:00 2001 From: Bobbin Threadbare <43513081+bobbinth@users.noreply.github.com> Date: Tue, 27 Aug 2024 13:11:21 -0700 Subject: [PATCH] refactor: wrap MastForest in Program and Library in Arc (#1465) * refactor: wrap MastFores in Program and Library in Arc * fix: enforce that program entrypoints are procedure roots * refactor: add external nodes to MastForest for re-exports --- CHANGELOG.md | 7 + assembly/src/assembler/mast_forest_builder.rs | 82 ++++----- assembly/src/assembler/mod.rs | 42 +++-- assembly/src/assembler/procedure.rs | 4 +- assembly/src/assembler/tests.rs | 6 +- assembly/src/library/error.rs | 2 + assembly/src/library/mod.rs | 170 +++++++----------- assembly/src/library/module.rs | 4 +- assembly/src/library/tests.rs | 167 ++++++++++++++++- core/src/mast/mod.rs | 2 + core/src/program.rs | 105 ++++++----- miden/src/examples/blake3.rs | 2 +- miden/src/tools/mod.rs | 2 +- processor/src/chiplets/tests.rs | 3 +- processor/src/decoder/tests.rs | 45 +++-- processor/src/host/mast_forest_store.rs | 4 +- processor/src/host/mod.rs | 2 +- processor/src/lib.rs | 2 +- processor/src/trace/tests/chiplets/hasher.rs | 14 +- processor/src/trace/tests/decoder.rs | 33 ++-- processor/src/trace/tests/mod.rs | 6 +- stdlib/src/lib.rs | 18 +- stdlib/tests/mem/mod.rs | 2 +- test-utils/src/lib.rs | 37 ++-- 24 files changed, 448 insertions(+), 313 deletions(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index 8f3c11b94..48e51b72b 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -1,5 +1,12 @@ # Changelog +## 0.11.0 (TBD) + +#### Changes + +- [BREAKING] Wrapped `MastForest`s in `Program` and `Library` structs in `Arc` (#1465). + + ## 0.10.5 (2024-08-21) #### Enhancements diff --git a/assembly/src/assembler/mast_forest_builder.rs b/assembly/src/assembler/mast_forest_builder.rs index e072bfff2..88615e6df 100644 --- a/assembly/src/assembler/mast_forest_builder.rs +++ b/assembly/src/assembler/mast_forest_builder.rs @@ -1,9 +1,7 @@ use alloc::{ collections::{BTreeMap, BTreeSet}, - sync::Arc, vec::Vec, }; -use core::ops::Index; use vm_core::{ crypto::hash::RpoDigest, @@ -24,14 +22,35 @@ const PROCEDURE_INLINING_THRESHOLD: usize = 32; // ================================================================================================ /// Builder for a [`MastForest`]. +/// +/// The purpose of the builder is to ensure that the underlying MAST forest contains as little +/// information as possible needed to adequately describe the logical MAST forest. Specifically: +/// - The builder ensures that only one copy of a given node exists in the MAST forest (i.e., no two +/// nodes have the same hash). +/// - The builder tries to merge adjacent basic blocks and eliminate the source block whenever this +/// does not have an impact on other nodes in the forest. #[derive(Clone, Debug, Default)] pub struct MastForestBuilder { + /// The MAST forest being built by this builder; this MAST forest is up-to-date - i.e., all + /// nodes added to the MAST forest builder are also immediately added to the underlying MAST + /// forest. mast_forest: MastForest, + /// A map of MAST node digests to their corresponding positions in the MAST forest. It is + /// guaranteed that a given digests maps to exactly one node in the MAST forest. node_id_by_hash: BTreeMap, - procedures: BTreeMap>, - procedure_hashes: BTreeMap, + /// A map of all procedures added to the MAST forest indexed by their global procedure ID. + /// This includes all local, exported, and re-exported procedures. In case multiple procedures + /// with the same digest are added to the MAST forest builder, only the first procedure is + /// added to the map, and all subsequent insertions are ignored. + procedures: BTreeMap, + /// A map from procedure MAST root to its global procedure index. Similar to the `procedures` + /// map, this map contains only the first inserted procedure for procedures with the same MAST + /// root. proc_gid_by_hash: BTreeMap, - merged_node_ids: BTreeSet, + /// A set of IDs for basic blocks which have been merged into a bigger basic blocks. This is + /// used as a candidate set of nodes that may be eliminated if the are not referenced by any + /// other node in the forest and are not a root of any procedure. + merged_basic_block_ids: BTreeSet, } impl MastForestBuilder { @@ -42,7 +61,7 @@ impl MastForestBuilder { /// unchanged. Any [`MastNodeId`] used in reference to the old [`MastForest`] should be remapped /// using this map. pub fn build(mut self) -> (MastForest, Option>) { - let nodes_to_remove = get_nodes_to_remove(self.merged_node_ids, &self.mast_forest); + let nodes_to_remove = get_nodes_to_remove(self.merged_basic_block_ids, &self.mast_forest); let id_remappings = self.mast_forest.remove_nodes(&nodes_to_remove); (self.mast_forest, id_remappings) @@ -109,21 +128,21 @@ impl MastForestBuilder { /// Returns a reference to the procedure with the specified [`GlobalProcedureIndex`], or None /// if such a procedure is not present in this MAST forest builder. #[inline(always)] - pub fn get_procedure(&self, gid: GlobalProcedureIndex) -> Option> { - self.procedures.get(&gid).cloned() + pub fn get_procedure(&self, gid: GlobalProcedureIndex) -> Option<&Procedure> { + self.procedures.get(&gid) } /// Returns the hash of the procedure with the specified [`GlobalProcedureIndex`], or None if /// such a procedure is not present in this MAST forest builder. #[inline(always)] pub fn get_procedure_hash(&self, gid: GlobalProcedureIndex) -> Option { - self.procedure_hashes.get(&gid).cloned() + self.procedures.get(&gid).map(|proc| proc.mast_root()) } /// Returns a reference to the procedure with the specified MAST root, or None /// if such a procedure is not present in this MAST forest builder. #[inline(always)] - pub fn find_procedure(&self, mast_root: &RpoDigest) -> Option> { + pub fn find_procedure(&self, mast_root: &RpoDigest) -> Option<&Procedure> { self.proc_gid_by_hash.get(mast_root).and_then(|gid| self.get_procedure(*gid)) } @@ -141,18 +160,9 @@ impl MastForestBuilder { } } +// ------------------------------------------------------------------------------------------------ +/// Procedure insertion impl MastForestBuilder { - pub fn insert_procedure_hash( - &mut self, - gid: GlobalProcedureIndex, - proc_hash: RpoDigest, - ) -> Result<(), AssemblyError> { - // TODO(plafer): Check if exists - self.procedure_hashes.insert(gid, proc_hash); - - Ok(()) - } - /// Inserts a procedure into this MAST forest builder. /// /// If the procedure with the same ID already exists in this forest builder, this will have @@ -202,19 +212,17 @@ impl MastForestBuilder { } } - self.make_root(procedure.body_node_id()); + self.mast_forest.make_root(procedure.body_node_id()); self.proc_gid_by_hash.insert(proc_root, gid); - self.insert_procedure_hash(gid, procedure.mast_root())?; - self.procedures.insert(gid, Arc::new(procedure)); + self.procedures.insert(gid, procedure); Ok(()) } +} - /// Marks the given [`MastNodeId`] as being the root of a procedure. - pub fn make_root(&mut self, new_root_id: MastNodeId) { - self.mast_forest.make_root(new_root_id) - } - +// ------------------------------------------------------------------------------------------------ +/// Joining nodes +impl MastForestBuilder { /// Builds a tree of `JOIN` operations to combine the provided MAST node IDs. pub fn join_nodes(&mut self, node_ids: Vec) -> Result { debug_assert!(!node_ids.is_empty(), "cannot combine empty MAST node id list"); @@ -254,7 +262,7 @@ impl MastForestBuilder { let mut contiguous_basic_block_ids: Vec = Vec::new(); for mast_node_id in node_ids { - if self[mast_node_id].is_basic_block() { + if self.mast_forest[mast_node_id].is_basic_block() { contiguous_basic_block_ids.push(mast_node_id); } else { merged_node_ids.extend(self.merge_basic_blocks(&contiguous_basic_block_ids)?); @@ -293,7 +301,8 @@ impl MastForestBuilder { for &basic_block_id in contiguous_basic_block_ids { // It is safe to unwrap here, since we already checked that all IDs in // `contiguous_basic_block_ids` are `BasicBlockNode`s - let basic_block_node = self[basic_block_id].get_basic_block().unwrap().clone(); + let basic_block_node = + self.mast_forest[basic_block_id].get_basic_block().unwrap().clone(); // check if the block should be merged with other blocks if should_merge( @@ -322,7 +331,7 @@ impl MastForestBuilder { } // Mark the removed basic blocks as merged - self.merged_node_ids.extend(contiguous_basic_block_ids.iter()); + self.merged_basic_block_ids.extend(contiguous_basic_block_ids.iter()); if !operations.is_empty() || !decorators.is_empty() { let merged_basic_block = self.ensure_block(operations, Some(decorators))?; @@ -414,15 +423,6 @@ impl MastForestBuilder { } } -impl Index for MastForestBuilder { - type Output = MastNode; - - #[inline(always)] - fn index(&self, node_id: MastNodeId) -> &Self::Output { - &self.mast_forest[node_id] - } -} - // HELPER FUNCTIONS // ================================================================================================ diff --git a/assembly/src/assembler/mod.rs b/assembly/src/assembler/mod.rs index ed67ed154..ba20fd764 100644 --- a/assembly/src/assembler/mod.rs +++ b/assembly/src/assembler/mod.rs @@ -301,7 +301,7 @@ impl Assembler { // TODO: show a warning if library exports are empty? let (mast_forest, _) = mast_forest_builder.build(); - Ok(Library::new(mast_forest, exports)) + Ok(Library::new(mast_forest.into(), exports)?) } /// Assembles the provided module into a [KernelLibrary] intended to be used as a Kernel. @@ -343,7 +343,7 @@ impl Assembler { // TODO: show a warning if library exports are empty? let (mast_forest, _) = mast_forest_builder.build(); - let library = Library::new(mast_forest, exports); + let library = Library::new(mast_forest.into(), exports)?; Ok(library.try_into()?) } @@ -379,21 +379,19 @@ impl Assembler { // Compile the module graph rooted at the entrypoint let mut mast_forest_builder = MastForestBuilder::default(); self.compile_subgraph(entrypoint, &mut mast_forest_builder)?; - let entry_procedure = mast_forest_builder + let entry_node_id = mast_forest_builder .get_procedure(entrypoint) - .expect("compilation succeeded but root not found in cache"); + .expect("compilation succeeded but root not found in cache") + .body_node_id(); + // in case the node IDs changed, update the entrypoint ID to the new value let (mast_forest, id_remappings) = mast_forest_builder.build(); - let entry_node_id = { - let old_entry_node_id = entry_procedure.body_node_id(); - - id_remappings - .map(|id_remappings| id_remappings[&old_entry_node_id]) - .unwrap_or(old_entry_node_id) - }; + let entry_node_id = id_remappings + .map(|id_remappings| id_remappings[&entry_node_id]) + .unwrap_or(entry_node_id); Ok(Program::with_kernel( - mast_forest, + mast_forest.into(), entry_node_id, self.module_graph.kernel().clone(), )) @@ -473,8 +471,13 @@ impl Assembler { // Compile this procedure let procedure = self.compile_procedure(pctx, mast_forest_builder)?; + // TODO: if a re-exported procedure with the same MAST root had been previously + // added to the builder, this will result in unreachable nodes added to the + // MAST forest. This is because while we won't insert a duplicate node for the + // procedure body node itself, all nodes that make up the procedure body would + // be added to the forest. - // Cache the compiled procedure. + // Cache the compiled procedure self.module_graph.register_mast_root(procedure_gid, procedure.mast_root())?; mast_forest_builder.insert_procedure(procedure_gid, procedure)?; }, @@ -493,15 +496,22 @@ impl Assembler { ) .with_span(proc_alias.span()); - let proc_alias_root = self.resolve_target( + let proc_mast_root = self.resolve_target( InvokeKind::ProcRef, &proc_alias.target().into(), &pctx, mast_forest_builder, )?; + + // insert external node into the MAST forest for this procedure; if a procedure + // with the same MAST rood had been previously added to the builder, this will + // have no effect + let proc_node_id = mast_forest_builder.ensure_external(proc_mast_root)?; + let procedure = pctx.into_procedure(proc_mast_root, proc_node_id); + // Make the MAST root available to all dependents - self.module_graph.register_mast_root(procedure_gid, proc_alias_root)?; - mast_forest_builder.insert_procedure_hash(procedure_gid, proc_alias_root)?; + self.module_graph.register_mast_root(procedure_gid, proc_mast_root)?; + mast_forest_builder.insert_procedure(procedure_gid, procedure)?; }, } } diff --git a/assembly/src/assembler/procedure.rs b/assembly/src/assembler/procedure.rs index 325e1e44c..167a25806 100644 --- a/assembly/src/assembler/procedure.rs +++ b/assembly/src/assembler/procedure.rs @@ -93,7 +93,9 @@ impl ProcedureContext { /// /// The passed-in `mast_root` defines the MAST root of the procedure's body while /// `mast_node_id` specifies the ID of the procedure's body node in the MAST forest in - /// which the procedure is defined. + /// which the procedure is defined. Note that if the procedure is re-exported (i.e., the body + /// of the procedure is defined in some other MAST forest) `mast_node_id` will point to a + /// single `External` node. /// ///
/// `mast_root` and `mast_node_id` must be consistent. That is, the node located in the MAST diff --git a/assembly/src/assembler/tests.rs b/assembly/src/assembler/tests.rs index 0cd220b0b..74059a1ba 100644 --- a/assembly/src/assembler/tests.rs +++ b/assembly/src/assembler/tests.rs @@ -142,7 +142,9 @@ fn nested_blocks() -> Result<(), Report> { .join_nodes(vec![before, r#if1, nested, exec_foo_bar_baz_node_id, syscall_foo_node_id]) .unwrap(); - let expected_program = Program::new(expected_mast_forest_builder.build().0, combined_node_id); + let mut expected_mast_forest = expected_mast_forest_builder.build().0; + expected_mast_forest.make_root(combined_node_id); + let expected_program = Program::new(expected_mast_forest.into(), combined_node_id); assert_eq!(expected_program.hash(), program.hash()); // also check that the program has the right number of procedures (which excludes the dummy @@ -214,7 +216,7 @@ fn duplicate_nodes() { expected_mast_forest.make_root(root_id); - let expected_program = Program::new(expected_mast_forest, root_id); + let expected_program = Program::new(expected_mast_forest.into(), root_id); assert_eq!(program, expected_program); } diff --git a/assembly/src/library/error.rs b/assembly/src/library/error.rs index c8e31f79d..3df795ca3 100644 --- a/assembly/src/library/error.rs +++ b/assembly/src/library/error.rs @@ -11,4 +11,6 @@ pub enum LibraryError { InvalidKernelExport { procedure_path: QualifiedProcedureName }, #[error(transparent)] Kernel(#[from] KernelError), + #[error("invalid export: no procedure root for {procedure_path} procedure")] + NoProcedureRootForExport { procedure_path: QualifiedProcedureName }, } diff --git a/assembly/src/library/mod.rs b/assembly/src/library/mod.rs index 8ae3415bd..c7656bb97 100644 --- a/assembly/src/library/mod.rs +++ b/assembly/src/library/mod.rs @@ -1,6 +1,7 @@ use alloc::{ collections::{BTreeMap, BTreeSet}, string::{String, ToString}, + sync::Arc, vec::Vec, }; @@ -43,12 +44,12 @@ pub struct Library { /// The content hash of this library, formed by hashing the roots of all exports in /// lexicographical order (by digest, not procedure name) digest: RpoDigest, - /// A map between procedure paths and the corresponding procedure toots in the MAST forest. + /// A map between procedure paths and the corresponding procedure roots in the MAST forest. /// Multiple paths can map to the same root, and also, some roots may not be associated with /// any paths. - exports: BTreeMap, + exports: BTreeMap, /// The MAST forest underlying this library. - mast_forest: MastForest, + mast_forest: Arc, } impl AsRef for Library { @@ -58,47 +59,41 @@ impl AsRef for Library { } } -#[derive(Debug, Clone, PartialEq, Eq)] -#[repr(u8)] -enum Export { - /// The export is contained in the [MastForest] of this library - Local(MastNodeId), - /// The export is a re-export of an externally-defined procedure from another library - External(RpoDigest), -} - +// ------------------------------------------------------------------------------------------------ /// Constructors impl Library { /// Constructs a new [`Library`] from the provided MAST forest and a set of exports. + /// + /// # Errors + /// Returns an error if any of the specified exports do not have a corresponding procedure root + /// in the provided MAST forest. pub fn new( - mast_forest: MastForest, + mast_forest: Arc, exports: BTreeMap, - ) -> Self { + ) -> Result { let mut fqn_to_export = BTreeMap::new(); // convert fqn |-> mast_root map into fqn |-> mast_node_id map for (fqn, mast_root) in exports.into_iter() { - match mast_forest.find_procedure_root(mast_root) { - Some(node_id) => { - fqn_to_export.insert(fqn, Export::Local(node_id)); - }, - None => { - fqn_to_export.insert(fqn, Export::External(mast_root)); - }, + if let Some(proc_node_id) = mast_forest.find_procedure_root(mast_root) { + fqn_to_export.insert(fqn, proc_node_id); + } else { + return Err(LibraryError::NoProcedureRootForExport { procedure_path: fqn }); } } - let digest = content_hash(&fqn_to_export, &mast_forest); + let digest = compute_content_hash(&fqn_to_export, &mast_forest); - Self { + Ok(Self { digest, exports: fqn_to_export, mast_forest, - } + }) } } -/// Accessors +// ------------------------------------------------------------------------------------------------ +/// Public accessors impl Library { /// Returns the [RpoDigest] representing the content hash of this library pub fn digest(&self) -> &RpoDigest { @@ -110,8 +105,29 @@ impl Library { self.exports.keys() } - /// Returns the inner [`MastForest`]. - pub fn mast_forest(&self) -> &MastForest { + /// Returns the number of exports in this library. + pub fn num_exports(&self) -> usize { + self.exports.len() + } + + /// Returns a MAST node ID associated with the specified exported procedure. + /// + /// # Panics + /// Panics if the specified procedure is not exported from this library. + pub fn get_export_node_id(&self, proc_name: &QualifiedProcedureName) -> MastNodeId { + *self.exports.get(proc_name).expect("procedure not exported from the library") + } + + /// Returns true if the specified exported procedure is re-exported from a dependency. + pub fn is_reexport(&self, proc_name: &QualifiedProcedureName) -> bool { + self.exports + .get(proc_name) + .map(|&node_id| self.mast_forest[node_id].is_external()) + .unwrap_or(false) + } + + /// Returns a reference to the inner [`MastForest`]. + pub fn mast_forest(&self) -> &Arc { &self.mast_forest } } @@ -122,17 +138,17 @@ impl Library { pub fn module_infos(&self) -> impl Iterator { let mut modules_by_path: BTreeMap = BTreeMap::new(); - for (proc_name, export) in self.exports.iter() { + for (proc_name, &proc_root_node_id) in self.exports.iter() { modules_by_path .entry(proc_name.module.clone()) .and_modify(|compiled_module| { - let proc_digest = export.digest(&self.mast_forest); + let proc_digest = self.mast_forest[proc_root_node_id].digest(); compiled_module.add_procedure(proc_name.name.clone(), proc_digest); }) .or_insert_with(|| { let mut module_info = ModuleInfo::new(proc_name.module.clone()); - let proc_digest = export.digest(&self.mast_forest); + let proc_digest = self.mast_forest[proc_root_node_id].digest(); module_info.add_procedure(proc_name.name.clone(), proc_digest); module_info @@ -143,12 +159,6 @@ impl Library { } } -impl From for MastForest { - fn from(value: Library) -> Self { - value.mast_forest - } -} - impl Serializable for Library { fn write_into(&self, target: &mut W) { let Self { digest: _, exports, mast_forest } = self; @@ -156,17 +166,17 @@ impl Serializable for Library { mast_forest.write_into(target); target.write_usize(exports.len()); - for (proc_name, export) in exports { + for (proc_name, proc_node_id) in exports { proc_name.module.write_into(target); proc_name.name.as_str().write_into(target); - export.write_into(target); + target.write_u32(proc_node_id.as_u32()); } } } impl Deserializable for Library { fn read_from(source: &mut R) -> Result { - let mast_forest = MastForest::read_from(source)?; + let mast_forest = Arc::new(MastForest::read_from(source)?); let num_exports = source.read_usize()?; let mut exports = BTreeMap::new(); @@ -176,22 +186,22 @@ impl Deserializable for Library { let proc_name = ProcedureName::new(proc_name) .map_err(|err| DeserializationError::InvalidValue(err.to_string()))?; let proc_name = QualifiedProcedureName::new(proc_module, proc_name); - let export = Export::read_with_forest(source, &mast_forest)?; + let proc_node_id = MastNodeId::from_u32_safe(source.read_u32()?, &mast_forest)?; - exports.insert(proc_name, export); + exports.insert(proc_name, proc_node_id); } - let digest = content_hash(&exports, &mast_forest); + let digest = compute_content_hash(&exports, &mast_forest); Ok(Self { digest, exports, mast_forest }) } } -fn content_hash( - exports: &BTreeMap, +fn compute_content_hash( + exports: &BTreeMap, mast_forest: &MastForest, ) -> RpoDigest { - let digests = BTreeSet::from_iter(exports.values().map(|export| export.digest(mast_forest))); + let digests = BTreeSet::from_iter(exports.values().map(|&id| mast_forest[id].digest())); digests .into_iter() .reduce(|a, b| vm_core::crypto::hash::Rpo256::merge(&[a, b])) @@ -295,58 +305,6 @@ mod use_std_library { } } -impl Export { - pub fn digest(&self, mast_forest: &MastForest) -> RpoDigest { - match self { - Self::Local(node_id) => mast_forest[*node_id].digest(), - Self::External(digest) => *digest, - } - } - - fn tag(&self) -> u8 { - // SAFETY: This is safe because we have given this enum a primitive representation with - // #[repr(u8)], with the first field of the underlying union-of-structs the discriminant. - // - // See the section on "accessing the numeric value of the discriminant" - // here: https://doc.rust-lang.org/std/mem/fn.discriminant.html - unsafe { *<*const _>::from(self).cast::() } - } -} - -impl Serializable for Export { - fn write_into(&self, target: &mut W) { - target.write_u8(self.tag()); - match self { - Self::Local(node_id) => target.write_u32(node_id.into()), - Self::External(digest) => digest.write_into(target), - } - } -} - -impl Export { - pub fn read_with_forest( - source: &mut R, - mast_forest: &MastForest, - ) -> Result { - match source.read_u8()? { - 0 => { - let node_id = MastNodeId::from_u32_safe(source.read_u32()?, mast_forest)?; - if !mast_forest.is_procedure_root(node_id) { - return Err(DeserializationError::InvalidValue(format!( - "node with id {node_id} is not a procedure root" - ))); - } - Ok(Self::Local(node_id)) - }, - 1 => RpoDigest::read_from(source).map(Self::External), - n => Err(DeserializationError::InvalidValue(format!( - "{} is not a valid compiled library export entry", - n - ))), - } - } -} - // KERNEL LIBRARY // ================================================================================================ @@ -356,7 +314,7 @@ impl Export { /// - All exported procedures must be exported directly from the kernel namespace (i.e., `#sys`). /// - There must be at least one exported procedure. /// - The number of exported procedures cannot exceed [Kernel::MAX_NUM_PROCEDURES] (i.e., 256). -#[derive(Clone)] +#[derive(Debug, Clone, PartialEq, Eq)] pub struct KernelLibrary { kernel: Kernel, kernel_info: ModuleInfo, @@ -376,13 +334,13 @@ impl KernelLibrary { &self.kernel } - /// Returns the inner [`MastForest`]. - pub fn mast_forest(&self) -> &MastForest { + /// Returns a reference to the inner [`MastForest`]. + pub fn mast_forest(&self) -> &Arc { self.library.mast_forest() } /// Destructures this kernel library into individual parts. - pub fn into_parts(self) -> (Kernel, ModuleInfo, MastForest) { + pub fn into_parts(self) -> (Kernel, ModuleInfo, Arc) { (self.kernel, self.kernel_info, self.library.mast_forest) } } @@ -400,7 +358,7 @@ impl TryFrom for KernelLibrary { let mut kernel_module = ModuleInfo::new(kernel_path.clone()); - for (proc_path, export) in library.exports.iter() { + for (proc_path, &proc_node_id) in library.exports.iter() { // make sure all procedures are exported only from the kernel root if proc_path.module != kernel_path { return Err(LibraryError::InvalidKernelExport { @@ -408,7 +366,7 @@ impl TryFrom for KernelLibrary { }); } - let proc_digest = export.digest(&library.mast_forest); + let proc_digest = library.mast_forest[proc_node_id].digest(); proc_digests.push(proc_digest); kernel_module.add_procedure(proc_path.name.clone(), proc_digest); } @@ -423,12 +381,6 @@ impl TryFrom for KernelLibrary { } } -impl From for MastForest { - fn from(value: KernelLibrary) -> Self { - value.library.mast_forest - } -} - impl Serializable for KernelLibrary { fn write_into(&self, target: &mut W) { let Self { kernel: _, kernel_info: _, library } = self; diff --git a/assembly/src/library/module.rs b/assembly/src/library/module.rs index eb7287769..ed010d0e2 100644 --- a/assembly/src/library/module.rs +++ b/assembly/src/library/module.rs @@ -9,7 +9,7 @@ use crate::{ // MODULE INFO // ================================================================================================ -#[derive(Debug, Clone)] +#[derive(Debug, Clone, PartialEq, Eq)] pub struct ModuleInfo { path: LibraryPath, procedures: Vec, @@ -68,7 +68,7 @@ impl ModuleInfo { } /// Stores the name and digest of a procedure. -#[derive(Debug, Clone)] +#[derive(Debug, Clone, PartialEq, Eq)] pub struct ProcedureInfo { pub name: ProcedureName, pub digest: RpoDigest, diff --git a/assembly/src/library/tests.rs b/assembly/src/library/tests.rs index ef13c8a19..f4f32341f 100644 --- a/assembly/src/library/tests.rs +++ b/assembly/src/library/tests.rs @@ -1,6 +1,5 @@ -use alloc::{string::ToString, vec::Vec}; - -use vm_core::utils::SliceReader; +use alloc::string::ToString; +use core::str::FromStr; use super::*; use crate::{ @@ -19,6 +18,158 @@ macro_rules! parse_module { }}; } +// TESTS +// ================================================================================================ + +#[test] +fn library_exports() -> Result<(), Report> { + let context = TestContext::new(); + + // build the first library + let baz = r#" + export.baz1 + push.7 push.8 sub + end + "#; + let baz = parse_module!(&context, "lib1::baz", baz); + + let lib1 = Assembler::new(context.source_manager()).assemble_library([baz])?; + + // build the second library + let foo = r#" + proc.foo1 + push.1 add + end + + export.foo2 + push.2 add + exec.foo1 + end + + export.foo3 + push.3 mul + exec.foo1 + exec.foo2 + end + "#; + let foo = parse_module!(&context, "lib2::foo", foo); + + // declare bar module + let bar = r#" + use.lib1::baz + use.lib2::foo + + export.baz::baz1->bar1 + + export.foo::foo2->bar2 + + export.bar3 + exec.foo::foo2 + end + + proc.bar4 + push.1 push.2 mul + end + + export.bar5 + push.3 sub + exec.foo::foo2 + exec.bar1 + exec.bar2 + exec.bar4 + end + "#; + let bar = parse_module!(&context, "lib2::bar", bar); + let modules = [foo, bar]; + + let lib2 = Assembler::new(context.source_manager()) + .with_library(lib1)? + .assemble_library(modules.iter().cloned())?; + + let foo2 = QualifiedProcedureName::from_str("lib2::foo::foo2").unwrap(); + let foo3 = QualifiedProcedureName::from_str("lib2::foo::foo3").unwrap(); + let bar1 = QualifiedProcedureName::from_str("lib2::bar::bar1").unwrap(); + let bar2 = QualifiedProcedureName::from_str("lib2::bar::bar2").unwrap(); + let bar3 = QualifiedProcedureName::from_str("lib2::bar::bar3").unwrap(); + let bar5 = QualifiedProcedureName::from_str("lib2::bar::bar5").unwrap(); + + // make sure the library exports all exported procedures + let expected_exports: BTreeSet<_> = [&foo2, &foo3, &bar1, &bar2, &bar3, &bar5].into(); + let actual_exports: BTreeSet<_> = lib2.exports().collect(); + assert_eq!(expected_exports, actual_exports); + + // make sure foo2, bar2, and bar3 map to the same MastNode + assert_eq!(lib2.get_export_node_id(&foo2), lib2.get_export_node_id(&bar2)); + assert_eq!(lib2.get_export_node_id(&foo2), lib2.get_export_node_id(&bar3)); + + // make sure there are 6 roots in the MAST (foo1, foo2, foo3, bar1, bar4, and bar5) + assert_eq!(lib2.mast_forest.num_procedures(), 6); + + // bar1 should be the only re-export + assert!(!lib2.is_reexport(&foo2)); + assert!(!lib2.is_reexport(&foo3)); + assert!(lib2.is_reexport(&bar1)); + assert!(!lib2.is_reexport(&bar2)); + assert!(!lib2.is_reexport(&bar3)); + assert!(!lib2.is_reexport(&bar5)); + + Ok(()) +} + +#[test] +fn library_procedure_collision() -> Result<(), Report> { + let context = TestContext::new(); + + // build the first library + let foo = r#" + export.foo1 + push.1 + if.true + push.1 push.2 add + else + push.1 push.2 mul + end + end + "#; + let foo = parse_module!(&context, "lib1::foo", foo); + let lib1 = Assembler::new(context.source_manager()).assemble_library([foo])?; + + // build the second library which defines the same procedure as the first one + let bar = r#" + use.lib1::foo + + export.foo::foo1->bar1 + + export.bar2 + push.1 + if.true + push.1 push.2 add + else + push.1 push.2 mul + end + end + "#; + let bar = parse_module!(&context, "lib2::bar", bar); + let lib2 = Assembler::new(context.source_manager()) + .with_library(lib1)? + .assemble_library([bar])?; + + let bar1 = QualifiedProcedureName::from_str("lib2::bar::bar1").unwrap(); + let bar2 = QualifiedProcedureName::from_str("lib2::bar::bar2").unwrap(); + + // make sure lib2 has the expected exports (i.e., bar1 and bar2) + assert_eq!(lib2.num_exports(), 2); + assert_eq!(lib2.get_export_node_id(&bar1), lib2.get_export_node_id(&bar2)); + + // make sure only one node was added to the forest + // NOTE: the MAST forest should actually have only 1 node (external node for the re-exported + // procedure), because nodes for the local procedure nodes should be pruned from the forest, + // but this is not implemented yet + assert_eq!(lib2.mast_forest().num_nodes(), 5); + + Ok(()) +} + #[test] fn library_serialization() -> Result<(), Report> { let context = TestContext::new(); @@ -46,13 +197,11 @@ fn library_serialization() -> Result<(), Report> { let modules = [foo, bar]; // serialize/deserialize the bundle with locations - let bundle = Assembler::new(context.source_manager()) - .assemble_library(modules.iter().cloned()) - .unwrap(); + let bundle = + Assembler::new(context.source_manager()).assemble_library(modules.iter().cloned())?; - let mut bytes = Vec::new(); - bundle.write_into(&mut bytes); - let deserialized = Library::read_from(&mut SliceReader::new(&bytes)).unwrap(); + let bytes = bundle.to_bytes(); + let deserialized = Library::read_from_bytes(&bytes).unwrap(); assert_eq!(bundle, deserialized); Ok(()) diff --git a/core/src/mast/mod.rs b/core/src/mast/mod.rs index 87a6daa35..d0fb2577e 100644 --- a/core/src/mast/mod.rs +++ b/core/src/mast/mod.rs @@ -125,6 +125,8 @@ impl MastForest { /// Marks the given [`MastNodeId`] as being the root of a procedure. /// + /// If the specified node is already marked as a root, this will have no effect. + /// /// # Panics /// - if `new_root_id`'s internal index is larger than the number of nodes in this forest (i.e. /// clearly doesn't belong to this MAST forest). diff --git a/core/src/program.rs b/core/src/program.rs index 093e0902a..385baa942 100644 --- a/core/src/program.rs +++ b/core/src/program.rs @@ -1,5 +1,5 @@ -use alloc::vec::Vec; -use core::{fmt, ops::Index}; +use alloc::{sync::Arc, vec::Vec}; +use core::fmt; use miden_crypto::{hash::rpo::RpoDigest, Felt, WORD_SIZE}; use winter_utils::{ByteReader, ByteWriter, Deserializable, DeserializationError, Serializable}; @@ -13,9 +13,14 @@ use crate::{ // PROGRAM // =============================================================================================== +/// An executable program for Miden VM. +/// +/// A program consists of a MAST forest, an entrypoint defining the MAST node at which the program +/// execution begins, and a definition of the kernel against which the program must be executed +/// (the kernel can be an empty kernel). #[derive(Clone, Debug, PartialEq, Eq)] pub struct Program { - mast_forest: MastForest, + mast_forest: Arc, /// The "entrypoint" is the node where execution of the program begins. entrypoint: MastNodeId, kernel: Kernel, @@ -27,38 +32,37 @@ impl Program { /// to be empty. /// /// # Panics: - /// - if `mast_forest` doesn't have an entrypoint - pub fn new(mast_forest: MastForest, entrypoint: MastNodeId) -> Self { - assert!(mast_forest.get_node_by_id(entrypoint).is_some()); - - Self { - mast_forest, - entrypoint, - kernel: Kernel::default(), - } + /// - if `mast_forest` doesn't contain the specified entrypoint. + /// - if the specified entrypoint is not a procedure root in the `mast_forest`. + pub fn new(mast_forest: Arc, entrypoint: MastNodeId) -> Self { + Self::with_kernel(mast_forest, entrypoint, Kernel::default()) } /// Construct a new [`Program`] from the given MAST forest, entrypoint, and kernel. /// /// # Panics: - /// - if `mast_forest` doesn't have an entrypoint - pub fn with_kernel(mast_forest: MastForest, entrypoint: MastNodeId, kernel: Kernel) -> Self { - assert!(mast_forest.get_node_by_id(entrypoint).is_some()); + /// - if `mast_forest` doesn't contain the specified entrypoint. + /// - if the specified entrypoint is not a procedure root in the `mast_forest`. + pub fn with_kernel( + mast_forest: Arc, + entrypoint: MastNodeId, + kernel: Kernel, + ) -> Self { + assert!(mast_forest.get_node_by_id(entrypoint).is_some(), "invalid entrypoint"); + assert!(mast_forest.is_procedure_root(entrypoint), "entrypoint not a procedure"); Self { mast_forest, entrypoint, kernel } } } +// ------------------------------------------------------------------------------------------------ /// Public accessors impl Program { - /// Returns the underlying [`MastForest`]. - pub fn mast_forest(&self) -> &MastForest { - &self.mast_forest - } - - /// Returns the kernel associated with this program. - pub fn kernel(&self) -> &Kernel { - &self.kernel + /// Returns the hash of the program's entrypoint. + /// + /// Equivalently, returns the hash of the root of the entrypoint procedure. + pub fn hash(&self) -> RpoDigest { + self.mast_forest[self.entrypoint].digest() } /// Returns the entrypoint associated with this program. @@ -66,17 +70,20 @@ impl Program { self.entrypoint } - /// Returns the hash of the program's entrypoint. - /// - /// Equivalently, returns the hash of the root of the entrypoint procedure. - pub fn hash(&self) -> RpoDigest { - self.mast_forest[self.entrypoint].digest() + /// Returns a reference to the underlying [`MastForest`]. + pub fn mast_forest(&self) -> &Arc { + &self.mast_forest + } + + /// Returns the kernel associated with this program. + pub fn kernel(&self) -> &Kernel { + &self.kernel } /// Returns the [`MastNode`] associated with the provided [`MastNodeId`] if valid, or else /// `None`. /// - /// This is the faillible version of indexing (e.g. `program[node_id]`). + /// This is the fallible version of indexing (e.g. `program[node_id]`). #[inline(always)] pub fn get_node_by_id(&self, node_id: MastNodeId) -> Option<&MastNode> { self.mast_forest.get_node_by_id(node_id) @@ -94,10 +101,11 @@ impl Program { } } +// ------------------------------------------------------------------------------------------------ /// Serialization +#[cfg(feature = "std")] impl Program { /// Writes this [Program] to the provided file path. - #[cfg(feature = "std")] pub fn write_to_file

(&self, path: P) -> std::io::Result<()> where P: AsRef, @@ -139,26 +147,27 @@ impl Serializable for Program { impl Deserializable for Program { fn read_from(source: &mut R) -> Result { - let mast_forest = source.read()?; + let mast_forest = Arc::new(source.read()?); let kernel = source.read()?; let entrypoint = MastNodeId::from_u32_safe(source.read_u32()?, &mast_forest)?; - Ok(Self { mast_forest, kernel, entrypoint }) - } -} - -impl Index for Program { - type Output = MastNode; + if mast_forest.is_procedure_root(entrypoint) { + return Err(DeserializationError::InvalidValue(format!( + "entrypoint {entrypoint} is not a procedure" + ))); + } - fn index(&self, node_id: MastNodeId) -> &Self::Output { - &self.mast_forest[node_id] + Ok(Self::with_kernel(mast_forest, entrypoint, kernel)) } } +// ------------------------------------------------------------------------------------------------ +// Pretty-printing + impl crate::prettier::PrettyPrint for Program { fn render(&self) -> crate::prettier::Document { use crate::prettier::*; - let entrypoint = self[self.entrypoint()].to_pretty_print(&self.mast_forest); + let entrypoint = self.mast_forest[self.entrypoint()].to_pretty_print(&self.mast_forest); indent(4, const_text("begin") + nl() + entrypoint.render()) + nl() + const_text("end") } @@ -171,12 +180,6 @@ impl fmt::Display for Program { } } -impl From for MastForest { - fn from(program: Program) -> Self { - program.mast_forest - } -} - // PROGRAM INFO // =============================================================================================== @@ -195,17 +198,11 @@ pub struct ProgramInfo { } impl ProgramInfo { - // CONSTRUCTORS - // -------------------------------------------------------------------------------------------- - /// Creates a new instance of a program info. pub const fn new(program_hash: RpoDigest, kernel: Kernel) -> Self { Self { program_hash, kernel } } - // PUBLIC ACCESSORS - // -------------------------------------------------------------------------------------------- - /// Returns the program hash computed from its code block root. pub const fn program_hash(&self) -> &RpoDigest { &self.program_hash @@ -231,8 +228,8 @@ impl From for ProgramInfo { } } -// SERIALIZATION // ------------------------------------------------------------------------------------------------ +// Serialization impl Serializable for ProgramInfo { fn write_into(&self, target: &mut W) { @@ -249,8 +246,8 @@ impl Deserializable for ProgramInfo { } } -// TO ELEMENTS // ------------------------------------------------------------------------------------------------ +// ToElements implementation impl ToElements for ProgramInfo { fn to_elements(&self) -> Vec { diff --git a/miden/src/examples/blake3.rs b/miden/src/examples/blake3.rs index 6db77685d..3960bd5ec 100644 --- a/miden/src/examples/blake3.rs +++ b/miden/src/examples/blake3.rs @@ -22,7 +22,7 @@ pub fn get_example(n: usize) -> Example> { ); let mut host = DefaultHost::default(); - host.load_mast_forest(StdLibrary::default().into()); + host.load_mast_forest(StdLibrary::default().mast_forest().clone()); let stack_inputs = StackInputs::try_from_ints(INITIAL_HASH_VALUE.iter().map(|&v| v as u64)).unwrap(); diff --git a/miden/src/tools/mod.rs b/miden/src/tools/mod.rs index 70501fbfb..39d9d9ea5 100644 --- a/miden/src/tools/mod.rs +++ b/miden/src/tools/mod.rs @@ -38,7 +38,7 @@ impl Analyze { // fetch the stack and program inputs from the arguments let stack_inputs = input_data.parse_stack_inputs().map_err(Report::msg)?; let mut host = DefaultHost::new(input_data.parse_advice_provider().map_err(Report::msg)?); - host.load_mast_forest(StdLibrary::default().into()); + host.load_mast_forest(StdLibrary::default().mast_forest().clone()); let execution_details: ExecutionDetails = analyze(program.as_str(), stack_inputs, host) .expect("Could not retrieve execution details"); diff --git a/processor/src/chiplets/tests.rs b/processor/src/chiplets/tests.rs index d5f120113..a0eab0162 100644 --- a/processor/src/chiplets/tests.rs +++ b/processor/src/chiplets/tests.rs @@ -119,8 +119,9 @@ fn build_trace( let mut mast_forest = MastForest::new(); let basic_block_id = mast_forest.add_block(operations, None).unwrap(); + mast_forest.make_root(basic_block_id); - Program::new(mast_forest, basic_block_id) + Program::new(mast_forest.into(), basic_block_id) }; process.execute(&program).unwrap(); diff --git a/processor/src/decoder/tests.rs b/processor/src/decoder/tests.rs index f17f94d1b..1f5d7b404 100644 --- a/processor/src/decoder/tests.rs +++ b/processor/src/decoder/tests.rs @@ -53,8 +53,9 @@ fn basic_block_one_group() { let basic_block_node = MastNode::Block(basic_block.clone()); let basic_block_id = mast_forest.add_node(basic_block_node).unwrap(); + mast_forest.make_root(basic_block_id); - Program::new(mast_forest, basic_block_id) + Program::new(mast_forest.into(), basic_block_id) }; let (trace, trace_len) = build_trace(&[], &program); @@ -99,8 +100,9 @@ fn basic_block_small() { let basic_block_node = MastNode::Block(basic_block.clone()); let basic_block_id = mast_forest.add_node(basic_block_node).unwrap(); + mast_forest.make_root(basic_block_id); - Program::new(mast_forest, basic_block_id) + Program::new(mast_forest.into(), basic_block_id) }; let (trace, trace_len) = build_trace(&[], &program); @@ -162,8 +164,9 @@ fn basic_block() { let basic_block_node = MastNode::Block(basic_block.clone()); let basic_block_id = mast_forest.add_node(basic_block_node).unwrap(); + mast_forest.make_root(basic_block_id); - Program::new(mast_forest, basic_block_id) + Program::new(mast_forest.into(), basic_block_id) }; let (trace, trace_len) = build_trace(&[], &program); @@ -254,8 +257,9 @@ fn span_block_with_respan() { let basic_block_node = MastNode::Block(basic_block.clone()); let basic_block_id = mast_forest.add_node(basic_block_node).unwrap(); + mast_forest.make_root(basic_block_id); - Program::new(mast_forest, basic_block_id) + Program::new(mast_forest.into(), basic_block_id) }; let (trace, trace_len) = build_trace(&[], &program); @@ -330,8 +334,9 @@ fn join_node() { let basic_block2_id = mast_forest.add_node(basic_block2.clone()).unwrap(); let join_node_id = mast_forest.add_join(basic_block1_id, basic_block2_id).unwrap(); + mast_forest.make_root(join_node_id); - Program::new(mast_forest, join_node_id) + Program::new(mast_forest.into(), join_node_id) }; let (trace, trace_len) = build_trace(&[], &program); @@ -395,8 +400,9 @@ fn split_node_true() { let basic_block2_id = mast_forest.add_node(basic_block2.clone()).unwrap(); let split_node_id = mast_forest.add_split(basic_block1_id, basic_block2_id).unwrap(); + mast_forest.make_root(split_node_id); - Program::new(mast_forest, split_node_id) + Program::new(mast_forest.into(), split_node_id) }; let (trace, trace_len) = build_trace(&[1], &program); @@ -447,8 +453,9 @@ fn split_node_false() { let basic_block2_id = mast_forest.add_node(basic_block2.clone()).unwrap(); let split_node_id = mast_forest.add_split(basic_block1_id, basic_block2_id).unwrap(); + mast_forest.make_root(split_node_id); - Program::new(mast_forest, split_node_id) + Program::new(mast_forest.into(), split_node_id) }; let (trace, trace_len) = build_trace(&[0], &program); @@ -498,10 +505,10 @@ fn loop_node() { let mut mast_forest = MastForest::new(); let loop_body_id = mast_forest.add_node(loop_body.clone()).unwrap(); - let loop_node_id = mast_forest.add_loop(loop_body_id).unwrap(); + mast_forest.make_root(loop_node_id); - Program::new(mast_forest, loop_node_id) + Program::new(mast_forest.into(), loop_node_id) }; let (trace, trace_len) = build_trace(&[0, 1], &program); @@ -550,10 +557,10 @@ fn loop_node_skip() { let mut mast_forest = MastForest::new(); let loop_body_id = mast_forest.add_node(loop_body.clone()).unwrap(); - let loop_node_id = mast_forest.add_loop(loop_body_id).unwrap(); + mast_forest.make_root(loop_node_id); - Program::new(mast_forest, loop_node_id) + Program::new(mast_forest.into(), loop_node_id) }; let (trace, trace_len) = build_trace(&[0], &program); @@ -592,10 +599,10 @@ fn loop_node_repeat() { let mut mast_forest = MastForest::new(); let loop_body_id = mast_forest.add_node(loop_body.clone()).unwrap(); - let loop_node_id = mast_forest.add_loop(loop_body_id).unwrap(); + mast_forest.make_root(loop_node_id); - Program::new(mast_forest, loop_node_id) + Program::new(mast_forest.into(), loop_node_id) }; let (trace, trace_len) = build_trace(&[0, 1, 1], &program); @@ -696,8 +703,9 @@ fn call_block() { let join1_node_id = mast_forest.add_node(join1_node.clone()).unwrap(); let program_root_id = mast_forest.add_join(join1_node_id, last_basic_block_id).unwrap(); + mast_forest.make_root(program_root_id); - let program = Program::new(mast_forest, program_root_id); + let program = Program::new(mast_forest.into(), program_root_id); let (sys_trace, dec_trace, trace_len) = build_call_trace(&program, Kernel::default()); @@ -922,8 +930,9 @@ fn syscall_block() { let program_root_node = MastNode::new_join(inner_join_node_id, last_basic_block_id, &mast_forest).unwrap(); let program_root_node_id = mast_forest.add_node(program_root_node.clone()).unwrap(); + mast_forest.make_root(program_root_node_id); - let program = Program::with_kernel(mast_forest, program_root_node_id, kernel.clone()); + let program = Program::with_kernel(mast_forest.into(), program_root_node_id, kernel.clone()); let (sys_trace, dec_trace, trace_len) = build_call_trace(&program, kernel); @@ -1195,8 +1204,9 @@ fn dyn_block() { let program_root_node = MastNode::new_join(join_node_id, dyn_node_id, &mast_forest).unwrap(); let program_root_node_id = mast_forest.add_node(program_root_node.clone()).unwrap(); + mast_forest.make_root(program_root_node_id); - let program = Program::new(mast_forest, program_root_node_id); + let program = Program::new(mast_forest.into(), program_root_node_id); let (trace, trace_len) = build_dyn_trace( &[ @@ -1302,8 +1312,9 @@ fn set_user_op_helpers_many() { let mut mast_forest = MastForest::new(); let basic_block_id = mast_forest.add_block(vec![Operation::U32div], None).unwrap(); + mast_forest.make_root(basic_block_id); - Program::new(mast_forest, basic_block_id) + Program::new(mast_forest.into(), basic_block_id) }; let a = rand_value::(); let b = rand_value::(); diff --git a/processor/src/host/mast_forest_store.rs b/processor/src/host/mast_forest_store.rs index f6d05b025..a2fd0296b 100644 --- a/processor/src/host/mast_forest_store.rs +++ b/processor/src/host/mast_forest_store.rs @@ -23,9 +23,7 @@ pub struct MemMastForestStore { impl MemMastForestStore { /// Inserts all the procedures of the provided MAST forest in the store. - pub fn insert(&mut self, mast_forest: MastForest) { - let mast_forest = Arc::new(mast_forest); - + pub fn insert(&mut self, mast_forest: Arc) { // only register the procedures which are local to this forest for proc_digest in mast_forest.local_procedure_digests() { self.mast_forests.insert(proc_digest, mast_forest.clone()); diff --git a/processor/src/host/mod.rs b/processor/src/host/mod.rs index 0fbd6ef6b..10366e7c6 100644 --- a/processor/src/host/mod.rs +++ b/processor/src/host/mod.rs @@ -316,7 +316,7 @@ where } } - pub fn load_mast_forest(&mut self, mast_forest: MastForest) { + pub fn load_mast_forest(&mut self, mast_forest: Arc) { self.store.insert(mast_forest) } diff --git a/processor/src/lib.rs b/processor/src/lib.rs index 19c8fc4e4..41028eedb 100644 --- a/processor/src/lib.rs +++ b/processor/src/lib.rs @@ -252,7 +252,7 @@ where return Err(ExecutionError::ProgramAlreadyExecuted); } - self.execute_mast_node(program.entrypoint(), program.mast_forest())?; + self.execute_mast_node(program.entrypoint(), &program.mast_forest().clone())?; Ok(self.stack.build_stack_outputs()) } diff --git a/processor/src/trace/tests/chiplets/hasher.rs b/processor/src/trace/tests/chiplets/hasher.rs index 0ed99b69d..01d6aa4bb 100644 --- a/processor/src/trace/tests/chiplets/hasher.rs +++ b/processor/src/trace/tests/chiplets/hasher.rs @@ -57,8 +57,9 @@ pub fn b_chip_span() { let basic_block_id = mast_forest.add_block(vec![Operation::Add, Operation::Mul], None).unwrap(); + mast_forest.make_root(basic_block_id); - Program::new(mast_forest, basic_block_id) + Program::new(mast_forest.into(), basic_block_id) }; let trace = build_trace_from_program(&program, &[]); @@ -129,8 +130,9 @@ pub fn b_chip_span_with_respan() { let (ops, _) = build_span_with_respan_ops(); let basic_block_id = mast_forest.add_block(ops, None).unwrap(); + mast_forest.make_root(basic_block_id); - Program::new(mast_forest, basic_block_id) + Program::new(mast_forest.into(), basic_block_id) }; let trace = build_trace_from_program(&program, &[]); @@ -220,12 +222,11 @@ pub fn b_chip_merge() { let mut mast_forest = MastForest::new(); let t_branch_id = mast_forest.add_block(vec![Operation::Add], None).unwrap(); - let f_branch_id = mast_forest.add_block(vec![Operation::Mul], None).unwrap(); - let split_id = mast_forest.add_split(t_branch_id, f_branch_id).unwrap(); + mast_forest.make_root(split_id); - Program::new(mast_forest, split_id) + Program::new(mast_forest.into(), split_id) }; let trace = build_trace_from_program(&program, &[]); @@ -336,8 +337,9 @@ pub fn b_chip_permutation() { let mut mast_forest = MastForest::new(); let basic_block_id = mast_forest.add_block(vec![Operation::HPerm], None).unwrap(); + mast_forest.make_root(basic_block_id); - Program::new(mast_forest, basic_block_id) + Program::new(mast_forest.into(), basic_block_id) }; let stack = vec![8, 0, 0, 0, 1, 2, 3, 4, 5, 6, 7, 8]; let trace = build_trace_from_program(&program, &stack); diff --git a/processor/src/trace/tests/decoder.rs b/processor/src/trace/tests/decoder.rs index 1213b05dd..021102dcc 100644 --- a/processor/src/trace/tests/decoder.rs +++ b/processor/src/trace/tests/decoder.rs @@ -74,12 +74,11 @@ fn decoder_p1_join() { let mut mast_forest = MastForest::new(); let basic_block_1_id = mast_forest.add_block(vec![Operation::Mul], None).unwrap(); - let basic_block_2_id = mast_forest.add_block(vec![Operation::Add], None).unwrap(); - let join_id = mast_forest.add_join(basic_block_1_id, basic_block_2_id).unwrap(); + mast_forest.make_root(join_id); - Program::new(mast_forest, join_id) + Program::new(mast_forest.into(), join_id) }; let trace = build_trace_from_program(&program, &[]); @@ -142,12 +141,11 @@ fn decoder_p1_split() { let mut mast_forest = MastForest::new(); let basic_block_1_id = mast_forest.add_block(vec![Operation::Mul], None).unwrap(); - let basic_block_2_id = mast_forest.add_block(vec![Operation::Add], None).unwrap(); - let split_id = mast_forest.add_split(basic_block_1_id, basic_block_2_id).unwrap(); + mast_forest.make_root(split_id); - Program::new(mast_forest, split_id) + Program::new(mast_forest.into(), split_id) }; let trace = build_trace_from_program(&program, &[1]); @@ -197,14 +195,12 @@ fn decoder_p1_loop_with_repeat() { let mut mast_forest = MastForest::new(); let basic_block_1_id = mast_forest.add_block(vec![Operation::Pad], None).unwrap(); - let basic_block_2_id = mast_forest.add_block(vec![Operation::Drop], None).unwrap(); - let join_id = mast_forest.add_join(basic_block_1_id, basic_block_2_id).unwrap(); - let loop_node_id = mast_forest.add_loop(join_id).unwrap(); + mast_forest.make_root(loop_node_id); - Program::new(mast_forest, loop_node_id) + Program::new(mast_forest.into(), loop_node_id) }; let trace = build_trace_from_program(&program, &[0, 1, 1]); @@ -324,8 +320,9 @@ fn decoder_p2_span_with_respan() { let (ops, _) = build_span_with_respan_ops(); let basic_block_id = mast_forest.add_block(ops, None).unwrap(); + mast_forest.make_root(basic_block_id); - Program::new(mast_forest, basic_block_id) + Program::new(mast_forest.into(), basic_block_id) }; let trace = build_trace_from_program(&program, &[]); let alphas = rand_array::(); @@ -366,8 +363,9 @@ fn decoder_p2_join() { let join = MastNode::new_join(basic_block_1_id, basic_block_2_id, &mast_forest).unwrap(); let join_id = mast_forest.add_node(join.clone()).unwrap(); + mast_forest.make_root(join_id); - let program = Program::new(mast_forest, join_id); + let program = Program::new(mast_forest.into(), join_id); let trace = build_trace_from_program(&program, &[]); let alphas = rand_array::(); @@ -425,12 +423,11 @@ fn decoder_p2_split_true() { let basic_block_1 = MastNode::new_basic_block(vec![Operation::Mul], None).unwrap(); let basic_block_1_id = mast_forest.add_node(basic_block_1.clone()).unwrap(); - let basic_block_2_id = mast_forest.add_block(vec![Operation::Add], None).unwrap(); - let split_id = mast_forest.add_split(basic_block_1_id, basic_block_2_id).unwrap(); + mast_forest.make_root(split_id); - let program = Program::new(mast_forest, split_id); + let program = Program::new(mast_forest.into(), split_id); // build trace from program let trace = build_trace_from_program(&program, &[1]); @@ -484,8 +481,9 @@ fn decoder_p2_split_false() { let basic_block_2_id = mast_forest.add_node(basic_block_2.clone()).unwrap(); let split_id = mast_forest.add_split(basic_block_1_id, basic_block_2_id).unwrap(); + mast_forest.make_root(split_id); - let program = Program::new(mast_forest, split_id); + let program = Program::new(mast_forest.into(), split_id); // build trace from program let trace = build_trace_from_program(&program, &[0]); @@ -542,8 +540,9 @@ fn decoder_p2_loop_with_repeat() { let join_id = mast_forest.add_node(join.clone()).unwrap(); let loop_node_id = mast_forest.add_loop(join_id).unwrap(); + mast_forest.make_root(loop_node_id); - let program = Program::new(mast_forest, loop_node_id); + let program = Program::new(mast_forest.into(), loop_node_id); // build trace from program let trace = build_trace_from_program(&program, &[0, 1, 1]); diff --git a/processor/src/trace/tests/mod.rs b/processor/src/trace/tests/mod.rs index c29c4267f..42524d68f 100644 --- a/processor/src/trace/tests/mod.rs +++ b/processor/src/trace/tests/mod.rs @@ -34,8 +34,9 @@ pub fn build_trace_from_ops(operations: Vec, stack: &[u64]) -> Execut let mut mast_forest = MastForest::new(); let basic_block_id = mast_forest.add_block(operations, None).unwrap(); + mast_forest.make_root(basic_block_id); - let program = Program::new(mast_forest, basic_block_id); + let program = Program::new(mast_forest.into(), basic_block_id); build_trace_from_program(&program, stack) } @@ -55,8 +56,9 @@ pub fn build_trace_from_ops_with_inputs( let mut mast_forest = MastForest::new(); let basic_block_id = mast_forest.add_block(operations, None).unwrap(); + mast_forest.make_root(basic_block_id); - let program = Program::new(mast_forest, basic_block_id); + let program = Program::new(mast_forest.into(), basic_block_id); process.execute(&program).unwrap(); ExecutionTrace::new(process, StackOutputs::default()) diff --git a/stdlib/src/lib.rs b/stdlib/src/lib.rs index 9637aff89..b0fb4e121 100644 --- a/stdlib/src/lib.rs +++ b/stdlib/src/lib.rs @@ -2,6 +2,8 @@ extern crate alloc; +use alloc::sync::Arc; + use assembly::{mast::MastForest, utils::Deserializable, Library}; // STANDARD LIBRARY @@ -22,22 +24,20 @@ impl From for Library { } } -impl From for MastForest { - fn from(value: StdLibrary) -> Self { - value.0.into() - } -} - impl StdLibrary { + /// Serialized representation of the Miden standard library. pub const SERIALIZED: &'static [u8] = include_bytes!(concat!(env!("OUT_DIR"), "/assets/std.masl")); + + /// Returns a reference to the [MastForest] underlying the Miden standard library. + pub fn mast_forest(&self) -> &Arc { + self.0.mast_forest() + } } impl Default for StdLibrary { fn default() -> Self { - let contents = - Library::read_from_bytes(Self::SERIALIZED).expect("failed to read std masl!"); - Self(contents) + Self(Library::read_from_bytes(Self::SERIALIZED).expect("failed to deserialize stdlib")) } } diff --git a/stdlib/tests/mem/mod.rs b/stdlib/tests/mem/mod.rs index 7ce4bc80d..6802a2620 100644 --- a/stdlib/tests/mem/mod.rs +++ b/stdlib/tests/mem/mod.rs @@ -31,7 +31,7 @@ fn test_memcopy() { assembler.assemble_program(source).expect("Failed to compile test source."); let mut host = DefaultHost::default(); - host.load_mast_forest(stdlib.into()); + host.load_mast_forest(stdlib.mast_forest().clone()); let mut process = Process::new( program.kernel().clone(), diff --git a/test-utils/src/lib.rs b/test-utils/src/lib.rs index 8c2d048e2..b32d2604b 100644 --- a/test-utils/src/lib.rs +++ b/test-utils/src/lib.rs @@ -5,9 +5,6 @@ extern crate alloc; #[cfg(feature = "std")] extern crate std; -// IMPORTS -// ================================================================================================ - #[cfg(not(target_family = "wasm"))] use alloc::format; use alloc::{ @@ -16,16 +13,14 @@ use alloc::{ vec::Vec, }; -use assembly::Library; -// EXPORTS -// ================================================================================================ pub use assembly::{diagnostics::Report, LibraryPath, SourceFile, SourceManager}; +use assembly::{KernelLibrary, Library}; pub use pretty_assertions::{assert_eq, assert_ne, assert_str_eq}; +use processor::Program; pub use processor::{ AdviceInputs, AdviceProvider, ContextId, DefaultHost, ExecutionError, ExecutionOptions, ExecutionTrace, Process, ProcessState, StackInputs, VmStateIterator, }; -use processor::{MastForest, Program}; #[cfg(not(target_family = "wasm"))] use proptest::prelude::{Arbitrary, Strategy}; pub use prover::{prove, MemAdviceProvider, ProvingOptions}; @@ -231,7 +226,7 @@ impl Test { let (program, kernel) = self.compile().expect("Failed to compile test source."); let mut host = DefaultHost::new(MemAdviceProvider::from(self.advice_inputs.clone())); if let Some(kernel) = kernel { - host.load_mast_forest(kernel); + host.load_mast_forest(kernel.mast_forest().clone()); } for library in &self.libraries { host.load_mast_forest(library.mast_forest().clone()); @@ -283,22 +278,26 @@ impl Test { // UTILITY METHODS // -------------------------------------------------------------------------------------------- - /// Compiles a test's source and returns the resulting Program or Assembly error. - pub fn compile(&self) -> Result<(Program, Option), Report> { + /// Compiles a test's source and returns the resulting Program together with the associated + /// kernel library (when specified). + /// + /// # Errors + /// Returns an error if compilation of the program source or the kernel fails. + pub fn compile(&self) -> Result<(Program, Option), Report> { use assembly::{ast::ModuleKind, Assembler, CompileOptions}; - let (assembler, compiled_kernel) = if let Some(kernel) = self.kernel_source.clone() { + let (assembler, kernel_lib) = if let Some(kernel) = self.kernel_source.clone() { let kernel_lib = Assembler::new(self.source_manager.clone()).assemble_kernel(kernel).unwrap(); - let compiled_kernel = kernel_lib.mast_forest().clone(); ( - Assembler::with_kernel(self.source_manager.clone(), kernel_lib), - Some(compiled_kernel), + Assembler::with_kernel(self.source_manager.clone(), kernel_lib.clone()), + Some(kernel_lib), ) } else { (Assembler::new(self.source_manager.clone()), None) }; + let mut assembler = self .add_modules .iter() @@ -315,7 +314,7 @@ impl Test { assembler.add_library(library).unwrap(); } - Ok((assembler.assemble_program(self.source.clone())?, compiled_kernel)) + Ok((assembler.assemble_program(self.source.clone())?, kernel_lib)) } /// Compiles the test's source to a Program and executes it with the tests inputs. Returns a @@ -325,7 +324,7 @@ impl Test { let (program, kernel) = self.compile().expect("Failed to compile test source."); let mut host = DefaultHost::new(MemAdviceProvider::from(self.advice_inputs.clone())); if let Some(kernel) = kernel { - host.load_mast_forest(kernel); + host.load_mast_forest(kernel.mast_forest().clone()); } for library in &self.libraries { host.load_mast_forest(library.mast_forest().clone()); @@ -341,7 +340,7 @@ impl Test { let (program, kernel) = self.compile().expect("Failed to compile test source."); let mut host = DefaultHost::new(MemAdviceProvider::from(self.advice_inputs.clone())); if let Some(kernel) = kernel { - host.load_mast_forest(kernel); + host.load_mast_forest(kernel.mast_forest().clone()); } for library in &self.libraries { host.load_mast_forest(library.mast_forest().clone()); @@ -365,7 +364,7 @@ impl Test { let (program, kernel) = self.compile().expect("Failed to compile test source."); let mut host = DefaultHost::new(MemAdviceProvider::from(self.advice_inputs.clone())); if let Some(kernel) = kernel { - host.load_mast_forest(kernel); + host.load_mast_forest(kernel.mast_forest().clone()); } for library in &self.libraries { host.load_mast_forest(library.mast_forest().clone()); @@ -390,7 +389,7 @@ impl Test { let (program, kernel) = self.compile().expect("Failed to compile test source."); let mut host = DefaultHost::new(MemAdviceProvider::from(self.advice_inputs.clone())); if let Some(kernel) = kernel { - host.load_mast_forest(kernel); + host.load_mast_forest(kernel.mast_forest().clone()); } for library in &self.libraries { host.load_mast_forest(library.mast_forest().clone());