Skip to content

Commit

Permalink
feat(core): Implement Multi forest iterator
Browse files Browse the repository at this point in the history
  • Loading branch information
PhilippGackstatter committed Oct 23, 2024
1 parent 1af7c77 commit 06cdb21
Show file tree
Hide file tree
Showing 5 changed files with 662 additions and 414 deletions.
297 changes: 140 additions & 157 deletions core/src/mast/merger/mod.rs
Original file line number Diff line number Diff line change
@@ -1,9 +1,10 @@
use alloc::{collections::BTreeMap, vec::Vec};
use core::ops::ControlFlow;

use miden_crypto::hash::{blake::Blake3Digest, rpo::RpoDigest};

use crate::mast::{DecoratorId, EqHash, MastForest, MastForestError, MastNode, MastNodeId};
use crate::mast::{
DecoratorId, EqHash, MastForest, MastForestError, MastNode, MastNodeId, MultiMastForestNodeIter,
};

#[cfg(test)]
mod tests;
Expand All @@ -17,43 +18,84 @@ pub(crate) struct MastForestMerger {
node_id_by_hash: BTreeMap<RpoDigest, Vec<(EqHash, MastNodeId)>>,
hash_by_node_id: BTreeMap<MastNodeId, EqHash>,
decorators_by_hash: BTreeMap<Blake3Digest<32>, DecoratorId>,
decorator_id_mappings: Vec<DecoratorIdMap>,
node_id_mappings: Vec<MastForestNodeIdMap>,
}

impl MastForestMerger {
/// Creates a new merger which creates a new internal, empty forest into which other
/// [`MastForest`]s are merged.
pub(crate) fn new() -> Self {
Self {
pub(crate) fn merge<'forest>(
forests: impl IntoIterator<Item = &'forest MastForest>,
) -> Result<(MastForest, Vec<MastForestRootMap>), MastForestError> {
let forests = forests.into_iter().collect::<Vec<_>>();
let decorator_id_mappings = Vec::with_capacity(forests.len());
let node_id_mappings = vec![MastForestNodeIdMap::new(); forests.len()];

let mut merger = Self {
node_id_by_hash: BTreeMap::new(),
hash_by_node_id: BTreeMap::new(),
decorators_by_hash: BTreeMap::new(),
mast_forest: MastForest::new(),
decorator_id_mappings,
node_id_mappings,
};

merger.merge_inner(forests.clone())?;

let Self { mast_forest, node_id_mappings, .. } = merger;

let mut root_maps = Vec::new();
for (forest_idx, mapping) in node_id_mappings.into_iter().enumerate() {
let forest = forests[forest_idx];
root_maps.push(MastForestRootMap::from_node_id_map(mapping, &forest.roots));
}

Ok((mast_forest, root_maps))
}

/// Merges `other_forest` into the forest contained in self.
pub(crate) fn merge(
fn merge_inner<'forest>(
&mut self,
other_forest: &MastForest,
) -> Result<MastForestRootMap, MastForestError> {
let mut decorator_id_remapping = DecoratorIdMap::new(other_forest.decorators.len());
let mut node_id_remapping = MastForestNodeIdMap::new();
forests: Vec<&'forest MastForest>,
) -> Result<(), MastForestError> {
for other_forest in forests.iter() {
self.merge_decorators(other_forest)?;
}

self.merge_decorators(other_forest, &mut decorator_id_remapping)?;
self.merge_nodes(other_forest, &decorator_id_remapping, &mut node_id_remapping)?;
self.merge_roots(other_forest, &node_id_remapping)?;
let iterator = MultiMastForestNodeIter::new(forests.clone());
for item in iterator {
match item {
super::MultiMastForestIteratorItem::Regular { forest_idx, node_id } => {
let node = &forests[forest_idx][node_id];
self.merge_node(forest_idx, node_id, node)?;
},
super::MultiMastForestIteratorItem::ExternalNodeReplacement {
replacement_forest_idx,
replacement_mast_node_id,
replaced_forest_idx,
replaced_mast_node_id,
} => {
let mapped_replacement = self.node_id_mappings[replacement_forest_idx]
.get(&replacement_mast_node_id)
.copied()
.expect("every node should be mapped");

self.node_id_mappings[replaced_forest_idx]
.insert(replaced_mast_node_id, mapped_replacement);
},
}
}

let root_map =
MastForestRootMap::from_node_id_map(node_id_remapping, other_forest.roots.as_slice());
for (forest_idx, forest) in forests.iter().enumerate() {
self.merge_roots(forest_idx, &forest)?;
}

Ok(root_map)
Ok(())
}

fn merge_decorators(
&mut self,
other_forest: &MastForest,
decorator_id_remapping: &mut DecoratorIdMap,
) -> Result<(), MastForestError> {
fn merge_decorators(&mut self, other_forest: &MastForest) -> Result<(), MastForestError> {
let mut decorator_id_remapping = DecoratorIdMap::new(other_forest.decorators.len());

for (merging_id, merging_decorator) in other_forest.decorators.iter().enumerate() {
let merging_decorator_hash = merging_decorator.eq_hash();
let new_decorator_id = if let Some(existing_decorator) =
Expand All @@ -70,54 +112,76 @@ impl MastForestMerger {
.insert(DecoratorId::new_unsafe(merging_id as u32), new_decorator_id);
}

self.decorator_id_mappings.push(decorator_id_remapping);

Ok(())
}

fn merge_nodes(
fn merge_node(
&mut self,
other_forest: &MastForest,
decorator_id_remapping: &DecoratorIdMap,
node_id_remapping: &mut MastForestNodeIdMap,
forest_idx: usize,
merging_id: MastNodeId,
node: &MastNode,
) -> Result<(), MastForestError> {
for (merging_id, node) in other_forest.iter_nodes() {
// We need to remap the node prior to computing the EqHash.
//
// This is because the EqHash computation looks up its descendants and decorators in
// the internal index, and if we were to pass the original node to that
// computation, it would look up the incorrect descendants and decorators (since the
// descendant's indices may have changed).
//
// Remapping at this point is guaranteed to be "complete", meaning all ids of children
// will be present in `node_id_remapping` since the DFS iteration guarantees
// that all children of this `node` have been processed before this node and
// their indices have been added to the mappings.
let remapped_node = self.remap_node(node, decorator_id_remapping, node_id_remapping)?;

let node_eq =
EqHash::from_mast_node(&self.mast_forest, &self.hash_by_node_id, &remapped_node);

match self.merge_external_nodes(
merging_id,
&node_eq,
&remapped_node,
node_id_remapping,
)? {
// Continue is interpreted as doing nothing.
ControlFlow::Continue(_) => (),
// Break is interpreted as continue in the loop sense.
ControlFlow::Break(_) => continue,
}
// We need to remap the node prior to computing the EqHash.
//
// This is because the EqHash computation looks up its descendants and decorators in
// the internal index, and if we were to pass the original node to that
// computation, it would look up the incorrect descendants and decorators (since the
// descendant's indices may have changed).
//
// Remapping at this point is guaranteed to be "complete", meaning all ids of children
// will be present in `node_id_remapping` since the DFS iteration guarantees
// that all children of this `node` have been processed before this node and
// their indices have been added to the mappings.
let remapped_node = self.remap_node(forest_idx, node)?;

let node_fingerprint =
EqHash::from_mast_node(&self.mast_forest, &self.hash_by_node_id, &remapped_node);

// If no node with a matching root exists, then the merging node is unique and we can add it
// to the merged forest.
let Some(matching_nodes) = self.lookup_all_nodes_by_root(&node_fingerprint.mast_root)
else {
return self.add_merged_node(forest_idx, merging_id, remapped_node, node_fingerprint);
};

// If an external node was previously replaced by the remapped node, this will detect
// them as duplicates here if their fingerprints match exactly and add the appropriate
// mapping from the merging id to the existing id.
match self.lookup_node_by_fingerprint(&node_eq) {
Some((_, existing_node_id)) => {
// We have to map any occurence of `merging_id` to `existing_node_id`.
node_id_remapping.insert(merging_id, *existing_node_id);
if remapped_node.is_external() {
// If there already is _any_ node with the same MAST root, map the merging
// external node to that existing one.
let (_, existing_external_node_id) = matching_nodes
.first()
.copied()
.expect("we should never insert empty entries in the internal index");
self.node_id_mappings[forest_idx].insert(merging_id, existing_external_node_id);
} else {
// It should never be the case that the MAST root of the merging node matches
// the referenced MAST root of an External node in the merged forest due to the
// preprocessing of external nodes.
debug_assert!(matching_nodes.into_iter().all(|(_, matching_node_id)| {
!self.mast_forest[*matching_node_id].is_external()
}));

match matching_nodes
.into_iter()
.find_map(|(matching_node_fingerprint, node_id)| {
if matching_node_fingerprint == &node_fingerprint {
Some(node_id)
} else {
None
}
})
.copied()
{
Some(matching_node_id) => {
// If a node with a matching fingerprint exists, then the merging node is a
// duplicate and we remap it to the existing node.
self.node_id_mappings[forest_idx].insert(merging_id, matching_node_id);
},
None => {
self.add_merged_node(merging_id, remapped_node, node_id_remapping, node_eq)?;
// If no node with a matching fingerprint exists, then the merging node is
// unique and we can add it to the merged forest.
self.add_merged_node(forest_idx, merging_id, remapped_node, node_fingerprint)?;
},
}
}
Expand All @@ -127,13 +191,14 @@ impl MastForestMerger {

fn merge_roots(
&mut self,
forest_idx: usize,
other_forest: &MastForest,
node_id_remapping: &MastForestNodeIdMap,
) -> Result<(), MastForestError> {
for root_id in other_forest.roots.iter() {
// Map the previous root to its possibly new id.
let new_root =
node_id_remapping.get(root_id).expect("all node ids should have an entry");
let new_root = self.node_id_mappings[forest_idx]
.get(root_id)
.expect("all node ids should have an entry");
// This will take O(n) every time to check if the root already exists.
// We could improve this by keeping a BTreeSet<MastNodeId> of existing roots during
// merging for a faster check.
Expand All @@ -145,13 +210,13 @@ impl MastForestMerger {

fn add_merged_node(
&mut self,
forest_idx: usize,
previous_id: MastNodeId,
node: MastNode,
node_id_remapping: &mut MastForestNodeIdMap,
node_eq: EqHash,
) -> Result<(), MastForestError> {
let new_node_id = self.mast_forest.add_node(node)?;
node_id_remapping.insert(previous_id, new_node_id);
self.node_id_mappings[forest_idx].insert(previous_id, new_node_id);

// We need to update the indices with the newly inserted nodes
// since the EqHash computation requires all descendants of a node
Expand All @@ -168,79 +233,23 @@ impl MastForestMerger {
Ok(())
}

/// This will handle two cases:
///
/// - The existing forest contains a node (external or non-external) with MAST root `foo` and
/// the merging External node refers to `foo`. In this case, the merging node will be mapped
/// to the existing node and dropped.
/// - The existing forest contains an External nodes with a MAST root `foo` and the non-external
/// merging node's digest is `foo`. In this case, the existing external node will be replaced
/// by the merging node.
///
/// Returns whether the caller should continue in their code path for this node or skip it.
fn merge_external_nodes(
&mut self,
previous_id: MastNodeId,
node_eq: &EqHash,
remapped_node: &MastNode,
node_id_remapping: &mut MastForestNodeIdMap,
) -> Result<ControlFlow<()>, MastForestError> {
if remapped_node.is_external() {
match self.lookup_node_by_root(&node_eq.mast_root) {
// If there already is any node with the same MAST root, map the merging external
// node to that existing one.
// This code path is also entered if the fingerprints match, so we can skip the
// general merging case by returning `Break`.
Some((_, existing_external_node_id)) => {
node_id_remapping.insert(previous_id, *existing_external_node_id);
Ok(ControlFlow::Break(()))
},
// If no duplicate for the external node exists do nothing as `merge_nodes`
// will simply add the node to the forest.
None => Ok(ControlFlow::Continue(())),
}
} else {
// Replace an external node in self with the given MAST root with the non-external
// node from the merging forest.
// Any node in the existing forest that pointed to the external node will
// have the same MAST root due to the semantics of external nodes.
match self.lookup_external_node_by_root(&node_eq.mast_root) {
Some((_, external_node_id)) => {
self.mast_forest[external_node_id] = remapped_node.clone();
node_id_remapping.insert(previous_id, external_node_id);
// The other branch of this function guarantees that no external and
// non-external node with the same MAST root exist in the
// merged forest, so if we found an external node with a
// given MAST root, it must be the only one in the merged
// forest, so we can skip the remainder of the `merge_nodes` code path.
Ok(ControlFlow::Break(()))
},
// If we did not find a matching node, we can continue in the `merge_nodes` code
// path.
None => Ok(ControlFlow::Continue(())),
}
}
}

/// Remaps a nodes' potentially contained children and decorators to their new IDs according to
/// the given maps.
fn remap_node(
&self,
node: &MastNode,
decorator_id_remapping: &DecoratorIdMap,
node_id_remapping: &MastForestNodeIdMap,
) -> Result<MastNode, MastForestError> {
fn remap_node(&self, forest_idx: usize, node: &MastNode) -> Result<MastNode, MastForestError> {
let map_decorator_id = |decorator_id: &DecoratorId| {
decorator_id_remapping.get(decorator_id).ok_or_else(|| {
MastForestError::DecoratorIdOverflow(*decorator_id, decorator_id_remapping.len())
self.decorator_id_mappings[forest_idx].get(decorator_id).ok_or_else(|| {
MastForestError::DecoratorIdOverflow(
*decorator_id,
self.decorator_id_mappings[forest_idx].len(),
)
})
};
let map_decorators = |decorators: &[DecoratorId]| -> Result<Vec<_>, MastForestError> {
decorators.iter().map(map_decorator_id).collect()
};

let map_node_id = |node_id: MastNodeId| {
node_id_remapping
self.node_id_mappings[forest_idx]
.get(&node_id)
.copied()
.expect("every node id should have an entry")
Expand Down Expand Up @@ -309,34 +318,8 @@ impl MastForestMerger {
// HELPERS
// ================================================================================================

fn lookup_node_by_fingerprint(&self, eq_hash: &EqHash) -> Option<&(EqHash, MastNodeId)> {
self.node_id_by_hash.get(&eq_hash.mast_root).and_then(|node_ids| {
node_ids.iter().find(|(node_fingerprint, _)| node_fingerprint == eq_hash)
})
}

fn lookup_node_by_root(&self, mast_root: &RpoDigest) -> Option<&(EqHash, MastNodeId)> {
self.node_id_by_hash.get(mast_root).and_then(|node_ids| node_ids.first())
}

fn lookup_external_node_by_root(&self, mast_root: &RpoDigest) -> Option<(EqHash, MastNodeId)> {
self.node_id_by_hash.get(mast_root).and_then(|ids| {
let mut iterator = ids
.iter()
.filter(|(_, node_id)| self.mast_forest[*node_id].is_external())
.copied();
let external_node = iterator.next();
// The merging implementation should guarantee that no two external nodes with the same
// MAST root exist.
debug_assert!(iterator.next().is_none());
external_node
})
}
}

impl From<MastForestMerger> for MastForest {
fn from(merger: MastForestMerger) -> Self {
merger.mast_forest
fn lookup_all_nodes_by_root(&self, mast_root: &RpoDigest) -> Option<&[(EqHash, MastNodeId)]> {
self.node_id_by_hash.get(mast_root).map(|node_ids| node_ids.as_slice())
}
}

Expand Down
Loading

0 comments on commit 06cdb21

Please sign in to comment.