Skip to content

Commit

Permalink
refactor: remove MerkleTreeNode trait
Browse files Browse the repository at this point in the history
  • Loading branch information
bobbinth committed Jul 20, 2024
1 parent 2c22ac1 commit f1c1282
Show file tree
Hide file tree
Showing 21 changed files with 109 additions and 109 deletions.
2 changes: 1 addition & 1 deletion assembly/src/assembler/mast_forest_builder.rs
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@ use core::ops::Index;
use alloc::{collections::BTreeMap, vec::Vec};
use vm_core::{
crypto::hash::RpoDigest,
mast::{MastForest, MastForestError, MastNode, MastNodeId, MerkleTreeNode},
mast::{MastForest, MastForestError, MastNode, MastNodeId},
DecoratorList, Operation,
};

Expand Down
2 changes: 1 addition & 1 deletion assembly/src/assembler/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@ use crate::{
use alloc::{boxed::Box, sync::Arc, vec::Vec};
use mast_forest_builder::MastForestBuilder;
use vm_core::{
mast::{MastForest, MastNodeId, MerkleTreeNode},
mast::{MastForest, MastNodeId},
Decorator, DecoratorList, Kernel, Operation, Program,
};

Expand Down
2 changes: 1 addition & 1 deletion assembly/src/assembler/procedure.rs
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@ use crate::{
diagnostics::SourceFile,
LibraryPath, RpoDigest, SourceSpan, Spanned,
};
use vm_core::mast::{MastForest, MastNodeId, MerkleTreeNode};
use vm_core::mast::{MastForest, MastNodeId};

pub type CallSet = BTreeSet<RpoDigest>;

Expand Down
8 changes: 2 additions & 6 deletions core/src/mast/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -15,12 +15,6 @@ mod serialization;
#[cfg(test)]
mod tests;

/// Encapsulates the behavior that a [`MastNode`] (and all its variants) is expected to have.
pub trait MerkleTreeNode {
fn digest(&self) -> RpoDigest;
fn to_display<'a>(&'a self, mast_forest: &'a MastForest) -> impl fmt::Display + 'a;
}

// MAST FOREST
// ================================================================================================

Expand Down Expand Up @@ -171,6 +165,8 @@ impl Deserializable for MastNodeId {
fn read_from<R: ByteReader>(source: &mut R) -> Result<Self, DeserializationError> {
let inner = source.read_u32()?;

// TODO: fix

Ok(Self(inner))
}
}
Expand Down
50 changes: 26 additions & 24 deletions core/src/mast/node/basic_block_node/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -6,9 +6,7 @@ use miden_formatting::prettier::PrettyPrint;
use winter_utils::flatten_slice_elements;

use crate::{
chiplets::hasher,
mast::{MastForest, MerkleTreeNode},
Decorator, DecoratorIterator, DecoratorList, Operation,
chiplets::hasher, mast::MastForest, Decorator, DecoratorIterator, DecoratorList, Operation,
};

#[cfg(test)]
Expand Down Expand Up @@ -67,12 +65,14 @@ pub struct BasicBlockNode {
decorators: DecoratorList,
}

// ------------------------------------------------------------------------------------------------
/// Constants
impl BasicBlockNode {
/// The domain of the basic block node (used for control block hashing).
pub const DOMAIN: Felt = ZERO;
}

// ------------------------------------------------------------------------------------------------
/// Constructors
impl BasicBlockNode {
/// Returns a new [`BasicBlockNode`] instantiated with the specified operations.
Expand Down Expand Up @@ -108,6 +108,7 @@ impl BasicBlockNode {
}
}

// ------------------------------------------------------------------------------------------------
/// Public accessors
impl BasicBlockNode {
pub fn num_operations_and_decorators(&self) -> u32 {
Expand Down Expand Up @@ -139,35 +140,18 @@ impl BasicBlockNode {
pub fn decorators(&self) -> &DecoratorList {
&self.decorators
}
}

impl MerkleTreeNode for BasicBlockNode {
fn digest(&self) -> RpoDigest {
pub fn digest(&self) -> RpoDigest {
self.digest
}

fn to_display<'a>(&'a self, _mast_forest: &'a MastForest) -> impl fmt::Display + 'a {
pub fn to_display<'a>(&'a self, _mast_forest: &'a MastForest) -> impl fmt::Display + 'a {
self
}
}

/// Checks if a given decorators list is valid (only checked in debug mode)
/// - Assert the decorator list is in ascending order.
/// - Assert the last op index in decorator list is less than or equal to the number of operations.
#[cfg(debug_assertions)]
fn validate_decorators(operations: &[Operation], decorators: &DecoratorList) {
if !decorators.is_empty() {
// check if decorator list is sorted
for i in 0..(decorators.len() - 1) {
debug_assert!(decorators[i + 1].0 >= decorators[i].0, "unsorted decorators list");
}
// assert the last index in decorator list is less than operations vector length
debug_assert!(
operations.len() >= decorators.last().expect("empty decorators list").0,
"last op index in decorator list should be less than or equal to the number of ops"
);
}
}
// PRETTY PRINTING
// ================================================================================================

impl PrettyPrint for BasicBlockNode {
#[rustfmt::skip]
Expand Down Expand Up @@ -515,3 +499,21 @@ pub fn get_span_op_group_count(op_batches: &[OpBatch]) -> usize {
let last_batch_num_groups = op_batches.last().expect("no last group").num_groups();
(op_batches.len() - 1) * BATCH_SIZE + last_batch_num_groups.next_power_of_two()
}

/// Checks if a given decorators list is valid (only checked in debug mode)
/// - Assert the decorator list is in ascending order.
/// - Assert the last op index in decorator list is less than or equal to the number of operations.
#[cfg(debug_assertions)]
fn validate_decorators(operations: &[Operation], decorators: &DecoratorList) {
if !decorators.is_empty() {
// check if decorator list is sorted
for i in 0..(decorators.len() - 1) {
debug_assert!(decorators[i + 1].0 >= decorators[i].0, "unsorted decorators list");
}
// assert the last index in decorator list is less than operations vector length
debug_assert!(
operations.len() >= decorators.last().expect("empty decorators list").0,
"last op index in decorator list should be less than or equal to the number of ops"
);
}
}
8 changes: 4 additions & 4 deletions core/src/mast/node/call_node.rs
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@ use miden_formatting::prettier::PrettyPrint;

use crate::{
chiplets::hasher,
mast::{MastForest, MastNodeId, MerkleTreeNode},
mast::{MastForest, MastNodeId},
OPCODE_CALL, OPCODE_SYSCALL,
};

Expand Down Expand Up @@ -87,12 +87,12 @@ impl CallNode {
}
}

impl MerkleTreeNode for CallNode {
fn digest(&self) -> RpoDigest {
impl CallNode {
pub fn digest(&self) -> RpoDigest {
self.digest
}

fn to_display<'a>(&'a self, mast_forest: &'a MastForest) -> impl fmt::Display + 'a {
pub fn to_display<'a>(&'a self, mast_forest: &'a MastForest) -> impl fmt::Display + 'a {
CallNodePrettyPrint {
call_node: self,
mast_forest,
Expand Down
11 changes: 4 additions & 7 deletions core/src/mast/node/dyn_node.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2,10 +2,7 @@ use core::fmt;

use miden_crypto::{hash::rpo::RpoDigest, Felt};

use crate::{
mast::{MastForest, MerkleTreeNode},
OPCODE_DYN,
};
use crate::{mast::MastForest, OPCODE_DYN};

#[derive(Debug, Clone, Default, PartialEq, Eq)]
pub struct DynNode;
Expand All @@ -16,8 +13,8 @@ impl DynNode {
pub const DOMAIN: Felt = Felt::new(OPCODE_DYN as u64);
}

impl MerkleTreeNode for DynNode {
fn digest(&self) -> RpoDigest {
impl DynNode {
pub fn digest(&self) -> RpoDigest {
// The Dyn node is represented by a constant, which is set to be the hash of two empty
// words ([ZERO, ZERO, ZERO, ZERO]) with a domain value of `DYN_DOMAIN`, i.e.
// hasher::merge_in_domain(&[Digest::default(), Digest::default()], DynNode::DOMAIN)
Expand All @@ -29,7 +26,7 @@ impl MerkleTreeNode for DynNode {
])
}

fn to_display<'a>(&'a self, _mast_forest: &'a MastForest) -> impl fmt::Display + 'a {
pub fn to_display<'a>(&'a self, _mast_forest: &'a MastForest) -> impl fmt::Display + 'a {
self
}
}
Expand Down
8 changes: 4 additions & 4 deletions core/src/mast/node/external.rs
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
use crate::mast::{MastForest, MerkleTreeNode};
use crate::mast::MastForest;
use core::fmt;
use miden_crypto::hash::rpo::RpoDigest;

Expand All @@ -24,11 +24,11 @@ impl ExternalNode {
}
}

impl MerkleTreeNode for ExternalNode {
fn digest(&self) -> RpoDigest {
impl ExternalNode {
pub fn digest(&self) -> RpoDigest {
self.digest
}
fn to_display<'a>(&'a self, _mast_forest: &'a MastForest) -> impl fmt::Display + 'a {
pub fn to_display<'a>(&'a self, _mast_forest: &'a MastForest) -> impl fmt::Display + 'a {
self
}
}
Expand Down
30 changes: 17 additions & 13 deletions core/src/mast/node/join_node.rs
Original file line number Diff line number Diff line change
Expand Up @@ -4,11 +4,14 @@ use miden_crypto::{hash::rpo::RpoDigest, Felt};

use crate::{
chiplets::hasher,
mast::{MastForest, MastNodeId, MerkleTreeNode},
mast::{MastForest, MastNodeId},
prettier::PrettyPrint,
OPCODE_JOIN,
};

// JOIN NODE
// ================================================================================================

#[derive(Debug, Clone, PartialEq, Eq)]
pub struct JoinNode {
children: [MastNodeId; 2],
Expand Down Expand Up @@ -41,7 +44,7 @@ impl JoinNode {
}
}

/// Accessors
/// Public accessors
impl JoinNode {
pub fn first(&self) -> MastNodeId {
self.children[0]
Expand All @@ -50,26 +53,27 @@ impl JoinNode {
pub fn second(&self) -> MastNodeId {
self.children[1]
}
}

impl JoinNode {
pub(super) fn to_pretty_print<'a>(
&'a self,
mast_forest: &'a MastForest,
) -> impl PrettyPrint + 'a {
pub fn digest(&self) -> RpoDigest {
self.digest
}

pub fn to_display<'a>(&'a self, mast_forest: &'a MastForest) -> impl fmt::Display + 'a {
JoinNodePrettyPrint {
join_node: self,
mast_forest,
}
}
}

impl MerkleTreeNode for JoinNode {
fn digest(&self) -> RpoDigest {
self.digest
}
// PRETTY PRINTING
// ================================================================================================

fn to_display<'a>(&'a self, mast_forest: &'a MastForest) -> impl fmt::Display + 'a {
impl JoinNode {
pub(super) fn to_pretty_print<'a>(
&'a self,
mast_forest: &'a MastForest,
) -> impl PrettyPrint + 'a {
JoinNodePrettyPrint {
join_node: self,
mast_forest,
Expand Down
36 changes: 21 additions & 15 deletions core/src/mast/node/loop_node.rs
Original file line number Diff line number Diff line change
Expand Up @@ -5,10 +5,13 @@ use miden_formatting::prettier::PrettyPrint;

use crate::{
chiplets::hasher,
mast::{MastForest, MastNodeId, MerkleTreeNode},
mast::{MastForest, MastNodeId},
OPCODE_LOOP,
};

// LOOP NODE
// ================================================================================================

#[derive(Debug, Clone, PartialEq, Eq)]
pub struct LoopNode {
body: MastNodeId,
Expand All @@ -32,30 +35,33 @@ impl LoopNode {

Self { body, digest }
}

pub(super) fn to_pretty_print<'a>(
&'a self,
mast_forest: &'a MastForest,
) -> impl PrettyPrint + 'a {
LoopNodePrettyPrint {
loop_node: self,
mast_forest,
}
}
}

impl LoopNode {
pub fn body(&self) -> MastNodeId {
self.body
}
}

impl MerkleTreeNode for LoopNode {
fn digest(&self) -> RpoDigest {
pub fn digest(&self) -> RpoDigest {
self.digest
}

fn to_display<'a>(&'a self, mast_forest: &'a MastForest) -> impl fmt::Display + 'a {
pub fn to_display<'a>(&'a self, mast_forest: &'a MastForest) -> impl fmt::Display + 'a {
LoopNodePrettyPrint {
loop_node: self,
mast_forest,
}
}
}

// PRETTY PRINTING
// ================================================================================================

impl LoopNode {
pub(super) fn to_pretty_print<'a>(
&'a self,
mast_forest: &'a MastForest,
) -> impl PrettyPrint + 'a {
LoopNodePrettyPrint {
loop_node: self,
mast_forest,
Expand Down
11 changes: 3 additions & 8 deletions core/src/mast/node/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@ mod loop_node;
pub use loop_node::LoopNode;

use crate::{
mast::{MastForest, MastNodeId, MerkleTreeNode},
mast::{MastForest, MastNodeId},
DecoratorList, Operation,
};

Expand Down Expand Up @@ -140,13 +140,8 @@ impl MastNode {
MastNode::External(_) => panic!("Can't fetch domain for an `External` node."),
}
}
}

// ------------------------------------------------------------------------------------------------
// MerkleTreeNode impl

impl MerkleTreeNode for MastNode {
fn digest(&self) -> RpoDigest {
pub fn digest(&self) -> RpoDigest {
match self {
MastNode::Block(node) => node.digest(),
MastNode::Join(node) => node.digest(),
Expand All @@ -158,7 +153,7 @@ impl MerkleTreeNode for MastNode {
}
}

fn to_display<'a>(&'a self, mast_forest: &'a MastForest) -> impl fmt::Display + 'a {
pub fn to_display<'a>(&'a self, mast_forest: &'a MastForest) -> impl fmt::Display + 'a {
match self {
MastNode::Block(node) => MastNodeDisplay::new(node.to_display(mast_forest)),
MastNode::Join(node) => MastNodeDisplay::new(node.to_display(mast_forest)),
Expand Down
Loading

0 comments on commit f1c1282

Please sign in to comment.