From 803037011fbb32b5c7b487a4f95feb483ba091ee Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Philippe=20Laferri=C3=A8re?= Date: Wed, 21 Aug 2024 13:09:30 -0400 Subject: [PATCH] add basic block merging threshold (#1461) --- CHANGELOG.md | 1 + assembly/src/assembler/mast_forest_builder.rs | 88 ++++++++++++++----- core/src/mast/node/basic_block_node/mod.rs | 13 ++- miden/Cargo.toml | 4 - miden/benches/program_compilation.rs | 30 ------- 5 files changed, 80 insertions(+), 56 deletions(-) delete mode 100644 miden/benches/program_compilation.rs diff --git a/CHANGELOG.md b/CHANGELOG.md index 52186b53b..fd11f9b73 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -2,6 +2,7 @@ ## 0.11.0 (TBD) - Assembler: Merge contiguous basic blocks (#1454) +- Assembler: Add a threshold number of operations after which we stop merging more in the same block (#1461) #### Enhancements diff --git a/assembly/src/assembler/mast_forest_builder.rs b/assembly/src/assembler/mast_forest_builder.rs index 4848ca121..e072bfff2 100644 --- a/assembly/src/assembler/mast_forest_builder.rs +++ b/assembly/src/assembler/mast_forest_builder.rs @@ -14,6 +14,12 @@ use vm_core::{ use super::{GlobalProcedureIndex, Procedure}; use crate::AssemblyError; +// CONSTANTS +// ================================================================================================ + +/// Constant that decides how many operation batches disqualify a procedure from inlining. +const PROCEDURE_INLINING_THRESHOLD: usize = 32; + // MAST FOREST BUILDER // ================================================================================================ @@ -251,20 +257,14 @@ impl MastForestBuilder { if self[mast_node_id].is_basic_block() { contiguous_basic_block_ids.push(mast_node_id); } else { - if let Some(merged_basic_block_id) = - self.merge_basic_blocks(&contiguous_basic_block_ids)? - { - merged_node_ids.push(merged_basic_block_id) - } + merged_node_ids.extend(self.merge_basic_blocks(&contiguous_basic_block_ids)?); contiguous_basic_block_ids.clear(); merged_node_ids.push(mast_node_id); } } - if let Some(merged_basic_block_id) = self.merge_basic_blocks(&contiguous_basic_block_ids)? { - merged_node_ids.push(merged_basic_block_id) - } + merged_node_ids.extend(self.merge_basic_blocks(&contiguous_basic_block_ids)?); Ok(merged_node_ids) } @@ -277,35 +277,59 @@ impl MastForestBuilder { fn merge_basic_blocks( &mut self, contiguous_basic_block_ids: &[MastNodeId], - ) -> Result, AssemblyError> { + ) -> Result, AssemblyError> { if contiguous_basic_block_ids.is_empty() { - return Ok(None); + return Ok(Vec::new()); } if contiguous_basic_block_ids.len() == 1 { - return Ok(Some(contiguous_basic_block_ids[0])); + return Ok(contiguous_basic_block_ids.to_vec()); } let mut operations: Vec = Vec::new(); let mut decorators = DecoratorList::new(); - for &basic_block_node_id in contiguous_basic_block_ids { + let mut merged_basic_blocks: Vec = Vec::new(); + + for &basic_block_id in contiguous_basic_block_ids { // It is safe to unwrap here, since we already checked that all IDs in // `contiguous_basic_block_ids` are `BasicBlockNode`s - let basic_block_node = self[basic_block_node_id].get_basic_block().unwrap(); - - for (op_idx, decorator) in basic_block_node.decorators() { - decorators.push((*op_idx + operations.len(), decorator.clone())); - } - for batch in basic_block_node.op_batches() { - operations.extend_from_slice(batch.ops()); + let basic_block_node = self[basic_block_id].get_basic_block().unwrap().clone(); + + // check if the block should be merged with other blocks + if should_merge( + self.mast_forest.is_procedure_root(basic_block_id), + basic_block_node.num_op_batches(), + ) { + for (op_idx, decorator) in basic_block_node.decorators() { + decorators.push((*op_idx + operations.len(), decorator.clone())); + } + for batch in basic_block_node.op_batches() { + operations.extend_from_slice(batch.ops()); + } + } else { + // if we don't want to merge this block, we flush the buffer of operations into a + // new block, and add the un-merged block after it + if !operations.is_empty() { + let block_ops = core::mem::take(&mut operations); + let block_decorators = core::mem::take(&mut decorators); + let merged_basic_block_id = + self.ensure_block(block_ops, Some(block_decorators))?; + + merged_basic_blocks.push(merged_basic_block_id); + } + merged_basic_blocks.push(basic_block_id); } } // Mark the removed basic blocks as merged self.merged_node_ids.extend(contiguous_basic_block_ids.iter()); - let merged_basic_block = self.ensure_block(operations, Some(decorators))?; - Ok(Some(merged_basic_block)) + if !operations.is_empty() || !decorators.is_empty() { + let merged_basic_block = self.ensure_block(operations, Some(decorators))?; + merged_basic_blocks.push(merged_basic_block); + } + + Ok(merged_basic_blocks) } } @@ -398,3 +422,25 @@ impl Index for MastForestBuilder { &self.mast_forest[node_id] } } + +// HELPER FUNCTIONS +// ================================================================================================ + +/// Determines if we want to merge a block with other blocks. Currently, this works as follows: +/// - If the block is a procedure, we merge it only if the number of operation batches is smaller +/// then the threshold (currently set at 32). The reasoning is based on an estimate of the the +/// runtime penalty of not inlining the procedure. We assume that this penalty is roughly 3 extra +/// nodes in the MAST and so would require 3 additional hashes at runtime. Since hashing each +/// operation batch requires 1 hash, this basically implies that if the runtime penalty is more +/// than 10%, we inline the block, but if it is less than 10% we accept the penalty to make +/// deserialization faster. +/// - If the block is not a procedure, we always merge it because: (1) if it is a large block, it is +/// likely to be unique and, thus, the original block will be orphaned and removed later; (2) if +/// it is a small block, there is a large run-time benefit for inlining it. +fn should_merge(is_procedure: bool, num_op_batches: usize) -> bool { + if is_procedure { + num_op_batches < PROCEDURE_INLINING_THRESHOLD + } else { + true + } +} diff --git a/core/src/mast/node/basic_block_node/mod.rs b/core/src/mast/node/basic_block_node/mod.rs index a049c9ffa..44e0a1835 100644 --- a/core/src/mast/node/basic_block_node/mod.rs +++ b/core/src/mast/node/basic_block_node/mod.rs @@ -128,6 +128,11 @@ impl BasicBlockNode { &self.op_batches } + /// Returns the number of operation batches in this basic block. + pub fn num_op_batches(&self) -> usize { + self.op_batches.len() + } + /// Returns the total number of operation groups in this basic block. /// /// Then number of operation groups is computed as follows: @@ -142,6 +147,12 @@ impl BasicBlockNode { (self.op_batches.len() - 1) * BATCH_SIZE + last_batch_num_groups.next_power_of_two() } + /// Returns the number of operations in this basic block. + pub fn num_operations(&self) -> u32 { + let num_ops: usize = self.op_batches.iter().map(|batch| batch.ops().len()).sum(); + num_ops.try_into().expect("basic block contains more than 2^32 operations") + } + /// Returns a list of decorators in this basic block node. /// /// Each decorator is accompanied by the operation index specifying the operation prior to @@ -158,7 +169,7 @@ impl BasicBlockNode { /// Returns the total number of operations and decorators in this basic block. pub fn num_operations_and_decorators(&self) -> u32 { - let num_ops: usize = self.op_batches.iter().map(|batch| batch.ops().len()).sum(); + let num_ops: usize = self.num_operations() as usize; let num_decorators = self.decorators.len(); (num_ops + num_decorators) diff --git a/miden/Cargo.toml b/miden/Cargo.toml index 70d675245..1f5cafd9e 100644 --- a/miden/Cargo.toml +++ b/miden/Cargo.toml @@ -25,10 +25,6 @@ path = "src/lib.rs" bench = false doctest = false -[[bench]] -name = "program_compilation" -harness = false - [[bench]] name = "program_execution" harness = false diff --git a/miden/benches/program_compilation.rs b/miden/benches/program_compilation.rs deleted file mode 100644 index e7a18b897..000000000 --- a/miden/benches/program_compilation.rs +++ /dev/null @@ -1,30 +0,0 @@ -use std::time::Duration; - -use assembly::Assembler; -use criterion::{criterion_group, criterion_main, Criterion}; -use stdlib::StdLibrary; - -fn program_compilation(c: &mut Criterion) { - let mut group = c.benchmark_group("program_compilation"); - group.measurement_time(Duration::from_secs(10)); - - let stdlib = StdLibrary::default(); - group.bench_function("sha256", |bench| { - let source = " - use.std::crypto::hashes::sha256 - - begin - exec.sha256::hash_2to1 - end"; - bench.iter(|| { - let mut assembler = Assembler::default(); - assembler.add_library(&stdlib).expect("failed to load stdlib"); - assembler.assemble_program(source).expect("Failed to compile test source.") - }); - }); - - group.finish(); -} - -criterion_group!(sha256_group, program_compilation); -criterion_main!(sha256_group);