Skip to content

Commit

Permalink
tool(script): state retrieval script (#4140)
Browse files Browse the repository at this point in the history
Co-authored-by: JimboJ <[email protected]>
  • Loading branch information
EclesioMeloJunior and jimjbrettj authored Sep 2, 2024
1 parent 357ba23 commit 26b7961
Show file tree
Hide file tree
Showing 4 changed files with 297 additions and 13 deletions.
46 changes: 46 additions & 0 deletions dot/network/messages/state.go
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@ import (

pb "github.com/ChainSafe/gossamer/dot/network/proto"
"github.com/ChainSafe/gossamer/lib/common"
"github.com/ChainSafe/gossamer/pkg/trie"
"google.golang.org/protobuf/proto"
)

Expand Down Expand Up @@ -52,3 +53,48 @@ func (s *StateRequest) Decode(in []byte) error {
s.NoProof = message.NoProof
return nil
}

type StateResponse struct {
Entries []KeyValueStateEntry
Proof []byte
}

type KeyValueStateEntry struct {
StateRoot common.Hash
StateEntries trie.Entries
Complete bool
}

func (s *StateResponse) Decode(in []byte) error {
decodedResponse := &pb.StateResponse{}
err := proto.Unmarshal(in, decodedResponse)
if err != nil {
return err
}

s.Proof = make([]byte, len(decodedResponse.Proof))
copy(s.Proof, decodedResponse.Proof)

s.Entries = make([]KeyValueStateEntry, len(decodedResponse.Entries))
for idx, entry := range decodedResponse.Entries {
s.Entries[idx] = KeyValueStateEntry{
Complete: entry.Complete,
StateRoot: common.BytesToHash(entry.StateRoot),
}

trieFragment := make(trie.Entries, len(entry.Entries))
for stateEntryIdx, stateEntry := range entry.Entries {
trieFragment[stateEntryIdx] = trie.Entry{
Key: make([]byte, len(stateEntry.Key)),
Value: make([]byte, len(stateEntry.Value)),
}

copy(trieFragment[stateEntryIdx].Key, stateEntry.Key)
copy(trieFragment[stateEntryIdx].Value, stateEntry.Value)
}

s.Entries[idx].StateEntries = trieFragment
}

return nil
}
25 changes: 15 additions & 10 deletions scripts/p2p/common_p2p.go
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@ package p2p

import (
"encoding/json"
"errors"
"fmt"
"io"
"log"
Expand Down Expand Up @@ -117,20 +118,22 @@ func parsePeerAddress(arg string) peer.AddrInfo {
return *p
}

func ReadStream(stream lip2pnetwork.Stream) []byte {
var errZeroLength = errors.New("zero length")

func ReadStream(stream lip2pnetwork.Stream) ([]byte, error) {
responseBuf := make([]byte, network.MaxBlockResponseSize)

length, _, err := network.ReadLEB128ToUint64(stream)
if err != nil {
log.Fatalf("reading response length: %s", err.Error())
return nil, fmt.Errorf("reading leb128: %w", err)
}

if length == 0 {
return nil
return nil, errZeroLength
}

if length > network.MaxBlockResponseSize {
log.Fatalf("%s: max %d, got %d", network.ErrGreaterThanMaxSize, network.MaxBlockResponseSize, length)
return nil, fmt.Errorf("%w: max %d, got %d", network.ErrGreaterThanMaxSize, network.MaxBlockResponseSize, length)
}

if length > uint64(len(responseBuf)) {
Expand All @@ -142,22 +145,22 @@ func ReadStream(stream lip2pnetwork.Stream) []byte {
for tot < int(length) {
n, err := stream.Read(responseBuf[tot:])
if err != nil {
log.Fatalf("reading stream: %s", err.Error())
return nil, fmt.Errorf("reading stream: %w", err)
}
tot += n
}

if tot != int(length) {
log.Fatalf("%s: expected %d bytes, received %d bytes", network.ErrFailedToReadEntireMessage, length, tot)
return nil, fmt.Errorf("%w: expected %d bytes, received %d bytes", network.ErrFailedToReadEntireMessage, length, tot)
}

return responseBuf[:tot]
return responseBuf[:tot], nil
}

func WriteStream(msg *messages.BlockRequestMessage, stream lip2pnetwork.Stream) {
func WriteStream(msg messages.P2PMessage, stream lip2pnetwork.Stream) error {
encMsg, err := msg.Encode()
if err != nil {
log.Fatalf("encoding message: %s", err.Error())
return fmt.Errorf("encoding message: %w", err)
}

msgLen := uint64(len(encMsg))
Expand All @@ -166,6 +169,8 @@ func WriteStream(msg *messages.BlockRequestMessage, stream lip2pnetwork.Stream)

_, err = stream.Write(encMsg)
if err != nil {
log.Fatalf("writing message: %s", err.Error())
return fmt.Errorf("writing message: %w", err)
}

return nil
}
16 changes: 13 additions & 3 deletions scripts/retrieve_block/retrieve_block.go
Original file line number Diff line number Diff line change
Expand Up @@ -62,13 +62,18 @@ func parseTargetBlock(arg string) variadic.Uint32OrHash {
}

func waitAndStoreResponse(stream lip2pnetwork.Stream, outputFile string) bool {
output := p2p.ReadStream(stream)
output, err := p2p.ReadStream(stream)
if len(output) == 0 {
return false
}

if err != nil {
log.Println(err.Error())
return false
}

blockResponse := &messages.BlockResponseMessage{}
err := blockResponse.Decode(output)
err = blockResponse.Decode(output)
if err != nil {
log.Fatalf("could not decode block response message: %s", err.Error())
}
Expand Down Expand Up @@ -125,7 +130,12 @@ func main() {
}

defer stream.Close() //nolint:errcheck
p2p.WriteStream(requestMessage, stream)
err = p2p.WriteStream(requestMessage, stream)
if err != nil {
log.Println(err.Error())
continue
}

if !waitAndStoreResponse(stream, os.Args[3]) {
continue
}
Expand Down
223 changes: 223 additions & 0 deletions scripts/retrieve_state/retrieve_state.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,223 @@
// Copyright 2024 ChainSafe Systems (ON)
// SPDX-License-Identifier: LGPL-3.0-only

package main

import (
"context"
"crypto/rand"
"encoding/json"
"errors"
"fmt"
"log"
"math/big"
"os"

"github.com/ChainSafe/gossamer/dot/network/messages"
"github.com/ChainSafe/gossamer/lib/common"
"github.com/ChainSafe/gossamer/pkg/scale"
"github.com/ChainSafe/gossamer/pkg/trie"
"github.com/ChainSafe/gossamer/pkg/trie/inmemory"
"github.com/ChainSafe/gossamer/scripts/p2p"
lip2pnetwork "github.com/libp2p/go-libp2p/core/network"
"github.com/libp2p/go-libp2p/core/peer"
"github.com/libp2p/go-libp2p/core/protocol"
)

var (
errZeroLengthResponse = errors.New("zero length response")
errEmptyStateEntries = errors.New("empty state entries")
)

type StateRequestProvider struct {
lastKeys [][]byte
collectedResponses []*messages.StateResponse
targetHash common.Hash
completed bool
}

func NewStateRequestProvider(target common.Hash) *StateRequestProvider {
return &StateRequestProvider{
lastKeys: [][]byte{},
targetHash: target,
collectedResponses: make([]*messages.StateResponse, 0),
}
}

func (s *StateRequestProvider) buildRequest() *messages.StateRequest {
return &messages.StateRequest{
Block: s.targetHash,
Start: s.lastKeys,
NoProof: true,
}
}

func (s *StateRequestProvider) processResponse(stateResponse *messages.StateResponse) (err error) {
if len(stateResponse.Entries) == 0 {
return errEmptyStateEntries
}

log.Printf("retrieved %d entries\n", len(stateResponse.Entries))
for idx, entry := range stateResponse.Entries {
log.Printf("\t#%d with %d entries (complete: %v, root: %s)\n",
idx, len(entry.StateEntries), entry.Complete, entry.StateRoot.String())
}

s.collectedResponses = append(s.collectedResponses, stateResponse)

if len(s.lastKeys) == 2 && len(stateResponse.Entries[0].StateEntries) == 0 {
// pop last item and keep the first
// do not remove the parent trie position.
s.lastKeys = s.lastKeys[:len(s.lastKeys)-1]
} else {
s.lastKeys = [][]byte{}
}

for _, state := range stateResponse.Entries {
if !state.Complete {
lastItemInResponse := state.StateEntries[len(state.StateEntries)-1]
s.lastKeys = append(s.lastKeys, lastItemInResponse.Key)
s.completed = false
} else {
s.completed = true
}
}

return nil
}

func (s *StateRequestProvider) buildTrie(expectedStorageRootHash common.Hash, destination string) error {
tt := inmemory.NewEmptyTrie()
tt.SetVersion(trie.V1)

entries := make([]string, 0)

for _, stateResponse := range s.collectedResponses {
for _, stateEntry := range stateResponse.Entries {
for _, kv := range stateEntry.StateEntries {

trieEntry := trie.Entry{Key: kv.Key, Value: kv.Value}
encodedTrieEntry, err := scale.Marshal(trieEntry)
if err != nil {
return err
}
entries = append(entries, common.BytesToHex(encodedTrieEntry))

if err := tt.Put(kv.Key, kv.Value); err != nil {
return err
}
}
}
}

rootHash := tt.MustHash()
if expectedStorageRootHash != rootHash {
log.Printf("\n\texpected root hash: %s\ngot root hash: %s\n",
expectedStorageRootHash.String(), rootHash.String())
}

fmt.Printf("=> trie root hash: %s\n", tt.MustHash().String())
encodedEntries, err := json.Marshal(entries)
if err != nil {
return err
}

err = os.WriteFile(destination, encodedEntries, 0o600)
return err
}

func main() {
if len(os.Args) != 5 {
log.Fatalf(`
script usage:
go run retrieve_state.go [hash] [expected storage root hash] [network chain spec] [output file]`)
}

targetBlockHash := common.MustHexToHash(os.Args[1])
expectedStorageRootHash := common.MustHexToHash(os.Args[2])
chain := p2p.ParseChainSpec(os.Args[3])

ctx, cancel := context.WithCancel(context.Background())
defer cancel()

protocolID := protocol.ID(fmt.Sprintf("/%s/state/2", chain.ProtocolID))

p2pHost := p2p.SetupP2PClient()
bootnodes := p2p.ParseBootnodes(chain.Bootnodes)
provider := NewStateRequestProvider(targetBlockHash)

var (
pid peer.AddrInfo
refreshPeerID bool = true
)

for !provider.completed {
if refreshPeerID {
rng, err := rand.Int(rand.Reader, big.NewInt(int64(len(bootnodes))))
if err != nil {
panic(err)
}

pid = bootnodes[rng.Uint64()]
err = p2pHost.Connect(ctx, pid)
if err != nil {
log.Printf("WARN: while connecting: %s\n", err.Error())
continue
}

log.Printf("OK: requesting from peer %s\n", pid.String())
}

stream, err := p2pHost.NewStream(ctx, pid.ID, protocolID)
if err != nil {
log.Printf("WARN: failed to create stream using protocol %s: %s", protocolID, err.Error())
refreshPeerID = false
continue
}

err = sendAndProcessResponse(provider, stream)
if err != nil {
log.Printf("WARN: %s\n", err.Error())
refreshPeerID = true
continue
}

// keep using the same peer
refreshPeerID = false
}

if err := provider.buildTrie(expectedStorageRootHash, os.Args[4]); err != nil {
panic(err)
}
}

func sendAndProcessResponse(provider *StateRequestProvider, stream lip2pnetwork.Stream) error {
defer stream.Close() //nolint:errcheck

err := p2p.WriteStream(provider.buildRequest(), stream)
if err != nil {
return err
}

output, err := p2p.ReadStream(stream)
if err != nil {
return err
}

if len(output) == 0 {
return errZeroLengthResponse
}

stateResponse := &messages.StateResponse{}
err = stateResponse.Decode(output)
if err != nil {
return err
}

err = provider.processResponse(stateResponse)
if err != nil {
return err
}

return nil
}

0 comments on commit 26b7961

Please sign in to comment.