Skip to content

Commit

Permalink
feat: Return local nonce when getTransactionCount request is signed (#…
Browse files Browse the repository at this point in the history
  • Loading branch information
ryanschneider authored Aug 7, 2024
1 parent f59b440 commit cbffdfd
Show file tree
Hide file tree
Showing 5 changed files with 224 additions and 3 deletions.
67 changes: 67 additions & 0 deletions adapters/flashbots/signature.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,67 @@
// Package flashbots provides methods for parsing the X-Flashbots-Signature header.
package flashbots

import (
"errors"
"fmt"
"strings"

"github.com/ethereum/go-ethereum/accounts"
"github.com/ethereum/go-ethereum/common/hexutil"
"github.com/ethereum/go-ethereum/crypto"
)

var (
ErrNoSignature = errors.New("no signature provided")
ErrInvalidSignature = errors.New("invalid signature provided")
)

func ParseSignature(header string, body []byte) (signingAddress string, err error) {
if header == "" {
return "", ErrNoSignature
}

splitSig := strings.Split(header, ":")
if len(splitSig) != 2 {
return "", ErrInvalidSignature
}

return VerifySignature(body, splitSig[0], splitSig[1])
}

func VerifySignature(body []byte, signingAddressStr, signatureStr string) (signingAddress string, err error) {
signature, err := hexutil.Decode(signatureStr)
if err != nil || len(signature) == 0 {
return "", fmt.Errorf("%w: %w", ErrInvalidSignature, err)
}

if signature[len(signature)-1] >= 27 {
signature[len(signature)-1] -= 27
}

hashedBody := crypto.Keccak256Hash(body).Hex()
messageHash := accounts.TextHash([]byte(hashedBody))
signaturePublicKeyBytes, err := crypto.Ecrecover(messageHash, signature)
if err != nil {
return "", fmt.Errorf("%w: %w", ErrInvalidSignature, err)
}

publicKey, err := crypto.UnmarshalPubkey(signaturePublicKeyBytes)
if err != nil {
return "", fmt.Errorf("%w: %w", ErrInvalidSignature, err)
}
signaturePubkey := *publicKey
signaturePubKeyAddress := crypto.PubkeyToAddress(signaturePubkey).Hex()

// case-insensitive equality check
if !strings.EqualFold(signaturePubKeyAddress, signingAddressStr) {
return "", fmt.Errorf("%w: signing address mismatch", ErrInvalidSignature)
}

signatureNoRecoverID := signature[:len(signature)-1] // remove recovery id
if !crypto.VerifySignature(signaturePublicKeyBytes, messageHash, signatureNoRecoverID) {
return "", fmt.Errorf("%w: %w", ErrInvalidSignature, err)
}

return signaturePubKeyAddress, nil
}
80 changes: 80 additions & 0 deletions adapters/flashbots/signature_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,80 @@
package flashbots_test

import (
"fmt"
"testing"

"github.com/ethereum/go-ethereum/accounts"
"github.com/ethereum/go-ethereum/common/hexutil"
"github.com/ethereum/go-ethereum/crypto"
"github.com/stretchr/testify/require"

"github.com/flashbots/rpc-endpoint/adapters/flashbots"
)

func TestParseSignature(t *testing.T) {

// For most of these test cases, we first need to generate a signature
privateKey, err := crypto.GenerateKey()
require.NoError(t, err)

address := crypto.PubkeyToAddress(privateKey.PublicKey).Hex()
body := fmt.Sprintf(
`{"jsonrpc":"2.0","method":"eth_getTransactionCount","params":["%s","pending"],"id":1}`,
address,
)

signature, err := crypto.Sign(
accounts.TextHash([]byte(hexutil.Encode(crypto.Keccak256([]byte(body))))),
privateKey,
)
require.NoError(t, err)

header := fmt.Sprintf("%s:%s", address, hexutil.Encode(signature))

t.Run("header is empty", func(t *testing.T) {
_, err := flashbots.ParseSignature("", []byte{})
require.ErrorIs(t, err, flashbots.ErrNoSignature)
})

t.Run("header is valid", func(t *testing.T) {
signingAddress, err := flashbots.ParseSignature(header, []byte(body))
require.NoError(t, err)
require.Equal(t, address, signingAddress)
})

t.Run("header is invalid", func(t *testing.T) {
_, err := flashbots.ParseSignature("invalid", []byte(body))
require.ErrorIs(t, err, flashbots.ErrInvalidSignature)
})

t.Run("header has extra bytes", func(t *testing.T) {
_, err := flashbots.ParseSignature(header+"deadbeef", []byte(body))
require.ErrorIs(t, err, flashbots.ErrInvalidSignature)
})

t.Run("header has missing bytes", func(t *testing.T) {
_, err := flashbots.ParseSignature(header[:len(header)-8], []byte(body))
require.ErrorIs(t, err, flashbots.ErrInvalidSignature)
})

t.Run("body is empty", func(t *testing.T) {
_, err := flashbots.ParseSignature(header, []byte{})
require.ErrorIs(t, err, flashbots.ErrInvalidSignature)
})

t.Run("body is invalid", func(t *testing.T) {
_, err := flashbots.ParseSignature(header, []byte(`{}`))
require.ErrorIs(t, err, flashbots.ErrInvalidSignature)
})

t.Run("body has extra bytes", func(t *testing.T) {
_, err := flashbots.ParseSignature(header, []byte(body+"..."))
require.ErrorIs(t, err, flashbots.ErrInvalidSignature)
})

t.Run("body has missing bytes", func(t *testing.T) {
_, err := flashbots.ParseSignature(header, []byte(body[:len(body)-8]))
require.ErrorIs(t, err, flashbots.ErrInvalidSignature)
})
}
11 changes: 9 additions & 2 deletions server/request_handler.go
Original file line number Diff line number Diff line change
Expand Up @@ -134,11 +134,11 @@ func (r *RpcRequestHandler) process() {
r.logger = r.logger.New("rpc_method", jsonReq.Method)

// Process single request
r.processRequest(client, jsonReq, origin, referer, isWhitehatBundleCollection, whitehatBundleId, urlParams, r.req.URL.String())
r.processRequest(client, jsonReq, origin, referer, isWhitehatBundleCollection, whitehatBundleId, urlParams, r.req.URL.String(), body)
}

// processRequest handles single request
func (r *RpcRequestHandler) processRequest(client RPCProxyClient, jsonReq *types.JsonRpcRequest, origin, referer string, isWhitehatBundleCollection bool, whitehatBundleId string, urlParams URLParameters, reqURL string) {
func (r *RpcRequestHandler) processRequest(client RPCProxyClient, jsonReq *types.JsonRpcRequest, origin, referer string, isWhitehatBundleCollection bool, whitehatBundleId string, urlParams URLParameters, reqURL string, body []byte) {
var entry *database.EthSendRawTxEntry
if jsonReq.Method == "eth_sendRawTransaction" {
entry = r.requestRecord.AddEthSendRawTxEntry(uuid.New())
Expand All @@ -147,6 +147,13 @@ func (r *RpcRequestHandler) processRequest(client RPCProxyClient, jsonReq *types
}
// Handle single request
rpcReq := NewRpcRequest(r.logger, client, jsonReq, r.relaySigningKey, r.relayUrl, origin, referer, isWhitehatBundleCollection, whitehatBundleId, entry, urlParams, r.chainID, r.rpcCache)

if err := rpcReq.CheckFlashbotsSignature(r.req.Header.Get("X-Flashbots-Signature"), body); err != nil {
r.logger.Warn("[processRequest] CheckFlashbotsSignature", "error", err)
rpcReq.writeRpcError(err.Error(), types.JsonRpcInvalidRequest)
r._writeRpcResponse(rpcReq.jsonRes)
return
}
res := rpcReq.ProcessRequest()
// Write response
r._writeRpcResponse(res)
Expand Down
44 changes: 44 additions & 0 deletions server/request_intercepts.go
Original file line number Diff line number Diff line change
Expand Up @@ -150,3 +150,47 @@ func (r *RpcRequest) intercept_eth_call_to_FlashRPC_Contract() (requestFinished
r.logger.Info("Intercepted eth_call to FlashRPC contract")
return true
}

func (r *RpcRequest) intercept_signed_eth_getTransactionCount() (requestFinished bool) {
if r.flashbotsSigningAddress == "" {
r.logger.Info("[eth_getTransactionCount] No signature found")
return false
}

if len(r.jsonReq.Params) != 2 {
r.logger.Info("[eth_getTransactionCount] Invalid params")
return false
}

blockSpecifier, ok := r.jsonReq.Params[1].(string)
if !ok || blockSpecifier != "pending" {
r.logger.Info("[eth_getTransactionCount] non-pending blockSpecifier")
return false
}

addr, ok := r.jsonReq.Params[0].(string)
if !ok {
r.logger.Info("[eth_getTransactionCount] non-string address")
return false
}
addr = strings.ToLower(addr)
if addr != strings.ToLower(r.flashbotsSigningAddress) {
r.logger.Info("[eth_getTransactionCount] address mismatch", "addr", addr, "signingAddress", r.flashbotsSigningAddress)
return false
}

nonce, found, err := RState.GetSenderMaxNonce(addr)
if err != nil {
r.logger.Error("[eth_getTransactionCount] Redis:GetSenderMaxNonce error", "error", err)
return false
}
if !found {
r.logger.Info("[eth_getTransactionCount] No nonce found")
return false
}

r.logger.Info("[eth_getTransactionCount] intercept", "nonce", nonce)
resp := fmt.Sprintf("0x%x", nonce+1)
r.writeRpcResult(resp)
return true
}
25 changes: 24 additions & 1 deletion server/request_processor.go
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@ import (
"strings"
"time"

"github.com/flashbots/rpc-endpoint/adapters/flashbots"
"github.com/flashbots/rpc-endpoint/application"
"github.com/flashbots/rpc-endpoint/database"

Expand Down Expand Up @@ -42,6 +43,7 @@ type RpcRequest struct {
urlParams URLParameters
chainID []byte
rpcCache *application.RpcCache
flashbotsSigningAddress string
}

func NewRpcRequest(
Expand All @@ -58,7 +60,7 @@ func NewRpcRequest(
rpcCache *application.RpcCache,
) *RpcRequest {
return &RpcRequest{
logger: logger,
logger: logger.With("method", jsonReq.Method),
client: client,
jsonReq: jsonReq,
relaySigningKey: relaySigningKey,
Expand Down Expand Up @@ -102,6 +104,7 @@ func (r *RpcRequest) ProcessRequest() *types.JsonRpcResponse {
case r.jsonReq.Method == "eth_sendRawTransaction":
r.ethSendRawTxEntry.WhiteHatBundleId = r.whitehatBundleId
r.handle_sendRawTransaction()
case r.jsonReq.Method == "eth_getTransactionCount" && r.intercept_signed_eth_getTransactionCount():
case r.jsonReq.Method == "eth_getTransactionCount" && r.intercept_mm_eth_getTransactionCount(): // intercept if MM needs to show an error to user
case r.jsonReq.Method == "eth_call" && r.intercept_eth_call_to_FlashRPC_Contract(): // intercept if Flashbots isRPC contract
case r.jsonReq.Method == "web3_clientVersion":
Expand Down Expand Up @@ -498,3 +501,23 @@ func (r *RpcRequest) writeRpcResult(result interface{}) {
Result: resBytes,
}
}

// CheckFlashbotsSignature parses and validates the Flashbots signature if present,
// returning an error if the signature is invalid. If the signature is present and valid
// the signing address is stored in the request.
func (r *RpcRequest) CheckFlashbotsSignature(signature string, body []byte) error {
// Most requests don't have a signature, so avoid parsing it if it's empty
if signature == "" {
return nil
}
signingAddress, err := flashbots.ParseSignature(signature, body)
if err != nil {
if errors.Is(err, flashbots.ErrNoSignature) {
return nil
} else {
return err
}
}
r.flashbotsSigningAddress = signingAddress
return nil
}

0 comments on commit cbffdfd

Please sign in to comment.