diff --git a/Cargo.toml b/Cargo.toml index c46c162..2106d9b 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -46,6 +46,10 @@ required-features = ["rusqlite", "callbacks"] name = "wallet" required-features = ["wallet", "trace", "rusqlite", "callbacks"] +[[example]] +name = "multi" +required-features = ["wallet", "trace", "rusqlite", "callbacks"] + [[example]] name = "events" required-features = ["wallet", "rusqlite", "events"] diff --git a/examples/multi.rs b/examples/multi.rs new file mode 100644 index 0000000..339b2b0 --- /dev/null +++ b/examples/multi.rs @@ -0,0 +1,105 @@ +use std::{ + collections::HashSet, + net::{IpAddr, Ipv4Addr}, +}; + +use bdk_kyoto::{ + logger::TraceLogger, + multi::{MultiEventReceiver, MultiSyncRequest}, +}; +use bdk_wallet::{KeychainKind, Wallet}; +use kyoto::{HeaderCheckpoint, Network, NodeBuilder, ScriptBuf}; + +#[derive(Debug, Clone, Copy, Hash, PartialEq, Eq)] +struct WalletId(u8); + +const WALLET_ID_ONE: WalletId = WalletId(1); +const WALLET_ID_TWO: WalletId = WalletId(2); + +const PRIV_RECV: &str = "wpkh(tprv8ZgxMBicQKsPdy6LMhUtFHAgpocR8GC6QmwMSFpZs7h6Eziw3SpThFfczTDh5rW2krkqffa11UpX3XkeTTB2FvzZKWXqPY54Y6Rq4AQ5R8L/84'/1'/0'/0/*)"; +const PRIV_CHANGE: &str = "wpkh(tprv8ZgxMBicQKsPdy6LMhUtFHAgpocR8GC6QmwMSFpZs7h6Eziw3SpThFfczTDh5rW2krkqffa11UpX3XkeTTB2FvzZKWXqPY54Y6Rq4AQ5R8L/84'/1'/0'/1/*)"; +const PUB_RECV: &str = "tr([7d94197e/86'/1'/0']tpubDCyQVJj8KzjiQsFjmb3KwECVXPvMwvAxxZGCP9XmWSopmjW3bCV3wD7TgxrUhiGSueDS1MU5X1Vb1YjYcp8jitXc5fXfdC1z68hDDEyKRNr/0/*)"; +const PUB_CHANGE: &str = "tr([7d94197e/86'/1'/0']tpubDCyQVJj8KzjiQsFjmb3KwECVXPvMwvAxxZGCP9XmWSopmjW3bCV3wD7TgxrUhiGSueDS1MU5X1Vb1YjYcp8jitXc5fXfdC1z68hDDEyKRNr/1/*)"; + +const NETWORK: Network = Network::Signet; + +const PEER: IpAddr = IpAddr::V4(Ipv4Addr::new(23, 137, 57, 100)); +const NUM_PEERS: u8 = 1; + +const RECOVERY_HEIGHT: u32 = 170_000; + +fn get_scripts_for_wallets(wallets: &[&Wallet]) -> HashSet { + let mut spks = HashSet::new(); + for wallet in wallets { + for keychain in [KeychainKind::External, KeychainKind::Internal] { + let last_revealed = wallet + .spk_index() + .last_revealed_index(keychain) + .unwrap_or(0); + let lookahead_index = last_revealed + wallet.spk_index().lookahead(); + for index in 0..=lookahead_index { + spks.insert(wallet.peek_address(keychain, index).script_pubkey()); + } + } + } + spks +} + +#[tokio::main] +async fn main() -> anyhow::Result<()> { + let logger = TraceLogger::new()?; + + let mut wallet_one = Wallet::create(PUB_RECV, PUB_CHANGE) + .network(NETWORK) + .lookahead(30) + .create_wallet_no_persist()?; + + let mut wallet_two = Wallet::create(PRIV_RECV, PRIV_CHANGE) + .network(NETWORK) + .lookahead(30) + .create_wallet_no_persist()?; + + let scripts = get_scripts_for_wallets(&[&wallet_one, &wallet_two]); + + let (node, client) = NodeBuilder::new(NETWORK) + .add_peer(PEER) + .num_required_peers(NUM_PEERS) + .add_scripts(scripts) + .anchor_checkpoint(HeaderCheckpoint::closest_checkpoint_below_height( + RECOVERY_HEIGHT, + NETWORK, + )) + .build_node()?; + + tokio::task::spawn(async move { node.run().await }); + + let request_one = MultiSyncRequest { + index: WALLET_ID_ONE, + checkpoint: wallet_one.local_chain().tip(), + spk_index: wallet_one.spk_index().clone(), + }; + let request_two = MultiSyncRequest { + index: WALLET_ID_TWO, + checkpoint: wallet_two.local_chain().tip(), + spk_index: wallet_two.spk_index().clone(), + }; + let requests = vec![request_one, request_two]; + + let (sender, receiver) = client.split(); + + let mut event_receiver = MultiEventReceiver::from_requests(requests, receiver)?; + let updates = event_receiver.updates(&logger).await; + for (id, update) in updates { + if id.eq(&WALLET_ID_ONE) { + wallet_one.apply_update(update)?; + let balance = wallet_one.balance().total().to_sat(); + tracing::info!("Wallet one has {balance} satoshis"); + } else if id.eq(&WALLET_ID_TWO) { + wallet_two.apply_update(update)?; + let balance = wallet_two.balance().total().to_sat(); + tracing::info!("Wallet two has {balance} satoshis"); + } + } + sender.shutdown().await?; + return Ok(()); +} diff --git a/src/lib.rs b/src/lib.rs index d7beb72..8671db3 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -189,6 +189,8 @@ pub use kyoto::{NodeState, Receiver, SyncUpdate, TxBroadcast, TxBroadcastPolicy, pub mod builder; #[cfg(feature = "callbacks")] pub mod logger; +#[cfg(feature = "wallet")] +pub mod multi; #[cfg(feature = "wallet")] #[derive(Debug)] diff --git a/src/multi.rs b/src/multi.rs new file mode 100644 index 0000000..faae0eb --- /dev/null +++ b/src/multi.rs @@ -0,0 +1,176 @@ +//! Build an event receiver for multiple [`Wallet`](bdk_wallet). +use core::fmt; +use std::{ + collections::{BTreeMap, HashMap}, + hash::Hash, +}; + +use bdk_chain::{ + bitcoin::FeeRate, + keychain_txout::KeychainTxOutIndex, + local_chain::{LocalChain, MissingGenesisError}, + spk_client::FullScanResponse, + CheckPoint, ConfirmationBlockTime, IndexedTxGraph, TxUpdate, +}; +use kyoto::{IndexedBlock, NodeMessage, Receiver, SyncUpdate}; + +use crate::NodeEventHandler; +use crate::StringExt; + +/// One of potentially multiple sync requets for the [`MultiEventReceiver`] +/// to handle. +#[derive(Debug)] +pub struct MultiSyncRequest { + /// A unique index to identify the [`Wallet`](bdk_wallet). + pub index: H, + /// The tip of the chain for this wallet. + pub checkpoint: CheckPoint, + /// The script pubkeys for this wallet. + pub spk_index: KeychainTxOutIndex, +} + +/// Interpret events from a node that is running to apply +/// multiple wallets in parallel. +#[derive(Debug)] +pub struct MultiEventReceiver { + // channel receiver + receiver: kyoto::Receiver, + // map of chain and spk index to an index + map: HashMap< + H, + ( + LocalChain, + IndexedTxGraph>, + ), + >, + // the network minimum to broadcast a transaction + min_broadcast_fee: FeeRate, +} + +impl MultiEventReceiver +where + H: Hash + Eq + Clone + Copy, + K: fmt::Debug + Clone + Ord, +{ + /// Build a light client event handler from a [`KeychainTxOutIndex`] and [`CheckPoint`]. + pub fn from_requests( + requests: impl IntoIterator>, + receiver: Receiver, + ) -> Result { + let mut map = HashMap::new(); + for MultiSyncRequest { + index, + checkpoint, + spk_index, + } in requests + { + map.insert( + index, + ( + LocalChain::from_tip(checkpoint)?, + IndexedTxGraph::new(spk_index), + ), + ); + } + Ok(Self { + receiver, + map, + min_broadcast_fee: FeeRate::BROADCAST_MIN, + }) + } + + /// Return the most recent update from the node once it has synced to the network's tip. + /// This may take a significant portion of time during wallet recoveries or dormant wallets. + /// Note that you may call this method in a loop as long as the node is running. + /// + /// A reference to a [`NodeEventHandler`] is required, which handles events emitted from a + /// running node. Production applications should define how the application handles + /// these events and displays them to end users. + #[cfg(feature = "callbacks")] + pub async fn updates( + &mut self, + logger: &dyn NodeEventHandler, + ) -> impl Iterator)> { + use bdk_chain::local_chain; + + let mut chain_changeset = BTreeMap::new(); + while let Ok(message) = self.receiver.recv().await { + self.log(&message, logger); + match message { + NodeMessage::Block(IndexedBlock { height, block }) => { + let hash = block.header.block_hash(); + chain_changeset.insert(height, Some(hash)); + for (_, graph) in self.map.values_mut() { + let _ = graph.apply_block_relevant(&block, height); + } + } + NodeMessage::BlocksDisconnected(headers) => { + for header in headers { + let height = header.height; + chain_changeset.insert(height, None); + } + } + NodeMessage::Synced(SyncUpdate { + tip: _, + recent_history, + }) => { + recent_history.into_iter().for_each(|(height, header)| { + chain_changeset.insert(height, Some(header.block_hash())); + }); + break; + } + NodeMessage::FeeFilter(fee_filter) => { + if self.min_broadcast_fee < fee_filter { + self.min_broadcast_fee = fee_filter; + } + } + _ => (), + } + } + let mut responses = Vec::new(); + for (index, (local_chain, graph)) in &mut self.map { + let tx_update = TxUpdate::from(graph.graph().clone()); + let last_active_indices = graph.index.last_used_indices(); + local_chain + .apply_changeset(&local_chain::ChangeSet::from(chain_changeset.clone())) + .expect("chain was initialized with genesis"); + let update = FullScanResponse { + tx_update, + last_active_indices, + chain_update: Some(local_chain.tip()), + }; + responses.push((*index, update)); + } + responses.into_iter() + } + + #[cfg(feature = "callbacks")] + fn log(&self, message: &NodeMessage, logger: &dyn NodeEventHandler) { + match message { + NodeMessage::Dialog(d) => logger.dialog(d.clone()), + NodeMessage::Warning(w) => logger.warning(w.clone()), + NodeMessage::StateChange(s) => logger.state_changed(*s), + NodeMessage::Block(b) => { + let hash = b.block.header.block_hash(); + logger.dialog(format!("Applying Block: {hash}")); + } + NodeMessage::Synced(SyncUpdate { + tip, + recent_history: _, + }) => { + logger.synced(tip.height); + } + NodeMessage::BlocksDisconnected(headers) => { + logger.blocks_disconnected(headers.iter().map(|dc| dc.height).collect()); + } + NodeMessage::TxSent(t) => { + logger.tx_sent(*t); + } + NodeMessage::TxBroadcastFailure(r) => { + logger.tx_failed(r.txid, r.reason.map(|reason| reason.into_string())) + } + NodeMessage::ConnectionsMet => logger.connections_met(), + _ => (), + } + } +}