Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

MastForest maximum node length invariant #1394

Merged
merged 4 commits into from
Jul 19, 2024
Merged
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
- Added support for immediate values for `u32and`, `u32or`, `u32xor` and `u32not` bitwise instructions (#1362).
- Optimized `std::sys::truncate_stuck` procedure (#1384).
- Add serialization/deserialization for `MastForest` (#1370)
- Introduce `MastForestError` to enforce `MastForest` node count invariant (#1394)

#### Changed

Expand Down
14 changes: 7 additions & 7 deletions assembly/src/assembler/basic_block_builder.rs
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@ use super::{
};
use alloc::{borrow::Borrow, string::ToString, vec::Vec};
use vm_core::{
mast::{MastNode, MastNodeId},
mast::{MastForestError, MastNode, MastNodeId},
AdviceInjector, AssemblyOp, Operation,
};

Expand Down Expand Up @@ -129,22 +129,22 @@ impl BasicBlockBuilder {
pub fn make_basic_block(
&mut self,
mast_forest_builder: &mut MastForestBuilder,
) -> Option<MastNodeId> {
) -> Result<Option<MastNodeId>, MastForestError> {
if !self.ops.is_empty() {
let ops = self.ops.drain(..).collect();
let decorators = self.decorators.drain(..).collect();

let basic_block_node = MastNode::new_basic_block_with_decorators(ops, decorators);
let basic_block_node_id = mast_forest_builder.ensure_node(basic_block_node);
let basic_block_node_id = mast_forest_builder.ensure_node(basic_block_node)?;

Some(basic_block_node_id)
Ok(Some(basic_block_node_id))
} else if !self.decorators.is_empty() {
// this is a bug in the assembler. we shouldn't have decorators added without their
// associated operations
// TODO: change this to an error or allow decorators in empty span blocks
unreachable!("decorators in an empty SPAN block")
} else {
None
Ok(None)
}
}

Expand All @@ -155,10 +155,10 @@ impl BasicBlockBuilder {
/// - Operations contained in the epilogue of the builder are appended to the list of ops which
/// go into the new BASIC BLOCK node.
/// - The builder is consumed in the process.
pub fn into_basic_block(
pub fn try_into_basic_block(
mut self,
mast_forest_builder: &mut MastForestBuilder,
) -> Option<MastNodeId> {
) -> Result<Option<MastNodeId>, MastForestError> {
self.ops.append(&mut self.epilogue);
self.make_basic_block(mast_forest_builder)
}
Expand Down
45 changes: 26 additions & 19 deletions assembly/src/assembler/instruction/procedures.rs
Original file line number Diff line number Diff line change
Expand Up @@ -91,37 +91,44 @@ impl Assembler {
// procedures, such that when we assemble a procedure, all
// procedures that it calls will have been assembled, and
// hence be present in the `MastForest`.
mast_forest_builder.find_procedure_root(mast_root).unwrap_or_else(|| {
// If the MAST root called isn't known to us, make it an external
// reference.
let external_node = MastNode::new_external(mast_root);
mast_forest_builder.ensure_node(external_node)
})
match mast_forest_builder.find_procedure_root(mast_root) {
Some(root) => root,
None => {
// If the MAST root called isn't known to us, make it an external
// reference.
let external_node = MastNode::new_external(mast_root);
mast_forest_builder.ensure_node(external_node)?
}
}
}
InvokeKind::Call => {
let callee_id =
mast_forest_builder.find_procedure_root(mast_root).unwrap_or_else(|| {
let callee_id = match mast_forest_builder.find_procedure_root(mast_root) {
Some(callee_id) => callee_id,
None => {
// If the MAST root called isn't known to us, make it an external
// reference.
let external_node = MastNode::new_external(mast_root);
mast_forest_builder.ensure_node(external_node)
});
mast_forest_builder.ensure_node(external_node)?
}
};

let call_node = MastNode::new_call(callee_id, mast_forest_builder.forest());
mast_forest_builder.ensure_node(call_node)
mast_forest_builder.ensure_node(call_node)?
}
InvokeKind::SysCall => {
let callee_id =
mast_forest_builder.find_procedure_root(mast_root).unwrap_or_else(|| {
let callee_id = match mast_forest_builder.find_procedure_root(mast_root) {
Some(callee_id) => callee_id,
None => {
// If the MAST root called isn't known to us, make it an external
// reference.
let external_node = MastNode::new_external(mast_root);
mast_forest_builder.ensure_node(external_node)
});
mast_forest_builder.ensure_node(external_node)?
}
};

let syscall_node =
MastNode::new_syscall(callee_id, mast_forest_builder.forest());
mast_forest_builder.ensure_node(syscall_node)
mast_forest_builder.ensure_node(syscall_node)?
}
}
};
Expand All @@ -134,7 +141,7 @@ impl Assembler {
&self,
mast_forest_builder: &mut MastForestBuilder,
) -> Result<Option<MastNodeId>, AssemblyError> {
let dyn_node_id = mast_forest_builder.ensure_node(MastNode::Dyn);
let dyn_node_id = mast_forest_builder.ensure_node(MastNode::Dyn)?;

Ok(Some(dyn_node_id))
}
Expand All @@ -145,10 +152,10 @@ impl Assembler {
mast_forest_builder: &mut MastForestBuilder,
) -> Result<Option<MastNodeId>, AssemblyError> {
let dyn_call_node_id = {
let dyn_node_id = mast_forest_builder.ensure_node(MastNode::Dyn);
let dyn_node_id = mast_forest_builder.ensure_node(MastNode::Dyn)?;
let dyn_call_node = MastNode::new_call(dyn_node_id, mast_forest_builder.forest());

mast_forest_builder.ensure_node(dyn_call_node)
mast_forest_builder.ensure_node(dyn_call_node)?
};

Ok(Some(dyn_call_node_id))
Expand Down
10 changes: 5 additions & 5 deletions assembly/src/assembler/mast_forest_builder.rs
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@ use core::ops::Index;
use alloc::collections::BTreeMap;
use vm_core::{
crypto::hash::RpoDigest,
mast::{MastForest, MastNode, MastNodeId, MerkleTreeNode},
mast::{MastForest, MastForestError, MastNode, MastNodeId, MerkleTreeNode},
};

/// Builder for a [`MastForest`].
Expand Down Expand Up @@ -44,17 +44,17 @@ impl MastForestBuilder {
/// If a [`MastNode`] which is equal to the current node was previously added, the previously
/// returned [`MastNodeId`] will be returned. This enforces this invariant that equal
/// [`MastNode`]s have equal [`MastNodeId`]s.
pub fn ensure_node(&mut self, node: MastNode) -> MastNodeId {
pub fn ensure_node(&mut self, node: MastNode) -> Result<MastNodeId, MastForestError> {
let node_digest = node.digest();

if let Some(node_id) = self.node_id_by_hash.get(&node_digest) {
// node already exists in the forest; return previously assigned id
*node_id
Ok(*node_id)
} else {
let new_node_id = self.mast_forest.add_node(node);
let new_node_id = self.mast_forest.add_node(node)?;
self.node_id_by_hash.insert(node_digest, new_node_id);

new_node_id
Ok(new_node_id)
}
}

Expand Down
26 changes: 14 additions & 12 deletions assembly/src/assembler/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -767,7 +767,7 @@ impl Assembler {
mast_forest_builder,
)? {
if let Some(basic_block_id) =
basic_block_builder.make_basic_block(mast_forest_builder)
basic_block_builder.make_basic_block(mast_forest_builder)?
{
mast_node_ids.push(basic_block_id);
}
Expand All @@ -780,7 +780,7 @@ impl Assembler {
then_blk, else_blk, ..
} => {
if let Some(basic_block_id) =
basic_block_builder.make_basic_block(mast_forest_builder)
basic_block_builder.make_basic_block(mast_forest_builder)?
{
mast_node_ids.push(basic_block_id);
}
Expand All @@ -794,14 +794,14 @@ impl Assembler {
let split_node =
MastNode::new_split(then_blk, else_blk, mast_forest_builder.forest());

mast_forest_builder.ensure_node(split_node)
mast_forest_builder.ensure_node(split_node)?
};
mast_node_ids.push(split_node_id);
}

Op::Repeat { count, body, .. } => {
if let Some(basic_block_id) =
basic_block_builder.make_basic_block(mast_forest_builder)
basic_block_builder.make_basic_block(mast_forest_builder)?
{
mast_node_ids.push(basic_block_id);
}
Expand All @@ -816,7 +816,7 @@ impl Assembler {

Op::While { body, .. } => {
if let Some(basic_block_id) =
basic_block_builder.make_basic_block(mast_forest_builder)
basic_block_builder.make_basic_block(mast_forest_builder)?
{
mast_node_ids.push(basic_block_id);
}
Expand All @@ -827,22 +827,24 @@ impl Assembler {
let loop_node_id = {
let loop_node =
MastNode::new_loop(loop_body_node_id, mast_forest_builder.forest());
mast_forest_builder.ensure_node(loop_node)
mast_forest_builder.ensure_node(loop_node)?
};
mast_node_ids.push(loop_node_id);
}
}
}

if let Some(basic_block_id) = basic_block_builder.into_basic_block(mast_forest_builder) {
if let Some(basic_block_id) =
basic_block_builder.try_into_basic_block(mast_forest_builder)?
{
mast_node_ids.push(basic_block_id);
}

Ok(if mast_node_ids.is_empty() {
let basic_block_node = MastNode::new_basic_block(vec![Operation::Noop]);
mast_forest_builder.ensure_node(basic_block_node)
mast_forest_builder.ensure_node(basic_block_node)?
} else {
combine_mast_node_ids(mast_node_ids, mast_forest_builder)
combine_mast_node_ids(mast_node_ids, mast_forest_builder)?
})
}

Expand Down Expand Up @@ -882,7 +884,7 @@ struct BodyWrapper {
fn combine_mast_node_ids(
mut mast_node_ids: Vec<MastNodeId>,
mast_forest_builder: &mut MastForestBuilder,
) -> MastNodeId {
) -> Result<MastNodeId, AssemblyError> {
debug_assert!(!mast_node_ids.is_empty(), "cannot combine empty MAST node id list");

// build a binary tree of blocks joining them using JOIN blocks
Expand All @@ -901,7 +903,7 @@ fn combine_mast_node_ids(
(source_mast_node_iter.next(), source_mast_node_iter.next())
{
let join_mast_node = MastNode::new_join(left, right, mast_forest_builder.forest());
let join_mast_node_id = mast_forest_builder.ensure_node(join_mast_node);
let join_mast_node_id = mast_forest_builder.ensure_node(join_mast_node)?;

mast_node_ids.push(join_mast_node_id);
}
Expand All @@ -910,5 +912,5 @@ fn combine_mast_node_ids(
}
}

mast_node_ids.remove(0)
Ok(mast_node_ids.remove(0))
}
Loading
Loading