Skip to content

Commit

Permalink
add Amount type
Browse files Browse the repository at this point in the history
  • Loading branch information
callebtc committed Nov 21, 2023
1 parent 1746a12 commit 7d79d18
Show file tree
Hide file tree
Showing 12 changed files with 132 additions and 65 deletions.
36 changes: 36 additions & 0 deletions cashu/core/base.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,7 @@
import base64
import json
import math
from dataclasses import dataclass
from enum import Enum
from sqlite3 import Row
from typing import Dict, List, Optional, Union
Expand Down Expand Up @@ -442,6 +444,40 @@ def str(self, amount: int) -> str:
raise Exception("Invalid unit")


@dataclass
class Amount:
unit: Unit
amount: int

def to(self, to_unit: Unit, round: Optional[str] = None):
if self.unit == to_unit:
return self

if self.unit == Unit.sat:
if to_unit == Unit.msat:
return Amount(to_unit, self.amount * 1000)
else:
raise Exception(f"Cannot convert {self.unit.name} to {to_unit.name}")
elif self.unit == Unit.msat:
if to_unit == Unit.sat:
if round == "up":
return Amount(to_unit, math.ceil(self.amount / 1000))
elif round == "down":
return Amount(to_unit, math.floor(self.amount / 1000))
else:
return Amount(to_unit, self.amount // 1000)
else:
raise Exception(f"Cannot convert {self.unit.name} to {to_unit.name}")
else:
return self

def str(self) -> str:
return self.unit.str(self.amount)

def __repr__(self):
return self.unit.str(self.amount)


class Method(Enum):
bolt11 = 0

Expand Down
18 changes: 13 additions & 5 deletions cashu/lightning/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,8 @@

from pydantic import BaseModel

from ..core.base import Amount, Unit


class StatusResponse(BaseModel):
error_message: Optional[str]
Expand All @@ -16,8 +18,8 @@ class InvoiceQuoteResponse(BaseModel):

class PaymentQuoteResponse(BaseModel):
checking_id: str
amount: int
fee: int
amount: Amount
fee: Amount


class InvoiceResponse(BaseModel):
Expand All @@ -30,14 +32,14 @@ class InvoiceResponse(BaseModel):
class PaymentResponse(BaseModel):
ok: Optional[bool] = None # True: paid, False: failed, None: pending or unknown
checking_id: Optional[str] = None
fee_msat: Optional[int] = None
fee: Optional[Amount] = None
preimage: Optional[str] = None
error_message: Optional[str] = None


class PaymentStatus(BaseModel):
paid: Optional[bool] = None
fee_msat: Optional[int] = None
fee: Optional[Amount] = None
preimage: Optional[str] = None

@property
Expand All @@ -60,14 +62,20 @@ def __str__(self) -> str:


class LightningBackend(ABC):
units: set[Unit]

def assert_unit_supported(self, unit: Unit):
if unit not in self.units:
raise Unsupported(f"Unit {unit} is not supported")

@abstractmethod
def status(self) -> Coroutine[None, None, StatusResponse]:
pass

@abstractmethod
def create_invoice(
self,
amount: int,
amount: Amount,
memo: Optional[str] = None,
description_hash: Optional[bytes] = None,
) -> Coroutine[None, None, InvoiceResponse]:
Expand Down
27 changes: 15 additions & 12 deletions cashu/lightning/corelightningrest.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,5 @@
import asyncio
import json
import math
import random
from typing import AsyncGenerator, Dict, Optional

Expand All @@ -11,6 +10,7 @@
)
from loguru import logger

from ..core.base import Amount, Unit
from ..core.helpers import fee_reserve
from ..core.settings import settings
from .base import (
Expand All @@ -26,6 +26,8 @@


class CoreLightningRestWallet(LightningBackend):
units = set([Unit.sat, Unit.msat])

def __init__(self):
macaroon = settings.mint_corelightning_rest_macaroon
assert macaroon, "missing cln-rest macaroon"
Expand Down Expand Up @@ -88,15 +90,16 @@ async def status(self) -> StatusResponse:

async def create_invoice(
self,
amount: int,
amount: Amount,
memo: Optional[str] = None,
description_hash: Optional[bytes] = None,
unhashed_description: Optional[bytes] = None,
**kwargs,
) -> InvoiceResponse:
self.assert_unit_supported(amount.unit)
label = f"lbl{random.random()}"
data: Dict = {
"amount": amount * 1000,
"amount": amount.to(Unit.msat, round="up").amount,
"description": memo,
"label": label,
}
Expand Down Expand Up @@ -151,7 +154,7 @@ async def pay_invoice(self, bolt11: str, fee_limit_msat: int) -> PaymentResponse
return PaymentResponse(
ok=False,
checking_id=None,
fee_msat=None,
fee=None,
preimage=None,
error_message=str(exc),
)
Expand All @@ -161,7 +164,7 @@ async def pay_invoice(self, bolt11: str, fee_limit_msat: int) -> PaymentResponse
return PaymentResponse(
ok=False,
checking_id=None,
fee_msat=None,
fee=None,
preimage=None,
error_message=error_message,
)
Expand All @@ -186,7 +189,7 @@ async def pay_invoice(self, bolt11: str, fee_limit_msat: int) -> PaymentResponse
return PaymentResponse(
ok=False,
checking_id=None,
fee_msat=None,
fee=None,
preimage=None,
error_message=error_message,
)
Expand All @@ -197,7 +200,7 @@ async def pay_invoice(self, bolt11: str, fee_limit_msat: int) -> PaymentResponse
return PaymentResponse(
ok=False,
checking_id=None,
fee_msat=None,
fee=None,
preimage=None,
error_message="payment failed",
)
Expand All @@ -209,7 +212,7 @@ async def pay_invoice(self, bolt11: str, fee_limit_msat: int) -> PaymentResponse
return PaymentResponse(
ok=self.statuses.get(data["status"]),
checking_id=checking_id,
fee_msat=fee_msat,
fee=Amount(unit=Unit.msat, amount=fee_msat) if fee_msat else None,
preimage=preimage,
error_message=None,
)
Expand Down Expand Up @@ -254,7 +257,7 @@ async def get_payment_status(self, checking_id: str) -> PaymentStatus:

return PaymentStatus(
paid=self.statuses.get(pay["status"]),
fee_msat=fee_msat,
fee=Amount(unit=Unit.msat, amount=fee_msat) if fee_msat else None,
preimage=preimage,
)
except Exception as e:
Expand Down Expand Up @@ -309,6 +312,6 @@ async def get_payment_quote(self, bolt11: str) -> PaymentQuoteResponse:
assert invoice_obj.amount_msat, "invoice has no amount."
amount_msat = int(invoice_obj.amount_msat)
fees_msat = fee_reserve(amount_msat)
fee_sat = math.ceil(fees_msat / 1000)
amount_sat = math.ceil(amount_msat / 1000)
return PaymentQuoteResponse(checking_id="", fee=fee_sat, amount=amount_sat)
fees = Amount(unit=Unit.msat, amount=fees_msat)
amount = Amount(unit=Unit.msat, amount=amount_msat)
return PaymentQuoteResponse(checking_id="", fee=fees, amount=amount)
20 changes: 10 additions & 10 deletions cashu/lightning/fake.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,5 @@
import asyncio
import hashlib
import math
import random
from datetime import datetime
from os import urandom
Expand All @@ -15,6 +14,7 @@
encode,
)

from ..core.base import Amount, Unit
from ..core.helpers import fee_reserve
from ..core.settings import settings
from .base import (
Expand All @@ -28,8 +28,7 @@


class FakeWallet(LightningBackend):
"""https://github.com/lnbits/lnbits"""

units = set([Unit.sat, Unit.msat])
queue: asyncio.Queue[Bolt11] = asyncio.Queue(0)
payment_secrets: Dict[str, str] = dict()
paid_invoices: Set[str] = set()
Expand All @@ -47,13 +46,14 @@ async def status(self) -> StatusResponse:

async def create_invoice(
self,
amount: int,
amount: Amount,
memo: Optional[str] = None,
description_hash: Optional[bytes] = None,
unhashed_description: Optional[bytes] = None,
expiry: Optional[int] = None,
payment_secret: Optional[bytes] = None,
) -> InvoiceResponse:
self.assert_unit_supported(amount.unit)
tags = Tags()

if description_hash:
Expand Down Expand Up @@ -83,7 +83,7 @@ async def create_invoice(

bolt11 = Bolt11(
currency="bc",
amount_msat=MilliSatoshi(amount * 1000),
amount_msat=MilliSatoshi(amount.to(Unit.msat, round="up").amount),
date=int(datetime.now().timestamp()),
tags=tags,
)
Expand All @@ -94,7 +94,7 @@ async def create_invoice(
ok=True, checking_id=payment_hash, payment_request=payment_request
)

async def pay_invoice(self, bolt11: str, fee_limit_msat: int) -> PaymentResponse:
async def pay_invoice(self, bolt11: str, fee_limit: int) -> PaymentResponse:
invoice = decode(bolt11)

if settings.fakewallet_delay_payment:
Expand All @@ -106,7 +106,7 @@ async def pay_invoice(self, bolt11: str, fee_limit_msat: int) -> PaymentResponse
return PaymentResponse(
ok=True,
checking_id=invoice.payment_hash,
fee_msat=0,
fee=Amount(unit=Unit.msat, amount=0),
preimage=self.payment_secrets.get(invoice.payment_hash) or "0" * 64,
)
else:
Expand Down Expand Up @@ -140,6 +140,6 @@ async def get_payment_quote(self, bolt11: str) -> PaymentQuoteResponse:
assert invoice_obj.amount_msat, "invoice has no amount."
amount_msat = int(invoice_obj.amount_msat)
fees_msat = fee_reserve(amount_msat)
fee_sat = math.ceil(fees_msat / 1000)
amount_sat = math.ceil(amount_msat / 1000)
return PaymentQuoteResponse(checking_id="", fee=fee_sat, amount=amount_sat)
fees = Amount(unit=Unit.msat, amount=fees_msat)
amount = Amount(unit=Unit.msat, amount=amount_msat)
return PaymentQuoteResponse(checking_id="", fee=fees, amount=amount)
16 changes: 10 additions & 6 deletions cashu/lightning/lnbits.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,12 @@
# type: ignore
import math
from typing import Optional

import httpx
from bolt11 import (
decode,
)

from ..core.base import Amount, Unit
from ..core.helpers import fee_reserve
from ..core.settings import settings
from .base import (
Expand All @@ -22,6 +22,8 @@
class LNbitsWallet(LightningBackend):
"""https://github.com/lnbits/lnbits"""

units = set([Unit.sat])

def __init__(self):
self.endpoint = settings.mint_lnbits_endpoint
self.client = httpx.AsyncClient(
Expand Down Expand Up @@ -57,12 +59,14 @@ async def status(self) -> StatusResponse:

async def create_invoice(
self,
amount: int,
amount: Amount,
memo: Optional[str] = None,
description_hash: Optional[bytes] = None,
unhashed_description: Optional[bytes] = None,
) -> InvoiceResponse:
data = {"out": False, "amount": amount}
self.assert_unit_supported(amount.unit)

data = {"out": False, "amount": amount.to(Unit.sat).amount}
if description_hash:
data["description_hash"] = description_hash.hex()
if unhashed_description:
Expand Down Expand Up @@ -153,6 +157,6 @@ async def get_payment_quote(self, bolt11: str) -> PaymentQuoteResponse:
assert invoice_obj.amount_msat, "invoice has no amount."
amount_msat = int(invoice_obj.amount_msat)
fees_msat = fee_reserve(amount_msat)
fee_sat = math.ceil(fees_msat / 1000)
amount_sat = math.ceil(amount_msat / 1000)
return PaymentQuoteResponse(checking_id="", fee=fee_sat, amount=amount_sat)
fees = Amount(unit=Unit.msat, amount=fees_msat)
amount = Amount(unit=Unit.msat, amount=amount_msat)
return PaymentQuoteResponse(checking_id="", fee=fees, amount=amount)
Loading

0 comments on commit 7d79d18

Please sign in to comment.