diff --git a/CHANGELOG.md b/CHANGELOG.md index 6504a8e6f5..89c5cd9821 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -16,6 +16,7 @@ - Relaxed the parser to allow one branch of an `if.(true|false)` to be empty - Added support for immediate values for `u32and`, `u32or`, `u32xor` and `u32not` bitwise instructions (#1362). - Optimized `std::sys::truncate_stuck` procedure (#1384). +- Add serialization/deserialization for `MastForest` (#1370) #### Changed diff --git a/air/src/constraints/stack/op_flags/mod.rs b/air/src/constraints/stack/op_flags/mod.rs index e8ea0b51d7..0341537447 100644 --- a/air/src/constraints/stack/op_flags/mod.rs +++ b/air/src/constraints/stack/op_flags/mod.rs @@ -840,7 +840,7 @@ impl OpFlags { /// Operation Flag of U32ASSERT2 operation. #[inline(always)] pub fn u32assert2(&self) -> E { - self.degree6_op_flags[get_op_index(Operation::U32assert2(ZERO).op_code())] + self.degree6_op_flags[get_op_index(Operation::U32assert2(0).op_code())] } /// Operation Flag of U32ADD3 operation. diff --git a/assembly/src/assembler/instruction/mod.rs b/assembly/src/assembler/instruction/mod.rs index 531323225f..973be8241e 100644 --- a/assembly/src/assembler/instruction/mod.rs +++ b/assembly/src/assembler/instruction/mod.rs @@ -119,17 +119,17 @@ impl Assembler { // ----- u32 manipulation ------------------------------------------------------------- Instruction::U32Test => span_builder.push_ops([Dup0, U32split, Swap, Drop, Eqz]), Instruction::U32TestW => u32_ops::u32testw(span_builder), - Instruction::U32Assert => span_builder.push_ops([Pad, U32assert2(ZERO), Drop]), + Instruction::U32Assert => span_builder.push_ops([Pad, U32assert2(0), Drop]), Instruction::U32AssertWithError(err_code) => { - span_builder.push_ops([Pad, U32assert2(Felt::from(err_code.expect_value())), Drop]) + span_builder.push_ops([Pad, U32assert2(err_code.expect_value()), Drop]) } - Instruction::U32Assert2 => span_builder.push_op(U32assert2(ZERO)), + Instruction::U32Assert2 => span_builder.push_op(U32assert2(0)), Instruction::U32Assert2WithError(err_code) => { - span_builder.push_op(U32assert2(Felt::from(err_code.expect_value()))) + span_builder.push_op(U32assert2(err_code.expect_value())) } - Instruction::U32AssertW => u32_ops::u32assertw(span_builder, ZERO), + Instruction::U32AssertW => u32_ops::u32assertw(span_builder, 0), Instruction::U32AssertWWithError(err_code) => { - u32_ops::u32assertw(span_builder, Felt::from(err_code.expect_value())) + u32_ops::u32assertw(span_builder, err_code.expect_value()) } Instruction::U32Cast => span_builder.push_ops([U32split, Drop]), diff --git a/assembly/src/assembler/instruction/u32_ops.rs b/assembly/src/assembler/instruction/u32_ops.rs index e3e3f355e1..4a3223cfa3 100644 --- a/assembly/src/assembler/instruction/u32_ops.rs +++ b/assembly/src/assembler/instruction/u32_ops.rs @@ -6,7 +6,6 @@ use crate::{ use vm_core::{ AdviceInjector, Felt, Operation::{self, *}, - ZERO, }; /// This enum is intended to determine the mode of operation passed to the parsing function @@ -45,7 +44,7 @@ pub fn u32testw(span_builder: &mut BasicBlockBuilder) { /// /// Implemented by executing `U32ASSERT2` on each pair of elements in the word. /// Total of 6 VM cycles. -pub fn u32assertw(span_builder: &mut BasicBlockBuilder, err_code: Felt) { +pub fn u32assertw(span_builder: &mut BasicBlockBuilder, err_code: u32) { #[rustfmt::skip] let ops = [ // Test the first and the second elements @@ -171,7 +170,7 @@ pub fn u32not(span_builder: &mut BasicBlockBuilder) { let ops = [ // Perform the operation Push(Felt::from(u32::MAX)), - U32assert2(ZERO), + U32assert2(0), Swap, U32sub, diff --git a/core/Cargo.toml b/core/Cargo.toml index 25a02a3073..d7b97938a4 100644 --- a/core/Cargo.toml +++ b/core/Cargo.toml @@ -19,19 +19,14 @@ doctest = false [features] default = ["std"] -std = [ - "miden-crypto/std", - "miden-formatting/std", - "math/std", - "winter-utils/std", - "thiserror/std", -] +std = ["miden-crypto/std", "miden-formatting/std", "math/std", "winter-utils/std", "thiserror/std"] [dependencies] math = { package = "winter-math", version = "0.9", default-features = false } -#miden-crypto = { version = "0.9", default-features = false } miden-crypto = { git = "https://github.com/0xPolygonMiden/crypto", branch = "next", default-features = false } miden-formatting = { version = "0.1", default-features = false } +num-derive = { version = "0.4", default-features = false } +num-traits = { version = "0.2", default-features = false } thiserror = { version = "1.0", git = "https://github.com/bitwalker/thiserror", branch = "no-std", default-features = false } winter-utils = { package = "winter-utils", version = "0.9", default-features = false } diff --git a/core/src/mast/mod.rs b/core/src/mast/mod.rs index f3e605a149..70a87c58c1 100644 --- a/core/src/mast/mod.rs +++ b/core/src/mast/mod.rs @@ -6,8 +6,11 @@ use miden_crypto::hash::rpo::RpoDigest; mod node; pub use node::{ get_span_op_group_count, BasicBlockNode, CallNode, DynNode, ExternalNode, JoinNode, LoopNode, - MastNode, OpBatch, SplitNode, OP_BATCH_SIZE, OP_GROUP_SIZE, + MastNode, OpBatch, OperationOrDecorator, SplitNode, OP_BATCH_SIZE, OP_GROUP_SIZE, }; +use winter_utils::{ByteReader, ByteWriter, Deserializable, DeserializationError, Serializable}; + +mod serialization; #[cfg(test)] mod tests; @@ -18,6 +21,9 @@ pub trait MerkleTreeNode { fn to_display<'a>(&'a self, mast_forest: &'a MastForest) -> impl fmt::Display + 'a; } +// MAST NODE ID +// ================================================================================================ + /// An opaque handle to a [`MastNode`] in some [`MastForest`]. It is the responsibility of the user /// to use a given [`MastNodeId`] with the corresponding [`MastForest`]. /// @@ -27,14 +33,49 @@ pub trait MerkleTreeNode { #[derive(Debug, Clone, Copy, PartialEq, Eq, PartialOrd, Ord, Hash)] pub struct MastNodeId(u32); +impl MastNodeId { + /// Returns a new `MastNodeId` with the provided inner value, or an error if the provided + /// `value` is greater than the number of nodes in the forest. + /// + /// For use in deserialization. + pub fn from_u32_safe( + value: u32, + mast_forest: &MastForest, + ) -> Result { + if (value as usize) < mast_forest.nodes.len() { + Ok(Self(value)) + } else { + Err(DeserializationError::InvalidValue(format!( + "Invalid deserialized MAST node ID '{}', but only {} nodes in the forest", + value, + mast_forest.nodes.len(), + ))) + } + } +} + impl fmt::Display for MastNodeId { fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { write!(f, "MastNodeId({})", self.0) } } +impl Serializable for MastNodeId { + fn write_into(&self, target: &mut W) { + self.0.write_into(target) + } +} + +impl Deserializable for MastNodeId { + fn read_from(source: &mut R) -> Result { + let inner = source.read_u32()?; + + Ok(Self(inner)) + } +} + // MAST FOREST -// =============================================================================================== +// ================================================================================================ /// Represents one or more procedures, represented as a collection of [`MastNode`]s. /// @@ -94,7 +135,7 @@ impl MastForest { /// Returns the [`MastNode`] associated with the provided [`MastNodeId`] if valid, or else /// `None`. /// - /// This is the faillible version of indexing (e.g. `mast_forest[node_id]`). + /// This is the failable version of indexing (e.g. `mast_forest[node_id]`). #[inline(always)] pub fn get_node_by_id(&self, node_id: MastNodeId) -> Option<&MastNode> { let idx = node_id.0 as usize; diff --git a/core/src/mast/node/basic_block_node.rs b/core/src/mast/node/basic_block_node/mod.rs similarity index 61% rename from core/src/mast/node/basic_block_node.rs rename to core/src/mast/node/basic_block_node/mod.rs index 49c92dd260..a8d07ab2de 100644 --- a/core/src/mast/node/basic_block_node.rs +++ b/core/src/mast/node/basic_block_node/mod.rs @@ -8,9 +8,12 @@ use winter_utils::flatten_slice_elements; use crate::{ chiplets::hasher, mast::{MastForest, MerkleTreeNode}, - DecoratorIterator, DecoratorList, Operation, + Decorator, DecoratorIterator, DecoratorList, Operation, }; +#[cfg(test)] +mod tests; + // CONSTANTS // ================================================================================================ @@ -107,10 +110,25 @@ impl BasicBlockNode { /// Public accessors impl BasicBlockNode { + pub fn num_operations_and_decorators(&self) -> u32 { + let num_ops: usize = self.op_batches.iter().map(|batch| batch.ops().len()).sum(); + let num_decorators = self.decorators.len(); + + (num_ops + num_decorators) + .try_into() + .expect("basic block contains more than 2^32 operations and decorators") + } + pub fn op_batches(&self) -> &[OpBatch] { &self.op_batches } + /// Returns an iterator over all operations and decorator, in the order in which they appear in + /// the program. + pub fn iter(&self) -> impl Iterator { + OperationOrDecoratorIterator::new(self) + } + /// Returns a [`DecoratorIterator`] which allows us to iterate through the decorator list of /// this basic block node while executing operation batches of this basic block node. pub fn decorator_iter(&self) -> DecoratorIterator { @@ -202,6 +220,77 @@ impl fmt::Display for BasicBlockNode { } } +// OPERATION OR DECORATOR +// ================================================================================================ + +/// Encodes either an [`Operation`] or a [`Decorator`]. +#[derive(Clone, Debug, Eq, PartialEq)] +pub enum OperationOrDecorator<'a> { + Operation(&'a Operation), + Decorator(&'a Decorator), +} + +struct OperationOrDecoratorIterator<'a> { + node: &'a BasicBlockNode, + + /// The index of the current batch + batch_index: usize, + + /// The index of the operation in the current batch + op_index_in_batch: usize, + + /// The index of the current operation across all batches + op_index: usize, + + /// The index of the next element in `node.decorator_list`. This list is assumed to be sorted. + decorator_list_next_index: usize, +} + +impl<'a> OperationOrDecoratorIterator<'a> { + fn new(node: &'a BasicBlockNode) -> Self { + Self { + node, + batch_index: 0, + op_index_in_batch: 0, + op_index: 0, + decorator_list_next_index: 0, + } + } +} + +impl<'a> Iterator for OperationOrDecoratorIterator<'a> { + type Item = OperationOrDecorator<'a>; + + fn next(&mut self) -> Option { + // check if there's a decorator to execute + if let Some((op_index, decorator)) = + self.node.decorators.get(self.decorator_list_next_index) + { + if *op_index == self.op_index { + self.decorator_list_next_index += 1; + return Some(OperationOrDecorator::Decorator(decorator)); + } + } + + // If no decorator needs to be executed, then execute the operation + if let Some(batch) = self.node.op_batches.get(self.batch_index) { + if let Some(operation) = batch.ops.get(self.op_index_in_batch) { + self.op_index_in_batch += 1; + self.op_index += 1; + + Some(OperationOrDecorator::Operation(operation)) + } else { + self.batch_index += 1; + self.op_index_in_batch = 0; + + self.next() + } + } else { + None + } + } +} + // OPERATION BATCH // ================================================================================================ @@ -426,316 +515,3 @@ pub fn get_span_op_group_count(op_batches: &[OpBatch]) -> usize { let last_batch_num_groups = op_batches.last().expect("no last group").num_groups(); (op_batches.len() - 1) * BATCH_SIZE + last_batch_num_groups.next_power_of_two() } - -// TESTS -// ================================================================================================ - -#[cfg(test)] -mod tests { - use super::{hasher, Felt, Operation, BATCH_SIZE, ZERO}; - use crate::ONE; - - #[test] - fn batch_ops() { - // --- one operation ---------------------------------------------------------------------- - let ops = vec![Operation::Add]; - let (batches, hash) = super::batch_ops(ops.clone()); - assert_eq!(1, batches.len()); - - let batch = &batches[0]; - assert_eq!(ops, batch.ops); - assert_eq!(1, batch.num_groups()); - - let mut batch_groups = [ZERO; BATCH_SIZE]; - batch_groups[0] = build_group(&ops); - - assert_eq!(batch_groups, batch.groups); - assert_eq!([1_usize, 0, 0, 0, 0, 0, 0, 0], batch.op_counts); - assert_eq!(hasher::hash_elements(&batch_groups), hash); - - // --- two operations --------------------------------------------------------------------- - let ops = vec![Operation::Add, Operation::Mul]; - let (batches, hash) = super::batch_ops(ops.clone()); - assert_eq!(1, batches.len()); - - let batch = &batches[0]; - assert_eq!(ops, batch.ops); - assert_eq!(1, batch.num_groups()); - - let mut batch_groups = [ZERO; BATCH_SIZE]; - batch_groups[0] = build_group(&ops); - - assert_eq!(batch_groups, batch.groups); - assert_eq!([2_usize, 0, 0, 0, 0, 0, 0, 0], batch.op_counts); - assert_eq!(hasher::hash_elements(&batch_groups), hash); - - // --- one group with one immediate value ------------------------------------------------- - let ops = vec![Operation::Add, Operation::Push(Felt::new(12345678))]; - let (batches, hash) = super::batch_ops(ops.clone()); - assert_eq!(1, batches.len()); - - let batch = &batches[0]; - assert_eq!(ops, batch.ops); - assert_eq!(2, batch.num_groups()); - - let mut batch_groups = [ZERO; BATCH_SIZE]; - batch_groups[0] = build_group(&ops); - batch_groups[1] = Felt::new(12345678); - - assert_eq!(batch_groups, batch.groups); - assert_eq!([2_usize, 0, 0, 0, 0, 0, 0, 0], batch.op_counts); - assert_eq!(hasher::hash_elements(&batch_groups), hash); - - // --- one group with 7 immediate values -------------------------------------------------- - let ops = vec![ - Operation::Push(ONE), - Operation::Push(Felt::new(2)), - Operation::Push(Felt::new(3)), - Operation::Push(Felt::new(4)), - Operation::Push(Felt::new(5)), - Operation::Push(Felt::new(6)), - Operation::Push(Felt::new(7)), - Operation::Add, - ]; - let (batches, hash) = super::batch_ops(ops.clone()); - assert_eq!(1, batches.len()); - - let batch = &batches[0]; - assert_eq!(ops, batch.ops); - assert_eq!(8, batch.num_groups()); - - let batch_groups = [ - build_group(&ops), - ONE, - Felt::new(2), - Felt::new(3), - Felt::new(4), - Felt::new(5), - Felt::new(6), - Felt::new(7), - ]; - - assert_eq!(batch_groups, batch.groups); - assert_eq!([8_usize, 0, 0, 0, 0, 0, 0, 0], batch.op_counts); - assert_eq!(hasher::hash_elements(&batch_groups), hash); - - // --- two groups with 7 immediate values; the last push overflows to the second batch ---- - let ops = vec![ - Operation::Add, - Operation::Mul, - Operation::Push(ONE), - Operation::Push(Felt::new(2)), - Operation::Push(Felt::new(3)), - Operation::Push(Felt::new(4)), - Operation::Push(Felt::new(5)), - Operation::Push(Felt::new(6)), - Operation::Add, - Operation::Push(Felt::new(7)), - ]; - let (batches, hash) = super::batch_ops(ops.clone()); - assert_eq!(2, batches.len()); - - let batch0 = &batches[0]; - assert_eq!(ops[..9], batch0.ops); - assert_eq!(7, batch0.num_groups()); - - let batch0_groups = [ - build_group(&ops[..9]), - ONE, - Felt::new(2), - Felt::new(3), - Felt::new(4), - Felt::new(5), - Felt::new(6), - ZERO, - ]; - - assert_eq!(batch0_groups, batch0.groups); - assert_eq!([9_usize, 0, 0, 0, 0, 0, 0, 0], batch0.op_counts); - - let batch1 = &batches[1]; - assert_eq!(vec![ops[9]], batch1.ops); - assert_eq!(2, batch1.num_groups()); - - let mut batch1_groups = [ZERO; BATCH_SIZE]; - batch1_groups[0] = build_group(&[ops[9]]); - batch1_groups[1] = Felt::new(7); - - assert_eq!([1_usize, 0, 0, 0, 0, 0, 0, 0], batch1.op_counts); - assert_eq!(batch1_groups, batch1.groups); - - let all_groups = [batch0_groups, batch1_groups].concat(); - assert_eq!(hasher::hash_elements(&all_groups), hash); - - // --- immediate values in-between groups ------------------------------------------------- - let ops = vec![ - Operation::Add, - Operation::Mul, - Operation::Add, - Operation::Push(Felt::new(7)), - Operation::Add, - Operation::Add, - Operation::Push(Felt::new(11)), - Operation::Mul, - Operation::Mul, - Operation::Add, - ]; - - let (batches, hash) = super::batch_ops(ops.clone()); - assert_eq!(1, batches.len()); - - let batch = &batches[0]; - assert_eq!(ops, batch.ops); - assert_eq!(4, batch.num_groups()); - - let batch_groups = [ - build_group(&ops[..9]), - Felt::new(7), - Felt::new(11), - build_group(&ops[9..]), - ZERO, - ZERO, - ZERO, - ZERO, - ]; - - assert_eq!([9_usize, 0, 0, 1, 0, 0, 0, 0], batch.op_counts); - assert_eq!(batch_groups, batch.groups); - assert_eq!(hasher::hash_elements(&batch_groups), hash); - - // --- push at the end of a group is moved into the next group ---------------------------- - let ops = vec![ - Operation::Add, - Operation::Mul, - Operation::Add, - Operation::Add, - Operation::Add, - Operation::Mul, - Operation::Mul, - Operation::Add, - Operation::Push(Felt::new(11)), - ]; - let (batches, hash) = super::batch_ops(ops.clone()); - assert_eq!(1, batches.len()); - - let batch = &batches[0]; - assert_eq!(ops, batch.ops); - assert_eq!(3, batch.num_groups()); - - let batch_groups = [ - build_group(&ops[..8]), - build_group(&[ops[8]]), - Felt::new(11), - ZERO, - ZERO, - ZERO, - ZERO, - ZERO, - ]; - - assert_eq!(batch_groups, batch.groups); - assert_eq!([8_usize, 1, 0, 0, 0, 0, 0, 0], batch.op_counts); - assert_eq!(hasher::hash_elements(&batch_groups), hash); - - // --- push at the end of a group is moved into the next group ---------------------------- - let ops = vec![ - Operation::Add, - Operation::Mul, - Operation::Add, - Operation::Add, - Operation::Add, - Operation::Mul, - Operation::Mul, - Operation::Push(ONE), - Operation::Push(Felt::new(2)), - ]; - let (batches, hash) = super::batch_ops(ops.clone()); - assert_eq!(1, batches.len()); - - let batch = &batches[0]; - assert_eq!(ops, batch.ops); - assert_eq!(4, batch.num_groups()); - - let batch_groups = [ - build_group(&ops[..8]), - ONE, - build_group(&[ops[8]]), - Felt::new(2), - ZERO, - ZERO, - ZERO, - ZERO, - ]; - - assert_eq!(batch_groups, batch.groups); - assert_eq!([8_usize, 0, 1, 0, 0, 0, 0, 0], batch.op_counts); - assert_eq!(hasher::hash_elements(&batch_groups), hash); - - // --- push at the end of the 7th group overflows to the next batch ----------------------- - let ops = vec![ - Operation::Add, - Operation::Mul, - Operation::Push(ONE), - Operation::Push(Felt::new(2)), - Operation::Push(Felt::new(3)), - Operation::Push(Felt::new(4)), - Operation::Push(Felt::new(5)), - Operation::Add, - Operation::Mul, - Operation::Add, - Operation::Mul, - Operation::Add, - Operation::Mul, - Operation::Add, - Operation::Mul, - Operation::Add, - Operation::Mul, - Operation::Push(Felt::new(6)), - Operation::Pad, - ]; - - let (batches, hash) = super::batch_ops(ops.clone()); - assert_eq!(2, batches.len()); - - let batch0 = &batches[0]; - assert_eq!(ops[..17], batch0.ops); - assert_eq!(7, batch0.num_groups()); - - let batch0_groups = [ - build_group(&ops[..9]), - ONE, - Felt::new(2), - Felt::new(3), - Felt::new(4), - Felt::new(5), - build_group(&ops[9..17]), - ZERO, - ]; - - assert_eq!(batch0_groups, batch0.groups); - assert_eq!([9_usize, 0, 0, 0, 0, 0, 8, 0], batch0.op_counts); - - let batch1 = &batches[1]; - assert_eq!(ops[17..], batch1.ops); - assert_eq!(2, batch1.num_groups()); - - let batch1_groups = - [build_group(&ops[17..]), Felt::new(6), ZERO, ZERO, ZERO, ZERO, ZERO, ZERO]; - assert_eq!(batch1_groups, batch1.groups); - assert_eq!([2_usize, 0, 0, 0, 0, 0, 0, 0], batch1.op_counts); - - let all_groups = [batch0_groups, batch1_groups].concat(); - assert_eq!(hasher::hash_elements(&all_groups), hash); - } - - // TEST HELPERS - // -------------------------------------------------------------------------------------------- - - fn build_group(ops: &[Operation]) -> Felt { - let mut group = 0u64; - for (i, op) in ops.iter().enumerate() { - group |= (op.op_code() as u64) << (Operation::OP_BITS * i); - } - Felt::new(group) - } -} diff --git a/core/src/mast/node/basic_block_node/tests.rs b/core/src/mast/node/basic_block_node/tests.rs new file mode 100644 index 0000000000..49a663ae33 --- /dev/null +++ b/core/src/mast/node/basic_block_node/tests.rs @@ -0,0 +1,341 @@ +use super::*; +use crate::{Decorator, ONE}; + +#[test] +fn batch_ops() { + // --- one operation ---------------------------------------------------------------------- + let ops = vec![Operation::Add]; + let (batches, hash) = super::batch_ops(ops.clone()); + assert_eq!(1, batches.len()); + + let batch = &batches[0]; + assert_eq!(ops, batch.ops); + assert_eq!(1, batch.num_groups()); + + let mut batch_groups = [ZERO; BATCH_SIZE]; + batch_groups[0] = build_group(&ops); + + assert_eq!(batch_groups, batch.groups); + assert_eq!([1_usize, 0, 0, 0, 0, 0, 0, 0], batch.op_counts); + assert_eq!(hasher::hash_elements(&batch_groups), hash); + + // --- two operations --------------------------------------------------------------------- + let ops = vec![Operation::Add, Operation::Mul]; + let (batches, hash) = super::batch_ops(ops.clone()); + assert_eq!(1, batches.len()); + + let batch = &batches[0]; + assert_eq!(ops, batch.ops); + assert_eq!(1, batch.num_groups()); + + let mut batch_groups = [ZERO; BATCH_SIZE]; + batch_groups[0] = build_group(&ops); + + assert_eq!(batch_groups, batch.groups); + assert_eq!([2_usize, 0, 0, 0, 0, 0, 0, 0], batch.op_counts); + assert_eq!(hasher::hash_elements(&batch_groups), hash); + + // --- one group with one immediate value ------------------------------------------------- + let ops = vec![Operation::Add, Operation::Push(Felt::new(12345678))]; + let (batches, hash) = super::batch_ops(ops.clone()); + assert_eq!(1, batches.len()); + + let batch = &batches[0]; + assert_eq!(ops, batch.ops); + assert_eq!(2, batch.num_groups()); + + let mut batch_groups = [ZERO; BATCH_SIZE]; + batch_groups[0] = build_group(&ops); + batch_groups[1] = Felt::new(12345678); + + assert_eq!(batch_groups, batch.groups); + assert_eq!([2_usize, 0, 0, 0, 0, 0, 0, 0], batch.op_counts); + assert_eq!(hasher::hash_elements(&batch_groups), hash); + + // --- one group with 7 immediate values -------------------------------------------------- + let ops = vec![ + Operation::Push(ONE), + Operation::Push(Felt::new(2)), + Operation::Push(Felt::new(3)), + Operation::Push(Felt::new(4)), + Operation::Push(Felt::new(5)), + Operation::Push(Felt::new(6)), + Operation::Push(Felt::new(7)), + Operation::Add, + ]; + let (batches, hash) = super::batch_ops(ops.clone()); + assert_eq!(1, batches.len()); + + let batch = &batches[0]; + assert_eq!(ops, batch.ops); + assert_eq!(8, batch.num_groups()); + + let batch_groups = [ + build_group(&ops), + ONE, + Felt::new(2), + Felt::new(3), + Felt::new(4), + Felt::new(5), + Felt::new(6), + Felt::new(7), + ]; + + assert_eq!(batch_groups, batch.groups); + assert_eq!([8_usize, 0, 0, 0, 0, 0, 0, 0], batch.op_counts); + assert_eq!(hasher::hash_elements(&batch_groups), hash); + + // --- two groups with 7 immediate values; the last push overflows to the second batch ---- + let ops = vec![ + Operation::Add, + Operation::Mul, + Operation::Push(ONE), + Operation::Push(Felt::new(2)), + Operation::Push(Felt::new(3)), + Operation::Push(Felt::new(4)), + Operation::Push(Felt::new(5)), + Operation::Push(Felt::new(6)), + Operation::Add, + Operation::Push(Felt::new(7)), + ]; + let (batches, hash) = super::batch_ops(ops.clone()); + assert_eq!(2, batches.len()); + + let batch0 = &batches[0]; + assert_eq!(ops[..9], batch0.ops); + assert_eq!(7, batch0.num_groups()); + + let batch0_groups = [ + build_group(&ops[..9]), + ONE, + Felt::new(2), + Felt::new(3), + Felt::new(4), + Felt::new(5), + Felt::new(6), + ZERO, + ]; + + assert_eq!(batch0_groups, batch0.groups); + assert_eq!([9_usize, 0, 0, 0, 0, 0, 0, 0], batch0.op_counts); + + let batch1 = &batches[1]; + assert_eq!(vec![ops[9]], batch1.ops); + assert_eq!(2, batch1.num_groups()); + + let mut batch1_groups = [ZERO; BATCH_SIZE]; + batch1_groups[0] = build_group(&[ops[9]]); + batch1_groups[1] = Felt::new(7); + + assert_eq!([1_usize, 0, 0, 0, 0, 0, 0, 0], batch1.op_counts); + assert_eq!(batch1_groups, batch1.groups); + + let all_groups = [batch0_groups, batch1_groups].concat(); + assert_eq!(hasher::hash_elements(&all_groups), hash); + + // --- immediate values in-between groups ------------------------------------------------- + let ops = vec![ + Operation::Add, + Operation::Mul, + Operation::Add, + Operation::Push(Felt::new(7)), + Operation::Add, + Operation::Add, + Operation::Push(Felt::new(11)), + Operation::Mul, + Operation::Mul, + Operation::Add, + ]; + + let (batches, hash) = super::batch_ops(ops.clone()); + assert_eq!(1, batches.len()); + + let batch = &batches[0]; + assert_eq!(ops, batch.ops); + assert_eq!(4, batch.num_groups()); + + let batch_groups = [ + build_group(&ops[..9]), + Felt::new(7), + Felt::new(11), + build_group(&ops[9..]), + ZERO, + ZERO, + ZERO, + ZERO, + ]; + + assert_eq!([9_usize, 0, 0, 1, 0, 0, 0, 0], batch.op_counts); + assert_eq!(batch_groups, batch.groups); + assert_eq!(hasher::hash_elements(&batch_groups), hash); + + // --- push at the end of a group is moved into the next group ---------------------------- + let ops = vec![ + Operation::Add, + Operation::Mul, + Operation::Add, + Operation::Add, + Operation::Add, + Operation::Mul, + Operation::Mul, + Operation::Add, + Operation::Push(Felt::new(11)), + ]; + let (batches, hash) = super::batch_ops(ops.clone()); + assert_eq!(1, batches.len()); + + let batch = &batches[0]; + assert_eq!(ops, batch.ops); + assert_eq!(3, batch.num_groups()); + + let batch_groups = [ + build_group(&ops[..8]), + build_group(&[ops[8]]), + Felt::new(11), + ZERO, + ZERO, + ZERO, + ZERO, + ZERO, + ]; + + assert_eq!(batch_groups, batch.groups); + assert_eq!([8_usize, 1, 0, 0, 0, 0, 0, 0], batch.op_counts); + assert_eq!(hasher::hash_elements(&batch_groups), hash); + + // --- push at the end of a group is moved into the next group ---------------------------- + let ops = vec![ + Operation::Add, + Operation::Mul, + Operation::Add, + Operation::Add, + Operation::Add, + Operation::Mul, + Operation::Mul, + Operation::Push(ONE), + Operation::Push(Felt::new(2)), + ]; + let (batches, hash) = super::batch_ops(ops.clone()); + assert_eq!(1, batches.len()); + + let batch = &batches[0]; + assert_eq!(ops, batch.ops); + assert_eq!(4, batch.num_groups()); + + let batch_groups = [ + build_group(&ops[..8]), + ONE, + build_group(&[ops[8]]), + Felt::new(2), + ZERO, + ZERO, + ZERO, + ZERO, + ]; + + assert_eq!(batch_groups, batch.groups); + assert_eq!([8_usize, 0, 1, 0, 0, 0, 0, 0], batch.op_counts); + assert_eq!(hasher::hash_elements(&batch_groups), hash); + + // --- push at the end of the 7th group overflows to the next batch ----------------------- + let ops = vec![ + Operation::Add, + Operation::Mul, + Operation::Push(ONE), + Operation::Push(Felt::new(2)), + Operation::Push(Felt::new(3)), + Operation::Push(Felt::new(4)), + Operation::Push(Felt::new(5)), + Operation::Add, + Operation::Mul, + Operation::Add, + Operation::Mul, + Operation::Add, + Operation::Mul, + Operation::Add, + Operation::Mul, + Operation::Add, + Operation::Mul, + Operation::Push(Felt::new(6)), + Operation::Pad, + ]; + + let (batches, hash) = super::batch_ops(ops.clone()); + assert_eq!(2, batches.len()); + + let batch0 = &batches[0]; + assert_eq!(ops[..17], batch0.ops); + assert_eq!(7, batch0.num_groups()); + + let batch0_groups = [ + build_group(&ops[..9]), + ONE, + Felt::new(2), + Felt::new(3), + Felt::new(4), + Felt::new(5), + build_group(&ops[9..17]), + ZERO, + ]; + + assert_eq!(batch0_groups, batch0.groups); + assert_eq!([9_usize, 0, 0, 0, 0, 0, 8, 0], batch0.op_counts); + + let batch1 = &batches[1]; + assert_eq!(ops[17..], batch1.ops); + assert_eq!(2, batch1.num_groups()); + + let batch1_groups = [build_group(&ops[17..]), Felt::new(6), ZERO, ZERO, ZERO, ZERO, ZERO, ZERO]; + assert_eq!(batch1_groups, batch1.groups); + assert_eq!([2_usize, 0, 0, 0, 0, 0, 0, 0], batch1.op_counts); + + let all_groups = [batch0_groups, batch1_groups].concat(); + assert_eq!(hasher::hash_elements(&all_groups), hash); +} + +#[test] +fn operation_or_decorator_iterator() { + let operations = vec![Operation::Add, Operation::Mul, Operation::MovDn2, Operation::MovDn3]; + + // Note: there are 2 decorators after the last instruction + let decorators = vec![ + (0, Decorator::Event(0)), + (0, Decorator::Event(1)), + (3, Decorator::Event(2)), + (4, Decorator::Event(3)), + (4, Decorator::Event(4)), + ]; + + let node = BasicBlockNode::with_decorators(operations, decorators); + + let mut iterator = node.iter(); + + // operation index 0 + assert_eq!(iterator.next(), Some(OperationOrDecorator::Decorator(&Decorator::Event(0)))); + assert_eq!(iterator.next(), Some(OperationOrDecorator::Decorator(&Decorator::Event(1)))); + assert_eq!(iterator.next(), Some(OperationOrDecorator::Operation(&Operation::Add))); + + // operations indices 1, 2 + assert_eq!(iterator.next(), Some(OperationOrDecorator::Operation(&Operation::Mul))); + assert_eq!(iterator.next(), Some(OperationOrDecorator::Operation(&Operation::MovDn2))); + + // operation index 3 + assert_eq!(iterator.next(), Some(OperationOrDecorator::Decorator(&Decorator::Event(2)))); + assert_eq!(iterator.next(), Some(OperationOrDecorator::Operation(&Operation::MovDn3))); + + // after last operation + assert_eq!(iterator.next(), Some(OperationOrDecorator::Decorator(&Decorator::Event(3)))); + assert_eq!(iterator.next(), Some(OperationOrDecorator::Decorator(&Decorator::Event(4)))); + assert_eq!(iterator.next(), None); +} + +// TEST HELPERS +// -------------------------------------------------------------------------------------------- + +fn build_group(ops: &[Operation]) -> Felt { + let mut group = 0u64; + for (i, op) in ops.iter().enumerate() { + group |= (op.op_code() as u64) << (Operation::OP_BITS * i); + } + Felt::new(group) +} diff --git a/core/src/mast/node/join_node.rs b/core/src/mast/node/join_node.rs index f1a5417bf9..1cd5322c0f 100644 --- a/core/src/mast/node/join_node.rs +++ b/core/src/mast/node/join_node.rs @@ -35,14 +35,9 @@ impl JoinNode { Self { children, digest } } - pub(super) fn to_pretty_print<'a>( - &'a self, - mast_forest: &'a MastForest, - ) -> impl PrettyPrint + 'a { - JoinNodePrettyPrint { - join_node: self, - mast_forest, - } + #[cfg(test)] + pub fn new_test(children: [MastNodeId; 2], digest: RpoDigest) -> Self { + Self { children, digest } } } @@ -57,6 +52,18 @@ impl JoinNode { } } +impl JoinNode { + pub(super) fn to_pretty_print<'a>( + &'a self, + mast_forest: &'a MastForest, + ) -> impl PrettyPrint + 'a { + JoinNodePrettyPrint { + join_node: self, + mast_forest, + } + } +} + impl MerkleTreeNode for JoinNode { fn digest(&self) -> RpoDigest { self.digest diff --git a/core/src/mast/node/mod.rs b/core/src/mast/node/mod.rs index 2bf0836cf3..31cb297309 100644 --- a/core/src/mast/node/mod.rs +++ b/core/src/mast/node/mod.rs @@ -3,8 +3,8 @@ use core::fmt; use alloc::{boxed::Box, vec::Vec}; pub use basic_block_node::{ - get_span_op_group_count, BasicBlockNode, OpBatch, BATCH_SIZE as OP_BATCH_SIZE, - GROUP_SIZE as OP_GROUP_SIZE, + get_span_op_group_count, BasicBlockNode, OpBatch, OperationOrDecorator, + BATCH_SIZE as OP_BATCH_SIZE, GROUP_SIZE as OP_GROUP_SIZE, }; mod call_node; @@ -88,10 +88,6 @@ impl MastNode { Self::Dyn } - pub fn new_dyncall(dyn_node_id: MastNodeId, mast_forest: &MastForest) -> Self { - Self::Call(CallNode::new(dyn_node_id, mast_forest)) - } - pub fn new_external(mast_root: RpoDigest) -> Self { Self::External(ExternalNode::new(mast_root)) } diff --git a/core/src/mast/node/split_node.rs b/core/src/mast/node/split_node.rs index 907820365f..600186a9e7 100644 --- a/core/src/mast/node/split_node.rs +++ b/core/src/mast/node/split_node.rs @@ -33,6 +33,11 @@ impl SplitNode { Self { branches, digest } } + + #[cfg(test)] + pub fn new_test(branches: [MastNodeId; 2], digest: RpoDigest) -> Self { + Self { branches, digest } + } } /// Public accessors diff --git a/core/src/mast/serialization/basic_block_data_builder.rs b/core/src/mast/serialization/basic_block_data_builder.rs new file mode 100644 index 0000000000..78a045d6c4 --- /dev/null +++ b/core/src/mast/serialization/basic_block_data_builder.rs @@ -0,0 +1,181 @@ +use alloc::{collections::BTreeMap, vec::Vec}; +use miden_crypto::hash::blake::{Blake3Digest, Blake3_256}; +use winter_utils::{ByteWriter, Serializable}; + +use crate::{ + mast::{BasicBlockNode, OperationOrDecorator}, + AdviceInjector, DebugOptions, Decorator, SignatureKind, +}; + +use super::{decorator::EncodedDecoratorVariant, DataOffset, StringIndex}; + +// BASIC BLOCK DATA BUILDER +// ================================================================================================ + +/// Builds the `data` section of a serialized [`crate::mast::MastForest`]. +#[derive(Debug, Default)] +pub struct BasicBlockDataBuilder { + data: Vec, + string_table_builder: StringTableBuilder, +} + +/// Constructors +impl BasicBlockDataBuilder { + pub fn new() -> Self { + Self::default() + } +} + +/// Accessors +impl BasicBlockDataBuilder { + /// Returns the current offset into the data buffer. + pub fn get_offset(&self) -> DataOffset { + self.data.len() as DataOffset + } +} + +/// Mutators +impl BasicBlockDataBuilder { + /// Encodes a [`BasicBlockNode`] into the serialized [`crate::mast::MastForest`] data field. + pub fn encode_basic_block(&mut self, basic_block: &BasicBlockNode) { + // 2nd part of `mast_node_to_info()` (inside the match) + for op_or_decorator in basic_block.iter() { + match op_or_decorator { + OperationOrDecorator::Operation(operation) => operation.write_into(&mut self.data), + OperationOrDecorator::Decorator(decorator) => self.encode_decorator(decorator), + } + } + } + + /// Returns the serialized [`crate::mast::MastForest`] data field, as well as the string table. + pub fn into_parts(mut self) -> (Vec, Vec) { + let string_table = self.string_table_builder.into_table(&mut self.data); + (self.data, string_table) + } +} + +/// Helpers +impl BasicBlockDataBuilder { + fn encode_decorator(&mut self, decorator: &Decorator) { + // Set the first byte to the decorator discriminant. + { + let decorator_variant: EncodedDecoratorVariant = decorator.into(); + self.data.push(decorator_variant.discriminant()); + } + + // For decorators that have extra data, encode it in `data` and `strings`. + match decorator { + Decorator::Advice(advice_injector) => match advice_injector { + AdviceInjector::MapValueToStack { + include_len, + key_offset, + } => { + self.data.write_bool(*include_len); + self.data.write_usize(*key_offset); + } + AdviceInjector::HdwordToMap { domain } => { + self.data.extend(domain.as_int().to_le_bytes()) + } + + // Note: Since there is only 1 variant, we don't need to write any extra bytes. + AdviceInjector::SigToStack { kind } => match kind { + SignatureKind::RpoFalcon512 => (), + }, + AdviceInjector::MerkleNodeMerge + | AdviceInjector::MerkleNodeToStack + | AdviceInjector::UpdateMerkleNode + | AdviceInjector::U64Div + | AdviceInjector::Ext2Inv + | AdviceInjector::Ext2Intt + | AdviceInjector::SmtGet + | AdviceInjector::SmtSet + | AdviceInjector::SmtPeek + | AdviceInjector::U32Clz + | AdviceInjector::U32Ctz + | AdviceInjector::U32Clo + | AdviceInjector::U32Cto + | AdviceInjector::ILog2 + | AdviceInjector::MemToMap + | AdviceInjector::HpermToMap => (), + }, + Decorator::AsmOp(assembly_op) => { + self.data.push(assembly_op.num_cycles()); + self.data.write_bool(assembly_op.should_break()); + + // context name + { + let str_index_in_table = + self.string_table_builder.add_string(assembly_op.context_name()); + self.data.write_usize(str_index_in_table); + } + + // op + { + let str_index_in_table = self.string_table_builder.add_string(assembly_op.op()); + self.data.write_usize(str_index_in_table); + } + } + Decorator::Debug(debug_options) => match debug_options { + DebugOptions::StackTop(value) => self.data.push(*value), + DebugOptions::MemInterval(start, end) => { + self.data.extend(start.to_le_bytes()); + self.data.extend(end.to_le_bytes()); + } + DebugOptions::LocalInterval(start, second, end) => { + self.data.extend(start.to_le_bytes()); + self.data.extend(second.to_le_bytes()); + self.data.extend(end.to_le_bytes()); + } + DebugOptions::StackAll | DebugOptions::MemAll => (), + }, + Decorator::Event(value) | Decorator::Trace(value) => { + self.data.extend(value.to_le_bytes()) + } + } + } +} + +// STRING TABLE BUILDER +// ================================================================================================ + +#[derive(Debug, Default)] +struct StringTableBuilder { + table: Vec, + str_to_index: BTreeMap, StringIndex>, + strings_data: Vec, +} + +impl StringTableBuilder { + pub fn add_string(&mut self, string: &str) -> StringIndex { + if let Some(str_idx) = self.str_to_index.get(&Blake3_256::hash(string.as_bytes())) { + // return already interned string + *str_idx + } else { + // add new string to table + // NOTE: these string refs' offset will need to be shifted again in `into_table()` + let str_offset = self + .strings_data + .len() + .try_into() + .expect("strings table larger than 2^32 bytes"); + + let str_idx = self.table.len(); + + string.write_into(&mut self.strings_data); + self.table.push(str_offset); + self.str_to_index.insert(Blake3_256::hash(string.as_bytes()), str_idx); + + str_idx + } + } + + pub fn into_table(self, data: &mut Vec) -> Vec { + let table_offset: u32 = data + .len() + .try_into() + .expect("MAST forest serialization: data field longer than 2^32 bytes"); + data.extend(self.strings_data); + + self.table.into_iter().map(|str_offset| str_offset + table_offset).collect() + } +} diff --git a/core/src/mast/serialization/basic_block_data_decoder.rs b/core/src/mast/serialization/basic_block_data_decoder.rs new file mode 100644 index 0000000000..78dd77215f --- /dev/null +++ b/core/src/mast/serialization/basic_block_data_decoder.rs @@ -0,0 +1,199 @@ +use crate::{ + AdviceInjector, AssemblyOp, DebugOptions, Decorator, DecoratorList, Operation, SignatureKind, +}; + +use super::{decorator::EncodedDecoratorVariant, DataOffset, StringIndex}; +use alloc::{string::String, vec::Vec}; +use miden_crypto::Felt; +use winter_utils::{ByteReader, Deserializable, DeserializationError, SliceReader}; + +pub struct BasicBlockDataDecoder<'a> { + data: &'a [u8], + strings: &'a [DataOffset], +} + +/// Constructors +impl<'a> BasicBlockDataDecoder<'a> { + pub fn new(data: &'a [u8], strings: &'a [DataOffset]) -> Self { + Self { data, strings } + } +} + +/// Mutators +impl<'a> BasicBlockDataDecoder<'a> { + pub fn decode_operations_and_decorators( + &self, + offset: DataOffset, + num_to_decode: u32, + ) -> Result<(Vec, DecoratorList), DeserializationError> { + let mut operations: Vec = Vec::new(); + let mut decorators: DecoratorList = Vec::new(); + + let mut data_reader = SliceReader::new(&self.data[offset as usize..]); + for _ in 0..num_to_decode { + let first_byte = data_reader.peek_u8()?; + + if first_byte & 0b1000_0000 == 0 { + // operation. + operations.push(Operation::read_from(&mut data_reader)?); + } else { + // decorator. + let decorator = self.decode_decorator(&mut data_reader)?; + decorators.push((operations.len(), decorator)); + } + } + + Ok((operations, decorators)) + } +} + +/// Helpers +impl<'a> BasicBlockDataDecoder<'a> { + fn decode_decorator( + &self, + data_reader: &mut SliceReader, + ) -> Result { + let discriminant = data_reader.read_u8()?; + + let decorator_variant = EncodedDecoratorVariant::from_discriminant(discriminant) + .ok_or_else(|| { + DeserializationError::InvalidValue(format!( + "invalid decorator variant discriminant: {discriminant}" + )) + })?; + + match decorator_variant { + EncodedDecoratorVariant::AdviceInjectorMerkleNodeMerge => { + Ok(Decorator::Advice(AdviceInjector::MerkleNodeMerge)) + } + EncodedDecoratorVariant::AdviceInjectorMerkleNodeToStack => { + Ok(Decorator::Advice(AdviceInjector::MerkleNodeToStack)) + } + EncodedDecoratorVariant::AdviceInjectorUpdateMerkleNode => { + Ok(Decorator::Advice(AdviceInjector::UpdateMerkleNode)) + } + EncodedDecoratorVariant::AdviceInjectorMapValueToStack => { + let include_len = data_reader.read_bool()?; + let key_offset = data_reader.read_usize()?; + + Ok(Decorator::Advice(AdviceInjector::MapValueToStack { + include_len, + key_offset, + })) + } + EncodedDecoratorVariant::AdviceInjectorU64Div => { + Ok(Decorator::Advice(AdviceInjector::U64Div)) + } + EncodedDecoratorVariant::AdviceInjectorExt2Inv => { + Ok(Decorator::Advice(AdviceInjector::Ext2Inv)) + } + EncodedDecoratorVariant::AdviceInjectorExt2Intt => { + Ok(Decorator::Advice(AdviceInjector::Ext2Intt)) + } + EncodedDecoratorVariant::AdviceInjectorSmtGet => { + Ok(Decorator::Advice(AdviceInjector::SmtGet)) + } + EncodedDecoratorVariant::AdviceInjectorSmtSet => { + Ok(Decorator::Advice(AdviceInjector::SmtSet)) + } + EncodedDecoratorVariant::AdviceInjectorSmtPeek => { + Ok(Decorator::Advice(AdviceInjector::SmtPeek)) + } + EncodedDecoratorVariant::AdviceInjectorU32Clz => { + Ok(Decorator::Advice(AdviceInjector::U32Clz)) + } + EncodedDecoratorVariant::AdviceInjectorU32Ctz => { + Ok(Decorator::Advice(AdviceInjector::U32Ctz)) + } + EncodedDecoratorVariant::AdviceInjectorU32Clo => { + Ok(Decorator::Advice(AdviceInjector::U32Clo)) + } + EncodedDecoratorVariant::AdviceInjectorU32Cto => { + Ok(Decorator::Advice(AdviceInjector::U32Cto)) + } + EncodedDecoratorVariant::AdviceInjectorILog2 => { + Ok(Decorator::Advice(AdviceInjector::ILog2)) + } + EncodedDecoratorVariant::AdviceInjectorMemToMap => { + Ok(Decorator::Advice(AdviceInjector::MemToMap)) + } + EncodedDecoratorVariant::AdviceInjectorHdwordToMap => { + let domain = data_reader.read_u64()?; + let domain = Felt::try_from(domain).map_err(|err| { + DeserializationError::InvalidValue(format!( + "Error when deserializing HdwordToMap decorator domain: {err}" + )) + })?; + + Ok(Decorator::Advice(AdviceInjector::HdwordToMap { domain })) + } + EncodedDecoratorVariant::AdviceInjectorHpermToMap => { + Ok(Decorator::Advice(AdviceInjector::HpermToMap)) + } + EncodedDecoratorVariant::AdviceInjectorSigToStack => { + Ok(Decorator::Advice(AdviceInjector::SigToStack { + kind: SignatureKind::RpoFalcon512, + })) + } + EncodedDecoratorVariant::AssemblyOp => { + let num_cycles = data_reader.read_u8()?; + let should_break = data_reader.read_bool()?; + + let context_name = { + let str_index_in_table = data_reader.read_usize()?; + self.read_string(str_index_in_table)? + }; + + let op = { + let str_index_in_table = data_reader.read_usize()?; + self.read_string(str_index_in_table)? + }; + + Ok(Decorator::AsmOp(AssemblyOp::new(context_name, num_cycles, op, should_break))) + } + EncodedDecoratorVariant::DebugOptionsStackAll => { + Ok(Decorator::Debug(DebugOptions::StackAll)) + } + EncodedDecoratorVariant::DebugOptionsStackTop => { + let value = data_reader.read_u8()?; + + Ok(Decorator::Debug(DebugOptions::StackTop(value))) + } + EncodedDecoratorVariant::DebugOptionsMemAll => { + Ok(Decorator::Debug(DebugOptions::MemAll)) + } + EncodedDecoratorVariant::DebugOptionsMemInterval => { + let start = data_reader.read_u32()?; + let end = data_reader.read_u32()?; + + Ok(Decorator::Debug(DebugOptions::MemInterval(start, end))) + } + EncodedDecoratorVariant::DebugOptionsLocalInterval => { + let start = data_reader.read_u16()?; + let second = data_reader.read_u16()?; + let end = data_reader.read_u16()?; + + Ok(Decorator::Debug(DebugOptions::LocalInterval(start, second, end))) + } + EncodedDecoratorVariant::Event => { + let value = data_reader.read_u32()?; + + Ok(Decorator::Event(value)) + } + EncodedDecoratorVariant::Trace => { + let value = data_reader.read_u32()?; + + Ok(Decorator::Trace(value)) + } + } + } + + fn read_string(&self, str_idx: StringIndex) -> Result { + let str_offset = self.strings.get(str_idx).copied().ok_or_else(|| { + DeserializationError::InvalidValue(format!("invalid index in strings table: {str_idx}")) + })? as usize; + + let mut reader = SliceReader::new(&self.data[str_offset..]); + reader.read() + } +} diff --git a/core/src/mast/serialization/decorator.rs b/core/src/mast/serialization/decorator.rs new file mode 100644 index 0000000000..c1d9b2f0f1 --- /dev/null +++ b/core/src/mast/serialization/decorator.rs @@ -0,0 +1,98 @@ +use num_derive::{FromPrimitive, ToPrimitive}; +use num_traits::{FromPrimitive, ToPrimitive}; + +use crate::{AdviceInjector, DebugOptions, Decorator}; + +/// Stores all the possible [`Decorator`] variants, without any associated data. +/// +/// This is effectively equivalent to a set of constants, and designed to convert between variant +/// discriminant and enum variant conveniently. +#[derive(FromPrimitive, ToPrimitive)] +#[repr(u8)] +pub enum EncodedDecoratorVariant { + AdviceInjectorMerkleNodeMerge, + AdviceInjectorMerkleNodeToStack, + AdviceInjectorUpdateMerkleNode, + AdviceInjectorMapValueToStack, + AdviceInjectorU64Div, + AdviceInjectorExt2Inv, + AdviceInjectorExt2Intt, + AdviceInjectorSmtGet, + AdviceInjectorSmtSet, + AdviceInjectorSmtPeek, + AdviceInjectorU32Clz, + AdviceInjectorU32Ctz, + AdviceInjectorU32Clo, + AdviceInjectorU32Cto, + AdviceInjectorILog2, + AdviceInjectorMemToMap, + AdviceInjectorHdwordToMap, + AdviceInjectorHpermToMap, + AdviceInjectorSigToStack, + AssemblyOp, + DebugOptionsStackAll, + DebugOptionsStackTop, + DebugOptionsMemAll, + DebugOptionsMemInterval, + DebugOptionsLocalInterval, + Event, + Trace, +} + +impl EncodedDecoratorVariant { + /// Returns the discriminant of the given decorator variant. + /// + /// To distinguish them from [`crate::Operation`] discriminants, the most significant bit of + /// decorator discriminant is always set to 1. + pub fn discriminant(&self) -> u8 { + let discriminant = self.to_u8().expect("guaranteed to fit in a `u8` due to #[repr(u8)]"); + + discriminant | 0b1000_0000 + } + + /// The inverse operation of [`Self::discriminant`]. + pub fn from_discriminant(discriminant: u8) -> Option { + Self::from_u8(discriminant & 0b0111_1111) + } +} + +impl From<&Decorator> for EncodedDecoratorVariant { + fn from(decorator: &Decorator) -> Self { + match decorator { + Decorator::Advice(advice_injector) => match advice_injector { + AdviceInjector::MerkleNodeMerge => Self::AdviceInjectorMerkleNodeMerge, + AdviceInjector::MerkleNodeToStack => Self::AdviceInjectorMerkleNodeToStack, + AdviceInjector::UpdateMerkleNode => Self::AdviceInjectorUpdateMerkleNode, + AdviceInjector::MapValueToStack { + include_len: _, + key_offset: _, + } => Self::AdviceInjectorMapValueToStack, + AdviceInjector::U64Div => Self::AdviceInjectorU64Div, + AdviceInjector::Ext2Inv => Self::AdviceInjectorExt2Inv, + AdviceInjector::Ext2Intt => Self::AdviceInjectorExt2Intt, + AdviceInjector::SmtGet => Self::AdviceInjectorSmtGet, + AdviceInjector::SmtSet => Self::AdviceInjectorSmtSet, + AdviceInjector::SmtPeek => Self::AdviceInjectorSmtPeek, + AdviceInjector::U32Clz => Self::AdviceInjectorU32Clz, + AdviceInjector::U32Ctz => Self::AdviceInjectorU32Ctz, + AdviceInjector::U32Clo => Self::AdviceInjectorU32Clo, + AdviceInjector::U32Cto => Self::AdviceInjectorU32Cto, + AdviceInjector::ILog2 => Self::AdviceInjectorILog2, + AdviceInjector::MemToMap => Self::AdviceInjectorMemToMap, + AdviceInjector::HdwordToMap { domain: _ } => Self::AdviceInjectorHdwordToMap, + AdviceInjector::HpermToMap => Self::AdviceInjectorHpermToMap, + AdviceInjector::SigToStack { kind: _ } => Self::AdviceInjectorSigToStack, + }, + Decorator::AsmOp(_) => Self::AssemblyOp, + Decorator::Debug(debug_options) => match debug_options { + DebugOptions::StackAll => Self::DebugOptionsStackAll, + DebugOptions::StackTop(_) => Self::DebugOptionsStackTop, + DebugOptions::MemAll => Self::DebugOptionsMemAll, + DebugOptions::MemInterval(_, _) => Self::DebugOptionsMemInterval, + DebugOptions::LocalInterval(_, _, _) => Self::DebugOptionsLocalInterval, + }, + Decorator::Event(_) => Self::Event, + Decorator::Trace(_) => Self::Trace, + } + } +} diff --git a/core/src/mast/serialization/info.rs b/core/src/mast/serialization/info.rs new file mode 100644 index 0000000000..b72e1701e0 --- /dev/null +++ b/core/src/mast/serialization/info.rs @@ -0,0 +1,404 @@ +use miden_crypto::hash::rpo::RpoDigest; +use winter_utils::{ByteReader, ByteWriter, Deserializable, DeserializationError, Serializable}; + +use crate::mast::{MastForest, MastNode, MastNodeId, MerkleTreeNode}; + +use super::{basic_block_data_decoder::BasicBlockDataDecoder, DataOffset}; + +// MAST NODE INFO +// ================================================================================================ + +/// Represents a serialized [`MastNode`], with some data inlined in its [`MastNodeType`]. +/// +/// The serialized representation of [`MastNodeInfo`] is guaranteed to be fixed width, so that the +/// nodes stored in the `nodes` table of the serialized [`MastForest`] can be accessed quickly by +/// index. +#[derive(Debug)] +pub struct MastNodeInfo { + ty: MastNodeType, + digest: RpoDigest, +} + +impl MastNodeInfo { + pub fn new(mast_node: &MastNode, basic_block_offset: DataOffset) -> Self { + let ty = MastNodeType::new(mast_node, basic_block_offset); + + Self { + ty, + digest: mast_node.digest(), + } + } + + pub fn try_into_mast_node( + self, + mast_forest: &MastForest, + basic_block_data_decoder: &BasicBlockDataDecoder, + ) -> Result { + let mast_node = match self.ty { + MastNodeType::Block { + offset, + len: num_operations_and_decorators, + } => { + let (operations, decorators) = basic_block_data_decoder + .decode_operations_and_decorators(offset, num_operations_and_decorators)?; + + Ok(MastNode::new_basic_block_with_decorators(operations, decorators)) + } + MastNodeType::Join { + left_child_id, + right_child_id, + } => { + let left_child = MastNodeId::from_u32_safe(left_child_id, mast_forest)?; + let right_child = MastNodeId::from_u32_safe(right_child_id, mast_forest)?; + + Ok(MastNode::new_join(left_child, right_child, mast_forest)) + } + MastNodeType::Split { + if_branch_id, + else_branch_id, + } => { + let if_branch = MastNodeId::from_u32_safe(if_branch_id, mast_forest)?; + let else_branch = MastNodeId::from_u32_safe(else_branch_id, mast_forest)?; + + Ok(MastNode::new_split(if_branch, else_branch, mast_forest)) + } + MastNodeType::Loop { body_id } => { + let body_id = MastNodeId::from_u32_safe(body_id, mast_forest)?; + + Ok(MastNode::new_loop(body_id, mast_forest)) + } + MastNodeType::Call { callee_id } => { + let callee_id = MastNodeId::from_u32_safe(callee_id, mast_forest)?; + + Ok(MastNode::new_call(callee_id, mast_forest)) + } + MastNodeType::SysCall { callee_id } => { + let callee_id = MastNodeId::from_u32_safe(callee_id, mast_forest)?; + + Ok(MastNode::new_syscall(callee_id, mast_forest)) + } + MastNodeType::Dyn => Ok(MastNode::new_dynexec()), + MastNodeType::External => Ok(MastNode::new_external(self.digest)), + }?; + + if mast_node.digest() == self.digest { + Ok(mast_node) + } else { + Err(DeserializationError::InvalidValue(format!( + "MastNodeInfo's digest '{}' doesn't match deserialized MastNode's digest '{}'", + self.digest, + mast_node.digest() + ))) + } + } +} + +impl Serializable for MastNodeInfo { + fn write_into(&self, target: &mut W) { + let Self { ty, digest } = self; + + ty.write_into(target); + digest.write_into(target); + } +} + +impl Deserializable for MastNodeInfo { + fn read_from(source: &mut R) -> Result { + let ty = Deserializable::read_from(source)?; + let digest = RpoDigest::read_from(source)?; + + Ok(Self { ty, digest }) + } +} + +// MAST NODE TYPE +// ================================================================================================ + +const JOIN: u8 = 0; +const SPLIT: u8 = 1; +const LOOP: u8 = 2; +const BLOCK: u8 = 3; +const CALL: u8 = 4; +const SYSCALL: u8 = 5; +const DYN: u8 = 6; +const EXTERNAL: u8 = 7; + +/// Represents the variant of a [`MastNode`], as well as any additional data. For example, for more +/// efficient decoding, and because of the frequency with which these node types appear, we directly +/// represent the child indices for `Join`, `Split`, and `Loop`, `Call` and `SysCall` inline. +/// +/// The serialized representation of the MAST node type is guaranteed to be 8 bytes, so that +/// [`MastNodeInfo`] (which contains it) can be of fixed width. +#[derive(Debug, Clone, Copy, PartialEq, Eq)] +#[repr(u8)] +pub enum MastNodeType { + Join { + left_child_id: u32, + right_child_id: u32, + } = JOIN, + Split { + if_branch_id: u32, + else_branch_id: u32, + } = SPLIT, + Loop { + body_id: u32, + } = LOOP, + Block { + /// Offset of the basic block in the data segment + offset: u32, + /// The number of operations and decorators in the basic block + len: u32, + } = BLOCK, + Call { + callee_id: u32, + } = CALL, + SysCall { + callee_id: u32, + } = SYSCALL, + Dyn = DYN, + External = EXTERNAL, +} + +/// Constructors +impl MastNodeType { + /// Constructs a new [`MastNodeType`] from a [`MastNode`]. + pub fn new(mast_node: &MastNode, basic_block_offset: u32) -> Self { + use MastNode::*; + + match mast_node { + Block(block_node) => { + let len = block_node.num_operations_and_decorators(); + + Self::Block { + len, + offset: basic_block_offset, + } + } + Join(join_node) => Self::Join { + left_child_id: join_node.first().0, + right_child_id: join_node.second().0, + }, + Split(split_node) => Self::Split { + if_branch_id: split_node.on_true().0, + else_branch_id: split_node.on_false().0, + }, + Loop(loop_node) => Self::Loop { + body_id: loop_node.body().0, + }, + Call(call_node) => { + if call_node.is_syscall() { + Self::SysCall { + callee_id: call_node.callee().0, + } + } else { + Self::Call { + callee_id: call_node.callee().0, + } + } + } + Dyn => Self::Dyn, + External(_) => Self::External, + } + } +} + +impl Serializable for MastNodeType { + fn write_into(&self, target: &mut W) { + let discriminant = self.discriminant() as u64; + assert!(discriminant <= 0b1111); + + let payload = match *self { + MastNodeType::Join { + left_child_id: left, + right_child_id: right, + } => Self::encode_u32_pair(left, right), + MastNodeType::Split { + if_branch_id: if_branch, + else_branch_id: else_branch, + } => Self::encode_u32_pair(if_branch, else_branch), + MastNodeType::Loop { body_id: body } => Self::encode_u32_payload(body), + MastNodeType::Block { offset, len } => Self::encode_u32_pair(offset, len), + MastNodeType::Call { callee_id } => Self::encode_u32_payload(callee_id), + MastNodeType::SysCall { callee_id } => Self::encode_u32_payload(callee_id), + MastNodeType::Dyn => 0, + MastNodeType::External => 0, + }; + + let value = (discriminant << 60) | payload; + target.write_u64(value); + } +} + +/// Serialization helpers +impl MastNodeType { + fn discriminant(&self) -> u8 { + // SAFETY: This is safe because we have given this enum a primitive representation with + // #[repr(u8)], with the first field of the underlying union-of-structs the discriminant. + // + // See the section on "accessing the numeric value of the discriminant" + // here: https://doc.rust-lang.org/std/mem/fn.discriminant.html + unsafe { *<*const _>::from(self).cast::() } + } + + /// Encodes two u32 numbers in the first 60 bits of a `u64`. + /// + /// # Panics + /// - Panics if either `left_value` or `right_value` doesn't fit in 30 bits. + fn encode_u32_pair(left_value: u32, right_value: u32) -> u64 { + assert!( + left_value.leading_zeros() >= 2, + "MastNodeType::encode_u32_pair: left value doesn't fit in 30 bits: {}", + left_value + ); + assert!( + right_value.leading_zeros() >= 2, + "MastNodeType::encode_u32_pair: right value doesn't fit in 30 bits: {}", + right_value + ); + + ((left_value as u64) << 30) | (right_value as u64) + } + + fn encode_u32_payload(payload: u32) -> u64 { + payload as u64 + } +} + +impl Deserializable for MastNodeType { + fn read_from(source: &mut R) -> Result { + let (discriminant, payload) = { + let value = source.read_u64()?; + + // 4 bits + let discriminant = (value >> 60) as u8; + // 60 bits + let payload = value & 0x0F_FF_FF_FF_FF_FF_FF_FF; + + (discriminant, payload) + }; + + match discriminant { + JOIN => { + let (left_child_id, right_child_id) = Self::decode_u32_pair(payload); + Ok(Self::Join { + left_child_id, + right_child_id, + }) + } + SPLIT => { + let (if_branch_id, else_branch_id) = Self::decode_u32_pair(payload); + Ok(Self::Split { + if_branch_id, + else_branch_id, + }) + } + LOOP => { + let body_id = Self::decode_u32_payload(payload)?; + Ok(Self::Loop { body_id }) + } + BLOCK => { + let (offset, len) = Self::decode_u32_pair(payload); + Ok(Self::Block { offset, len }) + } + CALL => { + let callee_id = Self::decode_u32_payload(payload)?; + Ok(Self::Call { callee_id }) + } + SYSCALL => { + let callee_id = Self::decode_u32_payload(payload)?; + Ok(Self::SysCall { callee_id }) + } + DYN => Ok(Self::Dyn), + EXTERNAL => Ok(Self::External), + _ => Err(DeserializationError::InvalidValue(format!( + "Invalid tag for MAST node: {discriminant}" + ))), + } + } +} + +/// Deserialization helpers +impl MastNodeType { + /// Decodes two `u32` numbers from a 60-bit payload. + fn decode_u32_pair(payload: u64) -> (u32, u32) { + let left_value = (payload >> 30) as u32; + let right_value = (payload & 0x3F_FF_FF_FF) as u32; + + (left_value, right_value) + } + + /// Decodes one `u32` number from a 60-bit payload. + /// + /// Returns an error if the payload doesn't fit in a `u32`. + pub fn decode_u32_payload(payload: u64) -> Result { + payload.try_into().map_err(|_| { + DeserializationError::InvalidValue(format!( + "Invalid payload: expected to fit in u32, but was {payload}" + )) + }) + } +} + +#[cfg(test)] +mod tests { + use alloc::vec::Vec; + + use super::*; + + #[test] + fn serialize_deserialize_60_bit_payload() { + // each child needs 30 bits + let mast_node_type = MastNodeType::Join { + left_child_id: 0x3F_FF_FF_FF, + right_child_id: 0x3F_FF_FF_FF, + }; + + let serialized = mast_node_type.to_bytes(); + let deserialized = MastNodeType::read_from_bytes(&serialized).unwrap(); + + assert_eq!(mast_node_type, deserialized); + } + + #[test] + #[should_panic] + fn serialize_large_payloads_fails_1() { + // left child needs 31 bits + let mast_node_type = MastNodeType::Join { + left_child_id: 0x4F_FF_FF_FF, + right_child_id: 0x0, + }; + + // must panic + let _serialized = mast_node_type.to_bytes(); + } + + #[test] + #[should_panic] + fn serialize_large_payloads_fails_2() { + // right child needs 31 bits + let mast_node_type = MastNodeType::Join { + left_child_id: 0x0, + right_child_id: 0x4F_FF_FF_FF, + }; + + // must panic + let _serialized = mast_node_type.to_bytes(); + } + + #[test] + fn deserialize_large_payloads_fails() { + // Serialized `CALL` with a 33-bit payload + let serialized = { + let serialized_value = ((CALL as u64) << 60) | (u32::MAX as u64 + 1_u64); + + let mut serialized_buffer: Vec = Vec::new(); + serialized_value.write_into(&mut serialized_buffer); + + serialized_buffer + }; + + let deserialized_result = MastNodeType::read_from_bytes(&serialized); + + assert_matches!(deserialized_result, Err(DeserializationError::InvalidValue(_))); + } +} diff --git a/core/src/mast/serialization/mod.rs b/core/src/mast/serialization/mod.rs new file mode 100644 index 0000000000..a71c79c89a --- /dev/null +++ b/core/src/mast/serialization/mod.rs @@ -0,0 +1,135 @@ +use alloc::vec::Vec; +use winter_utils::{ByteReader, ByteWriter, Deserializable, DeserializationError, Serializable}; + +use super::{MastForest, MastNode, MastNodeId}; + +mod decorator; + +mod info; +use info::MastNodeInfo; + +mod basic_block_data_builder; +use basic_block_data_builder::BasicBlockDataBuilder; + +mod basic_block_data_decoder; +use basic_block_data_decoder::BasicBlockDataDecoder; + +#[cfg(test)] +mod tests; + +// TYPE ALIASES +// ================================================================================================ + +/// Specifies an offset into the `data` section of an encoded [`MastForest`]. +type DataOffset = u32; + +/// Specifies an offset into the `strings` table of an encoded [`MastForest`] +type StringIndex = usize; + +// CONSTANTS +// ================================================================================================ + +/// Magic string for detecting that a file is binary-encoded MAST. +const MAGIC: &[u8; 5] = b"MAST\0"; + +/// The format version. +/// +/// If future modifications are made to this format, the version should be incremented by 1. A +/// version of `[255, 255, 255]` is reserved for future extensions that require extending the +/// version field itself, but should be considered invalid for now. +const VERSION: [u8; 3] = [0, 0, 0]; + +// MAST FOREST SERIALIZATION/DESERIALIZATION +// ================================================================================================ + +impl Serializable for MastForest { + fn write_into(&self, target: &mut W) { + let mut basic_block_data_builder = BasicBlockDataBuilder::new(); + + // magic & version + target.write_bytes(MAGIC); + target.write_bytes(&VERSION); + + // node count + target.write_usize(self.nodes.len()); + + // roots + self.roots.write_into(target); + + // Prepare MAST node infos, but don't store them yet. We store them at the end to make + // deserialization more efficient. + let mast_node_infos: Vec = self + .nodes + .iter() + .map(|mast_node| { + let mast_node_info = + MastNodeInfo::new(mast_node, basic_block_data_builder.get_offset()); + + if let MastNode::Block(basic_block) = mast_node { + basic_block_data_builder.encode_basic_block(basic_block); + } + + mast_node_info + }) + .collect(); + + let (data, string_table) = basic_block_data_builder.into_parts(); + + string_table.write_into(target); + data.write_into(target); + + for mast_node_info in mast_node_infos { + mast_node_info.write_into(target); + } + } +} + +impl Deserializable for MastForest { + fn read_from(source: &mut R) -> Result { + let magic: [u8; 5] = source.read_array()?; + if magic != *MAGIC { + return Err(DeserializationError::InvalidValue(format!( + "Invalid magic bytes. Expected '{:?}', got '{:?}'", + *MAGIC, magic + ))); + } + + let version: [u8; 3] = source.read_array()?; + if version != VERSION { + return Err(DeserializationError::InvalidValue(format!( + "Unsupported version. Got '{version:?}', but only '{VERSION:?}' is supported", + ))); + } + + let node_count = source.read_usize()?; + + let roots: Vec = Deserializable::read_from(source)?; + + let strings: Vec = Deserializable::read_from(source)?; + + let data: Vec = Deserializable::read_from(source)?; + + let basic_block_data_decoder = BasicBlockDataDecoder::new(&data, &strings); + + let mast_forest = { + let mut mast_forest = MastForest::new(); + + for _ in 0..node_count { + let mast_node_info = MastNodeInfo::read_from(source)?; + + let node = + mast_node_info.try_into_mast_node(&mast_forest, &basic_block_data_decoder)?; + + mast_forest.add_node(node); + } + + for root in roots { + mast_forest.make_root(root); + } + + mast_forest + }; + + Ok(mast_forest) + } +} diff --git a/core/src/mast/serialization/tests.rs b/core/src/mast/serialization/tests.rs new file mode 100644 index 0000000000..4f0c56bcb8 --- /dev/null +++ b/core/src/mast/serialization/tests.rs @@ -0,0 +1,345 @@ +use alloc::string::ToString; +use miden_crypto::{hash::rpo::RpoDigest, Felt}; + +use super::*; +use crate::{ + operations::Operation, AdviceInjector, AssemblyOp, DebugOptions, Decorator, SignatureKind, +}; + +/// If this test fails to compile, it means that `Operation` or `Decorator` was changed. Make sure +/// that all tests in this file are updated accordingly. For example, if a new `Operation` variant +/// was added, make sure that you add it in the vector of operations in +/// [`serialize_deserialize_all_nodes`]. +#[test] +fn confirm_operation_and_decorator_structure() { + match Operation::Noop { + Operation::Noop => (), + Operation::Assert(_) => (), + Operation::FmpAdd => (), + Operation::FmpUpdate => (), + Operation::SDepth => (), + Operation::Caller => (), + Operation::Clk => (), + Operation::Join => (), + Operation::Split => (), + Operation::Loop => (), + Operation::Call => (), + Operation::Dyn => (), + Operation::SysCall => (), + Operation::Span => (), + Operation::End => (), + Operation::Repeat => (), + Operation::Respan => (), + Operation::Halt => (), + Operation::Add => (), + Operation::Neg => (), + Operation::Mul => (), + Operation::Inv => (), + Operation::Incr => (), + Operation::And => (), + Operation::Or => (), + Operation::Not => (), + Operation::Eq => (), + Operation::Eqz => (), + Operation::Expacc => (), + Operation::Ext2Mul => (), + Operation::U32split => (), + Operation::U32add => (), + Operation::U32assert2(_) => (), + Operation::U32add3 => (), + Operation::U32sub => (), + Operation::U32mul => (), + Operation::U32madd => (), + Operation::U32div => (), + Operation::U32and => (), + Operation::U32xor => (), + Operation::Pad => (), + Operation::Drop => (), + Operation::Dup0 => (), + Operation::Dup1 => (), + Operation::Dup2 => (), + Operation::Dup3 => (), + Operation::Dup4 => (), + Operation::Dup5 => (), + Operation::Dup6 => (), + Operation::Dup7 => (), + Operation::Dup9 => (), + Operation::Dup11 => (), + Operation::Dup13 => (), + Operation::Dup15 => (), + Operation::Swap => (), + Operation::SwapW => (), + Operation::SwapW2 => (), + Operation::SwapW3 => (), + Operation::SwapDW => (), + Operation::MovUp2 => (), + Operation::MovUp3 => (), + Operation::MovUp4 => (), + Operation::MovUp5 => (), + Operation::MovUp6 => (), + Operation::MovUp7 => (), + Operation::MovUp8 => (), + Operation::MovDn2 => (), + Operation::MovDn3 => (), + Operation::MovDn4 => (), + Operation::MovDn5 => (), + Operation::MovDn6 => (), + Operation::MovDn7 => (), + Operation::MovDn8 => (), + Operation::CSwap => (), + Operation::CSwapW => (), + Operation::Push(_) => (), + Operation::AdvPop => (), + Operation::AdvPopW => (), + Operation::MLoadW => (), + Operation::MStoreW => (), + Operation::MLoad => (), + Operation::MStore => (), + Operation::MStream => (), + Operation::Pipe => (), + Operation::HPerm => (), + Operation::MpVerify(_) => (), + Operation::MrUpdate => (), + Operation::FriE2F4 => (), + Operation::RCombBase => (), + }; + + match Decorator::Event(0) { + Decorator::Advice(advice) => match advice { + AdviceInjector::MerkleNodeMerge => (), + AdviceInjector::MerkleNodeToStack => (), + AdviceInjector::UpdateMerkleNode => (), + AdviceInjector::MapValueToStack { + include_len: _, + key_offset: _, + } => (), + AdviceInjector::U64Div => (), + AdviceInjector::Ext2Inv => (), + AdviceInjector::Ext2Intt => (), + AdviceInjector::SmtGet => (), + AdviceInjector::SmtSet => (), + AdviceInjector::SmtPeek => (), + AdviceInjector::U32Clz => (), + AdviceInjector::U32Ctz => (), + AdviceInjector::U32Clo => (), + AdviceInjector::U32Cto => (), + AdviceInjector::ILog2 => (), + AdviceInjector::MemToMap => (), + AdviceInjector::HdwordToMap { domain: _ } => (), + AdviceInjector::HpermToMap => (), + AdviceInjector::SigToStack { kind: _ } => (), + }, + Decorator::AsmOp(_) => (), + Decorator::Debug(debug_options) => match debug_options { + DebugOptions::StackAll => (), + DebugOptions::StackTop(_) => (), + DebugOptions::MemAll => (), + DebugOptions::MemInterval(_, _) => (), + DebugOptions::LocalInterval(_, _, _) => (), + }, + Decorator::Event(_) => (), + Decorator::Trace(_) => (), + }; +} + +#[test] +fn serialize_deserialize_all_nodes() { + let mut mast_forest = MastForest::new(); + + let basic_block_id = { + let operations = vec![ + Operation::Noop, + Operation::Assert(42), + Operation::FmpAdd, + Operation::FmpUpdate, + Operation::SDepth, + Operation::Caller, + Operation::Clk, + Operation::Join, + Operation::Split, + Operation::Loop, + Operation::Call, + Operation::Dyn, + Operation::SysCall, + Operation::Span, + Operation::End, + Operation::Repeat, + Operation::Respan, + Operation::Halt, + Operation::Add, + Operation::Neg, + Operation::Mul, + Operation::Inv, + Operation::Incr, + Operation::And, + Operation::Or, + Operation::Not, + Operation::Eq, + Operation::Eqz, + Operation::Expacc, + Operation::Ext2Mul, + Operation::U32split, + Operation::U32add, + Operation::U32assert2(222), + Operation::U32add3, + Operation::U32sub, + Operation::U32mul, + Operation::U32madd, + Operation::U32div, + Operation::U32and, + Operation::U32xor, + Operation::Pad, + Operation::Drop, + Operation::Dup0, + Operation::Dup1, + Operation::Dup2, + Operation::Dup3, + Operation::Dup4, + Operation::Dup5, + Operation::Dup6, + Operation::Dup7, + Operation::Dup9, + Operation::Dup11, + Operation::Dup13, + Operation::Dup15, + Operation::Swap, + Operation::SwapW, + Operation::SwapW2, + Operation::SwapW3, + Operation::SwapDW, + Operation::MovUp2, + Operation::MovUp3, + Operation::MovUp4, + Operation::MovUp5, + Operation::MovUp6, + Operation::MovUp7, + Operation::MovUp8, + Operation::MovDn2, + Operation::MovDn3, + Operation::MovDn4, + Operation::MovDn5, + Operation::MovDn6, + Operation::MovDn7, + Operation::MovDn8, + Operation::CSwap, + Operation::CSwapW, + Operation::Push(Felt::new(45)), + Operation::AdvPop, + Operation::AdvPopW, + Operation::MLoadW, + Operation::MStoreW, + Operation::MLoad, + Operation::MStore, + Operation::MStream, + Operation::Pipe, + Operation::HPerm, + Operation::MpVerify(1022), + Operation::MrUpdate, + Operation::FriE2F4, + Operation::RCombBase, + ]; + + let num_operations = operations.len(); + + let decorators = vec![ + (0, Decorator::Advice(AdviceInjector::MerkleNodeMerge)), + (0, Decorator::Advice(AdviceInjector::MerkleNodeToStack)), + (0, Decorator::Advice(AdviceInjector::UpdateMerkleNode)), + ( + 0, + Decorator::Advice(AdviceInjector::MapValueToStack { + include_len: true, + key_offset: 1023, + }), + ), + (1, Decorator::Advice(AdviceInjector::U64Div)), + (3, Decorator::Advice(AdviceInjector::Ext2Inv)), + (5, Decorator::Advice(AdviceInjector::Ext2Intt)), + (5, Decorator::Advice(AdviceInjector::SmtGet)), + (5, Decorator::Advice(AdviceInjector::SmtSet)), + (5, Decorator::Advice(AdviceInjector::SmtPeek)), + (5, Decorator::Advice(AdviceInjector::U32Clz)), + (10, Decorator::Advice(AdviceInjector::U32Ctz)), + (10, Decorator::Advice(AdviceInjector::U32Clo)), + (10, Decorator::Advice(AdviceInjector::U32Cto)), + (10, Decorator::Advice(AdviceInjector::ILog2)), + (10, Decorator::Advice(AdviceInjector::MemToMap)), + ( + 10, + Decorator::Advice(AdviceInjector::HdwordToMap { + domain: Felt::new(423), + }), + ), + (15, Decorator::Advice(AdviceInjector::HpermToMap)), + ( + 15, + Decorator::Advice(AdviceInjector::SigToStack { + kind: SignatureKind::RpoFalcon512, + }), + ), + ( + 15, + Decorator::AsmOp(AssemblyOp::new( + "context".to_string(), + 15, + "op".to_string(), + false, + )), + ), + (15, Decorator::Debug(DebugOptions::StackAll)), + (15, Decorator::Debug(DebugOptions::StackTop(255))), + (15, Decorator::Debug(DebugOptions::MemAll)), + (15, Decorator::Debug(DebugOptions::MemInterval(0, 16))), + (17, Decorator::Debug(DebugOptions::LocalInterval(1, 2, 3))), + (num_operations, Decorator::Event(45)), + (num_operations, Decorator::Trace(55)), + ]; + + let basic_block_node = MastNode::new_basic_block_with_decorators(operations, decorators); + mast_forest.add_node(basic_block_node) + }; + + let call_node_id = { + let node = MastNode::new_call(basic_block_id, &mast_forest); + mast_forest.add_node(node) + }; + + let syscall_node_id = { + let node = MastNode::new_syscall(basic_block_id, &mast_forest); + mast_forest.add_node(node) + }; + + let loop_node_id = { + let node = MastNode::new_loop(basic_block_id, &mast_forest); + mast_forest.add_node(node) + }; + let join_node_id = { + let node = MastNode::new_join(basic_block_id, call_node_id, &mast_forest); + mast_forest.add_node(node) + }; + let split_node_id = { + let node = MastNode::new_split(basic_block_id, call_node_id, &mast_forest); + mast_forest.add_node(node) + }; + let dyn_node_id = { + let node = MastNode::new_dynexec(); + mast_forest.add_node(node) + }; + + let external_node_id = { + let node = MastNode::new_external(RpoDigest::default()); + mast_forest.add_node(node) + }; + + mast_forest.make_root(join_node_id); + mast_forest.make_root(syscall_node_id); + mast_forest.make_root(loop_node_id); + mast_forest.make_root(split_node_id); + mast_forest.make_root(dyn_node_id); + mast_forest.make_root(external_node_id); + + let serialized_mast_forest = mast_forest.to_bytes(); + let deserialized_mast_forest = MastForest::read_from_bytes(&serialized_mast_forest).unwrap(); + + assert_eq!(mast_forest, deserialized_mast_forest); +} diff --git a/core/src/operations/decorators/mod.rs b/core/src/operations/decorators/mod.rs index b049cdba86..d183ada7aa 100644 --- a/core/src/operations/decorators/mod.rs +++ b/core/src/operations/decorators/mod.rs @@ -32,7 +32,7 @@ pub enum Decorator { Debug(DebugOptions), /// Emits an event to the host. Event(u32), - /// Emmits a trace to the host. + /// Emits a trace to the host. Trace(u32), } diff --git a/core/src/operations/mod.rs b/core/src/operations/mod.rs index 8103d4d2e8..72fed3bae4 100644 --- a/core/src/operations/mod.rs +++ b/core/src/operations/mod.rs @@ -276,7 +276,7 @@ pub enum Operation { /// /// The internal value specifies an error code associated with the error in case when the /// assertion fails. - U32assert2(Felt) = OPCODE_U32ASSERT2, + U32assert2(u32) = OPCODE_U32ASSERT2, /// Pops three elements off the stack, adds them together, and splits the result into upper /// and lower 32-bit values. Then pushes the result back onto the stack. @@ -735,16 +735,16 @@ impl Serializable for Operation { // For operations that have extra data, encode it in `data`. match self { - Operation::Assert(err_code) | Operation::MpVerify(err_code) => { - err_code.to_le_bytes().write_into(target) + Operation::Assert(err_code) + | Operation::MpVerify(err_code) + | Operation::U32assert2(err_code) => { + err_code.to_le_bytes().write_into(target); } - Operation::U32assert2(err_code) => err_code.as_int().write_into(target), Operation::Push(value) => value.as_int().write_into(target), // Note: we explicitly write out all the operations so that whenever we make a - // modification to the `Operation` enum, we get a compile error here. This should help - // us remember to properly encode/decode each operation variant. Remember to also fix - // deserialization! + // modification to the `Operation` enum, we get a compile error here. This + // should help us remember to properly encode/decode each operation variant. Operation::Noop | Operation::FmpAdd | Operation::FmpUpdate @@ -873,8 +873,7 @@ impl Deserializable for Operation { OPCODE_SWAPDW => Self::SwapDW, OPCODE_ASSERT => { - let err_code_le_bytes: [u8; 4] = source.read_array()?; - let err_code = u32::from_le_bytes(err_code_le_bytes); + let err_code = source.read_u32()?; Self::Assert(err_code) } OPCODE_EQ => Self::Eq, @@ -916,23 +915,16 @@ impl Deserializable for Operation { OPCODE_U32DIV => Self::U32div, OPCODE_U32SPLIT => Self::U32split, OPCODE_U32ASSERT2 => { - let err_code_le_bytes: [u8; 8] = source.read_array()?; - let err_code_u64 = u64::from_le_bytes(err_code_le_bytes); - let err_code_felt = Felt::try_from(err_code_u64).map_err(|_| { - DeserializationError::InvalidValue(format!( - "Operation associated data doesn't fit in a field element: {err_code_u64}" - )) - })?; + let err_code = source.read_u32()?; - Self::U32assert2(err_code_felt) + Self::U32assert2(err_code) } OPCODE_U32ADD3 => Self::U32add3, OPCODE_U32MADD => Self::U32madd, OPCODE_HPERM => Self::HPerm, OPCODE_MPVERIFY => { - let err_code_le_bytes: [u8; 4] = source.read_array()?; - let err_code = u32::from_le_bytes(err_code_le_bytes); + let err_code = source.read_u32()?; Self::MpVerify(err_code) } @@ -947,8 +939,7 @@ impl Deserializable for Operation { OPCODE_MRUPDATE => Self::MrUpdate, OPCODE_PUSH => { - let value_le_bytes: [u8; 8] = source.read_array()?; - let value_u64 = u64::from_le_bytes(value_le_bytes); + let value_u64 = source.read_u64()?; let value_felt = Felt::try_from(value_u64).map_err(|_| { DeserializationError::InvalidValue(format!( "Operation associated data doesn't fit in a field element: {value_u64}" diff --git a/processor/src/operations/u32_ops.rs b/processor/src/operations/u32_ops.rs index 0b4f3a0198..d525ab4a9a 100644 --- a/processor/src/operations/u32_ops.rs +++ b/processor/src/operations/u32_ops.rs @@ -28,15 +28,15 @@ where /// Pops top two element off the stack, splits them into low and high 32-bit values, checks if /// the high values are equal to 0; if they are, puts the original elements back onto the /// stack; if they are not, returns an error. - pub(super) fn op_u32assert2(&mut self, err_code: Felt) -> Result<(), ExecutionError> { + pub(super) fn op_u32assert2(&mut self, err_code: u32) -> Result<(), ExecutionError> { let a = self.stack.get(0); let b = self.stack.get(1); if a.as_int() >> 32 != 0 { - return Err(ExecutionError::NotU32Value(a, err_code)); + return Err(ExecutionError::NotU32Value(a, Felt::from(err_code))); } if b.as_int() >> 32 != 0 { - return Err(ExecutionError::NotU32Value(b, err_code)); + return Err(ExecutionError::NotU32Value(b, Felt::from(err_code))); } self.add_range_checks(Operation::U32assert2(err_code), a, b, false); @@ -280,7 +280,7 @@ mod tests { let stack = StackInputs::try_from_ints([d as u64, c as u64, b as u64, a as u64]).unwrap(); let mut process = Process::new_dummy_with_decoder_helpers(stack); - process.execute_op(Operation::U32assert2(ZERO)).unwrap(); + process.execute_op(Operation::U32assert2(0)).unwrap(); let expected = build_expected(&[a, b, c, d]); assert_eq!(expected, process.stack.trace_state()); }