Skip to content

Commit

Permalink
Fixes trace transaction tests
Browse files Browse the repository at this point in the history
  • Loading branch information
thiagodeev committed Sep 2, 2024
1 parent 84722cb commit 83ab9a3
Show file tree
Hide file tree
Showing 2 changed files with 224 additions and 25 deletions.
113 changes: 91 additions & 22 deletions rpc/trace_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -32,14 +32,14 @@ func TestTransactionTrace(t *testing.T) {

type testSetType struct {
TransactionHash *felt.Felt
ExpectedResp *InvokeTxnTrace
ExpectedError *RPCError
ExpectedResp TxnTrace
ExpectedError error
}
testSet := map[string][]testSetType{
"mock": {
testSetType{
TransactionHash: utils.TestHexToFelt(t, "0x6a4a9c4f1a530f7d6dd7bba9b71f090a70d1e3bbde80998fde11a08aab8b282"),
ExpectedResp: &expectedResp,
ExpectedResp: expectedResp,
ExpectedError: nil,
},
testSetType{
Expand All @@ -61,7 +61,7 @@ func TestTransactionTrace(t *testing.T) {
"testnet": {
testSetType{
TransactionHash: utils.TestHexToFelt(t, "0x6a4a9c4f1a530f7d6dd7bba9b71f090a70d1e3bbde80998fde11a08aab8b282"),
ExpectedResp: &expectedResp,
ExpectedResp: expectedResp,
ExpectedError: nil,
},
},
Expand All @@ -70,12 +70,8 @@ func TestTransactionTrace(t *testing.T) {

for _, test := range testSet {
resp, err := testConfig.provider.TraceTransaction(context.Background(), test.TransactionHash)
if err != nil {
require.Equal(t, test.ExpectedError, err)
} else {
invokeTrace := resp.(InvokeTxnTrace)
require.Equal(t, invokeTrace, *test.ExpectedResp)
}
require.Equal(t, test.ExpectedError, err)
compareTraceTxs(t, test.ExpectedResp, resp)
}
}

Expand Down Expand Up @@ -144,8 +140,11 @@ func TestSimulateTransaction(t *testing.T) {
test.SimulateTxnInput.Txns,
test.SimulateTxnInput.SimulationFlags)
require.NoError(t, err)
require.Equal(t, test.ExpectedResp.Txns[0].FeeEstimate, resp[0].FeeEstimate)
require.Len(t, test.ExpectedResp.Txns, len(resp))

for i, trace := range resp {
require.Equal(t, test.ExpectedResp.Txns[i].FeeEstimate, trace.FeeEstimate)
compareTraceTxs(t, test.ExpectedResp.Txns[i].TxnTrace, trace.TxnTrace)
}
}
}

Expand All @@ -163,12 +162,13 @@ func TestSimulateTransaction(t *testing.T) {
// none
func TestTraceBlockTransactions(t *testing.T) {
testConfig := beforeEach(t)
require := require.New(t)

var blockTraceSepolia []Trace

expectedrespRaw, err := os.ReadFile("./tests/trace/sepoliaBlockTrace_0x42a4c6a4c3dffee2cce78f04259b499437049b0084c3296da9fbbec7eda79b2.json")
require.NoError(t, err, "Error ReadFile for TestTraceBlockTransactions")
require.NoError(t, json.Unmarshal(expectedrespRaw, &blockTraceSepolia), "Error unmarshalling testdata TestTraceBlockTransactions")
require.NoError(err, "Error ReadFile for TestTraceBlockTransactions")
require.NoError(json.Unmarshal(expectedrespRaw, &blockTraceSepolia), "Error unmarshalling testdata TestTraceBlockTransactions")

type testSetType struct {
BlockID BlockID
Expand All @@ -178,12 +178,12 @@ func TestTraceBlockTransactions(t *testing.T) {
testSet := map[string][]testSetType{
"devnet": {}, // devenet doesn't support TraceBlockTransactions https://0xspaceshard.github.io/starknet-devnet/docs/guide/json-rpc-api#trace-api
"mainnet": {},
"testnet": { // TODO: there is a conflict between the test data and the rpc data, even though the data came from the same source...
// testSetType{
// BlockID: WithBlockNumber(99433),
// ExpectedResp: blockTraceSepolia,
// ExpectedErr: nil,
// },
"testnet": {
testSetType{
BlockID: WithBlockNumber(99433),
ExpectedResp: blockTraceSepolia,
ExpectedErr: nil,
},
},
"mock": {
testSetType{
Expand All @@ -202,10 +202,79 @@ func TestTraceBlockTransactions(t *testing.T) {
resp, err := testConfig.provider.TraceBlockTransactions(context.Background(), test.BlockID)

if err != nil {
require.Equal(t, test.ExpectedErr, err)
require.Equal(test.ExpectedErr, err)
} else {
require.EqualValues(t, test.ExpectedResp, resp)
for i, trace := range resp {
require.Equal(test.ExpectedResp[i].TxnHash, trace.TxnHash)
compareTraceTxs(t, test.ExpectedResp[i].TraceRoot, trace.TraceRoot)
}
}

}
}

func compareTraceTxs(t *testing.T, traceTx1, traceTx2 TxnTrace) {
require := require.New(t)

switch traceTx := traceTx1.(type) {
case DeclareTxnTrace:
require.Equal(traceTx.ValidateInvocation, traceTx2.(DeclareTxnTrace).ValidateInvocation)
require.Equal(traceTx.FeeTransferInvocation, traceTx2.(DeclareTxnTrace).FeeTransferInvocation)
compareStateDiffs(t, traceTx.StateDiff, traceTx2.(DeclareTxnTrace).StateDiff)
require.Equal(traceTx.Type, traceTx2.(DeclareTxnTrace).Type)
require.Equal(traceTx.ExecutionResources, traceTx2.(DeclareTxnTrace).ExecutionResources)
case DeployAccountTxnTrace:
require.Equal(traceTx.ValidateInvocation, traceTx2.(DeployAccountTxnTrace).ValidateInvocation)
require.Equal(traceTx.ConstructorInvocation, traceTx2.(DeployAccountTxnTrace).ConstructorInvocation)
require.Equal(traceTx.FeeTransferInvocation, traceTx2.(DeployAccountTxnTrace).FeeTransferInvocation)
compareStateDiffs(t, traceTx.StateDiff, traceTx2.(DeployAccountTxnTrace).StateDiff)
require.Equal(traceTx.Type, traceTx2.(DeployAccountTxnTrace).Type)
require.Equal(traceTx.ExecutionResources, traceTx2.(DeployAccountTxnTrace).ExecutionResources)
case InvokeTxnTrace:
require.Equal(traceTx.ValidateInvocation, traceTx2.(InvokeTxnTrace).ValidateInvocation)
require.Equal(traceTx.ExecuteInvocation, traceTx2.(InvokeTxnTrace).ExecuteInvocation)
require.Equal(traceTx.FeeTransferInvocation, traceTx2.(InvokeTxnTrace).FeeTransferInvocation)
compareStateDiffs(t, traceTx.StateDiff, traceTx2.(InvokeTxnTrace).StateDiff)
require.Equal(traceTx.Type, traceTx2.(InvokeTxnTrace).Type)
require.Equal(traceTx.ExecutionResources, traceTx2.(InvokeTxnTrace).ExecutionResources)
case L1HandlerTxnTrace:
require.Equal(traceTx.FunctionInvocation, traceTx2.(L1HandlerTxnTrace).FunctionInvocation)
compareStateDiffs(t, traceTx.StateDiff, traceTx2.(L1HandlerTxnTrace).StateDiff)
require.Equal(traceTx.Type, traceTx2.(L1HandlerTxnTrace).Type)
}
}

func compareStateDiffs(t *testing.T, stateDiff1, stateDiff2 StateDiff) {
require.ElementsMatch(t, stateDiff1.DeprecatedDeclaredClasses, stateDiff2.DeprecatedDeclaredClasses)
require.ElementsMatch(t, stateDiff1.DeclaredClasses, stateDiff2.DeclaredClasses)
require.ElementsMatch(t, stateDiff1.DeployedContracts, stateDiff2.DeployedContracts)
require.ElementsMatch(t, stateDiff1.ReplacedClasses, stateDiff2.ReplacedClasses)
require.ElementsMatch(t, stateDiff1.Nonces, stateDiff2.Nonces)

// compares storage diffs (they come in a random order)
rawStorageDiff, err := json.Marshal(stateDiff2.StorageDiffs)
require.NoError(t, err)
var mapDiff []map[string]interface{}
require.NoError(t, json.Unmarshal(rawStorageDiff, &mapDiff))

for _, diff1 := range stateDiff1.StorageDiffs {
var diff2 ContractStorageDiffItem

for _, diffElem := range mapDiff {
address, ok := diffElem["address"]
require.True(t, ok)
addressFelt := utils.TestHexToFelt(t, address.(string))

if *addressFelt != *diff1.Address {
continue
}

err = remarshal(diffElem, &diff2)
require.NoError(t, err)
}
require.NotEmpty(t, diff2)

require.Equal(t, diff1.Address, diff2.Address)
require.ElementsMatch(t, diff1.StorageEntries, diff2.StorageEntries)
}
}
136 changes: 133 additions & 3 deletions rpc/types_trace.go
Original file line number Diff line number Diff line change
@@ -1,6 +1,12 @@
package rpc

import "github.com/NethermindEth/juno/core/felt"
import (
"encoding/json"
"fmt"

"github.com/NethermindEth/juno/core/felt"
"github.com/NethermindEth/starknet.go/utils"
)

type SimulateTransactionInput struct {
//a sequence of transactions to simulate, running each transaction on the state resulting from applying all the previous ones
Expand All @@ -24,8 +30,8 @@ type SimulateTransactionOutput struct {
}

type SimulatedTransaction struct {
TxnTrace `json:"transaction_trace"`
FeeEstimate
TxnTrace `json:"transaction_trace"`
FeeEstimate `json:"fee_estimation"`
}

type TxnTrace interface{}
Expand Down Expand Up @@ -130,3 +136,127 @@ type ExecInvocation struct {
FunctionInvocation FnInvocation `json:"function_invocation,omitempty"`
RevertReason string `json:"revert_reason,omitempty"`
}

// UnmarshalJSON unmarshals the data into a SimulatedTransaction object.
//
// It takes a byte slice as the parameter, representing the JSON data to be unmarshalled.
// The function returns an error if the unmarshalling process fails.
//
// Parameters:
// - data: The JSON data to be unmarshalled
// Returns:
// - error: An error if the unmarshalling process fails
func (txn *SimulatedTransaction) UnmarshalJSON(data []byte) error {
var dec map[string]interface{}
if err := json.Unmarshal(data, &dec); err != nil {
return err
}

// SimulatedTransaction wraps transactions in the TxnTrace field.
rawTxnTrace, err := utils.UnwrapJSON(dec, "transaction_trace")
if err != nil {
return err
}

trace, err := unmarshalTraceTxn(rawTxnTrace)
if err != nil {
return err
}

var feeEstimate FeeEstimate

if feeEstimateData, ok := dec["fee_estimation"]; ok {
err = remarshal(feeEstimateData, &feeEstimate)
if err != nil {
return err
}
} else {
return fmt.Errorf("fee estimate not found")
}

*txn = SimulatedTransaction{
TxnTrace: trace,
FeeEstimate: feeEstimate,
}
return nil
}

// UnmarshalJSON unmarshals the data into a Trace object.
//
// It takes a byte slice as the parameter, representing the JSON data to be unmarshalled.
// The function returns an error if the unmarshalling process fails.
//
// Parameters:
// - data: The JSON data to be unmarshalled
// Returns:
// - error: An error if the unmarshalling process fails
func (txn *Trace) UnmarshalJSON(data []byte) error {
var dec map[string]interface{}
if err := json.Unmarshal(data, &dec); err != nil {
return err
}

// Trace wrap trace transactions in the TraceRoot field.
rawTraceTx, err := utils.UnwrapJSON(dec, "trace_root")
if err != nil {
return err
}

t, err := unmarshalTraceTxn(rawTraceTx)
if err != nil {
return err
}

var txHash *felt.Felt
if txHashData, ok := dec["transaction_hash"]; ok {
txHashString, ok := txHashData.(string)
if !ok {
return fmt.Errorf("failed to unmarshal transaction hash, transaction_hash is not a string")
}
txHash, err = utils.HexToFelt(txHashString)
if err != nil {
return err
}
} else {
return fmt.Errorf("failed to unmarshal transaction hash, transaction_hash not found")
}

*txn = Trace{
TraceRoot: t,
TxnHash: txHash,
}
return nil
}

// unmarshalTraceTxn unmarshals a given interface and returns a TxnTrace.
//
// Parameter:
// - t: The interface{} to be unmarshalled
// Returns:
// - TxnTrace: a TxnTrace
// - error: an error if the unmarshaling process fails
func unmarshalTraceTxn(t interface{}) (TxnTrace, error) {
switch casted := t.(type) {
case map[string]interface{}:
switch TransactionType(casted["type"].(string)) {
case TransactionType_Declare:
var txn DeclareTxnTrace
err := remarshal(casted, &txn)
return txn, err
case TransactionType_DeployAccount:
var txn DeployAccountTxnTrace
err := remarshal(casted, &txn)
return txn, err
case TransactionType_Invoke:
var txn InvokeTxnTrace
err := remarshal(casted, &txn)
return txn, err
case TransactionType_L1Handler:
var txn L1HandlerTxnTrace
err := remarshal(casted, &txn)
return txn, err
}
}

return nil, fmt.Errorf("unknown transaction type: %v", t)
}

0 comments on commit 83ab9a3

Please sign in to comment.