Skip to content

Commit

Permalink
refactor: wrap MastForest in Program and Library in Arc (#1465)
Browse files Browse the repository at this point in the history
* refactor: wrap MastFores in Program and Library in Arc
* fix: enforce that program entrypoints are procedure roots
* refactor: add external nodes to MastForest for re-exports
  • Loading branch information
bobbinth authored Aug 27, 2024
1 parent d7c8933 commit b192c61
Show file tree
Hide file tree
Showing 24 changed files with 448 additions and 313 deletions.
7 changes: 7 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
@@ -1,5 +1,12 @@
# Changelog

## 0.11.0 (TBD)

#### Changes

- [BREAKING] Wrapped `MastForest`s in `Program` and `Library` structs in `Arc` (#1465).


## 0.10.5 (2024-08-21)

#### Enhancements
Expand Down
82 changes: 41 additions & 41 deletions assembly/src/assembler/mast_forest_builder.rs
Original file line number Diff line number Diff line change
@@ -1,9 +1,7 @@
use alloc::{
collections::{BTreeMap, BTreeSet},
sync::Arc,
vec::Vec,
};
use core::ops::Index;

use vm_core::{
crypto::hash::RpoDigest,
Expand All @@ -24,14 +22,35 @@ const PROCEDURE_INLINING_THRESHOLD: usize = 32;
// ================================================================================================

/// Builder for a [`MastForest`].
///
/// The purpose of the builder is to ensure that the underlying MAST forest contains as little
/// information as possible needed to adequately describe the logical MAST forest. Specifically:
/// - The builder ensures that only one copy of a given node exists in the MAST forest (i.e., no two
/// nodes have the same hash).
/// - The builder tries to merge adjacent basic blocks and eliminate the source block whenever this
/// does not have an impact on other nodes in the forest.
#[derive(Clone, Debug, Default)]
pub struct MastForestBuilder {
/// The MAST forest being built by this builder; this MAST forest is up-to-date - i.e., all
/// nodes added to the MAST forest builder are also immediately added to the underlying MAST
/// forest.
mast_forest: MastForest,
/// A map of MAST node digests to their corresponding positions in the MAST forest. It is
/// guaranteed that a given digests maps to exactly one node in the MAST forest.
node_id_by_hash: BTreeMap<RpoDigest, MastNodeId>,
procedures: BTreeMap<GlobalProcedureIndex, Arc<Procedure>>,
procedure_hashes: BTreeMap<GlobalProcedureIndex, RpoDigest>,
/// A map of all procedures added to the MAST forest indexed by their global procedure ID.
/// This includes all local, exported, and re-exported procedures. In case multiple procedures
/// with the same digest are added to the MAST forest builder, only the first procedure is
/// added to the map, and all subsequent insertions are ignored.
procedures: BTreeMap<GlobalProcedureIndex, Procedure>,
/// A map from procedure MAST root to its global procedure index. Similar to the `procedures`
/// map, this map contains only the first inserted procedure for procedures with the same MAST
/// root.
proc_gid_by_hash: BTreeMap<RpoDigest, GlobalProcedureIndex>,
merged_node_ids: BTreeSet<MastNodeId>,
/// A set of IDs for basic blocks which have been merged into a bigger basic blocks. This is
/// used as a candidate set of nodes that may be eliminated if the are not referenced by any
/// other node in the forest and are not a root of any procedure.
merged_basic_block_ids: BTreeSet<MastNodeId>,
}

impl MastForestBuilder {
Expand All @@ -42,7 +61,7 @@ impl MastForestBuilder {
/// unchanged. Any [`MastNodeId`] used in reference to the old [`MastForest`] should be remapped
/// using this map.
pub fn build(mut self) -> (MastForest, Option<BTreeMap<MastNodeId, MastNodeId>>) {
let nodes_to_remove = get_nodes_to_remove(self.merged_node_ids, &self.mast_forest);
let nodes_to_remove = get_nodes_to_remove(self.merged_basic_block_ids, &self.mast_forest);
let id_remappings = self.mast_forest.remove_nodes(&nodes_to_remove);

(self.mast_forest, id_remappings)
Expand Down Expand Up @@ -109,21 +128,21 @@ impl MastForestBuilder {
/// Returns a reference to the procedure with the specified [`GlobalProcedureIndex`], or None
/// if such a procedure is not present in this MAST forest builder.
#[inline(always)]
pub fn get_procedure(&self, gid: GlobalProcedureIndex) -> Option<Arc<Procedure>> {
self.procedures.get(&gid).cloned()
pub fn get_procedure(&self, gid: GlobalProcedureIndex) -> Option<&Procedure> {
self.procedures.get(&gid)
}

/// Returns the hash of the procedure with the specified [`GlobalProcedureIndex`], or None if
/// such a procedure is not present in this MAST forest builder.
#[inline(always)]
pub fn get_procedure_hash(&self, gid: GlobalProcedureIndex) -> Option<RpoDigest> {
self.procedure_hashes.get(&gid).cloned()
self.procedures.get(&gid).map(|proc| proc.mast_root())
}

/// Returns a reference to the procedure with the specified MAST root, or None
/// if such a procedure is not present in this MAST forest builder.
#[inline(always)]
pub fn find_procedure(&self, mast_root: &RpoDigest) -> Option<Arc<Procedure>> {
pub fn find_procedure(&self, mast_root: &RpoDigest) -> Option<&Procedure> {
self.proc_gid_by_hash.get(mast_root).and_then(|gid| self.get_procedure(*gid))
}

Expand All @@ -141,18 +160,9 @@ impl MastForestBuilder {
}
}

// ------------------------------------------------------------------------------------------------
/// Procedure insertion
impl MastForestBuilder {
pub fn insert_procedure_hash(
&mut self,
gid: GlobalProcedureIndex,
proc_hash: RpoDigest,
) -> Result<(), AssemblyError> {
// TODO(plafer): Check if exists
self.procedure_hashes.insert(gid, proc_hash);

Ok(())
}

/// Inserts a procedure into this MAST forest builder.
///
/// If the procedure with the same ID already exists in this forest builder, this will have
Expand Down Expand Up @@ -202,19 +212,17 @@ impl MastForestBuilder {
}
}

self.make_root(procedure.body_node_id());
self.mast_forest.make_root(procedure.body_node_id());
self.proc_gid_by_hash.insert(proc_root, gid);
self.insert_procedure_hash(gid, procedure.mast_root())?;
self.procedures.insert(gid, Arc::new(procedure));
self.procedures.insert(gid, procedure);

Ok(())
}
}

/// Marks the given [`MastNodeId`] as being the root of a procedure.
pub fn make_root(&mut self, new_root_id: MastNodeId) {
self.mast_forest.make_root(new_root_id)
}

// ------------------------------------------------------------------------------------------------
/// Joining nodes
impl MastForestBuilder {
/// Builds a tree of `JOIN` operations to combine the provided MAST node IDs.
pub fn join_nodes(&mut self, node_ids: Vec<MastNodeId>) -> Result<MastNodeId, AssemblyError> {
debug_assert!(!node_ids.is_empty(), "cannot combine empty MAST node id list");
Expand Down Expand Up @@ -254,7 +262,7 @@ impl MastForestBuilder {
let mut contiguous_basic_block_ids: Vec<MastNodeId> = Vec::new();

for mast_node_id in node_ids {
if self[mast_node_id].is_basic_block() {
if self.mast_forest[mast_node_id].is_basic_block() {
contiguous_basic_block_ids.push(mast_node_id);
} else {
merged_node_ids.extend(self.merge_basic_blocks(&contiguous_basic_block_ids)?);
Expand Down Expand Up @@ -293,7 +301,8 @@ impl MastForestBuilder {
for &basic_block_id in contiguous_basic_block_ids {
// It is safe to unwrap here, since we already checked that all IDs in
// `contiguous_basic_block_ids` are `BasicBlockNode`s
let basic_block_node = self[basic_block_id].get_basic_block().unwrap().clone();
let basic_block_node =
self.mast_forest[basic_block_id].get_basic_block().unwrap().clone();

// check if the block should be merged with other blocks
if should_merge(
Expand Down Expand Up @@ -322,7 +331,7 @@ impl MastForestBuilder {
}

// Mark the removed basic blocks as merged
self.merged_node_ids.extend(contiguous_basic_block_ids.iter());
self.merged_basic_block_ids.extend(contiguous_basic_block_ids.iter());

if !operations.is_empty() || !decorators.is_empty() {
let merged_basic_block = self.ensure_block(operations, Some(decorators))?;
Expand Down Expand Up @@ -414,15 +423,6 @@ impl MastForestBuilder {
}
}

impl Index<MastNodeId> for MastForestBuilder {
type Output = MastNode;

#[inline(always)]
fn index(&self, node_id: MastNodeId) -> &Self::Output {
&self.mast_forest[node_id]
}
}

// HELPER FUNCTIONS
// ================================================================================================

Expand Down
42 changes: 26 additions & 16 deletions assembly/src/assembler/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -301,7 +301,7 @@ impl Assembler {

// TODO: show a warning if library exports are empty?
let (mast_forest, _) = mast_forest_builder.build();
Ok(Library::new(mast_forest, exports))
Ok(Library::new(mast_forest.into(), exports)?)
}

/// Assembles the provided module into a [KernelLibrary] intended to be used as a Kernel.
Expand Down Expand Up @@ -343,7 +343,7 @@ impl Assembler {
// TODO: show a warning if library exports are empty?

let (mast_forest, _) = mast_forest_builder.build();
let library = Library::new(mast_forest, exports);
let library = Library::new(mast_forest.into(), exports)?;
Ok(library.try_into()?)
}

Expand Down Expand Up @@ -379,21 +379,19 @@ impl Assembler {
// Compile the module graph rooted at the entrypoint
let mut mast_forest_builder = MastForestBuilder::default();
self.compile_subgraph(entrypoint, &mut mast_forest_builder)?;
let entry_procedure = mast_forest_builder
let entry_node_id = mast_forest_builder
.get_procedure(entrypoint)
.expect("compilation succeeded but root not found in cache");
.expect("compilation succeeded but root not found in cache")
.body_node_id();

// in case the node IDs changed, update the entrypoint ID to the new value
let (mast_forest, id_remappings) = mast_forest_builder.build();
let entry_node_id = {
let old_entry_node_id = entry_procedure.body_node_id();

id_remappings
.map(|id_remappings| id_remappings[&old_entry_node_id])
.unwrap_or(old_entry_node_id)
};
let entry_node_id = id_remappings
.map(|id_remappings| id_remappings[&entry_node_id])
.unwrap_or(entry_node_id);

Ok(Program::with_kernel(
mast_forest,
mast_forest.into(),
entry_node_id,
self.module_graph.kernel().clone(),
))
Expand Down Expand Up @@ -473,8 +471,13 @@ impl Assembler {

// Compile this procedure
let procedure = self.compile_procedure(pctx, mast_forest_builder)?;
// TODO: if a re-exported procedure with the same MAST root had been previously
// added to the builder, this will result in unreachable nodes added to the
// MAST forest. This is because while we won't insert a duplicate node for the
// procedure body node itself, all nodes that make up the procedure body would
// be added to the forest.

// Cache the compiled procedure.
// Cache the compiled procedure
self.module_graph.register_mast_root(procedure_gid, procedure.mast_root())?;
mast_forest_builder.insert_procedure(procedure_gid, procedure)?;
},
Expand All @@ -493,15 +496,22 @@ impl Assembler {
)
.with_span(proc_alias.span());

let proc_alias_root = self.resolve_target(
let proc_mast_root = self.resolve_target(
InvokeKind::ProcRef,
&proc_alias.target().into(),
&pctx,
mast_forest_builder,
)?;

// insert external node into the MAST forest for this procedure; if a procedure
// with the same MAST rood had been previously added to the builder, this will
// have no effect
let proc_node_id = mast_forest_builder.ensure_external(proc_mast_root)?;
let procedure = pctx.into_procedure(proc_mast_root, proc_node_id);

// Make the MAST root available to all dependents
self.module_graph.register_mast_root(procedure_gid, proc_alias_root)?;
mast_forest_builder.insert_procedure_hash(procedure_gid, proc_alias_root)?;
self.module_graph.register_mast_root(procedure_gid, proc_mast_root)?;
mast_forest_builder.insert_procedure(procedure_gid, procedure)?;
},
}
}
Expand Down
4 changes: 3 additions & 1 deletion assembly/src/assembler/procedure.rs
Original file line number Diff line number Diff line change
Expand Up @@ -93,7 +93,9 @@ impl ProcedureContext {
///
/// The passed-in `mast_root` defines the MAST root of the procedure's body while
/// `mast_node_id` specifies the ID of the procedure's body node in the MAST forest in
/// which the procedure is defined.
/// which the procedure is defined. Note that if the procedure is re-exported (i.e., the body
/// of the procedure is defined in some other MAST forest) `mast_node_id` will point to a
/// single `External` node.
///
/// <div class="warning">
/// `mast_root` and `mast_node_id` must be consistent. That is, the node located in the MAST
Expand Down
6 changes: 4 additions & 2 deletions assembly/src/assembler/tests.rs
Original file line number Diff line number Diff line change
Expand Up @@ -142,7 +142,9 @@ fn nested_blocks() -> Result<(), Report> {
.join_nodes(vec![before, r#if1, nested, exec_foo_bar_baz_node_id, syscall_foo_node_id])
.unwrap();

let expected_program = Program::new(expected_mast_forest_builder.build().0, combined_node_id);
let mut expected_mast_forest = expected_mast_forest_builder.build().0;
expected_mast_forest.make_root(combined_node_id);
let expected_program = Program::new(expected_mast_forest.into(), combined_node_id);
assert_eq!(expected_program.hash(), program.hash());

// also check that the program has the right number of procedures (which excludes the dummy
Expand Down Expand Up @@ -214,7 +216,7 @@ fn duplicate_nodes() {

expected_mast_forest.make_root(root_id);

let expected_program = Program::new(expected_mast_forest, root_id);
let expected_program = Program::new(expected_mast_forest.into(), root_id);

assert_eq!(program, expected_program);
}
Expand Down
2 changes: 2 additions & 0 deletions assembly/src/library/error.rs
Original file line number Diff line number Diff line change
Expand Up @@ -11,4 +11,6 @@ pub enum LibraryError {
InvalidKernelExport { procedure_path: QualifiedProcedureName },
#[error(transparent)]
Kernel(#[from] KernelError),
#[error("invalid export: no procedure root for {procedure_path} procedure")]
NoProcedureRootForExport { procedure_path: QualifiedProcedureName },
}
Loading

0 comments on commit b192c61

Please sign in to comment.