Skip to content

Commit

Permalink
feat(core): Add test cases for external nodes with decorators
Browse files Browse the repository at this point in the history
  • Loading branch information
PhilippGackstatter committed Oct 18, 2024
1 parent e9ffd5d commit 4a97c2d
Show file tree
Hide file tree
Showing 2 changed files with 241 additions and 19 deletions.
86 changes: 67 additions & 19 deletions core/src/mast/merger/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -67,7 +67,7 @@ impl MastForestMerger {
/// Merges `other_forest` into the forest contained in self.
pub(crate) fn merge(&mut self, other_forest: &MastForest) -> Result<(), MastForestError> {
let mut decorator_id_remapping = ForestIdMap::new(other_forest.decorators.len());
let mut node_id_remapping = ForestIdMap::new(other_forest.nodes.len());
let mut node_id_remapping = MastForestIdMap::new();

self.merge_decorators(other_forest, &mut decorator_id_remapping)?;
self.merge_nodes(other_forest, &decorator_id_remapping, &mut node_id_remapping)?;
Expand Down Expand Up @@ -104,7 +104,7 @@ impl MastForestMerger {
&mut self,
other_forest: &MastForest,
decorator_id_remapping: &ForestIdMap<DecoratorId>,
node_id_remapping: &mut ForestIdMap<MastNodeId>,
node_id_remapping: &mut MastForestIdMap,
) -> Result<(), MastForestError> {
for (merging_id, node) in other_forest.iter_nodes() {
// We need to remap the node prior to computing the EqHash.
Expand All @@ -122,7 +122,13 @@ impl MastForestMerger {
let node_eq =
EqHash::from_mast_node(&self.mast_forest, &self.hash_by_node_id, &remapped_node);

self.merge_external_nodes(merging_id, &node_eq, &remapped_node, node_id_remapping)?;
self.merge_external_nodes(
merging_id,
&node_eq,
&remapped_node,
node_id_remapping,
decorator_id_remapping,
)?;

// If an external node was previously replaced by the remapped node, this will detect
// them as duplicates here if their fingerprints match exactly and add the appropriate
Expand All @@ -144,15 +150,16 @@ impl MastForestMerger {
fn merge_roots(
&mut self,
other_forest: &MastForest,
node_id_remapping: &ForestIdMap<MastNodeId>,
node_id_remapping: &MastForestIdMap,
) -> Result<(), MastForestError> {
for root_id in other_forest.roots.iter() {
// Map the previous root to its possibly new id.
let new_root = node_id_remapping.get(root_id);
let new_root =
node_id_remapping.get(root_id).expect("all node ids should have an entry");
// This will take O(n) every time to check if the root already exists.
// We could improve this by keeping a BTreeSet<MastNodeId> of existing roots during
// merging for a faster check.
self.mast_forest.make_root(new_root);
self.mast_forest.make_root(*new_root);
}

Ok(())
Expand All @@ -162,7 +169,7 @@ impl MastForestMerger {
&mut self,
previous_id: MastNodeId,
node: MastNode,
node_id_remapping: &mut ForestIdMap<MastNodeId>,
node_id_remapping: &mut MastForestIdMap,
node_eq: EqHash,
) -> Result<(), MastForestError> {
let new_node_id = self.mast_forest.add_node(node)?;
Expand Down Expand Up @@ -200,7 +207,8 @@ impl MastForestMerger {
previous_id: MastNodeId,
node_eq: &EqHash,
remapped_node: &MastNode,
node_id_remapping: &mut ForestIdMap<MastNodeId>,
node_id_remapping: &mut MastForestIdMap,
decorator_id_remapping: &ForestIdMap<DecoratorId>,
) -> Result<(), MastForestError> {
// Handle external node in the merging forest.
if remapped_node.is_external() {
Expand All @@ -212,6 +220,7 @@ impl MastForestMerger {
*node_eq,
&remapped_node,
*referenced_node_id,
decorator_id_remapping,
node_id_remapping,
)?;
},
Expand All @@ -232,6 +241,7 @@ impl MastForestMerger {
external_node_id,
&node_eq,
&remapped_node,
decorator_id_remapping,
)?;
}
}
Expand All @@ -247,13 +257,18 @@ impl MastForestMerger {
external_node_fingerprint: EqHash,
external_node: &MastNode,
referenced_node_id: MastNodeId,
node_id_remapping: &mut ForestIdMap<MastNodeId>,
decorator_id_remapping: &ForestIdMap<DecoratorId>,
node_id_remapping: &mut MastForestIdMap,
) -> Result<(), MastForestError> {
let referenced_node = &self.mast_forest[referenced_node_id];

let map_decorators = |decorators: &[DecoratorId]| {
decorators.iter().map(|deco| decorator_id_remapping.get(deco)).collect()
};

let new_node = self.merge_external_node_decorators(
external_node.before_enter(),
external_node.after_exit(),
map_decorators(external_node.before_enter()),
map_decorators(external_node.after_exit()),
referenced_node,
);

Expand All @@ -272,6 +287,7 @@ impl MastForestMerger {
external_node_id: MastNodeId,
merging_node_fingerprint: &EqHash,
merging_node: &MastNode,
decorator_id_remapping: &ForestIdMap<DecoratorId>,
) -> Result<(), MastForestError> {
// Special case: If the fingerprints match, we can replace directly.
// Note that the id mapping for the merging node will be updated in `merge_nodes`.
Expand All @@ -283,9 +299,13 @@ impl MastForestMerger {

let external_node = &self.mast_forest[external_node_id];

let map_decorators = |decorators: &[DecoratorId]| {
decorators.iter().map(|deco| decorator_id_remapping.get(deco)).collect()
};

let replacement_node = self.merge_external_node_decorators(
external_node.before_enter(),
external_node.after_exit(),
map_decorators(external_node.before_enter()),
map_decorators(external_node.before_enter()),
merging_node,
);

Expand All @@ -302,8 +322,8 @@ impl MastForestMerger {
/// - The reference node is an external node.
fn merge_external_node_decorators(
&self,
external_node_before_enter: &[DecoratorId],
external_node_after_exit: &[DecoratorId],
external_node_before_enter: Vec<DecoratorId>,
external_node_after_exit: Vec<DecoratorId>,
reference_node: &MastNode,
) -> MastNode {
let mut node = match reference_node {
Expand All @@ -317,8 +337,10 @@ impl MastForestMerger {
_ => reference_node.clone(),
};

node.set_before_enter(external_node_before_enter.to_vec());
node.set_after_exit(external_node_after_exit.to_vec());
node.set_before_enter(external_node_before_enter);
node.set_after_exit(external_node_after_exit);

// TODO: Call remap_node here instead of passing in mapped decorators.

node
}
Expand All @@ -329,13 +351,18 @@ impl MastForestMerger {
&self,
node: &MastNode,
decorator_id_remapping: &ForestIdMap<DecoratorId>,
node_id_remapping: &ForestIdMap<MastNodeId>,
node_id_remapping: &MastForestIdMap,
) -> MastNode {
let map_decorator_id =
|decorator_id: &DecoratorId| decorator_id_remapping.get(decorator_id);
let map_decorators =
|decorators: &[DecoratorId]| decorators.iter().map(map_decorator_id).collect();
let map_node_id = |node_id: MastNodeId| node_id_remapping.get(&node_id);
let map_node_id = |node_id: MastNodeId| {
node_id_remapping
.get(&node_id)
.copied()
.expect("every node id should have an entry")
};

// Due to DFS postorder iteration all children of node's should have been inserted before
// their parents which is why we can `expect` the constructor calls here.
Expand Down Expand Up @@ -440,6 +467,27 @@ impl From<MastForestMerger> for MastForest {
// MAST FOREST ID MAP
// ================================================================================================

pub struct MastForestIdMap {
map: BTreeMap<MastNodeId, MastNodeId>,
}

impl MastForestIdMap {
pub(crate) fn new() -> Self {
Self { map: BTreeMap::new() }
}

pub(crate) fn insert(&mut self, key: MastNodeId, value: MastNodeId) {
self.map.insert(key, value);
}

pub fn get(&self, key: &MastNodeId) -> Option<&MastNodeId> {
self.map.get(key)
}
}

// MAST FOREST ID MAP
// ================================================================================================

/// A specialized map from ID -> ID meant to be used with [`DecoratorId`] or [`MastNodeId`].
///
/// When mapping Decorator or Mast Node IDs during merging, we always map all IDs of the merging
Expand Down
174 changes: 174 additions & 0 deletions core/src/mast/merger/tests.rs
Original file line number Diff line number Diff line change
Expand Up @@ -334,3 +334,177 @@ fn mast_forest_merge_decorators() {
1
);
}

impl MastForest {
fn debug_print(&self) {
for (idx, node) in self.nodes().iter().enumerate() {
std::println!("Node {idx}\n{}\n", node.to_display(self));
}
}
}

/// TODO
///
/// [External(foo)]
/// +
/// [Block(foo, [Trace(1)])]
/// =
/// [Block(foo, [Trace(1)])]
/// +
/// [External(foo)]
/// =
/// [Block(foo, [Trace(1)])]
#[test]
fn mast_forest_merge_external_node_referenced_node_has_decorator() {
let mut forest_a = MastForest::new();
let trace = Decorator::Trace(1);

// Build Forest A
let deco = forest_a.add_decorator(trace.clone()).unwrap();

let mut foo_node_a = block_foo();
foo_node_a.set_before_enter(vec![deco]);
let foo_node_digest = foo_node_a.digest();
let id_foo_a = forest_a.add_node(foo_node_a).unwrap();

forest_a.make_root(id_foo_a);

// Build Forest B
let mut forest_b = MastForest::new();
let id_external_b = forest_b.add_external(foo_node_digest).unwrap();

forest_b.make_root(id_external_b);

for merged in [forest_a.merge(&forest_b).unwrap(), forest_b.merge(&forest_a).unwrap()] {
let id_foo_a_fingerprint =
EqHash::from_mast_node(&forest_a, &BTreeMap::new(), &forest_a[id_foo_a]);
let id_external_b_fingerprint =
EqHash::from_mast_node(&forest_b, &BTreeMap::new(), &forest_b[id_external_b]);

let fingerprints: Vec<_> = merged
.nodes()
.iter()
.map(|node| EqHash::from_mast_node(&merged, &BTreeMap::new(), node))
.collect();

assert_eq!(merged.nodes.len(), 2);
assert!(fingerprints.contains(&id_foo_a_fingerprint));
assert!(fingerprints.contains(&id_external_b_fingerprint));
}
}

#[test]
fn mast_forest_merge_external_node_has_decorator() {
let mut forest_a = MastForest::new();
let trace1 = Decorator::Trace(1);
let trace2 = Decorator::Trace(2);

// Build Forest A
let deco1 = forest_a.add_decorator(trace1.clone()).unwrap();
let deco2 = forest_a.add_decorator(trace2.clone()).unwrap();

let mut external_node_a = MastNode::new_external(block_foo().digest());
external_node_a.set_before_enter(vec![deco1]);
external_node_a.set_after_exit(vec![deco2]);
let id_external_a = forest_a.add_node(external_node_a).unwrap();

forest_a.make_root(id_external_a);

// Build Forest B
let mut forest_b = MastForest::new();
let id_foo_b = forest_b.add_node(block_foo()).unwrap();

forest_b.make_root(id_foo_b);

let foo_digest = block_foo().digest();

for merged in [forest_a.merge(&forest_b).unwrap(), forest_b.merge(&forest_a).unwrap()] {
assert_eq!(merged.nodes.len(), 2);
assert_eq!(
merged
.nodes()
.iter()
.filter(|node| {
let MastNode::Block(block) = node else {
panic!("expected only blocks");
};

block.digest() == foo_digest
&& block.decorators()
== &[(0, deco1), (block.num_operations() as usize, deco2)]
})
.count(),
1
);

let id_foo_b_fingerprint =
EqHash::from_mast_node(&forest_a, &BTreeMap::new(), &forest_b[id_foo_b]);

let fingerprints: Vec<_> = merged
.nodes()
.iter()
.map(|node| EqHash::from_mast_node(&merged, &BTreeMap::new(), node))
.collect();

// Block foo should be unmodified.
assert!(fingerprints.contains(&id_foo_b_fingerprint));
}
}

#[test]
fn mast_forest_merge_external_node_and_referenced_node_have_decorators() {
let mut forest_a = MastForest::new();
let trace1 = Decorator::Trace(1);
let trace2 = Decorator::Trace(2);

// Build Forest A
let deco1_a = forest_a.add_decorator(trace1.clone()).unwrap();

let mut external_node_a = MastNode::new_external(block_foo().digest());
external_node_a.set_before_enter(vec![deco1_a]);
let id_external_a = forest_a.add_node(external_node_a).unwrap();

forest_a.make_root(id_external_a);

// Build Forest B
let mut forest_b = MastForest::new();
let deco2_b = forest_a.add_decorator(trace2.clone()).unwrap();

let mut foo_node_b = block_foo();
foo_node_b.set_before_enter(vec![deco2_b]);
let id_foo_b = forest_b.add_node(foo_node_b).unwrap();

forest_b.make_root(id_foo_b);

let foo_digest = block_foo().digest();

for merged in [forest_a.merge(&forest_b).unwrap(), forest_b.merge(&forest_a).unwrap()] {
assert_eq!(merged.nodes.len(), 2);
assert_eq!(
merged
.nodes()
.iter()
.filter(|node| {
let MastNode::Block(block) = node else {
panic!("expected only blocks");
};

block.digest() == foo_digest && block.decorators() == &[(0, deco1_a)]
})
.count(),
1
);

let id_foo_b_fingerprint =
EqHash::from_mast_node(&forest_a, &BTreeMap::new(), &forest_b[id_foo_b]);

let fingerprints: Vec<_> = merged
.nodes()
.iter()
.map(|node| EqHash::from_mast_node(&merged, &BTreeMap::new(), node))
.collect();

// Block foo should be unmodified.
assert!(fingerprints.contains(&id_foo_b_fingerprint));
}
}

0 comments on commit 4a97c2d

Please sign in to comment.