Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Permit child MastNodeIds to exceed the MastNodeIds of their parents #1542

Open
wants to merge 1 commit into
base: next
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
26 changes: 26 additions & 0 deletions core/src/mast/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -466,6 +466,32 @@ impl MastNodeId {
}
}

/// Returns a new [`MastNodeId`] with the provided `id`, or an error if `id` is greater or equal
/// to `node_count`.
///
/// This function can be used when deserializing an id whose corresponding node is not yet in
/// the forest and [`Self::from_u32_safe`] would fail. For instance, when deserializing the ids
/// referenced by the Join node in this forest:
///
/// ```text
/// [Join(1, 2), Block(foo), Block(bar)]
/// ```
///
/// Since it is less safe than [`Self::from_u32_safe`] and usually not needed it is not public.
pub(crate) fn from_u32_with_node_count(
id: u32,
node_count: usize,
) -> Result<Self, DeserializationError> {
if (id as usize) < node_count {
Ok(Self(id))
} else {
Err(DeserializationError::InvalidValue(format!(
"Invalid deserialized MAST node ID '{}', but {} is the number of nodes in the forest",
id, node_count,
)))
}
}

pub fn as_usize(&self) -> usize {
self.0 as usize
}
Expand Down
17 changes: 9 additions & 8 deletions core/src/mast/serialization/info.rs
Original file line number Diff line number Diff line change
Expand Up @@ -44,7 +44,8 @@ impl MastNodeInfo {

pub fn try_into_mast_node(
self,
mast_forest: &mut MastForest,
mast_forest: &MastForest,
node_count: usize,
basic_block_data_decoder: &BasicBlockDataDecoder,
) -> Result<MastNode, DeserializationError> {
match self.ty {
Expand All @@ -59,29 +60,29 @@ impl MastNodeInfo {
Ok(MastNode::Block(block))
},
MastNodeType::Join { left_child_id, right_child_id } => {
let left_child = MastNodeId::from_u32_safe(left_child_id, mast_forest)?;
let right_child = MastNodeId::from_u32_safe(right_child_id, mast_forest)?;
let left_child = MastNodeId::from_u32_with_node_count(left_child_id, node_count)?;
let right_child = MastNodeId::from_u32_with_node_count(right_child_id, node_count)?;
let join = JoinNode::new_unsafe([left_child, right_child], self.digest);
Ok(MastNode::Join(join))
},
MastNodeType::Split { if_branch_id, else_branch_id } => {
let if_branch = MastNodeId::from_u32_safe(if_branch_id, mast_forest)?;
let else_branch = MastNodeId::from_u32_safe(else_branch_id, mast_forest)?;
let if_branch = MastNodeId::from_u32_with_node_count(if_branch_id, node_count)?;
let else_branch = MastNodeId::from_u32_with_node_count(else_branch_id, node_count)?;
let split = SplitNode::new_unsafe([if_branch, else_branch], self.digest);
Ok(MastNode::Split(split))
},
MastNodeType::Loop { body_id } => {
let body_id = MastNodeId::from_u32_safe(body_id, mast_forest)?;
let body_id = MastNodeId::from_u32_with_node_count(body_id, node_count)?;
let loop_node = LoopNode::new_unsafe(body_id, self.digest);
Ok(MastNode::Loop(loop_node))
},
MastNodeType::Call { callee_id } => {
let callee_id = MastNodeId::from_u32_safe(callee_id, mast_forest)?;
let callee_id = MastNodeId::from_u32_with_node_count(callee_id, node_count)?;
let call = CallNode::new_unsafe(callee_id, self.digest);
Ok(MastNode::Call(call))
},
MastNodeType::SysCall { callee_id } => {
let callee_id = MastNodeId::from_u32_safe(callee_id, mast_forest)?;
let callee_id = MastNodeId::from_u32_with_node_count(callee_id, node_count)?;
let syscall = CallNode::new_syscall_unsafe(callee_id, self.digest);
Ok(MastNode::Call(syscall))
},
Expand Down
7 changes: 5 additions & 2 deletions core/src/mast/serialization/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -209,8 +209,11 @@ impl Deserializable for MastForest {
for _ in 0..node_count {
let mast_node_info = MastNodeInfo::read_from(source)?;

let node = mast_node_info
.try_into_mast_node(&mut mast_forest, &basic_block_data_decoder)?;
let node = mast_node_info.try_into_mast_node(
&mast_forest,
node_count,
&basic_block_data_decoder,
)?;

mast_forest.add_node(node).map_err(|e| {
DeserializationError::InvalidValue(format!(
Expand Down
45 changes: 45 additions & 0 deletions core/src/mast/serialization/tests.rs
Original file line number Diff line number Diff line change
Expand Up @@ -347,6 +347,51 @@ fn serialize_deserialize_all_nodes() {
assert_eq!(mast_forest, deserialized_mast_forest);
}

/// Test that a forest with a node whose child ids are larger than its own id serializes and
/// deserializes successfully.
#[test]
fn mast_forest_serialize_deserialize_with_child_ids_exceeding_parent_id() {
let mut forest = MastForest::new();
let deco0 = forest.add_decorator(Decorator::Trace(0)).unwrap();
let deco1 = forest.add_decorator(Decorator::Trace(1)).unwrap();
let zero = forest.add_block(vec![Operation::U32div], None).unwrap();
let first = forest.add_block(vec![Operation::U32add], Some(vec![(0, deco0)])).unwrap();
let second = forest.add_block(vec![Operation::U32and], Some(vec![(1, deco1)])).unwrap();
forest.add_join(first, second).unwrap();

// Move the Join node before its child nodes and remove the temporary zero node.
forest.nodes.swap_remove(zero.as_usize());

MastForest::read_from_bytes(&forest.to_bytes()).unwrap();
}

/// Test that a forest with a node whose referenced index is >= the max number of nodes in
/// the forest returns an error during deserialization.
#[test]
fn mast_forest_serialize_deserialize_with_overflowing_ids_fails() {
let mut overflow_forest = MastForest::new();
let id0 = overflow_forest.add_block(vec![Operation::Eqz], None).unwrap();
overflow_forest.add_block(vec![Operation::Eqz], None).unwrap();
let id2 = overflow_forest.add_block(vec![Operation::Eqz], None).unwrap();
let id_join = overflow_forest.add_join(id0, id2).unwrap();

let join_node = overflow_forest[id_join].clone();

// Add the Join(0, 2) to this forest which does not have a node with index 2.
let mut forest = MastForest::new();
let deco0 = forest.add_decorator(Decorator::Trace(0)).unwrap();
let deco1 = forest.add_decorator(Decorator::Trace(1)).unwrap();
forest
.add_block(vec![Operation::U32add], Some(vec![(0, deco0), (1, deco1)]))
.unwrap();
forest.add_node(join_node).unwrap();

assert_matches!(
MastForest::read_from_bytes(&forest.to_bytes()),
Err(DeserializationError::InvalidValue(msg)) if msg.contains("number of nodes")
);
}

#[test]
fn mast_forest_invalid_node_id() {
// Hydrate a forest smaller than the second
Expand Down
Loading