Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: allow comparing currency values to currency strings #2149

Merged
merged 2 commits into from
Jun 17, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 5 additions & 3 deletions src/ape/api/address.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
from eth_pydantic_types import HexBytes

from ape.exceptions import ConversionError
from ape.types import AddressType, ContractCode
from ape.types import AddressType, ContractCode, CurrencyValue
from ape.utils import BaseInterface, abstractmethod, cached_property, log_instead_of_fail

if TYPE_CHECKING:
Expand Down Expand Up @@ -119,8 +119,10 @@ def balance(self) -> int:
"""
The total balance of the account.
"""

return self.provider.get_balance(self.address)
bal = self.provider.get_balance(self.address)
# By using CurrencyValue, we can compare with
# strings like "1 ether".
return CurrencyValue(bal)

# @balance.setter
# NOTE: commented out because of failure noted within `__setattr__`
Expand Down
19 changes: 19 additions & 0 deletions src/ape/types/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,7 @@
BaseInterfaceModel,
ExtraAttributesMixin,
ExtraModelAttributes,
ManagerAccessMixin,
cached_property,
)
from ape.utils.misc import ZERO_ADDRESS, log_instead_of_fail, to_int
Expand Down Expand Up @@ -470,6 +471,23 @@ def generator(self) -> Iterator:
yield from self._generator


class CurrencyValue(int):
"""
An integer you can compare with currency-value
strings, such as ``"1 ether"``.
"""

def __eq__(self, other: Any) -> bool:
if isinstance(other, int):
return super().__eq__(other)
elif isinstance(other, str):
other_value = ManagerAccessMixin.conversion_manager.convert(other, int)
return super().__eq__(other_value)

# Try from the other end, if hasn't already.
return NotImplemented


__all__ = [
"ABI",
"AddressType",
Expand All @@ -487,6 +505,7 @@ def generator(self) -> Iterator:
"CoverageProject",
"CoverageReport",
"CoverageStatement",
"CurrencyValue",
"GasReport",
"MessageSignature",
"PackageManifest",
Expand Down
7 changes: 4 additions & 3 deletions src/ape_ethereum/_converters.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
from decimal import Decimal

from ape.api import ConverterAPI
from ape.api.convert import ConverterAPI
from ape.types import CurrencyValue

ETHER_UNITS = {
"eth": int(1e18),
Expand Down Expand Up @@ -35,5 +36,5 @@ def is_convertible(self, value: str) -> bool:

def convert(self, value: str) -> int:
value, unit = value.split(" ")

return int(Decimal(value) * ETHER_UNITS[unit.lower()])
converted_value = int(Decimal(value) * ETHER_UNITS[unit.lower()])
return CurrencyValue(converted_value)
5 changes: 4 additions & 1 deletion tests/functional/conversion/test_ether.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,9 +18,12 @@
unit=st.sampled_from(list(ETHER_UNITS.keys())),
)
def test_ether_conversions(value, unit, convert):
actual = convert(f"{value} {unit}", int)
currency_str = f"{value} {unit}"
actual = convert(currency_str, int)
expected = int(value * ETHER_UNITS[unit])
assert actual == expected
# Also show can compare directly with str.
assert actual == currency_str


def test_bad_type(convert):
Expand Down
11 changes: 10 additions & 1 deletion tests/functional/test_address.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
import pytest
from pydantic import BaseModel

from ape.api.address import BaseAddress
from ape.api.address import Address, BaseAddress
from ape.types import AddressType


Expand Down Expand Up @@ -31,3 +31,12 @@ class CustomModel(BaseModel):

model = CustomModel(address=zero_address)
assert model.address == zero_address


def test_balance(zero_address):
address = Address(zero_address)
actual = address.balance
expected = 0
assert actual == expected
# Also show can compare directly to currency-str.
assert actual == "0 ETH"
Loading