From cbffdfd9f5923e29970704449ed1cd8aeffe0ca2 Mon Sep 17 00:00:00 2001 From: Ryan Schneider Date: Wed, 7 Aug 2024 09:51:41 -0700 Subject: [PATCH] feat: Return local nonce when getTransactionCount request is signed (#151) --- adapters/flashbots/signature.go | 67 +++++++++++++++++++++++ adapters/flashbots/signature_test.go | 80 ++++++++++++++++++++++++++++ server/request_handler.go | 11 +++- server/request_intercepts.go | 44 +++++++++++++++ server/request_processor.go | 25 ++++++++- 5 files changed, 224 insertions(+), 3 deletions(-) create mode 100644 adapters/flashbots/signature.go create mode 100644 adapters/flashbots/signature_test.go diff --git a/adapters/flashbots/signature.go b/adapters/flashbots/signature.go new file mode 100644 index 0000000..fe4dac1 --- /dev/null +++ b/adapters/flashbots/signature.go @@ -0,0 +1,67 @@ +// Package flashbots provides methods for parsing the X-Flashbots-Signature header. +package flashbots + +import ( + "errors" + "fmt" + "strings" + + "github.com/ethereum/go-ethereum/accounts" + "github.com/ethereum/go-ethereum/common/hexutil" + "github.com/ethereum/go-ethereum/crypto" +) + +var ( + ErrNoSignature = errors.New("no signature provided") + ErrInvalidSignature = errors.New("invalid signature provided") +) + +func ParseSignature(header string, body []byte) (signingAddress string, err error) { + if header == "" { + return "", ErrNoSignature + } + + splitSig := strings.Split(header, ":") + if len(splitSig) != 2 { + return "", ErrInvalidSignature + } + + return VerifySignature(body, splitSig[0], splitSig[1]) +} + +func VerifySignature(body []byte, signingAddressStr, signatureStr string) (signingAddress string, err error) { + signature, err := hexutil.Decode(signatureStr) + if err != nil || len(signature) == 0 { + return "", fmt.Errorf("%w: %w", ErrInvalidSignature, err) + } + + if signature[len(signature)-1] >= 27 { + signature[len(signature)-1] -= 27 + } + + hashedBody := crypto.Keccak256Hash(body).Hex() + messageHash := accounts.TextHash([]byte(hashedBody)) + signaturePublicKeyBytes, err := crypto.Ecrecover(messageHash, signature) + if err != nil { + return "", fmt.Errorf("%w: %w", ErrInvalidSignature, err) + } + + publicKey, err := crypto.UnmarshalPubkey(signaturePublicKeyBytes) + if err != nil { + return "", fmt.Errorf("%w: %w", ErrInvalidSignature, err) + } + signaturePubkey := *publicKey + signaturePubKeyAddress := crypto.PubkeyToAddress(signaturePubkey).Hex() + + // case-insensitive equality check + if !strings.EqualFold(signaturePubKeyAddress, signingAddressStr) { + return "", fmt.Errorf("%w: signing address mismatch", ErrInvalidSignature) + } + + signatureNoRecoverID := signature[:len(signature)-1] // remove recovery id + if !crypto.VerifySignature(signaturePublicKeyBytes, messageHash, signatureNoRecoverID) { + return "", fmt.Errorf("%w: %w", ErrInvalidSignature, err) + } + + return signaturePubKeyAddress, nil +} diff --git a/adapters/flashbots/signature_test.go b/adapters/flashbots/signature_test.go new file mode 100644 index 0000000..c0876bb --- /dev/null +++ b/adapters/flashbots/signature_test.go @@ -0,0 +1,80 @@ +package flashbots_test + +import ( + "fmt" + "testing" + + "github.com/ethereum/go-ethereum/accounts" + "github.com/ethereum/go-ethereum/common/hexutil" + "github.com/ethereum/go-ethereum/crypto" + "github.com/stretchr/testify/require" + + "github.com/flashbots/rpc-endpoint/adapters/flashbots" +) + +func TestParseSignature(t *testing.T) { + + // For most of these test cases, we first need to generate a signature + privateKey, err := crypto.GenerateKey() + require.NoError(t, err) + + address := crypto.PubkeyToAddress(privateKey.PublicKey).Hex() + body := fmt.Sprintf( + `{"jsonrpc":"2.0","method":"eth_getTransactionCount","params":["%s","pending"],"id":1}`, + address, + ) + + signature, err := crypto.Sign( + accounts.TextHash([]byte(hexutil.Encode(crypto.Keccak256([]byte(body))))), + privateKey, + ) + require.NoError(t, err) + + header := fmt.Sprintf("%s:%s", address, hexutil.Encode(signature)) + + t.Run("header is empty", func(t *testing.T) { + _, err := flashbots.ParseSignature("", []byte{}) + require.ErrorIs(t, err, flashbots.ErrNoSignature) + }) + + t.Run("header is valid", func(t *testing.T) { + signingAddress, err := flashbots.ParseSignature(header, []byte(body)) + require.NoError(t, err) + require.Equal(t, address, signingAddress) + }) + + t.Run("header is invalid", func(t *testing.T) { + _, err := flashbots.ParseSignature("invalid", []byte(body)) + require.ErrorIs(t, err, flashbots.ErrInvalidSignature) + }) + + t.Run("header has extra bytes", func(t *testing.T) { + _, err := flashbots.ParseSignature(header+"deadbeef", []byte(body)) + require.ErrorIs(t, err, flashbots.ErrInvalidSignature) + }) + + t.Run("header has missing bytes", func(t *testing.T) { + _, err := flashbots.ParseSignature(header[:len(header)-8], []byte(body)) + require.ErrorIs(t, err, flashbots.ErrInvalidSignature) + }) + + t.Run("body is empty", func(t *testing.T) { + _, err := flashbots.ParseSignature(header, []byte{}) + require.ErrorIs(t, err, flashbots.ErrInvalidSignature) + }) + + t.Run("body is invalid", func(t *testing.T) { + _, err := flashbots.ParseSignature(header, []byte(`{}`)) + require.ErrorIs(t, err, flashbots.ErrInvalidSignature) + }) + + t.Run("body has extra bytes", func(t *testing.T) { + _, err := flashbots.ParseSignature(header, []byte(body+"...")) + require.ErrorIs(t, err, flashbots.ErrInvalidSignature) + }) + + t.Run("body has missing bytes", func(t *testing.T) { + _, err := flashbots.ParseSignature(header, []byte(body[:len(body)-8])) + require.ErrorIs(t, err, flashbots.ErrInvalidSignature) + }) +} diff --git a/server/request_handler.go b/server/request_handler.go index 07723ca..64972bf 100644 --- a/server/request_handler.go +++ b/server/request_handler.go @@ -134,11 +134,11 @@ func (r *RpcRequestHandler) process() { r.logger = r.logger.New("rpc_method", jsonReq.Method) // Process single request - r.processRequest(client, jsonReq, origin, referer, isWhitehatBundleCollection, whitehatBundleId, urlParams, r.req.URL.String()) + r.processRequest(client, jsonReq, origin, referer, isWhitehatBundleCollection, whitehatBundleId, urlParams, r.req.URL.String(), body) } // processRequest handles single request -func (r *RpcRequestHandler) processRequest(client RPCProxyClient, jsonReq *types.JsonRpcRequest, origin, referer string, isWhitehatBundleCollection bool, whitehatBundleId string, urlParams URLParameters, reqURL string) { +func (r *RpcRequestHandler) processRequest(client RPCProxyClient, jsonReq *types.JsonRpcRequest, origin, referer string, isWhitehatBundleCollection bool, whitehatBundleId string, urlParams URLParameters, reqURL string, body []byte) { var entry *database.EthSendRawTxEntry if jsonReq.Method == "eth_sendRawTransaction" { entry = r.requestRecord.AddEthSendRawTxEntry(uuid.New()) @@ -147,6 +147,13 @@ func (r *RpcRequestHandler) processRequest(client RPCProxyClient, jsonReq *types } // Handle single request rpcReq := NewRpcRequest(r.logger, client, jsonReq, r.relaySigningKey, r.relayUrl, origin, referer, isWhitehatBundleCollection, whitehatBundleId, entry, urlParams, r.chainID, r.rpcCache) + + if err := rpcReq.CheckFlashbotsSignature(r.req.Header.Get("X-Flashbots-Signature"), body); err != nil { + r.logger.Warn("[processRequest] CheckFlashbotsSignature", "error", err) + rpcReq.writeRpcError(err.Error(), types.JsonRpcInvalidRequest) + r._writeRpcResponse(rpcReq.jsonRes) + return + } res := rpcReq.ProcessRequest() // Write response r._writeRpcResponse(res) diff --git a/server/request_intercepts.go b/server/request_intercepts.go index c77ad4f..366d51d 100644 --- a/server/request_intercepts.go +++ b/server/request_intercepts.go @@ -150,3 +150,47 @@ func (r *RpcRequest) intercept_eth_call_to_FlashRPC_Contract() (requestFinished r.logger.Info("Intercepted eth_call to FlashRPC contract") return true } + +func (r *RpcRequest) intercept_signed_eth_getTransactionCount() (requestFinished bool) { + if r.flashbotsSigningAddress == "" { + r.logger.Info("[eth_getTransactionCount] No signature found") + return false + } + + if len(r.jsonReq.Params) != 2 { + r.logger.Info("[eth_getTransactionCount] Invalid params") + return false + } + + blockSpecifier, ok := r.jsonReq.Params[1].(string) + if !ok || blockSpecifier != "pending" { + r.logger.Info("[eth_getTransactionCount] non-pending blockSpecifier") + return false + } + + addr, ok := r.jsonReq.Params[0].(string) + if !ok { + r.logger.Info("[eth_getTransactionCount] non-string address") + return false + } + addr = strings.ToLower(addr) + if addr != strings.ToLower(r.flashbotsSigningAddress) { + r.logger.Info("[eth_getTransactionCount] address mismatch", "addr", addr, "signingAddress", r.flashbotsSigningAddress) + return false + } + + nonce, found, err := RState.GetSenderMaxNonce(addr) + if err != nil { + r.logger.Error("[eth_getTransactionCount] Redis:GetSenderMaxNonce error", "error", err) + return false + } + if !found { + r.logger.Info("[eth_getTransactionCount] No nonce found") + return false + } + + r.logger.Info("[eth_getTransactionCount] intercept", "nonce", nonce) + resp := fmt.Sprintf("0x%x", nonce+1) + r.writeRpcResult(resp) + return true +} diff --git a/server/request_processor.go b/server/request_processor.go index 46a176b..ced8d2c 100644 --- a/server/request_processor.go +++ b/server/request_processor.go @@ -13,6 +13,7 @@ import ( "strings" "time" + "github.com/flashbots/rpc-endpoint/adapters/flashbots" "github.com/flashbots/rpc-endpoint/application" "github.com/flashbots/rpc-endpoint/database" @@ -42,6 +43,7 @@ type RpcRequest struct { urlParams URLParameters chainID []byte rpcCache *application.RpcCache + flashbotsSigningAddress string } func NewRpcRequest( @@ -58,7 +60,7 @@ func NewRpcRequest( rpcCache *application.RpcCache, ) *RpcRequest { return &RpcRequest{ - logger: logger, + logger: logger.With("method", jsonReq.Method), client: client, jsonReq: jsonReq, relaySigningKey: relaySigningKey, @@ -102,6 +104,7 @@ func (r *RpcRequest) ProcessRequest() *types.JsonRpcResponse { case r.jsonReq.Method == "eth_sendRawTransaction": r.ethSendRawTxEntry.WhiteHatBundleId = r.whitehatBundleId r.handle_sendRawTransaction() + case r.jsonReq.Method == "eth_getTransactionCount" && r.intercept_signed_eth_getTransactionCount(): case r.jsonReq.Method == "eth_getTransactionCount" && r.intercept_mm_eth_getTransactionCount(): // intercept if MM needs to show an error to user case r.jsonReq.Method == "eth_call" && r.intercept_eth_call_to_FlashRPC_Contract(): // intercept if Flashbots isRPC contract case r.jsonReq.Method == "web3_clientVersion": @@ -498,3 +501,23 @@ func (r *RpcRequest) writeRpcResult(result interface{}) { Result: resBytes, } } + +// CheckFlashbotsSignature parses and validates the Flashbots signature if present, +// returning an error if the signature is invalid. If the signature is present and valid +// the signing address is stored in the request. +func (r *RpcRequest) CheckFlashbotsSignature(signature string, body []byte) error { + // Most requests don't have a signature, so avoid parsing it if it's empty + if signature == "" { + return nil + } + signingAddress, err := flashbots.ParseSignature(signature, body) + if err != nil { + if errors.Is(err, flashbots.ErrNoSignature) { + return nil + } else { + return err + } + } + r.flashbotsSigningAddress = signingAddress + return nil +}