diff --git a/core/src/mast/mod.rs b/core/src/mast/mod.rs index 4a3561172b..fffc1956fb 100644 --- a/core/src/mast/mod.rs +++ b/core/src/mast/mod.rs @@ -255,7 +255,7 @@ impl fmt::Display for MastNodeId { // ================================================================================================ /// Represents the types of errors that can occur when dealing with MAST forest. -#[derive(Debug, thiserror::Error)] +#[derive(Debug, thiserror::Error, PartialEq)] pub enum MastForestError { #[error( "invalid node count: MAST forest exceeds the maximum of {} nodes", diff --git a/core/src/mast/serialization/tests.rs b/core/src/mast/serialization/tests.rs index 1d3e2d90be..f2856848f5 100644 --- a/core/src/mast/serialization/tests.rs +++ b/core/src/mast/serialization/tests.rs @@ -3,7 +3,8 @@ use miden_crypto::{hash::rpo::RpoDigest, Felt}; use super::*; use crate::{ - operations::Operation, AdviceInjector, AssemblyOp, DebugOptions, Decorator, SignatureKind, + mast::MastForestError, operations::Operation, AdviceInjector, AssemblyOp, DebugOptions, + Decorator, SignatureKind, }; /// If this test fails to compile, it means that `Operation` or `Decorator` was changed. Make sure @@ -321,3 +322,40 @@ fn serialize_deserialize_all_nodes() { assert_eq!(mast_forest, deserialized_mast_forest); } + +#[test] +fn mast_forest_invalid_node_id() { + // Hydrate a forest smaller than the second + let mut forest = MastForest::new(); + let first = forest.add_block(vec![Operation::U32div], None).unwrap(); + let second = forest.add_block(vec![Operation::U32div], None).unwrap(); + + // Hydrate a forest larger than the first to get an overflow MastNodeId + let mut overflow_forest = MastForest::new(); + let overflow = (0..=3) + .map(|_| overflow_forest.add_block(vec![Operation::U32div], None).unwrap()) + .last() + .unwrap(); + + // Attempt to join with invalid ids + let join = forest.add_join(overflow, second); + assert_eq!(join, Err(MastForestError::InvalidNodeId(overflow))); + let join = forest.add_join(first, overflow); + assert_eq!(join, Err(MastForestError::InvalidNodeId(overflow))); + + // Attempt to split with invalid ids + let split = forest.add_split(overflow, second); + assert_eq!(split, Err(MastForestError::InvalidNodeId(overflow))); + let split = forest.add_split(first, overflow); + assert_eq!(split, Err(MastForestError::InvalidNodeId(overflow))); + + // Attempt to loop with invalid ids + assert_eq!(forest.add_loop(overflow), Err(MastForestError::InvalidNodeId(overflow))); + + // Attempt to call with invalid ids + assert_eq!(forest.add_call(overflow), Err(MastForestError::InvalidNodeId(overflow))); + assert_eq!(forest.add_syscall(overflow), Err(MastForestError::InvalidNodeId(overflow))); + + // Validate normal operations + forest.add_join(first, second).unwrap(); +}