Skip to content

Commit

Permalink
add basic block merging threshold (#1461)
Browse files Browse the repository at this point in the history
  • Loading branch information
plafer authored Aug 21, 2024
1 parent 62a49fd commit 8030370
Show file tree
Hide file tree
Showing 5 changed files with 80 additions and 56 deletions.
1 change: 1 addition & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@

## 0.11.0 (TBD)
- Assembler: Merge contiguous basic blocks (#1454)
- Assembler: Add a threshold number of operations after which we stop merging more in the same block (#1461)

#### Enhancements

Expand Down
88 changes: 67 additions & 21 deletions assembly/src/assembler/mast_forest_builder.rs
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,12 @@ use vm_core::{
use super::{GlobalProcedureIndex, Procedure};
use crate::AssemblyError;

// CONSTANTS
// ================================================================================================

/// Constant that decides how many operation batches disqualify a procedure from inlining.
const PROCEDURE_INLINING_THRESHOLD: usize = 32;

// MAST FOREST BUILDER
// ================================================================================================

Expand Down Expand Up @@ -251,20 +257,14 @@ impl MastForestBuilder {
if self[mast_node_id].is_basic_block() {
contiguous_basic_block_ids.push(mast_node_id);
} else {
if let Some(merged_basic_block_id) =
self.merge_basic_blocks(&contiguous_basic_block_ids)?
{
merged_node_ids.push(merged_basic_block_id)
}
merged_node_ids.extend(self.merge_basic_blocks(&contiguous_basic_block_ids)?);
contiguous_basic_block_ids.clear();

merged_node_ids.push(mast_node_id);
}
}

if let Some(merged_basic_block_id) = self.merge_basic_blocks(&contiguous_basic_block_ids)? {
merged_node_ids.push(merged_basic_block_id)
}
merged_node_ids.extend(self.merge_basic_blocks(&contiguous_basic_block_ids)?);

Ok(merged_node_ids)
}
Expand All @@ -277,35 +277,59 @@ impl MastForestBuilder {
fn merge_basic_blocks(
&mut self,
contiguous_basic_block_ids: &[MastNodeId],
) -> Result<Option<MastNodeId>, AssemblyError> {
) -> Result<Vec<MastNodeId>, AssemblyError> {
if contiguous_basic_block_ids.is_empty() {
return Ok(None);
return Ok(Vec::new());
}
if contiguous_basic_block_ids.len() == 1 {
return Ok(Some(contiguous_basic_block_ids[0]));
return Ok(contiguous_basic_block_ids.to_vec());
}

let mut operations: Vec<Operation> = Vec::new();
let mut decorators = DecoratorList::new();

for &basic_block_node_id in contiguous_basic_block_ids {
let mut merged_basic_blocks: Vec<MastNodeId> = Vec::new();

for &basic_block_id in contiguous_basic_block_ids {
// It is safe to unwrap here, since we already checked that all IDs in
// `contiguous_basic_block_ids` are `BasicBlockNode`s
let basic_block_node = self[basic_block_node_id].get_basic_block().unwrap();

for (op_idx, decorator) in basic_block_node.decorators() {
decorators.push((*op_idx + operations.len(), decorator.clone()));
}
for batch in basic_block_node.op_batches() {
operations.extend_from_slice(batch.ops());
let basic_block_node = self[basic_block_id].get_basic_block().unwrap().clone();

// check if the block should be merged with other blocks
if should_merge(
self.mast_forest.is_procedure_root(basic_block_id),
basic_block_node.num_op_batches(),
) {
for (op_idx, decorator) in basic_block_node.decorators() {
decorators.push((*op_idx + operations.len(), decorator.clone()));
}
for batch in basic_block_node.op_batches() {
operations.extend_from_slice(batch.ops());
}
} else {
// if we don't want to merge this block, we flush the buffer of operations into a
// new block, and add the un-merged block after it
if !operations.is_empty() {
let block_ops = core::mem::take(&mut operations);
let block_decorators = core::mem::take(&mut decorators);
let merged_basic_block_id =
self.ensure_block(block_ops, Some(block_decorators))?;

merged_basic_blocks.push(merged_basic_block_id);
}
merged_basic_blocks.push(basic_block_id);
}
}

// Mark the removed basic blocks as merged
self.merged_node_ids.extend(contiguous_basic_block_ids.iter());

let merged_basic_block = self.ensure_block(operations, Some(decorators))?;
Ok(Some(merged_basic_block))
if !operations.is_empty() || !decorators.is_empty() {
let merged_basic_block = self.ensure_block(operations, Some(decorators))?;
merged_basic_blocks.push(merged_basic_block);
}

Ok(merged_basic_blocks)
}
}

Expand Down Expand Up @@ -398,3 +422,25 @@ impl Index<MastNodeId> for MastForestBuilder {
&self.mast_forest[node_id]
}
}

// HELPER FUNCTIONS
// ================================================================================================

/// Determines if we want to merge a block with other blocks. Currently, this works as follows:
/// - If the block is a procedure, we merge it only if the number of operation batches is smaller
/// then the threshold (currently set at 32). The reasoning is based on an estimate of the the
/// runtime penalty of not inlining the procedure. We assume that this penalty is roughly 3 extra
/// nodes in the MAST and so would require 3 additional hashes at runtime. Since hashing each
/// operation batch requires 1 hash, this basically implies that if the runtime penalty is more
/// than 10%, we inline the block, but if it is less than 10% we accept the penalty to make
/// deserialization faster.
/// - If the block is not a procedure, we always merge it because: (1) if it is a large block, it is
/// likely to be unique and, thus, the original block will be orphaned and removed later; (2) if
/// it is a small block, there is a large run-time benefit for inlining it.
fn should_merge(is_procedure: bool, num_op_batches: usize) -> bool {
if is_procedure {
num_op_batches < PROCEDURE_INLINING_THRESHOLD
} else {
true
}
}
13 changes: 12 additions & 1 deletion core/src/mast/node/basic_block_node/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -128,6 +128,11 @@ impl BasicBlockNode {
&self.op_batches
}

/// Returns the number of operation batches in this basic block.
pub fn num_op_batches(&self) -> usize {
self.op_batches.len()
}

/// Returns the total number of operation groups in this basic block.
///
/// Then number of operation groups is computed as follows:
Expand All @@ -142,6 +147,12 @@ impl BasicBlockNode {
(self.op_batches.len() - 1) * BATCH_SIZE + last_batch_num_groups.next_power_of_two()
}

/// Returns the number of operations in this basic block.
pub fn num_operations(&self) -> u32 {
let num_ops: usize = self.op_batches.iter().map(|batch| batch.ops().len()).sum();
num_ops.try_into().expect("basic block contains more than 2^32 operations")
}

/// Returns a list of decorators in this basic block node.
///
/// Each decorator is accompanied by the operation index specifying the operation prior to
Expand All @@ -158,7 +169,7 @@ impl BasicBlockNode {

/// Returns the total number of operations and decorators in this basic block.
pub fn num_operations_and_decorators(&self) -> u32 {
let num_ops: usize = self.op_batches.iter().map(|batch| batch.ops().len()).sum();
let num_ops: usize = self.num_operations() as usize;
let num_decorators = self.decorators.len();

(num_ops + num_decorators)
Expand Down
4 changes: 0 additions & 4 deletions miden/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -25,10 +25,6 @@ path = "src/lib.rs"
bench = false
doctest = false

[[bench]]
name = "program_compilation"
harness = false

[[bench]]
name = "program_execution"
harness = false
Expand Down
30 changes: 0 additions & 30 deletions miden/benches/program_compilation.rs

This file was deleted.

0 comments on commit 8030370

Please sign in to comment.