diff --git a/cashu/core/base.py b/cashu/core/base.py index 204d3ea0..88c83552 100644 --- a/cashu/core/base.py +++ b/cashu/core/base.py @@ -4,10 +4,10 @@ from dataclasses import dataclass from enum import Enum from sqlite3 import Row -from typing import Any, Dict, List, Optional, Union +from typing import Dict, List, Optional, Union from loguru import logger -from pydantic import BaseModel, Field +from pydantic import BaseModel from .crypto.aes import AESCipher from .crypto.b_dhke import hash_to_curve @@ -45,6 +45,21 @@ class DLEQWallet(BaseModel): # ------- PROOFS ------- +class SpentState(Enum): + unspent = "UNSPENT" + spent = "SPENT" + pending = "PENDING" + + def __str__(self): + return self.name + + +class ProofState(BaseModel): + Y: str + state: SpentState + witness: Optional[str] = None + + class HTLCWitness(BaseModel): preimage: Optional[str] = None signature: Optional[str] = None @@ -85,8 +100,7 @@ class Proof(BaseModel): Value token """ - # NOTE: None for backwards compatibility for old clients that do not include the keyset id < 0.3 - id: Union[None, str] = "" + id: str = "" amount: int = 0 secret: str = "" # secret or message to be blinded and signed Y: str = "" # hash_to_curve(secret) @@ -199,11 +213,6 @@ def p2pksigs(self) -> List[str]: return P2PKWitness.from_witness(self.witness).signatures -class BlindedMessages(BaseModel): - # NOTE: not used in Pydantic validation - __root__: List[BlindedMessage] = [] - - class BlindedSignature(BaseModel): """ Blinded signature or "promise" which is the signature on a `BlindedMessage` @@ -321,274 +330,6 @@ def from_row(cls, row: Row): ) -# ------- API ------- - -# ------- API: INFO ------- - - -class MintMeltMethodSetting(BaseModel): - method: str - unit: str - min_amount: Optional[int] = None - max_amount: Optional[int] = None - - -class GetInfoResponse(BaseModel): - name: Optional[str] = None - pubkey: Optional[str] = None - version: Optional[str] = None - description: Optional[str] = None - description_long: Optional[str] = None - contact: Optional[List[List[str]]] = None - motd: Optional[str] = None - nuts: Optional[Dict[int, Any]] = None - - -class Nut15MppSupport(BaseModel): - method: str - unit: str - mpp: bool - - -class GetInfoResponse_deprecated(BaseModel): - name: Optional[str] = None - pubkey: Optional[str] = None - version: Optional[str] = None - description: Optional[str] = None - description_long: Optional[str] = None - contact: Optional[List[List[str]]] = None - nuts: Optional[List[str]] = None - motd: Optional[str] = None - parameter: Optional[dict] = None - - -# ------- API: KEYS ------- - - -class KeysResponseKeyset(BaseModel): - id: str - unit: str - keys: Dict[int, str] - - -class KeysResponse(BaseModel): - keysets: List[KeysResponseKeyset] - - -class KeysetsResponseKeyset(BaseModel): - id: str - unit: str - active: bool - - -class KeysetsResponse(BaseModel): - keysets: list[KeysetsResponseKeyset] - - -class KeysResponse_deprecated(BaseModel): - __root__: Dict[str, str] - - -class KeysetsResponse_deprecated(BaseModel): - keysets: list[str] - - -# ------- API: MINT QUOTE ------- - - -class PostMintQuoteRequest(BaseModel): - unit: str = Field(..., max_length=settings.mint_max_request_length) # output unit - amount: int = Field(..., gt=0) # output amount - - -class PostMintQuoteResponse(BaseModel): - quote: str # quote id - request: str # input payment request - paid: bool # whether the request has been paid - expiry: Optional[int] # expiry of the quote - - -# ------- API: MINT ------- - - -class PostMintRequest(BaseModel): - quote: str = Field(..., max_length=settings.mint_max_request_length) # quote id - outputs: List[BlindedMessage] = Field( - ..., max_items=settings.mint_max_request_length - ) - - -class PostMintResponse(BaseModel): - signatures: List[BlindedSignature] = [] - - -class GetMintResponse_deprecated(BaseModel): - pr: str - hash: str - - -class PostMintRequest_deprecated(BaseModel): - outputs: List[BlindedMessage_Deprecated] = Field( - ..., max_items=settings.mint_max_request_length - ) - - -class PostMintResponse_deprecated(BaseModel): - promises: List[BlindedSignature] = [] - - -# ------- API: MELT QUOTE ------- - - -class PostMeltQuoteRequest(BaseModel): - unit: str = Field(..., max_length=settings.mint_max_request_length) # input unit - request: str = Field( - ..., max_length=settings.mint_max_request_length - ) # output payment request - amount: Optional[int] = Field(default=None, gt=0) # input amount - - -class PostMeltQuoteResponse(BaseModel): - quote: str # quote id - amount: int # input amount - fee_reserve: int # input fee reserve - paid: bool # whether the request has been paid - expiry: Optional[int] # expiry of the quote - - -# ------- API: MELT ------- - - -class PostMeltRequest(BaseModel): - quote: str = Field(..., max_length=settings.mint_max_request_length) # quote id - inputs: List[Proof] = Field(..., max_items=settings.mint_max_request_length) - outputs: Union[List[BlindedMessage], None] = Field( - None, max_items=settings.mint_max_request_length - ) - - -class PostMeltResponse(BaseModel): - paid: Union[bool, None] - payment_preimage: Union[str, None] - change: Union[List[BlindedSignature], None] = None - - -class PostMeltRequest_deprecated(BaseModel): - proofs: List[Proof] = Field(..., max_items=settings.mint_max_request_length) - pr: str = Field(..., max_length=settings.mint_max_request_length) - outputs: Union[List[BlindedMessage_Deprecated], None] = Field( - None, max_items=settings.mint_max_request_length - ) - - -class PostMeltResponse_deprecated(BaseModel): - paid: Union[bool, None] - preimage: Union[str, None] - change: Union[List[BlindedSignature], None] = None - - -# ------- API: SPLIT ------- - - -class PostSplitRequest(BaseModel): - inputs: List[Proof] = Field(..., max_items=settings.mint_max_request_length) - outputs: List[BlindedMessage] = Field( - ..., max_items=settings.mint_max_request_length - ) - - -class PostSplitResponse(BaseModel): - signatures: List[BlindedSignature] - - -# deprecated since 0.13.0 -class PostSplitRequest_Deprecated(BaseModel): - proofs: List[Proof] = Field(..., max_items=settings.mint_max_request_length) - amount: Optional[int] = None - outputs: List[BlindedMessage_Deprecated] = Field( - ..., max_items=settings.mint_max_request_length - ) - - -class PostSplitResponse_Deprecated(BaseModel): - promises: List[BlindedSignature] = [] - - -class PostSplitResponse_Very_Deprecated(BaseModel): - fst: List[BlindedSignature] = [] - snd: List[BlindedSignature] = [] - deprecated: str = "The amount field is deprecated since 0.13.0" - - -# ------- API: CHECK ------- - - -class PostCheckStateRequest(BaseModel): - Ys: List[str] = Field(..., max_items=settings.mint_max_request_length) - - -class SpentState(Enum): - unspent = "UNSPENT" - spent = "SPENT" - pending = "PENDING" - - def __str__(self): - return self.name - - -class ProofState(BaseModel): - Y: str - state: SpentState - witness: Optional[str] = None - - -class PostCheckStateResponse(BaseModel): - states: List[ProofState] = [] - - -class CheckSpendableRequest_deprecated(BaseModel): - proofs: List[Proof] = Field(..., max_items=settings.mint_max_request_length) - - -class CheckSpendableResponse_deprecated(BaseModel): - spendable: List[bool] - pending: List[bool] - - -class CheckFeesRequest_deprecated(BaseModel): - pr: str = Field(..., max_length=settings.mint_max_request_length) - - -class CheckFeesResponse_deprecated(BaseModel): - fee: Union[int, None] - - -# ------- API: RESTORE ------- - - -class PostRestoreRequest(BaseModel): - outputs: List[BlindedMessage] = Field( - ..., max_items=settings.mint_max_request_length - ) - - -class PostRestoreRequest_Deprecated(BaseModel): - outputs: List[BlindedMessage_Deprecated] = Field( - ..., max_items=settings.mint_max_request_length - ) - - -class PostRestoreResponse(BaseModel): - outputs: List[BlindedMessage] = [] - signatures: List[BlindedSignature] = [] - promises: Optional[List[BlindedSignature]] = [] # deprecated since 0.15.1 - - # duplicate value of "signatures" for backwards compatibility with old clients < 0.15.1 - def __init__(self, **data): - super().__init__(**data) - self.promises = self.signatures - - # ------- KEYSETS ------- @@ -672,6 +413,7 @@ class WalletKeyset: valid_to: Union[str, None] = None first_seen: Union[str, None] = None active: Union[bool, None] = True + input_fee_ppk: int = 0 def __init__( self, @@ -683,13 +425,14 @@ def __init__( valid_to=None, first_seen=None, active=True, - use_deprecated_id=False, # BACKWARDS COMPATIBILITY < 0.15.0 + input_fee_ppk=0, ): self.valid_from = valid_from self.valid_to = valid_to self.first_seen = first_seen self.active = active self.mint_url = mint_url + self.input_fee_ppk = input_fee_ppk self.public_keys = public_keys # overwrite id by deriving it from the public keys @@ -698,19 +441,9 @@ def __init__( else: self.id = id - # BEGIN BACKWARDS COMPATIBILITY < 0.15.0 - if use_deprecated_id: - logger.warning( - "Using deprecated keyset id derivation for backwards compatibility <" - " 0.15.0" - ) - self.id = derive_keyset_id_deprecated(self.public_keys) - # END BACKWARDS COMPATIBILITY < 0.15.0 - self.unit = Unit[unit] - logger.trace(f"Derived keyset id {self.id} from public keys.") - if id and id != self.id and use_deprecated_id: + if id and id != self.id: logger.warning( f"WARNING: Keyset id {self.id} does not match the given id {id}." " Overwriting." @@ -743,6 +476,7 @@ def deserialize(serialized: str) -> Dict[int, PublicKey]: valid_to=row["valid_to"], first_seen=row["first_seen"], active=row["active"], + input_fee_ppk=row["input_fee_ppk"], ) @@ -756,6 +490,7 @@ class MintKeyset: active: bool unit: Unit derivation_path: str + input_fee_ppk: int seed: Optional[str] = None encrypted_seed: Optional[str] = None seed_encryption_method: Optional[str] = None @@ -780,6 +515,7 @@ def __init__( active: Optional[bool] = None, unit: Optional[str] = None, version: Optional[str] = None, + input_fee_ppk: Optional[int] = None, id: str = "", ): self.derivation_path = derivation_path @@ -801,6 +537,10 @@ def __init__( self.first_seen = first_seen self.active = bool(active) if active is not None else False self.version = version or settings.version + self.input_fee_ppk = input_fee_ppk or 0 + + if self.input_fee_ppk < 0: + raise Exception("Input fee must be non-negative.") self.version_tuple = tuple( [int(i) for i in self.version.split(".")] if self.version else [] @@ -930,11 +670,14 @@ class TokenV3(BaseModel): token: List[TokenV3Token] = [] memo: Optional[str] = None + unit: Optional[str] = None def to_dict(self, include_dleq=False): return_dict = dict(token=[t.to_dict(include_dleq) for t in self.token]) if self.memo: return_dict.update(dict(memo=self.memo)) # type: ignore + if self.unit: + return_dict.update(dict(unit=self.unit)) # type: ignore return return_dict def get_proofs(self): diff --git a/cashu/core/errors.py b/cashu/core/errors.py index 96a9c263..36700acf 100644 --- a/cashu/core/errors.py +++ b/cashu/core/errors.py @@ -35,12 +35,18 @@ def __init__(self): super().__init__(self.detail, code=self.code) +class TransactionNotBalancedError(TransactionError): + code = 11002 + + def __init__(self, detail): + super().__init__(detail, code=self.code) + + class SecretTooLongError(TransactionError): - detail = "secret too long" code = 11003 - def __init__(self): - super().__init__(self.detail, code=self.code) + def __init__(self, detail="secret too long"): + super().__init__(detail, code=self.code) class NoSecretInProofsError(TransactionError): @@ -51,6 +57,13 @@ def __init__(self): super().__init__(self.detail, code=self.code) +class TransactionUnitError(TransactionError): + code = 11005 + + def __init__(self, detail): + super().__init__(detail, code=self.code) + + class KeysetError(CashuError): detail = "keyset error" code = 12000 diff --git a/cashu/core/helpers.py b/cashu/core/helpers.py index ff43e225..f3f3f0ff 100644 --- a/cashu/core/helpers.py +++ b/cashu/core/helpers.py @@ -3,10 +3,21 @@ from functools import partial, wraps from typing import List -from ..core.base import BlindedSignature, Proof +from ..core.base import Amount, BlindedSignature, Proof, Unit from ..core.settings import settings +def amount_summary(proofs: List[Proof], unit: Unit) -> str: + amounts_we_have = [ + (amount, len([p for p in proofs if p.amount == amount])) + for amount in set([p.amount for p in proofs]) + ] + amounts_we_have.sort(key=lambda x: x[0]) + return ( + f"{', '.join([f'{Amount(unit, a).str()} ({c}x)' for a, c in amounts_we_have])}" + ) + + def sum_proofs(proofs: List[Proof]): return sum([p.amount for p in proofs]) diff --git a/cashu/core/models.py b/cashu/core/models.py new file mode 100644 index 00000000..81711cdc --- /dev/null +++ b/cashu/core/models.py @@ -0,0 +1,265 @@ +from typing import Any, Dict, List, Optional, Union + +from pydantic import BaseModel, Field + +from .base import ( + BlindedMessage, + BlindedMessage_Deprecated, + BlindedSignature, + Proof, + ProofState, +) +from .settings import settings + +# ------- API ------- + +# ------- API: INFO ------- + + +class MintMeltMethodSetting(BaseModel): + method: str + unit: str + min_amount: Optional[int] = None + max_amount: Optional[int] = None + + +class GetInfoResponse(BaseModel): + name: Optional[str] = None + pubkey: Optional[str] = None + version: Optional[str] = None + description: Optional[str] = None + description_long: Optional[str] = None + contact: Optional[List[List[str]]] = None + motd: Optional[str] = None + nuts: Optional[Dict[int, Any]] = None + + +class Nut15MppSupport(BaseModel): + method: str + unit: str + mpp: bool + + +class GetInfoResponse_deprecated(BaseModel): + name: Optional[str] = None + pubkey: Optional[str] = None + version: Optional[str] = None + description: Optional[str] = None + description_long: Optional[str] = None + contact: Optional[List[List[str]]] = None + nuts: Optional[List[str]] = None + motd: Optional[str] = None + parameter: Optional[dict] = None + + +# ------- API: KEYS ------- + + +class KeysResponseKeyset(BaseModel): + id: str + unit: str + keys: Dict[int, str] + + +class KeysResponse(BaseModel): + keysets: List[KeysResponseKeyset] + + +class KeysetsResponseKeyset(BaseModel): + id: str + unit: str + active: bool + input_fee_ppk: Optional[int] = None + + +class KeysetsResponse(BaseModel): + keysets: list[KeysetsResponseKeyset] + + +class KeysResponse_deprecated(BaseModel): + __root__: Dict[str, str] + + +class KeysetsResponse_deprecated(BaseModel): + keysets: list[str] + + +# ------- API: MINT QUOTE ------- + + +class PostMintQuoteRequest(BaseModel): + unit: str = Field(..., max_length=settings.mint_max_request_length) # output unit + amount: int = Field(..., gt=0) # output amount + + +class PostMintQuoteResponse(BaseModel): + quote: str # quote id + request: str # input payment request + paid: bool # whether the request has been paid + expiry: Optional[int] # expiry of the quote + + +# ------- API: MINT ------- + + +class PostMintRequest(BaseModel): + quote: str = Field(..., max_length=settings.mint_max_request_length) # quote id + outputs: List[BlindedMessage] = Field( + ..., max_items=settings.mint_max_request_length + ) + + +class PostMintResponse(BaseModel): + signatures: List[BlindedSignature] = [] + + +class GetMintResponse_deprecated(BaseModel): + pr: str + hash: str + + +class PostMintRequest_deprecated(BaseModel): + outputs: List[BlindedMessage_Deprecated] = Field( + ..., max_items=settings.mint_max_request_length + ) + + +class PostMintResponse_deprecated(BaseModel): + promises: List[BlindedSignature] = [] + + +# ------- API: MELT QUOTE ------- + + +class PostMeltQuoteRequest(BaseModel): + unit: str = Field(..., max_length=settings.mint_max_request_length) # input unit + request: str = Field( + ..., max_length=settings.mint_max_request_length + ) # output payment request + amount: Optional[int] = Field(default=None, gt=0) # input amount + + +class PostMeltQuoteResponse(BaseModel): + quote: str # quote id + amount: int # input amount + fee_reserve: int # input fee reserve + paid: bool # whether the request has been paid + expiry: Optional[int] # expiry of the quote + + +# ------- API: MELT ------- + + +class PostMeltRequest(BaseModel): + quote: str = Field(..., max_length=settings.mint_max_request_length) # quote id + inputs: List[Proof] = Field(..., max_items=settings.mint_max_request_length) + outputs: Union[List[BlindedMessage], None] = Field( + None, max_items=settings.mint_max_request_length + ) + + +class PostMeltResponse(BaseModel): + paid: Union[bool, None] + payment_preimage: Union[str, None] + change: Union[List[BlindedSignature], None] = None + + +class PostMeltRequest_deprecated(BaseModel): + proofs: List[Proof] = Field(..., max_items=settings.mint_max_request_length) + pr: str = Field(..., max_length=settings.mint_max_request_length) + outputs: Union[List[BlindedMessage_Deprecated], None] = Field( + None, max_items=settings.mint_max_request_length + ) + + +class PostMeltResponse_deprecated(BaseModel): + paid: Union[bool, None] + preimage: Union[str, None] + change: Union[List[BlindedSignature], None] = None + + +# ------- API: SPLIT ------- + + +class PostSplitRequest(BaseModel): + inputs: List[Proof] = Field(..., max_items=settings.mint_max_request_length) + outputs: List[BlindedMessage] = Field( + ..., max_items=settings.mint_max_request_length + ) + + +class PostSplitResponse(BaseModel): + signatures: List[BlindedSignature] + + +# deprecated since 0.13.0 +class PostSplitRequest_Deprecated(BaseModel): + proofs: List[Proof] = Field(..., max_items=settings.mint_max_request_length) + amount: Optional[int] = None + outputs: List[BlindedMessage_Deprecated] = Field( + ..., max_items=settings.mint_max_request_length + ) + + +class PostSplitResponse_Deprecated(BaseModel): + promises: List[BlindedSignature] = [] + + +class PostSplitResponse_Very_Deprecated(BaseModel): + fst: List[BlindedSignature] = [] + snd: List[BlindedSignature] = [] + deprecated: str = "The amount field is deprecated since 0.13.0" + + +# ------- API: CHECK ------- + + +class PostCheckStateRequest(BaseModel): + Ys: List[str] = Field(..., max_items=settings.mint_max_request_length) + + +class PostCheckStateResponse(BaseModel): + states: List[ProofState] = [] + + +class CheckSpendableRequest_deprecated(BaseModel): + proofs: List[Proof] = Field(..., max_items=settings.mint_max_request_length) + + +class CheckSpendableResponse_deprecated(BaseModel): + spendable: List[bool] + pending: List[bool] + + +class CheckFeesRequest_deprecated(BaseModel): + pr: str = Field(..., max_length=settings.mint_max_request_length) + + +class CheckFeesResponse_deprecated(BaseModel): + fee: Union[int, None] + + +# ------- API: RESTORE ------- + + +class PostRestoreRequest(BaseModel): + outputs: List[BlindedMessage] = Field( + ..., max_items=settings.mint_max_request_length + ) + + +class PostRestoreRequest_Deprecated(BaseModel): + outputs: List[BlindedMessage_Deprecated] = Field( + ..., max_items=settings.mint_max_request_length + ) + + +class PostRestoreResponse(BaseModel): + outputs: List[BlindedMessage] = [] + signatures: List[BlindedSignature] = [] + promises: Optional[List[BlindedSignature]] = [] # deprecated since 0.15.1 + + # duplicate value of "signatures" for backwards compatibility with old clients < 0.15.1 + def __init__(self, **data): + super().__init__(**data) + self.promises = self.signatures diff --git a/cashu/core/settings.py b/cashu/core/settings.py index f28bb4b4..205c3099 100644 --- a/cashu/core/settings.py +++ b/cashu/core/settings.py @@ -58,6 +58,9 @@ class MintSettings(CashuSettings): mint_database: str = Field(default="data/mint") mint_test_database: str = Field(default="test_data/test_mint") + mint_max_secret_length: int = Field(default=512) + + mint_input_fee_ppk: int = Field(default=0) class MintBackends(MintSettings): @@ -170,6 +173,8 @@ class WalletSettings(CashuSettings): locktime_delta_seconds: int = Field(default=86400) # 1 day proofs_batch_size: int = Field(default=1000) + wallet_target_amount_count: int = Field(default=3) + class LndRestFundingSource(MintSettings): mint_lnd_rest_endpoint: Optional[str] = Field(default=None) diff --git a/cashu/lightning/base.py b/cashu/lightning/base.py index 8d35128e..afde6786 100644 --- a/cashu/lightning/base.py +++ b/cashu/lightning/base.py @@ -6,9 +6,9 @@ from ..core.base import ( Amount, MeltQuote, - PostMeltQuoteRequest, Unit, ) +from ..core.models import PostMeltQuoteRequest class StatusResponse(BaseModel): diff --git a/cashu/lightning/blink.py b/cashu/lightning/blink.py index e7ed2c60..42f6016a 100644 --- a/cashu/lightning/blink.py +++ b/cashu/lightning/blink.py @@ -11,7 +11,8 @@ ) from loguru import logger -from ..core.base import Amount, MeltQuote, PostMeltQuoteRequest, Unit +from ..core.base import Amount, MeltQuote, Unit +from ..core.models import PostMeltQuoteRequest from ..core.settings import settings from .base import ( InvoiceResponse, diff --git a/cashu/lightning/corelightningrest.py b/cashu/lightning/corelightningrest.py index d2fbbf31..ccc06772 100644 --- a/cashu/lightning/corelightningrest.py +++ b/cashu/lightning/corelightningrest.py @@ -10,8 +10,9 @@ ) from loguru import logger -from ..core.base import Amount, MeltQuote, PostMeltQuoteRequest, Unit +from ..core.base import Amount, MeltQuote, Unit from ..core.helpers import fee_reserve +from ..core.models import PostMeltQuoteRequest from ..core.settings import settings from .base import ( InvoiceResponse, diff --git a/cashu/lightning/fake.py b/cashu/lightning/fake.py index 5a8bcdcf..9ad5682c 100644 --- a/cashu/lightning/fake.py +++ b/cashu/lightning/fake.py @@ -15,8 +15,9 @@ encode, ) -from ..core.base import Amount, MeltQuote, PostMeltQuoteRequest, Unit +from ..core.base import Amount, MeltQuote, Unit from ..core.helpers import fee_reserve +from ..core.models import PostMeltQuoteRequest from ..core.settings import settings from .base import ( InvoiceResponse, diff --git a/cashu/lightning/lnbits.py b/cashu/lightning/lnbits.py index 35894a30..721e7046 100644 --- a/cashu/lightning/lnbits.py +++ b/cashu/lightning/lnbits.py @@ -6,8 +6,9 @@ decode, ) -from ..core.base import Amount, MeltQuote, PostMeltQuoteRequest, Unit +from ..core.base import Amount, MeltQuote, Unit from ..core.helpers import fee_reserve +from ..core.models import PostMeltQuoteRequest from ..core.settings import settings from .base import ( InvoiceResponse, diff --git a/cashu/lightning/lndrest.py b/cashu/lightning/lndrest.py index 04d6bc39..187912ff 100644 --- a/cashu/lightning/lndrest.py +++ b/cashu/lightning/lndrest.py @@ -12,8 +12,9 @@ ) from loguru import logger -from ..core.base import Amount, MeltQuote, PostMeltQuoteRequest, Unit +from ..core.base import Amount, MeltQuote, Unit from ..core.helpers import fee_reserve +from ..core.models import PostMeltQuoteRequest from ..core.settings import settings from .base import ( InvoiceResponse, diff --git a/cashu/lightning/strike.py b/cashu/lightning/strike.py index 7149a582..41c8d718 100644 --- a/cashu/lightning/strike.py +++ b/cashu/lightning/strike.py @@ -4,7 +4,8 @@ import httpx -from ..core.base import Amount, MeltQuote, PostMeltQuoteRequest, Unit +from ..core.base import Amount, MeltQuote, Unit +from ..core.models import PostMeltQuoteRequest from ..core.settings import settings from .base import ( InvoiceResponse, diff --git a/cashu/mint/crud.py b/cashu/mint/crud.py index 30d30b1c..a5d5a71a 100644 --- a/cashu/mint/crud.py +++ b/cashu/mint/crud.py @@ -34,7 +34,8 @@ async def get_keyset( derivation_path: str = "", seed: str = "", conn: Optional[Connection] = None, - ) -> List[MintKeyset]: ... + ) -> List[MintKeyset]: + ... @abstractmethod async def get_spent_proofs( @@ -42,7 +43,8 @@ async def get_spent_proofs( *, db: Database, conn: Optional[Connection] = None, - ) -> List[Proof]: ... + ) -> List[Proof]: + ... async def get_proof_used( self, @@ -50,7 +52,8 @@ async def get_proof_used( Y: str, db: Database, conn: Optional[Connection] = None, - ) -> Optional[Proof]: ... + ) -> Optional[Proof]: + ... @abstractmethod async def invalidate_proof( @@ -60,7 +63,8 @@ async def invalidate_proof( proof: Proof, quote_id: Optional[str] = None, conn: Optional[Connection] = None, - ) -> None: ... + ) -> None: + ... @abstractmethod async def get_all_melt_quotes_from_pending_proofs( @@ -68,7 +72,8 @@ async def get_all_melt_quotes_from_pending_proofs( *, db: Database, conn: Optional[Connection] = None, - ) -> List[MeltQuote]: ... + ) -> List[MeltQuote]: + ... @abstractmethod async def get_pending_proofs_for_quote( @@ -77,7 +82,8 @@ async def get_pending_proofs_for_quote( quote_id: str, db: Database, conn: Optional[Connection] = None, - ) -> List[Proof]: ... + ) -> List[Proof]: + ... @abstractmethod async def get_proofs_pending( @@ -86,7 +92,8 @@ async def get_proofs_pending( Ys: List[str], db: Database, conn: Optional[Connection] = None, - ) -> List[Proof]: ... + ) -> List[Proof]: + ... @abstractmethod async def set_proof_pending( @@ -96,7 +103,8 @@ async def set_proof_pending( proof: Proof, quote_id: Optional[str] = None, conn: Optional[Connection] = None, - ) -> None: ... + ) -> None: + ... @abstractmethod async def unset_proof_pending( @@ -105,7 +113,8 @@ async def unset_proof_pending( proof: Proof, db: Database, conn: Optional[Connection] = None, - ) -> None: ... + ) -> None: + ... @abstractmethod async def store_keyset( @@ -114,14 +123,16 @@ async def store_keyset( db: Database, keyset: MintKeyset, conn: Optional[Connection] = None, - ) -> None: ... + ) -> None: + ... @abstractmethod async def get_balance( self, db: Database, conn: Optional[Connection] = None, - ) -> int: ... + ) -> int: + ... @abstractmethod async def store_promise( @@ -135,7 +146,8 @@ async def store_promise( e: str = "", s: str = "", conn: Optional[Connection] = None, - ) -> None: ... + ) -> None: + ... @abstractmethod async def get_promise( @@ -144,7 +156,8 @@ async def get_promise( db: Database, b_: str, conn: Optional[Connection] = None, - ) -> Optional[BlindedSignature]: ... + ) -> Optional[BlindedSignature]: + ... @abstractmethod async def store_mint_quote( @@ -153,7 +166,8 @@ async def store_mint_quote( quote: MintQuote, db: Database, conn: Optional[Connection] = None, - ) -> None: ... + ) -> None: + ... @abstractmethod async def get_mint_quote( @@ -162,7 +176,8 @@ async def get_mint_quote( quote_id: str, db: Database, conn: Optional[Connection] = None, - ) -> Optional[MintQuote]: ... + ) -> Optional[MintQuote]: + ... @abstractmethod async def get_mint_quote_by_request( @@ -171,7 +186,8 @@ async def get_mint_quote_by_request( request: str, db: Database, conn: Optional[Connection] = None, - ) -> Optional[MintQuote]: ... + ) -> Optional[MintQuote]: + ... @abstractmethod async def update_mint_quote( @@ -180,7 +196,8 @@ async def update_mint_quote( quote: MintQuote, db: Database, conn: Optional[Connection] = None, - ) -> None: ... + ) -> None: + ... # @abstractmethod # async def update_mint_quote_paid( @@ -199,7 +216,8 @@ async def store_melt_quote( quote: MeltQuote, db: Database, conn: Optional[Connection] = None, - ) -> None: ... + ) -> None: + ... @abstractmethod async def get_melt_quote( @@ -209,7 +227,8 @@ async def get_melt_quote( db: Database, checking_id: Optional[str] = None, conn: Optional[Connection] = None, - ) -> Optional[MeltQuote]: ... + ) -> Optional[MeltQuote]: + ... @abstractmethod async def update_melt_quote( @@ -218,7 +237,8 @@ async def update_melt_quote( quote: MeltQuote, db: Database, conn: Optional[Connection] = None, - ) -> None: ... + ) -> None: + ... class LedgerCrudSqlite(LedgerCrud): @@ -586,8 +606,8 @@ async def store_keyset( await (conn or db).execute( # type: ignore f""" INSERT INTO {table_with_schema(db, 'keysets')} - (id, seed, encrypted_seed, seed_encryption_method, derivation_path, valid_from, valid_to, first_seen, active, version, unit) - VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?) + (id, seed, encrypted_seed, seed_encryption_method, derivation_path, valid_from, valid_to, first_seen, active, version, unit, input_fee_ppk) + VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?) """, ( keyset.id, @@ -601,6 +621,7 @@ async def store_keyset( True, keyset.version, keyset.unit.name, + keyset.input_fee_ppk, ), ) diff --git a/cashu/mint/ledger.py b/cashu/mint/ledger.py index 7d17e2c9..cb1c1bb7 100644 --- a/cashu/mint/ledger.py +++ b/cashu/mint/ledger.py @@ -14,9 +14,6 @@ Method, MintKeyset, MintQuote, - PostMeltQuoteRequest, - PostMeltQuoteResponse, - PostMintQuoteRequest, Proof, ProofState, SpentState, @@ -40,6 +37,11 @@ TransactionError, ) from ..core.helpers import sum_proofs +from ..core.models import ( + PostMeltQuoteRequest, + PostMeltQuoteResponse, + PostMintQuoteRequest, +) from ..core.settings import settings from ..core.split import amount_split from ..lightning.base import ( @@ -216,6 +218,7 @@ async def activate_keyset( seed=seed or self.seed, derivation_path=derivation_path, version=version or settings.version, + input_fee_ppk=settings.mint_input_fee_ppk, ) logger.debug(f"Generated new keyset {keyset.id}.") if autosave: @@ -298,9 +301,8 @@ async def _invalidate_proofs( async def _generate_change_promises( self, - input_amount: int, - output_amount: int, - output_fee_paid: int, + fee_provided: int, + fee_paid: int, outputs: Optional[List[BlindedMessage]], keyset: Optional[MintKeyset] = None, ) -> List[BlindedSignature]: @@ -326,34 +328,35 @@ async def _generate_change_promises( List[BlindedSignature]: Signatures on the outputs. """ # we make sure that the fee is positive - user_fee_paid = input_amount - output_amount - overpaid_fee = user_fee_paid - output_fee_paid + overpaid_fee = fee_provided - fee_paid + + if overpaid_fee == 0 or outputs is None: + return [] + logger.debug( - f"Lightning fee was: {output_fee_paid}. User paid: {user_fee_paid}. " + f"Lightning fee was: {fee_paid}. User provided: {fee_provided}. " f"Returning difference: {overpaid_fee}." ) - if overpaid_fee > 0 and outputs is not None: - return_amounts = amount_split(overpaid_fee) - - # We return at most as many outputs as were provided or as many as are - # required to pay back the overpaid fee. - n_return_outputs = min(len(outputs), len(return_amounts)) - - # we only need as many outputs as we have change to return - outputs = outputs[:n_return_outputs] - # we sort the return_amounts in descending order so we only - # take the largest values in the next step - return_amounts_sorted = sorted(return_amounts, reverse=True) - # we need to imprint these amounts into the blanket outputs - for i in range(len(outputs)): - outputs[i].amount = return_amounts_sorted[i] # type: ignore - if not self._verify_no_duplicate_outputs(outputs): - raise TransactionError("duplicate promises.") - return_promises = await self._generate_promises(outputs, keyset) - return return_promises - else: - return [] + return_amounts = amount_split(overpaid_fee) + + # We return at most as many outputs as were provided or as many as are + # required to pay back the overpaid fee. + n_return_outputs = min(len(outputs), len(return_amounts)) + + # we only need as many outputs as we have change to return + outputs = outputs[:n_return_outputs] + + # we sort the return_amounts in descending order so we only + # take the largest values in the next step + return_amounts_sorted = sorted(return_amounts, reverse=True) + # we need to imprint these amounts into the blanket outputs + for i in range(len(outputs)): + outputs[i].amount = return_amounts_sorted[i] # type: ignore + if not self._verify_no_duplicate_outputs(outputs): + raise TransactionError("duplicate promises.") + return_promises = await self._generate_promises(outputs, keyset) + return return_promises # ------- TRANSACTIONS ------- @@ -488,18 +491,14 @@ async def mint( logger.trace("called mint") await self._verify_outputs(outputs) sum_amount_outputs = sum([b.amount for b in outputs]) - - output_units = set([k.unit for k in [self.keysets[o.id] for o in outputs]]) - if not len(output_units) == 1: - raise TransactionError("outputs have different units") - output_unit = list(output_units)[0] + # we already know from _verify_outputs that all outputs have the same unit because they have the same keyset + output_unit = self.keysets[outputs[0].id].unit self.locks[quote_id] = ( self.locks.get(quote_id) or asyncio.Lock() ) # create a new lock if it doesn't exist async with self.locks[quote_id]: quote = await self.get_mint_quote(quote_id=quote_id) - if not quote.paid: raise QuoteNotPaidError() if quote.issued: @@ -564,14 +563,17 @@ async def melt_quote( if not mint_quote.checking_id: raise TransactionError("mint quote has no checking id") + internal_fee = Amount(unit, 0) # no internal fees + amount = Amount(unit, mint_quote.amount) + payment_quote = PaymentQuoteResponse( checking_id=mint_quote.checking_id, - amount=Amount(unit, mint_quote.amount), - fee=Amount(unit, amount=0), + amount=amount, + fee=internal_fee, ) logger.info( f"Issuing internal melt quote: {request} ->" - f" {mint_quote.quote} ({mint_quote.amount} {mint_quote.unit})" + f" {mint_quote.quote} ({amount.str()} + {internal_fee.str()} fees)" ) else: # not internal, get payment quote by backend @@ -586,6 +588,15 @@ async def melt_quote( if not payment_quote.fee.unit == unit: raise TransactionError("payment quote fee units do not match") + # verify that the amount of the proofs is not larger than the maximum allowed + if ( + settings.mint_max_peg_out + and payment_quote.amount.to(unit).amount > settings.mint_max_peg_out + ): + raise NotAllowedError( + f"Maximum melt amount is {settings.mint_max_peg_out} sat." + ) + # We assume that the request is a bolt11 invoice, this works since we # support only the bol11 method for now. invoice_obj = bolt11.decode(melt_quote.request) @@ -667,11 +678,16 @@ async def get_melt_quote(self, quote_id: str) -> MeltQuote: return melt_quote - async def melt_mint_settle_internally(self, melt_quote: MeltQuote) -> MeltQuote: + async def melt_mint_settle_internally( + self, melt_quote: MeltQuote, proofs: List[Proof] + ) -> MeltQuote: """Settles a melt quote internally if there is a mint quote with the same payment request. + `proofs` are passed to determine the ecash input transaction fees for this melt quote. + Args: melt_quote (MeltQuote): Melt quote to settle. + proofs (List[Proof]): Proofs provided for paying the Lightning invoice. Raises: Exception: Melt quote already paid. @@ -687,6 +703,7 @@ async def melt_mint_settle_internally(self, melt_quote: MeltQuote) -> MeltQuote: ) if not mint_quote: return melt_quote + # we settle the transaction internally if melt_quote.paid: raise TransactionError("melt quote already paid") @@ -715,15 +732,16 @@ async def melt_mint_settle_internally(self, melt_quote: MeltQuote) -> MeltQuote: f" {mint_quote.quote} ({melt_quote.amount} {melt_quote.unit})" ) - # we handle this transaction internally - melt_quote.fee_paid = 0 + melt_quote.fee_paid = 0 # no internal fees melt_quote.paid = True melt_quote.paid_time = int(time.time()) - await self.crud.update_melt_quote(quote=melt_quote, db=self.db) mint_quote.paid = True mint_quote.paid_time = melt_quote.paid_time - await self.crud.update_mint_quote(quote=mint_quote, db=self.db) + + async with self.db.connect() as conn: + await self.crud.update_melt_quote(quote=melt_quote, db=self.db, conn=conn) + await self.crud.update_mint_quote(quote=mint_quote, db=self.db, conn=conn) return melt_quote @@ -759,6 +777,7 @@ async def melt( # make sure that the outputs (for fee return) are in the same unit as the quote if outputs: + # _verify_outputs checks if all outputs have the same unit await self._verify_outputs(outputs, skip_amount_check=True) outputs_unit = self.keysets[outputs[0].id].unit if not melt_quote.unit == outputs_unit.name: @@ -768,11 +787,18 @@ async def melt( # verify that the amount of the input proofs is equal to the amount of the quote total_provided = sum_proofs(proofs) - total_needed = melt_quote.amount + (melt_quote.fee_reserve or 0) - if not total_provided >= total_needed: + input_fees = self.get_fees_for_proofs(proofs) + total_needed = melt_quote.amount + melt_quote.fee_reserve + input_fees + # we need the fees specifically for lightning to return the overpaid fees + fee_reserve_provided = total_provided - melt_quote.amount - input_fees + if total_provided < total_needed: raise TransactionError( f"not enough inputs provided for melt. Provided: {total_provided}, needed: {total_needed}" ) + if fee_reserve_provided < melt_quote.fee_reserve: + raise TransactionError( + f"not enough fee reserve provided for melt. Provided fee reserve: {fee_reserve_provided}, needed: {melt_quote.fee_reserve}" + ) # verify that the amount of the proofs is not larger than the maximum allowed if settings.mint_max_peg_out and total_provided > settings.mint_max_peg_out: @@ -789,7 +815,7 @@ async def melt( await self._set_proofs_pending(proofs, quote_id=melt_quote.quote) try: # settle the transaction internally if there is a mint quote with the same payment request - melt_quote = await self.melt_mint_settle_internally(melt_quote) + melt_quote = await self.melt_mint_settle_internally(melt_quote, proofs) # quote not paid yet (not internal), pay it with the backend if not melt_quote.paid: logger.debug(f"Lightning: pay invoice {melt_quote.request}") @@ -822,9 +848,8 @@ async def melt( return_promises: List[BlindedSignature] = [] if outputs: return_promises = await self._generate_change_promises( - input_amount=total_provided, - output_amount=melt_quote.amount, - output_fee_paid=melt_quote.fee_paid, + fee_provided=fee_reserve_provided, + fee_paid=melt_quote.fee_paid, outputs=outputs, keyset=self.keysets[outputs[0].id], ) @@ -898,12 +923,6 @@ async def restore( b_=output.B_, db=self.db, conn=conn ) if promise is not None: - # BEGIN backwards compatibility mints pre `m007_proofs_and_promises_store_id` - # add keyset id to promise if not present only if the current keyset - # is the only one ever used - if not promise.id and len(self.keysets) == 1: - promise.id = self.keyset.id - # END backwards compatibility signatures.append(promise) return_outputs.append(output) logger.trace(f"promise found: {promise}") diff --git a/cashu/mint/migrations.py b/cashu/mint/migrations.py index 664b80c4..73ad7e8e 100644 --- a/cashu/mint/migrations.py +++ b/cashu/mint/migrations.py @@ -763,3 +763,13 @@ async def m018_duplicate_deprecated_keyset_ids(db: Database): keyset.seed_encryption_method, ), ) + + +async def m019_add_fee_to_keysets(db: Database): + async with db.connect() as conn: + await conn.execute( + f"ALTER TABLE {table_with_schema(db, 'keysets')} ADD COLUMN input_fee_ppk INTEGER" + ) + await conn.execute( + f"UPDATE {table_with_schema(db, 'keysets')} SET input_fee_ppk = 0" + ) diff --git a/cashu/mint/router.py b/cashu/mint/router.py index 6011e09d..79277348 100644 --- a/cashu/mint/router.py +++ b/cashu/mint/router.py @@ -3,7 +3,8 @@ from fastapi import APIRouter, Request from loguru import logger -from ..core.base import ( +from ..core.errors import KeysetNotFoundError +from ..core.models import ( GetInfoResponse, KeysetsResponse, KeysetsResponseKeyset, @@ -25,7 +26,6 @@ PostSplitRequest, PostSplitResponse, ) -from ..core.errors import KeysetNotFoundError from ..core.settings import settings from ..mint.startup import ledger from .limit import limiter @@ -182,7 +182,10 @@ async def keysets() -> KeysetsResponse: for id, keyset in ledger.keysets.items(): keysets.append( KeysetsResponseKeyset( - id=id, unit=keyset.unit.name, active=keyset.active or False + id=keyset.id, + unit=keyset.unit.name, + active=keyset.active, + input_fee_ppk=keyset.input_fee_ppk, ) ) return KeysetsResponse(keysets=keysets) diff --git a/cashu/mint/router_deprecated.py b/cashu/mint/router_deprecated.py index 049bc528..67976d2a 100644 --- a/cashu/mint/router_deprecated.py +++ b/cashu/mint/router_deprecated.py @@ -3,9 +3,9 @@ from fastapi import APIRouter, Request from loguru import logger -from ..core.base import ( - BlindedMessage, - BlindedSignature, +from ..core.base import BlindedMessage, BlindedSignature, SpentState +from ..core.errors import CashuError +from ..core.models import ( CheckFeesRequest_deprecated, CheckFeesResponse_deprecated, CheckSpendableRequest_deprecated, @@ -25,9 +25,7 @@ PostSplitRequest_Deprecated, PostSplitResponse_Deprecated, PostSplitResponse_Very_Deprecated, - SpentState, ) -from ..core.errors import CashuError from ..core.settings import settings from .limit import limiter from .startup import ledger diff --git a/cashu/mint/verification.py b/cashu/mint/verification.py index c11fbe66..e97daf92 100644 --- a/cashu/mint/verification.py +++ b/cashu/mint/verification.py @@ -1,3 +1,4 @@ +import math from typing import Dict, List, Literal, Optional, Tuple, Union from loguru import logger @@ -19,6 +20,7 @@ SecretTooLongError, TokenAlreadySpentError, TransactionError, + TransactionUnitError, ) from ..core.settings import settings from ..lightning.base import LightningBackend @@ -47,7 +49,7 @@ async def verify_inputs_and_outputs( Args: proofs (List[Proof]): List of proofs to check. outputs (Optional[List[BlindedMessage]], optional): List of outputs to check. - Must be provided for a swap but not for a melt. Defaults to None. + Must be provided for a swap but not for a melt. Defaults to None. Raises: Exception: Scripts did not validate. @@ -75,7 +77,8 @@ async def verify_inputs_and_outputs( if not all([self._verify_input_spending_conditions(p) for p in proofs]): raise TransactionError("validation of input spending conditions failed.") - if not outputs: + if outputs is None: + # If no outputs are provided, we are melting return # Verify input and output amounts @@ -94,7 +97,6 @@ async def verify_inputs_and_outputs( [ self.keysets[p.id].unit == self.keysets[outputs[0].id].unit for p in proofs - if p.id ] ): raise TransactionError("input and output keysets have different units.") @@ -108,6 +110,8 @@ async def _verify_outputs( ): """Verify that the outputs are valid.""" logger.trace(f"Verifying {len(outputs)} outputs.") + if not outputs: + raise TransactionError("no outputs provided.") # Verify all outputs have the same keyset id if not all([o.id == outputs[0].id for o in outputs]): raise TransactionError("outputs have different keyset ids.") @@ -182,23 +186,21 @@ def _verify_secret_criteria(self, proof: Proof) -> Literal[True]: """Verifies that a secret is present and is not too long (DOS prevention).""" if proof.secret is None or proof.secret == "": raise NoSecretInProofsError() - if len(proof.secret) > 512: - raise SecretTooLongError() + if len(proof.secret) > settings.mint_max_secret_length: + raise SecretTooLongError( + f"secret too long. max: {settings.mint_max_secret_length}" + ) return True def _verify_proof_bdhke(self, proof: Proof) -> bool: """Verifies that the proof of promise was issued by this ledger.""" - # if no keyset id is given in proof, assume the current one - if not proof.id: - private_key_amount = self.keyset.private_keys[proof.amount] - else: - assert proof.id in self.keysets, f"keyset {proof.id} unknown" - logger.trace( - f"Validating proof {proof.secret} with keyset" - f" {self.keysets[proof.id].id}." - ) - # use the appropriate active keyset for this proof.id - private_key_amount = self.keysets[proof.id].private_keys[proof.amount] + assert proof.id in self.keysets, f"keyset {proof.id} unknown" + logger.trace( + f"Validating proof {proof.secret} with keyset" + f" {self.keysets[proof.id].id}." + ) + # use the appropriate active keyset for this proof.id + private_key_amount = self.keysets[proof.id].private_keys[proof.amount] C = PublicKey(bytes.fromhex(proof.C), raw=True) valid = b_dhke.verify(private_key_amount, C, proof.secret) @@ -231,23 +233,53 @@ def _verify_no_duplicate_outputs(self, outputs: List[BlindedMessage]) -> bool: def _verify_amount(self, amount: int) -> int: """Any amount used should be positive and not larger than 2^MAX_ORDER.""" valid = amount > 0 and amount < 2**settings.max_order - logger.trace(f"Verifying amount {amount} is valid: {valid}") if not valid: raise NotAllowedError("invalid amount: " + str(amount)) return amount - def _verify_equation_balanced( + def _verify_units_match( self, proofs: List[Proof], outs: Union[List[BlindedSignature], List[BlindedMessage]], + ) -> Unit: + """Verifies that the units of the inputs and outputs match.""" + units_proofs = [self.keysets[p.id].unit for p in proofs] + units_outputs = [self.keysets[o.id].unit for o in outs if o.id] + if not len(set(units_proofs)) == 1: + raise TransactionUnitError("inputs have different units.") + if not len(set(units_outputs)) == 1: + raise TransactionUnitError("outputs have different units.") + if not units_proofs[0] == units_outputs[0]: + raise TransactionUnitError("input and output keysets have different units.") + return units_proofs[0] + + def get_fees_for_proofs(self, proofs: List[Proof]) -> int: + if not len(set([self.keysets[p.id].unit for p in proofs])) == 1: + raise TransactionUnitError("inputs have different units.") + fee = math.ceil(sum([self.keysets[p.id].input_fee_ppk for p in proofs]) / 1000) + return fee + + def _verify_equation_balanced( + self, + proofs: List[Proof], + outs: List[BlindedMessage], ) -> None: """Verify that Σinputs - Σoutputs = 0. Outputs can be BlindedSignature or BlindedMessage. """ + if not proofs: + raise TransactionError("no proofs provided.") + if not outs: + raise TransactionError("no outputs provided.") + + _ = self._verify_units_match(proofs, outs) sum_inputs = sum(self._verify_amount(p.amount) for p in proofs) + fees_inputs = self.get_fees_for_proofs(proofs) sum_outputs = sum(self._verify_amount(p.amount) for p in outs) - if not sum_outputs - sum_inputs == 0: - raise TransactionError("inputs do not have same amount as outputs.") + if not sum_outputs + fees_inputs - sum_inputs == 0: + raise TransactionError( + f"inputs ({sum_inputs}) - fees ({fees_inputs}) vs outputs ({sum_outputs}) are not balanced." + ) def _verify_and_get_unit_method( self, unit_str: str, method_str: str diff --git a/cashu/wallet/api/router.py b/cashu/wallet/api/router.py index 611ceaa9..ab8b96e7 100644 --- a/cashu/wallet/api/router.py +++ b/cashu/wallet/api/router.py @@ -189,7 +189,7 @@ async def swap( # pay invoice from outgoing mint await outgoing_wallet.load_proofs(reload=True) - quote = await outgoing_wallet.request_melt(invoice.bolt11) + quote = await outgoing_wallet.melt_quote(invoice.bolt11) total_amount = quote.amount + quote.fee_reserve if outgoing_wallet.available_balance < total_amount: raise Exception("balance too low") @@ -237,16 +237,14 @@ async def send_command( default=None, description="Mint URL to send from (None for default mint)", ), - nosplit: bool = Query( - default=False, description="Do not split tokens before sending." - ), + offline: bool = Query(default=False, description="Force offline send."), ): global wallet if mint: wallet = await mint_wallet(mint) if not nostr: balance, token = await send( - wallet, amount=amount, lock=lock, legacy=False, split=not nosplit + wallet, amount=amount, lock=lock, legacy=False, offline=offline ) return SendResponse(balance=balance, token=token) else: diff --git a/cashu/wallet/cli/cli.py b/cashu/wallet/cli/cli.py index cd4accf7..60747f9d 100644 --- a/cashu/wallet/cli/cli.py +++ b/cashu/wallet/cli/cli.py @@ -138,7 +138,8 @@ async def cli(ctx: Context, host: str, walletname: str, unit: str, tests: bool): ctx.ensure_object(dict) ctx.obj["HOST"] = host or settings.mint_url - ctx.obj["UNIT"] = unit + ctx.obj["UNIT"] = unit or settings.wallet_unit + unit = ctx.obj["UNIT"] ctx.obj["WALLET_NAME"] = walletname settings.wallet_name = walletname @@ -147,16 +148,18 @@ async def cli(ctx: Context, host: str, walletname: str, unit: str, tests: bool): # otherwise it will create a mnemonic and store it in the database if ctx.invoked_subcommand == "restore": wallet = await Wallet.with_db( - ctx.obj["HOST"], db_path, name=walletname, skip_db_read=True + ctx.obj["HOST"], db_path, name=walletname, skip_db_read=True, unit=unit ) else: # # we need to run the migrations before we load the wallet for the first time # # otherwise the wallet will not be able to generate a new private key and store it wallet = await Wallet.with_db( - ctx.obj["HOST"], db_path, name=walletname, skip_db_read=True + ctx.obj["HOST"], db_path, name=walletname, skip_db_read=True, unit=unit ) # now with the migrations done, we can load the wallet and generate a new mnemonic if needed - wallet = await Wallet.with_db(ctx.obj["HOST"], db_path, name=walletname) + wallet = await Wallet.with_db( + ctx.obj["HOST"], db_path, name=walletname, unit=unit + ) assert wallet, "Wallet not found." ctx.obj["WALLET"] = wallet @@ -193,7 +196,7 @@ async def pay( wallet: Wallet = ctx.obj["WALLET"] await wallet.load_mint() await print_balance(ctx) - quote = await wallet.request_melt(invoice, amount) + quote = await wallet.melt_quote(invoice, amount) logger.debug(f"Quote: {quote}") total_amount = quote.amount + quote.fee_reserve if not yes: @@ -214,7 +217,9 @@ async def pay( if wallet.available_balance < total_amount: print(" Error: Balance too low.") return - _, send_proofs = await wallet.split_to_send(wallet.proofs, total_amount) + send_proofs, fees = await wallet.select_to_send( + wallet.proofs, total_amount, include_fees=True + ) try: melt_response = await wallet.melt( send_proofs, invoice, quote.fee_reserve, quote.quote @@ -341,11 +346,11 @@ async def swap(ctx: Context): invoice = await incoming_wallet.request_mint(amount) # pay invoice from outgoing mint - quote = await outgoing_wallet.request_melt(invoice.bolt11) + quote = await outgoing_wallet.melt_quote(invoice.bolt11) total_amount = quote.amount + quote.fee_reserve if outgoing_wallet.available_balance < total_amount: raise Exception("balance too low") - _, send_proofs = await outgoing_wallet.split_to_send( + send_proofs, fees = await outgoing_wallet.select_to_send( outgoing_wallet.proofs, total_amount, set_reserved=True ) await outgoing_wallet.melt( @@ -372,8 +377,9 @@ async def swap(ctx: Context): @coro async def balance(ctx: Context, verbose): wallet: Wallet = ctx.obj["WALLET"] - await wallet.load_proofs(unit=False) unit_balances = wallet.balance_per_unit() + await wallet.load_proofs(reload=True) + if len(unit_balances) > 1 and not ctx.obj["UNIT"]: print(f"You have balances in {len(unit_balances)} units:") print("") @@ -397,7 +403,6 @@ async def balance(ctx: Context, verbose): await print_mint_balances(wallet) - await wallet.load_proofs(reload=True) if verbose: print( f"Balance: {wallet.unit.str(wallet.available_balance)} (pending:" @@ -447,11 +452,19 @@ async def balance(ctx: Context, verbose): "--yes", "-y", default=False, is_flag=True, help="Skip confirmation.", type=bool ) @click.option( - "--nosplit", - "-s", + "--offline", + "-o", + default=False, + is_flag=True, + help="Force offline send.", + type=bool, +) +@click.option( + "--include-fees", + "-f", default=False, is_flag=True, - help="Do not split tokens before sending.", + help="Include fees for receiving token.", type=bool, ) @click.pass_context @@ -466,7 +479,8 @@ async def send_command( legacy: bool, verbose: bool, yes: bool, - nosplit: bool, + offline: bool, + include_fees: bool, ): wallet: Wallet = ctx.obj["WALLET"] amount = int(amount * 100) if wallet.unit == Unit.usd else int(amount) @@ -476,8 +490,9 @@ async def send_command( amount=amount, lock=lock, legacy=legacy, - split=not nosplit, + offline=offline, include_dleq=dleq, + include_fees=include_fees, ) else: await send_nostr( @@ -514,7 +529,9 @@ async def receive_cli( # ask the user if they want to trust the new mints for mint_url in set([t.mint for t in tokenObj.token if t.mint]): mint_wallet = Wallet( - mint_url, os.path.join(settings.cashu_dir, wallet.name) + mint_url, + os.path.join(settings.cashu_dir, wallet.name), + unit=tokenObj.unit or wallet.unit.name, ) await verify_mint(mint_wallet, mint_url) receive_wallet = await receive(wallet, tokenObj) @@ -853,6 +870,8 @@ async def wallets(ctx): @coro async def info(ctx: Context, mint: bool, mnemonic: bool): wallet: Wallet = ctx.obj["WALLET"] + await wallet.load_keysets_from_db(unit=None) + print(f"Version: {settings.version}") print(f"Wallet: {ctx.obj['WALLET_NAME']}") if settings.debug: @@ -861,30 +880,38 @@ async def info(ctx: Context, mint: bool, mnemonic: bool): mint_list = await list_mints(wallet) print("Mints:") for mint_url in mint_list: - print(f" - {mint_url}") + print(f" - URL: {mint_url}") + keysets_strs = [ + f"ID: {k.id} unit: {k.unit.name} active: {str(bool(k.active)) + ' ' if k.active else str(bool(k.active))} fee (ppk): {k.input_fee_ppk}" + for k in wallet.keysets.values() + ] + if keysets_strs: + print(" - Keysets:") + for k in keysets_strs: + print(f" - {k}") if mint: wallet.url = mint_url try: - mint_info: dict = (await wallet._load_mint_info()).dict() - print("") - print("---- Mint information ----") - print("") - print(f"Mint URL: {mint_url}") + mint_info: dict = (await wallet.load_mint_info()).dict() if mint_info: - print(f"Mint name: {mint_info['name']}") + print(f" - Mint name: {mint_info['name']}") if mint_info.get("description"): - print(f"Description: {mint_info['description']}") + print(f" - Description: {mint_info['description']}") if mint_info.get("description_long"): - print(f"Long description: {mint_info['description_long']}") - if mint_info.get("contact"): - print(f"Contact: {mint_info['contact']}") + print( + f" - Long description: {mint_info['description_long']}" + ) + if mint_info.get("contact") and mint_info.get("contact") != [ + ["", ""] + ]: + print(f" - Contact: {mint_info['contact']}") if mint_info.get("version"): - print(f"Version: {mint_info['version']}") + print(f" - Version: {mint_info['version']}") if mint_info.get("motd"): - print(f"Message of the day: {mint_info['motd']}") + print(f" - Message of the day: {mint_info['motd']}") if mint_info.get("nuts"): print( - "Supported NUTS:" + " - Supported NUTS:" f" {', '.join(['NUT-'+str(k) for k in mint_info['nuts'].keys()])}" ) print("") @@ -896,14 +923,16 @@ async def info(ctx: Context, mint: bool, mnemonic: bool): assert wallet.mnemonic print(f"Mnemonic:\n - {wallet.mnemonic}") if settings.env_file: - print(f"Settings: {settings.env_file}") + print("Settings:") + print(f" - File: {settings.env_file}") if settings.tor: print(f"Tor enabled: {settings.tor}") if settings.nostr_private_key: try: client = NostrClient(private_key=settings.nostr_private_key, connect=False) - print(f"Nostr public key: {client.public_key.bech32()}") - print(f"Nostr relays: {', '.join(settings.nostr_relays)}") + print("Nostr:") + print(f" - Public key: {client.public_key.bech32()}") + print(f" - Relays: {', '.join(settings.nostr_relays)}") except Exception: print("Nostr: Error. Invalid key.") if settings.socks_proxy: @@ -972,7 +1001,9 @@ async def selfpay(ctx: Context, all: bool = False): mint_balance_dict = await wallet.balance_per_minturl() mint_balance = int(mint_balance_dict[wallet.url]["available"]) # send balance once to mark as reserved - await wallet.split_to_send(wallet.proofs, mint_balance, None, set_reserved=True) + await wallet.select_to_send( + wallet.proofs, mint_balance, set_reserved=True, include_fees=False + ) # load all reserved proofs (including the one we just sent) reserved_proofs = await get_reserved_proofs(wallet.db) if not len(reserved_proofs): diff --git a/cashu/wallet/cli/cli_helpers.py b/cashu/wallet/cli/cli_helpers.py index f5102534..c0c02df9 100644 --- a/cashu/wallet/cli/cli_helpers.py +++ b/cashu/wallet/cli/cli_helpers.py @@ -12,7 +12,7 @@ async def print_balance(ctx: Context): wallet: Wallet = ctx.obj["WALLET"] - await wallet.load_proofs(reload=True, unit=wallet.unit) + await wallet.load_proofs(reload=True) print(f"Balance: {wallet.unit.str(wallet.available_balance)}") @@ -24,11 +24,11 @@ async def get_unit_wallet(ctx: Context, force_select: bool = False): force_select (bool, optional): Force the user to select a unit. Defaults to False. """ wallet: Wallet = ctx.obj["WALLET"] - await wallet.load_proofs(reload=True, unit=False) + await wallet.load_proofs(reload=False) # show balances per unit unit_balances = wallet.balance_per_unit() - if ctx.obj["UNIT"] in [u.name for u in unit_balances] and not force_select: - wallet.unit = Unit[ctx.obj["UNIT"]] + if wallet.unit in [unit_balances.keys()] and not force_select: + return wallet elif len(unit_balances) > 1 and not ctx.obj["UNIT"]: print(f"You have balances in {len(unit_balances)} units:") print("") @@ -68,7 +68,7 @@ async def get_mint_wallet(ctx: Context, force_select: bool = False): """ # we load a dummy wallet so we can check the balance per mint wallet: Wallet = ctx.obj["WALLET"] - await wallet.load_proofs(reload=True) + await wallet.load_proofs(reload=False) mint_balances = await wallet.balance_per_minturl() if ctx.obj["HOST"] not in mint_balances and not force_select: @@ -102,6 +102,7 @@ async def get_mint_wallet(ctx: Context, force_select: bool = False): mint_url, os.path.join(settings.cashu_dir, ctx.obj["WALLET_NAME"]), name=wallet.name, + unit=wallet.unit.name, ) await mint_wallet.load_proofs(reload=True) diff --git a/cashu/wallet/crud.py b/cashu/wallet/crud.py index 5e658747..aab64f33 100644 --- a/cashu/wallet/crud.py +++ b/cashu/wallet/crud.py @@ -34,6 +34,7 @@ async def store_proof( async def get_proofs( *, db: Database, + id: Optional[str] = "", melt_id: str = "", mint_id: str = "", table: str = "proofs", @@ -42,6 +43,9 @@ async def get_proofs( clauses = [] values: List[Any] = [] + if id: + clauses.append("id = ?") + values.append(id) if melt_id: clauses.append("melt_id = ?") values.append(melt_id) @@ -169,8 +173,8 @@ async def store_keyset( await (conn or db).execute( # type: ignore """ INSERT INTO keysets - (id, mint_url, valid_from, valid_to, first_seen, active, public_keys, unit) - VALUES (?, ?, ?, ?, ?, ?, ?, ?) + (id, mint_url, valid_from, valid_to, first_seen, active, public_keys, unit, input_fee_ppk) + VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?) """, ( keyset.id, @@ -181,26 +185,29 @@ async def store_keyset( keyset.active, keyset.serialize(), keyset.unit.name, + keyset.input_fee_ppk, ), ) async def get_keysets( id: str = "", - mint_url: str = "", + mint_url: Optional[str] = None, + unit: Optional[str] = None, db: Optional[Database] = None, conn: Optional[Connection] = None, ) -> List[WalletKeyset]: clauses = [] values: List[Any] = [] - clauses.append("active = ?") - values.append(True) if id: clauses.append("id = ?") values.append(id) if mint_url: clauses.append("mint_url = ?") values.append(mint_url) + if unit: + clauses.append("unit = ?") + values.append(unit) where = "" if clauses: where = f"WHERE {' AND '.join(clauses)}" @@ -219,6 +226,24 @@ async def get_keysets( return ret +async def update_keyset( + keyset: WalletKeyset, + db: Database, + conn: Optional[Connection] = None, +) -> None: + await (conn or db).execute( + """ + UPDATE keysets + SET active = ? + WHERE id = ? + """, + ( + keyset.active, + keyset.id, + ), + ) + + async def store_lightning_invoice( db: Database, invoice: Invoice, diff --git a/cashu/wallet/helpers.py b/cashu/wallet/helpers.py index 547cf6bc..b20e3c62 100644 --- a/cashu/wallet/helpers.py +++ b/cashu/wallet/helpers.py @@ -40,23 +40,26 @@ async def redeem_TokenV3_multimint(wallet: Wallet, token: TokenV3) -> Wallet: Helper function to iterate thruogh a token with multiple mints and redeem them from these mints one keyset at a time. """ + if not token.unit: + # load unit from wallet keyset db + keysets = await get_keysets(id=token.token[0].proofs[0].id, db=wallet.db) + if keysets: + token.unit = keysets[0].unit.name + for t in token.token: assert t.mint, Exception( "redeem_TokenV3_multimint: multimint redeem without URL" ) mint_wallet = await Wallet.with_db( - t.mint, os.path.join(settings.cashu_dir, wallet.name) + t.mint, + os.path.join(settings.cashu_dir, wallet.name), + unit=token.unit or wallet.unit.name, ) keyset_ids = mint_wallet._get_proofs_keysets(t.proofs) - logger.trace(f"Keysets in tokens: {keyset_ids}") - # loop over all keysets - for keyset_id in set(keyset_ids): - await mint_wallet.load_mint(keyset_id) - mint_wallet.unit = mint_wallet.keysets[keyset_id].unit - # redeem proofs of this keyset - redeem_proofs = [p for p in t.proofs if p.id == keyset_id] - _, _ = await mint_wallet.redeem(redeem_proofs) - print(f"Received {mint_wallet.unit.str(sum_proofs(redeem_proofs))}") + logger.trace(f"Keysets in tokens: {' '.join(set(keyset_ids))}") + await mint_wallet.load_mint() + proofs_to_keep, _ = await mint_wallet.redeem(t.proofs) + print(f"Received {mint_wallet.unit.str(sum_proofs(proofs_to_keep))}") # return the last mint_wallet return mint_wallet @@ -137,19 +140,19 @@ async def receive( ) else: # this is very legacy code, virtually any token should have mint information - # no mint information present, we extract the proofs and use wallet's default mint - # first we load the mint URL from the DB + # no mint information present, we extract the proofs find the mint and unit from the db keyset_in_token = proofs[0].id assert keyset_in_token # we get the keyset from the db mint_keysets = await get_keysets(id=keyset_in_token, db=wallet.db) assert mint_keysets, Exception(f"we don't know this keyset: {keyset_in_token}") - mint_keyset = mint_keysets[0] + mint_keyset = [k for k in mint_keysets if k.id == keyset_in_token][0] assert mint_keyset.mint_url, Exception("we don't know this mint's URL") # now we have the URL mint_wallet = await Wallet.with_db( mint_keyset.mint_url, os.path.join(settings.cashu_dir, wallet.name), + unit=mint_keyset.unit.name or wallet.unit.name, ) await mint_wallet.load_mint(keyset_in_token) _, _ = await mint_wallet.redeem(proofs) @@ -166,8 +169,9 @@ async def send( amount: int, lock: str, legacy: bool, - split: bool = True, + offline: bool = False, include_dleq: bool = False, + include_fees: bool = False, ): """ Prints token to send to stdout. @@ -191,24 +195,19 @@ async def send( sig_all=True, n_sigs=1, ) + print(f"Secret lock: {secret_lock}") await wallet.load_proofs() - if split: - await wallet.load_mint() - _, send_proofs = await wallet.split_to_send( - wallet.proofs, amount, secret_lock, set_reserved=True - ) - else: - # get a proof with specific amount - send_proofs = [] - for p in wallet.proofs: - if not p.reserved and p.amount == amount: - send_proofs = [p] - break - assert send_proofs, Exception( - "No proof with this amount found. Available amounts:" - f" {set([p.amount for p in wallet.proofs])}" - ) + + await wallet.load_mint() + # get a proof with specific amount + send_proofs, fees = await wallet.select_to_send( + wallet.proofs, + amount, + set_reserved=False, + offline=offline, + include_fees=include_fees, + ) token = await wallet.serialize_proofs( send_proofs, diff --git a/cashu/wallet/lightning/lightning.py b/cashu/wallet/lightning/lightning.py index 6b23be5e..4f24d688 100644 --- a/cashu/wallet/lightning/lightning.py +++ b/cashu/wallet/lightning/lightning.py @@ -55,7 +55,7 @@ async def pay_invoice(self, pr: str) -> PaymentResponse: Returns: bool: True if successful """ - quote = await self.request_melt(pr) + quote = await self.melt_quote(pr) total_amount = quote.amount + quote.fee_reserve assert total_amount > 0, "amount is not positive" if self.available_balance < total_amount: diff --git a/cashu/wallet/migrations.py b/cashu/wallet/migrations.py index 21e0e158..83bac045 100644 --- a/cashu/wallet/migrations.py +++ b/cashu/wallet/migrations.py @@ -236,3 +236,10 @@ async def m011_keysets_add_unit(db: Database): # add column for storing the unit of a keyset await conn.execute("ALTER TABLE keysets ADD COLUMN unit TEXT") await conn.execute("UPDATE keysets SET unit = 'sat'") + + +async def m012_add_fee_to_keysets(db: Database): + async with db.connect() as conn: + # add column for storing the fee of a keyset + await conn.execute("ALTER TABLE keysets ADD COLUMN input_fee_ppk INTEGER") + await conn.execute("UPDATE keysets SET input_fee_ppk = 0") diff --git a/cashu/wallet/mint_info.py b/cashu/wallet/mint_info.py index 30ec3f6d..e9092dd6 100644 --- a/cashu/wallet/mint_info.py +++ b/cashu/wallet/mint_info.py @@ -2,7 +2,8 @@ from pydantic import BaseModel -from ..core.base import Nut15MppSupport, Unit +from ..core.base import Unit +from ..core.models import Nut15MppSupport class MintInfo(BaseModel): diff --git a/cashu/wallet/nostr.py b/cashu/wallet/nostr.py index 72b1a60e..357e33eb 100644 --- a/cashu/wallet/nostr.py +++ b/cashu/wallet/nostr.py @@ -63,7 +63,7 @@ async def send_nostr( await wallet.load_mint() await wallet.load_proofs() _, send_proofs = await wallet.split_to_send( - wallet.proofs, amount, set_reserved=True + wallet.proofs, amount, set_reserved=True, include_fees=False ) token = await wallet.serialize_proofs(send_proofs, include_dleq=include_dleq) diff --git a/cashu/wallet/proofs.py b/cashu/wallet/proofs.py new file mode 100644 index 00000000..d3bd4cf4 --- /dev/null +++ b/cashu/wallet/proofs.py @@ -0,0 +1,208 @@ +import base64 +import json +from itertools import groupby +from typing import Dict, List, Optional + +from loguru import logger + +from ..core.base import ( + Proof, + TokenV2, + TokenV2Mint, + TokenV3, + TokenV3Token, + Unit, + WalletKeyset, +) +from ..core.db import Database +from ..wallet.crud import ( + get_keysets, +) +from .protocols import SupportsDb, SupportsKeysets + + +class WalletProofs(SupportsDb, SupportsKeysets): + keyset_id: str + db: Database + + @staticmethod + def _get_proofs_per_keyset(proofs: List[Proof]): + return { + key: list(group) for key, group in groupby(proofs, lambda p: p.id) if key + } + + async def _get_proofs_per_minturl( + self, proofs: List[Proof], unit: Optional[Unit] = None + ) -> Dict[str, List[Proof]]: + ret: Dict[str, List[Proof]] = {} + keyset_ids = set([p.id for p in proofs]) + for id in keyset_ids: + if id is None: + continue + keysets_crud = await get_keysets(id=id, db=self.db) + assert keysets_crud, f"keyset {id} not found" + keyset: WalletKeyset = keysets_crud[0] + if unit and keyset.unit != unit: + continue + assert keyset.mint_url + if keyset.mint_url not in ret: + ret[keyset.mint_url] = [p for p in proofs if p.id == id] + else: + ret[keyset.mint_url].extend([p for p in proofs if p.id == id]) + return ret + + def _get_proofs_per_unit(self, proofs: List[Proof]) -> Dict[Unit, List[Proof]]: + ret: Dict[Unit, List[Proof]] = {} + for proof in proofs: + if proof.id not in self.keysets: + logger.error(f"Keyset {proof.id} not found in wallet.") + continue + unit = self.keysets[proof.id].unit + if unit not in ret: + ret[unit] = [proof] + else: + ret[unit].append(proof) + return ret + + def _get_proofs_keysets(self, proofs: List[Proof]) -> List[str]: + """Extracts all keyset ids from a list of proofs. + + Args: + proofs (List[Proof]): List of proofs to get the keyset id's of + """ + keysets: List[str] = [proof.id for proof in proofs] + return keysets + + async def _get_keyset_urls(self, keysets: List[str]) -> Dict[str, List[str]]: + """Retrieves the mint URLs for a list of keyset id's from the wallet's database. + Returns a dictionary from URL to keyset ID + + Args: + keysets (List[str]): List of keysets. + """ + mint_urls: Dict[str, List[str]] = {} + for ks in set(keysets): + keysets_db = await get_keysets(id=ks, db=self.db) + keyset_db = keysets_db[0] if keysets_db else None + if keyset_db and keyset_db.mint_url: + mint_urls[keyset_db.mint_url] = ( + mint_urls[keyset_db.mint_url] + [ks] + if mint_urls.get(keyset_db.mint_url) + else [ks] + ) + return mint_urls + + async def _make_token( + self, proofs: List[Proof], include_mints=True, include_unit=True + ) -> TokenV3: + """ + Takes list of proofs and produces a TokenV3 by looking up + the mint URLs by the keyset id from the database. + + Args: + proofs (List[Proof]): List of proofs to be included in the token + include_mints (bool, optional): Whether to include the mint URLs in the token. Defaults to True. + + Returns: + TokenV3: TokenV3 object + """ + token = TokenV3() + if include_unit: + token.unit = self.unit.name + + if include_mints: + # we create a map from mint url to keyset id and then group + # all proofs with their mint url to build a tokenv3 + + # extract all keysets from proofs + keysets = self._get_proofs_keysets(proofs) + # get all mint URLs for all unique keysets from db + mint_urls = await self._get_keyset_urls(keysets) + + # append all url-grouped proofs to token + for url, ids in mint_urls.items(): + mint_proofs = [p for p in proofs if p.id in ids] + token.token.append(TokenV3Token(mint=url, proofs=mint_proofs)) + else: + token_proofs = TokenV3Token(proofs=proofs) + token.token.append(token_proofs) + return token + + async def serialize_proofs( + self, proofs: List[Proof], include_mints=True, include_dleq=False, legacy=False + ) -> str: + """Produces sharable token with proofs and mint information. + + Args: + proofs (List[Proof]): List of proofs to be included in the token + include_mints (bool, optional): Whether to include the mint URLs in the token. Defaults to True. + legacy (bool, optional): Whether to produce a legacy V2 token. Defaults to False. + + Returns: + str: Serialized Cashu token + """ + + if legacy: + # V2 tokens + token_v2 = await self._make_token_v2(proofs, include_mints) + return await self._serialize_token_base64_tokenv2(token_v2) + + # # deprecated code for V1 tokens + # proofs_serialized = [p.to_dict() for p in proofs] + # return base64.urlsafe_b64encode( + # json.dumps(proofs_serialized).encode() + # ).decode() + + # V3 tokens + token = await self._make_token(proofs, include_mints) + return token.serialize(include_dleq) + + async def _make_token_v2(self, proofs: List[Proof], include_mints=True) -> TokenV2: + """ + Takes list of proofs and produces a TokenV2 by looking up + the keyset id and mint URLs from the database. + """ + # build token + token = TokenV2(proofs=proofs) + + # add mint information to the token, if requested + if include_mints: + # dummy object to hold information about the mint + mints: Dict[str, TokenV2Mint] = {} + # dummy object to hold all keyset id's we need to fetch from the db later + keysets: List[str] = [proof.id for proof in proofs if proof.id] + # iterate through unique keyset ids + for id in set(keysets): + # load the keyset from the db + keysets_db = await get_keysets(id=id, db=self.db) + keyset_db = keysets_db[0] if keysets_db else None + if keyset_db and keyset_db.mint_url and keyset_db.id: + # we group all mints according to URL + if keyset_db.mint_url not in mints: + mints[keyset_db.mint_url] = TokenV2Mint( + url=keyset_db.mint_url, + ids=[keyset_db.id], + ) + else: + # if a mint URL has multiple keysets, append to the already existing list + mints[keyset_db.mint_url].ids.append(keyset_db.id) + if len(mints) > 0: + # add mints grouped by url to the token + token.mints = list(mints.values()) + return token + + async def _serialize_token_base64_tokenv2(self, token: TokenV2) -> str: + """ + Takes a TokenV2 and serializes it in urlsafe_base64. + + Args: + token (TokenV2): TokenV2 object to be serialized + + Returns: + str: Serialized token + """ + # encode the token as a base64 string + token_base64 = base64.urlsafe_b64encode( + json.dumps(token.to_dict()).encode() + ).decode() + return token_base64 diff --git a/cashu/wallet/protocols.py b/cashu/wallet/protocols.py index 0dca62a2..1f381a19 100644 --- a/cashu/wallet/protocols.py +++ b/cashu/wallet/protocols.py @@ -1,7 +1,8 @@ -from typing import Protocol +from typing import Dict, Protocol import httpx +from ..core.base import Unit, WalletKeyset from ..core.crypto.secp import PrivateKey from ..core.db import Database @@ -15,7 +16,9 @@ class SupportsDb(Protocol): class SupportsKeysets(Protocol): + keysets: Dict[str, WalletKeyset] # holds keysets keyset_id: str + unit: Unit class SupportsHttpxClient(Protocol): diff --git a/cashu/wallet/secrets.py b/cashu/wallet/secrets.py index 1ef1eff5..3e6e5722 100644 --- a/cashu/wallet/secrets.py +++ b/cashu/wallet/secrets.py @@ -9,6 +9,7 @@ from ..core.crypto.secp import PrivateKey from ..core.db import Database +from ..core.secret import Secret from ..core.settings import settings from ..wallet.crud import ( bump_secret_derivation, @@ -93,19 +94,13 @@ async def _init_private_key(self, from_mnemonic: Optional[str] = None) -> None: except Exception as e: logger.error(e) - async def _generate_secret(self) -> str: + async def _generate_random_secret(self) -> str: """Returns base64 encoded deterministic random string. NOTE: This method should probably retire after `deterministic_secrets`. We are deriving secrets from a counter but don't store the respective blinding factor. We won't be able to restore any ecash generated with these secrets. """ - # secret_counter = await bump_secret_derivation(db=self.db, keyset_id=keyset_id) - # logger.trace(f"secret_counter: {secret_counter}") - # s, _, _ = await self.generate_determinstic_secret(secret_counter, keyset_id) - # # return s.decode("utf-8") - # return hashlib.sha256(s).hexdigest() - # return random 32 byte hex string return hashlib.sha256(os.urandom(32)).hexdigest() @@ -209,3 +204,29 @@ async def generate_secrets_from_to( rs = [PrivateKey(privkey=s[1], raw=True) for s in secrets_rs_derivationpaths] derivation_paths = [s[2] for s in secrets_rs_derivationpaths] return secrets, rs, derivation_paths + + async def generate_locked_secrets( + self, send_outputs: List[int], keep_outputs: List[int], secret_lock: Secret + ) -> Tuple[List[str], List[PrivateKey], List[str]]: + """Generates secrets and blinding factors for a transaction with `send_outputs` and `keep_outputs`. + + Args: + send_outputs (List[int]): List of amounts to send + keep_outputs (List[int]): List of amounts to keep + + Returns: + Tuple[List[str], List[PrivateKey], List[str]]: Secrets, blinding factors, derivation paths + """ + rs: List[PrivateKey] = [] + # generate secrets for receiver + secret_locks = [secret_lock.serialize() for i in range(len(send_outputs))] + logger.debug(f"Creating proofs with custom secrets: {secret_locks}") + # append predefined secrets (to send) to random secrets (to keep) + # generate secrets to keep + secrets = [ + await self._generate_random_secret() for s in range(len(keep_outputs)) + ] + secret_locks + # TODO: derive derivation paths from secrets + derivation_paths = ["custom"] * len(secrets) + + return secrets, rs, derivation_paths diff --git a/cashu/wallet/transactions.py b/cashu/wallet/transactions.py new file mode 100644 index 00000000..7ac5ffa8 --- /dev/null +++ b/cashu/wallet/transactions.py @@ -0,0 +1,212 @@ +import math +import uuid +from typing import Dict, List, Tuple, Union + +from loguru import logger + +from ..core.base import ( + Proof, + Unit, + WalletKeyset, +) +from ..core.db import Database +from ..core.helpers import amount_summary, sum_proofs +from ..wallet.crud import ( + update_proof, +) +from .protocols import SupportsDb, SupportsKeysets + + +class WalletTransactions(SupportsDb, SupportsKeysets): + keysets: Dict[str, WalletKeyset] # holds keysets + keyset_id: str + db: Database + unit: Unit + + def get_fees_for_keyset(self, amounts: List[int], keyset: WalletKeyset) -> int: + fees = max(math.ceil(sum([keyset.input_fee_ppk for a in amounts]) / 1000), 0) + return fees + + def get_fees_for_proofs(self, proofs: List[Proof]) -> int: + # for each proof, find the keyset with the same id and sum the fees + fees = max( + math.ceil(sum([self.keysets[p.id].input_fee_ppk for p in proofs]) / 1000), 0 + ) + return fees + + def get_fees_for_proofs_ppk(self, proofs: List[Proof]) -> int: + return sum([self.keysets[p.id].input_fee_ppk for p in proofs]) + + async def _select_proofs_to_send_( + self, proofs: List[Proof], amount_to_send: int, tolerance: int = 0 + ) -> List[Proof]: + send_proofs: List[Proof] = [] + NO_SELECTION: List[Proof] = [] + + logger.trace(f"proofs: {[p.amount for p in proofs]}") + # sort proofs by amount (descending) + sorted_proofs = sorted(proofs, key=lambda p: p.amount, reverse=True) + # only consider proofs smaller than the amount we want to send (+ tolerance) for coin selection + fee_for_single_proof = self.get_fees_for_proofs([sorted_proofs[0]]) + sorted_proofs = [ + p + for p in sorted_proofs + if p.amount <= amount_to_send + tolerance + fee_for_single_proof + ] + if not sorted_proofs: + logger.info( + f"no small-enough proofs to send. Have: {[p.amount for p in proofs]}" + ) + return NO_SELECTION + + target_amount = amount_to_send + + # compose the target amount from the remaining_proofs + logger.debug(f"sorted_proofs: {[p.amount for p in sorted_proofs]}") + for p in sorted_proofs: + # logger.debug(f"send_proofs: {[p.amount for p in send_proofs]}") + # logger.debug(f"target_amount: {target_amount}") + # logger.debug(f"p.amount: {p.amount}") + if sum_proofs(send_proofs) + p.amount <= target_amount + tolerance: + send_proofs.append(p) + target_amount = amount_to_send + self.get_fees_for_proofs(send_proofs) + + if sum_proofs(send_proofs) < amount_to_send: + logger.info("could not select proofs to reach target amount (too little).") + return NO_SELECTION + + fees = self.get_fees_for_proofs(send_proofs) + logger.debug(f"Selected sum of proofs: {sum_proofs(send_proofs)}, fees: {fees}") + return send_proofs + + async def _select_proofs_to_send( + self, + proofs: List[Proof], + amount_to_send: Union[int, float], + *, + include_fees: bool = True, + ) -> List[Proof]: + # check that enough spendable proofs exist + if sum_proofs(proofs) < amount_to_send: + return [] + + logger.trace( + f"_select_proofs_to_send – amount_to_send: {amount_to_send} – amounts we have: {amount_summary(proofs, self.unit)} (sum: {sum_proofs(proofs)})" + ) + + sorted_proofs = sorted(proofs, key=lambda p: p.amount) + + next_bigger = next( + (p for p in sorted_proofs if p.amount > amount_to_send), None + ) + + smaller_proofs = [p for p in sorted_proofs if p.amount <= amount_to_send] + smaller_proofs = sorted(smaller_proofs, key=lambda p: p.amount, reverse=True) + + if not smaller_proofs and next_bigger: + logger.trace( + "> no proofs smaller than amount_to_send, adding next bigger proof" + ) + return [next_bigger] + + if not smaller_proofs and not next_bigger: + logger.trace("> no proofs to select from") + return [] + + remainder = amount_to_send + selected_proofs = [smaller_proofs[0]] + fee_ppk = self.get_fees_for_proofs_ppk(selected_proofs) if include_fees else 0 + logger.debug(f"adding proof: {smaller_proofs[0].amount} – fee: {fee_ppk} ppk") + remainder -= smaller_proofs[0].amount - fee_ppk / 1000 + logger.debug(f"remainder: {remainder}") + if remainder > 0: + logger.trace( + f"> selecting more proofs from {amount_summary(smaller_proofs[1:], self.unit)} sum: {sum_proofs(smaller_proofs[1:])} to reach {remainder}" + ) + selected_proofs += await self._select_proofs_to_send( + smaller_proofs[1:], remainder, include_fees=include_fees + ) + sum_selected_proofs = sum_proofs(selected_proofs) + + if sum_selected_proofs < amount_to_send and next_bigger: + logger.trace("> adding next bigger proof") + return [next_bigger] + + logger.trace( + f"_select_proofs_to_send - selected proof amounts: {amount_summary(selected_proofs, self.unit)} (sum: {sum_proofs(selected_proofs)})" + ) + return selected_proofs + + async def _select_proofs_to_split( + self, proofs: List[Proof], amount_to_send: int + ) -> Tuple[List[Proof], int]: + """ + Selects proofs that can be used with the current mint. Implements a simple coin selection algorithm. + + The algorithm has two objectives: Get rid of all tokens from old epochs and include additional proofs from + the current epoch starting from the proofs with the largest amount. + + Rules: + 1) Proofs that are not marked as reserved + 2) Proofs that have a different keyset than the activated keyset_id of the mint + 3) Include all proofs that have an older keyset than the current keyset of the mint (to get rid of old epochs). + 4) If the target amount is not reached, add proofs of the current keyset until it is. + + Args: + proofs (List[Proof]): List of proofs to select from + amount_to_send (int): Amount to select proofs for + + Returns: + List[Proof]: List of proofs to send (including fees) + int: Fees for the transaction + + Raises: + Exception: If the balance is too low to send the amount + """ + logger.debug( + f"_select_proofs_to_split - amounts we have: {amount_summary(proofs, self.unit)}" + ) + send_proofs: List[Proof] = [] + + # check that enough spendable proofs exist + if sum_proofs(proofs) < amount_to_send: + raise Exception("balance too low.") + + # add all proofs that have an older keyset than the current keyset of the mint + proofs_old_epochs = [ + p for p in proofs if p.id != self.keysets[self.keyset_id].id + ] + send_proofs += proofs_old_epochs + + # coinselect based on amount only from the current keyset + # start with the proofs with the largest amount and add them until the target amount is reached + proofs_current_epoch = [ + p for p in proofs if p.id == self.keysets[self.keyset_id].id + ] + sorted_proofs_of_current_keyset = sorted( + proofs_current_epoch, key=lambda p: p.amount + ) + + while sum_proofs(send_proofs) < amount_to_send + self.get_fees_for_proofs( + send_proofs + ): + proof_to_add = sorted_proofs_of_current_keyset.pop() + send_proofs.append(proof_to_add) + + logger.trace( + f"_select_proofs_to_split – selected proof amounts: {[p.amount for p in send_proofs]}" + ) + fees = self.get_fees_for_proofs(send_proofs) + return send_proofs, fees + + async def set_reserved(self, proofs: List[Proof], reserved: bool) -> None: + """Mark a proof as reserved or reset it in the wallet db to avoid reuse when it is sent. + + Args: + proofs (List[Proof]): List of proofs to mark as reserved + reserved (bool): Whether to mark the proofs as reserved or not + """ + uuid_str = str(uuid.uuid1()) + for proof in proofs: + proof.reserved = True + await update_proof(proof, reserved=reserved, send_id=uuid_str, db=self.db) diff --git a/cashu/wallet/v1_api.py b/cashu/wallet/v1_api.py new file mode 100644 index 00000000..59e9ba17 --- /dev/null +++ b/cashu/wallet/v1_api.py @@ -0,0 +1,539 @@ +import json +import uuid +from posixpath import join +from typing import List, Optional, Tuple, Union + +import bolt11 +import httpx +from httpx import Response +from loguru import logger + +from ..core.base import ( + BlindedMessage, + BlindedSignature, + Proof, + ProofState, + SpentState, + Unit, + WalletKeyset, +) +from ..core.crypto.secp import PublicKey +from ..core.db import Database +from ..core.models import ( + CheckFeesResponse_deprecated, + GetInfoResponse, + KeysetsResponse, + KeysetsResponseKeyset, + KeysResponse, + PostCheckStateRequest, + PostCheckStateResponse, + PostMeltQuoteRequest, + PostMeltQuoteResponse, + PostMeltRequest, + PostMeltResponse, + PostMeltResponse_deprecated, + PostMintQuoteRequest, + PostMintQuoteResponse, + PostMintRequest, + PostMintResponse, + PostRestoreResponse, + PostSplitRequest, + PostSplitResponse, +) +from ..core.settings import settings +from ..tor.tor import TorProxy +from .crud import ( + get_lightning_invoice, +) +from .wallet_deprecated import LedgerAPIDeprecated + + +def async_set_httpx_client(func): + """ + Decorator that wraps around any async class method of LedgerAPI that makes + API calls. Sets some HTTP headers and starts a Tor instance if none is + already running and and sets local proxy to use it. + """ + + async def wrapper(self, *args, **kwargs): + # set proxy + proxies_dict = {} + proxy_url: Union[str, None] = None + if settings.tor and TorProxy().check_platform(): + self.tor = TorProxy(timeout=True) + self.tor.run_daemon(verbose=True) + proxy_url = "socks5://localhost:9050" + elif settings.socks_proxy: + proxy_url = f"socks5://{settings.socks_proxy}" + elif settings.http_proxy: + proxy_url = settings.http_proxy + if proxy_url: + proxies_dict.update({"all://": proxy_url}) + + headers_dict = {"Client-version": settings.version} + + self.httpx = httpx.AsyncClient( + verify=not settings.debug, + proxies=proxies_dict, # type: ignore + headers=headers_dict, + base_url=self.url, + timeout=None if settings.debug else 60, + ) + return await func(self, *args, **kwargs) + + return wrapper + + +def async_ensure_mint_loaded(func): + """Decorator that ensures that the mint is loaded before calling the wrapped + function. If the mint is not loaded, it will be loaded first. + """ + + async def wrapper(self, *args, **kwargs): + if not self.keysets: + await self.load_mint() + return await func(self, *args, **kwargs) + + return wrapper + + +class LedgerAPI(LedgerAPIDeprecated, object): + tor: TorProxy + db: Database # we need the db for melt_deprecated + httpx: httpx.AsyncClient + + def __init__(self, url: str, db: Database): + self.url = url + self.db = db + + @async_set_httpx_client + async def _init_s(self): + """Dummy function that can be called from outside to use LedgerAPI.s""" + return + + @staticmethod + def raise_on_error_request( + resp: Response, + ) -> None: + """Raises an exception if the response from the mint contains an error. + + Args: + resp_dict (Response): Response dict (previously JSON) from mint + + Raises: + Exception: if the response contains an error + """ + try: + resp_dict = resp.json() + except json.JSONDecodeError: + # if we can't decode the response, raise for status + resp.raise_for_status() + return + if "detail" in resp_dict: + logger.trace(f"Error from mint: {resp_dict}") + error_message = f"Mint Error: {resp_dict['detail']}" + if "code" in resp_dict: + error_message += f" (Code: {resp_dict['code']})" + raise Exception(error_message) + # raise for status if no error + resp.raise_for_status() + + """ + ENDPOINTS + """ + + @async_set_httpx_client + async def _get_keys(self) -> List[WalletKeyset]: + """API that gets the current keys of the mint + + Args: + url (str): Mint URL + + Returns: + WalletKeyset: Current mint keyset + + Raises: + Exception: If no keys are received from the mint + """ + resp = await self.httpx.get( + join(self.url, "/v1/keys"), + ) + # BEGIN backwards compatibility < 0.15.0 + # assume the mint has not upgraded yet if we get a 404 + if resp.status_code == 404: + ret = await self._get_keys_deprecated(self.url) + return [ret] + # END backwards compatibility < 0.15.0 + self.raise_on_error_request(resp) + keys_dict: dict = resp.json() + assert len(keys_dict), Exception("did not receive any keys") + keys = KeysResponse.parse_obj(keys_dict) + logger.debug( + f"Received {len(keys.keysets)} keysets from mint:" + f" {' '.join([k.id + f' ({k.unit})' for k in keys.keysets])}." + ) + ret = [ + WalletKeyset( + id=keyset.id, + unit=keyset.unit, + public_keys={ + int(amt): PublicKey(bytes.fromhex(val), raw=True) + for amt, val in keyset.keys.items() + }, + mint_url=self.url, + ) + for keyset in keys.keysets + ] + return ret + + @async_set_httpx_client + async def _get_keyset(self, keyset_id: str) -> WalletKeyset: + """API that gets the keys of a specific keyset from the mint. + + + Args: + keyset_id (str): base64 keyset ID, needs to be urlsafe-encoded before sending to mint (done in this method) + + Returns: + WalletKeyset: Keyset with ID keyset_id + + Raises: + Exception: If no keys are received from the mint + """ + keyset_id_urlsafe = keyset_id.replace("+", "-").replace("/", "_") + resp = await self.httpx.get( + join(self.url, f"/v1/keys/{keyset_id_urlsafe}"), + ) + # BEGIN backwards compatibility < 0.15.0 + # assume the mint has not upgraded yet if we get a 404 + if resp.status_code == 404: + ret = await self._get_keyset_deprecated(self.url, keyset_id) + return ret + # END backwards compatibility < 0.15.0 + self.raise_on_error_request(resp) + + keys_dict = resp.json() + assert len(keys_dict), Exception("did not receive any keys") + keys = KeysResponse.parse_obj(keys_dict) + this_keyset = keys.keysets[0] + keyset_keys = { + int(amt): PublicKey(bytes.fromhex(val), raw=True) + for amt, val in this_keyset.keys.items() + } + keyset = WalletKeyset( + id=keyset_id, + unit=this_keyset.unit, + public_keys=keyset_keys, + mint_url=self.url, + ) + return keyset + + @async_set_httpx_client + async def _get_keysets(self) -> List[KeysetsResponseKeyset]: + """API that gets a list of all active keysets of the mint. + + Returns: + KeysetsResponse (List[str]): List of all active keyset IDs of the mint + + Raises: + Exception: If no keysets are received from the mint + """ + resp = await self.httpx.get( + join(self.url, "/v1/keysets"), + ) + # BEGIN backwards compatibility < 0.15.0 + # assume the mint has not upgraded yet if we get a 404 + if resp.status_code == 404: + ret = await self._get_keysets_deprecated(self.url) + return ret + # END backwards compatibility < 0.15.0 + self.raise_on_error_request(resp) + + keysets_dict = resp.json() + keysets = KeysetsResponse.parse_obj(keysets_dict).keysets + if not keysets: + raise Exception("did not receive any keysets") + return keysets + + @async_set_httpx_client + async def _get_info(self) -> GetInfoResponse: + """API that gets the mint info. + + Returns: + GetInfoResponse: Current mint info + + Raises: + Exception: If the mint info request fails + """ + resp = await self.httpx.get( + join(self.url, "/v1/info"), + ) + # BEGIN backwards compatibility < 0.15.0 + # assume the mint has not upgraded yet if we get a 404 + if resp.status_code == 404: + ret = await self._get_info_deprecated() + return ret + # END backwards compatibility < 0.15.0 + self.raise_on_error_request(resp) + data: dict = resp.json() + mint_info: GetInfoResponse = GetInfoResponse.parse_obj(data) + return mint_info + + @async_set_httpx_client + @async_ensure_mint_loaded + async def mint_quote(self, amount: int, unit: Unit) -> PostMintQuoteResponse: + """Requests a mint quote from the server and returns a payment request. + + Args: + amount (int): Amount of tokens to mint + + Returns: + PostMintQuoteResponse: Mint Quote Response + + Raises: + Exception: If the mint request fails + """ + logger.trace("Requesting mint: GET /v1/mint/bolt11") + payload = PostMintQuoteRequest(unit=unit.name, amount=amount) + resp = await self.httpx.post( + join(self.url, "/v1/mint/quote/bolt11"), json=payload.dict() + ) + # BEGIN backwards compatibility < 0.15.0 + # assume the mint has not upgraded yet if we get a 404 + if resp.status_code == 404: + ret = await self.request_mint_deprecated(amount) + return ret + # END backwards compatibility < 0.15.0 + self.raise_on_error_request(resp) + return_dict = resp.json() + return PostMintQuoteResponse.parse_obj(return_dict) + + @async_set_httpx_client + @async_ensure_mint_loaded + async def mint( + self, outputs: List[BlindedMessage], quote: str + ) -> List[BlindedSignature]: + """Mints new coins and returns a proof of promise. + + Args: + outputs (List[BlindedMessage]): Outputs to mint new tokens with + quote (str): Quote ID. + + Returns: + list[Proof]: List of proofs. + + Raises: + Exception: If the minting fails + """ + outputs_payload = PostMintRequest(outputs=outputs, quote=quote) + logger.trace("Checking Lightning invoice. POST /v1/mint/bolt11") + + def _mintrequest_include_fields(outputs: List[BlindedMessage]): + """strips away fields from the model that aren't necessary for the /mint""" + outputs_include = {"id", "amount", "B_"} + return { + "quote": ..., + "outputs": {i: outputs_include for i in range(len(outputs))}, + } + + payload = outputs_payload.dict(include=_mintrequest_include_fields(outputs)) # type: ignore + resp = await self.httpx.post( + join(self.url, "/v1/mint/bolt11"), + json=payload, # type: ignore + ) + # BEGIN backwards compatibility < 0.15.0 + # assume the mint has not upgraded yet if we get a 404 + if resp.status_code == 404: + ret = await self.mint_deprecated(outputs, quote) + return ret + # END backwards compatibility < 0.15.0 + self.raise_on_error_request(resp) + response_dict = resp.json() + logger.trace("Lightning invoice checked. POST /v1/mint/bolt11") + promises = PostMintResponse.parse_obj(response_dict).signatures + return promises + + @async_set_httpx_client + @async_ensure_mint_loaded + async def melt_quote( + self, payment_request: str, unit: Unit, amount: Optional[int] = None + ) -> PostMeltQuoteResponse: + """Checks whether the Lightning payment is internal.""" + invoice_obj = bolt11.decode(payment_request) + assert invoice_obj.amount_msat, "invoice must have amount" + payload = PostMeltQuoteRequest( + unit=unit.name, request=payment_request, amount=amount + ) + resp = await self.httpx.post( + join(self.url, "/v1/melt/quote/bolt11"), + json=payload.dict(), + ) + # BEGIN backwards compatibility < 0.15.0 + # assume the mint has not upgraded yet if we get a 404 + if resp.status_code == 404: + ret: CheckFeesResponse_deprecated = await self.check_fees_deprecated( + payment_request + ) + quote_id = "deprecated_" + str(uuid.uuid4()) + return PostMeltQuoteResponse( + quote=quote_id, + amount=amount or invoice_obj.amount_msat // 1000, + fee_reserve=ret.fee or 0, + paid=False, + expiry=invoice_obj.expiry, + ) + # END backwards compatibility < 0.15.0 + self.raise_on_error_request(resp) + return_dict = resp.json() + return PostMeltQuoteResponse.parse_obj(return_dict) + + @async_set_httpx_client + @async_ensure_mint_loaded + async def melt( + self, + quote: str, + proofs: List[Proof], + outputs: Optional[List[BlindedMessage]], + ) -> PostMeltResponse: + """ + Accepts proofs and a lightning invoice to pay in exchange. + """ + + payload = PostMeltRequest(quote=quote, inputs=proofs, outputs=outputs) + + def _meltrequest_include_fields( + proofs: List[Proof], outputs: List[BlindedMessage] + ): + """strips away fields from the model that aren't necessary for the /melt""" + proofs_include = {"id", "amount", "secret", "C", "witness"} + outputs_include = {"id", "amount", "B_"} + return { + "quote": ..., + "inputs": {i: proofs_include for i in range(len(proofs))}, + "outputs": {i: outputs_include for i in range(len(outputs))}, + } + + resp = await self.httpx.post( + join(self.url, "/v1/melt/bolt11"), + json=payload.dict(include=_meltrequest_include_fields(proofs, outputs)), # type: ignore + timeout=None, + ) + # BEGIN backwards compatibility < 0.15.0 + # assume the mint has not upgraded yet if we get a 404 + if resp.status_code == 404: + invoice = await get_lightning_invoice(id=quote, db=self.db) + assert invoice, f"no invoice found for id {quote}" + ret: PostMeltResponse_deprecated = await self.melt_deprecated( + proofs=proofs, outputs=outputs, invoice=invoice.bolt11 + ) + return PostMeltResponse( + paid=ret.paid, payment_preimage=ret.preimage, change=ret.change + ) + # END backwards compatibility < 0.15.0 + self.raise_on_error_request(resp) + return_dict = resp.json() + return PostMeltResponse.parse_obj(return_dict) + + @async_set_httpx_client + @async_ensure_mint_loaded + async def split( + self, + proofs: List[Proof], + outputs: List[BlindedMessage], + ) -> List[BlindedSignature]: + """Consume proofs and create new promises based on amount split.""" + logger.debug("Calling split. POST /v1/swap") + split_payload = PostSplitRequest(inputs=proofs, outputs=outputs) + + # construct payload + def _splitrequest_include_fields(proofs: List[Proof]): + """strips away fields from the model that aren't necessary for /v1/swap""" + proofs_include = { + "id", + "amount", + "secret", + "C", + "witness", + } + return { + "outputs": ..., + "inputs": {i: proofs_include for i in range(len(proofs))}, + } + + resp = await self.httpx.post( + join(self.url, "/v1/swap"), + json=split_payload.dict(include=_splitrequest_include_fields(proofs)), # type: ignore + ) + # BEGIN backwards compatibility < 0.15.0 + # assume the mint has not upgraded yet if we get a 404 + if resp.status_code == 404: + ret = await self.split_deprecated(proofs, outputs) + return ret + # END backwards compatibility < 0.15.0 + self.raise_on_error_request(resp) + promises_dict = resp.json() + mint_response = PostSplitResponse.parse_obj(promises_dict) + promises = [BlindedSignature(**p.dict()) for p in mint_response.signatures] + + if len(promises) == 0: + raise Exception("received no splits.") + + return promises + + @async_set_httpx_client + @async_ensure_mint_loaded + async def check_proof_state(self, proofs: List[Proof]) -> PostCheckStateResponse: + """ + Checks whether the secrets in proofs are already spent or not and returns a list of booleans. + """ + payload = PostCheckStateRequest(Ys=[p.Y for p in proofs]) + resp = await self.httpx.post( + join(self.url, "/v1/checkstate"), + json=payload.dict(), + ) + # BEGIN backwards compatibility < 0.15.0 + # assume the mint has not upgraded yet if we get a 404 + if resp.status_code == 404: + ret = await self.check_proof_state_deprecated(proofs) + # convert CheckSpendableResponse_deprecated to CheckSpendableResponse + states: List[ProofState] = [] + for spendable, pending, p in zip(ret.spendable, ret.pending, proofs): + if spendable and not pending: + states.append(ProofState(Y=p.Y, state=SpentState.unspent)) + elif spendable and pending: + states.append(ProofState(Y=p.Y, state=SpentState.pending)) + else: + states.append(ProofState(Y=p.Y, state=SpentState.spent)) + ret = PostCheckStateResponse(states=states) + return ret + # END backwards compatibility < 0.15.0 + self.raise_on_error_request(resp) + return PostCheckStateResponse.parse_obj(resp.json()) + + @async_set_httpx_client + @async_ensure_mint_loaded + async def restore_promises( + self, outputs: List[BlindedMessage] + ) -> Tuple[List[BlindedMessage], List[BlindedSignature]]: + """ + Asks the mint to restore promises corresponding to outputs. + """ + payload = PostMintRequest(quote="restore", outputs=outputs) + resp = await self.httpx.post(join(self.url, "/v1/restore"), json=payload.dict()) + # BEGIN backwards compatibility < 0.15.0 + # assume the mint has not upgraded yet if we get a 404 + if resp.status_code == 404: + ret = await self.restore_promises_deprecated(outputs) + return ret + # END backwards compatibility < 0.15.0 + self.raise_on_error_request(resp) + response_dict = resp.json() + returnObj = PostRestoreResponse.parse_obj(response_dict) + + # BEGIN backwards compatibility < 0.15.1 + # if the mint returns promises, duplicate into signatures + if returnObj.promises: + returnObj.signatures = returnObj.promises + # END backwards compatibility < 0.15.1 + + return returnObj.outputs, returnObj.signatures diff --git a/cashu/wallet/wallet.py b/cashu/wallet/wallet.py index f0f3e573..107bb12a 100644 --- a/cashu/wallet/wallet.py +++ b/cashu/wallet/wallet.py @@ -1,64 +1,38 @@ -import base64 import copy -import json import time -import uuid -from itertools import groupby -from posixpath import join from typing import Dict, List, Optional, Tuple, Union import bolt11 -import httpx from bip32 import BIP32 -from httpx import Response from loguru import logger from ..core.base import ( BlindedMessage, BlindedSignature, - CheckFeesResponse_deprecated, DLEQWallet, - GetInfoResponse, Invoice, - KeysetsResponse, - KeysResponse, - PostCheckStateRequest, - PostCheckStateResponse, - PostMeltQuoteRequest, - PostMeltQuoteResponse, - PostMeltRequest, - PostMeltResponse, - PostMeltResponse_deprecated, - PostMintQuoteRequest, - PostMintQuoteResponse, - PostMintRequest, - PostMintResponse, - PostRestoreResponse, - PostSplitRequest, - PostSplitResponse, Proof, - ProofState, SpentState, - TokenV2, - TokenV2Mint, - TokenV3, - TokenV3Token, Unit, WalletKeyset, ) from ..core.crypto import b_dhke from ..core.crypto.secp import PrivateKey, PublicKey from ..core.db import Database -from ..core.helpers import calculate_number_of_blank_outputs, sum_proofs +from ..core.errors import KeysetNotFoundError +from ..core.helpers import amount_summary, calculate_number_of_blank_outputs, sum_proofs from ..core.migrations import migrate_databases +from ..core.models import ( + PostCheckStateResponse, + PostMeltQuoteResponse, + PostMeltResponse, +) from ..core.p2pk import Secret from ..core.settings import settings from ..core.split import amount_split -from ..tor.tor import TorProxy from ..wallet.crud import ( bump_secret_derivation, get_keysets, - get_lightning_invoice, get_proofs, invalidate_proof, secret_used, @@ -66,6 +40,7 @@ store_keyset, store_lightning_invoice, store_proof, + update_keyset, update_lightning_invoice, update_proof, ) @@ -73,652 +48,49 @@ from .htlc import WalletHTLC from .mint_info import MintInfo from .p2pk import WalletP2PK +from .proofs import WalletProofs from .secrets import WalletSecrets -from .wallet_deprecated import LedgerAPIDeprecated +from .transactions import WalletTransactions +from .v1_api import LedgerAPI -def async_set_httpx_client(func): +class Wallet( + LedgerAPI, WalletP2PK, WalletHTLC, WalletSecrets, WalletTransactions, WalletProofs +): """ - Decorator that wraps around any async class method of LedgerAPI that makes - API calls. Sets some HTTP headers and starts a Tor instance if none is - already running and and sets local proxy to use it. - """ - - async def wrapper(self, *args, **kwargs): - # set proxy - proxies_dict = {} - proxy_url: Union[str, None] = None - if settings.tor and TorProxy().check_platform(): - self.tor = TorProxy(timeout=True) - self.tor.run_daemon(verbose=True) - proxy_url = "socks5://localhost:9050" - elif settings.socks_proxy: - proxy_url = f"socks5://{settings.socks_proxy}" - elif settings.http_proxy: - proxy_url = settings.http_proxy - if proxy_url: - proxies_dict.update({"all://": proxy_url}) - - headers_dict = {"Client-version": settings.version} - - self.httpx = httpx.AsyncClient( - verify=not settings.debug, - proxies=proxies_dict, # type: ignore - headers=headers_dict, - base_url=self.url, - timeout=None if settings.debug else 60, - ) - return await func(self, *args, **kwargs) + Nutshell wallet class. - return wrapper + This class is the main interface to the Nutshell wallet. It is a subclass of the + LedgerAPI class, which provides the API methods to interact with the mint. + To use `Wallet`, initialize it with the mint URL and the path to the database directory. -def async_ensure_mint_loaded(func): - """Decorator that ensures that the mint is loaded before calling the wrapped - function. If the mint is not loaded, it will be loaded first. - """ + Initialize the wallet with `Wallet.with_db(url, db)`. This will load the private key and + all keysets from the database. - async def wrapper(self, *args, **kwargs): - if not self.keysets: - await self._load_mint() - return await func(self, *args, **kwargs) + Use `load_proofs` to load all proofs of the selected mint and unit from the database. - return wrapper + Use `load_mint` to load the public keys of the mint and fetch those that we don't have. + This will also load the mint info. + Use `mint_quote` to request a Lightning invoice for minting tokens. + Use `mint` to mint tokens of a specific amount after an invoice has been paid. + Use `melt_quote` to fetch a quote for paying a Lightning invoice. + Use `melt` to pay a Lightning invoice. + """ -class LedgerAPI(LedgerAPIDeprecated, object): keyset_id: str # holds current keyset id keysets: Dict[str, WalletKeyset] # holds keysets - mint_keyset_ids: List[str] # holds active keyset ids of the mint + # mint_keyset_ids: List[str] # holds active keyset ids of the mint unit: Unit mint_info: MintInfo # holds info about mint - tor: TorProxy - db: Database - httpx: httpx.AsyncClient - - def __init__(self, url: str, db: Database): - self.url = url - self.db = db - self.keysets = {} - - @async_set_httpx_client - async def _init_s(self): - """Dummy function that can be called from outside to use LedgerAPI.s""" - return - - @staticmethod - def raise_on_error_request( - resp: Response, - ) -> None: - """Raises an exception if the response from the mint contains an error. - - Args: - resp_dict (Response): Response dict (previously JSON) from mint - - Raises: - Exception: if the response contains an error - """ - try: - resp_dict = resp.json() - except json.JSONDecodeError: - # if we can't decode the response, raise for status - resp.raise_for_status() - return - if "detail" in resp_dict: - logger.trace(f"Error from mint: {resp_dict}") - error_message = f"Mint Error: {resp_dict['detail']}" - if "code" in resp_dict: - error_message += f" (Code: {resp_dict['code']})" - raise Exception(error_message) - # raise for status if no error - resp.raise_for_status() - - async def _load_mint_keys(self, keyset_id: Optional[str] = None) -> None: - """Loads keys from mint and stores them in the database. - - Args: - keyset_id (str, optional): keyset id to load. If given, requests keys for this keyset - from the mint. If not given, requests current keyset of the mint. Defaults to "". - - Raises: - AssertionError: if mint URL is not set - AssertionError: if no keys are received from the mint - """ - logger.trace(f"Loading mint keys: {keyset_id}") - assert len( - self.url - ), "Ledger not initialized correctly: mint URL not specified yet. " - - keyset: WalletKeyset - - # if we want to load a specific keyset - if keyset_id: - # check if this keyset is in db - logger.trace(f"Loading keyset {keyset_id} from database.") - keysets = await get_keysets(keyset_id, db=self.db) - if keysets: - logger.debug(f"Found keyset {keyset_id} in database.") - # select as current keyset - keyset = keysets[0] - else: - logger.trace( - f"Could not find keyset {keyset_id} in database. Loading keyset" - " from mint." - ) - keyset = await self._get_keys_of_keyset(keyset_id) - if keyset.id == keyset_id: - # NOTE: Derived keyset *could* have a different id than the one - # requested because of the duplicate keysets for < 0.15.0 that's - # why we make an explicit check here to not overwrite an existing - # keyset with the incoming one. - logger.debug( - f"Storing new mint keyset: {keyset.id} ({keyset.unit.name})" - ) - await store_keyset(keyset=keyset, db=self.db) - keysets = [keyset] - else: - # else we load all active keysets of the mint and choose - # an appropriate one as the current keyset - keysets = await self._get_keys() - assert len(keysets), Exception("did not receive any keys") - # check if we have all keysets in db - for keyset in keysets: - keysets_in_db = await get_keysets(keyset.id, db=self.db) - if not keysets_in_db: - logger.debug( - "Storing new current mint keyset:" - f" {keyset.id} ({keyset.unit.name})" - ) - await store_keyset(keyset=keyset, db=self.db) - - # select a keyset that matches the wallet unit - wallet_unit_keysets = [k for k in keysets if k.unit == self.unit] - assert len(wallet_unit_keysets) > 0, f"no keyset for unit {self.unit.name}." - keyset = [k for k in keysets if k.unit == self.unit][0] - - # load all keysets we have into memory - for k in keysets: - self.keysets[k.id] = k - - # make sure we have selected a current keyset - assert keyset - assert keyset.id - assert len(keyset.public_keys) > 0, "no public keys in keyset" - # set current keyset id - self.keyset_id = keyset.id - logger.debug(f"Current mint keyset: {self.keyset_id}") - - async def _load_mint_keysets(self) -> List[str]: - """Loads the keyset IDs of the mint. - - Returns: - List[str]: list of keyset IDs of the mint - - Raises: - AssertionError: if no keysets are received from the mint - """ - logger.trace("Loading mint keysets.") - mint_keysets = [] - try: - mint_keysets = await self._get_keyset_ids() - except Exception: - assert self.keysets[ - self.keyset_id - ].id, "could not get keysets from mint, and do not have keys" - pass - self.mint_keyset_ids = mint_keysets or [self.keysets[self.keyset_id].id] - logger.debug(f"Mint keysets: {self.mint_keyset_ids}") - return self.mint_keyset_ids - - async def _load_mint_info(self) -> MintInfo: - """Loads the mint info from the mint.""" - mint_info_resp = await self._get_info() - self.mint_info = MintInfo(**mint_info_resp.dict()) - logger.debug(f"Mint info: {self.mint_info}") - return self.mint_info - - async def _load_mint(self, keyset_id: str = "") -> None: - """ - Loads the public keys of the mint. Either gets the keys for the specified - `keyset_id` or gets the keys of the active keyset from the mint. - Gets the active keyset ids of the mint and stores in `self.mint_keyset_ids`. - """ - logger.trace("Loading mint.") - await self._load_mint_keys(keyset_id) - await self._load_mint_keysets() - try: - await self._load_mint_info() - except Exception as e: - logger.debug(f"Could not load mint info: {e}") - pass - - if keyset_id: - assert ( - keyset_id in self.mint_keyset_ids - ), f"keyset {keyset_id} not active on mint" - - async def _check_used_secrets(self, secrets): - """Checks if any of the secrets have already been used""" - logger.trace("Checking secrets.") - for s in secrets: - if await secret_used(s, db=self.db): - raise Exception(f"secret already used: {s}") - logger.trace("Secret check complete.") - - """ - ENDPOINTS - """ - - @async_set_httpx_client - async def _get_keys(self) -> List[WalletKeyset]: - """API that gets the current keys of the mint - - Args: - url (str): Mint URL - - Returns: - WalletKeyset: Current mint keyset - - Raises: - Exception: If no keys are received from the mint - """ - resp = await self.httpx.get( - join(self.url, "/v1/keys"), - ) - # BEGIN backwards compatibility < 0.15.0 - # assume the mint has not upgraded yet if we get a 404 - if resp.status_code == 404: - ret = await self._get_keys_deprecated(self.url) - return [ret] - # END backwards compatibility < 0.15.0 - self.raise_on_error_request(resp) - keys_dict: dict = resp.json() - assert len(keys_dict), Exception("did not receive any keys") - keys = KeysResponse.parse_obj(keys_dict) - logger.debug( - f"Received {len(keys.keysets)} keysets from mint:" - f" {' '.join([k.id + f' ({k.unit})' for k in keys.keysets])}." - ) - ret = [ - WalletKeyset( - id=keyset.id, - unit=keyset.unit, - public_keys={ - int(amt): PublicKey(bytes.fromhex(val), raw=True) - for amt, val in keyset.keys.items() - }, - mint_url=self.url, - ) - for keyset in keys.keysets - ] - return ret - - @async_set_httpx_client - async def _get_keys_of_keyset(self, keyset_id: str) -> WalletKeyset: - """API that gets the keys of a specific keyset from the mint. - - - Args: - keyset_id (str): base64 keyset ID, needs to be urlsafe-encoded before sending to mint (done in this method) - - Returns: - WalletKeyset: Keyset with ID keyset_id - - Raises: - Exception: If no keys are received from the mint - """ - keyset_id_urlsafe = keyset_id.replace("+", "-").replace("/", "_") - resp = await self.httpx.get( - join(self.url, f"/v1/keys/{keyset_id_urlsafe}"), - ) - # BEGIN backwards compatibility < 0.15.0 - # assume the mint has not upgraded yet if we get a 404 - if resp.status_code == 404: - ret = await self._get_keys_of_keyset_deprecated(self.url, keyset_id) - return ret - # END backwards compatibility < 0.15.0 - self.raise_on_error_request(resp) - - keys_dict = resp.json() - assert len(keys_dict), Exception("did not receive any keys") - keys = KeysResponse.parse_obj(keys_dict) - keyset_keys = { - int(amt): PublicKey(bytes.fromhex(val), raw=True) - for amt, val in keys.keysets[0].keys.items() - } - keyset = WalletKeyset( - id=keyset_id, - unit=keys.keysets[0].unit, - public_keys=keyset_keys, - mint_url=self.url, - ) - return keyset - - @async_set_httpx_client - async def _get_keyset_ids(self) -> List[str]: - """API that gets a list of all active keysets of the mint. - - Returns: - KeysetsResponse (List[str]): List of all active keyset IDs of the mint - - Raises: - Exception: If no keysets are received from the mint - """ - resp = await self.httpx.get( - join(self.url, "/v1/keysets"), - ) - # BEGIN backwards compatibility < 0.15.0 - # assume the mint has not upgraded yet if we get a 404 - if resp.status_code == 404: - ret = await self._get_keyset_ids_deprecated(self.url) - return ret - # END backwards compatibility < 0.15.0 - self.raise_on_error_request(resp) - - keysets_dict = resp.json() - keysets = KeysetsResponse.parse_obj(keysets_dict) - assert len(keysets.keysets), Exception("did not receive any keysets") - return [k.id for k in keysets.keysets] - - @async_set_httpx_client - async def _get_info(self) -> GetInfoResponse: - """API that gets the mint info. - - Returns: - GetInfoResponse: Current mint info - - Raises: - Exception: If the mint info request fails - """ - resp = await self.httpx.get( - join(self.url, "/v1/info"), - ) - # BEGIN backwards compatibility < 0.15.0 - # assume the mint has not upgraded yet if we get a 404 - if resp.status_code == 404: - ret = await self._get_info_deprecated() - return ret - # END backwards compatibility < 0.15.0 - self.raise_on_error_request(resp) - data: dict = resp.json() - mint_info: GetInfoResponse = GetInfoResponse.parse_obj(data) - return mint_info - - @async_set_httpx_client - @async_ensure_mint_loaded - async def mint_quote(self, amount) -> PostMintQuoteResponse: - """Requests a mint quote from the server and returns a payment request. - - Args: - amount (int): Amount of tokens to mint - - Returns: - PostMintQuoteResponse: Mint Quote Response - - Raises: - Exception: If the mint request fails - """ - logger.trace("Requesting mint: GET /v1/mint/bolt11") - payload = PostMintQuoteRequest(unit=self.unit.name, amount=amount) - resp = await self.httpx.post( - join(self.url, "/v1/mint/quote/bolt11"), json=payload.dict() - ) - # BEGIN backwards compatibility < 0.15.0 - # assume the mint has not upgraded yet if we get a 404 - if resp.status_code == 404: - ret = await self.request_mint_deprecated(amount) - return ret - # END backwards compatibility < 0.15.0 - self.raise_on_error_request(resp) - return_dict = resp.json() - return PostMintQuoteResponse.parse_obj(return_dict) - - @async_set_httpx_client - @async_ensure_mint_loaded - async def mint( - self, outputs: List[BlindedMessage], quote: str - ) -> List[BlindedSignature]: - """Mints new coins and returns a proof of promise. - - Args: - outputs (List[BlindedMessage]): Outputs to mint new tokens with - quote (str): Quote ID. - - Returns: - list[Proof]: List of proofs. - - Raises: - Exception: If the minting fails - """ - outputs_payload = PostMintRequest(outputs=outputs, quote=quote) - logger.trace("Checking Lightning invoice. POST /v1/mint/bolt11") - - def _mintrequest_include_fields(outputs: List[BlindedMessage]): - """strips away fields from the model that aren't necessary for the /mint""" - outputs_include = {"id", "amount", "B_"} - return { - "quote": ..., - "outputs": {i: outputs_include for i in range(len(outputs))}, - } - - payload = outputs_payload.dict(include=_mintrequest_include_fields(outputs)) # type: ignore - resp = await self.httpx.post( - join(self.url, "/v1/mint/bolt11"), - json=payload, # type: ignore - ) - # BEGIN backwards compatibility < 0.15.0 - # assume the mint has not upgraded yet if we get a 404 - if resp.status_code == 404: - ret = await self.mint_deprecated(outputs, quote) - return ret - # END backwards compatibility < 0.15.0 - self.raise_on_error_request(resp) - response_dict = resp.json() - logger.trace("Lightning invoice checked. POST /v1/mint/bolt11") - promises = PostMintResponse.parse_obj(response_dict).signatures - return promises - - @async_set_httpx_client - @async_ensure_mint_loaded - async def melt_quote( - self, payment_request: str, amount: Optional[int] = None - ) -> PostMeltQuoteResponse: - """Checks whether the Lightning payment is internal.""" - invoice_obj = bolt11.decode(payment_request) - assert invoice_obj.amount_msat, "invoice must have amount" - payload = PostMeltQuoteRequest( - unit=self.unit.name, request=payment_request, amount=amount - ) - resp = await self.httpx.post( - join(self.url, "/v1/melt/quote/bolt11"), - json=payload.dict(), - ) - # BEGIN backwards compatibility < 0.15.0 - # assume the mint has not upgraded yet if we get a 404 - if resp.status_code == 404: - ret: CheckFeesResponse_deprecated = await self.check_fees_deprecated( - payment_request - ) - quote_id = "deprecated_" + str(uuid.uuid4()) - return PostMeltQuoteResponse( - quote=quote_id, - amount=amount or invoice_obj.amount_msat // 1000, - fee_reserve=ret.fee or 0, - paid=False, - expiry=invoice_obj.expiry, - ) - # END backwards compatibility < 0.15.0 - self.raise_on_error_request(resp) - return_dict = resp.json() - return PostMeltQuoteResponse.parse_obj(return_dict) - - @async_set_httpx_client - @async_ensure_mint_loaded - async def melt( - self, - quote: str, - proofs: List[Proof], - outputs: Optional[List[BlindedMessage]], - ) -> PostMeltResponse: - """ - Accepts proofs and a lightning invoice to pay in exchange. - """ - - payload = PostMeltRequest(quote=quote, inputs=proofs, outputs=outputs) - - def _meltrequest_include_fields( - proofs: List[Proof], outputs: List[BlindedMessage] - ): - """strips away fields from the model that aren't necessary for the /melt""" - proofs_include = {"id", "amount", "secret", "C", "witness"} - outputs_include = {"id", "amount", "B_"} - return { - "quote": ..., - "inputs": {i: proofs_include for i in range(len(proofs))}, - "outputs": {i: outputs_include for i in range(len(outputs))}, - } - - resp = await self.httpx.post( - join(self.url, "/v1/melt/bolt11"), - json=payload.dict(include=_meltrequest_include_fields(proofs, outputs)), # type: ignore - timeout=None, - ) - # BEGIN backwards compatibility < 0.15.0 - # assume the mint has not upgraded yet if we get a 404 - if resp.status_code == 404: - invoice = await get_lightning_invoice(id=quote, db=self.db) - assert invoice, f"no invoice found for id {quote}" - ret: PostMeltResponse_deprecated = await self.melt_deprecated( - proofs=proofs, outputs=outputs, invoice=invoice.bolt11 - ) - return PostMeltResponse( - paid=ret.paid, payment_preimage=ret.preimage, change=ret.change - ) - # END backwards compatibility < 0.15.0 - self.raise_on_error_request(resp) - return_dict = resp.json() - return PostMeltResponse.parse_obj(return_dict) - - @async_set_httpx_client - @async_ensure_mint_loaded - async def split( - self, - proofs: List[Proof], - outputs: List[BlindedMessage], - ) -> List[BlindedSignature]: - """Consume proofs and create new promises based on amount split.""" - logger.debug("Calling split. POST /v1/swap") - split_payload = PostSplitRequest(inputs=proofs, outputs=outputs) - - # construct payload - def _splitrequest_include_fields(proofs: List[Proof]): - """strips away fields from the model that aren't necessary for /v1/swap""" - proofs_include = { - "id", - "amount", - "secret", - "C", - "witness", - } - return { - "outputs": ..., - "inputs": {i: proofs_include for i in range(len(proofs))}, - } - - resp = await self.httpx.post( - join(self.url, "/v1/swap"), - json=split_payload.dict(include=_splitrequest_include_fields(proofs)), # type: ignore - ) - # BEGIN backwards compatibility < 0.15.0 - # assume the mint has not upgraded yet if we get a 404 - if resp.status_code == 404: - ret = await self.split_deprecated(proofs, outputs) - return ret - # END backwards compatibility < 0.15.0 - self.raise_on_error_request(resp) - promises_dict = resp.json() - mint_response = PostSplitResponse.parse_obj(promises_dict) - promises = [BlindedSignature(**p.dict()) for p in mint_response.signatures] - - if len(promises) == 0: - raise Exception("received no splits.") - - return promises - - @async_set_httpx_client - @async_ensure_mint_loaded - async def check_proof_state(self, proofs: List[Proof]) -> PostCheckStateResponse: - """ - Checks whether the secrets in proofs are already spent or not and returns a list of booleans. - """ - payload = PostCheckStateRequest(Ys=[p.Y for p in proofs]) - resp = await self.httpx.post( - join(self.url, "/v1/checkstate"), - json=payload.dict(), - ) - # BEGIN backwards compatibility < 0.15.0 - # assume the mint has not upgraded yet if we get a 404 - if resp.status_code == 404: - ret = await self.check_proof_state_deprecated(proofs) - # convert CheckSpendableResponse_deprecated to CheckSpendableResponse - states: List[ProofState] = [] - for spendable, pending, p in zip(ret.spendable, ret.pending, proofs): - if spendable and not pending: - states.append(ProofState(Y=p.Y, state=SpentState.unspent)) - elif spendable and pending: - states.append(ProofState(Y=p.Y, state=SpentState.pending)) - else: - states.append(ProofState(Y=p.Y, state=SpentState.spent)) - ret = PostCheckStateResponse(states=states) - return ret - # END backwards compatibility < 0.15.0 - self.raise_on_error_request(resp) - return PostCheckStateResponse.parse_obj(resp.json()) - - @async_set_httpx_client - @async_ensure_mint_loaded - async def restore_promises( - self, outputs: List[BlindedMessage] - ) -> Tuple[List[BlindedMessage], List[BlindedSignature]]: - """ - Asks the mint to restore promises corresponding to outputs. - """ - payload = PostMintRequest(quote="restore", outputs=outputs) - resp = await self.httpx.post(join(self.url, "/v1/restore"), json=payload.dict()) - # BEGIN backwards compatibility < 0.15.0 - # assume the mint has not upgraded yet if we get a 404 - if resp.status_code == 404: - ret = await self.restore_promises_deprecated(outputs) - return ret - # END backwards compatibility < 0.15.0 - self.raise_on_error_request(resp) - response_dict = resp.json() - returnObj = PostRestoreResponse.parse_obj(response_dict) - - # BEGIN backwards compatibility < 0.15.1 - # if the mint returns promises, duplicate into signatures - if returnObj.promises: - returnObj.signatures = returnObj.promises - # END backwards compatibility < 0.15.1 - - return returnObj.outputs, returnObj.signatures - - -class Wallet(LedgerAPI, WalletP2PK, WalletHTLC, WalletSecrets): - """Minimal wallet wrapper.""" - mnemonic: str # holds mnemonic of the wallet seed: bytes # holds private key of the wallet generated from the mnemonic - # db: Database + db: Database bip32: BIP32 # private_key: Optional[PrivateKey] = None - def __init__( - self, - url: str, - db: str, - name: str = "no_name", - ): + def __init__(self, url: str, db: str, name: str = "no_name", unit: str = "sat"): """A Cashu wallet. Args: @@ -729,7 +101,7 @@ def __init__( self.db = Database("wallet", db) self.proofs: List[Proof] = [] self.name = name - self.unit = Unit[settings.wallet_unit] + self.unit = Unit[unit] super().__init__(url=url, db=self.db) logger.debug("Wallet initialized") @@ -744,6 +116,7 @@ async def with_db( db: str, name: str = "no_name", skip_db_read: bool = False, + unit: str = "sat", ): """Initializes a wallet with a database and initializes the private key. @@ -759,13 +132,17 @@ async def with_db( Wallet: Initialized wallet. """ logger.trace(f"Initializing wallet with database: {db}") - self = cls(url=url, db=db, name=name) + self = cls(url=url, db=db, name=name, unit=unit) await self._migrate_database() if not skip_db_read: logger.trace("Mint init: loading private key and keysets from db.") await self._init_private_key() keysets_list = await get_keysets(mint_url=url, db=self.db) - self.keysets = {k.id: k for k in keysets_list} + keysets_active_unit = [k for k in keysets_list if k.unit == self.unit] + self.keysets = {k.id: k for k in keysets_active_unit} + logger.debug( + f"Loaded keysets: {' '.join([k.id + f' {k.unit}' for k in keysets_active_unit])}" + ) return self @@ -778,40 +155,162 @@ async def _migrate_database(self): # ---------- API ---------- - async def load_mint(self, keyset_id: str = ""): - """Load a mint's keys with a given keyset_id if specified or else - loads the active keyset of the mint into self.keys. - Also loads all keyset ids into self.mint_keyset_ids. + async def load_mint_info(self) -> MintInfo: + """Loads the mint info from the mint.""" + mint_info_resp = await self._get_info() + self.mint_info = MintInfo(**mint_info_resp.dict()) + logger.debug(f"Mint info: {self.mint_info}") + return self.mint_info + + async def load_mint_keysets(self): + """Loads all keyset of the mint and makes sure we have them all in the database. - Args: - keyset_id (str, optional): _description_. Defaults to "". + Then loads all keysets from the database for the active mint and active unit into self.keysets. """ - await super()._load_mint(keyset_id) + logger.trace("Loading mint keysets.") + mint_keysets_resp = await self._get_keysets() + mint_keysets_dict = {k.id: k for k in mint_keysets_resp} + + # load all keysets of thisd mint from the db + keysets_in_db = await get_keysets(mint_url=self.url, db=self.db) + + # db is empty, get all keys from the mint and store them + if not keysets_in_db: + all_keysets = await self._get_keys() + for keyset in all_keysets: + keyset.active = mint_keysets_dict[keyset.id].active + keyset.input_fee_ppk = mint_keysets_dict[keyset.id].input_fee_ppk or 0 + await store_keyset(keyset=keyset, db=self.db) + + keysets_in_db = await get_keysets(mint_url=self.url, db=self.db) + keysets_in_db_dict = {k.id: k for k in keysets_in_db} + + # get all new keysets that are not in memory yet and store them in the database + for mint_keyset in mint_keysets_dict.values(): + if mint_keyset.id not in keysets_in_db_dict: + logger.debug( + f"Storing new mint keyset: {mint_keyset.id} ({mint_keyset.unit})" + ) + wallet_keyset = await self._get_keyset(mint_keyset.id) + wallet_keyset.active = mint_keyset.active + wallet_keyset.input_fee_ppk = mint_keyset.input_fee_ppk or 0 + await store_keyset(keyset=wallet_keyset, db=self.db) + + for mint_keyset in mint_keysets_dict.values(): + # if the active or the fee attributes have changed, update them in the database + if mint_keyset.id in keysets_in_db_dict: + changed = False + if mint_keyset.active != keysets_in_db_dict[mint_keyset.id].active: + keysets_in_db_dict[mint_keyset.id].active = mint_keyset.active + changed = True + if ( + mint_keyset.input_fee_ppk + and mint_keyset.input_fee_ppk + != keysets_in_db_dict[mint_keyset.id].input_fee_ppk + ): + keysets_in_db_dict[ + mint_keyset.id + ].input_fee_ppk = mint_keyset.input_fee_ppk + changed = True + if changed: + await update_keyset( + keyset=keysets_in_db_dict[mint_keyset.id], db=self.db + ) - async def load_proofs( - self, reload: bool = False, unit: Union[Unit, bool] = True - ) -> None: - """Load all proofs from the database.""" + await self.load_keysets_from_db() + + async def activate_keyset(self, keyset_id: Optional[str] = None) -> None: + """Activates a keyset by setting self.keyset_id. Either activates a specific keyset + of chooses one of the active keysets of the mint with the same unit as the wallet. + """ + + if keyset_id: + if keyset_id not in self.keysets: + await self.load_mint_keysets() + + if keyset_id not in self.keysets: + raise KeysetNotFoundError(keyset_id) + + if self.keysets[keyset_id].unit != self.unit: + raise Exception( + f"Keyset {keyset_id} has unit {self.keysets[keyset_id].unit.name}," + f" but wallet has unit {self.unit.name}." + ) + + if not self.keysets[keyset_id].active: + raise Exception(f"Keyset {keyset_id} is not active.") + + self.keyset_id = keyset_id + else: + # if no keyset_id is given, choose an active keyset with the same unit as the wallet + chosen_keyset = None + for keyset in self.keysets.values(): + if keyset.unit == self.unit and keyset.active: + chosen_keyset = keyset + break + + if not chosen_keyset: + raise Exception(f"No active keyset found for unit {self.unit.name}.") + + self.keyset_id = chosen_keyset.id + + logger.debug(f"Activated keyset {self.keyset_id}") + + async def load_mint(self, keyset_id: str = "") -> None: + """ + Loads the public keys of the mint. Either gets the keys for the specified + `keyset_id` or gets the keys of the active keyset from the mint. + Gets the active keyset ids of the mint and stores in `self.mint_keyset_ids`. + """ + logger.trace("Loading mint.") + await self.load_mint_keysets() + await self.activate_keyset(keyset_id) + try: + await self.load_mint_info() + except Exception as e: + logger.debug(f"Could not load mint info: {e}") + pass + + async def load_proofs(self, reload: bool = False) -> None: + """Load all proofs of the selected mint and unit (i.e. self.keysets) into memory.""" if self.proofs and not reload: logger.debug("Proofs already loaded.") return - self.proofs = await get_proofs(db=self.db) - await self.load_keysets() - unit = self.unit if unit is True else unit - if unit: - self.unit = unit - self.proofs = [ - p - for p in self.proofs - if p.id in self.keysets and self.keysets[p.id].unit == unit - ] - - async def load_keysets(self) -> None: - """Load all keysets from the database.""" - keysets = await get_keysets(db=self.db) + + self.proofs = [] + await self.load_keysets_from_db() + async with self.db.connect() as conn: + for keyset_id in self.keysets: + proofs = await get_proofs(db=self.db, id=keyset_id, conn=conn) + self.proofs.extend(proofs) + logger.trace( + f"Proofs loaded for keysets: {' '.join([k.id + f' ({k.unit})' for k in self.keysets.values()])}" + ) + + async def load_keysets_from_db( + self, url: Union[str, None] = "", unit: Union[str, None] = "" + ): + """Load all keysets of the selected mint and unit from the database into self.keysets.""" + # so that the caller can set unit = None, otherwise use defaults + if unit == "": + unit = self.unit.name + if url == "": + url = self.url + keysets = await get_keysets(mint_url=url, unit=unit, db=self.db) for keyset in keysets: self.keysets[keyset.id] = keyset + logger.trace( + f"Loaded keysets from db: {[(k.id, k.unit.name, k.input_fee_ppk) for k in self.keysets.values()]}" + ) + + async def _check_used_secrets(self, secrets): + """Checks if any of the secrets have already been used""" + logger.trace("Checking secrets.") + for s in secrets: + if await secret_used(s, db=self.db): + raise Exception(f"secret already used: {s}") + logger.trace("Secret check complete.") async def request_mint(self, amount: int) -> Invoice: """Request a Lightning invoice for minting tokens. @@ -822,7 +321,73 @@ async def request_mint(self, amount: int) -> Invoice: Returns: PostMintQuoteResponse: Mint Quote Response """ - mint_quote_response = await super().mint_quote(amount) + mint_quote_response = await super().mint_quote(amount, self.unit) + decoded_invoice = bolt11.decode(mint_quote_response.request) + invoice = Invoice( + amount=amount, + bolt11=mint_quote_response.request, + payment_hash=decoded_invoice.payment_hash, + id=mint_quote_response.quote, + out=False, + time_created=int(time.time()), + ) + await store_lightning_invoice(db=self.db, invoice=invoice) + return invoice + + def split_wallet_state(self, amount: int) -> List[int]: + """This function produces an amount split for outputs based on the current state of the wallet. + Its objective is to fill up the wallet so that it reaches `n_target` coins of each amount. + + Args: + amount (int): Amount to split + + Returns: + List[int]: List of amounts to mint + """ + # read the target count for each amount from settings + n_target = settings.wallet_target_amount_count + amounts_we_have = [p.amount for p in self.proofs if p.reserved is not True] + amounts_we_have.sort() + # NOTE: Do not assume 2^n here + all_possible_amounts: list[int] = [2**i for i in range(settings.max_order)] + amounts_we_want_ll = [ + [a] * max(0, n_target - amounts_we_have.count(a)) + for a in all_possible_amounts + ] + # flatten list of lists to list + amounts_we_want = [item for sublist in amounts_we_want_ll for item in sublist] + # sort by increasing amount + amounts_we_want.sort() + + logger.debug( + f"Amounts we have: {[(a, amounts_we_have.count(a)) for a in set(amounts_we_have)]}" + ) + amounts: list[int] = [] + while sum(amounts) < amount and amounts_we_want: + if sum(amounts) + amounts_we_want[0] > amount: + break + amounts.append(amounts_we_want.pop(0)) + + remaining_amount = amount - sum(amounts) + if remaining_amount > 0: + amounts += amount_split(remaining_amount) + + logger.debug(f"Amounts we want: {amounts}") + if sum(amounts) != amount: + raise Exception(f"Amounts do not sum to {amount}.") + + return amounts + + async def mint_quote(self, amount: int) -> Invoice: + """Request a Lightning invoice for minting tokens. + + Args: + amount (int): Amount for Lightning invoice in satoshis + + Returns: + Invoice: Lightning invoice for minting tokens + """ + mint_quote_response = await super().mint_quote(amount, self.unit) decoded_invoice = bolt11.decode(mint_quote_response.request) invoice = Invoice( amount=amount, @@ -866,8 +431,10 @@ async def mint( f"Can only mint amounts with 2^n up to {2**settings.max_order}." ) + # split based on our wallet state + amounts = split or self.split_wallet_state(amount) # if no split was specified, we use the canonical split - amounts = split or amount_split(amount) + # amounts = split or amount_split(amount) # quirk: we skip bumping the secret counter in the database since we are # not sure if the minting will succeed. If it succeeds, we will bump it @@ -911,7 +478,38 @@ async def redeem( """ # verify DLEQ of incoming proofs self.verify_proofs_dleq(proofs) - return await self.split(proofs, sum_proofs(proofs)) + return await self.split(proofs=proofs, amount=0) + + def swap_send_and_keep_output_amounts( + self, proofs: List[Proof], amount: int, fees: int = 0 + ) -> Tuple[List[int], List[int]]: + """This function generates a suitable amount split for the outputs to keep and the outputs to send. It + calculates the amount to keep based on the wallet state and the amount to send based on the amount + provided. + + Args: + proofs (List[Proof]): Proofs to be split. + amount (int): Amount to be sent. + + Returns: + Tuple[List[int], List[int]]: Two lists of amounts, one for keeping and one for sending. + """ + # create a suitable amount split based on the proofs provided + total = sum_proofs(proofs) + keep_amt, send_amt = total - amount, amount + logger.trace(f"Keep amount: {keep_amt}, send amount: {send_amt}") + logger.trace(f"Total input: {sum_proofs(proofs)}") + # generate splits for outputs + send_outputs = amount_split(send_amt) + + # we subtract the fee for the entire transaction from the amount to keep + keep_amt -= self.get_fees_for_proofs(proofs) + logger.trace(f"Keep amount: {keep_amt}") + + # we determine the amounts to keep based on the wallet state + keep_outputs = self.split_wallet_state(keep_amt) + + return keep_outputs, send_outputs async def split( self, @@ -919,53 +517,45 @@ async def split( amount: int, secret_lock: Optional[Secret] = None, ) -> Tuple[List[Proof], List[Proof]]: - """If secret_lock is None, random secrets will be generated for the tokens to keep (frst_outputs) - and the promises to send (scnd_outputs). + """Calls the swap API to split the proofs into two sets of proofs, one for keeping and one for sending. - If secret_lock is provided, the wallet will create blinded secrets with those to attach a - predefined spending condition to the tokens they want to send. + If secret_lock is None, random secrets will be generated for the tokens to keep (keep_outputs) + and the promises to send (send_outputs). If secret_lock is provided, the wallet will create + blinded secrets with those to attach a predefined spending condition to the tokens they want to send. Args: - proofs (List[Proof]): _description_ - amount (int): _description_ - secret_lock (Optional[Secret], optional): _description_. Defaults to None. + proofs (List[Proof]): Proofs to be split. + amount (int): Amount to be sent. + secret_lock (Optional[Secret], optional): Secret to lock the tokens to be sent. Defaults to None. Returns: - _type_: _description_ + Tuple[List[Proof], List[Proof]]: Two lists of proofs, one for keeping and one for sending. """ assert len(proofs) > 0, "no proofs provided." assert sum_proofs(proofs) >= amount, "amount too large." - assert amount > 0, "amount must be positive." + assert amount >= 0, "amount can't be negative." # make sure we're operating on an independent copy of proofs proofs = copy.copy(proofs) # potentially add witnesses to unlock provided proofs (if they indicate one) proofs = await self.add_witnesses_to_proofs(proofs) - # create a suitable amount split based on the proofs provided - total = sum_proofs(proofs) - frst_amt, scnd_amt = total - amount, amount - frst_outputs = amount_split(frst_amt) - scnd_outputs = amount_split(scnd_amt) + input_fees = self.get_fees_for_proofs(proofs) + logger.debug(f"Input fees: {input_fees}") + # create a suitable amount lists to keep and send based on the proofs + # provided and the state of the wallet + keep_outputs, send_outputs = self.swap_send_and_keep_output_amounts( + proofs, amount, input_fees + ) - amounts = frst_outputs + scnd_outputs + amounts = keep_outputs + send_outputs # generate secrets for new outputs if secret_lock is None: secrets, rs, derivation_paths = await self.generate_n_secrets(len(amounts)) else: - # NOTE: we use random blinding factors for locks, we won't be able to - # restore these tokens from a backup - rs = [] - # generate secrets for receiver - secret_locks = [secret_lock.serialize() for i in range(len(scnd_outputs))] - logger.debug(f"Creating proofs with custom secrets: {secret_locks}") - # append predefined secrets (to send) to random secrets (to keep) - # generate secrets to keep - secrets = [ - await self._generate_secret() for s in range(len(frst_outputs)) - ] + secret_locks - # TODO: derive derivation paths from secrets - derivation_paths = ["custom"] * len(secrets) + secrets, rs, derivation_paths = await self.generate_locked_secrets( + send_outputs, keep_outputs, secret_lock + ) assert len(secrets) == len( amounts @@ -989,11 +579,11 @@ async def split( await self.invalidate(proofs) - keep_proofs = new_proofs[: len(frst_outputs)] - send_proofs = new_proofs[len(frst_outputs) :] + keep_proofs = new_proofs[: len(keep_outputs)] + send_proofs = new_proofs[len(keep_outputs) :] return keep_proofs, send_proofs - async def request_melt( + async def melt_quote( self, invoice: str, amount: Optional[int] = None ) -> PostMeltQuoteResponse: """ @@ -1001,7 +591,7 @@ async def request_melt( """ if amount and not self.mint_info.supports_mpp("bolt11", self.unit): raise Exception("Mint does not support MPP, cannot specify amount.") - melt_quote = await self.melt_quote(invoice, amount) + melt_quote = await super().melt_quote(invoice, self.unit, amount) logger.debug( f"Mint wants {self.unit.str(melt_quote.fee_reserve)} as fee reserve." ) @@ -1148,8 +738,8 @@ async def _construct_proofs( for promise, secret, r, path in zip(promises, secrets, rs, derivation_paths): if promise.id not in self.keysets: logger.debug(f"Keyset {promise.id} not found in db. Loading from mint.") - # we don't have the keyset for this promise, so we load it - await self._load_mint_keys(promise.id) + # we don't have the keyset for this promise, so we load all keysets from the mint + await self.load_mint_keysets() assert promise.id in self.keysets, "Could not load keyset." C_ = PublicKey(bytes.fromhex(promise.C_), raw=True) C = b_dhke.step3_alice( @@ -1242,6 +832,18 @@ def _construct_outputs( return outputs, rs_return + async def construct_outputs(self, amounts: List[int]) -> List[BlindedMessage]: + """Constructs outputs for a list of amounts. + + Args: + amounts (List[int]): List of amounts to construct outputs for. + + Returns: + List[BlindedMessage]: List of blinded messages that can be sent to the mint. + """ + secrets, rs, _ = await self.generate_n_secrets(len(amounts)) + return self._construct_outputs(amounts, secrets, rs)[0] + async def _store_proofs(self, proofs): try: async with self.db.connect() as conn: @@ -1252,255 +854,6 @@ async def _store_proofs(self, proofs): logger.error(proofs) raise e - @staticmethod - def _get_proofs_per_keyset(proofs: List[Proof]): - return { - key: list(group) for key, group in groupby(proofs, lambda p: p.id) if key - } - - async def _get_proofs_per_minturl( - self, proofs: List[Proof], unit: Optional[Unit] = None - ) -> Dict[str, List[Proof]]: - ret: Dict[str, List[Proof]] = {} - keyset_ids = set([p.id for p in proofs]) - for id in keyset_ids: - if id is None: - continue - keysets_crud = await get_keysets(id=id, db=self.db) - assert keysets_crud, f"keyset {id} not found" - keyset: WalletKeyset = keysets_crud[0] - if unit and keyset.unit != unit: - continue - assert keyset.mint_url - if keyset.mint_url not in ret: - ret[keyset.mint_url] = [p for p in proofs if p.id == id] - else: - ret[keyset.mint_url].extend([p for p in proofs if p.id == id]) - return ret - - def _get_proofs_per_unit(self, proofs: List[Proof]) -> Dict[Unit, List[Proof]]: - ret: Dict[Unit, List[Proof]] = {} - for proof in proofs: - if proof.id not in self.keysets: - logger.error(f"Keyset {proof.id} not found in wallet.") - continue - unit = self.keysets[proof.id].unit - if unit not in ret: - ret[unit] = [proof] - else: - ret[unit].append(proof) - return ret - - def _get_proofs_keysets(self, proofs: List[Proof]) -> List[str]: - """Extracts all keyset ids from a list of proofs. - - Args: - proofs (List[Proof]): List of proofs to get the keyset id's of - """ - keysets: List[str] = [proof.id for proof in proofs if proof.id] - return keysets - - async def _get_keyset_urls(self, keysets: List[str]) -> Dict[str, List[str]]: - """Retrieves the mint URLs for a list of keyset id's from the wallet's database. - Returns a dictionary from URL to keyset ID - - Args: - keysets (List[str]): List of keysets. - """ - mint_urls: Dict[str, List[str]] = {} - for ks in set(keysets): - keysets_db = await get_keysets(id=ks, db=self.db) - keyset_db = keysets_db[0] if keysets_db else None - if keyset_db and keyset_db.mint_url: - mint_urls[keyset_db.mint_url] = ( - mint_urls[keyset_db.mint_url] + [ks] - if mint_urls.get(keyset_db.mint_url) - else [ks] - ) - return mint_urls - - async def _make_token(self, proofs: List[Proof], include_mints=True) -> TokenV3: - """ - Takes list of proofs and produces a TokenV3 by looking up - the mint URLs by the keyset id from the database. - - Args: - proofs (List[Proof]): List of proofs to be included in the token - include_mints (bool, optional): Whether to include the mint URLs in the token. Defaults to True. - - Returns: - TokenV3: TokenV3 object - """ - token = TokenV3() - - if include_mints: - # we create a map from mint url to keyset id and then group - # all proofs with their mint url to build a tokenv3 - - # extract all keysets from proofs - keysets = self._get_proofs_keysets(proofs) - # get all mint URLs for all unique keysets from db - mint_urls = await self._get_keyset_urls(keysets) - - # append all url-grouped proofs to token - for url, ids in mint_urls.items(): - mint_proofs = [p for p in proofs if p.id in ids] - token.token.append(TokenV3Token(mint=url, proofs=mint_proofs)) - else: - token_proofs = TokenV3Token(proofs=proofs) - token.token.append(token_proofs) - return token - - async def serialize_proofs( - self, proofs: List[Proof], include_mints=True, include_dleq=False, legacy=False - ) -> str: - """Produces sharable token with proofs and mint information. - - Args: - proofs (List[Proof]): List of proofs to be included in the token - include_mints (bool, optional): Whether to include the mint URLs in the token. Defaults to True. - legacy (bool, optional): Whether to produce a legacy V2 token. Defaults to False. - - Returns: - str: Serialized Cashu token - """ - - if legacy: - # V2 tokens - token_v2 = await self._make_token_v2(proofs, include_mints) - return await self._serialize_token_base64_tokenv2(token_v2) - - # # deprecated code for V1 tokens - # proofs_serialized = [p.to_dict() for p in proofs] - # return base64.urlsafe_b64encode( - # json.dumps(proofs_serialized).encode() - # ).decode() - - # V3 tokens - token = await self._make_token(proofs, include_mints) - return token.serialize(include_dleq) - - async def _make_token_v2(self, proofs: List[Proof], include_mints=True) -> TokenV2: - """ - Takes list of proofs and produces a TokenV2 by looking up - the keyset id and mint URLs from the database. - """ - # build token - token = TokenV2(proofs=proofs) - - # add mint information to the token, if requested - if include_mints: - # dummy object to hold information about the mint - mints: Dict[str, TokenV2Mint] = {} - # dummy object to hold all keyset id's we need to fetch from the db later - keysets: List[str] = [proof.id for proof in proofs if proof.id] - # iterate through unique keyset ids - for id in set(keysets): - # load the keyset from the db - keysets_db = await get_keysets(id=id, db=self.db) - keyset_db = keysets_db[0] if keysets_db else None - if keyset_db and keyset_db.mint_url and keyset_db.id: - # we group all mints according to URL - if keyset_db.mint_url not in mints: - mints[keyset_db.mint_url] = TokenV2Mint( - url=keyset_db.mint_url, - ids=[keyset_db.id], - ) - else: - # if a mint URL has multiple keysets, append to the already existing list - mints[keyset_db.mint_url].ids.append(keyset_db.id) - if len(mints) > 0: - # add mints grouped by url to the token - token.mints = list(mints.values()) - return token - - async def _serialize_token_base64_tokenv2(self, token: TokenV2) -> str: - """ - Takes a TokenV2 and serializes it in urlsafe_base64. - - Args: - token (TokenV2): TokenV2 object to be serialized - - Returns: - str: Serialized token - """ - # encode the token as a base64 string - token_base64 = base64.urlsafe_b64encode( - json.dumps(token.to_dict()).encode() - ).decode() - return token_base64 - - async def _select_proofs_to_send( - self, proofs: List[Proof], amount_to_send: int - ) -> List[Proof]: - """ - Selects proofs that can be used with the current mint. Implements a simple coin selection algorithm. - - The algorithm has two objectives: Get rid of all tokens from old epochs and include additional proofs from - the current epoch starting from the proofs with the largest amount. - - Rules: - 1) Proofs that are not marked as reserved - 2) Proofs that have a keyset id that is in self.mint_keyset_ids (all active keysets of mint) - 3) Include all proofs that have an older keyset than the current keyset of the mint (to get rid of old epochs). - 4) If the target amount is not reached, add proofs of the current keyset until it is. - - Args: - proofs (List[Proof]): List of proofs to select from - amount_to_send (int): Amount to select proofs for - - Returns: - List[Proof]: List of proofs to send - - Raises: - Exception: If the balance is too low to send the amount - """ - send_proofs: List[Proof] = [] - - # select proofs that are not reserved - proofs = [p for p in proofs if not p.reserved] - - # select proofs that are in the active keysets of the mint - proofs = [p for p in proofs if p.id in self.mint_keyset_ids or not p.id] - - # check that enough spendable proofs exist - if sum_proofs(proofs) < amount_to_send: - raise Exception("balance too low.") - - # add all proofs that have an older keyset than the current keyset of the mint - proofs_old_epochs = [ - p for p in proofs if p.id != self.keysets[self.keyset_id].id - ] - send_proofs += proofs_old_epochs - - # coinselect based on amount only from the current keyset - # start with the proofs with the largest amount and add them until the target amount is reached - proofs_current_epoch = [ - p for p in proofs if p.id == self.keysets[self.keyset_id].id - ] - sorted_proofs_of_current_keyset = sorted( - proofs_current_epoch, key=lambda p: p.amount - ) - - while sum_proofs(send_proofs) < amount_to_send: - proof_to_add = sorted_proofs_of_current_keyset.pop() - send_proofs.append(proof_to_add) - - logger.trace(f"selected proof amounts: {[p.amount for p in send_proofs]}") - return send_proofs - - async def set_reserved(self, proofs: List[Proof], reserved: bool) -> None: - """Mark a proof as reserved or reset it in the wallet db to avoid reuse when it is sent. - - Args: - proofs (List[Proof]): List of proofs to mark as reserved - reserved (bool): Whether to mark the proofs as reserved or not - """ - uuid_str = str(uuid.uuid1()) - for proof in proofs: - proof.reserved = True - await update_proof(proof, reserved=reserved, send_id=uuid_str, db=self.db) - async def invalidate( self, proofs: List[Proof], check_spendable=False ) -> List[Proof]: @@ -1540,15 +893,76 @@ async def invalidate( # ---------- TRANSACTION HELPERS ---------- + async def select_to_send( + self, + proofs: List[Proof], + amount: int, + *, + set_reserved: bool = False, + offline: bool = False, + include_fees: bool = True, + ) -> Tuple[List[Proof], int]: + """ + Selects proofs such that a desired `amount` can be sent. If the offline coin selection is unsuccessful, + and `offline` is set to False (default), we split the available proofs with the mint to get the desired `amount`. + + If `set_reserved` is set to True, the proofs are marked as reserved so they aren't used in other transactions. + + If `include_fees` is set to False, the swap fees are not included in the amount to be selected. + + Args: + proofs (List[Proof]): Proofs to split + amount (int): Amount to split to + set_reserved (bool, optional): If set, the proofs are marked as reserved. + + Returns: + List[Proof]: Proofs to send + int: Fees for the transaction + """ + # select proofs that are not reserved and are in the active keysets of the mint + proofs = self.active_proofs(proofs) + if sum_proofs(proofs) < amount: + raise Exception("balance too low.") + + # coin selection for potentially offline sending + send_proofs = await self._select_proofs_to_send( + proofs, amount, include_fees=include_fees + ) + fees = self.get_fees_for_proofs(send_proofs) + logger.trace( + f"select_to_send: selected: {self.unit.str(sum_proofs(send_proofs))} (+ {self.unit.str(fees)} fees) – wanted: {self.unit.str(amount)}" + ) + # offline coin selection unsuccessful, we need to swap proofs before we can send + if not send_proofs or sum_proofs(send_proofs) > amount + fees: + if not offline: + logger.debug("Offline coin selection unsuccessful. Splitting proofs.") + # we set the proofs as reserved later + _, send_proofs = await self.split_to_send( + proofs, amount, set_reserved=False + ) + else: + raise Exception( + "Could not select proofs in offline mode. Available amounts:" + + amount_summary(proofs, self.unit) + ) + if set_reserved: + await self.set_reserved(send_proofs, reserved=True) + return send_proofs, fees + async def split_to_send( self, proofs: List[Proof], amount: int, + *, secret_lock: Optional[Secret] = None, set_reserved: bool = False, - ): + include_fees: bool = True, + ) -> Tuple[List[Proof], List[Proof]]: """ - Splits proofs such that a certain amount can be sent. + Swaps a set of proofs with the mint to get a set that sums up to a desired amount that can be sent. The remaining + proofs are returned to be kept. All newly created proofs will be stored in the database but if `set_reserved` is set + to True, the proofs to be sent (which sum up to `amount`) will be marked as reserved so they aren't used in other + transactions. Args: proofs (List[Proof]): Proofs to split @@ -1561,13 +975,28 @@ async def split_to_send( Returns: Tuple[List[Proof], List[Proof]]: Tuple of proofs to keep and proofs to send """ - if secret_lock: - logger.debug(f"Spending conditions: {secret_lock}") - spendable_proofs = await self._select_proofs_to_send(proofs, amount) + # select proofs that are not reserved and are in the active keysets of the mint + proofs = self.active_proofs(proofs) + if sum_proofs(proofs) < amount: + raise Exception("balance too low.") + + # coin selection for swapping + # spendable_proofs, fees = await self._select_proofs_to_split(proofs, amount) + swap_proofs = await self._select_proofs_to_send( + proofs, amount, include_fees=True + ) + # add proofs from inactive keysets to swap_proofs to get rid of them + swap_proofs += [ + p + for p in proofs + if not self.keysets[p.id].active and not p.reserved and p not in swap_proofs + ] - keep_proofs, send_proofs = await self.split( - spendable_proofs, amount, secret_lock + fees = self.get_fees_for_proofs(swap_proofs) + logger.debug( + f"Amount to send: {self.unit.str(amount)} (+ {self.unit.str(fees)} fees)" ) + keep_proofs, send_proofs = await self.split(swap_proofs, amount, secret_lock) if set_reserved: await self.set_reserved(send_proofs, reserved=True) return keep_proofs, send_proofs @@ -1587,6 +1016,21 @@ def proof_amounts(self): """Returns a sorted list of amounts of all proofs""" return [p.amount for p in sorted(self.proofs, key=lambda p: p.amount)] + def active_proofs(self, proofs: List[Proof]): + """Returns a list of proofs that + - have an id that is in the current `self.keysets` which have the unit in `self.unit` + - are not reserved + """ + + def is_active_proof(p: Proof) -> bool: + return ( + p.id in self.keysets + and self.keysets[p.id].unit == self.unit + and not p.reserved + ) + + return [p for p in proofs if is_active_proof(p)] + def balance_per_keyset(self) -> Dict[str, Dict[str, Union[int, str]]]: ret: Dict[str, Dict[str, Union[int, str]]] = { key: { @@ -1689,8 +1133,7 @@ async def restore_wallet_from_mnemonic( await self._init_private_key(mnemonic) await self.load_mint() print("Restoring tokens...") - keyset_ids = self.mint_keyset_ids - for keyset_id in keyset_ids: + for keyset_id in self.keysets.keys(): await self.restore_tokens_for_keyset(keyset_id, to, batch) async def restore_promises_from_to( diff --git a/cashu/wallet/wallet_deprecated.py b/cashu/wallet/wallet_deprecated.py index f300b1b4..614b9191 100644 --- a/cashu/wallet/wallet_deprecated.py +++ b/cashu/wallet/wallet_deprecated.py @@ -10,6 +10,11 @@ BlindedMessage, BlindedMessage_Deprecated, BlindedSignature, + Proof, + WalletKeyset, +) +from ..core.crypto.secp import PublicKey +from ..core.models import ( CheckFeesRequest_deprecated, CheckFeesResponse_deprecated, CheckSpendableRequest_deprecated, @@ -18,6 +23,7 @@ GetInfoResponse_deprecated, GetMintResponse_deprecated, KeysetsResponse_deprecated, + KeysetsResponseKeyset, PostMeltRequest_deprecated, PostMeltResponse_deprecated, PostMintQuoteResponse, @@ -26,10 +32,7 @@ PostRestoreResponse, PostSplitRequest_Deprecated, PostSplitResponse_Deprecated, - Proof, - WalletKeyset, ) -from ..core.crypto.secp import PublicKey from ..core.settings import settings from ..tor.tor import TorProxy from .protocols import SupportsHttpxClient, SupportsMintURL @@ -78,7 +81,7 @@ def async_ensure_mint_loaded_deprecated(func): async def wrapper(self, *args, **kwargs): if not self.keysets: - await self._load_mint() + await self.load_mint() return await func(self, *args, **kwargs) return wrapper @@ -164,9 +167,7 @@ async def _get_keys_deprecated(self, url: str) -> WalletKeyset: return keyset @async_set_httpx_client - async def _get_keys_of_keyset_deprecated( - self, url: str, keyset_id: str - ) -> WalletKeyset: + async def _get_keyset_deprecated(self, url: str, keyset_id: str) -> WalletKeyset: """API that gets the keys of a specific keyset from the mint. @@ -201,8 +202,7 @@ async def _get_keys_of_keyset_deprecated( return keyset @async_set_httpx_client - @async_ensure_mint_loaded_deprecated - async def _get_keyset_ids_deprecated(self, url: str) -> List[str]: + async def _get_keysets_deprecated(self, url: str) -> List[KeysetsResponseKeyset]: """API that gets a list of all active keysets of the mint. Args: @@ -222,7 +222,11 @@ async def _get_keyset_ids_deprecated(self, url: str) -> List[str]: keysets_dict = resp.json() keysets = KeysetsResponse_deprecated.parse_obj(keysets_dict) assert len(keysets.keysets), Exception("did not receive any keysets") - return keysets.keysets + keysets_new = [ + KeysetsResponseKeyset(id=id, unit="sat", active=True) + for id in keysets.keysets + ] + return keysets_new @async_set_httpx_client @async_ensure_mint_loaded_deprecated diff --git a/tests/conftest.py b/tests/conftest.py index f3a9a6b8..6a23880a 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -45,6 +45,7 @@ settings.mint_seed_decryption_key = "" settings.mint_max_balance = 0 settings.mint_lnd_enable_mpp = True +settings.mint_input_fee_ppk = 0 assert "test" in settings.cashu_dir shutil.rmtree(settings.cashu_dir, ignore_errors=True) diff --git a/tests/test_mint.py b/tests/test_mint.py index 0a46b0f3..e842e354 100644 --- a/tests/test_mint.py +++ b/tests/test_mint.py @@ -2,9 +2,10 @@ import pytest -from cashu.core.base import BlindedMessage, PostMintQuoteRequest, Proof +from cashu.core.base import BlindedMessage, Proof from cashu.core.crypto.b_dhke import step1_alice from cashu.core.helpers import calculate_number_of_blank_outputs +from cashu.core.models import PostMintQuoteRequest from cashu.core.settings import settings from cashu.mint.ledger import Ledger from tests.helpers import pay_if_regtest @@ -129,9 +130,9 @@ async def test_generate_promises(ledger: Ledger): async def test_generate_change_promises(ledger: Ledger): # Example slightly adapted from NUT-08 because we want to ensure the dynamic change # token amount works: `n_blank_outputs != n_returned_promises != 4`. - invoice_amount = 100_000 + # invoice_amount = 100_000 fee_reserve = 2_000 - total_provided = invoice_amount + fee_reserve + # total_provided = invoice_amount + fee_reserve actual_fee = 100 expected_returned_promises = 7 # Amounts = [4, 8, 32, 64, 256, 512, 1024] @@ -149,7 +150,7 @@ async def test_generate_change_promises(ledger: Ledger): ] promises = await ledger._generate_change_promises( - total_provided, invoice_amount, actual_fee, outputs + fee_provided=fee_reserve, fee_paid=actual_fee, outputs=outputs ) assert len(promises) == expected_returned_promises @@ -160,9 +161,9 @@ async def test_generate_change_promises(ledger: Ledger): async def test_generate_change_promises_legacy_wallet(ledger: Ledger): # Check if mint handles a legacy wallet implementation (always sends 4 blank # outputs) as well. - invoice_amount = 100_000 + # invoice_amount = 100_000 fee_reserve = 2_000 - total_provided = invoice_amount + fee_reserve + # total_provided = invoice_amount + fee_reserve actual_fee = 100 expected_returned_promises = 4 # Amounts = [64, 256, 512, 1024] @@ -179,9 +180,7 @@ async def test_generate_change_promises_legacy_wallet(ledger: Ledger): for b, _ in blinded_msgs ] - promises = await ledger._generate_change_promises( - total_provided, invoice_amount, actual_fee, outputs - ) + promises = await ledger._generate_change_promises(fee_reserve, actual_fee, outputs) assert len(promises) == expected_returned_promises assert sum([promise.amount for promise in promises]) == expected_returned_fees @@ -189,14 +188,14 @@ async def test_generate_change_promises_legacy_wallet(ledger: Ledger): @pytest.mark.asyncio async def test_generate_change_promises_returns_empty_if_no_outputs(ledger: Ledger): - invoice_amount = 100_000 + # invoice_amount = 100_000 fee_reserve = 1_000 - total_provided = invoice_amount + fee_reserve + # total_provided = invoice_amount + fee_reserve actual_fee_msat = 100_000 outputs = None promises = await ledger._generate_change_promises( - total_provided, invoice_amount, actual_fee_msat, outputs + fee_reserve, actual_fee_msat, outputs ) assert len(promises) == 0 diff --git a/tests/test_mint_api.py b/tests/test_mint_api.py index 5b105fcc..b374b227 100644 --- a/tests/test_mint_api.py +++ b/tests/test_mint_api.py @@ -3,14 +3,14 @@ import pytest import pytest_asyncio -from cashu.core.base import ( +from cashu.core.base import SpentState +from cashu.core.models import ( GetInfoResponse, MintMeltMethodSetting, PostCheckStateRequest, PostCheckStateResponse, PostRestoreRequest, PostRestoreResponse, - SpentState, ) from cashu.core.settings import settings from cashu.mint.ledger import Ledger @@ -89,6 +89,7 @@ async def test_api_keysets(ledger: Ledger): "id": "009a1f293253e41e", "unit": "sat", "active": True, + "input_fee_ppk": 0, }, ] } diff --git a/tests/test_mint_api_deprecated.py b/tests/test_mint_api_deprecated.py index f59b9b84..e1722562 100644 --- a/tests/test_mint_api_deprecated.py +++ b/tests/test_mint_api_deprecated.py @@ -2,12 +2,12 @@ import pytest import pytest_asyncio -from cashu.core.base import ( +from cashu.core.base import Proof +from cashu.core.models import ( CheckSpendableRequest_deprecated, CheckSpendableResponse_deprecated, PostRestoreRequest, PostRestoreResponse, - Proof, ) from cashu.mint.ledger import Ledger from cashu.wallet.crud import bump_secret_derivation diff --git a/tests/test_mint_db.py b/tests/test_mint_db.py index 92caf1e7..ea2d9a0d 100644 --- a/tests/test_mint_db.py +++ b/tests/test_mint_db.py @@ -1,7 +1,7 @@ import pytest import pytest_asyncio -from cashu.core.base import PostMeltQuoteRequest +from cashu.core.models import PostMeltQuoteRequest from cashu.mint.ledger import Ledger from cashu.wallet.wallet import Wallet from cashu.wallet.wallet import Wallet as Wallet1 diff --git a/tests/test_mint_fees.py b/tests/test_mint_fees.py new file mode 100644 index 00000000..106d1fbd --- /dev/null +++ b/tests/test_mint_fees.py @@ -0,0 +1,241 @@ +from typing import Optional + +import pytest +import pytest_asyncio + +from cashu.core.helpers import sum_proofs +from cashu.core.models import PostMeltQuoteRequest +from cashu.core.split import amount_split +from cashu.mint.ledger import Ledger +from cashu.wallet.wallet import Wallet +from cashu.wallet.wallet import Wallet as Wallet1 +from tests.conftest import SERVER_ENDPOINT +from tests.helpers import get_real_invoice, is_fake, is_regtest, pay_if_regtest + + +async def assert_err(f, msg): + """Compute f() and expect an error message 'msg'.""" + try: + await f + except Exception as exc: + if msg not in str(exc.args[0]): + raise Exception(f"Expected error: {msg}, got: {exc.args[0]}") + return + raise Exception(f"Expected error: {msg}, got no error") + + +@pytest_asyncio.fixture(scope="function") +async def wallet1(ledger: Ledger): + wallet1 = await Wallet1.with_db( + url=SERVER_ENDPOINT, + db="test_data/wallet1", + name="wallet1", + ) + await wallet1.load_mint() + yield wallet1 + + +def set_ledger_keyset_fees( + fee_ppk: int, ledger: Ledger, wallet: Optional[Wallet] = None +): + for keyset in ledger.keysets.values(): + keyset.input_fee_ppk = fee_ppk + + if wallet: + for wallet_keyset in wallet.keysets.values(): + wallet_keyset.input_fee_ppk = fee_ppk + + +@pytest.mark.asyncio +async def test_get_fees_for_proofs(wallet1: Wallet, ledger: Ledger): + invoice = await wallet1.request_mint(64) + pay_if_regtest(invoice.bolt11) + await wallet1.mint(64, split=[1] * 64, id=invoice.id) + + # two proofs + + set_ledger_keyset_fees(100, ledger) + proofs = [wallet1.proofs[0], wallet1.proofs[1]] + fees = ledger.get_fees_for_proofs(proofs) + assert fees == 1 + + set_ledger_keyset_fees(1234, ledger) + fees = ledger.get_fees_for_proofs(proofs) + assert fees == 3 + + set_ledger_keyset_fees(0, ledger) + fees = ledger.get_fees_for_proofs(proofs) + assert fees == 0 + + set_ledger_keyset_fees(1, ledger) + fees = ledger.get_fees_for_proofs(proofs) + assert fees == 1 + + # ten proofs + + ten_proofs = wallet1.proofs[:10] + set_ledger_keyset_fees(100, ledger) + fees = ledger.get_fees_for_proofs(ten_proofs) + assert fees == 1 + + set_ledger_keyset_fees(101, ledger) + fees = ledger.get_fees_for_proofs(ten_proofs) + assert fees == 2 + + # three proofs + + three_proofs = wallet1.proofs[:3] + set_ledger_keyset_fees(333, ledger) + fees = ledger.get_fees_for_proofs(three_proofs) + assert fees == 1 + + set_ledger_keyset_fees(334, ledger) + fees = ledger.get_fees_for_proofs(three_proofs) + assert fees == 2 + + +@pytest.mark.asyncio +@pytest.mark.skipif_with_fees(is_regtest, reason="only works with FakeWallet") +async def test_wallet_fee(wallet1: Wallet, ledger: Ledger): + # THIS TEST IS A FAKE, WE SET THE WALLET FEES MANUALLY IN set_ledger_keyset_fees + # It would be better to test if the wallet can get the fees from the mint itself + # but the ledger instance does not update the responses from the `mint` that is running in the background + # so we just pretend here and test really nothing... + + # set fees to 100 ppk + set_ledger_keyset_fees(100, ledger, wallet1) + + # check if all wallet keysets have the correct fees + for keyset in wallet1.keysets.values(): + assert keyset.input_fee_ppk == 100 + + +@pytest.mark.asyncio +async def test_split_with_fees(wallet1: Wallet, ledger: Ledger): + # set fees to 100 ppk + set_ledger_keyset_fees(100, ledger) + invoice = await wallet1.request_mint(64) + pay_if_regtest(invoice.bolt11) + await wallet1.mint(64, id=invoice.id) + + send_proofs, _ = await wallet1.select_to_send(wallet1.proofs, 10) + fees = ledger.get_fees_for_proofs(send_proofs) + assert fees == 1 + outputs = await wallet1.construct_outputs(amount_split(9)) + + promises = await ledger.split(proofs=send_proofs, outputs=outputs) + assert len(promises) == len(outputs) + assert [p.amount for p in promises] == [p.amount for p in outputs] + + +@pytest.mark.asyncio +async def test_split_with_high_fees(wallet1: Wallet, ledger: Ledger): + # set fees to 100 ppk + set_ledger_keyset_fees(1234, ledger) + invoice = await wallet1.request_mint(64) + pay_if_regtest(invoice.bolt11) + await wallet1.mint(64, id=invoice.id) + + send_proofs, _ = await wallet1.select_to_send(wallet1.proofs, 10) + fees = ledger.get_fees_for_proofs(send_proofs) + assert fees == 3 + outputs = await wallet1.construct_outputs(amount_split(7)) + + promises = await ledger.split(proofs=send_proofs, outputs=outputs) + assert len(promises) == len(outputs) + assert [p.amount for p in promises] == [p.amount for p in outputs] + + +@pytest.mark.asyncio +async def test_split_not_enough_fees(wallet1: Wallet, ledger: Ledger): + # set fees to 100 ppk + set_ledger_keyset_fees(100, ledger) + invoice = await wallet1.request_mint(64) + pay_if_regtest(invoice.bolt11) + await wallet1.mint(64, id=invoice.id) + + send_proofs, _ = await wallet1.select_to_send(wallet1.proofs, 10) + fees = ledger.get_fees_for_proofs(send_proofs) + assert fees == 1 + # with 10 sat input, we request 10 sat outputs but fees are 1 sat so the swap will fail + outputs = await wallet1.construct_outputs(amount_split(10)) + + await assert_err( + ledger.split(proofs=send_proofs, outputs=outputs), "are not balanced" + ) + + +@pytest.mark.asyncio +@pytest.mark.skipif(is_regtest, reason="only works with FakeWallet") +async def test_melt_internal(wallet1: Wallet, ledger: Ledger): + # set fees to 100 ppk + set_ledger_keyset_fees(100, ledger, wallet1) + + # mint twice so we have enough to pay the second invoice back + invoice = await wallet1.request_mint(128) + await wallet1.mint(128, id=invoice.id) + assert wallet1.balance == 128 + + # create a mint quote so that we can melt to it internally + invoice_to_pay = await wallet1.request_mint(64) + invoice_payment_request = invoice_to_pay.bolt11 + + melt_quote = await ledger.melt_quote( + PostMeltQuoteRequest(request=invoice_payment_request, unit="sat") + ) + assert not melt_quote.paid + assert melt_quote.amount == 64 + assert melt_quote.fee_reserve == 0 + + melt_quote_pre_payment = await ledger.get_melt_quote(melt_quote.quote) + assert not melt_quote_pre_payment.paid, "melt quote should not be paid" + + # let's first try to melt without enough funds + send_proofs, fees = await wallet1.select_to_send(wallet1.proofs, 63) + # this should fail because we need 64 + 1 sat fees + assert sum_proofs(send_proofs) == 64 + await assert_err( + ledger.melt(proofs=send_proofs, quote=melt_quote.quote), + "not enough inputs provided for melt", + ) + + # the wallet respects the fees for coin selection + send_proofs, fees = await wallet1.select_to_send(wallet1.proofs, 64) + # includes 1 sat fees + assert sum_proofs(send_proofs) == 65 + await ledger.melt(proofs=send_proofs, quote=melt_quote.quote) + + melt_quote_post_payment = await ledger.get_melt_quote(melt_quote.quote) + assert melt_quote_post_payment.paid, "melt quote should be paid" + + +@pytest.mark.asyncio +@pytest.mark.skipif(is_fake, reason="only works with Regtest") +async def test_melt_external_with_fees(wallet1: Wallet, ledger: Ledger): + # set fees to 100 ppk + set_ledger_keyset_fees(100, ledger, wallet1) + + # mint twice so we have enough to pay the second invoice back + invoice = await wallet1.request_mint(128) + pay_if_regtest(invoice.bolt11) + await wallet1.mint(128, id=invoice.id) + assert wallet1.balance == 128 + + invoice_dict = get_real_invoice(64) + invoice_payment_request = invoice_dict["payment_request"] + + mint_quote = await wallet1.melt_quote(invoice_payment_request) + total_amount = mint_quote.amount + mint_quote.fee_reserve + send_proofs, fee = await wallet1.select_to_send(wallet1.proofs, total_amount) + melt_quote = await ledger.melt_quote( + PostMeltQuoteRequest(request=invoice_payment_request, unit="sat") + ) + + melt_quote_pre_payment = await ledger.get_melt_quote(melt_quote.quote) + assert not melt_quote_pre_payment.paid, "melt quote should not be paid" + + assert not melt_quote.paid, "melt quote should not be paid" + await ledger.melt(proofs=send_proofs, quote=melt_quote.quote) + + melt_quote_post_payment = await ledger.get_melt_quote(melt_quote.quote) + assert melt_quote_post_payment.paid, "melt quote should be paid" diff --git a/tests/test_mint_lightning_blink.py b/tests/test_mint_lightning_blink.py index 040e5374..f870d87b 100644 --- a/tests/test_mint_lightning_blink.py +++ b/tests/test_mint_lightning_blink.py @@ -2,7 +2,8 @@ import respx from httpx import Response -from cashu.core.base import Amount, MeltQuote, PostMeltQuoteRequest, Unit +from cashu.core.base import Amount, MeltQuote, Unit +from cashu.core.models import PostMeltQuoteRequest from cashu.core.settings import settings from cashu.lightning.blink import MINIMUM_FEE_MSAT, BlinkWallet # type: ignore diff --git a/tests/test_mint_operations.py b/tests/test_mint_operations.py index df773583..f3884fb5 100644 --- a/tests/test_mint_operations.py +++ b/tests/test_mint_operations.py @@ -1,8 +1,8 @@ import pytest import pytest_asyncio -from cashu.core.base import PostMeltQuoteRequest, PostMintQuoteRequest from cashu.core.helpers import sum_proofs +from cashu.core.models import PostMeltQuoteRequest, PostMintQuoteRequest from cashu.mint.ledger import Ledger from cashu.wallet.wallet import Wallet from cashu.wallet.wallet import Wallet as Wallet1 @@ -155,6 +155,18 @@ async def test_split(wallet1: Wallet, ledger: Ledger): assert [p.amount for p in promises] == [p.amount for p in outputs] +@pytest.mark.asyncio +async def test_split_with_no_outputs(wallet1: Wallet, ledger: Ledger): + invoice = await wallet1.request_mint(64) + pay_if_regtest(invoice.bolt11) + await wallet1.mint(64, id=invoice.id) + _, send_proofs = await wallet1.split_to_send(wallet1.proofs, 10, set_reserved=False) + await assert_err( + ledger.split(proofs=send_proofs, outputs=[]), + "no outputs provided", + ) + + @pytest.mark.asyncio async def test_split_with_input_less_than_outputs(wallet1: Wallet, ledger: Ledger): invoice = await wallet1.request_mint(64) @@ -165,19 +177,19 @@ async def test_split_with_input_less_than_outputs(wallet1: Wallet, ledger: Ledge wallet1.proofs, 10, set_reserved=False ) - all_send_proofs = send_proofs + keep_proofs + too_many_proofs = send_proofs + send_proofs - # generate outputs for all proofs, not only the sent ones + # generate more outputs than inputs secrets, rs, derivation_paths = await wallet1.generate_n_secrets( - len(all_send_proofs) + len(too_many_proofs) ) outputs, rs = wallet1._construct_outputs( - [p.amount for p in all_send_proofs], secrets, rs + [p.amount for p in too_many_proofs], secrets, rs ) await assert_err( ledger.split(proofs=send_proofs, outputs=outputs), - "inputs do not have same amount as outputs.", + "are not balanced", ) # make sure we can still spend our tokens @@ -201,7 +213,7 @@ async def test_split_with_input_more_than_outputs(wallet1: Wallet, ledger: Ledge await assert_err( ledger.split(proofs=inputs, outputs=outputs), - "inputs do not have same amount as outputs", + "are not balanced", ) # make sure we can still spend our tokens @@ -216,6 +228,9 @@ async def test_split_twice_with_same_outputs(wallet1: Wallet, ledger: Ledger): inputs1 = wallet1.proofs[:1] inputs2 = wallet1.proofs[1:] + assert inputs1[0].amount == 64 + assert inputs2[0].amount == 64 + output_amounts = [64] secrets, rs, derivation_paths = await wallet1.generate_n_secrets( len(output_amounts) diff --git a/tests/test_wallet.py b/tests/test_wallet.py index 32847edb..1cfc4718 100644 --- a/tests/test_wallet.py +++ b/tests/test_wallet.py @@ -42,14 +42,14 @@ async def assert_err(f, msg: Union[str, CashuError]): def assert_amt(proofs: List[Proof], expected: int): """Assert amounts the proofs contain.""" - assert [p.amount for p in proofs] == expected + assert sum([p.amount for p in proofs]) == expected async def reset_wallet_db(wallet: Wallet): await wallet.db.execute("DELETE FROM proofs") await wallet.db.execute("DELETE FROM proofs_used") await wallet.db.execute("DELETE FROM keysets") - await wallet._load_mint() + await wallet.load_mint() @pytest_asyncio.fixture(scope="function") @@ -97,7 +97,7 @@ async def test_get_keyset(wallet1: Wallet): # gets the keys of a specific keyset assert keyset.id is not None assert keyset.public_keys is not None - keys2 = await wallet1._get_keys_of_keyset(keyset.id) + keys2 = await wallet1._get_keyset(keyset.id) assert keys2.public_keys is not None assert len(keyset.public_keys) == len(keys2.public_keys) @@ -105,12 +105,12 @@ async def test_get_keyset(wallet1: Wallet): @pytest.mark.asyncio async def test_get_keyset_from_db(wallet1: Wallet): # first load it from the mint - # await wallet1._load_mint_keys() + # await wallet1.activate_keyset() # NOTE: conftest already called wallet.load_mint() which got the keys from the mint keyset1 = copy.copy(wallet1.keysets[wallet1.keyset_id]) # then load it from the db - await wallet1._load_mint_keys() + await wallet1.activate_keyset() keyset2 = copy.copy(wallet1.keysets[wallet1.keyset_id]) assert keyset1.public_keys == keyset2.public_keys @@ -133,17 +133,17 @@ async def test_get_info(wallet1: Wallet): @pytest.mark.asyncio async def test_get_nonexistent_keyset(wallet1: Wallet): await assert_err( - wallet1._get_keys_of_keyset("nonexistent"), + wallet1._get_keyset("nonexistent"), KeysetNotFoundError(), ) @pytest.mark.asyncio -async def test_get_keyset_ids(wallet1: Wallet): - keysets = await wallet1._get_keyset_ids() +async def test_get_keysets(wallet1: Wallet): + keysets = await wallet1._get_keysets() assert isinstance(keysets, list) assert len(keysets) > 0 - assert wallet1.keyset_id in keysets + assert wallet1.keyset_id in [k.id for k in keysets] @pytest.mark.asyncio @@ -156,6 +156,7 @@ async def test_request_mint(wallet1: Wallet): async def test_mint(wallet1: Wallet): invoice = await wallet1.request_mint(64) pay_if_regtest(invoice.bolt11) + expected_proof_amounts = wallet1.split_wallet_state(64) await wallet1.mint(64, id=invoice.id) assert wallet1.balance == 64 @@ -168,7 +169,8 @@ async def test_mint(wallet1: Wallet): proofs_minted = await get_proofs( db=wallet1.db, mint_id=invoice_db.id, table="proofs" ) - assert len(proofs_minted) == 1 + assert len(proofs_minted) == len(expected_proof_amounts) + assert all([p.amount in expected_proof_amounts for p in proofs_minted]) assert all([p.mint_id == invoice.id for p in proofs_minted]) @@ -212,11 +214,15 @@ async def test_split(wallet1: Wallet): pay_if_regtest(invoice.bolt11) await wallet1.mint(64, id=invoice.id) assert wallet1.balance == 64 + # the outputs we keep that we expect after the split + expected_proof_amounts = wallet1.split_wallet_state(44) p1, p2 = await wallet1.split(wallet1.proofs, 20) assert wallet1.balance == 64 assert sum_proofs(p1) == 44 - assert [p.amount for p in p1] == [4, 8, 32] + # what we keep should have the expected amounts + assert [p.amount for p in p1] == expected_proof_amounts assert sum_proofs(p2) == 20 + # what we send should be the optimal split assert [p.amount for p in p2] == [4, 16] assert all([p.id == wallet1.keyset_id for p in p1]) assert all([p.id == wallet1.keyset_id for p in p2]) @@ -227,13 +233,19 @@ async def test_split_to_send(wallet1: Wallet): invoice = await wallet1.request_mint(64) pay_if_regtest(invoice.bolt11) await wallet1.mint(64, id=invoice.id) - keep_proofs, spendable_proofs = await wallet1.split_to_send( + assert wallet1.balance == 64 + + # this will select 32 sats and them (nothing to keep) + keep_proofs, send_proofs = await wallet1.split_to_send( wallet1.proofs, 32, set_reserved=True ) - get_spendable = await wallet1._select_proofs_to_send(wallet1.proofs, 32) - assert keep_proofs == get_spendable + assert_amt(send_proofs, 32) + assert_amt(keep_proofs, 0) + spendable_proofs = await wallet1._select_proofs_to_send(wallet1.proofs, 32) assert sum_proofs(spendable_proofs) == 32 + + assert sum_proofs(send_proofs) == 32 assert wallet1.balance == 64 assert wallet1.available_balance == 32 @@ -271,7 +283,7 @@ async def test_melt(wallet1: Wallet): invoice_payment_hash = str(invoice.payment_hash) invoice_payment_request = invoice.bolt11 - quote = await wallet1.request_melt(invoice_payment_request) + quote = await wallet1.melt_quote(invoice_payment_request) total_amount = quote.amount + quote.fee_reserve if is_regtest: @@ -421,7 +433,7 @@ async def test_split_invalid_amount(wallet1: Wallet): await wallet1.mint(64, id=invoice.id) await assert_err( wallet1.split(wallet1.proofs, -1), - "amount must be positive.", + "amount can't be negative", ) @@ -436,13 +448,13 @@ async def test_token_state(wallet1: Wallet): @pytest.mark.asyncio -async def test_load_mint_keys_specific_keyset(wallet1: Wallet): - await wallet1._load_mint_keys() +async def testactivate_keyset_specific_keyset(wallet1: Wallet): + await wallet1.activate_keyset() assert list(wallet1.keysets.keys()) == ["009a1f293253e41e"] - await wallet1._load_mint_keys(keyset_id=wallet1.keyset_id) - await wallet1._load_mint_keys(keyset_id="009a1f293253e41e") + await wallet1.activate_keyset(keyset_id=wallet1.keyset_id) + await wallet1.activate_keyset(keyset_id="009a1f293253e41e") # expect deprecated keyset id to be present await assert_err( - wallet1._load_mint_keys(keyset_id="nonexistent"), - KeysetNotFoundError(), + wallet1.activate_keyset(keyset_id="nonexistent"), + KeysetNotFoundError("nonexistent"), ) diff --git a/tests/test_wallet_api.py b/tests/test_wallet_api.py index 7005948d..14602f2e 100644 --- a/tests/test_wallet_api.py +++ b/tests/test_wallet_api.py @@ -65,16 +65,16 @@ async def test_send(wallet: Wallet): @pytest.mark.asyncio async def test_send_without_split(wallet: Wallet): with TestClient(app) as client: - response = client.post("/send?amount=2&nosplit=true") + response = client.post("/send?amount=2&offline=true") assert response.status_code == 200 assert response.json()["balance"] @pytest.mark.skipif(is_regtest, reason="regtest") @pytest.mark.asyncio -async def test_send_without_split_but_wrong_amount(wallet: Wallet): +async def test_send_too_much(wallet: Wallet): with TestClient(app) as client: - response = client.post("/send?amount=10&nosplit=true") + response = client.post("/send?amount=110000") assert response.status_code == 400 diff --git a/tests/test_wallet_cli.py b/tests/test_wallet_cli.py index 884a9059..eceebe6e 100644 --- a/tests/test_wallet_cli.py +++ b/tests/test_wallet_cli.py @@ -175,6 +175,7 @@ def test_invoice_with_split(mint, cli_prefix): wallet = asyncio.run(init_wallet()) assert wallet.proof_amounts.count(1) >= 10 + @pytest.mark.skipif(not is_fake, reason="only on fakewallet") def test_invoices_with_minting(cli_prefix): # arrange @@ -223,6 +224,7 @@ def test_invoices_without_minting(cli_prefix): assert get_invoice_from_invoices_command(result.output)["ID"] == invoice.id assert get_invoice_from_invoices_command(result.output)["Paid"] == str(invoice.paid) + @pytest.mark.skipif(not is_fake, reason="only on fakewallet") def test_invoices_with_onlypaid_option(cli_prefix): # arrange @@ -263,6 +265,7 @@ def test_invoices_with_onlypaid_option_without_minting(cli_prefix): assert result.exit_code == 0 assert "No invoices found." in result.output + @pytest.mark.skipif(not is_fake, reason="only on fakewallet") def test_invoices_with_onlyunpaid_option(cli_prefix): # arrange @@ -322,6 +325,7 @@ def test_invoices_with_both_onlypaid_and_onlyunpaid_options(cli_prefix): in result.output ) + @pytest.mark.skipif(not is_fake, reason="only on fakewallet") def test_invoices_with_pending_option(cli_prefix): # arrange @@ -422,11 +426,11 @@ def test_send_legacy(mint, cli_prefix): assert token_str.startswith("eyJwcm9v"), "output is not as expected" -def test_send_without_split(mint, cli_prefix): +def test_send_offline(mint, cli_prefix): runner = CliRunner() result = runner.invoke( cli, - [*cli_prefix, "send", "2", "--nosplit"], + [*cli_prefix, "send", "2", "--offline"], ) assert result.exception is None print("SEND") @@ -434,13 +438,13 @@ def test_send_without_split(mint, cli_prefix): assert "cashuA" in result.output, "output does not have a token" -def test_send_without_split_but_wrong_amount(mint, cli_prefix): +def test_send_too_much(mint, cli_prefix): runner = CliRunner() result = runner.invoke( cli, - [*cli_prefix, "send", "10", "--nosplit"], + [*cli_prefix, "send", "100000"], ) - assert "No proof with this amount found" in str(result.exception) + assert "balance too low" in str(result.exception) def test_receive_tokenv3(mint, cli_prefix): diff --git a/tests/test_wallet_lightning.py b/tests/test_wallet_lightning.py index b797f3c5..0e89ac96 100644 --- a/tests/test_wallet_lightning.py +++ b/tests/test_wallet_lightning.py @@ -37,7 +37,7 @@ async def reset_wallet_db(wallet: LightningWallet): await wallet.db.execute("DELETE FROM proofs") await wallet.db.execute("DELETE FROM proofs_used") await wallet.db.execute("DELETE FROM keysets") - await wallet._load_mint() + await wallet.load_mint() @pytest_asyncio.fixture(scope="function") diff --git a/tests/test_wallet_restore.py b/tests/test_wallet_restore.py index 136425b3..4a558d18 100644 --- a/tests/test_wallet_restore.py +++ b/tests/test_wallet_restore.py @@ -42,7 +42,7 @@ async def reset_wallet_db(wallet: Wallet): await wallet.db.execute("DELETE FROM proofs") await wallet.db.execute("DELETE FROM proofs_used") await wallet.db.execute("DELETE FROM keysets") - await wallet._load_mint() + await wallet.load_mint() @pytest_asyncio.fixture(scope="function") @@ -206,7 +206,7 @@ async def test_restore_wallet_after_split_to_send(wallet3: Wallet): wallet3.proofs = [] assert wallet3.balance == 0 await wallet3.restore_promises_from_to(0, 100) - assert wallet3.balance == 64 * 2 + assert wallet3.balance == 96 await wallet3.invalidate(wallet3.proofs, check_spendable=True) assert wallet3.balance == 64 @@ -233,7 +233,7 @@ async def test_restore_wallet_after_send_and_receive(wallet3: Wallet, wallet2: W assert wallet3.proofs == [] assert wallet3.balance == 0 await wallet3.restore_promises_from_to(0, 100) - assert wallet3.balance == 64 + 2 * 32 + assert wallet3.balance == 96 await wallet3.invalidate(wallet3.proofs, check_spendable=True) assert wallet3.balance == 32 @@ -276,7 +276,7 @@ async def test_restore_wallet_after_send_and_self_receive(wallet3: Wallet): assert wallet3.proofs == [] assert wallet3.balance == 0 await wallet3.restore_promises_from_to(0, 100) - assert wallet3.balance == 64 + 2 * 32 + 32 + assert wallet3.balance == 128 await wallet3.invalidate(wallet3.proofs, check_spendable=True) assert wallet3.balance == 64 @@ -311,7 +311,7 @@ async def test_restore_wallet_after_send_twice( assert wallet3.balance == 0 await wallet3.restore_promises_from_to(0, 10) box.add(wallet3.proofs) - assert wallet3.balance == 5 + assert wallet3.balance == 4 await wallet3.invalidate(wallet3.proofs, check_spendable=True) assert wallet3.balance == 2 @@ -333,7 +333,7 @@ async def test_restore_wallet_after_send_twice( assert wallet3.balance == 0 await wallet3.restore_promises_from_to(0, 15) box.add(wallet3.proofs) - assert wallet3.balance == 7 + assert wallet3.balance == 6 await wallet3.invalidate(wallet3.proofs, check_spendable=True) assert wallet3.balance == 2 @@ -370,7 +370,7 @@ async def test_restore_wallet_after_send_and_self_receive_nonquadratic_value( assert wallet3.balance == 0 await wallet3.restore_promises_from_to(0, 20) box.add(wallet3.proofs) - assert wallet3.balance == 138 + assert wallet3.balance == 84 await wallet3.invalidate(wallet3.proofs, check_spendable=True) assert wallet3.balance == 64 @@ -389,6 +389,6 @@ async def test_restore_wallet_after_send_and_self_receive_nonquadratic_value( assert wallet3.proofs == [] assert wallet3.balance == 0 await wallet3.restore_promises_from_to(0, 50) - assert wallet3.balance == 182 + assert wallet3.balance == 108 await wallet3.invalidate(wallet3.proofs, check_spendable=True) assert wallet3.balance == 64