Skip to content

Commit

Permalink
fix: pass contract_name to VyperContract (#338)
Browse files Browse the repository at this point in the history
add contract_name to VyperContract ctor. allows setting contract_name at deploy time.

---------

Co-authored-by: Charles Cooper <[email protected]>
  • Loading branch information
PatrickAlphaC and charles-cooper authored Oct 25, 2024
1 parent 2badb06 commit f58c33c
Show file tree
Hide file tree
Showing 8 changed files with 87 additions and 19 deletions.
3 changes: 2 additions & 1 deletion boa/contracts/vvm/vvm_contract.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,7 +54,7 @@ def constructor(self):
return ABIFunction(t, contract_name=self.filename)
return None

def deploy(self, *args, env=None, **kwargs):
def deploy(self, *args, contract_name=None, env=None, **kwargs):
encoded_args = b""
if self.constructor is not None:
encoded_args = self.constructor.prepare_calldata(*args)
Expand All @@ -66,6 +66,7 @@ def deploy(self, *args, env=None, **kwargs):

address, _ = env.deploy_code(bytecode=self.bytecode + encoded_args, **kwargs)

# TODO: pass thru contract_name
return self.at(address)

@cached_property
Expand Down
13 changes: 8 additions & 5 deletions boa/contracts/vyper/vyper_contract.py
Original file line number Diff line number Diff line change
Expand Up @@ -144,10 +144,13 @@ class _BaseVyperContract(_BaseEVMContract):
def __init__(
self,
compiler_data: CompilerData,
contract_name: Optional[str] = None,
env: Optional[Env] = None,
filename: Optional[str] = None,
):
contract_name = Path(compiler_data.contract_path).stem
if contract_name is None:
contract_name = Path(compiler_data.contract_path).stem

super().__init__(contract_name, env, filename)
self.compiler_data = compiler_data

Expand Down Expand Up @@ -185,12 +188,11 @@ def __init__(
env=None,
override_address=None,
blueprint_preamble=None,
contract_name=None,
filename=None,
gas=None,
):
# note slight code duplication with VyperContract ctor,
# maybe use common base class?
super().__init__(compiler_data, env, filename)
super().__init__(compiler_data, contract_name, env, filename)

deploy_bytecode = generate_blueprint_bytecode(
compiler_data.bytecode, blueprint_preamble
Expand Down Expand Up @@ -516,10 +518,11 @@ def __init__(
# whether to skip constructor
skip_initcode=False,
created_from: Address = None,
contract_name=None,
filename: str = None,
gas=None,
):
super().__init__(compiler_data, env, filename)
super().__init__(compiler_data, contract_name, env, filename)

self.created_from = created_from
self._computation = None
Expand Down
18 changes: 11 additions & 7 deletions boa/interpret.py
Original file line number Diff line number Diff line change
Expand Up @@ -129,11 +129,15 @@ def get_module_fingerprint(


def compiler_data(
source_code: str, contract_name: str, filename: str | Path, deployer=None, **kwargs
source_code: str,
contract_name: str | None,
filename: str | Path,
deployer=None,
**kwargs,
) -> CompilerData:
global _disk_cache, _search_path

path = Path(contract_name)
path = Path(filename)
resolved_path = Path(filename).resolve(strict=False)

file_input = FileInput(
Expand Down Expand Up @@ -164,7 +168,7 @@ def get_compiler_data():

assert isinstance(deployer, type) or deployer is None
deployer_id = repr(deployer) # a unique str identifying the deployer class
cache_key = str((contract_name, fingerprint, kwargs, deployer_id))
cache_key = str((contract_name, filename, fingerprint, kwargs, deployer_id))
return _disk_cache.caching_lookup(cache_key, get_compiler_data)


Expand All @@ -188,9 +192,9 @@ def loads(
):
d = loads_partial(source_code, name, filename=filename, compiler_args=compiler_args)
if as_blueprint:
return d.deploy_as_blueprint(**kwargs)
return d.deploy_as_blueprint(contract_name=name, **kwargs)
else:
return d.deploy(*args, **kwargs)
return d.deploy(*args, contract_name=name, **kwargs)


def load_abi(filename: str, *args, name: str = None, **kwargs) -> ABIContractFactory:
Expand Down Expand Up @@ -239,8 +243,8 @@ def loads_partial(
dedent: bool = True,
compiler_args: dict = None,
) -> VyperDeployer:
name = name or "VyperContract"
filename = filename or "<unknown>"
if filename is None:
filename = "<unknown>"

if dedent:
source_code = textwrap.dedent(source_code)
Expand Down
1 change: 1 addition & 0 deletions dev-requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@ pytest
pytest-xdist
pytest-cov
sphinx-rtd-theme
requests-cache

# jupyter
jupyter_server
Expand Down
32 changes: 31 additions & 1 deletion tests/integration/network/anvil/test_network_env.py
Original file line number Diff line number Diff line change
Expand Up @@ -73,9 +73,38 @@ def test_failed_transaction():
# XXX: probably want to test deployment revert behavior


def test_deployment_db():
def test_deployment_db_overriden_contract_name():
with set_deployments_db(DeploymentsDB(":memory:")) as db:
arg = 5
contract_name = "test_deployment"

# contract is written to deployments db
contract = boa.loads(code, arg, contract_name=contract_name)

# test get_deployments()
deployment = next(db.get_deployments())

initcode = contract.compiler_data.bytecode + arg.to_bytes(32, "big")

# sanity check all the fields
assert deployment.contract_address == contract.address
assert deployment.contract_name == contract.contract_name
assert deployment.contract_name == contract_name
assert deployment.deployer == boa.env.eoa
assert deployment.rpc == boa.env._rpc.name
assert deployment.source_code == contract.deployer.solc_json
assert deployment.abi == contract.abi

# some sanity checks on tx_dict and rx_dict fields
assert to_bytes(deployment.tx_dict["data"]) == initcode
assert deployment.tx_dict["chainId"] == hex(boa.env.get_chain_id())
assert Address(deployment.receipt_dict["contractAddress"]) == contract.address


def test_deployment_db_no_overriden_name():
with set_deployments_db(DeploymentsDB(":memory:")) as db:
arg = 5
non_contract_name = "test_deployment"

# contract is written to deployments db
contract = boa.loads(code, arg)
Expand All @@ -88,6 +117,7 @@ def test_deployment_db():
# sanity check all the fields
assert deployment.contract_address == contract.address
assert deployment.contract_name == contract.contract_name
assert deployment.contract_name != non_contract_name
assert deployment.deployer == boa.env.eoa
assert deployment.rpc == boa.env._rpc.name
assert deployment.source_code == contract.deployer.solc_json
Expand Down
4 changes: 3 additions & 1 deletion tests/integration/network/sepolia/test_sepolia_env.py
Original file line number Diff line number Diff line change
Expand Up @@ -75,9 +75,10 @@ def test_raise_exception(simple_contract, amount):
def test_deployment_db():
with set_deployments_db(DeploymentsDB(":memory:")) as db:
arg = 5
contract_name = "test_deployment"

# contract is written to deployments db
contract = boa.loads(code, arg)
contract = boa.loads(code, arg, contract_name=contract_name)

# test get_deployments()
deployment = next(db.get_deployments())
Expand All @@ -87,6 +88,7 @@ def test_deployment_db():
# sanity check all the fields
assert deployment.contract_address == contract.address
assert deployment.contract_name == contract.contract_name
assert deployment.contract_name == contract_name
assert deployment.deployer == boa.env.eoa
assert deployment.rpc == boa.env._rpc.name
assert deployment.source_code == contract.deployer.solc_json
Expand Down
27 changes: 27 additions & 0 deletions tests/unitary/contracts/vyper/test_vyper_contract.py
Original file line number Diff line number Diff line change
Expand Up @@ -74,6 +74,33 @@ def foo() -> bool:
c.foo()


def test_contract_name():
code = """
@external
def foo() -> bool:
return True
"""
c = boa.loads(code, name="return_one", filename="return_one.vy")

assert c.contract_name == "return_one"
assert c.filename == "return_one.vy"

c = boa.loads(code, filename="a/b/return_one.vy")

assert c.contract_name == "return_one"
assert c.filename == "a/b/return_one.vy"

c = boa.loads(code, filename=None, name="dummy_name")

assert c.contract_name == "dummy_name"
assert c.filename == "<unknown>"

c = boa.loads(code, filename=None, name=None)

assert c.contract_name == "<unknown>"
assert c.filename == "<unknown>"


def test_stomp():
code1 = """
VAR: immutable(uint256)
Expand Down
8 changes: 4 additions & 4 deletions tests/unitary/utils/test_cache.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,12 +22,12 @@ def test_cache_contract_name():
x: constant(int128) = 1000
"""
assert _disk_cache is not None
test1 = compiler_data(code, "test1", __file__, VyperDeployer)
test2 = compiler_data(code, "test2", __file__, VyperDeployer)
test3 = compiler_data(code, "test1", __file__, VyperDeployer)
test1 = compiler_data(code, "test1", "test1.vy", VyperDeployer)
test2 = compiler_data(code, "test2", "test2.vy", VyperDeployer)
test3 = compiler_data(code, "test1", "test1.vy", VyperDeployer)
assert _to_dict(test1) == _to_dict(test3), "Should hit the cache"
assert _to_dict(test1) != _to_dict(test2), "Should be different objects"
assert str(test2.contract_path) == "test2"
assert str(test2.contract_path) == "test2.vy"


def test_cache_vvm():
Expand Down

0 comments on commit f58c33c

Please sign in to comment.