diff --git a/rpc/storage_test.go b/rpc/storage_test.go index 311e97f5b9..0182413719 100644 --- a/rpc/storage_test.go +++ b/rpc/storage_test.go @@ -2,6 +2,7 @@ package rpc_test import ( "context" + "errors" "testing" "time" @@ -162,20 +163,24 @@ func TestStorageProof(t *testing.T) { t.Run("Trie proofs sanity check", func(t *testing.T) { t.Parallel() - kbs := key.Bytes() - kKey := trie.NewKey(251, kbs[:]) - proof, err := trie.GetProof(&kKey, tempTrie) + kKey := tempTrie.FeltToKey(key) + proof := trie.NewProofSet() + err := tempTrie.Prove(key, proof) require.NoError(t, err) root, err := tempTrie.Root() require.NoError(t, err) - require.True(t, trie.VerifyProof(root, &kKey, value, proof, tempTrie.HashFunc())) + leaf, err := trie.VerifyProof(root, &kKey, proof, tempTrie.HashFunc()) + require.NoError(t, err) + require.Equal(t, leaf, value) // non-membership test - kbs = noSuchKey.Bytes() - kKey = trie.NewKey(251, kbs[:]) - proof, err = trie.GetProof(&kKey, tempTrie) + kKey = tempTrie.FeltToKey(noSuchKey) + proof = trie.NewProofSet() + err = tempTrie.Prove(key, proof) + require.NoError(t, err) + leaf, err = trie.VerifyProof(root, &kKey, proof, tempTrie.HashFunc()) require.NoError(t, err) - require.True(t, trie.VerifyProof(root, &kKey, nil, proof, tempTrie.HashFunc())) + require.Equal(t, felt.Zero, *leaf) }) t.Run("global roots are filled", func(t *testing.T) { t.Parallel() @@ -792,22 +797,15 @@ func TestVerifyPathfinderResponse(t *testing.T) { firstContractAddr := utils.HexToFelt(t, "0x5a03b82d726f9bb31ba41ea3a0c1143f90241e37c9a4a92174d168cda9c716d") firstContractLeaf := utils.HexToFelt(t, "0x6b45780618ce075fb4543396b3a6949915c04962b2e411c4f1b2a6813d540da") - verifyIf(t, - root, firstContractAddr, firstContractLeaf, - extractAndReorder(t, root, firstContractLeaf, result.ContractsProof.Nodes), - crypto.Pedersen) + verifyIf(t, root, firstContractAddr, firstContractLeaf, result.ContractsProof.Nodes, crypto.Pedersen) }) t.Run("second contract proof verification", func(t *testing.T) { t.Parallel() - t.Skip("verification issue on the length=1 edge node") secondContractAddr := utils.HexToFelt(t, "0x5fbaa249500be29fee38fdd90a7a2651a8d3935c14167570f6863f563d838f0") secondContractLeaf := utils.HexToFelt(t, "0x25790175fe1fbeed47cbf510a41fba8676bea20a0c8888d4b9090b8f5cf19b8") - verifyIf(t, - root, secondContractAddr, secondContractLeaf, - extractAndReorder(t, root, secondContractLeaf, result.ContractsProof.Nodes), - crypto.Pedersen) + verifyIf(t, root, secondContractAddr, secondContractLeaf, result.ContractsProof.Nodes, crypto.Pedersen) }) } @@ -843,14 +841,21 @@ func verifyIf( ) { t.Helper() - pnodes := []trie.ProofNode{} + proofSet := trie.NewProofSet() for _, hn := range proof { - pnodes = append(pnodes, hn.Node.AsProofNode()) + proofSet.Put(*hn.Hash, hn.Node.AsProofNode()) } kbs := key.Bytes() kkey := trie.NewKey(251, kbs[:]) - require.True(t, trie.VerifyProof(root, &kkey, value, pnodes, hashF)) + leaf, err := trie.VerifyProof(root, &kkey, proofSet, hashF) + require.NoError(t, err) + + // non-membership test + if value == nil { + value = felt.Zero.Clone() + } + require.Equal(t, leaf, value) } func verifyGlobalStateRoot(t *testing.T, globalStateRoot, classRoot, storageRoot *felt.Felt) { @@ -861,47 +866,3 @@ func verifyGlobalStateRoot(t *testing.T, globalStateRoot, classRoot, storageRoot assert.Equal(t, globalStateRoot, crypto.PoseidonArray(stateVersion, storageRoot, classRoot)) } } - -// extractAndReorder extracts single proof path from the root node to the edge node with `Child == leaf` -// nodes may contain many paths for different leaves, so we select one starting from the root and leading to the leaf -func extractAndReorder(t *testing.T, root, leaf *felt.Felt, nodes []*rpc.HashToNode) []*rpc.HashToNode { - t.Helper() - - // parents is reversed child to parent node mapping - parents := make(map[felt.Felt]*rpc.HashToNode) - for _, node := range nodes { - switch it := node.Node.(type) { - case *rpc.MerkleEdgeNode: - parents[*it.Child] = node - case *rpc.MerkleBinaryNode: - parents[*it.Left] = node - parents[*it.Right] = node - } - } - require.Contains(t, parents, *leaf) - - // extract path from leaf to root - path := []*rpc.HashToNode{} - limit := 256 - for next := *leaf; next != *root; { - node := parents[next] - path = append(path, node) - next = *node.Hash - - limit-- - if limit == 0 { - t.Fatal("cycle in the proof path") - } - } - - edge := parents[*leaf].Hash - require.Equal(t, edge, path[0].Hash) - require.Equal(t, root, path[len(path)-1].Hash) - - // reverse the path to start from the root - for i, j := 0, len(path)-1; i < j; i, j = i+1, j-1 { - path[i], path[j] = path[j], path[i] - } - - return path -}