diff --git a/boa/__init__.py b/boa/__init__.py index e8cea0b5..03e3fb04 100644 --- a/boa/__init__.py +++ b/boa/__init__.py @@ -13,9 +13,11 @@ load, load_abi, load_partial, + load_vyi, loads, loads_abi, loads_partial, + loads_vyi, ) from boa.network import NetworkEnv from boa.precompile import precompile diff --git a/boa/interpret.py b/boa/interpret.py index 4447335f..17136d3c 100644 --- a/boa/interpret.py +++ b/boa/interpret.py @@ -8,6 +8,7 @@ import vvm import vyper +from vyper.ast.parse import parse_to_ast from vyper.cli.vyper_compile import get_search_paths from vyper.compiler.input_bundle import ( ABIInput, @@ -17,6 +18,7 @@ ) from vyper.compiler.phases import CompilerData from vyper.compiler.settings import Settings, anchor_settings +from vyper.semantics.analysis.module import analyze_module from vyper.semantics.types.module import ModuleT from vyper.utils import sha256sum @@ -194,6 +196,7 @@ def loads( def load_abi(filename: str, *args, name: str = None, **kwargs) -> ABIContractFactory: if name is None: name = Path(filename).stem + # TODO: pass filename to ABIContractFactory with open(filename) as fp: return loads_abi(fp.read(), *args, name=name, **kwargs) @@ -202,6 +205,33 @@ def loads_abi(json_str: str, *args, name: str = None, **kwargs) -> ABIContractFa return ABIContractFactory.from_abi_dict(json.loads(json_str), name, *args, **kwargs) +# load from .vyi file. +# NOTE: substantially same interface as load_abi and loads_abi, consider +# fusing them into load_interface? +def load_vyi(filename: str, name: str = None) -> ABIContractFactory: + if name is None: + name = Path(filename).stem + with open(filename) as fp: + return loads_vyi(fp.read(), name=name, filename=filename) + + +# load interface from .vyi file string contents. +def loads_vyi(source_code: str, name: str = None, filename: str = None): + global _search_path + + ast = parse_to_ast(source_code) + + if name is None: + name = "VyperContract.vyi" + + search_paths = get_search_paths(_search_path) + input_bundle = FilesystemInputBundle(search_paths) + + module_t = analyze_module(ast, input_bundle, is_interface=True) + abi = module_t.interface.to_toplevel_abi_dict() + return ABIContractFactory(name, abi, filename=filename) + + def loads_partial( source_code: str, name: str = None, diff --git a/tests/unitary/contracts/vyper/test_vyi.py b/tests/unitary/contracts/vyper/test_vyi.py new file mode 100644 index 00000000..8269c9b2 --- /dev/null +++ b/tests/unitary/contracts/vyper/test_vyi.py @@ -0,0 +1,42 @@ +import pytest + +import boa + +FOO_CONTRACT = """ +@external +def foo() -> uint256: + return 5 +""" + +FOO_INTERFACE = """ +@external +def foo() -> uint256: + ... +""" + + +@pytest.fixture +def foo_contract(): + return boa.loads(FOO_CONTRACT) + + +@pytest.fixture +def foo_interface(foo_contract): + return boa.loads_vyi(FOO_INTERFACE).at(foo_contract.address) + + +# from file +@pytest.fixture +def foo_interface2(foo_contract, tmp_path): + p = tmp_path / "foo.vyi" + with p.open("w") as f: + f.write(FOO_INTERFACE) + return boa.load_vyi(p).at(foo_contract.address) + + +def test_foo_interface(foo_interface): + assert foo_interface.foo() == 5 + + +def test_foo_interface2(foo_interface2): + assert foo_interface2.foo() == 5