diff --git a/crates/provider/src/ext/mod.rs b/crates/provider/src/ext/mod.rs index 7b6c0ffc1bf..84dabfd4737 100644 --- a/crates/provider/src/ext/mod.rs +++ b/crates/provider/src/ext/mod.rs @@ -28,7 +28,7 @@ pub use net::NetApi; #[cfg(feature = "trace-api")] mod trace; #[cfg(feature = "trace-api")] -pub use trace::{TraceApi, TraceCallList}; +pub use trace::{TraceApi, TraceCallList, TraceRpcWithBlock}; #[cfg(feature = "rpc-api")] mod rpc; diff --git a/crates/provider/src/ext/trace.rs b/crates/provider/src/ext/trace/api.rs similarity index 65% rename from crates/provider/src/ext/trace.rs rename to crates/provider/src/ext/trace/api.rs index 334ef2a883a..4ecc9bc11c7 100644 --- a/crates/provider/src/ext/trace.rs +++ b/crates/provider/src/ext/trace/api.rs @@ -1,5 +1,6 @@ //! This module extends the Ethereum JSON-RPC provider with the Trace namespace's RPC methods. -use crate::{Provider, RpcWithBlock}; +use super::TraceRpcWithBlock; +use crate::Provider; use alloy_eips::BlockNumberOrTag; use alloy_network::Network; use alloy_primitives::TxHash; @@ -26,11 +27,10 @@ where /// # Note /// /// Not all nodes support this call. - fn trace_call<'a, 'b>( + fn trace_call( &self, - request: &'a N::TransactionRequest, - trace_type: &'b [TraceType], - ) -> RpcWithBlock; + request: N::TransactionRequest, + ) -> TraceRpcWithBlock; /// Traces multiple transactions on top of the same block, i.e. transaction `n` will be executed /// on top of the given block with all `n - 1` transaction applied first. @@ -40,10 +40,10 @@ where /// # Note /// /// Not all nodes support this call. - fn trace_call_many<'a>( + fn trace_call_many( &self, - request: TraceCallList<'a, N>, - ) -> RpcWithBlock, TraceResults>; + requests: Vec, + ) -> TraceRpcWithBlock, Vec>; /// Parity trace transaction. async fn trace_transaction( @@ -64,16 +64,12 @@ where ) -> TransportResult; /// Trace the given raw transaction. - async fn trace_raw_transaction( - &self, - data: &[u8], - trace_type: &[TraceType], - ) -> TransportResult; + fn trace_raw_transaction(&self, data: Vec) -> TraceRpcWithBlock, TraceResults>; /// Traces matching given filter. async fn trace_filter( &self, - tracer: &TraceFilter, + filter: TraceFilter, ) -> TransportResult>; /// Trace all transactions in the given block. @@ -87,18 +83,13 @@ where ) -> TransportResult>; /// Replays a transaction. - async fn trace_replay_transaction( - &self, - hash: TxHash, - trace_type: &[TraceType], - ) -> TransportResult; + fn trace_replay_transaction(&self, hash: TxHash) -> TraceRpcWithBlock; /// Replays all transactions in the given block. - async fn trace_replay_block_transactions( + fn trace_replay_block_transactions( &self, block: BlockNumberOrTag, - trace_type: &[TraceType], - ) -> TransportResult>; + ) -> TraceRpcWithBlock>; } #[cfg_attr(target_arch = "wasm32", async_trait::async_trait(?Send))] @@ -109,20 +100,18 @@ where T: Transport + Clone, P: Provider, { - fn trace_call<'a, 'b>( + fn trace_call( &self, - request: &'a ::TransactionRequest, - trace_type: &'b [TraceType], - ) -> RpcWithBlock::TransactionRequest, &'b [TraceType]), TraceResults> - { - RpcWithBlock::new(self.weak_client(), "trace_call", (request, trace_type)) + request: N::TransactionRequest, + ) -> TraceRpcWithBlock { + TraceRpcWithBlock::new(self.weak_client(), "trace_call", request) } - fn trace_call_many<'a>( + fn trace_call_many( &self, - request: TraceCallList<'a, N>, - ) -> RpcWithBlock, TraceResults> { - RpcWithBlock::new(self.weak_client(), "trace_callMany", request) + requests: Vec, + ) -> TraceRpcWithBlock, Vec> { + TraceRpcWithBlock::new(self.weak_client(), "trace_callMany", requests) } async fn trace_transaction( @@ -141,19 +130,15 @@ where self.client().request("trace_get", (hash, (Index::from(index),))).await } - async fn trace_raw_transaction( - &self, - data: &[u8], - trace_type: &[TraceType], - ) -> TransportResult { - self.client().request("trace_rawTransaction", (data, trace_type)).await + fn trace_raw_transaction(&self, data: Vec) -> TraceRpcWithBlock, TraceResults> { + TraceRpcWithBlock::new(self.weak_client(), "trace_rawTransaction", data) } async fn trace_filter( &self, - tracer: &TraceFilter, + filter: TraceFilter, ) -> TransportResult> { - self.client().request("trace_filter", (tracer,)).await + self.client().request("trace_filter", (filter,)).await } async fn trace_block( @@ -163,20 +148,15 @@ where self.client().request("trace_block", (block,)).await } - async fn trace_replay_transaction( - &self, - hash: TxHash, - trace_type: &[TraceType], - ) -> TransportResult { - self.client().request("trace_replayTransaction", (hash, trace_type)).await + fn trace_replay_transaction(&self, hash: TxHash) -> TraceRpcWithBlock { + TraceRpcWithBlock::new(self.weak_client(), "trace_replayTransaction", hash) } - async fn trace_replay_block_transactions( + fn trace_replay_block_transactions( &self, block: BlockNumberOrTag, - trace_type: &[TraceType], - ) -> TransportResult> { - self.client().request("trace_replayBlockTransactions", (block, trace_type)).await + ) -> TraceRpcWithBlock> { + TraceRpcWithBlock::new(self.weak_client(), "trace_replayBlockTransactions", block) } } diff --git a/crates/provider/src/ext/trace/mod.rs b/crates/provider/src/ext/trace/mod.rs new file mode 100644 index 00000000000..0f20d99aaa3 --- /dev/null +++ b/crates/provider/src/ext/trace/mod.rs @@ -0,0 +1,5 @@ +mod api; +mod with_block; + +pub use api::*; +pub use with_block::*; diff --git a/crates/provider/src/ext/trace/with_block.rs b/crates/provider/src/ext/trace/with_block.rs new file mode 100644 index 00000000000..7cb9ccc0af5 --- /dev/null +++ b/crates/provider/src/ext/trace/with_block.rs @@ -0,0 +1,110 @@ +use crate::{RpcWithBlock, RpcWithBlockFut}; +use alloy_json_rpc::{RpcParam, RpcReturn}; +use alloy_rpc_client::WeakClient; +use alloy_rpc_types_trace::parity::TraceType; +use alloy_transport::{Transport, TransportResult}; +use std::{borrow::Cow, collections::HashSet, future::IntoFuture, ops::Deref}; + +/// An wrapper for [`TraceRpcWithBlock`] that takes an optional [`TraceType`] parameter. By default +/// this will use "trace". +#[derive(Debug, Clone)] +pub struct TraceRpcWithBlock Output> +where + T: Transport + Clone, + Params: RpcParam, + Resp: RpcReturn, + Map: Fn(Resp) -> Output, +{ + inner: RpcWithBlock, + trace_types: HashSet, +} + +impl Deref for TraceRpcWithBlock +where + T: Transport + Clone, + Params: RpcParam, + Resp: RpcReturn, + Map: Fn(Resp) -> Output, +{ + type Target = RpcWithBlock; + + fn deref(&self) -> &Self::Target { + &self.inner + } +} + +impl TraceRpcWithBlock +where + T: Transport + Clone, + Params: RpcParam, + Resp: RpcReturn, +{ + /// Create a new [`TraceRpcWithBlock`] instance. + pub fn new( + client: WeakClient, + method: impl Into>, + params: Params, + ) -> Self { + Self { + inner: RpcWithBlock::new(client, method, params), + trace_types: vec![TraceType::Trace].into_iter().collect(), + } + } +} + +impl TraceRpcWithBlock +where + T: Transport + Clone, + Params: RpcParam, + Resp: RpcReturn, + Map: Fn(Resp) -> Output + 'static, +{ + /// Set the trace type. + pub fn trace_type(mut self, trace_type: TraceType) -> Self { + self.trace_types.insert(trace_type); + self + } + + /// Set the trace types. + pub fn trace_types>(mut self, trace_types: I) -> Self { + self.trace_types.extend(trace_types); + self + } + + /// Set the trace type to "trace". + pub fn trace(self) -> Self { + self.trace_type(TraceType::Trace) + } + + /// Set the trace type to "vmTrace". + pub fn vm_trace(self) -> Self { + self.trace_type(TraceType::VmTrace) + } + + /// Set the trace type to "stateDiff". + pub fn state_diff(self) -> Self { + self.trace_type(TraceType::StateDiff) + } + + /// Get the trace types. + pub const fn get_trace_types(&self) -> &HashSet { + &self.trace_types + } +} + +impl IntoFuture for TraceRpcWithBlock +where + T: Transport + Clone, + Params: RpcParam, + Resp: RpcReturn, + Output: 'static, + Map: Fn(Resp) -> Output + 'static + Copy, +{ + type Output = TransportResult; + type IntoFuture = RpcWithBlockFut; + + fn into_future(self) -> Self::IntoFuture { + let inner: RpcWithBlock = self.into(); + inner.into_future() + } +} diff --git a/crates/provider/src/lib.rs b/crates/provider/src/lib.rs index 314e1ff2494..9be20928ea9 100644 --- a/crates/provider/src/lib.rs +++ b/crates/provider/src/lib.rs @@ -40,8 +40,8 @@ pub use heart::{PendingTransaction, PendingTransactionBuilder, PendingTransactio mod provider; pub use provider::{ - builder, EthCall, FilterPollerBuilder, Provider, RootProvider, RpcWithBlock, SendableTx, - WalletProvider, + builder, EthCall, FilterPollerBuilder, Provider, RootProvider, RpcWithBlock, RpcWithBlockFut, + SendableTx, WalletProvider, }; pub mod utils; diff --git a/crates/provider/src/provider/mod.rs b/crates/provider/src/provider/mod.rs index 0d8e939ed5a..7528847e787 100644 --- a/crates/provider/src/provider/mod.rs +++ b/crates/provider/src/provider/mod.rs @@ -14,4 +14,4 @@ mod wallet; pub use wallet::WalletProvider; mod with_block; -pub use with_block::RpcWithBlock; +pub use with_block::{RpcWithBlock, RpcWithBlockFut}; diff --git a/crates/provider/src/provider/with_block.rs b/crates/provider/src/provider/with_block.rs index 808b72a15f6..1e72183b09e 100644 --- a/crates/provider/src/provider/with_block.rs +++ b/crates/provider/src/provider/with_block.rs @@ -11,6 +11,13 @@ use std::{ task::Poll, }; +#[cfg(feature = "trace-api")] +use { + crate::ext::TraceRpcWithBlock, + alloy_rpc_types_trace::parity::TraceType, + std::{collections::HashSet, ops::Deref}, +}; + /// States of the #[derive(Clone)] enum States Output> @@ -26,6 +33,8 @@ where method: Cow<'static, str>, params: Params, block_id: BlockId, + #[cfg(feature = "trace-api")] + trace_types: HashSet, map: Map, }, Running(RpcCall), @@ -80,6 +89,15 @@ where cx: &mut std::task::Context<'_>, ) -> Poll> { let this = self.project(); + + #[cfg(feature = "trace-api")] + let States::Preparing { client, method, params, block_id, trace_types, map } = + std::mem::replace(this.state, States::Invalid) + else { + unreachable!("bad state") + }; + + #[cfg(not(feature = "trace-api"))] let States::Preparing { client, method, params, block_id, map } = std::mem::replace(this.state, States::Invalid) else { @@ -110,10 +128,26 @@ where // append the block id to the params if let serde_json::Value::Array(ref mut arr) = ser { arr.push(block_id); + #[cfg(feature = "trace-api")] + if !trace_types.is_empty() { + arr.push(serde_json::to_value(trace_types).map_err(RpcError::ser_err)?); + } } else if ser.is_null() { - ser = serde_json::Value::Array(vec![block_id]); + let mut arr = vec![]; + arr.push(block_id); + #[cfg(feature = "trace-api")] + if !trace_types.is_empty() { + arr.push(serde_json::to_value(trace_types).map_err(RpcError::ser_err)?); + } + ser = serde_json::Value::Array(arr); } else { - ser = serde_json::Value::Array(vec![ser, block_id]); + let mut arr = vec![ser]; + arr.push(block_id); + #[cfg(feature = "trace-api")] + if !trace_types.is_empty() { + arr.push(serde_json::to_value(trace_types).map_err(RpcError::ser_err)?); + } + ser = serde_json::Value::Array(arr); } // create the call @@ -173,6 +207,8 @@ where method: Cow<'static, str>, params: Params, block_id: BlockId, + #[cfg(feature = "trace-api")] + trace_types: HashSet, map: Map, _pd: PhantomData (Resp, Output)>, } @@ -194,6 +230,8 @@ where method: method.into(), params, block_id: Default::default(), + #[cfg(feature = "trace-api")] + trace_types: vec![TraceType::Trace].into_iter().collect(), map: std::convert::identity, _pd: PhantomData, } @@ -220,6 +258,8 @@ where method: self.method, params: self.params, block_id: self.block_id, + #[cfg(feature = "trace-api")] + trace_types: self.trace_types, map, _pd: PhantomData, } @@ -293,8 +333,34 @@ where method: self.method, params: self.params, block_id: self.block_id, + #[cfg(feature = "trace-api")] + trace_types: self.trace_types, map: self.map, }, } } } + +#[cfg(feature = "trace-api")] +impl From> + for RpcWithBlock +where + T: Transport + Clone, + Params: RpcParam, + Resp: RpcReturn, + Output: 'static, + Map: Fn(Resp) -> Output + 'static + Copy, +{ + fn from(trace_rpc: TraceRpcWithBlock) -> Self { + let rpc = trace_rpc.deref(); + Self { + client: rpc.client.clone(), + method: rpc.method.clone(), + params: rpc.params.clone(), + block_id: rpc.block_id, + trace_types: trace_rpc.get_trace_types().clone(), + map: rpc.map, + _pd: rpc._pd, + } + } +}