Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Decode the method name from calldata automatically #121

Merged
merged 4 commits into from
Feb 4, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
17 changes: 9 additions & 8 deletions .github/workflows/push.yml
Original file line number Diff line number Diff line change
Expand Up @@ -12,13 +12,14 @@ jobs:
strategy:
matrix:
test:
- ./core/taskengine
- ./core/taskengine/trigger
- ./core/taskengine/macros
- ./pkg/timekeeper
- ./pkg/graphql
- ./pkg/erc4337/preset
- ./aggregator/
- aggregator
- core/taskengine
- core/taskengine/trigger
- core/taskengine/macros
- pkg/timekeeper
- pkg/graphql
- pkg/byte4
- pkg/erc4337/preset

steps:
- uses: actions/checkout@v4
Expand All @@ -35,7 +36,7 @@ jobs:
CONTROLLER_PRIVATE_KEY: "${{ secrets.CONTROLLER_PRIVATE_KEY }}"

run: |
cd ${{ matrix.test }}
cd ./${{ matrix.test }}
go test . -v

publish-dev-build:
Expand Down
13 changes: 10 additions & 3 deletions core/taskengine/vm_runner_contract_read.go
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@ import (
"github.com/ethereum/go-ethereum/common"
"github.com/ethereum/go-ethereum/ethclient"

"github.com/AvaProtocol/ap-avs/pkg/byte4"
avsproto "github.com/AvaProtocol/ap-avs/protobuf"
)

Expand Down Expand Up @@ -73,15 +74,21 @@ func (r *ContractReadProcessor) Execute(stepID string, node *avsproto.ContractRe
return s, err
}

// Unpack the output
result, err := parsedABI.Unpack(node.Method, output)
// Unpack the output by parsing the 4byte from calldata, compare with the right method in ABI
method, err := byte4.GetMethodFromCalldata(parsedABI, common.FromHex(node.CallData))
if err != nil {
s.Success = false
s.Error = fmt.Errorf("error detect method from ABI: %w", err).Error()
return s, err
}
result, err := parsedABI.Unpack(method.Name, output)
if err != nil {
s.Success = false
s.Error = fmt.Errorf("error decode result: %w", err).Error()
return s, err
}

log.WriteString(fmt.Sprintf("Call %s on %s at %s", node.Method, node.ContractAddress, time.Now()))
log.WriteString(fmt.Sprintf("Call %s on %s at %s", method.Name, node.ContractAddress, time.Now()))
s.Log = log.String()
outputData, err := json.Marshal(result)
s.OutputData = string(outputData)
Expand Down
2 changes: 0 additions & 2 deletions core/taskengine/vm_runner_contract_read_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,6 @@ func TestContractReadSimpleReturn(t *testing.T) {
ContractAddress: "0x1c7d4b196cb0c7b01d743fbc6116a902379c7238",
CallData: "0x70a08231000000000000000000000000ce289bb9fb0a9591317981223cbe33d5dc42268d",
ContractAbi: `[{"inputs":[{"internalType":"address","name":"account","type":"address"}],"name":"balanceOf","outputs":[{"internalType":"uint256","name":"","type":"uint256"}],"stateMutability":"view","type":"function"}]`,
Method: "balanceOf",
}
nodes := []*avsproto.TaskNode{
&avsproto.TaskNode{
Expand Down Expand Up @@ -70,7 +69,6 @@ func TestContractReadComplexReturn(t *testing.T) {
ContractAddress: "0xc59E3633BAAC79493d908e63626716e204A45EdF",
CallData: "0x9a6fc8f500000000000000000000000000000000000000000000000100000000000052e7",
ContractAbi: `[{"inputs":[{"internalType":"uint80","name":"_roundId","type":"uint80"}],"name":"getRoundData","outputs":[{"internalType":"uint80","name":"roundId","type":"uint80"},{"internalType":"int256","name":"answer","type":"int256"},{"internalType":"uint256","name":"startedAt","type":"uint256"},{"internalType":"uint256","name":"updatedAt","type":"uint256"},{"internalType":"uint80","name":"answeredInRound","type":"uint80"}],"stateMutability":"view","type":"function"}]`,
Method: "getRoundData",
}

nodes := []*avsproto.TaskNode{
Expand Down
41 changes: 41 additions & 0 deletions pkg/byte4/signature.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,41 @@
package byte4

import (
"bytes"
"fmt"
"strings"

"github.com/ethereum/go-ethereum/accounts/abi"
"github.com/ethereum/go-ethereum/crypto"
)

// GetMethodFromCalldata returns the method name and ABI method for a given 4-byte selector or full calldata
func GetMethodFromCalldata(parsedABI abi.ABI, selector []byte) (*abi.Method, error) {
if len(selector) < 4 {
return nil, fmt.Errorf("invalid selector length: %d", len(selector))
}

// Get first 4 bytes of the calldata. This is the first 8 characters of the calldata
// Function calls in the Ethereum Virtual Machine(EVM) are specified by the first four bytes of data sent with a transaction. These 4-byte signatures are defined as the first four bytes of the Keccak hash (SHA3) of the canonical representation of the function signature.

methodID := selector[:4]

// Find matching method in ABI
for name, method := range parsedABI.Methods {
// Build the signature string from inputs
var types []string
for _, input := range method.Inputs {
types = append(types, input.Type.String())
}

// Create method signature: name(type1,type2,...)
sig := fmt.Sprintf("%v(%v)", name, strings.Join(types, ","))
hash := crypto.Keccak256([]byte(sig))[:4]

if bytes.Equal(hash, methodID) {
return &method, nil
}
}

return nil, fmt.Errorf("no matching method found for selector: 0x%x", methodID)
}
132 changes: 132 additions & 0 deletions pkg/byte4/signature_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,132 @@
package byte4

import (
"encoding/hex"
"strings"
"testing"

"github.com/ethereum/go-ethereum/accounts/abi"
)

func TestGetMethodFromSelector(t *testing.T) {
// ERC20 ABI with transfer and balanceOf methods. These hash can generate locally or getting from Etherscan/Remix
const abiJSON = `[
{
"constant": false,
"inputs": [
{
"name": "_to",
"type": "address"
},
{
"name": "_value",
"type": "uint256"
}
],
"name": "transfer",
"outputs": [],
"payable": false,
"stateMutability": "nonpayable",
"type": "function"
},
{
"constant": true,
"inputs": [
{
"name": "who",
"type": "address"
}
],
"name": "balanceOf",
"outputs": [
{
"name": "",
"type": "uint256"
}
],
"payable": false,
"stateMutability": "view",
"type": "function"
}
]`

parsedABI, err := abi.JSON(strings.NewReader(abiJSON))
if err != nil {
t.Fatalf("failed to parse ABI: %v", err)
}

// Helper function to decode hex string to bytes
decodeHex := func(s string) []byte {
b, err := hex.DecodeString(s)
if err != nil {
t.Fatalf("failed to decode hex: %v", err)
}
return b
}

// our test tables for a list of expected/evaluation
tests := []struct {
name string
selector []byte
wantMethod string
wantErr bool
errContains string
}{
{
name: "valid balanceOf selector",
selector: decodeHex("70a08231000000000000000000000000ce289bb9fb0a9591317981223cbe33d5dc42268d"),
wantMethod: "balanceOf",
wantErr: false,
},
{
name: "valid transfer selector",
selector: decodeHex("a9059cbb000000000000000000000000ce289bb9fb0a9591317981223cbe33d5dc42268d0000000000000000000000000000000000000000000000000de0b6b3a7640000"),
wantMethod: "transfer",
wantErr: false,
},
{
name: "invalid selector length",
selector: []byte{0x70, 0xa0},
wantErr: true,
errContains: "invalid selector length",
},
{
name: "unknown selector",
selector: decodeHex("12345678"),
wantErr: true,
errContains: "no matching method found",
},
}

for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
method, err := GetMethodFromCalldata(parsedABI, tt.selector)

if tt.wantErr {
if err == nil {
t.Error("expected error but got nil")
return
}
if !strings.Contains(err.Error(), tt.errContains) {
t.Errorf("error %q does not contain %q", err.Error(), tt.errContains)
}
if method != nil {
t.Error("expected nil method but got non-nil")
}
return
}

if err != nil {
t.Errorf("unexpected error: %v", err)
return
}
if method == nil {
t.Error("expected non-nil method but got nil")
return
}
if method.Name != tt.wantMethod {
t.Errorf("got method %q, want %q", method.Name, tt.wantMethod)
}
})
}
}
Loading