From 4fdf8e9228defd54503b010b1b7dccd48276b490 Mon Sep 17 00:00:00 2001 From: antazoey Date: Tue, 29 Oct 2024 16:06:17 -0500 Subject: [PATCH] perf: Enable flake8 type checks (#2352) --- .pre-commit-config.yaml | 2 +- setup.cfg | 5 +- setup.py | 1 + src/ape/_cli.py | 8 ++- src/ape/api/accounts.py | 5 +- src/ape/api/address.py | 4 +- src/ape/api/compiler.py | 36 +++++----- src/ape/api/explorers.py | 16 +++-- src/ape/api/networks.py | 35 +++++----- src/ape/api/providers.py | 87 +++++++++++++------------ src/ape/api/trace.py | 8 ++- src/ape/api/transactions.py | 25 +++---- src/ape/cli/commands.py | 13 ++-- src/ape/contracts/base.py | 44 +++++++------ src/ape/logging.py | 4 +- src/ape/managers/accounts.py | 6 +- src/ape/managers/chain.py | 29 +++++---- src/ape/managers/compilers.py | 17 ++--- src/ape/managers/config.py | 10 +-- src/ape/managers/converters.py | 8 ++- src/ape/managers/networks.py | 22 ++++--- src/ape/managers/query.py | 2 +- src/ape/pytest/config.py | 9 +-- src/ape/pytest/coverage.py | 39 ++++++----- src/ape/pytest/fixtures.py | 38 ++++++----- src/ape/pytest/gas.py | 26 ++++---- src/ape/pytest/plugin.py | 8 ++- src/ape/pytest/runners.py | 28 ++++---- src/ape/types/address.py | 9 ++- src/ape/types/coverage.py | 10 ++- src/ape/types/signatures.py | 22 ++++--- src/ape/types/trace.py | 9 +-- src/ape/types/units.py | 8 ++- src/ape/utils/misc.py | 4 +- src/ape/utils/os.py | 7 +- src/ape/utils/rpc.py | 2 +- src/ape_accounts/accounts.py | 21 ++++-- src/ape_cache/query.py | 6 +- src/ape_ethereum/_print.py | 15 +++-- src/ape_ethereum/ecosystem.py | 30 +++++---- src/ape_ethereum/multicall/handlers.py | 25 +++---- src/ape_ethereum/provider.py | 50 +++++++------- src/ape_ethereum/trace.py | 34 ++++++---- src/ape_ethereum/transactions.py | 11 ++-- src/ape_node/provider.py | 18 +++-- src/ape_pm/compiler.py | 8 ++- src/ape_pm/project.py | 5 +- src/ape_test/accounts.py | 18 +++-- src/ape_test/provider.py | 28 ++++---- tests/functional/conftest.py | 9 ++- tests/functional/test_config.py | 9 ++- tests/functional/test_contract_event.py | 8 ++- tests/functional/test_explorer.py | 15 +++-- tests/functional/test_receipt.py | 8 ++- 54 files changed, 519 insertions(+), 405 deletions(-) diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index 8c85fc0f1b..a0a339b0ec 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -19,7 +19,7 @@ repos: rev: 7.1.1 hooks: - id: flake8 - additional_dependencies: [flake8-breakpoint, flake8-print, flake8-pydantic] + additional_dependencies: [flake8-breakpoint, flake8-print, flake8-pydantic, flake8-type-checking] - repo: https://github.com/pre-commit/mirrors-mypy rev: v1.13.0 diff --git a/setup.cfg b/setup.cfg index 3272b9ceeb..0b71eb61a4 100644 --- a/setup.cfg +++ b/setup.cfg @@ -7,10 +7,13 @@ exclude = build .eggs tests/integration/cli/projects -ignore = E704,W503,PYD002 +ignore = E704,W503,PYD002,TC003,TC006 per-file-ignores = # Need signal handler before imports src/ape/__init__.py: E402 # Test data causes long lines tests/functional/data/python/__init__.py: E501 tests/functional/utils/expected_traces.py: E501 + +type-checking-pydantic-enabled = True +type-checking-sqlalchemy-enabled = True diff --git a/setup.py b/setup.py index 6d516c59c8..932989e652 100644 --- a/setup.py +++ b/setup.py @@ -34,6 +34,7 @@ "flake8-breakpoint>=1.1.0,<2", # Detect breakpoints left in code "flake8-print>=4.0.1,<5", # Detect print statements left in code "flake8-pydantic", # For detecting issues with Pydantic models + "flake8-type-checking", # Detect imports to move in/out of type-checking blocks "isort>=5.13.2,<6", # Import sorting linter "mdformat>=0.7.18", # Auto-formatter for markdown "mdformat-gfm>=0.3.5", # Needed for formatting GitHub-flavored markdown diff --git a/src/ape/_cli.py b/src/ape/_cli.py index caf223ce0a..8afe38318b 100644 --- a/src/ape/_cli.py +++ b/src/ape/_cli.py @@ -7,18 +7,20 @@ from importlib import import_module from importlib.metadata import entry_points from pathlib import Path -from typing import Any, Optional +from typing import TYPE_CHECKING, Any, Optional from warnings import catch_warnings, simplefilter import click import rich import yaml -from click import Context from ape.cli.options import ape_cli_context from ape.exceptions import Abort, ApeException, ConfigError, handle_ape_exception from ape.logging import logger +if TYPE_CHECKING: + from click import Context + _DIFFLIB_CUT_OFF = 0.6 @@ -53,7 +55,7 @@ def _validate_config(): class ApeCLI(click.MultiCommand): _CLI_GROUP_NAME = "ape_cli_subcommands" - def parse_args(self, ctx: Context, args: list[str]) -> list[str]: + def parse_args(self, ctx: "Context", args: list[str]) -> list[str]: # Validate the config before any argument parsing, # as arguments may utilize config. if "--help" not in args and args != []: diff --git a/src/ape/api/accounts.py b/src/ape/api/accounts.py index 9320bfa5f1..c8c357ffb6 100644 --- a/src/ape/api/accounts.py +++ b/src/ape/api/accounts.py @@ -10,7 +10,6 @@ from eip712.messages import SignableMessage as EIP712SignableMessage from eth_account import Account from eth_account.messages import encode_defunct -from eth_pydantic_types import HexBytes from eth_utils import to_hex from ethpm_types import ContractType @@ -31,6 +30,8 @@ from ape.utils.misc import raises_not_implemented if TYPE_CHECKING: + from eth_pydantic_types import HexBytes + from ape.contracts import ContractContainer, ContractInstance @@ -65,7 +66,7 @@ def alias(self) -> Optional[str]: """ return None - def sign_raw_msghash(self, msghash: HexBytes) -> Optional[MessageSignature]: + def sign_raw_msghash(self, msghash: "HexBytes") -> Optional[MessageSignature]: """ Sign a raw message hash. diff --git a/src/ape/api/address.py b/src/ape/api/address.py index cd62661c41..7ab52a81c3 100644 --- a/src/ape/api/address.py +++ b/src/ape/api/address.py @@ -7,13 +7,13 @@ from ape.exceptions import ConversionError from ape.types.address import AddressType from ape.types.units import CurrencyValue -from ape.types.vm import ContractCode from ape.utils.basemodel import BaseInterface from ape.utils.misc import log_instead_of_fail if TYPE_CHECKING: from ape.api.transactions import ReceiptAPI, TransactionAPI from ape.managers.chain import AccountHistory + from ape.types.vm import ContractCode class BaseAddress(BaseInterface): @@ -146,7 +146,7 @@ def __setattr__(self, attr: str, value: Any) -> None: super().__setattr__(attr, value) @property - def code(self) -> ContractCode: + def code(self) -> "ContractCode": """ The raw bytes of the smart-contract code at the address. """ diff --git a/src/ape/api/compiler.py b/src/ape/api/compiler.py index e870edc745..eec9c47265 100644 --- a/src/ape/api/compiler.py +++ b/src/ape/api/compiler.py @@ -4,21 +4,21 @@ from pathlib import Path from typing import TYPE_CHECKING, Optional -from eth_pydantic_types import HexBytes -from ethpm_types import ContractType -from ethpm_types.source import Content, ContractSource -from packaging.version import Version - -from ape.api.config import PluginConfig -from ape.api.trace import TraceAPI from ape.exceptions import APINotImplementedError, ContractLogicError -from ape.types.coverage import ContractSourceCoverage -from ape.types.trace import SourceTraceback from ape.utils.basemodel import BaseInterfaceModel from ape.utils.misc import log_instead_of_fail, raises_not_implemented if TYPE_CHECKING: + from eth_pydantic_types import HexBytes + from ethpm_types import ContractType + from ethpm_types.source import Content, ContractSource + from packaging.version import Version + + from ape.api.config import PluginConfig + from ape.api.trace import TraceAPI from ape.managers.project import ProjectManager + from ape.types.coverage import ContractSourceCoverage + from ape.types.trace import SourceTraceback class CompilerAPI(BaseInterfaceModel): @@ -44,7 +44,7 @@ def name(self) -> str: The name of the compiler. """ - def get_config(self, project: Optional["ProjectManager"] = None) -> PluginConfig: + def get_config(self, project: Optional["ProjectManager"] = None) -> "PluginConfig": """ The combination of settings from ``ape-config.yaml`` and ``.compiler_settings``. @@ -79,7 +79,7 @@ def get_compiler_settings( # type: ignore[empty-body] contract_filepaths: Iterable[Path], project: Optional["ProjectManager"] = None, **overrides, - ) -> dict[Version, dict]: + ) -> dict["Version", dict]: """ Get a mapping of the settings that would be used to compile each of the sources by the compiler version number. @@ -101,7 +101,7 @@ def compile( contract_filepaths: Iterable[Path], project: Optional["ProjectManager"], settings: Optional[dict] = None, - ) -> Iterator[ContractType]: + ) -> Iterator["ContractType"]: """ Compile the given source files. All compiler plugins must implement this function. @@ -123,7 +123,7 @@ def compile_code( # type: ignore[empty-body] project: Optional["ProjectManager"], settings: Optional[dict] = None, **kwargs, - ) -> ContractType: + ) -> "ContractType": """ Compile a program. @@ -162,7 +162,7 @@ def get_version_map( # type: ignore[empty-body] self, contract_filepaths: Iterable[Path], project: Optional["ProjectManager"] = None, - ) -> dict[Version, set[Path]]: + ) -> dict["Version", set[Path]]: """ Get a map of versions to source paths. @@ -218,8 +218,8 @@ def enrich_error(self, err: ContractLogicError) -> ContractLogicError: @raises_not_implemented def trace_source( # type: ignore[empty-body] - self, contract_source: ContractSource, trace: TraceAPI, calldata: HexBytes - ) -> SourceTraceback: + self, contract_source: "ContractSource", trace: "TraceAPI", calldata: "HexBytes" + ) -> "SourceTraceback": """ Get a source-traceback for the given contract type. The source traceback object contains all the control paths taken in the transaction. @@ -239,7 +239,7 @@ def trace_source( # type: ignore[empty-body] @raises_not_implemented def flatten_contract( # type: ignore[empty-body] self, path: Path, project: Optional["ProjectManager"] = None, **kwargs - ) -> Content: + ) -> "Content": """ Get the content of a flattened contract via its source path. Plugin implementations handle import resolution, SPDX de-duplication, @@ -259,7 +259,7 @@ def flatten_contract( # type: ignore[empty-body] @raises_not_implemented def init_coverage_profile( - self, source_coverage: ContractSourceCoverage, contract_source: ContractSource + self, source_coverage: "ContractSourceCoverage", contract_source: "ContractSource" ): # type: ignore[empty-body] """ Initialize an empty report for the given source ID. Modifies the given source diff --git a/src/ape/api/explorers.py b/src/ape/api/explorers.py index 2344c06d05..e1901d4a0f 100644 --- a/src/ape/api/explorers.py +++ b/src/ape/api/explorers.py @@ -1,12 +1,14 @@ from abc import abstractmethod -from typing import Optional - -from ethpm_types import ContractType +from typing import TYPE_CHECKING, Optional from ape.api.networks import NetworkAPI -from ape.types.address import AddressType from ape.utils.basemodel import BaseInterfaceModel +if TYPE_CHECKING: + from ethpm_types import ContractType + + from ape.types.address import AddressType + class ExplorerAPI(BaseInterfaceModel): """ @@ -18,7 +20,7 @@ class ExplorerAPI(BaseInterfaceModel): network: NetworkAPI @abstractmethod - def get_address_url(self, address: AddressType) -> str: + def get_address_url(self, address: "AddressType") -> str: """ Get an address URL, such as for a transaction. @@ -42,7 +44,7 @@ def get_transaction_url(self, transaction_hash: str) -> str: """ @abstractmethod - def get_contract_type(self, address: AddressType) -> Optional[ContractType]: + def get_contract_type(self, address: "AddressType") -> Optional["ContractType"]: """ Get the contract type for a given address if it has been published to this explorer. @@ -54,7 +56,7 @@ def get_contract_type(self, address: AddressType) -> Optional[ContractType]: """ @abstractmethod - def publish_contract(self, address: AddressType): + def publish_contract(self, address: "AddressType"): """ Publish a contract to the explorer. diff --git a/src/ape/api/networks.py b/src/ape/api/networks.py index 92145dbbf6..5f965428f1 100644 --- a/src/ape/api/networks.py +++ b/src/ape/api/networks.py @@ -12,8 +12,6 @@ ) from eth_pydantic_types import HexBytes from eth_utils import keccak, to_int -from ethpm_types import ContractType -from ethpm_types.abi import ABIType, ConstructorABI, EventABI, MethodABI from pydantic import model_validator from ape.exceptions import ( @@ -26,8 +24,7 @@ SignatureError, ) from ape.logging import logger -from ape.types.address import AddressType, RawAddress -from ape.types.events import ContractLog +from ape.types.address import AddressType from ape.types.gas import AutoGasLimit, GasLimit from ape.utils.basemodel import ( BaseInterfaceModel, @@ -47,6 +44,12 @@ from .config import PluginConfig if TYPE_CHECKING: + from ethpm_types import ContractType + from ethpm_types.abi import ABIType, ConstructorABI, EventABI, MethodABI + + from ape.types.address import RawAddress + from ape.types.events import ContractLog + from .explorers import ExplorerAPI from .providers import BlockAPI, ProviderAPI, UpstreamProvider from .trace import TraceAPI @@ -135,7 +138,7 @@ def custom_network(self) -> "NetworkAPI": @classmethod @abstractmethod - def decode_address(cls, raw_address: RawAddress) -> AddressType: + def decode_address(cls, raw_address: "RawAddress") -> AddressType: """ Convert a raw address to the ecosystem's native address type. @@ -149,7 +152,7 @@ def decode_address(cls, raw_address: RawAddress) -> AddressType: @classmethod @abstractmethod - def encode_address(cls, address: AddressType) -> RawAddress: + def encode_address(cls, address: AddressType) -> "RawAddress": """ Convert the ecosystem's native address type to a raw integer or str address. @@ -162,7 +165,7 @@ def encode_address(cls, address: AddressType) -> RawAddress: @raises_not_implemented def encode_contract_blueprint( # type: ignore[empty-body] - self, contract_type: ContractType, *args, **kwargs + self, contract_type: "ContractType", *args, **kwargs ) -> "TransactionAPI": """ Encode a unique type of transaction that allows contracts to be created @@ -386,7 +389,7 @@ def set_default_network(self, network_name: str): @abstractmethod def encode_deployment( - self, deployment_bytecode: HexBytes, abi: ConstructorABI, *args, **kwargs + self, deployment_bytecode: HexBytes, abi: "ConstructorABI", *args, **kwargs ) -> "TransactionAPI": """ Create a deployment transaction in the given ecosystem. @@ -404,7 +407,7 @@ def encode_deployment( @abstractmethod def encode_transaction( - self, address: AddressType, abi: MethodABI, *args, **kwargs + self, address: AddressType, abi: "MethodABI", *args, **kwargs ) -> "TransactionAPI": """ Encode a transaction object from a contract function's ABI and call arguments. @@ -421,12 +424,12 @@ def encode_transaction( """ @abstractmethod - def decode_logs(self, logs: Sequence[dict], *events: EventABI) -> Iterator[ContractLog]: + def decode_logs(self, logs: Sequence[dict], *events: "EventABI") -> Iterator["ContractLog"]: """ Decode any contract logs that match the given event ABI from the raw log data. Args: - logs (Sequence[Dict]): A list of raw log data from the chain. + logs (Sequence[dict]): A list of raw log data from the chain. *events (EventABI): Event definitions to decode. Returns: @@ -464,7 +467,7 @@ def create_transaction(self, **kwargs) -> "TransactionAPI": """ @abstractmethod - def decode_calldata(self, abi: Union[ConstructorABI, MethodABI], calldata: bytes) -> dict: + def decode_calldata(self, abi: Union["ConstructorABI", "MethodABI"], calldata: bytes) -> dict: """ Decode method calldata. @@ -479,7 +482,7 @@ def decode_calldata(self, abi: Union[ConstructorABI, MethodABI], calldata: bytes """ @abstractmethod - def encode_calldata(self, abi: Union[ConstructorABI, MethodABI], *args: Any) -> HexBytes: + def encode_calldata(self, abi: Union["ConstructorABI", "MethodABI"], *args: Any) -> HexBytes: """ Encode method calldata. @@ -492,7 +495,7 @@ def encode_calldata(self, abi: Union[ConstructorABI, MethodABI], *args: Any) -> """ @abstractmethod - def decode_returndata(self, abi: MethodABI, raw_data: bytes) -> Any: + def decode_returndata(self, abi: "MethodABI", raw_data: bytes) -> Any: """ Get the result of a contract call. @@ -586,7 +589,7 @@ def get_proxy_info(self, address: AddressType) -> Optional[ProxyInfoAPI]: """ return None - def get_method_selector(self, abi: MethodABI) -> HexBytes: + def get_method_selector(self, abi: "MethodABI") -> HexBytes: """ Get a contract method selector, typically via hashing such as ``keccak``. Defaults to using ``keccak`` but can be overridden in different ecosystems. @@ -626,7 +629,7 @@ def enrich_trace(self, trace: "TraceAPI", **kwargs) -> "TraceAPI": @raises_not_implemented def get_python_types( # type: ignore[empty-body] - self, abi_type: ABIType + self, abi_type: "ABIType" ) -> Union[type, Sequence]: """ Get the Python types for a given ABI type. diff --git a/src/ape/api/providers.py b/src/ape/api/providers.py index dd5bc51c15..b0e1695f0b 100644 --- a/src/ape/api/providers.py +++ b/src/ape/api/providers.py @@ -16,14 +16,10 @@ from subprocess import DEVNULL, PIPE, Popen from typing import TYPE_CHECKING, Any, Optional, Union, cast -from eth_pydantic_types import HexBytes -from ethpm_types.abi import EventABI from pydantic import Field, computed_field, field_serializer, model_validator -from ape.api.config import PluginConfig from ape.api.networks import NetworkAPI from ape.api.query import BlockTransactionQuery -from ape.api.trace import TraceAPI from ape.api.transactions import ReceiptAPI, TransactionAPI from ape.exceptions import ( APINotImplementedError, @@ -35,10 +31,7 @@ VirtualMachineError, ) from ape.logging import LogLevel, logger -from ape.types.address import AddressType from ape.types.basic import HexInt -from ape.types.events import ContractLog, LogFilter -from ape.types.vm import BlockID, ContractCode, SnapshotID from ape.utils.basemodel import BaseInterfaceModel from ape.utils.misc import ( EMPTY_BYTES32, @@ -51,7 +44,15 @@ from ape.utils.rpc import RPCHeaders if TYPE_CHECKING: + from eth_pydantic_types import HexBytes + from ethpm_types.abi import EventABI + from ape.api.accounts import TestAccountAPI + from ape.api.config import PluginConfig + from ape.api.trace import TraceAPI + from ape.types.address import AddressType + from ape.types.events import ContractLog, LogFilter + from ape.types.vm import BlockID, ContractCode, SnapshotID class BlockAPI(BaseInterfaceModel): @@ -254,7 +255,7 @@ def ws_uri(self) -> Optional[str]: return None @property - def settings(self) -> PluginConfig: + def settings(self) -> "PluginConfig": """ The combination of settings from ``ape-config.yaml`` and ``.provider_settings``. """ @@ -303,7 +304,7 @@ def chain_id(self) -> int: """ @abstractmethod - def get_balance(self, address: AddressType, block_id: Optional[BlockID] = None) -> int: + def get_balance(self, address: "AddressType", block_id: Optional["BlockID"] = None) -> int: """ Get the balance of an account. @@ -317,7 +318,9 @@ def get_balance(self, address: AddressType, block_id: Optional[BlockID] = None) """ @abstractmethod - def get_code(self, address: AddressType, block_id: Optional[BlockID] = None) -> ContractCode: + def get_code( + self, address: "AddressType", block_id: Optional["BlockID"] = None + ) -> "ContractCode": """ Get the bytes a contract. @@ -370,7 +373,7 @@ def stream_request( # type: ignore[empty-body] """ # TODO: In 0.9, delete this method. - def get_storage_at(self, *args, **kwargs) -> HexBytes: + def get_storage_at(self, *args, **kwargs) -> "HexBytes": warnings.warn( "'provider.get_storage_at()' is deprecated. Use 'provider.get_storage()'.", DeprecationWarning, @@ -379,8 +382,8 @@ def get_storage_at(self, *args, **kwargs) -> HexBytes: @raises_not_implemented def get_storage( # type: ignore[empty-body] - self, address: AddressType, slot: int, block_id: Optional[BlockID] = None - ) -> HexBytes: + self, address: "AddressType", slot: int, block_id: Optional["BlockID"] = None + ) -> "HexBytes": """ Gets the raw value of a storage slot of a contract. @@ -395,7 +398,7 @@ def get_storage( # type: ignore[empty-body] """ @abstractmethod - def get_nonce(self, address: AddressType, block_id: Optional[BlockID] = None) -> int: + def get_nonce(self, address: "AddressType", block_id: Optional["BlockID"] = None) -> int: """ Get the number of times an account has transacted. @@ -409,7 +412,7 @@ def get_nonce(self, address: AddressType, block_id: Optional[BlockID] = None) -> """ @abstractmethod - def estimate_gas_cost(self, txn: TransactionAPI, block_id: Optional[BlockID] = None) -> int: + def estimate_gas_cost(self, txn: TransactionAPI, block_id: Optional["BlockID"] = None) -> int: """ Estimate the cost of gas for a transaction. @@ -444,7 +447,7 @@ def max_gas(self) -> int: """ @property - def config(self) -> PluginConfig: + def config(self) -> "PluginConfig": """ The provider's configuration. """ @@ -482,7 +485,7 @@ def base_fee(self) -> int: raise APINotImplementedError("base_fee is not implemented by this provider") @abstractmethod - def get_block(self, block_id: BlockID) -> BlockAPI: + def get_block(self, block_id: "BlockID") -> BlockAPI: """ Get a block. @@ -502,10 +505,10 @@ def get_block(self, block_id: BlockID) -> BlockAPI: def send_call( self, txn: TransactionAPI, - block_id: Optional[BlockID] = None, + block_id: Optional["BlockID"] = None, state: Optional[dict] = None, **kwargs, - ) -> HexBytes: # Return value of function + ) -> "HexBytes": # Return value of function """ Execute a new transaction call immediately without creating a transaction on the block chain. @@ -538,7 +541,7 @@ def get_receipt(self, txn_hash: str, **kwargs) -> ReceiptAPI: """ @abstractmethod - def get_transactions_by_block(self, block_id: BlockID) -> Iterator[TransactionAPI]: + def get_transactions_by_block(self, block_id: "BlockID") -> Iterator[TransactionAPI]: """ Get the information about a set of transactions from a block. @@ -552,7 +555,7 @@ def get_transactions_by_block(self, block_id: BlockID) -> Iterator[TransactionAP @raises_not_implemented def get_transactions_by_account_nonce( # type: ignore[empty-body] self, - account: AddressType, + account: "AddressType", start_nonce: int = 0, stop_nonce: int = -1, ) -> Iterator[ReceiptAPI]: @@ -581,7 +584,7 @@ def send_transaction(self, txn: TransactionAPI) -> ReceiptAPI: """ @abstractmethod - def get_contract_logs(self, log_filter: LogFilter) -> Iterator[ContractLog]: + def get_contract_logs(self, log_filter: "LogFilter") -> Iterator["ContractLog"]: """ Get logs from contracts. @@ -622,25 +625,25 @@ def send_private_transaction(self, txn: TransactionAPI, **kwargs) -> ReceiptAPI: raise _create_raises_not_implemented_error(self.send_private_transaction) @raises_not_implemented - def snapshot(self) -> SnapshotID: # type: ignore[empty-body] + def snapshot(self) -> "SnapshotID": # type: ignore[empty-body] """ Defined to make the ``ProviderAPI`` interchangeable with a :class:`~ape.api.providers.TestProviderAPI`, as in :class:`ape.managers.chain.ChainManager`. Raises: - :class:`~ape.exceptions.APINotImplementedError`: Unless overriden. + :class:`~ape.exceptions.APINotImplementedError`: Unless overridden. """ @raises_not_implemented - def restore(self, snapshot_id: SnapshotID): + def restore(self, snapshot_id: "SnapshotID"): """ Defined to make the ``ProviderAPI`` interchangeable with a :class:`~ape.api.providers.TestProviderAPI`, as in :class:`ape.managers.chain.ChainManager`. Raises: - :class:`~ape.exceptions.APINotImplementedError`: Unless overriden. + :class:`~ape.exceptions.APINotImplementedError`: Unless overridden. """ @raises_not_implemented @@ -651,7 +654,7 @@ def set_timestamp(self, new_timestamp: int): :class:`ape.managers.chain.ChainManager`. Raises: - :class:`~ape.exceptions.APINotImplementedError`: Unless overriden. + :class:`~ape.exceptions.APINotImplementedError`: Unless overridden. """ @raises_not_implemented @@ -662,11 +665,11 @@ def mine(self, num_blocks: int = 1): :class:`ape.managers.chain.ChainManager`. Raises: - :class:`~ape.exceptions.APINotImplementedError`: Unless overriden. + :class:`~ape.exceptions.APINotImplementedError`: Unless overridden. """ @raises_not_implemented - def set_balance(self, address: AddressType, amount: int): + def set_balance(self, address: "AddressType", amount: int): """ Change the balance of an account. @@ -693,7 +696,7 @@ def __repr__(self) -> str: @raises_not_implemented def set_code( # type: ignore[empty-body] - self, address: AddressType, code: ContractCode + self, address: "AddressType", code: "ContractCode" ) -> bool: """ Change the code of a smart contract, for development purposes. @@ -706,7 +709,7 @@ def set_code( # type: ignore[empty-body] @raises_not_implemented def set_storage( # type: ignore[empty-body] - self, address: AddressType, slot: int, value: HexBytes + self, address: "AddressType", slot: int, value: "HexBytes" ): """ Sets the raw value of a storage slot of a contract. @@ -718,7 +721,7 @@ def set_storage( # type: ignore[empty-body] """ @raises_not_implemented - def unlock_account(self, address: AddressType) -> bool: # type: ignore[empty-body] + def unlock_account(self, address: "AddressType") -> bool: # type: ignore[empty-body] """ Ask the provider to allow an address to submit transactions without validating signatures. This feature is intended to be subclassed by a @@ -736,7 +739,7 @@ def unlock_account(self, address: AddressType) -> bool: # type: ignore[empty-bo """ @raises_not_implemented - def relock_account(self, address: AddressType): + def relock_account(self, address: "AddressType"): """ Stop impersonating an account. @@ -746,13 +749,13 @@ def relock_account(self, address: AddressType): @raises_not_implemented def get_transaction_trace( # type: ignore[empty-body] - self, txn_hash: Union[HexBytes, str] - ) -> TraceAPI: + self, txn_hash: Union["HexBytes", str] + ) -> "TraceAPI": """ Provide a detailed description of opcodes. Args: - transaction_hash (Union[HexBytes, str]): The hash of a transaction + txn_hash (Union[HexBytes, str]): The hash of a transaction to trace. Returns: @@ -794,12 +797,12 @@ def poll_blocks( # type: ignore[empty-body] def poll_logs( # type: ignore[empty-body] self, stop_block: Optional[int] = None, - address: Optional[AddressType] = None, + address: Optional["AddressType"] = None, topics: Optional[list[Union[str, list[str]]]] = None, required_confirmations: Optional[int] = None, new_block_timeout: Optional[int] = None, - events: Optional[list[EventABI]] = None, - ) -> Iterator[ContractLog]: + events: Optional[list["EventABI"]] = None, + ) -> Iterator["ContractLog"]: """ Poll new blocks. Optionally set a start block to include historical blocks. @@ -874,11 +877,11 @@ class TestProviderAPI(ProviderAPI): """ @cached_property - def test_config(self) -> PluginConfig: + def test_config(self) -> "PluginConfig": return self.config_manager.get_config("test") @abstractmethod - def snapshot(self) -> SnapshotID: + def snapshot(self) -> "SnapshotID": """ Record the current state of the blockchain with intent to later call the method :meth:`~ape.managers.chain.ChainManager.revert` @@ -889,7 +892,7 @@ def snapshot(self) -> SnapshotID: """ @abstractmethod - def restore(self, snapshot_id: SnapshotID): + def restore(self, snapshot_id: "SnapshotID"): """ Regress the current call using the given snapshot ID. Allows developers to go back to a previous state. diff --git a/src/ape/api/trace.py b/src/ape/api/trace.py index f4398fb58b..009f65e3a8 100644 --- a/src/ape/api/trace.py +++ b/src/ape/api/trace.py @@ -1,11 +1,13 @@ import sys from abc import abstractmethod from collections.abc import Iterator, Sequence -from typing import IO, Any, Optional +from typing import IO, TYPE_CHECKING, Any, Optional -from ape.types.trace import ContractFunctionPath, GasReport from ape.utils.basemodel import BaseInterfaceModel +if TYPE_CHECKING: + from ape.types.trace import ContractFunctionPath, GasReport + class TraceAPI(BaseInterfaceModel): """ @@ -22,7 +24,7 @@ def show(self, verbose: bool = False, file: IO[str] = sys.stdout): @abstractmethod def get_gas_report( self, exclude: Optional[Sequence["ContractFunctionPath"]] = None - ) -> GasReport: + ) -> "GasReport": """ Get the gas report. """ diff --git a/src/ape/api/transactions.py b/src/ape/api/transactions.py index 8a68da6163..930ac26b10 100644 --- a/src/ape/api/transactions.py +++ b/src/ape/api/transactions.py @@ -2,18 +2,16 @@ import time from abc import abstractmethod from collections.abc import Iterator -from datetime import datetime +from datetime import datetime as datetime_type from functools import cached_property from typing import IO, TYPE_CHECKING, Any, NoReturn, Optional, Union from eth_pydantic_types import HexBytes, HexStr from eth_utils import is_hex, to_hex, to_int -from ethpm_types.abi import EventABI, MethodABI from pydantic import ConfigDict, field_validator from pydantic.fields import Field from tqdm import tqdm # type: ignore -from ape.api.explorers import ExplorerAPI from ape.exceptions import ( NetworkError, ProviderNotConnectedError, @@ -24,17 +22,20 @@ from ape.logging import logger from ape.types.address import AddressType from ape.types.basic import HexInt -from ape.types.events import ContractLogContainer from ape.types.gas import AutoGasLimit from ape.types.signatures import TransactionSignature -from ape.types.trace import SourceTraceback from ape.utils.basemodel import BaseInterfaceModel, ExtraAttributesMixin, ExtraModelAttributes from ape.utils.misc import log_instead_of_fail, raises_not_implemented if TYPE_CHECKING: + from ethpm_types.abi import EventABI, MethodABI + + from ape.api.explorers import ExplorerAPI from ape.api.providers import BlockAPI from ape.api.trace import TraceAPI from ape.contracts import ContractEvent + from ape.types.events import ContractLogContainer + from ape.types.trace import SourceTraceback class TransactionAPI(BaseInterfaceModel): @@ -352,7 +353,7 @@ def trace(self) -> "TraceAPI": return self.provider.get_transaction_trace(self.txn_hash) @property - def _explorer(self) -> Optional[ExplorerAPI]: + def _explorer(self) -> Optional["ExplorerAPI"]: return self.provider.network.explorer @property @@ -377,11 +378,11 @@ def timestamp(self) -> int: return self.block.timestamp @property - def datetime(self) -> datetime: + def datetime(self) -> "datetime_type": return self.block.datetime @cached_property - def events(self) -> ContractLogContainer: + def events(self) -> "ContractLogContainer": """ All the events that were emitted from this call. """ @@ -392,9 +393,9 @@ def events(self) -> ContractLogContainer: def decode_logs( self, abi: Optional[ - Union[list[Union[EventABI, "ContractEvent"]], Union[EventABI, "ContractEvent"]] + Union[list[Union["EventABI", "ContractEvent"]], Union["EventABI", "ContractEvent"]] ] = None, - ) -> ContractLogContainer: + ) -> "ContractLogContainer": """ Decode the logs on the receipt. @@ -482,7 +483,7 @@ def _await_confirmations(self): time.sleep(time_to_sleep) @property - def method_called(self) -> Optional[MethodABI]: + def method_called(self) -> Optional["MethodABI"]: """ The method ABI of the method called to produce this receipt. """ @@ -502,7 +503,7 @@ def return_value(self) -> Any: @property @raises_not_implemented - def source_traceback(self) -> SourceTraceback: # type: ignore[empty-body] + def source_traceback(self) -> "SourceTraceback": # type: ignore[empty-body] """ A Pythonic style traceback for both failing and non-failing receipts. Requires a provider that implements diff --git a/src/ape/cli/commands.py b/src/ape/cli/commands.py index fb4d305363..63a8a2f246 100644 --- a/src/ape/cli/commands.py +++ b/src/ape/cli/commands.py @@ -3,17 +3,18 @@ from typing import TYPE_CHECKING, Any, Optional import click -from click import Context from ape.cli.choices import _NONE_NETWORK, NetworkChoice from ape.exceptions import NetworkError if TYPE_CHECKING: + from click import Context + from ape.api.networks import ProviderContextManager from ape.api.providers import ProviderAPI -def get_param_from_ctx(ctx: Context, param: str) -> Optional[Any]: +def get_param_from_ctx(ctx: "Context", param: str) -> Optional[Any]: if value := ctx.params.get(param): return value @@ -24,7 +25,7 @@ def get_param_from_ctx(ctx: Context, param: str) -> Optional[Any]: return None -def parse_network(ctx: Context) -> Optional["ProviderContextManager"]: +def parse_network(ctx: "Context") -> Optional["ProviderContextManager"]: from ape.utils.basemodel import ManagerAccessMixin as access interactive = get_param_from_ctx(ctx, "interactive") @@ -70,7 +71,7 @@ def __init__(self, *args, **kwargs): self._network_callback = kwargs.pop("network_callback", None) super().__init__(*args, **kwargs) - def parse_args(self, ctx: Context, args: list[str]) -> list[str]: + def parse_args(self, ctx: "Context", args: list[str]) -> list[str]: arguments = args # Renamed for better pdb support. provider_module = import_module("ape.api.providers") base_type = provider_module.ProviderAPI if self._use_cls_types else str @@ -96,7 +97,7 @@ def parse_args(self, ctx: Context, args: list[str]) -> list[str]: return super().parse_args(ctx, arguments) - def invoke(self, ctx: Context) -> Any: + def invoke(self, ctx: "Context") -> Any: if self.callback is None: return @@ -106,7 +107,7 @@ def invoke(self, ctx: Context) -> Any: else: return self._invoke(ctx) - def _invoke(self, ctx: Context, provider: Optional["ProviderAPI"] = None): + def _invoke(self, ctx: "Context", provider: Optional["ProviderAPI"] = None): # Will be put back with correct value if needed. # Else, causes issues. ctx.params.pop("network", None) diff --git a/src/ape/contracts/base.py b/src/ape/contracts/base.py index 1dbe3c0d8f..9c73fb8158 100644 --- a/src/ape/contracts/base.py +++ b/src/ape/contracts/base.py @@ -4,13 +4,13 @@ from functools import cached_property, partial, singledispatchmethod from itertools import islice from pathlib import Path -from typing import Any, Optional, Union +from typing import TYPE_CHECKING, Any, Optional, Union import click import pandas as pd from eth_pydantic_types import HexBytes from eth_utils import to_hex -from ethpm_types.abi import ConstructorABI, ErrorABI, EventABI, MethodABI +from ethpm_types.abi import EventABI, MethodABI from ethpm_types.contract_type import ABI_W_SELECTOR_T, ContractType from IPython.lib.pretty import for_type @@ -22,7 +22,6 @@ extract_fields, validate_and_expand_columns, ) -from ape.api.transactions import ReceiptAPI, TransactionAPI from ape.exceptions import ( ApeAttributeError, ArgumentsLengthError, @@ -49,12 +48,17 @@ ) from ape.utils.misc import log_instead_of_fail +if TYPE_CHECKING: + from ethpm_types.abi import ConstructorABI, ErrorABI + + from ape.api.transactions import ReceiptAPI, TransactionAPI + class ContractConstructor(ManagerAccessMixin): def __init__( self, deployment_bytecode: HexBytes, - abi: ConstructorABI, + abi: "ConstructorABI", ) -> None: self.deployment_bytecode = deployment_bytecode self.abi = abi @@ -76,14 +80,14 @@ def decode_input(self, calldata: bytes) -> tuple[str, dict[str, Any]]: decoded_inputs = self.provider.network.ecosystem.decode_calldata(self.abi, calldata) return self.abi.selector, decoded_inputs - def serialize_transaction(self, *args, **kwargs) -> TransactionAPI: + def serialize_transaction(self, *args, **kwargs) -> "TransactionAPI": arguments = self.conversion_manager.convert_method_args(self.abi, args) converted_kwargs = self.conversion_manager.convert_method_kwargs(kwargs) return self.provider.network.ecosystem.encode_deployment( self.deployment_bytecode, self.abi, *arguments, **converted_kwargs ) - def __call__(self, private: bool = False, *args, **kwargs) -> ReceiptAPI: + def __call__(self, private: bool = False, *args, **kwargs) -> "ReceiptAPI": txn = self.serialize_transaction(*args, **kwargs) if "sender" in kwargs and isinstance(kwargs["sender"], AccountAPI): @@ -109,7 +113,7 @@ def __init__(self, abi: MethodABI, address: AddressType) -> None: def __repr__(self) -> str: return self.abi.signature - def serialize_transaction(self, *args, **kwargs) -> TransactionAPI: + def serialize_transaction(self, *args, **kwargs) -> "TransactionAPI": converted_kwargs = self.conversion_manager.convert_method_kwargs(kwargs) return self.provider.network.ecosystem.encode_transaction( self.address, self.abi, *args, **converted_kwargs @@ -343,7 +347,7 @@ def __init__(self, abi: MethodABI, address: AddressType) -> None: def __repr__(self) -> str: return self.abi.signature - def serialize_transaction(self, *args, **kwargs) -> TransactionAPI: + def serialize_transaction(self, *args, **kwargs) -> "TransactionAPI": if "sender" in kwargs and isinstance(kwargs["sender"], (ContractInstance, Address)): # Automatically impersonate contracts (if API available) when sender kwargs["sender"] = self.account_manager.test_accounts[kwargs["sender"].address] @@ -354,7 +358,7 @@ def serialize_transaction(self, *args, **kwargs) -> TransactionAPI: self.address, self.abi, *arguments, **converted_kwargs ) - def __call__(self, *args, **kwargs) -> ReceiptAPI: + def __call__(self, *args, **kwargs) -> "ReceiptAPI": txn = self.serialize_transaction(*args, **kwargs) private = kwargs.get("private", False) @@ -370,7 +374,7 @@ def __call__(self, *args, **kwargs) -> ReceiptAPI: class ContractTransactionHandler(ContractMethodHandler): - def as_transaction(self, *args, **kwargs) -> TransactionAPI: + def as_transaction(self, *args, **kwargs) -> "TransactionAPI": """ Get a :class:`~ape.api.transactions.TransactionAPI` for this contract method invocation. This is useful @@ -421,7 +425,7 @@ def call(self) -> ContractCallHandler: return ContractCallHandler(self.contract, self.abis) - def __call__(self, *args, **kwargs) -> ReceiptAPI: + def __call__(self, *args, **kwargs) -> "ReceiptAPI": contract_transaction = self._as_transaction(*args) if "sender" not in kwargs and self.account_manager.default_sender is not None: kwargs["sender"] = self.account_manager.default_sender @@ -727,7 +731,7 @@ def range( ) yield from self.query_manager.query(contract_event_query) # type: ignore - def from_receipt(self, receipt: ReceiptAPI) -> list[ContractLog]: + def from_receipt(self, receipt: "ReceiptAPI") -> list[ContractLog]: """ Get all the events from the given receipt. @@ -864,7 +868,7 @@ def decode_input(self, calldata: bytes) -> tuple[str, dict[str, Any]]: input_dict = ecosystem.decode_calldata(method, rest_calldata) return method.selector, input_dict - def _create_custom_error_type(self, abi: ErrorABI, **kwargs) -> type[CustomError]: + def _create_custom_error_type(self, abi: "ErrorABI", **kwargs) -> type[CustomError]: def exec_body(namespace): namespace["abi"] = abi namespace["contract"] = self @@ -929,7 +933,7 @@ def __init__( (txn_hash if isinstance(txn_hash, str) else to_hex(txn_hash)) if txn_hash else None ) - def __call__(self, *args, **kwargs) -> ReceiptAPI: + def __call__(self, *args, **kwargs) -> "ReceiptAPI": has_value = kwargs.get("value") has_data = kwargs.get("data") or kwargs.get("input") has_non_payable_fallback = ( @@ -953,7 +957,7 @@ def __call__(self, *args, **kwargs) -> ReceiptAPI: return super().__call__(*args, **kwargs) @classmethod - def from_receipt(cls, receipt: ReceiptAPI, contract_type: ContractType) -> "ContractInstance": + def from_receipt(cls, receipt: "ReceiptAPI", contract_type: ContractType) -> "ContractInstance": """ Create a contract instance from the contract deployment receipt. """ @@ -1074,7 +1078,7 @@ def call_view_method(self, method_name: str, *args, **kwargs) -> Any: name = self.contract_type.name or ContractType.__name__ raise ApeAttributeError(f"'{name}' has no attribute '{method_name}'.") - def invoke_transaction(self, method_name: str, *args, **kwargs) -> ReceiptAPI: + def invoke_transaction(self, method_name: str, *args, **kwargs) -> "ReceiptAPI": """ Call a contract's function directly using the method_name. This function is for non-view function's which may change @@ -1183,7 +1187,7 @@ def _events_(self) -> dict[str, list[ContractEvent]]: @cached_property def _errors_(self) -> dict[str, list[type[CustomError]]]: - abis: dict[str, list[ErrorABI]] = {} + abis: dict[str, list["ErrorABI"]] = {} try: for abi in self.contract_type.errors: @@ -1434,7 +1438,7 @@ def constructor(self) -> ContractConstructor: deployment_bytecode=self.contract_type.get_deployment_bytecode() or HexBytes(""), ) - def __call__(self, *args, **kwargs) -> TransactionAPI: + def __call__(self, *args, **kwargs) -> "TransactionAPI": args_length = len(args) inputs_length = ( len(self.constructor.abi.inputs) @@ -1500,7 +1504,7 @@ def deploy(self, *args, publish: bool = False, **kwargs) -> ContractInstance: instance.base_path = self.base_path or self.local_project.contracts_folder return instance - def _cache_wrap(self, function: Callable) -> ReceiptAPI: + def _cache_wrap(self, function: Callable) -> "ReceiptAPI": """ A helper method to ensure a contract type is cached as early on as possible to help enrich errors from ``deploy()`` transactions @@ -1525,7 +1529,7 @@ def _cache_wrap(self, function: Callable) -> ReceiptAPI: raise # The error after caching. - def declare(self, *args, **kwargs) -> ReceiptAPI: + def declare(self, *args, **kwargs) -> "ReceiptAPI": transaction = self.provider.network.ecosystem.encode_contract_blueprint( self.contract_type, *args, **kwargs ) diff --git a/src/ape/logging.py b/src/ape/logging.py index a874c3353f..34fa511e1e 100644 --- a/src/ape/logging.py +++ b/src/ape/logging.py @@ -287,7 +287,9 @@ def _format_logger( def get_logger( - name: str, fmt: Optional[str] = None, handlers: Optional[Sequence[Callable[[str], str]]] = None + name: str, + fmt: Optional[str] = None, + handlers: Optional[Sequence[Callable[[str], str]]] = None, ) -> logging.Logger: """ Get a logger with the given ``name`` and configure it for usage with Ape. diff --git a/src/ape/managers/accounts.py b/src/ape/managers/accounts.py index 858cf0de8a..e45939204e 100644 --- a/src/ape/managers/accounts.py +++ b/src/ape/managers/accounts.py @@ -25,7 +25,7 @@ @contextlib.contextmanager def _use_sender( account: Union[AccountAPI, TestAccountAPI] -) -> Generator[AccountAPI, TestAccountAPI, None]: +) -> "Generator[AccountAPI, TestAccountAPI, None]": try: _DEFAULT_SENDERS.append(account) yield account @@ -160,7 +160,7 @@ def stop_impersonating(self, address: AddressType): def generate_test_account(self, container_name: str = "test") -> TestAccountAPI: return self.containers[container_name].generate_account() - def use_sender(self, account_id: Union[TestAccountAPI, AddressType, int]) -> ContextManager: + def use_sender(self, account_id: Union[TestAccountAPI, AddressType, int]) -> "ContextManager": account = account_id if isinstance(account_id, TestAccountAPI) else self[account_id] return _use_sender(account) @@ -412,7 +412,7 @@ def __contains__(self, address: AddressType) -> bool: def use_sender( self, account_id: Union[AccountAPI, AddressType, str, int], - ) -> ContextManager: + ) -> "ContextManager": if not isinstance(account_id, AccountAPI): if isinstance(account_id, int) or is_hex(account_id): account = self[account_id] diff --git a/src/ape/managers/chain.py b/src/ape/managers/chain.py index fbfde6f326..4c14bad123 100644 --- a/src/ape/managers/chain.py +++ b/src/ape/managers/chain.py @@ -6,13 +6,11 @@ from functools import partial, singledispatchmethod from pathlib import Path from statistics import mean, median -from typing import IO, Optional, Union, cast +from typing import IO, TYPE_CHECKING, Optional, Union, cast import pandas as pd -from eth_pydantic_types import HexBytes from ethpm_types import ABI, ContractType from rich.box import SIMPLE -from rich.console import Console as RichConsole from rich.table import Table from ape.api.address import BaseAddress @@ -42,11 +40,16 @@ from ape.logging import get_rich_console, logger from ape.managers.base import BaseManager from ape.types.address import AddressType -from ape.types.trace import GasReport, SourceTraceback -from ape.types.vm import SnapshotID from ape.utils.basemodel import BaseInterfaceModel from ape.utils.misc import is_evm_precompile, is_zero_hex, log_instead_of_fail, nonreentrant +if TYPE_CHECKING: + from eth_pydantic_types import HexBytes + from rich.console import Console as RichConsole + + from ape.types.trace import GasReport, SourceTraceback + from ape.types.vm import SnapshotID + class BlockContainer(BaseManager): """ @@ -1131,7 +1134,7 @@ def instance_at( self, address: Union[str, AddressType], contract_type: Optional[ContractType] = None, - txn_hash: Optional[Union[str, HexBytes]] = None, + txn_hash: Optional[Union[str, "HexBytes"]] = None, abi: Optional[Union[list[ABI], dict, str, Path]] = None, ) -> ContractInstance: """ @@ -1413,7 +1416,7 @@ class ReportManager(BaseManager): **NOTE**: This class is not part of the public API. """ - def show_gas(self, report: GasReport, file: Optional[IO[str]] = None): + def show_gas(self, report: "GasReport", file: Optional[IO[str]] = None): tables: list[Table] = [] for contract_id, method_calls in report.items(): @@ -1454,16 +1457,16 @@ def show_gas(self, report: GasReport, file: Optional[IO[str]] = None): self.echo(*tables, file=file) def echo( - self, *rich_items, file: Optional[IO[str]] = None, console: Optional[RichConsole] = None + self, *rich_items, file: Optional[IO[str]] = None, console: Optional["RichConsole"] = None ): console = console or get_rich_console(file) console.print(*rich_items) def show_source_traceback( self, - traceback: SourceTraceback, + traceback: "SourceTraceback", file: Optional[IO[str]] = None, - console: Optional[RichConsole] = None, + console: Optional["RichConsole"] = None, failing: bool = True, ): console = console or get_rich_console(file) @@ -1471,7 +1474,7 @@ def show_source_traceback( console.print(str(traceback), style=style) def show_events( - self, events: list, file: Optional[IO[str]] = None, console: Optional[RichConsole] = None + self, events: list, file: Optional[IO[str]] = None, console: Optional["RichConsole"] = None ): console = console or get_rich_console(file) console.print("Events emitted:") @@ -1587,7 +1590,7 @@ def __repr__(self) -> str: cls_name = getattr(type(self), "__name__", ChainManager.__name__) return f"<{cls_name} ({props})>" - def snapshot(self) -> SnapshotID: + def snapshot(self) -> "SnapshotID": """ Record the current state of the blockchain with intent to later call the method :meth:`~ape.managers.chain.ChainManager.revert` @@ -1607,7 +1610,7 @@ def snapshot(self) -> SnapshotID: return snapshot_id - def restore(self, snapshot_id: Optional[SnapshotID] = None): + def restore(self, snapshot_id: Optional["SnapshotID"] = None): """ Regress the current call using the given snapshot ID. Allows developers to go back to a previous state. diff --git a/src/ape/managers/compilers.py b/src/ape/managers/compilers.py index 09cc989fa4..b37a2cefec 100644 --- a/src/ape/managers/compilers.py +++ b/src/ape/managers/compilers.py @@ -5,10 +5,7 @@ from typing import TYPE_CHECKING, Any, Optional, Union from eth_pydantic_types import HexBytes -from ethpm_types import ContractType -from ethpm_types.source import Content -from ape.api.compiler import CompilerAPI from ape.contracts import ContractContainer from ape.exceptions import CompilerError, ContractLogicError, CustomError from ape.logging import logger @@ -23,6 +20,10 @@ from ape.utils.os import get_full_extension if TYPE_CHECKING: + from ethpm_types.contract_type import ContractType + from ethpm_types.source import Content + + from ape.api.compiler import CompilerAPI from ape.managers.project import ProjectManager @@ -39,7 +40,7 @@ class CompilerManager(BaseManager, ExtraAttributesMixin): from ape import compilers # "compilers" is the CompilerManager singleton """ - _registered_compilers_cache: dict[Path, dict[str, CompilerAPI]] = {} + _registered_compilers_cache: dict[Path, dict[str, "CompilerAPI"]] = {} @log_instead_of_fail(default="") def __repr__(self) -> str: @@ -59,7 +60,7 @@ def __getattr__(self, attr_name: str) -> Any: return get_attribute_with_extras(self, attr_name) @cached_property - def registered_compilers(self) -> dict[str, CompilerAPI]: + def registered_compilers(self) -> dict[str, "CompilerAPI"]: """ Each compile-able file extension mapped to its respective :class:`~ape.api.compiler.CompilerAPI` instance. @@ -80,7 +81,7 @@ def registered_compilers(self) -> dict[str, CompilerAPI]: return registered_compilers - def get_compiler(self, name: str, settings: Optional[dict] = None) -> Optional[CompilerAPI]: + def get_compiler(self, name: str, settings: Optional[dict] = None) -> Optional["CompilerAPI"]: for compiler in self.registered_compilers.values(): if compiler.name != name: continue @@ -98,7 +99,7 @@ def compile( contract_filepaths: Union[Path, str, Iterable[Union[Path, str]]], project: Optional["ProjectManager"] = None, settings: Optional[dict] = None, - ) -> Iterator[ContractType]: + ) -> Iterator["ContractType"]: """ Invoke :meth:`ape.ape.compiler.CompilerAPI.compile` for each of the given files. For example, use the `ape-solidity plugin `__ @@ -333,7 +334,7 @@ def get_custom_error(self, err: ContractLogicError) -> Optional[CustomError]: except NotImplementedError: return None - def flatten_contract(self, path: Path, **kwargs) -> Content: + def flatten_contract(self, path: Path, **kwargs) -> "Content": """ Get the flattened version of a contract via its source path. Delegates to the matching :class:`~ape.api.compilers.CompilerAPI`. diff --git a/src/ape/managers/config.py b/src/ape/managers/config.py index 7739703008..92d94bf53a 100644 --- a/src/ape/managers/config.py +++ b/src/ape/managers/config.py @@ -3,9 +3,7 @@ from contextlib import contextmanager from functools import cached_property from pathlib import Path -from typing import Any, Optional - -from ethpm_types import PackageManifest +from typing import TYPE_CHECKING, Any, Optional from ape.api.config import ApeConfig from ape.managers.base import BaseManager @@ -20,6 +18,10 @@ from ape.utils.os import create_tempdir, in_tempdir from ape.utils.rpc import RPCHeaders +if TYPE_CHECKING: + from ethpm_types import PackageManifest + + CONFIG_FILE_NAME = "ape-config.yaml" @@ -93,7 +95,7 @@ def merge_with_global(self, project_config: ApeConfig) -> ApeConfig: return ApeConfig.model_validate(merged_data) @classmethod - def extract_config(cls, manifest: PackageManifest, **overrides) -> ApeConfig: + def extract_config(cls, manifest: "PackageManifest", **overrides) -> ApeConfig: """ Calculate the ape-config data from a package manifest. diff --git a/src/ape/managers/converters.py b/src/ape/managers/converters.py index ca141b53ff..85c48f6e69 100644 --- a/src/ape/managers/converters.py +++ b/src/ape/managers/converters.py @@ -3,7 +3,7 @@ from datetime import datetime, timedelta, timezone from decimal import Decimal from functools import cached_property -from typing import Any, Union +from typing import TYPE_CHECKING, Any, Union from dateutil.parser import parse from eth_pydantic_types import Address, HexBytes @@ -16,7 +16,6 @@ to_checksum_address, to_int, ) -from ethpm_types import ConstructorABI, EventABI, MethodABI from ape.api.address import BaseAddress from ape.api.convert import ConverterAPI @@ -28,6 +27,9 @@ from .base import BaseManager +if TYPE_CHECKING: + from ethpm_types import ConstructorABI, EventABI, MethodABI + class HexConverter(ConverterAPI): """ @@ -400,7 +402,7 @@ def convert(self, value: Any, to_type: Union[type, tuple, list]) -> Any: def convert_method_args( self, - abi: Union[MethodABI, ConstructorABI, EventABI], + abi: Union["MethodABI", "ConstructorABI", "EventABI"], arguments: Sequence[Any], ): input_types = [i.canonical_type for i in abi.inputs] diff --git a/src/ape/managers/networks.py b/src/ape/managers/networks.py index b116140690..8297ce43a3 100644 --- a/src/ape/managers/networks.py +++ b/src/ape/managers/networks.py @@ -1,9 +1,8 @@ from collections.abc import Collection, Iterator from functools import cached_property -from typing import Optional, Union +from typing import TYPE_CHECKING, Optional, Union from ape.api.networks import EcosystemAPI, NetworkAPI, ProviderContextManager -from ape.api.providers import ProviderAPI from ape.exceptions import EcosystemNotFoundError, NetworkError, NetworkNotFoundError from ape.managers.base import BaseManager from ape.utils.basemodel import ( @@ -13,9 +12,12 @@ only_raise_attribute_error, ) from ape.utils.misc import _dict_overlay, log_instead_of_fail -from ape.utils.rpc import RPCHeaders from ape_ethereum.provider import EthereumNodeProvider +if TYPE_CHECKING: + from ape.api.providers import ProviderAPI + from ape.utils.rpc import RPCHeaders + class NetworkManager(BaseManager, ExtraAttributesMixin): """ @@ -32,7 +34,7 @@ class NetworkManager(BaseManager, ExtraAttributesMixin): ... """ - _active_provider: Optional[ProviderAPI] = None + _active_provider: Optional["ProviderAPI"] = None _default_ecosystem_name: Optional[str] = None # For adhoc adding custom networks, or incorporating some defined @@ -47,7 +49,7 @@ def __repr__(self) -> str: return f"<{content}>" @property - def active_provider(self) -> Optional[ProviderAPI]: + def active_provider(self) -> Optional["ProviderAPI"]: """ The currently connected provider if one exists. Otherwise, returns ``None``. """ @@ -55,7 +57,7 @@ def active_provider(self) -> Optional[ProviderAPI]: return self._active_provider @active_provider.setter - def active_provider(self, new_value: ProviderAPI): + def active_provider(self, new_value: "ProviderAPI"): self._active_provider = new_value @property @@ -88,7 +90,7 @@ def ecosystem(self) -> EcosystemAPI: def get_request_headers( self, ecosystem_name: str, network_name: str, provider_name: str - ) -> RPCHeaders: + ) -> "RPCHeaders": """ All request headers to be used when connecting to this network. """ @@ -249,9 +251,9 @@ def _plugin_ecosystems(self) -> dict[str, EcosystemAPI]: def create_custom_provider( self, connection_str: str, - provider_cls: type[ProviderAPI] = EthereumNodeProvider, + provider_cls: type["ProviderAPI"] = EthereumNodeProvider, provider_name: Optional[str] = None, - ) -> ProviderAPI: + ) -> "ProviderAPI": """ Create a custom connection to a URI using the EthereumNodeProvider provider. **NOTE**: This provider will assume EVM-like behavior and this is generally not recommended. @@ -444,7 +446,7 @@ def get_provider_from_choice( self, network_choice: Optional[str] = None, provider_settings: Optional[dict] = None, - ) -> ProviderAPI: + ) -> "ProviderAPI": """ Get a :class:`~ape.api.providers.ProviderAPI` from a network choice. A network choice is any value returned from diff --git a/src/ape/managers/query.py b/src/ape/managers/query.py index 86cd971b2c..c96cca06a6 100644 --- a/src/ape/managers/query.py +++ b/src/ape/managers/query.py @@ -14,7 +14,7 @@ QueryAPI, QueryType, ) -from ape.api.transactions import ReceiptAPI, TransactionAPI +from ape.api.transactions import ReceiptAPI, TransactionAPI # noqa: TC002 from ape.contracts.base import ContractLog, LogFilter from ape.exceptions import QueryEngineError from ape.logging import logger diff --git a/src/ape/pytest/config.py b/src/ape/pytest/config.py index 99f6db77c7..77825f338f 100644 --- a/src/ape/pytest/config.py +++ b/src/ape/pytest/config.py @@ -1,11 +1,12 @@ from functools import cached_property -from typing import Any, Optional, Union - -from _pytest.config import Config as PytestConfig +from typing import TYPE_CHECKING, Any, Optional, Union from ape.types.trace import ContractFunctionPath from ape.utils.basemodel import ManagerAccessMixin +if TYPE_CHECKING: + from _pytest.config import Config as PytestConfig + def _get_config_exclusions(config) -> list[ContractFunctionPath]: return [ @@ -21,7 +22,7 @@ class ConfigWrapper(ManagerAccessMixin): Pytest config object for ease-of-use and code-sharing. """ - def __init__(self, pytest_config: PytestConfig): + def __init__(self, pytest_config: "PytestConfig"): self.pytest_config = pytest_config @cached_property diff --git a/src/ape/pytest/coverage.py b/src/ape/pytest/coverage.py index 9adee6fa31..784a025f06 100644 --- a/src/ape/pytest/coverage.py +++ b/src/ape/pytest/coverage.py @@ -1,36 +1,39 @@ from collections.abc import Iterable from pathlib import Path -from typing import Callable, Optional, Union +from typing import TYPE_CHECKING, Callable, Optional, Union import click -from ethpm_types.abi import MethodABI -from ethpm_types.source import ContractSource from ape.logging import logger -from ape.managers.project import ProjectManager -from ape.pytest.config import ConfigWrapper from ape.types.coverage import CoverageProject, CoverageReport -from ape.types.trace import ContractFunctionPath, ControlFlow, SourceTraceback from ape.utils.basemodel import ManagerAccessMixin from ape.utils.misc import get_current_timestamp_ms from ape.utils.os import get_full_extension, get_relative_path from ape.utils.trace import parse_coverage_tables +if TYPE_CHECKING: + from ethpm_types.abi import MethodABI + from ethpm_types.source import ContractSource + + from ape.managers.project import ProjectManager + from ape.pytest.config import ConfigWrapper + from ape.types.trace import ContractFunctionPath, ControlFlow, SourceTraceback + class CoverageData(ManagerAccessMixin): def __init__( self, - project: ProjectManager, - sources: Union[Iterable[ContractSource], Callable[[], Iterable[ContractSource]]], + project: "ProjectManager", + sources: Union[Iterable["ContractSource"], Callable[[], Iterable["ContractSource"]]], ): self.project = project - self._sources: Union[Iterable[ContractSource], Callable[[], Iterable[ContractSource]]] = ( - sources - ) + self._sources: Union[ + Iterable["ContractSource"], Callable[[], Iterable["ContractSource"]] + ] = sources self._report: Optional[CoverageReport] = None @property - def sources(self) -> list[ContractSource]: + def sources(self) -> list["ContractSource"]: if isinstance(self._sources, list): return self._sources @@ -138,8 +141,8 @@ def cover( class CoverageTracker(ManagerAccessMixin): def __init__( self, - config_wrapper: ConfigWrapper, - project: Optional[ProjectManager] = None, + config_wrapper: "ConfigWrapper", + project: Optional["ProjectManager"] = None, output_path: Optional[Path] = None, ): self.config_wrapper = config_wrapper @@ -173,7 +176,7 @@ def enabled(self) -> bool: return self.config_wrapper.track_coverage @property - def exclusions(self) -> list[ContractFunctionPath]: + def exclusions(self) -> list["ContractFunctionPath"]: return self.config_wrapper.coverage_exclusions def reset(self): @@ -182,7 +185,7 @@ def reset(self): def cover( self, - traceback: SourceTraceback, + traceback: "SourceTraceback", contract: Optional[str] = None, function: Optional[str] = None, ): @@ -259,7 +262,7 @@ def cover( def _cover( self, - control_flow: ControlFlow, + control_flow: "ControlFlow", last_path: Optional[Path] = None, last_pcs: Optional[set[int]] = None, last_call: Optional[str] = None, @@ -281,7 +284,7 @@ def _cover( inc_fn = last_call is None or last_call != control_flow.closure.full_name return self.data.cover(control_flow.source_path, new_pcs, inc_fn_hits=inc_fn) - def hit_function(self, contract_source: ContractSource, method: MethodABI): + def hit_function(self, contract_source: "ContractSource", method: "MethodABI"): """ Another way to increment a function's hit count. Providers may not offer a way to trace calls but this method is available to still increment function diff --git a/src/ape/pytest/fixtures.py b/src/ape/pytest/fixtures.py index 80a77b0279..925a10d903 100644 --- a/src/ape/pytest/fixtures.py +++ b/src/ape/pytest/fixtures.py @@ -1,23 +1,25 @@ from collections.abc import Iterator from fnmatch import fnmatch from functools import cached_property -from typing import Optional +from typing import TYPE_CHECKING, Optional import pytest from eth_utils import to_hex -from ape.api.accounts import TestAccountAPI -from ape.api.transactions import ReceiptAPI from ape.exceptions import BlockNotFoundError, ChainError from ape.logging import logger -from ape.managers.chain import ChainManager -from ape.managers.networks import NetworkManager -from ape.managers.project import ProjectManager -from ape.pytest.config import ConfigWrapper -from ape.types.vm import SnapshotID from ape.utils.basemodel import ManagerAccessMixin from ape.utils.rpc import allow_disconnected +if TYPE_CHECKING: + from ape.api.accounts import TestAccountAPI + from ape.api.transactions import ReceiptAPI + from ape.managers.chain import ChainManager + from ape.managers.networks import NetworkManager + from ape.managers.project import ProjectManager + from ape.pytest.config import ConfigWrapper + from ape.types.vm import SnapshotID + class PytestApeFixtures(ManagerAccessMixin): # NOTE: Avoid including links, markdown, or rst in method-docs @@ -27,7 +29,7 @@ class PytestApeFixtures(ManagerAccessMixin): _supports_snapshot: bool = True receipt_capture: "ReceiptCapture" - def __init__(self, config_wrapper: ConfigWrapper, receipt_capture: "ReceiptCapture"): + def __init__(self, config_wrapper: "ConfigWrapper", receipt_capture: "ReceiptCapture"): self.config_wrapper = config_wrapper self.receipt_capture = receipt_capture @@ -40,7 +42,7 @@ def _track_transactions(self) -> bool: ) @pytest.fixture(scope="session") - def accounts(self) -> list[TestAccountAPI]: + def accounts(self) -> list["TestAccountAPI"]: """ A collection of pre-funded accounts. """ @@ -54,21 +56,21 @@ def compilers(self): return self.compiler_manager @pytest.fixture(scope="session") - def chain(self) -> ChainManager: + def chain(self) -> "ChainManager": """ Manipulate the blockchain, such as mine or change the pending timestamp. """ return self.chain_manager @pytest.fixture(scope="session") - def networks(self) -> NetworkManager: + def networks(self) -> "NetworkManager": """ Connect to other networks in your tests. """ return self.network_manager @pytest.fixture(scope="session") - def project(self) -> ProjectManager: + def project(self) -> "ProjectManager": """ Access contract types and dependencies. """ @@ -121,7 +123,7 @@ def _isolation(self) -> Iterator[None]: _function_isolation = pytest.fixture(_isolation, scope="function") @allow_disconnected - def _snapshot(self) -> Optional[SnapshotID]: + def _snapshot(self) -> Optional["SnapshotID"]: try: return self.chain_manager.snapshot() except NotImplementedError: @@ -135,7 +137,7 @@ def _snapshot(self) -> Optional[SnapshotID]: return None @allow_disconnected - def _restore(self, snapshot_id: SnapshotID): + def _restore(self, snapshot_id: "SnapshotID"): if snapshot_id not in self.chain_manager._snapshots[self.provider.chain_id]: return try: @@ -150,11 +152,11 @@ def _restore(self, snapshot_id: SnapshotID): class ReceiptCapture(ManagerAccessMixin): - config_wrapper: ConfigWrapper - receipt_map: dict[str, dict[str, ReceiptAPI]] = {} + config_wrapper: "ConfigWrapper" + receipt_map: dict[str, dict[str, "ReceiptAPI"]] = {} enter_blocks: list[int] = [] - def __init__(self, config_wrapper: ConfigWrapper): + def __init__(self, config_wrapper: "ConfigWrapper"): self.config_wrapper = config_wrapper def __enter__(self): diff --git a/src/ape/pytest/gas.py b/src/ape/pytest/gas.py index 3b8e0e63ce..1f37af2b68 100644 --- a/src/ape/pytest/gas.py +++ b/src/ape/pytest/gas.py @@ -1,16 +1,20 @@ -from typing import Optional +from typing import TYPE_CHECKING, Optional -from ethpm_types.abi import MethodABI -from ethpm_types.source import ContractSource from evm_trace.gas import merge_reports -from ape.api.trace import TraceAPI -from ape.pytest.config import ConfigWrapper -from ape.types.address import AddressType -from ape.types.trace import ContractFunctionPath, GasReport +from ape.types.trace import GasReport from ape.utils.basemodel import ManagerAccessMixin from ape.utils.trace import _exclude_gas, parse_gas_table +if TYPE_CHECKING: + from ethpm_types.abi import MethodABI + from ethpm_types.source import ContractSource + + from ape.api.trace import TraceAPI + from ape.pytest.config import ConfigWrapper + from ape.types.address import AddressType + from ape.types.trace import ContractFunctionPath + class GasTracker(ManagerAccessMixin): """ @@ -18,7 +22,7 @@ class GasTracker(ManagerAccessMixin): contracts in your test suite. """ - def __init__(self, config_wrapper: ConfigWrapper): + def __init__(self, config_wrapper: "ConfigWrapper"): self.config_wrapper = config_wrapper self.session_gas_report: Optional[GasReport] = None @@ -27,7 +31,7 @@ def enabled(self) -> bool: return self.config_wrapper.track_gas @property - def gas_exclusions(self) -> list[ContractFunctionPath]: + def gas_exclusions(self) -> list["ContractFunctionPath"]: return self.config_wrapper.gas_exclusions def show_session_gas(self) -> bool: @@ -38,7 +42,7 @@ def show_session_gas(self) -> bool: self.chain_manager._reports.echo(*tables) return True - def append_gas(self, trace: TraceAPI, contract_address: AddressType): + def append_gas(self, trace: "TraceAPI", contract_address: "AddressType"): contract_type = self.chain_manager.contracts.get(contract_address) if not contract_type: # Skip unknown contracts. @@ -47,7 +51,7 @@ def append_gas(self, trace: TraceAPI, contract_address: AddressType): report = trace.get_gas_report(exclude=self.gas_exclusions) self._merge(report) - def append_toplevel_gas(self, contract: ContractSource, method: MethodABI, gas_cost: int): + def append_toplevel_gas(self, contract: "ContractSource", method: "MethodABI", gas_cost: int): exclusions = self.gas_exclusions or [] if (contract_id := contract.contract_type.name) and not _exclude_gas( exclusions, contract_id, method.selector diff --git a/src/ape/pytest/plugin.py b/src/ape/pytest/plugin.py index 7cfe0c106f..72d09c1809 100644 --- a/src/ape/pytest/plugin.py +++ b/src/ape/pytest/plugin.py @@ -1,8 +1,7 @@ import sys from pathlib import Path -from typing import Optional +from typing import TYPE_CHECKING, Optional -from ape.api.networks import EcosystemAPI from ape.exceptions import ConfigError from ape.pytest.config import ConfigWrapper from ape.pytest.coverage import CoverageTracker @@ -11,8 +10,11 @@ from ape.pytest.runners import PytestApeRunner from ape.utils.basemodel import ManagerAccessMixin +if TYPE_CHECKING: + from ape.api.networks import EcosystemAPI -def _get_default_network(ecosystem: Optional[EcosystemAPI] = None) -> str: + +def _get_default_network(ecosystem: Optional["EcosystemAPI"] = None) -> str: if ecosystem is None: ecosystem = ManagerAccessMixin.network_manager.default_ecosystem diff --git a/src/ape/pytest/runners.py b/src/ape/pytest/runners.py index 98a708e99b..e41f724027 100644 --- a/src/ape/pytest/runners.py +++ b/src/ape/pytest/runners.py @@ -1,30 +1,32 @@ from pathlib import Path -from typing import Optional +from typing import TYPE_CHECKING, Optional import click import pytest from _pytest._code.code import Traceback as PytestTraceback from rich import print as rich_print -from ape.api.networks import ProviderContextManager from ape.exceptions import ConfigError from ape.logging import LogLevel -from ape.pytest.config import ConfigWrapper -from ape.pytest.coverage import CoverageTracker -from ape.pytest.fixtures import ReceiptCapture -from ape.pytest.gas import GasTracker -from ape.types.coverage import CoverageReport from ape.utils.basemodel import ManagerAccessMixin from ape_console._cli import console +if TYPE_CHECKING: + from ape.api.networks import ProviderContextManager + from ape.pytest.config import ConfigWrapper + from ape.pytest.coverage import CoverageTracker + from ape.pytest.fixtures import ReceiptCapture + from ape.pytest.gas import GasTracker + from ape.types.coverage import CoverageReport + class PytestApeRunner(ManagerAccessMixin): def __init__( self, - config_wrapper: ConfigWrapper, - receipt_capture: ReceiptCapture, - gas_tracker: GasTracker, - coverage_tracker: CoverageTracker, + config_wrapper: "ConfigWrapper", + receipt_capture: "ReceiptCapture", + gas_tracker: "GasTracker", + coverage_tracker: "CoverageTracker", ): self.config_wrapper = config_wrapper self.receipt_capture = receipt_capture @@ -36,11 +38,11 @@ def __init__( self.coverage_tracker = coverage_tracker @property - def _provider_context(self) -> ProviderContextManager: + def _provider_context(self) -> "ProviderContextManager": return self.network_manager.parse_network_choice(self.config_wrapper.network) @property - def _coverage_report(self) -> Optional[CoverageReport]: + def _coverage_report(self) -> Optional["CoverageReport"]: return self.coverage_tracker.data.report if self.coverage_tracker.data else None def pytest_exception_interact(self, report, call): diff --git a/src/ape/types/address.py b/src/ape/types/address.py index 1b68301a85..8572825edd 100644 --- a/src/ape/types/address.py +++ b/src/ape/types/address.py @@ -1,12 +1,15 @@ -from typing import Annotated, Any, Optional, Union +from typing import TYPE_CHECKING, Annotated, Any, Optional, Union from eth_pydantic_types import Address as _Address from eth_pydantic_types import HashBytes20, HashStr20 from eth_typing import ChecksumAddress -from pydantic_core.core_schema import ValidationInfo from ape.utils.basemodel import ManagerAccessMixin +if TYPE_CHECKING: + from pydantic_core.core_schema import ValidationInfo + + RawAddress = Union[str, int, HashStr20, HashBytes20] """ A raw data-type representation of an address. @@ -23,7 +26,7 @@ class _AddressValidator(_Address, ManagerAccessMixin): """ @classmethod - def __eth_pydantic_validate__(cls, value: Any, info: Optional[ValidationInfo] = None) -> str: + def __eth_pydantic_validate__(cls, value: Any, info: Optional["ValidationInfo"] = None) -> str: if type(value) in (list, tuple): return cls.conversion_manager.convert(value, list[AddressType]) diff --git a/src/ape/types/coverage.py b/src/ape/types/coverage.py index 760d20935a..1179fe7758 100644 --- a/src/ape/types/coverage.py +++ b/src/ape/types/coverage.py @@ -2,12 +2,12 @@ from datetime import datetime from html.parser import HTMLParser from pathlib import Path -from typing import Any, Optional +from typing import TYPE_CHECKING, Any, Optional from xml.dom.minidom import getDOMImplementation from xml.etree.ElementTree import Element, SubElement, tostring import requests -from ethpm_types.source import ContractSource, SourceLocation +from ethpm_types.source import SourceLocation from pydantic import NonNegativeInt, field_validator from ape.logging import logger @@ -15,6 +15,10 @@ from ape.utils.misc import get_current_timestamp_ms from ape.version import version as ape_version +if TYPE_CHECKING: + from ethpm_types.source import ContractSource + + _APE_DOCS_URL = "https://docs.apeworx.io/ape/stable/index.html" _DTD_URL = "https://raw.githubusercontent.com/cobertura/web/master/htdocs/xml/coverage-04.dtd" _CSS = """ @@ -545,7 +549,7 @@ def model_dump(self, *args, **kwargs) -> dict: return attribs - def include(self, contract_source: ContractSource) -> ContractSourceCoverage: + def include(self, contract_source: "ContractSource") -> ContractSourceCoverage: for src in self.sources: if src.source_id == contract_source.source_id: return src diff --git a/src/ape/types/signatures.py b/src/ape/types/signatures.py index 60db85fcb1..c3f857d919 100644 --- a/src/ape/types/signatures.py +++ b/src/ape/types/signatures.py @@ -1,5 +1,5 @@ from collections.abc import Iterator -from typing import Optional, Union +from typing import TYPE_CHECKING, Optional, Union from eth_account import Account from eth_account.messages import SignableMessage @@ -9,13 +9,15 @@ from ape.utils.misc import as_our_module, log_instead_of_fail -try: - # Only on Python 3.11 - from typing import Self # type: ignore -except ImportError: - from typing_extensions import Self # type: ignore +if TYPE_CHECKING: + from ape.types.address import AddressType + + try: + # Only on Python 3.11 + from typing import Self # type: ignore + except ImportError: + from typing_extensions import Self # type: ignore -from ape.types.address import AddressType # Fix 404 in doc link. as_our_module( @@ -89,7 +91,7 @@ def __iter__(self) -> Iterator[Union[int, bytes]]: yield self.s @classmethod - def from_rsv(cls, rsv: HexBytes) -> Self: + def from_rsv(cls, rsv: HexBytes) -> "Self": # NOTE: Values may be padded. if len(rsv) != 65: raise ValueError("Length of RSV bytes must be 65.") @@ -97,7 +99,7 @@ def from_rsv(cls, rsv: HexBytes) -> Self: return cls(r=HexBytes(rsv[:32]), s=HexBytes(rsv[32:64]), v=rsv[64]) @classmethod - def from_vrs(cls, vrs: HexBytes) -> Self: + def from_vrs(cls, vrs: HexBytes) -> "Self": # NOTE: Values may be padded. if len(vrs) != 65: raise ValueError("Length of VRS bytes must be 65.") @@ -122,7 +124,7 @@ class MessageSignature(_Signature): """ -def recover_signer(msg: SignableMessage, sig: MessageSignature) -> AddressType: +def recover_signer(msg: SignableMessage, sig: MessageSignature) -> "AddressType": """ Get the address of the signer. diff --git a/src/ape/types/trace.py b/src/ape/types/trace.py index dfa65ddeb7..1ac708edf9 100644 --- a/src/ape/types/trace.py +++ b/src/ape/types/trace.py @@ -5,7 +5,6 @@ from eth_pydantic_types import HexBytes from ethpm_types import ASTNode, BaseModel -from ethpm_types.ast import SourceLocation from ethpm_types.source import ( Closure, Content, @@ -20,6 +19,8 @@ from ape.utils.misc import log_instead_of_fail if TYPE_CHECKING: + from ethpm_types.ast import SourceLocation + from ape.api.trace import TraceAPI @@ -162,7 +163,7 @@ def pcs(self) -> set[int]: def extend( self, - location: SourceLocation, + location: "SourceLocation", pcs: Optional[set[int]] = None, ws_start: Optional[int] = None, ): @@ -441,7 +442,7 @@ def format(self) -> str: def add_jump( self, - location: SourceLocation, + location: "SourceLocation", function: Function, depth: int, pcs: Optional[set[int]] = None, @@ -469,7 +470,7 @@ def add_jump( ControlFlow.model_rebuild() self._add(asts, content, pcs, function, depth, source_path=source_path) - def extend_last(self, location: SourceLocation, pcs: Optional[set[int]] = None): + def extend_last(self, location: "SourceLocation", pcs: Optional[set[int]] = None): """ Extend the last node with more content. diff --git a/src/ape/types/units.py b/src/ape/types/units.py index c81d122a4c..22a2c1b1e1 100644 --- a/src/ape/types/units.py +++ b/src/ape/types/units.py @@ -1,4 +1,4 @@ -from typing import Any, Optional +from typing import TYPE_CHECKING, Any, Optional from pydantic_core.core_schema import ( CoreSchema, @@ -7,11 +7,13 @@ no_info_plain_validator_function, plain_serializer_function_ser_schema, ) -from typing_extensions import TypeAlias from ape.exceptions import ConversionError from ape.utils.basemodel import ManagerAccessMixin +if TYPE_CHECKING: + from typing_extensions import TypeAlias + class CurrencyValueComparable(int): """ @@ -72,7 +74,7 @@ def _serialize(value): CurrencyValueComparable.__name__ = int.__name__ -CurrencyValue: TypeAlias = CurrencyValueComparable +CurrencyValue: "TypeAlias" = CurrencyValueComparable """ An alias to :class:`~ape.types.CurrencyValueComparable` for situations when you know for sure the type is a currency-value diff --git a/src/ape/utils/misc.py b/src/ape/utils/misc.py index 6fae748b61..b369f7b1c5 100644 --- a/src/ape/utils/misc.py +++ b/src/ape/utils/misc.py @@ -63,7 +63,7 @@ ) -_python_version = ( +_python_version: str = ( f"{sys.version_info.major}.{sys.version_info.minor}" f".{sys.version_info.micro} {sys.version_info.releaselevel}" ) @@ -193,7 +193,7 @@ def get_package_version(obj: Any) -> str: return "" -__version__ = get_package_version(__name__) +__version__: str = get_package_version(__name__) def load_config(path: Path, expand_envars=True, must_exist=False) -> dict: diff --git a/src/ape/utils/os.py b/src/ape/utils/os.py index 822ea2387b..f66d9736c1 100644 --- a/src/ape/utils/os.py +++ b/src/ape/utils/os.py @@ -211,12 +211,7 @@ def create_tempdir(name: Optional[str] = None) -> Iterator[Path]: def run_in_tempdir( - fn: Callable[ - [ - Path, - ], - Any, - ], + fn: Callable[[Path], Any], name: Optional[str] = None, ): """ diff --git a/src/ape/utils/rpc.py b/src/ape/utils/rpc.py index 3cfa7b54e2..f552fb1c86 100644 --- a/src/ape/utils/rpc.py +++ b/src/ape/utils/rpc.py @@ -8,7 +8,7 @@ from ape.logging import logger from ape.utils.misc import __version__, _python_version -USER_AGENT = f"Ape/{__version__} (Python/{_python_version})" +USER_AGENT: str = f"Ape/{__version__} (Python/{_python_version})" def allow_disconnected(fn: Callable): diff --git a/src/ape_accounts/accounts.py b/src/ape_accounts/accounts.py index aea2f66f60..7dd0bae941 100644 --- a/src/ape_accounts/accounts.py +++ b/src/ape_accounts/accounts.py @@ -3,28 +3,31 @@ from collections.abc import Iterator from os import environ from pathlib import Path -from typing import Any, Optional +from typing import TYPE_CHECKING, Any, Optional import click from eip712.messages import EIP712Message from eth_account import Account as EthAccount from eth_account.hdaccount import ETHEREUM_DEFAULT_PATH from eth_account.messages import encode_defunct -from eth_account.signers.local import LocalAccount from eth_keys import keys # type: ignore from eth_pydantic_types import HexBytes from eth_utils import to_bytes, to_hex from ape.api.accounts import AccountAPI, AccountContainerAPI -from ape.api.transactions import TransactionAPI from ape.exceptions import AccountsError from ape.logging import logger -from ape.types.address import AddressType from ape.types.signatures import MessageSignature, SignableMessage, TransactionSignature from ape.utils.basemodel import ManagerAccessMixin from ape.utils.misc import log_instead_of_fail from ape.utils.validators import _validate_account_alias, _validate_account_passphrase +if TYPE_CHECKING: + from eth_account.signers.local import LocalAccount + + from ape.api.transactions import TransactionAPI + from ape.types.address import AddressType + class InvalidPasswordError(AccountsError): """ @@ -83,7 +86,7 @@ def keyfile(self) -> dict: return json.loads(self.keyfile_path.read_text()) @property - def address(self) -> AddressType: + def address(self) -> "AddressType": return self.network_manager.ethereum.decode_address(self.keyfile["address"]) @property @@ -220,7 +223,9 @@ def sign_message(self, msg: Any, **signer_options) -> Optional[MessageSignature] s=to_bytes(signed_msg.s), ) - def sign_transaction(self, txn: TransactionAPI, **signer_options) -> Optional[TransactionAPI]: + def sign_transaction( + self, txn: "TransactionAPI", **signer_options + ) -> Optional["TransactionAPI"]: user_approves = self.__autosign or click.confirm(f"{txn}\n\nSign: ") if not user_approves: return None @@ -292,7 +297,9 @@ def __decrypt_keyfile(self, passphrase: str) -> bytes: raise InvalidPasswordError() from err -def _write_and_return_account(alias: str, passphrase: str, account: LocalAccount) -> KeyfileAccount: +def _write_and_return_account( + alias: str, passphrase: str, account: "LocalAccount" +) -> KeyfileAccount: """Write an account to disk and return an Ape KeyfileAccount""" path = ManagerAccessMixin.account_manager.containers["accounts"].data_folder.joinpath( f"{alias}.json" diff --git a/src/ape_cache/query.py b/src/ape_cache/query.py index 30e32aeab7..deccb1c651 100644 --- a/src/ape_cache/query.py +++ b/src/ape_cache/query.py @@ -4,11 +4,11 @@ from typing import Any, Optional, cast from sqlalchemy import create_engine, func -from sqlalchemy.engine import CursorResult +from sqlalchemy.engine import CursorResult # noqa: TC002 from sqlalchemy.sql import column, insert, select -from sqlalchemy.sql.expression import Insert, Select +from sqlalchemy.sql.expression import Insert, Select # noqa: TC002 -from ape.api.providers import BlockAPI +from ape.api.providers import BlockAPI # noqa: TC002 from ape.api.query import ( BaseInterfaceModel, BlockQuery, diff --git a/src/ape_ethereum/_print.py b/src/ape_ethereum/_print.py index e2fc56bc7a..4102765d63 100644 --- a/src/ape_ethereum/_print.py +++ b/src/ape_ethereum/_print.py @@ -20,26 +20,29 @@ """ from collections.abc import Iterable -from typing import Any, cast +from typing import TYPE_CHECKING, Any, cast from eth_abi import decode from eth_typing import ChecksumAddress from eth_utils import add_0x_prefix, decode_hex, to_hex from ethpm_types import ContractType, MethodABI -from evm_trace import CallTreeNode from hexbytes import HexBytes -from typing_extensions import TypeGuard import ape from ape_ethereum._console_log_abi import CONSOLE_LOG_ABI +if TYPE_CHECKING: + from evm_trace import CallTreeNode + from typing_extensions import TypeGuard + + CONSOLE_ADDRESS = cast(ChecksumAddress, "0x000000000000000000636F6e736F6c652e6c6f67") VYPER_PRINT_METHOD_ID = HexBytes("0x23cdd8e8") # log(string,bytes) console_contract = ContractType(abi=CONSOLE_LOG_ABI, contractName="console") -def is_console_log(call: CallTreeNode) -> TypeGuard[CallTreeNode]: +def is_console_log(call: "CallTreeNode") -> "TypeGuard[CallTreeNode]": """Determine if a call is a standard console.log() call""" return ( call.address == HexBytes(CONSOLE_ADDRESS) @@ -47,7 +50,7 @@ def is_console_log(call: CallTreeNode) -> TypeGuard[CallTreeNode]: ) -def is_vyper_print(call: CallTreeNode) -> TypeGuard[CallTreeNode]: +def is_vyper_print(call: "CallTreeNode") -> "TypeGuard[CallTreeNode]": """Determine if a call is a standard Vyper print() call""" if call.address != HexBytes(CONSOLE_ADDRESS) or call.calldata[:4] != VYPER_PRINT_METHOD_ID: return False @@ -79,7 +82,7 @@ def vyper_print(calldata: str) -> tuple[Any]: return tuple(data) -def extract_debug_logs(call: CallTreeNode) -> Iterable[tuple[Any]]: +def extract_debug_logs(call: "CallTreeNode") -> Iterable[tuple[Any]]: """Filter calls to console.log() and print() from a transactions call tree""" if is_vyper_print(call) and call.calldata is not None: yield vyper_print(add_0x_prefix(to_hex(call.calldata[4:]))) diff --git a/src/ape_ethereum/ecosystem.py b/src/ape_ethereum/ecosystem.py index 82636f65fa..86ea1b3d16 100644 --- a/src/ape_ethereum/ecosystem.py +++ b/src/ape_ethereum/ecosystem.py @@ -2,7 +2,7 @@ from collections.abc import Iterator, Sequence from decimal import Decimal from functools import cached_property -from typing import Any, ClassVar, Optional, Union, cast +from typing import TYPE_CHECKING, Any, ClassVar, Optional, Union, cast from eth_abi import decode, encode from eth_abi.exceptions import InsufficientDataBytes, NonEmptyPaddingBytes @@ -20,7 +20,6 @@ to_checksum_address, to_hex, ) -from ethpm_types import ContractType from ethpm_types.abi import ABIType, ConstructorABI, EventABI, MethodABI from pydantic import Field, computed_field, field_validator, model_validator from pydantic_settings import SettingsConfigDict @@ -28,8 +27,6 @@ from ape.api.config import PluginConfig from ape.api.networks import EcosystemAPI from ape.api.providers import BlockAPI -from ape.api.trace import TraceAPI -from ape.api.transactions import ReceiptAPI, TransactionAPI from ape.contracts.base import ContractCall from ape.exceptions import ( ApeException, @@ -80,6 +77,13 @@ TransactionType, ) +if TYPE_CHECKING: + from ethpm_types import ContractType + + from ape.api.trace import TraceAPI + from ape.api.transactions import ReceiptAPI, TransactionAPI + + NETWORKS = { # chain_id, network_id "mainnet": (1, 1), @@ -418,7 +422,7 @@ def decode_address(cls, raw_address: RawAddress) -> AddressType: def encode_address(cls, address: AddressType) -> RawAddress: return f"{address}" - def decode_transaction_type(self, transaction_type_id: Any) -> type[TransactionAPI]: + def decode_transaction_type(self, transaction_type_id: Any) -> type["TransactionAPI"]: if isinstance(transaction_type_id, TransactionType): tx_type = transaction_type_id elif isinstance(transaction_type_id, int): @@ -435,8 +439,8 @@ def decode_transaction_type(self, transaction_type_id: Any) -> type[TransactionA return DynamicFeeTransaction def encode_contract_blueprint( - self, contract_type: ContractType, *args, **kwargs - ) -> TransactionAPI: + self, contract_type: "ContractType", *args, **kwargs + ) -> "TransactionAPI": # EIP-5202 implementation. bytes_obj = contract_type.deployment_bytecode contract_bytes = (bytes_obj.to_bytes() or b"") if bytes_obj else b"" @@ -546,7 +550,7 @@ def str_to_slot(text): return None - def decode_receipt(self, data: dict) -> ReceiptAPI: + def decode_receipt(self, data: dict) -> "ReceiptAPI": status = data.get("status") if status is not None: status = self.conversion_manager.convert(status, int) @@ -864,7 +868,7 @@ def encode_transaction( return cast(BaseTransaction, txn) - def create_transaction(self, **kwargs) -> TransactionAPI: + def create_transaction(self, **kwargs) -> "TransactionAPI": """ Returns a transaction using the given constructor kwargs. @@ -902,7 +906,7 @@ def create_transaction(self, **kwargs) -> TransactionAPI: tx_data["data"] = b"" # Deduce the transaction type. - transaction_types: dict[TransactionType, type[TransactionAPI]] = { + transaction_types: dict[TransactionType, type["TransactionAPI"]] = { TransactionType.STATIC: StaticFeeTransaction, TransactionType.ACCESS_LIST: AccessListTransaction, TransactionType.DYNAMIC: DynamicFeeTransaction, @@ -973,7 +977,7 @@ def create_transaction(self, **kwargs) -> TransactionAPI: return txn_class.model_validate(tx_data) - def decode_logs(self, logs: Sequence[dict], *events: EventABI) -> Iterator["ContractLog"]: + def decode_logs(self, logs: Sequence[dict], *events: EventABI) -> Iterator[ContractLog]: if not logs: return @@ -1052,7 +1056,7 @@ def get_abi(_topic: HexStr) -> Optional[LogInputABICollection]: ), ) - def enrich_trace(self, trace: TraceAPI, **kwargs) -> TraceAPI: + def enrich_trace(self, trace: "TraceAPI", **kwargs) -> "TraceAPI": kwargs["trace"] = trace if not isinstance(trace, Trace): # Can only enrich `ape_ethereum.trace.Trace` (or subclass) implementations. @@ -1416,7 +1420,7 @@ def _enrich_revert_message(self, call: dict) -> dict: def _get_contract_type_for_enrichment( self, address: AddressType, **kwargs - ) -> Optional[ContractType]: + ) -> Optional["ContractType"]: if not (contract_type := kwargs.get("contract_type")): try: contract_type = self.chain_manager.contracts.get(address) diff --git a/src/ape_ethereum/multicall/handlers.py b/src/ape_ethereum/multicall/handlers.py index b8722fe7e5..9e48208834 100644 --- a/src/ape_ethereum/multicall/handlers.py +++ b/src/ape_ethereum/multicall/handlers.py @@ -1,12 +1,10 @@ from collections.abc import Iterator from functools import cached_property from types import ModuleType -from typing import Any, Optional, Union +from typing import TYPE_CHECKING, Any, Optional, Union -from eth_pydantic_types import HexBytes from ethpm_types import ContractType -from ape.api.transactions import ReceiptAPI, TransactionAPI from ape.contracts.base import ( ContractCallHandler, ContractInstance, @@ -16,7 +14,6 @@ ) from ape.exceptions import ChainError, DecodingError from ape.logging import logger -from ape.types.address import AddressType from ape.utils.abi import MethodABI from ape.utils.basemodel import ManagerAccessMixin @@ -28,11 +25,17 @@ ) from .exceptions import InvalidOption, UnsupportedChainError, ValueRequired +if TYPE_CHECKING: + from eth_pydantic_types import HexBytes + + from ape.api.transactions import ReceiptAPI, TransactionAPI + from ape.types.address import AddressType + class BaseMulticall(ManagerAccessMixin): def __init__( self, - address: AddressType = MULTICALL3_ADDRESS, + address: "AddressType" = MULTICALL3_ADDRESS, supported_chains: Optional[list[int]] = None, ) -> None: """ @@ -159,13 +162,13 @@ class Call(BaseMulticall): def __init__( self, - address: AddressType = MULTICALL3_ADDRESS, + address: "AddressType" = MULTICALL3_ADDRESS, supported_chains: Optional[list[int]] = None, ) -> None: super().__init__(address=address, supported_chains=supported_chains) self.abis: list[MethodABI] = [] - self._result: Union[None, list[tuple[bool, HexBytes]]] = None + self._result: Union[None, list[tuple[bool, "HexBytes"]]] = None @property def handler(self) -> ContractCallHandler: # type: ignore[override] @@ -180,7 +183,7 @@ def add(self, call: ContractMethodHandler, *args, **kwargs): return self @property - def returnData(self) -> list[HexBytes]: + def returnData(self) -> list["HexBytes"]: # NOTE: this property is kept camelCase to align with the raw EVM struct result = self._result # Declare for typing reasons. return [res.returnData if res.success else None for res in result] # type: ignore @@ -225,7 +228,7 @@ def __call__(self, **call_kwargs) -> Iterator[Any]: self._result = self.handler(self.calls, **call_kwargs) return self._decode_results() - def as_transaction(self, **txn_kwargs) -> TransactionAPI: + def as_transaction(self, **txn_kwargs) -> "TransactionAPI": """ Encode the Multicall transaction as a ``TransactionAPI`` object, but do not execute it. @@ -272,7 +275,7 @@ def _validate_calls(self, **txn_kwargs) -> None: # NOTE: Won't fail if `value` is provided otherwise (won't do anything either) - def __call__(self, **txn_kwargs) -> ReceiptAPI: + def __call__(self, **txn_kwargs) -> "ReceiptAPI": """ Execute the Multicall transaction. The transaction will broadcast again every time the ``Transaction`` object is called. @@ -290,7 +293,7 @@ def __call__(self, **txn_kwargs) -> ReceiptAPI: self._validate_calls(**txn_kwargs) return self.handler(self.calls, **txn_kwargs) - def as_transaction(self, **txn_kwargs) -> TransactionAPI: + def as_transaction(self, **txn_kwargs) -> "TransactionAPI": """ Encode the Multicall transaction as a ``TransactionAPI`` object, but do not execute it. diff --git a/src/ape_ethereum/provider.py b/src/ape_ethereum/provider.py index 53b63b2bf2..c1ff49e704 100644 --- a/src/ape_ethereum/provider.py +++ b/src/ape_ethereum/provider.py @@ -9,14 +9,13 @@ from copy import copy from functools import cached_property, wraps from pathlib import Path -from typing import Any, Optional, Union, cast +from typing import TYPE_CHECKING, Any, Optional, Union, cast import ijson # type: ignore import requests from eth_pydantic_types import HexBytes from eth_typing import BlockNumber, HexStr from eth_utils import add_0x_prefix, is_hex, to_hex -from ethpm_types import EventABI from evmchains import get_random_rpc from pydantic.dataclasses import dataclass from requests import HTTPError @@ -39,7 +38,6 @@ from ape.api.address import Address from ape.api.providers import BlockAPI, ProviderAPI -from ape.api.trace import TraceAPI from ape.api.transactions import ReceiptAPI, TransactionAPI from ape.exceptions import ( _SOURCE_TRACEBACK_ARG, @@ -58,17 +56,23 @@ VirtualMachineError, ) from ape.logging import logger, sanitize_url -from ape.types.address import AddressType from ape.types.events import ContractLog, LogFilter from ape.types.gas import AutoGasLimit from ape.types.trace import SourceTraceback -from ape.types.vm import BlockID, ContractCode from ape.utils.basemodel import ManagerAccessMixin from ape.utils.misc import DEFAULT_MAX_RETRIES_TX, gas_estimation_error_message, to_int from ape_ethereum._print import CONSOLE_ADDRESS, console_contract from ape_ethereum.trace import CallTrace, TraceApproach, TransactionTrace from ape_ethereum.transactions import AccessList, AccessListTransaction, TransactionStatusEnum +if TYPE_CHECKING: + from ethpm_types import EventABI + + from ape.api.trace import TraceAPI + from ape.types.address import AddressType + from ape.types.vm import BlockID, ContractCode + + DEFAULT_PORT = 8545 DEFAULT_HOSTNAME = "localhost" DEFAULT_SETTINGS = {"uri": f"http://{DEFAULT_HOSTNAME}:{DEFAULT_PORT}"} @@ -322,7 +326,7 @@ def update_settings(self, new_settings: dict): self.provider_settings.update(new_settings) self.connect() - def estimate_gas_cost(self, txn: TransactionAPI, block_id: Optional[BlockID] = None) -> int: + def estimate_gas_cost(self, txn: TransactionAPI, block_id: Optional["BlockID"] = None) -> int: # NOTE: Using JSON mode since used as request data. txn_dict = txn.model_dump(by_alias=True, mode="json") @@ -410,7 +414,7 @@ def priority_fee(self) -> int: "eth_maxPriorityFeePerGas not supported in this RPC. Please specify manually." ) from err - def get_block(self, block_id: BlockID) -> BlockAPI: + def get_block(self, block_id: "BlockID") -> BlockAPI: if isinstance(block_id, str) and block_id.isnumeric(): block_id = int(block_id) @@ -429,17 +433,19 @@ def _get_latest_block(self) -> BlockAPI: def _get_latest_block_rpc(self) -> dict: return self.make_request("eth_getBlockByNumber", ["latest", False]) - def get_nonce(self, address: AddressType, block_id: Optional[BlockID] = None) -> int: + def get_nonce(self, address: "AddressType", block_id: Optional["BlockID"] = None) -> int: return self.web3.eth.get_transaction_count(address, block_identifier=block_id) - def get_balance(self, address: AddressType, block_id: Optional[BlockID] = None) -> int: + def get_balance(self, address: "AddressType", block_id: Optional["BlockID"] = None) -> int: return self.web3.eth.get_balance(address, block_identifier=block_id) - def get_code(self, address: AddressType, block_id: Optional[BlockID] = None) -> ContractCode: + def get_code( + self, address: "AddressType", block_id: Optional["BlockID"] = None + ) -> "ContractCode": return self.web3.eth.get_code(address, block_identifier=block_id) def get_storage( - self, address: AddressType, slot: int, block_id: Optional[BlockID] = None + self, address: "AddressType", slot: int, block_id: Optional["BlockID"] = None ) -> HexBytes: try: return HexBytes(self.web3.eth.get_storage_at(address, slot, block_identifier=block_id)) @@ -449,7 +455,7 @@ def get_storage( raise # Raise original error - def get_transaction_trace(self, transaction_hash: str, **kwargs) -> TraceAPI: + def get_transaction_trace(self, transaction_hash: str, **kwargs) -> "TraceAPI": if transaction_hash in self._transaction_trace_cache: return self._transaction_trace_cache[transaction_hash] @@ -463,7 +469,7 @@ def get_transaction_trace(self, transaction_hash: str, **kwargs) -> TraceAPI: def send_call( self, txn: TransactionAPI, - block_id: Optional[BlockID] = None, + block_id: Optional["BlockID"] = None, state: Optional[dict] = None, **kwargs: Any, ) -> HexBytes: @@ -694,7 +700,7 @@ def _create_receipt(self, **kwargs) -> ReceiptAPI: data = {"provider": self, **kwargs} return self.network.ecosystem.decode_receipt(data) - def get_transactions_by_block(self, block_id: BlockID) -> Iterator[TransactionAPI]: + def get_transactions_by_block(self, block_id: "BlockID") -> Iterator[TransactionAPI]: if isinstance(block_id, str): block_id = HexStr(block_id) @@ -707,7 +713,7 @@ def get_transactions_by_block(self, block_id: BlockID) -> Iterator[TransactionAP def get_transactions_by_account_nonce( self, - account: AddressType, + account: "AddressType", start_nonce: int = 0, stop_nonce: int = -1, ) -> Iterator[ReceiptAPI]: @@ -732,7 +738,7 @@ def get_transactions_by_account_nonce( def _find_txn_by_account_and_nonce( self, - account: AddressType, + account: "AddressType", start_nonce: int, stop_nonce: int, start_block: int, @@ -878,11 +884,11 @@ def assert_chain_activity(): def poll_logs( self, stop_block: Optional[int] = None, - address: Optional[AddressType] = None, + address: Optional["AddressType"] = None, topics: Optional[list[Union[str, list[str]]]] = None, required_confirmations: Optional[int] = None, new_block_timeout: Optional[int] = None, - events: Optional[list[EventABI]] = None, + events: Optional[list["EventABI"]] = None, ) -> Iterator[ContractLog]: events = events or [] if required_confirmations is None: @@ -1169,7 +1175,7 @@ def stream_request(self, method: str, params: Iterable, iter_path: str = "result del results[:] def create_access_list( - self, transaction: TransactionAPI, block_id: Optional[BlockID] = None + self, transaction: TransactionAPI, block_id: Optional["BlockID"] = None ) -> list[AccessList]: """ Get the access list for a transaction use ``eth_createAccessList``. @@ -1248,7 +1254,7 @@ def _handle_execution_reverted( exception: Union[Exception, str], txn: Optional[TransactionAPI] = None, trace: _TRACE_ARG = None, - contract_address: Optional[AddressType] = None, + contract_address: Optional["AddressType"] = None, source_traceback: _SOURCE_TRACEBACK_ARG = None, set_ape_traceback: Optional[bool] = None, ) -> ContractLogicError: @@ -1548,7 +1554,7 @@ def _log_connection(self, client_name: str): ) logger.info(f"{msg} {suffix}.") - def ots_get_contract_creator(self, address: AddressType) -> Optional[dict]: + def ots_get_contract_creator(self, address: "AddressType") -> Optional[dict]: if self._ots_api_level is None: return None @@ -1559,7 +1565,7 @@ def ots_get_contract_creator(self, address: AddressType) -> Optional[dict]: return result - def _get_contract_creation_receipt(self, address: AddressType) -> Optional[ReceiptAPI]: + def _get_contract_creation_receipt(self, address: "AddressType") -> Optional[ReceiptAPI]: if result := self.ots_get_contract_creator(address): tx_hash = result["hash"] return self.get_receipt(tx_hash) diff --git a/src/ape_ethereum/trace.py b/src/ape_ethereum/trace.py index 0812624f96..e67463cb0e 100644 --- a/src/ape_ethereum/trace.py +++ b/src/ape_ethereum/trace.py @@ -5,11 +5,10 @@ from collections.abc import Iterable, Iterator, Sequence from enum import Enum from functools import cached_property -from typing import IO, Any, Optional, Union +from typing import IO, TYPE_CHECKING, Any, Optional, Union from eth_pydantic_types import HexStr from eth_utils import is_0x_prefixed, to_hex -from ethpm_types import ContractType, MethodABI from evm_trace import ( CallTreeNode, CallType, @@ -25,17 +24,22 @@ from pydantic import field_validator from rich.tree import Tree -from ape.api.networks import EcosystemAPI from ape.api.trace import TraceAPI from ape.api.transactions import TransactionAPI from ape.exceptions import ContractLogicError, ProviderError, TransactionNotFoundError from ape.logging import get_rich_console, logger -from ape.types.address import AddressType -from ape.types.trace import ContractFunctionPath, GasReport from ape.utils.misc import ZERO_ADDRESS, is_evm_precompile, is_zero_hex, log_instead_of_fail from ape.utils.trace import TraceStyles, _exclude_gas from ape_ethereum._print import extract_debug_logs +if TYPE_CHECKING: + from ethpm_types import ContractType, MethodABI + + from ape.api.networks import EcosystemAPI + from ape.types.address import AddressType + from ape.types.trace import ContractFunctionPath, GasReport + + _INDENT = 2 _WRAP_THRESHOLD = 50 _REVERT_PREFIX = "0x08c379a00000000000000000000000000000000000000000000000000000000000000020" @@ -174,11 +178,11 @@ def frames(self) -> Iterator[TraceFrame]: yield from create_trace_frames(iter(self.raw_trace_frames)) @property - def addresses(self) -> Iterator[AddressType]: + def addresses(self) -> Iterator["AddressType"]: yield from self.get_addresses_used() @cached_property - def root_contract_type(self) -> Optional[ContractType]: + def root_contract_type(self) -> Optional["ContractType"]: if address := self.transaction.get("to"): try: return self.chain_manager.contracts.get(address) @@ -188,7 +192,7 @@ def root_contract_type(self) -> Optional[ContractType]: return None @cached_property - def root_method_abi(self) -> Optional[MethodABI]: + def root_method_abi(self) -> Optional["MethodABI"]: method_id = self.transaction.get("data", b"")[:10] if ct := self.root_contract_type: try: @@ -199,7 +203,7 @@ def root_method_abi(self) -> Optional[MethodABI]: return None @property - def _ecosystem(self) -> EcosystemAPI: + def _ecosystem(self) -> "EcosystemAPI": if provider := self.network_manager.active_provider: return provider.network.ecosystem @@ -357,13 +361,15 @@ def show(self, verbose: bool = False, file: IO[str] = sys.stdout): console.print(root) - def get_gas_report(self, exclude: Optional[Sequence[ContractFunctionPath]] = None) -> GasReport: + def get_gas_report( + self, exclude: Optional[Sequence["ContractFunctionPath"]] = None + ) -> "GasReport": call = self.enriched_calltree return self._get_gas_report_from_call(call, exclude=exclude) def _get_gas_report_from_call( - self, call: dict, exclude: Optional[Sequence[ContractFunctionPath]] = None - ) -> GasReport: + self, call: dict, exclude: Optional[Sequence["ContractFunctionPath"]] = None + ) -> "GasReport": tx = self.transaction # Enrich transfers. @@ -388,7 +394,7 @@ def _get_gas_report_from_call( return merge_reports(*sub_reports) elif not is_zero_hex(call["method_id"]) and not is_evm_precompile(call["method_id"]): - report: GasReport = { + report: "GasReport" = { call["contract_id"]: { call["method_id"]: ( [int(call["gas_cost"])] if call.get("gas_cost") is not None else [] @@ -434,7 +440,7 @@ def _debug_trace_transaction_struct_logs_to_call(self) -> CallTreeNode: def _get_tree(self, verbose: bool = False) -> Tree: return parse_rich_tree(self.enriched_calltree, verbose=verbose) - def _get_abi(self, call: dict) -> Optional[MethodABI]: + def _get_abi(self, call: dict) -> Optional["MethodABI"]: if not (addr := call.get("address")): return self.root_method_abi if not (calldata := call.get("calldata")): diff --git a/src/ape_ethereum/transactions.py b/src/ape_ethereum/transactions.py index 4385973469..83ebbd57a3 100644 --- a/src/ape_ethereum/transactions.py +++ b/src/ape_ethereum/transactions.py @@ -1,7 +1,7 @@ import sys from enum import Enum, IntEnum from functools import cached_property -from typing import IO, Any, Optional, Union +from typing import IO, TYPE_CHECKING, Any, Optional, Union from eth_abi import decode from eth_account import Account as EthAccount @@ -11,12 +11,10 @@ ) from eth_pydantic_types import HexBytes from eth_utils import decode_hex, encode_hex, keccak, to_hex, to_int -from ethpm_types import ContractType from ethpm_types.abi import EventABI, MethodABI from pydantic import BaseModel, Field, field_validator, model_validator from ape.api.transactions import ReceiptAPI, TransactionAPI -from ape.contracts import ContractEvent from ape.exceptions import OutOfGasError, SignatureError, TransactionError from ape.logging import logger from ape.types.address import AddressType @@ -26,6 +24,11 @@ from ape.utils.misc import ZERO_ADDRESS from ape_ethereum.trace import Trace, _events_to_trees +if TYPE_CHECKING: + from ethpm_types import ContractType + + from ape.contracts import ContractEvent + class TransactionStatusEnum(IntEnum): """ @@ -221,7 +224,7 @@ def debug_logs_typed(self) -> list[tuple[Any]]: return list(trace.debug_logs) @cached_property - def contract_type(self) -> Optional[ContractType]: + def contract_type(self) -> Optional["ContractType"]: if address := (self.receiver or self.contract_address): return self.chain_manager.contracts.get(address) diff --git a/src/ape_node/provider.py b/src/ape_node/provider.py index 6f459324ff..95bd54d2b7 100644 --- a/src/ape_node/provider.py +++ b/src/ape_node/provider.py @@ -2,13 +2,12 @@ import shutil from pathlib import Path from subprocess import DEVNULL, PIPE, Popen -from typing import Any, Optional, Union +from typing import TYPE_CHECKING, Any, Optional, Union from eth_utils import add_0x_prefix, to_hex from evmchains import get_random_rpc from geth.chain import initialize_chain from geth.process import BaseGethProcess -from geth.types import GenesisDataTypedDict from geth.wrapper import construct_test_chain_kwargs from pydantic import field_validator from pydantic_settings import SettingsConfigDict @@ -16,11 +15,9 @@ from web3.middleware import geth_poa_middleware as ExtraDataToPOAMiddleware from yarl import URL -from ape.api.accounts import TestAccountAPI from ape.api.config import PluginConfig from ape.api.providers import SubprocessProvider, TestProviderAPI from ape.logging import LogLevel, logger -from ape.types.vm import SnapshotID from ape.utils.misc import ZERO_ADDRESS, log_instead_of_fail, raises_not_implemented from ape.utils.process import JoinableQueue, spawn from ape.utils.testing import ( @@ -39,10 +36,17 @@ ) from ape_ethereum.trace import TraceApproach +if TYPE_CHECKING: + from geth.types import GenesisDataTypedDict + + from ape.api.accounts import TestAccountAPI + from ape.types.vm import SnapshotID + + Alloc = dict[str, dict[str, Any]] -def create_genesis_data(alloc: Alloc, chain_id: int) -> GenesisDataTypedDict: +def create_genesis_data(alloc: Alloc, chain_id: int) -> "GenesisDataTypedDict": """ A wrapper around genesis data for py-geth that fills in more defaults. @@ -398,10 +402,10 @@ def disconnect(self): super().disconnect() - def snapshot(self) -> SnapshotID: + def snapshot(self) -> "SnapshotID": return self._get_latest_block().number or 0 - def restore(self, snapshot_id: SnapshotID): + def restore(self, snapshot_id: "SnapshotID"): if isinstance(snapshot_id, int): block_number_int = snapshot_id block_number_hex_str = str(to_hex(snapshot_id)) diff --git a/src/ape_pm/compiler.py b/src/ape_pm/compiler.py index 297af2ad25..ebd5ea97ee 100644 --- a/src/ape_pm/compiler.py +++ b/src/ape_pm/compiler.py @@ -2,7 +2,7 @@ from collections.abc import Iterable, Iterator from json import JSONDecodeError from pathlib import Path -from typing import Optional +from typing import TYPE_CHECKING, Optional from eth_pydantic_types import HexBytes from eth_utils import is_0x_prefixed @@ -11,9 +11,11 @@ from ape.api.compiler import CompilerAPI from ape.exceptions import CompilerError, ContractLogicError from ape.logging import logger -from ape.managers.project import ProjectManager from ape.utils.os import get_relative_path +if TYPE_CHECKING: + from ape.managers.project import ProjectManager + class InterfaceCompiler(CompilerAPI): """ @@ -64,7 +66,7 @@ def compile( def compile_code( self, code: str, - project: Optional[ProjectManager] = None, + project: Optional["ProjectManager"] = None, **kwargs, ) -> ContractType: code = code or "[]" diff --git a/src/ape_pm/project.py b/src/ape_pm/project.py index 7e9a53cc4c..0091bbe142 100644 --- a/src/ape_pm/project.py +++ b/src/ape_pm/project.py @@ -1,5 +1,7 @@ import sys from collections.abc import Iterable +from pathlib import Path +from typing import Any, Optional from ape.utils._github import _GithubClient, github_client @@ -10,9 +12,6 @@ else: import toml as tomllib # type: ignore[no-redef] -from pathlib import Path -from typing import Any, Optional - from yaml import safe_load from ape.api.config import ApeConfig diff --git a/src/ape_test/accounts.py b/src/ape_test/accounts.py index 0dbe558250..43c46a29fd 100644 --- a/src/ape_test/accounts.py +++ b/src/ape_test/accounts.py @@ -1,6 +1,6 @@ import warnings from collections.abc import Iterator -from typing import Any, Optional, cast +from typing import TYPE_CHECKING, Any, Optional, cast from eip712.messages import EIP712Message from eth_account import Account as EthAccount @@ -11,9 +11,7 @@ from eth_utils import to_bytes, to_hex from ape.api.accounts import TestAccountAPI, TestAccountContainerAPI -from ape.api.transactions import TransactionAPI from ape.exceptions import ProviderNotConnectedError, SignatureError -from ape.types.address import AddressType from ape.types.signatures import MessageSignature, TransactionSignature from ape.utils.testing import ( DEFAULT_NUMBER_OF_TEST_ACCOUNTS, @@ -22,6 +20,10 @@ generate_dev_accounts, ) +if TYPE_CHECKING: + from ape.api.transactions import TransactionAPI + from ape.types.address import AddressType + class TestAccountContainer(TestAccountContainerAPI): generated_accounts: list["TestAccount"] = [] @@ -82,7 +84,9 @@ def generate_account(self, index: Optional[int] = None) -> "TestAccountAPI": return account @classmethod - def init_test_account(cls, index: int, address: AddressType, private_key: str) -> "TestAccount": + def init_test_account( + cls, index: int, address: "AddressType", private_key: str + ) -> "TestAccount": return TestAccount( index=index, address_str=address, @@ -105,7 +109,7 @@ def alias(self) -> str: return f"TEST::{self.index}" @property - def address(self) -> AddressType: + def address(self) -> "AddressType": return self.network_manager.ethereum.decode_address(self.address_str) def sign_message(self, msg: Any, **signer_options) -> Optional[MessageSignature]: @@ -129,7 +133,9 @@ def sign_message(self, msg: Any, **signer_options) -> Optional[MessageSignature] ) return None - def sign_transaction(self, txn: TransactionAPI, **signer_options) -> Optional[TransactionAPI]: + def sign_transaction( + self, txn: "TransactionAPI", **signer_options + ) -> Optional["TransactionAPI"]: # Signs any transaction that's given to it. # NOTE: Using JSON mode, as only primitive types can be signed. tx_data = txn.model_dump(mode="json", by_alias=True, exclude={"sender"}) diff --git a/src/ape_test/provider.py b/src/ape_test/provider.py index afd12e227a..f6c63e8060 100644 --- a/src/ape_test/provider.py +++ b/src/ape_test/provider.py @@ -18,8 +18,6 @@ from web3.types import TxParams from ape.api.providers import BlockAPI, TestProviderAPI -from ape.api.trace import TraceAPI -from ape.api.transactions import ReceiptAPI, TransactionAPI from ape.exceptions import ( APINotImplementedError, ContractLogicError, @@ -31,8 +29,6 @@ ) from ape.logging import logger from ape.types.address import AddressType -from ape.types.events import ContractLog, LogFilter -from ape.types.vm import BlockID, SnapshotID from ape.utils.misc import gas_estimation_error_message from ape.utils.testing import DEFAULT_TEST_HD_PATH from ape_ethereum.provider import Web3Provider @@ -41,6 +37,10 @@ if TYPE_CHECKING: from ape.api.accounts import TestAccountAPI + from ape.api.trace import TraceAPI + from ape.api.transactions import ReceiptAPI, TransactionAPI + from ape.types.events import ContractLog, LogFilter + from ape.types.vm import BlockID, SnapshotID class LocalProvider(TestProviderAPI, Web3Provider): @@ -121,7 +121,7 @@ def update_settings(self, new_settings: dict): self.connect() def estimate_gas_cost( - self, txn: TransactionAPI, block_id: Optional[BlockID] = None, **kwargs + self, txn: "TransactionAPI", block_id: Optional["BlockID"] = None, **kwargs ) -> int: if isinstance(self.network.gas_limit, int): return self.network.gas_limit @@ -201,8 +201,8 @@ def base_fee(self) -> int: def send_call( self, - txn: TransactionAPI, - block_id: Optional[BlockID] = None, + txn: "TransactionAPI", + block_id: Optional["BlockID"] = None, state: Optional[dict] = None, **kwargs, ) -> HexBytes: @@ -244,7 +244,7 @@ def send_call( return HexBytes(result) - def send_transaction(self, txn: TransactionAPI) -> ReceiptAPI: + def send_transaction(self, txn: "TransactionAPI") -> "ReceiptAPI": vm_err = None txn_dict = None try: @@ -304,10 +304,10 @@ def send_transaction(self, txn: TransactionAPI) -> ReceiptAPI: return receipt - def snapshot(self) -> SnapshotID: + def snapshot(self) -> "SnapshotID": return self.evm_backend.take_snapshot() - def restore(self, snapshot_id: SnapshotID): + def restore(self, snapshot_id: "SnapshotID"): if snapshot_id: current_hash = self._get_latest_block_rpc().get("hash") if current_hash != snapshot_id: @@ -341,18 +341,18 @@ def set_timestamp(self, new_timestamp: int): def mine(self, num_blocks: int = 1): self.evm_backend.mine_blocks(num_blocks) - def get_balance(self, address: AddressType, block_id: Optional[BlockID] = None) -> int: + def get_balance(self, address: AddressType, block_id: Optional["BlockID"] = None) -> int: # perf: Using evm_backend directly instead of going through web3. return self.evm_backend.get_balance( HexBytes(address), block_number="latest" if block_id is None else block_id ) - def get_nonce(self, address: AddressType, block_id: Optional[BlockID] = None) -> int: + def get_nonce(self, address: AddressType, block_id: Optional["BlockID"] = None) -> int: return self.evm_backend.get_nonce( HexBytes(address), block_number="latest" if block_id is None else block_id ) - def get_contract_logs(self, log_filter: LogFilter) -> Iterator[ContractLog]: + def get_contract_logs(self, log_filter: "LogFilter") -> Iterator["ContractLog"]: from_block = max(0, log_filter.start_block) if log_filter.stop_block is None: @@ -397,7 +397,7 @@ def _get_last_base_fee(self) -> int: raise APINotImplementedError("No base fee found in block.") - def get_transaction_trace(self, transaction_hash: str, **kwargs) -> TraceAPI: + def get_transaction_trace(self, transaction_hash: str, **kwargs) -> "TraceAPI": if "call_trace_approach" not in kwargs: kwargs["call_trace_approach"] = TraceApproach.BASIC diff --git a/tests/functional/conftest.py b/tests/functional/conftest.py index d06dd7bee0..2e7b2e0cf2 100644 --- a/tests/functional/conftest.py +++ b/tests/functional/conftest.py @@ -3,7 +3,7 @@ from contextlib import contextmanager from pathlib import Path from shutil import copytree -from typing import Optional, cast +from typing import TYPE_CHECKING, Optional, cast import pytest from eth_pydantic_types import HexBytes @@ -18,10 +18,13 @@ from ape.logging import LogLevel from ape.logging import logger as _logger from ape.types.address import AddressType -from ape.types.events import ContractLog from ape.utils.misc import LOCAL_NETWORK_NAME from ape_ethereum.proxies import minimal_proxy as _minimal_proxy_container +if TYPE_CHECKING: + from ape.types.events import ContractLog + + ALIAS_2 = "__FUNCTIONAL_TESTS_ALIAS_2__" TEST_ADDRESS = cast(AddressType, "0xd8dA6BF26964aF9D7eEd9e03E53415D37aA96045") BASE_PROJECTS_DIRECTORY = (Path(__file__).parent / "data" / "projects").absolute() @@ -431,7 +434,7 @@ def PollDaemon(): @pytest.fixture def assert_log_values(contract_instance): def _assert_log_values( - log: ContractLog, + log: "ContractLog", number: int, previous_number: Optional[int] = None, address: Optional[AddressType] = None, diff --git a/tests/functional/test_config.py b/tests/functional/test_config.py index 2c90a5ba06..ec49a34c0f 100644 --- a/tests/functional/test_config.py +++ b/tests/functional/test_config.py @@ -1,7 +1,7 @@ import os import re from pathlib import Path -from typing import Optional, Union +from typing import TYPE_CHECKING, Optional, Union import pytest from pydantic import ValidationError @@ -10,12 +10,15 @@ from ape.api.config import ApeConfig, ConfigEnum, PluginConfig from ape.exceptions import ConfigError from ape.managers.config import CONFIG_FILE_NAME, merge_configs -from ape.types.gas import GasLimit from ape.utils.os import create_tempdir from ape_ethereum.ecosystem import EthereumConfig, NetworkConfig from ape_networks import CustomNetwork from tests.functional.conftest import PROJECT_WITH_LONG_CONTRACTS_FOLDER +if TYPE_CHECKING: + from ape.types.gas import GasLimit + + CONTRACTS_FOLDER = "pathsomewhwere" NUMBER_OF_TEST_ACCOUNTS = 31 YAML_CONTENT = rf""" @@ -277,7 +280,7 @@ def test_network_gas_limit_default(config): assert eth_config.local.gas_limit == "max" -def _sepolia_with_gas_limit(gas_limit: GasLimit) -> dict: +def _sepolia_with_gas_limit(gas_limit: "GasLimit") -> dict: return { "ethereum": { "sepolia": { diff --git a/tests/functional/test_contract_event.py b/tests/functional/test_contract_event.py index 531366dedf..8dfa555b7e 100644 --- a/tests/functional/test_contract_event.py +++ b/tests/functional/test_contract_event.py @@ -1,6 +1,6 @@ import time from queue import Queue -from typing import Optional +from typing import TYPE_CHECKING, Optional import pytest from eth_pydantic_types import HexBytes @@ -8,11 +8,13 @@ from eth_utils import to_hex from ethpm_types import ContractType -from ape.api.transactions import ReceiptAPI from ape.exceptions import ProviderError from ape.types.events import ContractLog from ape.types.units import CurrencyValueComparable +if TYPE_CHECKING: + from ape.api.transactions import ReceiptAPI + @pytest.fixture def assert_log_values(owner, chain): @@ -38,7 +40,7 @@ def test_contract_logs_from_receipts(owner, contract_instance, assert_log_values receipt_1 = contract_instance.setNumber(2, sender=owner) receipt_2 = contract_instance.setNumber(3, sender=owner) - def assert_receipt_logs(receipt: ReceiptAPI, num: int): + def assert_receipt_logs(receipt: "ReceiptAPI", num: int): logs = event_type.from_receipt(receipt) assert len(logs) == 1 assert_log_values(logs[0], num) diff --git a/tests/functional/test_explorer.py b/tests/functional/test_explorer.py index db9e35e725..44214112eb 100644 --- a/tests/functional/test_explorer.py +++ b/tests/functional/test_explorer.py @@ -1,23 +1,26 @@ -from typing import Optional +from typing import TYPE_CHECKING, Optional import pytest -from ethpm_types import ContractType from ape.api.explorers import ExplorerAPI -from ape.types.address import AddressType + +if TYPE_CHECKING: + from ethpm_types import ContractType + + from ape.types.address import AddressType class MyExplorer(ExplorerAPI): def get_transaction_url(self, transaction_hash: str) -> str: return "" - def get_address_url(self, address: AddressType) -> str: + def get_address_url(self, address: "AddressType") -> str: return "" - def get_contract_type(self, address: AddressType) -> Optional[ContractType]: + def get_contract_type(self, address: "AddressType") -> Optional["ContractType"]: return None - def publish_contract(self, address: AddressType): + def publish_contract(self, address: "AddressType"): return diff --git a/tests/functional/test_receipt.py b/tests/functional/test_receipt.py index 318b65e562..1d99da80ec 100644 --- a/tests/functional/test_receipt.py +++ b/tests/functional/test_receipt.py @@ -1,12 +1,16 @@ +from typing import TYPE_CHECKING + import pytest from rich.table import Table from rich.tree import Tree -from ape.api import ReceiptAPI from ape.exceptions import ContractLogicError, OutOfGasError from ape.utils import ManagerAccessMixin from ape_ethereum.transactions import DynamicFeeTransaction, Receipt, TransactionStatusEnum +if TYPE_CHECKING: + from ape.api import ReceiptAPI + @pytest.fixture def deploy_receipt(vyper_contract_instance): @@ -147,7 +151,7 @@ def test_decode_logs(owner, contract_instance, assert_log_values): receipt_1 = contract_instance.setNumber(2, sender=owner) receipt_2 = contract_instance.setNumber(3, sender=owner) - def assert_receipt_logs(receipt: ReceiptAPI, num: int): + def assert_receipt_logs(receipt: "ReceiptAPI", num: int): logs = receipt.decode_logs(event_type) assert len(logs) == 1 assert_log_values(logs[0], num)