Skip to content

Commit

Permalink
Add tests
Browse files Browse the repository at this point in the history
  • Loading branch information
fjarri committed Mar 15, 2024
1 parent 990ff81 commit 21aa3ed
Show file tree
Hide file tree
Showing 6 changed files with 213 additions and 195 deletions.
182 changes: 112 additions & 70 deletions pons/_client.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
from collections.abc import AsyncIterator, Iterable, Sequence
from contextlib import asynccontextmanager
from collections.abc import AsyncIterator, Iterable, Iterator, Sequence
from contextlib import asynccontextmanager, contextmanager
from enum import Enum
from typing import Any, ParamSpec, TypeVar, cast

Expand Down Expand Up @@ -44,7 +44,7 @@
Type2Transaction,
)
from ._provider import InvalidResponse, Provider, ProviderSession
from ._serialization import StructuringError, structure, unstructure
from ._serialization import JSON, StructuringError, structure, unstructure
from ._signer import Signer


Expand Down Expand Up @@ -101,9 +101,8 @@ class ProviderError(RemoteError):

@classmethod
def from_rpc_error(cls, exc: RPCError) -> "ProviderError":
data = structure(None | bytes, exc.data)
parsed_code = RPCErrorCode.from_int(exc.code)
return cls(exc.code, parsed_code, exc.message, data)
return cls(exc.code, parsed_code, exc.message, exc.data)

def __init__(self, raw_code: int, code: RPCErrorCode, message: str, data: None | bytes = None):
super().__init__(raw_code, code, message, data)
Expand All @@ -124,41 +123,49 @@ def __str__(self) -> str:
RetType = TypeVar("RetType")


async def rpc_call(provider_session, method_name: str, ret_type, *args):
"""Catches various response formatting errors and returns them in a unified way."""
@contextmanager
def convert_errors(method_name: str) -> Iterator[None]:
try:
result = await provider_session.rpc(method_name, *(unstructure(arg) for arg in args))
return structure(ret_type, result)
yield
except (StructuringError, InvalidResponse) as exc:
raise BadResponseFormat(f"{method_name}: {exc}") from exc
except RPCError as exc:
raise ProviderError.from_rpc_error(exc) from exc


async def rpc_call_pin(provider_session, method_name: str, ret_type, *args):
async def rpc_call(
provider_session: ProviderSession, method_name: str, ret_type: type[RetType], *args: Any
) -> RetType:
"""Catches various response formatting errors and returns them in a unified way."""
try:
with convert_errors(method_name):
result = await provider_session.rpc(method_name, *(unstructure(arg) for arg in args))
return structure(ret_type, result)


async def rpc_call_pin(
provider_session: ProviderSession, method_name: str, ret_type: type[RetType], *args: Any
) -> tuple[RetType, tuple[int, ...]]:
"""Catches various response formatting errors and returns them in a unified way."""
with convert_errors(method_name):
result, provider_path = await provider_session.rpc_and_pin(
method_name, *(unstructure(arg) for arg in args)
)
return structure(ret_type, result), provider_path
except StructuringError as exc:
raise BadResponseFormat(f"{method_name}: {exc}") from exc
except RPCError as exc:
raise ProviderError.from_rpc_error(exc) from exc


async def rpc_call_at_pin(provider_session, provider_path, method_name: str, ret_type, *args):
async def rpc_call_at_pin(
provider_session: ProviderSession,
provider_path: tuple[int, ...],
method_name: str,
ret_type: type[RetType],
*args: Any,
) -> RetType:
"""Catches various response formatting errors and returns them in a unified way."""
try:
with convert_errors(method_name):
result = await provider_session.rpc_at_pin(
provider_path, method_name, *(unstructure(arg) for arg in args)
)
return structure(ret_type, result)
except StructuringError as exc:
raise BadResponseFormat(f"{method_name}: {exc}") from exc
except RPCError as exc:
raise ProviderError.from_rpc_error(exc) from exc


class ContractPanicReason(Enum):
Expand Down Expand Up @@ -301,14 +308,30 @@ async def eth_get_balance(self, address: Address, block: int | Block = Block.LAT

async def eth_get_transaction_by_hash(self, tx_hash: TxHash) -> None | TxInfo:
"""Calls the ``eth_getTransactionByHash`` RPC method."""
return await rpc_call(
self._provider_session, "eth_getTransactionByHash", None | TxInfo, tx_hash
# Need an explicit cast, mypy doesn't work with union types correctly.
# See https://github.com/python/mypy/issues/16935
return cast(
None | TxInfo,
await rpc_call(
self._provider_session,
"eth_getTransactionByHash",
None | TxInfo, # type: ignore[arg-type]
tx_hash,
),
)

async def eth_get_transaction_receipt(self, tx_hash: TxHash) -> None | TxReceipt:
"""Calls the ``eth_getTransactionReceipt`` RPC method."""
return await rpc_call(
self._provider_session, "eth_getTransactionReceipt", None | TxReceipt, tx_hash
# Need an explicit cast, mypy doesn't work with union types correctly.
# See https://github.com/python/mypy/issues/16935
return cast(
None | TxReceipt,
await rpc_call(
self._provider_session,
"eth_getTransactionReceipt",
None | TxReceipt, # type: ignore[arg-type]
tx_hash,
),
)

async def eth_get_transaction_count(
Expand Down Expand Up @@ -461,24 +484,34 @@ async def eth_get_block_by_hash(
self, block_hash: BlockHash, *, with_transactions: bool = False
) -> None | BlockInfo:
"""Calls the ``eth_getBlockByHash`` RPC method."""
return await rpc_call(
self._provider_session,
"eth_getBlockByHash",
# Need an explicit cast, mypy doesn't work with union types correctly.
# See https://github.com/python/mypy/issues/16935
return cast(
None | BlockInfo,
block_hash,
with_transactions,
await rpc_call(
self._provider_session,
"eth_getBlockByHash",
None | BlockInfo, # type: ignore[arg-type]
block_hash,
with_transactions,
),
)

async def eth_get_block_by_number(
self, block: int | Block = Block.LATEST, *, with_transactions: bool = False
) -> None | BlockInfo:
"""Calls the ``eth_getBlockByNumber`` RPC method."""
return await rpc_call(
self._provider_session,
"eth_getBlockByNumber",
# Need an explicit cast, mypy doesn't work with union types correctly.
# See https://github.com/python/mypy/issues/16935
return cast(
None | BlockInfo,
block,
with_transactions,
await rpc_call(
self._provider_session,
"eth_getBlockByNumber",
None | BlockInfo, # type: ignore[arg-type]
block,
with_transactions,
),
)

async def broadcast_transfer(
Expand All @@ -500,16 +533,19 @@ async def broadcast_transfer(
max_gas_price = await self.eth_gas_price()
max_tip = min(Amount.gwei(1), max_gas_price)
nonce = await self.eth_get_transaction_count(signer.address, Block.PENDING)
tx = unstructure(
Type2Transaction(
chain_id=chain_id,
to=destination_address,
value=amount,
gas=gas,
max_fee_per_gas=max_gas_price,
max_priority_fee_per_gas=max_tip,
nonce=nonce,
)
tx = cast(
dict[str, JSON],
unstructure(
Type2Transaction(
chain_id=chain_id,
to=destination_address,
value=amount,
gas=gas,
max_fee_per_gas=max_gas_price,
max_priority_fee_per_gas=max_tip,
nonce=nonce,
)
),
)
signed_tx = signer.sign_transaction(tx)
return await self._eth_send_raw_transaction(signed_tx)
Expand Down Expand Up @@ -567,16 +603,19 @@ async def deploy(
max_gas_price = await self.eth_gas_price()
max_tip = min(Amount.gwei(1), max_gas_price)
nonce = await self.eth_get_transaction_count(signer.address, Block.PENDING)
tx = unstructure(
Type2Transaction(
chain_id=chain_id,
value=amount,
gas=gas,
max_fee_per_gas=max_gas_price,
max_priority_fee_per_gas=max_tip,
nonce=nonce,
data=call.data_bytes,
)
tx = cast(
dict[str, JSON],
unstructure(
Type2Transaction(
chain_id=chain_id,
value=amount,
gas=gas,
max_fee_per_gas=max_gas_price,
max_priority_fee_per_gas=max_tip,
nonce=nonce,
data=call.data_bytes,
)
),
)
signed_tx = signer.sign_transaction(tx)
tx_hash = await self._eth_send_raw_transaction(signed_tx)
Expand Down Expand Up @@ -622,17 +661,20 @@ async def broadcast_transact(
max_gas_price = await self.eth_gas_price()
max_tip = min(Amount.gwei(1), max_gas_price)
nonce = await self.eth_get_transaction_count(signer.address, Block.PENDING)
tx = unstructure(
Type2Transaction(
chain_id=chain_id,
to=call.contract_address,
value=amount,
gas=gas,
max_fee_per_gas=max_gas_price,
max_priority_fee_per_gas=max_tip,
nonce=nonce,
data=call.data_bytes,
)
tx = cast(
dict[str, JSON],
unstructure(
Type2Transaction(
chain_id=chain_id,
to=call.contract_address,
value=amount,
gas=gas,
max_fee_per_gas=max_gas_price,
max_priority_fee_per_gas=max_tip,
nonce=nonce,
data=call.data_bytes,
)
),
)
signed_tx = signer.sign_transaction(tx)
return await self._eth_send_raw_transaction(signed_tx)
Expand Down Expand Up @@ -681,10 +723,6 @@ async def transact(
)
event_results = []
for log_entry in log_entries:
# We can't ensure it statically, since `eth_getFilterChanges` return type depends
# on the filter passed to it.
log_entry = cast(LogEntry, log_entry)

if log_entry.transaction_hash != receipt.transaction_hash:
continue

Expand All @@ -703,6 +741,8 @@ async def eth_get_logs(
to_block: int | Block = Block.LATEST,
) -> tuple[LogEntry, ...]:
"""Calls the ``eth_getLogs`` RPC method."""
if isinstance(source, Iterable):
source = tuple(source)
params = FilterParams(
from_block=from_block,
to_block=to_block,
Expand Down Expand Up @@ -733,6 +773,8 @@ async def eth_new_filter(
to_block: int | Block = Block.LATEST,
) -> LogFilter:
"""Calls the ``eth_newFilter`` RPC method."""
if isinstance(source, Iterable):
source = tuple(source)
params = FilterParams(
from_block=from_block,
to_block=to_block,
Expand Down
12 changes: 8 additions & 4 deletions pons/_entities.py
Original file line number Diff line number Diff line change
Expand Up @@ -59,6 +59,9 @@ def __init__(self, value: int):
def __hash__(self) -> int:
return hash(self._value)

def __int__(self) -> int:
return self._value

def _check_type(self: TypedQuantityLike, other: Any) -> TypedQuantityLike:
if type(self) != type(other):
raise TypeError(f"Incompatible types: {type(self).__name__} and {type(other).__name__}")
Expand Down Expand Up @@ -442,7 +445,8 @@ class TxReceipt:
"""An array of log objects generated by this transaction."""

@property
def succeeded(self):
def succeeded(self) -> bool:
"""``True`` if the transaction succeeded."""
return self.status == 1


Expand Down Expand Up @@ -492,7 +496,7 @@ class RPCError(Exception):

@classmethod
def invalid_request(cls) -> "RPCError":
return cls(RPCErrorCode.INVALID_REQUEST.value, "invalid json request")
return cls(ErrorCode(RPCErrorCode.INVALID_REQUEST.value), "invalid json request")


# EIP-2930 transaction
Expand Down Expand Up @@ -540,5 +544,5 @@ class FilterParams:

from_block: None | int | Block = None
to_block: None | int | Block = None
address: None | Address | list[Address] = None
topics: None | list[None | LogTopic | list[LogTopic]] = None
address: None | Address | tuple[Address, ...] = None
topics: None | tuple[None | LogTopic | tuple[LogTopic, ...], ...] = None
4 changes: 2 additions & 2 deletions pons/_local_provider.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@
from alysis import RPCError as AlysisRPCError
from eth_account import Account

from ._entities import Amount, RPCError
from ._entities import Amount, ErrorCode, RPCError
from ._provider import JSON, Provider, ProviderSession
from ._signer import AccountSigner, Signer

Expand Down Expand Up @@ -62,7 +62,7 @@ def rpc(self, method: str, *args: Any) -> JSON:
try:
return self._rpc_node.rpc(method, *args)
except AlysisRPCError as exc:
raise RPCError(exc.code, exc.message, exc.data) from exc
raise RPCError(ErrorCode(exc.code), exc.message, exc.data) from exc

@asynccontextmanager
async def session(self) -> AsyncIterator["LocalProviderSession"]:
Expand Down
Loading

0 comments on commit 21aa3ed

Please sign in to comment.