Skip to content

Commit

Permalink
feat(tests/scripts): create script to retrieve trie state via rpc (#3714
Browse files Browse the repository at this point in the history
)
  • Loading branch information
jimjbrettj authored Jan 29, 2024
1 parent 4566b14 commit 5ccea40
Show file tree
Hide file tree
Showing 2 changed files with 264 additions and 0 deletions.
138 changes: 138 additions & 0 deletions scripts/trie_state_script.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,138 @@
// Copyright 2024 ChainSafe Systems (ON)
// SPDX-License-Identifier: LGPL-3.0-only

package main

import (
"context"
"encoding/json"
"fmt"
"os"
"time"

"github.com/ChainSafe/gossamer/dot/rpc/modules"
"github.com/ChainSafe/gossamer/lib/common"
"github.com/ChainSafe/gossamer/lib/trie"
"github.com/ChainSafe/gossamer/pkg/scale"
"github.com/ChainSafe/gossamer/tests/utils/rpc"
)

func fetchWithTimeout(ctx context.Context,
method, params string, target interface{}) {

// Can adjust timeout as desired, default is very long
getResponseCtx, getResponseCancel := context.WithTimeout(ctx, 1000000*time.Second)
defer getResponseCancel()
err := getResponse(getResponseCtx, method, params, target)
if err != nil {
panic(fmt.Sprintf("error getting response %v", err))
}
}

func getResponse(ctx context.Context, method, params string, target interface{}) (err error) {
const rpcPort = "8545"
endpoint := rpc.NewEndpoint(rpcPort)
respBody, err := rpc.Post(ctx, endpoint, method, params)
if err != nil {
return fmt.Errorf("cannot RPC post: %w", err)
}

err = rpc.Decode(respBody, &target)
if err != nil {
return fmt.Errorf("cannot decode RPC response: %w", err)
}

return nil
}

func writeTrieState(response modules.StateTrieResponse, destination string) {
encResponse, err := json.Marshal(response)
if err != nil {
panic(fmt.Sprintf("json marshalling response %v", err))
}

err = os.WriteFile(destination, encResponse, 0o600)
if err != nil {
panic(fmt.Sprintf("writing to file %v", err))
}
}

func fetchTrieState(ctx context.Context, blockHash common.Hash, destination string) modules.StateTrieResponse {
params := fmt.Sprintf(`["%s"]`, blockHash)
var response modules.StateTrieResponse
fetchWithTimeout(ctx, "state_trie", params, &response)

writeTrieState(response, destination)
return response
}

func compareStateRoots(response modules.StateTrieResponse, expectedStateRoot common.Hash, trieVersion trie.TrieLayout) {
entries := make(map[string]string, len(response))
for _, encodedEntry := range response {
bytesEncodedEntry := common.MustHexToBytes(encodedEntry)

entry := trie.Entry{}
err := scale.Unmarshal(bytesEncodedEntry, &entry)
if err != nil {
panic(fmt.Sprintf("error unmarshalling into trie entry %v", err))
}
entries[common.BytesToHex(entry.Key)] = common.BytesToHex(entry.Value)
}

newTrie, err := trie.LoadFromMap(entries)
if err != nil {
panic(fmt.Sprintf("loading trie from map %v", err))
}

trieHash := trieVersion.MustHash(newTrie)
if expectedStateRoot != trieHash {
panic("westendDevStateRoot does not match trieHash")
}
}

/*
This is a script to query the trie state from a specific block height from a running node.
Example commands to run a node:
1. ./bin/gossamer init --chain westend-dev --key alice
2. ./bin/gossamer --chain westend-dev --key alice --rpc-external=true --unsafe-rpc=true
Once the node has started and processed the block whose state you need, can execute the script like so:
1. go run trieStateScript.go <block hash> <destination file> <optional: expected state root> <optional: trie version>
*/
func main() {
if len(os.Args) < 3 {
panic("expected more arguments, block hash and destination file required")
}

blockHash, err := common.HexToHash(os.Args[1])
if err != nil {
panic("block hash must be in hex format")
}

destinationFile := os.Args[2]
expectedStateRoot := common.Hash{}
var trieVersion trie.TrieLayout
if len(os.Args) == 5 {
expectedStateRoot, err = common.HexToHash(os.Args[3])
if err != nil {
panic("expected state root must be in hex format")
}

trieVersion, err = trie.ParseVersion(os.Args[4])
if err != nil {
panic("trie version must be an integer")
}
} else if len(os.Args) != 3 {
panic("invalid number of arguments")
}

ctx, _ := context.WithCancel(context.Background()) //nolint
response := fetchTrieState(ctx, blockHash, destinationFile)

if !expectedStateRoot.IsEmpty() {
compareStateRoots(response, expectedStateRoot, trieVersion)
}
}
126 changes: 126 additions & 0 deletions scripts/trie_state_script_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,126 @@
// Copyright 2024 ChainSafe Systems (ON)
// SPDX-License-Identifier: LGPL-3.0-only

package main

import (
"os"
"testing"

"github.com/ChainSafe/gossamer/dot/rpc/modules"
"github.com/ChainSafe/gossamer/lib/common"
"github.com/ChainSafe/gossamer/lib/trie"
"github.com/stretchr/testify/require"
)

// This is fake data used just for testing purposes
var testStateData = []string{"0x801cb6f36e027abb2091cfb5110ab5087faacf00b9b41fda7a9268821c2a2b3e4ca404d43593c715fdd31c61141abd04a99fd6822c8558854ccde39a5684e7a56da27d0100000000000000", "0x801cb6f36e027abb2091cfb5110ab5087faacf00b9b41fda7a9268821c2a2b3e4ca404d43593c715fdd31c61141abd04a99fd6822c8558854ccde39a5684e7a56da27d0100000000000000", "0x801cb6f36e027abb2091cfb5110ab5087faacf00b9b41fda7a9268821c2a2b3e4ca404d43593c715fdd31c61141abd04a99fd6822c8558854ccde39a5684e7a56da27d0100000000000000", "0x801cb6f36e027abb2091cfb5110ab5087faacf00b9b41fda7a9268821c2a2b3e4ca404d43593c715fdd31c61141abd04a99fd6822c8558854ccde39a5684e7a56da27d0100000000000000", "0x801cb6f36e027abb2091cfb5110ab5087faacf00b9b41fda7a9268821c2a2b3e4ca404d43593c715fdd31c61141abd04a99fd6822c8558854ccde39a5684e7a56da27d0100000000000000", "0x801cb6f36e027abb2091cfb5110ab5087faacf00b9b41fda7a9268821c2a2b3e4ca404d43593c715fdd31c61141abd04a99fd6822c8558854ccde39a5684e7a56da27d0100000000000000", "0x801cb6f36e027abb2091cfb5110ab5087faacf00b9b41fda7a9268821c2a2b3e4ca404d43593c715fdd31c61141abd04a99fd6822c8558854ccde39a5684e7a56da27d0100000000000000", "0x801cb6f36e027abb2091cfb5110ab5087faacf00b9b41fda7a9268821c2a2b3e4ca404d43593c715fdd31c61141abd04a99fd6822c8558854ccde39a5684e7a56da27d0100000000000000"} //nolint

func clean(t *testing.T, file string) {
t.Helper()
err := os.Remove(file)
require.NoError(t, err)
}

func Test_writeTrieState(t *testing.T) {
writeTrieState(testStateData, "westendDevTestState.json")
_, err := os.Stat("./westendDevTestState.json")
require.NoError(t, err)

clean(t, "westendDevTestState.json")
}

func Test_compareStateRoots(t *testing.T) {
type args struct {
response modules.StateTrieResponse
expectedStateRoot common.Hash
trieVersion trie.TrieLayout
}
tests := []struct {
name string
args args
shouldPanic bool
}{
{
name: "happy_path",
args: args{
response: testStateData,
expectedStateRoot: common.MustHexToHash("0x3b1863ff981a31864be76037e4cf5c927b937dd8a8e1e25494128da7a95b5cdf"),
trieVersion: 0,
},
},
{
name: "invalid_trie_version",
args: args{
response: testStateData,
expectedStateRoot: common.MustHexToHash("0x6120d3afde6c139305bd7c0dcf50bdff5b620203e00c7491b2c30f95dccacc32"),
trieVersion: 21,
},
shouldPanic: true,
},
{
name: "hashes_do_not_match",
args: args{
response: testStateData,
expectedStateRoot: common.MustHexToHash("0x01"),
trieVersion: 21,
},
shouldPanic: true,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
if tt.shouldPanic {
require.Panics(t,
func() {
compareStateRoots(tt.args.response, tt.args.expectedStateRoot, tt.args.trieVersion)
},
"The code did not panic")
} else {
compareStateRoots(tt.args.response, tt.args.expectedStateRoot, tt.args.trieVersion)
}
})
}
}

func Test_cli(t *testing.T) {
tests := []struct {
name string
args []string
}{
{
name: "no_arguments",
},
{
name: "to_few_arguments",
args: []string{"0x01"},
},
{
name: "invalid_formatting_for_block_hash",
args: []string{"hello", "output.json"},
},
{
name: "no_trie_version",
args: []string{"0x01", "output.json", "0x01"},
},
{
name: "invalid_formatting_for_root_hash",
args: []string{"0x01", "output.json", "hello", "1"},
},
{
name: "invalid_trie_version",
args: []string{"0x01", "output.json", "0x01", "hello"},
},
{
name: "to_many_arguments",
args: []string{"0x01", "output.json", "0x01", "1", "0x01"},
},
}
for _, tt := range tests {
tt := tt
t.Run(tt.name, func(t *testing.T) {
os.Args = tt.args
require.Panics(t, func() { main() }, "The code did not panic")
})
}
}

0 comments on commit 5ccea40

Please sign in to comment.