Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

fix: pass contract_name to VyperContract #338

Merged
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 2 additions & 1 deletion boa/contracts/vvm/vvm_contract.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,7 +54,7 @@ def constructor(self):
return ABIFunction(t, contract_name=self.filename)
return None

def deploy(self, *args, env=None, **kwargs):
def deploy(self, *args, contract_name=None, env=None, **kwargs):
encoded_args = b""
if self.constructor is not None:
encoded_args = self.constructor.prepare_calldata(*args)
Expand All @@ -66,6 +66,7 @@ def deploy(self, *args, env=None, **kwargs):

address, _ = env.deploy_code(bytecode=self.bytecode + encoded_args, **kwargs)

# TODO: pass thru contract_name
return self.at(address)

@cached_property
Expand Down
13 changes: 8 additions & 5 deletions boa/contracts/vyper/vyper_contract.py
Original file line number Diff line number Diff line change
Expand Up @@ -144,10 +144,13 @@ class _BaseVyperContract(_BaseEVMContract):
def __init__(
self,
compiler_data: CompilerData,
contract_name: Optional[str] = None,
env: Optional[Env] = None,
filename: Optional[str] = None,
):
contract_name = Path(compiler_data.contract_path).stem
if contract_name is None:
contract_name = Path(compiler_data.contract_path).stem

super().__init__(contract_name, env, filename)
self.compiler_data = compiler_data

Expand Down Expand Up @@ -185,12 +188,11 @@ def __init__(
env=None,
override_address=None,
blueprint_preamble=None,
contract_name=None,
filename=None,
gas=None,
):
# note slight code duplication with VyperContract ctor,
# maybe use common base class?
super().__init__(compiler_data, env, filename)
super().__init__(compiler_data, contract_name, env, filename)

deploy_bytecode = generate_blueprint_bytecode(
compiler_data.bytecode, blueprint_preamble
Expand Down Expand Up @@ -516,10 +518,11 @@ def __init__(
# whether to skip constructor
skip_initcode=False,
created_from: Address = None,
contract_name=None,
filename: str = None,
gas=None,
):
super().__init__(compiler_data, env, filename)
super().__init__(compiler_data, contract_name, env, filename)

self.created_from = created_from
self._computation = None
Expand Down
18 changes: 11 additions & 7 deletions boa/interpret.py
Original file line number Diff line number Diff line change
Expand Up @@ -129,11 +129,15 @@ def get_module_fingerprint(


def compiler_data(
source_code: str, contract_name: str, filename: str | Path, deployer=None, **kwargs
source_code: str,
contract_name: str | None,
filename: str | Path,
deployer=None,
**kwargs,
) -> CompilerData:
global _disk_cache, _search_path

path = Path(contract_name)
path = Path(filename)
resolved_path = Path(filename).resolve(strict=False)

file_input = FileInput(
Expand Down Expand Up @@ -164,7 +168,7 @@ def get_compiler_data():

assert isinstance(deployer, type) or deployer is None
deployer_id = repr(deployer) # a unique str identifying the deployer class
cache_key = str((contract_name, fingerprint, kwargs, deployer_id))
cache_key = str((contract_name, filename, fingerprint, kwargs, deployer_id))
return _disk_cache.caching_lookup(cache_key, get_compiler_data)


Expand All @@ -188,9 +192,9 @@ def loads(
):
d = loads_partial(source_code, name, filename=filename, compiler_args=compiler_args)
if as_blueprint:
return d.deploy_as_blueprint(**kwargs)
return d.deploy_as_blueprint(contract_name=name, **kwargs)
else:
return d.deploy(*args, **kwargs)
return d.deploy(*args, contract_name=name, **kwargs)


def load_abi(filename: str, *args, name: str = None, **kwargs) -> ABIContractFactory:
Expand Down Expand Up @@ -239,8 +243,8 @@ def loads_partial(
dedent: bool = True,
compiler_args: dict = None,
) -> VyperDeployer:
name = name or "VyperContract"
filename = filename or "<unknown>"
if filename is None:
filename = "<unknown>"

if dedent:
source_code = textwrap.dedent(source_code)
Expand Down
1 change: 1 addition & 0 deletions dev-requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@ pytest
pytest-xdist
pytest-cov
sphinx-rtd-theme
requests-cache

# jupyter
jupyter_server
Expand Down
32 changes: 31 additions & 1 deletion tests/integration/network/anvil/test_network_env.py
Original file line number Diff line number Diff line change
Expand Up @@ -73,9 +73,38 @@ def test_failed_transaction():
# XXX: probably want to test deployment revert behavior


def test_deployment_db():
def test_deployment_db_overriden_contract_name():
with set_deployments_db(DeploymentsDB(":memory:")) as db:
arg = 5
contract_name = "test_deployment"

# contract is written to deployments db
contract = boa.loads(code, arg, contract_name=contract_name)

# test get_deployments()
deployment = next(db.get_deployments())

initcode = contract.compiler_data.bytecode + arg.to_bytes(32, "big")

# sanity check all the fields
assert deployment.contract_address == contract.address
charles-cooper marked this conversation as resolved.
Show resolved Hide resolved
assert deployment.contract_name == contract.contract_name
assert deployment.contract_name == contract_name
assert deployment.deployer == boa.env.eoa
assert deployment.rpc == boa.env._rpc.name
assert deployment.source_code == contract.deployer.solc_json
assert deployment.abi == contract.abi

# some sanity checks on tx_dict and rx_dict fields
assert to_bytes(deployment.tx_dict["data"]) == initcode
assert deployment.tx_dict["chainId"] == hex(boa.env.get_chain_id())
assert Address(deployment.receipt_dict["contractAddress"]) == contract.address


def test_deployment_db_no_overriden_name():
with set_deployments_db(DeploymentsDB(":memory:")) as db:
arg = 5
non_contract_name = "test_deployment"

# contract is written to deployments db
contract = boa.loads(code, arg)
Expand All @@ -88,6 +117,7 @@ def test_deployment_db():
# sanity check all the fields
assert deployment.contract_address == contract.address
assert deployment.contract_name == contract.contract_name
assert deployment.contract_name != non_contract_name
assert deployment.deployer == boa.env.eoa
assert deployment.rpc == boa.env._rpc.name
assert deployment.source_code == contract.deployer.solc_json
Expand Down
4 changes: 3 additions & 1 deletion tests/integration/network/sepolia/test_sepolia_env.py
Original file line number Diff line number Diff line change
Expand Up @@ -75,9 +75,10 @@ def test_raise_exception(simple_contract, amount):
def test_deployment_db():
with set_deployments_db(DeploymentsDB(":memory:")) as db:
arg = 5
contract_name = "test_deployment"

# contract is written to deployments db
contract = boa.loads(code, arg)
contract = boa.loads(code, arg, contract_name=contract_name)

# test get_deployments()
deployment = next(db.get_deployments())
Expand All @@ -87,6 +88,7 @@ def test_deployment_db():
# sanity check all the fields
assert deployment.contract_address == contract.address
assert deployment.contract_name == contract.contract_name
assert deployment.contract_name == contract_name
assert deployment.deployer == boa.env.eoa
assert deployment.rpc == boa.env._rpc.name
assert deployment.source_code == contract.deployer.solc_json
Expand Down
27 changes: 27 additions & 0 deletions tests/unitary/contracts/vyper/test_vyper_contract.py
Original file line number Diff line number Diff line change
Expand Up @@ -74,6 +74,33 @@ def foo() -> bool:
c.foo()


def test_contract_name():
code = """
@external
def foo() -> bool:
return True
"""
c = boa.loads(code, name="return_one", filename="return_one.vy")

assert c.contract_name == "return_one"
assert c.filename == "return_one.vy"

c = boa.loads(code, filename="a/b/return_one.vy")

assert c.contract_name == "return_one"
assert c.filename == "a/b/return_one.vy"

c = boa.loads(code, filename=None, name="dummy_name")

assert c.contract_name == "dummy_name"
assert c.filename == "<unknown>"

c = boa.loads(code, filename=None, name=None)

assert c.contract_name == "<unknown>"
assert c.filename == "<unknown>"


def test_stomp():
code1 = """
VAR: immutable(uint256)
Expand Down
8 changes: 4 additions & 4 deletions tests/unitary/utils/test_cache.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,12 +22,12 @@ def test_cache_contract_name():
x: constant(int128) = 1000
"""
assert _disk_cache is not None
test1 = compiler_data(code, "test1", __file__, VyperDeployer)
test2 = compiler_data(code, "test2", __file__, VyperDeployer)
test3 = compiler_data(code, "test1", __file__, VyperDeployer)
test1 = compiler_data(code, "test1", "test1.vy", VyperDeployer)
test2 = compiler_data(code, "test2", "test2.vy", VyperDeployer)
test3 = compiler_data(code, "test1", "test1.vy", VyperDeployer)
assert _to_dict(test1) == _to_dict(test3), "Should hit the cache"
assert _to_dict(test1) != _to_dict(test2), "Should be different objects"
assert str(test2.contract_path) == "test2"
assert str(test2.contract_path) == "test2.vy"


def test_cache_vvm():
Expand Down
Loading