diff --git a/docs/api.rst b/docs/api.rst index 3348876..401dd47 100644 --- a/docs/api.rst +++ b/docs/api.rst @@ -103,6 +103,10 @@ Contract ABI :members: :special-members: __call__ +.. autoclass:: MultiMethod + :members: + :special-members: __call__ + .. autoclass:: Event :members: diff --git a/pons/__init__.py b/pons/__init__.py index c66bf30..7aac6f8 100644 --- a/pons/__init__.py +++ b/pons/__init__.py @@ -33,6 +33,7 @@ Fallback, Method, MethodCall, + MultiMethod, Mutability, Receive, ) @@ -83,6 +84,7 @@ "JSON", "Method", "MethodCall", + "MultiMethod", "Mutability", "PriorityFallback", "ProviderError", diff --git a/pons/_contract.py b/pons/_contract.py index be824bf..dde1f9a 100644 --- a/pons/_contract.py +++ b/pons/_contract.py @@ -1,6 +1,14 @@ -from typing import Any, Dict, List, Optional, Tuple - -from ._contract_abi import ContractABI, Error, Event, EventFilter, Method, Methods +from typing import Any, Dict, List, Optional, Tuple, Union + +from ._contract_abi import ( + ContractABI, + Error, + Event, + EventFilter, + Method, + Methods, + MultiMethod, +) from ._entities import Address, LogEntry, LogTopic from ._provider import JSON @@ -43,7 +51,12 @@ def __init__(self, contract_abi: ContractABI, data_bytes: bytes, *, payable: boo class BoundMethod: """A regular method bound to a specific contract's address.""" - def __init__(self, contract_abi: ContractABI, contract_address: Address, method: Method): + def __init__( + self, + contract_abi: ContractABI, + contract_address: Address, + method: Union[Method, MultiMethod], + ): self._contract_abi = contract_abi self._contract_address = contract_address self._method = method @@ -52,7 +65,7 @@ def __call__(self, *args: Any, **kwargs: Any) -> "BoundMethodCall": """Returns a contract call with encoded arguments bound to a specific address.""" call = self._method(*args, **kwargs) return BoundMethodCall( - self._contract_abi, self._method, self._contract_address, call.data_bytes + self._contract_abi, call.method, self._contract_address, call.data_bytes ) diff --git a/pons/_contract_abi.py b/pons/_contract_abi.py index 767c189..37dea4d 100644 --- a/pons/_contract_abi.py +++ b/pons/_contract_abi.py @@ -1,6 +1,7 @@ import inspect from enum import Enum from functools import cached_property +from inspect import BoundArguments from itertools import chain from keyword import iskeyword from typing import ( @@ -83,13 +84,19 @@ def canonical_form(self) -> str: """Returns the signature serialized in the canonical form as a string.""" return "(" + ",".join(tp.canonical_form for tp in self._types) + ")" + def bind(self, *args: Any, **kwargs: Any) -> BoundArguments: + return self._signature.bind(*args, **kwargs) + + def encode_bound(self, bound_args: BoundArguments) -> bytes: + return encode_args(*zip(self._types, bound_args.args)) + def encode(self, *args: Any, **kwargs: Any) -> bytes: """ Encodes assorted positional/keyword arguments into the bytestring according to the ABI format. """ - bound_args = self._signature.bind(*args, **kwargs) - return encode_args(*zip(self._types, bound_args.args)) + bound_args = self.bind(*args, **kwargs) + return self.encode_bound(bound_args) def decode_into_tuple(self, value_bytes: bytes) -> Tuple[Any, ...]: """Decodes the packed bytestring into a list of values.""" @@ -363,25 +370,32 @@ def __init__( @property def name(self) -> str: + """The name of this method.""" return self._name @property def inputs(self) -> Signature: + """The input signature of this method.""" return self._inputs + def bind(self, *args: Any, **kwargs: Any) -> BoundArguments: + return self._inputs.bind(*args, **kwargs) + def __call__(self, *args: Any, **kwargs: Any) -> "MethodCall": """Returns an encoded call with given arguments.""" - return MethodCall(self._encode_call(*args, **kwargs)) + bound_args = self.bind(*args, **kwargs) + return self.call_bound(bound_args) + + def call_bound(self, bound_args: BoundArguments) -> "MethodCall": + input_bytes = self.inputs.encode_bound(bound_args) + encoded = self.selector + input_bytes + return MethodCall(self, encoded) @cached_property def selector(self) -> bytes: """Method's selector.""" return keccak(self.name.encode() + self.inputs.canonical_form.encode())[:SELECTOR_LENGTH] - def _encode_call(self, *args: Any, **kwargs: Any) -> bytes: - input_bytes = self.inputs.encode(*args, **kwargs) - return self.selector + input_bytes - def decode_output(self, output_bytes: bytes) -> Any: """Decodes the output from ABI-packed bytes.""" results = self.outputs.decode_into_tuple(output_bytes) @@ -389,11 +403,83 @@ def decode_output(self, output_bytes: bytes) -> Any: results = results[0] return results + def with_method(self, method: "Method") -> "MultiMethod": + return MultiMethod(self, method) + def __str__(self) -> str: returns = "" if self.outputs.empty else f" returns {self.outputs}" return f"function {self.name}{self.inputs} {self._mutability.value}{returns}" +class MultiMethod: + """ + An overloaded contract method, containing several :py:class:`Method` objects with the same name + but different input signatures. + """ + + def __init__(self, *methods: Method): + if len(methods) == 0: + raise ValueError("`methods` cannot be empty") + first_method = methods[0] + self._methods = {first_method.inputs.canonical_form: first_method} + self._name = first_method.name + + for method in methods[1:]: + self._add_method(method) + + def __getitem__(self, args: str) -> Method: + """ + Returns the :py:class:`Method` with the given canonical form of an input signature + (corresponding to :py:attr:`Signature.canonical_form`). + """ + return self._methods[args] + + @property + def name(self) -> str: + """The name of this method.""" + return self._name + + @property + def methods(self) -> Dict[str, Method]: + """All the overloaded methods, indexed by the canonical form of their input signatures.""" + return self._methods + + def _add_method(self, method: Method) -> None: + if method.name != self.name: + raise ValueError("All overloaded methods must have the same name") + if method.inputs.canonical_form in self._methods: + raise ValueError( + f"A method {self.name}{method.inputs.canonical_form} " + "is already registered in this MultiMethod" + ) + self._methods[method.inputs.canonical_form] = method + + def with_method(self, method: Method) -> "MultiMethod": + """Returns a new ``MultiMethod`` with the given method included.""" + new_mm = MultiMethod(*self._methods.values()) + new_mm._add_method(method) # noqa: SLF001 + return new_mm + + def __call__(self, *args: Any, **kwds: Any) -> "MethodCall": + """Returns an encoded call with given arguments.""" + for method in self._methods.values(): + try: + bound_args = method.bind(*args, **kwds) + except TypeError: + # If it's a non-overloaded method, we do not want to complicate things + if len(self._methods) == 1: + raise + + continue + + return method.call_bound(bound_args) + + raise TypeError("Could not find a suitable overloaded method for the given arguments") + + def __str__(self) -> str: + return "; ".join(str(method) for method in self._methods.values()) + + class Event: """ A contract event. @@ -601,7 +687,11 @@ class MethodCall: data_bytes: bytes """Encoded call arguments with the selector.""" - def __init__(self, data_bytes: bytes): + method: Method + """The method object that encoded this call.""" + + def __init__(self, method: Method, data_bytes: bytes): + self.method = method self.data_bytes = data_bytes @@ -661,7 +751,7 @@ class ContractABI: receive: Optional[Receive] """Contract's receive method.""" - method: Methods[Method] + method: Methods[Union[Method, MultiMethod]] """Contract's regular methods.""" event: Methods[Event] @@ -676,7 +766,7 @@ def from_json(cls, json_abi: List[Dict[str, JSON]]) -> "ContractABI": # noqa: C constructor = None fallback = None receive = None - methods = {} + methods: Dict[Any, Union[Method, MultiMethod]] = {} events = {} errors = {} @@ -687,12 +777,11 @@ def from_json(cls, json_abi: List[Dict[str, JSON]]) -> "ContractABI": # noqa: C constructor = Constructor.from_json(entry) elif entry["type"] == "function": + method = Method.from_json(entry) if entry["name"] in methods: - # TODO (#21): add support for overloaded methods - raise ValueError( - f"JSON ABI contains more than one declarations of `{entry['name']}`" - ) - methods[entry["name"]] = Method.from_json(entry) + methods[entry["name"]] = methods[entry["name"]].with_method(method) + else: + methods[entry["name"]] = method elif entry["type"] == "fallback": if fallback: @@ -735,7 +824,7 @@ def __init__( constructor: Optional[Constructor] = None, fallback: Optional[Fallback] = None, receive: Optional[Receive] = None, - methods: Optional[Iterable[Method]] = None, + methods: Optional[Iterable[Union[Method, MultiMethod]]] = None, events: Optional[Iterable[Event]] = None, errors: Optional[Iterable[Error]] = None, ): @@ -771,7 +860,9 @@ def resolve_error(self, error_data: bytes) -> Tuple[Error, Dict[str, Any]]: raise UnknownError(f"Could not find an error with selector {selector.hex()} in the ABI") def __str__(self) -> str: - all_methods: Iterable[Union[Constructor, Fallback, Receive, Method, Event, Error]] = chain( + all_methods: Iterable[ + Union[Constructor, Fallback, Receive, Method, MultiMethod, Event, Error] + ] = chain( [self.constructor] if self.constructor else [], [self.fallback] if self.fallback else [], [self.receive] if self.receive else [], @@ -779,5 +870,13 @@ def __str__(self) -> str: self.event, self.error, ) - method_list = [" " + str(method) for method in all_methods] + + indent = " " + + def to_str(item: Any) -> str: + if isinstance(item, MultiMethod): + return ("\n" + indent).join(str(method) for method in item.methods.values()) + return str(item) + + method_list = [indent + to_str(method) for method in all_methods] return "{\n" + "\n".join(method_list) + "\n}" diff --git a/tests/TestContractFunctionality.sol b/tests/TestContractFunctionality.sol index bbc0451..814b3ed 100644 --- a/tests/TestContractFunctionality.sol +++ b/tests/TestContractFunctionality.sol @@ -42,6 +42,14 @@ contract Test { return v1 + _x; } + function overloaded(uint256 _x, uint256 _y) public view returns (uint256) { + return _y + _x; + } + + function overloaded(uint256 _x) public view returns (uint256) { + return v1 + _x; + } + struct Inner { uint256 inner1; uint256 inner2; diff --git a/tests/test_contract_abi.py b/tests/test_contract_abi.py index ef06bac..369a590 100644 --- a/tests/test_contract_abi.py +++ b/tests/test_contract_abi.py @@ -12,6 +12,7 @@ Event, Fallback, Method, + MultiMethod, Mutability, Receive, abi, @@ -256,6 +257,79 @@ def test_method_errors(): Method.from_json(json) +def test_multi_method(): + method1 = Method( + name="someMethod", + mutability=Mutability.VIEW, + inputs=dict(a=abi.uint(8), b=abi.bool), + outputs=abi.uint(8), + ) + method2 = Method( + name="someMethod", + mutability=Mutability.VIEW, + inputs=dict(a=abi.uint(8)), + outputs=abi.uint(8), + ) + + multi_method = MultiMethod(method1, method2) + assert multi_method["(uint8,bool)"] == method1 + assert multi_method["(uint8)"] == method2 + + assert str(multi_method) == ( + "function someMethod(uint8 a, bool b) view returns (uint8); " + "function someMethod(uint8 a) view returns (uint8)" + ) + + # Create sequentially + multi_method = MultiMethod(method1).with_method(method2) + assert multi_method["(uint8,bool)"] == method1 + assert multi_method["(uint8)"] == method2 + + # Call the first method + call = multi_method(1, b=True) + assert call.method == method1 + + # Call the second method + call = multi_method(a=1) + assert call.method == method2 + + # Call with arguments not matching any of the methods + with pytest.raises( + TypeError, match="Could not find a suitable overloaded method for the given arguments" + ): + multi_method(1, True, 2) + + # If the multi-method only contains one method, raise the binding error right away + multi_method = MultiMethod(method1) + with pytest.raises(TypeError, match="missing a required argument: 'b'"): + multi_method(1) + + +def test_multi_method_errors(): + with pytest.raises(ValueError, match="`methods` cannot be empty"): + MultiMethod() + + method = Method( + name="someMethod", + mutability=Mutability.VIEW, + inputs=dict(a=abi.uint(8), b=abi.bool), + outputs=abi.uint(8), + ) + method_with_different_name = Method( + name="someMethod2", + mutability=Mutability.VIEW, + inputs=dict(a=abi.uint(8)), + outputs=abi.uint(8), + ) + + with pytest.raises(ValueError, match="All overloaded methods must have the same name"): + MultiMethod(method, method_with_different_name) + + msg = re.escape("A method someMethod(uint8,bool) is already registered in this MultiMethod") + with pytest.raises(ValueError, match=msg): + MultiMethod(method, method) + + def test_fallback(): fallback = Fallback.from_json(dict(type="fallback", stateMutability="payable")) assert fallback.payable @@ -425,6 +499,45 @@ def test_contract_abi_init(): assert isinstance(cabi.method.writeMethod, Method) +def test_overloaded_methods(): + json_abi = [ + dict( + type="function", + name="readMethod", + stateMutability="view", + inputs=[ + dict(type="uint8", name="a"), + dict(type="bool", name="b"), + ], + outputs=[ + dict(type="uint8", name=""), + ], + ), + dict( + type="function", + name="readMethod", + stateMutability="view", + inputs=[ + dict(type="uint8", name="a"), + ], + outputs=[ + dict(type="uint8", name=""), + ], + ), + ] + + cabi = ContractABI.from_json(json_abi) + assert str(cabi) == ( + "{\n" + " constructor() nonpayable\n" + " function readMethod(uint8 a, bool b) view returns (uint8)\n" + " function readMethod(uint8 a) view returns (uint8)\n" + "}" + ) + + assert isinstance(cabi.method.readMethod, MultiMethod) + + def test_no_constructor(): cabi = ContractABI() assert isinstance(cabi.constructor, Constructor) @@ -438,12 +551,6 @@ def test_contract_abi_errors(): ): abi = ContractABI.from_json([constructor_abi, constructor_abi]) - write_abi = dict(type="function", name="someMethod", stateMutability="payable", inputs=[]) - with pytest.raises( - ValueError, match="JSON ABI contains more than one declarations of `someMethod`" - ): - abi = ContractABI.from_json([write_abi, write_abi]) - fallback_abi = dict(type="fallback", stateMutability="payable") with pytest.raises(ValueError, match="JSON ABI contains more than one fallback declarations"): abi = ContractABI.from_json([fallback_abi, fallback_abi]) diff --git a/tests/test_contract_functionality.py b/tests/test_contract_functionality.py index af371b3..3c219c5 100644 --- a/tests/test_contract_functionality.py +++ b/tests/test_contract_functionality.py @@ -62,6 +62,20 @@ async def test_basics(session, root_signer, another_signer, compiled_contracts): assert result == (inner, outer) +async def test_overloaded_method(session, root_signer, another_signer, compiled_contracts): + compiled_contract = compiled_contracts["Test"] + + # Deploy the contract + call = compiled_contract.constructor(12345, 56789) + deployed_contract = await session.deploy(root_signer, call) + + result = await session.eth_call(deployed_contract.method.overloaded(123)) + assert result == (12345 + 123,) + + result = await session.eth_call(deployed_contract.method.overloaded(123, 456)) + assert result == (456 + 123,) + + async def test_read_only_mode(session, root_signer, compiled_contracts): # Test that a "nonpayable" (that is, mutating) method can still be invoked # via `eth_call`, and it will use the current state of the contract