diff --git a/scripts/retrieve_state/retrieve_state.go b/scripts/retrieve_state/retrieve_state.go index 236d410b35..3ac70afdd3 100644 --- a/scripts/retrieve_state/retrieve_state.go +++ b/scripts/retrieve_state/retrieve_state.go @@ -27,6 +27,11 @@ import ( var ( errZeroLengthResponse = errors.New("zero length response") errEmptyStateEntries = errors.New("empty state entries") + + supportedVersions = map[string]trie.TrieLayout{ + "v0": trie.V0, + "v1": trie.V1, + } ) type StateRequestProvider struct { @@ -86,9 +91,10 @@ func (s *StateRequestProvider) processResponse(stateResponse *messages.StateResp return nil } -func (s *StateRequestProvider) buildTrie(expectedStorageRootHash common.Hash, destination string) error { +func (s *StateRequestProvider) buildTrie(expectedStorageRootHash common.Hash, + destination string, v trie.TrieLayout) error { tt := inmemory.NewEmptyTrie() - tt.SetVersion(trie.V1) + tt.SetVersion(v) entries := make([]string, 0) @@ -127,10 +133,14 @@ func (s *StateRequestProvider) buildTrie(expectedStorageRootHash common.Hash, de } func main() { - if len(os.Args) != 5 { - log.Fatalf(` - script usage: - go run retrieve_state.go [hash] [expected storage root hash] [network chain spec] [output file]`) + if len(os.Args) != 6 { + log.Fatalf(`script usage: + go run retrieve_state.go [block hash] [expected state root hash] [network chain spec] [v0|v1] [output file]`) + } + + version, ok := supportedVersions[os.Args[4]] + if !ok { + log.Fatalf("ERR version not supported: %s", os.Args[4]) } targetBlockHash := common.MustHexToHash(os.Args[1]) @@ -186,7 +196,7 @@ func main() { refreshPeerID = false } - if err := provider.buildTrie(expectedStorageRootHash, os.Args[4]); err != nil { + if err := provider.buildTrie(expectedStorageRootHash, os.Args[5], version); err != nil { panic(err) } }