Skip to content

Commit

Permalink
enforce MAX_NODES in MastForest and return errors in add_node and enf…
Browse files Browse the repository at this point in the history
…orce_node
  • Loading branch information
sergerad committed Jul 16, 2024
1 parent 781fc73 commit c235257
Show file tree
Hide file tree
Showing 16 changed files with 180 additions and 156 deletions.
2 changes: 1 addition & 1 deletion assembly/src/assembler/basic_block_builder.rs
Original file line number Diff line number Diff line change
Expand Up @@ -135,7 +135,7 @@ impl BasicBlockBuilder {
let decorators = self.decorators.drain(..).collect();

let basic_block_node = MastNode::new_basic_block_with_decorators(ops, decorators);
let basic_block_node_id = mast_forest_builder.ensure_node(basic_block_node);
let basic_block_node_id = mast_forest_builder.ensure_node(basic_block_node).unwrap();

Some(basic_block_node_id)
} else if !self.decorators.is_empty() {
Expand Down
38 changes: 22 additions & 16 deletions assembly/src/assembler/instruction/procedures.rs
Original file line number Diff line number Diff line change
Expand Up @@ -91,37 +91,43 @@ impl Assembler {
// procedures, such that when we assemble a procedure, all
// procedures that it calls will have been assembled, and
// hence be present in the `MastForest`.
mast_forest_builder.find_procedure_root(mast_root).unwrap_or_else(|| {
if let Some(root) = mast_forest_builder.find_procedure_root(mast_root) {
root
} else {
// If the MAST root called isn't known to us, make it an external
// reference.
let external_node = MastNode::new_external(mast_root);
mast_forest_builder.ensure_node(external_node)
})
mast_forest_builder.ensure_node(external_node)?
}
}
InvokeKind::Call => {
let callee_id =
mast_forest_builder.find_procedure_root(mast_root).unwrap_or_else(|| {
let callee_id = match mast_forest_builder.find_procedure_root(mast_root) {
None => {
// If the MAST root called isn't known to us, make it an external
// reference.
let external_node = MastNode::new_external(mast_root);
mast_forest_builder.ensure_node(external_node)
});
mast_forest_builder.ensure_node(external_node)?
}
Some(callee_id) => callee_id,
};

let call_node = MastNode::new_call(callee_id, mast_forest_builder.forest());
mast_forest_builder.ensure_node(call_node)
mast_forest_builder.ensure_node(call_node)?
}
InvokeKind::SysCall => {
let callee_id =
mast_forest_builder.find_procedure_root(mast_root).unwrap_or_else(|| {
let callee_id = match mast_forest_builder.find_procedure_root(mast_root) {
None => {
// If the MAST root called isn't known to us, make it an external
// reference.
let external_node = MastNode::new_external(mast_root);
mast_forest_builder.ensure_node(external_node)
});
mast_forest_builder.ensure_node(external_node)?
}
Some(callee_id) => callee_id,
};

let syscall_node =
MastNode::new_syscall(callee_id, mast_forest_builder.forest());
mast_forest_builder.ensure_node(syscall_node)
mast_forest_builder.ensure_node(syscall_node)?
}
}
};
Expand All @@ -134,7 +140,7 @@ impl Assembler {
&self,
mast_forest_builder: &mut MastForestBuilder,
) -> Result<Option<MastNodeId>, AssemblyError> {
let dyn_node_id = mast_forest_builder.ensure_node(MastNode::Dyn);
let dyn_node_id = mast_forest_builder.ensure_node(MastNode::Dyn)?;

Ok(Some(dyn_node_id))
}
Expand All @@ -145,10 +151,10 @@ impl Assembler {
mast_forest_builder: &mut MastForestBuilder,
) -> Result<Option<MastNodeId>, AssemblyError> {
let dyn_call_node_id = {
let dyn_node_id = mast_forest_builder.ensure_node(MastNode::Dyn);
let dyn_node_id = mast_forest_builder.ensure_node(MastNode::Dyn)?;
let dyn_call_node = MastNode::new_call(dyn_node_id, mast_forest_builder.forest());

mast_forest_builder.ensure_node(dyn_call_node)
mast_forest_builder.ensure_node(dyn_call_node)?
};

Ok(Some(dyn_call_node_id))
Expand Down
10 changes: 5 additions & 5 deletions assembly/src/assembler/mast_forest_builder.rs
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@ use core::ops::Index;
use alloc::collections::BTreeMap;
use vm_core::{
crypto::hash::RpoDigest,
mast::{MastForest, MastNode, MastNodeId, MerkleTreeNode},
mast::{MastForest, MastForestError, MastNode, MastNodeId, MerkleTreeNode},
};

/// Builder for a [`MastForest`].
Expand Down Expand Up @@ -44,17 +44,17 @@ impl MastForestBuilder {
/// If a [`MastNode`] which is equal to the current node was previously added, the previously
/// returned [`MastNodeId`] will be returned. This enforces this invariant that equal
/// [`MastNode`]s have equal [`MastNodeId`]s.
pub fn ensure_node(&mut self, node: MastNode) -> MastNodeId {
pub fn ensure_node(&mut self, node: MastNode) -> Result<MastNodeId, MastForestError> {
let node_digest = node.digest();

if let Some(node_id) = self.node_id_by_hash.get(&node_digest) {
// node already exists in the forest; return previously assigned id
*node_id
Ok(*node_id)
} else {
let new_node_id = self.mast_forest.add_node(node);
let new_node_id = self.mast_forest.add_node(node)?;
self.node_id_by_hash.insert(node_digest, new_node_id);

new_node_id
Ok(new_node_id)
}
}

Expand Down
14 changes: 7 additions & 7 deletions assembly/src/assembler/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -796,7 +796,7 @@ impl Assembler {
let split_node =
MastNode::new_split(then_blk, else_blk, mast_forest_builder.forest());

mast_forest_builder.ensure_node(split_node)
mast_forest_builder.ensure_node(split_node).map_err(AssemblyError::from)?
};
mast_node_ids.push(split_node_id);
}
Expand Down Expand Up @@ -829,7 +829,7 @@ impl Assembler {
let loop_node_id = {
let loop_node =
MastNode::new_loop(loop_body_node_id, mast_forest_builder.forest());
mast_forest_builder.ensure_node(loop_node)
mast_forest_builder.ensure_node(loop_node).map_err(AssemblyError::from)?
};
mast_node_ids.push(loop_node_id);
}
Expand All @@ -842,9 +842,9 @@ impl Assembler {

Ok(if mast_node_ids.is_empty() {
let basic_block_node = MastNode::new_basic_block(vec![Operation::Noop]);
mast_forest_builder.ensure_node(basic_block_node)
mast_forest_builder.ensure_node(basic_block_node).map_err(AssemblyError::from)?
} else {
combine_mast_node_ids(mast_node_ids, mast_forest_builder)
combine_mast_node_ids(mast_node_ids, mast_forest_builder)?
})
}

Expand Down Expand Up @@ -884,7 +884,7 @@ struct BodyWrapper {
fn combine_mast_node_ids(
mut mast_node_ids: Vec<MastNodeId>,
mast_forest_builder: &mut MastForestBuilder,
) -> MastNodeId {
) -> Result<MastNodeId, AssemblyError> {
debug_assert!(!mast_node_ids.is_empty(), "cannot combine empty MAST node id list");

// build a binary tree of blocks joining them using JOIN blocks
Expand All @@ -903,7 +903,7 @@ fn combine_mast_node_ids(
(source_mast_node_iter.next(), source_mast_node_iter.next())
{
let join_mast_node = MastNode::new_join(left, right, mast_forest_builder.forest());
let join_mast_node_id = mast_forest_builder.ensure_node(join_mast_node);
let join_mast_node_id = mast_forest_builder.ensure_node(join_mast_node)?;

mast_node_ids.push(join_mast_node_id);
}
Expand All @@ -912,5 +912,5 @@ fn combine_mast_node_ids(
}
}

mast_node_ids.remove(0)
Ok(mast_node_ids.remove(0))
}
49 changes: 25 additions & 24 deletions assembly/src/assembler/tests.rs
Original file line number Diff line number Diff line change
Expand Up @@ -82,11 +82,11 @@ fn nested_blocks() {
// `Assembler::with_kernel_from_module()`.
let syscall_foo_node_id = {
let kernel_foo_node = MastNode::new_basic_block(vec![Operation::Add]);
let kernel_foo_node_id = expected_mast_forest_builder.ensure_node(kernel_foo_node);
let kernel_foo_node_id = expected_mast_forest_builder.ensure_node(kernel_foo_node).unwrap();

let syscall_node =
MastNode::new_syscall(kernel_foo_node_id, expected_mast_forest_builder.forest());
expected_mast_forest_builder.ensure_node(syscall_node)
expected_mast_forest_builder.ensure_node(syscall_node).unwrap()
};

let program = r#"
Expand Down Expand Up @@ -130,63 +130,63 @@ fn nested_blocks() {
let exec_bar_node_id = {
// bar procedure
let basic_block_1 = MastNode::new_basic_block(vec![Operation::Push(17_u32.into())]);
let basic_block_1_id = expected_mast_forest_builder.ensure_node(basic_block_1);
let basic_block_1_id = expected_mast_forest_builder.ensure_node(basic_block_1).unwrap();

// Basic block representing the `foo` procedure
let basic_block_2 = MastNode::new_basic_block(vec![Operation::Push(19_u32.into())]);
let basic_block_2_id = expected_mast_forest_builder.ensure_node(basic_block_2);
let basic_block_2_id = expected_mast_forest_builder.ensure_node(basic_block_2).unwrap();

let join_node = MastNode::new_join(
basic_block_1_id,
basic_block_2_id,
expected_mast_forest_builder.forest(),
);
expected_mast_forest_builder.ensure_node(join_node)
expected_mast_forest_builder.ensure_node(join_node).unwrap()
};

let exec_foo_bar_baz_node_id = {
// basic block representing foo::bar.baz procedure
let basic_block = MastNode::new_basic_block(vec![Operation::Push(29_u32.into())]);
expected_mast_forest_builder.ensure_node(basic_block)
expected_mast_forest_builder.ensure_node(basic_block).unwrap()
};

let before = {
let before_node = MastNode::new_basic_block(vec![Operation::Push(2u32.into())]);
expected_mast_forest_builder.ensure_node(before_node)
expected_mast_forest_builder.ensure_node(before_node).unwrap()
};

let r#true1 = {
let r#true_node = MastNode::new_basic_block(vec![Operation::Push(3u32.into())]);
expected_mast_forest_builder.ensure_node(r#true_node)
expected_mast_forest_builder.ensure_node(r#true_node).unwrap()
};
let r#false1 = {
let r#false_node = MastNode::new_basic_block(vec![Operation::Push(5u32.into())]);
expected_mast_forest_builder.ensure_node(r#false_node)
expected_mast_forest_builder.ensure_node(r#false_node).unwrap()
};
let r#if1 = {
let r#if_node =
MastNode::new_split(r#true1, r#false1, expected_mast_forest_builder.forest());
expected_mast_forest_builder.ensure_node(r#if_node)
expected_mast_forest_builder.ensure_node(r#if_node).unwrap()
};

let r#true3 = {
let r#true_node = MastNode::new_basic_block(vec![Operation::Push(7u32.into())]);
expected_mast_forest_builder.ensure_node(r#true_node)
expected_mast_forest_builder.ensure_node(r#true_node).unwrap()
};
let r#false3 = {
let r#false_node = MastNode::new_basic_block(vec![Operation::Push(11u32.into())]);
expected_mast_forest_builder.ensure_node(r#false_node)
expected_mast_forest_builder.ensure_node(r#false_node).unwrap()
};
let r#true2 = {
let r#if_node =
MastNode::new_split(r#true3, r#false3, expected_mast_forest_builder.forest());
expected_mast_forest_builder.ensure_node(r#if_node)
expected_mast_forest_builder.ensure_node(r#if_node).unwrap()
};

let r#while = {
let push_basic_block_id = {
let push_basic_block = MastNode::new_basic_block(vec![Operation::Push(23u32.into())]);
expected_mast_forest_builder.ensure_node(push_basic_block)
expected_mast_forest_builder.ensure_node(push_basic_block).unwrap()
};
let body_node_id = {
let body_node = MastNode::new_join(
Expand All @@ -195,15 +195,15 @@ fn nested_blocks() {
expected_mast_forest_builder.forest(),
);

expected_mast_forest_builder.ensure_node(body_node)
expected_mast_forest_builder.ensure_node(body_node).unwrap()
};

let loop_node = MastNode::new_loop(body_node_id, expected_mast_forest_builder.forest());
expected_mast_forest_builder.ensure_node(loop_node)
expected_mast_forest_builder.ensure_node(loop_node).unwrap()
};
let push_13_basic_block_id = {
let node = MastNode::new_basic_block(vec![Operation::Push(13u32.into())]);
expected_mast_forest_builder.ensure_node(node)
expected_mast_forest_builder.ensure_node(node).unwrap()
};

let r#false2 = {
Expand All @@ -212,17 +212,18 @@ fn nested_blocks() {
r#while,
expected_mast_forest_builder.forest(),
);
expected_mast_forest_builder.ensure_node(node)
expected_mast_forest_builder.ensure_node(node).unwrap()
};
let nested = {
let node = MastNode::new_split(r#true2, r#false2, expected_mast_forest_builder.forest());
expected_mast_forest_builder.ensure_node(node)
expected_mast_forest_builder.ensure_node(node).unwrap()
};

let combined_node_id = combine_mast_node_ids(
vec![before, r#if1, nested, exec_foo_bar_baz_node_id, syscall_foo_node_id],
&mut expected_mast_forest_builder,
);
)
.unwrap();

let expected_program = Program::new(expected_mast_forest_builder.build(), combined_node_id);
assert_eq!(expected_program.hash(), program.hash());
Expand Down Expand Up @@ -281,26 +282,26 @@ fn duplicate_nodes() {
// basic block: mul
let mul_basic_block_id = {
let node = MastNode::new_basic_block(vec![Operation::Mul]);
expected_mast_forest.add_node(node)
expected_mast_forest.add_node(node).unwrap()
};

// basic block: add
let add_basic_block_id = {
let node = MastNode::new_basic_block(vec![Operation::Add]);
expected_mast_forest.add_node(node)
expected_mast_forest.add_node(node).unwrap()
};

// inner split: `if.true add else mul end`
let inner_split_id = {
let node =
MastNode::new_split(add_basic_block_id, mul_basic_block_id, &expected_mast_forest);
expected_mast_forest.add_node(node)
expected_mast_forest.add_node(node).unwrap()
};

// root: outer split
let root_id = {
let node = MastNode::new_split(mul_basic_block_id, inner_split_id, &expected_mast_forest);
expected_mast_forest.add_node(node)
expected_mast_forest.add_node(node).unwrap()
};
expected_mast_forest.make_root(root_id);

Expand Down
4 changes: 4 additions & 0 deletions assembly/src/errors.rs
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
use alloc::{string::String, sync::Arc, vec::Vec};
use vm_core::mast::MastForestError;

use crate::{
ast::{FullyQualifiedProcedureName, ProcedureName},
Expand Down Expand Up @@ -135,7 +136,10 @@ pub enum AssemblyError {
#[error(transparent)]
#[diagnostic(transparent)]
Other(#[from] RelatedError),
#[error(transparent)]
Forest(#[from] MastForestError),
}

impl From<Report> for AssemblyError {
fn from(report: Report) -> Self {
Self::Other(RelatedError::new(report))
Expand Down
27 changes: 19 additions & 8 deletions core/src/mast/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -77,6 +77,16 @@ impl Deserializable for MastNodeId {
// MAST FOREST
// ================================================================================================

/// Represents the types of errors that can occur when dealing with MAST forest.
#[derive(Debug, thiserror::Error, PartialEq, Eq)]
pub enum MastForestError {
#[error(
"invalid node count: MAST forest exceeds the maximum of {} nodes",
MastForest::MAX_NODES
)]
InvalidLength,
}

/// Represents one or more procedures, represented as a collection of [`MastNode`]s.
///
/// A [`MastForest`] does not have an entrypoint, and hence is not executable. A [`crate::Program`]
Expand All @@ -100,20 +110,21 @@ impl MastForest {

/// Mutators
impl MastForest {
/// The maximum number of nodes that can be stored in a single MAST forest.
const MAX_NODES: usize = 1 << 30;

/// Adds a node to the forest, and returns the associated [`MastNodeId`].
///
/// Adding two duplicate nodes will result in two distinct returned [`MastNodeId`]s.
pub fn add_node(&mut self, node: MastNode) -> MastNodeId {
let new_node_id = MastNodeId(
self.nodes
.len()
.try_into()
.expect("invalid node id: exceeded maximum number of nodes in a single forest"),
);
pub fn add_node(&mut self, node: MastNode) -> Result<MastNodeId, MastForestError> {
if self.nodes.len() == Self::MAX_NODES {
return Err(MastForestError::InvalidLength);
}

let new_node_id = MastNodeId(self.nodes.len() as u32);
self.nodes.push(node);

new_node_id
Ok(new_node_id)
}

/// Marks the given [`MastNodeId`] as being the root of a procedure.
Expand Down
Loading

0 comments on commit c235257

Please sign in to comment.