Skip to content

Commit

Permalink
api: add FindStorage RPC call to NeoRpcClient (#288)
Browse files Browse the repository at this point in the history
  • Loading branch information
ixje authored Oct 9, 2023
1 parent b0fd679 commit 8725549
Show file tree
Hide file tree
Showing 2 changed files with 101 additions and 1 deletion.
53 changes: 52 additions & 1 deletion neo3/api/noderpc.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,17 @@
from enum import Enum, IntEnum
from contextlib import suppress
from dataclasses import dataclass
from typing import Optional, TypedDict, Any, Protocol, Iterator, Union, cast, Type
from typing import (
Optional,
TypedDict,
Any,
Protocol,
Iterator,
Union,
cast,
Type,
AsyncGenerator,
)
from collections.abc import Sequence
from neo3.core import types, cryptography, interfaces, serialization
from neo3.contracts import manifest, nef, contract, abi
Expand Down Expand Up @@ -887,6 +897,47 @@ async def calculate_network_fee(self, tx: bytes | transaction.Transaction) -> in
result = await self._do_post("calculatenetworkfee", params)
return int(result["networkfee"])

async def find_states(
self, contract_hash: types.UInt160 | str, prefix: Optional[bytes] = None
) -> AsyncGenerator[tuple[bytes, bytes], None]:
"""
Fetch the smart contract storage state.
Args:
contract_hash: the hash of the smart contract to call.
prefix: storage prefix to search for. If omitted will return all storage
Returns:
a storage key/value pair
Examples:
# prints all deployed
prefix_contract_hash = b"\x0c"
async with api.NeoRpcClient("https://testnet1.neo.coz.io:443") as client:
async for k, v in client.find_states(CONTRACT_HASHES.MANAGEMENT, prefix_contract_hash):
print(k, v)
"""
if isinstance(contract_hash, str):
contract_hash = types.UInt160.from_string(contract_hash)
contract_hash = f"0x{contract_hash}"

if prefix is None:
prefix = b""
_prefix = base64.b64encode(prefix).decode()
start = 0
while True:
response = await self._do_post(
"findstorage", [contract_hash, _prefix, start]
)
for pair in response["results"]:
key = base64.b64decode(pair["key"])
value = base64.b64decode(pair["value"])
yield key, value
if not response["truncated"]:
break
start = response["next"]

async def get_application_log_transaction(
self, tx_hash: types.UInt256 | str
) -> TransactionApplicationLogResponse:
Expand Down
49 changes: 49 additions & 0 deletions tests/api/test_noderpc.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
class TestNeoRpcClient(unittest.IsolatedAsyncioTestCase):
async def asyncSetUp(self) -> None:
self.client = api.NeoRpcClient("localhost")
# CAREFULL THIS PATCHES ALL aiohttp CALLS!
self.helper = aioresponses()
self.helper.start()

Expand Down Expand Up @@ -43,6 +44,54 @@ async def test_calculate_network_fee(self):
)
self.assertEqual(123, response)

async def test_find_states(self):
key1 = b"\x0c\x00\x00\x00\x01"
key2 = b"\x0c\x00\x00\x00\x02"
key3 = b"\x0c\x00\x00\x00\x03"

value1 = b"\x97\"\x8dq\xd20\xaf\xde\\\xce\x8f\xf9'\x1f*\x9d(\x88u\xf0"
value2 = b"\x92,\x15\xa9\xa0\xe9\x00\x02\xed\xb4o\x1e>\xe4\xb7V\x8c\xb7%F"
value3 = b"\xe0\x98^\x9d\xf0w\xb0\x88v\x1eV\xb3m\x97\xef\x89\x08F\x12\x13"

captured1 = {
"truncated": True,
"next": 2,
"results": [
{
"key": base64.b64encode(key1).decode(),
"value": base64.b64encode(value1).decode(),
},
{
"key": base64.b64encode(key2).decode(),
"value": base64.b64encode(value2).decode(),
},
],
}
captured2 = {
"truncated": False,
"next": 3,
"results": [
{
"key": base64.b64encode(key3).decode(),
"value": base64.b64encode(value3).decode(),
}
],
}
self.mock_response(captured1)
self.mock_response(captured2)
from neo3.contracts.contract import CONTRACT_HASHES

results = []
async for k, v in self.client.find_states(CONTRACT_HASHES.MANAGEMENT, b"\x0c"):
results.append((k, v))
self.assertEqual(3, len(results))
self.assertEqual(key1, results[0][0])
self.assertEqual(value1, results[0][1])
self.assertEqual(key2, results[1][0])
self.assertEqual(value2, results[1][1])
self.assertEqual(key3, results[2][0])
self.assertEqual(value3, results[2][1])

async def test_get_application_log_transaction(self):
captured = {
"txid": "0x7da6ae7ff9d0b7af3d32f3a2feb2aa96c2a27ef8b651f9a132cfaad6ef20724c",
Expand Down

0 comments on commit 8725549

Please sign in to comment.