Skip to content

Commit

Permalink
[Wallet] Refactor restore_promises_from_to (#307)
Browse files Browse the repository at this point in the history
* refactor restore_promises_from_to

* fix mypy

* black

* fix tests
  • Loading branch information
callebtc authored Aug 25, 2023
1 parent e374d32 commit f551624
Show file tree
Hide file tree
Showing 3 changed files with 48 additions and 22 deletions.
8 changes: 4 additions & 4 deletions cashu/wallet/api/router.py
Original file line number Diff line number Diff line change
Expand Up @@ -282,9 +282,9 @@ async def burn(
wallet = await mint_wallet(mint)
if not (all or token or force or delete) or (token and all):
raise Exception(
"enter a token or use --all to burn all pending tokens, --force to check"
" all tokensor --delete with send ID to force-delete pending token from"
" list if mint is unavailable.",
"enter a token or use --all to burn all pending tokens, --force to"
" check all tokensor --delete with send ID to force-delete pending"
" token from list if mint is unavailable.",
)
if all:
# check only those who are flagged as reserved
Expand Down Expand Up @@ -414,7 +414,7 @@ async def restore(
if to < 0:
raise Exception("Counter must be positive")
await wallet.load_mint()
await wallet.restore_promises(0, to)
await wallet.restore_promises_from_to(0, to)
await wallet.invalidate(wallet.proofs)
wallet.status()
return RestoreResponse(balance=wallet.available_balance)
Expand Down
46 changes: 36 additions & 10 deletions cashu/wallet/wallet.py
Original file line number Diff line number Diff line change
Expand Up @@ -1655,7 +1655,7 @@ async def restore_wallet_from_mnemonic(
n_last_restored_proofs = 0
while stop_counter < to:
print(f"Restoring token {i} to {i + batch}...")
restored_proofs = await self.restore_promises(i, i + batch - 1)
restored_proofs = await self.restore_promises_from_to(i, i + batch - 1)
if len(restored_proofs) == 0:
stop_counter += 1
spendable_proofs = await self.invalidate(restored_proofs)
Expand All @@ -1679,7 +1679,9 @@ async def restore_wallet_from_mnemonic(
print("No tokens restored.")
return

async def restore_promises(self, from_counter: int, to_counter: int) -> List[Proof]:
async def restore_promises_from_to(
self, from_counter: int, to_counter: int
) -> List[Proof]:
"""Restores promises from a given range of counters. This is for restoring a wallet from a mnemonic.
Args:
Expand All @@ -1698,14 +1700,42 @@ async def restore_promises(self, from_counter: int, to_counter: int) -> List[Pro
# we generate outptus from deterministic secrets and rs
regenerated_outputs, _ = self._construct_outputs(amounts_dummy, secrets, rs)
# we ask the mint to reissue the promises
# restored_outputs is there so we can match the promises to the secrets and rs
restored_outputs, restored_promises = await super().restore_promises(
regenerated_outputs
proofs = await self.restore_promises(
outputs=regenerated_outputs,
secrets=secrets,
rs=rs,
derivation_paths=derivation_paths,
)

await set_secret_derivation(
db=self.db, keyset_id=self.keyset_id, counter=to_counter + 1
)
return proofs

async def restore_promises(
self,
outputs: List[BlindedMessage],
secrets: List[str],
rs: List[PrivateKey],
derivation_paths: List[str],
) -> List[Proof]:
"""Restores proofs from a list of outputs, secrets, rs and derivation paths.
Args:
outputs (List[BlindedMessage]): Outputs for which we request promises
secrets (List[str]): Secrets generated for the outputs
rs (List[PrivateKey]): Random blinding factors generated for the outputs
derivation_paths (List[str]): Derivation paths for the secrets
Returns:
List[Proof]: List of restored proofs
"""
# restored_outputs is there so we can match the promises to the secrets and rs
restored_outputs, restored_promises = await super().restore_promises(outputs)
# now we need to filter out the secrets and rs that had a match
matching_indices = [
idx
for idx, val in enumerate(regenerated_outputs)
for idx, val in enumerate(outputs)
if val.B_ in [o.B_ for o in restored_outputs]
]
secrets = [secrets[i] for i in matching_indices]
Expand All @@ -1721,8 +1751,4 @@ async def restore_promises(self, from_counter: int, to_counter: int) -> List[Pro
for proof in proofs:
if proof.secret not in [p.secret for p in self.proofs]:
self.proofs.append(proof)

await set_secret_derivation(
db=self.db, keyset_id=self.keyset_id, counter=to_counter + 1
)
return proofs
16 changes: 8 additions & 8 deletions tests/test_wallet.py
Original file line number Diff line number Diff line change
Expand Up @@ -387,7 +387,7 @@ async def test_restore_wallet_after_mint(wallet3: Wallet):
await wallet3.load_proofs()
wallet3.proofs = []
assert wallet3.balance == 0
await wallet3.restore_promises(0, 20)
await wallet3.restore_promises_from_to(0, 20)
assert wallet3.balance == 64


Expand Down Expand Up @@ -419,7 +419,7 @@ async def test_restore_wallet_after_split_to_send(wallet3: Wallet):
await wallet3.load_proofs()
wallet3.proofs = []
assert wallet3.balance == 0
await wallet3.restore_promises(0, 100)
await wallet3.restore_promises_from_to(0, 100)
assert wallet3.balance == 64 * 2
await wallet3.invalidate(wallet3.proofs)
assert wallet3.balance == 64
Expand All @@ -443,7 +443,7 @@ async def test_restore_wallet_after_send_and_receive(wallet3: Wallet, wallet2: W
await wallet3.load_proofs(reload=True)
assert wallet3.proofs == []
assert wallet3.balance == 0
await wallet3.restore_promises(0, 100)
await wallet3.restore_promises_from_to(0, 100)
assert wallet3.balance == 64 + 2 * 32
await wallet3.invalidate(wallet3.proofs)
assert wallet3.balance == 32
Expand Down Expand Up @@ -482,7 +482,7 @@ async def test_restore_wallet_after_send_and_self_receive(wallet3: Wallet):
await wallet3.load_proofs(reload=True)
assert wallet3.proofs == []
assert wallet3.balance == 0
await wallet3.restore_promises(0, 100)
await wallet3.restore_promises_from_to(0, 100)
assert wallet3.balance == 64 + 2 * 32 + 32
await wallet3.invalidate(wallet3.proofs)
assert wallet3.balance == 64
Expand Down Expand Up @@ -512,7 +512,7 @@ async def test_restore_wallet_after_send_twice(
await wallet3.load_proofs(reload=True)
assert wallet3.proofs == []
assert wallet3.balance == 0
await wallet3.restore_promises(0, 10)
await wallet3.restore_promises_from_to(0, 10)
box.add(wallet3.proofs)
assert wallet3.balance == 5
await wallet3.invalidate(wallet3.proofs)
Expand All @@ -532,7 +532,7 @@ async def test_restore_wallet_after_send_twice(
await wallet3.load_proofs(reload=True)
assert wallet3.proofs == []
assert wallet3.balance == 0
await wallet3.restore_promises(0, 15)
await wallet3.restore_promises_from_to(0, 15)
box.add(wallet3.proofs)
assert wallet3.balance == 7
await wallet3.invalidate(wallet3.proofs)
Expand Down Expand Up @@ -565,7 +565,7 @@ async def test_restore_wallet_after_send_and_self_receive_nonquadratic_value(
await wallet3.load_proofs(reload=True)
assert wallet3.proofs == []
assert wallet3.balance == 0
await wallet3.restore_promises(0, 20)
await wallet3.restore_promises_from_to(0, 20)
box.add(wallet3.proofs)
assert wallet3.balance == 138
await wallet3.invalidate(wallet3.proofs)
Expand All @@ -583,7 +583,7 @@ async def test_restore_wallet_after_send_and_self_receive_nonquadratic_value(
await wallet3.load_proofs(reload=True)
assert wallet3.proofs == []
assert wallet3.balance == 0
await wallet3.restore_promises(0, 50)
await wallet3.restore_promises_from_to(0, 50)
assert wallet3.balance == 182
await wallet3.invalidate(wallet3.proofs)
assert wallet3.balance == 64

0 comments on commit f551624

Please sign in to comment.