From c00fef2615e889cbec7b0b43f6e67aedb81741b4 Mon Sep 17 00:00:00 2001 From: Philippe Laferriere Date: Thu, 15 Aug 2024 10:05:46 -0400 Subject: [PATCH] feat: implement dead code elimination --- .../src/assembler/dead_code_elimination.rs | 155 ++++++++++++++++++ assembly/src/assembler/mod.rs | 31 ++-- assembly/src/tests.rs | 3 + core/src/mast/mod.rs | 20 +++ 4 files changed, 198 insertions(+), 11 deletions(-) create mode 100644 assembly/src/assembler/dead_code_elimination.rs diff --git a/assembly/src/assembler/dead_code_elimination.rs b/assembly/src/assembler/dead_code_elimination.rs new file mode 100644 index 000000000..36e15268a --- /dev/null +++ b/assembly/src/assembler/dead_code_elimination.rs @@ -0,0 +1,155 @@ +use alloc::{ + collections::{BTreeMap, BTreeSet}, + vec::Vec, +}; + +use vm_core::mast::{MastForest, MastNode, MastNodeId}; + +/// Returns a `MastForest` where all nodes that are unreachable from all procedures are removed. +/// +/// It also returns the map from old node IDs to new node IDs; or `None` if the `MastForest` was +/// unchanged. Any [`MastNodeId`] used in reference to the old [`MastForest`] should be remapped +/// using this map. +pub fn dead_code_elimination( + mast_forest: MastForest, +) -> (MastForest, Option>) { + let live_ids = compute_live_ids(&mast_forest); + if live_ids.len() == mast_forest.num_nodes() as usize { + return (mast_forest, None); + } + + let (old_nodes, old_roots) = mast_forest.into_parts(); + let (live_nodes, id_remappings) = prune_nodes(old_nodes, live_ids); + + ( + build_pruned_mast_forest(live_nodes, old_roots, &id_remappings), + Some(id_remappings), + ) +} + +/// Compute all [`MastNodeId`]s that are "live"; that is, accessed by at least one procedure in the +/// MAST forest. +fn compute_live_ids(mast_forest: &MastForest) -> BTreeSet { + let mut live_ids = BTreeSet::new(); + + for &procedure_root in mast_forest.procedure_roots() { + compute_live_ids_for_node(procedure_root, mast_forest, &mut live_ids); + } + + live_ids +} + +fn compute_live_ids_for_node( + mast_node_id: MastNodeId, + mast_forest: &MastForest, + live_ids: &mut BTreeSet, +) { + live_ids.insert(mast_node_id); + + match &mast_forest[mast_node_id] { + MastNode::Join(node) => { + compute_live_ids_for_node(node.first(), mast_forest, live_ids); + compute_live_ids_for_node(node.second(), mast_forest, live_ids); + }, + MastNode::Split(node) => { + compute_live_ids_for_node(node.on_true(), mast_forest, live_ids); + compute_live_ids_for_node(node.on_false(), mast_forest, live_ids); + }, + MastNode::Loop(node) => { + compute_live_ids_for_node(node.body(), mast_forest, live_ids); + }, + MastNode::Call(node) => { + compute_live_ids_for_node(node.callee(), mast_forest, live_ids); + }, + MastNode::Block(_) | MastNode::Dyn | MastNode::External(_) => (), + } +} + +/// Returns the set of nodes that are live, as well as the mapping from "old ID" to "new ID" for all +/// live nodes. +fn prune_nodes( + mast_nodes: Vec, + live_ids: BTreeSet, +) -> (Vec, BTreeMap) { + // Note: this allows us to safely use `usize as u32`, guaranteeing that it won't wrap around. + assert!(mast_nodes.len() < u32::MAX as usize); + + let mut pruned_nodes = Vec::with_capacity(mast_nodes.len()); + let mut id_remappings = BTreeMap::new(); + + for (old_node_index, old_node) in mast_nodes.into_iter().enumerate() { + let old_node_id: MastNodeId = (old_node_index as u32).into(); + + if live_ids.contains(&old_node_id) { + let new_node_id: MastNodeId = (pruned_nodes.len() as u32).into(); + id_remappings.insert(old_node_id, new_node_id); + + pruned_nodes.push(old_node); + } + } + + (pruned_nodes, id_remappings) +} + +/// Rewrites all [`MastNodeId`]s in the live nodes to the correct updated IDs using `id_remappings`, +/// which maps all old node IDs to new IDs. +fn build_pruned_mast_forest( + live_nodes: Vec, + old_root_ids: Vec, + id_remappings: &BTreeMap, +) -> MastForest { + let mut pruned_mast_forest = MastForest::new(); + + // Add each live node to the new MAST forest, making sure to rewrite any outdated internal + // `MastNodeId`s + for live_node in live_nodes { + match &live_node { + MastNode::Join(join_node) => { + let first_child = + id_remappings.get(&join_node.first()).copied().unwrap_or(join_node.first()); + let second_child = + id_remappings.get(&join_node.second()).copied().unwrap_or(join_node.second()); + + pruned_mast_forest.add_join(first_child, second_child).unwrap(); + }, + MastNode::Split(split_node) => { + let on_true_child = id_remappings + .get(&split_node.on_true()) + .copied() + .unwrap_or(split_node.on_true()); + let on_false_child = id_remappings + .get(&split_node.on_false()) + .copied() + .unwrap_or(split_node.on_false()); + + pruned_mast_forest.add_split(on_true_child, on_false_child).unwrap(); + }, + MastNode::Loop(loop_node) => { + let body_id = + id_remappings.get(&loop_node.body()).copied().unwrap_or(loop_node.body()); + + pruned_mast_forest.add_loop(body_id).unwrap(); + }, + MastNode::Call(call_node) => { + let callee_id = + id_remappings.get(&call_node.callee()).copied().unwrap_or(call_node.callee()); + + if call_node.is_syscall() { + pruned_mast_forest.add_syscall(callee_id).unwrap(); + } else { + pruned_mast_forest.add_call(callee_id).unwrap(); + } + }, + MastNode::Block(_) | MastNode::Dyn | MastNode::External(_) => { + pruned_mast_forest.add_node(live_node).unwrap(); + }, + } + } + + for old_root_id in old_root_ids { + let new_root_id = id_remappings.get(&old_root_id).copied().unwrap_or(old_root_id); + pruned_mast_forest.make_root(new_root_id); + } + + pruned_mast_forest +} diff --git a/assembly/src/assembler/mod.rs b/assembly/src/assembler/mod.rs index 8070c4d2f..43918965a 100644 --- a/assembly/src/assembler/mod.rs +++ b/assembly/src/assembler/mod.rs @@ -1,5 +1,6 @@ use alloc::{collections::BTreeMap, sync::Arc, vec::Vec}; +use dead_code_elimination::dead_code_elimination; use mast_forest_builder::MastForestBuilder; use module_graph::{ProcedureWrapper, WrappedModule}; use vm_core::{mast::MastNodeId, DecoratorList, Felt, Kernel, Operation, Program}; @@ -14,11 +15,13 @@ use crate::{ }; mod basic_block_builder; +mod dead_code_elimination; mod id; mod instruction; mod mast_forest_builder; mod module_graph; mod procedure; + #[cfg(test)] mod tests; @@ -299,8 +302,8 @@ impl Assembler { }; // TODO: show a warning if library exports are empty? - - Ok(Library::new(mast_forest_builder.build(), exports)) + let (mast_forest, _) = dead_code_elimination(mast_forest_builder.build()); + Ok(Library::new(mast_forest, exports)) } /// Assembles the provided module into a [KernelLibrary] intended to be used as a Kernel. @@ -341,7 +344,8 @@ impl Assembler { // TODO: show a warning if library exports are empty? - let library = Library::new(mast_forest_builder.build(), exports); + let (mast_forest, _) = dead_code_elimination(mast_forest_builder.build()); + let library = Library::new(mast_forest, exports); Ok(library.try_into()?) } @@ -381,9 +385,18 @@ impl Assembler { .get_procedure(entrypoint) .expect("compilation succeeded but root not found in cache"); + let (mast_forest, id_remappings) = dead_code_elimination(mast_forest_builder.build()); + let entry_node_id = { + let old_entry_node_id = entry_procedure.body_node_id(); + + id_remappings + .map(|id_remappings| id_remappings[&old_entry_node_id]) + .unwrap_or(old_entry_node_id) + }; + Ok(Program::with_kernel( - mast_forest_builder.build(), - entry_procedure.body_node_id(), + mast_forest, + entry_node_id, self.module_graph.kernel().clone(), )) } @@ -708,7 +721,7 @@ fn merge_contiguous_basic_blocks( let mut contiguous_basic_block_ids: Vec = Vec::new(); for mast_node_id in mast_node_ids { - if mast_forest_builder.get_mast_node(mast_node_id).unwrap().is_basic_block() { + if mast_forest_builder[mast_node_id].is_basic_block() { contiguous_basic_block_ids.push(mast_node_id); } else { if let Some(merged_basic_block_id) = @@ -748,11 +761,7 @@ fn merge_basic_blocks( for &basic_block_node_id in contiguous_basic_block_ids { // It is safe to unwrap here, since we already checked that all IDs in // `contiguous_basic_block_ids` are `BasicBlockNode`s - let basic_block_node = mast_forest_builder - .get_mast_node(basic_block_node_id) - .unwrap() - .get_basic_block() - .unwrap(); + let basic_block_node = mast_forest_builder[basic_block_node_id].get_basic_block().unwrap(); for (op_idx, decorator) in basic_block_node.decorators() { decorators.push((*op_idx + operations.len(), decorator.clone())); diff --git a/assembly/src/tests.rs b/assembly/src/tests.rs index 0660b6c6f..8a726a913 100644 --- a/assembly/src/tests.rs +++ b/assembly/src/tests.rs @@ -153,6 +153,9 @@ begin basic_block mul add add add add add end end"; assert_str_eq!(format!("{}", program), expected); + + // Also ensure that dead code elimination works properly + assert_eq!(program.mast_forest().num_nodes(), 1); Ok(()) } diff --git a/core/src/mast/mod.rs b/core/src/mast/mod.rs index b594c0875..19e1bce25 100644 --- a/core/src/mast/mod.rs +++ b/core/src/mast/mod.rs @@ -180,6 +180,20 @@ impl MastForest { .try_into() .expect("MAST forest contains more than 2^32 procedures.") } + + /// Returns the number of nodes in this MAST forest. + pub fn num_nodes(&self) -> u32 { + self.nodes.len() as u32 + } +} + +/// Destructors +impl MastForest { + pub fn into_parts(self) -> (Vec, Vec) { + let Self { nodes, roots } = self; + + (nodes, roots) + } } impl Index for MastForest { @@ -252,6 +266,12 @@ impl From<&MastNodeId> for u32 { } } +impl From for MastNodeId { + fn from(value: u32) -> Self { + Self(value) + } +} + impl fmt::Display for MastNodeId { fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { write!(f, "MastNodeId({})", self.0)