Skip to content

Commit

Permalink
[FIX] Wallet sort outputs before swapping (#648)
Browse files Browse the repository at this point in the history
* sort proofs

* outputs-ordering

* mypy fix

* clean up

* test if output amounts are sorted

* clean up test

---------

Co-authored-by: callebtc <[email protected]>
  • Loading branch information
lollerfirst and callebtc authored Nov 5, 2024
1 parent 9cdfba5 commit ed0d25d
Show file tree
Hide file tree
Showing 2 changed files with 75 additions and 4 deletions.
22 changes: 18 additions & 4 deletions cashu/wallet/wallet.py
Original file line number Diff line number Diff line change
Expand Up @@ -456,7 +456,7 @@ def split_wallet_state(self, amount: int) -> List[int]:
# sort by increasing amount
amounts_we_want.sort()

logger.debug(
logger.trace(
f"Amounts we have: {[(a, amounts_we_have.count(a)) for a in set(amounts_we_have)]}"
)
amounts: list[int] = []
Expand All @@ -470,7 +470,7 @@ def split_wallet_state(self, amount: int) -> List[int]:
amounts += amount_split(remaining_amount)
amounts.sort()

logger.debug(f"Amounts we want: {amounts}")
logger.trace(f"Amounts we want: {amounts}")
if sum(amounts) != amount:
raise Exception(f"Amounts do not sum to {amount}.")

Expand Down Expand Up @@ -643,7 +643,7 @@ async def split(
proofs = self.add_witnesses_to_proofs(proofs)

input_fees = self.get_fees_for_proofs(proofs)
logger.debug(f"Input fees: {input_fees}")
logger.trace(f"Input fees: {input_fees}")
# create a suitable amounts to keep and send.
keep_outputs, send_outputs = self.determine_output_amounts(
proofs,
Expand Down Expand Up @@ -674,8 +674,22 @@ async def split(
# potentially add witnesses to outputs based on what requirement the proofs indicate
outputs = self.add_witnesses_to_outputs(proofs, outputs)

# sort outputs by amount, remember original order
sorted_outputs_with_indices = sorted(
enumerate(outputs), key=lambda p: p[1].amount
)
original_indices, sorted_outputs = zip(*sorted_outputs_with_indices)

# Call swap API
promises = await super().split(proofs, outputs)
sorted_promises = await super().split(proofs, sorted_outputs)

# sort promises back to original order
promises = [
promise
for _, promise in sorted(
zip(original_indices, sorted_promises), key=lambda x: x[0]
)
]

# Construct proofs from returned promises (i.e., unblind the signatures)
new_proofs = await self._construct_proofs(
Expand Down
57 changes: 57 additions & 0 deletions tests/test_wallet_requests.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,57 @@
import json

import pytest
import pytest_asyncio
import respx
from httpx import Request, Response

from cashu.core.base import BlindedSignature
from cashu.core.crypto.b_dhke import hash_to_curve
from cashu.wallet.wallet import Wallet
from cashu.wallet.wallet import Wallet as Wallet1
from tests.conftest import SERVER_ENDPOINT
from tests.helpers import pay_if_regtest


@pytest_asyncio.fixture(scope="function")
async def wallet1(mint):
wallet1 = await Wallet1.with_db(
url=SERVER_ENDPOINT,
db="test_data/wallet1",
name="wallet1",
)
await wallet1.load_mint()
yield wallet1


@pytest.mark.asyncio
async def test_swap_outputs_are_sorted(wallet1: Wallet):
await wallet1.load_mint()
mint_quote = await wallet1.request_mint(16)
await pay_if_regtest(mint_quote.request)
await wallet1.mint(16, quote_id=mint_quote.quote, split=[16])
assert wallet1.balance == 16

test_url = f"{wallet1.url}/v1/swap"
key = hash_to_curve("test".encode("utf-8"))
mock_blind_signature = BlindedSignature(
id=wallet1.keyset_id,
amount=8,
C_=key.serialize().hex(),
)
mock_response_data = {"signatures": [mock_blind_signature.dict()]}
with respx.mock() as mock:
route = mock.post(test_url).mock(
return_value=Response(200, json=mock_response_data)
)
await wallet1.select_to_send(wallet1.proofs, 5)

assert route.called
assert route.call_count == 1
request: Request = route.calls[0].request
assert request.method == "POST"
assert request.url == test_url
request_data = json.loads(request.content.decode("utf-8"))
output_amounts = [o["amount"] for o in request_data["outputs"]]
# assert that output amounts are sorted
assert output_amounts == sorted(output_amounts)

0 comments on commit ed0d25d

Please sign in to comment.