diff --git a/CHANGELOG.md b/CHANGELOG.md index 0ea06b51be..2f5eb5f46c 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -6,6 +6,7 @@ - Added error codes support for the `mtree_verify` instruction (#1328). - Added support for immediate values for `lt`, `lte`, `gt`, `gte` comparison instructions (#1346). +- Change MAST to a table-based representation (#1349) ## 0.9.2 (2024-05-22) - `stdlib` crate only - Skip writing MASM documentation to file when building on docs.rs (#1341). diff --git a/air/src/constraints/chiplets/hasher/mod.rs b/air/src/constraints/chiplets/hasher/mod.rs index 8df5fca0cb..9e2a3a2dac 100644 --- a/air/src/constraints/chiplets/hasher/mod.rs +++ b/air/src/constraints/chiplets/hasher/mod.rs @@ -100,8 +100,8 @@ pub fn get_transition_constraint_count() -> usize { /// Enforces constraints for the hasher chiplet. /// -/// - The `hasher_flag` determines if the hasher chiplet is currently enabled. It should be -/// computed by the caller and set to `Felt::ONE` +/// - The `hasher_flag` determines if the hasher chiplet is currently enabled. It should be computed +/// by the caller and set to `Felt::ONE` /// - The `transition_flag` indicates whether this is the last row this chiplet's execution trace, /// and therefore the constraints should not be enforced. pub fn enforce_constraints>( diff --git a/air/src/constraints/chiplets/memory/tests.rs b/air/src/constraints/chiplets/memory/tests.rs index 48f096f170..2d18fc7dfc 100644 --- a/air/src/constraints/chiplets/memory/tests.rs +++ b/air/src/constraints/chiplets/memory/tests.rs @@ -106,7 +106,7 @@ enum MemoryTestDeltaType { /// - To test a valid write, the MemoryTestDeltaType must be Context or Address and the `old_values` /// and `new_values` must change. /// - To test a valid read, the `delta_type` must be Clock and the `old_values` and `new_values` -/// must be equal. +/// must be equal. fn get_constraint_evaluation( selectors: Selectors, delta_type: MemoryTestDeltaType, diff --git a/air/src/constraints/stack/field_ops/mod.rs b/air/src/constraints/stack/field_ops/mod.rs index f185e255c9..74871b1bc0 100644 --- a/air/src/constraints/stack/field_ops/mod.rs +++ b/air/src/constraints/stack/field_ops/mod.rs @@ -187,8 +187,8 @@ pub fn enforce_incr_constraints( /// in the stack with its bitwise not value. Therefore, the following constraints are /// enforced: /// - The top element should be a binary. It is enforced as a general constraint. -/// - The first element of the next frame should be a binary not of the first element of -/// the current frame. s0` + s0 = 1. +/// - The first element of the next frame should be a binary not of the first element of the current +/// frame. s0` + s0 = 1. pub fn enforce_not_constraints( frame: &EvaluationFrame, result: &mut [E], @@ -206,8 +206,8 @@ pub fn enforce_not_constraints( /// Enforces constraints of the AND operation. The AND operation computes the bitwise and of the /// first two elements in the current trace. Therefore, the following constraints are enforced: -/// - The top two element in the current frame of the stack should be binary. s0^2 - s0 = 0, -/// s1^2 - s1 = 0. The top element is binary or not is enforced as a general constraint. +/// - The top two element in the current frame of the stack should be binary. s0^2 - s0 = 0, s1^2 - +/// s1 = 0. The top element is binary or not is enforced as a general constraint. /// - The first element of the next frame should be a binary and of the first two elements in the /// current frame. s0` - s0 * s1 = 0. pub fn enforce_and_constraints( @@ -233,8 +233,8 @@ pub fn enforce_and_constraints( /// Enforces constraints of the OR operation. The OR operation computes the bitwise or of the /// first two elements in the current trace. Therefore, the following constraints are enforced: -/// - The top two element in the current frame of the stack should be binary. s0^2 - s0 = 0, -/// s1^2 - s1 = 0. The top element is binary or not is enforced as a general constraint. +/// - The top two element in the current frame of the stack should be binary. s0^2 - s0 = 0, s1^2 - +/// s1 = 0. The top element is binary or not is enforced as a general constraint. /// - The first element of the next frame should be a binary or of the first two elements in the /// current frame. s0` - ( s0 + s1 - s0 * s1 ) = 0. pub fn enforce_or_constraints( diff --git a/air/src/constraints/stack/overflow/mod.rs b/air/src/constraints/stack/overflow/mod.rs index 4cc319711c..7a54be3136 100644 --- a/air/src/constraints/stack/overflow/mod.rs +++ b/air/src/constraints/stack/overflow/mod.rs @@ -92,8 +92,7 @@ pub fn enforce_stack_depth_constraints( /// Enforces constraints on the overflow flag h0. Therefore, the following constraints /// are enforced: -/// - If overflow table has values, then, h0 should be set to ONE, otherwise it should -/// be ZERO. +/// - If overflow table has values, then, h0 should be set to ONE, otherwise it should be ZERO. pub fn enforce_overflow_flag_constraints( frame: &EvaluationFrame, result: &mut [E], @@ -107,8 +106,8 @@ pub fn enforce_overflow_flag_constraints( } /// Enforces constraints on the bookkeeping index `b1`. The following constraints are enforced: -/// - In the case of a right shift operation, the next b1 index should be updated with current -/// `clk` value. +/// - In the case of a right shift operation, the next b1 index should be updated with current `clk` +/// value. /// - In the case of a left shift operation, the last stack item should be set to ZERO when the /// depth of the stack is 16. pub fn enforce_overflow_index_constraints( diff --git a/air/src/constraints/stack/stack_manipulation/mod.rs b/air/src/constraints/stack/stack_manipulation/mod.rs index 0678fe615e..89359b6c1a 100644 --- a/air/src/constraints/stack/stack_manipulation/mod.rs +++ b/air/src/constraints/stack/stack_manipulation/mod.rs @@ -93,8 +93,8 @@ pub fn enforce_pad_constraints( /// Enforces constraints of the DUPn and MOVUPn operations. The DUPn operation copies the element /// at depth n in the stack and pushes the copy onto the stack, whereas MOVUPn opearation moves the /// element at depth n to the top of the stack. Therefore, the following constraints are enforced: -/// - The top element in the next frame should be equal to the element at depth n in the -/// current frame. s0` - sn = 0. +/// - The top element in the next frame should be equal to the element at depth n in the current +/// frame. s0` - sn = 0. pub fn enforce_dup_movup_n_constraints( frame: &EvaluationFrame, result: &mut [E], @@ -244,8 +244,8 @@ pub fn enforce_swapwx_constraints( /// Enforces constraints of the MOVDNn operation. The MOVDNn operation moves the top element /// to depth n in the stack. Therefore, the following constraints are enforced: -/// - The top element in the current frame should be equal to the element at depth n in the -/// next frame. s0 - sn` = 0. +/// - The top element in the current frame should be equal to the element at depth n in the next +/// frame. s0 - sn` = 0. pub fn enforce_movdnn_constraints( frame: &EvaluationFrame, result: &mut [E], diff --git a/air/src/constraints/stack/u32_ops/mod.rs b/air/src/constraints/stack/u32_ops/mod.rs index 5f26692fdf..3604b43873 100644 --- a/air/src/constraints/stack/u32_ops/mod.rs +++ b/air/src/constraints/stack/u32_ops/mod.rs @@ -119,8 +119,8 @@ pub fn enforce_u32split_constraints>( /// Enforces constraints of the U32ADD operation. The U32ADD operation adds the top two /// elements in the current trace of the stack. Therefore, the following constraints are /// enforced: -/// - The aggregation of limbs from the helper registers is equal to the sum of the top two -/// element in the stack. +/// - The aggregation of limbs from the helper registers is equal to the sum of the top two element +/// in the stack. pub fn enforce_u32add_constraints>( frame: &EvaluationFrame, result: &mut [E], diff --git a/air/src/options.rs b/air/src/options.rs index 2798a75d24..3062a11afe 100644 --- a/air/src/options.rs +++ b/air/src/options.rs @@ -227,8 +227,8 @@ impl ExecutionOptions { /// /// In debug mode the VM does the following: /// - Executes `debug` instructions (these are ignored in regular mode). - /// - Records additional info about program execution (e.g., keeps track of stack state at - /// every cycle of the VM) which enables stepping through the program forward and backward. + /// - Records additional info about program execution (e.g., keeps track of stack state at every + /// cycle of the VM) which enables stepping through the program forward and backward. pub fn with_debugging(mut self) -> Self { self.enable_debugging = true; self diff --git a/assembly/src/assembler/span_builder.rs b/assembly/src/assembler/basic_block_builder.rs similarity index 78% rename from assembly/src/assembler/span_builder.rs rename to assembly/src/assembler/basic_block_builder.rs index 349b5d1393..24889ef3b7 100644 --- a/assembly/src/assembler/span_builder.rs +++ b/assembly/src/assembler/basic_block_builder.rs @@ -1,8 +1,11 @@ use super::{AssemblyContext, BodyWrapper, Decorator, DecoratorList, Instruction}; use alloc::{borrow::Borrow, string::ToString, vec::Vec}; -use vm_core::{code_blocks::CodeBlock, AdviceInjector, AssemblyOp, Operation}; +use vm_core::{ + mast::{MastForest, MastNode, MastNodeId}, + AdviceInjector, AssemblyOp, Operation, +}; -// SPAN BUILDER +// BASIC BLOCK BUILDER // ================================================================================================ /// A helper struct for constructing SPAN blocks while compiling procedure bodies. @@ -13,7 +16,7 @@ use vm_core::{code_blocks::CodeBlock, AdviceInjector, AssemblyOp, Operation}; /// The same span builder can be used to construct many blocks. It is expected that when the last /// SPAN block in a procedure's body is constructed `extract_final_span_into()` will be used. #[derive(Default)] -pub struct SpanBuilder { +pub struct BasicBlockBuilder { ops: Vec, decorators: DecoratorList, epilogue: Vec, @@ -21,7 +24,7 @@ pub struct SpanBuilder { } /// Constructors -impl SpanBuilder { +impl BasicBlockBuilder { /// Returns a new [SpanBuilder] instantiated with the specified optional wrapper. /// /// If the wrapper is provided, the prologue of the wrapper is immediately appended to the @@ -41,7 +44,7 @@ impl SpanBuilder { } /// Operations -impl SpanBuilder { +impl BasicBlockBuilder { /// Adds the specified operation to the list of span operations. pub fn push_op(&mut self, op: Operation) { self.ops.push(op); @@ -64,8 +67,8 @@ impl SpanBuilder { } /// Decorators -impl SpanBuilder { - /// Add ths specified decorator to the list of span decorators. +impl BasicBlockBuilder { + /// Add the specified decorator to the list of span decorators. pub fn push_decorator(&mut self, decorator: Decorator) { self.decorators.push((self.ops.len(), decorator)); } @@ -114,34 +117,40 @@ impl SpanBuilder { } /// Span Constructors -impl SpanBuilder { - /// Creates a new SPAN block from the operations and decorators currently in this builder and - /// appends the block to the provided target. +impl BasicBlockBuilder { + /// Creates and returns a new BASIC BLOCK node from the operations and decorators currently in + /// this builder. If the builder is empty, then no node is created and `None` is returned. /// /// This consumes all operations and decorators in the builder, but does not touch the /// operations in the epilogue of the builder. - pub fn extract_span_into(&mut self, target: &mut Vec) { + pub fn make_basic_block(&mut self, mast_forest: &mut MastForest) -> Option { if !self.ops.is_empty() { let ops = self.ops.drain(..).collect(); let decorators = self.decorators.drain(..).collect(); - target.push(CodeBlock::new_span_with_decorators(ops, decorators)); + + let basic_block_node = MastNode::new_basic_block_with_decorators(ops, decorators); + let basic_block_node_id = mast_forest.ensure_node(basic_block_node); + + Some(basic_block_node_id) } else if !self.decorators.is_empty() { // this is a bug in the assembler. we shouldn't have decorators added without their // associated operations // TODO: change this to an error or allow decorators in empty span blocks unreachable!("decorators in an empty SPAN block") + } else { + None } } - /// Creates a new SPAN block from the operations and decorators currently in this builder and - /// appends the block to the provided target. + /// Creates and returns a new BASIC BLOCK node from the operations and decorators currently in + /// this builder. If the builder is empty, then no node is created and `None` is returned. /// - /// The main differences from the `extract_span_int()` method above are: - /// - Operations contained in the epilogue of the span builder are appended to the list of ops - /// which go into the new SPAN block. - /// - The span builder is consumed in the process. - pub fn extract_final_span_into(mut self, target: &mut Vec) { + /// The main differences with [`Self::to_basic_block`] are: + /// - Operations contained in the epilogue of the builder are appended to the list of ops which + /// go into the new BASIC BLOCK node. + /// - The builder is consumed in the process. + pub fn into_basic_block(mut self, mast_forest: &mut MastForest) -> Option { self.ops.append(&mut self.epilogue); - self.extract_span_into(target); + self.make_basic_block(mast_forest) } } diff --git a/assembly/src/assembler/context.rs b/assembly/src/assembler/context.rs index 2ca4d53c67..aa828afbe2 100644 --- a/assembly/src/assembler/context.rs +++ b/assembly/src/assembler/context.rs @@ -6,7 +6,7 @@ use crate::{ diagnostics::SourceFile, AssemblyError, LibraryPath, RpoDigest, SourceSpan, Span, Spanned, }; -use vm_core::code_blocks::CodeBlock; +use vm_core::mast::{MastForest, MastNodeId}; // ASSEMBLY CONTEXT // ================================================================================================ @@ -168,6 +168,7 @@ impl AssemblyContext { &mut self, callee: &Procedure, inlined: bool, + mast_forest: &MastForest, ) -> Result<(), AssemblyError> { let context = self.unwrap_current_procedure_mut(); @@ -176,7 +177,7 @@ impl AssemblyContext { // If the callee is not being inlined, add it to our callset if !inlined { - context.insert_callee(callee.mast_root()); + context.insert_callee(callee.mast_root(mast_forest)); } Ok(()) @@ -264,11 +265,12 @@ impl ProcedureContext { self.visibility.is_syscall() } - pub fn into_procedure(self, code: CodeBlock) -> Box { - let procedure = Procedure::new(self.name, self.visibility, self.num_locals as u32, code) - .with_span(self.span) - .with_source_file(self.source_file) - .with_callset(self.callset); + pub fn into_procedure(self, body_node_id: MastNodeId) -> Box { + let procedure = + Procedure::new(self.name, self.visibility, self.num_locals as u32, body_node_id) + .with_span(self.span) + .with_source_file(self.source_file) + .with_callset(self.callset); Box::new(procedure) } } diff --git a/assembly/src/assembler/instruction/adv_ops.rs b/assembly/src/assembler/instruction/adv_ops.rs index a361bdf500..ec434557b4 100644 --- a/assembly/src/assembler/instruction/adv_ops.rs +++ b/assembly/src/assembler/instruction/adv_ops.rs @@ -1,4 +1,4 @@ -use super::{validate_param, SpanBuilder}; +use super::{validate_param, BasicBlockBuilder}; use crate::{ast::AdviceInjectorNode, AssemblyError, ADVICE_READ_LIMIT}; use vm_core::Operation; @@ -12,7 +12,7 @@ use vm_core::Operation; /// # Errors /// Returns an error if the specified number of values to pushed is smaller than 1 or greater /// than 16. -pub fn adv_push(span: &mut SpanBuilder, n: u8) -> Result<(), AssemblyError> { +pub fn adv_push(span: &mut BasicBlockBuilder, n: u8) -> Result<(), AssemblyError> { validate_param(n, 1..=ADVICE_READ_LIMIT)?; span.push_op_many(Operation::AdvPop, n as usize); Ok(()) @@ -22,6 +22,6 @@ pub fn adv_push(span: &mut SpanBuilder, n: u8) -> Result<(), AssemblyError> { // ================================================================================================ /// Appends advice injector decorator to the span. -pub fn adv_inject(span: &mut SpanBuilder, injector: &AdviceInjectorNode) { +pub fn adv_inject(span: &mut BasicBlockBuilder, injector: &AdviceInjectorNode) { span.push_advice_injector(injector.into()); } diff --git a/assembly/src/assembler/instruction/crypto_ops.rs b/assembly/src/assembler/instruction/crypto_ops.rs index 7f657a6273..68a5cfd791 100644 --- a/assembly/src/assembler/instruction/crypto_ops.rs +++ b/assembly/src/assembler/instruction/crypto_ops.rs @@ -1,4 +1,4 @@ -use super::SpanBuilder; +use super::BasicBlockBuilder; use vm_core::{AdviceInjector, Operation::*}; // HASHING @@ -23,7 +23,7 @@ use vm_core::{AdviceInjector, Operation::*}; /// 3. Drop D and B to achieve our result [C, ...] /// /// This operation takes 20 VM cycles. -pub(super) fn hash(span: &mut SpanBuilder) { +pub(super) fn hash(span: &mut BasicBlockBuilder) { #[rustfmt::skip] let ops = [ // add 4 elements to the stack to be used as the capacity elements for the RPO permutation @@ -69,7 +69,7 @@ pub(super) fn hash(span: &mut SpanBuilder) { /// 4. Drop F and D to return our result [E, ...]. /// /// This operation takes 16 VM cycles. -pub(super) fn hmerge(span: &mut SpanBuilder) { +pub(super) fn hmerge(span: &mut BasicBlockBuilder) { #[rustfmt::skip] let ops = [ // Add 4 elements to the stack to prepare the capacity portion for the RPO permutation @@ -110,7 +110,7 @@ pub(super) fn hmerge(span: &mut SpanBuilder) { /// - root of the tree, 4 elements. /// /// This operation takes 9 VM cycles. -pub(super) fn mtree_get(span: &mut SpanBuilder) { +pub(super) fn mtree_get(span: &mut BasicBlockBuilder) { // stack: [d, i, R, ...] // pops the value of the node we are looking for from the advice stack read_mtree_node(span); @@ -140,7 +140,7 @@ pub(super) fn mtree_get(span: &mut SpanBuilder) { /// - new root of the tree after the update, 4 elements /// /// This operation takes 29 VM cycles. -pub(super) fn mtree_set(span: &mut SpanBuilder) { +pub(super) fn mtree_set(span: &mut BasicBlockBuilder) { // stack: [d, i, R_old, V_new, ...] // stack: [V_old, R_new, ...] (29 cycles) @@ -160,7 +160,7 @@ pub(super) fn mtree_set(span: &mut SpanBuilder) { /// It is not checked whether the provided roots exist as Merkle trees in the advide providers. /// /// This operation takes 16 VM cycles. -pub(super) fn mtree_merge(span: &mut SpanBuilder) { +pub(super) fn mtree_merge(span: &mut BasicBlockBuilder) { // stack input: [R_rhs, R_lhs, ...] // stack output: [R_merged, ...] @@ -192,7 +192,7 @@ pub(super) fn mtree_merge(span: &mut SpanBuilder) { /// - new value of the node, 4 elements (only in the case of mtree_set) /// /// This operation takes 4 VM cycles. -fn read_mtree_node(span: &mut SpanBuilder) { +fn read_mtree_node(span: &mut BasicBlockBuilder) { // The stack should be arranged in the following way: [d, i, R, ...] so that the decorator // can fetch the node value from the root. In the `mtree.get` operation we have the stack in // the following format: [d, i, R], whereas in the case of `mtree.set` we would also have the @@ -210,7 +210,7 @@ fn read_mtree_node(span: &mut SpanBuilder) { /// and perform the mutation on the copied tree. /// /// This operation takes 29 VM cycles. -fn update_mtree(span: &mut SpanBuilder) { +fn update_mtree(span: &mut BasicBlockBuilder) { // stack: [d, i, R_old, V_new, ...] // output: [R_new, R_old, V_new, V_old, ...] diff --git a/assembly/src/assembler/instruction/env_ops.rs b/assembly/src/assembler/instruction/env_ops.rs index 0818602496..900b5a39af 100644 --- a/assembly/src/assembler/instruction/env_ops.rs +++ b/assembly/src/assembler/instruction/env_ops.rs @@ -1,4 +1,4 @@ -use super::{mem_ops::local_to_absolute_addr, push_felt, AssemblyContext, SpanBuilder}; +use super::{mem_ops::local_to_absolute_addr, push_felt, AssemblyContext, BasicBlockBuilder}; use crate::{AssemblyError, Felt, Spanned}; use vm_core::Operation::*; @@ -10,7 +10,7 @@ use vm_core::Operation::*; /// In cases when the immediate value is 0, `PUSH` operation is replaced with `PAD`. Also, in cases /// when immediate value is 1, `PUSH` operation is replaced with `PAD INCR` because in most cases /// this will be more efficient than doing a `PUSH`. -pub fn push_one(imm: T, span: &mut SpanBuilder) +pub fn push_one(imm: T, span: &mut BasicBlockBuilder) where T: Into, { @@ -23,7 +23,7 @@ where /// In cases when the immediate value is 0, `PUSH` operation is replaced with `PAD`. Also, in cases /// when immediate value is 1, `PUSH` operation is replaced with `PAD INCR` because in most cases /// this will be more efficient than doing a `PUSH`. -pub fn push_many(imms: &[T], span: &mut SpanBuilder) +pub fn push_many(imms: &[T], span: &mut BasicBlockBuilder) where T: Into + Copy, { @@ -39,7 +39,7 @@ where /// # Errors /// Returns an error if index is greater than the number of procedure locals. pub fn locaddr( - span: &mut SpanBuilder, + span: &mut BasicBlockBuilder, index: u16, context: &AssemblyContext, ) -> Result<(), AssemblyError> { @@ -51,7 +51,10 @@ pub fn locaddr( /// /// # Errors /// Returns an error if the instruction is being executed outside of kernel context. -pub fn caller(span: &mut SpanBuilder, context: &AssemblyContext) -> Result<(), AssemblyError> { +pub fn caller( + span: &mut BasicBlockBuilder, + context: &AssemblyContext, +) -> Result<(), AssemblyError> { let current_procedure = context.unwrap_current_procedure(); if !current_procedure.is_kernel() { return Err(AssemblyError::CallerOutsideOfKernel { diff --git a/assembly/src/assembler/instruction/ext2_ops.rs b/assembly/src/assembler/instruction/ext2_ops.rs index 9bf8470731..35ec3e8710 100644 --- a/assembly/src/assembler/instruction/ext2_ops.rs +++ b/assembly/src/assembler/instruction/ext2_ops.rs @@ -1,4 +1,4 @@ -use super::SpanBuilder; +use super::BasicBlockBuilder; use vm_core::{AdviceInjector::Ext2Inv, Operation::*}; /// Given a stack in the following initial configuration [b1, b0, a1, a0, ...] where a = (a0, a1) @@ -6,7 +6,7 @@ use vm_core::{AdviceInjector::Ext2Inv, Operation::*}; /// operations outputs the result c = (c1, c0) where c1 = a1 + b1 and c0 = a0 + b0. /// /// This operation takes 5 VM cycles. -pub fn ext2_add(span: &mut SpanBuilder) { +pub fn ext2_add(span: &mut BasicBlockBuilder) { #[rustfmt::skip] let ops = [ Swap, // [b0, b1, a1, a0, ...] @@ -23,7 +23,7 @@ pub fn ext2_add(span: &mut SpanBuilder) { /// operations outputs the result c = (c1, c0) where c1 = a1 - b1 and c0 = a0 - b0. /// /// This operation takes 7 VM cycles. -pub fn ext2_sub(span: &mut SpanBuilder) { +pub fn ext2_sub(span: &mut BasicBlockBuilder) { #[rustfmt::skip] let ops = [ Neg, // [-b1, b0, a1, a0, ...] @@ -42,7 +42,7 @@ pub fn ext2_sub(span: &mut SpanBuilder) { /// outputs the product c = (c1, c0) where c0 = a0b0 - 2(a1b1) and c1 = (a0 + a1)(b0 + b1) - a0b0 /// /// This operation takes 3 VM cycles. -pub fn ext2_mul(span: &mut SpanBuilder) { +pub fn ext2_mul(span: &mut BasicBlockBuilder) { span.push_ops([Ext2Mul, Drop, Drop]); } @@ -51,7 +51,7 @@ pub fn ext2_mul(span: &mut SpanBuilder) { /// operations outputs the result c = (c1, c0) where c = a * b^-1. /// /// This operation takes 11 VM cycles. -pub fn ext2_div(span: &mut SpanBuilder) { +pub fn ext2_div(span: &mut BasicBlockBuilder) { span.push_advice_injector(Ext2Inv); #[rustfmt::skip] let ops = [ @@ -75,7 +75,7 @@ pub fn ext2_div(span: &mut SpanBuilder) { /// [-a1, -a0, ...] /// /// This operation takes 4 VM cycles. -pub fn ext2_neg(span: &mut SpanBuilder) { +pub fn ext2_neg(span: &mut BasicBlockBuilder) { #[rustfmt::skip] let ops = [ Neg, // [a1, a0, ...] @@ -111,7 +111,7 @@ pub fn ext2_neg(span: &mut SpanBuilder) { /// assert b = (1, 0) | (1, 0) is the multiplicative identity of extension field. /// /// This operation takes 8 VM cycles. -pub fn ext2_inv(span: &mut SpanBuilder) { +pub fn ext2_inv(span: &mut BasicBlockBuilder) { span.push_advice_injector(Ext2Inv); #[rustfmt::skip] let ops = [ diff --git a/assembly/src/assembler/instruction/field_ops.rs b/assembly/src/assembler/instruction/field_ops.rs index 38db35980e..8c7fabc90e 100644 --- a/assembly/src/assembler/instruction/field_ops.rs +++ b/assembly/src/assembler/instruction/field_ops.rs @@ -1,4 +1,4 @@ -use super::{validate_param, AssemblyContext, SpanBuilder}; +use super::{validate_param, AssemblyContext, BasicBlockBuilder}; use crate::{ diagnostics::{RelatedError, Report}, AssemblyError, Felt, Span, MAX_EXP_BITS, ONE, ZERO, @@ -14,7 +14,7 @@ const TWO: Felt = Felt::new(2); /// Asserts that the top two words in the stack are equal. /// /// VM cycles: 11 cycles -pub fn assertw(span_builder: &mut SpanBuilder, err_code: u32) { +pub fn assertw(span_builder: &mut BasicBlockBuilder, err_code: u32) { span_builder.push_ops([ MovUp4, Eq, @@ -39,7 +39,7 @@ pub fn assertw(span_builder: &mut SpanBuilder, err_code: u32) { /// - else if imm = 1: INCR /// - else if imm = 2: INCR INCR /// - otherwise: PUSH(imm) ADD -pub fn add_imm(span_builder: &mut SpanBuilder, imm: Felt) { +pub fn add_imm(span_builder: &mut BasicBlockBuilder, imm: Felt) { if imm == ZERO { span_builder.push_op(Noop); } else if imm == ONE { @@ -55,7 +55,7 @@ pub fn add_imm(span_builder: &mut SpanBuilder, imm: Felt) { /// stack. Specifically, the sequences are: /// - if imm = 0: NOOP /// - otherwise: PUSH(-imm) ADD -pub fn sub_imm(span_builder: &mut SpanBuilder, imm: Felt) { +pub fn sub_imm(span_builder: &mut BasicBlockBuilder, imm: Felt) { if imm == ZERO { span_builder.push_op(Noop); } else { @@ -68,7 +68,7 @@ pub fn sub_imm(span_builder: &mut SpanBuilder, imm: Felt) { /// - if imm = 0: DROP PAD /// - else if imm = 1: NOOP /// - otherwise: PUSH(imm) MUL -pub fn mul_imm(span_builder: &mut SpanBuilder, imm: Felt) { +pub fn mul_imm(span_builder: &mut BasicBlockBuilder, imm: Felt) { if imm == ZERO { span_builder.push_ops([Drop, Pad]); } else if imm == ONE { @@ -87,7 +87,7 @@ pub fn mul_imm(span_builder: &mut SpanBuilder, imm: Felt) { /// # Errors /// Returns an error if the immediate value is ZERO. pub fn div_imm( - span_builder: &mut SpanBuilder, + span_builder: &mut BasicBlockBuilder, ctx: &mut AssemblyContext, imm: Span, ) -> Result<(), AssemblyError> { @@ -114,14 +114,14 @@ pub fn div_imm( /// top of the stack. /// /// VM cycles: 16 cycles -pub fn pow2(span_builder: &mut SpanBuilder) { +pub fn pow2(span_builder: &mut BasicBlockBuilder) { append_pow2_op(span_builder); } /// Appends relevant operations to the span_builder block for the computation of power of 2. /// /// VM cycles: 16 cycles -pub fn append_pow2_op(span_builder: &mut SpanBuilder) { +pub fn append_pow2_op(span_builder: &mut BasicBlockBuilder) { // push base 2 onto the stack: [exp, ...] -> [2, exp, ...] span_builder.push_op(Push(2_u8.into())); // introduce initial value of acc onto the stack: [2, exp, ...] -> [1, 2, exp, ...] @@ -149,7 +149,7 @@ pub fn append_pow2_op(span_builder: &mut SpanBuilder) { /// /// # Errors /// Returns an error if num_pow_bits is greater than 64. -pub fn exp(span_builder: &mut SpanBuilder, num_pow_bits: u8) -> Result<(), AssemblyError> { +pub fn exp(span_builder: &mut BasicBlockBuilder, num_pow_bits: u8) -> Result<(), AssemblyError> { validate_param(num_pow_bits, 0..=MAX_EXP_BITS)?; // arranging the stack to prepare it for expacc instruction. @@ -178,7 +178,7 @@ pub fn exp(span_builder: &mut SpanBuilder, num_pow_bits: u8) -> Result<(), Assem /// - pow = 6: 10 cycles /// - pow = 7: 12 cycles /// - pow > 7: 9 + Ceil(log2(pow)) -pub fn exp_imm(span_builder: &mut SpanBuilder, pow: Felt) -> Result<(), AssemblyError> { +pub fn exp_imm(span_builder: &mut BasicBlockBuilder, pow: Felt) -> Result<(), AssemblyError> { if pow.as_int() <= 7 { perform_exp_for_small_power(span_builder, pow.as_int()); Ok(()) @@ -210,7 +210,7 @@ pub fn exp_imm(span_builder: &mut SpanBuilder, pow: Felt) -> Result<(), Assembly /// - pow = 5: 8 cycles /// - pow = 6: 10 cycles /// - pow = 7: 12 cycles -fn perform_exp_for_small_power(span_builder: &mut SpanBuilder, pow: u64) { +fn perform_exp_for_small_power(span_builder: &mut BasicBlockBuilder, pow: u64) { match pow { 0 => { span_builder.push_op(Drop); @@ -256,7 +256,7 @@ fn perform_exp_for_small_power(span_builder: &mut SpanBuilder, pow: u64) { /// /// # Errors /// Returns an error if the logarithm argument (top stack element) equals ZERO. -pub fn ilog2(span: &mut SpanBuilder) { +pub fn ilog2(span: &mut BasicBlockBuilder) { span.push_advice_injector(AdviceInjector::ILog2); span.push_op(AdvPop); // [ilog2, n, ...] @@ -296,7 +296,7 @@ pub fn ilog2(span: &mut SpanBuilder) { /// and the provided immediate value. Specifically, the sequences are: /// - if imm = 0: EQZ /// - otherwise: PUSH(imm) EQ -pub fn eq_imm(span_builder: &mut SpanBuilder, imm: Felt) { +pub fn eq_imm(span_builder: &mut BasicBlockBuilder, imm: Felt) { if imm == ZERO { span_builder.push_op(Eqz); } else { @@ -308,7 +308,7 @@ pub fn eq_imm(span_builder: &mut SpanBuilder, imm: Felt) { /// and the provided immediate value. Specifically, the sequences are: /// - if imm = 0: EQZ NOT /// - otherwise: PUSH(imm) EQ NOT -pub fn neq_imm(span_builder: &mut SpanBuilder, imm: Felt) { +pub fn neq_imm(span_builder: &mut BasicBlockBuilder, imm: Felt) { if imm == ZERO { span_builder.push_ops([Eqz, Not]); } else { @@ -319,7 +319,7 @@ pub fn neq_imm(span_builder: &mut SpanBuilder, imm: Felt) { /// Appends a sequence of operations to check equality between two words at the top of the stack. /// /// This operation takes 15 VM cycles. -pub fn eqw(span_builder: &mut SpanBuilder) { +pub fn eqw(span_builder: &mut BasicBlockBuilder) { span_builder.push_ops([ // duplicate first pair of for comparison(4th elements of each word) in reverse order // to avoid using dup.8 after stack shifting(dup.X where X > 7, takes more VM cycles ) @@ -334,7 +334,7 @@ pub fn eqw(span_builder: &mut SpanBuilder) { /// of 1 is pushed onto the stack if a < b. Otherwise, 0 is pushed. /// /// This operation takes 14 VM cycles. -pub fn lt(span_builder: &mut SpanBuilder) { +pub fn lt(span_builder: &mut BasicBlockBuilder) { // Split both elements into high and low bits // 3 cycles split_elements(span_builder); @@ -358,9 +358,9 @@ pub fn lt(span_builder: &mut SpanBuilder) { /// (from the top). A value of 1 is pushed onto the stack if a < imm. Otherwise, 0 is pushed. /// /// This operation takes 15 VM cycles. -pub fn lt_imm(span_builder: &mut SpanBuilder, imm: Felt) { - span_builder.push_op(Push(imm)); - lt(span_builder); +pub fn lt_imm(basic_block_builder: &mut BasicBlockBuilder, imm: Felt) { + basic_block_builder.push_op(Push(imm)); + lt(basic_block_builder); } /// Appends a sequence of operations to pop the top 2 elements off the stack and do a "less @@ -368,7 +368,7 @@ pub fn lt_imm(span_builder: &mut SpanBuilder, imm: Felt) { /// A value of 1 is pushed onto the stack if a <= b. Otherwise, 0 is pushed. /// /// This operation takes 15 VM cycles. -pub fn lte(span_builder: &mut SpanBuilder) { +pub fn lte(span_builder: &mut BasicBlockBuilder) { // Split both elements into high and low bits // 3 cycles split_elements(span_builder); @@ -393,7 +393,7 @@ pub fn lte(span_builder: &mut SpanBuilder) { /// pushed. /// /// This operation takes 16 VM cycles. -pub fn lte_imm(span_builder: &mut SpanBuilder, imm: Felt) { +pub fn lte_imm(span_builder: &mut BasicBlockBuilder, imm: Felt) { span_builder.push_op(Push(imm)); lte(span_builder); } @@ -403,7 +403,7 @@ pub fn lte_imm(span_builder: &mut SpanBuilder, imm: Felt) { /// of 1 is pushed onto the stack if a > b. Otherwise, 0 is pushed. /// /// This operation takes 15 VM cycles. -pub fn gt(span_builder: &mut SpanBuilder) { +pub fn gt(span_builder: &mut BasicBlockBuilder) { // Split both elements into high and low bits // 3 cycles split_elements(span_builder); @@ -427,7 +427,7 @@ pub fn gt(span_builder: &mut SpanBuilder) { /// (from the top). A value of 1 is pushed onto the stack if a > imm. Otherwise, 0 is pushed. /// /// This operation takes 16 VM cycles. -pub fn gt_imm(span_builder: &mut SpanBuilder, imm: Felt) { +pub fn gt_imm(span_builder: &mut BasicBlockBuilder, imm: Felt) { span_builder.push_op(Push(imm)); gt(span_builder); } @@ -437,7 +437,7 @@ pub fn gt_imm(span_builder: &mut SpanBuilder, imm: Felt) { /// A value of 1 is pushed onto the stack if a >= b. Otherwise, 0 is pushed. /// /// This operation takes 16 VM cycles. -pub fn gte(span_builder: &mut SpanBuilder) { +pub fn gte(span_builder: &mut BasicBlockBuilder) { // Split both elements into high and low bits // 3 cycles split_elements(span_builder); @@ -462,7 +462,7 @@ pub fn gte(span_builder: &mut SpanBuilder) { /// pushed. /// /// This operation takes 17 VM cycles. -pub fn gte_imm(span_builder: &mut SpanBuilder, imm: Felt) { +pub fn gte_imm(span_builder: &mut BasicBlockBuilder, imm: Felt) { span_builder.push_op(Push(imm)); gte(span_builder); } @@ -470,7 +470,7 @@ pub fn gte_imm(span_builder: &mut SpanBuilder, imm: Felt) { /// Checks if the top element in the stack is an odd number or not. /// /// Vm cycles: 5 -pub fn is_odd(span_builder: &mut SpanBuilder) { +pub fn is_odd(span_builder: &mut BasicBlockBuilder) { span_builder.push_ops([U32split, Drop, Pad, Incr, U32and]); } @@ -483,7 +483,7 @@ pub fn is_odd(span_builder: &mut SpanBuilder) { /// After these operations, the stack state will be: [a_hi, a_lo, b_hi, b_lo, ...]. /// /// This operation takes 3 cycles. -fn split_elements(span_builder: &mut SpanBuilder) { +fn split_elements(span_builder: &mut BasicBlockBuilder) { // stack: [b, a, ...] => [b_hi, b_lo, a, ...] span_builder.push_op(U32split); // => [a, b_hi, b_lo, ...] @@ -502,7 +502,7 @@ fn split_elements(span_builder: &mut SpanBuilder) { /// The resulting stack after this operation is: [eq_flag, lt_flag, ...]. /// /// This operation takes 3 cycles. -fn check_lt_and_eq(span_builder: &mut SpanBuilder) { +fn check_lt_and_eq(span_builder: &mut BasicBlockBuilder) { // calculate a - b // stack: [b, a, ...] => [underflow_flag, result, ...] span_builder.push_op(U32sub); @@ -536,7 +536,7 @@ fn check_lt_and_eq(span_builder: &mut SpanBuilder) { /// - hi_flag_lt: 1 if a's high-bit values were less than b's (a_hi < b_hi); 0 otherwise /// /// This operation takes 6 cycles. -fn check_lt_high_bits(span_builder: &mut SpanBuilder) { +fn check_lt_high_bits(span_builder: &mut BasicBlockBuilder) { // reorder the stack to check a_hi < b_hi span_builder.push_op(MovUp2); @@ -558,7 +558,7 @@ fn check_lt_high_bits(span_builder: &mut SpanBuilder) { /// condition will be true if the underflow flag is set. /// /// This operation takes 3 cycles. -fn check_lt(span_builder: &mut SpanBuilder) { +fn check_lt(span_builder: &mut BasicBlockBuilder) { // calculate a - b // stack: [b, a, ...] => [underflow_flag, result, ...] span_builder.push_op(U32sub); @@ -580,7 +580,7 @@ fn check_lt(span_builder: &mut SpanBuilder) { /// - high-bit comparison flag: 1 if the lt/gt condition being checked was true; 0 otherwise /// /// This function takes 2 cycles. -fn set_result(span_builder: &mut SpanBuilder) { +fn set_result(span_builder: &mut BasicBlockBuilder) { // check if high bits are equal AND low bit comparison condition was true span_builder.push_op(And); @@ -598,7 +598,7 @@ fn set_result(span_builder: &mut SpanBuilder) { /// there was no underflow and the result is 0. /// /// This function takes 4 cycles. -fn check_lte(span_builder: &mut SpanBuilder) { +fn check_lte(span_builder: &mut BasicBlockBuilder) { // calculate a - b // stack: [b, a, ...] => [underflow_flag, result, ...] span_builder.push_op(U32sub); @@ -627,7 +627,7 @@ fn check_lte(span_builder: &mut SpanBuilder) { /// - hi_flag_gt: 1 if a's high-bit values were greater than b's (a_hi > b_hi); 0 otherwise /// /// This function takes 7 cycles. -fn check_gt_high_bits(span_builder: &mut SpanBuilder) { +fn check_gt_high_bits(span_builder: &mut BasicBlockBuilder) { // reorder the stack to check b_hi < a_hi span_builder.push_ops([Swap, MovDn2]); diff --git a/assembly/src/assembler/instruction/mem_ops.rs b/assembly/src/assembler/instruction/mem_ops.rs index e6775ffa0a..733b438185 100644 --- a/assembly/src/assembler/instruction/mem_ops.rs +++ b/assembly/src/assembler/instruction/mem_ops.rs @@ -1,4 +1,4 @@ -use super::{push_felt, push_u32_value, validate_param, AssemblyContext, SpanBuilder}; +use super::{push_felt, push_u32_value, validate_param, AssemblyContext, BasicBlockBuilder}; use crate::AssemblyError; use vm_core::{Felt, Operation::*}; @@ -20,7 +20,7 @@ use vm_core::{Felt, Operation::*}; /// Returns an error if we are reading from local memory and local memory index is greater than /// the number of procedure locals. pub fn mem_read( - span: &mut SpanBuilder, + span: &mut BasicBlockBuilder, context: &AssemblyContext, addr: Option, is_local: bool, @@ -71,7 +71,7 @@ pub fn mem_read( /// Returns an error if we are writing to local memory and local memory index is greater than /// the number of procedure locals. pub fn mem_write_imm( - span: &mut SpanBuilder, + span: &mut BasicBlockBuilder, context: &AssemblyContext, addr: u32, is_local: bool, @@ -108,7 +108,7 @@ pub fn mem_write_imm( /// # Errors /// Returns an error if index is greater than the number of procedure locals. pub fn local_to_absolute_addr( - span: &mut SpanBuilder, + span: &mut BasicBlockBuilder, index: u16, num_proc_locals: u16, ) -> Result<(), AssemblyError> { diff --git a/assembly/src/assembler/instruction/mod.rs b/assembly/src/assembler/instruction/mod.rs index 0a7a69ef2d..461c774d9c 100644 --- a/assembly/src/assembler/instruction/mod.rs +++ b/assembly/src/assembler/instruction/mod.rs @@ -1,10 +1,13 @@ use super::{ - ast::InvokeKind, Assembler, AssemblyContext, CodeBlock, Felt, Instruction, Operation, - SpanBuilder, ONE, ZERO, + ast::InvokeKind, Assembler, AssemblyContext, BasicBlockBuilder, Felt, Instruction, Operation, + ONE, ZERO, }; use crate::{diagnostics::Report, utils::bound_into_included_u64, AssemblyError}; use core::ops::RangeBounds; -use vm_core::Decorator; +use vm_core::{ + mast::{MastForest, MastNodeId}, + Decorator, +}; mod adv_ops; mod crypto_ops; @@ -22,9 +25,10 @@ impl Assembler { pub(super) fn compile_instruction( &self, instruction: &Instruction, - span_builder: &mut SpanBuilder, + span_builder: &mut BasicBlockBuilder, ctx: &mut AssemblyContext, - ) -> Result, AssemblyError> { + mast_forest: &mut MastForest, + ) -> Result, AssemblyError> { // if the assembler is in debug mode, start tracking the instruction about to be executed; // this will allow us to map the instruction to the sequence of operations which were // executed as a part of this instruction. @@ -32,7 +36,7 @@ impl Assembler { span_builder.track_instruction(instruction, ctx); } - let result = self.compile_instruction_impl(instruction, span_builder, ctx)?; + let result = self.compile_instruction_impl(instruction, span_builder, ctx, mast_forest)?; // compute and update the cycle count of the instruction which just finished executing if self.in_debug_mode() { @@ -45,9 +49,10 @@ impl Assembler { fn compile_instruction_impl( &self, instruction: &Instruction, - span_builder: &mut SpanBuilder, + span_builder: &mut BasicBlockBuilder, ctx: &mut AssemblyContext, - ) -> Result, AssemblyError> { + mast_forest: &mut MastForest, + ) -> Result, AssemblyError> { use Operation::*; match instruction { @@ -363,14 +368,20 @@ impl Assembler { Instruction::RCombBase => span_builder.push_op(RCombBase), // ----- exec/call instructions ------------------------------------------------------- - Instruction::Exec(ref callee) => return self.invoke(InvokeKind::Exec, callee, ctx), - Instruction::Call(ref callee) => return self.invoke(InvokeKind::Call, callee, ctx), + Instruction::Exec(ref callee) => { + return self.invoke(InvokeKind::Exec, callee, ctx, mast_forest) + } + Instruction::Call(ref callee) => { + return self.invoke(InvokeKind::Call, callee, ctx, mast_forest) + } Instruction::SysCall(ref callee) => { - return self.invoke(InvokeKind::SysCall, callee, ctx) + return self.invoke(InvokeKind::SysCall, callee, ctx, mast_forest) + } + Instruction::DynExec => return self.dynexec(mast_forest), + Instruction::DynCall => return self.dyncall(mast_forest), + Instruction::ProcRef(ref callee) => { + self.procref(callee, ctx, span_builder, mast_forest)? } - Instruction::DynExec => return self.dynexec(), - Instruction::DynCall => return self.dyncall(), - Instruction::ProcRef(ref callee) => self.procref(callee, ctx, span_builder)?, // ----- debug decorators ------------------------------------------------------------- Instruction::Breakpoint => { @@ -411,7 +422,7 @@ impl Assembler { /// /// When the value is 0, PUSH operation is replaced with PAD. When the value is 1, PUSH operation /// is replaced with PAD INCR because in most cases this will be more efficient than doing a PUSH. -fn push_u32_value(span_builder: &mut SpanBuilder, value: u32) { +fn push_u32_value(span_builder: &mut BasicBlockBuilder, value: u32) { use Operation::*; if value == 0 { @@ -429,7 +440,7 @@ fn push_u32_value(span_builder: &mut SpanBuilder, value: u32) { /// /// When the value is 0, PUSH operation is replaced with PAD. When the value is 1, PUSH operation /// is replaced with PAD INCR because in most cases this will be more efficient than doing a PUSH. -fn push_felt(span_builder: &mut SpanBuilder, value: Felt) { +fn push_felt(span_builder: &mut BasicBlockBuilder, value: Felt) { use Operation::*; if value == ZERO { diff --git a/assembly/src/assembler/instruction/procedures.rs b/assembly/src/assembler/instruction/procedures.rs index 1949d75e3a..9ac2bcfdd3 100644 --- a/assembly/src/assembler/instruction/procedures.rs +++ b/assembly/src/assembler/instruction/procedures.rs @@ -1,10 +1,11 @@ -use super::{Assembler, AssemblyContext, CodeBlock, Operation, SpanBuilder}; +use super::{Assembler, AssemblyContext, BasicBlockBuilder, Operation}; use crate::{ ast::{InvocationTarget, InvokeKind}, AssemblyError, RpoDigest, SourceSpan, Span, Spanned, }; use smallvec::SmallVec; +use vm_core::mast::{MastForest, MastNode, MastNodeId}; /// Procedure Invocation impl Assembler { @@ -13,10 +14,11 @@ impl Assembler { kind: InvokeKind, callee: &InvocationTarget, context: &mut AssemblyContext, - ) -> Result, AssemblyError> { + mast_forest: &mut MastForest, + ) -> Result, AssemblyError> { let span = callee.span(); - let digest = self.resolve_target(kind, callee, context)?; - self.invoke_mast_root(kind, span, digest, context) + let digest = self.resolve_target(kind, callee, context, mast_forest)?; + self.invoke_mast_root(kind, span, digest, context, mast_forest) } fn invoke_mast_root( @@ -25,7 +27,8 @@ impl Assembler { span: SourceSpan, mast_root: RpoDigest, context: &mut AssemblyContext, - ) -> Result, AssemblyError> { + mast_forest: &mut MastForest, + ) -> Result, AssemblyError> { // Get the procedure from the assembler let cache = &self.procedure_cache; let current_source_file = context.unwrap_current_procedure().source_file(); @@ -65,9 +68,9 @@ impl Assembler { }) } })?; - context.register_external_call(&proc, false)?; + context.register_external_call(&proc, false, mast_forest)?; } - Some(proc) => context.register_external_call(&proc, false)?, + Some(proc) => context.register_external_call(&proc, false, mast_forest)?, None if matches!(kind, InvokeKind::SysCall) => { return Err(AssemblyError::UnknownSysCallTarget { span, @@ -78,37 +81,71 @@ impl Assembler { None => context.register_phantom_call(Span::new(span, mast_root))?, } - let block = match kind { - // For `exec`, we use a PROXY block to reflect that the root is - // conceptually inlined at this location - InvokeKind::Exec => CodeBlock::new_proxy(mast_root), - // For `call`, we just use the corresponding CALL block - InvokeKind::Call => CodeBlock::new_call(mast_root), - // For `syscall`, we just use the corresponding SYSCALL block - InvokeKind::SysCall => CodeBlock::new_syscall(mast_root), + let mast_root_node_id = { + // Note that here we rely on the fact that we topologically sorted the procedures, such + // that when we assemble a procedure, all procedures that it calls will have been + // assembled, and hence be present in the `MastForest`. We currently assume that the + // `MastForest` contains all the procedures being called; "external procedures" only + // known by digest are not currently supported. + let callee_id = mast_forest + .get_node_id_by_digest(mast_root) + .unwrap_or_else(|| panic!("MAST root {} not in MAST forest", mast_root)); + + match kind { + // For `exec`, we return the root of the procedure being exec'd, which has the + // effect of inlining it + InvokeKind::Exec => callee_id, + // For `call`, we just use the corresponding CALL block + InvokeKind::Call => { + let node = MastNode::new_call(callee_id, mast_forest); + mast_forest.ensure_node(node) + } + // For `syscall`, we just use the corresponding SYSCALL block + InvokeKind::SysCall => { + let node = MastNode::new_syscall(callee_id, mast_forest); + mast_forest.ensure_node(node) + } + } }; - Ok(Some(block)) + + Ok(Some(mast_root_node_id)) } - pub(super) fn dynexec(&self) -> Result, AssemblyError> { - // create a new DYN block for the dynamic code execution and return - Ok(Some(CodeBlock::new_dyn())) + /// Creates a new DYN block for the dynamic code execution and return. + pub(super) fn dynexec( + &self, + mast_forest: &mut MastForest, + ) -> Result, AssemblyError> { + let dyn_node_id = mast_forest.ensure_node(MastNode::Dyn); + + Ok(Some(dyn_node_id)) } - pub(super) fn dyncall(&self) -> Result, AssemblyError> { - // create a new CALL block whose target is DYN - Ok(Some(CodeBlock::new_dyncall())) + /// Creates a new CALL block whose target is DYN. + pub(super) fn dyncall( + &self, + mast_forest: &mut MastForest, + ) -> Result, AssemblyError> { + let dyn_call_node_id = { + let dyn_node_id = mast_forest.ensure_node(MastNode::Dyn); + let dyn_call_node = MastNode::new_call(dyn_node_id, mast_forest); + + mast_forest.ensure_node(dyn_call_node) + }; + + Ok(Some(dyn_call_node_id)) } pub(super) fn procref( &self, callee: &InvocationTarget, context: &mut AssemblyContext, - span_builder: &mut SpanBuilder, + span_builder: &mut BasicBlockBuilder, + mast_forest: &MastForest, ) -> Result<(), AssemblyError> { let span = callee.span(); - let digest = self.resolve_target(InvokeKind::Exec, callee, context)?; - self.procref_mast_root(span, digest, context, span_builder) + let digest = self.resolve_target(InvokeKind::Exec, callee, context, mast_forest)?; + self.procref_mast_root(span, digest, context, span_builder, mast_forest) } fn procref_mast_root( @@ -116,13 +153,14 @@ impl Assembler { span: SourceSpan, mast_root: RpoDigest, context: &mut AssemblyContext, - span_builder: &mut SpanBuilder, + span_builder: &mut BasicBlockBuilder, + mast_forest: &MastForest, ) -> Result<(), AssemblyError> { // Add the root to the callset to be able to use dynamic instructions // with the referenced procedure later let cache = &self.procedure_cache; match cache.get_by_mast_root(&mast_root) { - Some(proc) => context.register_external_call(&proc, false)?, + Some(proc) => context.register_external_call(&proc, false, mast_forest)?, None => context.register_phantom_call(Span::new(span, mast_root))?, } diff --git a/assembly/src/assembler/instruction/u32_ops.rs b/assembly/src/assembler/instruction/u32_ops.rs index 487ad73b45..e8c93c1833 100644 --- a/assembly/src/assembler/instruction/u32_ops.rs +++ b/assembly/src/assembler/instruction/u32_ops.rs @@ -1,4 +1,4 @@ -use super::{field_ops::append_pow2_op, push_u32_value, validate_param, SpanBuilder}; +use super::{field_ops::append_pow2_op, push_u32_value, validate_param, BasicBlockBuilder}; use crate::{ diagnostics::{RelatedError, Report}, AssemblyContext, AssemblyError, Span, MAX_U32_ROTATE_VALUE, MAX_U32_SHIFT_VALUE, @@ -23,7 +23,7 @@ pub enum U32OpMode { /// /// Implemented by executing DUP U32SPLIT SWAP DROP EQZ on each element in the word /// and combining the results using AND operation (total of 23 VM cycles) -pub fn u32testw(span_builder: &mut SpanBuilder) { +pub fn u32testw(span_builder: &mut BasicBlockBuilder) { #[rustfmt::skip] let ops = [ // Test the fourth element @@ -45,7 +45,7 @@ pub fn u32testw(span_builder: &mut SpanBuilder) { /// /// Implemented by executing `U32ASSERT2` on each pair of elements in the word. /// Total of 6 VM cycles. -pub fn u32assertw(span_builder: &mut SpanBuilder, err_code: Felt) { +pub fn u32assertw(span_builder: &mut BasicBlockBuilder, err_code: Felt) { #[rustfmt::skip] let ops = [ // Test the first and the second elements @@ -76,7 +76,7 @@ pub fn u32assertw(span_builder: &mut SpanBuilder, err_code: Felt) { /// - u32wrapping_add.b: 3 cycles /// - u32overflowing_add: 1 cycles /// - u32overflowing_add.b: 2 cycles -pub fn u32add(span_builder: &mut SpanBuilder, op_mode: U32OpMode, imm: Option) { +pub fn u32add(span_builder: &mut BasicBlockBuilder, op_mode: U32OpMode, imm: Option) { handle_arithmetic_operation(span_builder, U32add, op_mode, imm); } @@ -90,7 +90,7 @@ pub fn u32add(span_builder: &mut SpanBuilder, op_mode: U32OpMode, imm: Option) { +pub fn u32sub(span_builder: &mut BasicBlockBuilder, op_mode: U32OpMode, imm: Option) { handle_arithmetic_operation(span_builder, U32sub, op_mode, imm); } @@ -104,7 +104,7 @@ pub fn u32sub(span_builder: &mut SpanBuilder, op_mode: U32OpMode, imm: Option) { +pub fn u32mul(span_builder: &mut BasicBlockBuilder, op_mode: U32OpMode, imm: Option) { handle_arithmetic_operation(span_builder, U32mul, op_mode, imm); } @@ -116,7 +116,7 @@ pub fn u32mul(span_builder: &mut SpanBuilder, op_mode: U32OpMode, imm: Option>, ) -> Result<(), AssemblyError> { @@ -133,7 +133,7 @@ pub fn u32div( /// - 5 cycles if b is 1 /// - 4 cycles if b is not 1 pub fn u32mod( - span_builder: &mut SpanBuilder, + span_builder: &mut BasicBlockBuilder, ctx: &AssemblyContext, imm: Option>, ) -> Result<(), AssemblyError> { @@ -150,7 +150,7 @@ pub fn u32mod( /// - 3 cycles if b is 1 /// - 2 cycles if b is not 1 pub fn u32divmod( - span_builder: &mut SpanBuilder, + span_builder: &mut BasicBlockBuilder, ctx: &AssemblyContext, imm: Option>, ) -> Result<(), AssemblyError> { @@ -166,7 +166,7 @@ pub fn u32divmod( /// subtracting the element, flips the bits of the original value to perform a bitwise NOT. /// /// This takes 5 VM cycles. -pub fn u32not(span_builder: &mut SpanBuilder) { +pub fn u32not(span_builder: &mut BasicBlockBuilder) { #[rustfmt::skip] let ops = [ // Perform the operation @@ -189,7 +189,7 @@ pub fn u32not(span_builder: &mut SpanBuilder) { /// VM cycles per mode: /// - u32shl: 18 cycles /// - u32shl.b: 3 cycles -pub fn u32shl(span_builder: &mut SpanBuilder, imm: Option) -> Result<(), AssemblyError> { +pub fn u32shl(span_builder: &mut BasicBlockBuilder, imm: Option) -> Result<(), AssemblyError> { prepare_bitwise::(span_builder, imm)?; if imm != Some(0) { span_builder.push_ops([U32mul, Drop]); @@ -205,7 +205,7 @@ pub fn u32shl(span_builder: &mut SpanBuilder, imm: Option) -> Result<(), Ass /// VM cycles per mode: /// - u32shr: 18 cycles /// - u32shr.b: 3 cycles -pub fn u32shr(span_builder: &mut SpanBuilder, imm: Option) -> Result<(), AssemblyError> { +pub fn u32shr(span_builder: &mut BasicBlockBuilder, imm: Option) -> Result<(), AssemblyError> { prepare_bitwise::(span_builder, imm)?; if imm != Some(0) { span_builder.push_ops([U32div, Drop]); @@ -221,7 +221,7 @@ pub fn u32shr(span_builder: &mut SpanBuilder, imm: Option) -> Result<(), Ass /// VM cycles per mode: /// - u32rotl: 18 cycles /// - u32rotl.b: 3 cycles -pub fn u32rotl(span_builder: &mut SpanBuilder, imm: Option) -> Result<(), AssemblyError> { +pub fn u32rotl(span_builder: &mut BasicBlockBuilder, imm: Option) -> Result<(), AssemblyError> { prepare_bitwise::(span_builder, imm)?; if imm != Some(0) { span_builder.push_ops([U32mul, Add]); @@ -237,7 +237,7 @@ pub fn u32rotl(span_builder: &mut SpanBuilder, imm: Option) -> Result<(), As /// VM cycles per mode: /// - u32rotr: 22 cycles /// - u32rotr.b: 3 cycles -pub fn u32rotr(span_builder: &mut SpanBuilder, imm: Option) -> Result<(), AssemblyError> { +pub fn u32rotr(span_builder: &mut BasicBlockBuilder, imm: Option) -> Result<(), AssemblyError> { match imm { Some(0) => { // if rotation is performed by 0, do nothing (Noop) @@ -260,7 +260,7 @@ pub fn u32rotr(span_builder: &mut SpanBuilder, imm: Option) -> Result<(), As /// Translates u32popcnt assembly instructions to VM operations. /// /// This operation takes 33 cycles. -pub fn u32popcnt(span_builder: &mut SpanBuilder) { +pub fn u32popcnt(span_builder: &mut BasicBlockBuilder) { #[rustfmt::skip] let ops = [ // i = i - ((i >> 1) & 0x55555555); @@ -297,7 +297,7 @@ pub fn u32popcnt(span_builder: &mut SpanBuilder) { /// provider). /// /// This operation takes 37 VM cycles. -pub fn u32clz(span: &mut SpanBuilder) { +pub fn u32clz(span: &mut BasicBlockBuilder) { span.push_advice_injector(AdviceInjector::U32Clz); span.push_op(AdvPop); // [clz, n, ...] @@ -309,7 +309,7 @@ pub fn u32clz(span: &mut SpanBuilder) { /// provider). /// /// This operation takes 34 VM cycles. -pub fn u32ctz(span: &mut SpanBuilder) { +pub fn u32ctz(span: &mut BasicBlockBuilder) { span.push_advice_injector(AdviceInjector::U32Ctz); span.push_op(AdvPop); // [ctz, n, ...] @@ -321,7 +321,7 @@ pub fn u32ctz(span: &mut SpanBuilder) { /// provider). /// /// This operation takes 36 VM cycles. -pub fn u32clo(span: &mut SpanBuilder) { +pub fn u32clo(span: &mut BasicBlockBuilder) { span.push_advice_injector(AdviceInjector::U32Clo); span.push_op(AdvPop); // [clo, n, ...] @@ -333,7 +333,7 @@ pub fn u32clo(span: &mut SpanBuilder) { /// provider). /// /// This operation takes 33 VM cycles. -pub fn u32cto(span: &mut SpanBuilder) { +pub fn u32cto(span: &mut BasicBlockBuilder) { span.push_advice_injector(AdviceInjector::U32Cto); span.push_op(AdvPop); // [cto, n, ...] @@ -346,7 +346,7 @@ pub fn u32cto(span: &mut SpanBuilder) { /// - Overflowing: does not check if the inputs are u32 values; overflow or underflow bits are /// pushed onto the stack. fn handle_arithmetic_operation( - span_builder: &mut SpanBuilder, + span_builder: &mut BasicBlockBuilder, op: Operation, op_mode: U32OpMode, imm: Option, @@ -366,7 +366,7 @@ fn handle_arithmetic_operation( /// Handles common parts of u32div, u32mod, and u32divmod operations, including handling of /// immediate parameters. fn handle_division( - span_builder: &mut SpanBuilder, + span_builder: &mut BasicBlockBuilder, ctx: &AssemblyContext, imm: Option>, ) -> Result<(), AssemblyError> { @@ -394,7 +394,7 @@ fn handle_division( /// Mutate the first two elements of the stack from `[b, a, ..]` into `[2^b, a, ..]`, with `b` /// either as a provided immediate value, or as an element that already exists in the stack. fn prepare_bitwise( - span_builder: &mut SpanBuilder, + span_builder: &mut BasicBlockBuilder, imm: Option, ) -> Result<(), AssemblyError> { match imm { @@ -449,7 +449,7 @@ fn prepare_bitwise( /// `[clz, n, ... ] -> [clz, ... ]` /// /// VM cycles: 36 -fn calculate_clz(span: &mut SpanBuilder) { +fn calculate_clz(span: &mut BasicBlockBuilder) { // [clz, n, ...] #[rustfmt::skip] let ops_group_1 = [ @@ -524,7 +524,7 @@ fn calculate_clz(span: &mut SpanBuilder) { /// `[clo, n, ... ] -> [clo, ... ]` /// /// VM cycles: 35 -fn calculate_clo(span: &mut SpanBuilder) { +fn calculate_clo(span: &mut BasicBlockBuilder) { // [clo, n, ...] #[rustfmt::skip] let ops_group_1 = [ @@ -599,7 +599,7 @@ fn calculate_clo(span: &mut SpanBuilder) { /// `[ctz, n, ... ] -> [ctz, ... ]` /// /// VM cycles: 33 -fn calculate_ctz(span: &mut SpanBuilder) { +fn calculate_ctz(span: &mut BasicBlockBuilder) { // [ctz, n, ...] #[rustfmt::skip] let ops_group_1 = [ @@ -673,7 +673,7 @@ fn calculate_ctz(span: &mut SpanBuilder) { /// `[cto, n, ... ] -> [cto, ... ]` /// /// VM cycles: 32 -fn calculate_cto(span: &mut SpanBuilder) { +fn calculate_cto(span: &mut BasicBlockBuilder) { // [cto, n, ...] #[rustfmt::skip] let ops_group_1 = [ @@ -717,14 +717,14 @@ fn calculate_cto(span: &mut SpanBuilder) { /// Translates u32lt assembly instructions to VM operations. /// /// This operation takes 3 cycles. -pub fn u32lt(span_builder: &mut SpanBuilder) { +pub fn u32lt(span_builder: &mut BasicBlockBuilder) { compute_lt(span_builder); } /// Translates u32lte assembly instructions to VM operations. /// /// This operation takes 5 cycles. -pub fn u32lte(span_builder: &mut SpanBuilder) { +pub fn u32lte(span_builder: &mut BasicBlockBuilder) { // Compute the lt with reversed number to get a gt check span_builder.push_op(Swap); compute_lt(span_builder); @@ -736,7 +736,7 @@ pub fn u32lte(span_builder: &mut SpanBuilder) { /// Translates u32gt assembly instructions to VM operations. /// /// This operation takes 4 cycles. -pub fn u32gt(span_builder: &mut SpanBuilder) { +pub fn u32gt(span_builder: &mut BasicBlockBuilder) { // Reverse the numbers so we can get a gt check. span_builder.push_op(Swap); @@ -746,7 +746,7 @@ pub fn u32gt(span_builder: &mut SpanBuilder) { /// Translates u32gte assembly instructions to VM operations. /// /// This operation takes 4 cycles. -pub fn u32gte(span_builder: &mut SpanBuilder) { +pub fn u32gte(span_builder: &mut BasicBlockBuilder) { compute_lt(span_builder); // Flip the final results to get the gte results. @@ -760,7 +760,7 @@ pub fn u32gte(span_builder: &mut SpanBuilder) { /// Then we finally drop the top element to keep the min. /// /// This operation takes 8 cycles. -pub fn u32min(span_builder: &mut SpanBuilder) { +pub fn u32min(span_builder: &mut BasicBlockBuilder) { compute_max_and_min(span_builder); // Drop the max and keep the min @@ -774,7 +774,7 @@ pub fn u32min(span_builder: &mut SpanBuilder) { /// Then we finally drop the 2nd element to keep the max. /// /// This operation takes 9 cycles. -pub fn u32max(span_builder: &mut SpanBuilder) { +pub fn u32max(span_builder: &mut BasicBlockBuilder) { compute_max_and_min(span_builder); // Drop the min and keep the max @@ -786,7 +786,7 @@ pub fn u32max(span_builder: &mut SpanBuilder) { /// Inserts the VM operations to check if the second element is less than /// the top element. This takes 3 cycles. -fn compute_lt(span_builder: &mut SpanBuilder) { +fn compute_lt(span_builder: &mut BasicBlockBuilder) { span_builder.push_ops([ U32sub, Swap, Drop, // Perform the operations ]) @@ -795,7 +795,7 @@ fn compute_lt(span_builder: &mut SpanBuilder) { /// Duplicate the top two elements in the stack and determine the min and max between them. /// /// The maximum number will be at the top of the stack and minimum will be at the 2nd index. -fn compute_max_and_min(span_builder: &mut SpanBuilder) { +fn compute_max_and_min(span_builder: &mut BasicBlockBuilder) { // Copy top two elements of the stack. span_builder.push_ops([Dup1, Dup1]); diff --git a/assembly/src/assembler/mod.rs b/assembly/src/assembler/mod.rs index 09d21b2f76..4c7bf1dace 100644 --- a/assembly/src/assembler/mod.rs +++ b/assembly/src/assembler/mod.rs @@ -9,16 +9,18 @@ use crate::{ RpoDigest, Spanned, ONE, ZERO, }; use alloc::{boxed::Box, sync::Arc, vec::Vec}; +use miette::miette; use vm_core::{ - code_blocks::CodeBlock, CodeBlockTable, Decorator, DecoratorList, Kernel, Operation, Program, + mast::{MastForest, MastNode, MastNodeId, MerkleTreeNode}, + Decorator, DecoratorList, Kernel, Operation, Program, }; +mod basic_block_builder; mod context; mod id; mod instruction; mod module_graph; mod procedure; -mod span_builder; #[cfg(test)] mod tests; @@ -27,9 +29,9 @@ pub use self::id::{GlobalProcedureIndex, ModuleIndex}; pub(crate) use self::module_graph::ProcedureCache; pub use self::procedure::Procedure; +use self::basic_block_builder::BasicBlockBuilder; use self::context::ProcedureContext; use self::module_graph::{CallerInfo, ModuleGraph, ResolvedTarget}; -use self::span_builder::SpanBuilder; // ARTIFACT KIND // ================================================================================================ @@ -83,7 +85,9 @@ pub enum ArtifactKind { /// procedures, build the assembler with them first, using the various builder methods on /// [Assembler], e.g. [Assembler::with_module], [Assembler::with_library], etc. Then, call /// [Assembler::compile] or [Assembler::compile_ast] to get your compiled program. +#[derive(Clone)] pub struct Assembler { + mast_forest: MastForest, /// The global [ModuleGraph] for this assembler. All new [AssemblyContext]s inherit this graph /// as a baseline. module_graph: Box, @@ -100,6 +104,7 @@ pub struct Assembler { impl Default for Assembler { fn default() -> Self { Self { + mast_forest: Default::default(), module_graph: Default::default(), procedure_cache: Default::default(), warnings_as_errors: false, @@ -116,10 +121,14 @@ impl Assembler { Self::default() } - /// Start building an [Assembler] with the given [Kernel]. - pub fn with_kernel(kernel: Kernel) -> Self { + /// Start building an [`Assembler`] with the given [`Kernel`] and the [`MastForest`] that was + /// used to compile the kernel. + pub fn with_kernel(kernel: Kernel, mast_forest: MastForest) -> Self { let mut assembler = Self::new(); + assembler.module_graph.set_kernel(None, kernel); + assembler.mast_forest = mast_forest; + assembler } @@ -131,8 +140,14 @@ impl Assembler { let mut assembler = Self::new(); let opts = CompileOptions::for_kernel(); let module = module.compile_with_options(opts)?; - let (kernel_index, kernel) = assembler.assemble_kernel_module(module)?; + + let mut mast_forest = MastForest::new(); + + let (kernel_index, kernel) = assembler.assemble_kernel_module(module, &mut mast_forest)?; assembler.module_graph.set_kernel(Some(kernel_index), kernel); + mast_forest.set_kernel(assembler.module_graph.kernel().clone()); + + assembler.mast_forest = mast_forest; Ok(assembler) } @@ -289,12 +304,6 @@ impl Assembler { self.allow_phantom_calls } - #[cfg(any(test, feature = "testing"))] - #[doc(hidden)] - pub fn procedure_cache(&self) -> &ProcedureCache { - &self.procedure_cache - } - #[cfg(any(test, feature = "testing"))] #[doc(hidden)] pub fn module_graph(&self) -> &ModuleGraph { @@ -304,25 +313,37 @@ impl Assembler { /// Compilation/Assembly impl Assembler { - /// Compiles the provided module into a [Program]. The resulting program can be executed - /// on Miden VM. + /// Compiles the provided module into a [`MastForest`]. /// /// # Errors /// /// Returns an error if parsing or compilation of the specified program fails. - pub fn assemble(&mut self, source: impl Compile) -> Result { + pub fn assemble(self, source: impl Compile) -> Result { let mut context = AssemblyContext::default(); context.set_warnings_as_errors(self.warnings_as_errors); self.assemble_in_context(source, &mut context) } + /// Compiles the provided module into a [`Program`]. The resulting program can be executed on + /// Miden VM. + /// + /// # Errors + /// + /// Returns an error if parsing or compilation of the specified program fails, or if the source + /// doesn't have an entrypoint. + pub fn assemble_program(self, source: impl Compile) -> Result { + let mast_forest = self.assemble(source)?; + + mast_forest.try_into().map_err(|program_err| miette!("{program_err}")) + } + /// Like [Assembler::compile], but also takes an [AssemblyContext] to configure the assembler. pub fn assemble_in_context( - &mut self, + self, source: impl Compile, context: &mut AssemblyContext, - ) -> Result { + ) -> Result { let opts = CompileOptions { warnings_as_errors: context.warnings_as_errors(), ..CompileOptions::default() @@ -339,10 +360,10 @@ impl Assembler { /// Returns an error if parsing or compilation of the specified program fails, or the options /// are invalid. pub fn assemble_with_options( - &mut self, + self, source: impl Compile, options: CompileOptions, - ) -> Result { + ) -> Result { let mut context = AssemblyContext::default(); context.set_warnings_as_errors(options.warnings_as_errors); @@ -353,17 +374,32 @@ impl Assembler { /// to configure the assembler. #[instrument("assemble_with_opts_in_context", skip_all)] pub fn assemble_with_options_in_context( - &mut self, + self, + source: impl Compile, + options: CompileOptions, + context: &mut AssemblyContext, + ) -> Result { + self.assemble_with_options_in_context_impl(source, options, context) + } + + /// Implementation of [`Self::assemble_with_options_in_context`] which doesn't consume `self`. + /// + /// The main purpose of this separation is to enable some tests to access the assembler state + /// after assembly. + fn assemble_with_options_in_context_impl( + mut self, source: impl Compile, options: CompileOptions, context: &mut AssemblyContext, - ) -> Result { + ) -> Result { if options.kind != ModuleKind::Executable { return Err(Report::msg( "invalid compile options: assemble_with_opts_in_context requires that the kind be 'executable'", )); } + let mut mast_forest = core::mem::take(&mut self.mast_forest); + let program = source.compile_with_options(CompileOptions { // Override the module name so that we always compile the executable // module as #exec @@ -392,7 +428,9 @@ impl Assembler { }) .ok_or(SemanticAnalysisError::MissingEntrypoint)?; - self.compile_program(entrypoint, context) + self.compile_program(entrypoint, context, &mut mast_forest)?; + + Ok(mast_forest) } /// Compile and assembles all procedures in the specified module, adding them to the procedure @@ -438,9 +476,16 @@ impl Assembler { // Recompute graph with the provided module, and start assembly let module_id = self.module_graph.add_module(module)?; self.module_graph.recompute()?; - self.assemble_graph(context)?; - self.get_module_exports(module_id) + let mut mast_forest = core::mem::take(&mut self.mast_forest); + + self.assemble_graph(context, &mut mast_forest)?; + let exported_procedure_digests = self.get_module_exports(module_id, &mast_forest); + + // Reassign the mast_forest to the assembler for use is a future program assembly + self.mast_forest = mast_forest; + + exported_procedure_digests } /// Compiles the given kernel module, returning both the compiled kernel and its index in the @@ -448,6 +493,7 @@ impl Assembler { fn assemble_kernel_module( &mut self, module: Box, + mast_forest: &mut MastForest, ) -> Result<(ModuleIndex, Kernel), Report> { if !module.is_kernel() { return Err(Report::msg(format!("expected kernel module, got {}", module.kind()))); @@ -469,8 +515,8 @@ impl Assembler { module: kernel_index, index: ProcedureIndex::new(index), }; - let compiled = self.compile_subgraph(gid, false, &mut context)?; - kernel.push(compiled.code().hash()); + let compiled = self.compile_subgraph(gid, false, &mut context, mast_forest)?; + kernel.push(compiled.mast_root(mast_forest)); } Kernel::new(&kernel) @@ -481,7 +527,11 @@ impl Assembler { /// Get the set of procedure roots for all exports of the given module /// /// Returns an error if the provided Miden Assembly is invalid. - fn get_module_exports(&mut self, module: ModuleIndex) -> Result, Report> { + fn get_module_exports( + &mut self, + module: ModuleIndex, + mast_forest: &MastForest, + ) -> Result, Report> { assert!(self.module_graph.contains_module(module), "invalid module index"); let mut exports = Vec::new(); @@ -519,7 +569,8 @@ impl Assembler { } }); - exports.push(proc.code().hash()); + let proc_code_node = &mast_forest[proc.body_node_id()]; + exports.push(proc_code_node.digest()); } Ok(exports) @@ -527,43 +578,39 @@ impl Assembler { /// Compile the provided [Module] into a [Program]. /// + /// Ensures that the [`MastForest`] entrypoint is set to the entrypoint of the program. + /// /// Returns an error if the provided Miden Assembly is invalid. fn compile_program( &mut self, entrypoint: GlobalProcedureIndex, context: &mut AssemblyContext, - ) -> Result { + mast_forest: &mut MastForest, + ) -> Result<(), Report> { // Raise an error if we are called with an invalid entrypoint assert!(self.module_graph[entrypoint].name().is_main()); // Compile the module graph rooted at the entrypoint - let entry = self.compile_subgraph(entrypoint, true, context)?; - - // Construct the code block table by taking the call set of the - // executable entrypoint and adding the code blocks of all those - // procedures to the table. - let mut code_blocks = CodeBlockTable::default(); - for callee in entry.callset().iter() { - let code_block = self - .procedure_cache - .get_by_mast_root(callee) - .map(|p| p.code().clone()) - .ok_or(AssemblyError::UndefinedCallSetProcedure { digest: *callee })?; - code_blocks.insert(code_block); - } + let entry_procedure = self.compile_subgraph(entrypoint, true, context, mast_forest)?; - let body = entry.code().clone(); - Ok(Program::with_kernel(body, self.module_graph.kernel().clone(), code_blocks)) + mast_forest.set_entrypoint(entry_procedure.body_node_id()); + + Ok(()) } /// Compile all of the uncompiled procedures in the module graph, placing them /// in the procedure cache once compiled. /// /// Returns an error if any of the provided Miden Assembly is invalid. - fn assemble_graph(&mut self, context: &mut AssemblyContext) -> Result<(), Report> { + fn assemble_graph( + &mut self, + context: &mut AssemblyContext, + mast_forest: &mut MastForest, + ) -> Result<(), Report> { let mut worklist = self.module_graph.topological_sort().to_vec(); assert!(!worklist.is_empty()); - self.process_graph_worklist(&mut worklist, context, None).map(|_| ()) + self.process_graph_worklist(&mut worklist, context, None, mast_forest) + .map(|_| ()) } /// Compile the uncompiled procedure in the module graph which are members of the subgraph @@ -575,6 +622,7 @@ impl Assembler { root: GlobalProcedureIndex, is_entrypoint: bool, context: &mut AssemblyContext, + mast_forest: &mut MastForest, ) -> Result, Report> { let mut worklist = self.module_graph.topological_sort_from_root(root).map_err(|cycle| { let iter = cycle.into_node_ids(); @@ -590,9 +638,9 @@ impl Assembler { assert!(!worklist.is_empty()); let compiled = if is_entrypoint { - self.process_graph_worklist(&mut worklist, context, Some(root))? + self.process_graph_worklist(&mut worklist, context, Some(root), mast_forest)? } else { - let _ = self.process_graph_worklist(&mut worklist, context, None)?; + let _ = self.process_graph_worklist(&mut worklist, context, None, mast_forest)?; self.procedure_cache.get(root) }; @@ -604,6 +652,7 @@ impl Assembler { worklist: &mut Vec, context: &mut AssemblyContext, entrypoint: Option, + mast_forest: &mut MastForest, ) -> Result>, Report> { // Process the topological ordering in reverse order (bottom-up), so that // each procedure is compiled with all of its dependencies fully compiled @@ -611,7 +660,8 @@ impl Assembler { while let Some(procedure_gid) = worklist.pop() { // If we have already compiled this procedure, do not recompile if let Some(proc) = self.procedure_cache.get(procedure_gid) { - self.module_graph.register_mast_root(procedure_gid, proc.mast_root())?; + self.module_graph + .register_mast_root(procedure_gid, proc.mast_root(mast_forest))?; continue; } let is_entry = entrypoint == Some(procedure_gid); @@ -631,17 +681,17 @@ impl Assembler { .with_source_file(ast.source_file()); // Compile this procedure - let procedure = self.compile_procedure(pctx, context)?; + let procedure = self.compile_procedure(pctx, context, mast_forest)?; // Cache the compiled procedure, unless it's the program entrypoint if is_entry { compiled_entrypoint = Some(Arc::from(procedure)); } else { // Make the MAST root available to all dependents - let digest = procedure.mast_root(); + let digest = procedure.mast_root(mast_forest); self.module_graph.register_mast_root(procedure_gid, digest)?; - self.procedure_cache.insert(procedure_gid, Arc::from(procedure))?; + self.procedure_cache.insert(procedure_gid, Arc::from(procedure), mast_forest)?; } } @@ -653,6 +703,7 @@ impl Assembler { &self, procedure: ProcedureContext, context: &mut AssemblyContext, + mast_forest: &mut MastForest, ) -> Result, Report> { // Make sure the current procedure context is available during codegen let gid = procedure.id(); @@ -670,9 +721,9 @@ impl Assembler { prologue: vec![Operation::Push(num_locals), Operation::FmpUpdate], epilogue: vec![Operation::Push(-num_locals), Operation::FmpUpdate], }; - self.compile_body(proc.iter(), context, Some(wrapper))? + self.compile_body(proc.iter(), context, Some(wrapper), mast_forest)? } else { - self.compile_body(proc.iter(), context, None)? + self.compile_body(proc.iter(), context, None, mast_forest)? }; let pctx = context.take_current_procedure().unwrap(); @@ -684,69 +735,103 @@ impl Assembler { body: I, context: &mut AssemblyContext, wrapper: Option, - ) -> Result + mast_forest: &mut MastForest, + ) -> Result where I: Iterator, { use ast::Op; - let mut blocks: Vec = Vec::new(); - let mut span = SpanBuilder::new(wrapper); + let mut mast_node_ids: Vec = Vec::new(); + let mut basic_block_builder = BasicBlockBuilder::new(wrapper); for op in body { match op { Op::Inst(inst) => { - if let Some(block) = self.compile_instruction(inst, &mut span, context)? { - span.extract_span_into(&mut blocks); - blocks.push(block); + if let Some(mast_node_id) = self.compile_instruction( + inst, + &mut basic_block_builder, + context, + mast_forest, + )? { + if let Some(basic_block_id) = + basic_block_builder.make_basic_block(mast_forest) + { + mast_node_ids.push(basic_block_id); + } + + mast_node_ids.push(mast_node_id); } } Op::If { then_blk, else_blk, .. } => { - span.extract_span_into(&mut blocks); + if let Some(basic_block_id) = basic_block_builder.make_basic_block(mast_forest) + { + mast_node_ids.push(basic_block_id); + } - let then_blk = self.compile_body(then_blk.iter(), context, None)?; + let then_blk = + self.compile_body(then_blk.iter(), context, None, mast_forest)?; // else is an exception because it is optional; hence, will have to be replaced // by noop span let else_blk = if else_blk.is_empty() { - CodeBlock::new_span(vec![Operation::Noop]) + let basic_block_node = MastNode::new_basic_block(vec![Operation::Noop]); + mast_forest.ensure_node(basic_block_node) } else { - self.compile_body(else_blk.iter(), context, None)? + self.compile_body(else_blk.iter(), context, None, mast_forest)? }; - let block = CodeBlock::new_split(then_blk, else_blk); + let split_node_id = { + let split_node = MastNode::new_split(then_blk, else_blk, mast_forest); - blocks.push(block); + mast_forest.ensure_node(split_node) + }; + mast_node_ids.push(split_node_id); } Op::Repeat { count, body, .. } => { - span.extract_span_into(&mut blocks); + if let Some(basic_block_id) = basic_block_builder.make_basic_block(mast_forest) + { + mast_node_ids.push(basic_block_id); + } - let block = self.compile_body(body.iter(), context, None)?; + let repeat_node_id = + self.compile_body(body.iter(), context, None, mast_forest)?; for _ in 0..*count { - blocks.push(block.clone()); + mast_node_ids.push(repeat_node_id); } } Op::While { body, .. } => { - span.extract_span_into(&mut blocks); + if let Some(basic_block_id) = basic_block_builder.make_basic_block(mast_forest) + { + mast_node_ids.push(basic_block_id); + } - let block = self.compile_body(body.iter(), context, None)?; - let block = CodeBlock::new_loop(block); + let loop_body_node_id = + self.compile_body(body.iter(), context, None, mast_forest)?; - blocks.push(block); + let loop_node_id = { + let loop_node = MastNode::new_loop(loop_body_node_id, mast_forest); + mast_forest.ensure_node(loop_node) + }; + mast_node_ids.push(loop_node_id); } } } - span.extract_final_span_into(&mut blocks); - Ok(if blocks.is_empty() { - CodeBlock::new_span(vec![Operation::Noop]) + if let Some(basic_block_id) = basic_block_builder.into_basic_block(mast_forest) { + mast_node_ids.push(basic_block_id); + } + + Ok(if mast_node_ids.is_empty() { + let basic_block_node = MastNode::new_basic_block(vec![Operation::Noop]); + mast_forest.ensure_node(basic_block_node) } else { - combine_blocks(blocks) + combine_mast_node_ids(mast_node_ids, mast_forest) }) } @@ -755,6 +840,7 @@ impl Assembler { kind: InvokeKind, target: &InvocationTarget, context: &AssemblyContext, + mast_forest: &MastForest, ) -> Result { let current_proc = context.unwrap_current_procedure(); let caller = CallerInfo { @@ -769,7 +855,7 @@ impl Assembler { ResolvedTarget::Exact { gid } | ResolvedTarget::Resolved { gid, .. } => Ok(self .procedure_cache .get(gid) - .map(|p| p.mast_root()) + .map(|p| p.mast_root(mast_forest)) .expect("expected callee to have been compiled already")), } } @@ -782,71 +868,36 @@ struct BodyWrapper { epilogue: Vec, } -fn combine_blocks(mut blocks: Vec) -> CodeBlock { - debug_assert!(!blocks.is_empty(), "cannot combine empty block list"); - // merge consecutive Span blocks. - let mut merged_blocks: Vec = Vec::with_capacity(blocks.len()); - // Keep track of all the consecutive Span blocks and are merged together when - // there is a discontinuity. - let mut contiguous_spans: Vec = Vec::new(); - - blocks.drain(0..).for_each(|block| { - if block.is_span() { - contiguous_spans.push(block); - } else { - if !contiguous_spans.is_empty() { - merged_blocks.push(combine_spans(&mut contiguous_spans)); - } - merged_blocks.push(block); - } - }); - if !contiguous_spans.is_empty() { - merged_blocks.push(combine_spans(&mut contiguous_spans)); - } +fn combine_mast_node_ids( + mut mast_node_ids: Vec, + mast_forest: &mut MastForest, +) -> MastNodeId { + debug_assert!(!mast_node_ids.is_empty(), "cannot combine empty MAST node id list"); // build a binary tree of blocks joining them using JOIN blocks - let mut blocks = merged_blocks; - while blocks.len() > 1 { - let last_block = if blocks.len() % 2 == 0 { None } else { blocks.pop() }; + while mast_node_ids.len() > 1 { + let last_mast_node_id = if mast_node_ids.len() % 2 == 0 { + None + } else { + mast_node_ids.pop() + }; - let mut source_blocks = Vec::new(); - core::mem::swap(&mut blocks, &mut source_blocks); + let mut source_mast_node_ids = Vec::new(); + core::mem::swap(&mut mast_node_ids, &mut source_mast_node_ids); + + let mut source_mast_node_iter = source_mast_node_ids.drain(0..); + while let (Some(left), Some(right)) = + (source_mast_node_iter.next(), source_mast_node_iter.next()) + { + let join_mast_node = MastNode::new_join(left, right, mast_forest); + let join_mast_node_id = mast_forest.ensure_node(join_mast_node); - let mut source_block_iter = source_blocks.drain(0..); - while let (Some(left), Some(right)) = (source_block_iter.next(), source_block_iter.next()) { - blocks.push(CodeBlock::new_join([left, right])); + mast_node_ids.push(join_mast_node_id); } - if let Some(block) = last_block { - blocks.push(block); + if let Some(mast_node_id) = last_mast_node_id { + mast_node_ids.push(mast_node_id); } } - debug_assert!(!blocks.is_empty(), "no blocks"); - blocks.remove(0) -} - -/// Combines a vector of SPAN blocks into a single SPAN block. -/// -/// # Panics -/// Panics if any of the provided blocks is not a SPAN block. -fn combine_spans(spans: &mut Vec) -> CodeBlock { - if spans.len() == 1 { - return spans.remove(0); - } - - let mut ops = Vec::::new(); - let mut decorators = DecoratorList::new(); - spans.drain(0..).for_each(|block| { - if let CodeBlock::Span(span) = block { - for decorator in span.decorators() { - decorators.push((decorator.0 + ops.len(), decorator.1.clone())); - } - for batch in span.op_batches() { - ops.extend_from_slice(batch.ops()); - } - } else { - panic!("CodeBlock was expected to be a Span Block, got {block:?}."); - } - }); - CodeBlock::new_span_with_decorators(ops, decorators) + mast_node_ids.remove(0) } diff --git a/assembly/src/assembler/module_graph/mod.rs b/assembly/src/assembler/module_graph/mod.rs index 4b1e48a7aa..6c7456d486 100644 --- a/assembly/src/assembler/module_graph/mod.rs +++ b/assembly/src/assembler/module_graph/mod.rs @@ -18,9 +18,9 @@ use alloc::{ vec::Vec, }; use core::ops::Index; +use vm_core::Kernel; use smallvec::{smallvec, SmallVec}; -use vm_core::Kernel; use self::{ analysis::MaybeRewriteCheck, name_resolver::NameResolver, phantom::PhantomCall, diff --git a/assembly/src/assembler/module_graph/procedure_cache.rs b/assembly/src/assembler/module_graph/procedure_cache.rs index 0980c566ad..cfe72437f6 100644 --- a/assembly/src/assembler/module_graph/procedure_cache.rs +++ b/assembly/src/assembler/module_graph/procedure_cache.rs @@ -4,6 +4,7 @@ use alloc::{ vec::Vec, }; use core::{fmt, ops::Index}; +use vm_core::mast::MastForest; use crate::{ assembler::{GlobalProcedureIndex, ModuleIndex, Procedure}, @@ -33,7 +34,7 @@ use crate::{ /// As a result of this design choice, a unique [ProcedureCache] is associated with each context in /// play during compilation: the global assembler context has its own cache, and each /// [AssemblyContext] has its own cache. -#[derive(Default)] +#[derive(Default, Clone)] pub struct ProcedureCache { cache: Vec>>>, /// This is always the same length as `cache` @@ -142,12 +143,6 @@ impl ProcedureCache { self.by_mast_root.contains_key(hash) } - /// Returns an iterator over the non-empty entries in the cache - #[cfg(test)] - pub fn entries(&self) -> impl Iterator> + '_ { - self.cache.iter().flat_map(|m| m.iter().filter_map(|p| p.clone())) - } - /// Inserts the given [Procedure] into this cache, using the [GlobalProcedureIndex] as the /// cache key. /// @@ -162,8 +157,9 @@ impl ProcedureCache { &mut self, id: GlobalProcedureIndex, procedure: Arc, + mast_forest: &MastForest, ) -> Result<(), AssemblyError> { - let mast_root = procedure.mast_root(); + let mast_root = procedure.mast_root(mast_forest); // Make sure we can index to the cache slot for this procedure self.ensure_cache_slot_exists(id, procedure.path()); @@ -173,7 +169,9 @@ impl ProcedureCache { // If there is already a cache entry, but it conflicts with what we're trying to cache, // then raise an error. if let Some(cached) = self.get(id) { - if cached.mast_root() != mast_root || cached.num_locals() != procedure.num_locals() { + if cached.mast_root(mast_forest) != mast_root + || cached.num_locals() != procedure.num_locals() + { return Err(AssemblyError::ConflictingDefinitions { first: cached.fully_qualified_name().clone(), second: procedure.fully_qualified_name().clone(), @@ -194,7 +192,7 @@ impl ProcedureCache { // cache entry with the same MAST root: if let Some(cached) = self.get_by_mast_root(&mast_root) { // Sanity check - assert_eq!(cached.mast_root(), mast_root); + assert_eq!(cached.mast_root(mast_forest), mast_root); if cached.num_locals() != procedure.num_locals() { return Err(AssemblyError::ConflictingDefinitions { diff --git a/assembly/src/assembler/procedure.rs b/assembly/src/assembler/procedure.rs index 0d0ba911df..88224396da 100644 --- a/assembly/src/assembler/procedure.rs +++ b/assembly/src/assembler/procedure.rs @@ -5,7 +5,7 @@ use crate::{ diagnostics::SourceFile, LibraryPath, RpoDigest, SourceSpan, Spanned, }; -use vm_core::code_blocks::CodeBlock; +use vm_core::mast::{MastForest, MastNodeId, MerkleTreeNode}; pub type CallSet = BTreeSet; @@ -28,8 +28,8 @@ pub struct Procedure { path: FullyQualifiedProcedureName, visibility: Visibility, num_locals: u32, - /// The MAST for this procedure - code: CodeBlock, + /// The MAST node id for the root of this procedure + body_node_id: MastNodeId, /// The set of MAST roots called by this procedure callset: CallSet, } @@ -40,7 +40,7 @@ impl Procedure { path: FullyQualifiedProcedureName, visibility: Visibility, num_locals: u32, - code: CodeBlock, + body_node_id: MastNodeId, ) -> Self { Self { span: SourceSpan::default(), @@ -48,7 +48,7 @@ impl Procedure { path, visibility, num_locals, - code, + body_node_id, callset: Default::default(), } } @@ -103,13 +103,14 @@ impl Procedure { } /// Returns the root of this procedure's MAST. - pub fn mast_root(&self) -> RpoDigest { - self.code.hash() + pub fn mast_root(&self, mast_forest: &MastForest) -> RpoDigest { + let body_node = &mast_forest[self.body_node_id]; + body_node.digest() } - /// Returns a reference to the MAST of this procedure. - pub fn code(&self) -> &CodeBlock { - &self.code + /// Returns a reference to the MAST node ID of this procedure. + pub fn body_node_id(&self) -> MastNodeId { + self.body_node_id } /// Returns a reference to a set of all procedures (identified by their MAST roots) which may diff --git a/assembly/src/assembler/tests.rs b/assembly/src/assembler/tests.rs index 0f8a3dd785..5f99c93aa3 100644 --- a/assembly/src/assembler/tests.rs +++ b/assembly/src/assembler/tests.rs @@ -1,7 +1,9 @@ use alloc::{boxed::Box, vec::Vec}; +use vm_core::mast::{MastForest, MastNode, MerkleTreeNode}; -use super::{combine_blocks, Assembler, CodeBlock, Library, Operation}; +use super::{Assembler, Library, Operation}; use crate::{ + assembler::combine_mast_node_ids, ast::{Module, ModuleKind}, LibraryNamespace, Version, }; @@ -60,22 +62,26 @@ fn nested_blocks() { } } - let mut assembler = Assembler::with_kernel_from_module(KERNEL) + let assembler = Assembler::with_kernel_from_module(KERNEL) .unwrap() .with_library(&DummyLibrary::default()) .unwrap(); - // the assembler should have a single kernel proc in its cache before the compilation of the - // source - assert_eq!(assembler.procedure_cache().len(), 1); + // The expected `MastForest` for the program (that we will build by hand) + let mut expected_mast_forest = MastForest::new(); // fetch the kernel digest and store into a syscall block - let syscall = assembler - .procedure_cache() - .entries() - .next() - .map(|p| CodeBlock::new_syscall(p.mast_root())) - .unwrap(); + // + // Note: this assumes the current internal implementation detail that `assembler.mast_forest` + // contains the MAST nodes for the kernel after a call to + // `Assembler::with_kernel_from_module()`. + let syscall_foo_node_id = { + let kernel_foo_node = MastNode::new_basic_block(vec![Operation::Add]); + let kernel_foo_node_id = expected_mast_forest.ensure_node(kernel_foo_node); + + let syscall_node = MastNode::new_syscall(kernel_foo_node_id, &expected_mast_forest); + expected_mast_forest.ensure_node(syscall_node) + }; let program = r#" use.foo::bar @@ -115,36 +121,93 @@ fn nested_blocks() { let program = assembler.assemble(program).unwrap(); - let exec_bar = assembler - .procedure_cache() - .get_by_name(&"#exec::bar".parse().unwrap()) - .map(|p| CodeBlock::new_proxy(p.code().hash())) - .unwrap(); - - let exec_foo_bar_baz = assembler - .procedure_cache() - .get_by_name(&"foo::bar::baz".parse().unwrap()) - .map(|p| CodeBlock::new_proxy(p.code().hash())) - .unwrap(); - - let before = CodeBlock::new_span(vec![Operation::Push(2u32.into())]); - - let r#true = CodeBlock::new_span(vec![Operation::Push(3u32.into())]); - let r#false = CodeBlock::new_span(vec![Operation::Push(5u32.into())]); - let r#if = CodeBlock::new_split(r#true, r#false); - - let r#true = CodeBlock::new_span(vec![Operation::Push(7u32.into())]); - let r#false = CodeBlock::new_span(vec![Operation::Push(11u32.into())]); - let r#true = CodeBlock::new_split(r#true, r#false); - - let r#while = - CodeBlock::new_join([exec_bar, CodeBlock::new_span(vec![Operation::Push(23u32.into())])]); - let r#while = CodeBlock::new_loop(r#while); - let span = CodeBlock::new_span(vec![Operation::Push(13u32.into())]); - let r#false = CodeBlock::new_join([span, r#while]); - let nested = CodeBlock::new_split(r#true, r#false); - - let combined = combine_blocks(vec![before, r#if, nested, exec_foo_bar_baz, syscall]); - - assert_eq!(combined.hash(), program.hash()); + let exec_bar_node_id = { + // bar procedure + let basic_block_1 = MastNode::new_basic_block(vec![Operation::Push(17_u32.into())]); + let basic_block_1_id = expected_mast_forest.ensure_node(basic_block_1); + + // Basic block representing the `foo` procedure + let basic_block_2 = MastNode::new_basic_block(vec![Operation::Push(19_u32.into())]); + let basic_block_2_id = expected_mast_forest.ensure_node(basic_block_2); + + let join_node = + MastNode::new_join(basic_block_1_id, basic_block_2_id, &expected_mast_forest); + expected_mast_forest.ensure_node(join_node) + }; + + let exec_foo_bar_baz_node_id = { + // basic block representing foo::bar.baz procedure + let basic_block = MastNode::new_basic_block(vec![Operation::Push(29_u32.into())]); + expected_mast_forest.ensure_node(basic_block) + }; + + let before = { + let before_node = MastNode::new_basic_block(vec![Operation::Push(2u32.into())]); + expected_mast_forest.ensure_node(before_node) + }; + + let r#true1 = { + let r#true_node = MastNode::new_basic_block(vec![Operation::Push(3u32.into())]); + expected_mast_forest.ensure_node(r#true_node) + }; + let r#false1 = { + let r#false_node = MastNode::new_basic_block(vec![Operation::Push(5u32.into())]); + expected_mast_forest.ensure_node(r#false_node) + }; + let r#if1 = { + let r#if_node = MastNode::new_split(r#true1, r#false1, &expected_mast_forest); + expected_mast_forest.ensure_node(r#if_node) + }; + + let r#true3 = { + let r#true_node = MastNode::new_basic_block(vec![Operation::Push(7u32.into())]); + expected_mast_forest.ensure_node(r#true_node) + }; + let r#false3 = { + let r#false_node = MastNode::new_basic_block(vec![Operation::Push(11u32.into())]); + expected_mast_forest.ensure_node(r#false_node) + }; + let r#true2 = { + let r#if_node = MastNode::new_split(r#true3, r#false3, &expected_mast_forest); + expected_mast_forest.ensure_node(r#if_node) + }; + + let r#while = { + let push_basic_block_id = { + let push_basic_block = MastNode::new_basic_block(vec![Operation::Push(23u32.into())]); + expected_mast_forest.ensure_node(push_basic_block) + }; + let body_node_id = { + let body_node = + MastNode::new_join(exec_bar_node_id, push_basic_block_id, &expected_mast_forest); + + expected_mast_forest.ensure_node(body_node) + }; + + let loop_node = MastNode::new_loop(body_node_id, &expected_mast_forest); + expected_mast_forest.ensure_node(loop_node) + }; + let push_13_basic_block_id = { + let node = MastNode::new_basic_block(vec![Operation::Push(13u32.into())]); + expected_mast_forest.ensure_node(node) + }; + + let r#false2 = { + let node = MastNode::new_join(push_13_basic_block_id, r#while, &expected_mast_forest); + expected_mast_forest.ensure_node(node) + }; + let nested = { + let node = MastNode::new_split(r#true2, r#false2, &expected_mast_forest); + expected_mast_forest.ensure_node(node) + }; + + let combined_node_id = combine_mast_node_ids( + vec![before, r#if1, nested, exec_foo_bar_baz_node_id, syscall_foo_node_id], + &mut expected_mast_forest, + ); + expected_mast_forest.set_entrypoint(combined_node_id); + + let combined_node = &expected_mast_forest[combined_node_id]; + + assert_eq!(combined_node.digest(), program.entrypoint_digest().unwrap()); } diff --git a/assembly/src/ast/module.rs b/assembly/src/ast/module.rs index 1a5b6af251..af6544d0af 100644 --- a/assembly/src/ast/module.rs +++ b/assembly/src/ast/module.rs @@ -46,8 +46,8 @@ pub enum ModuleKind { /// A kernel is like a library module, but is special in a few ways: /// /// * Its code always executes in the root context, so it is stateful in a way that normal - /// libraries cannot replicate. This can be used to provide core services that would otherwise - /// not be possible to implement. + /// libraries cannot replicate. This can be used to provide core services that would + /// otherwise not be possible to implement. /// /// * The procedures exported from the kernel may be the target of the `syscall` instruction, /// and in fact _must_ be called that way. diff --git a/assembly/src/ast/tests.rs b/assembly/src/ast/tests.rs index 7a6e987a36..302197a442 100644 --- a/assembly/src/ast/tests.rs +++ b/assembly/src/ast/tests.rs @@ -1053,253 +1053,3 @@ fn assert_parsing_line_unexpected_token() { r#" help: expected "begin", or "const", or "export", or "proc", or "use", or end of file, or doc comment"# ); } - -// SERIALIZATION AND DESERIALIZATION TESTS -// ================================================================================================ - -#[cfg(feature = "nope")] -mod serialization { - - #[test] - fn test_ast_program_serde_simple() { - let source = "begin push.0xabc234 push.0 assertz end"; - assert_correct_program_serialization(source, true); - } - - #[test] - fn test_ast_program_serde_local_procs() { - let source = "\ - proc.foo.1 - loc_load.0 - end - proc.bar.2 - padw - end - begin - exec.foo - exec.bar - end"; - assert_correct_program_serialization(source, true); - } - - #[test] - fn test_ast_program_serde_exported_procs() { - let source = "\ - export.foo.1 - loc_load.0 - end - export.bar.2 - padw - end"; - assert_correct_module_serialization(source, true); - } - - #[test] - fn test_ast_program_serde_control_flow() { - let source = "\ - begin - repeat.3 - push.1 - push.0.1 - end - - if.true - and - loc_store.0 - else - padw - end - - while.true - push.5.7 - u32wrapping_add - loc_store.1 - push.0 - end - - repeat.3 - push.2 - u32overflowing_mul - end - - end"; - assert_correct_program_serialization(source, true); - } - - #[test] - fn test_ast_program_serde_imports_serialized() { - let source = "\ - use.std::math::u64 - use.std::crypto::fri - - begin - push.0 - push.1 - exec.u64::wrapping_add - end"; - assert_correct_program_serialization(source, true); - } - - #[test] - fn test_ast_program_serde_imports_not_serialized() { - let source = "\ - use.std::math::u64 - use.std::crypto::fri - - begin - push.0 - push.1 - exec.u64::wrapping_add - end"; - assert_correct_program_serialization(source, false); - } - - #[test] - fn test_ast_module_serde_imports_serialized() { - let source = "\ - use.std::math::u64 - use.std::crypto::fri - - proc.foo.2 - push.0 - push.1 - exec.u64::wrapping_add - end"; - assert_correct_module_serialization(source, true); - } - - #[test] - fn test_ast_module_serde_imports_not_serialized() { - let source = "\ - use.std::math::u64 - use.std::crypto::fri - - proc.foo.2 - push.0 - push.1 - exec.u64::wrapping_add - end"; - assert_correct_module_serialization(source, false); - } - - #[test] - fn test_repeat_with_constant_count() { - let source = "\ - const.A=3 - const.B=A*3+5 - - begin - repeat.A - push.1 - end - - repeat.B - push.0 - end - end"; - - assert_correct_program_serialization(source, false); - - let nodes: Vec = vec![ - Node::Repeat { - times: 3, - body: CodeBody::new(vec![Node::Instruction(Instruction::PushU8(1))]), - }, - Node::Repeat { - times: 14, - body: CodeBody::new(vec![Node::Instruction(Instruction::PushU8(0))]), - }, - ]; - - assert_program_output(source, BTreeMap::new(), nodes); - } - - /// Clears the module's imports. - /// - /// Serialization of imports is optional, so if they are not serialized, then they have to be - /// cleared before testing for equality - fn clear_imports_module(module: &mut ModuleAst) { - module.clear_imports(); - } - - /// Clears the program's imports. - /// - /// Serialization of imports is optional, so if they are not serialized, then they have to be - /// cleared before testing for equality - fn clear_imports_program(program: &mut ProgramAst) { - program.clear_imports(); - } - - fn assert_correct_program_serialization(source: &str, serialize_imports: bool) { - let program = ProgramAst::parse(source).unwrap(); - - // assert the correct program serialization - let program_serialized = program.to_bytes(AstSerdeOptions::new(serialize_imports)); - let mut program_deserialized = - ProgramAst::from_bytes(program_serialized.as_slice()).unwrap(); - let mut clear_program = clear_procs_loc_program(program.clone()); - if !serialize_imports { - clear_imports_program(&mut clear_program); - } - assert_eq!(clear_program, program_deserialized); - - // assert the correct locations serialization - let mut locations = Vec::new(); - program.write_source_locations(&mut locations); - - // assert empty locations - { - let mut locations = program_deserialized.source_locations(); - let start = locations.next().unwrap(); - assert_eq!(start, &SourceLocation::default()); - assert!(locations.next().is_none()); - } - - program_deserialized - .load_source_locations(&mut SliceReader::new(&locations)) - .unwrap(); - - let program_deserialized = if !serialize_imports { - program_deserialized.with_import_info(program.import_info().clone()) - } else { - program_deserialized - }; - - assert_eq!(program, program_deserialized); - } - - fn assert_correct_module_serialization(source: &str, serialize_imports: bool) { - let module = ModuleAst::parse(source).unwrap(); - let module_serialized = module.to_bytes(AstSerdeOptions::new(serialize_imports)); - let mut module_deserialized = ModuleAst::from_bytes(module_serialized.as_slice()).unwrap(); - let mut clear_module = clear_procs_loc_module(module.clone()); - if !serialize_imports { - clear_imports_module(&mut clear_module); - } - assert_eq!(clear_module, module_deserialized); - - // assert the correct locations serialization - let mut locations = Vec::new(); - module.write_source_locations(&mut locations); - - // assert module locations are empty - module_deserialized.procs().iter().for_each(|m| { - let mut locations = m.source_locations(); - let start = locations.next().unwrap(); - assert_eq!(start, &SourceLocation::default()); - assert!(locations.next().is_none()); - }); - - module_deserialized - .load_source_locations(&mut SliceReader::new(&locations)) - .unwrap(); - - module_deserialized = if !serialize_imports { - module_deserialized.with_import_info(module.import_info().clone()) - } else { - module_deserialized - }; - - assert_eq!(module, module_deserialized); - } -} diff --git a/assembly/src/testing.rs b/assembly/src/testing.rs index 990f2c2ec5..4bc3e20310 100644 --- a/assembly/src/testing.rs +++ b/assembly/src/testing.rs @@ -1,6 +1,6 @@ use crate::{ - assembler::{Assembler, AssemblyContext, ProcedureCache}, - ast::{Form, FullyQualifiedProcedureName, Module, ModuleKind}, + assembler::{Assembler, AssemblyContext}, + ast::{Form, Module, ModuleKind}, diagnostics::{ reporting::{set_hook, ReportHandlerOpts}, Report, SourceFile, @@ -13,7 +13,7 @@ use crate::diagnostics::reporting::set_panic_hook; use alloc::{boxed::Box, string::String, sync::Arc, vec::Vec}; use core::fmt; -use vm_core::{utils::DisplayHex, Program}; +use vm_core::Program; /// Represents a pattern for matching text abstractly /// for use in asserting contents of complex diagnostics @@ -308,7 +308,10 @@ impl TestContext { /// module represented in `source`. #[track_caller] pub fn assemble(&mut self, source: impl Compile) -> Result { - self.assembler.assemble(source) + self.assembler + .clone() + .assemble(source) + .map(|mast_forest| mast_forest.try_into().unwrap()) } /// Compile a module from `source`, with the fully-qualified name `path`, to MAST, returning @@ -329,32 +332,4 @@ impl TestContext { }; self.assembler.assemble_module(module, options, &mut context) } - - /// Get a reference to the [ProcedureCache] of the [Assembler] constructed by this context. - pub fn procedure_cache(&self) -> &ProcedureCache { - self.assembler.procedure_cache() - } - - /// Display the MAST root associated with `name` in the procedure cache of the [Assembler] - /// constructed by this context. - /// - /// It is expected that the module containing `name` was previously compiled by the assembler, - /// and is thus in the cache. This function will panic if that is not the case. - pub fn display_digest_from_cache( - &self, - name: &FullyQualifiedProcedureName, - ) -> impl fmt::Display { - self.procedure_cache() - .get_by_name(name) - .map(|p| p.code().hash()) - .map(DisplayDigest) - .unwrap_or_else(|| panic!("procedure '{}' is not in the procedure cache", name)) - } -} - -struct DisplayDigest(RpoDigest); -impl fmt::Display for DisplayDigest { - fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { - write!(f, "{:#x}", DisplayHex(self.0.as_bytes().as_slice())) - } } diff --git a/assembly/src/tests.rs b/assembly/src/tests.rs index 6481a9cddf..309dade35a 100644 --- a/assembly/src/tests.rs +++ b/assembly/src/tests.rs @@ -39,7 +39,7 @@ fn simple_instructions() -> TestResult { let program = context.assemble(source)?; let expected = "\ begin - span pad eqz assert(0) end + basic_block pad eqz assert(0) end end"; assert_str_eq!(format!("{program}"), expected); @@ -47,7 +47,7 @@ end"; let program = context.assemble(source)?; let expected = "\ begin - span push(10) push(50) push(2) u32madd drop end + basic_block push(10) push(50) push(2) u32madd drop end end"; assert_str_eq!(format!("{program}"), expected); @@ -55,7 +55,7 @@ end"; let program = context.assemble(source)?; let expected = "\ begin - span push(10) push(50) push(2) u32add3 drop end + basic_block push(10) push(50) push(2) u32add3 drop end end"; assert_str_eq!(format!("{program}"), expected); Ok(()) @@ -68,7 +68,7 @@ fn empty_program() -> TestResult { let mut context = TestContext::default(); let source = source_file!("begin end"); let program = context.assemble(source)?; - let expected = "begin span noop end end"; + let expected = "begin basic_block noop end end"; assert_eq!(expected, format!("{}", program)); Ok(()) } @@ -83,9 +83,9 @@ fn empty_if() -> TestResult { let expected = "\ begin if.true - span noop end + basic_block noop end else - span noop end + basic_block noop end end end"; assert_str_eq!(format!("{}", program), expected); @@ -102,7 +102,7 @@ fn empty_while() -> TestResult { let expected = "\ begin while.true - span noop end + basic_block noop end end end"; assert_str_eq!(format!("{}", program), expected); @@ -118,27 +118,27 @@ fn empty_repeat() -> TestResult { let program = context.assemble(source)?; let expected = "\ begin - span noop noop noop noop noop end + basic_block noop noop noop noop noop end end"; assert_str_eq!(format!("{}", program), expected); Ok(()) } #[test] -fn single_span() -> TestResult { +fn single_basic_block() -> TestResult { let mut context = TestContext::default(); let source = source_file!("begin push.1 push.2 add end"); let program = context.assemble(source)?; let expected = "\ begin - span pad incr push(2) add end + basic_block pad incr push(2) add end end"; assert_str_eq!(format!("{program}"), expected); Ok(()) } #[test] -fn span_and_simple_if() -> TestResult { +fn basic_block_and_simple_if() -> TestResult { let mut context = TestContext::default(); // if with else @@ -147,11 +147,11 @@ fn span_and_simple_if() -> TestResult { let expected = "\ begin join - span push(2) push(3) end + basic_block push(2) push(3) end if.true - span add end + basic_block add end else - span mul end + basic_block mul end end end end"; @@ -163,11 +163,11 @@ end"; let expected = "\ begin join - span push(2) push(3) end + basic_block push(2) push(3) end if.true - span add end + basic_block add end else - span noop end + basic_block noop end end end end"; @@ -226,8 +226,8 @@ fn simple_main_call() -> TestResult { #[test] fn call_without_path() -> TestResult { let mut context = TestContext::default(); + // compile first module - //context.add_module_from_source( context.assemble_module( "account_code1".parse().unwrap(), source_file!( @@ -246,7 +246,6 @@ fn call_without_path() -> TestResult { //--------------------------------------------------------------------------------------------- // compile second module - //context.add_module_from_source( context.assemble_module( "account_code2".parse().unwrap(), source_file!( @@ -400,7 +399,7 @@ fn simple_constant() -> TestResult { ); let expected = "\ begin - span push(7) end + basic_block push(7) end end"; let program = context.assemble(source)?; assert_str_eq!(format!("{program}"), expected); @@ -419,7 +418,7 @@ fn multiple_constants_push() -> TestResult { ); let expected = "\ begin - span push(21) push(64) push(44) push(72) end + basic_block push(21) push(64) push(44) push(72) end end"; let program = context.assemble(source)?; assert_str_eq!(format!("{program}"), expected); @@ -438,7 +437,7 @@ fn constant_numeric_expression() -> TestResult { ); let expected = "\ begin - span push(26) end + basic_block push(26) end end"; let program = context.assemble(source)?; assert_str_eq!(format!("{program}"), expected); @@ -459,7 +458,7 @@ fn constant_alphanumeric_expression() -> TestResult { ); let expected = "\ begin - span push(21) end + basic_block push(21) end end"; let program = context.assemble(source)?; assert_str_eq!(format!("{program}"), expected); @@ -478,7 +477,7 @@ fn constant_hexadecimal_value() -> TestResult { ); let expected = "\ begin - span push(255) end + basic_block push(255) end end"; let program = context.assemble(source)?; assert_str_eq!(format!("{program}"), expected); @@ -497,7 +496,7 @@ fn constant_field_division() -> TestResult { ); let expected = "\ begin - span push(2) end + basic_block push(2) end end"; let program = context.assemble(source)?; assert_str_eq!(format!("{program}"), expected); @@ -908,7 +907,7 @@ fn assert_with_code() -> TestResult { let expected = "\ begin - span assert(0) assert(1) assert(2) end + basic_block assert(0) assert(1) assert(2) end end"; assert_str_eq!(format!("{program}"), expected); Ok(()) @@ -932,7 +931,7 @@ fn assertz_with_code() -> TestResult { let expected = "\ begin - span eqz assert(0) eqz assert(1) eqz assert(2) end + basic_block eqz assert(0) eqz assert(1) eqz assert(2) end end"; assert_str_eq!(format!("{program}"), expected); Ok(()) @@ -956,7 +955,7 @@ fn assert_eq_with_code() -> TestResult { let expected = "\ begin - span eq assert(0) eq assert(1) eq assert(2) end + basic_block eq assert(0) eq assert(1) eq assert(2) end end"; assert_str_eq!(format!("{program}"), expected); Ok(()) @@ -980,7 +979,7 @@ fn assert_eqw_with_code() -> TestResult { let expected = "\ begin - span + basic_block movup4 eq assert(0) @@ -1038,7 +1037,7 @@ fn u32assert_with_code() -> TestResult { let expected = "\ begin - span + basic_block pad u32assert2(0) drop @@ -1072,7 +1071,7 @@ fn u32assert2_with_code() -> TestResult { let expected = "\ begin - span u32assert2(0) u32assert2(1) u32assert2(2) end + basic_block u32assert2(0) u32assert2(1) u32assert2(2) end end"; assert_str_eq!(format!("{program}"), expected); Ok(()) @@ -1096,7 +1095,7 @@ fn u32assertw_with_code() -> TestResult { let expected = "\ begin - span + basic_block u32assert2(0) movup3 movup3 @@ -1140,7 +1139,7 @@ fn mtree_verify_with_code() -> TestResult { let expected = "\ begin - span mpverify(0) mpverify(1) mpverify(2) end + basic_block mpverify(0) mpverify(1) mpverify(2) end end"; assert_str_eq!(format!("{program}"), expected); Ok(()) @@ -1170,26 +1169,32 @@ fn nested_control_blocks() -> TestResult { begin join join - span push(2) push(3) end + basic_block push(2) push(3) end if.true join - span add end + basic_block add end while.true - span push(7) push(11) add end + basic_block push(7) push(11) add end end end else join - span mul push(8) push(8) end - if.true - span mul end - else - span noop end + join + basic_block mul end + basic_block push(8) end + end + join + basic_block push(8) end + if.true + basic_block mul end + else + basic_block noop end + end end end end end - span push(3) add end + basic_block push(3) add end end end"; assert_str_eq!(format!("{program}"), expected); @@ -1205,38 +1210,17 @@ fn program_with_one_procedure() -> TestResult { let source = source_file!("proc.foo push.3 push.7 mul end begin push.2 push.3 add exec.foo end"); let program = context.assemble(source)?; - let foo = context.display_digest_from_cache(&"#exec::foo".parse().unwrap()); - let expected = format!( - "\ + let expected = "\ begin join - span push(2) push(3) add end - proxy.{foo} + basic_block push(2) push(3) add end + basic_block push(3) push(7) mul end end -end" - ); +end"; assert_str_eq!(format!("{program}"), expected); Ok(()) } -// TODO(pauls): Do we want to support this in the surface MASM syntax? -#[test] -#[ignore] -fn program_with_one_empty_procedure() -> TestResult { - let mut context = TestContext::default(); - let source = source_file!("proc.foo end begin exec.foo end"); - let program = context.assemble(source)?; - let foo = context.display_digest_from_cache(&"#exec::foo".parse().unwrap()); - let expected = format!( - "\ -begin - proxy.{foo} -end" - ); - assert_str_eq!(format!("{}", program), expected); - Ok(()) -} - #[test] fn program_with_nested_procedure() -> TestResult { let mut context = TestContext::default(); @@ -1247,26 +1231,28 @@ fn program_with_nested_procedure() -> TestResult { begin push.2 push.4 add exec.foo push.11 exec.bar sub end" ); let program = context.assemble(source)?; - let foo = context.display_digest_from_cache(&"#exec::foo".parse().unwrap()); - let bar = context.display_digest_from_cache(&"#exec::bar".parse().unwrap()); - let expected = format!( - "\ + let expected = "\ begin join join join - span push(2) push(4) add end - proxy.{foo} + basic_block push(2) push(4) add end + basic_block push(3) push(7) mul end end join - span push(11) end - proxy.{bar} + basic_block push(11) end + join + join + basic_block push(5) end + basic_block push(3) push(7) mul end + end + basic_block add end + end end end - span neg add end + basic_block neg add end end -end" - ); +end"; assert_str_eq!(format!("{program}"), expected); Ok(()) } @@ -1288,16 +1274,27 @@ fn program_with_proc_locals() -> TestResult { end" ); let program = context.assemble(source)?; - let foo = context.display_digest_from_cache(&"#exec::foo".parse().unwrap()); - let expected = format!( - "\ + let expected = "\ begin join - span push(4) push(3) push(2) end - proxy.{foo} + basic_block push(4) push(3) push(2) end + basic_block + push(1) + fmpupdate + pad + fmpadd + mstore + drop + add + pad + fmpadd + mload + mul + push(18446744069414584320) + fmpupdate + end end -end" - ); +end"; assert_str_eq!(format!("{program}"), expected); Ok(()) } @@ -1407,6 +1404,8 @@ fn program_with_invalid_rpo_digest_call() { ); } +/// Phantom calls are currently not implemented. Re-enable this test once they are implemented. +#[ignore] #[test] fn program_with_phantom_mast_call() -> TestResult { let mut context = TestContext::default(); @@ -1416,7 +1415,7 @@ fn program_with_phantom_mast_call() -> TestResult { let ast = context.parse_program(source)?; // phantom calls not allowed - let mut assembler = Assembler::default().with_debug_mode(true); + let assembler = Assembler::default().with_debug_mode(true); let mut context = AssemblyContext::for_program(ast.path()).with_phantom_calls(false); let err = assembler @@ -1434,6 +1433,7 @@ fn program_with_phantom_mast_call() -> TestResult { ); // phantom calls allowed + let assembler = Assembler::default().with_debug_mode(true); let mut context = AssemblyContext::for_program(ast.path()).with_phantom_calls(true); assembler.assemble_in_context(ast, &mut context)?; Ok(()) @@ -1469,25 +1469,42 @@ fn program_with_one_import_and_hex_call() -> TestResult { begin push.4 push.3 exec.u256::iszero_unsafe - call.0xc2545da99d3a1f3f38d957c7893c44d78998d8ea8b11aba7e22c8c2b2a213dae + call.0x20234ee941e53a15886e733cc8e041198c6e90d2a16ea18ce1030e8c3596dd38 end"# )); let program = context.assemble(source)?; - let iszero_unsafe = - context.display_digest_from_cache(&"dummy::math::u256::iszero_unsafe".parse().unwrap()); - let expected = format!( - "\ + let expected = "\ begin join join - span push(4) push(3) end - proxy.{iszero_unsafe} + basic_block push(4) push(3) end + join + join + join + basic_block eqz end + basic_block swap eqz and end + end + join + basic_block swap eqz and end + basic_block swap eqz and end + end + end + join + join + basic_block swap eqz and end + basic_block swap eqz and end + end + join + basic_block swap eqz and end + basic_block swap eqz and end + end + end + end end - call.0xc2545da99d3a1f3f38d957c7893c44d78998d8ea8b11aba7e22c8c2b2a213dae + call.0x20234ee941e53a15886e733cc8e041198c6e90d2a16ea18ce1030e8c3596dd38 end -end" - ); +end"; assert_str_eq!(format!("{program}"), expected); Ok(()) } @@ -1602,22 +1619,16 @@ fn program_with_reexported_proc_in_same_library() -> TestResult { end"# )); let program = context.assemble(source)?; - let checked_eqz = - context.display_digest_from_cache(&"dummy1::math::u64::checked_eqz".parse().unwrap()); - let notchecked_eqz = - context.display_digest_from_cache(&"dummy1::math::u64::unchecked_eqz".parse().unwrap()); - let expected = format!( - "\ + let expected = "\ begin join join - span push(4) push(3) end - proxy.{checked_eqz} + basic_block push(4) push(3) end + basic_block u32assert2(0) eqz swap eqz and end end - proxy.{notchecked_eqz} + basic_block eqz swap eqz and end end -end" - ); +end"; assert_str_eq!(format!("{program}"), expected); Ok(()) } @@ -1674,22 +1685,16 @@ fn program_with_reexported_proc_in_another_library() -> TestResult { )); let program = context.assemble(source)?; - let checked_eqz = - context.display_digest_from_cache(&"dummy2::math::u64::checked_eqz".parse().unwrap()); - let notchecked_eqz = - context.display_digest_from_cache(&"dummy2::math::u64::unchecked_eqz".parse().unwrap()); - let expected = format!( - "\ + let expected = "\ begin join join - span push(4) push(3) end - proxy.{checked_eqz} + basic_block push(4) push(3) end + basic_block u32assert2(0) eqz swap eqz and end end - proxy.{notchecked_eqz} + basic_block eqz swap eqz and end end -end" - ); +end"; assert_str_eq!(format!("{program}"), expected); // when the re-exported proc is part of a different library and the library is not passed to @@ -1745,17 +1750,24 @@ fn module_alias() -> TestResult { ); let program = context.assemble(source)?; - let checked_add = - context.display_digest_from_cache(&"dummy::math::u64::checked_add".parse().unwrap()); - let expected = format!( - "\ + let expected = "\ begin join - span pad incr pad push(2) pad end - proxy.{checked_add} + basic_block pad incr pad push(2) pad end + basic_block + swap + movup3 + u32assert2(0) + u32add + movup3 + movup3 + u32assert2(0) + u32add3 + eqz + assert(0) + end end -end" - ); +end"; assert_str_eq!(format!("{program}"), expected); // --- invalid module alias ----------------------------------------------- @@ -1889,7 +1901,7 @@ fn comment_simple() -> TestResult { let program = context.assemble(source)?; let expected = "\ begin - span pad incr push(2) add end + basic_block pad incr push(2) add end end"; assert_str_eq!(format!("{program}"), expected); Ok(()) @@ -1918,26 +1930,32 @@ fn comment_in_nested_control_blocks() -> TestResult { begin join join - span pad incr push(2) end + basic_block pad incr push(2) end if.true join - span add end + basic_block add end while.true - span push(7) push(11) add end + basic_block push(7) push(11) add end end end else join - span mul push(8) push(8) end - if.true - span mul end - else - span noop end + join + basic_block mul end + basic_block push(8) end + end + join + basic_block push(8) end + if.true + basic_block mul end + else + basic_block noop end + end end end end end - span push(3) add end + basic_block push(3) add end end end"; assert_str_eq!(format!("{program}"), expected); @@ -1951,7 +1969,7 @@ fn comment_before_program() -> TestResult { let program = context.assemble(source)?; let expected = "\ begin - span pad incr push(2) add end + basic_block pad incr push(2) add end end"; assert_str_eq!(format!("{program}"), expected); Ok(()) @@ -1964,7 +1982,7 @@ fn comment_after_program() -> TestResult { let program = context.assemble(source)?; let expected = "\ begin - span pad incr push(2) add end + basic_block pad incr push(2) add end end"; assert_str_eq!(format!("{program}"), expected); Ok(()) diff --git a/core/src/errors.rs b/core/src/errors.rs index 5e4d0428e1..a3c01446c6 100644 --- a/core/src/errors.rs +++ b/core/src/errors.rs @@ -40,3 +40,12 @@ pub enum KernelError { #[error("kernel can have at most {0} procedures, received {1}")] TooManyProcedures(usize, usize), } + +// PROGRAM ERROR +// ================================================================================================ + +#[derive(Clone, Debug, thiserror::Error)] +pub enum ProgramError { + #[error("tried to create a program from a MAST forest with no entrypoint")] + NoEntrypoint, +} diff --git a/core/src/kernel.rs b/core/src/kernel.rs new file mode 100644 index 0000000000..0f225822de --- /dev/null +++ b/core/src/kernel.rs @@ -0,0 +1,73 @@ +use crate::{ + errors::KernelError, + utils::{ByteReader, ByteWriter, Deserializable, DeserializationError, Serializable}, +}; +use alloc::vec::Vec; +use miden_crypto::hash::rpo::RpoDigest; + +/// A list of procedure hashes defining a VM kernel. +/// +/// The internally-stored list always has a consistent order, regardless of the order of procedure +/// list used to instantiate a kernel. +#[derive(Debug, Clone, Default, PartialEq, Eq)] +pub struct Kernel(Vec); + +pub const MAX_KERNEL_PROCEDURES: usize = u8::MAX as usize; + +impl Kernel { + /// Returns a new [Kernel] instantiated with the specified procedure hashes. + pub fn new(proc_hashes: &[RpoDigest]) -> Result { + if proc_hashes.len() > MAX_KERNEL_PROCEDURES { + Err(KernelError::TooManyProcedures(MAX_KERNEL_PROCEDURES, proc_hashes.len())) + } else { + let mut hashes = proc_hashes.to_vec(); + hashes.sort_by_key(|v| v.as_bytes()); // ensure consistent order + + let duplicated = hashes.windows(2).any(|data| data[0] == data[1]); + + if duplicated { + Err(KernelError::DuplicatedProcedures) + } else { + Ok(Self(hashes)) + } + } + } + + /// Returns true if this kernel does not contain any procedures. + pub fn is_empty(&self) -> bool { + self.0.is_empty() + } + + /// Returns true if a procedure with the specified hash belongs to this kernel. + pub fn contains_proc(&self, proc_hash: RpoDigest) -> bool { + self.0.binary_search(&proc_hash).is_ok() + } + + /// Returns a list of procedure hashes contained in this kernel. + pub fn proc_hashes(&self) -> &[RpoDigest] { + &self.0 + } +} + +// this is required by AIR as public inputs will be serialized with the proof +impl Serializable for Kernel { + fn write_into(&self, target: &mut W) { + debug_assert!(self.0.len() <= MAX_KERNEL_PROCEDURES); + target.write_usize(self.0.len()); + target.write_many(&self.0) + } +} + +impl Deserializable for Kernel { + fn read_from(source: &mut R) -> Result { + let len = source.read_usize()?; + if len > MAX_KERNEL_PROCEDURES { + return Err(DeserializationError::InvalidValue(format!( + "Number of kernel procedures can not be more than {}, but {} was provided", + MAX_KERNEL_PROCEDURES, len + ))); + } + let kernel = source.read_many::(len)?; + Ok(Self(kernel)) + } +} diff --git a/core/src/lib.rs b/core/src/lib.rs index aa905a4b43..6422eb5725 100644 --- a/core/src/lib.rs +++ b/core/src/lib.rs @@ -52,6 +52,12 @@ assertion failed: `(left matches right)` pub mod chiplets; pub mod errors; +mod program; +pub use program::{Program, ProgramInfo}; + +mod kernel; +pub use kernel::Kernel; + pub use miden_crypto::{Word, EMPTY_WORD, ONE, WORD_SIZE, ZERO}; pub mod crypto { pub mod merkle { @@ -82,6 +88,8 @@ pub mod crypto { } } +pub mod mast; + pub use math::{ fields::{f64::BaseElement as Felt, QuadExtension}, polynom, ExtensionOf, FieldElement, StarkField, ToElements, @@ -91,9 +99,6 @@ pub mod prettier { pub use miden_formatting::{prettier::*, pretty_via_display, pretty_via_to_string}; } -mod program; -pub use program::{blocks as code_blocks, CodeBlockTable, Kernel, Program, ProgramInfo}; - mod operations; pub use operations::{ AdviceInjector, AssemblyOp, DebugOptions, Decorator, DecoratorIterator, DecoratorList, diff --git a/core/src/mast/mod.rs b/core/src/mast/mod.rs new file mode 100644 index 0000000000..070e36ef40 --- /dev/null +++ b/core/src/mast/mod.rs @@ -0,0 +1,155 @@ +use core::{fmt, ops::Index}; + +use alloc::{collections::BTreeMap, vec::Vec}; +use miden_crypto::hash::rpo::RpoDigest; + +mod node; +pub use node::{ + get_span_op_group_count, BasicBlockNode, CallNode, DynNode, JoinNode, LoopNode, MastNode, + OpBatch, SplitNode, OP_BATCH_SIZE, OP_GROUP_SIZE, +}; + +use crate::Kernel; + +#[cfg(test)] +mod tests; + +/// Encapsulates the behavior that a [`MastNode`] (and all its variants) is expected to have. +pub trait MerkleTreeNode { + fn digest(&self) -> RpoDigest; + fn to_display<'a>(&'a self, mast_forest: &'a MastForest) -> impl fmt::Display + 'a; +} + +/// An opaque handle to a [`MastNode`] in some [`MastForest`]. It is the responsibility of the user +/// to use a given [`MastNodeId`] with the corresponding [`MastForest`]. +/// +/// Note that since a [`MastForest`] enforces the invariant that equal [`MastNode`]s MUST have equal +/// [`MastNodeId`]s, [`MastNodeId`] equality can be used to determine equality of the underlying +/// [`MastNode`]. +#[derive(Debug, Clone, Copy, PartialEq, Eq, PartialOrd, Ord, Hash)] +pub struct MastNodeId(u32); + +impl fmt::Display for MastNodeId { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + write!(f, "MastNodeId({})", self.0) + } +} + +// MAST FOREST +// =============================================================================================== + +#[derive(Clone, Debug, Default)] +pub struct MastForest { + /// All of the blocks local to the trees comprising the MAST forest + nodes: Vec, + node_id_by_hash: BTreeMap, + + /// The "entrypoint", when set, is the root of the entire forest, i.e. a path exists from this + /// node to all other roots in the forest. This corresponds to the executable entry point. + /// Whether or not the entrypoint is set distinguishes a MAST which is executable, versus a + /// MAST which represents a library. + entrypoint: Option, + kernel: Kernel, +} + +/// Constructors +impl MastForest { + /// Creates a new empty [`MastForest`]. + pub fn new() -> Self { + Self::default() + } +} + +/// Mutators +impl MastForest { + /// Adds a node to the forest, and returns the [`MastNodeId`] associated with it. + /// + /// If a [`MastNode`] which is equal to the current node was previously added, the previously + /// returned [`MastNodeId`] will be returned. This enforces this invariant that equal + /// [`MastNode`]s have equal [`MastNodeId`]s. + pub fn ensure_node(&mut self, node: MastNode) -> MastNodeId { + let node_digest = node.digest(); + + if let Some(node_id) = self.node_id_by_hash.get(&node_digest) { + // node already exists in the forest; return previously assigned id + *node_id + } else { + let new_node_id = + MastNodeId(self.nodes.len().try_into().expect( + "invalid node id: exceeded maximum number of nodes in a single forest", + )); + + self.node_id_by_hash.insert(node_digest, new_node_id); + self.nodes.push(node); + + new_node_id + } + } + + /// Sets the kernel for this forest. + /// + /// The kernel MUST have been compiled using this [`MastForest`]; that is, all kernel procedures + /// must be present in this forest. + pub fn set_kernel(&mut self, kernel: Kernel) { + #[cfg(debug_assertions)] + for proc_hash in kernel.proc_hashes() { + assert!(self.node_id_by_hash.contains_key(proc_hash)); + } + + self.kernel = kernel; + } + + /// Sets the entrypoint for this forest. + pub fn set_entrypoint(&mut self, entrypoint: MastNodeId) { + self.entrypoint = Some(entrypoint); + } +} + +/// Public accessors +impl MastForest { + /// Returns the kernel associated with this forest. + pub fn kernel(&self) -> &Kernel { + &self.kernel + } + + /// Returns the entrypoint associated with this forest, if any. + pub fn entrypoint(&self) -> Option { + self.entrypoint + } + + /// A convenience method that provides the hash of the entrypoint, if any. + pub fn entrypoint_digest(&self) -> Option { + self.entrypoint.map(|entrypoint| self[entrypoint].digest()) + } + + /// Returns the [`MastNode`] associated with the provided [`MastNodeId`] if valid, or else + /// `None`. + /// + /// This is the faillible version of indexing (e.g. `mast_forest[node_id]`). + #[inline(always)] + pub fn get_node_by_id(&self, node_id: MastNodeId) -> Option<&MastNode> { + let idx = node_id.0 as usize; + + self.nodes.get(idx) + } + + /// Returns the [`MastNodeId`] associated with a given digest, if any. + /// + /// That is, every [`MastNode`] hashes to some digest. If there exists a [`MastNode`] in the + /// forest that hashes to this digest, then its id is returned. + #[inline(always)] + pub fn get_node_id_by_digest(&self, digest: RpoDigest) -> Option { + self.node_id_by_hash.get(&digest).copied() + } +} + +impl Index for MastForest { + type Output = MastNode; + + #[inline(always)] + fn index(&self, node_id: MastNodeId) -> &Self::Output { + let idx = node_id.0 as usize; + + &self.nodes[idx] + } +} diff --git a/core/src/program/blocks/span_block.rs b/core/src/mast/node/basic_block_node.rs similarity index 87% rename from core/src/program/blocks/span_block.rs rename to core/src/mast/node/basic_block_node.rs index c7ae5c8caa..2d4b4a0313 100644 --- a/core/src/program/blocks/span_block.rs +++ b/core/src/mast/node/basic_block_node.rs @@ -1,10 +1,15 @@ -use alloc::vec::Vec; use core::fmt; +use alloc::vec::Vec; +use miden_crypto::{hash::rpo::RpoDigest, Felt, ZERO}; +use miden_formatting::prettier::PrettyPrint; use winter_utils::flatten_slice_elements; -use super::{hasher, Digest, Felt, Operation}; -use crate::{DecoratorIterator, DecoratorList, ZERO}; +use crate::{ + chiplets::hasher, + mast::{MastForest, MerkleTreeNode}, + DecoratorIterator, DecoratorList, Operation, +}; // CONSTANTS // ================================================================================================ @@ -15,20 +20,18 @@ pub const GROUP_SIZE: usize = 9; /// Maximum number of groups per batch. pub const BATCH_SIZE: usize = 8; -/// Maximum number of operations which can fit into a single operation batch. -const MAX_OPS_PER_BATCH: usize = GROUP_SIZE * BATCH_SIZE; - -// SPAN BLOCK +// BASIC BLOCK NODE // ================================================================================================ + /// Block for a linear sequence of operations (i.e., no branching or loops). /// /// Executes its operations in order. Fails if any of the operations fails. /// -/// A span is composed of operation batches, operation batches are composed of operation groups, -/// operation groups encode the VM's operations and immediate values. These values are created -/// according to these rules: +/// A basic block is composed of operation batches, operation batches are composed of operation +/// groups, operation groups encode the VM's operations and immediate values. These values are +/// created according to these rules: /// -/// - A span contains one or more batches. +/// - A basic block contains one or more batches. /// - A batch contains exactly 8 groups. /// - A group contains exactly 9 operations or 1 immediate value. /// - NOOPs are used to fill a group or batch when necessary. @@ -43,28 +46,33 @@ const MAX_OPS_PER_BATCH: usize = GROUP_SIZE * BATCH_SIZE; /// - Second batch: First group with the last push opcode and 8 zero-paddings packed together, /// followed by one immediate and 6 padding groups. /// -/// The hash of a span block is: +/// The hash of a basic block is: /// -/// > hash(batches, domain=SPAN_DOMAIN) +/// > hash(batches, domain=BASIC_BLOCK_DOMAIN) /// -/// Where `batches` is the concatenation of each `batch` in the span, and each batch is 8 field -/// elements (512 bits). -#[derive(Clone, Debug, PartialEq, Eq)] -pub struct Span { +/// Where `batches` is the concatenation of each `batch` in the basic block, and each batch is 8 +/// field elements (512 bits). +#[derive(Debug, Clone, PartialEq, Eq)] +pub struct BasicBlockNode { + /// The primitive operations contained in this basic block. + /// + /// The operations are broken up into batches of 8 groups, with each group containing up to 9 + /// operations, or a single immediates. Thus the maximum size of each batch is 72 operations. + /// Multiple batches are used for blocks consisting of more than 72 operations. op_batches: Vec, - hash: Digest, + digest: RpoDigest, decorators: DecoratorList, } -impl Span { - // CONSTANTS - // -------------------------------------------------------------------------------------------- - /// The domain of the span block (used for control block hashing). +/// Constants +impl BasicBlockNode { + /// The domain of the basic block node (used for control block hashing). pub const DOMAIN: Felt = ZERO; +} - // CONSTRUCTOR - // -------------------------------------------------------------------------------------------- - /// Returns a new [Span] block instantiated with the specified operations. +/// Constructors +impl BasicBlockNode { + /// Returns a new [`BasicBlockNode`] instantiated with the specified operations. /// /// # Errors (TODO) /// Returns an error if: @@ -75,7 +83,7 @@ impl Span { Self::with_decorators(operations, DecoratorList::new()) } - /// Returns a new [Span] block instantiated with the specified operations and decorators. + /// Returns a new [`BasicBlockNode`] instantiated with the specified operations and decorators. /// /// # Errors (TODO) /// Returns an error if: @@ -88,84 +96,73 @@ impl Span { #[cfg(debug_assertions)] validate_decorators(&operations, &decorators); - let (op_batches, hash) = batch_ops(operations); + let (op_batches, digest) = batch_ops(operations); Self { op_batches, - hash, + digest, decorators, } } +} - // PUBLIC ACCESSORS - // -------------------------------------------------------------------------------------------- - - /// Returns a hash of this code block. - pub fn hash(&self) -> Digest { - self.hash - } - - /// Returns list of operation batches contained in this span block. +/// Public accessors +impl BasicBlockNode { pub fn op_batches(&self) -> &[OpBatch] { &self.op_batches } - // SPAN MUTATORS - // -------------------------------------------------------------------------------------------- - - /// Returns a new [Span] block instantiated with operations from this block repeated the - /// specified number of times. - #[must_use] - pub fn replicate(&self, num_copies: usize) -> Self { - let own_ops = self.get_ops(); - let own_decorators = &self.decorators; - let mut ops = Vec::with_capacity(own_ops.len() * num_copies); - let mut decorators = DecoratorList::new(); - - for i in 0..num_copies { - // replicate decorators of a span block - for decorator in own_decorators { - decorators.push((own_ops.len() * i + decorator.0, decorator.1.clone())) - } - ops.extend_from_slice(&own_ops); - } - Self::with_decorators(ops, decorators) + /// Returns a [`DecoratorIterator`] which allows us to iterate through the decorator list of + /// this basic block node while executing operation batches of this basic block node. + pub fn decorator_iter(&self) -> DecoratorIterator { + DecoratorIterator::new(&self.decorators) } - /// Returns a list of decorators in this span block + /// Returns a list of decorators in this basic block node. pub fn decorators(&self) -> &DecoratorList { &self.decorators } +} - /// Returns a [DecoratorIterator] which allows us to iterate through the decorator list of this - /// span block while executing operation batches of this span block - pub fn decorator_iter(&self) -> DecoratorIterator { - DecoratorIterator::new(&self.decorators) +impl MerkleTreeNode for BasicBlockNode { + fn digest(&self) -> RpoDigest { + self.digest } - // HELPER METHODS - // -------------------------------------------------------------------------------------------- + fn to_display<'a>(&'a self, _mast_forest: &'a MastForest) -> impl fmt::Display + 'a { + self + } +} - /// Returns a list of operations contained in this span block. - fn get_ops(&self) -> Vec { - let mut ops = Vec::with_capacity(self.op_batches.len() * MAX_OPS_PER_BATCH); - for batch in self.op_batches.iter() { - ops.extend_from_slice(&batch.ops); +/// Checks if a given decorators list is valid (only checked in debug mode) +/// - Assert the decorator list is in ascending order. +/// - Assert the last op index in decorator list is less than or equal to the number of operations. +#[cfg(debug_assertions)] +fn validate_decorators(operations: &[Operation], decorators: &DecoratorList) { + if !decorators.is_empty() { + // check if decorator list is sorted + for i in 0..(decorators.len() - 1) { + debug_assert!(decorators[i + 1].0 >= decorators[i].0, "unsorted decorators list"); } - ops + // assert the last index in decorator list is less than operations vector length + debug_assert!( + operations.len() >= decorators.last().expect("empty decorators list").0, + "last op index in decorator list should be less than or equal to the number of ops" + ); } } -impl crate::prettier::PrettyPrint for Span { +impl PrettyPrint for BasicBlockNode { + #[rustfmt::skip] fn render(&self) -> crate::prettier::Document { use crate::prettier::*; - // e.g. `span a b c end` - let single_line = const_text("span") + // e.g. `basic_block a b c end` + let single_line = const_text("basic_block") + const_text(" ") + self .op_batches .iter() - .flat_map(|batch| batch.ops.iter()) + .flat_map(|batch| batch.ops().iter()) .map(|p| p.render()) .reduce(|acc, doc| acc + const_text(" ") + doc) .unwrap_or_default() @@ -173,20 +170,21 @@ impl crate::prettier::PrettyPrint for Span { + const_text("end"); // e.g. ` - // span + // basic_block // a // b // c // end // ` + let multi_line = indent( 4, - const_text("span") + const_text("basic_block") + nl() + self .op_batches .iter() - .flat_map(|batch| batch.ops.iter()) + .flat_map(|batch| batch.ops().iter()) .map(|p| p.render()) .reduce(|acc, doc| acc + nl() + doc) .unwrap_or_default(), @@ -197,7 +195,7 @@ impl crate::prettier::PrettyPrint for Span { } } -impl fmt::Display for Span { +impl fmt::Display for BasicBlockNode { fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { use crate::prettier::PrettyPrint; self.pretty_print(f) @@ -380,7 +378,7 @@ impl OpBatchAccumulator { /// up to 9 operations per group, and 8 groups per batch). /// /// After the operations have been grouped, computes the hash of the block. -fn batch_ops(ops: Vec) -> (Vec, Digest) { +fn batch_ops(ops: Vec) -> (Vec, RpoDigest) { let mut batch_acc = OpBatchAccumulator::new(); let mut batches = Vec::::new(); let mut batch_groups = Vec::<[Felt; BATCH_SIZE]>::new(); @@ -429,24 +427,6 @@ pub fn get_span_op_group_count(op_batches: &[OpBatch]) -> usize { (op_batches.len() - 1) * BATCH_SIZE + last_batch_num_groups.next_power_of_two() } -/// Checks if a given decorators list is valid (only checked in debug mode) -/// - Assert the decorator list is in ascending order. -/// - Assert the last op index in decorator list is less than or equal to the number of operations. -#[cfg(debug_assertions)] -fn validate_decorators(operations: &[Operation], decorators: &DecoratorList) { - if !decorators.is_empty() { - // check if decorator list is sorted - for i in 0..(decorators.len() - 1) { - debug_assert!(decorators[i + 1].0 >= decorators[i].0, "unsorted decorators list"); - } - // assert the last index in decorator list is less than operations vector length - debug_assert!( - operations.len() >= decorators.last().expect("empty decorators list").0, - "last op index in decorator list should be less than or equal to the number of ops" - ); - } -} - // TESTS // ================================================================================================ diff --git a/core/src/mast/node/call_node.rs b/core/src/mast/node/call_node.rs new file mode 100644 index 0000000000..c2183ea84d --- /dev/null +++ b/core/src/mast/node/call_node.rs @@ -0,0 +1,128 @@ +use core::fmt; + +use miden_crypto::{hash::rpo::RpoDigest, Felt}; +use miden_formatting::prettier::PrettyPrint; + +use crate::{chiplets::hasher, Operation}; + +use crate::mast::{MastForest, MastNodeId, MerkleTreeNode}; + +#[derive(Debug, Clone, PartialEq, Eq)] +pub struct CallNode { + callee: MastNodeId, + is_syscall: bool, + digest: RpoDigest, +} + +/// Constants +impl CallNode { + /// The domain of the call block (used for control block hashing). + pub const CALL_DOMAIN: Felt = Felt::new(Operation::Call.op_code() as u64); + /// The domain of the syscall block (used for control block hashing). + pub const SYSCALL_DOMAIN: Felt = Felt::new(Operation::SysCall.op_code() as u64); +} + +/// Constructors +impl CallNode { + /// Returns a new [`CallNode`] instantiated with the specified callee. + pub fn new(callee: MastNodeId, mast_forest: &MastForest) -> Self { + let digest = { + let callee_digest = mast_forest[callee].digest(); + + hasher::merge_in_domain(&[callee_digest, RpoDigest::default()], Self::CALL_DOMAIN) + }; + + Self { + callee, + is_syscall: false, + digest, + } + } + + /// Returns a new [`CallNode`] instantiated with the specified callee and marked as a kernel + /// call. + pub fn new_syscall(callee: MastNodeId, mast_forest: &MastForest) -> Self { + let digest = { + let callee_digest = mast_forest[callee].digest(); + + hasher::merge_in_domain(&[callee_digest, RpoDigest::default()], Self::SYSCALL_DOMAIN) + }; + + Self { + callee, + is_syscall: true, + digest, + } + } +} + +impl CallNode { + pub fn callee(&self) -> MastNodeId { + self.callee + } + + pub fn is_syscall(&self) -> bool { + self.is_syscall + } + + /// Returns the domain of the call node. + pub fn domain(&self) -> Felt { + if self.is_syscall() { + Self::SYSCALL_DOMAIN + } else { + Self::CALL_DOMAIN + } + } + + pub(super) fn to_pretty_print<'a>( + &'a self, + mast_forest: &'a MastForest, + ) -> impl PrettyPrint + 'a { + CallNodePrettyPrint { + call_node: self, + mast_forest, + } + } +} + +impl MerkleTreeNode for CallNode { + fn digest(&self) -> RpoDigest { + self.digest + } + + fn to_display<'a>(&'a self, mast_forest: &'a MastForest) -> impl fmt::Display + 'a { + CallNodePrettyPrint { + call_node: self, + mast_forest, + } + } +} + +struct CallNodePrettyPrint<'a> { + call_node: &'a CallNode, + mast_forest: &'a MastForest, +} + +impl<'a> PrettyPrint for CallNodePrettyPrint<'a> { + #[rustfmt::skip] + fn render(&self) -> crate::prettier::Document { + use crate::prettier::*; + use miden_formatting::hex::ToHex; + + let callee_digest = self.mast_forest[self.call_node.callee].digest(); + + let doc = if self.call_node.is_syscall { + const_text("syscall") + } else { + const_text("call") + }; + doc + const_text(".") + text(callee_digest.as_bytes().to_hex_with_prefix()) + } +} + +impl<'a> fmt::Display for CallNodePrettyPrint<'a> { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + use crate::prettier::PrettyPrint; + self.pretty_print(f) + } +} diff --git a/core/src/mast/node/dyn_node.rs b/core/src/mast/node/dyn_node.rs new file mode 100644 index 0000000000..c298a03ade --- /dev/null +++ b/core/src/mast/node/dyn_node.rs @@ -0,0 +1,66 @@ +use core::fmt; + +use miden_crypto::{hash::rpo::RpoDigest, Felt}; + +use crate::{ + mast::{MastForest, MerkleTreeNode}, + Operation, +}; + +#[derive(Debug, Clone, Default, PartialEq, Eq)] +pub struct DynNode; + +/// Constants +impl DynNode { + /// The domain of the Dyn block (used for control block hashing). + pub const DOMAIN: Felt = Felt::new(Operation::Dyn.op_code() as u64); +} + +impl MerkleTreeNode for DynNode { + fn digest(&self) -> RpoDigest { + // The Dyn node is represented by a constant, which is set to be the hash of two empty + // words ([ZERO, ZERO, ZERO, ZERO]) with a domain value of `DYN_DOMAIN`, i.e. + // hasher::merge_in_domain(&[Digest::default(), Digest::default()], DynNode::DOMAIN) + RpoDigest::new([ + Felt::new(8115106948140260551), + Felt::new(13491227816952616836), + Felt::new(15015806788322198710), + Felt::new(16575543461540527115), + ]) + } + + fn to_display<'a>(&'a self, _mast_forest: &'a MastForest) -> impl fmt::Display + 'a { + self + } +} + +impl crate::prettier::PrettyPrint for DynNode { + fn render(&self) -> crate::prettier::Document { + use crate::prettier::*; + const_text("dyn") + } +} + +impl fmt::Display for DynNode { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + use miden_formatting::prettier::PrettyPrint; + self.pretty_print(f) + } +} + +#[cfg(test)] +mod tests { + use miden_crypto::hash::rpo::Rpo256; + + use super::*; + + /// Ensures that the hash of `DynNode` is indeed the hash of 2 empty words, in the `DynNode` + /// domain. + #[test] + pub fn test_dyn_node_digest() { + assert_eq!( + DynNode.digest(), + Rpo256::merge_in_domain(&[RpoDigest::default(), RpoDigest::default()], DynNode::DOMAIN) + ); + } +} diff --git a/core/src/mast/node/join_node.rs b/core/src/mast/node/join_node.rs new file mode 100644 index 0000000000..3c4a712655 --- /dev/null +++ b/core/src/mast/node/join_node.rs @@ -0,0 +1,99 @@ +use core::fmt; + +use miden_crypto::{hash::rpo::RpoDigest, Felt}; + +use crate::{chiplets::hasher, prettier::PrettyPrint, Operation}; + +use crate::mast::{MastForest, MastNodeId, MerkleTreeNode}; + +#[derive(Debug, Clone, PartialEq, Eq)] +pub struct JoinNode { + children: [MastNodeId; 2], + digest: RpoDigest, +} + +/// Constants +impl JoinNode { + /// The domain of the join block (used for control block hashing). + pub const DOMAIN: Felt = Felt::new(Operation::Join.op_code() as u64); +} + +/// Constructors +impl JoinNode { + /// Returns a new [`JoinNode`] instantiated with the specified children nodes. + pub fn new(children: [MastNodeId; 2], mast_forest: &MastForest) -> Self { + let digest = { + let left_child_hash = mast_forest[children[0]].digest(); + let right_child_hash = mast_forest[children[1]].digest(); + + hasher::merge_in_domain(&[left_child_hash, right_child_hash], Self::DOMAIN) + }; + + Self { children, digest } + } + + pub(super) fn to_pretty_print<'a>( + &'a self, + mast_forest: &'a MastForest, + ) -> impl PrettyPrint + 'a { + JoinNodePrettyPrint { + join_node: self, + mast_forest, + } + } +} + +/// Accessors +impl JoinNode { + pub fn first(&self) -> MastNodeId { + self.children[0] + } + + pub fn second(&self) -> MastNodeId { + self.children[1] + } +} + +impl MerkleTreeNode for JoinNode { + fn digest(&self) -> RpoDigest { + self.digest + } + + fn to_display<'a>(&'a self, mast_forest: &'a MastForest) -> impl fmt::Display + 'a { + JoinNodePrettyPrint { + join_node: self, + mast_forest, + } + } +} + +struct JoinNodePrettyPrint<'a> { + join_node: &'a JoinNode, + mast_forest: &'a MastForest, +} + +impl<'a> PrettyPrint for JoinNodePrettyPrint<'a> { + #[rustfmt::skip] + fn render(&self) -> crate::prettier::Document { + use crate::prettier::*; + + let first_child = self.mast_forest[self.join_node.first()].to_pretty_print(self.mast_forest); + let second_child = self.mast_forest[self.join_node.second()].to_pretty_print(self.mast_forest); + + indent( + 4, + const_text("join") + + nl() + + first_child.render() + + nl() + + second_child.render(), + ) + nl() + const_text("end") + } +} + +impl<'a> fmt::Display for JoinNodePrettyPrint<'a> { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + use crate::prettier::PrettyPrint; + self.pretty_print(f) + } +} diff --git a/core/src/mast/node/loop_node.rs b/core/src/mast/node/loop_node.rs new file mode 100644 index 0000000000..fc63b11367 --- /dev/null +++ b/core/src/mast/node/loop_node.rs @@ -0,0 +1,84 @@ +use core::fmt; + +use miden_crypto::{hash::rpo::RpoDigest, Felt}; +use miden_formatting::prettier::PrettyPrint; + +use crate::{chiplets::hasher, Operation}; + +use crate::mast::{MastForest, MastNodeId, MerkleTreeNode}; + +#[derive(Debug, Clone, PartialEq, Eq)] +pub struct LoopNode { + body: MastNodeId, + digest: RpoDigest, +} + +/// Constants +impl LoopNode { + /// The domain of the loop node (used for control block hashing). + pub const DOMAIN: Felt = Felt::new(Operation::Loop.op_code() as u64); +} + +/// Constructors +impl LoopNode { + pub fn new(body: MastNodeId, mast_forest: &MastForest) -> Self { + let digest = { + let body_hash = mast_forest[body].digest(); + + hasher::merge_in_domain(&[body_hash, RpoDigest::default()], Self::DOMAIN) + }; + + Self { body, digest } + } + + pub(super) fn to_pretty_print<'a>( + &'a self, + mast_forest: &'a MastForest, + ) -> impl PrettyPrint + 'a { + LoopNodePrettyPrint { + loop_node: self, + mast_forest, + } + } +} + +impl LoopNode { + pub fn body(&self) -> MastNodeId { + self.body + } +} + +impl MerkleTreeNode for LoopNode { + fn digest(&self) -> RpoDigest { + self.digest + } + + fn to_display<'a>(&'a self, mast_forest: &'a MastForest) -> impl fmt::Display + 'a { + LoopNodePrettyPrint { + loop_node: self, + mast_forest, + } + } +} + +struct LoopNodePrettyPrint<'a> { + loop_node: &'a LoopNode, + mast_forest: &'a MastForest, +} + +impl<'a> crate::prettier::PrettyPrint for LoopNodePrettyPrint<'a> { + fn render(&self) -> crate::prettier::Document { + use crate::prettier::*; + + let loop_body = self.mast_forest[self.loop_node.body].to_pretty_print(self.mast_forest); + + indent(4, const_text("while.true") + nl() + loop_body.render()) + nl() + const_text("end") + } +} + +impl<'a> fmt::Display for LoopNodePrettyPrint<'a> { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + use crate::prettier::PrettyPrint; + self.pretty_print(f) + } +} diff --git a/core/src/mast/node/mod.rs b/core/src/mast/node/mod.rs new file mode 100644 index 0000000000..1fc8275194 --- /dev/null +++ b/core/src/mast/node/mod.rs @@ -0,0 +1,190 @@ +mod basic_block_node; +use core::fmt; + +use alloc::{boxed::Box, vec::Vec}; +pub use basic_block_node::{ + get_span_op_group_count, BasicBlockNode, OpBatch, BATCH_SIZE as OP_BATCH_SIZE, + GROUP_SIZE as OP_GROUP_SIZE, +}; + +mod call_node; +pub use call_node::CallNode; + +mod dyn_node; +pub use dyn_node::DynNode; + +mod join_node; +pub use join_node::JoinNode; + +mod split_node; +use miden_crypto::{hash::rpo::RpoDigest, Felt}; +use miden_formatting::prettier::{Document, PrettyPrint}; +pub use split_node::SplitNode; + +mod loop_node; +pub use loop_node::LoopNode; + +use crate::{ + mast::{MastForest, MastNodeId, MerkleTreeNode}, + DecoratorList, Operation, +}; + +#[derive(Debug, Clone, PartialEq, Eq)] +pub enum MastNode { + Block(BasicBlockNode), + Join(JoinNode), + Split(SplitNode), + Loop(LoopNode), + Call(CallNode), + Dyn, +} + +/// Constructors +impl MastNode { + pub fn new_basic_block(operations: Vec) -> Self { + Self::Block(BasicBlockNode::new(operations)) + } + + pub fn new_basic_block_with_decorators( + operations: Vec, + decorators: DecoratorList, + ) -> Self { + Self::Block(BasicBlockNode::with_decorators(operations, decorators)) + } + + pub fn new_join( + left_child: MastNodeId, + right_child: MastNodeId, + mast_forest: &MastForest, + ) -> Self { + Self::Join(JoinNode::new([left_child, right_child], mast_forest)) + } + + pub fn new_split( + if_branch: MastNodeId, + else_branch: MastNodeId, + mast_forest: &MastForest, + ) -> Self { + Self::Split(SplitNode::new([if_branch, else_branch], mast_forest)) + } + + pub fn new_loop(body: MastNodeId, mast_forest: &MastForest) -> Self { + Self::Loop(LoopNode::new(body, mast_forest)) + } + + pub fn new_call(callee: MastNodeId, mast_forest: &MastForest) -> Self { + Self::Call(CallNode::new(callee, mast_forest)) + } + + pub fn new_syscall(callee: MastNodeId, mast_forest: &MastForest) -> Self { + Self::Call(CallNode::new_syscall(callee, mast_forest)) + } + + pub fn new_dynexec() -> Self { + Self::Dyn + } + + pub fn new_dyncall(dyn_node_id: MastNodeId, mast_forest: &MastForest) -> Self { + Self::Call(CallNode::new(dyn_node_id, mast_forest)) + } +} + +/// Public accessors +impl MastNode { + pub fn is_basic_block(&self) -> bool { + matches!(self, Self::Block(_)) + } + + pub(crate) fn to_pretty_print<'a>( + &'a self, + mast_forest: &'a MastForest, + ) -> impl PrettyPrint + 'a { + match self { + MastNode::Block(basic_block_node) => { + MastNodePrettyPrint::new(Box::new(basic_block_node)) + } + MastNode::Join(join_node) => { + MastNodePrettyPrint::new(Box::new(join_node.to_pretty_print(mast_forest))) + } + MastNode::Split(split_node) => { + MastNodePrettyPrint::new(Box::new(split_node.to_pretty_print(mast_forest))) + } + MastNode::Loop(loop_node) => { + MastNodePrettyPrint::new(Box::new(loop_node.to_pretty_print(mast_forest))) + } + MastNode::Call(call_node) => { + MastNodePrettyPrint::new(Box::new(call_node.to_pretty_print(mast_forest))) + } + MastNode::Dyn => MastNodePrettyPrint::new(Box::new(DynNode)), + } + } + + pub fn domain(&self) -> Felt { + match self { + MastNode::Block(_) => BasicBlockNode::DOMAIN, + MastNode::Join(_) => JoinNode::DOMAIN, + MastNode::Split(_) => SplitNode::DOMAIN, + MastNode::Loop(_) => LoopNode::DOMAIN, + MastNode::Call(call_node) => call_node.domain(), + MastNode::Dyn => DynNode::DOMAIN, + } + } +} + +impl MerkleTreeNode for MastNode { + fn digest(&self) -> RpoDigest { + match self { + MastNode::Block(node) => node.digest(), + MastNode::Join(node) => node.digest(), + MastNode::Split(node) => node.digest(), + MastNode::Loop(node) => node.digest(), + MastNode::Call(node) => node.digest(), + MastNode::Dyn => DynNode.digest(), + } + } + + fn to_display<'a>(&'a self, mast_forest: &'a MastForest) -> impl fmt::Display + 'a { + match self { + MastNode::Block(node) => MastNodeDisplay::new(node.to_display(mast_forest)), + MastNode::Join(node) => MastNodeDisplay::new(node.to_display(mast_forest)), + MastNode::Split(node) => MastNodeDisplay::new(node.to_display(mast_forest)), + MastNode::Loop(node) => MastNodeDisplay::new(node.to_display(mast_forest)), + MastNode::Call(node) => MastNodeDisplay::new(node.to_display(mast_forest)), + MastNode::Dyn => MastNodeDisplay::new(DynNode.to_display(mast_forest)), + } + } +} + +struct MastNodePrettyPrint<'a> { + node_pretty_print: Box, +} + +impl<'a> MastNodePrettyPrint<'a> { + pub fn new(node_pretty_print: Box) -> Self { + Self { node_pretty_print } + } +} + +impl<'a> PrettyPrint for MastNodePrettyPrint<'a> { + fn render(&self) -> Document { + self.node_pretty_print.render() + } +} + +struct MastNodeDisplay<'a> { + node_display: Box, +} + +impl<'a> MastNodeDisplay<'a> { + pub fn new(node: impl fmt::Display + 'a) -> Self { + Self { + node_display: Box::new(node), + } + } +} + +impl<'a> fmt::Display for MastNodeDisplay<'a> { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + self.node_display.fmt(f) + } +} diff --git a/core/src/mast/node/split_node.rs b/core/src/mast/node/split_node.rs new file mode 100644 index 0000000000..ca87501fe3 --- /dev/null +++ b/core/src/mast/node/split_node.rs @@ -0,0 +1,96 @@ +use core::fmt; + +use miden_crypto::{hash::rpo::RpoDigest, Felt}; +use miden_formatting::prettier::PrettyPrint; + +use crate::{chiplets::hasher, Operation}; + +use crate::mast::{MastForest, MastNodeId, MerkleTreeNode}; + +#[derive(Debug, Clone, PartialEq, Eq)] +pub struct SplitNode { + branches: [MastNodeId; 2], + digest: RpoDigest, +} + +/// Constants +impl SplitNode { + /// The domain of the split node (used for control block hashing). + pub const DOMAIN: Felt = Felt::new(Operation::Split.op_code() as u64); +} + +/// Constructors +impl SplitNode { + pub fn new(branches: [MastNodeId; 2], mast_forest: &MastForest) -> Self { + let digest = { + let if_branch_hash = mast_forest[branches[0]].digest(); + let else_branch_hash = mast_forest[branches[1]].digest(); + + hasher::merge_in_domain(&[if_branch_hash, else_branch_hash], Self::DOMAIN) + }; + + Self { branches, digest } + } +} + +/// Public accessors +impl SplitNode { + pub fn on_true(&self) -> MastNodeId { + self.branches[0] + } + + pub fn on_false(&self) -> MastNodeId { + self.branches[1] + } +} + +impl SplitNode { + pub(super) fn to_pretty_print<'a>( + &'a self, + mast_forest: &'a MastForest, + ) -> impl PrettyPrint + 'a { + SplitNodePrettyPrint { + split_node: self, + mast_forest, + } + } +} + +impl MerkleTreeNode for SplitNode { + fn digest(&self) -> RpoDigest { + self.digest + } + + fn to_display<'a>(&'a self, mast_forest: &'a MastForest) -> impl core::fmt::Display + 'a { + SplitNodePrettyPrint { + split_node: self, + mast_forest, + } + } +} + +struct SplitNodePrettyPrint<'a> { + split_node: &'a SplitNode, + mast_forest: &'a MastForest, +} + +impl<'a> PrettyPrint for SplitNodePrettyPrint<'a> { + #[rustfmt::skip] + fn render(&self) -> crate::prettier::Document { + use crate::prettier::*; + + let true_branch = self.mast_forest[self.split_node.on_true()].to_pretty_print(self.mast_forest); + let false_branch = self.mast_forest[self.split_node.on_false()].to_pretty_print(self.mast_forest); + + let mut doc = indent(4, const_text("if.true") + nl() + true_branch.render()) + nl(); + doc += indent(4, const_text("else") + nl() + false_branch.render()); + doc + nl() + const_text("end") + } +} + +impl<'a> fmt::Display for SplitNodePrettyPrint<'a> { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + use crate::prettier::PrettyPrint; + self.pretty_print(f) + } +} diff --git a/core/src/program/tests.rs b/core/src/mast/tests.rs similarity index 72% rename from core/src/program/tests.rs rename to core/src/mast/tests.rs index 3e5e0a5329..5c4e54e738 100644 --- a/core/src/program/tests.rs +++ b/core/src/mast/tests.rs @@ -1,14 +1,19 @@ -use super::{blocks::Dyn, Deserializable, Digest, Felt, Kernel, ProgramInfo, Serializable}; -use crate::{chiplets::hasher, Word}; +use crate::{ + chiplets::hasher, + mast::{DynNode, Kernel, MerkleTreeNode}, + ProgramInfo, Word, +}; use alloc::vec::Vec; +use miden_crypto::{hash::rpo::RpoDigest, Felt}; use proptest::prelude::*; use rand_utils::prng_array; +use winter_utils::{Deserializable, Serializable}; #[test] fn dyn_hash_is_correct() { let expected_constant = - hasher::merge_in_domain(&[Digest::default(), Digest::default()], Dyn::DOMAIN); - assert_eq!(expected_constant, Dyn::new().hash()); + hasher::merge_in_domain(&[RpoDigest::default(), RpoDigest::default()], DynNode::DOMAIN); + assert_eq!(expected_constant, DynNode.digest()); } proptest! { @@ -18,7 +23,7 @@ proptest! { ref seed in any::<[u8; 32]>() ) { let program_hash = digest_from_seed(*seed); - let kernel: Vec = (0..kernel_count) + let kernel: Vec = (0..kernel_count) .scan(*seed, |seed, _| { *seed = prng_array(*seed); Some(digest_from_seed(*seed)) @@ -35,7 +40,7 @@ proptest! { // HELPER FUNCTIONS // -------------------------------------------------------------------------------------------- -fn digest_from_seed(seed: [u8; 32]) -> Digest { +fn digest_from_seed(seed: [u8; 32]) -> RpoDigest { let mut digest = Word::default(); digest.iter_mut().enumerate().for_each(|(i, d)| { *d = <[u8; 8]>::try_from(&seed[i * 8..(i + 1) * 8]) diff --git a/core/src/program.rs b/core/src/program.rs new file mode 100644 index 0000000000..73e9fcf4be --- /dev/null +++ b/core/src/program.rs @@ -0,0 +1,211 @@ +use core::{fmt, ops::Index}; + +use alloc::vec::Vec; +use miden_crypto::{hash::rpo::RpoDigest, Felt, WORD_SIZE}; +use winter_utils::{ByteReader, ByteWriter, Deserializable, DeserializationError, Serializable}; + +use crate::{ + errors::ProgramError, + mast::{MastForest, MastNode, MastNodeId}, + utils::ToElements, +}; + +use super::Kernel; + +// PROGRAM +// =============================================================================================== + +#[derive(Clone, Debug)] +pub struct Program { + mast_forest: MastForest, +} + +/// Constructors +impl Program { + pub fn new(mast_forest: MastForest) -> Result { + if mast_forest.entrypoint().is_some() { + Ok(Self { mast_forest }) + } else { + Err(ProgramError::NoEntrypoint) + } + } +} + +/// Public accessors +impl Program { + /// Returns the underlying [`MastForest`]. + pub fn mast_forest(&self) -> &MastForest { + &self.mast_forest + } + + /// Returns the kernel associated with this program. + pub fn kernel(&self) -> &Kernel { + self.mast_forest.kernel() + } + + /// Returns the entrypoint associated with this program. + pub fn entrypoint(&self) -> MastNodeId { + self.mast_forest.entrypoint().unwrap() + } + + /// A convenience method that provides the hash of the entrypoint. + pub fn hash(&self) -> RpoDigest { + self.mast_forest.entrypoint_digest().unwrap() + } + + /// Returns the [`MastNode`] associated with the provided [`MastNodeId`] if valid, or else + /// `None`. + /// + /// This is the faillible version of indexing (e.g. `program[node_id]`). + #[inline(always)] + pub fn get_node_by_id(&self, node_id: MastNodeId) -> Option<&MastNode> { + self.mast_forest.get_node_by_id(node_id) + } + + /// Returns the [`MastNodeId`] associated with a given digest, if any. + /// + /// That is, every [`MastNode`] hashes to some digest. If there exists a [`MastNode`] in the + /// forest that hashes to this digest, then its id is returned. + #[inline(always)] + pub fn get_node_id_by_digest(&self, digest: RpoDigest) -> Option { + self.mast_forest.get_node_id_by_digest(digest) + } +} + +impl Index for Program { + type Output = MastNode; + + fn index(&self, node_id: MastNodeId) -> &Self::Output { + &self.mast_forest[node_id] + } +} + +impl crate::prettier::PrettyPrint for Program { + fn render(&self) -> crate::prettier::Document { + use crate::prettier::*; + let entrypoint = self[self.entrypoint()].to_pretty_print(&self.mast_forest); + + indent(4, const_text("begin") + nl() + entrypoint.render()) + nl() + const_text("end") + } +} + +impl fmt::Display for Program { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + use crate::prettier::PrettyPrint; + self.pretty_print(f) + } +} + +impl TryFrom for Program { + type Error = ProgramError; + + fn try_from(mast_forest: MastForest) -> Result { + Self::new(mast_forest) + } +} + +impl From for MastForest { + fn from(program: Program) -> Self { + program.mast_forest + } +} + +// PROGRAM INFO +// =============================================================================================== + +/// A program information set consisting of its MAST root and set of kernel procedure roots used +/// for its compilation. +/// +/// This will be used as public inputs of the proof so we bind its verification to the kernel and +/// root used to execute the program. This way, we extend the correctness of the proof to the +/// security guarantees provided by the kernel. We also allow the user to easily prove the +/// membership of a given kernel procedure for a given proof, without compromising its +/// zero-knowledge properties. +#[derive(Debug, Clone, Default, PartialEq, Eq)] +pub struct ProgramInfo { + program_hash: RpoDigest, + kernel: Kernel, +} + +impl ProgramInfo { + // CONSTRUCTORS + // -------------------------------------------------------------------------------------------- + + /// Creates a new instance of a program info. + pub const fn new(program_hash: RpoDigest, kernel: Kernel) -> Self { + Self { + program_hash, + kernel, + } + } + + // PUBLIC ACCESSORS + // -------------------------------------------------------------------------------------------- + + /// Returns the program hash computed from its code block root. + pub const fn program_hash(&self) -> &RpoDigest { + &self.program_hash + } + + /// Returns the program kernel used during the compilation. + pub const fn kernel(&self) -> &Kernel { + &self.kernel + } + + /// Returns the list of procedures of the kernel used during the compilation. + pub fn kernel_procedures(&self) -> &[RpoDigest] { + self.kernel.proc_hashes() + } +} + +impl From for ProgramInfo { + fn from(program: Program) -> Self { + let program_hash = program.hash(); + let kernel = program.kernel().clone(); + + Self { + program_hash, + kernel, + } + } +} + +// SERIALIZATION +// ------------------------------------------------------------------------------------------------ + +impl Serializable for ProgramInfo { + fn write_into(&self, target: &mut W) { + self.program_hash.write_into(target); + self.kernel.write_into(target); + } +} + +impl Deserializable for ProgramInfo { + fn read_from(source: &mut R) -> Result { + let program_hash = source.read()?; + let kernel = source.read()?; + Ok(Self { + program_hash, + kernel, + }) + } +} + +// TO ELEMENTS +// ------------------------------------------------------------------------------------------------ + +impl ToElements for ProgramInfo { + fn to_elements(&self) -> Vec { + let num_kernel_proc_elements = self.kernel.proc_hashes().len() * WORD_SIZE; + let mut result = Vec::with_capacity(WORD_SIZE + num_kernel_proc_elements); + + // append program hash elements + result.extend_from_slice(self.program_hash.as_elements()); + + // append kernel procedure hash elements + for proc_hash in self.kernel.proc_hashes() { + result.extend_from_slice(proc_hash.as_elements()); + } + result + } +} diff --git a/core/src/program/blocks/call_block.rs b/core/src/program/blocks/call_block.rs deleted file mode 100644 index 2cffa54cd9..0000000000 --- a/core/src/program/blocks/call_block.rs +++ /dev/null @@ -1,101 +0,0 @@ -use super::{hasher, Digest, Felt, Operation}; -use core::fmt; - -// CALL BLOCK -// ================================================================================================ -/// Block for a function call. -/// -/// Executes the function referenced by `fn_hash`. Fails if the body is unavailable to the VM, or -/// if the execution of the call fails. -/// -/// The hash of a call block is computed as: -/// -/// > hash(fn_hash || padding, domain=CALL_DOMAIN) -/// > hash(fn_hash || padding, domain=SYSCALL_DOMAIN) # when a syscall is used -/// -/// Where `fn_hash` is 4 field elements (256 bits), and `padding` is 4 ZERO elements (256 bits). -#[derive(Clone, Debug, PartialEq, Eq)] -pub struct Call { - hash: Digest, - fn_hash: Digest, - is_syscall: bool, -} - -impl Call { - // CONSTANTS - // -------------------------------------------------------------------------------------------- - /// The domain of the call block (used for control block hashing). - pub const CALL_DOMAIN: Felt = Felt::new(Operation::Call.op_code() as u64); - /// The domain of the syscall block (used for control block hashing). - pub const SYSCALL_DOMAIN: Felt = Felt::new(Operation::SysCall.op_code() as u64); - - // CONSTRUCTOR - // -------------------------------------------------------------------------------------------- - /// Returns a new [Call] block instantiated with the specified function body hash. - pub fn new(fn_hash: Digest) -> Self { - let hash = hasher::merge_in_domain(&[fn_hash, Digest::default()], Self::CALL_DOMAIN); - Self { - hash, - fn_hash, - is_syscall: false, - } - } - - /// Returns a new [Call] block instantiated with the specified function body hash and marked - /// as a kernel call. - pub fn new_syscall(fn_hash: Digest) -> Self { - let hash = hasher::merge_in_domain(&[fn_hash, Digest::default()], Self::SYSCALL_DOMAIN); - Self { - hash, - fn_hash, - is_syscall: true, - } - } - - // PUBLIC ACCESSORS - // -------------------------------------------------------------------------------------------- - - /// Returns a hash of this code block. - pub fn hash(&self) -> Digest { - self.hash - } - - /// Returns a hash of the function to be called by this block. - pub fn fn_hash(&self) -> Digest { - self.fn_hash - } - - /// Returns true if this call block corresponds to a kernel call. - pub fn is_syscall(&self) -> bool { - self.is_syscall - } - - /// Returns the domain of the call block - pub fn domain(&self) -> Felt { - match self.is_syscall() { - true => Self::SYSCALL_DOMAIN, - false => Self::CALL_DOMAIN, - } - } -} - -impl crate::prettier::PrettyPrint for Call { - fn render(&self) -> crate::prettier::Document { - use crate::prettier::*; - use miden_formatting::hex::ToHex; - - let doc = if self.is_syscall { - const_text("syscall") - } else { - const_text("call") - }; - doc + const_text(".") + text(self.fn_hash.as_bytes().to_hex_with_prefix()) - } -} - -impl fmt::Display for Call { - fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { - use crate::prettier::PrettyPrint; - self.pretty_print(f) - } -} diff --git a/core/src/program/blocks/dyn_block.rs b/core/src/program/blocks/dyn_block.rs deleted file mode 100644 index 8571abf69d..0000000000 --- a/core/src/program/blocks/dyn_block.rs +++ /dev/null @@ -1,77 +0,0 @@ -use miden_formatting::prettier::PrettyPrint; - -use super::{Digest, Felt, Operation}; -use core::fmt; - -// CONSTANTS -// ================================================================================================ - -/// The Dyn block is represented by a constant, which is set to be the hash of two empty words -/// ([ZERO, ZERO, ZERO, ZERO]) with a domain value of `DYN_DOMAIN`, i.e. -/// hasher::merge_in_domain(&[Digest::default(), Digest::default()], Dyn::DOMAIN) -const DYN_CONSTANT: Digest = Digest::new([ - Felt::new(8115106948140260551), - Felt::new(13491227816952616836), - Felt::new(15015806788322198710), - Felt::new(16575543461540527115), -]); - -// Dyn BLOCK -// ================================================================================================ -/// Block for dynamic code where the target is specified by the stack. -/// -/// Executes the code block referenced by the hash on top of the stack. Fails if the body is -/// unavailable to the VM, or if the execution of the dynamically-specified code block fails. -/// -/// The child of a Dyn block (the target specified by the stack) is always dynamic and does not -/// affect the representation of the Dyn block. Therefore all Dyn blocks are represented by the same -/// constant (rather than by unique hashes), which is computed as an RPO hash of two empty words -/// ([ZERO, ZERO, ZERO, ZERO]) with a domain value of `DYN_DOMAIN`. -#[derive(Clone, Debug, PartialEq, Eq)] -pub struct Dyn {} - -impl Dyn { - // CONSTANTS - // -------------------------------------------------------------------------------------------- - /// The domain of the Dyn block (used for control block hashing). - pub const DOMAIN: Felt = Felt::new(Operation::Dyn.op_code() as u64); - - // CONSTRUCTOR - // -------------------------------------------------------------------------------------------- - /// Returns a new [Dyn] block instantiated with the specified function body hash. - pub fn new() -> Self { - Self {} - } - - // PUBLIC ACCESSORS - // -------------------------------------------------------------------------------------------- - - /// Returns a hash of this code block. - pub fn hash(&self) -> Digest { - Self::dyn_hash() - } - - /// Returns a hash of this code block. - pub fn dyn_hash() -> Digest { - DYN_CONSTANT - } -} - -impl Default for Dyn { - fn default() -> Self { - Self::new() - } -} - -impl crate::prettier::PrettyPrint for Dyn { - fn render(&self) -> crate::prettier::Document { - use crate::prettier::*; - const_text("dyn") - } -} - -impl fmt::Display for Dyn { - fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { - self.pretty_print(f) - } -} diff --git a/core/src/program/blocks/join_block.rs b/core/src/program/blocks/join_block.rs deleted file mode 100644 index b68ca75043..0000000000 --- a/core/src/program/blocks/join_block.rs +++ /dev/null @@ -1,83 +0,0 @@ -use alloc::boxed::Box; -use core::fmt; - -use super::{hasher, CodeBlock, Digest, Felt, Operation}; - -// JOIN BLOCKS -// ================================================================================================ -/// Block for sequential execution of two sub-blocks. -/// -/// Executes left sub-block then the right sub-block. Fails if either of the sub-block execution -/// fails. -/// -/// The hash of a join block is computed as: -/// -/// > hash(left_block_hash || right_block_hash, domain=JOIN_DOMAIN) -/// -/// Where `left_block_hash` and `right_block_hash` are 4 field elements (256 bits) each. -#[derive(Clone, Debug, PartialEq, Eq)] -pub struct Join { - body: Box<[CodeBlock; 2]>, - hash: Digest, -} - -impl Join { - // CONSTANTS - // -------------------------------------------------------------------------------------------- - /// The domain of the join block (used for control block hashing). - pub const DOMAIN: Felt = Felt::new(Operation::Join.op_code() as u64); - - // CONSTRUCTOR - // -------------------------------------------------------------------------------------------- - /// Returns a new [Join] block instantiated with the specified code blocks. - pub fn new(body: [CodeBlock; 2]) -> Self { - let hash = hasher::merge_in_domain(&[body[0].hash(), body[1].hash()], Self::DOMAIN); - Self { - body: Box::new(body), - hash, - } - } - - // PUBLIC ACCESSORS - // -------------------------------------------------------------------------------------------- - - /// Returns a hash of this code block. - pub fn hash(&self) -> Digest { - self.hash - } - - /// Returns a reference to the code block which is to be executed first when this join block - /// is executed. - pub fn first(&self) -> &CodeBlock { - &self.body[0] - } - - /// Returns a reference to the code block which is to be executed second when this join block - /// is executed. - pub fn second(&self) -> &CodeBlock { - &self.body[1] - } -} - -impl crate::prettier::PrettyPrint for Join { - #[rustfmt::skip] - fn render(&self) -> crate::prettier::Document { - use crate::prettier::*; - - indent( - 4, - const_text("join") - + nl() - + self.body[0].render() - + nl() - + self.body[1].render(), - ) + nl() + const_text("end") - } -} - -impl fmt::Display for Join { - fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { - use crate::prettier::PrettyPrint; - self.pretty_print(f) - } -} diff --git a/core/src/program/blocks/loop_block.rs b/core/src/program/blocks/loop_block.rs deleted file mode 100644 index 999c8cc89a..0000000000 --- a/core/src/program/blocks/loop_block.rs +++ /dev/null @@ -1,67 +0,0 @@ -use alloc::boxed::Box; -use core::fmt; - -use super::{hasher, CodeBlock, Digest, Felt, Operation}; - -// LOOP BLOCK -// ================================================================================================ -/// Block for a conditional loop. -/// -/// Executes the loop body while the value on the top of the stack is `1`, stops when `0`. Fails if -/// the top of the stack is neither `1` nor `0`, or if the execution of the body fails. -/// -/// The hash of a loop block is: -/// -/// > hash(body_hash || padding, domain=LOOP_DOMAIN) -/// -/// Where `body_hash` is 4 field elements (256 bits), and `padding` is 4 ZERO elements (256 bits). -#[derive(Clone, Debug, PartialEq, Eq)] -pub struct Loop { - body: Box, - hash: Digest, -} - -impl Loop { - // CONSTANTS - // -------------------------------------------------------------------------------------------- - /// The domain of the loop block (used for control block hashing). - pub const DOMAIN: Felt = Felt::new(Operation::Loop.op_code() as u64); - - // CONSTRUCTOR - // -------------------------------------------------------------------------------------------- - /// Returns a new [Loop] bock instantiated with the specified body. - pub fn new(body: CodeBlock) -> Self { - let hash = hasher::merge_in_domain(&[body.hash(), Digest::default()], Self::DOMAIN); - Self { - body: Box::new(body), - hash, - } - } - - // PUBLIC ACCESSORS - // -------------------------------------------------------------------------------------------- - - /// Returns a hash of this code block. - pub fn hash(&self) -> Digest { - self.hash - } - - /// Returns a reference to the code block which represents the body of the loop. - pub fn body(&self) -> &CodeBlock { - &self.body - } -} - -impl crate::prettier::PrettyPrint for Loop { - fn render(&self) -> crate::prettier::Document { - use crate::prettier::*; - indent(4, const_text("while.true") + nl() + self.body.render()) + nl() + const_text("end") - } -} - -impl fmt::Display for Loop { - fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { - use crate::prettier::PrettyPrint; - self.pretty_print(f) - } -} diff --git a/core/src/program/blocks/mod.rs b/core/src/program/blocks/mod.rs deleted file mode 100644 index 945f5c8eaf..0000000000 --- a/core/src/program/blocks/mod.rs +++ /dev/null @@ -1,148 +0,0 @@ -use super::{hasher, Digest, Felt, Operation}; -use crate::DecoratorList; -use alloc::vec::Vec; -use core::fmt; - -mod call_block; -mod dyn_block; -mod join_block; -mod loop_block; -mod proxy_block; -mod span_block; -mod split_block; - -pub use call_block::Call; -pub use dyn_block::Dyn; -pub use join_block::Join; -pub use loop_block::Loop; -pub use proxy_block::Proxy; -pub use span_block::{ - get_span_op_group_count, OpBatch, Span, BATCH_SIZE as OP_BATCH_SIZE, - GROUP_SIZE as OP_GROUP_SIZE, -}; -pub use split_block::Split; - -// PROGRAM BLOCK -// ================================================================================================ -/// TODO: add comments -#[derive(Clone, Debug, PartialEq, Eq)] -pub enum CodeBlock { - Span(Span), - Join(Join), - Split(Split), - Loop(Loop), - Call(Call), - Dyn(Dyn), - Proxy(Proxy), -} - -impl CodeBlock { - // CONSTRUCTORS - // -------------------------------------------------------------------------------------------- - - /// Returns a new Span block instantiated with the provided operations. - pub fn new_span(operations: Vec) -> Self { - Self::Span(Span::new(operations)) - } - - /// Returns a new Span block instantiated with the provided operations and decorator list. - pub fn new_span_with_decorators(operations: Vec, decorators: DecoratorList) -> Self { - Self::Span(Span::with_decorators(operations, decorators)) - } - - /// TODO: add comments - pub fn new_join(blocks: [CodeBlock; 2]) -> Self { - Self::Join(Join::new(blocks)) - } - - /// TODO: add comments - pub fn new_split(t_branch: CodeBlock, f_branch: CodeBlock) -> Self { - Self::Split(Split::new(t_branch, f_branch)) - } - - /// TODO: add comments - pub fn new_loop(body: CodeBlock) -> Self { - Self::Loop(Loop::new(body)) - } - - /// TODO: add comments - pub fn new_call(fn_hash: Digest) -> Self { - Self::Call(Call::new(fn_hash)) - } - - /// TODO: add comments - pub fn new_syscall(fn_hash: Digest) -> Self { - Self::Call(Call::new_syscall(fn_hash)) - } - - /// TODO: add comments - pub fn new_dyn() -> Self { - Self::Dyn(Dyn::new()) - } - - /// TODO: add comments - pub fn new_dyncall() -> Self { - Self::Call(Call::new(Dyn::dyn_hash())) - } - - /// TODO: add comments - pub fn new_proxy(code_hash: Digest) -> Self { - Self::Proxy(Proxy::new(code_hash)) - } - - // PUBLIC ACCESSORS - // -------------------------------------------------------------------------------------------- - - /// Returns true if this code block is a [Span] block. - pub fn is_span(&self) -> bool { - matches!(self, CodeBlock::Span(_)) - } - - /// Returns a hash of this code block. - pub fn hash(&self) -> Digest { - match self { - CodeBlock::Span(block) => block.hash(), - CodeBlock::Join(block) => block.hash(), - CodeBlock::Split(block) => block.hash(), - CodeBlock::Loop(block) => block.hash(), - CodeBlock::Call(block) => block.hash(), - CodeBlock::Dyn(block) => block.hash(), - CodeBlock::Proxy(block) => block.hash(), - } - } - - /// Returns the domain of the code block - pub fn domain(&self) -> Felt { - match self { - CodeBlock::Call(block) => block.domain(), - CodeBlock::Dyn(_) => Dyn::DOMAIN, - CodeBlock::Join(_) => Join::DOMAIN, - CodeBlock::Loop(_) => Loop::DOMAIN, - CodeBlock::Span(_) => Span::DOMAIN, - CodeBlock::Split(_) => Split::DOMAIN, - CodeBlock::Proxy(_) => panic!("Can't fetch `domain` for a `Proxy` block!"), - } - } -} - -impl crate::prettier::PrettyPrint for CodeBlock { - fn render(&self) -> crate::prettier::Document { - match self { - Self::Span(block) => block.render(), - Self::Join(block) => block.render(), - Self::Split(block) => block.render(), - Self::Loop(block) => block.render(), - Self::Call(block) => block.render(), - Self::Dyn(block) => block.render(), - Self::Proxy(block) => block.render(), - } - } -} - -impl fmt::Display for CodeBlock { - fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { - use crate::prettier::PrettyPrint; - - self.pretty_print(f) - } -} diff --git a/core/src/program/blocks/proxy_block.rs b/core/src/program/blocks/proxy_block.rs deleted file mode 100644 index 84a45cf1c3..0000000000 --- a/core/src/program/blocks/proxy_block.rs +++ /dev/null @@ -1,43 +0,0 @@ -use super::Digest; -use core::fmt; - -// PROXY BLOCK -// ================================================================================================ -/// Block for a unknown function call. -/// -/// Proxy blocks are used to verify the integrity of a program's hash while keeping parts -/// of the program secret. Fails if executed. -/// -/// Hash of a proxy block is not computed but is rather defined at instantiation time. -#[derive(Clone, Debug, PartialEq, Eq)] -pub struct Proxy { - hash: Digest, -} - -impl Proxy { - /// Returns a new [Proxy] block instantiated with the specified code hash. - pub fn new(code_hash: Digest) -> Self { - Self { hash: code_hash } - } - - /// Returns a hash of this code block. - pub fn hash(&self) -> Digest { - self.hash - } -} - -impl crate::prettier::PrettyPrint for Proxy { - fn render(&self) -> crate::prettier::Document { - use crate::prettier::*; - use miden_formatting::hex::ToHex; - - const_text("proxy") + const_text(".") + text(self.hash.as_bytes().to_hex_with_prefix()) - } -} - -impl fmt::Display for Proxy { - fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { - use crate::prettier::PrettyPrint; - self.pretty_print(f) - } -} diff --git a/core/src/program/blocks/split_block.rs b/core/src/program/blocks/split_block.rs deleted file mode 100644 index 738f3ceffe..0000000000 --- a/core/src/program/blocks/split_block.rs +++ /dev/null @@ -1,77 +0,0 @@ -use alloc::boxed::Box; -use core::fmt; - -use super::{hasher, CodeBlock, Digest, Felt, Operation}; - -// SPLIT BLOCK -// ================================================================================================ -/// Block for conditional execution. -/// -/// Executes the first branch if the top of the stack is `1` or the second branch if `0`. Fails if -/// the top of the stack is neither `1` or `0` or if the branch execution fails. -/// -/// The hash of a split block is: -/// -/// > hash(true_branch_hash || false_branch_hash, domain=SPLIT_DOMAIN) -/// -/// Where `true_branch_hash` and `false_branch_hash` are 4 field elements (256 bits) each. -#[derive(Clone, Debug, PartialEq, Eq)] -pub struct Split { - branches: Box<[CodeBlock; 2]>, - hash: Digest, -} - -impl Split { - // CONSTANTS - // -------------------------------------------------------------------------------------------- - /// The domain of the split block (used for control block hashing). - pub const DOMAIN: Felt = Felt::new(Operation::Split.op_code() as u64); - - // CONSTRUCTOR - // -------------------------------------------------------------------------------------------- - /// Returns a new [Split] block instantiated with the specified true and false branches. - pub fn new(t_branch: CodeBlock, f_branch: CodeBlock) -> Self { - let hash = hasher::merge_in_domain(&[t_branch.hash(), f_branch.hash()], Self::DOMAIN); - Self { - branches: Box::new([t_branch, f_branch]), - hash, - } - } - - // PUBLIC ACCESSORS - // -------------------------------------------------------------------------------------------- - - /// Returns a hash of this code block. - pub fn hash(&self) -> Digest { - self.hash - } - - /// Returns a reference to the code block which is to be executed when the top of the stack - /// is `1`. - pub fn on_true(&self) -> &CodeBlock { - &self.branches[0] - } - - /// Returns a reference to the code block which is to be executed when the top of the stack - /// is `0`. - pub fn on_false(&self) -> &CodeBlock { - &self.branches[1] - } -} - -impl crate::prettier::PrettyPrint for Split { - fn render(&self) -> crate::prettier::Document { - use crate::prettier::*; - - let mut doc = indent(4, const_text("if.true") + nl() + self.branches[0].render()) + nl(); - doc += indent(4, const_text("else") + nl() + self.branches[1].render()); - doc + nl() + const_text("end") - } -} - -impl fmt::Display for Split { - fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { - use crate::prettier::PrettyPrint; - self.pretty_print(f) - } -} diff --git a/core/src/program/info.rs b/core/src/program/info.rs deleted file mode 100644 index 083fcb0c23..0000000000 --- a/core/src/program/info.rs +++ /dev/null @@ -1,106 +0,0 @@ -use super::{ - super::{ToElements, WORD_SIZE}, - ByteReader, ByteWriter, Deserializable, DeserializationError, Digest, Felt, Kernel, Program, - Serializable, -}; -use alloc::vec::Vec; - -// PROGRAM INFO -// ================================================================================================ - -/// A program information set consisting of its MAST root and set of kernel procedure roots used -/// for its compilation. -/// -/// This will be used as public inputs of the proof so we bind its verification to the kernel and -/// root used to execute the program. This way, we extend the correctness of the proof to the -/// security guarantees provided by the kernel. We also allow the user to easily prove the -/// membership of a given kernel procedure for a given proof, without compromising its -/// zero-knowledge properties. -#[derive(Debug, Clone, Default, PartialEq, Eq)] -pub struct ProgramInfo { - program_hash: Digest, - kernel: Kernel, -} - -impl ProgramInfo { - // CONSTRUCTORS - // -------------------------------------------------------------------------------------------- - - /// Creates a new instance of a program info. - pub const fn new(program_hash: Digest, kernel: Kernel) -> Self { - Self { - program_hash, - kernel, - } - } - - // PUBLIC ACCESSORS - // -------------------------------------------------------------------------------------------- - - /// Returns the program hash computed from its code block root. - pub const fn program_hash(&self) -> &Digest { - &self.program_hash - } - - /// Returns the program kernel used during the compilation. - pub const fn kernel(&self) -> &Kernel { - &self.kernel - } - - /// Returns the list of procedures of the kernel used during the compilation. - pub fn kernel_procedures(&self) -> &[Digest] { - self.kernel.proc_hashes() - } -} - -impl From for ProgramInfo { - fn from(program: Program) -> Self { - let Program { root, kernel, .. } = program; - let program_hash = root.hash(); - - Self { - program_hash, - kernel, - } - } -} - -// SERIALIZATION -// ------------------------------------------------------------------------------------------------ - -impl Serializable for ProgramInfo { - fn write_into(&self, target: &mut W) { - self.program_hash.write_into(target); - self.kernel.write_into(target); - } -} - -impl Deserializable for ProgramInfo { - fn read_from(source: &mut R) -> Result { - let program_hash = source.read()?; - let kernel = source.read()?; - Ok(Self { - program_hash, - kernel, - }) - } -} - -// TO ELEMENTS -// ------------------------------------------------------------------------------------------------ - -impl ToElements for ProgramInfo { - fn to_elements(&self) -> Vec { - let num_kernel_proc_elements = self.kernel.proc_hashes().len() * WORD_SIZE; - let mut result = Vec::with_capacity(WORD_SIZE + num_kernel_proc_elements); - - // append program hash elements - result.extend_from_slice(self.program_hash.as_elements()); - - // append kernel procedure hash elements - for proc_hash in self.kernel.proc_hashes() { - result.extend_from_slice(proc_hash.as_elements()); - } - result - } -} diff --git a/core/src/program/mod.rs b/core/src/program/mod.rs deleted file mode 100644 index c343b0ebcb..0000000000 --- a/core/src/program/mod.rs +++ /dev/null @@ -1,195 +0,0 @@ -use super::{ - chiplets::hasher::{self, Digest}, - errors, Felt, Operation, -}; -use crate::utils::{ByteReader, ByteWriter, Deserializable, DeserializationError, Serializable}; -use alloc::{collections::BTreeMap, vec::Vec}; -use core::fmt; - -pub mod blocks; -use blocks::CodeBlock; - -mod info; -pub use info::ProgramInfo; - -#[cfg(test)] -mod tests; - -// PROGRAM -// ================================================================================================ -/// A program which can be executed by the VM. -/// -/// A program is described by a Merkelized Abstract Syntax Tree (MAST), where each node is a -/// [CodeBlock]. Internal nodes describe control flow semantics of the program, while leaf nodes -/// contain linear sequences of instructions which contain no control flow. -#[derive(Clone, Debug)] -pub struct Program { - root: CodeBlock, - kernel: Kernel, - cb_table: CodeBlockTable, -} - -impl Program { - // CONSTRUCTORS - // -------------------------------------------------------------------------------------------- - /// Instantiates a new [Program] from the specified code block. - pub fn new(root: CodeBlock) -> Self { - Self::with_kernel(root, Kernel::default(), CodeBlockTable::default()) - } - - /// Instantiates a new [Program] from the specified code block and associated code block table. - pub fn with_kernel(root: CodeBlock, kernel: Kernel, cb_table: CodeBlockTable) -> Self { - Self { - root, - kernel, - cb_table, - } - } - - // PUBLIC ACCESSORS - // -------------------------------------------------------------------------------------------- - - /// Returns the root code block of this program. - pub fn root(&self) -> &CodeBlock { - &self.root - } - - /// Returns a hash of this program. - pub fn hash(&self) -> Digest { - self.root.hash() - } - - /// Returns a kernel for this program. - pub fn kernel(&self) -> &Kernel { - &self.kernel - } - - /// Returns code block table for this program. - pub fn cb_table(&self) -> &CodeBlockTable { - &self.cb_table - } -} - -impl crate::prettier::PrettyPrint for Program { - fn render(&self) -> crate::prettier::Document { - use crate::prettier::*; - - indent(4, const_text("begin") + nl() + self.root.render()) + nl() + const_text("end") - } -} - -impl fmt::Display for Program { - fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { - use crate::prettier::PrettyPrint; - self.pretty_print(f) - } -} - -// CODE BLOCK TABLE -// ================================================================================================ - -/// A map of code block hashes to their underlying code blocks. -/// -/// This table is used to hold code blocks which are referenced from the program MAST but are -/// actually not a part of the MAST itself. Thus, for example, multiple nodes in the MAST can -/// reference the same code block in the table. -#[derive(Clone, Debug, Default)] -pub struct CodeBlockTable(BTreeMap<[u8; 32], CodeBlock>); - -impl CodeBlockTable { - /// Returns a code block for the specified hash, or None if the code block is not present - /// in this table. - pub fn get(&self, hash: Digest) -> Option<&CodeBlock> { - let key: [u8; 32] = hash.into(); - self.0.get(&key) - } - - /// Returns true if a code block with the specified hash is present in this table. - pub fn has(&self, hash: Digest) -> bool { - let key: [u8; 32] = hash.into(); - self.0.contains_key(&key) - } - - /// Inserts the provided code block into this table. - pub fn insert(&mut self, block: CodeBlock) { - let key: [u8; 32] = block.hash().into(); - self.0.insert(key, block); - } - - /// Returns true if this code block table is empty. - pub fn is_empty(&self) -> bool { - self.0.is_empty() - } -} - -// KERNEL -// ================================================================================================ - -/// A list of procedure hashes defining a VM kernel. -/// -/// The internally-stored list always has a consistent order, regardless of the order of procedure -/// list used to instantiate a kernel. -#[derive(Debug, Clone, Default, PartialEq, Eq)] -pub struct Kernel(Vec); - -pub const MAX_KERNEL_PROCEDURES: usize = u8::MAX as usize; - -impl Kernel { - /// Returns a new [Kernel] instantiated with the specified procedure hashes. - pub fn new(proc_hashes: &[Digest]) -> Result { - if proc_hashes.len() > MAX_KERNEL_PROCEDURES { - Err(errors::KernelError::TooManyProcedures(MAX_KERNEL_PROCEDURES, proc_hashes.len())) - } else { - let mut hashes = proc_hashes.to_vec(); - hashes.sort_by_key(|v| v.as_bytes()); // ensure consistent order - - let duplicated = hashes.windows(2).any(|data| data[0] == data[1]); - - if duplicated { - Err(errors::KernelError::DuplicatedProcedures) - } else { - Ok(Self(hashes)) - } - } - } - - /// Returns true if this kernel does not contain any procedures. - pub fn is_empty(&self) -> bool { - self.0.is_empty() - } - - /// Returns true if a procedure with the specified hash belongs to this kernel. - pub fn contains_proc(&self, proc_hash: Digest) -> bool { - // linear search here is OK because we expect the kernels to have a relatively small number - // of procedures (e.g., under 100) - self.0.iter().any(|&h| h == proc_hash) - } - - /// Returns a list of procedure hashes contained in this kernel. - pub fn proc_hashes(&self) -> &[Digest] { - &self.0 - } -} - -// this is required by AIR as public inputs will be serialized with the proof -impl Serializable for Kernel { - fn write_into(&self, target: &mut W) { - debug_assert!(self.0.len() <= MAX_KERNEL_PROCEDURES); - target.write_usize(self.0.len()); - target.write_many(&self.0) - } -} - -impl Deserializable for Kernel { - fn read_from(source: &mut R) -> Result { - let len = source.read_usize()?; - if len > MAX_KERNEL_PROCEDURES { - return Err(DeserializationError::InvalidValue(format!( - "Number of kernel procedures can not be more than {}, but {} was provided", - MAX_KERNEL_PROCEDURES, len - ))); - } - let kernel = source.read_many::(len)?; - Ok(Self(kernel)) - } -} diff --git a/miden/README.md b/miden/README.md index e7057e05cb..80e87ad5e9 100644 --- a/miden/README.md +++ b/miden/README.md @@ -50,14 +50,14 @@ The `execute_iter()` function takes similar arguments (but without the `options` For example: ```rust -use miden_vm::{Assembler, execute, execute_iter, DefaultHost, StackInputs}; +use miden_vm::{Assembler, execute, execute_iter, DefaultHost, Program, StackInputs}; use processor::ExecutionOptions; // instantiate the assembler let mut assembler = Assembler::default(); // compile Miden assembly source code into a program -let program = assembler.assemble("begin push.3 push.5 add end").unwrap(); +let program = assembler.assemble_program("begin push.3 push.5 add end").unwrap(); // use an empty list as initial stack let stack_inputs = StackInputs::default(); @@ -99,13 +99,13 @@ If the program is executed successfully, the function returns a tuple with 2 ele Here is a simple example of executing a program which pushes two numbers onto the stack and computes their sum: ```rust -use miden_vm::{Assembler, DefaultHost, ProvingOptions, prove, StackInputs}; +use miden_vm::{Assembler, DefaultHost, ProvingOptions, Program, prove, StackInputs}; // instantiate the assembler let mut assembler = Assembler::default(); // this is our program, we compile it from assembly code -let program = assembler.assemble("begin push.3 push.5 add end").unwrap(); +let program = assembler.assemble_program("begin push.3 push.5 add end").unwrap(); // let's execute it and generate a STARK proof let (outputs, proof) = prove( @@ -177,7 +177,7 @@ add // stack state: 3 2 Notice that except for the first 2 operations which initialize the stack, the sequence of `swap dup.1 add` operations repeats over and over. In fact, we can repeat these operations an arbitrary number of times to compute an arbitrary Fibonacci number. In Rust, it would look like this (this is actually a simplified version of the example in [fibonacci.rs](src/examples/src/fibonacci.rs)): ```rust -use miden_vm::{Assembler, DefaultHost, ProvingOptions, StackInputs}; +use miden_vm::{Assembler, DefaultHost, Program, ProvingOptions, StackInputs}; // set the number of terms to compute let n = 50; @@ -193,7 +193,7 @@ let source = format!( n - 1 ); let mut assembler = Assembler::default(); -let program = assembler.assemble(&source).unwrap(); +let program = assembler.assemble_program(&source).unwrap(); // initialize a default host (with an empty advice provider) let host = DefaultHost::default(); diff --git a/miden/benches/program_compilation.rs b/miden/benches/program_compilation.rs index 31480708f8..95fa9b38ff 100644 --- a/miden/benches/program_compilation.rs +++ b/miden/benches/program_compilation.rs @@ -15,7 +15,7 @@ fn program_compilation(c: &mut Criterion) { exec.sha256::hash_2to1 end"; bench.iter(|| { - let mut assembler = Assembler::default() + let assembler = Assembler::default() .with_library(&StdLibrary::default()) .expect("failed to load stdlib"); assembler.assemble(source).expect("Failed to compile test source.") diff --git a/miden/benches/program_execution.rs b/miden/benches/program_execution.rs index 1bb954910f..1d191986ab 100644 --- a/miden/benches/program_execution.rs +++ b/miden/benches/program_execution.rs @@ -1,6 +1,6 @@ use criterion::{criterion_group, criterion_main, Criterion}; -use miden_vm::{execute, Assembler, DefaultHost, StackInputs}; -use processor::ExecutionOptions; +use miden_vm::{Assembler, DefaultHost, StackInputs}; +use processor::{execute, ExecutionOptions, Program}; use std::time::Duration; use stdlib::StdLibrary; @@ -15,10 +15,14 @@ fn program_execution(c: &mut Criterion) { begin exec.sha256::hash_2to1 end"; - let mut assembler = Assembler::default() + let assembler = Assembler::default() .with_library(&StdLibrary::default()) .expect("failed to load stdlib"); - let program = assembler.assemble(source).expect("Failed to compile test source."); + let program: Program = assembler + .assemble(source) + .expect("Failed to compile test source.") + .try_into() + .expect("test source has no entrypoint."); bench.iter(|| { execute( &program, diff --git a/miden/src/cli/data.rs b/miden/src/cli/data.rs index fbe0b90411..1d797b0c7b 100644 --- a/miden/src/cli/data.rs +++ b/miden/src/cli/data.rs @@ -419,8 +419,9 @@ impl ProgramFile { .with_libraries(libraries.into_iter()) .wrap_err("Failed to load libraries")?; - let program = - assembler.assemble(self.ast.as_ref()).wrap_err("Failed to compile program")?; + let program: Program = assembler + .assemble_program(self.ast.as_ref()) + .wrap_err("Failed to compile program")?; Ok(program) } diff --git a/miden/src/examples/blake3.rs b/miden/src/examples/blake3.rs index 2e87cadbcc..7bae853923 100644 --- a/miden/src/examples/blake3.rs +++ b/miden/src/examples/blake3.rs @@ -47,7 +47,7 @@ fn generate_blake3_program(n: usize) -> Program { Assembler::default() .with_library(&StdLibrary::default()) .unwrap() - .assemble(program) + .assemble_program(program) .unwrap() } diff --git a/miden/src/examples/fibonacci.rs b/miden/src/examples/fibonacci.rs index f81dcb32ff..caa177abe8 100644 --- a/miden/src/examples/fibonacci.rs +++ b/miden/src/examples/fibonacci.rs @@ -39,7 +39,7 @@ fn generate_fibonacci_program(n: usize) -> Program { n - 1 ); - Assembler::default().assemble(program).unwrap() + Assembler::default().assemble_program(program).unwrap() } /// Computes the `n`-th term of Fibonacci sequence diff --git a/miden/src/repl/mod.rs b/miden/src/repl/mod.rs index aa39436d98..aea32e62ea 100644 --- a/miden/src/repl/mod.rs +++ b/miden/src/repl/mod.rs @@ -1,5 +1,5 @@ use assembly::{Assembler, Library, MaslLibrary}; -use miden_vm::{math::Felt, DefaultHost, StackInputs, Word}; +use miden_vm::{math::Felt, DefaultHost, Program, StackInputs, Word}; use processor::ContextId; use rustyline::{error::ReadlineError, DefaultEditor}; use std::{collections::BTreeSet, path::PathBuf}; @@ -293,7 +293,7 @@ fn execute( .with_libraries(provided_libraries.iter()) .map_err(|err| format!("{err}"))?; - let program = assembler.assemble(program).map_err(|err| format!("{err}"))?; + let program = assembler.assemble_program(program).map_err(|err| format!("{err}"))?; let stack_inputs = StackInputs::default(); let host = DefaultHost::default(); diff --git a/miden/src/tools/mod.rs b/miden/src/tools/mod.rs index 0028b2c4e1..fce3d34dc8 100644 --- a/miden/src/tools/mod.rs +++ b/miden/src/tools/mod.rs @@ -2,7 +2,7 @@ use super::cli::InputFile; use assembly::diagnostics::{IntoDiagnostic, Report, WrapErr}; use clap::Parser; use core::fmt; -use miden_vm::{Assembler, DefaultHost, Host, Operation, StackInputs}; +use miden_vm::{Assembler, DefaultHost, Host, Operation, Program, StackInputs}; use processor::{AsmOpInfo, TraceLenSummary}; use std::{fs, path::PathBuf}; use stdlib::StdLibrary; @@ -216,7 +216,7 @@ where let program = Assembler::default() .with_debug_mode(true) .with_library(&StdLibrary::default())? - .assemble(program)?; + .assemble_program(program)?; let mut execution_details = ExecutionDetails::default(); let vm_state_iterator = processor::execute_iter(&program, stack_inputs, host); diff --git a/miden/tests/integration/cli/cli_test.rs b/miden/tests/integration/cli/cli_test.rs index f25ad2fd25..34b42bd7aa 100644 --- a/miden/tests/integration/cli/cli_test.rs +++ b/miden/tests/integration/cli/cli_test.rs @@ -22,9 +22,9 @@ fn cli_run() -> Result<(), Box> { .arg("-n") .arg("1") .arg("-m") - .arg("4096") + .arg("8192") .arg("-e") - .arg("4096"); + .arg("8192"); let output = cmd.unwrap(); diff --git a/miden/tests/integration/operations/decorators/asmop.rs b/miden/tests/integration/operations/decorators/asmop.rs index 444f5630fd..8dfcaba749 100644 --- a/miden/tests/integration/operations/decorators/asmop.rs +++ b/miden/tests/integration/operations/decorators/asmop.rs @@ -136,10 +136,20 @@ fn asmop_repeat_test() { VmStatePartial { clk: 1, asmop: None, - op: Some(Operation::Span), + op: Some(Operation::Join), }, VmStatePartial { clk: 2, + asmop: None, + op: Some(Operation::Join), + }, + VmStatePartial { + clk: 3, + asmop: None, + op: Some(Operation::Span), + }, + VmStatePartial { + clk: 4, asmop: Some(AsmOpInfo::new( AssemblyOp::new("#exec::#main".to_string(), 2, "push.1".to_string(), false), 1, @@ -147,7 +157,7 @@ fn asmop_repeat_test() { op: Some(Operation::Pad), }, VmStatePartial { - clk: 3, + clk: 5, asmop: Some(AsmOpInfo::new( AssemblyOp::new("#exec::#main".to_string(), 2, "push.1".to_string(), false), 2, @@ -155,7 +165,7 @@ fn asmop_repeat_test() { op: Some(Operation::Incr), }, VmStatePartial { - clk: 4, + clk: 6, asmop: Some(AsmOpInfo::new( AssemblyOp::new("#exec::#main".to_string(), 1, "push.2".to_string(), false), 1, @@ -163,15 +173,26 @@ fn asmop_repeat_test() { op: Some(Operation::Push(Felt::new(2))), }, VmStatePartial { - clk: 5, + clk: 7, asmop: Some(AsmOpInfo::new( AssemblyOp::new("#exec::#main".to_string(), 1, "add".to_string(), false), 1, )), op: Some(Operation::Add), }, + // End first Span VmStatePartial { - clk: 6, + clk: 8, + asmop: None, + op: Some(Operation::End), + }, + VmStatePartial { + clk: 9, + asmop: None, + op: Some(Operation::Span), + }, + VmStatePartial { + clk: 10, asmop: Some(AsmOpInfo::new( AssemblyOp::new("#exec::#main".to_string(), 2, "push.1".to_string(), false), 1, @@ -179,7 +200,7 @@ fn asmop_repeat_test() { op: Some(Operation::Pad), }, VmStatePartial { - clk: 7, + clk: 11, asmop: Some(AsmOpInfo::new( AssemblyOp::new("#exec::#main".to_string(), 2, "push.1".to_string(), false), 2, @@ -187,7 +208,7 @@ fn asmop_repeat_test() { op: Some(Operation::Incr), }, VmStatePartial { - clk: 8, + clk: 12, asmop: Some(AsmOpInfo::new( AssemblyOp::new("#exec::#main".to_string(), 1, "push.2".to_string(), false), 1, @@ -195,15 +216,32 @@ fn asmop_repeat_test() { op: Some(Operation::Push(Felt::new(2))), }, VmStatePartial { - clk: 9, + clk: 13, asmop: Some(AsmOpInfo::new( AssemblyOp::new("#exec::#main".to_string(), 1, "add".to_string(), false), 1, )), op: Some(Operation::Add), }, + // End second Span VmStatePartial { - clk: 10, + clk: 14, + asmop: None, + op: Some(Operation::End), + }, + // End first Join + VmStatePartial { + clk: 15, + asmop: None, + op: Some(Operation::End), + }, + VmStatePartial { + clk: 16, + asmop: None, + op: Some(Operation::Span), + }, + VmStatePartial { + clk: 17, asmop: Some(AsmOpInfo::new( AssemblyOp::new("#exec::#main".to_string(), 2, "push.1".to_string(), false), 1, @@ -211,7 +249,7 @@ fn asmop_repeat_test() { op: Some(Operation::Pad), }, VmStatePartial { - clk: 11, + clk: 18, asmop: Some(AsmOpInfo::new( AssemblyOp::new("#exec::#main".to_string(), 2, "push.1".to_string(), false), 2, @@ -219,7 +257,7 @@ fn asmop_repeat_test() { op: Some(Operation::Incr), }, VmStatePartial { - clk: 12, + clk: 19, asmop: Some(AsmOpInfo::new( AssemblyOp::new("#exec::#main".to_string(), 1, "push.2".to_string(), false), 1, @@ -227,30 +265,22 @@ fn asmop_repeat_test() { op: Some(Operation::Push(Felt::new(2))), }, VmStatePartial { - clk: 13, + clk: 20, asmop: Some(AsmOpInfo::new( AssemblyOp::new("#exec::#main".to_string(), 1, "add".to_string(), false), 1, )), op: Some(Operation::Add), }, + // End Span VmStatePartial { - clk: 14, - asmop: None, - op: Some(Operation::Noop), - }, - VmStatePartial { - clk: 15, + clk: 21, asmop: None, - op: Some(Operation::Noop), - }, - VmStatePartial { - clk: 16, - asmop: None, - op: Some(Operation::Noop), + op: Some(Operation::End), }, + // End second Join VmStatePartial { - clk: 17, + clk: 22, asmop: None, op: Some(Operation::End), }, diff --git a/miden/tests/integration/operations/decorators/events.rs b/miden/tests/integration/operations/decorators/events.rs index 6c2815f839..d1d6397327 100644 --- a/miden/tests/integration/operations/decorators/events.rs +++ b/miden/tests/integration/operations/decorators/events.rs @@ -1,6 +1,6 @@ use super::TestHost; use assembly::Assembler; -use processor::ExecutionOptions; +use processor::{ExecutionOptions, Program}; #[test] fn test_event_handling() { @@ -13,7 +13,11 @@ fn test_event_handling() { end"; // compile and execute program - let program = Assembler::default().assemble(source).unwrap(); + let program: Program = Assembler::default() + .assemble(source) + .unwrap() + .try_into() + .expect("test source has no entrypoint."); let mut host = TestHost::default(); processor::execute(&program, Default::default(), &mut host, Default::default()).unwrap(); @@ -33,7 +37,11 @@ fn test_trace_handling() { end"; // compile program - let program = Assembler::default().assemble(source).unwrap(); + let program: Program = Assembler::default() + .assemble(source) + .unwrap() + .try_into() + .expect("test source has no entrypoint."); let mut host = TestHost::default(); // execute program with disabled tracing diff --git a/miden/tests/integration/operations/io_ops/env_ops.rs b/miden/tests/integration/operations/io_ops/env_ops.rs index 8716d6e42d..c8acfbd41e 100644 --- a/miden/tests/integration/operations/io_ops/env_ops.rs +++ b/miden/tests/integration/operations/io_ops/env_ops.rs @@ -1,6 +1,9 @@ use processor::FMP_MIN; use test_utils::{build_op_test, build_test, StackInputs, Test, Word, STACK_TOP_SIZE}; -use vm_core::{code_blocks::CodeBlock, Operation}; +use vm_core::{ + mast::{MastForest, MastNode, MerkleTreeNode}, + Operation, +}; // SDEPTH INSTRUCTION // ================================================================================================ @@ -155,9 +158,13 @@ fn caller() { } fn build_bar_hash() -> [u64; 4] { - let foo_root = CodeBlock::new_span(vec![Operation::Caller]); - let bar_root = CodeBlock::new_syscall(foo_root.hash()); - let bar_hash: Word = bar_root.hash().into(); + let mut mast_forest = MastForest::new(); + + let foo_root = MastNode::new_basic_block(vec![Operation::Caller]); + let foo_root_id = mast_forest.ensure_node(foo_root); + + let bar_root = MastNode::new_syscall(foo_root_id, &mast_forest); + let bar_hash: Word = bar_root.digest().into(); [ bar_hash[0].as_int(), bar_hash[1].as_int(), diff --git a/processor/src/chiplets/hasher/mod.rs b/processor/src/chiplets/hasher/mod.rs index 941617f510..88d08512b8 100644 --- a/processor/src/chiplets/hasher/mod.rs +++ b/processor/src/chiplets/hasher/mod.rs @@ -119,7 +119,7 @@ impl Hasher { /// /// The returned tuple also contains the row address of the execution trace at which the hash /// computation started. - pub(super) fn hash_span_block( + pub(super) fn hash_basic_block( &mut self, op_batches: &[OpBatch], expected_hash: Digest, diff --git a/processor/src/chiplets/hasher/tests.rs b/processor/src/chiplets/hasher/tests.rs index dca1baa5ca..4462ad6bef 100644 --- a/processor/src/chiplets/hasher/tests.rs +++ b/processor/src/chiplets/hasher/tests.rs @@ -11,8 +11,8 @@ use miden_air::trace::chiplets::hasher::{ use test_utils::rand::rand_array; use vm_core::{ chiplets::hasher, - code_blocks::CodeBlock, crypto::merkle::{MerkleTree, NodeIndex}, + mast::{MastForest, MastNode, MerkleTreeNode}, Operation, ONE, ZERO, }; @@ -246,48 +246,58 @@ fn hash_memoization_control_blocks() { // / \ // Split1 Split2 (memoized) - let t_branch = CodeBlock::new_span(vec![Operation::Push(ZERO)]); - let f_branch = CodeBlock::new_span(vec![Operation::Push(ONE)]); - let split1_block = CodeBlock::new_split(t_branch.clone(), f_branch.clone()); - let split2_block = CodeBlock::new_split(t_branch.clone(), f_branch.clone()); - let join_block = CodeBlock::new_join([split1_block.clone(), split2_block.clone()]); + let mut mast_forest = MastForest::new(); + + let t_branch = MastNode::new_basic_block(vec![Operation::Push(ZERO)]); + let t_branch_id = mast_forest.ensure_node(t_branch.clone()); + + let f_branch = MastNode::new_basic_block(vec![Operation::Push(ONE)]); + let f_branch_id = mast_forest.ensure_node(f_branch.clone()); + + let split1 = MastNode::new_split(t_branch_id, f_branch_id, &mast_forest); + let split1_id = mast_forest.ensure_node(split1.clone()); + + let split2 = MastNode::new_split(t_branch_id, f_branch_id, &mast_forest); + let split2_id = mast_forest.ensure_node(split2.clone()); + + let join_node = MastNode::new_join(split1_id, split2_id, &mast_forest); + let _join_node_id = mast_forest.ensure_node(join_node.clone()); let mut hasher = Hasher::default(); - let h1: [Felt; DIGEST_LEN] = split1_block - .hash() + let h1: [Felt; DIGEST_LEN] = split1 + .digest() .as_elements() .try_into() .expect("Could not convert slice to array"); - let h2: [Felt; DIGEST_LEN] = split2_block - .hash() + let h2: [Felt; DIGEST_LEN] = split2 + .digest() .as_elements() .try_into() .expect("Could not convert slice to array"); - let expected_hash = join_block.hash(); + let expected_hash = join_node.digest(); // builds the trace of the join block. - let (_, final_state) = hasher.hash_control_block(h1, h2, join_block.domain(), expected_hash); + let (_, final_state) = hasher.hash_control_block(h1, h2, join_node.domain(), expected_hash); // make sure the hash of the final state is the same as the expected hash. assert_eq!(Digest::new(final_state), expected_hash); let h1: [Felt; DIGEST_LEN] = t_branch - .hash() + .digest() .as_elements() .try_into() .expect("Could not convert slice to array"); let h2: [Felt; DIGEST_LEN] = f_branch - .hash() + .digest() .as_elements() .try_into() .expect("Could not convert slice to array"); - let expected_hash = split1_block.hash(); + let expected_hash = split1.digest(); // builds the hash execution trace of the first split block from scratch. - let (addr, final_state) = - hasher.hash_control_block(h1, h2, split1_block.domain(), expected_hash); + let (addr, final_state) = hasher.hash_control_block(h1, h2, split1.domain(), expected_hash); let first_block_final_state = final_state; @@ -299,21 +309,20 @@ fn hash_memoization_control_blocks() { let end_row = hasher.trace_len() - 1; let h1: [Felt; DIGEST_LEN] = t_branch - .hash() + .digest() .as_elements() .try_into() .expect("Could not convert slice to array"); let h2: [Felt; DIGEST_LEN] = f_branch - .hash() + .digest() .as_elements() .try_into() .expect("Could not convert slice to array"); - let expected_hash = split2_block.hash(); + let expected_hash = split2.digest(); // builds the hash execution trace of the second split block by copying it from the trace of // the first split block. - let (addr, final_state) = - hasher.hash_control_block(h1, h2, split2_block.domain(), expected_hash); + let (addr, final_state) = hasher.hash_control_block(h1, h2, split2.domain(), expected_hash); // make sure the hash of the final state of the second split block is the same as the expected // hash. @@ -337,14 +346,15 @@ fn hash_memoization_control_blocks() { } #[test] -fn hash_memoization_span_blocks() { - // --- span block with 1 batch ---------------------------------------------------------------- - let span_block = CodeBlock::new_span(vec![Operation::Push(Felt::new(10)), Operation::Drop]); +fn hash_memoization_basic_blocks() { + // --- basic block with 1 batch ---------------------------------------------------------------- + let basic_block = + MastNode::new_basic_block(vec![Operation::Push(Felt::new(10)), Operation::Drop]); - hash_memoization_span_blocks_check(span_block); + hash_memoization_basic_blocks_check(basic_block); - // --- span block with multiple batches ------------------------------------------------------- - let span_block = CodeBlock::new_span(vec![ + // --- basic block with multiple batches ------------------------------------------------------- + let basic_block = MastNode::new_basic_block(vec![ Operation::Push(ONE), Operation::Push(Felt::new(2)), Operation::Push(Felt::new(3)), @@ -383,43 +393,55 @@ fn hash_memoization_span_blocks() { Operation::Drop, ]); - hash_memoization_span_blocks_check(span_block); + hash_memoization_basic_blocks_check(basic_block); } -fn hash_memoization_span_blocks_check(span_block: CodeBlock) { - // Join block with a join and span block as children. The span child of the first join - // child block is the same as the span child of root join block. Here the hash execution - // trace of the second span block is built by copying the trace built for the first same - // span block. +fn hash_memoization_basic_blocks_check(basic_block: MastNode) { + // Join block with a join and basic block as children. The child of the first join + // child node is the same as the basic block child of root join node. Here the hash execution + // trace of the second basic block is built by copying the trace built for the first same + // basic block. // Join1 // / \ // / \ // / \ - // Join2 Span2 (memoized) + // Join2 BB2 (memoized) // / \ // / \ // / \ - // Span1 Loop + // BB1 Loop - let span1_block = span_block.clone(); - let loop_body = CodeBlock::new_span(vec![Operation::Pad, Operation::Eq, Operation::Not]); - let loop_block = CodeBlock::new_loop(loop_body); - let join2_block = CodeBlock::new_join([span1_block.clone(), loop_block.clone()]); - let span2_block = span_block; - let join1_block = CodeBlock::new_join([join2_block.clone(), span2_block.clone()]); + let mut mast_forest = MastForest::new(); + + let basic_block_1 = basic_block.clone(); + let basic_block_1_id = mast_forest.ensure_node(basic_block_1.clone()); + + let loop_body = MastNode::new_basic_block(vec![Operation::Pad, Operation::Eq, Operation::Not]); + let loop_body_id = mast_forest.ensure_node(loop_body); + + let loop_block = MastNode::new_loop(loop_body_id, &mast_forest); + let loop_block_id = mast_forest.ensure_node(loop_block.clone()); + + let join2_block = MastNode::new_join(basic_block_1_id, loop_block_id, &mast_forest); + let join2_block_id = mast_forest.ensure_node(join2_block.clone()); + + let basic_block_2 = basic_block; + let basic_block_2_id = mast_forest.ensure_node(basic_block_2.clone()); + + let join1_block = MastNode::new_join(join2_block_id, basic_block_2_id, &mast_forest); let mut hasher = Hasher::default(); let h1: [Felt; DIGEST_LEN] = join2_block - .hash() + .digest() .as_elements() .try_into() .expect("Could not convert slice to array"); - let h2: [Felt; DIGEST_LEN] = span2_block - .hash() + let h2: [Felt; DIGEST_LEN] = basic_block_2 + .digest() .as_elements() .try_into() .expect("Could not convert slice to array"); - let expected_hash = join1_block.hash(); + let expected_hash = join1_block.digest(); // builds the trace of the Join1 block. let (_, final_state) = hasher.hash_control_block(h1, h2, join1_block.domain(), expected_hash); @@ -427,63 +449,61 @@ fn hash_memoization_span_blocks_check(span_block: CodeBlock) { // make sure the hash of the final state of Join1 is the same as the expected hash. assert_eq!(Digest::new(final_state), expected_hash); - let h1: [Felt; DIGEST_LEN] = span1_block - .hash() + let h1: [Felt; DIGEST_LEN] = basic_block_1 + .digest() .as_elements() .try_into() .expect("Could not convert slice to array"); let h2: [Felt; DIGEST_LEN] = loop_block - .hash() + .digest() .as_elements() .try_into() .expect("Could not convert slice to array"); - let expected_hash = join2_block.hash(); + let expected_hash = join2_block.digest(); let (_, final_state) = hasher.hash_control_block(h1, h2, join2_block.domain(), expected_hash); // make sure the hash of the final state of Join2 is the same as the expected hash. assert_eq!(Digest::new(final_state), expected_hash); - let span1_block_val = if let CodeBlock::Span(span) = span1_block.clone() { - span + let basic_block_1_val = if let MastNode::Block(basic_block) = basic_block_1.clone() { + basic_block } else { unreachable!() }; - // builds the hash execution trace of the first span block from scratch. + // builds the hash execution trace of the first basic block from scratch. let (addr, final_state) = - hasher.hash_span_block(span1_block_val.op_batches(), span1_block.hash()); - - let _num_batches = span1_block_val.op_batches().len(); + hasher.hash_basic_block(basic_block_1_val.op_batches(), basic_block_1.digest()); - let first_span_block_final_state = final_state; + let first_basic_block_final_state = final_state; - // make sure the hash of the final state of Span1 block is the same as the expected hash. - let expected_hash = span1_block.hash(); + // make sure the hash of the final state of basic block 1 is the same as the expected hash. + let expected_hash = basic_block_1.digest(); assert_eq!(Digest::new(final_state), expected_hash); let start_row = addr.as_int() as usize - 1; let end_row = hasher.trace_len() - 1; - let span2_block_val = if let CodeBlock::Span(span) = span2_block.clone() { - span + let basic_block_2_val = if let MastNode::Block(basic_block) = basic_block_2.clone() { + basic_block } else { unreachable!() }; - // builds the hash execution trace of the second span block by copying the sections of the - // trace corresponding to the first span block with the same hash. + // builds the hash execution trace of the second basic block by copying the sections of the + // trace corresponding to the first basic block with the same hash. let (addr, final_state) = - hasher.hash_span_block(span2_block_val.op_batches(), span2_block.hash()); + hasher.hash_basic_block(basic_block_2_val.op_batches(), basic_block_2.digest()); - let _num_batches = span2_block_val.op_batches().len(); + let _num_batches = basic_block_2_val.op_batches().len(); - let expected_hash = span2_block.hash(); - // make sure the hash of the final state of Span2 block is the same as the expected hash. + let expected_hash = basic_block_2.digest(); + // make sure the hash of the final state of basic block 2 is the same as the expected hash. assert_eq!(Digest::new(final_state), expected_hash); - // make sure the hash of the first and second span blocks is the same. - assert_eq!(first_span_block_final_state, final_state); + // make sure the hash of the first and second basic blocks is the same. + assert_eq!(first_basic_block_final_state, final_state); let copied_start_row = addr.as_int() as usize - 1; let copied_end_row = hasher.trace_len() - 1; diff --git a/processor/src/chiplets/mod.rs b/processor/src/chiplets/mod.rs index 137edc24f3..dad7b0b04d 100644 --- a/processor/src/chiplets/mod.rs +++ b/processor/src/chiplets/mod.rs @@ -6,7 +6,7 @@ use super::{ }; use alloc::vec::Vec; use miden_air::trace::chiplets::hasher::{Digest, HasherState}; -use vm_core::{code_blocks::OpBatch, Kernel}; +use vm_core::{mast::OpBatch, Kernel}; mod bitwise; use bitwise::Bitwise; @@ -37,39 +37,39 @@ mod tests; /// chiplet selectors. /// /// The module's trace can be thought of as 5 stacked chiplet segments in the following form: -/// * Hasher segment: contains the trace and selector for the hasher chiplet * -/// This segment fills the first rows of the trace up to the length of the hasher `trace_len`. +/// * Hasher segment: contains the trace and selector for the hasher chiplet * This segment fills +/// the first rows of the trace up to the length of the hasher `trace_len`. /// - column 0: selector column with values set to ZERO /// - columns 1-17: execution trace of hash chiplet /// -/// * Bitwise segment: contains the trace and selectors for the bitwise chiplet * -/// This segment begins at the end of the hasher segment and fills the next rows of the trace for -/// the `trace_len` of the bitwise chiplet. +/// * Bitwise segment: contains the trace and selectors for the bitwise chiplet * This segment +/// begins at the end of the hasher segment and fills the next rows of the trace for the +/// `trace_len` of the bitwise chiplet. /// - column 0: selector column with values set to ONE /// - column 1: selector column with values set to ZERO /// - columns 2-14: execution trace of bitwise chiplet /// - columns 15-17: unused columns padded with ZERO /// -/// * Memory segment: contains the trace and selectors for the memory chiplet * -/// This segment begins at the end of the bitwise segment and fills the next rows of the trace for -/// the `trace_len` of the memory chiplet. +/// * Memory segment: contains the trace and selectors for the memory chiplet * This segment begins +/// at the end of the bitwise segment and fills the next rows of the trace for the `trace_len` of +/// the memory chiplet. /// - column 0-1: selector columns with values set to ONE /// - column 2: selector column with values set to ZERO /// - columns 3-14: execution trace of memory chiplet /// - columns 15-17: unused column padded with ZERO /// -/// * Kernel ROM segment: contains the trace and selectors for the kernel ROM chiplet * -/// This segment begins at the end of the memory segment and fills the next rows of the trace for -/// the `trace_len` of the kernel ROM chiplet. +/// * Kernel ROM segment: contains the trace and selectors for the kernel ROM chiplet * This segment +/// begins at the end of the memory segment and fills the next rows of the trace for the +/// `trace_len` of the kernel ROM chiplet. /// - column 0-2: selector columns with values set to ONE /// - column 3: selector column with values set to ZERO /// - columns 4-9: execution trace of kernel ROM chiplet /// - columns 10-17: unused column padded with ZERO /// -/// * Padding segment: unused * -/// This segment begins at the end of the kernel ROM segment and fills the rest of the execution -/// trace minus the number of random rows. When it finishes, the execution trace should have -/// exactly enough rows remaining for the specified number of random rows. +/// * Padding segment: unused * This segment begins at the end of the kernel ROM segment and fills +/// the rest of the execution trace minus the number of random rows. When it finishes, the +/// execution trace should have exactly enough rows remaining for the specified number of random +/// rows. /// - columns 0-3: selector columns with values set to ONE /// - columns 3-17: unused columns padded with ZERO /// @@ -251,7 +251,7 @@ impl Chiplets { /// /// It returns the row address of the execution trace at which the hash computation started. pub fn hash_span_block(&mut self, op_batches: &[OpBatch], expected_hash: Digest) -> Felt { - let (addr, result) = self.hasher.hash_span_block(op_batches, expected_hash); + let (addr, result) = self.hasher.hash_basic_block(op_batches, expected_hash); // make sure the result computed by the hasher is the same as the expected block hash debug_assert_eq!(expected_hash, result.into()); diff --git a/processor/src/chiplets/tests.rs b/processor/src/chiplets/tests.rs index c83f84b4b9..a7c6c526e7 100644 --- a/processor/src/chiplets/tests.rs +++ b/processor/src/chiplets/tests.rs @@ -1,6 +1,5 @@ use crate::{ - CodeBlock, DefaultHost, ExecutionOptions, ExecutionTrace, Kernel, Operation, Process, - StackInputs, + DefaultHost, ExecutionOptions, ExecutionTrace, Kernel, Operation, Process, StackInputs, }; use alloc::vec::Vec; use miden_air::trace::{ @@ -13,7 +12,10 @@ use miden_air::trace::{ }, CHIPLETS_RANGE, CHIPLETS_WIDTH, }; -use vm_core::{CodeBlockTable, Felt, ONE, ZERO}; +use vm_core::{ + mast::{MastForest, MastNode}, + Felt, ONE, ZERO, +}; type ChipletsTrace = [Vec; CHIPLETS_WIDTH]; @@ -114,8 +116,16 @@ fn build_trace( let stack_inputs = StackInputs::try_from_ints(stack_inputs.iter().copied()).unwrap(); let host = DefaultHost::default(); let mut process = Process::new(kernel, stack_inputs, host, ExecutionOptions::default()); - let program = CodeBlock::new_span(operations); - process.execute_code_block(&program, &CodeBlockTable::default()).unwrap(); + let program = { + let mut mast_forest = MastForest::new(); + + let basic_block = MastNode::new_basic_block(operations); + let basic_block_id = mast_forest.ensure_node(basic_block); + mast_forest.set_entrypoint(basic_block_id); + + mast_forest.try_into().unwrap() + }; + process.execute(&program).unwrap(); let (trace, _, _) = ExecutionTrace::test_finalize_trace(process); let trace_len = trace.num_rows() - ExecutionTrace::NUM_RAND_ROWS; diff --git a/processor/src/decoder/mod.rs b/processor/src/decoder/mod.rs index d5938d7a4a..72845ba549 100644 --- a/processor/src/decoder/mod.rs +++ b/processor/src/decoder/mod.rs @@ -1,6 +1,6 @@ use super::{ - Call, Dyn, ExecutionError, Felt, Host, Join, Loop, OpBatch, Operation, Process, Span, Split, - Word, EMPTY_WORD, MIN_TRACE_LEN, ONE, OP_BATCH_SIZE, ZERO, + ExecutionError, Felt, Host, OpBatch, Operation, Process, Word, EMPTY_WORD, MIN_TRACE_LEN, ONE, + ZERO, }; use alloc::vec::Vec; use miden_air::trace::{ @@ -10,7 +10,14 @@ use miden_air::trace::{ OP_BATCH_1_GROUPS, OP_BATCH_2_GROUPS, OP_BATCH_4_GROUPS, OP_BATCH_8_GROUPS, }, }; -use vm_core::{code_blocks::get_span_op_group_count, stack::STACK_TOP_SIZE, AssemblyOp}; +use vm_core::{ + mast::{ + get_span_op_group_count, BasicBlockNode, CallNode, DynNode, JoinNode, LoopNode, + MerkleTreeNode, SplitNode, OP_BATCH_SIZE, + }, + stack::STACK_TOP_SIZE, + AssemblyOp, Program, +}; mod trace; use trace::DecoderTrace; @@ -42,19 +49,39 @@ impl Process where H: Host, { - // JOIN BLOCK + // JOIN NODE // -------------------------------------------------------------------------------------------- - /// Starts decoding of a JOIN block. - pub(super) fn start_join_block(&mut self, block: &Join) -> Result<(), ExecutionError> { + /// Starts decoding of a JOIN node. + pub(super) fn start_join_node( + &mut self, + node: &JoinNode, + program: &Program, + ) -> Result<(), ExecutionError> { // use the hasher to compute the hash of the JOIN block; the row address returned by the // hasher is used as the ID of the block; the result of the hash is expected to be in // row addr + 7. - let child1_hash = block.first().hash().into(); - let child2_hash = block.second().hash().into(); - let addr = - self.chiplets - .hash_control_block(child1_hash, child2_hash, Join::DOMAIN, block.hash()); + let child1_hash = program + .get_node_by_id(node.first()) + .ok_or(ExecutionError::MastNodeNotFoundInForest { + node_id: node.first(), + })? + .digest() + .into(); + let child2_hash = program + .get_node_by_id(node.second()) + .ok_or(ExecutionError::MastNodeNotFoundInForest { + node_id: node.second(), + })? + .digest() + .into(); + + let addr = self.chiplets.hash_control_block( + child1_hash, + child2_hash, + JoinNode::DOMAIN, + node.digest(), + ); // start decoding the JOIN block; this appends a row with JOIN operation to the decoder // trace. when JOIN operation is executed, the rest of the VM state does not change @@ -62,31 +89,50 @@ where self.execute_op(Operation::Noop) } - /// Ends decoding of a JOIN block. - pub(super) fn end_join_block(&mut self, block: &Join) -> Result<(), ExecutionError> { + /// Ends decoding of a JOIN node. + pub(super) fn end_join_node(&mut self, node: &JoinNode) -> Result<(), ExecutionError> { // this appends a row with END operation to the decoder trace. when END operation is // executed the rest of the VM state does not change - self.decoder.end_control_block(block.hash().into()); + self.decoder.end_control_block(node.digest().into()); self.execute_op(Operation::Noop) } - // SPLIT BLOCK + // SPLIT NODE // -------------------------------------------------------------------------------------------- - /// Starts decoding a SPLIT block. This also pops the value from the top of the stack and + /// Starts decoding a SPLIT node. This also pops the value from the top of the stack and /// returns it. - pub(super) fn start_split_block(&mut self, block: &Split) -> Result { + pub(super) fn start_split_node( + &mut self, + node: &SplitNode, + program: &Program, + ) -> Result { let condition = self.stack.peek(); // use the hasher to compute the hash of the SPLIT block; the row address returned by the // hasher is used as the ID of the block; the result of the hash is expected to be in // row addr + 7. - let child1_hash = block.on_true().hash().into(); - let child2_hash = block.on_false().hash().into(); - let addr = - self.chiplets - .hash_control_block(child1_hash, child2_hash, Split::DOMAIN, block.hash()); + let child1_hash = program + .get_node_by_id(node.on_true()) + .ok_or(ExecutionError::MastNodeNotFoundInForest { + node_id: node.on_true(), + })? + .digest() + .into(); + let child2_hash = program + .get_node_by_id(node.on_false()) + .ok_or(ExecutionError::MastNodeNotFoundInForest { + node_id: node.on_false(), + })? + .digest() + .into(); + let addr = self.chiplets.hash_control_block( + child1_hash, + child2_hash, + SplitNode::DOMAIN, + node.digest(), + ); // start decoding the SPLIT block. this appends a row with SPLIT operation to the decoder // trace. we also pop the value off the top of the stack and return it. @@ -95,31 +141,44 @@ where Ok(condition) } - /// Ends decoding of a SPLIT block. - pub(super) fn end_split_block(&mut self, block: &Split) -> Result<(), ExecutionError> { + /// Ends decoding of a SPLIT node. + pub(super) fn end_split_node(&mut self, block: &SplitNode) -> Result<(), ExecutionError> { // this appends a row with END operation to the decoder trace. when END operation is // executed the rest of the VM state does not change - self.decoder.end_control_block(block.hash().into()); + self.decoder.end_control_block(block.digest().into()); self.execute_op(Operation::Noop) } - // LOOP BLOCK + // LOOP NODE // -------------------------------------------------------------------------------------------- - /// Starts decoding a LOOP block. This also pops the value from the top of the stack and + /// Starts decoding a LOOP node. This also pops the value from the top of the stack and /// returns it. - pub(super) fn start_loop_block(&mut self, block: &Loop) -> Result { + pub(super) fn start_loop_node( + &mut self, + node: &LoopNode, + program: &Program, + ) -> Result { let condition = self.stack.peek(); // use the hasher to compute the hash of the LOOP block; for LOOP block there is no // second child so we set the second hash to ZEROs; the row address returned by the // hasher is used as the ID of the block; the result of the hash is expected to be in // row addr + 7. - let body_hash = block.body().hash().into(); - let addr = - self.chiplets - .hash_control_block(body_hash, EMPTY_WORD, Loop::DOMAIN, block.hash()); + let body_hash = program + .get_node_by_id(node.body()) + .ok_or(ExecutionError::MastNodeNotFoundInForest { + node_id: node.body(), + })? + .digest() + .into(); + let addr = self.chiplets.hash_control_block( + body_hash, + EMPTY_WORD, + LoopNode::DOMAIN, + node.digest(), + ); // start decoding the LOOP block; this appends a row with LOOP operation to the decoder // trace, but if the value on the top of the stack is not ONE, the block is not marked @@ -133,13 +192,13 @@ where /// Ends decoding of a LOOP block. If pop_stack is set to true, this also removes the /// value at the top of the stack. - pub(super) fn end_loop_block( + pub(super) fn end_loop_node( &mut self, - block: &Loop, + node: &LoopNode, pop_stack: bool, ) -> Result<(), ExecutionError> { // this appends a row with END operation to the decoder trace. - self.decoder.end_control_block(block.hash().into()); + self.decoder.end_control_block(node.digest().into()); // if we are exiting a loop, we also need to pop the top value off the stack (and this // value must be ZERO - otherwise, we should have stayed in the loop). but, if we never @@ -156,18 +215,28 @@ where } } - // CALL BLOCK + // CALL NODE // -------------------------------------------------------------------------------------------- - /// Starts decoding of a CALL or a SYSCALL block. - pub(super) fn start_call_block(&mut self, block: &Call) -> Result<(), ExecutionError> { + /// Starts decoding of a CALL or a SYSCALL node. + pub(super) fn start_call_node( + &mut self, + node: &CallNode, + program: &Program, + ) -> Result<(), ExecutionError> { // use the hasher to compute the hash of the CALL or SYSCALL block; the row address // returned by the hasher is used as the ID of the block; the result of the hash is // expected to be in row addr + 7. - let fn_hash = block.fn_hash().into(); + let callee_hash = program + .get_node_by_id(node.callee()) + .ok_or(ExecutionError::MastNodeNotFoundInForest { + node_id: node.callee(), + })? + .digest() + .into(); let addr = self.chiplets - .hash_control_block(fn_hash, EMPTY_WORD, block.domain(), block.hash()); + .hash_control_block(callee_hash, EMPTY_WORD, node.domain(), node.digest()); // start new execution context for the operand stack. this has the effect of resetting // stack depth to 16. @@ -186,12 +255,12 @@ where next_overflow_addr, ); - if block.is_syscall() { + if node.is_syscall() { self.system.start_syscall(); - self.decoder.start_syscall(fn_hash, addr, ctx_info); + self.decoder.start_syscall(callee_hash, addr, ctx_info); } else { - self.system.start_call(fn_hash); - self.decoder.start_call(fn_hash, addr, ctx_info); + self.system.start_call(callee_hash); + self.decoder.start_call(callee_hash, addr, ctx_info); } // the rest of the VM state does not change @@ -199,7 +268,7 @@ where } /// Ends decoding of a CALL or a SYSCALL block. - pub(super) fn end_call_block(&mut self, block: &Call) -> Result<(), ExecutionError> { + pub(super) fn end_call_node(&mut self, node: &CallNode) -> Result<(), ExecutionError> { // when a CALL block ends, stack depth must be exactly 16 let stack_depth = self.stack.depth(); if stack_depth > STACK_TOP_SIZE { @@ -210,7 +279,7 @@ where // information about the execution context prior to execution of the CALL block let ctx_info = self .decoder - .end_control_block(block.hash().into()) + .end_control_block(node.digest().into()) .expect("no execution context"); // when returning from a function call or a syscall, restore the context of the system @@ -229,65 +298,71 @@ where self.execute_op(Operation::Noop) } - // DYN BLOCK + // DYN NODE // -------------------------------------------------------------------------------------------- - /// Starts decoding of a DYN block. - pub(super) fn start_dyn_block( - &mut self, - block: &Dyn, - dyn_hash: Word, - ) -> Result<(), ExecutionError> { - let addr = - self.chiplets - .hash_control_block(EMPTY_WORD, EMPTY_WORD, Dyn::DOMAIN, block.hash()); + /// Starts decoding of a DYN node. + pub(super) fn start_dyn_node(&mut self, callee_hash: Word) -> Result<(), ExecutionError> { + let addr = self.chiplets.hash_control_block( + EMPTY_WORD, + EMPTY_WORD, + DynNode::DOMAIN, + DynNode.digest(), + ); - self.decoder.start_dyn(dyn_hash, addr); + self.decoder.start_dyn(callee_hash, addr); self.execute_op(Operation::Noop) } - /// Ends decoding of a DYN block. - pub(super) fn end_dyn_block(&mut self, block: &Dyn) -> Result<(), ExecutionError> { + /// Ends decoding of a DYN node. + pub(super) fn end_dyn_node(&mut self) -> Result<(), ExecutionError> { // this appends a row with END operation to the decoder trace. when the END operation is // executed the rest of the VM state does not change - self.decoder.end_control_block(block.hash().into()); + self.decoder.end_control_block(DynNode.digest().into()); self.execute_op(Operation::Noop) } - // SPAN BLOCK + // BASIC BLOCK NODE // -------------------------------------------------------------------------------------------- - /// Starts decoding a SPAN block. - pub(super) fn start_span_block(&mut self, block: &Span) -> Result<(), ExecutionError> { + /// Starts decoding a BASIC BLOCK node. + pub(super) fn start_basic_block_node( + &mut self, + basic_block: &BasicBlockNode, + ) -> Result<(), ExecutionError> { // use the hasher to compute the hash of the SPAN block; the row address returned by the // hasher is used as the ID of the block; hash of a SPAN block is computed by sequentially // hashing operation batches. Thus, the result of the hash is expected to be in row // addr + (num_batches * 8) - 1. - let op_batches = block.op_batches(); - let addr = self.chiplets.hash_span_block(op_batches, block.hash()); + let op_batches = basic_block.op_batches(); + let addr = self.chiplets.hash_span_block(op_batches, basic_block.digest()); // start decoding the first operation batch; this also appends a row with SPAN operation // to the decoder trace. we also need the total number of operation groups so that we can // set the value of the group_count register at the beginning of the SPAN. let num_op_groups = get_span_op_group_count(op_batches); - self.decoder.start_span(&op_batches[0], Felt::new(num_op_groups as u64), addr); + self.decoder + .start_basic_block(&op_batches[0], Felt::new(num_op_groups as u64), addr); self.execute_op(Operation::Noop) } - /// Continues decoding a SPAN block by absorbing the next batch of operations. - pub(super) fn respan(&mut self, op_batch: &OpBatch) { - self.decoder.respan(op_batch); - } - - /// Ends decoding a SPAN block. - pub(super) fn end_span_block(&mut self, block: &Span) -> Result<(), ExecutionError> { + /// Ends decoding a BASIC BLOCK node. + pub(super) fn end_basic_block_node( + &mut self, + block: &BasicBlockNode, + ) -> Result<(), ExecutionError> { // this appends a row with END operation to the decoder trace. when END operation is // executed the rest of the VM state does not change - self.decoder.end_span(block.hash().into()); + self.decoder.end_basic_block(block.digest().into()); self.execute_op(Operation::Noop) } + + /// Continues decoding a SPAN block by absorbing the next batch of operations. + pub(super) fn respan(&mut self, op_batch: &OpBatch) { + self.decoder.respan(op_batch); + } } // DECODER @@ -506,7 +581,7 @@ impl Decoder { // -------------------------------------------------------------------------------------------- /// Starts decoding of a SPAN block defined by the specified operation batches. - pub fn start_span(&mut self, first_op_batch: &OpBatch, num_op_groups: Felt, addr: Felt) { + pub fn start_basic_block(&mut self, first_op_batch: &OpBatch, num_op_groups: Felt, addr: Felt) { debug_assert!(self.span_context.is_none(), "already in span"); let parent_addr = self.block_stack.push(addr, BlockType::Span, None); @@ -595,7 +670,7 @@ impl Decoder { } /// Ends decoding of a SPAN block. - pub fn end_span(&mut self, block_hash: Word) { + pub fn end_basic_block(&mut self, block_hash: Word) { // remove the block from the stack of executing blocks and add an END row to the // execution trace let block_info = self.block_stack.pop(); diff --git a/processor/src/decoder/tests.rs b/processor/src/decoder/tests.rs index b751ad3432..a0474cb398 100644 --- a/processor/src/decoder/tests.rs +++ b/processor/src/decoder/tests.rs @@ -18,8 +18,8 @@ use miden_air::trace::{ }; use test_utils::rand::rand_value; use vm_core::{ - code_blocks::{CodeBlock, Span, OP_BATCH_SIZE}, - CodeBlockTable, EMPTY_WORD, ONE, ZERO, + mast::{BasicBlockNode, MastForest, MastNode, MerkleTreeNode, OP_BATCH_SIZE}, + Program, EMPTY_WORD, ONE, ZERO, }; // CONSTANTS @@ -43,10 +43,18 @@ type DecoderTrace = [Vec; DECODER_TRACE_WIDTH]; // ================================================================================================ #[test] -fn span_block_one_group() { +fn basic_block_one_group() { let ops = vec![Operation::Pad, Operation::Add, Operation::Mul]; - let span = Span::new(ops.clone()); - let program = CodeBlock::new_span(ops.clone()); + let basic_block = BasicBlockNode::new(ops.clone()); + let program = { + let mut mast_forest = MastForest::new(); + + let basic_block_node = MastNode::Block(basic_block.clone()); + let basic_block_id = mast_forest.ensure_node(basic_block_node); + mast_forest.set_entrypoint(basic_block_id); + + Program::new(mast_forest).unwrap() + }; let (trace, trace_len) = build_trace(&[], &program); @@ -63,7 +71,7 @@ fn span_block_one_group() { check_hasher_state( &trace, vec![ - span.op_batches()[0].groups().to_vec(), // first group should contain op batch + basic_block.op_batches()[0].groups().to_vec(), // first group should contain op batch vec![build_op_group(&ops[1..])], vec![build_op_group(&ops[2..])], vec![], @@ -81,11 +89,19 @@ fn span_block_one_group() { } #[test] -fn span_block_small() { +fn basic_block_small() { let iv = [ONE, TWO]; let ops = vec![Operation::Push(iv[0]), Operation::Push(iv[1]), Operation::Add]; - let span = Span::new(ops.clone()); - let program = CodeBlock::new_span(ops.clone()); + let basic_block = BasicBlockNode::new(ops.clone()); + let program = { + let mut mast_forest = MastForest::new(); + + let basic_block_node = MastNode::Block(basic_block.clone()); + let basic_block_id = mast_forest.ensure_node(basic_block_node); + mast_forest.set_entrypoint(basic_block_id); + + Program::new(mast_forest).unwrap() + }; let (trace, trace_len) = build_trace(&[], &program); @@ -105,7 +121,7 @@ fn span_block_small() { check_hasher_state( &trace, vec![ - span.op_batches()[0].groups().to_vec(), + basic_block.op_batches()[0].groups().to_vec(), vec![build_op_group(&ops[1..])], vec![build_op_group(&ops[2..])], vec![], @@ -124,7 +140,7 @@ fn span_block_small() { } #[test] -fn span_block() { +fn basic_block() { let iv = [ONE, TWO, Felt::new(3), Felt::new(4), Felt::new(5)]; let ops = vec![ Operation::Push(iv[0]), @@ -140,8 +156,16 @@ fn span_block() { Operation::Add, Operation::Inv, ]; - let span = Span::new(ops.clone()); - let program = CodeBlock::new_span(ops.clone()); + let basic_block = BasicBlockNode::new(ops.clone()); + let program = { + let mut mast_forest = MastForest::new(); + + let basic_block_node = MastNode::Block(basic_block.clone()); + let basic_block_id = mast_forest.ensure_node(basic_block_node); + mast_forest.set_entrypoint(basic_block_id); + + Program::new(mast_forest).unwrap() + }; let (trace, trace_len) = build_trace(&[], &program); // --- check block address, op_bits, group count, op_index, and in_span columns --------------- @@ -171,7 +195,7 @@ fn span_block() { check_hasher_state( &trace, vec![ - span.op_batches()[0].groups().to_vec(), + basic_block.op_batches()[0].groups().to_vec(), vec![build_op_group(&ops[1..8])], // first group starts vec![build_op_group(&ops[2..8])], vec![build_op_group(&ops[3..8])], @@ -225,8 +249,16 @@ fn span_block_with_respan() { Operation::Add, Operation::Push(iv[8]), ]; - let span = Span::new(ops.clone()); - let program = CodeBlock::new_span(ops.clone()); + let basic_block = BasicBlockNode::new(ops.clone()); + let program = { + let mut mast_forest = MastForest::new(); + + let basic_block_node = MastNode::Block(basic_block.clone()); + let basic_block_id = mast_forest.ensure_node(basic_block_node); + mast_forest.set_entrypoint(basic_block_id); + + Program::new(mast_forest).unwrap() + }; let (trace, trace_len) = build_trace(&[], &program); // --- check block address, op_bits, group count, op_index, and in_span columns --------------- @@ -258,7 +290,7 @@ fn span_block_with_respan() { check_hasher_state( &trace, vec![ - span.op_batches()[0].groups().to_vec(), + basic_block.op_batches()[0].groups().to_vec(), vec![build_op_group(&ops[1..7])], // first group starts vec![build_op_group(&ops[2..7])], vec![build_op_group(&ops[3..7])], @@ -267,7 +299,7 @@ fn span_block_with_respan() { vec![build_op_group(&ops[6..7])], vec![], vec![], // a NOOP inserted after last PUSH - span.op_batches()[1].groups().to_vec(), + basic_block.op_batches()[1].groups().to_vec(), vec![build_op_group(&ops[8..])], // next group starts vec![build_op_group(&ops[9..])], vec![], @@ -290,10 +322,21 @@ fn span_block_with_respan() { // ================================================================================================ #[test] -fn join_block() { - let span1 = CodeBlock::new_span(vec![Operation::Mul]); - let span2 = CodeBlock::new_span(vec![Operation::Add]); - let program = CodeBlock::new_join([span1.clone(), span2.clone()]); +fn join_node() { + let basic_block1 = MastNode::new_basic_block(vec![Operation::Mul]); + let basic_block2 = MastNode::new_basic_block(vec![Operation::Add]); + let program = { + let mut mast_forest = MastForest::new(); + + let basic_block1_id = mast_forest.ensure_node(basic_block1.clone()); + let basic_block2_id = mast_forest.ensure_node(basic_block2.clone()); + + let join_node = MastNode::new_join(basic_block1_id, basic_block2_id, &mast_forest); + let join_node_id = mast_forest.ensure_node(join_node); + mast_forest.set_entrypoint(join_node_id); + + Program::new(mast_forest).unwrap() + }; let (trace, trace_len) = build_trace(&[], &program); @@ -315,8 +358,8 @@ fn join_block() { // --- check hasher state columns ------------------------------------------------------------- // in the first row, the hasher state is set to hashes of both child nodes - let span1_hash: Word = span1.hash().into(); - let span2_hash: Word = span2.hash().into(); + let span1_hash: Word = basic_block1.digest().into(); + let span2_hash: Word = basic_block2.digest().into(); assert_eq!(span1_hash, get_hasher_state1(&trace, 0)); assert_eq!(span2_hash, get_hasher_state2(&trace, 0)); @@ -346,27 +389,38 @@ fn join_block() { // ================================================================================================ #[test] -fn split_block_true() { - let span1 = CodeBlock::new_span(vec![Operation::Mul]); - let span2 = CodeBlock::new_span(vec![Operation::Add]); - let program = CodeBlock::new_split(span1.clone(), span2.clone()); +fn split_node_true() { + let basic_block1 = MastNode::new_basic_block(vec![Operation::Mul]); + let basic_block2 = MastNode::new_basic_block(vec![Operation::Add]); + let program = { + let mut mast_forest = MastForest::new(); + + let basic_block1_id = mast_forest.ensure_node(basic_block1.clone()); + let basic_block2_id = mast_forest.ensure_node(basic_block2.clone()); + + let split_node = MastNode::new_split(basic_block1_id, basic_block2_id, &mast_forest); + let split_node_id = mast_forest.ensure_node(split_node); + mast_forest.set_entrypoint(split_node_id); + + Program::new(mast_forest).unwrap() + }; let (trace, trace_len) = build_trace(&[1], &program); // --- check block address, op_bits, group count, op_index, and in_span columns --------------- - let span_addr = INIT_ADDR + EIGHT; + let basic_block_addr = INIT_ADDR + EIGHT; check_op_decoding(&trace, 0, ZERO, Operation::Split, 0, 0, 0); check_op_decoding(&trace, 1, INIT_ADDR, Operation::Span, 1, 0, 0); - check_op_decoding(&trace, 2, span_addr, Operation::Mul, 0, 0, 1); - check_op_decoding(&trace, 3, span_addr, Operation::End, 0, 0, 0); + check_op_decoding(&trace, 2, basic_block_addr, Operation::Mul, 0, 0, 1); + check_op_decoding(&trace, 3, basic_block_addr, Operation::End, 0, 0, 0); check_op_decoding(&trace, 4, INIT_ADDR, Operation::End, 0, 0, 0); check_op_decoding(&trace, 5, ZERO, Operation::Halt, 0, 0, 0); // --- check hasher state columns ------------------------------------------------------------- // in the first row, the hasher state is set to hashes of both child nodes - let span1_hash: Word = span1.hash().into(); - let span2_hash: Word = span2.hash().into(); + let span1_hash: Word = basic_block1.digest().into(); + let span2_hash: Word = basic_block2.digest().into(); assert_eq!(span1_hash, get_hasher_state1(&trace, 0)); assert_eq!(span2_hash, get_hasher_state2(&trace, 0)); @@ -389,27 +443,38 @@ fn split_block_true() { } #[test] -fn split_block_false() { - let span1 = CodeBlock::new_span(vec![Operation::Mul]); - let span2 = CodeBlock::new_span(vec![Operation::Add]); - let program = CodeBlock::new_split(span1.clone(), span2.clone()); +fn split_node_false() { + let basic_block1 = MastNode::new_basic_block(vec![Operation::Mul]); + let basic_block2 = MastNode::new_basic_block(vec![Operation::Add]); + let program = { + let mut mast_forest = MastForest::new(); + + let basic_block1_id = mast_forest.ensure_node(basic_block1.clone()); + let basic_block2_id = mast_forest.ensure_node(basic_block2.clone()); + + let split_node = MastNode::new_split(basic_block1_id, basic_block2_id, &mast_forest); + let split_node_id = mast_forest.ensure_node(split_node); + mast_forest.set_entrypoint(split_node_id); + + Program::new(mast_forest).unwrap() + }; let (trace, trace_len) = build_trace(&[0], &program); // --- check block address, op_bits, group count, op_index, and in_span columns --------------- - let span_addr = INIT_ADDR + EIGHT; + let basic_block_addr = INIT_ADDR + EIGHT; check_op_decoding(&trace, 0, ZERO, Operation::Split, 0, 0, 0); check_op_decoding(&trace, 1, INIT_ADDR, Operation::Span, 1, 0, 0); - check_op_decoding(&trace, 2, span_addr, Operation::Add, 0, 0, 1); - check_op_decoding(&trace, 3, span_addr, Operation::End, 0, 0, 0); + check_op_decoding(&trace, 2, basic_block_addr, Operation::Add, 0, 0, 1); + check_op_decoding(&trace, 3, basic_block_addr, Operation::End, 0, 0, 0); check_op_decoding(&trace, 4, INIT_ADDR, Operation::End, 0, 0, 0); check_op_decoding(&trace, 5, ZERO, Operation::Halt, 0, 0, 0); // --- check hasher state columns ------------------------------------------------------------- // in the first row, the hasher state is set to hashes of both child nodes - let span1_hash: Word = span1.hash().into(); - let span2_hash: Word = span2.hash().into(); + let span1_hash: Word = basic_block1.digest().into(); + let span2_hash: Word = basic_block2.digest().into(); assert_eq!(span1_hash, get_hasher_state1(&trace, 0)); assert_eq!(span2_hash, get_hasher_state2(&trace, 0)); @@ -435,9 +500,19 @@ fn split_block_false() { // ================================================================================================ #[test] -fn loop_block() { - let loop_body = CodeBlock::new_span(vec![Operation::Pad, Operation::Drop]); - let program = CodeBlock::new_loop(loop_body.clone()); +fn loop_node() { + let loop_body = MastNode::new_basic_block(vec![Operation::Pad, Operation::Drop]); + let program = { + let mut mast_forest = MastForest::new(); + + let loop_body_id = mast_forest.ensure_node(loop_body.clone()); + + let loop_node = MastNode::new_loop(loop_body_id, &mast_forest); + let loop_node_id = mast_forest.ensure_node(loop_node); + mast_forest.set_entrypoint(loop_node_id); + + Program::new(mast_forest).unwrap() + }; let (trace, trace_len) = build_trace(&[0, 1], &program); @@ -454,7 +529,7 @@ fn loop_block() { // --- check hasher state columns ------------------------------------------------------------- // in the first row, the hasher state is set to the hash of the loop's body - let loop_body_hash: Word = loop_body.hash().into(); + let loop_body_hash: Word = loop_body.digest().into(); assert_eq!(loop_body_hash, get_hasher_state1(&trace, 0)); assert_eq!(EMPTY_WORD, get_hasher_state2(&trace, 0)); @@ -479,9 +554,19 @@ fn loop_block() { } #[test] -fn loop_block_skip() { - let loop_body = CodeBlock::new_span(vec![Operation::Pad, Operation::Drop]); - let program = CodeBlock::new_loop(loop_body.clone()); +fn loop_node_skip() { + let loop_body = MastNode::new_basic_block(vec![Operation::Pad, Operation::Drop]); + let program = { + let mut mast_forest = MastForest::new(); + + let loop_body_id = mast_forest.ensure_node(loop_body.clone()); + + let loop_node = MastNode::new_loop(loop_body_id, &mast_forest); + let loop_node_id = mast_forest.ensure_node(loop_node); + mast_forest.set_entrypoint(loop_node_id); + + Program::new(mast_forest).unwrap() + }; let (trace, trace_len) = build_trace(&[0], &program); @@ -493,7 +578,7 @@ fn loop_block_skip() { // --- check hasher state columns ------------------------------------------------------------- // in the first row, the hasher state is set to the hash of the loop's body - let loop_body_hash: Word = loop_body.hash().into(); + let loop_body_hash: Word = loop_body.digest().into(); assert_eq!(loop_body_hash, get_hasher_state1(&trace, 0)); assert_eq!(EMPTY_WORD, get_hasher_state2(&trace, 0)); @@ -513,9 +598,19 @@ fn loop_block_skip() { } #[test] -fn loop_block_repeat() { - let loop_body = CodeBlock::new_span(vec![Operation::Pad, Operation::Drop]); - let program = CodeBlock::new_loop(loop_body.clone()); +fn loop_node_repeat() { + let loop_body = MastNode::new_basic_block(vec![Operation::Pad, Operation::Drop]); + let program = { + let mut mast_forest = MastForest::new(); + + let loop_body_id = mast_forest.ensure_node(loop_body.clone()); + + let loop_node = MastNode::new_loop(loop_body_id, &mast_forest); + let loop_node_id = mast_forest.ensure_node(loop_node); + mast_forest.set_entrypoint(loop_node_id); + + Program::new(mast_forest).unwrap() + }; let (trace, trace_len) = build_trace(&[0, 1, 1], &program); @@ -540,7 +635,7 @@ fn loop_block_repeat() { // --- check hasher state columns ------------------------------------------------------------- // in the first row, the hasher state is set to the hash of the loop's body - let loop_body_hash: Word = loop_body.hash().into(); + let loop_body_hash: Word = loop_body.digest().into(); assert_eq!(loop_body_hash, get_hasher_state1(&trace, 0)); assert_eq!(EMPTY_WORD, get_hasher_state2(&trace, 0)); @@ -591,20 +686,37 @@ fn call_block() { // stack[0] <- fmp // end - let first_span = CodeBlock::new_span(vec![ + let mut mast_forest = MastForest::new(); + + let first_basic_block = MastNode::new_basic_block(vec![ Operation::Push(TWO), Operation::FmpUpdate, Operation::Pad, ]); - let foo_root = CodeBlock::new_span(vec![Operation::Push(ONE), Operation::FmpUpdate]); - let last_span = CodeBlock::new_span(vec![Operation::FmpAdd]); + let first_basic_block_id = mast_forest.ensure_node(first_basic_block.clone()); + + let foo_root_node = MastNode::new_basic_block(vec![ + Operation::Push(ONE), Operation::FmpUpdate + ]); + let foo_root_node_id = mast_forest.ensure_node(foo_root_node.clone()); + + let last_basic_block = MastNode::new_basic_block(vec![Operation::FmpAdd]); + let last_basic_block_id = mast_forest.ensure_node(last_basic_block.clone()); + + let foo_call_node = MastNode::new_call(foo_root_node_id, &mast_forest); + let foo_call_node_id = mast_forest.ensure_node(foo_call_node.clone()); + + let join1_node = MastNode::new_join(first_basic_block_id, foo_call_node_id, &mast_forest); + let join1_node_id = mast_forest.ensure_node(join1_node.clone()); + + let program_root = MastNode::new_join(join1_node_id, last_basic_block_id, &mast_forest); + let program_root_id = mast_forest.ensure_node(program_root); + mast_forest.set_entrypoint(program_root_id); - let foo_call = CodeBlock::new_call(foo_root.hash()); - let join1 = CodeBlock::new_join([first_span.clone(), foo_call.clone()]); - let program = CodeBlock::new_join([join1.clone(), last_span.clone()]); + let program = Program::new(mast_forest).unwrap(); let (sys_trace, dec_trace, trace_len) = - build_call_trace(&program, foo_root.clone(), None); + build_call_trace(&program, Kernel::default()); // --- check block address, op_bits, group count, op_index, and in_span columns --------------- check_op_decoding(&dec_trace, 0, ZERO, Operation::Join, 0, 0, 0); @@ -612,14 +724,14 @@ fn call_block() { let join1_addr = INIT_ADDR + EIGHT; check_op_decoding(&dec_trace, 1, INIT_ADDR, Operation::Join, 0, 0, 0); // starting first SPAN block - let first_span_addr = join1_addr + EIGHT; + let first_basic_block_addr = join1_addr + EIGHT; check_op_decoding(&dec_trace, 2, join1_addr, Operation::Span, 2, 0, 0); - check_op_decoding(&dec_trace, 3, first_span_addr, Operation::Push(TWO), 1, 0, 1); - check_op_decoding(&dec_trace, 4, first_span_addr, Operation::FmpUpdate, 0, 1, 1); - check_op_decoding(&dec_trace, 5, first_span_addr, Operation::Pad, 0, 2, 1); - check_op_decoding(&dec_trace, 6, first_span_addr, Operation::End, 0, 0, 0); + check_op_decoding(&dec_trace, 3, first_basic_block_addr, Operation::Push(TWO), 1, 0, 1); + check_op_decoding(&dec_trace, 4, first_basic_block_addr, Operation::FmpUpdate, 0, 1, 1); + check_op_decoding(&dec_trace, 5, first_basic_block_addr, Operation::Pad, 0, 2, 1); + check_op_decoding(&dec_trace, 6, first_basic_block_addr, Operation::End, 0, 0, 0); // starting CALL block - let foo_call_addr = first_span_addr + EIGHT; + let foo_call_addr = first_basic_block_addr + EIGHT; check_op_decoding(&dec_trace, 7, join1_addr, Operation::Call, 0, 0, 0); // starting second SPAN block let foo_root_addr = foo_call_addr + EIGHT; @@ -632,24 +744,24 @@ fn call_block() { // ending internal JOIN block check_op_decoding(&dec_trace, 13, join1_addr, Operation::End, 0, 0, 0); // starting the 3rd SPAN block - let last_span_addr = foo_root_addr + EIGHT; + let last_basic_block_addr = foo_root_addr + EIGHT; check_op_decoding(&dec_trace, 14, INIT_ADDR, Operation::Span, 1, 0, 0); - check_op_decoding(&dec_trace, 15, last_span_addr, Operation::FmpAdd, 0, 0, 1); - check_op_decoding(&dec_trace, 16, last_span_addr, Operation::End, 0, 0, 0); + check_op_decoding(&dec_trace, 15, last_basic_block_addr, Operation::FmpAdd, 0, 0, 1); + check_op_decoding(&dec_trace, 16, last_basic_block_addr, Operation::End, 0, 0, 0); // ending the program check_op_decoding(&dec_trace, 17, INIT_ADDR, Operation::End, 0, 0, 0); check_op_decoding(&dec_trace, 18, ZERO, Operation::Halt, 0, 0, 0); // --- check hasher state columns ------------------------------------------------------------- // in the first row, the hasher state is set to hashes of (join1, span3) - let join1_hash: Word = join1.hash().into(); - let last_span_hash: Word = last_span.hash().into(); + let join1_hash: Word = join1_node.digest().into(); + let last_basic_block_hash: Word = last_basic_block.digest().into(); assert_eq!(join1_hash, get_hasher_state1(&dec_trace, 0)); - assert_eq!(last_span_hash, get_hasher_state2(&dec_trace, 0)); + assert_eq!(last_basic_block_hash, get_hasher_state2(&dec_trace, 0)); // in the second row, the hasher state is set to hashes of (span1, fn_block) - let first_span_hash: Word = first_span.hash().into(); - let foo_call_hash: Word = foo_call.hash().into(); + let first_span_hash: Word = first_basic_block.digest().into(); + let foo_call_hash: Word = foo_call_node.digest().into(); assert_eq!(first_span_hash, get_hasher_state1(&dec_trace, 1)); assert_eq!(foo_call_hash, get_hasher_state2(&dec_trace, 1)); @@ -657,8 +769,8 @@ fn call_block() { assert_eq!(first_span_hash, get_hasher_state1(&dec_trace, 6)); assert_eq!(EMPTY_WORD, get_hasher_state2(&dec_trace, 6)); - // in the 7th row, we start the CALL block which hash span2 as its only child - let foo_root_hash: Word = foo_root.hash().into(); + // in the 7th row, we start the CALL block which has basic_block2 as its only child + let foo_root_hash: Word = foo_root_node.digest().into(); assert_eq!(foo_root_hash, get_hasher_state1(&dec_trace, 7)); assert_eq!(EMPTY_WORD, get_hasher_state2(&dec_trace, 7)); @@ -676,7 +788,7 @@ fn call_block() { assert_eq!(EMPTY_WORD, get_hasher_state2(&dec_trace, 13)); // span3 ends in the 14th row - assert_eq!(last_span_hash, get_hasher_state1(&dec_trace, 16)); + assert_eq!(last_basic_block_hash, get_hasher_state1(&dec_trace, 16)); assert_eq!(EMPTY_WORD, get_hasher_state2(&dec_trace, 16)); // the program ends in the 17th row @@ -788,28 +900,49 @@ fn syscall_block() { // stack[0] <- fmp // end + let mut mast_forest = MastForest::new(); + // build foo procedure body - let foo_root = CodeBlock::new_span(vec![Operation::Push(THREE), Operation::FmpUpdate]); + let foo_root = MastNode::new_basic_block(vec![Operation::Push(THREE), Operation::FmpUpdate]); + let foo_root_id = mast_forest.ensure_node(foo_root.clone()); + let kernel = Kernel::new(&[foo_root.digest()]).unwrap(); + mast_forest.set_kernel(kernel.clone()); // build bar procedure body - let bar_span = CodeBlock::new_span(vec![Operation::Push(TWO), Operation::FmpUpdate]); - let foo_call = CodeBlock::new_syscall(foo_root.hash()); - let bar_root = CodeBlock::new_join([bar_span.clone(), foo_call.clone()]); + let bar_basic_block = MastNode::new_basic_block(vec![Operation::Push(TWO), Operation::FmpUpdate]); + let bar_basic_block_id = mast_forest.ensure_node(bar_basic_block.clone()); + + let foo_call_node = MastNode::new_syscall(foo_root_id, &mast_forest); + let foo_call_node_id = mast_forest.ensure_node(foo_call_node.clone()); + + let bar_root_node = MastNode::new_join(bar_basic_block_id, foo_call_node_id, &mast_forest); + let bar_root_node_id = mast_forest.ensure_node(bar_root_node.clone()); // build the program - let first_span = CodeBlock::new_span(vec![ + let first_basic_block = MastNode::new_basic_block(vec![ Operation::Push(ONE), Operation::FmpUpdate, Operation::Pad, ]); - let last_span = CodeBlock::new_span(vec![Operation::FmpAdd]); + let first_basic_block_id = mast_forest.ensure_node(first_basic_block.clone()); + + let last_basic_block = MastNode::new_basic_block(vec![Operation::FmpAdd]); + let last_basic_block_id = mast_forest.ensure_node(last_basic_block.clone()); + + let bar_call_node = MastNode::new_call(bar_root_node_id, &mast_forest); + let bar_call_node_id = mast_forest.ensure_node(bar_call_node.clone()); + + let inner_join_node = MastNode::new_join(first_basic_block_id, bar_call_node_id, &mast_forest); + let inner_join_node_id = mast_forest.ensure_node(inner_join_node.clone()); + + let program_root_node = MastNode::new_join(inner_join_node_id, last_basic_block_id, &mast_forest); + let program_root_node_id = mast_forest.ensure_node(program_root_node.clone()); + mast_forest.set_entrypoint(program_root_node_id); - let bar_call = CodeBlock::new_call(bar_root.hash()); - let inner_join = CodeBlock::new_join([first_span.clone(), bar_call.clone()]); - let program = CodeBlock::new_join([inner_join.clone(), last_span.clone()]); + let program = Program::new(mast_forest).unwrap(); let (sys_trace, dec_trace, trace_len) = - build_call_trace(&program, bar_root.clone(), Some(foo_root.clone())); + build_call_trace(&program, kernel); // --- check block address, op_bits, group count, op_index, and in_span columns --------------- check_op_decoding(&dec_trace, 0, ZERO, Operation::Join, 0, 0, 0); @@ -817,35 +950,35 @@ fn syscall_block() { let inner_join_addr = INIT_ADDR + EIGHT; check_op_decoding(&dec_trace, 1, INIT_ADDR, Operation::Join, 0, 0, 0); // starting first SPAN block - let first_span_addr = inner_join_addr + EIGHT; + let first_basic_block_addr = inner_join_addr + EIGHT; check_op_decoding(&dec_trace, 2, inner_join_addr, Operation::Span, 2, 0, 0); - check_op_decoding(&dec_trace, 3, first_span_addr, Operation::Push(TWO), 1, 0, 1); - check_op_decoding(&dec_trace, 4, first_span_addr, Operation::FmpUpdate, 0, 1, 1); - check_op_decoding(&dec_trace, 5, first_span_addr, Operation::Pad, 0, 2, 1); - check_op_decoding(&dec_trace, 6, first_span_addr, Operation::End, 0, 0, 0); + check_op_decoding(&dec_trace, 3, first_basic_block_addr, Operation::Push(TWO), 1, 0, 1); + check_op_decoding(&dec_trace, 4, first_basic_block_addr, Operation::FmpUpdate, 0, 1, 1); + check_op_decoding(&dec_trace, 5, first_basic_block_addr, Operation::Pad, 0, 2, 1); + check_op_decoding(&dec_trace, 6, first_basic_block_addr, Operation::End, 0, 0, 0); // starting CALL block for bar - let call_addr = first_span_addr + EIGHT; + let call_addr = first_basic_block_addr + EIGHT; check_op_decoding(&dec_trace, 7, inner_join_addr, Operation::Call, 0, 0, 0); // starting JOIN block inside bar let bar_join_addr = call_addr + EIGHT; check_op_decoding(&dec_trace, 8, call_addr, Operation::Join, 0, 0, 0); // starting SPAN block inside bar - let bar_span_addr = bar_join_addr + EIGHT; + let bar_basic_block_addr = bar_join_addr + EIGHT; check_op_decoding(&dec_trace, 9, bar_join_addr, Operation::Span, 2, 0, 0); - check_op_decoding(&dec_trace, 10, bar_span_addr, Operation::Push(ONE), 1, 0, 1); - check_op_decoding(&dec_trace, 11, bar_span_addr, Operation::FmpUpdate, 0, 1, 1); - check_op_decoding(&dec_trace, 12, bar_span_addr, Operation::End, 0, 0, 0); + check_op_decoding(&dec_trace, 10, bar_basic_block_addr, Operation::Push(ONE), 1, 0, 1); + check_op_decoding(&dec_trace, 11, bar_basic_block_addr, Operation::FmpUpdate, 0, 1, 1); + check_op_decoding(&dec_trace, 12, bar_basic_block_addr, Operation::End, 0, 0, 0); // starting SYSCALL block for bar - let syscall_addr = bar_span_addr + EIGHT; + let syscall_addr = bar_basic_block_addr + EIGHT; check_op_decoding(&dec_trace, 13, bar_join_addr, Operation::SysCall, 0, 0, 0); // starting SPAN block within syscall - let syscall_span_addr = syscall_addr + EIGHT; + let syscall_basic_block_addr = syscall_addr + EIGHT; check_op_decoding(&dec_trace, 14, syscall_addr, Operation::Span, 2, 0, 0); - check_op_decoding(&dec_trace, 15, syscall_span_addr, Operation::Push(THREE), 1, 0, 1); - check_op_decoding(&dec_trace, 16, syscall_span_addr, Operation::FmpUpdate, 0, 1, 1); - check_op_decoding(&dec_trace, 17, syscall_span_addr, Operation::End, 0, 0, 0); + check_op_decoding(&dec_trace, 15, syscall_basic_block_addr, Operation::Push(THREE), 1, 0, 1); + check_op_decoding(&dec_trace, 16, syscall_basic_block_addr, Operation::FmpUpdate, 0, 1, 1); + check_op_decoding(&dec_trace, 17, syscall_basic_block_addr, Operation::End, 0, 0, 0); // ending SYSCALL block check_op_decoding(&dec_trace, 18, syscall_addr, Operation::End, 0, 0, 0); @@ -857,10 +990,10 @@ fn syscall_block() { check_op_decoding(&dec_trace, 21, inner_join_addr, Operation::End, 0, 0, 0); // starting the last SPAN block - let last_span_addr = syscall_span_addr + EIGHT; + let last_basic_block_addr = syscall_basic_block_addr + EIGHT; check_op_decoding(&dec_trace, 22, INIT_ADDR, Operation::Span, 1, 0, 0); - check_op_decoding(&dec_trace, 23, last_span_addr, Operation::FmpAdd, 0, 0, 1); - check_op_decoding(&dec_trace, 24, last_span_addr, Operation::End, 0, 0, 0); + check_op_decoding(&dec_trace, 23, last_basic_block_addr, Operation::FmpAdd, 0, 0, 1); + check_op_decoding(&dec_trace, 24, last_basic_block_addr, Operation::End, 0, 0, 0); // ending the program check_op_decoding(&dec_trace, 25, INIT_ADDR, Operation::End, 0, 0, 0); @@ -868,14 +1001,14 @@ fn syscall_block() { // --- check hasher state columns ------------------------------------------------------------- // in the first row, the hasher state is set to hashes of (inner_join, last_span) - let inner_join_hash: Word = inner_join.hash().into(); - let last_span_hash: Word = last_span.hash().into(); + let inner_join_hash: Word = inner_join_node.digest().into(); + let last_span_hash: Word = last_basic_block.digest().into(); assert_eq!(inner_join_hash, get_hasher_state1(&dec_trace, 0)); assert_eq!(last_span_hash, get_hasher_state2(&dec_trace, 0)); // in the second row, the hasher state is set to hashes of (first_span, bar_call) - let first_span_hash: Word = first_span.hash().into(); - let bar_call_hash: Word = bar_call.hash().into(); + let first_span_hash: Word = first_basic_block.digest().into(); + let bar_call_hash: Word = bar_call_node.digest().into(); assert_eq!(first_span_hash, get_hasher_state1(&dec_trace, 1)); assert_eq!(bar_call_hash, get_hasher_state2(&dec_trace, 1)); @@ -884,13 +1017,13 @@ fn syscall_block() { assert_eq!(EMPTY_WORD, get_hasher_state2(&dec_trace, 6)); // in the 7th row, we start the CALL block which has bar_join as its only child - let bar_root_hash: Word = bar_root.hash().into(); + let bar_root_hash: Word = bar_root_node.digest().into(); assert_eq!(bar_root_hash, get_hasher_state1(&dec_trace, 7)); assert_eq!(EMPTY_WORD, get_hasher_state2(&dec_trace, 7)); // in the 8th row, the hasher state is set to hashes of (bar_span, foo_call) - let bar_span_hash: Word = bar_span.hash().into(); - let foo_call_hash: Word = foo_call.hash().into(); + let bar_span_hash: Word = bar_basic_block.digest().into(); + let foo_call_hash: Word = foo_call_node.digest().into(); assert_eq!(bar_span_hash, get_hasher_state1(&dec_trace, 8)); assert_eq!(foo_call_hash, get_hasher_state2(&dec_trace, 8)); @@ -899,7 +1032,7 @@ fn syscall_block() { assert_eq!(EMPTY_WORD, get_hasher_state2(&dec_trace, 12)); // in the 13th row, we start the SYSCALL block which has foo_span as its only child - let foo_root_hash: Word = foo_root.hash().into(); + let foo_root_hash: Word = foo_root.digest().into(); assert_eq!(foo_root_hash, get_hasher_state1(&dec_trace, 13)); assert_eq!(EMPTY_WORD, get_hasher_state2(&dec_trace, 13)); @@ -930,7 +1063,7 @@ fn syscall_block() { assert_eq!(EMPTY_WORD, get_hasher_state2(&dec_trace, 24)); // the program ends in the 25th row - let program_hash: Word = program.hash().into(); + let program_hash: Word = program_root_node.digest().into(); assert_eq!(program_hash, get_hasher_state1(&dec_trace, 25)); assert_eq!(EMPTY_WORD, get_hasher_state2(&dec_trace, 25)); @@ -1056,25 +1189,40 @@ fn dyn_block() { // build a dynamic block which looks like this: // push.1 add - let foo_root = CodeBlock::new_span(vec![Operation::Push(ONE), Operation::Add]); - let mul_span = CodeBlock::new_span(vec![Operation::Mul]); - let save_span = CodeBlock::new_span(vec![Operation::MovDn4]); - let join = CodeBlock::new_join([mul_span.clone(), save_span.clone()]); + let mut mast_forest = MastForest::new(); + + let foo_root_node = MastNode::new_basic_block(vec![Operation::Push(ONE), Operation::Add]); + let _foo_root_node_id = mast_forest.ensure_node(foo_root_node.clone()); + + let mul_bb_node = MastNode::new_basic_block(vec![Operation::Mul]); + let mul_bb_node_id = mast_forest.ensure_node(mul_bb_node.clone()); + + let save_bb_node = MastNode::new_basic_block(vec![Operation::MovDn4]); + let save_bb_node_id = mast_forest.ensure_node(save_bb_node.clone()); + + let join_node = MastNode::new_join(mul_bb_node_id, save_bb_node_id, &mast_forest); + let join_node_id = mast_forest.ensure_node(join_node.clone()); + // This dyn will point to foo. - let dyn_block = CodeBlock::new_dyn(); - let program = CodeBlock::new_join([join.clone(), dyn_block.clone()]); + let dyn_node = MastNode::new_dynexec(); + let dyn_node_id = mast_forest.ensure_node(dyn_node.clone()); + + let program_root_node = MastNode::new_join(join_node_id, dyn_node_id, &mast_forest); + let program_root_node_id = mast_forest.ensure_node(program_root_node.clone()); + mast_forest.set_entrypoint(program_root_node_id); + + let program = Program::new(mast_forest).unwrap(); let (trace, trace_len) = build_dyn_trace( &[ - foo_root.hash()[0].as_int(), - foo_root.hash()[1].as_int(), - foo_root.hash()[2].as_int(), - foo_root.hash()[3].as_int(), + foo_root_node.digest()[0].as_int(), + foo_root_node.digest()[1].as_int(), + foo_root_node.digest()[2].as_int(), + foo_root_node.digest()[3].as_int(), 2, 4, ], &program, - foo_root.clone(), ); // --- check block address, op_bits, group count, op_index, and in_span columns --------------- @@ -1083,26 +1231,26 @@ fn dyn_block() { let join_addr = INIT_ADDR + EIGHT; check_op_decoding(&trace, 1, INIT_ADDR, Operation::Join, 0, 0, 0); // starting first span - let mul_span_addr = join_addr + EIGHT; + let mul_basic_block_addr = join_addr + EIGHT; check_op_decoding(&trace, 2, join_addr, Operation::Span, 1, 0, 0); - check_op_decoding(&trace, 3, mul_span_addr, Operation::Mul, 0, 0, 1); - check_op_decoding(&trace, 4, mul_span_addr, Operation::End, 0, 0, 0); + check_op_decoding(&trace, 3, mul_basic_block_addr, Operation::Mul, 0, 0, 1); + check_op_decoding(&trace, 4, mul_basic_block_addr, Operation::End, 0, 0, 0); // starting second span - let save_span_addr = mul_span_addr + EIGHT; + let save_basic_block_addr = mul_basic_block_addr + EIGHT; check_op_decoding(&trace, 5, join_addr, Operation::Span, 1, 0, 0); - check_op_decoding(&trace, 6, save_span_addr, Operation::MovDn4, 0, 0, 1); - check_op_decoding(&trace, 7, save_span_addr, Operation::End, 0, 0, 0); + check_op_decoding(&trace, 6, save_basic_block_addr, Operation::MovDn4, 0, 0, 1); + check_op_decoding(&trace, 7, save_basic_block_addr, Operation::End, 0, 0, 0); // end inner join check_op_decoding(&trace, 8, join_addr, Operation::End, 0, 0, 0); // dyn check_op_decoding(&trace, 9, INIT_ADDR, Operation::Dyn, 0, 0, 0); // starting foo span - let dyn_addr = save_span_addr + EIGHT; - let add_span_addr = dyn_addr + EIGHT; + let dyn_addr = save_basic_block_addr + EIGHT; + let add_basic_block_addr = dyn_addr + EIGHT; check_op_decoding(&trace, 10, dyn_addr, Operation::Span, 2, 0, 0); - check_op_decoding(&trace, 11, add_span_addr, Operation::Push(ONE), 1, 0, 1); - check_op_decoding(&trace, 12, add_span_addr, Operation::Add, 0, 1, 1); - check_op_decoding(&trace, 13, add_span_addr, Operation::End, 0, 0, 0); + check_op_decoding(&trace, 11, add_basic_block_addr, Operation::Push(ONE), 1, 0, 1); + check_op_decoding(&trace, 12, add_basic_block_addr, Operation::Add, 0, 1, 1); + check_op_decoding(&trace, 13, add_basic_block_addr, Operation::End, 0, 0, 0); // end dyn check_op_decoding(&trace, 14, dyn_addr, Operation::End, 0, 0, 0); // end outer join @@ -1111,23 +1259,23 @@ fn dyn_block() { // --- check hasher state columns ------------------------------------------------------------- // in the first row, the hasher state is set to hashes of both child nodes - let join_hash: Word = join.hash().into(); - let dyn_hash: Word = dyn_block.hash().into(); + let join_hash: Word = join_node.digest().into(); + let dyn_hash: Word = dyn_node.digest().into(); assert_eq!(join_hash, get_hasher_state1(&trace, 0)); assert_eq!(dyn_hash, get_hasher_state2(&trace, 0)); // in the second row, the hasher set is set to hashes of both child nodes of the inner JOIN - let mul_span_hash: Word = mul_span.hash().into(); - let save_span_hash: Word = save_span.hash().into(); - assert_eq!(mul_span_hash, get_hasher_state1(&trace, 1)); - assert_eq!(save_span_hash, get_hasher_state2(&trace, 1)); + let mul_bb_node_hash: Word = mul_bb_node.digest().into(); + let save_bb_node_hash: Word = save_bb_node.digest().into(); + assert_eq!(mul_bb_node_hash, get_hasher_state1(&trace, 1)); + assert_eq!(save_bb_node_hash, get_hasher_state2(&trace, 1)); // at the end of the first SPAN, the hasher state is set to the hash of the first child - assert_eq!(mul_span_hash, get_hasher_state1(&trace, 4)); + assert_eq!(mul_bb_node_hash, get_hasher_state1(&trace, 4)); assert_eq!([ZERO, ZERO, ZERO, ZERO], get_hasher_state2(&trace, 4)); // at the end of the second SPAN, the hasher state is set to the hash of the second child - assert_eq!(save_span_hash, get_hasher_state1(&trace, 7)); + assert_eq!(save_bb_node_hash, get_hasher_state1(&trace, 7)); assert_eq!([ZERO, ZERO, ZERO, ZERO], get_hasher_state2(&trace, 7)); // at the end of the inner JOIN, the hasher set is set to the hash of the JOIN @@ -1135,7 +1283,7 @@ fn dyn_block() { assert_eq!([ZERO, ZERO, ZERO, ZERO], get_hasher_state2(&trace, 8)); // at the start of the DYN block, the hasher state is set to the hash of its child (foo span) - let foo_hash: Word = foo_root.hash().into(); + let foo_hash: Word = foo_root_node.digest().into(); assert_eq!(foo_hash, get_hasher_state1(&trace, 9)); assert_eq!([ZERO, ZERO, ZERO, ZERO], get_hasher_state2(&trace, 9)); @@ -1147,7 +1295,7 @@ fn dyn_block() { assert_eq!(dyn_hash, get_hasher_state1(&trace, 14)); // at the end of the program, the hasher state is set to the hash of the entire program - let program_hash: Word = program.hash().into(); + let program_hash: Word = program_root_node.digest().into(); assert_eq!(program_hash, get_hasher_state1(&trace, 15)); assert_eq!([ZERO, ZERO, ZERO, ZERO], get_hasher_state2(&trace, 15)); @@ -1165,7 +1313,15 @@ fn dyn_block() { #[test] fn set_user_op_helpers_many() { // --- user operation with 4 helper values ---------------------------------------------------- - let program = CodeBlock::new_span(vec![Operation::U32div]); + let program = { + let mut mast_forest = MastForest::new(); + + let basic_block = MastNode::new_basic_block(vec![Operation::U32div]); + let basic_block_id = mast_forest.ensure_node(basic_block); + mast_forest.set_entrypoint(basic_block_id); + + mast_forest.try_into().unwrap() + }; let a = rand_value::(); let b = rand_value::(); let (dividend, divisor) = if a > b { (a, b) } else { (b, a) }; @@ -1193,12 +1349,12 @@ fn set_user_op_helpers_many() { // HELPER FUNCTIONS // ================================================================================================ -fn build_trace(stack_inputs: &[u64], program: &CodeBlock) -> (DecoderTrace, usize) { +fn build_trace(stack_inputs: &[u64], program: &Program) -> (DecoderTrace, usize) { let stack_inputs = StackInputs::try_from_ints(stack_inputs.iter().copied()).unwrap(); let host = DefaultHost::default(); let mut process = Process::new(Kernel::default(), stack_inputs, host, ExecutionOptions::default()); - process.execute_code_block(program, &CodeBlockTable::default()).unwrap(); + process.execute(program).unwrap(); let (trace, _, _) = ExecutionTrace::test_finalize_trace(process); let trace_len = trace.num_rows() - ExecutionTrace::NUM_RAND_ROWS; @@ -1212,21 +1368,13 @@ fn build_trace(stack_inputs: &[u64], program: &CodeBlock) -> (DecoderTrace, usiz ) } -fn build_dyn_trace( - stack_inputs: &[u64], - program: &CodeBlock, - fn_block: CodeBlock, -) -> (DecoderTrace, usize) { +fn build_dyn_trace(stack_inputs: &[u64], program: &Program) -> (DecoderTrace, usize) { let stack_inputs = StackInputs::try_from_ints(stack_inputs.iter().copied()).unwrap(); let host = DefaultHost::default(); let mut process = Process::new(Kernel::default(), stack_inputs, host, ExecutionOptions::default()); - // build code block table - let mut cb_table = CodeBlockTable::default(); - cb_table.insert(fn_block); - - process.execute_code_block(program, &cb_table).unwrap(); + process.execute(program).unwrap(); let (trace, _, _) = ExecutionTrace::test_finalize_trace(process); let trace_len = trace.num_rows() - ExecutionTrace::NUM_RAND_ROWS; @@ -1240,27 +1388,12 @@ fn build_dyn_trace( ) } -fn build_call_trace( - program: &CodeBlock, - fn_block: CodeBlock, - kernel_proc: Option, -) -> (SystemTrace, DecoderTrace, usize) { - let kernel = match kernel_proc { - Some(ref proc) => Kernel::new(&[proc.hash()]).unwrap(), - None => Kernel::default(), - }; +fn build_call_trace(program: &Program, kernel: Kernel) -> (SystemTrace, DecoderTrace, usize) { let host = DefaultHost::default(); let stack_inputs = crate::StackInputs::default(); let mut process = Process::new(kernel, stack_inputs, host, ExecutionOptions::default()); - // build code block table - let mut cb_table = CodeBlockTable::default(); - cb_table.insert(fn_block); - if let Some(proc) = kernel_proc { - cb_table.insert(proc); - } - - process.execute_code_block(program, &cb_table).unwrap(); + process.execute(program).unwrap(); let (trace, _, _) = ExecutionTrace::test_finalize_trace(process); let trace_len = trace.num_rows() - ExecutionTrace::NUM_RAND_ROWS; diff --git a/processor/src/decoder/trace.rs b/processor/src/decoder/trace.rs index f06fe2e9d1..b8bda0805c 100644 --- a/processor/src/decoder/trace.rs +++ b/processor/src/decoder/trace.rs @@ -289,13 +289,13 @@ impl DecoderTrace { pub fn append_user_op( &mut self, op: Operation, - span_addr: Felt, + basic_block_addr: Felt, parent_addr: Felt, num_groups_left: Felt, group_ops_left: Felt, op_idx: Felt, ) { - self.addr_trace.push(span_addr); + self.addr_trace.push(basic_block_addr); self.append_opcode(op); self.hasher_trace[0].push(group_ops_left); diff --git a/processor/src/errors.rs b/processor/src/errors.rs index ca7e3357b8..fa5d9251ac 100644 --- a/processor/src/errors.rs +++ b/processor/src/errors.rs @@ -1,11 +1,11 @@ use super::{ crypto::MerkleError, system::{FMP_MAX, FMP_MIN}, - CodeBlock, Digest, Felt, QuadFelt, Word, + Digest, Felt, QuadFelt, Word, }; use alloc::string::String; use core::fmt::{Display, Formatter}; -use vm_core::{stack::STACK_TOP_SIZE, utils::to_hex}; +use vm_core::{mast::MastNodeId, stack::STACK_TOP_SIZE, utils::to_hex}; use winter_prover::{math::FieldElement, ProverError}; #[cfg(feature = "std")] @@ -19,10 +19,9 @@ pub enum ExecutionError { AdviceMapKeyNotFound(Word), AdviceStackReadFailed(u32), CallerNotInSyscall, - CodeBlockNotFound(Digest), CycleLimitExceeded(u32), DivideByZero(u32), - DynamicCodeBlockNotFound(Digest), + DynamicNodeNotFound(Digest), EventError(String), Ext2InttError(Ext2InttError), FailedAssertion { @@ -49,6 +48,9 @@ pub enum ExecutionError { }, LogArgumentZero(u32), MalformedSignatureKey(&'static str), + MastNodeNotFoundInForest { + node_id: MastNodeId, + }, MemoryAddressOutOfBounds(u64), MerklePathVerificationFailed { value: Word, @@ -61,11 +63,11 @@ pub enum ExecutionError { MerkleStoreUpdateFailed(MerkleError), NotBinaryValue(Felt), NotU32Value(Felt, Felt), + ProgramAlreadyExecuted, ProverError(ProverError), SmtNodeNotFound(Word), SmtNodePreImageNotValid(Word, usize), SyscallTargetNotInKernel(Digest), - UnexecutableCodeBlock(CodeBlock), } impl Display for ExecutionError { @@ -81,18 +83,11 @@ impl Display for ExecutionError { CallerNotInSyscall => { write!(f, "Instruction `caller` used outside of kernel context") } - CodeBlockNotFound(digest) => { - let hex = to_hex(digest.as_bytes()); - write!( - f, - "Failed to execute code block with root {hex}; the block could not be found" - ) - } CycleLimitExceeded(max_cycles) => { write!(f, "Exceeded the allowed number of cycles (max cycles = {max_cycles})") } DivideByZero(clk) => write!(f, "Division by zero at clock cycle {clk}"), - DynamicCodeBlockNotFound(digest) => { + DynamicNodeNotFound(digest) => { let hex = to_hex(digest.as_bytes()); write!( f, @@ -152,6 +147,9 @@ impl Display for ExecutionError { ) } MalformedSignatureKey(signature) => write!(f, "Malformed signature key: {signature}"), + MastNodeNotFoundInForest { node_id } => { + write!(f, "Malformed MAST forest, node id {node_id} doesn't exist") + } MemoryAddressOutOfBounds(addr) => { write!(f, "Memory address cannot exceed 2^32 but was {addr}") } @@ -191,14 +189,14 @@ impl Display for ExecutionError { let node_hex = to_hex(Felt::elements_as_bytes(node)); write!(f, "Invalid pre-image for node {node_hex}. Expected pre-image length to be a multiple of 8, but was {preimage_len}") } + ProgramAlreadyExecuted => { + write!(f, "a program has already been executed in this process") + } ProverError(error) => write!(f, "Proof generation failed: {error}"), SyscallTargetNotInKernel(proc) => { let hex = to_hex(proc.as_bytes()); write!(f, "Syscall failed: procedure with root {hex} was not found in the kernel") } - UnexecutableCodeBlock(block) => { - write!(f, "Execution reached unexecutable code block {block:?}") - } } } } diff --git a/processor/src/host/advice/injectors/dsa.rs b/processor/src/host/advice/injectors/dsa.rs index be11d1bd6f..e0f8056134 100644 --- a/processor/src/host/advice/injectors/dsa.rs +++ b/processor/src/host/advice/injectors/dsa.rs @@ -9,8 +9,8 @@ use super::super::{ExecutionError, Felt, Word}; /// 1. The nonce represented as 8 field elements. /// 2. The expanded public key represented as the coefficients of a polynomial of degree < 512. /// 3. The signature represented as the coefficients of a polynomial of degree < 512. -/// 4. The product of the above two polynomials in the ring of polynomials with coefficients -/// in the Miden field. +/// 4. The product of the above two polynomials in the ring of polynomials with coefficients in the +/// Miden field. /// /// # Errors /// Will return an error if either: diff --git a/processor/src/lib.rs b/processor/src/lib.rs index 42a86e0282..244cb31b4a 100644 --- a/processor/src/lib.rs +++ b/processor/src/lib.rs @@ -15,15 +15,19 @@ use miden_air::trace::{ }; pub use miden_air::{ExecutionOptions, ExecutionOptionsError}; pub use vm_core::{ - chiplets::hasher::Digest, crypto::merkle::SMT_DEPTH, errors::InputError, - utils::DeserializationError, AdviceInjector, AssemblyOp, Felt, Kernel, Operation, Program, - ProgramInfo, QuadExtension, StackInputs, StackOutputs, Word, EMPTY_WORD, ONE, ZERO, + chiplets::hasher::Digest, + crypto::merkle::SMT_DEPTH, + errors::InputError, + mast::{MastForest, MastNode, MastNodeId, MerkleTreeNode}, + utils::DeserializationError, + AdviceInjector, AssemblyOp, Felt, Kernel, Operation, Program, ProgramInfo, QuadExtension, + StackInputs, StackOutputs, Word, EMPTY_WORD, ONE, ZERO, }; use vm_core::{ - code_blocks::{ - Call, CodeBlock, Dyn, Join, Loop, OpBatch, Span, Split, OP_BATCH_SIZE, OP_GROUP_SIZE, + mast::{ + BasicBlockNode, CallNode, DynNode, JoinNode, LoopNode, OpBatch, SplitNode, OP_GROUP_SIZE, }, - CodeBlockTable, Decorator, DecoratorIterator, FieldElement, StackTopState, + Decorator, DecoratorIterator, FieldElement, StackTopState, }; pub use winter_prover::matrix::ColMatrix; @@ -221,92 +225,89 @@ where // PROGRAM EXECUTOR // -------------------------------------------------------------------------------------------- - /// Executes the provided [Program] in this process. + /// Executes the provided [`Program`] in this process. pub fn execute(&mut self, program: &Program) -> Result { - assert_eq!(self.system.clk(), 0, "a program has already been executed in this process"); - self.execute_code_block(program.root(), program.cb_table())?; + if self.system.clk() != 0 { + return Err(ExecutionError::ProgramAlreadyExecuted); + } + + self.execute_mast_node(program.entrypoint(), program)?; Ok(self.stack.build_stack_outputs()) } - // CODE BLOCK EXECUTORS + // NODE EXECUTORS // -------------------------------------------------------------------------------------------- - /// Executes the specified [CodeBlock]. - /// - /// # Errors - /// Returns an [ExecutionError] if executing the specified block fails for any reason. - fn execute_code_block( + fn execute_mast_node( &mut self, - block: &CodeBlock, - cb_table: &CodeBlockTable, + node_id: MastNodeId, + program: &Program, ) -> Result<(), ExecutionError> { - match block { - CodeBlock::Join(block) => self.execute_join_block(block, cb_table), - CodeBlock::Split(block) => self.execute_split_block(block, cb_table), - CodeBlock::Loop(block) => self.execute_loop_block(block, cb_table), - CodeBlock::Call(block) => self.execute_call_block(block, cb_table), - CodeBlock::Dyn(block) => self.execute_dyn_block(block, cb_table), - CodeBlock::Span(block) => self.execute_span_block(block), - CodeBlock::Proxy(proxy) => match cb_table.get(proxy.hash()) { - Some(block) => self.execute_code_block(block, cb_table), - None => Err(ExecutionError::UnexecutableCodeBlock(block.clone())), - }, + let wrapper_node = &program + .get_node_by_id(node_id) + .ok_or(ExecutionError::MastNodeNotFoundInForest { node_id })?; + + match wrapper_node { + MastNode::Block(node) => self.execute_basic_block_node(node), + MastNode::Join(node) => self.execute_join_node(node, program), + MastNode::Split(node) => self.execute_split_node(node, program), + MastNode::Loop(node) => self.execute_loop_node(node, program), + MastNode::Call(node) => self.execute_call_node(node, program), + MastNode::Dyn => self.execute_dyn_node(program), } } - /// Executes the specified [Join] block. #[inline(always)] - fn execute_join_block( + fn execute_join_node( &mut self, - block: &Join, - cb_table: &CodeBlockTable, + node: &JoinNode, + program: &Program, ) -> Result<(), ExecutionError> { - self.start_join_block(block)?; + self.start_join_node(node, program)?; // execute first and then second child of the join block - self.execute_code_block(block.first(), cb_table)?; - self.execute_code_block(block.second(), cb_table)?; + self.execute_mast_node(node.first(), program)?; + self.execute_mast_node(node.second(), program)?; - self.end_join_block(block) + self.end_join_node(node) } - /// Executes the specified [Split] block. #[inline(always)] - fn execute_split_block( + fn execute_split_node( &mut self, - block: &Split, - cb_table: &CodeBlockTable, + node: &SplitNode, + program: &Program, ) -> Result<(), ExecutionError> { // start the SPLIT block; this also pops the stack and returns the popped element - let condition = self.start_split_block(block)?; + let condition = self.start_split_node(node, program)?; // execute either the true or the false branch of the split block based on the condition if condition == ONE { - self.execute_code_block(block.on_true(), cb_table)?; + self.execute_mast_node(node.on_true(), program)?; } else if condition == ZERO { - self.execute_code_block(block.on_false(), cb_table)?; + self.execute_mast_node(node.on_false(), program)?; } else { return Err(ExecutionError::NotBinaryValue(condition)); } - self.end_split_block(block) + self.end_split_node(node) } /// Executes the specified [Loop] block. #[inline(always)] - fn execute_loop_block( + fn execute_loop_node( &mut self, - block: &Loop, - cb_table: &CodeBlockTable, + node: &LoopNode, + program: &Program, ) -> Result<(), ExecutionError> { // start the LOOP block; this also pops the stack and returns the popped element - let condition = self.start_loop_block(block)?; + let condition = self.start_loop_node(node, program)?; // if the top of the stack is ONE, execute the loop body; otherwise skip the loop body if condition == ONE { // execute the loop body at least once - self.execute_code_block(block.body(), cb_table)?; + self.execute_mast_node(node.body(), program)?; // keep executing the loop body until the condition on the top of the stack is no // longer ONE; each iteration of the loop is preceded by executing REPEAT operation @@ -314,15 +315,15 @@ where while self.stack.peek() == ONE { self.decoder.repeat(); self.execute_op(Operation::Drop)?; - self.execute_code_block(block.body(), cb_table)?; + self.execute_mast_node(node.body(), program)?; } // end the LOOP block and drop the condition from the stack - self.end_loop_block(block, true) + self.end_loop_node(node, true) } else if condition == ZERO { // end the LOOP block, but don't drop the condition from the stack because it was // already dropped when we started the LOOP block - self.end_loop_block(block, false) + self.end_loop_node(node, false) } else { Err(ExecutionError::NotBinaryValue(condition)) } @@ -330,76 +331,80 @@ where /// Executes the specified [Call] block. #[inline(always)] - fn execute_call_block( + fn execute_call_node( &mut self, - block: &Call, - cb_table: &CodeBlockTable, + call_node: &CallNode, + program: &Program, ) -> Result<(), ExecutionError> { + let callee_digest = { + let callee = program.get_node_by_id(call_node.callee()).ok_or_else(|| { + ExecutionError::MastNodeNotFoundInForest { + node_id: call_node.callee(), + } + })?; + + callee.digest() + }; + // if this is a syscall, make sure the call target exists in the kernel - if block.is_syscall() { - self.chiplets.access_kernel_proc(block.fn_hash())?; + if call_node.is_syscall() { + self.chiplets.access_kernel_proc(callee_digest)?; } - self.start_call_block(block)?; + self.start_call_node(call_node, program)?; // if this is a dyncall, execute the dynamic code block - if block.fn_hash() == Dyn::dyn_hash() { - self.execute_dyn_block(&Dyn::new(), cb_table)?; + if callee_digest == DynNode.digest() { + self.execute_dyn_node(program)?; } else { - // get function body from the code block table and execute it - let fn_body = cb_table - .get(block.fn_hash()) - .ok_or_else(|| ExecutionError::CodeBlockNotFound(block.fn_hash()))?; - self.execute_code_block(fn_body, cb_table)?; + self.execute_mast_node(call_node.callee(), program)?; } - self.end_call_block(block) + self.end_call_node(call_node) } - /// Executes the specified [Dyn] block. + /// Executes the specified [DynNode] node. #[inline(always)] - fn execute_dyn_block( - &mut self, - block: &Dyn, - cb_table: &CodeBlockTable, - ) -> Result<(), ExecutionError> { + fn execute_dyn_node(&mut self, program: &Program) -> Result<(), ExecutionError> { // get target hash from the stack - let dyn_hash = self.stack.get_word(0); - self.start_dyn_block(block, dyn_hash)?; + let callee_hash = self.stack.get_word(0); + self.start_dyn_node(callee_hash)?; // get dynamic code from the code block table and execute it - let dyn_digest = dyn_hash.into(); - let dyn_code = cb_table - .get(dyn_digest) - .ok_or_else(|| ExecutionError::DynamicCodeBlockNotFound(dyn_digest))?; - self.execute_code_block(dyn_code, cb_table)?; + let callee_id = program + .get_node_id_by_digest(callee_hash.into()) + .ok_or_else(|| ExecutionError::DynamicNodeNotFound(callee_hash.into()))?; + self.execute_mast_node(callee_id, program)?; - self.end_dyn_block(block) + self.end_dyn_node() } - /// Executes the specified [Span] block. + /// Executes the specified [`BasicBlockNode`] block. #[inline(always)] - fn execute_span_block(&mut self, block: &Span) -> Result<(), ExecutionError> { - self.start_span_block(block)?; + fn execute_basic_block_node( + &mut self, + basic_block: &BasicBlockNode, + ) -> Result<(), ExecutionError> { + self.start_basic_block_node(basic_block)?; let mut op_offset = 0; - let mut decorators = block.decorator_iter(); + let mut decorators = basic_block.decorator_iter(); // execute the first operation batch - self.execute_op_batch(&block.op_batches()[0], &mut decorators, op_offset)?; - op_offset += block.op_batches()[0].ops().len(); + self.execute_op_batch(&basic_block.op_batches()[0], &mut decorators, op_offset)?; + op_offset += basic_block.op_batches()[0].ops().len(); // if the span contains more operation batches, execute them. each additional batch is // preceded by a RESPAN operation; executing RESPAN operation does not change the state // of the stack - for op_batch in block.op_batches().iter().skip(1) { + for op_batch in basic_block.op_batches().iter().skip(1) { self.respan(op_batch); self.execute_op(Operation::Noop)?; self.execute_op_batch(op_batch, &mut decorators, op_offset)?; op_offset += op_batch.ops().len(); } - self.end_span_block(block)?; + self.end_basic_block_node(basic_block)?; // execute any decorators which have not been executed during span ops execution; this // can happen for decorators appearing after all operations in a block. these decorators diff --git a/processor/src/operations/comb_ops.rs b/processor/src/operations/comb_ops.rs index c3754baca1..d88ed0a192 100644 --- a/processor/src/operations/comb_ops.rs +++ b/processor/src/operations/comb_ops.rs @@ -52,10 +52,10 @@ where /// with common denominator (x - gz). /// 4. x_addr is the memory address from which we are loading the Ti's using the MSTREAM /// instruction. - /// 5. z_addr is the memory address to the i-th OOD evaluations at z and gz - /// i.e. T_i(z):= (T_i(z)0, T_i(z)1) and T_i(gz):= (T_i(gz)0, T_i(gz)1). - /// 6. a_addr is the memory address of the i-th random element alpha_i used in batching - /// the trace polynomial quotients. + /// 5. z_addr is the memory address to the i-th OOD evaluations at z and gz i.e. T_i(z):= + /// (T_i(z)0, T_i(z)1) and T_i(gz):= (T_i(gz)0, T_i(gz)1). + /// 6. a_addr is the memory address of the i-th random element alpha_i used in batching the + /// trace polynomial quotients. /// /// The instruction also makes use of the helper registers to hold the values of T_i(z), T_i(gz) /// and alpha_i during the course of its execution. diff --git a/processor/src/trace/tests/chiplets/hasher.rs b/processor/src/trace/tests/chiplets/hasher.rs index 429442ccea..19bfd34886 100644 --- a/processor/src/trace/tests/chiplets/hasher.rs +++ b/processor/src/trace/tests/chiplets/hasher.rs @@ -1,5 +1,5 @@ use super::{ - build_span_with_respan_ops, build_trace_from_block, build_trace_from_ops_with_inputs, + build_span_with_respan_ops, build_trace_from_ops_with_inputs, build_trace_from_program, init_state_from_words, rand_array, AdviceInputs, ExecutionTrace, Felt, FieldElement, Operation, Trace, AUX_TRACE_RAND_ELEMENTS, CHIPLETS_AUX_TRACE_OFFSET, NUM_RAND_ROWS, ONE, ZERO, }; @@ -21,10 +21,10 @@ use miden_air::trace::{ }; use vm_core::{ chiplets::hasher::apply_permutation, - code_blocks::CodeBlock, crypto::merkle::{MerkleStore, MerkleTree, NodeIndex}, + mast::{MastForest, MastNode}, utils::range, - Word, + Program, Word, }; // CONSTANTS @@ -47,8 +47,17 @@ pub const DECODER_OP_BITS_RANGE: Range = #[test] #[allow(clippy::needless_range_loop)] pub fn b_chip_span() { - let program = CodeBlock::new_span(vec![Operation::Add, Operation::Mul]); - let trace = build_trace_from_block(&program, &[]); + let program = { + let mut mast_forest = MastForest::new(); + + let basic_block = MastNode::new_basic_block(vec![Operation::Add, Operation::Mul]); + let basic_block_id = mast_forest.ensure_node(basic_block); + mast_forest.set_entrypoint(basic_block_id); + + Program::new(mast_forest).unwrap() + }; + + let trace = build_trace_from_program(&program, &[]); let alphas = rand_array::(); let aux_columns = trace.build_aux_trace(&alphas).unwrap(); @@ -111,9 +120,17 @@ pub fn b_chip_span() { #[test] #[allow(clippy::needless_range_loop)] pub fn b_chip_span_with_respan() { - let (ops, _) = build_span_with_respan_ops(); - let program = CodeBlock::new_span(ops); - let trace = build_trace_from_block(&program, &[]); + let program = { + let mut mast_forest = MastForest::new(); + + let (ops, _) = build_span_with_respan_ops(); + let basic_block = MastNode::new_basic_block(ops); + let basic_block_id = mast_forest.ensure_node(basic_block); + mast_forest.set_entrypoint(basic_block_id); + + Program::new(mast_forest).unwrap() + }; + let trace = build_trace_from_program(&program, &[]); let alphas = rand_array::(); let aux_columns = trace.build_aux_trace(&alphas).unwrap(); @@ -197,10 +214,24 @@ pub fn b_chip_span_with_respan() { #[test] #[allow(clippy::needless_range_loop)] pub fn b_chip_merge() { - let t_branch = CodeBlock::new_span(vec![Operation::Add]); - let f_branch = CodeBlock::new_span(vec![Operation::Mul]); - let program = CodeBlock::new_split(t_branch, f_branch); - let trace = build_trace_from_block(&program, &[]); + let program = { + let mut mast_forest = MastForest::new(); + + let t_branch = MastNode::new_basic_block(vec![Operation::Add]); + let t_branch_id = mast_forest.ensure_node(t_branch); + + let f_branch = MastNode::new_basic_block(vec![Operation::Mul]); + let f_branch_id = mast_forest.ensure_node(f_branch); + + let split = MastNode::new_split(t_branch_id, f_branch_id, &mast_forest); + let split_id = mast_forest.ensure_node(split); + + mast_forest.set_entrypoint(split_id); + + Program::new(mast_forest).unwrap() + }; + + let trace = build_trace_from_program(&program, &[]); let alphas = rand_array::(); let aux_columns = trace.build_aux_trace(&alphas).unwrap(); @@ -304,9 +335,17 @@ pub fn b_chip_merge() { #[test] #[allow(clippy::needless_range_loop)] pub fn b_chip_permutation() { - let program = CodeBlock::new_span(vec![Operation::HPerm]); + let program = { + let mut mast_forest = MastForest::new(); + + let basic_block = MastNode::new_basic_block(vec![Operation::HPerm]); + let basic_block_id = mast_forest.ensure_node(basic_block); + mast_forest.set_entrypoint(basic_block_id); + + Program::new(mast_forest).unwrap() + }; let stack = vec![8, 0, 0, 0, 1, 2, 3, 4, 5, 6, 7, 8]; - let trace = build_trace_from_block(&program, &stack); + let trace = build_trace_from_program(&program, &stack); let mut hperm_state: [Felt; STATE_WIDTH] = stack .iter() diff --git a/processor/src/trace/tests/chiplets/mod.rs b/processor/src/trace/tests/chiplets/mod.rs index 9637dfda70..c95c21bf10 100644 --- a/processor/src/trace/tests/chiplets/mod.rs +++ b/processor/src/trace/tests/chiplets/mod.rs @@ -1,6 +1,6 @@ use super::{ super::{utils::build_span_with_respan_ops, Trace, NUM_RAND_ROWS}, - build_trace_from_block, build_trace_from_ops, build_trace_from_ops_with_inputs, + build_trace_from_ops, build_trace_from_ops_with_inputs, build_trace_from_program, init_state_from_words, rand_array, AdviceInputs, ExecutionTrace, Felt, FieldElement, Operation, Word, ONE, ZERO, }; diff --git a/processor/src/trace/tests/decoder.rs b/processor/src/trace/tests/decoder.rs index fb744302b1..cbb4cf7e6c 100644 --- a/processor/src/trace/tests/decoder.rs +++ b/processor/src/trace/tests/decoder.rs @@ -1,6 +1,6 @@ use super::{ super::{ - tests::{build_trace_from_block, build_trace_from_ops}, + tests::{build_trace_from_ops, build_trace_from_program}, utils::build_span_with_respan_ops, NUM_RAND_ROWS, }, @@ -15,7 +15,10 @@ use miden_air::trace::{ AUX_TRACE_RAND_ELEMENTS, }; use test_utils::rand::rand_array; -use vm_core::{code_blocks::CodeBlock, FieldElement, Operation, Word, ONE, ZERO}; +use vm_core::{ + mast::{MastForest, MastNode, MerkleTreeNode}, + FieldElement, Operation, Program, Word, ONE, ZERO, +}; // BLOCK STACK TABLE TESTS // ================================================================================================ @@ -66,11 +69,24 @@ fn decoder_p1_span_with_respan() { #[test] #[allow(clippy::needless_range_loop)] fn decoder_p1_join() { - let span1 = CodeBlock::new_span(vec![Operation::Mul]); - let span2 = CodeBlock::new_span(vec![Operation::Add]); - let program = CodeBlock::new_join([span1, span2]); + let program = { + let mut mast_forest = MastForest::new(); + + let basic_block_1 = MastNode::new_basic_block(vec![Operation::Mul]); + let basic_block_1_id = mast_forest.ensure_node(basic_block_1); + + let basic_block_2 = MastNode::new_basic_block(vec![Operation::Add]); + let basic_block_2_id = mast_forest.ensure_node(basic_block_2); + + let join = MastNode::new_join(basic_block_1_id, basic_block_2_id, &mast_forest); + let join_id = mast_forest.ensure_node(join); + + mast_forest.set_entrypoint(join_id); + + Program::new(mast_forest).unwrap() + }; - let trace = build_trace_from_block(&program, &[]); + let trace = build_trace_from_program(&program, &[]); let alphas = rand_array::(); let aux_columns = trace.build_aux_trace(&alphas).unwrap(); let p1 = aux_columns.get_column(P1_COL_IDX); @@ -126,11 +142,24 @@ fn decoder_p1_join() { #[test] #[allow(clippy::needless_range_loop)] fn decoder_p1_split() { - let span1 = CodeBlock::new_span(vec![Operation::Mul]); - let span2 = CodeBlock::new_span(vec![Operation::Add]); - let program = CodeBlock::new_split(span1, span2); + let program = { + let mut mast_forest = MastForest::new(); + + let basic_block_1 = MastNode::new_basic_block(vec![Operation::Mul]); + let basic_block_1_id = mast_forest.ensure_node(basic_block_1); + + let basic_block_2 = MastNode::new_basic_block(vec![Operation::Add]); + let basic_block_2_id = mast_forest.ensure_node(basic_block_2); + + let split = MastNode::new_split(basic_block_1_id, basic_block_2_id, &mast_forest); + let split_id = mast_forest.ensure_node(split); + + mast_forest.set_entrypoint(split_id); + + Program::new(mast_forest).unwrap() + }; - let trace = build_trace_from_block(&program, &[1]); + let trace = build_trace_from_program(&program, &[1]); let alphas = rand_array::(); let aux_columns = trace.build_aux_trace(&alphas).unwrap(); let p1 = aux_columns.get_column(P1_COL_IDX); @@ -173,12 +202,27 @@ fn decoder_p1_split() { #[test] #[allow(clippy::needless_range_loop)] fn decoder_p1_loop_with_repeat() { - let span1 = CodeBlock::new_span(vec![Operation::Pad]); - let span2 = CodeBlock::new_span(vec![Operation::Drop]); - let body = CodeBlock::new_join([span1, span2]); - let program = CodeBlock::new_loop(body); + let program = { + let mut mast_forest = MastForest::new(); - let trace = build_trace_from_block(&program, &[0, 1, 1]); + let basic_block_1 = MastNode::new_basic_block(vec![Operation::Pad]); + let basic_block_1_id = mast_forest.ensure_node(basic_block_1); + + let basic_block_2 = MastNode::new_basic_block(vec![Operation::Drop]); + let basic_block_2_id = mast_forest.ensure_node(basic_block_2); + + let join = MastNode::new_join(basic_block_1_id, basic_block_2_id, &mast_forest); + let join_id = mast_forest.ensure_node(join); + + let loop_node = MastNode::new_loop(join_id, &mast_forest); + let loop_node_id = mast_forest.ensure_node(loop_node); + + mast_forest.set_entrypoint(loop_node_id); + + Program::new(mast_forest).unwrap() + }; + + let trace = build_trace_from_program(&program, &[0, 1, 1]); let alphas = rand_array::(); let aux_columns = trace.build_aux_trace(&alphas).unwrap(); let p1 = aux_columns.get_column(P1_COL_IDX); @@ -290,15 +334,24 @@ fn decoder_p1_loop_with_repeat() { #[test] #[allow(clippy::needless_range_loop)] fn decoder_p2_span_with_respan() { - let (ops, _) = build_span_with_respan_ops(); - let span = CodeBlock::new_span(ops); - let trace = build_trace_from_block(&span, &[]); + let program = { + let mut mast_forest = MastForest::new(); + + let (ops, _) = build_span_with_respan_ops(); + let basic_block = MastNode::new_basic_block(ops); + let basic_block_id = mast_forest.ensure_node(basic_block); + mast_forest.set_entrypoint(basic_block_id); + + Program::new(mast_forest).unwrap() + }; + let trace = build_trace_from_program(&program, &[]); let alphas = rand_array::(); let aux_columns = trace.build_aux_trace(&alphas).unwrap(); let p2 = aux_columns.get_column(P2_COL_IDX); - let row_values = - [BlockHashTableRow::new_test(ZERO, span.hash().into(), false, false).collapse(&alphas)]; + let row_values = [ + BlockHashTableRow::new_test(ZERO, program.hash().into(), false, false).collapse(&alphas) + ]; // make sure the first entry is initialized to program hash let mut expected_value = row_values[0]; @@ -320,19 +373,31 @@ fn decoder_p2_span_with_respan() { #[test] #[allow(clippy::needless_range_loop)] fn decoder_p2_join() { - let span1 = CodeBlock::new_span(vec![Operation::Mul]); - let span2 = CodeBlock::new_span(vec![Operation::Add]); - let program = CodeBlock::new_join([span1.clone(), span2.clone()]); + let mut mast_forest = MastForest::new(); - let trace = build_trace_from_block(&program, &[]); + let basic_block_1 = MastNode::new_basic_block(vec![Operation::Mul]); + let basic_block_1_id = mast_forest.ensure_node(basic_block_1.clone()); + + let basic_block_2 = MastNode::new_basic_block(vec![Operation::Add]); + let basic_block_2_id = mast_forest.ensure_node(basic_block_2.clone()); + + let join = MastNode::new_join(basic_block_1_id, basic_block_2_id, &mast_forest); + let join_id = mast_forest.ensure_node(join.clone()); + mast_forest.set_entrypoint(join_id); + + let program = Program::new(mast_forest).unwrap(); + + let trace = build_trace_from_program(&program, &[]); let alphas = rand_array::(); let aux_columns = trace.build_aux_trace(&alphas).unwrap(); let p2 = aux_columns.get_column(P2_COL_IDX); let row_values = [ - BlockHashTableRow::new_test(ZERO, program.hash().into(), false, false).collapse(&alphas), - BlockHashTableRow::new_test(ONE, span1.hash().into(), true, false).collapse(&alphas), - BlockHashTableRow::new_test(ONE, span2.hash().into(), false, false).collapse(&alphas), + BlockHashTableRow::new_test(ZERO, join.digest().into(), false, false).collapse(&alphas), + BlockHashTableRow::new_test(ONE, basic_block_1.digest().into(), true, false) + .collapse(&alphas), + BlockHashTableRow::new_test(ONE, basic_block_2.digest().into(), false, false) + .collapse(&alphas), ]; // make sure the first entry is initialized to program hash @@ -373,18 +438,32 @@ fn decoder_p2_join() { #[test] #[allow(clippy::needless_range_loop)] fn decoder_p2_split_true() { - let span1 = CodeBlock::new_span(vec![Operation::Mul]); - let span2 = CodeBlock::new_span(vec![Operation::Add]); - let program = CodeBlock::new_split(span1.clone(), span2); + // build program + let mut mast_forest = MastForest::new(); + + let basic_block_1 = MastNode::new_basic_block(vec![Operation::Mul]); + let basic_block_1_id = mast_forest.ensure_node(basic_block_1.clone()); - let trace = build_trace_from_block(&program, &[1]); + let basic_block_2 = MastNode::new_basic_block(vec![Operation::Add]); + let basic_block_2_id = mast_forest.ensure_node(basic_block_2); + + let split = MastNode::new_split(basic_block_1_id, basic_block_2_id, &mast_forest); + let split_id = mast_forest.ensure_node(split); + + mast_forest.set_entrypoint(split_id); + + let program = Program::new(mast_forest).unwrap(); + + // build trace from program + let trace = build_trace_from_program(&program, &[1]); let alphas = rand_array::(); let aux_columns = trace.build_aux_trace(&alphas).unwrap(); let p2 = aux_columns.get_column(P2_COL_IDX); let row_values = [ BlockHashTableRow::new_test(ZERO, program.hash().into(), false, false).collapse(&alphas), - BlockHashTableRow::new_test(ONE, span1.hash().into(), false, false).collapse(&alphas), + BlockHashTableRow::new_test(ONE, basic_block_1.digest().into(), false, false) + .collapse(&alphas), ]; // make sure the first entry is initialized to program hash @@ -417,18 +496,32 @@ fn decoder_p2_split_true() { #[test] #[allow(clippy::needless_range_loop)] fn decoder_p2_split_false() { - let span1 = CodeBlock::new_span(vec![Operation::Mul]); - let span2 = CodeBlock::new_span(vec![Operation::Add]); - let program = CodeBlock::new_split(span1, span2.clone()); + // build program + let mut mast_forest = MastForest::new(); + + let basic_block_1 = MastNode::new_basic_block(vec![Operation::Mul]); + let basic_block_1_id = mast_forest.ensure_node(basic_block_1.clone()); - let trace = build_trace_from_block(&program, &[0]); + let basic_block_2 = MastNode::new_basic_block(vec![Operation::Add]); + let basic_block_2_id = mast_forest.ensure_node(basic_block_2.clone()); + + let split = MastNode::new_split(basic_block_1_id, basic_block_2_id, &mast_forest); + let split_id = mast_forest.ensure_node(split); + + mast_forest.set_entrypoint(split_id); + + let program = Program::new(mast_forest).unwrap(); + + // build trace from program + let trace = build_trace_from_program(&program, &[0]); let alphas = rand_array::(); let aux_columns = trace.build_aux_trace(&alphas).unwrap(); let p2 = aux_columns.get_column(P2_COL_IDX); let row_values = [ BlockHashTableRow::new_test(ZERO, program.hash().into(), false, false).collapse(&alphas), - BlockHashTableRow::new_test(ONE, span2.hash().into(), false, false).collapse(&alphas), + BlockHashTableRow::new_test(ONE, basic_block_2.digest().into(), false, false) + .collapse(&alphas), ]; // make sure the first entry is initialized to program hash @@ -461,12 +554,27 @@ fn decoder_p2_split_false() { #[test] #[allow(clippy::needless_range_loop)] fn decoder_p2_loop_with_repeat() { - let span1 = CodeBlock::new_span(vec![Operation::Pad]); - let span2 = CodeBlock::new_span(vec![Operation::Drop]); - let body = CodeBlock::new_join([span1.clone(), span2.clone()]); - let program = CodeBlock::new_loop(body.clone()); + // build program + let mut mast_forest = MastForest::new(); + + let basic_block_1 = MastNode::new_basic_block(vec![Operation::Pad]); + let basic_block_1_id = mast_forest.ensure_node(basic_block_1.clone()); + + let basic_block_2 = MastNode::new_basic_block(vec![Operation::Drop]); + let basic_block_2_id = mast_forest.ensure_node(basic_block_2.clone()); + + let join = MastNode::new_join(basic_block_1_id, basic_block_2_id, &mast_forest); + let join_id = mast_forest.ensure_node(join.clone()); + + let loop_node = MastNode::new_loop(join_id, &mast_forest); + let loop_node_id = mast_forest.ensure_node(loop_node); + + mast_forest.set_entrypoint(loop_node_id); + + let program = Program::new(mast_forest).unwrap(); - let trace = build_trace_from_block(&program, &[0, 1, 1]); + // build trace from program + let trace = build_trace_from_program(&program, &[0, 1, 1]); let alphas = rand_array::(); let aux_columns = trace.build_aux_trace(&alphas).unwrap(); let p2 = aux_columns.get_column(P2_COL_IDX); @@ -475,11 +583,15 @@ fn decoder_p2_loop_with_repeat() { let a_33 = Felt::new(33); // address of the JOIN block in the second iteration let row_values = [ BlockHashTableRow::new_test(ZERO, program.hash().into(), false, false).collapse(&alphas), - BlockHashTableRow::new_test(ONE, body.hash().into(), false, true).collapse(&alphas), - BlockHashTableRow::new_test(a_9, span1.hash().into(), true, false).collapse(&alphas), - BlockHashTableRow::new_test(a_9, span2.hash().into(), false, false).collapse(&alphas), - BlockHashTableRow::new_test(a_33, span1.hash().into(), true, false).collapse(&alphas), - BlockHashTableRow::new_test(a_33, span2.hash().into(), false, false).collapse(&alphas), + BlockHashTableRow::new_test(ONE, join.digest().into(), false, true).collapse(&alphas), + BlockHashTableRow::new_test(a_9, basic_block_1.digest().into(), true, false) + .collapse(&alphas), + BlockHashTableRow::new_test(a_9, basic_block_2.digest().into(), false, false) + .collapse(&alphas), + BlockHashTableRow::new_test(a_33, basic_block_1.digest().into(), true, false) + .collapse(&alphas), + BlockHashTableRow::new_test(a_33, basic_block_2.digest().into(), false, false) + .collapse(&alphas), ]; // make sure the first entry is initialized to program hash diff --git a/processor/src/trace/tests/mod.rs b/processor/src/trace/tests/mod.rs index 35e41e26d2..adc83f5245 100644 --- a/processor/src/trace/tests/mod.rs +++ b/processor/src/trace/tests/mod.rs @@ -6,7 +6,8 @@ use crate::{AdviceInputs, DefaultHost, ExecutionOptions, MemAdviceProvider, Stac use alloc::vec::Vec; use test_utils::rand::rand_array; use vm_core::{ - code_blocks::CodeBlock, CodeBlockTable, Kernel, Operation, StackOutputs, Word, ONE, ZERO, + mast::{MastForest, MastNode}, + Kernel, Operation, Program, StackOutputs, Word, ONE, ZERO, }; mod chiplets; @@ -19,20 +20,27 @@ mod stack; // ================================================================================================ /// Builds a sample trace by executing the provided code block against the provided stack inputs. -pub fn build_trace_from_block(program: &CodeBlock, stack_inputs: &[u64]) -> ExecutionTrace { +pub fn build_trace_from_program(program: &Program, stack_inputs: &[u64]) -> ExecutionTrace { let stack_inputs = StackInputs::try_from_ints(stack_inputs.iter().copied()).unwrap(); let host = DefaultHost::default(); let mut process = Process::new(Kernel::default(), stack_inputs, host, ExecutionOptions::default()); - process.execute_code_block(program, &CodeBlockTable::default()).unwrap(); + process.execute(program).unwrap(); ExecutionTrace::new(process, StackOutputs::default()) } /// Builds a sample trace by executing a span block containing the specified operations. This /// results in 1 additional hash cycle (8 rows) at the beginning of the hash chiplet. pub fn build_trace_from_ops(operations: Vec, stack: &[u64]) -> ExecutionTrace { - let program = CodeBlock::new_span(operations); - build_trace_from_block(&program, stack) + let mut mast_forest = MastForest::new(); + + let basic_block = MastNode::new_basic_block(operations); + let basic_block_id = mast_forest.ensure_node(basic_block); + mast_forest.set_entrypoint(basic_block_id); + + let program = Program::new(mast_forest).unwrap(); + + build_trace_from_program(&program, stack) } /// Builds a sample trace by executing a span block containing the specified operations. Unlike the @@ -47,7 +55,14 @@ pub fn build_trace_from_ops_with_inputs( let host = DefaultHost::new(advice_provider); let mut process = Process::new(Kernel::default(), stack_inputs, host, ExecutionOptions::default()); - let program = CodeBlock::new_span(operations); - process.execute_code_block(&program, &CodeBlockTable::default()).unwrap(); + + let mut mast_forest = MastForest::new(); + let basic_block = MastNode::new_basic_block(operations); + let basic_block_id = mast_forest.ensure_node(basic_block); + mast_forest.set_entrypoint(basic_block_id); + + let program = Program::new(mast_forest).unwrap(); + + process.execute(&program).unwrap(); ExecutionTrace::new(process, StackOutputs::default()) } diff --git a/prover/src/lib.rs b/prover/src/lib.rs index 38cc68f054..9884753f71 100644 --- a/prover/src/lib.rs +++ b/prover/src/lib.rs @@ -16,7 +16,7 @@ use processor::{ RpxRandomCoin, WinterRandomCoin, }, math::{Felt, FieldElement}, - ExecutionTrace, + ExecutionTrace, Program, }; use tracing::instrument; use winter_prover::{ @@ -35,7 +35,7 @@ mod gpu; pub use air::{DeserializationError, ExecutionProof, FieldExtension, HashFunction, ProvingOptions}; pub use processor::{ crypto, math, utils, AdviceInputs, Digest, ExecutionError, Host, InputError, MemAdviceProvider, - Program, StackInputs, StackOutputs, Word, + StackInputs, StackOutputs, Word, }; pub use winter_prover::Proof; diff --git a/stdlib/tests/crypto/falcon.rs b/stdlib/tests/crypto/falcon.rs index 4e3693a504..8a81a97d7d 100644 --- a/stdlib/tests/crypto/falcon.rs +++ b/stdlib/tests/crypto/falcon.rs @@ -1,3 +1,4 @@ +use processor::{Program, ProgramInfo}; use rand::{thread_rng, Rng}; use assembly::{utils::Serializable, Assembler}; @@ -14,7 +15,7 @@ use test_utils::{ crypto::{rpo_falcon512::Polynomial, rpo_falcon512::SecretKey, MerkleStore}, expect_exec_error, rand::rand_vector, - FieldElement, ProgramInfo, QuadFelt, Word, WORD_SIZE, + FieldElement, QuadFelt, Word, WORD_SIZE, }; /// Modulus used for rpo falcon 512. @@ -172,7 +173,7 @@ fn test_falcon512_probabilistic_product_failure() { expect_exec_error!( test, ExecutionError::FailedAssertion { - clk: 17490, + clk: 31615, err_code: 0, err_msg: None, } @@ -198,11 +199,13 @@ fn falcon_prove_verify() { let message = rand_vector::(4).try_into().unwrap(); let (source, op_stack, _, _, advice_map) = generate_test(sk, message); - let program = Assembler::default() + let program: Program = Assembler::default() .with_library(&StdLibrary::default()) .expect("failed to load stdlib") .assemble(source) - .expect("failed to compile test source"); + .expect("failed to compile test source") + .try_into() + .expect("test source has no entrypoint"); let stack_inputs = StackInputs::try_from_ints(op_stack).expect("failed to create stack inputs"); let advice_inputs = AdviceInputs::default().with_map(advice_map); diff --git a/stdlib/tests/crypto/stark/mod.rs b/stdlib/tests/crypto/stark/mod.rs index 7a2d8f2dd6..8157770920 100644 --- a/stdlib/tests/crypto/stark/mod.rs +++ b/stdlib/tests/crypto/stark/mod.rs @@ -3,9 +3,9 @@ use verifier_recursive::{generate_advice_inputs, VerifierData}; use assembly::Assembler; use miden_air::{FieldExtension, HashFunction, PublicInputs}; -use processor::DefaultHost; +use processor::{DefaultHost, Program, ProgramInfo}; use test_utils::{ - prove, AdviceInputs, MemAdviceProvider, ProgramInfo, ProvingOptions, StackInputs, VerifierError, + prove, AdviceInputs, MemAdviceProvider, ProvingOptions, StackInputs, VerifierError, }; // Note: Changes to MidenVM may cause this test to fail when some of the assumptions documented @@ -51,7 +51,11 @@ pub fn generate_recursive_verifier_data( source: &str, stack_inputs: Vec, ) -> Result { - let program = Assembler::default().assemble(source).unwrap(); + let program: Program = Assembler::default() + .assemble(source) + .unwrap() + .try_into() + .expect("test source has no entrypoint"); let stack_inputs = StackInputs::try_from_ints(stack_inputs).unwrap(); let advice_inputs = AdviceInputs::default(); let advice_provider = MemAdviceProvider::from(advice_inputs); diff --git a/stdlib/tests/mem/mod.rs b/stdlib/tests/mem/mod.rs index b199db7042..c76677d01b 100644 --- a/stdlib/tests/mem/mod.rs +++ b/stdlib/tests/mem/mod.rs @@ -1,4 +1,4 @@ -use processor::{ContextId, DefaultHost, ProcessState}; +use processor::{ContextId, DefaultHost, ProcessState, Program}; use test_utils::{ build_expected_hash, build_expected_perm, stack_to_ints, ExecutionOptions, Process, StackInputs, ONE, ZERO, @@ -22,11 +22,15 @@ fn test_memcopy() { end "; - let mut assembler = assembly::Assembler::default() + let assembler = assembly::Assembler::default() .with_library(&StdLibrary::default()) .expect("failed to load stdlib"); - let program = assembler.assemble(source).expect("Failed to compile test source."); + let program: Program = assembler + .assemble(source) + .expect("Failed to compile test source.") + .try_into() + .expect("test source has no entrypoint."); let mut process = Process::new( program.kernel().clone(), diff --git a/test-utils/src/lib.rs b/test-utils/src/lib.rs index 3b72a4fe44..95929813a7 100644 --- a/test-utils/src/lib.rs +++ b/test-utils/src/lib.rs @@ -8,6 +8,7 @@ extern crate std; // IMPORTS // ================================================================================================ +use processor::Program; #[cfg(not(target_family = "wasm"))] use proptest::prelude::{Arbitrary, Strategy}; @@ -17,7 +18,7 @@ use alloc::{ sync::Arc, vec::Vec, }; -use vm_core::chiplets::hasher::apply_permutation; +use vm_core::{chiplets::hasher::apply_permutation, ProgramInfo}; // EXPORTS // ================================================================================================ @@ -33,12 +34,12 @@ pub use processor::{ }; pub use prover::{prove, MemAdviceProvider, ProvingOptions}; pub use test_case::test_case; -pub use verifier::{verify, AcceptableOptions, ProgramInfo, VerifierError}; +pub use verifier::{verify, AcceptableOptions, VerifierError}; pub use vm_core::{ chiplets::hasher::{hash_elements, STATE_WIDTH}, stack::STACK_TOP_SIZE, utils::{collections, group_slice_elements, IntoBytes, ToElements}, - Felt, FieldElement, Program, StarkField, Word, EMPTY_WORD, ONE, WORD_SIZE, ZERO, + Felt, FieldElement, StarkField, Word, EMPTY_WORD, ONE, WORD_SIZE, ZERO, }; pub mod math { @@ -166,8 +167,8 @@ macro_rules! assert_assembler_diagnostic { /// - Proptest: run an execution test inside a proptest. /// /// Types of failure tests: -/// - Assembly error test: check that attempting to compile the given source causes an -/// AssemblyError which contains the specified substring. +/// - Assembly error test: check that attempting to compile the given source causes an AssemblyError +/// which contains the specified substring. /// - Execution error test: check that running a program compiled from the given source causes an /// ExecutionError which contains the specified substring. pub struct Test { @@ -224,7 +225,7 @@ impl Test { expected_mem: &[u64], ) { // compile the program - let program = self.compile().expect("Failed to compile test source."); + let program: Program = self.compile().expect("Failed to compile test source."); let host = DefaultHost::new(MemAdviceProvider::from(self.advice_inputs.clone())); // execute the test @@ -281,7 +282,7 @@ impl Test { } else { assembly::Assembler::default() }; - let mut assembler = self + let assembler = self .add_modules .iter() .fold(assembler, |assembler, (path, source)| { @@ -296,14 +297,14 @@ impl Test { .with_libraries(self.libraries.iter()) .expect("failed to load stdlib"); - assembler.assemble(self.source.clone()) + assembler.assemble_program(self.source.clone()) } /// Compiles the test's source to a Program and executes it with the tests inputs. Returns a /// resulting execution trace or error. #[track_caller] pub fn execute(&self) -> Result { - let program = self.compile().expect("Failed to compile test source."); + let program: Program = self.compile().expect("Failed to compile test source."); let host = DefaultHost::new(MemAdviceProvider::from(self.advice_inputs.clone())); processor::execute(&program, self.stack_inputs.clone(), host, ExecutionOptions::default()) } @@ -313,7 +314,7 @@ impl Test { pub fn execute_process( &self, ) -> Result>, ExecutionError> { - let program = self.compile().expect("Failed to compile test source."); + let program: Program = self.compile().expect("Failed to compile test source."); let host = DefaultHost::new(MemAdviceProvider::from(self.advice_inputs.clone())); let mut process = Process::new( program.kernel().clone(), @@ -330,7 +331,7 @@ impl Test { /// is true, this function will force a failure by modifying the first output. pub fn prove_and_verify(&self, pub_inputs: Vec, test_fail: bool) { let stack_inputs = StackInputs::try_from_ints(pub_inputs).unwrap(); - let program = self.compile().expect("Failed to compile test source."); + let program: Program = self.compile().expect("Failed to compile test source."); let host = DefaultHost::new(MemAdviceProvider::from(self.advice_inputs.clone())); let (mut stack_outputs, proof) = prover::prove(&program, stack_inputs.clone(), host, ProvingOptions::default()).unwrap(); @@ -349,7 +350,7 @@ impl Test { /// VmStateIterator that allows us to iterate through each clock cycle and inspect the process /// state. pub fn execute_iter(&self) -> VmStateIterator { - let program = self.compile().expect("Failed to compile test source."); + let program: Program = self.compile().expect("Failed to compile test source."); let host = DefaultHost::new(MemAdviceProvider::from(self.advice_inputs.clone())); processor::execute_iter(&program, self.stack_inputs.clone(), host) }