Skip to content

Commit

Permalink
Add ClientSession.eth_get_logs() and eth_get_filter_logs()
Browse files Browse the repository at this point in the history
  • Loading branch information
fjarri committed Feb 10, 2024
1 parent 95cf03e commit b52e8f0
Show file tree
Hide file tree
Showing 4 changed files with 154 additions and 28 deletions.
1 change: 1 addition & 0 deletions docs/changelog.rst
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,7 @@ Added
- ``eth_getCode`` support (as ``ClientSession.eth_get_code()``). (PR_64_)
- ``eth_getStorageAt`` support (as ``ClientSession.eth_get_storage_at()``). (PR_64_)
- Support for the ``logs`` field in ``TxReceipt``. (PR_68_)
- ``ClientSession.eth_get_logs()`` and ``eth_get_filter_logs()``. (PR_68_)


Fixed
Expand Down
101 changes: 73 additions & 28 deletions pons/_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -739,6 +739,49 @@ async def transact(

return results

def _encode_filter_params(
self,
source: Optional[Union[Address, Iterable[Address]]],
event_filter: Optional[EventFilter],
from_block: Union[int, Block],
to_block: Union[int, Block],
) -> JSON:
params: Dict[str, Any] = {
"fromBlock": rpc_encode_block(from_block),
"toBlock": rpc_encode_block(to_block),
}
if isinstance(source, Address):
params["address"] = source.rpc_encode()
elif source:
params["address"] = [address.rpc_encode() for address in source]
if event_filter:
encoded_topics: List[Optional[List[str]]] = []
for topic in event_filter.topics:
if topic is None:
encoded_topics.append(None)
else:
encoded_topics.append([elem.rpc_encode() for elem in topic])
params["topics"] = encoded_topics
return params

@rpc_call("eth_getLogs")
async def eth_get_logs(
self,
source: Optional[Union[Address, Iterable[Address]]] = None,
event_filter: Optional[EventFilter] = None,
from_block: Union[int, Block] = Block.LATEST,
to_block: Union[int, Block] = Block.LATEST,
) -> Tuple[LogEntry, ...]:
"""Calls the ``eth_getLogs`` RPC method."""
params = self._encode_filter_params(
source=source, event_filter=event_filter, from_block=from_block, to_block=to_block
)
result = await self._provider_session.rpc("eth_getLogs", params)
# TODO: this will go away with generalized RPC decoding.
if not isinstance(result, list):
raise InvalidResponse(f"Expected a list as a response, got {type(result).__name__}")
return tuple(LogEntry.rpc_decode(ResponseDict(elem)) for elem in result)

@rpc_call("eth_newBlockFilter")
async def eth_new_block_filter(self) -> BlockFilter:
"""Calls the ``eth_newBlockFilter`` RPC method."""
Expand All @@ -764,27 +807,38 @@ async def eth_new_filter(
to_block: Union[int, Block] = Block.LATEST,
) -> LogFilter:
"""Calls the ``eth_newFilter`` RPC method."""
params: Dict[str, Any] = {
"fromBlock": rpc_encode_block(from_block),
"toBlock": rpc_encode_block(to_block),
}
if isinstance(source, Address):
params["address"] = source.rpc_encode()
elif source:
params["address"] = [address.rpc_encode() for address in source]
if event_filter:
encoded_topics: List[Optional[List[str]]] = []
for topic in event_filter.topics:
if topic is None:
encoded_topics.append(None)
else:
encoded_topics.append([elem.rpc_encode() for elem in topic])
params["topics"] = encoded_topics

params = self._encode_filter_params(
source=source, event_filter=event_filter, from_block=from_block, to_block=to_block
)
result, provider_path = await self._provider_session.rpc_and_pin("eth_newFilter", params)
filter_id = LogFilterId.rpc_decode(result)
return LogFilter(id_=filter_id, provider_path=provider_path)

def _parse_filter_result(
self,
filter_: Union[BlockFilter, PendingTransactionFilter, LogFilter],
result: JSON,
) -> Union[Tuple[BlockHash, ...], Tuple[TxHash, ...], Tuple[LogEntry, ...]]:
# TODO: this will go away with generalized RPC decoding.
if not isinstance(result, list):
raise InvalidResponse(f"Expected a list as a response, got {type(result).__name__}")

if isinstance(filter_, BlockFilter):
return tuple(BlockHash.rpc_decode(elem) for elem in result)
if isinstance(filter_, PendingTransactionFilter):
return tuple(TxHash.rpc_decode(elem) for elem in result)
return tuple(LogEntry.rpc_decode(ResponseDict(elem)) for elem in result)

@rpc_call("eth_getFilterLogs")
async def eth_get_filter_logs(
self, filter_: Union[BlockFilter, PendingTransactionFilter, LogFilter]
) -> Union[Tuple[BlockHash, ...], Tuple[TxHash, ...], Tuple[LogEntry, ...]]:
"""Calls the ``eth_getFilterLogs`` RPC method."""
result = await self._provider_session.rpc_at_pin(
filter_.provider_path, "eth_getFilterLogs", filter_.id_.rpc_encode()
)
return self._parse_filter_result(filter_, result)

@rpc_call("eth_getFilterChanges")
async def eth_get_filter_changes(
self, filter_: Union[BlockFilter, PendingTransactionFilter, LogFilter]
Expand All @@ -794,19 +848,10 @@ async def eth_get_filter_changes(
Depending on what ``filter_`` was, returns a tuple of corresponding results.
"""
# TODO: split into separate functions with specific return types?
results = await self._provider_session.rpc_at_pin(
result = await self._provider_session.rpc_at_pin(
filter_.provider_path, "eth_getFilterChanges", filter_.id_.rpc_encode()
)

# TODO: this will go away with generalized RPC decoding.
if not isinstance(results, list):
raise InvalidResponse(f"Expected a list as a response, got {type(results).__name__}")

if isinstance(filter_, BlockFilter):
return tuple(BlockHash.rpc_decode(elem) for elem in results)
if isinstance(filter_, PendingTransactionFilter):
return tuple(TxHash.rpc_decode(elem) for elem in results)
return tuple(LogEntry.rpc_decode(ResponseDict(elem)) for elem in results)
return self._parse_filter_result(filter_, result)

async def iter_blocks(self, poll_interval: int = 1) -> AsyncIterator[BlockHash]:
"""Yields hashes of new blocks being mined."""
Expand Down
25 changes: 25 additions & 0 deletions pons/_local_provider.py
Original file line number Diff line number Diff line change
Expand Up @@ -183,6 +183,8 @@ def rpc(self, method: str, *args: Any) -> JSON:
eth_newPendingTransactionFilter=self.eth_new_pending_transaction_filter,
eth_newFilter=self.eth_new_filter,
eth_getFilterChanges=self.eth_get_filter_changes,
eth_getLogs=self.eth_get_logs,
eth_getFilterLogs=self.eth_get_filter_logs,
)
if method not in dispatch:
raise RPCError.method_not_found(method)
Expand Down Expand Up @@ -346,6 +348,29 @@ def eth_get_filter_changes(self, filter_id: str) -> JSON:
result["removed"] = False
return cast(JSON, results)

def eth_get_logs(self, params: Mapping[str, Any]) -> JSON:
address = params.get("address", None)
topics = params.get("topics", None)
results = self._ethereum_tester.get_logs(
from_block=rpc_decode_block(params["fromBlock"]),
to_block=rpc_decode_block(params["toBlock"]),
address=address,
topics=topics,
)
results = normalize_return_value(results)
for result in results:
# returned by regular RPC providers, but not by EthereumTester
result["removed"] = False
return cast(JSON, results)

def eth_get_filter_logs(self, filter_id: str) -> JSON:
results = self._ethereum_tester.get_all_filter_logs(rpc_decode_quantity(filter_id))
results = normalize_return_value(results)
for result in results:
# returned by regular RPC providers, but not by EthereumTester
result["removed"] = False
return cast(JSON, results)

@asynccontextmanager
async def session(self) -> AsyncIterator["LocalProviderSession"]:
yield LocalProviderSession(self)
Expand Down
55 changes: 55 additions & 0 deletions tests/test_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -620,6 +620,61 @@ async def test_pending_transaction_filter(local_provider, session, root_signer,
assert tx_hashes == (tx_hash,)


async def test_eth_get_logs(
monkeypatch, local_provider, session, compiled_contracts, root_signer, another_signer
):
basic_contract = compiled_contracts["BasicContract"]
await session.transfer(root_signer, another_signer.address, Amount.ether(1))
contract1 = await session.deploy(root_signer, basic_contract.constructor(123))
contract2 = await session.deploy(another_signer, basic_contract.constructor(123))
await session.transact(root_signer, contract1.method.deposit(b"1234"))
await session.transact(another_signer, contract2.method.deposit2(b"4567"))

entries = await session.eth_get_logs(source=contract2.address)
assert len(entries) == 1
assert entries[0].address == contract2.address
assert (
normalize_topics(entries[0].topics)
== contract2.abi.event.Deposit2(another_signer.address, b"4567").topics
)

# Test an invalid response

monkeypatch.setattr(local_provider, "eth_get_logs", lambda _filter_id: {"foo": 1})

block_filter = await session.eth_new_block_filter()

with pytest.raises(
BadResponseFormat, match=r"eth_getLogs: Expected a list as a response, got dict"
):
await session.eth_get_logs(source=contract2.address)


async def test_eth_get_filter_logs(session, compiled_contracts, root_signer, another_signer):
basic_contract = compiled_contracts["BasicContract"]
await session.transfer(root_signer, another_signer.address, Amount.ether(1))
contract1 = await session.deploy(root_signer, basic_contract.constructor(123))
contract2 = await session.deploy(another_signer, basic_contract.constructor(123))

log_filter = await session.eth_new_filter()
await session.transact(root_signer, contract1.method.deposit(b"1234"))
await session.transact(another_signer, contract2.method.deposit2(b"4567"))

entries = await session.eth_get_filter_logs(log_filter)
assert len(entries) == 2
assert entries[0].address == contract1.address
assert entries[1].address == contract2.address

assert (
normalize_topics(entries[0].topics)
== contract1.abi.event.Deposit(root_signer.address, b"1234").topics
)
assert (
normalize_topics(entries[1].topics)
== contract2.abi.event.Deposit2(another_signer.address, b"4567").topics
)


async def test_log_filter_all(session, compiled_contracts, root_signer, another_signer):
basic_contract = compiled_contracts["BasicContract"]
await session.transfer(root_signer, another_signer.address, Amount.ether(1))
Expand Down

0 comments on commit b52e8f0

Please sign in to comment.