From e9b51f7b899dd9c2689a9f1e1109cad733dd3387 Mon Sep 17 00:00:00 2001 From: Chris Smith Date: Tue, 1 Oct 2024 13:47:38 -0400 Subject: [PATCH] feat: wallet service use JSON-RPC --- src/handlers/wallet/handler.rs | 139 +++++++++++++++++++++ src/handlers/wallet/mod.rs | 1 + src/handlers/wallet/prepare_calls.rs | 33 ++--- src/handlers/wallet/send_prepared_calls.rs | 61 ++++----- src/json_rpc/mod.rs | 61 +++++++-- src/json_rpc/tests.rs | 53 +++++--- src/lib.rs | 3 +- 7 files changed, 267 insertions(+), 84 deletions(-) create mode 100644 src/handlers/wallet/handler.rs diff --git a/src/handlers/wallet/handler.rs b/src/handlers/wallet/handler.rs new file mode 100644 index 00000000..26ed6804 --- /dev/null +++ b/src/handlers/wallet/handler.rs @@ -0,0 +1,139 @@ +use crate::error::RpcError; +use crate::json_rpc::{ + ErrorResponse, JsonRpcError, JsonRpcRequest, JsonRpcResponse, JsonRpcResult, +}; +use crate::{handlers::HANDLER_TASK_METRICS, state::AppState}; +use axum::extract::Query; +use axum::{extract::State, Json}; +use serde::Deserialize; +use std::sync::Arc; +use thiserror::Error; +use tracing::error; +use wc::future::FutureExt; + +use super::prepare_calls::{self, PrepareCallsError}; +use super::send_prepared_calls::{self, SendPreparedCallsError}; + +#[derive(Debug, Deserialize, Clone)] +#[serde(rename_all = "camelCase")] +pub struct WalletQueryParams { + pub project_id: String, +} + +pub async fn handler( + state: State>, + query: Query, + Json(request_payload): Json, +) -> Json { + handler_internal(state, query, request_payload) + .with_metrics(HANDLER_TASK_METRICS.with_name("wallet")) + .await +} + +#[tracing::instrument(skip(state), level = "debug")] +async fn handler_internal( + state: State>, + query: Query, + request: JsonRpcRequest, +) -> Json { + match handle_rpc(state, query, request.method, request.params).await { + Ok(result) => Json(JsonRpcResponse::Result(JsonRpcResult::new( + request.id, result, + ))), + Err(e) => { + if matches!(e, Error::Internal(_)) { + error!("Internal server error handling wallet RPC request: {e:?}"); + } + Json(JsonRpcResponse::Error(JsonRpcError::new( + request.id, + ErrorResponse { + code: e.to_json_rpc_error_code(), + message: e.to_string().into(), + data: None, + }, + ))) + } + } +} + +const WALLET_PREPARE_CALLS: &str = "wallet_prepareCalls"; +const WALLET_SEND_PREPARED_CALLS: &str = "wallet_sendPreparedCalls"; + +#[derive(Debug, Error)] +enum Error { + #[error("Invalid project ID: {0}")] + InvalidProjectId(RpcError), + + #[error("{WALLET_PREPARE_CALLS}: {0}")] + PrepareCalls(PrepareCallsError), + + #[error("{WALLET_SEND_PREPARED_CALLS}: {0}")] + SendPreparedCalls(SendPreparedCallsError), + + #[error("Method not found")] + MethodNotFound, + + #[error("Invalid params: {0}")] + InvalidParams(serde_json::Error), + + #[error("Internal error")] + Internal(InternalError), +} + +#[derive(Debug, Error)] +enum InternalError { + #[error("Serializing response: {0}")] + SerializeResponse(serde_json::Error), +} + +impl Error { + fn to_json_rpc_error_code(&self) -> i32 { + match self { + Error::InvalidProjectId(_) => -1, + Error::PrepareCalls(_) => -2, // TODO more specific codes + Error::SendPreparedCalls(_) => -3, // TODO more specific codes + Error::MethodNotFound => -32601, + Error::InvalidParams(_) => -32602, + Error::Internal(_) => -32000, + } + } +} + +#[tracing::instrument(skip(state), level = "debug")] +async fn handle_rpc( + state: State>, + Query(query): Query, + method: Arc, + params: serde_json::Value, +) -> Result { + let project_id = query.project_id; + state + .validate_project_access_and_quota(&project_id) + .await + // TODO refactor to differentiate between user and server errors + .map_err(Error::InvalidProjectId)?; + + match method.as_ref() { + WALLET_PREPARE_CALLS => serde_json::to_value( + &prepare_calls::handler( + state, + project_id, + serde_json::from_value(params).map_err(Error::InvalidParams)?, + ) + .await + .map_err(Error::PrepareCalls)?, + ) + .map_err(|e| Error::Internal(InternalError::SerializeResponse(e))), + WALLET_SEND_PREPARED_CALLS => serde_json::to_value( + &send_prepared_calls::handler( + state, + project_id, + serde_json::from_value(params).map_err(Error::InvalidParams)?, + ) + .await + .map_err(Error::SendPreparedCalls)?, + ) + .map_err(|e| Error::Internal(InternalError::SerializeResponse(e))), + _ => Err(Error::MethodNotFound), + } +} diff --git a/src/handlers/wallet/mod.rs b/src/handlers/wallet/mod.rs index 390b219e..9e7ee6b9 100644 --- a/src/handlers/wallet/mod.rs +++ b/src/handlers/wallet/mod.rs @@ -1,3 +1,4 @@ pub mod prepare_calls; pub mod send_prepared_calls; mod types; +pub mod handler; diff --git a/src/handlers/wallet/prepare_calls.rs b/src/handlers/wallet/prepare_calls.rs index 001461de..2c598f0c 100644 --- a/src/handlers/wallet/prepare_calls.rs +++ b/src/handlers/wallet/prepare_calls.rs @@ -13,7 +13,6 @@ use alloy::providers::{Provider, ReqwestProvider}; use alloy::sol_types::SolCall; use alloy::sol_types::SolValue; use alloy::transports::Transport; -use axum::extract::Query; use axum::{ extract::State, response::{IntoResponse, Response}, @@ -40,12 +39,6 @@ use yttrium::{ user_operation::{user_operation_hash::UserOperationHash, UserOperationV07}, }; -#[derive(Debug, Deserialize, Clone)] -#[serde(rename_all = "camelCase")] -pub struct PrepareCallsQueryParams { - pub project_id: String, -} - pub type PrepareCallsRequest = Vec; #[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)] @@ -181,10 +174,10 @@ impl IntoResponse for PrepareCallsError { pub async fn handler( state: State>, - query: Query, - Json(request_payload): Json, -) -> Result { - handler_internal(state, query, request_payload) + project_id: String, + request: PrepareCallsRequest, +) -> Result { + handler_internal(state, project_id, request) .with_metrics(HANDLER_TASK_METRICS.with_name("wallet_prepare_calls")) .await } @@ -192,15 +185,9 @@ pub async fn handler( #[tracing::instrument(skip(state), level = "debug")] async fn handler_internal( state: State>, - query: Query, + project_id: String, request: PrepareCallsRequest, -) -> Result { - // TODO refactor to differentiate between user and server errors - state - .validate_project_access_and_quota(&query.project_id) - .await - .map_err(PrepareCallsError::InvalidProjectId)?; - +) -> Result { let mut response = Vec::with_capacity(request.len()); for request in request { let chain_id = ChainId::new_eip155(request.chain_id.to::()); @@ -235,7 +222,7 @@ async fn handler_internal( format!( "https://rpc.walletconnect.com/v1?chainId={}&projectId={}&source={}", chain_id.caip2_identifier(), - query.project_id, + project_id, MessageSource::WalletPrepareCalls, ) .parse() @@ -285,7 +272,7 @@ async fn handler_internal( format!( "https://rpc.walletconnect.com/v1/bundler?chainId={}&projectId={}&bundler=pimlico", chain_id.caip2_identifier(), - query.project_id, + project_id, ) .parse() .unwrap(), @@ -324,7 +311,7 @@ async fn handler_internal( format!( "https://rpc.walletconnect.com/v1/bundler?chainId={}&projectId={}&bundler=pimlico", chain_id.caip2_identifier(), - query.project_id, + project_id, ) .parse() .unwrap(), @@ -373,7 +360,7 @@ async fn handler_internal( }); } - Ok(Json(response).into_response()) + Ok(response) } pub fn split_permissions_context_and_check_validator( diff --git a/src/handlers/wallet/send_prepared_calls.rs b/src/handlers/wallet/send_prepared_calls.rs index 0d71b7ec..846603d4 100644 --- a/src/handlers/wallet/send_prepared_calls.rs +++ b/src/handlers/wallet/send_prepared_calls.rs @@ -39,12 +39,6 @@ use yttrium::{ user_operation::UserOperationV07, }; -#[derive(Debug, Deserialize, Clone)] -#[serde(rename_all = "camelCase")] -pub struct SendPreparedCallsQueryParams { - pub project_id: String, -} - pub type SendPreparedCallsRequest = Vec; #[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)] @@ -74,7 +68,7 @@ pub enum SendPreparedCallsError { #[error("Invalid chain ID")] InvalidChainId, #[error("Cosign error: {0}")] - Cosign(RpcError), + Cosign(String), #[error("Permission not found")] PermissionNotFound, @@ -122,7 +116,7 @@ pub enum SendPreparedCallsInternalError { IrnNotConfigured, #[error("Cosign: {0}")] - Cosign(RpcError), + Cosign(String), #[error("Cosign unsuccessful: {0:?}")] CosignUnsuccessful(std::result::Result), @@ -177,10 +171,10 @@ impl IntoResponse for SendPreparedCallsError { pub async fn handler( state: State>, - query: Query, - Json(request_payload): Json, -) -> Result { - handler_internal(state, query, request_payload) + project_id: String, + request: SendPreparedCallsRequest, +) -> Result { + handler_internal(state, project_id, request) .with_metrics(HANDLER_TASK_METRICS.with_name("wallet_send_prepared_calls")) .await } @@ -188,15 +182,9 @@ pub async fn handler( #[tracing::instrument(skip(state), level = "debug")] async fn handler_internal( state: State>, - query: Query, + project_id: String, request: SendPreparedCallsRequest, -) -> Result { - // TODO refactor to differentiate between user and server errors - state - .validate_project_access_and_quota(&query.project_id) - .await - .map_err(SendPreparedCallsError::InvalidProjectId)?; - +) -> Result { let mut response = Vec::with_capacity(request.len()); for request in request { let chain_id = ChainId::new_eip155(request.prepared_calls.chain_id.to::()); @@ -306,12 +294,27 @@ async fn handler_internal( .await { Ok(response) => response, - Err(e) => return Ok(e.into_response()), - // if e.clone().into_response().status().is_server_error() { - // SendPreparedCallsError::InternalError(SendPreparedCallsInternalError::Cosign(e)) - // } else { - // SendPreparedCallsError::Cosign(e) - // } + Err(e) => { + let response = e.into_response(); + let status = response.status(); + let response = String::from_utf8( + to_bytes(response.into_body()) + .await + // Lazy error handling here for now. We will refactor soon to avoid all this + .unwrap_or_default() + .to_vec(), + ) + // Lazy error handling here for now. We will refactor soon to avoid all this + .unwrap_or_default(); + let e = if status.is_server_error() { + SendPreparedCallsError::InternalError( + SendPreparedCallsInternalError::Cosign(response), + ) + } else { + SendPreparedCallsError::Cosign(response) + }; + return Err(e); + } }; if !response.status().is_success() { return Err(SendPreparedCallsError::InternalError( @@ -374,7 +377,7 @@ async fn handler_internal( format!( "https://rpc.walletconnect.com/v1?chainId={}&projectId={}&source={}", chain_id.caip2_identifier(), - query.project_id, + project_id, MessageSource::WalletSendPreparedCalls, ) .parse() @@ -441,7 +444,7 @@ async fn handler_internal( format!( "https://rpc.walletconnect.com/v1/bundler?chainId={}&projectId={}&bundler=pimlico", chain_id.caip2_identifier(), - query.project_id, + project_id, ) .parse() .unwrap(), @@ -459,5 +462,5 @@ async fn handler_internal( response.push(SendPreparedCallsResponseItem { user_op_hash }); } - Ok(Json(response).into_response()) + Ok(response) } diff --git a/src/json_rpc/mod.rs b/src/json_rpc/mod.rs index 73cc8c17..5e72cc85 100644 --- a/src/json_rpc/mod.rs +++ b/src/json_rpc/mod.rs @@ -6,19 +6,21 @@ use { derive_more::{Display, From, Into}, serde::{Deserialize, Serialize}, - serde_aux::prelude::deserialize_number_from_string, std::sync::Arc, }; #[cfg(test)] mod tests; -pub const JSON_RPC_VERSION: &str = "2.0"; +pub const JSON_RPC_VERSION_STR: &str = "2.0"; + +pub static JSON_RPC_VERSION: once_cell::sync::Lazy> = + once_cell::sync::Lazy::new(|| Arc::from(JSON_RPC_VERSION_STR)); /// Represents the message ID type. -#[derive(Copy, Debug, Hash, Clone, PartialEq, Eq, Serialize, Deserialize, From, Into, Display)] +#[derive(Debug, Hash, Clone, PartialEq, Eq, Serialize, Deserialize, From, Into, Display)] #[serde(transparent)] -pub struct MessageId(#[serde(deserialize_with = "deserialize_number_from_string")] u64); +pub struct MessageId(Arc); /// Enum representing a JSON RPC Payload. #[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)] @@ -32,13 +34,15 @@ pub enum JsonRpcPayload { /// Data structure representing a JSON RPC Request #[derive(Debug, Clone, PartialEq, Eq, Hash, Serialize, Deserialize)] -pub struct JsonRpcRequest { +pub struct JsonRpcRequest { /// ID this message corresponds to. pub id: MessageId, /// The JSON RPC version. pub jsonrpc: Arc, /// The RPC method. pub method: Arc, + /// The RPC params. + pub params: T, } impl JsonRpcRequest { @@ -46,8 +50,20 @@ impl JsonRpcRequest { pub fn new(id: MessageId, method: Arc) -> Self { Self { id, - jsonrpc: JSON_RPC_VERSION.into(), + jsonrpc: JSON_RPC_VERSION.clone(), + method, + params: serde_json::Value::Null, + } + } +} + +impl JsonRpcRequest { + pub fn new_with_params(id: MessageId, method: Arc, params: T) -> Self { + Self { + id, + jsonrpc: JSON_RPC_VERSION.clone(), method, + params, } } } @@ -64,34 +80,53 @@ pub enum JsonRpcResponse { /// Data structure representing a JSON RPC Result. #[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)] -pub struct JsonRpcResult { +pub struct JsonRpcResult { /// ID this message corresponds to. pub id: MessageId, /// RPC version. pub jsonrpc: Arc, /// The result for the message. - pub result: serde_json::Value, + pub result: T, +} + +impl JsonRpcResult { + pub fn new(id: MessageId, result: serde_json::Value) -> Self { + Self { + id, + jsonrpc: JSON_RPC_VERSION.clone(), + result, + } + } } /// Data structure representing a JSON RPC Error. #[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)] -pub struct JsonRpcError { +pub struct JsonRpcError>> { /// ID this message corresponds to. pub id: MessageId, /// RPC version. pub jsonrpc: Arc, /// The ErrorResponse corresponding to this message. - pub error: ErrorResponse, + pub error: ErrorResponse, +} + +impl JsonRpcError { + pub fn new(id: MessageId, error: ErrorResponse) -> Self { + Self { + id, + jsonrpc: JSON_RPC_VERSION.clone(), + error, + } + } } /// Data structure representing a ErrorResponse. #[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)] -pub struct ErrorResponse { +pub struct ErrorResponse { /// Error code. pub code: i32, /// Error message. pub message: Arc, /// Error data, if any. - #[serde(skip_serializing_if = "Option::is_none")] - pub data: Option>, + pub data: T, } diff --git a/src/json_rpc/tests.rs b/src/json_rpc/tests.rs index 48936107..cd023afc 100644 --- a/src/json_rpc/tests.rs +++ b/src/json_rpc/tests.rs @@ -1,28 +1,47 @@ use super::*; #[test] -fn test_request() { - let payload: JsonRpcPayload = - JsonRpcPayload::Request(JsonRpcRequest::new(1.into(), "eth_chainId".into())); - - let serialized = serde_json::to_string(&payload).unwrap(); - +fn test_request_serialized() { assert_eq!( - &serialized, - "{\"id\":1,\"jsonrpc\":\"2.0\",\"method\":\"eth_chainId\"}" + &serde_json::to_string(&JsonRpcPayload::Request(JsonRpcRequest::new( + MessageId("1".into()), + "eth_chainId".into() + ))) + .unwrap(), + "{\"id\":\"1\",\"jsonrpc\":\"2.0\",\"method\":\"eth_chainId\"}" ); +} - let deserialized: JsonRpcPayload = serde_json::from_str(&serialized).unwrap(); - - assert_eq!(&payload, &deserialized) +#[test] +fn test_request_deserialized() { + assert_eq!( + &serde_json::from_str::( + "{\"id\":1,\"jsonrpc\":\"2.0\",\"method\":\"eth_chainId\"}" + ) + .unwrap(), + &JsonRpcPayload::Request(JsonRpcRequest::new( + MessageId("1".into()), + "eth_chainId".into(), + )), + ); + assert_eq!( + &serde_json::from_str::( + "{\"id\":\"abc\",\"jsonrpc\":\"2.0\",\"method\":\"eth_chainId\"}" + ) + .unwrap(), + &JsonRpcPayload::Request(JsonRpcRequest::new( + MessageId("abc".into()), + "eth_chainId".into(), + )), + ); } #[test] fn test_response_result() { let payload: JsonRpcPayload = JsonRpcPayload::Response(JsonRpcResponse::Result(JsonRpcResult { - id: 1.into(), - jsonrpc: Arc::from(JSON_RPC_VERSION), + id: MessageId("1".into()), + jsonrpc: JSON_RPC_VERSION.clone(), result: "some result".into(), })); @@ -30,7 +49,7 @@ fn test_response_result() { assert_eq!( &serialized, - "{\"id\":1,\"jsonrpc\":\"2.0\",\"result\":\"some result\"}" + "{\"id\":\"1\",\"jsonrpc\":\"2.0\",\"result\":\"some result\"}" ); let deserialized: JsonRpcPayload = serde_json::from_str(&serialized).unwrap(); @@ -41,8 +60,8 @@ fn test_response_result() { #[test] fn test_response_error() { let payload: JsonRpcPayload = JsonRpcPayload::Response(JsonRpcResponse::Error(JsonRpcError { - id: 1.into(), - jsonrpc: Arc::from(JSON_RPC_VERSION), + id: MessageId(1.to_string().into()), + jsonrpc: JSON_RPC_VERSION.clone(), error: ErrorResponse { code: 32, message: Arc::from("some message"), @@ -54,7 +73,7 @@ fn test_response_error() { assert_eq!( &serialized, - "{\"id\":1,\"jsonrpc\":\"2.0\",\"error\":{\"code\":32,\"message\":\"some message\"}}" + "{\"id\":1,\"jsonrpc\":\"2.0\",\"error\":{\"code\":32,\"message\":\"some message\",\"data\":null}}" ); let deserialized: JsonRpcPayload = serde_json::from_str(&serialized).unwrap(); diff --git a/src/lib.rs b/src/lib.rs index 56d94610..fd5b5ecc 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -348,8 +348,7 @@ pub async fn bootstrap(config: Config) -> RpcResult<()> { // Bundler .route("/v1/bundler", post(handlers::bundler::handler)) // Wallet - .route("/v1/wallet/prepareCalls", post(handlers::wallet::prepare_calls::handler)) - .route("/v1/wallet/sendPreparedCalls", post(handlers::wallet::send_prepared_calls::handler)) + .route("/v1/wallet", post(handlers::wallet::handler::handler)) // Health .route("/health", get(handlers::health::handler)) .route_layer(tracing_and_metrics_layer)