diff --git a/src/ape/contracts/base.py b/src/ape/contracts/base.py index 63c42c65fa..d61f4e7010 100644 --- a/src/ape/contracts/base.py +++ b/src/ape/contracts/base.py @@ -1514,7 +1514,7 @@ def _cache_wrap(self, function: Callable) -> ReceiptAPI: except ContractLogicError as err: if address := err.address: self.chain_manager.contracts[address] = self.contract_type - err._set_tb() # Re-try setting source traceback + err = err.with_ape_traceback() # Re-try setting source traceback new_err = None try: # Try enrichment again now that the contract type is cached. diff --git a/src/ape/exceptions.py b/src/ape/exceptions.py index 9907d4f7f5..b164813b43 100644 --- a/src/ape/exceptions.py +++ b/src/ape/exceptions.py @@ -8,7 +8,7 @@ from inspect import getframeinfo, stack from pathlib import Path from types import CodeType, TracebackType -from typing import TYPE_CHECKING, Any, Optional, Union, cast +from typing import TYPE_CHECKING, Any, Callable, Optional, Union, cast import click from eth_typing import Hash32 @@ -163,6 +163,12 @@ class MethodNonPayableError(ContractDataError): """ +_TRACE_ARG = Optional[Union["TraceAPI", Callable[[], Optional["TraceAPI"]]]] +_SOURCE_TRACEBACK_ARG = Optional[ + Union["SourceTraceback", Callable[[], Optional["SourceTraceback"]]] +] + + class TransactionError(ApeException): """ Raised when issues occur related to transactions. @@ -176,25 +182,28 @@ def __init__( base_err: Optional[Exception] = None, code: Optional[int] = None, txn: Optional[FailedTxn] = None, - trace: Optional["TraceAPI"] = None, + trace: _TRACE_ARG = None, contract_address: Optional["AddressType"] = None, - source_traceback: Optional["SourceTraceback"] = None, + source_traceback: _SOURCE_TRACEBACK_ARG = None, project: Optional["ProjectManager"] = None, + set_ape_traceback: bool = False, # Overriden in ContractLogicError ): message = message or (str(base_err) if base_err else self.DEFAULT_MESSAGE) self.message = message self.base_err = base_err self.code = code self.txn = txn - self.trace = trace + self._trace = trace self.contract_address = contract_address - self.source_traceback: Optional["SourceTraceback"] = source_traceback + self._source_traceback = source_traceback self._project = project ex_message = f"({code}) {message}" if code else message # Finalizes expected revert message. super().__init__(ex_message) - self._set_tb() + + if set_ape_traceback: + self.with_ape_traceback() @property def address(self) -> Optional["AddressType"]: @@ -223,15 +232,51 @@ def contract_type(self) -> Optional[ContractType]: except (RecursionError, ProviderNotConnectedError): return None - def _set_tb(self): - if not self.source_traceback and self.txn: - self.source_traceback = _get_ape_traceback_from_tx(self.txn) + @property + def trace(self) -> Optional["TraceAPI"]: + tr = self._trace + if callable(tr): + result = tr() + self._trace = result + return result + + return tr + + @trace.setter + def trace(self, value): + self._trace = value + + @property + def source_traceback(self) -> Optional["SourceTraceback"]: + tb = self._source_traceback + result: Optional["SourceTraceback"] + if callable(tb): + result = tb() + self._source_traceback = result + else: + result = tb + + return result - if src_tb := self.source_traceback: + @source_traceback.setter + def source_traceback(self, value): + self._source_traceback = value + + def _get_ape_traceback(self) -> Optional[TracebackType]: + source_tb = self.source_traceback + if not source_tb and self.txn: + source_tb = _get_ape_traceback_from_tx(self.txn) + + if src_tb := source_tb: # Create a custom Pythonic traceback using lines from the sources # found from analyzing the trace of the transaction. if py_tb := _get_custom_python_traceback(self, src_tb, project=self._project): - self.__traceback__ = py_tb + return py_tb + + return None + + def with_ape_traceback(self): + return self.with_traceback(self._get_ape_traceback()) class VirtualMachineError(TransactionError): @@ -250,19 +295,22 @@ def __init__( self, revert_message: Optional[str] = None, txn: Optional[FailedTxn] = None, - trace: Optional["TraceAPI"] = None, + trace: _TRACE_ARG = None, contract_address: Optional["AddressType"] = None, - source_traceback: Optional["SourceTraceback"] = None, + source_traceback: _SOURCE_TRACEBACK_ARG = None, base_err: Optional[Exception] = None, + project: Optional["ProjectManager"] = None, + set_ape_traceback: bool = True, # Overriden default. ): self.txn = txn - self.trace = trace self.contract_address = contract_address super().__init__( base_err=base_err, contract_address=contract_address, message=revert_message, + project=project, + set_ape_traceback=set_ape_traceback, source_traceback=source_traceback, trace=trace, txn=txn, @@ -313,8 +361,15 @@ def __init__( code: Optional[int] = None, txn: Optional[FailedTxn] = None, base_err: Optional[Exception] = None, + set_ape_traceback: bool = False, ): - super().__init__("The transaction ran out of gas.", code=code, txn=txn, base_err=base_err) + super().__init__( + "The transaction ran out of gas.", + code=code, + txn=txn, + base_err=base_err, + set_ape_traceback=set_ape_traceback, + ) class NetworkError(ApeException): @@ -786,10 +841,10 @@ def __init__( abi: ErrorABI, inputs: dict[str, Any], txn: Optional[FailedTxn] = None, - trace: Optional["TraceAPI"] = None, + trace: _TRACE_ARG = None, contract_address: Optional["AddressType"] = None, base_err: Optional[Exception] = None, - source_traceback: Optional["SourceTraceback"] = None, + source_traceback: _SOURCE_TRACEBACK_ARG = None, ): self.abi = abi self.inputs = inputs diff --git a/src/ape/managers/compilers.py b/src/ape/managers/compilers.py index b953a46e36..76552680cb 100644 --- a/src/ape/managers/compilers.py +++ b/src/ape/managers/compilers.py @@ -319,7 +319,7 @@ def get_custom_error(self, err: ContractLogicError) -> Optional[CustomError]: HexBytes(message), address, base_err=err.base_err, - source_traceback=err.source_traceback, + source_traceback=lambda: err.source_traceback, trace=err.trace, txn=err.txn, ) diff --git a/src/ape/pytest/coverage.py b/src/ape/pytest/coverage.py index bc4674243e..9df48b8227 100644 --- a/src/ape/pytest/coverage.py +++ b/src/ape/pytest/coverage.py @@ -1,6 +1,6 @@ from collections.abc import Iterable from pathlib import Path -from typing import Optional +from typing import Callable, Optional, Union import click from ethpm_types.abi import MethodABI @@ -23,16 +23,33 @@ class CoverageData(ManagerAccessMixin): - def __init__(self, project: ProjectManager, sources: Iterable[ContractSource]): + def __init__( + self, + project: ProjectManager, + sources: Union[Iterable[ContractSource], Callable[[], Iterable[ContractSource]]], + ): self.project = project - self.sources = list(sources) + self._sources: Union[Iterable[ContractSource], Callable[[], Iterable[ContractSource]]] = ( + sources + ) self._report: Optional[CoverageReport] = None - self._init_coverage_profile() # Inits self._report. + + @property + def sources(self) -> list[ContractSource]: + if isinstance(self._sources, list): + return self._sources + + elif callable(self._sources): + # Lazily evaluated. + self._sources = self._sources() + + self._sources = [src for src in self._sources] + return self._sources @property def report(self) -> CoverageReport: if self._report is None: - return self._init_coverage_profile() + self._report = self._init_coverage_profile() return self._report @@ -69,7 +86,6 @@ def _init_coverage_profile( for project in report.projects: project.sources = [x for x in project.sources if len(x.statements) > 0] - self._report = report return report def cover( @@ -142,11 +158,20 @@ def __init__( else: self._output_path = Path.cwd() - sources = self._project._contract_sources + # Data gets initialized lazily (if coverage is needed). + self._data: Optional[CoverageData] = None - self.data: Optional[CoverageData] = ( - CoverageData(self._project, sources) if self.config_wrapper.track_coverage else None - ) + @property + def data(self) -> Optional[CoverageData]: + if not self.config_wrapper.track_coverage: + return None + + elif self._data is None: + # First time being initialized. + self._data = CoverageData(self._project, lambda: self._project._contract_sources) + return self._data + + return self._data @property def enabled(self) -> bool: diff --git a/src/ape_ethereum/provider.py b/src/ape_ethereum/provider.py index 70bd56bbd2..02bbd57b54 100644 --- a/src/ape_ethereum/provider.py +++ b/src/ape_ethereum/provider.py @@ -59,7 +59,7 @@ LogFilter, SourceTraceback, ) -from ape.utils import gas_estimation_error_message, to_int +from ape.utils import ManagerAccessMixin, gas_estimation_error_message, to_int from ape.utils.misc import DEFAULT_MAX_RETRIES_TX from ape_ethereum._print import CONSOLE_ADDRESS, console_contract from ape_ethereum.trace import CallTrace, TraceApproach, TransactionTrace @@ -348,11 +348,11 @@ def estimate_gas_cost(self, txn: TransactionAPI, block_id: Optional[BlockID] = N else: tx_to_trace[key] = val - trace = CallTrace(tx=txn) tx_error = self.get_virtual_machine_error( err, txn=txn, - trace=trace, + trace=lambda: CallTrace(tx=txn), + set_ape_traceback=False, ) # If this is the cause of a would-be revert, @@ -362,7 +362,11 @@ def estimate_gas_cost(self, txn: TransactionAPI, block_id: Optional[BlockID] = N message = gas_estimation_error_message(tx_error) raise TransactionError( - message, base_err=tx_error, txn=txn, source_traceback=tx_error.source_traceback + message, + base_err=tx_error, + txn=txn, + source_traceback=lambda: tx_error.source_traceback, + set_ape_traceback=True, ) from err @cached_property @@ -541,33 +545,28 @@ def _eth_call( try: result = self.make_request("eth_call", arguments) except Exception as err: - trace = None - tb = None contract_address = arguments[0].get("to") + _lazy_call_trace = _LazyCallTrace(arguments) + if not skip_trace: if address := contract_address: try: contract_type = self.chain_manager.contracts.get(address) except RecursionError: # Occurs when already in the middle of fetching this contract. - contract_type = None - else: - contract_type = None - - trace = CallTrace( - tx=arguments[0], arguments=arguments[1:], use_tokens_for_symbols=True - ) - method_id = arguments[0].get("data", "")[:10] or None - tb = None - if contract_type and method_id: - if contract_src := self.local_project._create_contract_source(contract_type): - tb = SourceTraceback.create(contract_src, trace, method_id) + pass + else: + _lazy_call_trace.contract_type = contract_type vm_err = self.get_virtual_machine_error( - err, trace=trace, contract_address=contract_address, source_traceback=tb + err, + trace=lambda: _lazy_call_trace.trace, + contract_address=contract_address, + source_traceback=lambda: _lazy_call_trace.source_traceback, + set_ape_traceback=raise_on_revert, ) if raise_on_revert: - raise vm_err from err + raise vm_err.with_ape_traceback() from err else: logger.error(vm_err) @@ -1009,7 +1008,9 @@ def send_transaction(self, txn: TransactionAPI) -> ReceiptAPI: txn_hash = to_hex(self.web3.eth.send_raw_transaction(txn.serialize_transaction())) except (ValueError, Web3ContractLogicError) as err: - vm_err = self.get_virtual_machine_error(err, txn=txn) + vm_err = self.get_virtual_machine_error( + err, txn=txn, set_ape_traceback=txn.raise_on_revert + ) if txn.raise_on_revert: raise vm_err from err else: @@ -1057,7 +1058,9 @@ def send_transaction(self, txn: TransactionAPI) -> ReceiptAPI: try: self.web3.eth.call(txn_params) except Exception as err: - vm_err = self.get_virtual_machine_error(err, txn=txn) + vm_err = self.get_virtual_machine_error( + err, txn=txn, set_ape_traceback=txn.raise_on_revert + ) receipt.error = vm_err if txn.raise_on_revert: raise vm_err from err @@ -1221,6 +1224,7 @@ def _handle_execution_reverted( trace: Optional[TraceAPI] = None, contract_address: Optional[AddressType] = None, source_traceback: Optional[SourceTraceback] = None, + set_ape_traceback: Optional[bool] = None, ) -> ContractLogicError: if hasattr(exception, "args") and len(exception.args) == 2: message = exception.args[0].replace("execution reverted: ", "") @@ -1234,6 +1238,9 @@ def _handle_execution_reverted( "contract_address": contract_address, "source_traceback": source_traceback, } + if set_ape_traceback is not None: + params["set_ape_traceback"] = set_ape_traceback + no_reason = message == "execution reverted" if isinstance(exception, Web3ContractLogicError) and no_reason: @@ -1580,3 +1587,29 @@ def _is_ws_url(val: str) -> bool: def _is_ipc_path(val: str) -> bool: return val.endswith(".ipc") + + +class _LazyCallTrace(ManagerAccessMixin): + def __init__(self, eth_call_args: list): + self._arguments = eth_call_args + + self.contract_type = None + + @cached_property + def trace(self) -> CallTrace: + return CallTrace( + tx=self._arguments[0], arguments=self._arguments[1:], use_tokens_for_symbols=True + ) + + @cached_property + def source_traceback(self) -> Optional[SourceTraceback]: + ct = self.contract_type + if ct is None: + return None + + method_id = self._arguments[0].get("data", "")[:10] or None + if ct and method_id: + if contract_src := self.local_project._create_contract_source(ct): + return SourceTraceback.create(contract_src, self.trace, method_id) + + return None diff --git a/src/ape_pm/compiler.py b/src/ape_pm/compiler.py index ad7ac652b5..01f75b8185 100644 --- a/src/ape_pm/compiler.py +++ b/src/ape_pm/compiler.py @@ -147,7 +147,7 @@ def enrich_error(self, err: ContractLogicError) -> ContractLogicError: abi, inputs, txn=err.txn, - trace=err.trace, + trace=lambda: err.trace, contract_address=address, - source_traceback=err.source_traceback, + source_traceback=lambda: err.source_traceback, ) diff --git a/src/ape_test/provider.py b/src/ape_test/provider.py index 05bea7044c..e7b9a2f59f 100644 --- a/src/ape_test/provider.py +++ b/src/ape_test/provider.py @@ -131,7 +131,7 @@ def estimate_gas_cost( try: return estimate_gas(txn_data, block_identifier=block_id) except (ValidationError, TransactionFailed, Web3ContractLogicError) as err: - ape_err = self.get_virtual_machine_error(err, txn=txn) + ape_err = self.get_virtual_machine_error(err, txn=txn, set_ape_traceback=False) gas_match = self._INVALID_NONCE_PATTERN.match(str(ape_err)) if gas_match: # Sometimes, EthTester is confused about the sender nonce @@ -148,11 +148,15 @@ def estimate_gas_cost( return value elif isinstance(ape_err, ContractLogicError): - raise ape_err from err + raise ape_err.with_ape_traceback() from err else: message = gas_estimation_error_message(ape_err) raise TransactionError( - message, base_err=ape_err, txn=txn, source_traceback=ape_err.source_traceback + message, + base_err=ape_err, + txn=txn, + source_traceback=lambda: ape_err.source_traceback, + set_ape_traceback=False, ) from ape_err @property @@ -224,7 +228,7 @@ def send_call( result = HexBytes("0x") except (TransactionFailed, Web3ContractLogicError) as err: - vm_err = self.get_virtual_machine_error(err, txn=txn) + vm_err = self.get_virtual_machine_error(err, txn=txn, set_ape_traceback=False) if raise_on_revert: raise vm_err from err else: @@ -244,7 +248,7 @@ def send_transaction(self, txn: TransactionAPI) -> ReceiptAPI: txn.serialize_transaction().hex() ) except (ValidationError, TransactionFailed, Web3ContractLogicError) as err: - vm_err = self.get_virtual_machine_error(err, txn=txn) + vm_err = self.get_virtual_machine_error(err, txn=txn, set_ape_traceback=False) if txn.raise_on_revert: raise vm_err from err else: @@ -281,7 +285,7 @@ def send_transaction(self, txn: TransactionAPI) -> ReceiptAPI: try: self.web3.eth.call(txn_params) except (ValidationError, TransactionFailed, Web3ContractLogicError) as err: - vm_err = self.get_virtual_machine_error(err, txn=receipt) + vm_err = self.get_virtual_machine_error(err, txn=receipt, set_ape_traceback=False) receipt.error = vm_err if txn.raise_on_revert: raise vm_err from err diff --git a/tests/functional/test_contract_instance.py b/tests/functional/test_contract_instance.py index eabdf13ac9..423187c125 100644 --- a/tests/functional/test_contract_instance.py +++ b/tests/functional/test_contract_instance.py @@ -203,7 +203,7 @@ def test_revert_allow(not_owner, contract_instance): def test_revert_handles_compiler_panic(owner, contract_instance): - # note: setBalance is a weird name - it actually adjust the balance. + # note: setBalance is a weird name - it actually adjusts the balance. # first, set it to be 1 less than an overflow. contract_instance.setBalance(owner, 2**256 - 1, sender=owner) # then, add 1 more, so it should no overflow and cause a compiler panic. diff --git a/tests/functional/test_exceptions.py b/tests/functional/test_exceptions.py index cd88449666..8bbd999672 100644 --- a/tests/functional/test_exceptions.py +++ b/tests/functional/test_exceptions.py @@ -108,8 +108,10 @@ def test_call_with_source_tb_and_not_txn(self, mocker, project_with_contract): mock_exec.closure = mock_closure mock_tb.__getitem__.return_value = mock_exec mock_tb.__len__.return_value = 1 - - err = TransactionError(source_traceback=mock_tb, project=project_with_contract) + mock_tb.return_value = mock_tb + err = TransactionError( + source_traceback=mock_tb, project=project_with_contract, set_ape_traceback=True + ) # Have to raise for sys.exc_info() to be available. try: @@ -117,13 +119,36 @@ def test_call_with_source_tb_and_not_txn(self, mocker, project_with_contract): except Exception: pass - assert err.__traceback__ is not None + def assert_ape_traceback(err_arg): + assert err_arg.__traceback__ is not None + # The Vyper-frame gets injected at tb_next. + assert err_arg.__traceback__.tb_next is not None + actual = str(err_arg.__traceback__.tb_next.tb_frame) + assert src_path in actual + + assert_ape_traceback(err) + + err2 = TransactionError( + source_traceback=mock_tb, + project=project_with_contract, + set_ape_traceback=False, + ) + try: + raise err2 + except Exception: + pass + + # No Ape frames are here. + if err2.__traceback__: + assert err2.__traceback__.tb_next is None - # The Vyper-frame gets injected at tb_next. - assert err.__traceback__.tb_next is not None + err3 = ContractLogicError(source_traceback=mock_tb, project=project_with_contract) + try: + raise err3 + except Exception: + pass - actual = str(err.__traceback__.tb_next.tb_frame) - assert src_path in actual + assert_ape_traceback(err3) class TestNetworkNotFoundError: