Skip to content

Commit

Permalink
feat: add unsafe constructors to nodes for deserialization (#1453)
Browse files Browse the repository at this point in the history
  • Loading branch information
sergerad authored Aug 18, 2024
1 parent 60ccf23 commit fd32a69
Show file tree
Hide file tree
Showing 17 changed files with 188 additions and 129 deletions.
9 changes: 9 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,15 @@

## 0.11.0 (TBD)

#### Enhancements

- Updated `MastForest::read_from` to deserialize without computing node hashes unnecessarily (#1453).

#### Changes

- Added `new_unsafe()` constructors to MAST node types which do not compute node hashes (#1453).
- Consolidated `BasicBlockNode` constructors and converted assert flow to `MastForestError::EmptyBasicBlock` (#1453).

## 0.10.3 (2024-08-12)

#### Enhancements
Expand Down
8 changes: 2 additions & 6 deletions assembly/src/assembler/mast_forest_builder.rs
Original file line number Diff line number Diff line change
Expand Up @@ -170,12 +170,8 @@ impl MastForestBuilder {
operations: Vec<Operation>,
decorators: Option<DecoratorList>,
) -> Result<MastNodeId, AssemblyError> {
match decorators {
Some(decorators) => {
self.ensure_node(MastNode::new_basic_block_with_decorators(operations, decorators))
},
None => self.ensure_node(MastNode::new_basic_block(operations)),
}
let block = MastNode::new_basic_block(operations, decorators)?;
self.ensure_node(block)
}

/// Adds a join node to the forest, and returns the [`MastNodeId`] associated with it.
Expand Down
10 changes: 4 additions & 6 deletions core/src/mast/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -68,12 +68,8 @@ impl MastForest {
operations: Vec<Operation>,
decorators: Option<DecoratorList>,
) -> Result<MastNodeId, MastForestError> {
match decorators {
Some(decorators) => {
self.add_node(MastNode::new_basic_block_with_decorators(operations, decorators))
},
None => self.add_node(MastNode::new_basic_block(operations)),
}
let block = MastNode::new_basic_block(operations, decorators)?;
self.add_node(block)
}

/// Adds a join node to the forest, and returns the [`MastNodeId`] associated with it.
Expand Down Expand Up @@ -271,4 +267,6 @@ pub enum MastForestError {
TooManyNodes,
#[error("node id: {0} is greater than or equal to forest length: {1}")]
NodeIdOverflow(MastNodeId, usize),
#[error("basic block cannot be created from an empty list of operations")]
EmptyBasicBlock,
}
79 changes: 47 additions & 32 deletions core/src/mast/node/basic_block_node/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,9 @@ use miden_crypto::{hash::rpo::RpoDigest, Felt, ZERO};
use miden_formatting::prettier::PrettyPrint;
use winter_utils::flatten_slice_elements;

use crate::{chiplets::hasher, Decorator, DecoratorIterator, DecoratorList, Operation};
use crate::{
chiplets::hasher, mast::MastForestError, Decorator, DecoratorIterator, DecoratorList, Operation,
};

mod op_batch;
pub use op_batch::OpBatch;
Expand Down Expand Up @@ -77,31 +79,38 @@ impl BasicBlockNode {
// ------------------------------------------------------------------------------------------------
/// Constructors
impl BasicBlockNode {
/// Returns a new [`BasicBlockNode`] instantiated with the specified operations.
///
/// # Errors (TODO)
/// Returns an error if:
/// - `operations` vector is empty.
/// - `operations` vector contains any number of system operations.
pub fn new(operations: Vec<Operation>) -> Self {
assert!(!operations.is_empty()); // TODO: return error
Self::with_decorators(operations, DecoratorList::new())
}

/// Returns a new [`BasicBlockNode`] instantiated with the specified operations and decorators.
///
/// # Errors (TODO)
/// Returns an error if:
/// - `operations` vector is empty.
/// - `operations` vector contains any number of system operations.
pub fn with_decorators(operations: Vec<Operation>, decorators: DecoratorList) -> Self {
assert!(!operations.is_empty()); // TODO: return error
pub fn new(
operations: Vec<Operation>,
decorators: Option<DecoratorList>,
) -> Result<Self, MastForestError> {
if operations.is_empty() {
return Err(MastForestError::EmptyBasicBlock);
}

// validate decorators list (only in debug mode)
// None is equivalent to an empty list of decorators moving forward.
let decorators = decorators.unwrap_or_default();

// Validate decorators list (only in debug mode).
#[cfg(debug_assertions)]
validate_decorators(&operations, &decorators);

let (op_batches, digest) = batch_ops(operations);
let (op_batches, digest) = batch_and_hash_ops(operations);
Ok(Self { op_batches, digest, decorators })
}

/// Returns a new [`BasicBlockNode`] from values that are assumed to be correct.
/// Should only be used when the source of the inputs is trusted (e.g. deserialization).
pub fn new_unsafe(
operations: Vec<Operation>,
decorators: DecoratorList,
digest: RpoDigest,
) -> Self {
assert!(!operations.is_empty());
let (op_batches, _) = batch_ops(operations);
Self { op_batches, digest, decorators }
}
}
Expand Down Expand Up @@ -292,18 +301,29 @@ impl<'a> Iterator for OperationOrDecoratorIterator<'a> {
// HELPER FUNCTIONS
// ================================================================================================

/// Groups the provided operations into batches and computes the hash of the block.
fn batch_and_hash_ops(ops: Vec<Operation>) -> (Vec<OpBatch>, RpoDigest) {
// Group the operations into batches.
let (batches, batch_groups) = batch_ops(ops);

// Compute the hash of all operation groups.
let op_groups = &flatten_slice_elements(&batch_groups);
let hash = hasher::hash_elements(op_groups);

(batches, hash)
}

/// Groups the provided operations into batches as described in the docs for this module (i.e.,
/// up to 9 operations per group, and 8 groups per batch).
///
/// After the operations have been grouped, computes the hash of the block.
fn batch_ops(ops: Vec<Operation>) -> (Vec<OpBatch>, RpoDigest) {
let mut batch_acc = OpBatchAccumulator::new();
/// Returns a list of operation batches and a list of operation groups.
fn batch_ops(ops: Vec<Operation>) -> (Vec<OpBatch>, Vec<[Felt; BATCH_SIZE]>) {
let mut batches = Vec::<OpBatch>::new();
let mut batch_acc = OpBatchAccumulator::new();
let mut batch_groups = Vec::<[Felt; BATCH_SIZE]>::new();

for op in ops {
// if the operation cannot be accepted into the current accumulator, add the contents of
// the accumulator to the list of batches and start a new accumulator
// If the operation cannot be accepted into the current accumulator, add the contents of
// the accumulator to the list of batches and start a new accumulator.
if !batch_acc.can_accept_op(op) {
let batch = batch_acc.into_batch();
batch_acc = OpBatchAccumulator::new();
Expand All @@ -312,22 +332,17 @@ fn batch_ops(ops: Vec<Operation>) -> (Vec<OpBatch>, RpoDigest) {
batches.push(batch);
}

// add the operation to the accumulator
// Add the operation to the accumulator.
batch_acc.add_op(op);
}

// make sure we finished processing the last batch
// Make sure we finished processing the last batch.
if !batch_acc.is_empty() {
let batch = batch_acc.into_batch();
batch_groups.push(*batch.groups());
batches.push(batch);
}

// compute the hash of all operation groups
let op_groups = &flatten_slice_elements(&batch_groups);
let hash = hasher::hash_elements(op_groups);

(batches, hash)
(batches, batch_groups)
}

/// Checks if a given decorators list is valid (only checked in debug mode)
Expand Down
20 changes: 10 additions & 10 deletions core/src/mast/node/basic_block_node/tests.rs
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@ use crate::{Decorator, ONE};
fn batch_ops() {
// --- one operation ----------------------------------------------------------------------
let ops = vec![Operation::Add];
let (batches, hash) = super::batch_ops(ops.clone());
let (batches, hash) = super::batch_and_hash_ops(ops.clone());
assert_eq!(1, batches.len());

let batch = &batches[0];
Expand All @@ -21,7 +21,7 @@ fn batch_ops() {

// --- two operations ---------------------------------------------------------------------
let ops = vec![Operation::Add, Operation::Mul];
let (batches, hash) = super::batch_ops(ops.clone());
let (batches, hash) = super::batch_and_hash_ops(ops.clone());
assert_eq!(1, batches.len());

let batch = &batches[0];
Expand All @@ -37,7 +37,7 @@ fn batch_ops() {

// --- one group with one immediate value -------------------------------------------------
let ops = vec![Operation::Add, Operation::Push(Felt::new(12345678))];
let (batches, hash) = super::batch_ops(ops.clone());
let (batches, hash) = super::batch_and_hash_ops(ops.clone());
assert_eq!(1, batches.len());

let batch = &batches[0];
Expand All @@ -63,7 +63,7 @@ fn batch_ops() {
Operation::Push(Felt::new(7)),
Operation::Add,
];
let (batches, hash) = super::batch_ops(ops.clone());
let (batches, hash) = super::batch_and_hash_ops(ops.clone());
assert_eq!(1, batches.len());

let batch = &batches[0];
Expand Down Expand Up @@ -98,7 +98,7 @@ fn batch_ops() {
Operation::Add,
Operation::Push(Felt::new(7)),
];
let (batches, hash) = super::batch_ops(ops.clone());
let (batches, hash) = super::batch_and_hash_ops(ops.clone());
assert_eq!(2, batches.len());

let batch0 = &batches[0];
Expand Down Expand Up @@ -147,7 +147,7 @@ fn batch_ops() {
Operation::Add,
];

let (batches, hash) = super::batch_ops(ops.clone());
let (batches, hash) = super::batch_and_hash_ops(ops.clone());
assert_eq!(1, batches.len());

let batch = &batches[0];
Expand Down Expand Up @@ -181,7 +181,7 @@ fn batch_ops() {
Operation::Add,
Operation::Push(Felt::new(11)),
];
let (batches, hash) = super::batch_ops(ops.clone());
let (batches, hash) = super::batch_and_hash_ops(ops.clone());
assert_eq!(1, batches.len());

let batch = &batches[0];
Expand Down Expand Up @@ -215,7 +215,7 @@ fn batch_ops() {
Operation::Push(ONE),
Operation::Push(Felt::new(2)),
];
let (batches, hash) = super::batch_ops(ops.clone());
let (batches, hash) = super::batch_and_hash_ops(ops.clone());
assert_eq!(1, batches.len());

let batch = &batches[0];
Expand Down Expand Up @@ -260,7 +260,7 @@ fn batch_ops() {
Operation::Pad,
];

let (batches, hash) = super::batch_ops(ops.clone());
let (batches, hash) = super::batch_and_hash_ops(ops.clone());
assert_eq!(2, batches.len());

let batch0 = &batches[0];
Expand Down Expand Up @@ -306,7 +306,7 @@ fn operation_or_decorator_iterator() {
(4, Decorator::Event(4)),
];

let node = BasicBlockNode::with_decorators(operations, decorators);
let node = BasicBlockNode::new(operations, Some(decorators)).unwrap();

let mut iterator = node.iter();

Expand Down
12 changes: 12 additions & 0 deletions core/src/mast/node/call_node.rs
Original file line number Diff line number Diff line change
Expand Up @@ -51,6 +51,12 @@ impl CallNode {
Ok(Self { callee, is_syscall: false, digest })
}

/// Returns a new [`CallNode`] from values that are assumed to be correct.
/// Should only be used when the source of the inputs is trusted (e.g. deserialization).
pub fn new_unsafe(callee: MastNodeId, digest: RpoDigest) -> Self {
Self { callee, is_syscall: false, digest }
}

/// Returns a new [`CallNode`] instantiated with the specified callee and marked as a kernel
/// call.
pub fn new_syscall(
Expand All @@ -68,6 +74,12 @@ impl CallNode {

Ok(Self { callee, is_syscall: true, digest })
}

/// Returns a new syscall [`CallNode`] from values that are assumed to be correct.
/// Should only be used when the source of the inputs is trusted (e.g. deserialization).
pub fn new_syscall_unsafe(callee: MastNodeId, digest: RpoDigest) -> Self {
Self { callee, is_syscall: true, digest }
}
}

//-------------------------------------------------------------------------------------------------
Expand Down
5 changes: 3 additions & 2 deletions core/src/mast/node/join_node.rs
Original file line number Diff line number Diff line change
Expand Up @@ -49,8 +49,9 @@ impl JoinNode {
Ok(Self { children, digest })
}

#[cfg(test)]
pub fn new_test(children: [MastNodeId; 2], digest: RpoDigest) -> Self {
/// Returns a new [`JoinNode`] from values that are assumed to be correct.
/// Should only be used when the source of the inputs is trusted (e.g. deserialization).
pub fn new_unsafe(children: [MastNodeId; 2], digest: RpoDigest) -> Self {
Self { children, digest }
}
}
Expand Down
7 changes: 7 additions & 0 deletions core/src/mast/node/loop_node.rs
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,7 @@ impl LoopNode {

/// Constructors
impl LoopNode {
/// Returns a new [`LoopNode`] instantiated with the specified body node.
pub fn new(body: MastNodeId, mast_forest: &MastForest) -> Result<Self, MastForestError> {
if body.as_usize() >= mast_forest.nodes.len() {
return Err(MastForestError::NodeIdOverflow(body, mast_forest.nodes.len()));
Expand All @@ -44,6 +45,12 @@ impl LoopNode {

Ok(Self { body, digest })
}

/// Returns a new [`LoopNode`] from values that are assumed to be correct.
/// Should only be used when the source of the inputs is trusted (e.g. deserialization).
pub fn new_unsafe(body: MastNodeId, digest: RpoDigest) -> Self {
Self { body, digest }
}
}

impl LoopNode {
Expand Down
13 changes: 5 additions & 8 deletions core/src/mast/node/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -50,15 +50,12 @@ pub enum MastNode {
// ------------------------------------------------------------------------------------------------
/// Constructors
impl MastNode {
pub fn new_basic_block(operations: Vec<Operation>) -> Self {
Self::Block(BasicBlockNode::new(operations))
}

pub fn new_basic_block_with_decorators(
pub fn new_basic_block(
operations: Vec<Operation>,
decorators: DecoratorList,
) -> Self {
Self::Block(BasicBlockNode::with_decorators(operations, decorators))
decorators: Option<DecoratorList>,
) -> Result<Self, MastForestError> {
let block = BasicBlockNode::new(operations, decorators)?;
Ok(Self::Block(block))
}

pub fn new_join(
Expand Down
5 changes: 3 additions & 2 deletions core/src/mast/node/split_node.rs
Original file line number Diff line number Diff line change
Expand Up @@ -52,8 +52,9 @@ impl SplitNode {
Ok(Self { branches, digest })
}

#[cfg(test)]
pub fn new_test(branches: [MastNodeId; 2], digest: RpoDigest) -> Self {
/// Returns a new [`SplitNode`] from values that are assumed to be correct.
/// Should only be used when the source of the inputs is trusted (e.g. deserialization).
pub fn new_unsafe(branches: [MastNodeId; 2], digest: RpoDigest) -> Self {
Self { branches, digest }
}
}
Expand Down
Loading

0 comments on commit fd32a69

Please sign in to comment.