diff --git a/CHANGELOG.md b/CHANGELOG.md index 23c149e88c..0f27e0ed8f 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -7,6 +7,7 @@ - Added error codes support for the `mtree_verify` instruction (#1328). - Added support for immediate values for `lt`, `lte`, `gt`, `gte` comparison instructions (#1346). - Change MAST to a table-based representation (#1349) +- Introduce `MastForestStore` (#1359) - Adjusted prover's metal acceleration code to work with 0.9 versions of the crates (#1357) - Added support for immediate values for `u32lt`, `u32lte`, `u32gt`, `u32gte`, `u32min` and `u32max` comparison instructions (#1358). - Added support for the `nop` instruction, which corresponds to the VM opcode of the same name, and has the same semantics. This is implemented for use by compilers primarily. diff --git a/assembly/src/assembler/basic_block_builder.rs b/assembly/src/assembler/basic_block_builder.rs index 24889ef3b7..dfbdbd84c2 100644 --- a/assembly/src/assembler/basic_block_builder.rs +++ b/assembly/src/assembler/basic_block_builder.rs @@ -1,7 +1,10 @@ -use super::{AssemblyContext, BodyWrapper, Decorator, DecoratorList, Instruction}; +use super::{ + mast_forest_builder::MastForestBuilder, AssemblyContext, BodyWrapper, Decorator, DecoratorList, + Instruction, +}; use alloc::{borrow::Borrow, string::ToString, vec::Vec}; use vm_core::{ - mast::{MastForest, MastNode, MastNodeId}, + mast::{MastNode, MastNodeId}, AdviceInjector, AssemblyOp, Operation, }; @@ -123,13 +126,16 @@ impl BasicBlockBuilder { /// /// This consumes all operations and decorators in the builder, but does not touch the /// operations in the epilogue of the builder. - pub fn make_basic_block(&mut self, mast_forest: &mut MastForest) -> Option { + pub fn make_basic_block( + &mut self, + mast_forest_builder: &mut MastForestBuilder, + ) -> Option { if !self.ops.is_empty() { let ops = self.ops.drain(..).collect(); let decorators = self.decorators.drain(..).collect(); let basic_block_node = MastNode::new_basic_block_with_decorators(ops, decorators); - let basic_block_node_id = mast_forest.ensure_node(basic_block_node); + let basic_block_node_id = mast_forest_builder.ensure_node(basic_block_node); Some(basic_block_node_id) } else if !self.decorators.is_empty() { @@ -149,8 +155,11 @@ impl BasicBlockBuilder { /// - Operations contained in the epilogue of the builder are appended to the list of ops which /// go into the new BASIC BLOCK node. /// - The builder is consumed in the process. - pub fn into_basic_block(mut self, mast_forest: &mut MastForest) -> Option { + pub fn into_basic_block( + mut self, + mast_forest_builder: &mut MastForestBuilder, + ) -> Option { self.ops.append(&mut self.epilogue); - self.make_basic_block(mast_forest) + self.make_basic_block(mast_forest_builder) } } diff --git a/assembly/src/assembler/instruction/mod.rs b/assembly/src/assembler/instruction/mod.rs index 71619bbecd..48f01ce203 100644 --- a/assembly/src/assembler/instruction/mod.rs +++ b/assembly/src/assembler/instruction/mod.rs @@ -1,13 +1,10 @@ use super::{ - ast::InvokeKind, Assembler, AssemblyContext, BasicBlockBuilder, Felt, Instruction, Operation, - ONE, ZERO, + ast::InvokeKind, mast_forest_builder::MastForestBuilder, Assembler, AssemblyContext, + BasicBlockBuilder, Felt, Instruction, Operation, ONE, ZERO, }; use crate::{diagnostics::Report, utils::bound_into_included_u64, AssemblyError}; use core::ops::RangeBounds; -use vm_core::{ - mast::{MastForest, MastNodeId}, - Decorator, -}; +use vm_core::{mast::MastNodeId, Decorator}; mod adv_ops; mod crypto_ops; @@ -27,7 +24,7 @@ impl Assembler { instruction: &Instruction, span_builder: &mut BasicBlockBuilder, ctx: &mut AssemblyContext, - mast_forest: &mut MastForest, + mast_forest_builder: &mut MastForestBuilder, ) -> Result, AssemblyError> { // if the assembler is in debug mode, start tracking the instruction about to be executed; // this will allow us to map the instruction to the sequence of operations which were @@ -36,7 +33,8 @@ impl Assembler { span_builder.track_instruction(instruction, ctx); } - let result = self.compile_instruction_impl(instruction, span_builder, ctx, mast_forest)?; + let result = + self.compile_instruction_impl(instruction, span_builder, ctx, mast_forest_builder)?; // compute and update the cycle count of the instruction which just finished executing if self.in_debug_mode() { @@ -51,7 +49,7 @@ impl Assembler { instruction: &Instruction, span_builder: &mut BasicBlockBuilder, ctx: &mut AssemblyContext, - mast_forest: &mut MastForest, + mast_forest_builder: &mut MastForestBuilder, ) -> Result, AssemblyError> { use Operation::*; @@ -376,18 +374,18 @@ impl Assembler { // ----- exec/call instructions ------------------------------------------------------- Instruction::Exec(ref callee) => { - return self.invoke(InvokeKind::Exec, callee, ctx, mast_forest) + return self.invoke(InvokeKind::Exec, callee, ctx, mast_forest_builder) } Instruction::Call(ref callee) => { - return self.invoke(InvokeKind::Call, callee, ctx, mast_forest) + return self.invoke(InvokeKind::Call, callee, ctx, mast_forest_builder) } Instruction::SysCall(ref callee) => { - return self.invoke(InvokeKind::SysCall, callee, ctx, mast_forest) + return self.invoke(InvokeKind::SysCall, callee, ctx, mast_forest_builder) } - Instruction::DynExec => return self.dynexec(mast_forest), - Instruction::DynCall => return self.dyncall(mast_forest), + Instruction::DynExec => return self.dynexec(mast_forest_builder), + Instruction::DynCall => return self.dyncall(mast_forest_builder), Instruction::ProcRef(ref callee) => { - self.procref(callee, ctx, span_builder, mast_forest)? + self.procref(callee, ctx, span_builder, mast_forest_builder.forest())? } // ----- debug decorators ------------------------------------------------------------- diff --git a/assembly/src/assembler/instruction/procedures.rs b/assembly/src/assembler/instruction/procedures.rs index 9ac2bcfdd3..9894a82092 100644 --- a/assembly/src/assembler/instruction/procedures.rs +++ b/assembly/src/assembler/instruction/procedures.rs @@ -1,5 +1,6 @@ use super::{Assembler, AssemblyContext, BasicBlockBuilder, Operation}; use crate::{ + assembler::mast_forest_builder::MastForestBuilder, ast::{InvocationTarget, InvokeKind}, AssemblyError, RpoDigest, SourceSpan, Span, Spanned, }; @@ -14,11 +15,11 @@ impl Assembler { kind: InvokeKind, callee: &InvocationTarget, context: &mut AssemblyContext, - mast_forest: &mut MastForest, + mast_forest_builder: &mut MastForestBuilder, ) -> Result, AssemblyError> { let span = callee.span(); - let digest = self.resolve_target(kind, callee, context, mast_forest)?; - self.invoke_mast_root(kind, span, digest, context, mast_forest) + let digest = self.resolve_target(kind, callee, context, mast_forest_builder.forest())?; + self.invoke_mast_root(kind, span, digest, context, mast_forest_builder) } fn invoke_mast_root( @@ -27,7 +28,7 @@ impl Assembler { span: SourceSpan, mast_root: RpoDigest, context: &mut AssemblyContext, - mast_forest: &mut MastForest, + mast_forest_builder: &mut MastForestBuilder, ) -> Result, AssemblyError> { // Get the procedure from the assembler let cache = &self.procedure_cache; @@ -68,13 +69,15 @@ impl Assembler { }) } })?; - context.register_external_call(&proc, false, mast_forest)?; + context.register_external_call(&proc, false, mast_forest_builder.forest())?; + } + Some(proc) => { + context.register_external_call(&proc, false, mast_forest_builder.forest())? } - Some(proc) => context.register_external_call(&proc, false, mast_forest)?, None if matches!(kind, InvokeKind::SysCall) => { return Err(AssemblyError::UnknownSysCallTarget { span, - source_file: current_source_file, + source_file: current_source_file.clone(), callee: mast_root, }); } @@ -82,28 +85,43 @@ impl Assembler { } let mast_root_node_id = { - // Note that here we rely on the fact that we topologically sorted the procedures, such - // that when we assemble a procedure, all procedures that it calls will have been - // assembled, and hence be present in the `MastForest`. We currently assume that the - // `MastForest` contains all the procedures being called; "external procedures" only - // known by digest are not currently supported. - let callee_id = mast_forest - .get_node_id_by_digest(mast_root) - .unwrap_or_else(|| panic!("MAST root {} not in MAST forest", mast_root)); - match kind { - // For `exec`, we return the root of the procedure being exec'd, which has the - // effect of inlining it - InvokeKind::Exec => callee_id, - // For `call`, we just use the corresponding CALL block + InvokeKind::Exec => { + // Note that here we rely on the fact that we topologically sorted the + // procedures, such that when we assemble a procedure, all + // procedures that it calls will have been assembled, and + // hence be present in the `MastForest`. + mast_forest_builder.find_procedure_root(mast_root).unwrap_or_else(|| { + // If the MAST root called isn't known to us, make it an external + // reference. + let external_node = MastNode::new_external(mast_root); + mast_forest_builder.ensure_node(external_node) + }) + } InvokeKind::Call => { - let node = MastNode::new_call(callee_id, mast_forest); - mast_forest.ensure_node(node) + let callee_id = + mast_forest_builder.find_procedure_root(mast_root).unwrap_or_else(|| { + // If the MAST root called isn't known to us, make it an external + // reference. + let external_node = MastNode::new_external(mast_root); + mast_forest_builder.ensure_node(external_node) + }); + + let call_node = MastNode::new_call(callee_id, mast_forest_builder.forest()); + mast_forest_builder.ensure_node(call_node) } - // For `syscall`, we just use the corresponding SYSCALL block InvokeKind::SysCall => { - let node = MastNode::new_syscall(callee_id, mast_forest); - mast_forest.ensure_node(node) + let callee_id = + mast_forest_builder.find_procedure_root(mast_root).unwrap_or_else(|| { + // If the MAST root called isn't known to us, make it an external + // reference. + let external_node = MastNode::new_external(mast_root); + mast_forest_builder.ensure_node(external_node) + }); + + let syscall_node = + MastNode::new_syscall(callee_id, mast_forest_builder.forest()); + mast_forest_builder.ensure_node(syscall_node) } } }; @@ -114,9 +132,9 @@ impl Assembler { /// Creates a new DYN block for the dynamic code execution and return. pub(super) fn dynexec( &self, - mast_forest: &mut MastForest, + mast_forest_builder: &mut MastForestBuilder, ) -> Result, AssemblyError> { - let dyn_node_id = mast_forest.ensure_node(MastNode::Dyn); + let dyn_node_id = mast_forest_builder.ensure_node(MastNode::Dyn); Ok(Some(dyn_node_id)) } @@ -124,13 +142,13 @@ impl Assembler { /// Creates a new CALL block whose target is DYN. pub(super) fn dyncall( &self, - mast_forest: &mut MastForest, + mast_forest_builder: &mut MastForestBuilder, ) -> Result, AssemblyError> { let dyn_call_node_id = { - let dyn_node_id = mast_forest.ensure_node(MastNode::Dyn); - let dyn_call_node = MastNode::new_call(dyn_node_id, mast_forest); + let dyn_node_id = mast_forest_builder.ensure_node(MastNode::Dyn); + let dyn_call_node = MastNode::new_call(dyn_node_id, mast_forest_builder.forest()); - mast_forest.ensure_node(dyn_call_node) + mast_forest_builder.ensure_node(dyn_call_node) }; Ok(Some(dyn_call_node_id)) diff --git a/assembly/src/assembler/mast_forest_builder.rs b/assembly/src/assembler/mast_forest_builder.rs new file mode 100644 index 0000000000..39c42c388b --- /dev/null +++ b/assembly/src/assembler/mast_forest_builder.rs @@ -0,0 +1,74 @@ +use core::ops::Index; + +use alloc::collections::BTreeMap; +use vm_core::{ + crypto::hash::RpoDigest, + mast::{MastForest, MastNode, MastNodeId, MerkleTreeNode}, +}; + +/// Builder for a [`MastForest`]. +#[derive(Clone, Debug, Default)] +pub struct MastForestBuilder { + mast_forest: MastForest, + node_id_by_hash: BTreeMap, +} + +impl MastForestBuilder { + pub fn new() -> Self { + Self::default() + } + + pub fn build(self) -> MastForest { + self.mast_forest + } +} + +/// Accessors +impl MastForestBuilder { + /// Returns the underlying [`MastForest`] being built + pub fn forest(&self) -> &MastForest { + &self.mast_forest + } + + /// Returns the [`MastNodeId`] of the procedure associated with a given digest, if any. + #[inline(always)] + pub fn find_procedure_root(&self, digest: RpoDigest) -> Option { + self.mast_forest.find_procedure_root(digest) + } +} + +/// Mutators +impl MastForestBuilder { + /// Adds a node to the forest, and returns the [`MastNodeId`] associated with it. + /// + /// If a [`MastNode`] which is equal to the current node was previously added, the previously + /// returned [`MastNodeId`] will be returned. This enforces this invariant that equal + /// [`MastNode`]s have equal [`MastNodeId`]s. + pub fn ensure_node(&mut self, node: MastNode) -> MastNodeId { + let node_digest = node.digest(); + + if let Some(node_id) = self.node_id_by_hash.get(&node_digest) { + // node already exists in the forest; return previously assigned id + *node_id + } else { + let new_node_id = self.mast_forest.add_node(node); + self.node_id_by_hash.insert(node_digest, new_node_id); + + new_node_id + } + } + + /// Marks the given [`MastNodeId`] as being the root of a procedure. + pub fn make_root(&mut self, new_root_id: MastNodeId) { + self.mast_forest.make_root(new_root_id) + } +} + +impl Index for MastForestBuilder { + type Output = MastNode; + + #[inline(always)] + fn index(&self, node_id: MastNodeId) -> &Self::Output { + &self.mast_forest[node_id] + } +} diff --git a/assembly/src/assembler/mod.rs b/assembly/src/assembler/mod.rs index 85a7d29907..a63f7ffb25 100644 --- a/assembly/src/assembler/mod.rs +++ b/assembly/src/assembler/mod.rs @@ -9,7 +9,7 @@ use crate::{ RpoDigest, Spanned, ONE, ZERO, }; use alloc::{boxed::Box, sync::Arc, vec::Vec}; -use miette::miette; +use mast_forest_builder::MastForestBuilder; use vm_core::{ mast::{MastForest, MastNode, MastNodeId, MerkleTreeNode}, Decorator, DecoratorList, Kernel, Operation, Program, @@ -19,6 +19,7 @@ mod basic_block_builder; mod context; mod id; mod instruction; +mod mast_forest_builder; mod module_graph; mod procedure; #[cfg(test)] @@ -87,7 +88,7 @@ pub enum ArtifactKind { /// [Assembler::compile] or [Assembler::compile_ast] to get your compiled program. #[derive(Clone)] pub struct Assembler { - mast_forest: MastForest, + mast_forest_builder: MastForestBuilder, /// The global [ModuleGraph] for this assembler. All new [AssemblyContext]s inherit this graph /// as a baseline. module_graph: Box, @@ -104,7 +105,7 @@ pub struct Assembler { impl Default for Assembler { fn default() -> Self { Self { - mast_forest: Default::default(), + mast_forest_builder: Default::default(), module_graph: Default::default(), procedure_cache: Default::default(), warnings_as_errors: false, @@ -121,13 +122,11 @@ impl Assembler { Self::default() } - /// Start building an [`Assembler`] with the given [`Kernel`] and the [`MastForest`] that was - /// used to compile the kernel. - pub fn with_kernel(kernel: Kernel, mast_forest: MastForest) -> Self { + /// Start building an [`Assembler`] with the given [`Kernel`]. + pub fn with_kernel(kernel: Kernel) -> Self { let mut assembler = Self::new(); assembler.module_graph.set_kernel(None, kernel); - assembler.mast_forest = mast_forest; assembler } @@ -141,13 +140,13 @@ impl Assembler { let opts = CompileOptions::for_kernel(); let module = module.compile_with_options(opts)?; - let mut mast_forest = MastForest::new(); + let mut mast_forest_builder = MastForestBuilder::new(); - let (kernel_index, kernel) = assembler.assemble_kernel_module(module, &mut mast_forest)?; + let (kernel_index, kernel) = + assembler.assemble_kernel_module(module, &mut mast_forest_builder)?; assembler.module_graph.set_kernel(Some(kernel_index), kernel); - mast_forest.set_kernel(assembler.module_graph.kernel().clone()); - assembler.mast_forest = mast_forest; + assembler.mast_forest_builder = mast_forest_builder; Ok(assembler) } @@ -313,18 +312,6 @@ impl Assembler { /// Compilation/Assembly impl Assembler { - /// Compiles the provided module into a [`MastForest`]. - /// - /// # Errors - /// - /// Returns an error if parsing or compilation of the specified program fails. - pub fn assemble(self, source: impl Compile) -> Result { - let mut context = AssemblyContext::default(); - context.set_warnings_as_errors(self.warnings_as_errors); - - self.assemble_in_context(source, &mut context) - } - /// Compiles the provided module into a [`Program`]. The resulting program can be executed on /// Miden VM. /// @@ -332,10 +319,11 @@ impl Assembler { /// /// Returns an error if parsing or compilation of the specified program fails, or if the source /// doesn't have an entrypoint. - pub fn assemble_program(self, source: impl Compile) -> Result { - let mast_forest = self.assemble(source)?; + pub fn assemble(self, source: impl Compile) -> Result { + let mut context = AssemblyContext::default(); + context.set_warnings_as_errors(self.warnings_as_errors); - mast_forest.try_into().map_err(|program_err| miette!("{program_err}")) + self.assemble_in_context(source, &mut context) } /// Like [Assembler::compile], but also takes an [AssemblyContext] to configure the assembler. @@ -343,7 +331,7 @@ impl Assembler { self, source: impl Compile, context: &mut AssemblyContext, - ) -> Result { + ) -> Result { let opts = CompileOptions { warnings_as_errors: context.warnings_as_errors(), ..CompileOptions::default() @@ -363,7 +351,7 @@ impl Assembler { self, source: impl Compile, options: CompileOptions, - ) -> Result { + ) -> Result { let mut context = AssemblyContext::default(); context.set_warnings_as_errors(options.warnings_as_errors); @@ -378,7 +366,7 @@ impl Assembler { source: impl Compile, options: CompileOptions, context: &mut AssemblyContext, - ) -> Result { + ) -> Result { self.assemble_with_options_in_context_impl(source, options, context) } @@ -391,14 +379,14 @@ impl Assembler { source: impl Compile, options: CompileOptions, context: &mut AssemblyContext, - ) -> Result { + ) -> Result { if options.kind != ModuleKind::Executable { return Err(Report::msg( "invalid compile options: assemble_with_opts_in_context requires that the kind be 'executable'", )); } - let mut mast_forest = core::mem::take(&mut self.mast_forest); + let mast_forest_builder = core::mem::take(&mut self.mast_forest_builder); let program = source.compile_with_options(CompileOptions { // Override the module name so that we always compile the executable @@ -428,9 +416,7 @@ impl Assembler { }) .ok_or(SemanticAnalysisError::MissingEntrypoint)?; - self.compile_program(entrypoint, context, &mut mast_forest)?; - - Ok(mast_forest) + self.compile_program(entrypoint, context, mast_forest_builder) } /// Compile and assembles all procedures in the specified module, adding them to the procedure @@ -477,13 +463,14 @@ impl Assembler { let module_id = self.module_graph.add_module(module)?; self.module_graph.recompute()?; - let mut mast_forest = core::mem::take(&mut self.mast_forest); + let mut mast_forest_builder = core::mem::take(&mut self.mast_forest_builder); - self.assemble_graph(context, &mut mast_forest)?; - let exported_procedure_digests = self.get_module_exports(module_id, &mast_forest); + self.assemble_graph(context, &mut mast_forest_builder)?; + let exported_procedure_digests = + self.get_module_exports(module_id, mast_forest_builder.forest()); // Reassign the mast_forest to the assembler for use is a future program assembly - self.mast_forest = mast_forest; + self.mast_forest_builder = mast_forest_builder; exported_procedure_digests } @@ -493,7 +480,7 @@ impl Assembler { fn assemble_kernel_module( &mut self, module: Box, - mast_forest: &mut MastForest, + mast_forest_builder: &mut MastForestBuilder, ) -> Result<(ModuleIndex, Kernel), Report> { if !module.is_kernel() { return Err(Report::msg(format!("expected kernel module, got {}", module.kind()))); @@ -515,8 +502,8 @@ impl Assembler { module: kernel_index, index: ProcedureIndex::new(index), }; - let compiled = self.compile_subgraph(gid, false, &mut context, mast_forest)?; - kernel.push(compiled.mast_root(mast_forest)); + let compiled = self.compile_subgraph(gid, false, &mut context, mast_forest_builder)?; + kernel.push(compiled.mast_root(mast_forest_builder.forest())); } Kernel::new(&kernel) @@ -600,17 +587,20 @@ impl Assembler { &mut self, entrypoint: GlobalProcedureIndex, context: &mut AssemblyContext, - mast_forest: &mut MastForest, - ) -> Result<(), Report> { + mut mast_forest_builder: MastForestBuilder, + ) -> Result { // Raise an error if we are called with an invalid entrypoint assert!(self.module_graph[entrypoint].name().is_main()); // Compile the module graph rooted at the entrypoint - let entry_procedure = self.compile_subgraph(entrypoint, true, context, mast_forest)?; - - mast_forest.set_entrypoint(entry_procedure.body_node_id()); + let entry_procedure = + self.compile_subgraph(entrypoint, true, context, &mut mast_forest_builder)?; - Ok(()) + Ok(Program::with_kernel( + mast_forest_builder.build(), + entry_procedure.body_node_id(), + self.module_graph.kernel().clone(), + )) } /// Compile all of the uncompiled procedures in the module graph, placing them @@ -620,11 +610,11 @@ impl Assembler { fn assemble_graph( &mut self, context: &mut AssemblyContext, - mast_forest: &mut MastForest, + mast_forest_builder: &mut MastForestBuilder, ) -> Result<(), Report> { let mut worklist = self.module_graph.topological_sort().to_vec(); assert!(!worklist.is_empty()); - self.process_graph_worklist(&mut worklist, context, None, mast_forest) + self.process_graph_worklist(&mut worklist, context, None, mast_forest_builder) .map(|_| ()) } @@ -637,7 +627,7 @@ impl Assembler { root: GlobalProcedureIndex, is_entrypoint: bool, context: &mut AssemblyContext, - mast_forest: &mut MastForest, + mast_forest_builder: &mut MastForestBuilder, ) -> Result, Report> { let mut worklist = self.module_graph.topological_sort_from_root(root).map_err(|cycle| { let iter = cycle.into_node_ids(); @@ -653,9 +643,10 @@ impl Assembler { assert!(!worklist.is_empty()); let compiled = if is_entrypoint { - self.process_graph_worklist(&mut worklist, context, Some(root), mast_forest)? + self.process_graph_worklist(&mut worklist, context, Some(root), mast_forest_builder)? } else { - let _ = self.process_graph_worklist(&mut worklist, context, None, mast_forest)?; + let _ = + self.process_graph_worklist(&mut worklist, context, None, mast_forest_builder)?; self.procedure_cache.get(root) }; @@ -667,7 +658,7 @@ impl Assembler { worklist: &mut Vec, context: &mut AssemblyContext, entrypoint: Option, - mast_forest: &mut MastForest, + mast_forest_builder: &mut MastForestBuilder, ) -> Result>, Report> { // Process the topological ordering in reverse order (bottom-up), so that // each procedure is compiled with all of its dependencies fully compiled @@ -675,8 +666,10 @@ impl Assembler { while let Some(procedure_gid) = worklist.pop() { // If we have already compiled this procedure, do not recompile if let Some(proc) = self.procedure_cache.get(procedure_gid) { - self.module_graph - .register_mast_root(procedure_gid, proc.mast_root(mast_forest))?; + self.module_graph.register_mast_root( + procedure_gid, + proc.mast_root(mast_forest_builder.forest()), + )?; continue; } let is_entry = entrypoint == Some(procedure_gid); @@ -696,17 +689,21 @@ impl Assembler { .with_source_file(ast.source_file()); // Compile this procedure - let procedure = self.compile_procedure(pctx, context, mast_forest)?; + let procedure = self.compile_procedure(pctx, context, mast_forest_builder)?; // Cache the compiled procedure, unless it's the program entrypoint if is_entry { compiled_entrypoint = Some(Arc::from(procedure)); } else { // Make the MAST root available to all dependents - let digest = procedure.mast_root(mast_forest); + let digest = procedure.mast_root(mast_forest_builder.forest()); self.module_graph.register_mast_root(procedure_gid, digest)?; - self.procedure_cache.insert(procedure_gid, Arc::from(procedure), mast_forest)?; + self.procedure_cache.insert( + procedure_gid, + Arc::from(procedure), + mast_forest_builder.forest(), + )?; } } @@ -718,7 +715,7 @@ impl Assembler { &self, procedure: ProcedureContext, context: &mut AssemblyContext, - mast_forest: &mut MastForest, + mast_forest_builder: &mut MastForestBuilder, ) -> Result, Report> { // Make sure the current procedure context is available during codegen let gid = procedure.id(); @@ -726,7 +723,7 @@ impl Assembler { context.set_current_procedure(procedure); let proc = self.module_graph[gid].unwrap_procedure(); - let code = if num_locals > 0 { + let proc_body_root = if num_locals > 0 { // for procedures with locals, we need to update fmp register before and after the // procedure body is executed. specifically: // - to allocate procedure locals we need to increment fmp by the number of locals @@ -736,13 +733,15 @@ impl Assembler { prologue: vec![Operation::Push(num_locals), Operation::FmpUpdate], epilogue: vec![Operation::Push(-num_locals), Operation::FmpUpdate], }; - self.compile_body(proc.iter(), context, Some(wrapper), mast_forest)? + self.compile_body(proc.iter(), context, Some(wrapper), mast_forest_builder)? } else { - self.compile_body(proc.iter(), context, None, mast_forest)? + self.compile_body(proc.iter(), context, None, mast_forest_builder)? }; + mast_forest_builder.make_root(proc_body_root); + let pctx = context.take_current_procedure().unwrap(); - Ok(pctx.into_procedure(code)) + Ok(pctx.into_procedure(proc_body_root)) } fn compile_body<'a, I>( @@ -750,7 +749,7 @@ impl Assembler { body: I, context: &mut AssemblyContext, wrapper: Option, - mast_forest: &mut MastForest, + mast_forest_builder: &mut MastForestBuilder, ) -> Result where I: Iterator, @@ -767,10 +766,10 @@ impl Assembler { inst, &mut basic_block_builder, context, - mast_forest, + mast_forest_builder, )? { if let Some(basic_block_id) = - basic_block_builder.make_basic_block(mast_forest) + basic_block_builder.make_basic_block(mast_forest_builder) { mast_node_ids.push(basic_block_id); } @@ -782,32 +781,35 @@ impl Assembler { Op::If { then_blk, else_blk, .. } => { - if let Some(basic_block_id) = basic_block_builder.make_basic_block(mast_forest) + if let Some(basic_block_id) = + basic_block_builder.make_basic_block(mast_forest_builder) { mast_node_ids.push(basic_block_id); } let then_blk = - self.compile_body(then_blk.iter(), context, None, mast_forest)?; + self.compile_body(then_blk.iter(), context, None, mast_forest_builder)?; let else_blk = - self.compile_body(else_blk.iter(), context, None, mast_forest)?; + self.compile_body(else_blk.iter(), context, None, mast_forest_builder)?; let split_node_id = { - let split_node = MastNode::new_split(then_blk, else_blk, mast_forest); + let split_node = + MastNode::new_split(then_blk, else_blk, mast_forest_builder.forest()); - mast_forest.ensure_node(split_node) + mast_forest_builder.ensure_node(split_node) }; mast_node_ids.push(split_node_id); } Op::Repeat { count, body, .. } => { - if let Some(basic_block_id) = basic_block_builder.make_basic_block(mast_forest) + if let Some(basic_block_id) = + basic_block_builder.make_basic_block(mast_forest_builder) { mast_node_ids.push(basic_block_id); } let repeat_node_id = - self.compile_body(body.iter(), context, None, mast_forest)?; + self.compile_body(body.iter(), context, None, mast_forest_builder)?; for _ in 0..*count { mast_node_ids.push(repeat_node_id); @@ -815,32 +817,34 @@ impl Assembler { } Op::While { body, .. } => { - if let Some(basic_block_id) = basic_block_builder.make_basic_block(mast_forest) + if let Some(basic_block_id) = + basic_block_builder.make_basic_block(mast_forest_builder) { mast_node_ids.push(basic_block_id); } let loop_body_node_id = - self.compile_body(body.iter(), context, None, mast_forest)?; + self.compile_body(body.iter(), context, None, mast_forest_builder)?; let loop_node_id = { - let loop_node = MastNode::new_loop(loop_body_node_id, mast_forest); - mast_forest.ensure_node(loop_node) + let loop_node = + MastNode::new_loop(loop_body_node_id, mast_forest_builder.forest()); + mast_forest_builder.ensure_node(loop_node) }; mast_node_ids.push(loop_node_id); } } } - if let Some(basic_block_id) = basic_block_builder.into_basic_block(mast_forest) { + if let Some(basic_block_id) = basic_block_builder.into_basic_block(mast_forest_builder) { mast_node_ids.push(basic_block_id); } Ok(if mast_node_ids.is_empty() { let basic_block_node = MastNode::new_basic_block(vec![Operation::Noop]); - mast_forest.ensure_node(basic_block_node) + mast_forest_builder.ensure_node(basic_block_node) } else { - combine_mast_node_ids(mast_node_ids, mast_forest) + combine_mast_node_ids(mast_node_ids, mast_forest_builder) }) } @@ -879,7 +883,7 @@ struct BodyWrapper { fn combine_mast_node_ids( mut mast_node_ids: Vec, - mast_forest: &mut MastForest, + mast_forest_builder: &mut MastForestBuilder, ) -> MastNodeId { debug_assert!(!mast_node_ids.is_empty(), "cannot combine empty MAST node id list"); @@ -898,8 +902,8 @@ fn combine_mast_node_ids( while let (Some(left), Some(right)) = (source_mast_node_iter.next(), source_mast_node_iter.next()) { - let join_mast_node = MastNode::new_join(left, right, mast_forest); - let join_mast_node_id = mast_forest.ensure_node(join_mast_node); + let join_mast_node = MastNode::new_join(left, right, mast_forest_builder.forest()); + let join_mast_node_id = mast_forest_builder.ensure_node(join_mast_node); mast_node_ids.push(join_mast_node_id); } diff --git a/assembly/src/assembler/module_graph/name_resolver.rs b/assembly/src/assembler/module_graph/name_resolver.rs index d23ea7ea10..6070a52e18 100644 --- a/assembly/src/assembler/module_graph/name_resolver.rs +++ b/assembly/src/assembler/module_graph/name_resolver.rs @@ -395,7 +395,8 @@ impl<'a> NameResolver<'a> { if let Some(id) = self.graph.get_procedure_index_by_digest(digest) { break Ok(id); } - // This is a phantom procedure - we know its root, but do not have its definition + // This is a phantom procedure - we know its root, but do not have its + // definition break Err(AssemblyError::Failed { labels: vec![ RelatedLabel::error("undefined procedure") diff --git a/assembly/src/assembler/tests.rs b/assembly/src/assembler/tests.rs index 90649b715e..ee68d98b27 100644 --- a/assembly/src/assembler/tests.rs +++ b/assembly/src/assembler/tests.rs @@ -1,12 +1,14 @@ use alloc::{boxed::Box, vec::Vec}; +use pretty_assertions::assert_eq; use vm_core::{ assert_matches, - mast::{MastForest, MastNode, MerkleTreeNode}, + mast::{MastForest, MastNode}, + Program, }; use super::{Assembler, Library, Operation}; use crate::{ - assembler::combine_mast_node_ids, + assembler::{combine_mast_node_ids, mast_forest_builder::MastForestBuilder}, ast::{Module, ModuleKind}, LibraryNamespace, Version, }; @@ -71,7 +73,7 @@ fn nested_blocks() { .unwrap(); // The expected `MastForest` for the program (that we will build by hand) - let mut expected_mast_forest = MastForest::new(); + let mut expected_mast_forest_builder = MastForestBuilder::new(); // fetch the kernel digest and store into a syscall block // @@ -80,10 +82,11 @@ fn nested_blocks() { // `Assembler::with_kernel_from_module()`. let syscall_foo_node_id = { let kernel_foo_node = MastNode::new_basic_block(vec![Operation::Add]); - let kernel_foo_node_id = expected_mast_forest.ensure_node(kernel_foo_node); + let kernel_foo_node_id = expected_mast_forest_builder.ensure_node(kernel_foo_node); - let syscall_node = MastNode::new_syscall(kernel_foo_node_id, &expected_mast_forest); - expected_mast_forest.ensure_node(syscall_node) + let syscall_node = + MastNode::new_syscall(kernel_foo_node_id, expected_mast_forest_builder.forest()); + expected_mast_forest_builder.ensure_node(syscall_node) }; let program = r#" @@ -127,92 +130,183 @@ fn nested_blocks() { let exec_bar_node_id = { // bar procedure let basic_block_1 = MastNode::new_basic_block(vec![Operation::Push(17_u32.into())]); - let basic_block_1_id = expected_mast_forest.ensure_node(basic_block_1); + let basic_block_1_id = expected_mast_forest_builder.ensure_node(basic_block_1); // Basic block representing the `foo` procedure let basic_block_2 = MastNode::new_basic_block(vec![Operation::Push(19_u32.into())]); - let basic_block_2_id = expected_mast_forest.ensure_node(basic_block_2); - - let join_node = - MastNode::new_join(basic_block_1_id, basic_block_2_id, &expected_mast_forest); - expected_mast_forest.ensure_node(join_node) + let basic_block_2_id = expected_mast_forest_builder.ensure_node(basic_block_2); + + let join_node = MastNode::new_join( + basic_block_1_id, + basic_block_2_id, + expected_mast_forest_builder.forest(), + ); + expected_mast_forest_builder.ensure_node(join_node) }; let exec_foo_bar_baz_node_id = { // basic block representing foo::bar.baz procedure let basic_block = MastNode::new_basic_block(vec![Operation::Push(29_u32.into())]); - expected_mast_forest.ensure_node(basic_block) + expected_mast_forest_builder.ensure_node(basic_block) }; let before = { let before_node = MastNode::new_basic_block(vec![Operation::Push(2u32.into())]); - expected_mast_forest.ensure_node(before_node) + expected_mast_forest_builder.ensure_node(before_node) }; let r#true1 = { let r#true_node = MastNode::new_basic_block(vec![Operation::Push(3u32.into())]); - expected_mast_forest.ensure_node(r#true_node) + expected_mast_forest_builder.ensure_node(r#true_node) }; let r#false1 = { let r#false_node = MastNode::new_basic_block(vec![Operation::Push(5u32.into())]); - expected_mast_forest.ensure_node(r#false_node) + expected_mast_forest_builder.ensure_node(r#false_node) }; let r#if1 = { - let r#if_node = MastNode::new_split(r#true1, r#false1, &expected_mast_forest); - expected_mast_forest.ensure_node(r#if_node) + let r#if_node = + MastNode::new_split(r#true1, r#false1, expected_mast_forest_builder.forest()); + expected_mast_forest_builder.ensure_node(r#if_node) }; let r#true3 = { let r#true_node = MastNode::new_basic_block(vec![Operation::Push(7u32.into())]); - expected_mast_forest.ensure_node(r#true_node) + expected_mast_forest_builder.ensure_node(r#true_node) }; let r#false3 = { let r#false_node = MastNode::new_basic_block(vec![Operation::Push(11u32.into())]); - expected_mast_forest.ensure_node(r#false_node) + expected_mast_forest_builder.ensure_node(r#false_node) }; let r#true2 = { - let r#if_node = MastNode::new_split(r#true3, r#false3, &expected_mast_forest); - expected_mast_forest.ensure_node(r#if_node) + let r#if_node = + MastNode::new_split(r#true3, r#false3, expected_mast_forest_builder.forest()); + expected_mast_forest_builder.ensure_node(r#if_node) }; let r#while = { let push_basic_block_id = { let push_basic_block = MastNode::new_basic_block(vec![Operation::Push(23u32.into())]); - expected_mast_forest.ensure_node(push_basic_block) + expected_mast_forest_builder.ensure_node(push_basic_block) }; let body_node_id = { - let body_node = - MastNode::new_join(exec_bar_node_id, push_basic_block_id, &expected_mast_forest); + let body_node = MastNode::new_join( + exec_bar_node_id, + push_basic_block_id, + expected_mast_forest_builder.forest(), + ); - expected_mast_forest.ensure_node(body_node) + expected_mast_forest_builder.ensure_node(body_node) }; - let loop_node = MastNode::new_loop(body_node_id, &expected_mast_forest); - expected_mast_forest.ensure_node(loop_node) + let loop_node = MastNode::new_loop(body_node_id, expected_mast_forest_builder.forest()); + expected_mast_forest_builder.ensure_node(loop_node) }; let push_13_basic_block_id = { let node = MastNode::new_basic_block(vec![Operation::Push(13u32.into())]); - expected_mast_forest.ensure_node(node) + expected_mast_forest_builder.ensure_node(node) }; let r#false2 = { - let node = MastNode::new_join(push_13_basic_block_id, r#while, &expected_mast_forest); - expected_mast_forest.ensure_node(node) + let node = MastNode::new_join( + push_13_basic_block_id, + r#while, + expected_mast_forest_builder.forest(), + ); + expected_mast_forest_builder.ensure_node(node) }; let nested = { - let node = MastNode::new_split(r#true2, r#false2, &expected_mast_forest); - expected_mast_forest.ensure_node(node) + let node = MastNode::new_split(r#true2, r#false2, expected_mast_forest_builder.forest()); + expected_mast_forest_builder.ensure_node(node) }; let combined_node_id = combine_mast_node_ids( vec![before, r#if1, nested, exec_foo_bar_baz_node_id, syscall_foo_node_id], - &mut expected_mast_forest, + &mut expected_mast_forest_builder, ); - expected_mast_forest.set_entrypoint(combined_node_id); - let combined_node = &expected_mast_forest[combined_node_id]; + let expected_program = Program::new(expected_mast_forest_builder.build(), combined_node_id); + assert_eq!(expected_program.hash(), program.hash()); + + // also check that the program has the right number of procedures + assert_eq!(program.num_procedures(), 5); +} + +/// Ensures that a single copy of procedures with the same MAST root are added only once to the MAST +/// forest. +#[test] +fn duplicate_procedure() { + let assembler = Assembler::new(); + + let program_source = r#" + proc.foo + add + mul + end + + proc.bar + add + mul + end + + begin + # specific impl irrelevant + exec.foo + exec.bar + end + "#; + + let program = assembler.assemble(program_source).unwrap(); + assert_eq!(program.num_procedures(), 2); +} + +/// Ensures that equal MAST nodes don't get added twice to a MAST forest +#[test] +fn duplicate_nodes() { + let assembler = Assembler::new(); + + let program_source = r#" + begin + if.true + mul + else + if.true add else mul end + end + end + "#; + + let program = assembler.assemble(program_source).unwrap(); + + let mut expected_mast_forest = MastForest::new(); + + // basic block: mul + let mul_basic_block_id = { + let node = MastNode::new_basic_block(vec![Operation::Mul]); + expected_mast_forest.add_node(node) + }; + + // basic block: add + let add_basic_block_id = { + let node = MastNode::new_basic_block(vec![Operation::Add]); + expected_mast_forest.add_node(node) + }; + + // inner split: `if.true add else mul end` + let inner_split_id = { + let node = + MastNode::new_split(add_basic_block_id, mul_basic_block_id, &expected_mast_forest); + expected_mast_forest.add_node(node) + }; + + // root: outer split + let root_id = { + let node = MastNode::new_split(mul_basic_block_id, inner_split_id, &expected_mast_forest); + expected_mast_forest.add_node(node) + }; + expected_mast_forest.make_root(root_id); + + let expected_program = Program::new(expected_mast_forest, root_id); - assert_eq!(combined_node.digest(), program.entrypoint_digest().unwrap()); + assert_eq!(program, expected_program); } #[test] diff --git a/assembly/src/testing.rs b/assembly/src/testing.rs index 4bc3e20310..a9b99051a8 100644 --- a/assembly/src/testing.rs +++ b/assembly/src/testing.rs @@ -308,10 +308,7 @@ impl TestContext { /// module represented in `source`. #[track_caller] pub fn assemble(&mut self, source: impl Compile) -> Result { - self.assembler - .clone() - .assemble(source) - .map(|mast_forest| mast_forest.try_into().unwrap()) + self.assembler.clone().assemble(source) } /// Compile a module from `source`, with the fully-qualified name `path`, to MAST, returning diff --git a/core/src/errors.rs b/core/src/errors.rs index a3c01446c6..5e4d0428e1 100644 --- a/core/src/errors.rs +++ b/core/src/errors.rs @@ -40,12 +40,3 @@ pub enum KernelError { #[error("kernel can have at most {0} procedures, received {1}")] TooManyProcedures(usize, usize), } - -// PROGRAM ERROR -// ================================================================================================ - -#[derive(Clone, Debug, thiserror::Error)] -pub enum ProgramError { - #[error("tried to create a program from a MAST forest with no entrypoint")] - NoEntrypoint, -} diff --git a/core/src/mast/mod.rs b/core/src/mast/mod.rs index 070e36ef40..5371d4c35f 100644 --- a/core/src/mast/mod.rs +++ b/core/src/mast/mod.rs @@ -1,6 +1,6 @@ use core::{fmt, ops::Index}; -use alloc::{collections::BTreeMap, vec::Vec}; +use alloc::vec::Vec; use miden_crypto::hash::rpo::RpoDigest; mod node; @@ -9,8 +9,6 @@ pub use node::{ OpBatch, SplitNode, OP_BATCH_SIZE, OP_GROUP_SIZE, }; -use crate::Kernel; - #[cfg(test)] mod tests; @@ -23,9 +21,9 @@ pub trait MerkleTreeNode { /// An opaque handle to a [`MastNode`] in some [`MastForest`]. It is the responsibility of the user /// to use a given [`MastNodeId`] with the corresponding [`MastForest`]. /// -/// Note that since a [`MastForest`] enforces the invariant that equal [`MastNode`]s MUST have equal -/// [`MastNodeId`]s, [`MastNodeId`] equality can be used to determine equality of the underlying -/// [`MastNode`]. +/// Note that the [`MastForest`] does *not* ensure that equal [`MastNode`]s have equal +/// [`MastNodeId`] handles. Hence, [`MastNodeId`] equality must not be used to test for equality of +/// the underlying [`MastNode`]. #[derive(Debug, Clone, Copy, PartialEq, Eq, PartialOrd, Ord, Hash)] pub struct MastNodeId(u32); @@ -38,18 +36,17 @@ impl fmt::Display for MastNodeId { // MAST FOREST // =============================================================================================== -#[derive(Clone, Debug, Default)] +/// Represents one or more procedures, represented as a collection of [`MastNode`]s. +/// +/// A [`MastForest`] does not have an entrypoint, and hence is not executable. A [`crate::Program`] +/// can be built from a [`MastForest`] to specify an entrypoint. +#[derive(Clone, Debug, Default, PartialEq, Eq)] pub struct MastForest { - /// All of the blocks local to the trees comprising the MAST forest + /// All of the nodes local to the trees comprising the MAST forest. nodes: Vec, - node_id_by_hash: BTreeMap, - - /// The "entrypoint", when set, is the root of the entire forest, i.e. a path exists from this - /// node to all other roots in the forest. This corresponds to the executable entry point. - /// Whether or not the entrypoint is set distinguishes a MAST which is executable, versus a - /// MAST which represents a library. - entrypoint: Option, - kernel: Kernel, + + /// Roots of procedures defined within this MAST forest. + roots: Vec, } /// Constructors @@ -62,66 +59,38 @@ impl MastForest { /// Mutators impl MastForest { - /// Adds a node to the forest, and returns the [`MastNodeId`] associated with it. + /// Adds a node to the forest, and returns the associated [`MastNodeId`]. /// - /// If a [`MastNode`] which is equal to the current node was previously added, the previously - /// returned [`MastNodeId`] will be returned. This enforces this invariant that equal - /// [`MastNode`]s have equal [`MastNodeId`]s. - pub fn ensure_node(&mut self, node: MastNode) -> MastNodeId { - let node_digest = node.digest(); - - if let Some(node_id) = self.node_id_by_hash.get(&node_digest) { - // node already exists in the forest; return previously assigned id - *node_id - } else { - let new_node_id = - MastNodeId(self.nodes.len().try_into().expect( - "invalid node id: exceeded maximum number of nodes in a single forest", - )); - - self.node_id_by_hash.insert(node_digest, new_node_id); - self.nodes.push(node); - - new_node_id - } + /// Adding two duplicate nodes will result in two distinct returned [`MastNodeId`]s. + pub fn add_node(&mut self, node: MastNode) -> MastNodeId { + let new_node_id = MastNodeId( + self.nodes + .len() + .try_into() + .expect("invalid node id: exceeded maximum number of nodes in a single forest"), + ); + + self.nodes.push(node); + + new_node_id } - /// Sets the kernel for this forest. + /// Marks the given [`MastNodeId`] as being the root of a procedure. /// - /// The kernel MUST have been compiled using this [`MastForest`]; that is, all kernel procedures - /// must be present in this forest. - pub fn set_kernel(&mut self, kernel: Kernel) { - #[cfg(debug_assertions)] - for proc_hash in kernel.proc_hashes() { - assert!(self.node_id_by_hash.contains_key(proc_hash)); + /// # Panics + /// - if `new_root_id`'s internal index is larger than the number of nodes in this forest (i.e. + /// clearly doesn't belong to this MAST forest). + pub fn make_root(&mut self, new_root_id: MastNodeId) { + assert!((new_root_id.0 as usize) < self.nodes.len()); + + if !self.roots.contains(&new_root_id) { + self.roots.push(new_root_id); } - - self.kernel = kernel; - } - - /// Sets the entrypoint for this forest. - pub fn set_entrypoint(&mut self, entrypoint: MastNodeId) { - self.entrypoint = Some(entrypoint); } } /// Public accessors impl MastForest { - /// Returns the kernel associated with this forest. - pub fn kernel(&self) -> &Kernel { - &self.kernel - } - - /// Returns the entrypoint associated with this forest, if any. - pub fn entrypoint(&self) -> Option { - self.entrypoint - } - - /// A convenience method that provides the hash of the entrypoint, if any. - pub fn entrypoint_digest(&self) -> Option { - self.entrypoint.map(|entrypoint| self[entrypoint].digest()) - } - /// Returns the [`MastNode`] associated with the provided [`MastNodeId`] if valid, or else /// `None`. /// @@ -133,13 +102,23 @@ impl MastForest { self.nodes.get(idx) } - /// Returns the [`MastNodeId`] associated with a given digest, if any. - /// - /// That is, every [`MastNode`] hashes to some digest. If there exists a [`MastNode`] in the - /// forest that hashes to this digest, then its id is returned. + /// Returns the [`MastNodeId`] of the procedure associated with a given digest, if any. #[inline(always)] - pub fn get_node_id_by_digest(&self, digest: RpoDigest) -> Option { - self.node_id_by_hash.get(&digest).copied() + pub fn find_procedure_root(&self, digest: RpoDigest) -> Option { + self.roots.iter().find(|&&root_id| self[root_id].digest() == digest).copied() + } + + /// Returns an iterator over the digest of the procedures in this MAST forest. + pub fn procedure_roots(&self) -> impl Iterator + '_ { + self.roots.iter().map(|&root_id| self[root_id].digest()) + } + + /// Returns the number of procedures in this MAST forest. + pub fn num_procedures(&self) -> u32 { + self.roots + .len() + .try_into() + .expect("MAST forest contains more than 2^32 procedures.") } } diff --git a/core/src/mast/node/external.rs b/core/src/mast/node/external.rs new file mode 100644 index 0000000000..c0b8ff10a3 --- /dev/null +++ b/core/src/mast/node/external.rs @@ -0,0 +1,49 @@ +use crate::mast::{MastForest, MerkleTreeNode}; +use core::fmt; +use miden_crypto::hash::rpo::RpoDigest; + +/// Node for referencing procedures not present in a given [`MastForest`] (hence "external"). +/// +/// External nodes can be used to verify the integrity of a program's hash while keeping parts of +/// the program secret. They also allow a program to refer to a well-known procedure that was not +/// compiled with the program (e.g. a procedure in the standard library). +/// +/// The hash of an external node is the hash of the procedure it represents, such that an external +/// node can be swapped with the actual subtree that it represents without changing the MAST root. +#[derive(Clone, Debug, PartialEq, Eq)] +pub struct ExternalNode { + digest: RpoDigest, +} + +impl ExternalNode { + /// Returns a new [`ExternalNode`] instantiated with the specified procedure hash. + pub fn new(procedure_hash: RpoDigest) -> Self { + Self { + digest: procedure_hash, + } + } +} + +impl MerkleTreeNode for ExternalNode { + fn digest(&self) -> RpoDigest { + self.digest + } + fn to_display<'a>(&'a self, _mast_forest: &'a MastForest) -> impl fmt::Display + 'a { + self + } +} + +impl crate::prettier::PrettyPrint for ExternalNode { + fn render(&self) -> crate::prettier::Document { + use crate::prettier::*; + use miden_formatting::hex::ToHex; + const_text("external") + const_text(".") + text(self.digest.as_bytes().to_hex_with_prefix()) + } +} + +impl fmt::Display for ExternalNode { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + use crate::prettier::PrettyPrint; + self.pretty_print(f) + } +} diff --git a/core/src/mast/node/mod.rs b/core/src/mast/node/mod.rs index 1fc8275194..2bf0836cf3 100644 --- a/core/src/mast/node/mod.rs +++ b/core/src/mast/node/mod.rs @@ -13,6 +13,9 @@ pub use call_node::CallNode; mod dyn_node; pub use dyn_node::DynNode; +mod external; +pub use external::ExternalNode; + mod join_node; pub use join_node::JoinNode; @@ -37,6 +40,7 @@ pub enum MastNode { Loop(LoopNode), Call(CallNode), Dyn, + External(ExternalNode), } /// Constructors @@ -87,6 +91,10 @@ impl MastNode { pub fn new_dyncall(dyn_node_id: MastNodeId, mast_forest: &MastForest) -> Self { Self::Call(CallNode::new(dyn_node_id, mast_forest)) } + + pub fn new_external(mast_root: RpoDigest) -> Self { + Self::External(ExternalNode::new(mast_root)) + } } /// Public accessors @@ -116,6 +124,7 @@ impl MastNode { MastNodePrettyPrint::new(Box::new(call_node.to_pretty_print(mast_forest))) } MastNode::Dyn => MastNodePrettyPrint::new(Box::new(DynNode)), + MastNode::External(external_node) => MastNodePrettyPrint::new(Box::new(external_node)), } } @@ -127,6 +136,7 @@ impl MastNode { MastNode::Loop(_) => LoopNode::DOMAIN, MastNode::Call(call_node) => call_node.domain(), MastNode::Dyn => DynNode::DOMAIN, + MastNode::External(_) => panic!("Can't fetch domain for an `External` node."), } } } @@ -140,6 +150,7 @@ impl MerkleTreeNode for MastNode { MastNode::Loop(node) => node.digest(), MastNode::Call(node) => node.digest(), MastNode::Dyn => DynNode.digest(), + MastNode::External(node) => node.digest(), } } @@ -151,6 +162,7 @@ impl MerkleTreeNode for MastNode { MastNode::Loop(node) => MastNodeDisplay::new(node.to_display(mast_forest)), MastNode::Call(node) => MastNodeDisplay::new(node.to_display(mast_forest)), MastNode::Dyn => MastNodeDisplay::new(DynNode.to_display(mast_forest)), + MastNode::External(node) => MastNodeDisplay::new(node.to_display(mast_forest)), } } } diff --git a/core/src/mast/tests.rs b/core/src/mast/tests.rs index 5c4e54e738..da43d1b5b4 100644 --- a/core/src/mast/tests.rs +++ b/core/src/mast/tests.rs @@ -1,7 +1,7 @@ use crate::{ chiplets::hasher, - mast::{DynNode, Kernel, MerkleTreeNode}, - ProgramInfo, Word, + mast::{DynNode, MerkleTreeNode}, + Kernel, ProgramInfo, Word, }; use alloc::vec::Vec; use miden_crypto::{hash::rpo::RpoDigest, Felt}; diff --git a/core/src/program.rs b/core/src/program.rs index 73e9fcf4be..b055bf3313 100644 --- a/core/src/program.rs +++ b/core/src/program.rs @@ -5,8 +5,7 @@ use miden_crypto::{hash::rpo::RpoDigest, Felt, WORD_SIZE}; use winter_utils::{ByteReader, ByteWriter, Deserializable, DeserializationError, Serializable}; use crate::{ - errors::ProgramError, - mast::{MastForest, MastNode, MastNodeId}, + mast::{MastForest, MastNode, MastNodeId, MerkleTreeNode}, utils::ToElements, }; @@ -15,18 +14,42 @@ use super::Kernel; // PROGRAM // =============================================================================================== -#[derive(Clone, Debug)] +#[derive(Clone, Debug, PartialEq, Eq)] pub struct Program { mast_forest: MastForest, + /// The "entrypoint" is the node where execution of the program begins. + entrypoint: MastNodeId, + kernel: Kernel, } /// Constructors impl Program { - pub fn new(mast_forest: MastForest) -> Result { - if mast_forest.entrypoint().is_some() { - Ok(Self { mast_forest }) - } else { - Err(ProgramError::NoEntrypoint) + /// Construct a new [`Program`] from the given MAST forest and entrypoint. The kernel is assumed + /// to be empty. + /// + /// # Panics: + /// - if `mast_forest` doesn't have an entrypoint + pub fn new(mast_forest: MastForest, entrypoint: MastNodeId) -> Self { + assert!(mast_forest.get_node_by_id(entrypoint).is_some()); + + Self { + mast_forest, + entrypoint, + kernel: Kernel::default(), + } + } + + /// Construct a new [`Program`] from the given MAST forest, entrypoint, and kernel. + /// + /// # Panics: + /// - if `mast_forest` doesn't have an entrypoint + pub fn with_kernel(mast_forest: MastForest, entrypoint: MastNodeId, kernel: Kernel) -> Self { + assert!(mast_forest.get_node_by_id(entrypoint).is_some()); + + Self { + mast_forest, + entrypoint, + kernel, } } } @@ -40,17 +63,19 @@ impl Program { /// Returns the kernel associated with this program. pub fn kernel(&self) -> &Kernel { - self.mast_forest.kernel() + &self.kernel } /// Returns the entrypoint associated with this program. pub fn entrypoint(&self) -> MastNodeId { - self.mast_forest.entrypoint().unwrap() + self.entrypoint } - /// A convenience method that provides the hash of the entrypoint. + /// Returns the hash of the program's entrypoint. + /// + /// Equivalently, returns the hash of the root of the entrypoint procedure. pub fn hash(&self) -> RpoDigest { - self.mast_forest.entrypoint_digest().unwrap() + self.mast_forest[self.entrypoint].digest() } /// Returns the [`MastNode`] associated with the provided [`MastNodeId`] if valid, or else @@ -62,13 +87,15 @@ impl Program { self.mast_forest.get_node_by_id(node_id) } - /// Returns the [`MastNodeId`] associated with a given digest, if any. - /// - /// That is, every [`MastNode`] hashes to some digest. If there exists a [`MastNode`] in the - /// forest that hashes to this digest, then its id is returned. + /// Returns the [`MastNodeId`] of the procedure root associated with a given digest, if any. #[inline(always)] - pub fn get_node_id_by_digest(&self, digest: RpoDigest) -> Option { - self.mast_forest.get_node_id_by_digest(digest) + pub fn find_procedure_root(&self, digest: RpoDigest) -> Option { + self.mast_forest.find_procedure_root(digest) + } + + /// Returns the number of procedures in this program. + pub fn num_procedures(&self) -> u32 { + self.mast_forest.num_procedures() } } @@ -96,14 +123,6 @@ impl fmt::Display for Program { } } -impl TryFrom for Program { - type Error = ProgramError; - - fn try_from(mast_forest: MastForest) -> Result { - Self::new(mast_forest) - } -} - impl From for MastForest { fn from(program: Program) -> Self { program.mast_forest diff --git a/miden/README.md b/miden/README.md index 80e87ad5e9..7176467b56 100644 --- a/miden/README.md +++ b/miden/README.md @@ -57,7 +57,7 @@ use processor::ExecutionOptions; let mut assembler = Assembler::default(); // compile Miden assembly source code into a program -let program = assembler.assemble_program("begin push.3 push.5 add end").unwrap(); +let program = assembler.assemble("begin push.3 push.5 add end").unwrap(); // use an empty list as initial stack let stack_inputs = StackInputs::default(); @@ -105,7 +105,7 @@ use miden_vm::{Assembler, DefaultHost, ProvingOptions, Program, prove, StackInpu let mut assembler = Assembler::default(); // this is our program, we compile it from assembly code -let program = assembler.assemble_program("begin push.3 push.5 add end").unwrap(); +let program = assembler.assemble("begin push.3 push.5 add end").unwrap(); // let's execute it and generate a STARK proof let (outputs, proof) = prove( @@ -193,7 +193,7 @@ let source = format!( n - 1 ); let mut assembler = Assembler::default(); -let program = assembler.assemble_program(&source).unwrap(); +let program = assembler.assemble(&source).unwrap(); // initialize a default host (with an empty advice provider) let host = DefaultHost::default(); diff --git a/miden/benches/program_execution.rs b/miden/benches/program_execution.rs index 1d191986ab..57719f31b2 100644 --- a/miden/benches/program_execution.rs +++ b/miden/benches/program_execution.rs @@ -18,11 +18,7 @@ fn program_execution(c: &mut Criterion) { let assembler = Assembler::default() .with_library(&StdLibrary::default()) .expect("failed to load stdlib"); - let program: Program = assembler - .assemble(source) - .expect("Failed to compile test source.") - .try_into() - .expect("test source has no entrypoint."); + let program: Program = assembler.assemble(source).expect("Failed to compile test source."); bench.iter(|| { execute( &program, diff --git a/miden/src/cli/data.rs b/miden/src/cli/data.rs index 1d797b0c7b..d68308ee4c 100644 --- a/miden/src/cli/data.rs +++ b/miden/src/cli/data.rs @@ -419,9 +419,8 @@ impl ProgramFile { .with_libraries(libraries.into_iter()) .wrap_err("Failed to load libraries")?; - let program: Program = assembler - .assemble_program(self.ast.as_ref()) - .wrap_err("Failed to compile program")?; + let program: Program = + assembler.assemble(self.ast.as_ref()).wrap_err("Failed to compile program")?; Ok(program) } diff --git a/miden/src/examples/blake3.rs b/miden/src/examples/blake3.rs index 7bae853923..2e87cadbcc 100644 --- a/miden/src/examples/blake3.rs +++ b/miden/src/examples/blake3.rs @@ -47,7 +47,7 @@ fn generate_blake3_program(n: usize) -> Program { Assembler::default() .with_library(&StdLibrary::default()) .unwrap() - .assemble_program(program) + .assemble(program) .unwrap() } diff --git a/miden/src/examples/fibonacci.rs b/miden/src/examples/fibonacci.rs index 76c827f710..7bd6555c52 100644 --- a/miden/src/examples/fibonacci.rs +++ b/miden/src/examples/fibonacci.rs @@ -39,7 +39,7 @@ fn generate_fibonacci_program(n: usize) -> Program { n - 1 ); - Assembler::default().assemble_program(program).unwrap() + Assembler::default().assemble(program).unwrap() } /// Computes the `n`-th term of Fibonacci sequence @@ -57,22 +57,27 @@ fn compute_fibonacci(n: usize) -> Felt { // EXAMPLE TESTER // ================================================================================================ -#[test] -fn test_fib_example() { - let example = get_example(16); - super::test_example(example, false); -} +#[cfg(test)] +mod tests { + use super::*; + use crate::examples::{test_example, test_example_with_options}; + use prover::ProvingOptions; -#[test] -fn test_fib_example_fail() { - let example = get_example(16); - super::test_example(example, true); -} + #[test] + fn test_fib_example() { + let example = get_example(16); + test_example(example, false); + } -#[test] -fn test_fib_example_rpo() { - use miden_vm::ProvingOptions; + #[test] + fn test_fib_example_fail() { + let example = get_example(16); + test_example(example, true); + } - let example = get_example(16); - super::test_example_with_options(example, false, ProvingOptions::with_96_bit_security(true)); + #[test] + fn test_fib_example_rpo() { + let example = get_example(16); + test_example_with_options(example, false, ProvingOptions::with_96_bit_security(true)); + } } diff --git a/miden/src/repl/mod.rs b/miden/src/repl/mod.rs index 074b6f7e25..aa39436d98 100644 --- a/miden/src/repl/mod.rs +++ b/miden/src/repl/mod.rs @@ -293,7 +293,7 @@ fn execute( .with_libraries(provided_libraries.iter()) .map_err(|err| format!("{err}"))?; - let program = assembler.assemble_program(program).map_err(|err| format!("{err}"))?; + let program = assembler.assemble(program).map_err(|err| format!("{err}"))?; let stack_inputs = StackInputs::default(); let host = DefaultHost::default(); diff --git a/miden/src/tools/mod.rs b/miden/src/tools/mod.rs index 60e299d00d..0028b2c4e1 100644 --- a/miden/src/tools/mod.rs +++ b/miden/src/tools/mod.rs @@ -216,7 +216,7 @@ where let program = Assembler::default() .with_debug_mode(true) .with_library(&StdLibrary::default())? - .assemble_program(program)?; + .assemble(program)?; let mut execution_details = ExecutionDetails::default(); let vm_state_iterator = processor::execute_iter(&program, stack_inputs, host); diff --git a/miden/tests/integration/operations/decorators/events.rs b/miden/tests/integration/operations/decorators/events.rs index d1d6397327..c9385fe46f 100644 --- a/miden/tests/integration/operations/decorators/events.rs +++ b/miden/tests/integration/operations/decorators/events.rs @@ -13,11 +13,7 @@ fn test_event_handling() { end"; // compile and execute program - let program: Program = Assembler::default() - .assemble(source) - .unwrap() - .try_into() - .expect("test source has no entrypoint."); + let program: Program = Assembler::default().assemble(source).unwrap(); let mut host = TestHost::default(); processor::execute(&program, Default::default(), &mut host, Default::default()).unwrap(); @@ -37,11 +33,7 @@ fn test_trace_handling() { end"; // compile program - let program: Program = Assembler::default() - .assemble(source) - .unwrap() - .try_into() - .expect("test source has no entrypoint."); + let program: Program = Assembler::default().assemble(source).unwrap(); let mut host = TestHost::default(); // execute program with disabled tracing diff --git a/miden/tests/integration/operations/decorators/mod.rs b/miden/tests/integration/operations/decorators/mod.rs index 7212cedd2f..7c5fd8db9f 100644 --- a/miden/tests/integration/operations/decorators/mod.rs +++ b/miden/tests/integration/operations/decorators/mod.rs @@ -1,6 +1,8 @@ +use std::sync::Arc; + use processor::{ - AdviceExtractor, AdviceProvider, ExecutionError, Host, HostResponse, MemAdviceProvider, - ProcessState, + AdviceExtractor, AdviceProvider, ExecutionError, Host, HostResponse, MastForest, + MemAdviceProvider, ProcessState, }; use vm_core::AdviceInjector; @@ -60,4 +62,9 @@ impl Host for TestHost { self.trace_handler.push(trace_id); Ok(HostResponse::None) } + + fn get_mast_forest(&self, _node_digest: &prover::Digest) -> Option> { + // Empty MAST forest store + None + } } diff --git a/miden/tests/integration/operations/io_ops/env_ops.rs b/miden/tests/integration/operations/io_ops/env_ops.rs index c8acfbd41e..b0736ea74e 100644 --- a/miden/tests/integration/operations/io_ops/env_ops.rs +++ b/miden/tests/integration/operations/io_ops/env_ops.rs @@ -161,7 +161,7 @@ fn build_bar_hash() -> [u64; 4] { let mut mast_forest = MastForest::new(); let foo_root = MastNode::new_basic_block(vec![Operation::Caller]); - let foo_root_id = mast_forest.ensure_node(foo_root); + let foo_root_id = mast_forest.add_node(foo_root); let bar_root = MastNode::new_syscall(foo_root_id, &mast_forest); let bar_hash: Word = bar_root.digest().into(); diff --git a/processor/src/chiplets/hasher/tests.rs b/processor/src/chiplets/hasher/tests.rs index 4462ad6bef..912aeadc28 100644 --- a/processor/src/chiplets/hasher/tests.rs +++ b/processor/src/chiplets/hasher/tests.rs @@ -249,19 +249,19 @@ fn hash_memoization_control_blocks() { let mut mast_forest = MastForest::new(); let t_branch = MastNode::new_basic_block(vec![Operation::Push(ZERO)]); - let t_branch_id = mast_forest.ensure_node(t_branch.clone()); + let t_branch_id = mast_forest.add_node(t_branch.clone()); let f_branch = MastNode::new_basic_block(vec![Operation::Push(ONE)]); - let f_branch_id = mast_forest.ensure_node(f_branch.clone()); + let f_branch_id = mast_forest.add_node(f_branch.clone()); let split1 = MastNode::new_split(t_branch_id, f_branch_id, &mast_forest); - let split1_id = mast_forest.ensure_node(split1.clone()); + let split1_id = mast_forest.add_node(split1.clone()); let split2 = MastNode::new_split(t_branch_id, f_branch_id, &mast_forest); - let split2_id = mast_forest.ensure_node(split2.clone()); + let split2_id = mast_forest.add_node(split2.clone()); let join_node = MastNode::new_join(split1_id, split2_id, &mast_forest); - let _join_node_id = mast_forest.ensure_node(join_node.clone()); + let _join_node_id = mast_forest.add_node(join_node.clone()); let mut hasher = Hasher::default(); let h1: [Felt; DIGEST_LEN] = split1 @@ -414,19 +414,19 @@ fn hash_memoization_basic_blocks_check(basic_block: MastNode) { let mut mast_forest = MastForest::new(); let basic_block_1 = basic_block.clone(); - let basic_block_1_id = mast_forest.ensure_node(basic_block_1.clone()); + let basic_block_1_id = mast_forest.add_node(basic_block_1.clone()); let loop_body = MastNode::new_basic_block(vec![Operation::Pad, Operation::Eq, Operation::Not]); - let loop_body_id = mast_forest.ensure_node(loop_body); + let loop_body_id = mast_forest.add_node(loop_body); let loop_block = MastNode::new_loop(loop_body_id, &mast_forest); - let loop_block_id = mast_forest.ensure_node(loop_block.clone()); + let loop_block_id = mast_forest.add_node(loop_block.clone()); let join2_block = MastNode::new_join(basic_block_1_id, loop_block_id, &mast_forest); - let join2_block_id = mast_forest.ensure_node(join2_block.clone()); + let join2_block_id = mast_forest.add_node(join2_block.clone()); let basic_block_2 = basic_block; - let basic_block_2_id = mast_forest.ensure_node(basic_block_2.clone()); + let basic_block_2_id = mast_forest.add_node(basic_block_2.clone()); let join1_block = MastNode::new_join(join2_block_id, basic_block_2_id, &mast_forest); diff --git a/processor/src/chiplets/tests.rs b/processor/src/chiplets/tests.rs index a7c6c526e7..8353376989 100644 --- a/processor/src/chiplets/tests.rs +++ b/processor/src/chiplets/tests.rs @@ -14,7 +14,7 @@ use miden_air::trace::{ }; use vm_core::{ mast::{MastForest, MastNode}, - Felt, ONE, ZERO, + Felt, Program, ONE, ZERO, }; type ChipletsTrace = [Vec; CHIPLETS_WIDTH]; @@ -120,10 +120,9 @@ fn build_trace( let mut mast_forest = MastForest::new(); let basic_block = MastNode::new_basic_block(operations); - let basic_block_id = mast_forest.ensure_node(basic_block); - mast_forest.set_entrypoint(basic_block_id); + let basic_block_id = mast_forest.add_node(basic_block); - mast_forest.try_into().unwrap() + Program::new(mast_forest, basic_block_id) }; process.execute(&program).unwrap(); diff --git a/processor/src/decoder/mod.rs b/processor/src/decoder/mod.rs index 72845ba549..5d66138df4 100644 --- a/processor/src/decoder/mod.rs +++ b/processor/src/decoder/mod.rs @@ -12,11 +12,11 @@ use miden_air::trace::{ }; use vm_core::{ mast::{ - get_span_op_group_count, BasicBlockNode, CallNode, DynNode, JoinNode, LoopNode, + get_span_op_group_count, BasicBlockNode, CallNode, DynNode, JoinNode, LoopNode, MastForest, MerkleTreeNode, SplitNode, OP_BATCH_SIZE, }, stack::STACK_TOP_SIZE, - AssemblyOp, Program, + AssemblyOp, }; mod trace; @@ -56,7 +56,7 @@ where pub(super) fn start_join_node( &mut self, node: &JoinNode, - program: &Program, + program: &MastForest, ) -> Result<(), ExecutionError> { // use the hasher to compute the hash of the JOIN block; the row address returned by the // hasher is used as the ID of the block; the result of the hash is expected to be in @@ -106,7 +106,7 @@ where pub(super) fn start_split_node( &mut self, node: &SplitNode, - program: &Program, + program: &MastForest, ) -> Result { let condition = self.stack.peek(); @@ -158,7 +158,7 @@ where pub(super) fn start_loop_node( &mut self, node: &LoopNode, - program: &Program, + program: &MastForest, ) -> Result { let condition = self.stack.peek(); @@ -222,7 +222,7 @@ where pub(super) fn start_call_node( &mut self, node: &CallNode, - program: &Program, + program: &MastForest, ) -> Result<(), ExecutionError> { // use the hasher to compute the hash of the CALL or SYSCALL block; the row address // returned by the hasher is used as the ID of the block; the result of the hash is diff --git a/processor/src/decoder/tests.rs b/processor/src/decoder/tests.rs index a0474cb398..598db1abdf 100644 --- a/processor/src/decoder/tests.rs +++ b/processor/src/decoder/tests.rs @@ -50,10 +50,9 @@ fn basic_block_one_group() { let mut mast_forest = MastForest::new(); let basic_block_node = MastNode::Block(basic_block.clone()); - let basic_block_id = mast_forest.ensure_node(basic_block_node); - mast_forest.set_entrypoint(basic_block_id); + let basic_block_id = mast_forest.add_node(basic_block_node); - Program::new(mast_forest).unwrap() + Program::new(mast_forest, basic_block_id) }; let (trace, trace_len) = build_trace(&[], &program); @@ -97,10 +96,9 @@ fn basic_block_small() { let mut mast_forest = MastForest::new(); let basic_block_node = MastNode::Block(basic_block.clone()); - let basic_block_id = mast_forest.ensure_node(basic_block_node); - mast_forest.set_entrypoint(basic_block_id); + let basic_block_id = mast_forest.add_node(basic_block_node); - Program::new(mast_forest).unwrap() + Program::new(mast_forest, basic_block_id) }; let (trace, trace_len) = build_trace(&[], &program); @@ -161,10 +159,9 @@ fn basic_block() { let mut mast_forest = MastForest::new(); let basic_block_node = MastNode::Block(basic_block.clone()); - let basic_block_id = mast_forest.ensure_node(basic_block_node); - mast_forest.set_entrypoint(basic_block_id); + let basic_block_id = mast_forest.add_node(basic_block_node); - Program::new(mast_forest).unwrap() + Program::new(mast_forest, basic_block_id) }; let (trace, trace_len) = build_trace(&[], &program); @@ -254,10 +251,9 @@ fn span_block_with_respan() { let mut mast_forest = MastForest::new(); let basic_block_node = MastNode::Block(basic_block.clone()); - let basic_block_id = mast_forest.ensure_node(basic_block_node); - mast_forest.set_entrypoint(basic_block_id); + let basic_block_id = mast_forest.add_node(basic_block_node); - Program::new(mast_forest).unwrap() + Program::new(mast_forest, basic_block_id) }; let (trace, trace_len) = build_trace(&[], &program); @@ -328,14 +324,13 @@ fn join_node() { let program = { let mut mast_forest = MastForest::new(); - let basic_block1_id = mast_forest.ensure_node(basic_block1.clone()); - let basic_block2_id = mast_forest.ensure_node(basic_block2.clone()); + let basic_block1_id = mast_forest.add_node(basic_block1.clone()); + let basic_block2_id = mast_forest.add_node(basic_block2.clone()); let join_node = MastNode::new_join(basic_block1_id, basic_block2_id, &mast_forest); - let join_node_id = mast_forest.ensure_node(join_node); - mast_forest.set_entrypoint(join_node_id); + let join_node_id = mast_forest.add_node(join_node); - Program::new(mast_forest).unwrap() + Program::new(mast_forest, join_node_id) }; let (trace, trace_len) = build_trace(&[], &program); @@ -395,14 +390,13 @@ fn split_node_true() { let program = { let mut mast_forest = MastForest::new(); - let basic_block1_id = mast_forest.ensure_node(basic_block1.clone()); - let basic_block2_id = mast_forest.ensure_node(basic_block2.clone()); + let basic_block1_id = mast_forest.add_node(basic_block1.clone()); + let basic_block2_id = mast_forest.add_node(basic_block2.clone()); let split_node = MastNode::new_split(basic_block1_id, basic_block2_id, &mast_forest); - let split_node_id = mast_forest.ensure_node(split_node); - mast_forest.set_entrypoint(split_node_id); + let split_node_id = mast_forest.add_node(split_node); - Program::new(mast_forest).unwrap() + Program::new(mast_forest, split_node_id) }; let (trace, trace_len) = build_trace(&[1], &program); @@ -449,14 +443,13 @@ fn split_node_false() { let program = { let mut mast_forest = MastForest::new(); - let basic_block1_id = mast_forest.ensure_node(basic_block1.clone()); - let basic_block2_id = mast_forest.ensure_node(basic_block2.clone()); + let basic_block1_id = mast_forest.add_node(basic_block1.clone()); + let basic_block2_id = mast_forest.add_node(basic_block2.clone()); let split_node = MastNode::new_split(basic_block1_id, basic_block2_id, &mast_forest); - let split_node_id = mast_forest.ensure_node(split_node); - mast_forest.set_entrypoint(split_node_id); + let split_node_id = mast_forest.add_node(split_node); - Program::new(mast_forest).unwrap() + Program::new(mast_forest, split_node_id) }; let (trace, trace_len) = build_trace(&[0], &program); @@ -505,13 +498,12 @@ fn loop_node() { let program = { let mut mast_forest = MastForest::new(); - let loop_body_id = mast_forest.ensure_node(loop_body.clone()); + let loop_body_id = mast_forest.add_node(loop_body.clone()); let loop_node = MastNode::new_loop(loop_body_id, &mast_forest); - let loop_node_id = mast_forest.ensure_node(loop_node); - mast_forest.set_entrypoint(loop_node_id); + let loop_node_id = mast_forest.add_node(loop_node); - Program::new(mast_forest).unwrap() + Program::new(mast_forest, loop_node_id) }; let (trace, trace_len) = build_trace(&[0, 1], &program); @@ -559,13 +551,12 @@ fn loop_node_skip() { let program = { let mut mast_forest = MastForest::new(); - let loop_body_id = mast_forest.ensure_node(loop_body.clone()); + let loop_body_id = mast_forest.add_node(loop_body.clone()); let loop_node = MastNode::new_loop(loop_body_id, &mast_forest); - let loop_node_id = mast_forest.ensure_node(loop_node); - mast_forest.set_entrypoint(loop_node_id); + let loop_node_id = mast_forest.add_node(loop_node); - Program::new(mast_forest).unwrap() + Program::new(mast_forest, loop_node_id) }; let (trace, trace_len) = build_trace(&[0], &program); @@ -603,13 +594,12 @@ fn loop_node_repeat() { let program = { let mut mast_forest = MastForest::new(); - let loop_body_id = mast_forest.ensure_node(loop_body.clone()); + let loop_body_id = mast_forest.add_node(loop_body.clone()); let loop_node = MastNode::new_loop(loop_body_id, &mast_forest); - let loop_node_id = mast_forest.ensure_node(loop_node); - mast_forest.set_entrypoint(loop_node_id); + let loop_node_id = mast_forest.add_node(loop_node); - Program::new(mast_forest).unwrap() + Program::new(mast_forest, loop_node_id) }; let (trace, trace_len) = build_trace(&[0, 1, 1], &program); @@ -693,27 +683,26 @@ fn call_block() { Operation::FmpUpdate, Operation::Pad, ]); - let first_basic_block_id = mast_forest.ensure_node(first_basic_block.clone()); + let first_basic_block_id = mast_forest.add_node(first_basic_block.clone()); let foo_root_node = MastNode::new_basic_block(vec![ Operation::Push(ONE), Operation::FmpUpdate ]); - let foo_root_node_id = mast_forest.ensure_node(foo_root_node.clone()); + let foo_root_node_id = mast_forest.add_node(foo_root_node.clone()); let last_basic_block = MastNode::new_basic_block(vec![Operation::FmpAdd]); - let last_basic_block_id = mast_forest.ensure_node(last_basic_block.clone()); + let last_basic_block_id = mast_forest.add_node(last_basic_block.clone()); let foo_call_node = MastNode::new_call(foo_root_node_id, &mast_forest); - let foo_call_node_id = mast_forest.ensure_node(foo_call_node.clone()); + let foo_call_node_id = mast_forest.add_node(foo_call_node.clone()); let join1_node = MastNode::new_join(first_basic_block_id, foo_call_node_id, &mast_forest); - let join1_node_id = mast_forest.ensure_node(join1_node.clone()); + let join1_node_id = mast_forest.add_node(join1_node.clone()); let program_root = MastNode::new_join(join1_node_id, last_basic_block_id, &mast_forest); - let program_root_id = mast_forest.ensure_node(program_root); - mast_forest.set_entrypoint(program_root_id); + let program_root_id = mast_forest.add_node(program_root); - let program = Program::new(mast_forest).unwrap(); + let program = Program::new(mast_forest, program_root_id); let (sys_trace, dec_trace, trace_len) = build_call_trace(&program, Kernel::default()); @@ -904,19 +893,20 @@ fn syscall_block() { // build foo procedure body let foo_root = MastNode::new_basic_block(vec![Operation::Push(THREE), Operation::FmpUpdate]); - let foo_root_id = mast_forest.ensure_node(foo_root.clone()); + let foo_root_id = mast_forest.add_node(foo_root.clone()); + mast_forest.make_root(foo_root_id); let kernel = Kernel::new(&[foo_root.digest()]).unwrap(); - mast_forest.set_kernel(kernel.clone()); // build bar procedure body let bar_basic_block = MastNode::new_basic_block(vec![Operation::Push(TWO), Operation::FmpUpdate]); - let bar_basic_block_id = mast_forest.ensure_node(bar_basic_block.clone()); + let bar_basic_block_id = mast_forest.add_node(bar_basic_block.clone()); let foo_call_node = MastNode::new_syscall(foo_root_id, &mast_forest); - let foo_call_node_id = mast_forest.ensure_node(foo_call_node.clone()); + let foo_call_node_id = mast_forest.add_node(foo_call_node.clone()); let bar_root_node = MastNode::new_join(bar_basic_block_id, foo_call_node_id, &mast_forest); - let bar_root_node_id = mast_forest.ensure_node(bar_root_node.clone()); + let bar_root_node_id = mast_forest.add_node(bar_root_node.clone()); + mast_forest.make_root(bar_root_node_id); // build the program let first_basic_block = MastNode::new_basic_block(vec![ @@ -924,22 +914,21 @@ fn syscall_block() { Operation::FmpUpdate, Operation::Pad, ]); - let first_basic_block_id = mast_forest.ensure_node(first_basic_block.clone()); + let first_basic_block_id = mast_forest.add_node(first_basic_block.clone()); let last_basic_block = MastNode::new_basic_block(vec![Operation::FmpAdd]); - let last_basic_block_id = mast_forest.ensure_node(last_basic_block.clone()); + let last_basic_block_id = mast_forest.add_node(last_basic_block.clone()); let bar_call_node = MastNode::new_call(bar_root_node_id, &mast_forest); - let bar_call_node_id = mast_forest.ensure_node(bar_call_node.clone()); + let bar_call_node_id = mast_forest.add_node(bar_call_node.clone()); let inner_join_node = MastNode::new_join(first_basic_block_id, bar_call_node_id, &mast_forest); - let inner_join_node_id = mast_forest.ensure_node(inner_join_node.clone()); + let inner_join_node_id = mast_forest.add_node(inner_join_node.clone()); let program_root_node = MastNode::new_join(inner_join_node_id, last_basic_block_id, &mast_forest); - let program_root_node_id = mast_forest.ensure_node(program_root_node.clone()); - mast_forest.set_entrypoint(program_root_node_id); + let program_root_node_id = mast_forest.add_node(program_root_node.clone()); - let program = Program::new(mast_forest).unwrap(); + let program = Program::with_kernel(mast_forest, program_root_node_id, kernel.clone()); let (sys_trace, dec_trace, trace_len) = build_call_trace(&program, kernel); @@ -1192,26 +1181,26 @@ fn dyn_block() { let mut mast_forest = MastForest::new(); let foo_root_node = MastNode::new_basic_block(vec![Operation::Push(ONE), Operation::Add]); - let _foo_root_node_id = mast_forest.ensure_node(foo_root_node.clone()); + let foo_root_node_id = mast_forest.add_node(foo_root_node.clone()); + mast_forest.make_root(foo_root_node_id); let mul_bb_node = MastNode::new_basic_block(vec![Operation::Mul]); - let mul_bb_node_id = mast_forest.ensure_node(mul_bb_node.clone()); + let mul_bb_node_id = mast_forest.add_node(mul_bb_node.clone()); let save_bb_node = MastNode::new_basic_block(vec![Operation::MovDn4]); - let save_bb_node_id = mast_forest.ensure_node(save_bb_node.clone()); + let save_bb_node_id = mast_forest.add_node(save_bb_node.clone()); let join_node = MastNode::new_join(mul_bb_node_id, save_bb_node_id, &mast_forest); - let join_node_id = mast_forest.ensure_node(join_node.clone()); + let join_node_id = mast_forest.add_node(join_node.clone()); // This dyn will point to foo. let dyn_node = MastNode::new_dynexec(); - let dyn_node_id = mast_forest.ensure_node(dyn_node.clone()); + let dyn_node_id = mast_forest.add_node(dyn_node.clone()); let program_root_node = MastNode::new_join(join_node_id, dyn_node_id, &mast_forest); - let program_root_node_id = mast_forest.ensure_node(program_root_node.clone()); - mast_forest.set_entrypoint(program_root_node_id); + let program_root_node_id = mast_forest.add_node(program_root_node.clone()); - let program = Program::new(mast_forest).unwrap(); + let program = Program::new(mast_forest, program_root_node_id); let (trace, trace_len) = build_dyn_trace( &[ @@ -1317,10 +1306,9 @@ fn set_user_op_helpers_many() { let mut mast_forest = MastForest::new(); let basic_block = MastNode::new_basic_block(vec![Operation::U32div]); - let basic_block_id = mast_forest.ensure_node(basic_block); - mast_forest.set_entrypoint(basic_block_id); + let basic_block_id = mast_forest.add_node(basic_block); - mast_forest.try_into().unwrap() + Program::new(mast_forest, basic_block_id) }; let a = rand_value::(); let b = rand_value::(); diff --git a/processor/src/errors.rs b/processor/src/errors.rs index fa5d9251ac..203384367f 100644 --- a/processor/src/errors.rs +++ b/processor/src/errors.rs @@ -48,9 +48,15 @@ pub enum ExecutionError { }, LogArgumentZero(u32), MalformedSignatureKey(&'static str), + MalformedMastForestInHost { + root_digest: Digest, + }, MastNodeNotFoundInForest { node_id: MastNodeId, }, + MastForestNotFound { + root_digest: Digest, + }, MemoryAddressOutOfBounds(u64), MerklePathVerificationFailed { value: Word, @@ -147,9 +153,18 @@ impl Display for ExecutionError { ) } MalformedSignatureKey(signature) => write!(f, "Malformed signature key: {signature}"), + MalformedMastForestInHost { root_digest } => { + write!(f, "Malformed host: MAST forest indexed by procedure root {} doesn't contain that root", root_digest) + } MastNodeNotFoundInForest { node_id } => { write!(f, "Malformed MAST forest, node id {node_id} doesn't exist") } + MastForestNotFound { root_digest } => { + write!( + f, + "No MAST forest contains the following procedure root digest: {root_digest}" + ) + } MemoryAddressOutOfBounds(addr) => { write!(f, "Memory address cannot exceed 2^32 but was {addr}") } diff --git a/processor/src/host/mast_forest_store.rs b/processor/src/host/mast_forest_store.rs new file mode 100644 index 0000000000..eb2ae42055 --- /dev/null +++ b/processor/src/host/mast_forest_store.rs @@ -0,0 +1,38 @@ +use alloc::{collections::BTreeMap, sync::Arc}; +use vm_core::{crypto::hash::RpoDigest, mast::MastForest}; + +/// A set of [`MastForest`]s available to the prover that programs may refer to (by means of an +/// [`ExternalNode`]). +/// +/// For example, a program's kernel and standard library would most likely not be compiled directly +/// with the program, and instead be provided separately to the prover. This has the benefit of +/// reducing program binary size. The store could also be much more complex, such as accessing a +/// centralized registry of [`MastForest`]s when it doesn't find one locally. +pub trait MastForestStore { + /// Returns a [`MastForest`] which is guaranteed to contain a procedure with the provided + /// procedure hash as one of its procedure, if any. + fn get(&self, procedure_hash: &RpoDigest) -> Option>; +} + +/// A simple [`MastForestStore`] where all known [`MastForest`]s are held in memory. +#[derive(Debug, Default, Clone)] +pub struct MemMastForestStore { + mast_forests: BTreeMap>, +} + +impl MemMastForestStore { + /// Inserts all the procedures of the provided MAST forest in the store. + pub fn insert(&mut self, mast_forest: MastForest) { + let mast_forest = Arc::new(mast_forest); + + for root in mast_forest.procedure_roots() { + self.mast_forests.insert(root, mast_forest.clone()); + } + } +} + +impl MastForestStore for MemMastForestStore { + fn get(&self, procedure_hash: &RpoDigest) -> Option> { + self.mast_forests.get(procedure_hash).cloned() + } +} diff --git a/processor/src/host/mod.rs b/processor/src/host/mod.rs index 9642b22e28..d6bfe9a79d 100644 --- a/processor/src/host/mod.rs +++ b/processor/src/host/mod.rs @@ -1,6 +1,11 @@ use super::{ExecutionError, Felt, ProcessState}; use crate::MemAdviceProvider; -use vm_core::{crypto::merkle::MerklePath, AdviceInjector, DebugOptions, Word}; +use alloc::sync::Arc; +use vm_core::{ + crypto::{hash::RpoDigest, merkle::MerklePath}, + mast::MastForest, + AdviceInjector, DebugOptions, Word, +}; pub(super) mod advice; use advice::{AdviceExtractor, AdviceProvider}; @@ -8,6 +13,9 @@ use advice::{AdviceExtractor, AdviceProvider}; #[cfg(feature = "std")] mod debug; +mod mast_forest_store; +pub use mast_forest_store::{MastForestStore, MemMastForestStore}; + // HOST TRAIT // ================================================================================================ @@ -25,19 +33,23 @@ pub trait Host { // -------------------------------------------------------------------------------------------- /// Returns the requested advice, specified by [AdviceExtractor], from the host to the VM. - fn get_advice( + fn get_advice( &mut self, - process: &S, + process: &P, extractor: AdviceExtractor, ) -> Result; /// Sets the requested advice, specified by [AdviceInjector], on the host. - fn set_advice( + fn set_advice( &mut self, - process: &S, + process: &P, injector: AdviceInjector, ) -> Result; + /// Returns MAST forest corresponding to the specified digest, or None if the MAST forest for + /// this digest could not be found in this [Host]. + fn get_mast_forest(&self, node_digest: &RpoDigest) -> Option>; + // PROVIDED METHODS // -------------------------------------------------------------------------------------------- @@ -182,6 +194,10 @@ where H::set_advice(self, process, injector) } + fn get_mast_forest(&self, node_digest: &RpoDigest) -> Option> { + H::get_mast_forest(self, node_digest) + } + fn on_debug( &mut self, process: &S, @@ -266,19 +282,31 @@ impl From for Felt { /// A default [Host] implementation that provides the essential functionality required by the VM. pub struct DefaultHost { adv_provider: A, + store: MemMastForestStore, } impl Default for DefaultHost { fn default() -> Self { Self { adv_provider: MemAdviceProvider::default(), + store: MemMastForestStore::default(), } } } -impl DefaultHost { +impl DefaultHost +where + A: AdviceProvider, +{ pub fn new(adv_provider: A) -> Self { - Self { adv_provider } + Self { + adv_provider, + store: MemMastForestStore::default(), + } + } + + pub fn load_mast_forest(&mut self, mast_forest: MastForest) { + self.store.insert(mast_forest) } #[cfg(any(test, feature = "internals"))] @@ -296,20 +324,27 @@ impl DefaultHost { } } -impl Host for DefaultHost { - fn get_advice( +impl Host for DefaultHost +where + A: AdviceProvider, +{ + fn get_advice( &mut self, - process: &S, + process: &P, extractor: AdviceExtractor, ) -> Result { self.adv_provider.get_advice(process, &extractor) } - fn set_advice( + fn set_advice( &mut self, - process: &S, + process: &P, injector: AdviceInjector, ) -> Result { self.adv_provider.set_advice(process, &injector) } + + fn get_mast_forest(&self, node_digest: &RpoDigest) -> Option> { + self.store.get(node_digest) + } } diff --git a/processor/src/lib.rs b/processor/src/lib.rs index 244cb31b4a..fc6b136a2b 100644 --- a/processor/src/lib.rs +++ b/processor/src/lib.rs @@ -53,7 +53,7 @@ pub use host::{ AdviceExtractor, AdviceInputs, AdviceMap, AdviceProvider, AdviceSource, MemAdviceProvider, RecAdviceProvider, }, - DefaultHost, Host, HostResponse, + DefaultHost, Host, HostResponse, MastForestStore, MemMastForestStore, }; mod chiplets; @@ -231,7 +231,7 @@ where return Err(ExecutionError::ProgramAlreadyExecuted); } - self.execute_mast_node(program.entrypoint(), program)?; + self.execute_mast_node(program.entrypoint(), program.mast_forest())?; Ok(self.stack.build_stack_outputs()) } @@ -242,7 +242,7 @@ where fn execute_mast_node( &mut self, node_id: MastNodeId, - program: &Program, + program: &MastForest, ) -> Result<(), ExecutionError> { let wrapper_node = &program .get_node_by_id(node_id) @@ -255,6 +255,24 @@ where MastNode::Loop(node) => self.execute_loop_node(node, program), MastNode::Call(node) => self.execute_call_node(node, program), MastNode::Dyn => self.execute_dyn_node(program), + MastNode::External(external_node) => { + let mast_forest = + self.host.borrow().get_mast_forest(&external_node.digest()).ok_or_else( + || ExecutionError::MastForestNotFound { + root_digest: external_node.digest(), + }, + )?; + + // We temporarily limit the parts of the program that can be called externally to + // procedure roots, even though MAST doesn't have that restriction. + let root_id = mast_forest.find_procedure_root(external_node.digest()).ok_or( + ExecutionError::MalformedMastForestInHost { + root_digest: external_node.digest(), + }, + )?; + + self.execute_mast_node(root_id, &mast_forest) + } } } @@ -262,7 +280,7 @@ where fn execute_join_node( &mut self, node: &JoinNode, - program: &Program, + program: &MastForest, ) -> Result<(), ExecutionError> { self.start_join_node(node, program)?; @@ -277,7 +295,7 @@ where fn execute_split_node( &mut self, node: &SplitNode, - program: &Program, + program: &MastForest, ) -> Result<(), ExecutionError> { // start the SPLIT block; this also pops the stack and returns the popped element let condition = self.start_split_node(node, program)?; @@ -299,7 +317,7 @@ where fn execute_loop_node( &mut self, node: &LoopNode, - program: &Program, + program: &MastForest, ) -> Result<(), ExecutionError> { // start the LOOP block; this also pops the stack and returns the popped element let condition = self.start_loop_node(node, program)?; @@ -334,7 +352,7 @@ where fn execute_call_node( &mut self, call_node: &CallNode, - program: &Program, + program: &MastForest, ) -> Result<(), ExecutionError> { let callee_digest = { let callee = program.get_node_by_id(call_node.callee()).ok_or_else(|| { @@ -365,14 +383,14 @@ where /// Executes the specified [DynNode] node. #[inline(always)] - fn execute_dyn_node(&mut self, program: &Program) -> Result<(), ExecutionError> { + fn execute_dyn_node(&mut self, program: &MastForest) -> Result<(), ExecutionError> { // get target hash from the stack let callee_hash = self.stack.get_word(0); self.start_dyn_node(callee_hash)?; // get dynamic code from the code block table and execute it let callee_id = program - .get_node_id_by_digest(callee_hash.into()) + .find_procedure_root(callee_hash.into()) .ok_or_else(|| ExecutionError::DynamicNodeNotFound(callee_hash.into()))?; self.execute_mast_node(callee_id, program)?; diff --git a/processor/src/operations/mod.rs b/processor/src/operations/mod.rs index 94aa107c89..cb677f8ac3 100644 --- a/processor/src/operations/mod.rs +++ b/processor/src/operations/mod.rs @@ -173,68 +173,72 @@ where } #[cfg(test)] -impl Process> { - // TEST METHODS - // -------------------------------------------------------------------------------------------- - - /// Instantiates a new blank process for testing purposes. The stack in the process is - /// initialized with the provided values. - fn new_dummy(stack_inputs: super::StackInputs) -> Self { - let host = super::DefaultHost::default(); - let mut process = - Self::new(Kernel::default(), stack_inputs, host, super::ExecutionOptions::default()); - process.execute_op(Operation::Noop).unwrap(); - process - } +pub mod testing { + use super::*; + use miden_air::ExecutionOptions; + use vm_core::StackInputs; + + use crate::{AdviceInputs, DefaultHost, MemAdviceProvider}; + + impl Process> { + /// Instantiates a new blank process for testing purposes. The stack in the process is + /// initialized with the provided values. + pub fn new_dummy(stack_inputs: StackInputs) -> Self { + let host = DefaultHost::default(); + let mut process = + Self::new(Kernel::default(), stack_inputs, host, ExecutionOptions::default()); + process.execute_op(Operation::Noop).unwrap(); + process + } - /// Instantiates a new blank process for testing purposes. - fn new_dummy_with_empty_stack() -> Self { - let stack = super::StackInputs::default(); - Self::new_dummy(stack) - } + /// Instantiates a new blank process for testing purposes. + pub fn new_dummy_with_empty_stack() -> Self { + let stack = StackInputs::default(); + Self::new_dummy(stack) + } - /// Instantiates a new process with an advice stack for testing purposes. - fn new_dummy_with_advice_stack(advice_stack: &[u64]) -> Self { - let stack_inputs = super::StackInputs::default(); - let advice_inputs = super::AdviceInputs::default() - .with_stack_values(advice_stack.iter().copied()) - .unwrap(); - let advice_provider = super::MemAdviceProvider::from(advice_inputs); - let host = super::DefaultHost::new(advice_provider); - let mut process = - Self::new(Kernel::default(), stack_inputs, host, super::ExecutionOptions::default()); - process.execute_op(Operation::Noop).unwrap(); - process - } + /// Instantiates a new process with an advice stack for testing purposes. + pub fn new_dummy_with_advice_stack(advice_stack: &[u64]) -> Self { + let stack_inputs = StackInputs::default(); + let advice_inputs = + AdviceInputs::default().with_stack_values(advice_stack.iter().copied()).unwrap(); + let advice_provider = MemAdviceProvider::from(advice_inputs); + let host = DefaultHost::new(advice_provider); + let mut process = + Self::new(Kernel::default(), stack_inputs, host, ExecutionOptions::default()); + process.execute_op(Operation::Noop).unwrap(); + process + } - /// Instantiates a new blank process with one decoder trace row for testing purposes. This - /// allows for setting helpers in the decoder when executing operations during tests. - fn new_dummy_with_decoder_helpers_and_empty_stack() -> Self { - let stack_inputs = super::StackInputs::default(); - Self::new_dummy_with_decoder_helpers(stack_inputs) - } + /// Instantiates a new blank process with one decoder trace row for testing purposes. This + /// allows for setting helpers in the decoder when executing operations during tests. + pub fn new_dummy_with_decoder_helpers_and_empty_stack() -> Self { + let stack_inputs = StackInputs::default(); + Self::new_dummy_with_decoder_helpers(stack_inputs) + } - /// Instantiates a new blank process with one decoder trace row for testing purposes. This - /// allows for setting helpers in the decoder when executing operations during tests. - /// - /// The stack in the process is initialized with the provided values. - fn new_dummy_with_decoder_helpers(stack_inputs: super::StackInputs) -> Self { - let advice_inputs = super::AdviceInputs::default(); - Self::new_dummy_with_inputs_and_decoder_helpers(stack_inputs, advice_inputs) - } + /// Instantiates a new blank process with one decoder trace row for testing purposes. This + /// allows for setting helpers in the decoder when executing operations during tests. + /// + /// The stack in the process is initialized with the provided values. + pub fn new_dummy_with_decoder_helpers(stack_inputs: StackInputs) -> Self { + let advice_inputs = AdviceInputs::default(); + Self::new_dummy_with_inputs_and_decoder_helpers(stack_inputs, advice_inputs) + } - /// Instantiates a new process having Program inputs along with one decoder trace row - /// for testing purposes. - fn new_dummy_with_inputs_and_decoder_helpers( - stack_inputs: super::StackInputs, - advice_inputs: super::AdviceInputs, - ) -> Self { - let advice_provider = super::MemAdviceProvider::from(advice_inputs); - let host = super::DefaultHost::new(advice_provider); - let mut process = - Self::new(Kernel::default(), stack_inputs, host, super::ExecutionOptions::default()); - process.decoder.add_dummy_trace_row(); - process.execute_op(Operation::Noop).unwrap(); - process + /// Instantiates a new process having Program inputs along with one decoder trace row + /// for testing purposes. + pub fn new_dummy_with_inputs_and_decoder_helpers( + stack_inputs: StackInputs, + advice_inputs: AdviceInputs, + ) -> Self { + let advice_provider = MemAdviceProvider::from(advice_inputs); + let host = DefaultHost::new(advice_provider); + let mut process = + Self::new(Kernel::default(), stack_inputs, host, ExecutionOptions::default()); + process.decoder.add_dummy_trace_row(); + process.execute_op(Operation::Noop).unwrap(); + process + } } } diff --git a/processor/src/trace/tests/chiplets/hasher.rs b/processor/src/trace/tests/chiplets/hasher.rs index 19bfd34886..992cee60d4 100644 --- a/processor/src/trace/tests/chiplets/hasher.rs +++ b/processor/src/trace/tests/chiplets/hasher.rs @@ -51,10 +51,9 @@ pub fn b_chip_span() { let mut mast_forest = MastForest::new(); let basic_block = MastNode::new_basic_block(vec![Operation::Add, Operation::Mul]); - let basic_block_id = mast_forest.ensure_node(basic_block); - mast_forest.set_entrypoint(basic_block_id); + let basic_block_id = mast_forest.add_node(basic_block); - Program::new(mast_forest).unwrap() + Program::new(mast_forest, basic_block_id) }; let trace = build_trace_from_program(&program, &[]); @@ -125,10 +124,9 @@ pub fn b_chip_span_with_respan() { let (ops, _) = build_span_with_respan_ops(); let basic_block = MastNode::new_basic_block(ops); - let basic_block_id = mast_forest.ensure_node(basic_block); - mast_forest.set_entrypoint(basic_block_id); + let basic_block_id = mast_forest.add_node(basic_block); - Program::new(mast_forest).unwrap() + Program::new(mast_forest, basic_block_id) }; let trace = build_trace_from_program(&program, &[]); @@ -218,17 +216,15 @@ pub fn b_chip_merge() { let mut mast_forest = MastForest::new(); let t_branch = MastNode::new_basic_block(vec![Operation::Add]); - let t_branch_id = mast_forest.ensure_node(t_branch); + let t_branch_id = mast_forest.add_node(t_branch); let f_branch = MastNode::new_basic_block(vec![Operation::Mul]); - let f_branch_id = mast_forest.ensure_node(f_branch); + let f_branch_id = mast_forest.add_node(f_branch); let split = MastNode::new_split(t_branch_id, f_branch_id, &mast_forest); - let split_id = mast_forest.ensure_node(split); + let split_id = mast_forest.add_node(split); - mast_forest.set_entrypoint(split_id); - - Program::new(mast_forest).unwrap() + Program::new(mast_forest, split_id) }; let trace = build_trace_from_program(&program, &[]); @@ -339,10 +335,9 @@ pub fn b_chip_permutation() { let mut mast_forest = MastForest::new(); let basic_block = MastNode::new_basic_block(vec![Operation::HPerm]); - let basic_block_id = mast_forest.ensure_node(basic_block); - mast_forest.set_entrypoint(basic_block_id); + let basic_block_id = mast_forest.add_node(basic_block); - Program::new(mast_forest).unwrap() + Program::new(mast_forest, basic_block_id) }; let stack = vec![8, 0, 0, 0, 1, 2, 3, 4, 5, 6, 7, 8]; let trace = build_trace_from_program(&program, &stack); diff --git a/processor/src/trace/tests/decoder.rs b/processor/src/trace/tests/decoder.rs index cbb4cf7e6c..3b20428865 100644 --- a/processor/src/trace/tests/decoder.rs +++ b/processor/src/trace/tests/decoder.rs @@ -73,17 +73,15 @@ fn decoder_p1_join() { let mut mast_forest = MastForest::new(); let basic_block_1 = MastNode::new_basic_block(vec![Operation::Mul]); - let basic_block_1_id = mast_forest.ensure_node(basic_block_1); + let basic_block_1_id = mast_forest.add_node(basic_block_1); let basic_block_2 = MastNode::new_basic_block(vec![Operation::Add]); - let basic_block_2_id = mast_forest.ensure_node(basic_block_2); + let basic_block_2_id = mast_forest.add_node(basic_block_2); let join = MastNode::new_join(basic_block_1_id, basic_block_2_id, &mast_forest); - let join_id = mast_forest.ensure_node(join); + let join_id = mast_forest.add_node(join); - mast_forest.set_entrypoint(join_id); - - Program::new(mast_forest).unwrap() + Program::new(mast_forest, join_id) }; let trace = build_trace_from_program(&program, &[]); @@ -146,17 +144,15 @@ fn decoder_p1_split() { let mut mast_forest = MastForest::new(); let basic_block_1 = MastNode::new_basic_block(vec![Operation::Mul]); - let basic_block_1_id = mast_forest.ensure_node(basic_block_1); + let basic_block_1_id = mast_forest.add_node(basic_block_1); let basic_block_2 = MastNode::new_basic_block(vec![Operation::Add]); - let basic_block_2_id = mast_forest.ensure_node(basic_block_2); + let basic_block_2_id = mast_forest.add_node(basic_block_2); let split = MastNode::new_split(basic_block_1_id, basic_block_2_id, &mast_forest); - let split_id = mast_forest.ensure_node(split); - - mast_forest.set_entrypoint(split_id); + let split_id = mast_forest.add_node(split); - Program::new(mast_forest).unwrap() + Program::new(mast_forest, split_id) }; let trace = build_trace_from_program(&program, &[1]); @@ -206,20 +202,18 @@ fn decoder_p1_loop_with_repeat() { let mut mast_forest = MastForest::new(); let basic_block_1 = MastNode::new_basic_block(vec![Operation::Pad]); - let basic_block_1_id = mast_forest.ensure_node(basic_block_1); + let basic_block_1_id = mast_forest.add_node(basic_block_1); let basic_block_2 = MastNode::new_basic_block(vec![Operation::Drop]); - let basic_block_2_id = mast_forest.ensure_node(basic_block_2); + let basic_block_2_id = mast_forest.add_node(basic_block_2); let join = MastNode::new_join(basic_block_1_id, basic_block_2_id, &mast_forest); - let join_id = mast_forest.ensure_node(join); + let join_id = mast_forest.add_node(join); let loop_node = MastNode::new_loop(join_id, &mast_forest); - let loop_node_id = mast_forest.ensure_node(loop_node); - - mast_forest.set_entrypoint(loop_node_id); + let loop_node_id = mast_forest.add_node(loop_node); - Program::new(mast_forest).unwrap() + Program::new(mast_forest, loop_node_id) }; let trace = build_trace_from_program(&program, &[0, 1, 1]); @@ -339,10 +333,9 @@ fn decoder_p2_span_with_respan() { let (ops, _) = build_span_with_respan_ops(); let basic_block = MastNode::new_basic_block(ops); - let basic_block_id = mast_forest.ensure_node(basic_block); - mast_forest.set_entrypoint(basic_block_id); + let basic_block_id = mast_forest.add_node(basic_block); - Program::new(mast_forest).unwrap() + Program::new(mast_forest, basic_block_id) }; let trace = build_trace_from_program(&program, &[]); let alphas = rand_array::(); @@ -376,16 +369,15 @@ fn decoder_p2_join() { let mut mast_forest = MastForest::new(); let basic_block_1 = MastNode::new_basic_block(vec![Operation::Mul]); - let basic_block_1_id = mast_forest.ensure_node(basic_block_1.clone()); + let basic_block_1_id = mast_forest.add_node(basic_block_1.clone()); let basic_block_2 = MastNode::new_basic_block(vec![Operation::Add]); - let basic_block_2_id = mast_forest.ensure_node(basic_block_2.clone()); + let basic_block_2_id = mast_forest.add_node(basic_block_2.clone()); let join = MastNode::new_join(basic_block_1_id, basic_block_2_id, &mast_forest); - let join_id = mast_forest.ensure_node(join.clone()); - mast_forest.set_entrypoint(join_id); + let join_id = mast_forest.add_node(join.clone()); - let program = Program::new(mast_forest).unwrap(); + let program = Program::new(mast_forest, join_id); let trace = build_trace_from_program(&program, &[]); let alphas = rand_array::(); @@ -442,17 +434,15 @@ fn decoder_p2_split_true() { let mut mast_forest = MastForest::new(); let basic_block_1 = MastNode::new_basic_block(vec![Operation::Mul]); - let basic_block_1_id = mast_forest.ensure_node(basic_block_1.clone()); + let basic_block_1_id = mast_forest.add_node(basic_block_1.clone()); let basic_block_2 = MastNode::new_basic_block(vec![Operation::Add]); - let basic_block_2_id = mast_forest.ensure_node(basic_block_2); + let basic_block_2_id = mast_forest.add_node(basic_block_2); let split = MastNode::new_split(basic_block_1_id, basic_block_2_id, &mast_forest); - let split_id = mast_forest.ensure_node(split); + let split_id = mast_forest.add_node(split); - mast_forest.set_entrypoint(split_id); - - let program = Program::new(mast_forest).unwrap(); + let program = Program::new(mast_forest, split_id); // build trace from program let trace = build_trace_from_program(&program, &[1]); @@ -500,17 +490,15 @@ fn decoder_p2_split_false() { let mut mast_forest = MastForest::new(); let basic_block_1 = MastNode::new_basic_block(vec![Operation::Mul]); - let basic_block_1_id = mast_forest.ensure_node(basic_block_1.clone()); + let basic_block_1_id = mast_forest.add_node(basic_block_1.clone()); let basic_block_2 = MastNode::new_basic_block(vec![Operation::Add]); - let basic_block_2_id = mast_forest.ensure_node(basic_block_2.clone()); + let basic_block_2_id = mast_forest.add_node(basic_block_2.clone()); let split = MastNode::new_split(basic_block_1_id, basic_block_2_id, &mast_forest); - let split_id = mast_forest.ensure_node(split); - - mast_forest.set_entrypoint(split_id); + let split_id = mast_forest.add_node(split); - let program = Program::new(mast_forest).unwrap(); + let program = Program::new(mast_forest, split_id); // build trace from program let trace = build_trace_from_program(&program, &[0]); @@ -558,20 +546,18 @@ fn decoder_p2_loop_with_repeat() { let mut mast_forest = MastForest::new(); let basic_block_1 = MastNode::new_basic_block(vec![Operation::Pad]); - let basic_block_1_id = mast_forest.ensure_node(basic_block_1.clone()); + let basic_block_1_id = mast_forest.add_node(basic_block_1.clone()); let basic_block_2 = MastNode::new_basic_block(vec![Operation::Drop]); - let basic_block_2_id = mast_forest.ensure_node(basic_block_2.clone()); + let basic_block_2_id = mast_forest.add_node(basic_block_2.clone()); let join = MastNode::new_join(basic_block_1_id, basic_block_2_id, &mast_forest); - let join_id = mast_forest.ensure_node(join.clone()); + let join_id = mast_forest.add_node(join.clone()); let loop_node = MastNode::new_loop(join_id, &mast_forest); - let loop_node_id = mast_forest.ensure_node(loop_node); - - mast_forest.set_entrypoint(loop_node_id); + let loop_node_id = mast_forest.add_node(loop_node); - let program = Program::new(mast_forest).unwrap(); + let program = Program::new(mast_forest, loop_node_id); // build trace from program let trace = build_trace_from_program(&program, &[0, 1, 1]); diff --git a/processor/src/trace/tests/mod.rs b/processor/src/trace/tests/mod.rs index adc83f5245..19c11defbc 100644 --- a/processor/src/trace/tests/mod.rs +++ b/processor/src/trace/tests/mod.rs @@ -35,10 +35,9 @@ pub fn build_trace_from_ops(operations: Vec, stack: &[u64]) -> Execut let mut mast_forest = MastForest::new(); let basic_block = MastNode::new_basic_block(operations); - let basic_block_id = mast_forest.ensure_node(basic_block); - mast_forest.set_entrypoint(basic_block_id); + let basic_block_id = mast_forest.add_node(basic_block); - let program = Program::new(mast_forest).unwrap(); + let program = Program::new(mast_forest, basic_block_id); build_trace_from_program(&program, stack) } @@ -58,10 +57,9 @@ pub fn build_trace_from_ops_with_inputs( let mut mast_forest = MastForest::new(); let basic_block = MastNode::new_basic_block(operations); - let basic_block_id = mast_forest.ensure_node(basic_block); - mast_forest.set_entrypoint(basic_block_id); + let basic_block_id = mast_forest.add_node(basic_block); - let program = Program::new(mast_forest).unwrap(); + let program = Program::new(mast_forest, basic_block_id); process.execute(&program).unwrap(); ExecutionTrace::new(process, StackOutputs::default()) diff --git a/prover/src/gpu/metal/mod.rs b/prover/src/gpu/metal/mod.rs index 7547c18aa4..602f878378 100644 --- a/prover/src/gpu/metal/mod.rs +++ b/prover/src/gpu/metal/mod.rs @@ -91,8 +91,9 @@ where // if we will fill the entire segment, we allocate uninitialized memory unsafe { page_aligned_uninit_vector(domain_size) } } else { - // but if some columns in the segment will remain unfilled, we allocate memory initialized - // to zeros to make sure we don't end up with memory with undefined values + // but if some columns in the segment will remain unfilled, we allocate memory + // initialized to zeros to make sure we don't end up with memory with + // undefined values vec![[E::BaseField::ZERO; N]; domain_size] }; diff --git a/stdlib/tests/crypto/falcon.rs b/stdlib/tests/crypto/falcon.rs index 8a81a97d7d..8a89c6ae83 100644 --- a/stdlib/tests/crypto/falcon.rs +++ b/stdlib/tests/crypto/falcon.rs @@ -203,9 +203,7 @@ fn falcon_prove_verify() { .with_library(&StdLibrary::default()) .expect("failed to load stdlib") .assemble(source) - .expect("failed to compile test source") - .try_into() - .expect("test source has no entrypoint"); + .expect("failed to compile test source"); let stack_inputs = StackInputs::try_from_ints(op_stack).expect("failed to create stack inputs"); let advice_inputs = AdviceInputs::default().with_map(advice_map); diff --git a/stdlib/tests/crypto/stark/mod.rs b/stdlib/tests/crypto/stark/mod.rs index 8157770920..ac29770e3b 100644 --- a/stdlib/tests/crypto/stark/mod.rs +++ b/stdlib/tests/crypto/stark/mod.rs @@ -51,11 +51,7 @@ pub fn generate_recursive_verifier_data( source: &str, stack_inputs: Vec, ) -> Result { - let program: Program = Assembler::default() - .assemble(source) - .unwrap() - .try_into() - .expect("test source has no entrypoint"); + let program: Program = Assembler::default().assemble(source).unwrap(); let stack_inputs = StackInputs::try_from_ints(stack_inputs).unwrap(); let advice_inputs = AdviceInputs::default(); let advice_provider = MemAdviceProvider::from(advice_inputs); diff --git a/stdlib/tests/mem/mod.rs b/stdlib/tests/mem/mod.rs index c76677d01b..1dbef6927b 100644 --- a/stdlib/tests/mem/mod.rs +++ b/stdlib/tests/mem/mod.rs @@ -26,11 +26,7 @@ fn test_memcopy() { .with_library(&StdLibrary::default()) .expect("failed to load stdlib"); - let program: Program = assembler - .assemble(source) - .expect("Failed to compile test source.") - .try_into() - .expect("test source has no entrypoint."); + let program: Program = assembler.assemble(source).expect("Failed to compile test source."); let mut process = Process::new( program.kernel().clone(), diff --git a/test-utils/src/lib.rs b/test-utils/src/lib.rs index 95929813a7..ce8b28f172 100644 --- a/test-utils/src/lib.rs +++ b/test-utils/src/lib.rs @@ -297,7 +297,7 @@ impl Test { .with_libraries(self.libraries.iter()) .expect("failed to load stdlib"); - assembler.assemble_program(self.source.clone()) + assembler.assemble(self.source.clone()) } /// Compiles the test's source to a Program and executes it with the tests inputs. Returns a