diff --git a/cashu/wallet/api/router.py b/cashu/wallet/api/router.py index 0b96fa81..63820b9e 100644 --- a/cashu/wallet/api/router.py +++ b/cashu/wallet/api/router.py @@ -282,9 +282,9 @@ async def burn( wallet = await mint_wallet(mint) if not (all or token or force or delete) or (token and all): raise Exception( - "enter a token or use --all to burn all pending tokens, --force to check" - " all tokensor --delete with send ID to force-delete pending token from" - " list if mint is unavailable.", + "enter a token or use --all to burn all pending tokens, --force to" + " check all tokensor --delete with send ID to force-delete pending" + " token from list if mint is unavailable.", ) if all: # check only those who are flagged as reserved @@ -414,7 +414,7 @@ async def restore( if to < 0: raise Exception("Counter must be positive") await wallet.load_mint() - await wallet.restore_promises(0, to) + await wallet.restore_promises_from_to(0, to) await wallet.invalidate(wallet.proofs) wallet.status() return RestoreResponse(balance=wallet.available_balance) diff --git a/cashu/wallet/wallet.py b/cashu/wallet/wallet.py index 8926ffa7..69a40b50 100644 --- a/cashu/wallet/wallet.py +++ b/cashu/wallet/wallet.py @@ -1655,7 +1655,7 @@ async def restore_wallet_from_mnemonic( n_last_restored_proofs = 0 while stop_counter < to: print(f"Restoring token {i} to {i + batch}...") - restored_proofs = await self.restore_promises(i, i + batch - 1) + restored_proofs = await self.restore_promises_from_to(i, i + batch - 1) if len(restored_proofs) == 0: stop_counter += 1 spendable_proofs = await self.invalidate(restored_proofs) @@ -1679,7 +1679,9 @@ async def restore_wallet_from_mnemonic( print("No tokens restored.") return - async def restore_promises(self, from_counter: int, to_counter: int) -> List[Proof]: + async def restore_promises_from_to( + self, from_counter: int, to_counter: int + ) -> List[Proof]: """Restores promises from a given range of counters. This is for restoring a wallet from a mnemonic. Args: @@ -1698,14 +1700,42 @@ async def restore_promises(self, from_counter: int, to_counter: int) -> List[Pro # we generate outptus from deterministic secrets and rs regenerated_outputs, _ = self._construct_outputs(amounts_dummy, secrets, rs) # we ask the mint to reissue the promises - # restored_outputs is there so we can match the promises to the secrets and rs - restored_outputs, restored_promises = await super().restore_promises( - regenerated_outputs + proofs = await self.restore_promises( + outputs=regenerated_outputs, + secrets=secrets, + rs=rs, + derivation_paths=derivation_paths, + ) + + await set_secret_derivation( + db=self.db, keyset_id=self.keyset_id, counter=to_counter + 1 ) + return proofs + + async def restore_promises( + self, + outputs: List[BlindedMessage], + secrets: List[str], + rs: List[PrivateKey], + derivation_paths: List[str], + ) -> List[Proof]: + """Restores proofs from a list of outputs, secrets, rs and derivation paths. + + Args: + outputs (List[BlindedMessage]): Outputs for which we request promises + secrets (List[str]): Secrets generated for the outputs + rs (List[PrivateKey]): Random blinding factors generated for the outputs + derivation_paths (List[str]): Derivation paths for the secrets + + Returns: + List[Proof]: List of restored proofs + """ + # restored_outputs is there so we can match the promises to the secrets and rs + restored_outputs, restored_promises = await super().restore_promises(outputs) # now we need to filter out the secrets and rs that had a match matching_indices = [ idx - for idx, val in enumerate(regenerated_outputs) + for idx, val in enumerate(outputs) if val.B_ in [o.B_ for o in restored_outputs] ] secrets = [secrets[i] for i in matching_indices] @@ -1721,8 +1751,4 @@ async def restore_promises(self, from_counter: int, to_counter: int) -> List[Pro for proof in proofs: if proof.secret not in [p.secret for p in self.proofs]: self.proofs.append(proof) - - await set_secret_derivation( - db=self.db, keyset_id=self.keyset_id, counter=to_counter + 1 - ) return proofs diff --git a/tests/test_wallet.py b/tests/test_wallet.py index 9396ca52..f4ec78b6 100644 --- a/tests/test_wallet.py +++ b/tests/test_wallet.py @@ -387,7 +387,7 @@ async def test_restore_wallet_after_mint(wallet3: Wallet): await wallet3.load_proofs() wallet3.proofs = [] assert wallet3.balance == 0 - await wallet3.restore_promises(0, 20) + await wallet3.restore_promises_from_to(0, 20) assert wallet3.balance == 64 @@ -419,7 +419,7 @@ async def test_restore_wallet_after_split_to_send(wallet3: Wallet): await wallet3.load_proofs() wallet3.proofs = [] assert wallet3.balance == 0 - await wallet3.restore_promises(0, 100) + await wallet3.restore_promises_from_to(0, 100) assert wallet3.balance == 64 * 2 await wallet3.invalidate(wallet3.proofs) assert wallet3.balance == 64 @@ -443,7 +443,7 @@ async def test_restore_wallet_after_send_and_receive(wallet3: Wallet, wallet2: W await wallet3.load_proofs(reload=True) assert wallet3.proofs == [] assert wallet3.balance == 0 - await wallet3.restore_promises(0, 100) + await wallet3.restore_promises_from_to(0, 100) assert wallet3.balance == 64 + 2 * 32 await wallet3.invalidate(wallet3.proofs) assert wallet3.balance == 32 @@ -482,7 +482,7 @@ async def test_restore_wallet_after_send_and_self_receive(wallet3: Wallet): await wallet3.load_proofs(reload=True) assert wallet3.proofs == [] assert wallet3.balance == 0 - await wallet3.restore_promises(0, 100) + await wallet3.restore_promises_from_to(0, 100) assert wallet3.balance == 64 + 2 * 32 + 32 await wallet3.invalidate(wallet3.proofs) assert wallet3.balance == 64 @@ -512,7 +512,7 @@ async def test_restore_wallet_after_send_twice( await wallet3.load_proofs(reload=True) assert wallet3.proofs == [] assert wallet3.balance == 0 - await wallet3.restore_promises(0, 10) + await wallet3.restore_promises_from_to(0, 10) box.add(wallet3.proofs) assert wallet3.balance == 5 await wallet3.invalidate(wallet3.proofs) @@ -532,7 +532,7 @@ async def test_restore_wallet_after_send_twice( await wallet3.load_proofs(reload=True) assert wallet3.proofs == [] assert wallet3.balance == 0 - await wallet3.restore_promises(0, 15) + await wallet3.restore_promises_from_to(0, 15) box.add(wallet3.proofs) assert wallet3.balance == 7 await wallet3.invalidate(wallet3.proofs) @@ -565,7 +565,7 @@ async def test_restore_wallet_after_send_and_self_receive_nonquadratic_value( await wallet3.load_proofs(reload=True) assert wallet3.proofs == [] assert wallet3.balance == 0 - await wallet3.restore_promises(0, 20) + await wallet3.restore_promises_from_to(0, 20) box.add(wallet3.proofs) assert wallet3.balance == 138 await wallet3.invalidate(wallet3.proofs) @@ -583,7 +583,7 @@ async def test_restore_wallet_after_send_and_self_receive_nonquadratic_value( await wallet3.load_proofs(reload=True) assert wallet3.proofs == [] assert wallet3.balance == 0 - await wallet3.restore_promises(0, 50) + await wallet3.restore_promises_from_to(0, 50) assert wallet3.balance == 182 await wallet3.invalidate(wallet3.proofs) assert wallet3.balance == 64