Skip to content

Commit

Permalink
ADd prepare_calldata to overloaded functions
Browse files Browse the repository at this point in the history
  • Loading branch information
DanielSchiavini committed Feb 19, 2024
1 parent 3d2ae39 commit f81d850
Showing 1 changed file with 22 additions and 5 deletions.
27 changes: 22 additions & 5 deletions boa/contracts/abi/abi_contract.py
Original file line number Diff line number Diff line change
Expand Up @@ -83,6 +83,11 @@ def is_encodable(self, *args, **kwargs) -> bool:
for abi_type, arg in zip(self.argument_types, parsed_args)
)

def prepare_calldata(self, *args, **kwargs) -> bytes:
"""Prepare the call data for the function call."""
abi_args = self._merge_kwargs(*args, **kwargs)
return self.method_id + abi_encode(self.signature, abi_args)

def _merge_kwargs(self, *args, **kwargs) -> list:
"""Merge positional and keyword arguments into a single list."""
if len(kwargs) + len(args) != self.argument_count:
Expand Down Expand Up @@ -120,10 +125,6 @@ def __call__(self, *args, value=0, gas=None, sender=None, **kwargs):
case multiple:
return tuple(multiple)

def prepare_calldata(self, *args, **kwargs):
args = self._merge_kwargs(*args, **kwargs)
return self.method_id + abi_encode(self.signature, args)


class ABIOverload:
"""
Expand Down Expand Up @@ -154,6 +155,13 @@ def __init__(self, functions: list[ABIFunction]):
def name(self) -> str:
return self.functions[0].name

def prepare_calldata(self, *args, disambiguate_signature=None, **kwargs) -> bytes:
"""Prepare the calldata for the function that matches the given arguments."""
function = self._pick_overload(
*args, disambiguate_signature=disambiguate_signature, **kwargs
)
return function.prepare_calldata(*args, **kwargs)

def __call__(
self,
*args,
Expand All @@ -167,6 +175,15 @@ def __call__(
Call the function that matches the given arguments.
:raises Exception: if a single function is not found
"""
function = self._pick_overload(
*args, disambiguate_signature=disambiguate_signature, **kwargs
)
return function(*args, value=value, gas=gas, sender=sender, **kwargs)

def _pick_overload(
self, *args, disambiguate_signature=None, **kwargs
) -> ABIFunction:
"""Pick the function that matches the given arguments."""
if disambiguate_signature is None:
matches = [f for f in self.functions if f.is_encodable(*args, **kwargs)]
else:
Expand All @@ -177,7 +194,7 @@ def __call__(

match matches:
case [function]:
return function(*args, value=value, gas=gas, sender=sender, **kwargs)
return function
case []:
raise Exception(
f"Could not find matching {self.name} function for given arguments."
Expand Down

0 comments on commit f81d850

Please sign in to comment.