Skip to content
This repository was archived by the owner on Feb 6, 2025. It is now read-only.

Commit

Permalink
feat: enable prefetch on the new engine (#164)
Browse files Browse the repository at this point in the history
Co-authored-by: Keefe Liu <[email protected]>
keefel and Keefe Liu authored Oct 25, 2024
1 parent fda577b commit 80317a4
Showing 10 changed files with 216 additions and 61 deletions.
1 change: 1 addition & 0 deletions Cargo.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

28 changes: 11 additions & 17 deletions crates/blockchain-tree/src/chain.rs
Original file line number Diff line number Diff line change
@@ -30,7 +30,6 @@ use std::{
clone::Clone,
collections::BTreeMap,
ops::{Deref, DerefMut},
sync::Arc,
time::Instant,
};

@@ -231,14 +230,9 @@ impl AppendableChain {

let initial_execution_outcome = ExecutionOutcome::from((state, block.number));

// stop the prefetch task.
if let Some(interrupt_tx) = interrupt_tx {
let _ = interrupt_tx.send(());
}

// check state root if the block extends the canonical chain __and__ if state root
// validation was requested.
if block_validation_kind.is_exhaustive() {
let result = if block_validation_kind.is_exhaustive() {
// calculate and check state root
let start = Instant::now();
let (state_root, trie_updates) = if block_attachment.is_canonical() {
@@ -283,7 +277,14 @@ impl AppendableChain {
Ok((initial_execution_outcome, trie_updates))
} else {
Ok((initial_execution_outcome, None))
}
};

// stop the prefetch task.
if let Some(interrupt_tx) = interrupt_tx {
let _ = interrupt_tx.send(());
};

result
}

/// Validate and execute the given block, and append it to this chain.
@@ -356,18 +357,11 @@ impl AppendableChain {
let (interrupt_tx, interrupt_rx) = tokio::sync::oneshot::channel();

let mut trie_prefetch = TriePrefetch::new();
let consistent_view = if let Ok(view) =
ConsistentDbView::new_with_latest_tip(externals.provider_factory.clone())
{
view
} else {
tracing::debug!("Failed to create consistent view for trie prefetch");
return (None, None)
};
let provider_factory = externals.provider_factory.clone();

tokio::spawn({
async move {
trie_prefetch.run(Arc::new(consistent_view), prefetch_rx, interrupt_rx).await;
trie_prefetch.run(provider_factory, prefetch_rx, interrupt_rx).await;
}
});

3 changes: 3 additions & 0 deletions crates/engine/service/src/service.rs
Original file line number Diff line number Diff line change
@@ -80,6 +80,7 @@ where
invalid_block_hook: Box<dyn InvalidBlockHook>,
sync_metrics_tx: MetricEventsSender,
skip_state_root_validation: bool,
enable_prefetch: bool,
) -> Self {
let engine_kind =
if chain_spec.is_optimism() { EngineApiKind::OpStack } else { EngineApiKind::Ethereum };
@@ -104,6 +105,7 @@ where
invalid_block_hook,
engine_kind,
skip_state_root_validation,
enable_prefetch,
);

let engine_handler = EngineApiRequestHandler::new(to_tree_tx, from_tree);
@@ -217,6 +219,7 @@ mod tests {
Box::new(NoopInvalidBlockHook::default()),
sync_metrics_tx,
false,
false,
);
}
}
1 change: 1 addition & 0 deletions crates/engine/tree/Cargo.toml
Original file line number Diff line number Diff line change
@@ -33,6 +33,7 @@ reth-stages-api.workspace = true
reth-tasks.workspace = true
reth-trie.workspace = true
reth-trie-parallel.workspace = true
reth-trie-prefetch.workspace = true

# alloy
alloy-primitives.workspace = true
57 changes: 53 additions & 4 deletions crates/engine/tree/src/tree/mod.rs
Original file line number Diff line number Diff line change
@@ -33,7 +33,8 @@ use reth_payload_builder::PayloadBuilderHandle;
use reth_payload_primitives::{PayloadAttributes, PayloadBuilder, PayloadBuilderAttributes};
use reth_payload_validator::ExecutionPayloadValidator;
use reth_primitives::{
Block, GotExpected, Header, SealedBlock, SealedBlockWithSenders, SealedHeader,
revm_primitives::EvmState, Block, GotExpected, Header, SealedBlock, SealedBlockWithSenders,
SealedHeader,
};
use reth_provider::{
providers::ConsistentDbView, BlockReader, DatabaseProviderFactory, ExecutionOutcome,
@@ -44,6 +45,7 @@ use reth_revm::database::StateProviderDatabase;
use reth_stages_api::ControlFlow;
use reth_trie::{updates::TrieUpdates, HashedPostState, TrieInput};
use reth_trie_parallel::parallel_root::{ParallelStateRoot, ParallelStateRootError};
use reth_trie_prefetch::TriePrefetch;
use std::{
cmp::Ordering,
collections::{btree_map, hash_map, BTreeMap, VecDeque},
@@ -505,6 +507,8 @@ pub struct EngineApiTreeHandler<P, E, T: EngineTypes, Spec> {
engine_kind: EngineApiKind,
/// Flag indicating whether the state root validation should be skipped.
skip_state_root_validation: bool,
/// Flag indicating whether to enable prefetch.
enable_prefetch: bool,
}

impl<P: Debug, E: Debug, T: EngineTypes + Debug, Spec: Debug> std::fmt::Debug
@@ -527,6 +531,8 @@ impl<P: Debug, E: Debug, T: EngineTypes + Debug, Spec: Debug> std::fmt::Debug
.field("metrics", &self.metrics)
.field("invalid_block_hook", &format!("{:p}", self.invalid_block_hook))
.field("engine_kind", &self.engine_kind)
.field("skip_state_root_validation", &self.skip_state_root_validation)
.field("enable_prefetch", &self.enable_prefetch)
.finish()
}
}
@@ -555,6 +561,7 @@ where
config: TreeConfig,
engine_kind: EngineApiKind,
skip_state_root_validation: bool,
enable_prefetch: bool,
) -> Self {
let (incoming_tx, incoming) = std::sync::mpsc::channel();

@@ -577,6 +584,7 @@ where
invalid_block_hook: Box::new(NoopInvalidBlockHook),
engine_kind,
skip_state_root_validation,
enable_prefetch,
}
}

@@ -603,6 +611,7 @@ where
invalid_block_hook: Box<dyn InvalidBlockHook>,
kind: EngineApiKind,
skip_state_root_validation: bool,
enable_prefetch: bool,
) -> (Sender<FromEngine<EngineApiRequest<T>>>, UnboundedReceiver<EngineApiEvent>) {
let best_block_number = provider.best_block_number().unwrap_or(0);
let header = provider.sealed_header(best_block_number).ok().flatten().unwrap_or_default();
@@ -634,10 +643,20 @@ where
config,
kind,
skip_state_root_validation,
enable_prefetch,
);
task.set_invalid_block_hook(invalid_block_hook);
let incoming = task.incoming_tx.clone();
std::thread::Builder::new().name("Tree Task".to_string()).spawn(|| task.run()).unwrap();
std::thread::Builder::new()
.name("Tree Task".to_string())
.spawn(move || {
let runtime =
tokio::runtime::Builder::new_current_thread().enable_all().build().unwrap();
runtime.block_on(async {
task.run();
});
})
.unwrap();
(incoming, outgoing)
}

@@ -2175,8 +2194,16 @@ where
}

trace!(target: "engine::tree", block=?block.num_hash(), "Executing block");
let executor =
self.executor_provider.executor(StateProviderDatabase::new(&state_provider), None);
let (prefetch_tx, interrupt_tx) =
if self.enable_prefetch && !self.skip_state_root_validation {
self.setup_prefetch()
} else {
(None, None)
};

let executor = self
.executor_provider
.executor(StateProviderDatabase::new(&state_provider), prefetch_tx);

let block_number = block.number;
let block_hash = block.hash();
@@ -2247,6 +2274,11 @@ where
state_provider.state_root_with_updates(hashed_state.clone())?
};

// stop the prefetch task.
if let Some(interrupt_tx) = interrupt_tx {
let _ = interrupt_tx.send(());
};

if state_root != block.state_root {
// call post-block hook
self.invalid_block_hook.on_invalid_block(
@@ -2575,6 +2607,22 @@ where
);
Ok(())
}

fn setup_prefetch(&self) -> (Option<UnboundedSender<EvmState>>, Option<oneshot::Sender<()>>) {
let (prefetch_tx, prefetch_rx) = tokio::sync::mpsc::unbounded_channel();
let (interrupt_tx, interrupt_rx) = oneshot::channel();

let mut trie_prefetch = TriePrefetch::new();
let provider_factory = self.provider.clone();

tokio::spawn({
async move {
trie_prefetch.run(provider_factory, prefetch_rx, interrupt_rx).await;
}
});

(Some(prefetch_tx), Some(interrupt_tx))
}
}

/// This is an error that can come from advancing persistence. Either this can be a
@@ -2728,6 +2776,7 @@ mod tests {
TreeConfig::default(),
EngineApiKind::Ethereum,
false,
false,
);

let block_builder = TestBlockBuilder::default().with_chain_spec((*chain_spec).clone());
2 changes: 2 additions & 0 deletions crates/node/builder/src/launch/engine.rs
Original file line number Diff line number Diff line change
@@ -237,6 +237,7 @@ where
ctx.invalid_block_hook()?,
ctx.sync_metrics_tx(),
ctx.node_config().skip_state_root_validation,
ctx.node_config().enable_prefetch,
);
eth_service
}
@@ -271,6 +272,7 @@ where
ctx.invalid_block_hook()?,
ctx.sync_metrics_tx(),
ctx.node_config().skip_state_root_validation,
ctx.node_config().enable_prefetch,
);
eth_service
}
8 changes: 7 additions & 1 deletion crates/trie/parallel/src/parallel_root.rs
Original file line number Diff line number Diff line change
@@ -148,14 +148,17 @@ where
hashed_cursor_factory.hashed_account_cursor().map_err(ProviderError::Database)?,
);

let account_tree_start = std::time::Instant::now();
let mut hash_builder = HashBuilder::default().with_updates(retain_updates);
let mut account_rlp = Vec::with_capacity(128);
while let Some(node) = account_node_iter.try_next().map_err(ProviderError::Database)? {
match node {
TrieElement::Branch(node) => {
tracker.inc_branch();
hash_builder.add_branch(node.key, node.value, node.children_are_in_trie);
}
TrieElement::Leaf(hashed_address, account) => {
tracker.inc_leaf();
let (storage_root, _, updates) = match storage_roots.remove(&hashed_address) {
Some(rx) => rx.recv().map_err(|_| {
ParallelStateRootError::StorageRoot(StorageRootError::Database(
@@ -199,15 +202,18 @@ where
prefix_sets.destroyed_accounts,
);

let account_tree_duration = account_tree_start.elapsed();
let stats = tracker.finish();

#[cfg(feature = "metrics")]
self.metrics.record_state_trie(stats);

trace!(
debug!(
target: "trie::parallel_state_root",
%root,
duration = ?stats.duration(),
account_tree_duration = ?account_tree_duration,
storage_trees_duration = ?(stats.duration() - account_tree_duration),
branches_added = stats.branches_added(),
leaves_added = stats.leaves_added(),
missed_leaves = stats.missed_leaves(),
7 changes: 5 additions & 2 deletions crates/trie/prefetch/src/lib.rs
Original file line number Diff line number Diff line change
@@ -7,8 +7,11 @@
)]
#![cfg_attr(docsrs, feature(doc_cfg, doc_auto_cfg))]

pub use prefetch::TriePrefetch;
pub use reth_trie_parallel::StorageRootTargets;

/// Trie prefetch stats.
pub mod stats;

/// Implementation of trie prefetch.
mod prefetch;
pub use prefetch::TriePrefetch;
pub mod prefetch;
125 changes: 88 additions & 37 deletions crates/trie/prefetch/src/prefetch.rs
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
use std::{collections::HashMap, sync::Arc};

use alloy_primitives::B256;
use rayon::prelude::*;
use reth_execution_errors::StorageRootError;
@@ -16,14 +18,15 @@ use reth_trie::{
};
use reth_trie_db::{DatabaseHashedCursorFactory, DatabaseTrieCursorFactory};
use reth_trie_parallel::{parallel_root::ParallelStateRootError, StorageRootTargets};
use std::{collections::HashMap, sync::Arc};
use thiserror::Error;
use tokio::{
sync::{mpsc::UnboundedReceiver, oneshot::Receiver},
sync::{mpsc::UnboundedReceiver, oneshot::Receiver, Mutex},
task::JoinSet,
};
use tracing::{debug, trace};

use crate::stats::TriePrefetchTracker;

/// Prefetch trie storage when executing transactions.
#[derive(Debug, Clone)]
pub struct TriePrefetch {
@@ -56,31 +59,53 @@ impl TriePrefetch {
/// Run the prefetching task.
pub async fn run<Factory>(
&mut self,
consistent_view: Arc<ConsistentDbView<Factory>>,
provider_factory: Factory,
mut prefetch_rx: UnboundedReceiver<EvmState>,
mut interrupt_rx: Receiver<()>,
) where
Factory: DatabaseProviderFactory<Provider: BlockReader> + Send + Sync + 'static,
Factory: DatabaseProviderFactory<Provider: BlockReader> + Clone + 'static,
{
let mut join_set = JoinSet::new();
let arc_tracker = Arc::new(Mutex::new(TriePrefetchTracker::default()));

loop {
tokio::select! {
state = prefetch_rx.recv() => {
if let Some(state) = state {
let consistent_view = Arc::clone(&consistent_view);
let hashed_state = self.deduplicate_and_update_cached(state);

let self_clone = Arc::new(self.clone());
let consistent_view = ConsistentDbView::new_with_latest_tip(provider_factory.clone()).unwrap();
let hashed_state_clone = hashed_state.clone();
let arc_tracker_clone = Arc::clone(&arc_tracker);
join_set.spawn(async move {
if let Err(e) = self_clone.prefetch_accounts::<Factory>(consistent_view, hashed_state_clone, arc_tracker_clone).await {
debug!(target: "trie::trie_prefetch", ?e, "Error while prefetching account trie storage");
};
});

let self_clone = Arc::new(self.clone());
let consistent_view = ConsistentDbView::new_with_latest_tip(provider_factory.clone()).unwrap();
join_set.spawn(async move {
if let Err(e) = self_clone.prefetch_once(consistent_view, hashed_state).await {
debug!(target: "trie::trie_prefetch", ?e, "Error while prefetching trie storage");
if let Err(e) = self_clone.prefetch_storages::<Factory>(consistent_view, hashed_state).await {
debug!(target: "trie::trie_prefetch", ?e, "Error while prefetching storage trie storage");
};
});
}
}

_ = &mut interrupt_rx => {
debug!(target: "trie::trie_prefetch", "Interrupted trie prefetch task. Unprocessed tx {:?}", prefetch_rx.len());
let stat = arc_tracker.lock().await.finish();
debug!(
target: "trie::trie_prefetch",
unprocessed_tx = prefetch_rx.len(),
accounts_cached = self.cached_accounts.len(),
storages_cached = self.cached_storages.len(),
branches_prefetched = stat.branches_prefetched(),
leaves_prefetched = stat.leaves_prefetched(),
"trie prefetch interrupted"
);

join_set.abort_all();
return
}
@@ -95,9 +120,7 @@ impl TriePrefetch {

// deduplicate accounts if their keys are not present in storages
for (address, account) in &hashed_state.accounts {
if !hashed_state.storages.contains_key(address) &&
!self.cached_accounts.contains_key(address)
{
if !self.cached_accounts.contains_key(address) {
self.cached_accounts.insert(*address, true);
new_hashed_state.accounts.insert(*address, *account);
}
@@ -138,14 +161,15 @@ impl TriePrefetch {
new_hashed_state
}

/// Prefetch trie storage for the given hashed state.
pub async fn prefetch_once<Factory>(
/// Prefetch account trie nodes for the given hashed state.
pub async fn prefetch_accounts<Factory>(
self: Arc<Self>,
consistent_view: Arc<ConsistentDbView<Factory>>,
consistent_view: ConsistentDbView<Factory>,
hashed_state: HashedPostState,
arc_prefetch_tracker: Arc<Mutex<TriePrefetchTracker>>,
) -> Result<(), TriePrefetchError>
where
Factory: DatabaseProviderFactory<Provider: BlockReader> + Send + Sync + 'static,
Factory: DatabaseProviderFactory<Provider: BlockReader>,
{
let mut tracker = TrieTracker::default();

@@ -156,28 +180,10 @@ impl TriePrefetch {
);
let hashed_state_sorted = hashed_state.into_sorted();

trace!(target: "trie::trie_prefetch", "start prefetching trie storages");
// Solely for marking storage roots, so precise calculations are not necessary.
let mut storage_roots = storage_root_targets
.into_par_iter()
.map(|(hashed_address, prefix_set)| {
let provider_ro = consistent_view.provider_ro()?;
let trie_cursor_factory = DatabaseTrieCursorFactory::new(provider_ro.tx_ref());
let hashed_cursor_factory = HashedPostStateCursorFactory::new(
DatabaseHashedCursorFactory::new(provider_ro.tx_ref()),
&hashed_state_sorted,
);
let storage_root_result = StorageRoot::new_hashed(
trie_cursor_factory,
hashed_cursor_factory,
hashed_address,
#[cfg(feature = "metrics")]
self.metrics.clone(),
)
.with_prefix_set(prefix_set)
.prefetch();

Ok((hashed_address, storage_root_result?))
})
.map(|(hashed_address, _)| Ok((hashed_address, 1)))
.collect::<Result<HashMap<_, _>, ParallelStateRootError>>()?;

trace!(target: "trie::trie_prefetch", "prefetching account tries");
@@ -205,6 +211,7 @@ impl TriePrefetch {
tracker.inc_branch();
}
TrieElement::Leaf(hashed_address, _) => {
tracker.inc_leaf();
match storage_roots.remove(&hashed_address) {
Some(result) => result,
// Since we do not store all intermediate nodes in the database, there might
@@ -220,17 +227,19 @@ impl TriePrefetch {
.ok()
.unwrap_or_default(),
};
tracker.inc_leaf();
}
}
}

let stats = tracker.finish();
let mut prefetch_tracker = arc_prefetch_tracker.lock().await;
prefetch_tracker.inc_branches(stats.branches_added());
prefetch_tracker.inc_leaves(stats.leaves_added());

#[cfg(feature = "metrics")]
self.metrics.record(stats);

trace!(
debug!(
target: "trie::trie_prefetch",
duration = ?stats.duration(),
branches_added = stats.branches_added(),
@@ -240,6 +249,48 @@ impl TriePrefetch {

Ok(())
}

/// Prefetch storage trie nodes for the given hashed state.
pub async fn prefetch_storages<Factory>(
self: Arc<Self>,
consistent_view: ConsistentDbView<Factory>,
hashed_state: HashedPostState,
) -> Result<(), TriePrefetchError>
where
Factory: DatabaseProviderFactory<Provider: BlockReader>,
{
let prefix_sets = hashed_state.construct_prefix_sets().freeze();
let storage_root_targets = StorageRootTargets::new(
hashed_state.accounts.keys().copied(),
prefix_sets.storage_prefix_sets,
);
let hashed_state_sorted = hashed_state.into_sorted();

storage_root_targets
.into_par_iter()
.map(|(hashed_address, prefix_set)| {
let provider_ro = consistent_view.provider_ro()?;
let trie_cursor_factory = DatabaseTrieCursorFactory::new(provider_ro.tx_ref());
let hashed_cursor_factory = HashedPostStateCursorFactory::new(
DatabaseHashedCursorFactory::new(provider_ro.tx_ref()),
&hashed_state_sorted,
);
let storage_root_result = StorageRoot::new_hashed(
trie_cursor_factory,
hashed_cursor_factory,
hashed_address,
#[cfg(feature = "metrics")]
self.metrics.clone(),
)
.with_prefix_set(prefix_set)
.prefetch();

Ok((hashed_address, storage_root_result?))
})
.collect::<Result<HashMap<_, _>, ParallelStateRootError>>()?;

Ok(())
}
}

/// Error during prefetching trie storage.
45 changes: 45 additions & 0 deletions crates/trie/prefetch/src/stats.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,45 @@
/// Trie stats.
#[derive(Clone, Copy, Debug)]
pub struct TriePrefetchStats {
branches_prefetched: u64,
leaves_prefetched: u64,
}

impl TriePrefetchStats {
/// The number of added branch nodes for which we prefetched.
pub const fn branches_prefetched(&self) -> u64 {
self.branches_prefetched
}

/// The number of added leaf nodes for which we prefetched.
pub const fn leaves_prefetched(&self) -> u64 {
self.leaves_prefetched
}
}

/// Trie metrics tracker.
#[derive(Default, Debug, Clone, Copy)]
pub struct TriePrefetchTracker {
branches_prefetched: u64,
leaves_prefetched: u64,
}

impl TriePrefetchTracker {
/// Increment the number of branches prefetched.
pub fn inc_branches(&mut self, num: u64) {
self.branches_prefetched += num;
}

/// Increment the number of leaves prefetched.
pub fn inc_leaves(&mut self, num: u64) {
self.leaves_prefetched += num;
}

/// Called when prefetch is finished to return trie prefetch statistics.
pub const fn finish(self) -> TriePrefetchStats {
TriePrefetchStats {
branches_prefetched: self.branches_prefetched,
leaves_prefetched: self.leaves_prefetched,
}
}
}

0 comments on commit 80317a4

Please sign in to comment.