Skip to content

Commit

Permalink
Merge pull request #154 from charles-cooper/fix/address-inheritance
Browse files Browse the repository at this point in the history
refactor: Address inherit from bytes
  • Loading branch information
charles-cooper authored Feb 16, 2024
2 parents 9055344 + 2593bfe commit 1f68e6b
Show file tree
Hide file tree
Showing 9 changed files with 43 additions and 52 deletions.
2 changes: 1 addition & 1 deletion boa/contracts/abi/abi_contract.py
Original file line number Diff line number Diff line change
Expand Up @@ -202,7 +202,7 @@ def __init__(
super().__init__(env, filename=filename, address=address)
self._name = name
self._functions = functions
self._bytecode = self.env.vm.state.get_code(address.canonical_address)
self._bytecode = self.env.vm.state.get_code(address)
if not self._bytecode:
warn(
f"Requested {self} but there is no bytecode at that address!",
Expand Down
6 changes: 3 additions & 3 deletions boa/contracts/vyper/vyper_contract.py
Original file line number Diff line number Diff line change
Expand Up @@ -120,7 +120,7 @@ def at(self, address: Any) -> "VyperContract":

ret = self.deploy(override_address=address, skip_initcode=True)
vm = ret.env.vm
bytecode = vm.state.get_code(address.canonical_address)
bytecode = vm.state.get_code(address)

ret._set_bytecode(bytecode)

Expand Down Expand Up @@ -361,7 +361,7 @@ def setpath(lens, path, val):
class StorageVar:
def __init__(self, contract, slot, typ):
self.contract = contract
self.addr = self.contract._address.canonical_address
self.addr = self.contract._address
self.accountdb = contract.env.vm.state._account_db
self.slot = slot
self.typ = typ
Expand Down Expand Up @@ -668,7 +668,7 @@ def event_for(self):

def decode_log(self, e):
log_id, address, topics, data = e
assert self._address.canonical_address == address
assert self._address == address
event_hash = topics[0]
event_t = self.event_for[event_hash]

Expand Down
22 changes: 11 additions & 11 deletions boa/environment.py
Original file line number Diff line number Diff line change
Expand Up @@ -127,11 +127,11 @@ def register_raw_precompile(address, fn, force=False):
address = Address(address)
if address in _precompiles and not force:
raise ValueError(f"Already registered: {address}")
_precompiles[address.canonical_address] = fn
_precompiles[address] = fn


def deregister_raw_precompile(address, force=True):
address = Address(address).canonical_address
address = Address(address)
if address not in _precompiles and not force:
raise ValueError("Not registered: {address}")
_precompiles.pop(address, None)
Expand Down Expand Up @@ -492,14 +492,14 @@ def reset_gas_metering_behavior(self) -> None:

# set balance of address in py-evm
def set_balance(self, addr, value):
self.vm.state.set_balance(Address(addr).canonical_address, value)
self.vm.state.set_balance(Address(addr), value)

# get balance of address in py-evm
def get_balance(self, addr):
return self.vm.state.get_balance(Address(addr).canonical_address)
return self.vm.state.get_balance(Address(addr))

def register_contract(self, address, obj):
addr = Address(address).canonical_address
addr = Address(address)
self._contracts[addr] = obj

# also register it in the registry for
Expand All @@ -516,13 +516,13 @@ def _lookup_contract_fast(self, address: PYEVM_Address):
def lookup_contract(self, address: _AddressType):
if address == b"":
return None
return self._contracts.get(Address(address).canonical_address)
return self._contracts.get(Address(address))

def alias(self, address, name):
self._aliases[Address(address).canonical_address] = name
self._aliases[Address(address)] = name

def lookup_alias(self, address):
return self._aliases[Address(address).canonical_address]
return self._aliases[Address(address)]

# advanced: reset warm/cold counters for addresses and storage
def _reset_access_counters(self):
Expand Down Expand Up @@ -577,7 +577,7 @@ def _get_sender(self, sender=None) -> PYEVM_Address:
sender = self.eoa
if self.eoa is None:
raise ValueError(f"{self}.eoa not defined!")
return Address(sender).canonical_address
return Address(sender)

def _update_gas_used(self, gas_used: int):
self._gas_tracker += gas_used
Expand Down Expand Up @@ -610,7 +610,7 @@ def deploy_code(
gas=gas,
value=value,
code=bytecode,
create_address=target_address.canonical_address,
create_address=target_address,
data=b"",
)

Expand Down Expand Up @@ -672,7 +672,7 @@ def execute_code(

sender = self._get_sender(sender)

to = Address(to_address).canonical_address
to = Address(to_address)

bytecode = override_bytecode
if override_bytecode is None:
Expand Down
2 changes: 1 addition & 1 deletion boa/interpret.py
Original file line number Diff line number Diff line change
Expand Up @@ -161,7 +161,7 @@ def load_partial(filename: str, compiler_args=None) -> VyperDeployer: # type: i
def from_etherscan(
address: Any, name=None, uri="https://api.etherscan.io/api", api_key=None
):
addr = Address(address)
addr = Address(address).checksum_address
abi = fetch_abi_from_etherscan(addr, uri, api_key)
return ABIContractFactory.from_abi_dict(abi, name=name).at(addr)

Expand Down
20 changes: 5 additions & 15 deletions boa/test/strategies.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,13 +3,13 @@
from typing import Any, Callable, Iterable, Optional, Union

from eth_abi.grammar import BasicType, TupleType, parse
from eth_utils import to_checksum_address
from hypothesis import given
from hypothesis import strategies as st
from hypothesis.strategies import SearchStrategy
from hypothesis.strategies._internal.deferred import DeferredStrategy

from boa.contracts.vyper.vyper_contract import VyperFunction
from boa.util.abi import Address

# hypothesis fuzzing strategies, adapted from brownie 0.19.2 (86258c7bd)
# in the future these may be superseded by eth-stdlib.
Expand Down Expand Up @@ -87,25 +87,15 @@ def _decimal_strategy(
return st.decimals(min_value=min_value, max_value=max_value, places=places)


def format_addr(t):
if isinstance(t, str):
t = t.encode("utf-8")
return to_checksum_address(t.rjust(20, b"\x00"))


def generate_random_string(n):
return ["".join(random.choices(string.ascii_lowercase, k=5)) for i in range(n)]
def generate_random_strings(n):
return [b"".join(random.choices(string.ascii_lowercase, k=5)) for i in range(n)]


@_exclude_filter
def _address_strategy(length: Optional[int] = 100) -> SearchStrategy:
random_strings = generate_random_string(length)
def _address_strategy() -> SearchStrategy:
# TODO: add addresses from the environment. probably everything in
# boa.env._contracts, boa.env._blueprints and boa.env.eoa.
accounts = [format_addr(i) for i in random_strings]
return _DeferredStrategyRepr(
lambda: st.sampled_from(list(accounts)[:length]), "accounts"
)
return st.binary(min_size=20, max_size=20).map(Address)


@_exclude_filter
Expand Down
28 changes: 16 additions & 12 deletions boa/util/abi.py
Original file line number Diff line number Diff line change
@@ -1,29 +1,27 @@
# wrapper module around whatever encoder we are using
from typing import Annotated, Any
from typing import Any

from eth.codecs.abi import nodes
from eth.codecs.abi.decoder import Decoder
from eth.codecs.abi.encoder import Encoder
from eth.codecs.abi.exceptions import ABIError
from eth.codecs.abi.nodes import ABITypeNode
from eth.codecs.abi.parser import Parser
from eth_typing import Address as PYEVM_Address
from eth_utils import to_canonical_address, to_checksum_address

from boa.util.lrudict import lrudict

_parsers: dict[str, ABITypeNode] = {}


# XXX: inherit from bytes directly so that we can pass it to py-evm?
# inherit from `str` so that ABI encoder / decoder can work without failing
class Address(str): # (PYEVM_Address):
# inherit from bytes so we don't need conversion when interacting with pyevm
class Address(bytes):
# converting between checksum and canonical addresses is a hotspot;
# this class contains both and caches recently seen conversions
__slots__ = ("canonical_address",)
# TODO: maybe this class belongs in its own module
_cache = lrudict(1024)

canonical_address: Annotated[PYEVM_Address, "canonical address"]
checksum_address: str

def __new__(cls, address):
if isinstance(address, Address):
Expand All @@ -34,15 +32,14 @@ def __new__(cls, address):
except KeyError:
pass

checksum_address = to_checksum_address(address)
self = super().__new__(cls, checksum_address)
self.canonical_address = to_canonical_address(address)
canonical_address = to_canonical_address(address)
self = super().__new__(cls, canonical_address)
self.checksum_address = to_checksum_address(address)
cls._cache[address] = self
return self

def __repr__(self):
checksum_addr = super().__repr__()
return f"_Address({checksum_addr})"
return f"_Address({self.checksum_address})"


class _ABIEncoder(Encoder):
Expand All @@ -54,6 +51,13 @@ class _ABIEncoder(Encoder):
@classmethod
def visit_AddressNode(cls, node: nodes.AddressNode, value) -> bytes:
value = getattr(value, "address", value)

if isinstance(value, Address):
assert len(value) == 20 # guaranteed by to_canonical_address
# for performance, inline the implementation
# return the bytes value, left-pad with zeros
return value.rjust(32, b"\x00")

return super().visit_AddressNode(node, value)


Expand Down
2 changes: 1 addition & 1 deletion tests/unitary/jupyter/test_browser.py
Original file line number Diff line number Diff line change
Expand Up @@ -157,7 +157,7 @@ def test_browser_sign_typed_data(
browser, display_mock, mock_inject_javascript, mock_callback
):
signer = browser.BrowserSigner(boa.env.generate_address())
signature = boa.env.generate_address()
signature = boa.env.generate_address().checksum_address
mock_callback("signTypedData", signature)
data = signer.sign_typed_data({"name": "My App"}, {"types": []}, {"data": "0x1234"})
assert data == signature
Expand Down
11 changes: 4 additions & 7 deletions tests/unitary/strategy/test_address.py
Original file line number Diff line number Diff line change
@@ -1,18 +1,15 @@
from hypothesis import HealthCheck, given, settings
from hypothesis.strategies._internal.deferred import DeferredStrategy
from hypothesis.strategies import SearchStrategy

from boa.test import strategy
from boa.util.abi import Address


def test_strategy():
assert isinstance(strategy("address"), DeferredStrategy)
assert isinstance(strategy("address"), SearchStrategy)


@given(value=strategy("address"))
@settings(suppress_health_check=[HealthCheck.function_scoped_fixture])
def test_given(value):
assert isinstance(value, str)


def test_repr():
assert repr(strategy("address")) == "sampled_from(accounts)"
assert isinstance(value, Address)
2 changes: 1 addition & 1 deletion tests/unitary/test_blueprints.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,6 @@ def test_create2_address():
blueprint_bytecode = boa.env.vm.state.get_code(
to_canonical_address(blueprint.address)
)
assert child_contract_address == get_create2_address(
assert child_contract_address.checksum_address == get_create2_address(
blueprint_bytecode, factory.address, salt
)

0 comments on commit 1f68e6b

Please sign in to comment.