From f81d850747ce0c227d5311994298c2ba1ced34f2 Mon Sep 17 00:00:00 2001 From: Daniel Schiavini Date: Mon, 19 Feb 2024 13:58:01 +0100 Subject: [PATCH] ADd prepare_calldata to overloaded functions --- boa/contracts/abi/abi_contract.py | 27 ++++++++++++++++++++++----- 1 file changed, 22 insertions(+), 5 deletions(-) diff --git a/boa/contracts/abi/abi_contract.py b/boa/contracts/abi/abi_contract.py index 77d98edb..e1b14f7b 100644 --- a/boa/contracts/abi/abi_contract.py +++ b/boa/contracts/abi/abi_contract.py @@ -83,6 +83,11 @@ def is_encodable(self, *args, **kwargs) -> bool: for abi_type, arg in zip(self.argument_types, parsed_args) ) + def prepare_calldata(self, *args, **kwargs) -> bytes: + """Prepare the call data for the function call.""" + abi_args = self._merge_kwargs(*args, **kwargs) + return self.method_id + abi_encode(self.signature, abi_args) + def _merge_kwargs(self, *args, **kwargs) -> list: """Merge positional and keyword arguments into a single list.""" if len(kwargs) + len(args) != self.argument_count: @@ -120,10 +125,6 @@ def __call__(self, *args, value=0, gas=None, sender=None, **kwargs): case multiple: return tuple(multiple) - def prepare_calldata(self, *args, **kwargs): - args = self._merge_kwargs(*args, **kwargs) - return self.method_id + abi_encode(self.signature, args) - class ABIOverload: """ @@ -154,6 +155,13 @@ def __init__(self, functions: list[ABIFunction]): def name(self) -> str: return self.functions[0].name + def prepare_calldata(self, *args, disambiguate_signature=None, **kwargs) -> bytes: + """Prepare the calldata for the function that matches the given arguments.""" + function = self._pick_overload( + *args, disambiguate_signature=disambiguate_signature, **kwargs + ) + return function.prepare_calldata(*args, **kwargs) + def __call__( self, *args, @@ -167,6 +175,15 @@ def __call__( Call the function that matches the given arguments. :raises Exception: if a single function is not found """ + function = self._pick_overload( + *args, disambiguate_signature=disambiguate_signature, **kwargs + ) + return function(*args, value=value, gas=gas, sender=sender, **kwargs) + + def _pick_overload( + self, *args, disambiguate_signature=None, **kwargs + ) -> ABIFunction: + """Pick the function that matches the given arguments.""" if disambiguate_signature is None: matches = [f for f in self.functions if f.is_encodable(*args, **kwargs)] else: @@ -177,7 +194,7 @@ def __call__( match matches: case [function]: - return function(*args, value=value, gas=gas, sender=sender, **kwargs) + return function case []: raise Exception( f"Could not find matching {self.name} function for given arguments."