Skip to content

Commit

Permalink
Add support for method overloading
Browse files Browse the repository at this point in the history
  • Loading branch information
fjarri committed Dec 27, 2023
1 parent 2a6a59c commit 67a9aef
Show file tree
Hide file tree
Showing 7 changed files with 276 additions and 29 deletions.
4 changes: 4 additions & 0 deletions docs/api.rst
Original file line number Diff line number Diff line change
Expand Up @@ -103,6 +103,10 @@ Contract ABI
:members:
:special-members: __call__

.. autoclass:: MultiMethod
:members:
:special-members: __call__

.. autoclass:: Event
:members:

Expand Down
2 changes: 2 additions & 0 deletions pons/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,7 @@
Fallback,
Method,
MethodCall,
MultiMethod,
Mutability,
Receive,
)
Expand Down Expand Up @@ -83,6 +84,7 @@
"JSON",
"Method",
"MethodCall",
"MultiMethod",
"Mutability",
"PriorityFallback",
"ProviderError",
Expand Down
23 changes: 18 additions & 5 deletions pons/_contract.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,14 @@
from typing import Any, Dict, List, Optional, Tuple

from ._contract_abi import ContractABI, Error, Event, EventFilter, Method, Methods
from typing import Any, Dict, List, Optional, Tuple, Union

from ._contract_abi import (
ContractABI,
Error,
Event,
EventFilter,
Method,
Methods,
MultiMethod,
)
from ._entities import Address, LogEntry, LogTopic
from ._provider import JSON

Expand Down Expand Up @@ -43,7 +51,12 @@ def __init__(self, contract_abi: ContractABI, data_bytes: bytes, *, payable: boo
class BoundMethod:
"""A regular method bound to a specific contract's address."""

def __init__(self, contract_abi: ContractABI, contract_address: Address, method: Method):
def __init__(
self,
contract_abi: ContractABI,
contract_address: Address,
method: Union[Method, MultiMethod],
):
self._contract_abi = contract_abi
self._contract_address = contract_address
self._method = method
Expand All @@ -52,7 +65,7 @@ def __call__(self, *args: Any, **kwargs: Any) -> "BoundMethodCall":
"""Returns a contract call with encoded arguments bound to a specific address."""
call = self._method(*args, **kwargs)
return BoundMethodCall(
self._contract_abi, self._method, self._contract_address, call.data_bytes
self._contract_abi, call.method, self._contract_address, call.data_bytes
)


Expand Down
135 changes: 117 additions & 18 deletions pons/_contract_abi.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
import inspect
from enum import Enum
from functools import cached_property
from inspect import BoundArguments
from itertools import chain
from keyword import iskeyword
from typing import (
Expand Down Expand Up @@ -83,13 +84,19 @@ def canonical_form(self) -> str:
"""Returns the signature serialized in the canonical form as a string."""
return "(" + ",".join(tp.canonical_form for tp in self._types) + ")"

def bind(self, *args: Any, **kwargs: Any) -> BoundArguments:
return self._signature.bind(*args, **kwargs)

def encode_bound(self, bound_args: BoundArguments) -> bytes:
return encode_args(*zip(self._types, bound_args.args))

def encode(self, *args: Any, **kwargs: Any) -> bytes:
"""
Encodes assorted positional/keyword arguments into the bytestring
according to the ABI format.
"""
bound_args = self._signature.bind(*args, **kwargs)
return encode_args(*zip(self._types, bound_args.args))
bound_args = self.bind(*args, **kwargs)
return self.encode_bound(bound_args)

def decode_into_tuple(self, value_bytes: bytes) -> Tuple[Any, ...]:
"""Decodes the packed bytestring into a list of values."""
Expand Down Expand Up @@ -363,37 +370,116 @@ def __init__(

@property
def name(self) -> str:
"""The name of this method."""
return self._name

@property
def inputs(self) -> Signature:
"""The input signature of this method."""
return self._inputs

def bind(self, *args: Any, **kwargs: Any) -> BoundArguments:
return self._inputs.bind(*args, **kwargs)

def __call__(self, *args: Any, **kwargs: Any) -> "MethodCall":
"""Returns an encoded call with given arguments."""
return MethodCall(self._encode_call(*args, **kwargs))
bound_args = self.bind(*args, **kwargs)
return self.call_bound(bound_args)

def call_bound(self, bound_args: BoundArguments) -> "MethodCall":
input_bytes = self.inputs.encode_bound(bound_args)
encoded = self.selector + input_bytes
return MethodCall(self, encoded)

@cached_property
def selector(self) -> bytes:
"""Method's selector."""
return keccak(self.name.encode() + self.inputs.canonical_form.encode())[:SELECTOR_LENGTH]

def _encode_call(self, *args: Any, **kwargs: Any) -> bytes:
input_bytes = self.inputs.encode(*args, **kwargs)
return self.selector + input_bytes

def decode_output(self, output_bytes: bytes) -> Any:
"""Decodes the output from ABI-packed bytes."""
results = self.outputs.decode_into_tuple(output_bytes)
if self._single_output:
results = results[0]
return results

def with_method(self, method: "Method") -> "MultiMethod":
return MultiMethod(self, method)

def __str__(self) -> str:
returns = "" if self.outputs.empty else f" returns {self.outputs}"
return f"function {self.name}{self.inputs} {self._mutability.value}{returns}"


class MultiMethod:
"""
An overloaded contract method, containing several :py:class:`Method` objects with the same name
but different input signatures.
"""

def __init__(self, *methods: Method):
if len(methods) == 0:
raise ValueError("`methods` cannot be empty")
first_method = methods[0]
self._methods = {first_method.inputs.canonical_form: first_method}
self._name = first_method.name

for method in methods[1:]:
self._add_method(method)

def __getitem__(self, args: str) -> Method:
"""
Returns the :py:class:`Method` with the given canonical form of an input signature
(corresponding to :py:attr:`Signature.canonical_form`).
"""
return self._methods[args]

@property
def name(self) -> str:
"""The name of this method."""
return self._name

@property
def methods(self) -> Dict[str, Method]:
"""All the overloaded methods, indexed by the canonical form of their input signatures."""
return self._methods

def _add_method(self, method: Method) -> None:
if method.name != self.name:
raise ValueError("All overloaded methods must have the same name")
if method.inputs.canonical_form in self._methods:
raise ValueError(
f"A method {self.name}{method.inputs.canonical_form} "
"is already registered in this MultiMethod"
)
self._methods[method.inputs.canonical_form] = method

def with_method(self, method: Method) -> "MultiMethod":
"""Returns a new ``MultiMethod`` with the given method included."""
new_mm = MultiMethod(*self._methods.values())
new_mm._add_method(method) # noqa: SLF001
return new_mm

def __call__(self, *args: Any, **kwds: Any) -> "MethodCall":
"""Returns an encoded call with given arguments."""
for method in self._methods.values():
try:
bound_args = method.bind(*args, **kwds)
except TypeError:
# If it's a non-overloaded method, we do not want to complicate things
if len(self._methods) == 1:
raise

continue

return method.call_bound(bound_args)

raise TypeError("Could not find a suitable overloaded method for the given arguments")

def __str__(self) -> str:
return "; ".join(str(method) for method in self._methods.values())


class Event:
"""
A contract event.
Expand Down Expand Up @@ -601,7 +687,11 @@ class MethodCall:
data_bytes: bytes
"""Encoded call arguments with the selector."""

def __init__(self, data_bytes: bytes):
method: Method
"""The method object that encoded this call."""

def __init__(self, method: Method, data_bytes: bytes):
self.method = method
self.data_bytes = data_bytes


Expand Down Expand Up @@ -661,7 +751,7 @@ class ContractABI:
receive: Optional[Receive]
"""Contract's receive method."""

method: Methods[Method]
method: Methods[Union[Method, MultiMethod]]
"""Contract's regular methods."""

event: Methods[Event]
Expand All @@ -676,7 +766,7 @@ def from_json(cls, json_abi: List[Dict[str, JSON]]) -> "ContractABI": # noqa: C
constructor = None
fallback = None
receive = None
methods = {}
methods: Dict[Any, Union[Method, MultiMethod]] = {}
events = {}
errors = {}

Expand All @@ -687,12 +777,11 @@ def from_json(cls, json_abi: List[Dict[str, JSON]]) -> "ContractABI": # noqa: C
constructor = Constructor.from_json(entry)

elif entry["type"] == "function":
method = Method.from_json(entry)
if entry["name"] in methods:
# TODO (#21): add support for overloaded methods
raise ValueError(
f"JSON ABI contains more than one declarations of `{entry['name']}`"
)
methods[entry["name"]] = Method.from_json(entry)
methods[entry["name"]] = methods[entry["name"]].with_method(method)
else:
methods[entry["name"]] = method

elif entry["type"] == "fallback":
if fallback:
Expand Down Expand Up @@ -735,7 +824,7 @@ def __init__(
constructor: Optional[Constructor] = None,
fallback: Optional[Fallback] = None,
receive: Optional[Receive] = None,
methods: Optional[Iterable[Method]] = None,
methods: Optional[Iterable[Union[Method, MultiMethod]]] = None,
events: Optional[Iterable[Event]] = None,
errors: Optional[Iterable[Error]] = None,
):
Expand Down Expand Up @@ -771,13 +860,23 @@ def resolve_error(self, error_data: bytes) -> Tuple[Error, Dict[str, Any]]:
raise UnknownError(f"Could not find an error with selector {selector.hex()} in the ABI")

def __str__(self) -> str:
all_methods: Iterable[Union[Constructor, Fallback, Receive, Method, Event, Error]] = chain(
all_methods: Iterable[
Union[Constructor, Fallback, Receive, Method, MultiMethod, Event, Error]
] = chain(
[self.constructor] if self.constructor else [],
[self.fallback] if self.fallback else [],
[self.receive] if self.receive else [],
self.method,
self.event,
self.error,
)
method_list = [" " + str(method) for method in all_methods]

indent = " "

def to_str(item: Any) -> str:
if isinstance(item, MultiMethod):
return ("\n" + indent).join(str(method) for method in item.methods.values())
return str(item)

method_list = [indent + to_str(method) for method in all_methods]
return "{\n" + "\n".join(method_list) + "\n}"
8 changes: 8 additions & 0 deletions tests/TestContractFunctionality.sol
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,14 @@ contract Test {
return v1 + _x;
}

function overloaded(uint256 _x, uint256 _y) public view returns (uint256) {
return _y + _x;
}

function overloaded(uint256 _x) public view returns (uint256) {
return v1 + _x;
}

struct Inner {
uint256 inner1;
uint256 inner2;
Expand Down
Loading

0 comments on commit 67a9aef

Please sign in to comment.