Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: Return local nonce when getTransactionCount request is signed #151

Merged
merged 5 commits into from
Aug 7, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
67 changes: 67 additions & 0 deletions adapters/flashbots/signature.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,67 @@
// Package flashbots provides methods for parsing the X-Flashbots-Signature header.
metachris marked this conversation as resolved.
Show resolved Hide resolved
package flashbots

import (
"errors"
"fmt"
"strings"

"github.com/ethereum/go-ethereum/accounts"
"github.com/ethereum/go-ethereum/common/hexutil"
"github.com/ethereum/go-ethereum/crypto"
)

var (
ErrNoSignature = errors.New("no signature provided")
ErrInvalidSignature = errors.New("invalid signature provided")
)

func ParseSignature(header string, body []byte) (signingAddress string, err error) {
if header == "" {
return "", ErrNoSignature
}

splitSig := strings.Split(header, ":")
if len(splitSig) != 2 {
return "", ErrInvalidSignature
}

return VerifySignature(body, splitSig[0], splitSig[1])
}

func VerifySignature(body []byte, signingAddressStr, signatureStr string) (signingAddress string, err error) {
signature, err := hexutil.Decode(signatureStr)
if err != nil || len(signature) == 0 {
return "", fmt.Errorf("%w: %w", ErrInvalidSignature, err)
}

if signature[len(signature)-1] >= 27 {
signature[len(signature)-1] -= 27
}

hashedBody := crypto.Keccak256Hash(body).Hex()
messageHash := accounts.TextHash([]byte(hashedBody))
signaturePublicKeyBytes, err := crypto.Ecrecover(messageHash, signature)
if err != nil {
return "", fmt.Errorf("%w: %w", ErrInvalidSignature, err)
}

publicKey, err := crypto.UnmarshalPubkey(signaturePublicKeyBytes)
if err != nil {
return "", fmt.Errorf("%w: %w", ErrInvalidSignature, err)
}
signaturePubkey := *publicKey
signaturePubKeyAddress := crypto.PubkeyToAddress(signaturePubkey).Hex()

// case-insensitive equality check
if !strings.EqualFold(signaturePubKeyAddress, signingAddressStr) {
return "", fmt.Errorf("%w: signing address mismatch", ErrInvalidSignature)
}

signatureNoRecoverID := signature[:len(signature)-1] // remove recovery id
if !crypto.VerifySignature(signaturePublicKeyBytes, messageHash, signatureNoRecoverID) {
return "", fmt.Errorf("%w: %w", ErrInvalidSignature, err)
}

return signaturePubKeyAddress, nil
}
80 changes: 80 additions & 0 deletions adapters/flashbots/signature_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,80 @@
package flashbots_test

import (
"fmt"
"testing"

"github.com/ethereum/go-ethereum/accounts"
"github.com/ethereum/go-ethereum/common/hexutil"
"github.com/ethereum/go-ethereum/crypto"
"github.com/stretchr/testify/require"

"github.com/flashbots/rpc-endpoint/adapters/flashbots"
)

func TestParseSignature(t *testing.T) {

// For most of these test cases, we first need to generate a signature
privateKey, err := crypto.GenerateKey()
require.NoError(t, err)

address := crypto.PubkeyToAddress(privateKey.PublicKey).Hex()
body := fmt.Sprintf(
`{"jsonrpc":"2.0","method":"eth_getTransactionCount","params":["%s","pending"],"id":1}`,
address,
)

signature, err := crypto.Sign(
accounts.TextHash([]byte(hexutil.Encode(crypto.Keccak256([]byte(body))))),
privateKey,
)
require.NoError(t, err)

header := fmt.Sprintf("%s:%s", address, hexutil.Encode(signature))

t.Run("header is empty", func(t *testing.T) {
_, err := flashbots.ParseSignature("", []byte{})
require.ErrorIs(t, err, flashbots.ErrNoSignature)
})

t.Run("header is valid", func(t *testing.T) {
signingAddress, err := flashbots.ParseSignature(header, []byte(body))
require.NoError(t, err)
require.Equal(t, address, signingAddress)
})

t.Run("header is invalid", func(t *testing.T) {
_, err := flashbots.ParseSignature("invalid", []byte(body))
require.ErrorIs(t, err, flashbots.ErrInvalidSignature)
})

t.Run("header has extra bytes", func(t *testing.T) {
_, err := flashbots.ParseSignature(header+"deadbeef", []byte(body))
require.ErrorIs(t, err, flashbots.ErrInvalidSignature)
})

t.Run("header has missing bytes", func(t *testing.T) {
_, err := flashbots.ParseSignature(header[:len(header)-8], []byte(body))
require.ErrorIs(t, err, flashbots.ErrInvalidSignature)
})

t.Run("body is empty", func(t *testing.T) {
_, err := flashbots.ParseSignature(header, []byte{})
require.ErrorIs(t, err, flashbots.ErrInvalidSignature)
})

t.Run("body is invalid", func(t *testing.T) {
_, err := flashbots.ParseSignature(header, []byte(`{}`))
require.ErrorIs(t, err, flashbots.ErrInvalidSignature)
})

t.Run("body has extra bytes", func(t *testing.T) {
_, err := flashbots.ParseSignature(header, []byte(body+"..."))
require.ErrorIs(t, err, flashbots.ErrInvalidSignature)
})

t.Run("body has missing bytes", func(t *testing.T) {
_, err := flashbots.ParseSignature(header, []byte(body[:len(body)-8]))
require.ErrorIs(t, err, flashbots.ErrInvalidSignature)
})
}
11 changes: 9 additions & 2 deletions server/request_handler.go
Original file line number Diff line number Diff line change
Expand Up @@ -134,11 +134,11 @@ func (r *RpcRequestHandler) process() {
r.logger = r.logger.New("rpc_method", jsonReq.Method)

// Process single request
r.processRequest(client, jsonReq, origin, referer, isWhitehatBundleCollection, whitehatBundleId, urlParams, r.req.URL.String())
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I wonder if a better approach would be not to pass the full body around for possible later signature check, but instead here do a signature check if the header is present, and only pass the result around.

Downside of that approach: doing the signature check also on requests that wouldn't need it. Upside: everything is a bit simpler.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ya I like that too, I'll make that change and add tests for the signatures.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Refactored to that approach in d90fb9b.

r.processRequest(client, jsonReq, origin, referer, isWhitehatBundleCollection, whitehatBundleId, urlParams, r.req.URL.String(), body)
ryanschneider marked this conversation as resolved.
Show resolved Hide resolved
}

// processRequest handles single request
func (r *RpcRequestHandler) processRequest(client RPCProxyClient, jsonReq *types.JsonRpcRequest, origin, referer string, isWhitehatBundleCollection bool, whitehatBundleId string, urlParams URLParameters, reqURL string) {
func (r *RpcRequestHandler) processRequest(client RPCProxyClient, jsonReq *types.JsonRpcRequest, origin, referer string, isWhitehatBundleCollection bool, whitehatBundleId string, urlParams URLParameters, reqURL string, body []byte) {
var entry *database.EthSendRawTxEntry
if jsonReq.Method == "eth_sendRawTransaction" {
entry = r.requestRecord.AddEthSendRawTxEntry(uuid.New())
Expand All @@ -147,6 +147,13 @@ func (r *RpcRequestHandler) processRequest(client RPCProxyClient, jsonReq *types
}
// Handle single request
rpcReq := NewRpcRequest(r.logger, client, jsonReq, r.relaySigningKey, r.relayUrl, origin, referer, isWhitehatBundleCollection, whitehatBundleId, entry, urlParams, r.chainID, r.rpcCache)

if err := rpcReq.CheckFlashbotsSignature(r.req.Header.Get("X-Flashbots-Signature"), body); err != nil {
r.logger.Warn("[processRequest] CheckFlashbotsSignature", "error", err)
rpcReq.writeRpcError(err.Error(), types.JsonRpcInvalidRequest)
r._writeRpcResponse(rpcReq.jsonRes)
return
}
res := rpcReq.ProcessRequest()
// Write response
r._writeRpcResponse(res)
Expand Down
44 changes: 44 additions & 0 deletions server/request_intercepts.go
Original file line number Diff line number Diff line change
Expand Up @@ -150,3 +150,47 @@ func (r *RpcRequest) intercept_eth_call_to_FlashRPC_Contract() (requestFinished
r.logger.Info("Intercepted eth_call to FlashRPC contract")
return true
}

func (r *RpcRequest) intercept_signed_eth_getTransactionCount() (requestFinished bool) {
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I would love if we could start doing unit / e2e tests for these handlers, but that seems def out of scope for now.

if r.flashbotsSigningAddress == "" {
r.logger.Info("[eth_getTransactionCount] No signature found")
return false
}

if len(r.jsonReq.Params) != 2 {
r.logger.Info("[eth_getTransactionCount] Invalid params")
return false
}

blockSpecifier, ok := r.jsonReq.Params[1].(string)
if !ok || blockSpecifier != "pending" {
r.logger.Info("[eth_getTransactionCount] non-pending blockSpecifier")
return false
}

addr, ok := r.jsonReq.Params[0].(string)
if !ok {
r.logger.Info("[eth_getTransactionCount] non-string address")
return false
}
addr = strings.ToLower(addr)
if addr != strings.ToLower(r.flashbotsSigningAddress) {
r.logger.Info("[eth_getTransactionCount] address mismatch", "addr", addr, "signingAddress", r.flashbotsSigningAddress)
return false
}

nonce, found, err := RState.GetSenderMaxNonce(addr)
if err != nil {
r.logger.Error("[eth_getTransactionCount] Redis:GetSenderMaxNonce error", "error", err)
return false
}
if !found {
r.logger.Info("[eth_getTransactionCount] No nonce found")
return false
}

r.logger.Info("[eth_getTransactionCount] intercept", "nonce", nonce)
resp := fmt.Sprintf("0x%x", nonce+1)
r.writeRpcResult(resp)
return true
}
25 changes: 24 additions & 1 deletion server/request_processor.go
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@ import (
"strings"
"time"

"github.com/flashbots/rpc-endpoint/adapters/flashbots"
"github.com/flashbots/rpc-endpoint/application"
"github.com/flashbots/rpc-endpoint/database"

Expand Down Expand Up @@ -42,6 +43,7 @@ type RpcRequest struct {
urlParams URLParameters
chainID []byte
rpcCache *application.RpcCache
flashbotsSigningAddress string
}

func NewRpcRequest(
Expand All @@ -58,7 +60,7 @@ func NewRpcRequest(
rpcCache *application.RpcCache,
) *RpcRequest {
return &RpcRequest{
logger: logger,
logger: logger.With("method", jsonReq.Method),
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nice

client: client,
jsonReq: jsonReq,
relaySigningKey: relaySigningKey,
Expand Down Expand Up @@ -102,6 +104,7 @@ func (r *RpcRequest) ProcessRequest() *types.JsonRpcResponse {
case r.jsonReq.Method == "eth_sendRawTransaction":
r.ethSendRawTxEntry.WhiteHatBundleId = r.whitehatBundleId
r.handle_sendRawTransaction()
case r.jsonReq.Method == "eth_getTransactionCount" && r.intercept_signed_eth_getTransactionCount():
case r.jsonReq.Method == "eth_getTransactionCount" && r.intercept_mm_eth_getTransactionCount(): // intercept if MM needs to show an error to user
case r.jsonReq.Method == "eth_call" && r.intercept_eth_call_to_FlashRPC_Contract(): // intercept if Flashbots isRPC contract
case r.jsonReq.Method == "web3_clientVersion":
Expand Down Expand Up @@ -498,3 +501,23 @@ func (r *RpcRequest) writeRpcResult(result interface{}) {
Result: resBytes,
}
}

// CheckFlashbotsSignature parses and validates the Flashbots signature if present,
// returning an error if the signature is invalid. If the signature is present and valid
// the signing address is stored in the request.
func (r *RpcRequest) CheckFlashbotsSignature(signature string, body []byte) error {
// Most requests don't have a signature, so avoid parsing it if it's empty
if signature == "" {
return nil
}
signingAddress, err := flashbots.ParseSignature(signature, body)
if err != nil {
if errors.Is(err, flashbots.ErrNoSignature) {
return nil
} else {
return err
}
}
r.flashbotsSigningAddress = signingAddress
return nil
}
Loading