From 5271c8996188f5ef113271544f034376ac9be250 Mon Sep 17 00:00:00 2001 From: Elias Rohrer Date: Mon, 30 Oct 2023 09:28:29 +0100 Subject: [PATCH] PREVIEW: Reformat project This commit is just included to give a preview for what the reformatted code will look like. --- lightning-background-processor/src/lib.rs | 1086 ++- lightning-block-sync/src/convert.rs | 297 +- lightning-block-sync/src/gossip.rs | 160 +- lightning-block-sync/src/http.rs | 154 +- lightning-block-sync/src/init.rs | 54 +- lightning-block-sync/src/lib.rs | 157 +- lightning-block-sync/src/poll.rs | 110 +- lightning-block-sync/src/rest.rs | 52 +- lightning-block-sync/src/rpc.rs | 69 +- lightning-block-sync/src/test_utils.rs | 29 +- lightning-invoice/src/de.rs | 482 +- lightning-invoice/src/lib.rs | 677 +- lightning-invoice/src/payment.rs | 160 +- lightning-invoice/src/ser.rs | 118 +- lightning-invoice/src/sync.rs | 2 +- lightning-invoice/src/tb.rs | 2 +- lightning-invoice/src/utils.rs | 1320 ++- lightning-invoice/tests/ser_de.rs | 39 +- lightning-net-tokio/src/lib.rs | 429 +- lightning-persister/src/fs_store.rs | 177 +- lightning-persister/src/lib.rs | 5 +- lightning-persister/src/test_utils.rs | 87 +- lightning-persister/src/utils.rs | 49 +- lightning-rapid-gossip-sync/src/error.rs | 4 +- lightning-rapid-gossip-sync/src/lib.rs | 43 +- lightning-rapid-gossip-sync/src/processing.rs | 158 +- lightning/src/blinded_path/message.rs | 32 +- lightning/src/blinded_path/mod.rs | 73 +- lightning/src/blinded_path/payment.rs | 353 +- lightning/src/blinded_path/utils.rs | 55 +- lightning/src/chain/chaininterface.rs | 32 +- lightning/src/chain/chainmonitor.rs | 664 +- lightning/src/chain/channelmonitor.rs | 3126 +++++-- lightning/src/chain/mod.rs | 38 +- lightning/src/chain/onchaintx.rs | 548 +- lightning/src/chain/package.rs | 884 +- lightning/src/chain/transaction.rs | 29 +- lightning/src/events/bump_transaction.rs | 375 +- lightning/src/events/mod.rs | 265 +- lightning/src/lib.rs | 43 +- lightning/src/ln/blinded_payment_tests.rs | 83 +- lightning/src/ln/chan_utils.rs | 1434 ++- lightning/src/ln/chanmon_update_fail_tests.rs | 2732 ++++-- lightning/src/ln/channel.rs | 6272 +++++++++---- lightning/src/ln/channel_id.rs | 18 +- lightning/src/ln/channelmanager.rs | 7827 +++++++++++----- lightning/src/ln/features.rs | 327 +- lightning/src/ln/functional_test_utils.rs | 2915 ++++-- lightning/src/ln/functional_tests.rs | 7844 +++++++++++++---- lightning/src/ln/inbound_payment.rs | 187 +- lightning/src/ln/mod.rs | 26 +- lightning/src/ln/monitor_tests.rs | 2669 ++++-- lightning/src/ln/msgs.rs | 1486 +++- lightning/src/ln/onion_route_tests.rs | 1821 +++- lightning/src/ln/onion_utils.rs | 631 +- lightning/src/ln/outbound_payment.rs | 1974 +++-- lightning/src/ln/payment_tests.rs | 3298 +++++-- lightning/src/ln/peer_channel_encryptor.rs | 636 +- lightning/src/ln/peer_handler.rs | 1800 ++-- lightning/src/ln/priv_short_conf_tests.rs | 897 +- lightning/src/ln/reload_tests.rs | 880 +- lightning/src/ln/reorg_tests.rs | 308 +- lightning/src/ln/script.rs | 40 +- lightning/src/ln/shutdown_tests.rs | 1198 ++- lightning/src/ln/wire.rs | 306 +- lightning/src/offers/invoice.rs | 1004 ++- lightning/src/offers/invoice_error.rs | 25 +- lightning/src/offers/invoice_request.rs | 974 +- lightning/src/offers/merkle.rs | 102 +- lightning/src/offers/offer.rs | 337 +- lightning/src/offers/parse.rs | 13 +- lightning/src/offers/refund.rs | 380 +- lightning/src/offers/signer.rs | 64 +- lightning/src/offers/test_utils.rs | 12 +- .../src/onion_message/functional_tests.rs | 243 +- lightning/src/onion_message/messenger.rs | 369 +- lightning/src/onion_message/mod.rs | 14 +- lightning/src/onion_message/offers.rs | 20 +- lightning/src/onion_message/packet.rs | 68 +- lightning/src/routing/gossip.rs | 1551 ++-- lightning/src/routing/mod.rs | 2 +- lightning/src/routing/router.rs | 5294 +++++++---- lightning/src/routing/scoring.rs | 2310 +++-- lightning/src/routing/test_utils.rs | 352 +- lightning/src/routing/utxo.rs | 756 +- lightning/src/sign/mod.rs | 1049 ++- lightning/src/sign/type_resolver.rs | 12 +- lightning/src/sync/debug_sync.rs | 65 +- lightning/src/sync/fairrwlock.rs | 4 +- lightning/src/sync/mod.rs | 21 +- lightning/src/sync/nostd_sync.rs | 32 +- lightning/src/sync/test_lockorder_checks.rs | 2 +- lightning/src/util/atomic_counter.rs | 9 +- lightning/src/util/base32.rs | 38 +- lightning/src/util/byte_utils.rs | 24 +- lightning/src/util/chacha20.rs | 518 +- lightning/src/util/chacha20poly1305rfc.rs | 67 +- lightning/src/util/config.rs | 23 +- lightning/src/util/crypto.rs | 17 +- lightning/src/util/errors.rs | 28 +- lightning/src/util/fuzz_wrappers.rs | 22 +- lightning/src/util/indexed_map.rs | 53 +- lightning/src/util/invoice.rs | 9 +- lightning/src/util/logger.rs | 17 +- lightning/src/util/macro_logger.rs | 77 +- lightning/src/util/message_signing.rs | 21 +- lightning/src/util/mod.rs | 19 +- lightning/src/util/persist.rs | 449 +- lightning/src/util/poly1305.rs | 236 +- lightning/src/util/scid_utils.rs | 104 +- lightning/src/util/ser.rs | 188 +- lightning/src/util/ser_macros.rs | 289 +- lightning/src/util/string.rs | 6 +- lightning/src/util/test_channel_signer.rs | 231 +- lightning/src/util/test_utils.rs | 592 +- lightning/src/util/time.rs | 26 +- lightning/src/util/transaction_utils.rs | 185 +- lightning/src/util/wakers.rs | 138 +- 118 files changed, 55924 insertions(+), 22583 deletions(-) diff --git a/lightning-background-processor/src/lib.rs b/lightning-background-processor/src/lib.rs index aa6d0b0615e..41ba0633a01 100644 --- a/lightning-background-processor/src/lib.rs +++ b/lightning-background-processor/src/lib.rs @@ -5,12 +5,9 @@ // Prefix these with `rustdoc::` when we update our MSRV to be >= 1.52 to remove warnings. #![deny(broken_intra_doc_links)] #![deny(private_intra_doc_links)] - #![deny(missing_docs)] #![cfg_attr(not(feature = "futures"), deny(unsafe_code))] - #![cfg_attr(docsrs, feature(doc_auto_cfg))] - #![cfg_attr(all(not(feature = "std"), not(test)), no_std)] #[cfg(any(test, feature = "std"))] @@ -19,22 +16,23 @@ extern crate core; #[cfg(not(feature = "std"))] extern crate alloc; -#[macro_use] extern crate lightning; +#[macro_use] +extern crate lightning; extern crate lightning_rapid_gossip_sync; use lightning::chain; use lightning::chain::chaininterface::{BroadcasterInterface, FeeEstimator}; use lightning::chain::chainmonitor::{ChainMonitor, Persist}; -use lightning::sign::{EntropySource, NodeSigner, SignerProvider}; use lightning::events::{Event, PathFailure}; #[cfg(feature = "std")] use lightning::events::{EventHandler, EventsProvider}; use lightning::ln::channelmanager::ChannelManager; use lightning::ln::peer_handler::APeerManager; use lightning::routing::gossip::{NetworkGraph, P2PGossipSync}; -use lightning::routing::utxo::UtxoLookup; use lightning::routing::router::Router; use lightning::routing::scoring::{ScoreUpdate, WriteableScore}; +use lightning::routing::utxo::UtxoLookup; +use lightning::sign::{EntropySource, NodeSigner, SignerProvider}; use lightning::util::logger::Logger; use lightning::util::persist::Persister; #[cfg(feature = "std")] @@ -44,11 +42,11 @@ use lightning_rapid_gossip_sync::RapidGossipSync; use core::ops::Deref; use core::time::Duration; -#[cfg(feature = "std")] -use std::sync::Arc; #[cfg(feature = "std")] use core::sync::atomic::{AtomicBool, Ordering}; #[cfg(feature = "std")] +use std::sync::Arc; +#[cfg(feature = "std")] use std::thread::{self, JoinHandle}; #[cfg(feature = "std")] use std::time::Instant; @@ -124,10 +122,18 @@ const REBROADCAST_TIMER: u64 = 1; #[cfg(feature = "futures")] /// core::cmp::min is not currently const, so we define a trivial (and equivalent) replacement -const fn min_u64(a: u64, b: u64) -> u64 { if a < b { a } else { b } } +const fn min_u64(a: u64, b: u64) -> u64 { + if a < b { + a + } else { + b + } +} #[cfg(feature = "futures")] -const FASTEST_TIMER: u64 = min_u64(min_u64(FRESHNESS_TIMER, PING_TIMER), - min_u64(SCORER_PERSIST_TIMER, min_u64(FIRST_NETWORK_PRUNE_TIMER, REBROADCAST_TIMER))); +const FASTEST_TIMER: u64 = min_u64( + min_u64(FRESHNESS_TIMER, PING_TIMER), + min_u64(SCORER_PERSIST_TIMER, min_u64(FIRST_NETWORK_PRUNE_TIMER, REBROADCAST_TIMER)), +); /// Either [`P2PGossipSync`] or [`RapidGossipSync`]. pub enum GossipSync< @@ -136,8 +142,10 @@ pub enum GossipSync< G: Deref>, U: Deref, L: Deref, -> -where U::Target: UtxoLookup, L::Target: Logger { +> where + U::Target: UtxoLookup, + L::Target: Logger, +{ /// Gossip sync via the lightning peer-to-peer network as defined by BOLT 7. P2P(P), /// Rapid gossip sync from a trusted server. @@ -147,13 +155,16 @@ where U::Target: UtxoLookup, L::Target: Logger { } impl< - P: Deref>, - R: Deref>, - G: Deref>, - U: Deref, - L: Deref, -> GossipSync -where U::Target: UtxoLookup, L::Target: Logger { + P: Deref>, + R: Deref>, + G: Deref>, + U: Deref, + L: Deref, + > GossipSync +where + U::Target: UtxoLookup, + L::Target: Logger, +{ fn network_graph(&self) -> Option<&G> { match self { GossipSync::P2P(gossip_sync) => Some(gossip_sync.network_graph()), @@ -178,8 +189,12 @@ where U::Target: UtxoLookup, L::Target: Logger { } /// This is not exported to bindings users as the bindings concretize everything and have constructors for us -impl>, G: Deref>, U: Deref, L: Deref> - GossipSync, G, U, L> +impl< + P: Deref>, + G: Deref>, + U: Deref, + L: Deref, + > GossipSync, G, U, L> where U::Target: UtxoLookup, L::Target: Logger, @@ -191,15 +206,19 @@ where } /// This is not exported to bindings users as the bindings concretize everything and have constructors for us -impl<'a, R: Deref>, G: Deref>, L: Deref> +impl< + 'a, + R: Deref>, + G: Deref>, + L: Deref, + > GossipSync< &P2PGossipSync, R, G, &'a (dyn UtxoLookup + Send + Sync), L, - > -where + > where L::Target: Logger, { /// Initializes a new [`GossipSync::Rapid`] variant. @@ -216,8 +235,7 @@ impl<'a, L: Deref> &'a NetworkGraph, &'a (dyn UtxoLookup + Send + Sync), L, - > -where + > where L::Target: Logger, { /// Initializes a new [`GossipSync::None`] variant. @@ -226,11 +244,14 @@ where } } -fn handle_network_graph_update( - network_graph: &NetworkGraph, event: &Event -) where L::Target: Logger { +fn handle_network_graph_update(network_graph: &NetworkGraph, event: &Event) +where + L::Target: Logger, +{ if let Event::PaymentPathFailed { - failure: PathFailure::OnPath { network_update: Some(ref upd) }, .. } = event + failure: PathFailure::OnPath { network_update: Some(ref upd) }, + .. + } = event { network_graph.handle_network_update(upd); } @@ -239,7 +260,7 @@ fn handle_network_graph_update( /// Updates scorer based on event and returns whether an update occurred so we can decide whether /// to persist. fn update_scorer<'a, S: 'static + Deref + Send + Sync, SC: 'a + WriteableScore<'a>>( - scorer: &'a S, event: &Event + scorer: &'a S, event: &Event, ) -> bool { match event { Event::PaymentPathFailed { ref path, short_channel_id: Some(scid), .. } => { @@ -428,35 +449,50 @@ macro_rules! define_run_body { #[cfg(feature = "futures")] pub(crate) mod futures_util { use core::future::Future; - use core::task::{Poll, Waker, RawWaker, RawWakerVTable}; - use core::pin::Pin; use core::marker::Unpin; + use core::pin::Pin; + use core::task::{Poll, RawWaker, RawWakerVTable, Waker}; pub(crate) struct Selector< - A: Future + Unpin, B: Future + Unpin, C: Future + Unpin + A: Future + Unpin, + B: Future + Unpin, + C: Future + Unpin, > { pub a: A, pub b: B, pub c: C, } pub(crate) enum SelectorOutput { - A, B, C(bool), + A, + B, + C(bool), } impl< - A: Future + Unpin, B: Future + Unpin, C: Future + Unpin - > Future for Selector { + A: Future + Unpin, + B: Future + Unpin, + C: Future + Unpin, + > Future for Selector + { type Output = SelectorOutput; - fn poll(mut self: Pin<&mut Self>, ctx: &mut core::task::Context<'_>) -> Poll { + fn poll( + mut self: Pin<&mut Self>, ctx: &mut core::task::Context<'_>, + ) -> Poll { match Pin::new(&mut self.a).poll(ctx) { - Poll::Ready(()) => { return Poll::Ready(SelectorOutput::A); }, + Poll::Ready(()) => { + return Poll::Ready(SelectorOutput::A); + }, Poll::Pending => {}, } match Pin::new(&mut self.b).poll(ctx) { - Poll::Ready(()) => { return Poll::Ready(SelectorOutput::B); }, + Poll::Ready(()) => { + return Poll::Ready(SelectorOutput::B); + }, Poll::Pending => {}, } match Pin::new(&mut self.c).poll(ctx) { - Poll::Ready(res) => { return Poll::Ready(SelectorOutput::C(res)); }, + Poll::Ready(res) => { + return Poll::Ready(SelectorOutput::C(res)); + }, Poll::Pending => {}, } Poll::Pending @@ -466,17 +502,25 @@ pub(crate) mod futures_util { // If we want to poll a future without an async context to figure out if it has completed or // not without awaiting, we need a Waker, which needs a vtable...we fill it with dummy values // but sadly there's a good bit of boilerplate here. - fn dummy_waker_clone(_: *const ()) -> RawWaker { RawWaker::new(core::ptr::null(), &DUMMY_WAKER_VTABLE) } - fn dummy_waker_action(_: *const ()) { } + fn dummy_waker_clone(_: *const ()) -> RawWaker { + RawWaker::new(core::ptr::null(), &DUMMY_WAKER_VTABLE) + } + fn dummy_waker_action(_: *const ()) {} const DUMMY_WAKER_VTABLE: RawWakerVTable = RawWakerVTable::new( - dummy_waker_clone, dummy_waker_action, dummy_waker_action, dummy_waker_action); - pub(crate) fn dummy_waker() -> Waker { unsafe { Waker::from_raw(RawWaker::new(core::ptr::null(), &DUMMY_WAKER_VTABLE)) } } + dummy_waker_clone, + dummy_waker_action, + dummy_waker_action, + dummy_waker_action, + ); + pub(crate) fn dummy_waker() -> Waker { + unsafe { Waker::from_raw(RawWaker::new(core::ptr::null(), &DUMMY_WAKER_VTABLE)) } + } } #[cfg(feature = "futures")] -use futures_util::{Selector, SelectorOutput, dummy_waker}; -#[cfg(feature = "futures")] use core::task; +#[cfg(feature = "futures")] +use futures_util::{dummy_waker, Selector, SelectorOutput}; /// Processes background events in a future. /// @@ -599,7 +643,10 @@ pub async fn process_events_async< EventHandlerFuture: core::future::Future, EventHandler: Fn(Event) -> EventHandlerFuture, PS: 'static + Deref + Send, - M: 'static + Deref::Signer, CF, T, F, L, P>> + Send + Sync, + M: 'static + + Deref::Signer, CF, T, F, L, P>> + + Send + + Sync, CM: 'static + Deref> + Send + Sync, PGS: 'static + Deref> + Send + Sync, RGS: 'static + Deref> + Send, @@ -608,7 +655,7 @@ pub async fn process_events_async< S: 'static + Deref + Send + Sync, SC: for<'b> WriteableScore<'b>, SleepFuture: core::future::Future + core::marker::Unpin, - Sleeper: Fn(Duration) -> SleepFuture + Sleeper: Fn(Duration) -> SleepFuture, >( persister: PS, event_handler: EventHandler, chain_monitor: M, channel_manager: CM, gossip_sync: GossipSync, peer_manager: PM, logger: L, scorer: Option, @@ -629,51 +676,70 @@ where PS::Target: 'static + Persister<'a, CW, T, ES, NS, SP, F, R, L, SC>, { let mut should_break = false; - let async_event_handler = |event| { - let network_graph = gossip_sync.network_graph(); - let event_handler = &event_handler; - let scorer = &scorer; - let logger = &logger; - let persister = &persister; - async move { - if let Some(network_graph) = network_graph { - handle_network_graph_update(network_graph, &event) - } - if let Some(ref scorer) = scorer { - if update_scorer(scorer, &event) { - log_trace!(logger, "Persisting scorer after update"); - if let Err(e) = persister.persist_scorer(&scorer) { - log_error!(logger, "Error: Failed to persist scorer, check your disk and permissions {}", e) + let async_event_handler = + |event| { + let network_graph = gossip_sync.network_graph(); + let event_handler = &event_handler; + let scorer = &scorer; + let logger = &logger; + let persister = &persister; + async move { + if let Some(network_graph) = network_graph { + handle_network_graph_update(network_graph, &event) + } + if let Some(ref scorer) = scorer { + if update_scorer(scorer, &event) { + log_trace!(logger, "Persisting scorer after update"); + if let Err(e) = persister.persist_scorer(&scorer) { + log_error!(logger, "Error: Failed to persist scorer, check your disk and permissions {}", e) + } } } + event_handler(event).await; } - event_handler(event).await; - } - }; - define_run_body!(persister, - chain_monitor, chain_monitor.process_pending_events_async(async_event_handler).await, - channel_manager, channel_manager.process_pending_events_async(async_event_handler).await, - gossip_sync, peer_manager, logger, scorer, should_break, { + }; + define_run_body!( + persister, + chain_monitor, + chain_monitor.process_pending_events_async(async_event_handler).await, + channel_manager, + channel_manager.process_pending_events_async(async_event_handler).await, + gossip_sync, + peer_manager, + logger, + scorer, + should_break, + { let fut = Selector { a: channel_manager.get_event_or_persistence_needed_future(), b: chain_monitor.get_update_future(), - c: sleeper(if mobile_interruptable_platform { Duration::from_millis(100) } else { Duration::from_secs(FASTEST_TIMER) }), + c: sleeper(if mobile_interruptable_platform { + Duration::from_millis(100) + } else { + Duration::from_secs(FASTEST_TIMER) + }), }; match fut.await { - SelectorOutput::A|SelectorOutput::B => {}, + SelectorOutput::A | SelectorOutput::B => {}, SelectorOutput::C(exit) => { should_break = exit; - } + }, } - }, |t| sleeper(Duration::from_secs(t)), + }, + |t| sleeper(Duration::from_secs(t)), |fut: &mut SleepFuture, _| { let mut waker = dummy_waker(); let mut ctx = task::Context::from_waker(&mut waker); match core::pin::Pin::new(fut).poll(&mut ctx) { - task::Poll::Ready(exit) => { should_break = exit; true }, + task::Poll::Ready(exit) => { + should_break = exit; + true + }, task::Poll::Pending => false, } - }, mobile_interruptable_platform) + }, + mobile_interruptable_platform + ) } #[cfg(feature = "std")] @@ -738,17 +804,21 @@ impl BackgroundProcessor { P: 'static + Deref + Send + Sync, EH: 'static + EventHandler + Send, PS: 'static + Deref + Send, - M: 'static + Deref::Signer, CF, T, F, L, P>> + Send + Sync, + M: 'static + + Deref::Signer, CF, T, F, L, P>> + + Send + + Sync, CM: 'static + Deref> + Send + Sync, PGS: 'static + Deref> + Send + Sync, RGS: 'static + Deref> + Send, APM: APeerManager + Send + Sync, PM: 'static + Deref + Send + Sync, S: 'static + Deref + Send + Sync, - SC: for <'b> WriteableScore<'b>, + SC: for<'b> WriteableScore<'b>, >( persister: PS, event_handler: EH, chain_monitor: M, channel_manager: CM, - gossip_sync: GossipSync, peer_manager: PM, logger: L, scorer: Option, + gossip_sync: GossipSync, peer_manager: PM, logger: L, + scorer: Option, ) -> Self where UL::Target: 'static + UtxoLookup, @@ -782,14 +852,28 @@ impl BackgroundProcessor { } event_handler.handle_event(event); }; - define_run_body!(persister, chain_monitor, chain_monitor.process_pending_events(&event_handler), - channel_manager, channel_manager.process_pending_events(&event_handler), - gossip_sync, peer_manager, logger, scorer, stop_thread.load(Ordering::Acquire), - { Sleeper::from_two_futures( - channel_manager.get_event_or_persistence_needed_future(), - chain_monitor.get_update_future() - ).wait_timeout(Duration::from_millis(100)); }, - |_| Instant::now(), |time: &Instant, dur| time.elapsed().as_secs() > dur, false) + define_run_body!( + persister, + chain_monitor, + chain_monitor.process_pending_events(&event_handler), + channel_manager, + channel_manager.process_pending_events(&event_handler), + gossip_sync, + peer_manager, + logger, + scorer, + stop_thread.load(Ordering::Acquire), + { + Sleeper::from_two_futures( + channel_manager.get_event_or_persistence_needed_future(), + chain_monitor.get_update_future(), + ) + .wait_timeout(Duration::from_millis(100)); + }, + |_| Instant::now(), + |time: &Instant, dur| time.elapsed().as_secs() > dur, + false + ) }); Self { stop_thread: stop_thread_clone, thread_handle: Some(handle) } } @@ -844,48 +928,55 @@ impl Drop for BackgroundProcessor { #[cfg(all(feature = "std", test))] mod tests { + use super::{BackgroundProcessor, GossipSync, FRESHNESS_TIMER}; use bitcoin::blockdata::constants::{genesis_block, ChainHash}; use bitcoin::blockdata::locktime::PackedLockTime; use bitcoin::blockdata::transaction::{Transaction, TxOut}; use bitcoin::network::constants::Network; - use bitcoin::secp256k1::{SecretKey, PublicKey, Secp256k1}; - use lightning::chain::{BestBlock, Confirm, chainmonitor}; + use bitcoin::secp256k1::{PublicKey, Secp256k1, SecretKey}; use lightning::chain::channelmonitor::ANTI_REORG_DELAY; - use lightning::sign::{InMemorySigner, KeysManager}; use lightning::chain::transaction::OutPoint; - use lightning::events::{Event, PathFailure, MessageSendEventsProvider, MessageSendEvent}; - use lightning::{get_event_msg, get_event}; - use lightning::ln::PaymentHash; + use lightning::chain::{chainmonitor, BestBlock, Confirm}; + use lightning::events::{Event, MessageSendEvent, MessageSendEventsProvider, PathFailure}; use lightning::ln::channelmanager; - use lightning::ln::channelmanager::{BREAKDOWN_TIMEOUT, ChainParameters, MIN_CLTV_EXPIRY_DELTA, PaymentId}; + use lightning::ln::channelmanager::{ + ChainParameters, PaymentId, BREAKDOWN_TIMEOUT, MIN_CLTV_EXPIRY_DELTA, + }; use lightning::ln::features::{ChannelFeatures, NodeFeatures}; use lightning::ln::functional_test_utils::*; use lightning::ln::msgs::{ChannelMessageHandler, Init}; - use lightning::ln::peer_handler::{PeerManager, MessageHandler, SocketDescriptor, IgnoringMessageHandler}; + use lightning::ln::peer_handler::{ + IgnoringMessageHandler, MessageHandler, PeerManager, SocketDescriptor, + }; + use lightning::ln::PaymentHash; use lightning::routing::gossip::{NetworkGraph, NodeId, P2PGossipSync}; use lightning::routing::router::{DefaultRouter, Path, RouteHop}; - use lightning::routing::scoring::{ChannelUsage, ScoreUpdate, ScoreLookUp, LockableScore}; + use lightning::routing::scoring::{ChannelUsage, LockableScore, ScoreLookUp, ScoreUpdate}; + use lightning::sign::{InMemorySigner, KeysManager}; use lightning::util::config::UserConfig; + use lightning::util::persist::{ + KVStore, CHANNEL_MANAGER_PERSISTENCE_KEY, CHANNEL_MANAGER_PERSISTENCE_PRIMARY_NAMESPACE, + CHANNEL_MANAGER_PERSISTENCE_SECONDARY_NAMESPACE, NETWORK_GRAPH_PERSISTENCE_KEY, + NETWORK_GRAPH_PERSISTENCE_PRIMARY_NAMESPACE, NETWORK_GRAPH_PERSISTENCE_SECONDARY_NAMESPACE, + SCORER_PERSISTENCE_KEY, SCORER_PERSISTENCE_PRIMARY_NAMESPACE, + SCORER_PERSISTENCE_SECONDARY_NAMESPACE, + }; use lightning::util::ser::Writeable; use lightning::util::test_utils; - use lightning::util::persist::{KVStore, - CHANNEL_MANAGER_PERSISTENCE_PRIMARY_NAMESPACE, CHANNEL_MANAGER_PERSISTENCE_SECONDARY_NAMESPACE, CHANNEL_MANAGER_PERSISTENCE_KEY, - NETWORK_GRAPH_PERSISTENCE_PRIMARY_NAMESPACE, NETWORK_GRAPH_PERSISTENCE_SECONDARY_NAMESPACE, NETWORK_GRAPH_PERSISTENCE_KEY, - SCORER_PERSISTENCE_PRIMARY_NAMESPACE, SCORER_PERSISTENCE_SECONDARY_NAMESPACE, SCORER_PERSISTENCE_KEY}; + use lightning::{get_event, get_event_msg}; use lightning_persister::fs_store::FilesystemStore; + use lightning_rapid_gossip_sync::RapidGossipSync; use std::collections::VecDeque; - use std::{fs, env}; use std::path::PathBuf; - use std::sync::{Arc, Mutex}; use std::sync::mpsc::SyncSender; + use std::sync::{Arc, Mutex}; use std::time::Duration; - use lightning_rapid_gossip_sync::RapidGossipSync; - use super::{BackgroundProcessor, GossipSync, FRESHNESS_TIMER}; + use std::{env, fs}; const EVENT_DEADLINE: u64 = 5 * FRESHNESS_TIMER; #[derive(Clone, Hash, PartialEq, Eq)] - struct TestDescriptor{} + struct TestDescriptor {} impl SocketDescriptor for TestDescriptor { fn send_data(&mut self, _data: &[u8], _resume_read: bool) -> usize { 0 @@ -899,33 +990,63 @@ mod tests { #[cfg(not(c_bindings))] type LockingWrapper = Mutex; - type ChannelManager = - channelmanager::ChannelManager< - Arc, - Arc, - Arc, - Arc, - Arc, - Arc, - Arc, + Arc, + Arc, + Arc, + Arc, + Arc, + Arc< + DefaultRouter< Arc>>, Arc, Arc>, (), - TestScorer> + TestScorer, >, - Arc>; - - type ChainMonitor = chainmonitor::ChainMonitor, Arc, Arc, Arc, Arc>; - - type PGS = Arc>>, Arc, Arc>>; - type RGS = Arc>>, Arc>>; + >, + Arc, + >; + + type ChainMonitor = chainmonitor::ChainMonitor< + InMemorySigner, + Arc, + Arc, + Arc, + Arc, + Arc, + >; + + type PGS = Arc< + P2PGossipSync< + Arc>>, + Arc, + Arc, + >, + >; + type RGS = Arc< + RapidGossipSync< + Arc>>, + Arc, + >, + >; struct Node { node: Arc, p2p_gossip_sync: PGS, rapid_gossip_sync: RGS, - peer_manager: Arc, Arc, IgnoringMessageHandler, Arc, IgnoringMessageHandler, Arc>>, + peer_manager: Arc< + PeerManager< + TestDescriptor, + Arc, + Arc, + IgnoringMessageHandler, + Arc, + IgnoringMessageHandler, + Arc, + >, + >, chain_monitor: Arc, kv_store: Arc, tx_broadcaster: Arc, @@ -936,15 +1057,39 @@ mod tests { } impl Node { - fn p2p_gossip_sync(&self) -> GossipSync>>, Arc, Arc> { + fn p2p_gossip_sync( + &self, + ) -> GossipSync< + PGS, + RGS, + Arc>>, + Arc, + Arc, + > { GossipSync::P2P(self.p2p_gossip_sync.clone()) } - fn rapid_gossip_sync(&self) -> GossipSync>>, Arc, Arc> { + fn rapid_gossip_sync( + &self, + ) -> GossipSync< + PGS, + RGS, + Arc>>, + Arc, + Arc, + > { GossipSync::Rapid(self.rapid_gossip_sync.clone()) } - fn no_gossip_sync(&self) -> GossipSync>>, Arc, Arc> { + fn no_gossip_sync( + &self, + ) -> GossipSync< + PGS, + RGS, + Arc>>, + Arc, + Arc, + > { GossipSync::None } } @@ -953,8 +1098,10 @@ mod tests { fn drop(&mut self) { let data_dir = self.kv_store.get_data_dir(); match fs::remove_dir_all(data_dir.clone()) { - Err(e) => println!("Failed to remove test store directory {}: {}", data_dir.display(), e), - _ => {} + Err(e) => { + println!("Failed to remove test store directory {}: {}", data_dir.display(), e) + }, + _ => {}, } } } @@ -970,7 +1117,13 @@ mod tests { impl Persister { fn new(data_dir: PathBuf) -> Self { let kv_store = FilesystemStore::new(data_dir); - Self { graph_error: None, graph_persistence_notifier: None, manager_error: None, scorer_error: None, kv_store } + Self { + graph_error: None, + graph_persistence_notifier: None, + manager_error: None, + scorer_error: None, + kv_store, + } } fn with_graph_error(self, error: std::io::ErrorKind, message: &'static str) -> Self { @@ -991,53 +1144,63 @@ mod tests { } impl KVStore for Persister { - fn read(&self, primary_namespace: &str, secondary_namespace: &str, key: &str) -> lightning::io::Result> { + fn read( + &self, primary_namespace: &str, secondary_namespace: &str, key: &str, + ) -> lightning::io::Result> { self.kv_store.read(primary_namespace, secondary_namespace, key) } - fn write(&self, primary_namespace: &str, secondary_namespace: &str, key: &str, buf: &[u8]) -> lightning::io::Result<()> { - if primary_namespace == CHANNEL_MANAGER_PERSISTENCE_PRIMARY_NAMESPACE && - secondary_namespace == CHANNEL_MANAGER_PERSISTENCE_SECONDARY_NAMESPACE && - key == CHANNEL_MANAGER_PERSISTENCE_KEY + fn write( + &self, primary_namespace: &str, secondary_namespace: &str, key: &str, buf: &[u8], + ) -> lightning::io::Result<()> { + if primary_namespace == CHANNEL_MANAGER_PERSISTENCE_PRIMARY_NAMESPACE + && secondary_namespace == CHANNEL_MANAGER_PERSISTENCE_SECONDARY_NAMESPACE + && key == CHANNEL_MANAGER_PERSISTENCE_KEY { if let Some((error, message)) = self.manager_error { - return Err(std::io::Error::new(error, message)) + return Err(std::io::Error::new(error, message)); } } - if primary_namespace == NETWORK_GRAPH_PERSISTENCE_PRIMARY_NAMESPACE && - secondary_namespace == NETWORK_GRAPH_PERSISTENCE_SECONDARY_NAMESPACE && - key == NETWORK_GRAPH_PERSISTENCE_KEY + if primary_namespace == NETWORK_GRAPH_PERSISTENCE_PRIMARY_NAMESPACE + && secondary_namespace == NETWORK_GRAPH_PERSISTENCE_SECONDARY_NAMESPACE + && key == NETWORK_GRAPH_PERSISTENCE_KEY { if let Some(sender) = &self.graph_persistence_notifier { match sender.send(()) { Ok(()) => {}, - Err(std::sync::mpsc::SendError(())) => println!("Persister failed to notify as receiver went away."), + Err(std::sync::mpsc::SendError(())) => { + println!("Persister failed to notify as receiver went away.") + }, } }; if let Some((error, message)) = self.graph_error { - return Err(std::io::Error::new(error, message)) + return Err(std::io::Error::new(error, message)); } } - if primary_namespace == SCORER_PERSISTENCE_PRIMARY_NAMESPACE && - secondary_namespace == SCORER_PERSISTENCE_SECONDARY_NAMESPACE && - key == SCORER_PERSISTENCE_KEY + if primary_namespace == SCORER_PERSISTENCE_PRIMARY_NAMESPACE + && secondary_namespace == SCORER_PERSISTENCE_SECONDARY_NAMESPACE + && key == SCORER_PERSISTENCE_KEY { if let Some((error, message)) = self.scorer_error { - return Err(std::io::Error::new(error, message)) + return Err(std::io::Error::new(error, message)); } } self.kv_store.write(primary_namespace, secondary_namespace, key, buf) } - fn remove(&self, primary_namespace: &str, secondary_namespace: &str, key: &str, lazy: bool) -> lightning::io::Result<()> { + fn remove( + &self, primary_namespace: &str, secondary_namespace: &str, key: &str, lazy: bool, + ) -> lightning::io::Result<()> { self.kv_store.remove(primary_namespace, secondary_namespace, key, lazy) } - fn list(&self, primary_namespace: &str, secondary_namespace: &str) -> lightning::io::Result> { + fn list( + &self, primary_namespace: &str, secondary_namespace: &str, + ) -> lightning::io::Result> { self.kv_store.list(primary_namespace, secondary_namespace) } } @@ -1065,14 +1228,21 @@ mod tests { } impl lightning::util::ser::Writeable for TestScorer { - fn write(&self, _: &mut W) -> Result<(), lightning::io::Error> { Ok(()) } + fn write( + &self, _: &mut W, + ) -> Result<(), lightning::io::Error> { + Ok(()) + } } impl ScoreLookUp for TestScorer { type ScoreParams = (); fn channel_penalty_msat( - &self, _short_channel_id: u64, _source: &NodeId, _target: &NodeId, _usage: ChannelUsage, _score_params: &Self::ScoreParams - ) -> u64 { unimplemented!(); } + &self, _short_channel_id: u64, _source: &NodeId, _target: &NodeId, + _usage: ChannelUsage, _score_params: &Self::ScoreParams, + ) -> u64 { + unimplemented!(); + } } impl ScoreUpdate for TestScorer { @@ -1091,7 +1261,7 @@ mod tests { }, TestResult::ProbeSuccess { path } => { panic!("Unexpected probe success: {:?}", path) - } + }, } } } @@ -1110,7 +1280,7 @@ mod tests { }, TestResult::ProbeSuccess { path } => { panic!("Unexpected probe success: {:?}", path) - } + }, } } } @@ -1129,7 +1299,7 @@ mod tests { }, TestResult::ProbeSuccess { path } => { panic!("Unexpected probe success: {:?}", path) - } + }, } } } @@ -1147,7 +1317,7 @@ mod tests { }, TestResult::ProbeSuccess { path } => { assert_eq!(actual_path, &path); - } + }, } } } @@ -1183,41 +1353,111 @@ mod tests { let mut nodes = Vec::new(); for i in 0..num_nodes { let tx_broadcaster = Arc::new(test_utils::TestBroadcaster::new(network)); - let fee_estimator = Arc::new(test_utils::TestFeeEstimator { sat_per_kw: Mutex::new(253) }); + let fee_estimator = + Arc::new(test_utils::TestFeeEstimator { sat_per_kw: Mutex::new(253) }); let logger = Arc::new(test_utils::TestLogger::with_id(format!("node {}", i))); let genesis_block = genesis_block(network); let network_graph = Arc::new(NetworkGraph::new(network, logger.clone())); let scorer = Arc::new(LockingWrapper::new(TestScorer::new())); let seed = [i as u8; 32]; - let router = Arc::new(DefaultRouter::new(network_graph.clone(), logger.clone(), seed, scorer.clone(), Default::default())); + let router = Arc::new(DefaultRouter::new( + network_graph.clone(), + logger.clone(), + seed, + scorer.clone(), + Default::default(), + )); let chain_source = Arc::new(test_utils::TestChainSource::new(Network::Bitcoin)); - let kv_store = Arc::new(FilesystemStore::new(format!("{}_persister_{}", &persist_dir, i).into())); + let kv_store = + Arc::new(FilesystemStore::new(format!("{}_persister_{}", &persist_dir, i).into())); let now = Duration::from_secs(genesis_block.header.time as u64); let keys_manager = Arc::new(KeysManager::new(&seed, now.as_secs(), now.subsec_nanos())); - let chain_monitor = Arc::new(chainmonitor::ChainMonitor::new(Some(chain_source.clone()), tx_broadcaster.clone(), logger.clone(), fee_estimator.clone(), kv_store.clone())); + let chain_monitor = Arc::new(chainmonitor::ChainMonitor::new( + Some(chain_source.clone()), + tx_broadcaster.clone(), + logger.clone(), + fee_estimator.clone(), + kv_store.clone(), + )); let best_block = BestBlock::from_network(network); let params = ChainParameters { network, best_block }; - let manager = Arc::new(ChannelManager::new(fee_estimator.clone(), chain_monitor.clone(), tx_broadcaster.clone(), router.clone(), logger.clone(), keys_manager.clone(), keys_manager.clone(), keys_manager.clone(), UserConfig::default(), params, genesis_block.header.time)); - let p2p_gossip_sync = Arc::new(P2PGossipSync::new(network_graph.clone(), Some(chain_source.clone()), logger.clone())); - let rapid_gossip_sync = Arc::new(RapidGossipSync::new(network_graph.clone(), logger.clone())); + let manager = Arc::new(ChannelManager::new( + fee_estimator.clone(), + chain_monitor.clone(), + tx_broadcaster.clone(), + router.clone(), + logger.clone(), + keys_manager.clone(), + keys_manager.clone(), + keys_manager.clone(), + UserConfig::default(), + params, + genesis_block.header.time, + )); + let p2p_gossip_sync = Arc::new(P2PGossipSync::new( + network_graph.clone(), + Some(chain_source.clone()), + logger.clone(), + )); + let rapid_gossip_sync = + Arc::new(RapidGossipSync::new(network_graph.clone(), logger.clone())); let msg_handler = MessageHandler { - chan_handler: Arc::new(test_utils::TestChannelMessageHandler::new(ChainHash::using_genesis_block(Network::Testnet))), + chan_handler: Arc::new(test_utils::TestChannelMessageHandler::new( + ChainHash::using_genesis_block(Network::Testnet), + )), route_handler: Arc::new(test_utils::TestRoutingMessageHandler::new()), - onion_message_handler: IgnoringMessageHandler{}, custom_message_handler: IgnoringMessageHandler{} + onion_message_handler: IgnoringMessageHandler {}, + custom_message_handler: IgnoringMessageHandler {}, + }; + let peer_manager = Arc::new(PeerManager::new( + msg_handler, + 0, + &seed, + logger.clone(), + keys_manager.clone(), + )); + let node = Node { + node: manager, + p2p_gossip_sync, + rapid_gossip_sync, + peer_manager, + chain_monitor, + kv_store, + tx_broadcaster, + network_graph, + logger, + best_block, + scorer, }; - let peer_manager = Arc::new(PeerManager::new(msg_handler, 0, &seed, logger.clone(), keys_manager.clone())); - let node = Node { node: manager, p2p_gossip_sync, rapid_gossip_sync, peer_manager, chain_monitor, kv_store, tx_broadcaster, network_graph, logger, best_block, scorer }; nodes.push(node); } for i in 0..num_nodes { - for j in (i+1)..num_nodes { - nodes[i].node.peer_connected(&nodes[j].node.get_our_node_id(), &Init { - features: nodes[j].node.init_features(), networks: None, remote_network_address: None - }, true).unwrap(); - nodes[j].node.peer_connected(&nodes[i].node.get_our_node_id(), &Init { - features: nodes[i].node.init_features(), networks: None, remote_network_address: None - }, false).unwrap(); + for j in (i + 1)..num_nodes { + nodes[i] + .node + .peer_connected( + &nodes[j].node.get_our_node_id(), + &Init { + features: nodes[j].node.init_features(), + networks: None, + remote_network_address: None, + }, + true, + ) + .unwrap(); + nodes[j] + .node + .peer_connected( + &nodes[i].node.get_our_node_id(), + &Init { + features: nodes[i].node.init_features(), + networks: None, + remote_network_address: None, + }, + false, + ) + .unwrap(); } } @@ -1229,39 +1469,90 @@ mod tests { begin_open_channel!($node_a, $node_b, $channel_value); let events = $node_a.node.get_and_clear_pending_events(); assert_eq!(events.len(), 1); - let (temporary_channel_id, tx) = handle_funding_generation_ready!(events[0], $channel_value); - $node_a.node.funding_transaction_generated(&temporary_channel_id, &$node_b.node.get_our_node_id(), tx.clone()).unwrap(); - $node_b.node.handle_funding_created(&$node_a.node.get_our_node_id(), &get_event_msg!($node_a, MessageSendEvent::SendFundingCreated, $node_b.node.get_our_node_id())); + let (temporary_channel_id, tx) = + handle_funding_generation_ready!(events[0], $channel_value); + $node_a + .node + .funding_transaction_generated( + &temporary_channel_id, + &$node_b.node.get_our_node_id(), + tx.clone(), + ) + .unwrap(); + $node_b.node.handle_funding_created( + &$node_a.node.get_our_node_id(), + &get_event_msg!( + $node_a, + MessageSendEvent::SendFundingCreated, + $node_b.node.get_our_node_id() + ), + ); get_event!($node_b, Event::ChannelPending); - $node_a.node.handle_funding_signed(&$node_b.node.get_our_node_id(), &get_event_msg!($node_b, MessageSendEvent::SendFundingSigned, $node_a.node.get_our_node_id())); + $node_a.node.handle_funding_signed( + &$node_b.node.get_our_node_id(), + &get_event_msg!( + $node_b, + MessageSendEvent::SendFundingSigned, + $node_a.node.get_our_node_id() + ), + ); get_event!($node_a, Event::ChannelPending); tx - }} + }}; } macro_rules! begin_open_channel { ($node_a: expr, $node_b: expr, $channel_value: expr) => {{ - $node_a.node.create_channel($node_b.node.get_our_node_id(), $channel_value, 100, 42, None).unwrap(); - $node_b.node.handle_open_channel(&$node_a.node.get_our_node_id(), &get_event_msg!($node_a, MessageSendEvent::SendOpenChannel, $node_b.node.get_our_node_id())); - $node_a.node.handle_accept_channel(&$node_b.node.get_our_node_id(), &get_event_msg!($node_b, MessageSendEvent::SendAcceptChannel, $node_a.node.get_our_node_id())); - }} + $node_a + .node + .create_channel($node_b.node.get_our_node_id(), $channel_value, 100, 42, None) + .unwrap(); + $node_b.node.handle_open_channel( + &$node_a.node.get_our_node_id(), + &get_event_msg!( + $node_a, + MessageSendEvent::SendOpenChannel, + $node_b.node.get_our_node_id() + ), + ); + $node_a.node.handle_accept_channel( + &$node_b.node.get_our_node_id(), + &get_event_msg!( + $node_b, + MessageSendEvent::SendAcceptChannel, + $node_a.node.get_our_node_id() + ), + ); + }}; } macro_rules! handle_funding_generation_ready { ($event: expr, $channel_value: expr) => {{ match $event { - Event::FundingGenerationReady { temporary_channel_id, channel_value_satoshis, ref output_script, user_channel_id, .. } => { + Event::FundingGenerationReady { + temporary_channel_id, + channel_value_satoshis, + ref output_script, + user_channel_id, + .. + } => { assert_eq!(channel_value_satoshis, $channel_value); assert_eq!(user_channel_id, 42); - let tx = Transaction { version: 1 as i32, lock_time: PackedLockTime(0), input: Vec::new(), output: vec![TxOut { - value: channel_value_satoshis, script_pubkey: output_script.clone(), - }]}; + let tx = Transaction { + version: 1 as i32, + lock_time: PackedLockTime(0), + input: Vec::new(), + output: vec![TxOut { + value: channel_value_satoshis, + script_pubkey: output_script.clone(), + }], + }; (temporary_channel_id, tx) }, _ => panic!("Unexpected event"), } - }} + }}; } fn confirm_transaction_depth(node: &mut Node, tx: &Transaction, depth: u32) { @@ -1304,7 +1595,16 @@ mod tests { let data_dir = nodes[0].kv_store.get_data_dir(); let persister = Arc::new(Persister::new(data_dir)); let event_handler = |_: _| {}; - let bg_processor = BackgroundProcessor::start(persister, event_handler, nodes[0].chain_monitor.clone(), nodes[0].node.clone(), nodes[0].p2p_gossip_sync(), nodes[0].peer_manager.clone(), nodes[0].logger.clone(), Some(nodes[0].scorer.clone())); + let bg_processor = BackgroundProcessor::start( + persister, + event_handler, + nodes[0].chain_monitor.clone(), + nodes[0].node.clone(), + nodes[0].p2p_gossip_sync(), + nodes[0].peer_manager.clone(), + nodes[0].logger.clone(), + Some(nodes[0].scorer.clone()), + ); macro_rules! check_persisted_data { ($node: expr, $filepath: expr) => { @@ -1312,47 +1612,58 @@ mod tests { loop { expected_bytes.clear(); match $node.write(&mut expected_bytes) { - Ok(()) => { - match std::fs::read($filepath) { - Ok(bytes) => { - if bytes == expected_bytes { - break - } else { - continue - } - }, - Err(_) => continue - } + Ok(()) => match std::fs::read($filepath) { + Ok(bytes) => { + if bytes == expected_bytes { + break; + } else { + continue; + } + }, + Err(_) => continue, }, - Err(e) => panic!("Unexpected error: {}", e) + Err(e) => panic!("Unexpected error: {}", e), } } - } + }; } // Check that the initial channel manager data is persisted as expected. - let filepath = get_full_filepath(format!("{}_persister_0", &persist_dir), "manager".to_string()); + let filepath = + get_full_filepath(format!("{}_persister_0", &persist_dir), "manager".to_string()); check_persisted_data!(nodes[0].node, filepath.clone()); loop { - if !nodes[0].node.get_event_or_persist_condvar_value() { break } + if !nodes[0].node.get_event_or_persist_condvar_value() { + break; + } } // Force-close the channel. - nodes[0].node.force_close_broadcasting_latest_txn(&OutPoint { txid: tx.txid(), index: 0 }.to_channel_id(), &nodes[1].node.get_our_node_id()).unwrap(); + nodes[0] + .node + .force_close_broadcasting_latest_txn( + &OutPoint { txid: tx.txid(), index: 0 }.to_channel_id(), + &nodes[1].node.get_our_node_id(), + ) + .unwrap(); // Check that the force-close updates are persisted. check_persisted_data!(nodes[0].node, filepath.clone()); loop { - if !nodes[0].node.get_event_or_persist_condvar_value() { break } + if !nodes[0].node.get_event_or_persist_condvar_value() { + break; + } } // Check network graph is persisted - let filepath = get_full_filepath(format!("{}_persister_0", &persist_dir), "network_graph".to_string()); + let filepath = + get_full_filepath(format!("{}_persister_0", &persist_dir), "network_graph".to_string()); check_persisted_data!(nodes[0].network_graph, filepath.clone()); // Check scorer is persisted - let filepath = get_full_filepath(format!("{}_persister_0", &persist_dir), "scorer".to_string()); + let filepath = + get_full_filepath(format!("{}_persister_0", &persist_dir), "scorer".to_string()); check_persisted_data!(nodes[0].scorer, filepath.clone()); if !std::thread::panicking() { @@ -1369,16 +1680,30 @@ mod tests { let data_dir = nodes[0].kv_store.get_data_dir(); let persister = Arc::new(Persister::new(data_dir)); let event_handler = |_: _| {}; - let bg_processor = BackgroundProcessor::start(persister, event_handler, nodes[0].chain_monitor.clone(), nodes[0].node.clone(), nodes[0].no_gossip_sync(), nodes[0].peer_manager.clone(), nodes[0].logger.clone(), Some(nodes[0].scorer.clone())); + let bg_processor = BackgroundProcessor::start( + persister, + event_handler, + nodes[0].chain_monitor.clone(), + nodes[0].node.clone(), + nodes[0].no_gossip_sync(), + nodes[0].peer_manager.clone(), + nodes[0].logger.clone(), + Some(nodes[0].scorer.clone()), + ); loop { let log_entries = nodes[0].logger.lines.lock().unwrap(); let desired_log_1 = "Calling ChannelManager's timer_tick_occurred".to_string(); let desired_log_2 = "Calling PeerManager's timer_tick_occurred".to_string(); let desired_log_3 = "Rebroadcasting monitor's pending claims".to_string(); - if log_entries.get(&("lightning_background_processor".to_string(), desired_log_1)).is_some() && - log_entries.get(&("lightning_background_processor".to_string(), desired_log_2)).is_some() && - log_entries.get(&("lightning_background_processor".to_string(), desired_log_3)).is_some() { - break + if log_entries + .get(&("lightning_background_processor".to_string(), desired_log_1)) + .is_some() && log_entries + .get(&("lightning_background_processor".to_string(), desired_log_2)) + .is_some() && log_entries + .get(&("lightning_background_processor".to_string(), desired_log_3)) + .is_some() + { + break; } } @@ -1394,9 +1719,20 @@ mod tests { open_channel!(nodes[0], nodes[1], 100000); let data_dir = nodes[0].kv_store.get_data_dir(); - let persister = Arc::new(Persister::new(data_dir).with_manager_error(std::io::ErrorKind::Other, "test")); + let persister = Arc::new( + Persister::new(data_dir).with_manager_error(std::io::ErrorKind::Other, "test"), + ); let event_handler = |_: _| {}; - let bg_processor = BackgroundProcessor::start(persister, event_handler, nodes[0].chain_monitor.clone(), nodes[0].node.clone(), nodes[0].no_gossip_sync(), nodes[0].peer_manager.clone(), nodes[0].logger.clone(), Some(nodes[0].scorer.clone())); + let bg_processor = BackgroundProcessor::start( + persister, + event_handler, + nodes[0].chain_monitor.clone(), + nodes[0].node.clone(), + nodes[0].no_gossip_sync(), + nodes[0].peer_manager.clone(), + nodes[0].logger.clone(), + Some(nodes[0].scorer.clone()), + ); match bg_processor.join() { Ok(_) => panic!("Expected error persisting manager"), Err(e) => { @@ -1414,17 +1750,26 @@ mod tests { open_channel!(nodes[0], nodes[1], 100000); let data_dir = nodes[0].kv_store.get_data_dir(); - let persister = Arc::new(Persister::new(data_dir).with_manager_error(std::io::ErrorKind::Other, "test")); + let persister = Arc::new( + Persister::new(data_dir).with_manager_error(std::io::ErrorKind::Other, "test"), + ); let bp_future = super::process_events_async( - persister, |_: _| {async {}}, nodes[0].chain_monitor.clone(), nodes[0].node.clone(), - nodes[0].rapid_gossip_sync(), nodes[0].peer_manager.clone(), nodes[0].logger.clone(), - Some(nodes[0].scorer.clone()), move |dur: Duration| { + persister, + |_: _| async {}, + nodes[0].chain_monitor.clone(), + nodes[0].node.clone(), + nodes[0].rapid_gossip_sync(), + nodes[0].peer_manager.clone(), + nodes[0].logger.clone(), + Some(nodes[0].scorer.clone()), + move |dur: Duration| { Box::pin(async move { tokio::time::sleep(dur).await; false // Never exit }) - }, false, + }, + false, ); match bp_future.await { Ok(_) => panic!("Expected error persisting manager"), @@ -1440,9 +1785,19 @@ mod tests { // Test that if we encounter an error during network graph persistence, an error gets returned. let (_, nodes) = create_nodes(2, "test_persist_network_graph_error"); let data_dir = nodes[0].kv_store.get_data_dir(); - let persister = Arc::new(Persister::new(data_dir).with_graph_error(std::io::ErrorKind::Other, "test")); + let persister = + Arc::new(Persister::new(data_dir).with_graph_error(std::io::ErrorKind::Other, "test")); let event_handler = |_: _| {}; - let bg_processor = BackgroundProcessor::start(persister, event_handler, nodes[0].chain_monitor.clone(), nodes[0].node.clone(), nodes[0].p2p_gossip_sync(), nodes[0].peer_manager.clone(), nodes[0].logger.clone(), Some(nodes[0].scorer.clone())); + let bg_processor = BackgroundProcessor::start( + persister, + event_handler, + nodes[0].chain_monitor.clone(), + nodes[0].node.clone(), + nodes[0].p2p_gossip_sync(), + nodes[0].peer_manager.clone(), + nodes[0].logger.clone(), + Some(nodes[0].scorer.clone()), + ); match bg_processor.stop() { Ok(_) => panic!("Expected error persisting network graph"), @@ -1458,9 +1813,19 @@ mod tests { // Test that if we encounter an error during scorer persistence, an error gets returned. let (_, nodes) = create_nodes(2, "test_persist_scorer_error"); let data_dir = nodes[0].kv_store.get_data_dir(); - let persister = Arc::new(Persister::new(data_dir).with_scorer_error(std::io::ErrorKind::Other, "test")); + let persister = + Arc::new(Persister::new(data_dir).with_scorer_error(std::io::ErrorKind::Other, "test")); let event_handler = |_: _| {}; - let bg_processor = BackgroundProcessor::start(persister, event_handler, nodes[0].chain_monitor.clone(), nodes[0].node.clone(), nodes[0].no_gossip_sync(), nodes[0].peer_manager.clone(), nodes[0].logger.clone(), Some(nodes[0].scorer.clone())); + let bg_processor = BackgroundProcessor::start( + persister, + event_handler, + nodes[0].chain_monitor.clone(), + nodes[0].node.clone(), + nodes[0].no_gossip_sync(), + nodes[0].peer_manager.clone(), + nodes[0].logger.clone(), + Some(nodes[0].scorer.clone()), + ); match bg_processor.stop() { Ok(_) => panic!("Expected error persisting scorer"), @@ -1482,35 +1847,84 @@ mod tests { let (funding_generation_send, funding_generation_recv) = std::sync::mpsc::sync_channel(1); let (channel_pending_send, channel_pending_recv) = std::sync::mpsc::sync_channel(1); let event_handler = move |event: Event| match event { - Event::FundingGenerationReady { .. } => funding_generation_send.send(handle_funding_generation_ready!(event, channel_value)).unwrap(), + Event::FundingGenerationReady { .. } => funding_generation_send + .send(handle_funding_generation_ready!(event, channel_value)) + .unwrap(), Event::ChannelPending { .. } => channel_pending_send.send(()).unwrap(), Event::ChannelReady { .. } => {}, _ => panic!("Unexpected event: {:?}", event), }; - let bg_processor = BackgroundProcessor::start(persister, event_handler, nodes[0].chain_monitor.clone(), nodes[0].node.clone(), nodes[0].no_gossip_sync(), nodes[0].peer_manager.clone(), nodes[0].logger.clone(), Some(nodes[0].scorer.clone())); + let bg_processor = BackgroundProcessor::start( + persister, + event_handler, + nodes[0].chain_monitor.clone(), + nodes[0].node.clone(), + nodes[0].no_gossip_sync(), + nodes[0].peer_manager.clone(), + nodes[0].logger.clone(), + Some(nodes[0].scorer.clone()), + ); // Open a channel and check that the FundingGenerationReady event was handled. begin_open_channel!(nodes[0], nodes[1], channel_value); let (temporary_channel_id, funding_tx) = funding_generation_recv .recv_timeout(Duration::from_secs(EVENT_DEADLINE)) .expect("FundingGenerationReady not handled within deadline"); - nodes[0].node.funding_transaction_generated(&temporary_channel_id, &nodes[1].node.get_our_node_id(), funding_tx.clone()).unwrap(); - nodes[1].node.handle_funding_created(&nodes[0].node.get_our_node_id(), &get_event_msg!(nodes[0], MessageSendEvent::SendFundingCreated, nodes[1].node.get_our_node_id())); + nodes[0] + .node + .funding_transaction_generated( + &temporary_channel_id, + &nodes[1].node.get_our_node_id(), + funding_tx.clone(), + ) + .unwrap(); + nodes[1].node.handle_funding_created( + &nodes[0].node.get_our_node_id(), + &get_event_msg!( + nodes[0], + MessageSendEvent::SendFundingCreated, + nodes[1].node.get_our_node_id() + ), + ); get_event!(nodes[1], Event::ChannelPending); - nodes[0].node.handle_funding_signed(&nodes[1].node.get_our_node_id(), &get_event_msg!(nodes[1], MessageSendEvent::SendFundingSigned, nodes[0].node.get_our_node_id())); - let _ = channel_pending_recv.recv_timeout(Duration::from_secs(EVENT_DEADLINE)) + nodes[0].node.handle_funding_signed( + &nodes[1].node.get_our_node_id(), + &get_event_msg!( + nodes[1], + MessageSendEvent::SendFundingSigned, + nodes[0].node.get_our_node_id() + ), + ); + let _ = channel_pending_recv + .recv_timeout(Duration::from_secs(EVENT_DEADLINE)) .expect("ChannelPending not handled within deadline"); // Confirm the funding transaction. confirm_transaction(&mut nodes[0], &funding_tx); - let as_funding = get_event_msg!(nodes[0], MessageSendEvent::SendChannelReady, nodes[1].node.get_our_node_id()); + let as_funding = get_event_msg!( + nodes[0], + MessageSendEvent::SendChannelReady, + nodes[1].node.get_our_node_id() + ); confirm_transaction(&mut nodes[1], &funding_tx); - let bs_funding = get_event_msg!(nodes[1], MessageSendEvent::SendChannelReady, nodes[0].node.get_our_node_id()); + let bs_funding = get_event_msg!( + nodes[1], + MessageSendEvent::SendChannelReady, + nodes[0].node.get_our_node_id() + ); nodes[0].node.handle_channel_ready(&nodes[1].node.get_our_node_id(), &bs_funding); - let _as_channel_update = get_event_msg!(nodes[0], MessageSendEvent::SendChannelUpdate, nodes[1].node.get_our_node_id()); + let _as_channel_update = get_event_msg!( + nodes[0], + MessageSendEvent::SendChannelUpdate, + nodes[1].node.get_our_node_id() + ); nodes[1].node.handle_channel_ready(&nodes[0].node.get_our_node_id(), &as_funding); - let _bs_channel_update = get_event_msg!(nodes[1], MessageSendEvent::SendChannelUpdate, nodes[0].node.get_our_node_id()); + let _bs_channel_update = get_event_msg!( + nodes[1], + MessageSendEvent::SendChannelUpdate, + nodes[0].node.get_our_node_id() + ); if !std::thread::panicking() { bg_processor.stop().unwrap(); @@ -1525,10 +1939,25 @@ mod tests { _ => panic!("Unexpected event: {:?}", event), }; let persister = Arc::new(Persister::new(data_dir)); - let bg_processor = BackgroundProcessor::start(persister, event_handler, nodes[0].chain_monitor.clone(), nodes[0].node.clone(), nodes[0].no_gossip_sync(), nodes[0].peer_manager.clone(), nodes[0].logger.clone(), Some(nodes[0].scorer.clone())); + let bg_processor = BackgroundProcessor::start( + persister, + event_handler, + nodes[0].chain_monitor.clone(), + nodes[0].node.clone(), + nodes[0].no_gossip_sync(), + nodes[0].peer_manager.clone(), + nodes[0].logger.clone(), + Some(nodes[0].scorer.clone()), + ); // Force close the channel and check that the SpendableOutputs event was handled. - nodes[0].node.force_close_broadcasting_latest_txn(&nodes[0].node.list_channels()[0].channel_id, &nodes[1].node.get_our_node_id()).unwrap(); + nodes[0] + .node + .force_close_broadcasting_latest_txn( + &nodes[0].node.list_channels()[0].channel_id, + &nodes[1].node.get_our_node_id(), + ) + .unwrap(); let commitment_tx = nodes[0].tx_broadcaster.txn_broadcasted.lock().unwrap().pop().unwrap(); confirm_transaction_depth(&mut nodes[0], &commitment_tx, BREAKDOWN_TIMEOUT as u32); @@ -1551,13 +1980,25 @@ mod tests { let data_dir = nodes[0].kv_store.get_data_dir(); let persister = Arc::new(Persister::new(data_dir)); let event_handler = |_: _| {}; - let bg_processor = BackgroundProcessor::start(persister, event_handler, nodes[0].chain_monitor.clone(), nodes[0].node.clone(), nodes[0].no_gossip_sync(), nodes[0].peer_manager.clone(), nodes[0].logger.clone(), Some(nodes[0].scorer.clone())); + let bg_processor = BackgroundProcessor::start( + persister, + event_handler, + nodes[0].chain_monitor.clone(), + nodes[0].node.clone(), + nodes[0].no_gossip_sync(), + nodes[0].peer_manager.clone(), + nodes[0].logger.clone(), + Some(nodes[0].scorer.clone()), + ); loop { let log_entries = nodes[0].logger.lines.lock().unwrap(); let expected_log = "Persisting scorer".to_string(); - if log_entries.get(&("lightning_background_processor".to_string(), expected_log)).is_some() { - break + if log_entries + .get(&("lightning_background_processor".to_string(), expected_log)) + .is_some() + { + break; } } @@ -1569,9 +2010,16 @@ mod tests { macro_rules! do_test_not_pruning_network_graph_until_graph_sync_completion { ($nodes: expr, $receive: expr, $sleep: expr) => { let features = ChannelFeatures::empty(); - $nodes[0].network_graph.add_channel_from_partial_announcement( - 42, 53, features, $nodes[0].node.get_our_node_id(), $nodes[1].node.get_our_node_id() - ).expect("Failed to update channel from partial announcement"); + $nodes[0] + .network_graph + .add_channel_from_partial_announcement( + 42, + 53, + features, + $nodes[0].node.get_our_node_id(), + $nodes[1].node.get_our_node_id(), + ) + .expect("Failed to update channel from partial announcement"); let original_graph_description = $nodes[0].network_graph.to_string(); assert!(original_graph_description.contains("42: features: 0000, node_one:")); assert_eq!($nodes[0].network_graph.read_only().channels().len(), 1); @@ -1580,30 +2028,36 @@ mod tests { $sleep; let log_entries = $nodes[0].logger.lines.lock().unwrap(); let loop_counter = "Calling ChannelManager's timer_tick_occurred".to_string(); - if *log_entries.get(&("lightning_background_processor".to_string(), loop_counter)) - .unwrap_or(&0) > 1 + if *log_entries + .get(&("lightning_background_processor".to_string(), loop_counter)) + .unwrap_or(&0) + > 1 { // Wait until the loop has gone around at least twice. - break + break; } } let initialization_input = vec![ - 76, 68, 75, 1, 111, 226, 140, 10, 182, 241, 179, 114, 193, 166, 162, 70, 174, 99, 247, - 79, 147, 30, 131, 101, 225, 90, 8, 156, 104, 214, 25, 0, 0, 0, 0, 0, 97, 227, 98, 218, - 0, 0, 0, 4, 2, 22, 7, 207, 206, 25, 164, 197, 231, 230, 231, 56, 102, 61, 250, 251, - 187, 172, 38, 46, 79, 247, 108, 44, 155, 48, 219, 238, 252, 53, 192, 6, 67, 2, 36, 125, - 157, 176, 223, 175, 234, 116, 94, 248, 201, 225, 97, 235, 50, 47, 115, 172, 63, 136, - 88, 216, 115, 11, 111, 217, 114, 84, 116, 124, 231, 107, 2, 158, 1, 242, 121, 152, 106, - 204, 131, 186, 35, 93, 70, 216, 10, 237, 224, 183, 89, 95, 65, 3, 83, 185, 58, 138, - 181, 64, 187, 103, 127, 68, 50, 2, 201, 19, 17, 138, 136, 149, 185, 226, 156, 137, 175, - 110, 32, 237, 0, 217, 90, 31, 100, 228, 149, 46, 219, 175, 168, 77, 4, 143, 38, 128, - 76, 97, 0, 0, 0, 2, 0, 0, 255, 8, 153, 192, 0, 2, 27, 0, 0, 0, 1, 0, 0, 255, 2, 68, - 226, 0, 6, 11, 0, 1, 2, 3, 0, 0, 0, 2, 0, 40, 0, 0, 0, 0, 0, 0, 3, 232, 0, 0, 3, 232, - 0, 0, 0, 1, 0, 0, 0, 0, 58, 85, 116, 216, 255, 8, 153, 192, 0, 2, 27, 0, 0, 25, 0, 0, - 0, 1, 0, 0, 0, 125, 255, 2, 68, 226, 0, 6, 11, 0, 1, 5, 0, 0, 0, 0, 29, 129, 25, 192, + 76, 68, 75, 1, 111, 226, 140, 10, 182, 241, 179, 114, 193, 166, 162, 70, 174, 99, + 247, 79, 147, 30, 131, 101, 225, 90, 8, 156, 104, 214, 25, 0, 0, 0, 0, 0, 97, 227, + 98, 218, 0, 0, 0, 4, 2, 22, 7, 207, 206, 25, 164, 197, 231, 230, 231, 56, 102, 61, + 250, 251, 187, 172, 38, 46, 79, 247, 108, 44, 155, 48, 219, 238, 252, 53, 192, 6, + 67, 2, 36, 125, 157, 176, 223, 175, 234, 116, 94, 248, 201, 225, 97, 235, 50, 47, + 115, 172, 63, 136, 88, 216, 115, 11, 111, 217, 114, 84, 116, 124, 231, 107, 2, 158, + 1, 242, 121, 152, 106, 204, 131, 186, 35, 93, 70, 216, 10, 237, 224, 183, 89, 95, + 65, 3, 83, 185, 58, 138, 181, 64, 187, 103, 127, 68, 50, 2, 201, 19, 17, 138, 136, + 149, 185, 226, 156, 137, 175, 110, 32, 237, 0, 217, 90, 31, 100, 228, 149, 46, 219, + 175, 168, 77, 4, 143, 38, 128, 76, 97, 0, 0, 0, 2, 0, 0, 255, 8, 153, 192, 0, 2, + 27, 0, 0, 0, 1, 0, 0, 255, 2, 68, 226, 0, 6, 11, 0, 1, 2, 3, 0, 0, 0, 2, 0, 40, 0, + 0, 0, 0, 0, 0, 3, 232, 0, 0, 3, 232, 0, 0, 0, 1, 0, 0, 0, 0, 58, 85, 116, 216, 255, + 8, 153, 192, 0, 2, 27, 0, 0, 25, 0, 0, 0, 1, 0, 0, 0, 125, 255, 2, 68, 226, 0, 6, + 11, 0, 1, 5, 0, 0, 0, 0, 29, 129, 25, 192, ]; - $nodes[0].rapid_gossip_sync.update_network_graph_no_std(&initialization_input[..], Some(1642291930)).unwrap(); + $nodes[0] + .rapid_gossip_sync + .update_network_graph_no_std(&initialization_input[..], Some(1642291930)) + .unwrap(); // this should have added two channels and pruned the previous one. assert_eq!($nodes[0].network_graph.read_only().channels().len(), 2); @@ -1612,23 +2066,35 @@ mod tests { // all channels should now be pruned assert_eq!($nodes[0].network_graph.read_only().channels().len(), 0); - } + }; } #[test] fn test_not_pruning_network_graph_until_graph_sync_completion() { let (sender, receiver) = std::sync::mpsc::sync_channel(1); - let (_, nodes) = create_nodes(2, "test_not_pruning_network_graph_until_graph_sync_completion"); + let (_, nodes) = + create_nodes(2, "test_not_pruning_network_graph_until_graph_sync_completion"); let data_dir = nodes[0].kv_store.get_data_dir(); let persister = Arc::new(Persister::new(data_dir).with_graph_persistence_notifier(sender)); let event_handler = |_: _| {}; - let background_processor = BackgroundProcessor::start(persister, event_handler, nodes[0].chain_monitor.clone(), nodes[0].node.clone(), nodes[0].rapid_gossip_sync(), nodes[0].peer_manager.clone(), nodes[0].logger.clone(), Some(nodes[0].scorer.clone())); + let background_processor = BackgroundProcessor::start( + persister, + event_handler, + nodes[0].chain_monitor.clone(), + nodes[0].node.clone(), + nodes[0].rapid_gossip_sync(), + nodes[0].peer_manager.clone(), + nodes[0].logger.clone(), + Some(nodes[0].scorer.clone()), + ); - do_test_not_pruning_network_graph_until_graph_sync_completion!(nodes, + do_test_not_pruning_network_graph_until_graph_sync_completion!( + nodes, receiver.recv_timeout(Duration::from_secs(super::FIRST_NETWORK_PRUNE_TIMER * 5)), - std::thread::sleep(Duration::from_millis(1))); + std::thread::sleep(Duration::from_millis(1)) + ); background_processor.stop().unwrap(); } @@ -1638,15 +2104,22 @@ mod tests { async fn test_not_pruning_network_graph_until_graph_sync_completion_async() { let (sender, receiver) = std::sync::mpsc::sync_channel(1); - let (_, nodes) = create_nodes(2, "test_not_pruning_network_graph_until_graph_sync_completion_async"); + let (_, nodes) = + create_nodes(2, "test_not_pruning_network_graph_until_graph_sync_completion_async"); let data_dir = nodes[0].kv_store.get_data_dir(); let persister = Arc::new(Persister::new(data_dir).with_graph_persistence_notifier(sender)); let (exit_sender, exit_receiver) = tokio::sync::watch::channel(()); let bp_future = super::process_events_async( - persister, |_: _| {async {}}, nodes[0].chain_monitor.clone(), nodes[0].node.clone(), - nodes[0].rapid_gossip_sync(), nodes[0].peer_manager.clone(), nodes[0].logger.clone(), - Some(nodes[0].scorer.clone()), move |dur: Duration| { + persister, + |_: _| async {}, + nodes[0].chain_monitor.clone(), + nodes[0].node.clone(), + nodes[0].rapid_gossip_sync(), + nodes[0].peer_manager.clone(), + nodes[0].logger.clone(), + Some(nodes[0].scorer.clone()), + move |dur: Duration| { let mut exit_receiver = exit_receiver.clone(); Box::pin(async move { tokio::select! { @@ -1654,20 +2127,28 @@ mod tests { _ = exit_receiver.changed() => true, } }) - }, false, + }, + false, ); let t1 = tokio::spawn(bp_future); let t2 = tokio::spawn(async move { - do_test_not_pruning_network_graph_until_graph_sync_completion!(nodes, { - let mut i = 0; - loop { - tokio::time::sleep(Duration::from_secs(super::FIRST_NETWORK_PRUNE_TIMER)).await; - if let Ok(()) = receiver.try_recv() { break Ok::<(), ()>(()); } - assert!(i < 5); - i += 1; - } - }, tokio::time::sleep(Duration::from_millis(1)).await); + do_test_not_pruning_network_graph_until_graph_sync_completion!( + nodes, + { + let mut i = 0; + loop { + tokio::time::sleep(Duration::from_secs(super::FIRST_NETWORK_PRUNE_TIMER)) + .await; + if let Ok(()) = receiver.try_recv() { + break Ok::<(), ()>(()); + } + assert!(i < 5); + i += 1; + } + }, + tokio::time::sleep(Duration::from_millis(1)).await + ); exit_sender.send(()).unwrap(); }); let (r1, r2) = tokio::join!(t1, t2); @@ -1782,9 +2263,21 @@ mod tests { let (_, nodes) = create_nodes(1, "test_payment_path_scoring"); let data_dir = nodes[0].kv_store.get_data_dir(); let persister = Arc::new(Persister::new(data_dir)); - let bg_processor = BackgroundProcessor::start(persister, event_handler, nodes[0].chain_monitor.clone(), nodes[0].node.clone(), nodes[0].no_gossip_sync(), nodes[0].peer_manager.clone(), nodes[0].logger.clone(), Some(nodes[0].scorer.clone())); + let bg_processor = BackgroundProcessor::start( + persister, + event_handler, + nodes[0].chain_monitor.clone(), + nodes[0].node.clone(), + nodes[0].no_gossip_sync(), + nodes[0].peer_manager.clone(), + nodes[0].logger.clone(), + Some(nodes[0].scorer.clone()), + ); - do_test_payment_path_scoring!(nodes, receiver.recv_timeout(Duration::from_secs(EVENT_DEADLINE))); + do_test_payment_path_scoring!( + nodes, + receiver.recv_timeout(Duration::from_secs(EVENT_DEADLINE)) + ); if !std::thread::panicking() { bg_processor.stop().unwrap(); @@ -1792,7 +2285,12 @@ mod tests { let log_entries = nodes[0].logger.lines.lock().unwrap(); let expected_log = "Persisting scorer after update".to_string(); - assert_eq!(*log_entries.get(&("lightning_background_processor".to_string(), expected_log)).unwrap(), 5); + assert_eq!( + *log_entries + .get(&("lightning_background_processor".to_string(), expected_log)) + .unwrap(), + 5 + ); } #[tokio::test] @@ -1803,10 +2301,10 @@ mod tests { let sender_ref = sender.clone(); async move { match event { - Event::PaymentPathFailed { .. } => { sender_ref.send(event).await.unwrap() }, - Event::PaymentPathSuccessful { .. } => { sender_ref.send(event).await.unwrap() }, - Event::ProbeSuccessful { .. } => { sender_ref.send(event).await.unwrap() }, - Event::ProbeFailed { .. } => { sender_ref.send(event).await.unwrap() }, + Event::PaymentPathFailed { .. } => sender_ref.send(event).await.unwrap(), + Event::PaymentPathSuccessful { .. } => sender_ref.send(event).await.unwrap(), + Event::ProbeSuccessful { .. } => sender_ref.send(event).await.unwrap(), + Event::ProbeFailed { .. } => sender_ref.send(event).await.unwrap(), _ => panic!("Unexpected event: {:?}", event), } } @@ -1819,9 +2317,15 @@ mod tests { let (exit_sender, exit_receiver) = tokio::sync::watch::channel(()); let bp_future = super::process_events_async( - persister, event_handler, nodes[0].chain_monitor.clone(), nodes[0].node.clone(), - nodes[0].no_gossip_sync(), nodes[0].peer_manager.clone(), nodes[0].logger.clone(), - Some(nodes[0].scorer.clone()), move |dur: Duration| { + persister, + event_handler, + nodes[0].chain_monitor.clone(), + nodes[0].node.clone(), + nodes[0].no_gossip_sync(), + nodes[0].peer_manager.clone(), + nodes[0].logger.clone(), + Some(nodes[0].scorer.clone()), + move |dur: Duration| { let mut exit_receiver = exit_receiver.clone(); Box::pin(async move { tokio::select! { @@ -1829,7 +2333,8 @@ mod tests { _ = exit_receiver.changed() => true, } }) - }, false, + }, + false, ); let t1 = tokio::spawn(bp_future); let t2 = tokio::spawn(async move { @@ -1838,7 +2343,12 @@ mod tests { let log_entries = nodes[0].logger.lines.lock().unwrap(); let expected_log = "Persisting scorer after update".to_string(); - assert_eq!(*log_entries.get(&("lightning_background_processor".to_string(), expected_log)).unwrap(), 5); + assert_eq!( + *log_entries + .get(&("lightning_background_processor".to_string(), expected_log)) + .unwrap(), + 5 + ); }); let (r1, r2) = tokio::join!(t1, t2); diff --git a/lightning-block-sync/src/convert.rs b/lightning-block-sync/src/convert.rs index bf9e9577619..185b38dc311 100644 --- a/lightning-block-sync/src/convert.rs +++ b/lightning-block-sync/src/convert.rs @@ -10,15 +10,17 @@ use bitcoin::Transaction; use serde_json; +use bitcoin::hashes::Hash; use std::convert::From; use std::convert::TryFrom; use std::convert::TryInto; use std::str::FromStr; -use bitcoin::hashes::Hash; impl TryInto for JsonResponse { type Error = std::io::Error; - fn try_into(self) -> Result { Ok(self.0) } + fn try_into(self) -> Result { + Ok(self.0) + } } /// Conversion from `std::io::Error` into `BlockSourceError`. @@ -38,7 +40,12 @@ impl TryInto for BinaryResponse { fn try_into(self) -> std::io::Result { match encode::deserialize(&self.0) { - Err(_) => return Err(std::io::Error::new(std::io::ErrorKind::InvalidData, "invalid block data")), + Err(_) => { + return Err(std::io::Error::new( + std::io::ErrorKind::InvalidData, + "invalid block data", + )) + }, Ok(block) => Ok(block), } } @@ -49,9 +56,9 @@ impl TryInto for BinaryResponse { type Error = std::io::Error; fn try_into(self) -> std::io::Result { - BlockHash::from_slice(&self.0).map_err(|_| + BlockHash::from_slice(&self.0).map_err(|_| { std::io::Error::new(std::io::ErrorKind::InvalidData, "bad block hash length") - ) + }) } } @@ -62,18 +69,30 @@ impl TryInto for JsonResponse { fn try_into(self) -> std::io::Result { let header = match self.0 { - serde_json::Value::Array(mut array) if !array.is_empty() => array.drain(..).next().unwrap(), + serde_json::Value::Array(mut array) if !array.is_empty() => { + array.drain(..).next().unwrap() + }, serde_json::Value::Object(_) => self.0, - _ => return Err(std::io::Error::new(std::io::ErrorKind::InvalidData, "unexpected JSON type")), + _ => { + return Err(std::io::Error::new( + std::io::ErrorKind::InvalidData, + "unexpected JSON type", + )) + }, }; if !header.is_object() { - return Err(std::io::Error::new(std::io::ErrorKind::InvalidData, "expected JSON object")); + return Err(std::io::Error::new( + std::io::ErrorKind::InvalidData, + "expected JSON object", + )); } // Add an empty previousblockhash for the genesis block. match header.try_into() { - Err(_) => Err(std::io::Error::new(std::io::ErrorKind::InvalidData, "invalid header data")), + Err(_) => { + Err(std::io::Error::new(std::io::ErrorKind::InvalidData, "invalid header data")) + }, Ok(header) => Ok(header), } } @@ -83,19 +102,26 @@ impl TryFrom for BlockHeaderData { type Error = (); fn try_from(response: serde_json::Value) -> Result { - macro_rules! get_field { ($name: expr, $ty_access: tt) => { - response.get($name).ok_or(())?.$ty_access().ok_or(())? - } } + macro_rules! get_field { + ($name: expr, $ty_access: tt) => { + response.get($name).ok_or(())?.$ty_access().ok_or(())? + }; + } Ok(BlockHeaderData { header: BlockHeader { version: get_field!("version", as_i64).try_into().map_err(|_| ())?, prev_blockhash: if let Some(hash_str) = response.get("previousblockhash") { - BlockHash::from_hex(hash_str.as_str().ok_or(())?).map_err(|_| ())? - } else { BlockHash::all_zeros() }, - merkle_root: TxMerkleNode::from_hex(get_field!("merkleroot", as_str)).map_err(|_| ())?, + BlockHash::from_hex(hash_str.as_str().ok_or(())?).map_err(|_| ())? + } else { + BlockHash::all_zeros() + }, + merkle_root: TxMerkleNode::from_hex(get_field!("merkleroot", as_str)) + .map_err(|_| ())?, time: get_field!("time", as_u64).try_into().map_err(|_| ())?, - bits: u32::from_be_bytes(<[u8; 4]>::from_hex(get_field!("bits", as_str)).map_err(|_| ())?), + bits: u32::from_be_bytes( + <[u8; 4]>::from_hex(get_field!("bits", as_str)).map_err(|_| ())?, + ), nonce: get_field!("nonce", as_u64).try_into().map_err(|_| ())?, }, chainwork: hex_to_uint256(get_field!("chainwork", as_str)).map_err(|_| ())?, @@ -110,11 +136,18 @@ impl TryInto for JsonResponse { fn try_into(self) -> std::io::Result { match self.0.as_str() { - None => Err(std::io::Error::new(std::io::ErrorKind::InvalidData, "expected JSON string")), + None => { + Err(std::io::Error::new(std::io::ErrorKind::InvalidData, "expected JSON string")) + }, Some(hex_data) => match Vec::::from_hex(hex_data) { - Err(_) => Err(std::io::Error::new(std::io::ErrorKind::InvalidData, "invalid hex data")), + Err(_) => { + Err(std::io::Error::new(std::io::ErrorKind::InvalidData, "invalid hex data")) + }, Ok(block_data) => match encode::deserialize(&block_data) { - Err(_) => Err(std::io::Error::new(std::io::ErrorKind::InvalidData, "invalid block data")), + Err(_) => Err(std::io::Error::new( + std::io::ErrorKind::InvalidData, + "invalid block data", + )), Ok(block) => Ok(block), }, }, @@ -128,27 +161,55 @@ impl TryInto<(BlockHash, Option)> for JsonResponse { fn try_into(self) -> std::io::Result<(BlockHash, Option)> { if !self.0.is_object() { - return Err(std::io::Error::new(std::io::ErrorKind::InvalidData, "expected JSON object")); + return Err(std::io::Error::new( + std::io::ErrorKind::InvalidData, + "expected JSON object", + )); } let hash = match &self.0["bestblockhash"] { serde_json::Value::String(hex_data) => match BlockHash::from_hex(&hex_data) { - Err(_) => return Err(std::io::Error::new(std::io::ErrorKind::InvalidData, "invalid hex data")), + Err(_) => { + return Err(std::io::Error::new( + std::io::ErrorKind::InvalidData, + "invalid hex data", + )) + }, Ok(block_hash) => block_hash, }, - _ => return Err(std::io::Error::new(std::io::ErrorKind::InvalidData, "expected JSON string")), + _ => { + return Err(std::io::Error::new( + std::io::ErrorKind::InvalidData, + "expected JSON string", + )) + }, }; let height = match &self.0["blocks"] { serde_json::Value::Null => None, serde_json::Value::Number(height) => match height.as_u64() { - None => return Err(std::io::Error::new(std::io::ErrorKind::InvalidData, "invalid height")), + None => { + return Err(std::io::Error::new( + std::io::ErrorKind::InvalidData, + "invalid height", + )) + }, Some(height) => match height.try_into() { - Err(_) => return Err(std::io::Error::new(std::io::ErrorKind::InvalidData, "invalid height")), + Err(_) => { + return Err(std::io::Error::new( + std::io::ErrorKind::InvalidData, + "invalid height", + )) + }, Ok(height) => Some(height), - } + }, + }, + _ => { + return Err(std::io::Error::new( + std::io::ErrorKind::InvalidData, + "expected JSON number", + )) }, - _ => return Err(std::io::Error::new(std::io::ErrorKind::InvalidData, "expected JSON number")), }; Ok((hash, height)) @@ -159,20 +220,17 @@ impl TryInto for JsonResponse { type Error = std::io::Error; fn try_into(self) -> std::io::Result { match self.0.as_str() { - None => Err(std::io::Error::new( - std::io::ErrorKind::InvalidData, - "expected JSON string", - )), + None => { + Err(std::io::Error::new(std::io::ErrorKind::InvalidData, "expected JSON string")) + }, Some(hex_data) => match Vec::::from_hex(hex_data) { - Err(_) => Err(std::io::Error::new( - std::io::ErrorKind::InvalidData, - "invalid hex data", - )), + Err(_) => { + Err(std::io::Error::new(std::io::ErrorKind::InvalidData, "invalid hex data")) + }, Ok(txid_data) => match encode::deserialize(&txid_data) { - Err(_) => Err(std::io::Error::new( - std::io::ErrorKind::InvalidData, - "invalid txid", - )), + Err(_) => { + Err(std::io::Error::new(std::io::ErrorKind::InvalidData, "invalid txid")) + }, Ok(txid) => Ok(txid), }, }, @@ -205,14 +263,16 @@ impl TryInto for JsonResponse { } else { hex_data } - } + }, // result is a complete transaction (e.g. getrawtranaction verbose) _ => hex_data, }, - _ => return Err(std::io::Error::new( - std::io::ErrorKind::InvalidData, - "expected JSON string", - )), + _ => { + return Err(std::io::Error::new( + std::io::ErrorKind::InvalidData, + "expected JSON string", + )) + }, } } else { // result is plain text (e.g. getrawtransaction no verbose) @@ -223,20 +283,16 @@ impl TryInto for JsonResponse { std::io::ErrorKind::InvalidData, "expected JSON string", )) - } + }, } }; match Vec::::from_hex(hex_tx) { - Err(_) => Err(std::io::Error::new( - std::io::ErrorKind::InvalidData, - "invalid hex data", - )), + Err(_) => Err(std::io::Error::new(std::io::ErrorKind::InvalidData, "invalid hex data")), Ok(tx_data) => match encode::deserialize(&tx_data) { - Err(_) => Err(std::io::Error::new( - std::io::ErrorKind::InvalidData, - "invalid transaction", - )), + Err(_) => { + Err(std::io::Error::new(std::io::ErrorKind::InvalidData, "invalid transaction")) + }, Ok(tx) => Ok(tx), }, } @@ -248,11 +304,15 @@ impl TryInto for JsonResponse { fn try_into(self) -> std::io::Result { match self.0.as_str() { - None => Err(std::io::Error::new(std::io::ErrorKind::InvalidData, "expected JSON string")), - Some(hex_data) if hex_data.len() != 64 => - Err(std::io::Error::new(std::io::ErrorKind::InvalidData, "invalid hash length")), - Some(hex_data) => BlockHash::from_str(hex_data) - .map_err(|_| std::io::Error::new(std::io::ErrorKind::InvalidData, "invalid hex data")), + None => { + Err(std::io::Error::new(std::io::ErrorKind::InvalidData, "expected JSON string")) + }, + Some(hex_data) if hex_data.len() != 64 => { + Err(std::io::Error::new(std::io::ErrorKind::InvalidData, "invalid hash length")) + }, + Some(hex_data) => BlockHash::from_str(hex_data).map_err(|_| { + std::io::Error::new(std::io::ErrorKind::InvalidData, "invalid hex data") + }), } } } @@ -261,25 +321,34 @@ impl TryInto for JsonResponse { /// - whether the `hit bitmap` field had any entries. Thus we condense the result down into only /// that. pub(crate) struct GetUtxosResponse { - pub(crate) hit_bitmap_nonempty: bool + pub(crate) hit_bitmap_nonempty: bool, } impl TryInto for JsonResponse { type Error = std::io::Error; fn try_into(self) -> std::io::Result { - let bitmap_str = - self.0.as_object().ok_or(std::io::Error::new(std::io::ErrorKind::InvalidData, "expected an object"))? - .get("bitmap").ok_or(std::io::Error::new(std::io::ErrorKind::InvalidData, "missing bitmap field"))? - .as_str().ok_or(std::io::Error::new(std::io::ErrorKind::InvalidData, "bitmap should be an str"))?; - let mut hit_bitmap_nonempty = false; - for c in bitmap_str.chars() { - if c < '0' || c > '9' { - return Err(std::io::Error::new(std::io::ErrorKind::InvalidData, "invalid byte")); - } - if c > '0' { hit_bitmap_nonempty = true; } + let bitmap_str = self + .0 + .as_object() + .ok_or(std::io::Error::new(std::io::ErrorKind::InvalidData, "expected an object"))? + .get("bitmap") + .ok_or(std::io::Error::new(std::io::ErrorKind::InvalidData, "missing bitmap field"))? + .as_str() + .ok_or(std::io::Error::new( + std::io::ErrorKind::InvalidData, + "bitmap should be an str", + ))?; + let mut hit_bitmap_nonempty = false; + for c in bitmap_str.chars() { + if c < '0' || c > '9' { + return Err(std::io::Error::new(std::io::ErrorKind::InvalidData, "invalid byte")); + } + if c > '0' { + hit_bitmap_nonempty = true; } - Ok(GetUtxosResponse { hit_bitmap_nonempty }) + } + Ok(GetUtxosResponse { hit_bitmap_nonempty }) } } @@ -287,8 +356,8 @@ impl TryInto for JsonResponse { pub(crate) mod tests { use super::*; use bitcoin::blockdata::constants::genesis_block; - use bitcoin::hashes::Hash; use bitcoin::hashes::hex::ToHex; + use bitcoin::hashes::Hash; use bitcoin::network::constants::Network; use serde_json::value::Number; use serde_json::Value; @@ -337,11 +406,10 @@ pub(crate) mod tests { #[test] fn into_block_header_from_json_response_with_invalid_header_response() { let block = genesis_block(Network::Bitcoin); - let mut response = JsonResponse(BlockHeaderData { - chainwork: block.header.work(), - height: 0, - header: block.header - }.into()); + let mut response = JsonResponse( + BlockHeaderData { chainwork: block.header.work(), height: 0, header: block.header } + .into(), + ); response.0["chainwork"].take(); match TryInto::::try_into(response) { @@ -356,11 +424,10 @@ pub(crate) mod tests { #[test] fn into_block_header_from_json_response_with_invalid_header_data() { let block = genesis_block(Network::Bitcoin); - let mut response = JsonResponse(BlockHeaderData { - chainwork: block.header.work(), - height: 0, - header: block.header - }.into()); + let mut response = JsonResponse( + BlockHeaderData { chainwork: block.header.work(), height: 0, header: block.header } + .into(), + ); response.0["chainwork"] = serde_json::json!("foobar"); match TryInto::::try_into(response) { @@ -375,11 +442,10 @@ pub(crate) mod tests { #[test] fn into_block_header_from_json_response_with_valid_header() { let block = genesis_block(Network::Bitcoin); - let response = JsonResponse(BlockHeaderData { - chainwork: block.header.work(), - height: 0, - header: block.header - }.into()); + let response = JsonResponse( + BlockHeaderData { chainwork: block.header.work(), height: 0, header: block.header } + .into(), + ); match TryInto::::try_into(response) { Err(e) => panic!("Unexpected error: {:?}", e), @@ -394,18 +460,20 @@ pub(crate) mod tests { #[test] fn into_block_header_from_json_response_with_valid_header_array() { let genesis_block = genesis_block(Network::Bitcoin); - let best_block_header = BlockHeader { - prev_blockhash: genesis_block.block_hash(), - ..genesis_block.header - }; + let best_block_header = + BlockHeader { prev_blockhash: genesis_block.block_hash(), ..genesis_block.header }; let chainwork = genesis_block.header.work() + best_block_header.work(); let response = JsonResponse(serde_json::json!([ - serde_json::Value::from(BlockHeaderData { - chainwork, height: 1, header: best_block_header, - }), - serde_json::Value::from(BlockHeaderData { - chainwork: genesis_block.header.work(), height: 0, header: genesis_block.header, - }), + serde_json::Value::from(BlockHeaderData { + chainwork, + height: 1, + header: best_block_header, + }), + serde_json::Value::from(BlockHeaderData { + chainwork: genesis_block.header.work(), + height: 0, + header: genesis_block.header, + }), ])); match TryInto::::try_into(response) { @@ -421,11 +489,10 @@ pub(crate) mod tests { #[test] fn into_block_header_from_json_response_without_previous_block_hash() { let block = genesis_block(Network::Bitcoin); - let mut response = JsonResponse(BlockHeaderData { - chainwork: block.header.work(), - height: 0, - header: block.header - }.into()); + let mut response = JsonResponse( + BlockHeaderData { chainwork: block.header.work(), height: 0, header: block.header } + .into(), + ); response.0.as_object_mut().unwrap().remove("previousblockhash"); match TryInto::::try_into(response) { @@ -607,7 +674,7 @@ pub(crate) mod tests { Err(e) => { assert_eq!(e.kind(), std::io::ErrorKind::InvalidData); assert_eq!(e.get_ref().unwrap().to_string(), "expected JSON string"); - } + }, Ok(_) => panic!("Expected error"), } } @@ -619,7 +686,7 @@ pub(crate) mod tests { Err(e) => { assert_eq!(e.kind(), std::io::ErrorKind::InvalidData); assert_eq!(e.get_ref().unwrap().to_string(), "invalid hex data"); - } + }, Ok(_) => panic!("Expected error"), } } @@ -631,7 +698,7 @@ pub(crate) mod tests { Err(e) => { assert_eq!(e.kind(), std::io::ErrorKind::InvalidData); assert_eq!(e.get_ref().unwrap().to_string(), "invalid txid"); - } + }, Ok(_) => panic!("Expected error"), } } @@ -660,7 +727,7 @@ pub(crate) mod tests { Err(e) => { assert_eq!(e.kind(), std::io::ErrorKind::InvalidData); assert_eq!(e.get_ref().unwrap().to_string(), "invalid hex data"); - } + }, Ok(_) => panic!("Expected error"), } } @@ -672,7 +739,7 @@ pub(crate) mod tests { Err(e) => { assert_eq!(e.kind(), std::io::ErrorKind::InvalidData); assert_eq!(e.get_ref().unwrap().to_string(), "expected JSON string"); - } + }, Ok(_) => panic!("Expected error"), } } @@ -684,7 +751,7 @@ pub(crate) mod tests { Err(e) => { assert_eq!(e.kind(), std::io::ErrorKind::InvalidData); assert_eq!(e.get_ref().unwrap().to_string(), "invalid transaction"); - } + }, Ok(_) => panic!("Expected error"), } } @@ -719,11 +786,8 @@ pub(crate) mod tests { match TryInto::::try_into(response) { Err(e) => { assert_eq!(e.kind(), std::io::ErrorKind::InvalidData); - assert_eq!( - e.get_ref().unwrap().to_string(), - "expected JSON string" - ); - } + assert_eq!(e.get_ref().unwrap().to_string(), "expected JSON string"); + }, Ok(_) => panic!("Expected error"), } } @@ -734,11 +798,12 @@ pub(crate) mod tests { match TryInto::::try_into(response) { Err(e) => { assert_eq!(e.kind(), std::io::ErrorKind::InvalidData); - assert!( - e.get_ref().unwrap().to_string().contains( - "transaction couldn't be signed") - ); - } + assert!(e + .get_ref() + .unwrap() + .to_string() + .contains("transaction couldn't be signed")); + }, Ok(_) => panic!("Expected error"), } } diff --git a/lightning-block-sync/src/gossip.rs b/lightning-block-sync/src/gossip.rs index 3b6e9f68376..c4e848b0c36 100644 --- a/lightning-block-sync/src/gossip.rs +++ b/lightning-block-sync/src/gossip.rs @@ -6,24 +6,24 @@ use crate::{AsyncBlockSourceResult, BlockData, BlockSource, BlockSourceError}; use bitcoin::blockdata::block::Block; use bitcoin::blockdata::constants::ChainHash; -use bitcoin::blockdata::transaction::{TxOut, OutPoint}; +use bitcoin::blockdata::transaction::{OutPoint, TxOut}; use bitcoin::hash_types::BlockHash; use lightning::sign::NodeSigner; -use lightning::ln::peer_handler::{CustomMessageHandler, PeerManager, SocketDescriptor}; use lightning::ln::msgs::{ChannelMessageHandler, OnionMessageHandler}; +use lightning::ln::peer_handler::{CustomMessageHandler, PeerManager, SocketDescriptor}; use lightning::routing::gossip::{NetworkGraph, P2PGossipSync}; -use lightning::routing::utxo::{UtxoFuture, UtxoLookup, UtxoResult, UtxoLookupError}; +use lightning::routing::utxo::{UtxoFuture, UtxoLookup, UtxoLookupError, UtxoResult}; use lightning::util::logger::Logger; -use std::sync::{Arc, Mutex}; use std::collections::VecDeque; use std::future::Future; use std::ops::Deref; use std::pin::Pin; +use std::sync::{Arc, Mutex}; use std::task::Poll; /// A trait which extends [`BlockSource`] and can be queried to fetch the block at a given height @@ -32,12 +32,14 @@ use std::task::Poll; /// Note that while this is implementable for a [`BlockSource`] which returns filtered block data /// (i.e. [`BlockData::HeaderOnly`] for [`BlockSource::get_block`] requests), such an /// implementation will reject all gossip as it is not fully able to verify the UTXOs referenced. -pub trait UtxoSource : BlockSource + 'static { +pub trait UtxoSource: BlockSource + 'static { /// Fetches the block hash of the block at the given height. /// /// This will, in turn, be passed to to [`BlockSource::get_block`] to fetch the block needed /// for gossip validation. - fn get_block_hash_by_height<'a>(&'a self, block_height: u32) -> AsyncBlockSourceResult<'a, BlockHash>; + fn get_block_hash_by_height<'a>( + &'a self, block_height: u32, + ) -> AsyncBlockSourceResult<'a, BlockHash>; /// Returns true if the given output has *not* been spent, i.e. is a member of the current UTXO /// set. @@ -48,7 +50,7 @@ pub trait UtxoSource : BlockSource + 'static { /// /// If the `tokio` feature is enabled, this is implemented on `TokioSpawner` struct which /// delegates to `tokio::spawn()`. -pub trait FutureSpawner : Send + Sync + 'static { +pub trait FutureSpawner: Send + Sync + 'static { /// Spawns the given future as a background task. /// /// This method MUST NOT block on the given future immediately. @@ -68,8 +70,8 @@ impl FutureSpawner for TokioSpawner { /// A trivial future which joins two other futures and polls them at the same time, returning only /// once both complete. pub(crate) struct Joiner< - A: Future), BlockSourceError>> + Unpin, - B: Future> + Unpin, + A: Future), BlockSourceError>> + Unpin, + B: Future> + Unpin, > { pub a: A, pub b: B, @@ -78,16 +80,20 @@ pub(crate) struct Joiner< } impl< - A: Future), BlockSourceError>> + Unpin, - B: Future> + Unpin, -> Joiner { - fn new(a: A, b: B) -> Self { Self { a, b, a_res: None, b_res: None } } + A: Future), BlockSourceError>> + Unpin, + B: Future> + Unpin, + > Joiner +{ + fn new(a: A, b: B) -> Self { + Self { a, b, a_res: None, b_res: None } + } } impl< - A: Future), BlockSourceError>> + Unpin, - B: Future> + Unpin, -> Future for Joiner { + A: Future), BlockSourceError>> + Unpin, + B: Future> + Unpin, + > Future for Joiner +{ type Output = Result<((BlockHash, Option), BlockHash), BlockSourceError>; fn poll(mut self: Pin<&mut Self>, ctx: &mut core::task::Context<'_>) -> Poll { if self.a_res.is_none() { @@ -110,14 +116,13 @@ impl< } else { return Poll::Ready(Err(res.unwrap_err())); } - }, Poll::Pending => {}, } } if let Some(b_res) = self.b_res { if let Some(a_res) = self.a_res { - return Poll::Ready(Ok((a_res, b_res))) + return Poll::Ready(Ok((a_res, b_res))); } } Poll::Pending @@ -132,7 +137,8 @@ impl< /// value of 1024 should more than suffice), and ensure you have sufficient file descriptors /// available on both Bitcoin Core and your LDK application for each request to hold its own /// connection. -pub struct GossipVerifier>, Self, L>>, OM, L, CMH, NS>>, + peer_manager: Arc< + PeerManager< + Descriptor, + CM, + Arc>, Self, L>>, + OM, + L, + CMH, + NS, + >, + >, gossiper: Arc>, Self, L>>, spawn: S, block_cache: Arc>>, @@ -157,15 +173,17 @@ pub struct GossipVerifier GossipVerifier where +impl< + S: FutureSpawner, + Blocks: Deref + Send + Sync + Clone, + L: Deref + Send + Sync, + Descriptor: SocketDescriptor + Send + Sync, + CM: Deref + Send + Sync, + OM: Deref + Send + Sync, + CMH: Deref + Send + Sync, + NS: Deref + Send + Sync, + > GossipVerifier +where Blocks::Target: UtxoSource, L::Target: Logger, CM::Target: ChannelMessageHandler, @@ -177,15 +195,31 @@ impl>, Self, L>>, peer_manager: Arc>, Self, L>>, OM, L, CMH, NS>>) -> Self { + pub fn new( + source: Blocks, spawn: S, gossiper: Arc>, Self, L>>, + peer_manager: Arc< + PeerManager< + Descriptor, + CM, + Arc>, Self, L>>, + OM, + L, + CMH, + NS, + >, + >, + ) -> Self { Self { - source, spawn, gossiper, peer_manager, + source, + spawn, + gossiper, + peer_manager, block_cache: Arc::new(Mutex::new(VecDeque::with_capacity(BLOCK_CACHE_SIZE))), } } async fn retrieve_utxo( - source: Blocks, block_cache: Arc>>, short_channel_id: u64 + source: Blocks, block_cache: Arc>>, short_channel_id: u64, ) -> Result { let block_height = (short_channel_id >> 5 * 8) as u32; // block height is most significant three bytes let transaction_index = ((short_channel_id >> 2 * 8) & 0xffffff) as u32; @@ -193,9 +227,10 @@ impl { { + ($block: expr) => {{ if transaction_index as usize >= $block.txdata.len() { return Err(UtxoLookupError::UnknownTx); } @@ -206,7 +241,7 @@ impl return Err(UtxoLookupError::UnknownTx), BlockData::FullBlock(block) => block, @@ -255,7 +290,7 @@ impl Deref for GossipVerifier where +impl< + S: FutureSpawner, + Blocks: Deref + Send + Sync + Clone, + L: Deref + Send + Sync, + Descriptor: SocketDescriptor + Send + Sync, + CM: Deref + Send + Sync, + OM: Deref + Send + Sync, + CMH: Deref + Send + Sync, + NS: Deref + Send + Sync, + > Deref for GossipVerifier +where Blocks::Target: UtxoSource, L::Target: Logger, CM::Target: ChannelMessageHandler, @@ -283,19 +320,22 @@ impl &Self { self } + fn deref(&self) -> &Self { + self + } } - -impl UtxoLookup for GossipVerifier where +impl< + S: FutureSpawner, + Blocks: Deref + Send + Sync + Clone, + L: Deref + Send + Sync, + Descriptor: SocketDescriptor + Send + Sync, + CM: Deref + Send + Sync, + OM: Deref + Send + Sync, + CMH: Deref + Send + Sync, + NS: Deref + Send + Sync, + > UtxoLookup for GossipVerifier +where Blocks::Target: UtxoSource, L::Target: Logger, CM::Target: ChannelMessageHandler, diff --git a/lightning-block-sync/src/http.rs b/lightning-block-sync/src/http.rs index 58d66686f01..f7bac661b27 100644 --- a/lightning-block-sync/src/http.rs +++ b/lightning-block-sync/src/http.rs @@ -50,11 +50,7 @@ pub struct HttpEndpoint { impl HttpEndpoint { /// Creates an endpoint for the given host and default HTTP port. pub fn for_host(host: String) -> Self { - Self { - host, - port: None, - path: String::from("/"), - } + Self { host, port: None, path: String::from("/") } } /// Specifies a port to use with the endpoint. @@ -107,7 +103,10 @@ impl HttpClient { pub fn connect(endpoint: E) -> std::io::Result { let address = match endpoint.to_socket_addrs()?.next() { None => { - return Err(std::io::Error::new(std::io::ErrorKind::InvalidInput, "could not resolve to any addresses")); + return Err(std::io::Error::new( + std::io::ErrorKind::InvalidInput, + "could not resolve to any addresses", + )); }, Some(address) => address, }; @@ -129,12 +128,16 @@ impl HttpClient { /// Returns the response body in `F` format. #[allow(dead_code)] pub async fn get(&mut self, uri: &str, host: &str) -> std::io::Result - where F: TryFrom, Error = std::io::Error> { + where + F: TryFrom, Error = std::io::Error>, + { let request = format!( "GET {} HTTP/1.1\r\n\ Host: {}\r\n\ Connection: keep-alive\r\n\ - \r\n", uri, host); + \r\n", + uri, host + ); let response_body = self.send_request_with_retry(&request).await?; F::try_from(response_body) } @@ -145,8 +148,12 @@ impl HttpClient { /// The request body consists of the provided JSON `content`. Returns the response body in `F` /// format. #[allow(dead_code)] - pub async fn post(&mut self, uri: &str, host: &str, auth: &str, content: serde_json::Value) -> std::io::Result - where F: TryFrom, Error = std::io::Error> { + pub async fn post( + &mut self, uri: &str, host: &str, auth: &str, content: serde_json::Value, + ) -> std::io::Result + where + F: TryFrom, Error = std::io::Error>, + { let content = content.to_string(); let request = format!( "POST {} HTTP/1.1\r\n\ @@ -156,7 +163,13 @@ impl HttpClient { Content-Type: application/json\r\n\ Content-Length: {}\r\n\ \r\n\ - {}", uri, host, auth, content.len(), content); + {}", + uri, + host, + auth, + content.len(), + content + ); let response_body = self.send_request_with_retry(&request).await?; F::try_from(response_body) } @@ -218,8 +231,10 @@ impl HttpClient { let mut reader = std::io::BufReader::new(limited_stream); macro_rules! read_line { - () => { read_line!(0) }; - ($retry_count: expr) => { { + () => { + read_line!(0) + }; + ($retry_count: expr) => {{ let mut line = String::new(); let mut timeout_count: u64 = 0; let bytes_read = loop { @@ -236,7 +251,7 @@ impl HttpClient { } else { continue; } - } + }, Err(e) => return Err(e), } }; @@ -245,17 +260,23 @@ impl HttpClient { 0 => None, _ => { // Remove trailing CRLF - if line.ends_with('\n') { line.pop(); if line.ends_with('\r') { line.pop(); } } + if line.ends_with('\n') { + line.pop(); + if line.ends_with('\r') { + line.pop(); + } + } Some(line) }, } - } } + }}; } // Read and parse status line // Note that we allow retrying a few times to reach TCP_STREAM_RESPONSE_TIMEOUT. - let status_line = read_line!(TCP_STREAM_RESPONSE_TIMEOUT.as_secs() / TCP_STREAM_TIMEOUT.as_secs()) - .ok_or(std::io::Error::new(std::io::ErrorKind::UnexpectedEof, "no status line"))?; + let status_line = + read_line!(TCP_STREAM_RESPONSE_TIMEOUT.as_secs() / TCP_STREAM_TIMEOUT.as_secs()) + .ok_or(std::io::Error::new(std::io::ErrorKind::UnexpectedEof, "no status line"))?; let status = HttpStatus::parse(&status_line)?; // Read and parse relevant headers @@ -263,11 +284,15 @@ impl HttpClient { loop { let line = read_line!() .ok_or(std::io::Error::new(std::io::ErrorKind::UnexpectedEof, "no headers"))?; - if line.is_empty() { break; } + if line.is_empty() { + break; + } let header = HttpHeader::parse(&line)?; if header.has_name("Content-Length") { - let length = header.value.parse() + let length = header + .value + .parse() .map_err(|e| std::io::Error::new(std::io::ErrorKind::InvalidData, e))?; if let HttpMessageLength::Empty = message_length { message_length = HttpMessageLength::ContentLength(length); @@ -285,10 +310,13 @@ impl HttpClient { let read_limit = MAX_HTTP_MESSAGE_BODY_SIZE - reader.buffer().len(); reader.get_mut().set_limit(read_limit as u64); let contents = match message_length { - HttpMessageLength::Empty => { Vec::new() }, + HttpMessageLength::Empty => Vec::new(), HttpMessageLength::ContentLength(length) => { if length == 0 || length > MAX_HTTP_MESSAGE_BODY_SIZE { - return Err(std::io::Error::new(std::io::ErrorKind::InvalidData, "out of range")) + return Err(std::io::Error::new( + std::io::ErrorKind::InvalidData, + "out of range", + )); } else { let mut content = vec![0; length]; #[cfg(feature = "tokio")] @@ -301,7 +329,9 @@ impl HttpClient { HttpMessageLength::TransferEncoding(coding) => { if !coding.eq_ignore_ascii_case("chunked") { return Err(std::io::Error::new( - std::io::ErrorKind::InvalidInput, "unsupported transfer coding")) + std::io::ErrorKind::InvalidInput, + "unsupported transfer coding", + )); } else { let mut content = Vec::new(); #[cfg(feature = "tokio")] @@ -323,7 +353,8 @@ impl HttpClient { // Decode the chunk header to obtain the chunk size. let mut buffer = Vec::new(); - let mut decoder = chunked_transfer::Decoder::new(chunk_header.as_bytes()); + let mut decoder = + chunked_transfer::Decoder::new(chunk_header.as_bytes()); decoder.read_to_end(&mut buffer)?; // Read the chunk body. @@ -350,10 +381,7 @@ impl HttpClient { if !status.is_ok() { // TODO: Handle 3xx redirection responses. - let error = HttpError { - status_code: status.code.to_string(), - contents, - }; + let error = HttpError { status_code: status.code.to_string(), contents }; return Err(std::io::Error::new(std::io::ErrorKind::Other, error)); } @@ -391,20 +419,30 @@ impl<'a> HttpStatus<'a> { fn parse(line: &'a String) -> std::io::Result> { let mut tokens = line.splitn(3, ' '); - let http_version = tokens.next() + let http_version = tokens + .next() .ok_or(std::io::Error::new(std::io::ErrorKind::InvalidData, "no HTTP-Version"))?; - if !http_version.eq_ignore_ascii_case("HTTP/1.1") && - !http_version.eq_ignore_ascii_case("HTTP/1.0") { - return Err(std::io::Error::new(std::io::ErrorKind::InvalidData, "invalid HTTP-Version")); + if !http_version.eq_ignore_ascii_case("HTTP/1.1") + && !http_version.eq_ignore_ascii_case("HTTP/1.0") + { + return Err(std::io::Error::new( + std::io::ErrorKind::InvalidData, + "invalid HTTP-Version", + )); } - let code = tokens.next() + let code = tokens + .next() .ok_or(std::io::Error::new(std::io::ErrorKind::InvalidData, "no Status-Code"))?; if code.len() != 3 || !code.chars().all(|c| c.is_ascii_digit()) { - return Err(std::io::Error::new(std::io::ErrorKind::InvalidData, "invalid Status-Code")); + return Err(std::io::Error::new( + std::io::ErrorKind::InvalidData, + "invalid Status-Code", + )); } - let _reason = tokens.next() + let _reason = tokens + .next() .ok_or(std::io::Error::new(std::io::ErrorKind::InvalidData, "no Reason-Phrase"))?; Ok(Self { code }) @@ -430,9 +468,11 @@ impl<'a> HttpHeader<'a> { /// [RFC 7230]: https://tools.ietf.org/html/rfc7230#section-3.2 fn parse(line: &'a String) -> std::io::Result> { let mut tokens = line.splitn(2, ':'); - let name = tokens.next() + let name = tokens + .next() .ok_or(std::io::Error::new(std::io::ErrorKind::InvalidData, "no header name"))?; - let value = tokens.next() + let value = tokens + .next() .ok_or(std::io::Error::new(std::io::ErrorKind::InvalidData, "no header value"))? .trim_start(); Ok(Self { name, value }) @@ -524,7 +564,7 @@ mod endpoint_tests { assert_eq!(addr, std_addrs.next().unwrap()); } assert!(std_addrs.next().is_none()); - } + }, } } } @@ -559,7 +599,11 @@ pub(crate) mod client_tests { "{}\r\n\ Content-Length: {}\r\n\ \r\n\ - {}", status, body.len(), body) + {}", + status, + body.len(), + body + ) }, MessageBody::ChunkedContent(body) => { let mut chuncked_body = Vec::new(); @@ -572,7 +616,10 @@ pub(crate) mod client_tests { "{}\r\n\ Transfer-Encoding: chunked\r\n\ \r\n\ - {}", status, String::from_utf8(chuncked_body).unwrap()) + {}", + status, + String::from_utf8(chuncked_body).unwrap() + ) }, }; HttpServer::responding_with(response) @@ -606,14 +653,20 @@ pub(crate) mod client_tests { .lines() .take_while(|line| !line.as_ref().unwrap().is_empty()) .count(); - if lines_read == 0 { continue; } + if lines_read == 0 { + continue; + } for chunk in response.as_bytes().chunks(16) { if shutdown_signaled.load(std::sync::atomic::Ordering::SeqCst) { return; } else { - if let Err(_) = stream.write(chunk) { break; } - if let Err(_) = stream.flush() { break; } + if let Err(_) = stream.write(chunk) { + break; + } + if let Err(_) = stream.flush() { + break; + } } } } @@ -636,8 +689,12 @@ pub(crate) mod client_tests { fn connect_to_unresolvable_host() { match HttpClient::connect(("example.invalid", 80)) { Err(e) => { - assert!(e.to_string().contains("failed to lookup address information") || - e.to_string().contains("No such host"), "{:?}", e); + assert!( + e.to_string().contains("failed to lookup address information") + || e.to_string().contains("No such host"), + "{:?}", + e + ); }, Ok(_) => panic!("Expected error"), } @@ -705,7 +762,9 @@ pub(crate) mod client_tests { let response = format!( "HTTP/1.1 302 Found\r\n\ Location: {}\r\n\ - \r\n", "Z".repeat(MAX_HTTP_MESSAGE_HEADER_SIZE)); + \r\n", + "Z".repeat(MAX_HTTP_MESSAGE_HEADER_SIZE) + ); let server = HttpServer::responding_with(response); let mut client = HttpClient::connect(&server.endpoint()).unwrap(); @@ -740,7 +799,8 @@ pub(crate) mod client_tests { "HTTP/1.1 200 OK\r\n\ Transfer-Encoding: gzip\r\n\ \r\n\ - foobar"); + foobar", + ); let server = HttpServer::responding_with(response); let mut client = HttpClient::connect(&server.endpoint()).unwrap(); diff --git a/lightning-block-sync/src/init.rs b/lightning-block-sync/src/init.rs index 5423bba5182..853a1a4c53e 100644 --- a/lightning-block-sync/src/init.rs +++ b/lightning-block-sync/src/init.rs @@ -1,8 +1,8 @@ //! Utilities to assist in the initial sync required to initialize or reload Rust-Lightning objects //! from disk. -use crate::{BlockSource, BlockSourceResult, Cache, ChainNotifier}; use crate::poll::{ChainPoller, Validate, ValidatedBlockHeader}; +use crate::{BlockSource, BlockSourceResult, Cache, ChainNotifier}; use bitcoin::blockdata::block::BlockHeader; use bitcoin::hash_types::BlockHash; @@ -18,12 +18,14 @@ use std::ops::Deref; /// start when there are no chain listeners to sync yet. /// /// [`SpvClient`]: crate::SpvClient -pub async fn validate_best_block_header(block_source: B) -> -BlockSourceResult where B::Target: BlockSource { +pub async fn validate_best_block_header( + block_source: B, +) -> BlockSourceResult +where + B::Target: BlockSource, +{ let (best_block_hash, best_block_height) = block_source.get_best_block().await?; - block_source - .get_header(&best_block_hash, best_block_height).await? - .validate(best_block_hash) + block_source.get_header(&best_block_hash, best_block_height).await?.validate(best_block_hash) } /// Performs a one-time sync of chain listeners using a single *trusted* block source, bringing each @@ -131,12 +133,17 @@ BlockSourceResult where B::Target: BlockSource { /// [`SpvClient`]: crate::SpvClient /// [`ChannelManager`]: lightning::ln::channelmanager::ChannelManager /// [`ChannelMonitor`]: lightning::chain::channelmonitor::ChannelMonitor -pub async fn synchronize_listeners( - block_source: B, - network: Network, - header_cache: &mut C, +pub async fn synchronize_listeners< + B: Deref + Sized + Send + Sync, + C: Cache, + L: chain::Listen + ?Sized, +>( + block_source: B, network: Network, header_cache: &mut C, mut chain_listeners: Vec<(BlockHash, &L)>, -) -> BlockSourceResult where B::Target: BlockSource { +) -> BlockSourceResult +where + B::Target: BlockSource, +{ let best_header = validate_best_block_header(&*block_source).await?; // Fetch the header for the block hash paired with each listener. @@ -144,9 +151,9 @@ pub async fn synchronize_listeners *header, - None => block_source - .get_header(&old_block_hash, None).await? - .validate(old_block_hash)? + None => { + block_source.get_header(&old_block_hash, None).await?.validate(old_block_hash)? + }, }; chain_listeners_with_old_headers.push((old_header, chain_listener)) } @@ -180,8 +187,10 @@ pub async fn synchronize_listeners Cache for ReadOnlyCache<'a, C> { struct DynamicChainListener<'a, L: chain::Listen + ?Sized>(&'a L); impl<'a, L: chain::Listen + ?Sized> chain::Listen for DynamicChainListener<'a, L> { - fn filtered_block_connected(&self, _header: &BlockHeader, _txdata: &chain::transaction::TransactionData, _height: u32) { + fn filtered_block_connected( + &self, _header: &BlockHeader, _txdata: &chain::transaction::TransactionData, _height: u32, + ) { unreachable!() } @@ -234,7 +245,9 @@ impl<'a, L: chain::Listen + ?Sized> chain::Listen for ChainListenerSet<'a, L> { } } - fn filtered_block_connected(&self, header: &BlockHeader, txdata: &chain::transaction::TransactionData, height: u32) { + fn filtered_block_connected( + &self, header: &BlockHeader, txdata: &chain::transaction::TransactionData, height: u32, + ) { for (starting_height, chain_listener) in self.0.iter() { if height > *starting_height { chain_listener.filtered_block_connected(header, txdata, height); @@ -249,8 +262,8 @@ impl<'a, L: chain::Listen + ?Sized> chain::Listen for ChainListenerSet<'a, L> { #[cfg(test)] mod tests { - use crate::test_utils::{Blockchain, MockChainListener}; use super::*; + use crate::test_utils::{Blockchain, MockChainListener}; use bitcoin::network::constants::Network; @@ -265,8 +278,7 @@ mod tests { let listener_2 = MockChainListener::new() .expect_block_connected(*chain.at_height(3)) .expect_block_connected(*chain.at_height(4)); - let listener_3 = MockChainListener::new() - .expect_block_connected(*chain.at_height(4)); + let listener_3 = MockChainListener::new().expect_block_connected(*chain.at_height(4)); let listeners = vec![ (chain.at_height(1).block_hash, &listener_1 as &dyn chain::Listen), diff --git a/lightning-block-sync/src/lib.rs b/lightning-block-sync/src/lib.rs index 3561a1b5d76..81fc6ce88ac 100644 --- a/lightning-block-sync/src/lib.rs +++ b/lightning-block-sync/src/lib.rs @@ -16,10 +16,8 @@ // Prefix these with `rustdoc::` when we update our MSRV to be >= 1.52 to remove warnings. #![deny(broken_intra_doc_links)] #![deny(private_intra_doc_links)] - #![deny(missing_docs)] #![deny(unsafe_code)] - #![cfg_attr(docsrs, feature(doc_auto_cfg))] #[cfg(any(feature = "rest-client", feature = "rpc-client"))] @@ -59,18 +57,21 @@ use std::ops::Deref; use std::pin::Pin; /// Abstract type for retrieving block headers and data. -pub trait BlockSource : Sync + Send { +pub trait BlockSource: Sync + Send { /// Returns the header for a given hash. A height hint may be provided in case a block source /// cannot easily find headers based on a hash. This is merely a hint and thus the returned /// header must have the same hash as was requested. Otherwise, an error must be returned. /// /// Implementations that cannot find headers based on the hash should return a `Transient` error /// when `height_hint` is `None`. - fn get_header<'a>(&'a self, header_hash: &'a BlockHash, height_hint: Option) -> AsyncBlockSourceResult<'a, BlockHeaderData>; + fn get_header<'a>( + &'a self, header_hash: &'a BlockHash, height_hint: Option, + ) -> AsyncBlockSourceResult<'a, BlockHeaderData>; /// Returns the block for a given hash. A headers-only block source should return a `Transient` /// error. - fn get_block<'a>(&'a self, header_hash: &'a BlockHash) -> AsyncBlockSourceResult<'a, BlockData>; + fn get_block<'a>(&'a self, header_hash: &'a BlockHash) + -> AsyncBlockSourceResult<'a, BlockData>; /// Returns the hash of the best block and, optionally, its height. /// @@ -87,7 +88,8 @@ pub type BlockSourceResult = Result; // TODO: Replace with BlockSourceResult once `async` trait functions are supported. For details, // see: https://areweasyncyet.rs. /// Result type for asynchronous `BlockSource` requests. -pub type AsyncBlockSourceResult<'a, T> = Pin> + 'a + Send>>; +pub type AsyncBlockSourceResult<'a, T> = + Pin> + 'a + Send>>; /// Error type for `BlockSource` requests. /// @@ -112,20 +114,18 @@ pub enum BlockSourceErrorKind { impl BlockSourceError { /// Creates a new persistent error originated from the given error. pub fn persistent(error: E) -> Self - where E: Into> { - Self { - kind: BlockSourceErrorKind::Persistent, - error: error.into(), - } + where + E: Into>, + { + Self { kind: BlockSourceErrorKind::Persistent, error: error.into() } } /// Creates a new transient error originated from the given error. pub fn transient(error: E) -> Self - where E: Into> { - Self { - kind: BlockSourceErrorKind::Transient, - error: error.into(), - } + where + E: Into>, + { + Self { kind: BlockSourceErrorKind::Transient, error: error.into() } } /// Returns the kind of error. @@ -182,7 +182,9 @@ pub enum BlockData { /// Hence, there is a trade-off between a lower memory footprint and potentially increased network /// I/O as headers are re-fetched during fork detection. pub struct SpvClient<'a, P: Poll, C: Cache, L: Deref> -where L::Target: chain::Listen { +where + L::Target: chain::Listen, +{ chain_tip: ValidatedBlockHeader, chain_poller: P, chain_notifier: ChainNotifier<'a, C, L>, @@ -228,7 +230,10 @@ impl Cache for UnboundedCache { } } -impl<'a, P: Poll, C: Cache, L: Deref> SpvClient<'a, P, C, L> where L::Target: chain::Listen { +impl<'a, P: Poll, C: Cache, L: Deref> SpvClient<'a, P, C, L> +where + L::Target: chain::Listen, +{ /// Creates a new SPV client using `chain_tip` as the best known chain tip. /// /// Subsequent calls to [`poll_best_tip`] will poll for the best chain tip using the given chain @@ -240,9 +245,7 @@ impl<'a, P: Poll, C: Cache, L: Deref> SpvClient<'a, P, C, L> where L::Target: ch /// /// [`poll_best_tip`]: SpvClient::poll_best_tip pub fn new( - chain_tip: ValidatedBlockHeader, - chain_poller: P, - header_cache: &'a mut C, + chain_tip: ValidatedBlockHeader, chain_poller: P, header_cache: &'a mut C, chain_listener: L, ) -> Self { let chain_notifier = ChainNotifier { header_cache, chain_listener }; @@ -275,8 +278,10 @@ impl<'a, P: Poll, C: Cache, L: Deref> SpvClient<'a, P, C, L> where L::Target: ch /// Updates the chain tip, syncing the chain listener with any connected or disconnected /// blocks. Returns whether there were any such blocks. async fn update_chain_tip(&mut self, best_chain_tip: ValidatedBlockHeader) -> bool { - match self.chain_notifier.synchronize_listener( - best_chain_tip, &self.chain_tip, &mut self.chain_poller).await + match self + .chain_notifier + .synchronize_listener(best_chain_tip, &self.chain_tip, &mut self.chain_poller) + .await { Ok(_) => { self.chain_tip = best_chain_tip; @@ -294,7 +299,10 @@ impl<'a, P: Poll, C: Cache, L: Deref> SpvClient<'a, P, C, L> where L::Target: ch /// Notifies [listeners] of blocks that have been connected or disconnected from the chain. /// /// [listeners]: lightning::chain::Listen -pub struct ChainNotifier<'a, C: Cache, L: Deref> where L::Target: chain::Listen { +pub struct ChainNotifier<'a, C: Cache, L: Deref> +where + L::Target: chain::Listen, +{ /// Cache for looking up headers before fetching from a block source. header_cache: &'a mut C, @@ -320,7 +328,10 @@ struct ChainDifference { connected_blocks: Vec, } -impl<'a, C: Cache, L: Deref> ChainNotifier<'a, C, L> where L::Target: chain::Listen { +impl<'a, C: Cache, L: Deref> ChainNotifier<'a, C, L> +where + L::Target: chain::Listen, +{ /// Finds the first common ancestor between `new_header` and `old_header`, disconnecting blocks /// from `old_header` to get to that point and then connecting blocks until `new_header`. /// @@ -329,19 +340,16 @@ impl<'a, C: Cache, L: Deref> ChainNotifier<'a, C, L> where L::Target: chain::Lis /// ended up which may not be `new_header`. Note that the returned `Err` contains `Some` header /// if and only if the transition from `old_header` to `new_header` is valid. async fn synchronize_listener( - &mut self, - new_header: ValidatedBlockHeader, - old_header: &ValidatedBlockHeader, + &mut self, new_header: ValidatedBlockHeader, old_header: &ValidatedBlockHeader, chain_poller: &mut P, ) -> Result<(), (BlockSourceError, Option)> { - let difference = self.find_difference(new_header, old_header, chain_poller).await + let difference = self + .find_difference(new_header, old_header, chain_poller) + .await .map_err(|e| (e, None))?; self.disconnect_blocks(difference.disconnected_blocks); - self.connect_blocks( - difference.common_ancestor, - difference.connected_blocks, - chain_poller, - ).await + self.connect_blocks(difference.common_ancestor, difference.connected_blocks, chain_poller) + .await } /// Returns the changes needed to produce the chain with `current_header` as its tip from the @@ -349,9 +357,7 @@ impl<'a, C: Cache, L: Deref> ChainNotifier<'a, C, L> where L::Target: chain::Lis /// /// Walks backwards from `current_header` and `prev_header`, finding the common ancestor. async fn find_difference( - &self, - current_header: ValidatedBlockHeader, - prev_header: &ValidatedBlockHeader, + &self, current_header: ValidatedBlockHeader, prev_header: &ValidatedBlockHeader, chain_poller: &mut P, ) -> BlockSourceResult { let mut disconnected_blocks = Vec::new(); @@ -385,9 +391,7 @@ impl<'a, C: Cache, L: Deref> ChainNotifier<'a, C, L> where L::Target: chain::Lis /// Returns the previous header for the given header, either by looking it up in the cache or /// fetching it if not found. async fn look_up_previous_header( - &self, - chain_poller: &mut P, - header: &ValidatedBlockHeader, + &self, chain_poller: &mut P, header: &ValidatedBlockHeader, ) -> BlockSourceResult { match self.header_cache.look_up(&header.header.prev_blockhash) { Some(prev_header) => Ok(*prev_header), @@ -407,16 +411,13 @@ impl<'a, C: Cache, L: Deref> ChainNotifier<'a, C, L> where L::Target: chain::Lis /// Notifies the chain listeners of connected blocks. async fn connect_blocks( - &mut self, - mut new_tip: ValidatedBlockHeader, - mut connected_blocks: Vec, - chain_poller: &mut P, + &mut self, mut new_tip: ValidatedBlockHeader, + mut connected_blocks: Vec, chain_poller: &mut P, ) -> Result<(), (BlockSourceError, Option)> { for header in connected_blocks.drain(..).rev() { let height = header.height; - let block_data = chain_poller - .fetch_block(&header).await - .map_err(|e| (e, Some(new_tip)))?; + let block_data = + chain_poller.fetch_block(&header).await.map_err(|e| (e, Some(new_tip)))?; debug_assert_eq!(block_data.block_hash, header.block_hash); match block_data.deref() { @@ -438,8 +439,8 @@ impl<'a, C: Cache, L: Deref> ChainNotifier<'a, C, L> where L::Target: chain::Lis #[cfg(test)] mod spv_client_tests { - use crate::test_utils::{Blockchain, NullChainListener}; use super::*; + use crate::test_utils::{Blockchain, NullChainListener}; use bitcoin::network::constants::Network; @@ -565,8 +566,8 @@ mod spv_client_tests { #[cfg(test)] mod chain_notifier_tests { - use crate::test_utils::{Blockchain, MockChainListener}; use super::*; + use crate::test_utils::{Blockchain, MockChainListener}; use bitcoin::network::constants::Network; @@ -579,10 +580,8 @@ mod chain_notifier_tests { let chain_listener = &MockChainListener::new() .expect_block_connected(*chain.at_height(2)) .expect_block_connected(*new_tip); - let mut notifier = ChainNotifier { - header_cache: &mut chain.header_cache(0..=1), - chain_listener, - }; + let mut notifier = + ChainNotifier { header_cache: &mut chain.header_cache(0..=1), chain_listener }; let mut poller = poll::ChainPoller::new(&mut chain, Network::Testnet); match notifier.synchronize_listener(new_tip, &old_tip, &mut poller).await { Err((e, _)) => panic!("Unexpected error: {:?}", e), @@ -598,10 +597,8 @@ mod chain_notifier_tests { let new_tip = test_chain.tip(); let old_tip = main_chain.tip(); let chain_listener = &MockChainListener::new(); - let mut notifier = ChainNotifier { - header_cache: &mut main_chain.header_cache(0..=1), - chain_listener, - }; + let mut notifier = + ChainNotifier { header_cache: &mut main_chain.header_cache(0..=1), chain_listener }; let mut poller = poll::ChainPoller::new(&mut test_chain, Network::Testnet); match notifier.synchronize_listener(new_tip, &old_tip, &mut poller).await { Err((e, _)) => { @@ -622,10 +619,8 @@ mod chain_notifier_tests { let chain_listener = &MockChainListener::new() .expect_block_disconnected(*old_tip) .expect_block_connected(*new_tip); - let mut notifier = ChainNotifier { - header_cache: &mut main_chain.header_cache(0..=2), - chain_listener, - }; + let mut notifier = + ChainNotifier { header_cache: &mut main_chain.header_cache(0..=2), chain_listener }; let mut poller = poll::ChainPoller::new(&mut fork_chain, Network::Testnet); match notifier.synchronize_listener(new_tip, &old_tip, &mut poller).await { Err((e, _)) => panic!("Unexpected error: {:?}", e), @@ -645,10 +640,8 @@ mod chain_notifier_tests { .expect_block_disconnected(*old_tip) .expect_block_disconnected(*main_chain.at_height(2)) .expect_block_connected(*new_tip); - let mut notifier = ChainNotifier { - header_cache: &mut main_chain.header_cache(0..=3), - chain_listener, - }; + let mut notifier = + ChainNotifier { header_cache: &mut main_chain.header_cache(0..=3), chain_listener }; let mut poller = poll::ChainPoller::new(&mut fork_chain, Network::Testnet); match notifier.synchronize_listener(new_tip, &old_tip, &mut poller).await { Err((e, _)) => panic!("Unexpected error: {:?}", e), @@ -668,10 +661,8 @@ mod chain_notifier_tests { .expect_block_disconnected(*old_tip) .expect_block_connected(*fork_chain.at_height(2)) .expect_block_connected(*new_tip); - let mut notifier = ChainNotifier { - header_cache: &mut main_chain.header_cache(0..=2), - chain_listener, - }; + let mut notifier = + ChainNotifier { header_cache: &mut main_chain.header_cache(0..=2), chain_listener }; let mut poller = poll::ChainPoller::new(&mut fork_chain, Network::Testnet); match notifier.synchronize_listener(new_tip, &old_tip, &mut poller).await { Err((e, _)) => panic!("Unexpected error: {:?}", e), @@ -686,10 +677,8 @@ mod chain_notifier_tests { let new_tip = chain.tip(); let old_tip = chain.at_height(1); let chain_listener = &MockChainListener::new(); - let mut notifier = ChainNotifier { - header_cache: &mut chain.header_cache(0..=1), - chain_listener, - }; + let mut notifier = + ChainNotifier { header_cache: &mut chain.header_cache(0..=1), chain_listener }; let mut poller = poll::ChainPoller::new(&mut chain, Network::Testnet); match notifier.synchronize_listener(new_tip, &old_tip, &mut poller).await { Err((_, tip)) => assert_eq!(tip, None), @@ -704,10 +693,8 @@ mod chain_notifier_tests { let new_tip = chain.tip(); let old_tip = chain.at_height(1); let chain_listener = &MockChainListener::new(); - let mut notifier = ChainNotifier { - header_cache: &mut chain.header_cache(0..=3), - chain_listener, - }; + let mut notifier = + ChainNotifier { header_cache: &mut chain.header_cache(0..=3), chain_listener }; let mut poller = poll::ChainPoller::new(&mut chain, Network::Testnet); match notifier.synchronize_listener(new_tip, &old_tip, &mut poller).await { Err((_, tip)) => assert_eq!(tip, Some(old_tip)), @@ -721,12 +708,9 @@ mod chain_notifier_tests { let new_tip = chain.tip(); let old_tip = chain.at_height(1); - let chain_listener = &MockChainListener::new() - .expect_block_connected(*chain.at_height(2)); - let mut notifier = ChainNotifier { - header_cache: &mut chain.header_cache(0..=3), - chain_listener, - }; + let chain_listener = &MockChainListener::new().expect_block_connected(*chain.at_height(2)); + let mut notifier = + ChainNotifier { header_cache: &mut chain.header_cache(0..=3), chain_listener }; let mut poller = poll::ChainPoller::new(&mut chain, Network::Testnet); match notifier.synchronize_listener(new_tip, &old_tip, &mut poller).await { Err((_, tip)) => assert_eq!(tip, Some(chain.at_height(2))), @@ -743,15 +727,12 @@ mod chain_notifier_tests { let chain_listener = &MockChainListener::new() .expect_filtered_block_connected(*chain.at_height(2)) .expect_filtered_block_connected(*new_tip); - let mut notifier = ChainNotifier { - header_cache: &mut chain.header_cache(0..=1), - chain_listener, - }; + let mut notifier = + ChainNotifier { header_cache: &mut chain.header_cache(0..=1), chain_listener }; let mut poller = poll::ChainPoller::new(&mut chain, Network::Testnet); match notifier.synchronize_listener(new_tip, &old_tip, &mut poller).await { Err((e, _)) => panic!("Unexpected error: {:?}", e), Ok(_) => {}, } } - } diff --git a/lightning-block-sync/src/poll.rs b/lightning-block-sync/src/poll.rs index e7171cf3656..96d1b2eae76 100644 --- a/lightning-block-sync/src/poll.rs +++ b/lightning-block-sync/src/poll.rs @@ -1,6 +1,9 @@ //! Adapters that make one or more [`BlockSource`]s simpler to poll for new chain tip transitions. -use crate::{AsyncBlockSourceResult, BlockData, BlockHeaderData, BlockSource, BlockSourceError, BlockSourceResult}; +use crate::{ + AsyncBlockSourceResult, BlockData, BlockHeaderData, BlockSource, BlockSourceError, + BlockSourceResult, +}; use bitcoin::hash_types::BlockHash; use bitcoin::network::constants::Network; @@ -17,16 +20,19 @@ use std::ops::Deref; /// [`ChainPoller`]: ../struct.ChainPoller.html pub trait Poll { /// Returns a chain tip in terms of its relationship to the provided chain tip. - fn poll_chain_tip<'a>(&'a self, best_known_chain_tip: ValidatedBlockHeader) -> - AsyncBlockSourceResult<'a, ChainTip>; + fn poll_chain_tip<'a>( + &'a self, best_known_chain_tip: ValidatedBlockHeader, + ) -> AsyncBlockSourceResult<'a, ChainTip>; /// Returns the header that preceded the given header in the chain. - fn look_up_previous_header<'a>(&'a self, header: &'a ValidatedBlockHeader) -> - AsyncBlockSourceResult<'a, ValidatedBlockHeader>; + fn look_up_previous_header<'a>( + &'a self, header: &'a ValidatedBlockHeader, + ) -> AsyncBlockSourceResult<'a, ValidatedBlockHeader>; /// Returns the block associated with the given header. - fn fetch_block<'a>(&'a self, header: &'a ValidatedBlockHeader) -> - AsyncBlockSourceResult<'a, ValidatedBlock>; + fn fetch_block<'a>( + &'a self, header: &'a ValidatedBlockHeader, + ) -> AsyncBlockSourceResult<'a, ValidatedBlock>; } /// A chain tip relative to another chain tip in terms of block hash and chainwork. @@ -59,7 +65,8 @@ impl Validate for BlockHeaderData { type T = ValidatedBlockHeader; fn validate(self, block_hash: BlockHash) -> BlockSourceResult { - let pow_valid_block_hash = self.header + let pow_valid_block_hash = self + .header .validate_pow(&self.header.target()) .map_err(BlockSourceError::persistent)?; @@ -80,9 +87,8 @@ impl Validate for BlockData { BlockData::HeaderOnly(header) => header, }; - let pow_valid_block_hash = header - .validate_pow(&header.target()) - .map_err(BlockSourceError::persistent)?; + let pow_valid_block_hash = + header.validate_pow(&header.target()).map_err(BlockSourceError::persistent)?; if pow_valid_block_hash != block_hash { return Err(BlockSourceError::persistent("invalid block hash")); @@ -120,7 +126,9 @@ impl std::ops::Deref for ValidatedBlockHeader { impl ValidatedBlockHeader { /// Checks that the header correctly builds on previous_header: the claimed work differential /// matches the actual PoW and the difficulty transition is possible, i.e., within 4x. - fn check_builds_on(&self, previous_header: &ValidatedBlockHeader, network: Network) -> BlockSourceResult<()> { + fn check_builds_on( + &self, previous_header: &ValidatedBlockHeader, network: Network, + ) -> BlockSourceResult<()> { if self.header.prev_blockhash != previous_header.block_hash { return Err(BlockSourceError::persistent("invalid previous block hash")); } @@ -141,28 +149,28 @@ impl ValidatedBlockHeader { let min_target = previous_target >> 2; let max_target = previous_target << 2; if target > max_target || target < min_target { - return Err(BlockSourceError::persistent("invalid difficulty transition")) + return Err(BlockSourceError::persistent("invalid difficulty transition")); } } else if self.header.bits != previous_header.header.bits { - return Err(BlockSourceError::persistent("invalid difficulty")) + return Err(BlockSourceError::persistent("invalid difficulty")); } } Ok(()) } - /// Returns the [`BestBlock`] corresponding to this validated block header, which can be passed - /// into [`ChannelManager::new`] as part of its [`ChainParameters`]. Useful for ensuring that - /// the [`SpvClient`] and [`ChannelManager`] are initialized to the same block during a fresh - /// start. - /// - /// [`SpvClient`]: crate::SpvClient - /// [`ChainParameters`]: lightning::ln::channelmanager::ChainParameters - /// [`ChannelManager`]: lightning::ln::channelmanager::ChannelManager - /// [`ChannelManager::new`]: lightning::ln::channelmanager::ChannelManager::new - pub fn to_best_block(&self) -> BestBlock { - BestBlock::new(self.block_hash, self.inner.height) - } + /// Returns the [`BestBlock`] corresponding to this validated block header, which can be passed + /// into [`ChannelManager::new`] as part of its [`ChainParameters`]. Useful for ensuring that + /// the [`SpvClient`] and [`ChannelManager`] are initialized to the same block during a fresh + /// start. + /// + /// [`SpvClient`]: crate::SpvClient + /// [`ChainParameters`]: lightning::ln::channelmanager::ChainParameters + /// [`ChannelManager`]: lightning::ln::channelmanager::ChannelManager + /// [`ChannelManager::new`]: lightning::ln::channelmanager::ChannelManager::new + pub fn to_best_block(&self) -> BestBlock { + BestBlock::new(self.block_hash, self.inner.height) + } } /// A block with validated data against its transaction list and corresponding block hash. @@ -191,12 +199,12 @@ mod sealed { /// /// Other `Poll` implementations should be built using `ChainPoller` as it provides the simplest way /// of validating chain data and checking consistency. -pub struct ChainPoller + Sized + Send + Sync, T: BlockSource + ?Sized> { +pub struct ChainPoller + Sized + Send + Sync, T: BlockSource + ?Sized> { block_source: B, network: Network, } -impl + Sized + Send + Sync, T: BlockSource + ?Sized> ChainPoller { +impl + Sized + Send + Sync, T: BlockSource + ?Sized> ChainPoller { /// Creates a new poller for the given block source. /// /// If the `network` parameter is mainnet, then the difficulty between blocks is checked for @@ -206,19 +214,20 @@ impl + Sized + Send + Sync, T: BlockSource + ?Sized> ChainPol } } -impl + Sized + Send + Sync, T: BlockSource + ?Sized> Poll for ChainPoller { - fn poll_chain_tip<'a>(&'a self, best_known_chain_tip: ValidatedBlockHeader) -> - AsyncBlockSourceResult<'a, ChainTip> - { +impl + Sized + Send + Sync, T: BlockSource + ?Sized> Poll + for ChainPoller +{ + fn poll_chain_tip<'a>( + &'a self, best_known_chain_tip: ValidatedBlockHeader, + ) -> AsyncBlockSourceResult<'a, ChainTip> { Box::pin(async move { let (block_hash, height) = self.block_source.get_best_block().await?; if block_hash == best_known_chain_tip.header.block_hash() { return Ok(ChainTip::Common); } - let chain_tip = self.block_source - .get_header(&block_hash, height).await? - .validate(block_hash)?; + let chain_tip = + self.block_source.get_header(&block_hash, height).await?.validate(block_hash)?; if chain_tip.chainwork > best_known_chain_tip.chainwork { Ok(ChainTip::Better(chain_tip)) } else { @@ -227,9 +236,9 @@ impl + Sized + Send + Sync, T: BlockSource + ?Sized> Poll for }) } - fn look_up_previous_header<'a>(&'a self, header: &'a ValidatedBlockHeader) -> - AsyncBlockSourceResult<'a, ValidatedBlockHeader> - { + fn look_up_previous_header<'a>( + &'a self, header: &'a ValidatedBlockHeader, + ) -> AsyncBlockSourceResult<'a, ValidatedBlockHeader> { Box::pin(async move { if header.height == 0 { return Err(BlockSourceError::persistent("genesis block reached")); @@ -237,8 +246,10 @@ impl + Sized + Send + Sync, T: BlockSource + ?Sized> Poll for let previous_hash = &header.header.prev_blockhash; let height = header.height - 1; - let previous_header = self.block_source - .get_header(previous_hash, Some(height)).await? + let previous_header = self + .block_source + .get_header(previous_hash, Some(height)) + .await? .validate(*previous_hash)?; header.check_builds_on(&previous_header, self.network)?; @@ -246,22 +257,20 @@ impl + Sized + Send + Sync, T: BlockSource + ?Sized> Poll for }) } - fn fetch_block<'a>(&'a self, header: &'a ValidatedBlockHeader) -> - AsyncBlockSourceResult<'a, ValidatedBlock> - { + fn fetch_block<'a>( + &'a self, header: &'a ValidatedBlockHeader, + ) -> AsyncBlockSourceResult<'a, ValidatedBlock> { Box::pin(async move { - self.block_source - .get_block(&header.block_hash).await? - .validate(header.block_hash) + self.block_source.get_block(&header.block_hash).await?.validate(header.block_hash) }) } } #[cfg(test)] mod tests { - use crate::*; - use crate::test_utils::Blockchain; use super::*; + use crate::test_utils::Blockchain; + use crate::*; use bitcoin::util::uint::Uint256; #[tokio::test] @@ -308,7 +317,10 @@ mod tests { match poller.poll_chain_tip(best_known_chain_tip).await { Err(e) => { assert_eq!(e.kind(), BlockSourceErrorKind::Persistent); - assert_eq!(e.into_inner().as_ref().to_string(), "block target correct but not attained"); + assert_eq!( + e.into_inner().as_ref().to_string(), + "block target correct but not attained" + ); }, Ok(_) => panic!("Expected error"), } diff --git a/lightning-block-sync/src/rest.rs b/lightning-block-sync/src/rest.rs index 5690da12ea0..d77126675ec 100644 --- a/lightning-block-sync/src/rest.rs +++ b/lightning-block-sync/src/rest.rs @@ -1,14 +1,14 @@ //! Simple REST client implementation which implements [`BlockSource`] against a Bitcoin Core REST //! endpoint. -use crate::{BlockData, BlockHeaderData, BlockSource, AsyncBlockSourceResult}; -use crate::http::{BinaryResponse, HttpEndpoint, HttpClient, JsonResponse}; -use crate::gossip::UtxoSource; use crate::convert::GetUtxosResponse; +use crate::gossip::UtxoSource; +use crate::http::{BinaryResponse, HttpClient, HttpEndpoint, JsonResponse}; +use crate::{AsyncBlockSourceResult, BlockData, BlockHeaderData, BlockSource}; -use bitcoin::OutPoint; use bitcoin::hash_types::BlockHash; use bitcoin::hashes::hex::ToHex; +use bitcoin::OutPoint; use std::convert::TryFrom; use std::convert::TryInto; @@ -30,11 +30,16 @@ impl RestClient { /// Requests a resource encoded in `F` format and interpreted as type `T`. pub async fn request_resource(&self, resource_path: &str) -> std::io::Result - where F: TryFrom, Error = std::io::Error> + TryInto { + where + F: TryFrom, Error = std::io::Error> + TryInto, + { let host = format!("{}:{}", self.endpoint.host(), self.endpoint.port()); let uri = format!("{}/{}", self.endpoint.path().trim_end_matches("/"), resource_path); - let mut client = if let Some(client) = self.client.lock().unwrap().take() { client } - else { HttpClient::connect(&self.endpoint)? }; + let mut client = if let Some(client) = self.client.lock().unwrap().take() { + client + } else { + HttpClient::connect(&self.endpoint)? + }; let res = client.get::(&uri, &host).await?.try_into(); *self.client.lock().unwrap() = Some(client); res @@ -42,29 +47,37 @@ impl RestClient { } impl BlockSource for RestClient { - fn get_header<'a>(&'a self, header_hash: &'a BlockHash, _height: Option) -> AsyncBlockSourceResult<'a, BlockHeaderData> { + fn get_header<'a>( + &'a self, header_hash: &'a BlockHash, _height: Option, + ) -> AsyncBlockSourceResult<'a, BlockHeaderData> { Box::pin(async move { let resource_path = format!("headers/1/{}.json", header_hash.to_hex()); Ok(self.request_resource::(&resource_path).await?) }) } - fn get_block<'a>(&'a self, header_hash: &'a BlockHash) -> AsyncBlockSourceResult<'a, BlockData> { + fn get_block<'a>( + &'a self, header_hash: &'a BlockHash, + ) -> AsyncBlockSourceResult<'a, BlockData> { Box::pin(async move { let resource_path = format!("block/{}.bin", header_hash.to_hex()); - Ok(BlockData::FullBlock(self.request_resource::(&resource_path).await?)) + Ok(BlockData::FullBlock( + self.request_resource::(&resource_path).await?, + )) }) } fn get_best_block<'a>(&'a self) -> AsyncBlockSourceResult<'a, (BlockHash, Option)> { - Box::pin(async move { - Ok(self.request_resource::("chaininfo.json").await?) - }) + Box::pin( + async move { Ok(self.request_resource::("chaininfo.json").await?) }, + ) } } impl UtxoSource for RestClient { - fn get_block_hash_by_height<'a>(&'a self, block_height: u32) -> AsyncBlockSourceResult<'a, BlockHash> { + fn get_block_hash_by_height<'a>( + &'a self, block_height: u32, + ) -> AsyncBlockSourceResult<'a, BlockHash> { Box::pin(async move { let resource_path = format!("blockhashbyheight/{}.bin", block_height); Ok(self.request_resource::(&resource_path).await?) @@ -73,7 +86,8 @@ impl UtxoSource for RestClient { fn is_output_unspent<'a>(&'a self, outpoint: OutPoint) -> AsyncBlockSourceResult<'a, bool> { Box::pin(async move { - let resource_path = format!("getutxos/{}-{}.json", outpoint.txid.to_hex(), outpoint.vout); + let resource_path = + format!("getutxos/{}-{}.json", outpoint.txid.to_hex(), outpoint.vout); let utxo_result = self.request_resource::(&resource_path).await?; Ok(utxo_result.hit_bitmap_nonempty) @@ -84,8 +98,8 @@ impl UtxoSource for RestClient { #[cfg(test)] mod tests { use super::*; - use crate::http::BinaryResponse; use crate::http::client_tests::{HttpServer, MessageBody}; + use crate::http::BinaryResponse; use bitcoin::hashes::Hash; /// Parses binary data as a string-encoded `u32`. @@ -98,7 +112,7 @@ mod tests { Ok(s) => match u32::from_str_radix(s, 10) { Err(e) => Err(std::io::Error::new(std::io::ErrorKind::InvalidData, e)), Ok(n) => Ok(n), - } + }, } } } @@ -141,7 +155,7 @@ mod tests { let server = HttpServer::responding_with_ok(MessageBody::Content( // A real response contains a few more fields, but we actually only look at the // "bitmap" field, so this should suffice for testing - "{\"chainHeight\": 1, \"bitmap\":\"0\",\"utxos\":[]}" + "{\"chainHeight\": 1, \"bitmap\":\"0\",\"utxos\":[]}", )); let client = RestClient::new(server.endpoint()).unwrap(); @@ -155,7 +169,7 @@ mod tests { let server = HttpServer::responding_with_ok(MessageBody::Content( // A real response contains lots more data, but we actually only look at the "bitmap" // field, so this should suffice for testing - "{\"chainHeight\": 1, \"bitmap\":\"1\",\"utxos\":[]}" + "{\"chainHeight\": 1, \"bitmap\":\"1\",\"utxos\":[]}", )); let client = RestClient::new(server.endpoint()).unwrap(); diff --git a/lightning-block-sync/src/rpc.rs b/lightning-block-sync/src/rpc.rs index 0ad94040aca..ab73851359d 100644 --- a/lightning-block-sync/src/rpc.rs +++ b/lightning-block-sync/src/rpc.rs @@ -1,9 +1,9 @@ //! Simple RPC client implementation which implements [`BlockSource`] against a Bitcoin Core RPC //! endpoint. -use crate::{BlockData, BlockHeaderData, BlockSource, AsyncBlockSourceResult}; -use crate::http::{HttpClient, HttpEndpoint, HttpError, JsonResponse}; use crate::gossip::UtxoSource; +use crate::http::{HttpClient, HttpEndpoint, HttpError, JsonResponse}; +use crate::{AsyncBlockSourceResult, BlockData, BlockHeaderData, BlockSource}; use bitcoin::hash_types::BlockHash; use bitcoin::hashes::hex::ToHex; @@ -29,9 +29,9 @@ pub struct RpcError { } impl fmt::Display for RpcError { - fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { - write!(f, "RPC error {}: {}", self.code, self.message) - } + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + write!(f, "RPC error {}: {}", self.code, self.message) + } } impl Error for RpcError {} @@ -64,8 +64,12 @@ impl RpcClient { /// /// When an `Err` is returned, [`std::io::Error::into_inner`] may contain an [`RpcError`] if /// [`std::io::Error::kind`] is [`std::io::ErrorKind::Other`]. - pub async fn call_method(&self, method: &str, params: &[serde_json::Value]) -> std::io::Result - where JsonResponse: TryFrom, Error = std::io::Error> + TryInto { + pub async fn call_method( + &self, method: &str, params: &[serde_json::Value], + ) -> std::io::Result + where + JsonResponse: TryFrom, Error = std::io::Error> + TryInto, + { let host = format!("{}:{}", self.endpoint.host(), self.endpoint.port()); let uri = self.endpoint.path(); let content = serde_json::json!({ @@ -74,9 +78,13 @@ impl RpcClient { "id": &self.id.fetch_add(1, Ordering::AcqRel).to_string() }); - let mut client = if let Some(client) = self.client.lock().unwrap().take() { client } - else { HttpClient::connect(&self.endpoint)? }; - let http_response = client.post::(&uri, &host, &self.basic_auth, content).await; + let mut client = if let Some(client) = self.client.lock().unwrap().take() { + client + } else { + HttpClient::connect(&self.endpoint)? + }; + let http_response = + client.post::(&uri, &host, &self.basic_auth, content).await; *self.client.lock().unwrap() = Some(client); let mut response = match http_response { @@ -94,23 +102,30 @@ impl RpcClient { }; if !response.is_object() { - return Err(std::io::Error::new(std::io::ErrorKind::InvalidData, "expected JSON object")); + return Err(std::io::Error::new( + std::io::ErrorKind::InvalidData, + "expected JSON object", + )); } let error = &response["error"]; if !error.is_null() { // TODO: Examine error code for a more precise std::io::ErrorKind. - let rpc_error = RpcError { - code: error["code"].as_i64().unwrap_or(-1), - message: error["message"].as_str().unwrap_or("unknown error").to_string() + let rpc_error = RpcError { + code: error["code"].as_i64().unwrap_or(-1), + message: error["message"].as_str().unwrap_or("unknown error").to_string(), }; return Err(std::io::Error::new(std::io::ErrorKind::Other, rpc_error)); } let result = match response.get_mut("result") { Some(result) => result.take(), - None => - return Err(std::io::Error::new(std::io::ErrorKind::InvalidData, "expected JSON result")), + None => { + return Err(std::io::Error::new( + std::io::ErrorKind::InvalidData, + "expected JSON result", + )) + }, }; JsonResponse(result).try_into() @@ -118,14 +133,18 @@ impl RpcClient { } impl BlockSource for RpcClient { - fn get_header<'a>(&'a self, header_hash: &'a BlockHash, _height: Option) -> AsyncBlockSourceResult<'a, BlockHeaderData> { + fn get_header<'a>( + &'a self, header_hash: &'a BlockHash, _height: Option, + ) -> AsyncBlockSourceResult<'a, BlockHeaderData> { Box::pin(async move { let header_hash = serde_json::json!(header_hash.to_hex()); Ok(self.call_method("getblockheader", &[header_hash]).await?) }) } - fn get_block<'a>(&'a self, header_hash: &'a BlockHash) -> AsyncBlockSourceResult<'a, BlockData> { + fn get_block<'a>( + &'a self, header_hash: &'a BlockHash, + ) -> AsyncBlockSourceResult<'a, BlockData> { Box::pin(async move { let header_hash = serde_json::json!(header_hash.to_hex()); let verbosity = serde_json::json!(0); @@ -134,14 +153,14 @@ impl BlockSource for RpcClient { } fn get_best_block<'a>(&'a self) -> AsyncBlockSourceResult<'a, (BlockHash, Option)> { - Box::pin(async move { - Ok(self.call_method("getblockchaininfo", &[]).await?) - }) + Box::pin(async move { Ok(self.call_method("getblockchaininfo", &[]).await?) }) } } impl UtxoSource for RpcClient { - fn get_block_hash_by_height<'a>(&'a self, block_height: u32) -> AsyncBlockSourceResult<'a, BlockHash> { + fn get_block_hash_by_height<'a>( + &'a self, block_height: u32, + ) -> AsyncBlockSourceResult<'a, BlockHash> { Box::pin(async move { let height_param = serde_json::json!(block_height); Ok(self.call_method("getblockhash", &[height_param]).await?) @@ -153,8 +172,8 @@ impl UtxoSource for RpcClient { let txid_param = serde_json::json!(outpoint.txid.to_hex()); let vout_param = serde_json::json!(outpoint.vout); let include_mempool = serde_json::json!(false); - let utxo_opt: serde_json::Value = self.call_method( - "gettxout", &[txid_param, vout_param, include_mempool]).await?; + let utxo_opt: serde_json::Value = + self.call_method("gettxout", &[txid_param, vout_param, include_mempool]).await?; Ok(!utxo_opt.is_null()) }) } @@ -230,7 +249,7 @@ mod tests { #[tokio::test] async fn call_method_returning_missing_result() { - let response = serde_json::json!({ }); + let response = serde_json::json!({}); let server = HttpServer::responding_with_ok(MessageBody::Content(response)); let client = RpcClient::new(CREDENTIALS, server.endpoint()).unwrap(); diff --git a/lightning-block-sync/src/test_utils.rs b/lightning-block-sync/src/test_utils.rs index 597d2a85fd5..51a85cf40a2 100644 --- a/lightning-block-sync/src/test_utils.rs +++ b/lightning-block-sync/src/test_utils.rs @@ -1,12 +1,15 @@ -use crate::{AsyncBlockSourceResult, BlockData, BlockHeaderData, BlockSource, BlockSourceError, UnboundedCache}; use crate::poll::{Validate, ValidatedBlockHeader}; +use crate::{ + AsyncBlockSourceResult, BlockData, BlockHeaderData, BlockSource, BlockSourceError, + UnboundedCache, +}; use bitcoin::blockdata::block::{Block, BlockHeader}; use bitcoin::blockdata::constants::genesis_block; use bitcoin::hash_types::BlockHash; use bitcoin::network::constants::Network; -use bitcoin::util::uint::Uint256; use bitcoin::util::hash::bitcoin_merkle_root; +use bitcoin::util::uint::Uint256; use bitcoin::{PackedLockTime, Transaction}; use lightning::chain; @@ -48,9 +51,10 @@ impl Blockchain { version: 0, lock_time: PackedLockTime::ZERO, input: vec![], - output: vec![] + output: vec![], }; - let merkle_root = bitcoin_merkle_root(vec![coinbase.txid().as_hash()].into_iter()).unwrap(); + let merkle_root = + bitcoin_merkle_root(vec![coinbase.txid().as_hash()].into_iter()).unwrap(); self.blocks.push(Block { header: BlockHeader { version: 0, @@ -131,7 +135,9 @@ impl Blockchain { } impl BlockSource for Blockchain { - fn get_header<'a>(&'a self, header_hash: &'a BlockHash, _height_hint: Option) -> AsyncBlockSourceResult<'a, BlockHeaderData> { + fn get_header<'a>( + &'a self, header_hash: &'a BlockHash, _height_hint: Option, + ) -> AsyncBlockSourceResult<'a, BlockHeaderData> { Box::pin(async move { if self.without_headers { return Err(BlockSourceError::persistent("header not found")); @@ -151,7 +157,9 @@ impl BlockSource for Blockchain { }) } - fn get_block<'a>(&'a self, header_hash: &'a BlockHash) -> AsyncBlockSourceResult<'a, BlockData> { + fn get_block<'a>( + &'a self, header_hash: &'a BlockHash, + ) -> AsyncBlockSourceResult<'a, BlockData> { Box::pin(async move { for (height, block) in self.blocks.iter().enumerate() { if block.header.block_hash() == *header_hash { @@ -188,7 +196,10 @@ impl BlockSource for Blockchain { pub struct NullChainListener; impl chain::Listen for NullChainListener { - fn filtered_block_connected(&self, _header: &BlockHeader, _txdata: &chain::transaction::TransactionData, _height: u32) {} + fn filtered_block_connected( + &self, _header: &BlockHeader, _txdata: &chain::transaction::TransactionData, _height: u32, + ) { + } fn block_disconnected(&self, _header: &BlockHeader, _height: u32) {} } @@ -236,7 +247,9 @@ impl chain::Listen for MockChainListener { } } - fn filtered_block_connected(&self, header: &BlockHeader, _txdata: &chain::transaction::TransactionData, height: u32) { + fn filtered_block_connected( + &self, header: &BlockHeader, _txdata: &chain::transaction::TransactionData, height: u32, + ) { match self.expected_filtered_blocks_connected.borrow_mut().pop_front() { None => { panic!("Unexpected filtered block connected: {:?}", header.block_hash()); diff --git a/lightning-invoice/src/de.rs b/lightning-invoice/src/de.rs index 5bfa9a042c3..093c5b216b0 100644 --- a/lightning-invoice/src/de.rs +++ b/lightning-invoice/src/de.rs @@ -1,31 +1,34 @@ -#[cfg(feature = "std")] -use std::error; use core::convert::TryFrom; use core::fmt; use core::fmt::{Display, Formatter}; use core::num::ParseIntError; use core::str; use core::str::FromStr; +#[cfg(feature = "std")] +use std::error; use bech32::{u5, FromBase32}; -use bitcoin::{PubkeyHash, ScriptHash}; +use crate::prelude::*; use bitcoin::util::address::WitnessVersion; -use bitcoin_hashes::Hash; +use bitcoin::{PubkeyHash, ScriptHash}; use bitcoin_hashes::sha256; -use crate::prelude::*; +use bitcoin_hashes::Hash; use lightning::ln::PaymentSecret; use lightning::routing::gossip::RoutingFees; use lightning::routing::router::{RouteHint, RouteHintHop}; use num_traits::{CheckedAdd, CheckedMul}; -use secp256k1::ecdsa::{RecoveryId, RecoverableSignature}; +use secp256k1::ecdsa::{RecoverableSignature, RecoveryId}; use secp256k1::PublicKey; -use super::{Bolt11Invoice, Sha256, TaggedField, ExpiryTime, MinFinalCltvExpiryDelta, Fallback, PayeePubKey, Bolt11InvoiceSignature, PositiveTimestamp, - Bolt11SemanticError, PrivateRoute, Bolt11ParseError, ParseOrSemanticError, Description, RawTaggedField, Currency, RawHrp, SiPrefix, RawBolt11Invoice, - constants, SignedRawBolt11Invoice, RawDataPart, Bolt11InvoiceFeatures}; +use super::{ + constants, Bolt11Invoice, Bolt11InvoiceFeatures, Bolt11InvoiceSignature, Bolt11ParseError, + Bolt11SemanticError, Currency, Description, ExpiryTime, Fallback, MinFinalCltvExpiryDelta, + ParseOrSemanticError, PayeePubKey, PositiveTimestamp, PrivateRoute, RawBolt11Invoice, + RawDataPart, RawHrp, RawTaggedField, Sha256, SiPrefix, SignedRawBolt11Invoice, TaggedField, +}; use self::hrp_sm::parse_hrp; @@ -52,7 +55,7 @@ mod hrp_sm { } else { Err(super::Bolt11ParseError::MalformedHRP) } - } + }, States::ParseL => { if read_symbol == 'n' { Ok(States::ParseN) @@ -92,7 +95,6 @@ mod hrp_sm { } } - struct StateMachine { state: States, position: usize, @@ -114,8 +116,8 @@ mod hrp_sm { fn update_range(range: &mut Option>, position: usize) { let new_range = match *range { - None => Range {start: position, end: position + 1}, - Some(ref r) => Range {start: r.start, end: r.end + 1}, + None => Range { start: position, end: position + 1 }, + Some(ref r) => Range { start: r.start, end: r.end + 1 }, }; *range = Some(new_range); } @@ -125,14 +127,14 @@ mod hrp_sm { match next_state { States::ParseCurrencyPrefix => { StateMachine::update_range(&mut self.currency_prefix, self.position) - } + }, States::ParseAmountNumber => { StateMachine::update_range(&mut self.amount_number, self.position) }, States::ParseAmountSiPrefix => { StateMachine::update_range(&mut self.amount_si_prefix, self.position) }, - _ => {} + _ => {}, } self.position += 1; @@ -167,18 +169,14 @@ mod hrp_sm { return Err(super::Bolt11ParseError::MalformedHRP); } - let currency = sm.currency_prefix().clone() - .map(|r| &input[r]).unwrap_or(""); - let amount = sm.amount_number().clone() - .map(|r| &input[r]).unwrap_or(""); - let si = sm.amount_si_prefix().clone() - .map(|r| &input[r]).unwrap_or(""); + let currency = sm.currency_prefix().clone().map(|r| &input[r]).unwrap_or(""); + let amount = sm.amount_number().clone().map(|r| &input[r]).unwrap_or(""); + let si = sm.amount_si_prefix().clone().map(|r| &input[r]).unwrap_or(""); Ok((currency, amount, si)) } } - impl FromStr for super::Currency { type Err = Bolt11ParseError; @@ -189,7 +187,7 @@ impl FromStr for super::Currency { "bcrt" => Ok(Currency::Regtest), "sb" => Ok(Currency::Simnet), "tbs" => Ok(Currency::Signet), - _ => Err(Bolt11ParseError::UnknownCurrency) + _ => Err(Bolt11ParseError::UnknownCurrency), } } } @@ -204,7 +202,7 @@ impl FromStr for SiPrefix { "u" => Ok(Micro), "n" => Ok(Nano), "p" => Ok(Pico), - _ => Err(Bolt11ParseError::UnknownSiPrefix) + _ => Err(Bolt11ParseError::UnknownSiPrefix), } } } @@ -281,18 +279,12 @@ impl FromStr for SignedRawBolt11Invoice { } let raw_hrp: RawHrp = hrp.parse()?; - let data_part = RawDataPart::from_base32(&data[..data.len()-104])?; + let data_part = RawDataPart::from_base32(&data[..data.len() - 104])?; Ok(SignedRawBolt11Invoice { - raw_invoice: RawBolt11Invoice { - hrp: raw_hrp, - data: data_part, - }, - hash: RawBolt11Invoice::hash_from_parts( - hrp.as_bytes(), - &data[..data.len()-104] - ), - signature: Bolt11InvoiceSignature::from_base32(&data[data.len()-104..])?, + raw_invoice: RawBolt11Invoice { hrp: raw_hrp, data: data_part }, + hash: RawBolt11Invoice::hash_from_parts(hrp.as_bytes(), &data[..data.len() - 104]), + signature: Bolt11InvoiceSignature::from_base32(&data[data.len() - 104..])?, }) } } @@ -305,11 +297,7 @@ impl FromStr for RawHrp { let currency = parts.0.parse::()?; - let amount = if !parts.1.is_empty() { - Some(parts.1.parse::()?) - } else { - None - }; + let amount = if !parts.1.is_empty() { Some(parts.1.parse::()?) } else { None }; let si_prefix: Option = if parts.2.is_empty() { None @@ -323,11 +311,7 @@ impl FromStr for RawHrp { Some(si) }; - Ok(RawHrp { - currency, - raw_amount: amount, - si_prefix, - }) + Ok(RawHrp { currency, raw_amount: amount, si_prefix }) } } @@ -335,17 +319,15 @@ impl FromBase32 for RawDataPart { type Err = Bolt11ParseError; fn from_base32(data: &[u5]) -> Result { - if data.len() < 7 { // timestamp length + if data.len() < 7 { + // timestamp length return Err(Bolt11ParseError::TooShortDataPart); } let timestamp = PositiveTimestamp::from_base32(&data[0..7])?; let tagged = parse_tagged_parts(&data[7..])?; - Ok(RawDataPart { - timestamp, - tagged_fields: tagged, - }) + Ok(RawDataPart { timestamp, tagged_fields: tagged }) } } @@ -354,10 +336,11 @@ impl FromBase32 for PositiveTimestamp { fn from_base32(b32: &[u5]) -> Result { if b32.len() != 7 { - return Err(Bolt11ParseError::InvalidSliceLength("PositiveTimestamp::from_base32()".into())); + return Err(Bolt11ParseError::InvalidSliceLength( + "PositiveTimestamp::from_base32()".into(), + )); } - let timestamp: u64 = parse_int_be(b32, 32) - .expect("7*5bit < 64bit, no overflow possible"); + let timestamp: u64 = parse_int_be(b32, 32).expect("7*5bit < 64bit, no overflow possible"); match PositiveTimestamp::from_unix_timestamp(timestamp) { Ok(t) => Ok(t), Err(_) => unreachable!(), @@ -369,28 +352,27 @@ impl FromBase32 for Bolt11InvoiceSignature { type Err = Bolt11ParseError; fn from_base32(signature: &[u5]) -> Result { if signature.len() != 104 { - return Err(Bolt11ParseError::InvalidSliceLength("Bolt11InvoiceSignature::from_base32()".into())); + return Err(Bolt11ParseError::InvalidSliceLength( + "Bolt11InvoiceSignature::from_base32()".into(), + )); } let recoverable_signature_bytes = Vec::::from_base32(signature)?; let signature = &recoverable_signature_bytes[0..64]; let recovery_id = RecoveryId::from_i32(recoverable_signature_bytes[64] as i32)?; - Ok(Bolt11InvoiceSignature(RecoverableSignature::from_compact( - signature, - recovery_id - )?)) + Ok(Bolt11InvoiceSignature(RecoverableSignature::from_compact(signature, recovery_id)?)) } } pub(crate) fn parse_int_be(digits: &[U], base: T) -> Option - where T: CheckedAdd + CheckedMul + From + Default, - U: Into + Copy +where + T: CheckedAdd + CheckedMul + From + Default, + U: Into + Copy, { - digits.iter().fold(Some(Default::default()), |acc, b| - acc - .and_then(|x| x.checked_mul(&base)) + digits.iter().fold(Some(Default::default()), |acc, b| { + acc.and_then(|x| x.checked_mul(&base)) .and_then(|x| x.checked_add(&(Into::::into(*b)).into())) - ) + }) } fn parse_tagged_parts(data: &[u5]) -> Result, Bolt11ParseError> { @@ -418,13 +400,12 @@ fn parse_tagged_parts(data: &[u5]) -> Result, Bolt11ParseErr data = &data[last_element..]; match TaggedField::from_base32(field) { - Ok(field) => { - parts.push(RawTaggedField::KnownSemantics(field)) - }, - Err(Bolt11ParseError::Skip)|Err(Bolt11ParseError::Bech32Error(bech32::Error::InvalidLength)) => { + Ok(field) => parts.push(RawTaggedField::KnownSemantics(field)), + Err(Bolt11ParseError::Skip) + | Err(Bolt11ParseError::Bech32Error(bech32::Error::InvalidLength)) => { parts.push(RawTaggedField::UnknownSemantics(field.into())) }, - Err(e) => {return Err(e)} + Err(e) => return Err(e), } } Ok(parts) @@ -439,35 +420,46 @@ impl FromBase32 for TaggedField { } let tag = field[0]; - let field_data = &field[3..]; + let field_data = &field[3..]; match tag.to_u8() { - constants::TAG_PAYMENT_HASH => - Ok(TaggedField::PaymentHash(Sha256::from_base32(field_data)?)), - constants::TAG_DESCRIPTION => - Ok(TaggedField::Description(Description::from_base32(field_data)?)), - constants::TAG_PAYEE_PUB_KEY => - Ok(TaggedField::PayeePubKey(PayeePubKey::from_base32(field_data)?)), - constants::TAG_DESCRIPTION_HASH => - Ok(TaggedField::DescriptionHash(Sha256::from_base32(field_data)?)), - constants::TAG_EXPIRY_TIME => - Ok(TaggedField::ExpiryTime(ExpiryTime::from_base32(field_data)?)), - constants::TAG_MIN_FINAL_CLTV_EXPIRY_DELTA => - Ok(TaggedField::MinFinalCltvExpiryDelta(MinFinalCltvExpiryDelta::from_base32(field_data)?)), - constants::TAG_FALLBACK => - Ok(TaggedField::Fallback(Fallback::from_base32(field_data)?)), - constants::TAG_PRIVATE_ROUTE => - Ok(TaggedField::PrivateRoute(PrivateRoute::from_base32(field_data)?)), - constants::TAG_PAYMENT_SECRET => - Ok(TaggedField::PaymentSecret(PaymentSecret::from_base32(field_data)?)), - constants::TAG_PAYMENT_METADATA => - Ok(TaggedField::PaymentMetadata(Vec::::from_base32(field_data)?)), - constants::TAG_FEATURES => - Ok(TaggedField::Features(Bolt11InvoiceFeatures::from_base32(field_data)?)), + constants::TAG_PAYMENT_HASH => { + Ok(TaggedField::PaymentHash(Sha256::from_base32(field_data)?)) + }, + constants::TAG_DESCRIPTION => { + Ok(TaggedField::Description(Description::from_base32(field_data)?)) + }, + constants::TAG_PAYEE_PUB_KEY => { + Ok(TaggedField::PayeePubKey(PayeePubKey::from_base32(field_data)?)) + }, + constants::TAG_DESCRIPTION_HASH => { + Ok(TaggedField::DescriptionHash(Sha256::from_base32(field_data)?)) + }, + constants::TAG_EXPIRY_TIME => { + Ok(TaggedField::ExpiryTime(ExpiryTime::from_base32(field_data)?)) + }, + constants::TAG_MIN_FINAL_CLTV_EXPIRY_DELTA => Ok(TaggedField::MinFinalCltvExpiryDelta( + MinFinalCltvExpiryDelta::from_base32(field_data)?, + )), + constants::TAG_FALLBACK => { + Ok(TaggedField::Fallback(Fallback::from_base32(field_data)?)) + }, + constants::TAG_PRIVATE_ROUTE => { + Ok(TaggedField::PrivateRoute(PrivateRoute::from_base32(field_data)?)) + }, + constants::TAG_PAYMENT_SECRET => { + Ok(TaggedField::PaymentSecret(PaymentSecret::from_base32(field_data)?)) + }, + constants::TAG_PAYMENT_METADATA => { + Ok(TaggedField::PaymentMetadata(Vec::::from_base32(field_data)?)) + }, + constants::TAG_FEATURES => { + Ok(TaggedField::Features(Bolt11InvoiceFeatures::from_base32(field_data)?)) + }, _ => { // "A reader MUST skip over unknown fields" Err(Bolt11ParseError::Skip) - } + }, } } } @@ -480,8 +472,10 @@ impl FromBase32 for Sha256 { // "A reader MUST skip over […] a p, [or] h […] field that does not have data_length 52 […]." Err(Bolt11ParseError::Skip) } else { - Ok(Sha256(sha256::Hash::from_slice(&Vec::::from_base32(field_data)?) - .expect("length was checked before (52 u5 -> 32 u8)"))) + Ok(Sha256( + sha256::Hash::from_slice(&Vec::::from_base32(field_data)?) + .expect("length was checked before (52 u5 -> 32 u8)"), + )) } } } @@ -492,9 +486,8 @@ impl FromBase32 for Description { fn from_base32(field_data: &[u5]) -> Result { let bytes = Vec::::from_base32(field_data)?; let description = String::from(str::from_utf8(&bytes)?); - Ok(Description::new(description).expect( - "Max len is 639=floor(1023*5/8) since the len field is only 10bits long" - )) + Ok(Description::new(description) + .expect("Max len is 639=floor(1023*5/8) since the len field is only 10bits long")) } } @@ -517,9 +510,7 @@ impl FromBase32 for ExpiryTime { type Err = Bolt11ParseError; fn from_base32(field_data: &[u5]) -> Result { - match parse_int_be::(field_data, 32) - .map(ExpiryTime::from_seconds) - { + match parse_int_be::(field_data, 32).map(ExpiryTime::from_seconds) { Some(t) => Ok(t), None => Err(Bolt11ParseError::IntegerOverflowError), } @@ -555,27 +546,29 @@ impl FromBase32 for Fallback { if bytes.len() < 2 || bytes.len() > 40 { return Err(Bolt11ParseError::InvalidSegWitProgramLength); } - let version = WitnessVersion::try_from(version).expect("0 through 16 are valid SegWit versions"); - Ok(Fallback::SegWitProgram { - version, - program: bytes - }) + let version = WitnessVersion::try_from(version) + .expect("0 through 16 are valid SegWit versions"); + Ok(Fallback::SegWitProgram { version, program: bytes }) }, 17 => { let pkh = match PubkeyHash::from_slice(&bytes) { Ok(pkh) => pkh, - Err(bitcoin_hashes::Error::InvalidLength(_, _)) => return Err(Bolt11ParseError::InvalidPubKeyHashLength), + Err(bitcoin_hashes::Error::InvalidLength(_, _)) => { + return Err(Bolt11ParseError::InvalidPubKeyHashLength) + }, }; Ok(Fallback::PubKeyHash(pkh)) - } + }, 18 => { let sh = match ScriptHash::from_slice(&bytes) { Ok(sh) => sh, - Err(bitcoin_hashes::Error::InvalidLength(_, _)) => return Err(Bolt11ParseError::InvalidScriptHashLength), + Err(bitcoin_hashes::Error::InvalidLength(_, _)) => { + return Err(Bolt11ParseError::InvalidScriptHashLength) + }, }; Ok(Fallback::ScriptHash(sh)) - } - _ => Err(Bolt11ParseError::Skip) + }, + _ => Err(Bolt11ParseError::Skip), } } } @@ -602,10 +595,12 @@ impl FromBase32 for PrivateRoute { let hop = RouteHintHop { src_node_id: PublicKey::from_slice(&hop_bytes[0..33])?, - short_channel_id: parse_int_be(&channel_id, 256).expect("short chan ID slice too big?"), + short_channel_id: parse_int_be(&channel_id, 256) + .expect("short chan ID slice too big?"), fees: RoutingFees { base_msat: parse_int_be(&hop_bytes[41..45], 256).expect("slice too big?"), - proportional_millionths: parse_int_be(&hop_bytes[45..49], 256).expect("slice too big?"), + proportional_millionths: parse_int_be(&hop_bytes[45..49], 256) + .expect("slice too big?"), }, cltv_expiry_delta: parse_int_be(&hop_bytes[49..51], 256).expect("slice too big?"), htlc_minimum_msat: None, @@ -625,19 +620,19 @@ impl Display for Bolt11ParseError { // TODO: find a way to combine the first three arms (e as error::Error?) Bolt11ParseError::Bech32Error(ref e) => { write!(f, "Invalid bech32: {}", e) - } + }, Bolt11ParseError::ParseAmountError(ref e) => { write!(f, "Invalid amount in hrp ({})", e) - } + }, Bolt11ParseError::MalformedSignature(ref e) => { write!(f, "Invalid secp256k1 signature: {}", e) - } + }, Bolt11ParseError::DescriptionDecodeError(ref e) => { write!(f, "Description is not a valid utf-8 string: {}", e) - } + }, Bolt11ParseError::InvalidSliceLength(ref function) => { write!(f, "Slice in function {} had the wrong length", function) - } + }, Bolt11ParseError::BadPrefix => f.write_str("did not begin with 'ln'"), Bolt11ParseError::UnknownCurrency => f.write_str("currency code unknown"), Bolt11ParseError::UnknownSiPrefix => f.write_str("unknown SI prefix"), @@ -664,9 +659,9 @@ impl Display for Bolt11ParseError { Bolt11ParseError::InvalidRecoveryId => { f.write_str("recovery id is out of range (should be in [0,3])") }, - Bolt11ParseError::Skip => { - f.write_str("the tagged field has to be skipped because of an unexpected, but allowed property") - }, + Bolt11ParseError::Skip => f.write_str( + "the tagged field has to be skipped because of an unexpected, but allowed property", + ), } } } @@ -687,13 +682,13 @@ impl error::Error for Bolt11ParseError {} impl error::Error for ParseOrSemanticError {} macro_rules! from_error { - ($my_error:expr, $extern_error:ty) => { - impl From<$extern_error> for Bolt11ParseError { - fn from(e: $extern_error) -> Self { - $my_error(e) - } - } - } + ($my_error:expr, $extern_error:ty) => { + impl From<$extern_error> for Bolt11ParseError { + fn from(e: $extern_error) -> Self { + $my_error(e) + } + } + }; } from_error!(Bolt11ParseError::MalformedSignature, secp256k1::Error); @@ -704,7 +699,7 @@ impl From for Bolt11ParseError { fn from(e: bech32::Error) -> Self { match e { bech32::Error::InvalidPadding => Bolt11ParseError::PaddingError, - _ => Bolt11ParseError::Bech32Error(e) + _ => Bolt11ParseError::Bech32Error(e), } } } @@ -724,27 +719,22 @@ impl From for ParseOrSemanticError { #[cfg(test)] mod test { use crate::de::Bolt11ParseError; - use secp256k1::PublicKey; use bech32::u5; use bitcoin_hashes::hex::FromHex; use bitcoin_hashes::sha256; + use secp256k1::PublicKey; const CHARSET_REV: [i8; 128] = [ - -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, - -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, - -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, - 15, -1, 10, 17, 21, 20, 26, 30, 7, 5, -1, -1, -1, -1, -1, -1, - -1, 29, -1, 24, 13, 25, 9, 8, 23, -1, 18, 22, 31, 27, 19, -1, - 1, 0, 3, 16, 11, 28, 12, 14, 6, 4, 2, -1, -1, -1, -1, -1, - -1, 29, -1, 24, 13, 25, 9, 8, 23, -1, 18, 22, 31, 27, 19, -1, - 1, 0, 3, 16, 11, 28, 12, 14, 6, 4, 2, -1, -1, -1, -1, -1 + -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, + -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, + -1, -1, 15, -1, 10, 17, 21, 20, 26, 30, 7, 5, -1, -1, -1, -1, -1, -1, -1, 29, -1, 24, 13, + 25, 9, 8, 23, -1, 18, 22, 31, 27, 19, -1, 1, 0, 3, 16, 11, 28, 12, 14, 6, 4, 2, -1, -1, -1, + -1, -1, -1, 29, -1, 24, 13, 25, 9, 8, 23, -1, 18, 22, 31, 27, 19, -1, 1, 0, 3, 16, 11, 28, + 12, 14, 6, 4, 2, -1, -1, -1, -1, -1, ]; fn from_bech32(bytes_5b: &[u8]) -> Vec { - bytes_5b - .iter() - .map(|c| u5::try_from_u8(CHARSET_REV[*c as usize] as u8).unwrap()) - .collect() + bytes_5b.iter().map(|c| u5::try_from_u8(CHARSET_REV[*c as usize] as u8).unwrap()).collect() } #[test] @@ -774,21 +764,19 @@ mod test { use crate::Sha256; use bech32::FromBase32; - let input = from_bech32( - "qqqsyqcyq5rqwzqfqqqsyqcyq5rqwzqfqqqsyqcyq5rqwzqfqypq".as_bytes() - ); + let input = from_bech32("qqqsyqcyq5rqwzqfqqqsyqcyq5rqwzqfqqqsyqcyq5rqwzqfqypq".as_bytes()); let hash = sha256::Hash::from_hex( - "0001020304050607080900010203040506070809000102030405060708090102" - ).unwrap(); + "0001020304050607080900010203040506070809000102030405060708090102", + ) + .unwrap(); let expected = Ok(Sha256(hash)); assert_eq!(Sha256::from_base32(&input), expected); // make sure hashes of unknown length get skipped - let input_unexpected_length = from_bech32( - "qqqsyqcyq5rqwzqfqqqsyqcyq5rqwzqfqqqsyqcyq5rqwzqfqypyq".as_bytes() - ); + let input_unexpected_length = + from_bech32("qqqsyqcyq5rqwzqfqqqsyqcyq5rqwzqfqqqsyqcyq5rqwzqfqypyq".as_bytes()); assert_eq!(Sha256::from_base32(&input_unexpected_length), Err(Bolt11ParseError::Skip)); } @@ -811,18 +799,15 @@ mod test { let pk_bytes = [ 0x03, 0xe7, 0x15, 0x6a, 0xe3, 0x3b, 0x0a, 0x20, 0x8d, 0x07, 0x44, 0x19, 0x91, 0x63, 0x17, 0x7e, 0x90, 0x9e, 0x80, 0x17, 0x6e, 0x55, 0xd9, 0x7a, 0x2f, 0x22, 0x1e, 0xde, - 0x0f, 0x93, 0x4d, 0xd9, 0xad + 0x0f, 0x93, 0x4d, 0xd9, 0xad, ]; - let expected = Ok(PayeePubKey( - PublicKey::from_slice(&pk_bytes[..]).unwrap() - )); + let expected = Ok(PayeePubKey(PublicKey::from_slice(&pk_bytes[..]).unwrap())); assert_eq!(PayeePubKey::from_base32(&input), expected); // expects 33 bytes - let input_unexpected_length = from_bech32( - "q0n326hr8v9zprg8gsvezcch06gfaqqhde2aj730yg0durunfhvq".as_bytes() - ); + let input_unexpected_length = + from_bech32("q0n326hr8v9zprg8gsvezcch06gfaqqhde2aj730yg0durunfhvq".as_bytes()); assert_eq!(PayeePubKey::from_base32(&input_unexpected_length), Err(Bolt11ParseError::Skip)); } @@ -836,7 +821,10 @@ mod test { assert_eq!(ExpiryTime::from_base32(&input), expected); let input_too_large = from_bech32("sqqqqqqqqqqqq".as_bytes()); - assert_eq!(ExpiryTime::from_base32(&input_too_large), Err(Bolt11ParseError::IntegerOverflowError)); + assert_eq!( + ExpiryTime::from_base32(&input_too_large), + Err(Bolt11ParseError::IntegerOverflowError) + ); } #[test] @@ -854,55 +842,51 @@ mod test { fn test_parse_fallback() { use crate::Fallback; use bech32::FromBase32; - use bitcoin::{PubkeyHash, ScriptHash}; use bitcoin::util::address::WitnessVersion; + use bitcoin::{PubkeyHash, ScriptHash}; use bitcoin_hashes::Hash; let cases = vec![ ( from_bech32("3x9et2e20v6pu37c5d9vax37wxq72un98".as_bytes()), - Ok(Fallback::PubKeyHash(PubkeyHash::from_slice(&[ - 0x31, 0x72, 0xb5, 0x65, 0x4f, 0x66, 0x83, 0xc8, 0xfb, 0x14, 0x69, 0x59, 0xd3, - 0x47, 0xce, 0x30, 0x3c, 0xae, 0x4c, 0xa7 - ]).unwrap())) + Ok(Fallback::PubKeyHash( + PubkeyHash::from_slice(&[ + 0x31, 0x72, 0xb5, 0x65, 0x4f, 0x66, 0x83, 0xc8, 0xfb, 0x14, 0x69, 0x59, + 0xd3, 0x47, 0xce, 0x30, 0x3c, 0xae, 0x4c, 0xa7, + ]) + .unwrap(), + )), ), ( from_bech32("j3a24vwu6r8ejrss3axul8rxldph2q7z9".as_bytes()), - Ok(Fallback::ScriptHash(ScriptHash::from_slice(&[ - 0x8f, 0x55, 0x56, 0x3b, 0x9a, 0x19, 0xf3, 0x21, 0xc2, 0x11, 0xe9, 0xb9, 0xf3, - 0x8c, 0xdf, 0x68, 0x6e, 0xa0, 0x78, 0x45 - ]).unwrap())) + Ok(Fallback::ScriptHash( + ScriptHash::from_slice(&[ + 0x8f, 0x55, 0x56, 0x3b, 0x9a, 0x19, 0xf3, 0x21, 0xc2, 0x11, 0xe9, 0xb9, + 0xf3, 0x8c, 0xdf, 0x68, 0x6e, 0xa0, 0x78, 0x45, + ]) + .unwrap(), + )), ), ( from_bech32("qw508d6qejxtdg4y5r3zarvary0c5xw7k".as_bytes()), Ok(Fallback::SegWitProgram { version: WitnessVersion::V0, - program: Vec::from(&[ - 0x75u8, 0x1e, 0x76, 0xe8, 0x19, 0x91, 0x96, 0xd4, 0x54, 0x94, 0x1c, 0x45, - 0xd1, 0xb3, 0xa3, 0x23, 0xf1, 0x43, 0x3b, 0xd6 - ][..]) - }) - ), - ( - vec![u5::try_from_u8(21).unwrap(); 41], - Err(Bolt11ParseError::Skip) - ), - ( - vec![], - Err(Bolt11ParseError::UnexpectedEndOfTaggedFields) + program: Vec::from( + &[ + 0x75u8, 0x1e, 0x76, 0xe8, 0x19, 0x91, 0x96, 0xd4, 0x54, 0x94, 0x1c, + 0x45, 0xd1, 0xb3, 0xa3, 0x23, 0xf1, 0x43, 0x3b, 0xd6, + ][..], + ), + }), ), + (vec![u5::try_from_u8(21).unwrap(); 41], Err(Bolt11ParseError::Skip)), + (vec![], Err(Bolt11ParseError::UnexpectedEndOfTaggedFields)), ( vec![u5::try_from_u8(1).unwrap(); 81], - Err(Bolt11ParseError::InvalidSegWitProgramLength) + Err(Bolt11ParseError::InvalidSegWitProgramLength), ), - ( - vec![u5::try_from_u8(17).unwrap(); 1], - Err(Bolt11ParseError::InvalidPubKeyHashLength) - ), - ( - vec![u5::try_from_u8(18).unwrap(); 1], - Err(Bolt11ParseError::InvalidScriptHashLength) - ) + (vec![u5::try_from_u8(17).unwrap(); 1], Err(Bolt11ParseError::InvalidPubKeyHashLength)), + (vec![u5::try_from_u8(18).unwrap(); 1], Err(Bolt11ParseError::InvalidScriptHashLength)), ]; for (input, expected) in cases.into_iter() { @@ -912,11 +896,11 @@ mod test { #[test] fn test_parse_route() { - use lightning::routing::gossip::RoutingFees; - use lightning::routing::router::{RouteHint, RouteHintHop}; + use crate::de::parse_int_be; use crate::PrivateRoute; use bech32::FromBase32; - use crate::de::parse_int_be; + use lightning::routing::gossip::RoutingFees; + use lightning::routing::router::{RouteHint, RouteHintHop}; let input = from_bech32( "q20q82gphp2nflc7jtzrcazrra7wwgzxqc8u7754cdlpfrmccae92qgzqvzq2ps8pqqqqqqpqqqqq9qqqvpeuqa\ @@ -929,34 +913,32 @@ mod test { &[ 0x02u8, 0x9e, 0x03, 0xa9, 0x01, 0xb8, 0x55, 0x34, 0xff, 0x1e, 0x92, 0xc4, 0x3c, 0x74, 0x43, 0x1f, 0x7c, 0xe7, 0x20, 0x46, 0x06, 0x0f, 0xcf, 0x7a, 0x95, 0xc3, - 0x7e, 0x14, 0x8f, 0x78, 0xc7, 0x72, 0x55 - ][..] - ).unwrap(), - short_channel_id: parse_int_be(&[0x01, 0x02, 0x03, 0x04, 0x05, 0x06, 0x07, 0x08], 256).expect("short chan ID slice too big?"), - fees: RoutingFees { - base_msat: 1, - proportional_millionths: 20, - }, + 0x7e, 0x14, 0x8f, 0x78, 0xc7, 0x72, 0x55, + ][..], + ) + .unwrap(), + short_channel_id: parse_int_be(&[0x01, 0x02, 0x03, 0x04, 0x05, 0x06, 0x07, 0x08], 256) + .expect("short chan ID slice too big?"), + fees: RoutingFees { base_msat: 1, proportional_millionths: 20 }, cltv_expiry_delta: 3, htlc_minimum_msat: None, - htlc_maximum_msat: None + htlc_maximum_msat: None, }); expected.push(RouteHintHop { src_node_id: PublicKey::from_slice( &[ 0x03u8, 0x9e, 0x03, 0xa9, 0x01, 0xb8, 0x55, 0x34, 0xff, 0x1e, 0x92, 0xc4, 0x3c, 0x74, 0x43, 0x1f, 0x7c, 0xe7, 0x20, 0x46, 0x06, 0x0f, 0xcf, 0x7a, 0x95, 0xc3, - 0x7e, 0x14, 0x8f, 0x78, 0xc7, 0x72, 0x55 - ][..] - ).unwrap(), - short_channel_id: parse_int_be(&[0x03, 0x04, 0x05, 0x06, 0x07, 0x08, 0x09, 0x0a], 256).expect("short chan ID slice too big?"), - fees: RoutingFees { - base_msat: 2, - proportional_millionths: 30, - }, + 0x7e, 0x14, 0x8f, 0x78, 0xc7, 0x72, 0x55, + ][..], + ) + .unwrap(), + short_channel_id: parse_int_be(&[0x03, 0x04, 0x05, 0x06, 0x07, 0x08, 0x09, 0x0a], 256) + .expect("short chan ID slice too big?"), + fees: RoutingFees { base_msat: 2, proportional_millionths: 30 }, cltv_expiry_delta: 4, htlc_minimum_msat: None, - htlc_maximum_msat: None + htlc_maximum_msat: None, }); assert_eq!(PrivateRoute::from_base32(&input), Ok(PrivateRoute(RouteHint(expected)))); @@ -969,57 +951,69 @@ mod test { #[test] fn test_payment_secret_and_features_de_and_ser() { - use lightning::ln::features::Bolt11InvoiceFeatures; - use secp256k1::ecdsa::{RecoveryId, RecoverableSignature}; use crate::TaggedField::*; - use crate::{SiPrefix, SignedRawBolt11Invoice, Bolt11InvoiceSignature, RawBolt11Invoice, RawHrp, RawDataPart, - Currency, Sha256, PositiveTimestamp}; + use crate::{ + Bolt11InvoiceSignature, Currency, PositiveTimestamp, RawBolt11Invoice, RawDataPart, + RawHrp, Sha256, SiPrefix, SignedRawBolt11Invoice, + }; + use lightning::ln::features::Bolt11InvoiceFeatures; + use secp256k1::ecdsa::{RecoverableSignature, RecoveryId}; // Feature bits 9, 15, and 99 are set. - let expected_features = Bolt11InvoiceFeatures::from_le_bytes(vec![0, 130, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 8]); + let expected_features = + Bolt11InvoiceFeatures::from_le_bytes(vec![0, 130, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 8]); let invoice_str = "lnbc25m1pvjluezpp5qqqsyqcyq5rqwzqfqqqsyqcyq5rqwzqfqqqsyqcyq5rqwzqfqypqdq5vdhkven9v5sxyetpdeessp5zyg3zyg3zyg3zyg3zyg3zyg3zyg3zyg3zyg3zyg3zyg3zyg3zygs9q5sqqqqqqqqqqqqqqqpqsq67gye39hfg3zd8rgc80k32tvy9xk2xunwm5lzexnvpx6fd77en8qaq424dxgt56cag2dpt359k3ssyhetktkpqh24jqnjyw6uqd08sgptq44qu"; - let invoice = SignedRawBolt11Invoice { - raw_invoice: RawBolt11Invoice { - hrp: RawHrp { - currency: Currency::Bitcoin, - raw_amount: Some(25), - si_prefix: Some(SiPrefix::Milli) - }, - data: RawDataPart { - timestamp: PositiveTimestamp::from_unix_timestamp(1496314658).unwrap(), - tagged_fields: vec ! [ + let invoice = + SignedRawBolt11Invoice { + raw_invoice: RawBolt11Invoice { + hrp: RawHrp { + currency: Currency::Bitcoin, + raw_amount: Some(25), + si_prefix: Some(SiPrefix::Milli), + }, + data: RawDataPart { + timestamp: PositiveTimestamp::from_unix_timestamp(1496314658).unwrap(), + tagged_fields: vec ! [ PaymentHash(Sha256(sha256::Hash::from_hex( "0001020304050607080900010203040506070809000102030405060708090102" ).unwrap())).into(), Description(crate::Description::new("coffee beans".to_owned()).unwrap()).into(), PaymentSecret(crate::PaymentSecret([17; 32])).into(), - Features(expected_features).into()]} - }, - hash: [0xb1, 0x96, 0x46, 0xc3, 0xbc, 0x56, 0x76, 0x1d, 0x20, 0x65, 0x6e, 0x0e, 0x32, - 0xec, 0xd2, 0x69, 0x27, 0xb7, 0x62, 0x6e, 0x2a, 0x8b, 0xe6, 0x97, 0x71, 0x9f, - 0xf8, 0x7e, 0x44, 0x54, 0x55, 0xb9], - signature: Bolt11InvoiceSignature(RecoverableSignature::from_compact( - &[0xd7, 0x90, 0x4c, 0xc4, 0xb7, 0x4a, 0x22, 0x26, 0x9c, 0x68, 0xc1, 0xdf, 0x68, - 0xa9, 0x6c, 0x21, 0x4d, 0x65, 0x1b, 0x93, 0x76, 0xe9, 0xf1, 0x64, 0xd3, 0x60, - 0x4d, 0xa4, 0xb7, 0xde, 0xcc, 0xce, 0x0e, 0x82, 0xaa, 0xab, 0x4c, 0x85, 0xd3, - 0x58, 0xea, 0x14, 0xd0, 0xae, 0x34, 0x2d, 0xa3, 0x08, 0x12, 0xf9, 0x5d, 0x97, - 0x60, 0x82, 0xea, 0xac, 0x81, 0x39, 0x11, 0xda, 0xe0, 0x1a, 0xf3, 0xc1], - RecoveryId::from_i32(1).unwrap() - ).unwrap()), + Features(expected_features).into()], + }, + }, + hash: [ + 0xb1, 0x96, 0x46, 0xc3, 0xbc, 0x56, 0x76, 0x1d, 0x20, 0x65, 0x6e, 0x0e, 0x32, + 0xec, 0xd2, 0x69, 0x27, 0xb7, 0x62, 0x6e, 0x2a, 0x8b, 0xe6, 0x97, 0x71, 0x9f, + 0xf8, 0x7e, 0x44, 0x54, 0x55, 0xb9, + ], + signature: Bolt11InvoiceSignature( + RecoverableSignature::from_compact( + &[ + 0xd7, 0x90, 0x4c, 0xc4, 0xb7, 0x4a, 0x22, 0x26, 0x9c, 0x68, 0xc1, 0xdf, + 0x68, 0xa9, 0x6c, 0x21, 0x4d, 0x65, 0x1b, 0x93, 0x76, 0xe9, 0xf1, 0x64, + 0xd3, 0x60, 0x4d, 0xa4, 0xb7, 0xde, 0xcc, 0xce, 0x0e, 0x82, 0xaa, 0xab, + 0x4c, 0x85, 0xd3, 0x58, 0xea, 0x14, 0xd0, 0xae, 0x34, 0x2d, 0xa3, 0x08, + 0x12, 0xf9, 0x5d, 0x97, 0x60, 0x82, 0xea, 0xac, 0x81, 0x39, 0x11, 0xda, + 0xe0, 0x1a, 0xf3, 0xc1, + ], + RecoveryId::from_i32(1).unwrap(), + ) + .unwrap(), + ), }; assert_eq!(invoice_str, invoice.to_string()); - assert_eq!( - invoice_str.parse(), - Ok(invoice) - ); + assert_eq!(invoice_str.parse(), Ok(invoice)); } #[test] fn test_raw_signed_invoice_deserialization() { use crate::TaggedField::*; - use secp256k1::ecdsa::{RecoveryId, RecoverableSignature}; - use crate::{SignedRawBolt11Invoice, Bolt11InvoiceSignature, RawBolt11Invoice, RawHrp, RawDataPart, Currency, Sha256, - PositiveTimestamp}; + use crate::{ + Bolt11InvoiceSignature, Currency, PositiveTimestamp, RawBolt11Invoice, RawDataPart, + RawHrp, Sha256, SignedRawBolt11Invoice, + }; + use secp256k1::ecdsa::{RecoverableSignature, RecoveryId}; assert_eq!( "lnbc1pvjluezpp5qqqsyqcyq5rqwzqfqqqsyqcyq5rqwzqfqqqsyqcyq5rqwzqfqypqdpl2pkx2ctnv5sxxmmw\ diff --git a/lightning-invoice/src/lib.rs b/lightning-invoice/src/lib.rs index d953795cf8e..d7be7405e4b 100644 --- a/lightning-invoice/src/lib.rs +++ b/lightning-invoice/src/lib.rs @@ -1,15 +1,12 @@ // Prefix these with `rustdoc::` when we update our MSRV to be >= 1.52 to remove warnings. #![deny(broken_intra_doc_links)] #![deny(private_intra_doc_links)] - #![deny(missing_docs)] #![deny(non_upper_case_globals)] #![deny(non_camel_case_types)] #![deny(non_snake_case)] #![deny(unused_mut)] - #![cfg_attr(docsrs, feature(doc_auto_cfg))] - #![cfg_attr(feature = "strict", deny(warnings))] #![cfg_attr(all(not(feature = "std"), not(test)), no_std)] @@ -32,12 +29,13 @@ pub mod utils; extern crate bech32; extern crate bitcoin_hashes; -#[macro_use] extern crate lightning; -extern crate num_traits; -extern crate secp256k1; +#[macro_use] +extern crate lightning; extern crate alloc; #[cfg(any(test, feature = "std"))] extern crate core; +extern crate num_traits; +extern crate secp256k1; #[cfg(feature = "serde")] extern crate serde; @@ -45,34 +43,34 @@ extern crate serde; use std::time::SystemTime; use bech32::u5; -use bitcoin::{Address, Network, PubkeyHash, ScriptHash}; use bitcoin::util::address::{Payload, WitnessVersion}; -use bitcoin_hashes::{Hash, sha256}; +use bitcoin::{Address, Network, PubkeyHash, ScriptHash}; +use bitcoin_hashes::{sha256, Hash}; use lightning::ln::features::Bolt11InvoiceFeatures; use lightning::util::invoice::construct_invoice_preimage; +use secp256k1::ecdsa::RecoverableSignature; use secp256k1::PublicKey; use secp256k1::{Message, Secp256k1}; -use secp256k1::ecdsa::RecoverableSignature; use core::cmp::Ordering; -use core::fmt::{Display, Formatter, self}; +use core::fmt::{self, Display, Formatter}; use core::iter::FilterMap; use core::num::ParseIntError; use core::ops::Deref; use core::slice::Iter; -use core::time::Duration; use core::str; +use core::time::Duration; #[cfg(feature = "serde")] -use serde::{Deserialize, Deserializer,Serialize, Serializer, de::Error}; +use serde::{de::Error, Deserialize, Deserializer, Serialize, Serializer}; #[doc(no_inline)] pub use lightning::ln::PaymentSecret; #[doc(no_inline)] -pub use lightning::routing::router::{RouteHint, RouteHintHop}; -#[doc(no_inline)] pub use lightning::routing::gossip::RoutingFees; +#[doc(no_inline)] +pub use lightning::routing::router::{RouteHint, RouteHintHop}; mod de; mod ser; @@ -82,11 +80,11 @@ mod prelude { #[cfg(feature = "hashbrown")] extern crate hashbrown; - pub use alloc::{vec, vec::Vec, string::String, collections::VecDeque, boxed::Box}; - #[cfg(not(feature = "hashbrown"))] - pub use std::collections::{HashMap, HashSet, hash_map}; #[cfg(feature = "hashbrown")] - pub use self::hashbrown::{HashMap, HashSet, hash_map}; + pub use self::hashbrown::{hash_map, HashMap, HashSet}; + pub use alloc::{boxed::Box, collections::VecDeque, string::String, vec, vec::Vec}; + #[cfg(not(feature = "hashbrown"))] + pub use std::collections::{hash_map, HashMap, HashSet}; pub use alloc::string::ToString; } @@ -226,7 +224,14 @@ pub const DEFAULT_MIN_FINAL_CLTV_EXPIRY_DELTA: u64 = 18; /// /// This is not exported to bindings users as we likely need to manually select one set of boolean type parameters. #[derive(Eq, PartialEq, Debug, Clone)] -pub struct InvoiceBuilder { +pub struct InvoiceBuilder< + D: tb::Bool, + H: tb::Bool, + T: tb::Bool, + C: tb::Bool, + S: tb::Bool, + M: tb::Bool, +> { currency: Currency, amount: Option, si_prefix: Option, @@ -454,8 +459,10 @@ pub enum TaggedField { /// SHA-256 hash #[derive(Clone, Debug, Hash, Eq, PartialEq, Ord, PartialOrd)] -pub struct Sha256(/// This is not exported to bindings users as the native hash types are not currently mapped - pub sha256::Hash); +pub struct Sha256( + /// This is not exported to bindings users as the native hash types are not currently mapped + pub sha256::Hash, +); impl Sha256 { /// Constructs a new [`Sha256`] from the given bytes, which are assumed to be the output of a @@ -490,10 +497,7 @@ pub struct MinFinalCltvExpiryDelta(pub u64); #[allow(missing_docs)] #[derive(Clone, Debug, Hash, Eq, PartialEq, Ord, PartialOrd)] pub enum Fallback { - SegWitProgram { - version: WitnessVersion, - program: Vec, - }, + SegWitProgram { version: WitnessVersion, program: Vec }, PubKeyHash(PubkeyHash), ScriptHash(ScriptHash), } @@ -560,9 +564,20 @@ impl InvoiceBuilder InvoiceBuilder { +impl + InvoiceBuilder +{ /// Helper function to set the completeness flags. - fn set_flags(self) -> InvoiceBuilder { + fn set_flags< + DN: tb::Bool, + HN: tb::Bool, + TN: tb::Bool, + CN: tb::Bool, + SN: tb::Bool, + MN: tb::Bool, + >( + self, + ) -> InvoiceBuilder { InvoiceBuilder:: { currency: self.currency, amount: self.amount, @@ -621,41 +636,37 @@ impl InvoiceBuilder { +impl + InvoiceBuilder +{ /// Builds a [`RawBolt11Invoice`] if no [`CreationError`] occurred while construction any of the /// fields. pub fn build_raw(self) -> Result { - // If an error occurred at any time before, return it now if let Some(e) = self.error { return Err(e); } - let hrp = RawHrp { - currency: self.currency, - raw_amount: self.amount, - si_prefix: self.si_prefix, - }; + let hrp = + RawHrp { currency: self.currency, raw_amount: self.amount, si_prefix: self.si_prefix }; let timestamp = self.timestamp.expect("ensured to be Some(t) by type T"); - let tagged_fields = self.tagged_fields.into_iter().map(|tf| { - RawTaggedField::KnownSemantics(tf) - }).collect::>(); + let tagged_fields = self + .tagged_fields + .into_iter() + .map(|tf| RawTaggedField::KnownSemantics(tf)) + .collect::>(); - let data = RawDataPart { - timestamp, - tagged_fields, - }; + let data = RawDataPart { timestamp, tagged_fields }; - Ok(RawBolt11Invoice { - hrp, - data, - }) + Ok(RawBolt11Invoice { hrp, data }) } } -impl InvoiceBuilder { +impl + InvoiceBuilder +{ /// Set the description. This function is only available if no description (hash) was set. pub fn description(mut self, description: String) -> InvoiceBuilder { match Description::new(description) { @@ -666,25 +677,27 @@ impl InvoiceBui } /// Set the description hash. This function is only available if no description (hash) was set. - pub fn description_hash(mut self, description_hash: sha256::Hash) -> InvoiceBuilder { + pub fn description_hash( + mut self, description_hash: sha256::Hash, + ) -> InvoiceBuilder { self.tagged_fields.push(TaggedField::DescriptionHash(Sha256(description_hash))); self.set_flags() } /// Set the description or description hash. This function is only available if no description (hash) was set. - pub fn invoice_description(self, description: Bolt11InvoiceDescription) -> InvoiceBuilder { + pub fn invoice_description( + self, description: Bolt11InvoiceDescription, + ) -> InvoiceBuilder { match description { - Bolt11InvoiceDescription::Direct(desc) => { - self.description(desc.clone().into_inner()) - } - Bolt11InvoiceDescription::Hash(hash) => { - self.description_hash(hash.0) - } + Bolt11InvoiceDescription::Direct(desc) => self.description(desc.clone().into_inner()), + Bolt11InvoiceDescription::Hash(hash) => self.description_hash(hash.0), } } } -impl InvoiceBuilder { +impl + InvoiceBuilder +{ /// Set the payment hash. This function is only available if no payment hash was set. pub fn payment_hash(mut self, hash: sha256::Hash) -> InvoiceBuilder { self.tagged_fields.push(TaggedField::PaymentHash(Sha256(hash))); @@ -692,7 +705,9 @@ impl InvoiceBui } } -impl InvoiceBuilder { +impl + InvoiceBuilder +{ /// Sets the timestamp to a specific [`SystemTime`]. #[cfg(feature = "std")] pub fn timestamp(mut self, time: SystemTime) -> InvoiceBuilder { @@ -706,7 +721,9 @@ impl InvoiceBui /// Sets the timestamp to a duration since the Unix epoch, dropping the subsecond part (which /// is not representable in BOLT 11 invoices). - pub fn duration_since_epoch(mut self, time: Duration) -> InvoiceBuilder { + pub fn duration_since_epoch( + mut self, time: Duration, + ) -> InvoiceBuilder { match PositiveTimestamp::from_duration_since_epoch(time) { Ok(t) => self.timestamp = Some(t), Err(e) => self.error = Some(e), @@ -724,17 +741,27 @@ impl InvoiceBui } } -impl InvoiceBuilder { +impl + InvoiceBuilder +{ /// Sets `min_final_cltv_expiry_delta`. - pub fn min_final_cltv_expiry_delta(mut self, min_final_cltv_expiry_delta: u64) -> InvoiceBuilder { - self.tagged_fields.push(TaggedField::MinFinalCltvExpiryDelta(MinFinalCltvExpiryDelta(min_final_cltv_expiry_delta))); + pub fn min_final_cltv_expiry_delta( + mut self, min_final_cltv_expiry_delta: u64, + ) -> InvoiceBuilder { + self.tagged_fields.push(TaggedField::MinFinalCltvExpiryDelta(MinFinalCltvExpiryDelta( + min_final_cltv_expiry_delta, + ))); self.set_flags() } } -impl InvoiceBuilder { +impl + InvoiceBuilder +{ /// Sets the payment secret and relevant features. - pub fn payment_secret(mut self, payment_secret: PaymentSecret) -> InvoiceBuilder { + pub fn payment_secret( + mut self, payment_secret: PaymentSecret, + ) -> InvoiceBuilder { let mut found_features = false; for field in self.tagged_fields.iter_mut() { if let TaggedField::Features(f) = field { @@ -754,14 +781,18 @@ impl InvoiceBui } } -impl InvoiceBuilder { +impl + InvoiceBuilder +{ /// Sets the payment metadata. /// /// By default features are set to *optionally* allow the sender to include the payment metadata. /// If you wish to require that the sender include the metadata (and fail to parse the invoice if /// they don't support payment metadata fields), you need to call /// [`InvoiceBuilder::require_payment_metadata`] after this. - pub fn payment_metadata(mut self, payment_metadata: Vec) -> InvoiceBuilder { + pub fn payment_metadata( + mut self, payment_metadata: Vec, + ) -> InvoiceBuilder { self.tagged_fields.push(TaggedField::PaymentMetadata(payment_metadata)); let mut found_features = false; for field in self.tagged_fields.iter_mut() { @@ -779,7 +810,9 @@ impl InvoiceBui } } -impl InvoiceBuilder { +impl + InvoiceBuilder +{ /// Sets forwarding of payment metadata as required. A reader of the invoice which does not /// support sending payment metadata will fail to read the invoice. pub fn require_payment_metadata(mut self) -> InvoiceBuilder { @@ -792,7 +825,9 @@ impl InvoiceBui } } -impl InvoiceBuilder { +impl + InvoiceBuilder +{ /// Sets the `basic_mpp` feature as optional. pub fn basic_mpp(mut self) -> Self { for field in self.tagged_fields.iter_mut() { @@ -809,11 +844,10 @@ impl InvoiceBuilder(self, sign_function: F) -> Result - where F: FnOnce(&Message) -> RecoverableSignature + where + F: FnOnce(&Message) -> RecoverableSignature, { - let invoice = self.try_build_signed::<_, ()>(|hash| { - Ok(sign_function(hash)) - }); + let invoice = self.try_build_signed::<_, ()>(|hash| Ok(sign_function(hash))); match invoice { Ok(i) => Ok(i), @@ -825,8 +859,11 @@ impl InvoiceBuilder(self, sign_function: F) -> Result> - where F: FnOnce(&Message) -> Result + pub fn try_build_signed( + self, sign_function: F, + ) -> Result> + where + F: FnOnce(&Message) -> Result, { let raw = match self.build_raw() { Ok(r) => r, @@ -838,9 +875,7 @@ impl InvoiceBuilder return Err(SignOrCreationError::SignError(e)), }; - let invoice = Bolt11Invoice { - signed_invoice: signed, - }; + let invoice = Bolt11Invoice { signed_invoice: signed }; invoice.check_field_counts().expect("should be ensured by type signature of builder"); invoice.check_feature_bits().expect("should be ensured by type signature of builder"); @@ -850,7 +885,6 @@ impl InvoiceBuilder true, @@ -982,10 +1010,7 @@ impl RawBolt11Invoice { pub fn signable_hash(&self) -> [u8; 32] { use bech32::ToBase32; - RawBolt11Invoice::hash_from_parts( - self.hrp.to_string().as_bytes(), - &self.data.to_base32() - ) + RawBolt11Invoice::hash_from_parts(self.hrp.to_string().as_bytes(), &self.data.to_base32()) } /// Signs the invoice using the supplied `sign_method`. This function MAY fail with an error of @@ -995,7 +1020,8 @@ impl RawBolt11Invoice { /// This is not exported to bindings users as we don't currently support passing function pointers into methods /// explicitly. pub fn sign(self, sign_method: F) -> Result - where F: FnOnce(&Message) -> Result + where + F: FnOnce(&Message) -> Result, { let raw_hash = self.signable_hash(); let hash = Message::from_slice(&raw_hash[..]) @@ -1012,9 +1038,9 @@ impl RawBolt11Invoice { /// Returns an iterator over all tagged fields with known semantics. /// /// This is not exported to bindings users as there is not yet a manual mapping for a FilterMap - pub fn known_tagged_fields(&self) - -> FilterMap, fn(&RawTaggedField) -> Option<&TaggedField>> - { + pub fn known_tagged_fields( + &self, + ) -> FilterMap, fn(&RawTaggedField) -> Option<&TaggedField>> { // For 1.14.0 compatibility: closures' types can't be written an fn()->() in the // function's type signature. // TODO: refactor once impl Trait is available @@ -1025,7 +1051,7 @@ impl RawBolt11Invoice { } } - self.data.tagged_fields.iter().filter_map(match_raw ) + self.data.tagged_fields.iter().filter_map(match_raw) } pub fn payment_hash(&self) -> Option<&Sha256> { @@ -1075,7 +1101,7 @@ impl RawBolt11Invoice { pub fn amount_pico_btc(&self) -> Option { self.hrp.raw_amount.map(|v| { - v * self.hrp.si_prefix.as_ref().map_or(1_000_000_000_000, |si| { si.multiplier() }) + v * self.hrp.si_prefix.as_ref().map_or(1_000_000_000_000, |si| si.multiplier()) }) } @@ -1157,10 +1183,13 @@ impl Bolt11Invoice { /// Check that all mandatory fields are present fn check_field_counts(&self) -> Result<(), Bolt11SemanticError> { // "A writer MUST include exactly one p field […]." - let payment_hash_cnt = self.tagged_fields().filter(|&tf| match *tf { - TaggedField::PaymentHash(_) => true, - _ => false, - }).count(); + let payment_hash_cnt = self + .tagged_fields() + .filter(|&tf| match *tf { + TaggedField::PaymentHash(_) => true, + _ => false, + }) + .count(); if payment_hash_cnt < 1 { return Err(Bolt11SemanticError::NoPaymentHash); } else if payment_hash_cnt > 1 { @@ -1168,14 +1197,17 @@ impl Bolt11Invoice { } // "A writer MUST include either exactly one d or exactly one h field." - let description_cnt = self.tagged_fields().filter(|&tf| match *tf { - TaggedField::Description(_) | TaggedField::DescriptionHash(_) => true, - _ => false, - }).count(); - if description_cnt < 1 { + let description_cnt = self + .tagged_fields() + .filter(|&tf| match *tf { + TaggedField::Description(_) | TaggedField::DescriptionHash(_) => true, + _ => false, + }) + .count(); + if description_cnt < 1 { return Err(Bolt11SemanticError::NoDescription); } else if description_cnt > 1 { - return Err(Bolt11SemanticError::MultipleDescriptions); + return Err(Bolt11SemanticError::MultipleDescriptions); } self.check_payment_secret()?; @@ -1186,10 +1218,13 @@ impl Bolt11Invoice { /// Checks that there is exactly one payment secret field fn check_payment_secret(&self) -> Result<(), Bolt11SemanticError> { // "A writer MUST include exactly one `s` field." - let payment_secret_count = self.tagged_fields().filter(|&tf| match *tf { - TaggedField::PaymentSecret(_) => true, - _ => false, - }).count(); + let payment_secret_count = self + .tagged_fields() + .filter(|&tf| match *tf { + TaggedField::PaymentSecret(_) => true, + _ => false, + }) + .count(); if payment_secret_count < 1 { return Err(Bolt11SemanticError::NoPaymentSecret); } else if payment_secret_count > 1 { @@ -1238,10 +1273,12 @@ impl Bolt11Invoice { /// Check that the invoice is signed correctly and that key recovery works pub fn check_signature(&self) -> Result<(), Bolt11SemanticError> { match self.signed_invoice.recover_payee_pub_key() { - Err(secp256k1::Error::InvalidRecoveryId) => - return Err(Bolt11SemanticError::InvalidRecoveryId), - Err(secp256k1::Error::InvalidSignature) => - return Err(Bolt11SemanticError::InvalidSignature), + Err(secp256k1::Error::InvalidRecoveryId) => { + return Err(Bolt11SemanticError::InvalidRecoveryId) + }, + Err(secp256k1::Error::InvalidSignature) => { + return Err(Bolt11SemanticError::InvalidSignature) + }, Err(e) => panic!("no other error may occur, got {:?}", e), Ok(_) => {}, } @@ -1273,10 +1310,10 @@ impl Bolt11Invoice { /// /// assert!(Bolt11Invoice::from_signed(signed).is_ok()); /// ``` - pub fn from_signed(signed_invoice: SignedRawBolt11Invoice) -> Result { - let invoice = Bolt11Invoice { - signed_invoice, - }; + pub fn from_signed( + signed_invoice: SignedRawBolt11Invoice, + ) -> Result { + let invoice = Bolt11Invoice { signed_invoice }; invoice.check_field_counts()?; invoice.check_feature_bits()?; invoice.check_signature()?; @@ -1299,8 +1336,9 @@ impl Bolt11Invoice { /// Returns an iterator over all tagged fields of this `Bolt11Invoice`. /// /// This is not exported to bindings users as there is not yet a manual mapping for a FilterMap - pub fn tagged_fields(&self) - -> FilterMap, fn(&RawTaggedField) -> Option<&TaggedField>> { + pub fn tagged_fields( + &self, + ) -> FilterMap, fn(&RawTaggedField) -> Option<&TaggedField>> { self.signed_invoice.raw_invoice().known_tagged_fields() } @@ -1354,7 +1392,8 @@ impl Bolt11Invoice { /// Returns the invoice's expiry time, if present, otherwise [`DEFAULT_EXPIRY_TIME`]. pub fn expiry_time(&self) -> Duration { - self.signed_invoice.expiry_time() + self.signed_invoice + .expiry_time() .map(|x| x.0) .unwrap_or(Duration::from_secs(DEFAULT_EXPIRY_TIME)) } @@ -1377,7 +1416,8 @@ impl Bolt11Invoice { /// Returns the Duration remaining until the invoice expires. #[cfg(feature = "std")] pub fn duration_until_expiry(&self) -> Duration { - SystemTime::now().duration_since(SystemTime::UNIX_EPOCH) + SystemTime::now() + .duration_since(SystemTime::UNIX_EPOCH) .map(|now| self.expiration_remaining_from_epoch(now)) .unwrap_or(Duration::from_nanos(0)) } @@ -1393,13 +1433,15 @@ impl Bolt11Invoice { pub fn would_expire(&self, at_time: Duration) -> bool { self.duration_since_epoch() .checked_add(self.expiry_time()) - .unwrap_or_else(|| Duration::new(u64::max_value(), 1_000_000_000 - 1)) < at_time + .unwrap_or_else(|| Duration::new(u64::max_value(), 1_000_000_000 - 1)) + < at_time } /// Returns the invoice's `min_final_cltv_expiry_delta` time, if present, otherwise /// [`DEFAULT_MIN_FINAL_CLTV_EXPIRY_DELTA`]. pub fn min_final_cltv_expiry_delta(&self) -> u64 { - self.signed_invoice.min_final_cltv_expiry_delta() + self.signed_invoice + .min_final_cltv_expiry_delta() .map(|x| x.0) .unwrap_or(DEFAULT_MIN_FINAL_CLTV_EXPIRY_DELTA) } @@ -1413,21 +1455,20 @@ impl Bolt11Invoice { /// Returns a list of all fallback addresses as [`Address`]es pub fn fallback_addresses(&self) -> Vec
{ - self.fallbacks().iter().map(|fallback| { - let payload = match fallback { - Fallback::SegWitProgram { version, program } => { - Payload::WitnessProgram { version: *version, program: program.to_vec() } - } - Fallback::PubKeyHash(pkh) => { - Payload::PubkeyHash(*pkh) - } - Fallback::ScriptHash(sh) => { - Payload::ScriptHash(*sh) - } - }; - - Address { payload, network: self.network() } - }).collect() + self.fallbacks() + .iter() + .map(|fallback| { + let payload = match fallback { + Fallback::SegWitProgram { version, program } => { + Payload::WitnessProgram { version: *version, program: program.to_vec() } + }, + Fallback::PubKeyHash(pkh) => Payload::PubkeyHash(*pkh), + Fallback::ScriptHash(sh) => Payload::ScriptHash(*sh), + }; + + Address { payload, network: self.network() } + }) + .collect() } /// Returns a list of all routes included in the invoice @@ -1438,8 +1479,12 @@ impl Bolt11Invoice { /// Returns a list of all routes included in the invoice as the underlying hints pub fn route_hints(&self) -> Vec { find_all_extract!( - self.signed_invoice.known_tagged_fields(), TaggedField::PrivateRoute(ref x), x - ).map(|route| (**route).clone()).collect() + self.signed_invoice.known_tagged_fields(), + TaggedField::PrivateRoute(ref x), + x + ) + .map(|route| (**route).clone()) + .collect() } /// Returns the currency for which the invoice was issued @@ -1493,7 +1538,6 @@ impl TaggedField { } impl Description { - /// Creates a new `Description` if `description` is at most 1023 __bytes__ long, /// returns [`CreationError::DescriptionTooLong`] otherwise /// @@ -1650,7 +1694,7 @@ impl Display for CreationError { } #[cfg(feature = "std")] -impl std::error::Error for CreationError { } +impl std::error::Error for CreationError {} /// Errors that may occur when converting a [`RawBolt11Invoice`] to a [`Bolt11Invoice`]. They relate to /// the requirements sections in BOLT #11 @@ -1706,7 +1750,7 @@ impl Display for Bolt11SemanticError { } #[cfg(feature = "std")] -impl std::error::Error for Bolt11SemanticError { } +impl std::error::Error for Bolt11SemanticError {} /// When signing using a fallible method either an user-supplied `SignError` or a [`CreationError`] /// may occur. @@ -1730,13 +1774,19 @@ impl Display for SignOrCreationError { #[cfg(feature = "serde")] impl Serialize for Bolt11Invoice { - fn serialize(&self, serializer: S) -> Result where S: Serializer { + fn serialize(&self, serializer: S) -> Result + where + S: Serializer, + { serializer.serialize_str(self.to_string().as_str()) } } #[cfg(feature = "serde")] impl<'de> Deserialize<'de> for Bolt11Invoice { - fn deserialize(deserializer: D) -> Result where D: Deserializer<'de> { + fn deserialize(deserializer: D) -> Result + where + D: Deserializer<'de>, + { let bolt11 = String::deserialize(deserializer)? .parse::() .map_err(|e| D::Error::custom(format_args!("{:?}", e)))?; @@ -1761,24 +1811,28 @@ mod test { #[test] fn test_calc_invoice_hash() { - use crate::{RawBolt11Invoice, RawHrp, RawDataPart, Currency, PositiveTimestamp}; use crate::TaggedField::*; + use crate::{Currency, PositiveTimestamp, RawBolt11Invoice, RawDataPart, RawHrp}; let invoice = RawBolt11Invoice { - hrp: RawHrp { - currency: Currency::Bitcoin, - raw_amount: None, - si_prefix: None, - }, + hrp: RawHrp { currency: Currency::Bitcoin, raw_amount: None, si_prefix: None }, data: RawDataPart { timestamp: PositiveTimestamp::from_unix_timestamp(1496314658).unwrap(), tagged_fields: vec![ - PaymentHash(crate::Sha256(sha256::Hash::from_hex( - "0001020304050607080900010203040506070809000102030405060708090102" - ).unwrap())).into(), - Description(crate::Description::new( - "Please consider supporting this project".to_owned() - ).unwrap()).into(), + PaymentHash(crate::Sha256( + sha256::Hash::from_hex( + "0001020304050607080900010203040506070809000102030405060708090102", + ) + .unwrap(), + )) + .into(), + Description( + crate::Description::new( + "Please consider supporting this project".to_owned(), + ) + .unwrap(), + ) + .into(), ], }, }; @@ -1786,7 +1840,7 @@ mod test { let expected_hash = [ 0xc3, 0xd4, 0xe8, 0x3f, 0x64, 0x6f, 0xa7, 0x9a, 0x39, 0x3d, 0x75, 0x27, 0x7b, 0x1d, 0x85, 0x8d, 0xb1, 0xd1, 0xf7, 0xab, 0x71, 0x37, 0xdc, 0xb7, 0x83, 0x5d, 0xb2, 0xec, - 0xd5, 0x18, 0xe1, 0xc9 + 0xd5, 0x18, 0xe1, 0xc9, ]; assert_eq!(invoice.signable_hash(), expected_hash) @@ -1795,22 +1849,21 @@ mod test { #[test] fn test_check_signature() { use crate::TaggedField::*; + use crate::{ + Bolt11InvoiceSignature, Currency, PositiveTimestamp, RawBolt11Invoice, RawDataPart, + RawHrp, Sha256, SignedRawBolt11Invoice, + }; + use secp256k1::ecdsa::{RecoverableSignature, RecoveryId}; use secp256k1::Secp256k1; - use secp256k1::ecdsa::{RecoveryId, RecoverableSignature}; - use secp256k1::{SecretKey, PublicKey}; - use crate::{SignedRawBolt11Invoice, Bolt11InvoiceSignature, RawBolt11Invoice, RawHrp, RawDataPart, Currency, Sha256, - PositiveTimestamp}; - - let invoice = SignedRawBolt11Invoice { - raw_invoice: RawBolt11Invoice { - hrp: RawHrp { - currency: Currency::Bitcoin, - raw_amount: None, - si_prefix: None, - }, - data: RawDataPart { - timestamp: PositiveTimestamp::from_unix_timestamp(1496314658).unwrap(), - tagged_fields: vec ! [ + use secp256k1::{PublicKey, SecretKey}; + + let invoice = + SignedRawBolt11Invoice { + raw_invoice: RawBolt11Invoice { + hrp: RawHrp { currency: Currency::Bitcoin, raw_amount: None, si_prefix: None }, + data: RawDataPart { + timestamp: PositiveTimestamp::from_unix_timestamp(1496314658).unwrap(), + tagged_fields: vec ! [ PaymentHash(Sha256(sha256::Hash::from_hex( "0001020304050607080900010203040506070809000102030405060708090102" ).unwrap())).into(), @@ -1820,25 +1873,28 @@ mod test { ).unwrap() ).into(), ], + }, }, - }, - hash: [ - 0xc3, 0xd4, 0xe8, 0x3f, 0x64, 0x6f, 0xa7, 0x9a, 0x39, 0x3d, 0x75, 0x27, - 0x7b, 0x1d, 0x85, 0x8d, 0xb1, 0xd1, 0xf7, 0xab, 0x71, 0x37, 0xdc, 0xb7, - 0x83, 0x5d, 0xb2, 0xec, 0xd5, 0x18, 0xe1, 0xc9 - ], - signature: Bolt11InvoiceSignature(RecoverableSignature::from_compact( - & [ - 0x38u8, 0xec, 0x68, 0x91, 0x34, 0x5e, 0x20, 0x41, 0x45, 0xbe, 0x8a, - 0x3a, 0x99, 0xde, 0x38, 0xe9, 0x8a, 0x39, 0xd6, 0xa5, 0x69, 0x43, - 0x4e, 0x18, 0x45, 0xc8, 0xaf, 0x72, 0x05, 0xaf, 0xcf, 0xcc, 0x7f, - 0x42, 0x5f, 0xcd, 0x14, 0x63, 0xe9, 0x3c, 0x32, 0x88, 0x1e, 0xad, - 0x0d, 0x6e, 0x35, 0x6d, 0x46, 0x7e, 0xc8, 0xc0, 0x25, 0x53, 0xf9, - 0xaa, 0xb1, 0x5e, 0x57, 0x38, 0xb1, 0x1f, 0x12, 0x7f + hash: [ + 0xc3, 0xd4, 0xe8, 0x3f, 0x64, 0x6f, 0xa7, 0x9a, 0x39, 0x3d, 0x75, 0x27, 0x7b, + 0x1d, 0x85, 0x8d, 0xb1, 0xd1, 0xf7, 0xab, 0x71, 0x37, 0xdc, 0xb7, 0x83, 0x5d, + 0xb2, 0xec, 0xd5, 0x18, 0xe1, 0xc9, ], - RecoveryId::from_i32(0).unwrap() - ).unwrap()), - }; + signature: Bolt11InvoiceSignature( + RecoverableSignature::from_compact( + &[ + 0x38u8, 0xec, 0x68, 0x91, 0x34, 0x5e, 0x20, 0x41, 0x45, 0xbe, 0x8a, + 0x3a, 0x99, 0xde, 0x38, 0xe9, 0x8a, 0x39, 0xd6, 0xa5, 0x69, 0x43, 0x4e, + 0x18, 0x45, 0xc8, 0xaf, 0x72, 0x05, 0xaf, 0xcf, 0xcc, 0x7f, 0x42, 0x5f, + 0xcd, 0x14, 0x63, 0xe9, 0x3c, 0x32, 0x88, 0x1e, 0xad, 0x0d, 0x6e, 0x35, + 0x6d, 0x46, 0x7e, 0xc8, 0xc0, 0x25, 0x53, 0xf9, 0xaa, 0xb1, 0x5e, 0x57, + 0x38, 0xb1, 0x1f, 0x12, 0x7f, + ], + RecoveryId::from_i32(0).unwrap(), + ) + .unwrap(), + ), + }; assert!(invoice.check_signature()); @@ -1846,17 +1902,18 @@ mod test { &[ 0xe1, 0x26, 0xf6, 0x8f, 0x7e, 0xaf, 0xcc, 0x8b, 0x74, 0xf5, 0x4d, 0x26, 0x9f, 0xe2, 0x06, 0xbe, 0x71, 0x50, 0x00, 0xf9, 0x4d, 0xac, 0x06, 0x7d, 0x1c, 0x04, 0xa8, 0xca, - 0x3b, 0x2d, 0xb7, 0x34 - ][..] - ).unwrap(); + 0x3b, 0x2d, 0xb7, 0x34, + ][..], + ) + .unwrap(); let public_key = PublicKey::from_secret_key(&Secp256k1::new(), &private_key); assert_eq!(invoice.recover_payee_pub_key(), Ok(crate::PayeePubKey(public_key))); let (raw_invoice, _, _) = invoice.into_parts(); - let new_signed = raw_invoice.sign::<_, ()>(|hash| { - Ok(Secp256k1::new().sign_ecdsa_recoverable(hash, &private_key)) - }).unwrap(); + let new_signed = raw_invoice + .sign::<_, ()>(|hash| Ok(Secp256k1::new().sign_ecdsa_recoverable(hash, &private_key))) + .unwrap(); assert!(new_signed.check_signature()); } @@ -1864,31 +1921,35 @@ mod test { #[test] fn test_check_feature_bits() { use crate::TaggedField::*; + use crate::{ + Bolt11Invoice, Bolt11SemanticError, Currency, PositiveTimestamp, RawBolt11Invoice, + RawDataPart, RawHrp, Sha256, + }; use lightning::ln::features::Bolt11InvoiceFeatures; use secp256k1::Secp256k1; use secp256k1::SecretKey; - use crate::{Bolt11Invoice, RawBolt11Invoice, RawHrp, RawDataPart, Currency, Sha256, PositiveTimestamp, - Bolt11SemanticError}; let private_key = SecretKey::from_slice(&[42; 32]).unwrap(); let payment_secret = lightning::ln::PaymentSecret([21; 32]); let invoice_template = RawBolt11Invoice { - hrp: RawHrp { - currency: Currency::Bitcoin, - raw_amount: None, - si_prefix: None, - }, + hrp: RawHrp { currency: Currency::Bitcoin, raw_amount: None, si_prefix: None }, data: RawDataPart { timestamp: PositiveTimestamp::from_unix_timestamp(1496314658).unwrap(), - tagged_fields: vec ! [ - PaymentHash(Sha256(sha256::Hash::from_hex( - "0001020304050607080900010203040506070809000102030405060708090102" - ).unwrap())).into(), + tagged_fields: vec![ + PaymentHash(Sha256( + sha256::Hash::from_hex( + "0001020304050607080900010203040506070809000102030405060708090102", + ) + .unwrap(), + )) + .into(), Description( crate::Description::new( - "Please consider supporting this project".to_owned() - ).unwrap() - ).into(), + "Please consider supporting this project".to_owned(), + ) + .unwrap(), + ) + .into(), ], }, }; @@ -1897,8 +1958,11 @@ mod test { let invoice = { let mut invoice = invoice_template.clone(); invoice.data.tagged_fields.push(PaymentSecret(payment_secret).into()); - invoice.sign::<_, ()>(|hash| Ok(Secp256k1::new().sign_ecdsa_recoverable(hash, &private_key))) - }.unwrap(); + invoice.sign::<_, ()>(|hash| { + Ok(Secp256k1::new().sign_ecdsa_recoverable(hash, &private_key)) + }) + } + .unwrap(); assert_eq!(Bolt11Invoice::from_signed(invoice), Err(Bolt11SemanticError::InvalidFeatures)); // Missing feature bits @@ -1906,8 +1970,11 @@ mod test { let mut invoice = invoice_template.clone(); invoice.data.tagged_fields.push(PaymentSecret(payment_secret).into()); invoice.data.tagged_fields.push(Features(Bolt11InvoiceFeatures::empty()).into()); - invoice.sign::<_, ()>(|hash| Ok(Secp256k1::new().sign_ecdsa_recoverable(hash, &private_key))) - }.unwrap(); + invoice.sign::<_, ()>(|hash| { + Ok(Secp256k1::new().sign_ecdsa_recoverable(hash, &private_key)) + }) + } + .unwrap(); assert_eq!(Bolt11Invoice::from_signed(invoice), Err(Bolt11SemanticError::InvalidFeatures)); let mut payment_secret_features = Bolt11InvoiceFeatures::empty(); @@ -1918,31 +1985,43 @@ mod test { let mut invoice = invoice_template.clone(); invoice.data.tagged_fields.push(PaymentSecret(payment_secret).into()); invoice.data.tagged_fields.push(Features(payment_secret_features.clone()).into()); - invoice.sign::<_, ()>(|hash| Ok(Secp256k1::new().sign_ecdsa_recoverable(hash, &private_key))) - }.unwrap(); + invoice.sign::<_, ()>(|hash| { + Ok(Secp256k1::new().sign_ecdsa_recoverable(hash, &private_key)) + }) + } + .unwrap(); assert!(Bolt11Invoice::from_signed(invoice).is_ok()); // No payment secret or features let invoice = { let invoice = invoice_template.clone(); - invoice.sign::<_, ()>(|hash| Ok(Secp256k1::new().sign_ecdsa_recoverable(hash, &private_key))) - }.unwrap(); + invoice.sign::<_, ()>(|hash| { + Ok(Secp256k1::new().sign_ecdsa_recoverable(hash, &private_key)) + }) + } + .unwrap(); assert_eq!(Bolt11Invoice::from_signed(invoice), Err(Bolt11SemanticError::NoPaymentSecret)); // No payment secret or feature bits let invoice = { let mut invoice = invoice_template.clone(); invoice.data.tagged_fields.push(Features(Bolt11InvoiceFeatures::empty()).into()); - invoice.sign::<_, ()>(|hash| Ok(Secp256k1::new().sign_ecdsa_recoverable(hash, &private_key))) - }.unwrap(); + invoice.sign::<_, ()>(|hash| { + Ok(Secp256k1::new().sign_ecdsa_recoverable(hash, &private_key)) + }) + } + .unwrap(); assert_eq!(Bolt11Invoice::from_signed(invoice), Err(Bolt11SemanticError::NoPaymentSecret)); // Missing payment secret let invoice = { let mut invoice = invoice_template.clone(); invoice.data.tagged_fields.push(Features(payment_secret_features).into()); - invoice.sign::<_, ()>(|hash| Ok(Secp256k1::new().sign_ecdsa_recoverable(hash, &private_key))) - }.unwrap(); + invoice.sign::<_, ()>(|hash| { + Ok(Secp256k1::new().sign_ecdsa_recoverable(hash, &private_key)) + }) + } + .unwrap(); assert_eq!(Bolt11Invoice::from_signed(invoice), Err(Bolt11SemanticError::NoPaymentSecret)); // Multiple payment secrets @@ -1950,9 +2029,15 @@ mod test { let mut invoice = invoice_template; invoice.data.tagged_fields.push(PaymentSecret(payment_secret).into()); invoice.data.tagged_fields.push(PaymentSecret(payment_secret).into()); - invoice.sign::<_, ()>(|hash| Ok(Secp256k1::new().sign_ecdsa_recoverable(hash, &private_key))) - }.unwrap(); - assert_eq!(Bolt11Invoice::from_signed(invoice), Err(Bolt11SemanticError::MultiplePaymentSecrets)); + invoice.sign::<_, ()>(|hash| { + Ok(Secp256k1::new().sign_ecdsa_recoverable(hash, &private_key)) + }) + } + .unwrap(); + assert_eq!( + Bolt11Invoice::from_signed(invoice), + Err(Bolt11SemanticError::MultiplePaymentSecrets) + ); } #[test] @@ -1961,22 +2046,15 @@ mod test { let builder = InvoiceBuilder::new(Currency::Bitcoin) .description("Test".into()) - .payment_hash(sha256::Hash::from_slice(&[0;32][..]).unwrap()) + .payment_hash(sha256::Hash::from_slice(&[0; 32][..]).unwrap()) .duration_since_epoch(Duration::from_secs(1234567)); - let invoice = builder.clone() - .amount_milli_satoshis(1500) - .build_raw() - .unwrap(); + let invoice = builder.clone().amount_milli_satoshis(1500).build_raw().unwrap(); assert_eq!(invoice.hrp.si_prefix, Some(SiPrefix::Nano)); assert_eq!(invoice.hrp.raw_amount, Some(15)); - - let invoice = builder - .amount_milli_satoshis(150) - .build_raw() - .unwrap(); + let invoice = builder.amount_milli_satoshis(150).build_raw().unwrap(); assert_eq!(invoice.hrp.si_prefix, Some(SiPrefix::Pico)); assert_eq!(invoice.hrp.raw_amount, Some(1500)); @@ -1986,53 +2064,43 @@ mod test { fn test_builder_fail() { use crate::*; use lightning::routing::router::RouteHintHop; - use std::iter::FromIterator; use secp256k1::PublicKey; + use std::iter::FromIterator; let builder = InvoiceBuilder::new(Currency::Bitcoin) - .payment_hash(sha256::Hash::from_slice(&[0;32][..]).unwrap()) + .payment_hash(sha256::Hash::from_slice(&[0; 32][..]).unwrap()) .duration_since_epoch(Duration::from_secs(1234567)) .min_final_cltv_expiry_delta(144); - let too_long_string = String::from_iter( - (0..1024).map(|_| '?') - ); + let too_long_string = String::from_iter((0..1024).map(|_| '?')); - let long_desc_res = builder.clone() - .description(too_long_string) - .build_raw(); + let long_desc_res = builder.clone().description(too_long_string).build_raw(); assert_eq!(long_desc_res, Err(CreationError::DescriptionTooLong)); let route_hop = RouteHintHop { src_node_id: PublicKey::from_slice( - &[ - 0x03, 0x9e, 0x03, 0xa9, 0x01, 0xb8, 0x55, 0x34, 0xff, 0x1e, 0x92, 0xc4, - 0x3c, 0x74, 0x43, 0x1f, 0x7c, 0xe7, 0x20, 0x46, 0x06, 0x0f, 0xcf, 0x7a, - 0x95, 0xc3, 0x7e, 0x14, 0x8f, 0x78, 0xc7, 0x72, 0x55 - ][..] - ).unwrap(), + &[ + 0x03, 0x9e, 0x03, 0xa9, 0x01, 0xb8, 0x55, 0x34, 0xff, 0x1e, 0x92, 0xc4, 0x3c, + 0x74, 0x43, 0x1f, 0x7c, 0xe7, 0x20, 0x46, 0x06, 0x0f, 0xcf, 0x7a, 0x95, 0xc3, + 0x7e, 0x14, 0x8f, 0x78, 0xc7, 0x72, 0x55, + ][..], + ) + .unwrap(), short_channel_id: 0, - fees: RoutingFees { - base_msat: 0, - proportional_millionths: 0, - }, + fees: RoutingFees { base_msat: 0, proportional_millionths: 0 }, cltv_expiry_delta: 0, htlc_minimum_msat: None, htlc_maximum_msat: None, }; let too_long_route = RouteHint(vec![route_hop; 13]); - let long_route_res = builder.clone() - .description("Test".into()) - .private_route(too_long_route) - .build_raw(); + let long_route_res = + builder.clone().description("Test".into()).private_route(too_long_route).build_raw(); assert_eq!(long_route_res, Err(CreationError::RouteTooLong)); let sign_error_res = builder .description("Test".into()) .payment_secret(PaymentSecret([0; 32])) - .try_build_signed(|_| { - Err("ImaginaryError") - }); + .try_build_signed(|_| Err("ImaginaryError")); assert_eq!(sign_error_res, Err(SignOrCreationError::SignError("ImaginaryError"))); } @@ -2041,8 +2109,8 @@ mod test { use crate::*; use lightning::routing::router::RouteHintHop; use secp256k1::Secp256k1; - use secp256k1::{SecretKey, PublicKey}; - use std::time::{UNIX_EPOCH, Duration}; + use secp256k1::{PublicKey, SecretKey}; + use std::time::{Duration, UNIX_EPOCH}; let secp_ctx = Secp256k1::new(); @@ -2050,59 +2118,51 @@ mod test { &[ 0xe1, 0x26, 0xf6, 0x8f, 0x7e, 0xaf, 0xcc, 0x8b, 0x74, 0xf5, 0x4d, 0x26, 0x9f, 0xe2, 0x06, 0xbe, 0x71, 0x50, 0x00, 0xf9, 0x4d, 0xac, 0x06, 0x7d, 0x1c, 0x04, 0xa8, 0xca, - 0x3b, 0x2d, 0xb7, 0x34 - ][..] - ).unwrap(); + 0x3b, 0x2d, 0xb7, 0x34, + ][..], + ) + .unwrap(); let public_key = PublicKey::from_secret_key(&secp_ctx, &private_key); let route_1 = RouteHint(vec![ RouteHintHop { src_node_id: public_key, - short_channel_id: de::parse_int_be(&[123; 8], 256).expect("short chan ID slice too big?"), - fees: RoutingFees { - base_msat: 2, - proportional_millionths: 1, - }, + short_channel_id: de::parse_int_be(&[123; 8], 256) + .expect("short chan ID slice too big?"), + fees: RoutingFees { base_msat: 2, proportional_millionths: 1 }, cltv_expiry_delta: 145, htlc_minimum_msat: None, htlc_maximum_msat: None, }, RouteHintHop { src_node_id: public_key, - short_channel_id: de::parse_int_be(&[42; 8], 256).expect("short chan ID slice too big?"), - fees: RoutingFees { - base_msat: 3, - proportional_millionths: 2, - }, + short_channel_id: de::parse_int_be(&[42; 8], 256) + .expect("short chan ID slice too big?"), + fees: RoutingFees { base_msat: 3, proportional_millionths: 2 }, cltv_expiry_delta: 146, htlc_minimum_msat: None, htlc_maximum_msat: None, - } + }, ]); let route_2 = RouteHint(vec![ RouteHintHop { src_node_id: public_key, short_channel_id: 0, - fees: RoutingFees { - base_msat: 4, - proportional_millionths: 3, - }, + fees: RoutingFees { base_msat: 4, proportional_millionths: 3 }, cltv_expiry_delta: 147, htlc_minimum_msat: None, htlc_maximum_msat: None, }, RouteHintHop { src_node_id: public_key, - short_channel_id: de::parse_int_be(&[1; 8], 256).expect("short chan ID slice too big?"), - fees: RoutingFees { - base_msat: 5, - proportional_millionths: 4, - }, + short_channel_id: de::parse_int_be(&[1; 8], 256) + .expect("short chan ID slice too big?"), + fees: RoutingFees { base_msat: 5, proportional_millionths: 4 }, cltv_expiry_delta: 148, htlc_minimum_msat: None, htlc_maximum_msat: None, - } + }, ]); let builder = InvoiceBuilder::new(Currency::BitcoinTestnet) @@ -2111,17 +2171,18 @@ mod test { .payee_pub_key(public_key) .expiry_time(Duration::from_secs(54321)) .min_final_cltv_expiry_delta(144) - .fallback(Fallback::PubKeyHash(PubkeyHash::from_slice(&[0;20]).unwrap())) + .fallback(Fallback::PubKeyHash(PubkeyHash::from_slice(&[0; 20]).unwrap())) .private_route(route_1.clone()) .private_route(route_2.clone()) - .description_hash(sha256::Hash::from_slice(&[3;32][..]).unwrap()) - .payment_hash(sha256::Hash::from_slice(&[21;32][..]).unwrap()) + .description_hash(sha256::Hash::from_slice(&[3; 32][..]).unwrap()) + .payment_hash(sha256::Hash::from_slice(&[21; 32][..]).unwrap()) .payment_secret(PaymentSecret([42; 32])) .basic_mpp(); - let invoice = builder.clone().build_signed(|hash| { - secp_ctx.sign_ecdsa_recoverable(hash, &private_key) - }).unwrap(); + let invoice = builder + .clone() + .build_signed(|hash| secp_ctx.sign_ecdsa_recoverable(hash, &private_key)) + .unwrap(); assert!(invoice.check_signature().is_ok()); assert_eq!(invoice.tagged_fields().count(), 10); @@ -2130,22 +2191,27 @@ mod test { assert_eq!(invoice.amount_pico_btc(), Some(1230)); assert_eq!(invoice.currency(), Currency::BitcoinTestnet); #[cfg(feature = "std")] - assert_eq!( - invoice.timestamp().duration_since(UNIX_EPOCH).unwrap().as_secs(), - 1234567 - ); + assert_eq!(invoice.timestamp().duration_since(UNIX_EPOCH).unwrap().as_secs(), 1234567); assert_eq!(invoice.payee_pub_key(), Some(&public_key)); assert_eq!(invoice.expiry_time(), Duration::from_secs(54321)); assert_eq!(invoice.min_final_cltv_expiry_delta(), 144); - assert_eq!(invoice.fallbacks(), vec![&Fallback::PubKeyHash(PubkeyHash::from_slice(&[0;20]).unwrap())]); - let address = Address::from_script(&Script::new_p2pkh(&PubkeyHash::from_slice(&[0;20]).unwrap()), Network::Testnet).unwrap(); + assert_eq!(invoice.fallbacks(), vec![&Fallback::PubKeyHash( + PubkeyHash::from_slice(&[0; 20]).unwrap() + )]); + let address = Address::from_script( + &Script::new_p2pkh(&PubkeyHash::from_slice(&[0; 20]).unwrap()), + Network::Testnet, + ) + .unwrap(); assert_eq!(invoice.fallback_addresses(), vec![address]); assert_eq!(invoice.private_routes(), vec![&PrivateRoute(route_1), &PrivateRoute(route_2)]); assert_eq!( invoice.description(), - Bolt11InvoiceDescription::Hash(&Sha256(sha256::Hash::from_slice(&[3;32][..]).unwrap())) + Bolt11InvoiceDescription::Hash(&Sha256( + sha256::Hash::from_slice(&[3; 32][..]).unwrap() + )) ); - assert_eq!(invoice.payment_hash(), &sha256::Hash::from_slice(&[21;32][..]).unwrap()); + assert_eq!(invoice.payment_hash(), &sha256::Hash::from_slice(&[21; 32][..]).unwrap()); assert_eq!(invoice.payment_secret(), &PaymentSecret([42; 32])); let mut expected_features = Bolt11InvoiceFeatures::empty(); @@ -2166,7 +2232,7 @@ mod test { let signed_invoice = InvoiceBuilder::new(Currency::Bitcoin) .description("Test".into()) - .payment_hash(sha256::Hash::from_slice(&[0;32][..]).unwrap()) + .payment_hash(sha256::Hash::from_slice(&[0; 32][..]).unwrap()) .payment_secret(PaymentSecret([0; 32])) .duration_since_epoch(Duration::from_secs(1234567)) .build_raw() @@ -2192,7 +2258,7 @@ mod test { let signed_invoice = InvoiceBuilder::new(Currency::Bitcoin) .description("Test".into()) - .payment_hash(sha256::Hash::from_slice(&[0;32][..]).unwrap()) + .payment_hash(sha256::Hash::from_slice(&[0; 32][..]).unwrap()) .payment_secret(PaymentSecret([0; 32])) .duration_since_epoch(Duration::from_secs(1234567)) .build_raw() @@ -2224,7 +2290,8 @@ mod test { j5r6drg6k6zcqj0fcwg"; let invoice = invoice_str.parse::().unwrap(); let serialized_invoice = serde_json::to_string(&invoice).unwrap(); - let deserialized_invoice: super::Bolt11Invoice = serde_json::from_str(serialized_invoice.as_str()).unwrap(); + let deserialized_invoice: super::Bolt11Invoice = + serde_json::from_str(serialized_invoice.as_str()).unwrap(); assert_eq!(invoice, deserialized_invoice); assert_eq!(invoice_str, deserialized_invoice.to_string().as_str()); assert_eq!(invoice_str, serialized_invoice.as_str().trim_matches('\"')); diff --git a/lightning-invoice/src/payment.rs b/lightning-invoice/src/payment.rs index 89842591fde..6416064d3b2 100644 --- a/lightning-invoice/src/payment.rs +++ b/lightning-invoice/src/payment.rs @@ -9,17 +9,20 @@ //! Convenient utilities for paying Lightning invoices. -use crate::Bolt11Invoice; use crate::prelude::*; +use crate::Bolt11Invoice; use bitcoin_hashes::Hash; use lightning::chain; use lightning::chain::chaininterface::{BroadcasterInterface, FeeEstimator}; -use lightning::sign::{NodeSigner, SignerProvider, EntropySource}; +use lightning::ln::channelmanager::{ + AChannelManager, ChannelManager, PaymentId, ProbeSendFailure, RecipientOnionFields, Retry, + RetryableSendFailure, +}; use lightning::ln::PaymentHash; -use lightning::ln::channelmanager::{AChannelManager, ChannelManager, PaymentId, Retry, RetryableSendFailure, RecipientOnionFields, ProbeSendFailure}; use lightning::routing::router::{PaymentParameters, RouteParameters, Router}; +use lightning::sign::{EntropySource, NodeSigner, SignerProvider}; use lightning::util::logger::Logger; use core::fmt::Debug; @@ -34,9 +37,10 @@ use core::time::Duration; /// /// If you wish to use a different payment idempotency token, see [`pay_invoice_with_id`]. pub fn pay_invoice( - invoice: &Bolt11Invoice, retry_strategy: Retry, channelmanager: C + invoice: &Bolt11Invoice, retry_strategy: Retry, channelmanager: C, ) -> Result -where C::Target: AChannelManager, +where + C::Target: AChannelManager, { let payment_id = PaymentId(invoice.payment_hash().into_inner()); pay_invoice_with_id(invoice, payment_id, retry_strategy, channelmanager.get_cm()) @@ -54,11 +58,13 @@ where C::Target: AChannelManager, /// /// See [`pay_invoice`] for a variant which uses the [`PaymentHash`] for the idempotency token. pub fn pay_invoice_with_id( - invoice: &Bolt11Invoice, payment_id: PaymentId, retry_strategy: Retry, channelmanager: C + invoice: &Bolt11Invoice, payment_id: PaymentId, retry_strategy: Retry, channelmanager: C, ) -> Result<(), PaymentError> -where C::Target: AChannelManager, +where + C::Target: AChannelManager, { - let amt_msat = invoice.amount_milli_satoshis().ok_or(PaymentError::Invoice("amount missing"))?; + let amt_msat = + invoice.amount_milli_satoshis().ok_or(PaymentError::Invoice("amount missing"))?; pay_invoice_using_amount(invoice, amt_msat, payment_id, retry_strategy, channelmanager.get_cm()) } @@ -72,14 +78,20 @@ where C::Target: AChannelManager, /// If you wish to use a different payment idempotency token, see /// [`pay_zero_value_invoice_with_id`]. pub fn pay_zero_value_invoice( - invoice: &Bolt11Invoice, amount_msats: u64, retry_strategy: Retry, channelmanager: C + invoice: &Bolt11Invoice, amount_msats: u64, retry_strategy: Retry, channelmanager: C, ) -> Result -where C::Target: AChannelManager, +where + C::Target: AChannelManager, { let payment_id = PaymentId(invoice.payment_hash().into_inner()); - pay_zero_value_invoice_with_id(invoice, amount_msats, payment_id, retry_strategy, - channelmanager) - .map(|()| payment_id) + pay_zero_value_invoice_with_id( + invoice, + amount_msats, + payment_id, + retry_strategy, + channelmanager, + ) + .map(|()| payment_id) } /// Pays the given zero-value [`Bolt11Invoice`] using the given amount and custom idempotency key, @@ -95,29 +107,41 @@ where C::Target: AChannelManager, /// idempotency token. pub fn pay_zero_value_invoice_with_id( invoice: &Bolt11Invoice, amount_msats: u64, payment_id: PaymentId, retry_strategy: Retry, - channelmanager: C + channelmanager: C, ) -> Result<(), PaymentError> -where C::Target: AChannelManager, +where + C::Target: AChannelManager, { if invoice.amount_milli_satoshis().is_some() { Err(PaymentError::Invoice("amount unexpected")) } else { - pay_invoice_using_amount(invoice, amount_msats, payment_id, retry_strategy, - channelmanager.get_cm()) + pay_invoice_using_amount( + invoice, + amount_msats, + payment_id, + retry_strategy, + channelmanager.get_cm(), + ) } } fn pay_invoice_using_amount( invoice: &Bolt11Invoice, amount_msats: u64, payment_id: PaymentId, retry_strategy: Retry, - payer: P -) -> Result<(), PaymentError> where P::Target: Payer { + payer: P, +) -> Result<(), PaymentError> +where + P::Target: Payer, +{ let payment_hash = PaymentHash((*invoice.payment_hash()).into_inner()); let mut recipient_onion = RecipientOnionFields::secret_only(*invoice.payment_secret()); recipient_onion.payment_metadata = invoice.payment_metadata().map(|v| v.clone()); - let mut payment_params = PaymentParameters::from_node_id(invoice.recover_payee_pub_key(), - invoice.min_final_cltv_expiry_delta() as u32) - .with_expiry_time(expiry_time_from_unix_epoch(invoice).as_secs()) - .with_route_hints(invoice.route_hints()).unwrap(); + let mut payment_params = PaymentParameters::from_node_id( + invoice.recover_payee_pub_key(), + invoice.min_final_cltv_expiry_delta() as u32, + ) + .with_expiry_time(expiry_time_from_unix_epoch(invoice).as_secs()) + .with_route_hints(invoice.route_hints()) + .unwrap(); if let Some(features) = invoice.features() { payment_params = payment_params.with_bolt11_features(features.clone()).unwrap(); } @@ -132,12 +156,15 @@ fn pay_invoice_using_amount( pub fn preflight_probe_invoice( invoice: &Bolt11Invoice, channelmanager: C, liquidity_limit_multiplier: Option, ) -> Result, ProbingError> -where C::Target: AChannelManager, +where + C::Target: AChannelManager, { let amount_msat = if let Some(invoice_amount_msat) = invoice.amount_milli_satoshis() { invoice_amount_msat } else { - return Err(ProbingError::Invoice("Failed to send probe as no amount was given in the invoice.")); + return Err(ProbingError::Invoice( + "Failed to send probe as no amount was given in the invoice.", + )); }; let mut payment_params = PaymentParameters::from_node_id( @@ -153,7 +180,9 @@ where C::Target: AChannelManager, } let route_params = RouteParameters::from_payment_params_and_value(payment_params, amount_msat); - channelmanager.get_cm().send_preflight_probes(route_params, liquidity_limit_multiplier) + channelmanager + .get_cm() + .send_preflight_probes(route_params, liquidity_limit_multiplier) .map_err(ProbingError::Sending) } @@ -165,7 +194,8 @@ pub fn preflight_probe_zero_value_invoice( invoice: &Bolt11Invoice, amount_msat: u64, channelmanager: C, liquidity_limit_multiplier: Option, ) -> Result, ProbingError> -where C::Target: AChannelManager, +where + C::Target: AChannelManager, { if invoice.amount_milli_satoshis().is_some() { return Err(ProbingError::Invoice("amount unexpected")); @@ -184,7 +214,9 @@ where C::Target: AChannelManager, } let route_params = RouteParameters::from_payment_params_and_value(payment_params, amount_msat); - channelmanager.get_cm().send_preflight_probes(route_params, liquidity_limit_multiplier) + channelmanager + .get_cm() + .send_preflight_probes(route_params, liquidity_limit_multiplier) .map_err(ProbingError::Sending) } @@ -219,24 +251,25 @@ trait Payer { /// [`Route`]: lightning::routing::router::Route fn send_payment( &self, payment_hash: PaymentHash, recipient_onion: RecipientOnionFields, - payment_id: PaymentId, route_params: RouteParameters, retry_strategy: Retry + payment_id: PaymentId, route_params: RouteParameters, retry_strategy: Retry, ) -> Result<(), PaymentError>; } -impl Payer for ChannelManager +impl Payer + for ChannelManager where - M::Target: chain::Watch<::Signer>, - T::Target: BroadcasterInterface, - ES::Target: EntropySource, - NS::Target: NodeSigner, - SP::Target: SignerProvider, - F::Target: FeeEstimator, - R::Target: Router, - L::Target: Logger, + M::Target: chain::Watch<::Signer>, + T::Target: BroadcasterInterface, + ES::Target: EntropySource, + NS::Target: NodeSigner, + SP::Target: SignerProvider, + F::Target: FeeEstimator, + R::Target: Router, + L::Target: Logger, { fn send_payment( &self, payment_hash: PaymentHash, recipient_onion: RecipientOnionFields, - payment_id: PaymentId, route_params: RouteParameters, retry_strategy: Retry + payment_id: PaymentId, route_params: RouteParameters, retry_strategy: Retry, ) -> Result<(), PaymentError> { self.send_payment(payment_hash, recipient_onion, payment_id, route_params, retry_strategy) .map_err(PaymentError::Sending) @@ -246,15 +279,15 @@ where #[cfg(test)] mod tests { use super::*; - use crate::{InvoiceBuilder, Currency}; + use crate::{Currency, InvoiceBuilder}; use bitcoin_hashes::sha256::Hash as Sha256; use lightning::events::Event; + use lightning::ln::functional_test_utils::*; use lightning::ln::msgs::ChannelMessageHandler; use lightning::ln::{PaymentPreimage, PaymentSecret}; - use lightning::ln::functional_test_utils::*; - use secp256k1::{SecretKey, Secp256k1}; + use secp256k1::{Secp256k1, SecretKey}; use std::collections::VecDeque; - use std::time::{SystemTime, Duration}; + use std::time::{Duration, SystemTime}; struct TestPayer { expectations: core::cell::RefCell>, @@ -262,9 +295,7 @@ mod tests { impl TestPayer { fn new() -> Self { - Self { - expectations: core::cell::RefCell::new(VecDeque::new()), - } + Self { expectations: core::cell::RefCell::new(VecDeque::new()) } } fn expect_send(self, value_msat: Amount) -> Self { @@ -288,7 +319,7 @@ mod tests { impl Payer for TestPayer { fn send_payment( &self, _payment_hash: PaymentHash, _recipient_onion: RecipientOnionFields, - _payment_id: PaymentId, route_params: RouteParameters, _retry_strategy: Retry + _payment_id: PaymentId, route_params: RouteParameters, _retry_strategy: Retry, ) -> Result<(), PaymentError> { self.check_value_msats(Amount(route_params.final_value_msat)); Ok(()) @@ -309,8 +340,7 @@ mod tests { fn duration_since_epoch() -> Duration { #[cfg(feature = "std")] - let duration_since_epoch = - SystemTime::now().duration_since(SystemTime::UNIX_EPOCH).unwrap(); + let duration_since_epoch = SystemTime::now().duration_since(SystemTime::UNIX_EPOCH).unwrap(); #[cfg(not(feature = "std"))] let duration_since_epoch = Duration::from_secs(1234567); duration_since_epoch @@ -327,9 +357,7 @@ mod tests { .duration_since_epoch(duration_since_epoch()) .min_final_cltv_expiry_delta(144) .amount_milli_satoshis(128) - .build_signed(|hash| { - Secp256k1::new().sign_ecdsa_recoverable(hash, &private_key) - }) + .build_signed(|hash| Secp256k1::new().sign_ecdsa_recoverable(hash, &private_key)) .unwrap() } @@ -343,10 +371,8 @@ mod tests { .payment_secret(PaymentSecret([0; 32])) .duration_since_epoch(duration_since_epoch()) .min_final_cltv_expiry_delta(144) - .build_signed(|hash| { - Secp256k1::new().sign_ecdsa_recoverable(hash, &private_key) - }) - .unwrap() + .build_signed(|hash| Secp256k1::new().sign_ecdsa_recoverable(hash, &private_key)) + .unwrap() } #[test] @@ -357,7 +383,14 @@ mod tests { let final_value_msat = invoice.amount_milli_satoshis().unwrap(); let payer = TestPayer::new().expect_send(Amount(final_value_msat)); - pay_invoice_using_amount(&invoice, final_value_msat, payment_id, Retry::Attempts(0), &payer).unwrap(); + pay_invoice_using_amount( + &invoice, + final_value_msat, + payment_id, + Retry::Attempts(0), + &payer, + ) + .unwrap(); } #[test] @@ -368,7 +401,8 @@ mod tests { let amt_msat = 10_000; let payer = TestPayer::new().expect_send(Amount(amt_msat)); - pay_invoice_using_amount(&invoice, amt_msat, payment_id, Retry::Attempts(0), &payer).unwrap(); + pay_invoice_using_amount(&invoice, amt_msat, payment_id, Retry::Attempts(0), &payer) + .unwrap(); } #[test] @@ -384,7 +418,7 @@ mod tests { match pay_zero_value_invoice(&invoice, amt_msat, Retry::Attempts(0), nodes[0].node) { Err(PaymentError::Invoice("amount unexpected")) => {}, - _ => panic!() + _ => panic!(), } } @@ -413,8 +447,10 @@ mod tests { .amount_milli_satoshis(50_000) .payment_metadata(payment_metadata.clone()) .build_signed(|hash| { - Secp256k1::new().sign_ecdsa_recoverable(hash, - &nodes[1].keys_manager.backing.get_node_secret_key()) + Secp256k1::new().sign_ecdsa_recoverable( + hash, + &nodes[1].keys_manager.backing.get_node_secret_key(), + ) }) .unwrap(); @@ -432,7 +468,7 @@ mod tests { Event::PaymentClaimable { onion_fields, .. } => { assert_eq!(Some(payment_metadata), onion_fields.unwrap().payment_metadata); }, - _ => panic!("Unexpected event") + _ => panic!("Unexpected event"), } } } diff --git a/lightning-invoice/src/ser.rs b/lightning-invoice/src/ser.rs index dc5dba45da0..503f993e360 100644 --- a/lightning-invoice/src/ser.rs +++ b/lightning-invoice/src/ser.rs @@ -1,10 +1,13 @@ +use crate::prelude::*; +use bech32::{u5, Base32Len, ToBase32, WriteBase32}; use core::fmt; use core::fmt::{Display, Formatter}; -use bech32::{ToBase32, u5, WriteBase32, Base32Len}; -use crate::prelude::*; -use super::{Bolt11Invoice, Sha256, TaggedField, ExpiryTime, MinFinalCltvExpiryDelta, Fallback, PayeePubKey, Bolt11InvoiceSignature, PositiveTimestamp, - PrivateRoute, Description, RawTaggedField, Currency, RawHrp, SiPrefix, constants, SignedRawBolt11Invoice, RawDataPart}; +use super::{ + constants, Bolt11Invoice, Bolt11InvoiceSignature, Currency, Description, ExpiryTime, Fallback, + MinFinalCltvExpiryDelta, PayeePubKey, PositiveTimestamp, PrivateRoute, RawDataPart, RawHrp, + RawTaggedField, Sha256, SiPrefix, SignedRawBolt11Invoice, TaggedField, +}; /// Converts a stream of bytes written to it to base32. On finalization the according padding will /// be applied. That means the results of writing two data blocks with one or two `BytesToBase32` @@ -24,11 +27,7 @@ impl<'a, W: WriteBase32> BytesToBase32<'a, W> { /// Create a new bytes-to-base32 converter with `writer` as a sink for the resulting base32 /// data. pub fn new(writer: &'a mut W) -> BytesToBase32<'a, W> { - BytesToBase32 { - writer, - buffer: 0, - buffer_bits: 0, - } + BytesToBase32 { writer, buffer: 0, buffer_bits: 0 } } /// Add more bytes to the current conversion unit @@ -44,9 +43,7 @@ impl<'a, W: WriteBase32> BytesToBase32<'a, W> { // buffer holds too many bits, so we don't have to combine buffer bits with new bits // from this rounds byte. if self.buffer_bits >= 5 { - self.writer.write_u5( - u5::try_from_u8((self.buffer & 0b11111000) >> 3 ).expect("<32") - )?; + self.writer.write_u5(u5::try_from_u8((self.buffer & 0b11111000) >> 3).expect("<32"))?; self.buffer <<= 5; self.buffer_bits -= 5; } @@ -63,18 +60,16 @@ impl<'a, W: WriteBase32> BytesToBase32<'a, W> { Ok(()) } - pub fn finalize(mut self) -> Result<(), W::Err> { + pub fn finalize(mut self) -> Result<(), W::Err> { self.inner_finalize()?; core::mem::forget(self); Ok(()) } - fn inner_finalize(&mut self) -> Result<(), W::Err>{ + fn inner_finalize(&mut self) -> Result<(), W::Err> { // There can be at most two u5s left in the buffer after processing all bytes, write them. if self.buffer_bits >= 5 { - self.writer.write_u5( - u5::try_from_u8((self.buffer & 0b11111000) >> 3).expect("<32") - )?; + self.writer.write_u5(u5::try_from_u8((self.buffer & 0b11111000) >> 3).expect("<32"))?; self.buffer <<= 5; self.buffer_bits -= 5; } @@ -115,7 +110,7 @@ impl Display for Bolt11Invoice { impl Display for SignedRawBolt11Invoice { fn fmt(&self, f: &mut Formatter) -> Result<(), fmt::Error> { let hrp = self.raw_invoice.hrp.to_string(); - let mut data = self.raw_invoice.data.to_base32(); + let mut data = self.raw_invoice.data.to_base32(); data.extend_from_slice(&self.signature.to_base32()); bech32::encode_to_fmt(f, &hrp, data, bech32::Variant::Bech32).expect("HRP is valid")?; @@ -137,13 +132,7 @@ impl Display for RawHrp { None => String::new(), }; - write!( - f, - "ln{}{}{}", - self.currency, - amount, - si_prefix - ) + write!(f, "ln{}{}{}", self.currency, amount, si_prefix) } } @@ -162,14 +151,12 @@ impl Display for Currency { impl Display for SiPrefix { fn fmt(&self, f: &mut Formatter) -> Result<(), fmt::Error> { - write!(f, "{}", - match *self { - SiPrefix::Milli => "m", - SiPrefix::Micro => "u", - SiPrefix::Nano => "n", - SiPrefix::Pico => "p", - } - ) + write!(f, "{}", match *self { + SiPrefix::Milli => "m", + SiPrefix::Micro => "u", + SiPrefix::Nano => "n", + SiPrefix::Pico => "p", + }) } } @@ -215,7 +202,8 @@ fn encode_int_be_base256>(int: T) -> Vec { /// Appends the default value of `T` to the front of the `in_vec` till it reaches the length /// `target_length`. If `in_vec` already is too lang `None` is returned. fn try_stretch(mut in_vec: Vec, target_len: usize) -> Option> - where T: Default + Copy +where + T: Default + Copy, { if in_vec.len() > target_len { None @@ -248,7 +236,7 @@ impl ToBase32 for PositiveTimestamp { // FIXME: use writer for int encoding writer.write( &try_stretch(encode_int_be_base32(self.as_unix_timestamp()), 7) - .expect("Can't be longer due than 7 u5s due to timestamp bounds") + .expect("Can't be longer due than 7 u5s due to timestamp bounds"), ) } } @@ -256,12 +244,8 @@ impl ToBase32 for PositiveTimestamp { impl ToBase32 for RawTaggedField { fn write_base32(&self, writer: &mut W) -> Result<(), ::Err> { match *self { - RawTaggedField::UnknownSemantics(ref content) => { - writer.write(content) - }, - RawTaggedField::KnownSemantics(ref tagged_field) => { - tagged_field.write_base32(writer) - } + RawTaggedField::UnknownSemantics(ref content) => writer.write(content), + RawTaggedField::KnownSemantics(ref tagged_field) => tagged_field.write_base32(writer), } } } @@ -328,7 +312,7 @@ impl Base32Len for MinFinalCltvExpiryDelta { impl ToBase32 for Fallback { fn write_base32(&self, writer: &mut W) -> Result<(), ::Err> { match *self { - Fallback::SegWitProgram {version: v, program: ref p} => { + Fallback::SegWitProgram { version: v, program: ref p } => { writer.write_u5(Into::::into(v))?; p.write_base32(writer) }, @@ -339,7 +323,7 @@ impl ToBase32 for Fallback { Fallback::ScriptHash(ref hash) => { writer.write_u5(u5::try_from_u8(18).expect("18 < 32"))?; (&hash[..]).write_base32(writer) - } + }, } } } @@ -347,12 +331,10 @@ impl ToBase32 for Fallback { impl Base32Len for Fallback { fn base32_len(&self) -> usize { match *self { - Fallback::SegWitProgram {program: ref p, ..} => { + Fallback::SegWitProgram { program: ref p, .. } => { bytes_size_to_base32_size(p.len()) + 1 }, - Fallback::PubKeyHash(_) | Fallback::ScriptHash(_) => { - 33 - }, + Fallback::PubKeyHash(_) | Fallback::ScriptHash(_) => 33, } } } @@ -363,28 +345,21 @@ impl ToBase32 for PrivateRoute { for hop in (self.0).0.iter() { converter.append(&hop.src_node_id.serialize()[..])?; - let short_channel_id = try_stretch( - encode_int_be_base256(hop.short_channel_id), - 8 - ).expect("sizeof(u64) == 8"); + let short_channel_id = try_stretch(encode_int_be_base256(hop.short_channel_id), 8) + .expect("sizeof(u64) == 8"); converter.append(&short_channel_id)?; - let fee_base_msat = try_stretch( - encode_int_be_base256(hop.fees.base_msat), - 4 - ).expect("sizeof(u32) == 4"); + let fee_base_msat = try_stretch(encode_int_be_base256(hop.fees.base_msat), 4) + .expect("sizeof(u32) == 4"); converter.append(&fee_base_msat)?; - let fee_proportional_millionths = try_stretch( - encode_int_be_base256(hop.fees.proportional_millionths), - 4 - ).expect("sizeof(u32) == 4"); + let fee_proportional_millionths = + try_stretch(encode_int_be_base256(hop.fees.proportional_millionths), 4) + .expect("sizeof(u32) == 4"); converter.append(&fee_proportional_millionths)?; - let cltv_expiry_delta = try_stretch( - encode_int_be_base256(hop.cltv_expiry_delta), - 2 - ).expect("sizeof(u16) == 2"); + let cltv_expiry_delta = try_stretch(encode_int_be_base256(hop.cltv_expiry_delta), 2) + .expect("sizeof(u16) == 2"); converter.append(&cltv_expiry_delta)?; } @@ -404,17 +379,18 @@ impl ToBase32 for TaggedField { /// Writes a tagged field: tag, length and data. `tag` should be in `0..32` otherwise the /// function will panic. fn write_tagged_field(writer: &mut W, tag: u8, payload: &P) -> Result<(), W::Err> - where W: WriteBase32, - P: ToBase32 + Base32Len, + where + W: WriteBase32, + P: ToBase32 + Base32Len, { let len = payload.base32_len(); assert!(len < 1024, "Every tagged field data can be at most 1023 bytes long."); writer.write_u5(u5::try_from_u8(tag).expect("invalid tag, not in 0..32"))?; - writer.write(&try_stretch( - encode_int_be_base32(len as u64), - 2 - ).expect("Can't be longer than 2, see assert above."))?; + writer.write( + &try_stretch(encode_int_be_base32(len as u64), 2) + .expect("Can't be longer than 2, see assert above."), + )?; payload.write_base32(writer) } @@ -444,10 +420,10 @@ impl ToBase32 for TaggedField { write_tagged_field(writer, constants::TAG_PRIVATE_ROUTE, route_hops) }, TaggedField::PaymentSecret(ref payment_secret) => { - write_tagged_field(writer, constants::TAG_PAYMENT_SECRET, payment_secret) + write_tagged_field(writer, constants::TAG_PAYMENT_SECRET, payment_secret) }, TaggedField::PaymentMetadata(ref payment_metadata) => { - write_tagged_field(writer, constants::TAG_PAYMENT_METADATA, payment_metadata) + write_tagged_field(writer, constants::TAG_PAYMENT_METADATA, payment_metadata) }, TaggedField::Features(ref features) => { write_tagged_field(writer, constants::TAG_FEATURES, features) diff --git a/lightning-invoice/src/sync.rs b/lightning-invoice/src/sync.rs index fae923feb65..1da755f8af7 100644 --- a/lightning-invoice/src/sync.rs +++ b/lightning-invoice/src/sync.rs @@ -4,7 +4,7 @@ use core::ops::{Deref, DerefMut}; pub type LockResult = Result; pub struct Mutex { - inner: RefCell + inner: RefCell, } #[must_use = "if unused the Mutex will immediately unlock"] diff --git a/lightning-invoice/src/tb.rs b/lightning-invoice/src/tb.rs index dde8a53f99c..9f8a5135906 100644 --- a/lightning-invoice/src/tb.rs +++ b/lightning-invoice/src/tb.rs @@ -7,4 +7,4 @@ pub struct True {} pub struct False {} impl Bool for True {} -impl Bool for False {} \ No newline at end of file +impl Bool for False {} diff --git a/lightning-invoice/src/utils.rs b/lightning-invoice/src/utils.rs index a512b2de05d..91099e1ec46 100644 --- a/lightning-invoice/src/utils.rs +++ b/lightning-invoice/src/utils.rs @@ -2,23 +2,23 @@ use crate::{Bolt11Invoice, CreationError, Currency, InvoiceBuilder, SignOrCreationError}; -use crate::{prelude::*, Description, Bolt11InvoiceDescription, Sha256}; +use crate::{prelude::*, Bolt11InvoiceDescription, Description, Sha256}; use bech32::ToBase32; use bitcoin_hashes::Hash; +use core::iter::Iterator; +use core::ops::Deref; +use core::time::Duration; use lightning::chain; use lightning::chain::chaininterface::{BroadcasterInterface, FeeEstimator}; -use lightning::sign::{Recipient, NodeSigner, SignerProvider, EntropySource}; -use lightning::ln::{PaymentHash, PaymentSecret}; use lightning::ln::channelmanager::{ChannelDetails, ChannelManager, MIN_FINAL_CLTV_EXPIRY_DELTA}; use lightning::ln::channelmanager::{PhantomRouteHints, MIN_CLTV_EXPIRY_DELTA}; use lightning::ln::inbound_payment::{create, create_from_hash, ExpandedKey}; +use lightning::ln::{PaymentHash, PaymentSecret}; use lightning::routing::gossip::RoutingFees; use lightning::routing::router::{RouteHint, RouteHintHop, Router}; +use lightning::sign::{EntropySource, NodeSigner, Recipient, SignerProvider}; use lightning::util::logger::Logger; use secp256k1::PublicKey; -use core::ops::Deref; -use core::time::Duration; -use core::iter::Iterator; /// Utility to create an invoice that can be paid to one of multiple nodes, or a "phantom invoice." /// See [`PhantomKeysManager`] for more information on phantom node payments. @@ -62,8 +62,9 @@ use core::iter::Iterator; /// available and the current time is supplied by the caller. pub fn create_phantom_invoice( amt_msat: Option, payment_hash: Option, description: String, - invoice_expiry_delta_secs: u32, phantom_route_hints: Vec, entropy_source: ES, - node_signer: NS, logger: L, network: Currency, min_final_cltv_expiry_delta: Option, duration_since_epoch: Duration, + invoice_expiry_delta_secs: u32, phantom_route_hints: Vec, + entropy_source: ES, node_signer: NS, logger: L, network: Currency, + min_final_cltv_expiry_delta: Option, duration_since_epoch: Duration, ) -> Result> where ES::Target: EntropySource, @@ -71,10 +72,19 @@ where L::Target: Logger, { let description = Description::new(description).map_err(SignOrCreationError::CreationError)?; - let description = Bolt11InvoiceDescription::Direct(&description,); + let description = Bolt11InvoiceDescription::Direct(&description); _create_phantom_invoice::( - amt_msat, payment_hash, description, invoice_expiry_delta_secs, phantom_route_hints, - entropy_source, node_signer, logger, network, min_final_cltv_expiry_delta, duration_since_epoch, + amt_msat, + payment_hash, + description, + invoice_expiry_delta_secs, + phantom_route_hints, + entropy_source, + node_signer, + logger, + network, + min_final_cltv_expiry_delta, + duration_since_epoch, ) } @@ -119,7 +129,8 @@ where pub fn create_phantom_invoice_with_description_hash( amt_msat: Option, payment_hash: Option, invoice_expiry_delta_secs: u32, description_hash: Sha256, phantom_route_hints: Vec, entropy_source: ES, - node_signer: NS, logger: L, network: Currency, min_final_cltv_expiry_delta: Option, duration_since_epoch: Duration, + node_signer: NS, logger: L, network: Currency, min_final_cltv_expiry_delta: Option, + duration_since_epoch: Duration, ) -> Result> where ES::Target: EntropySource, @@ -127,40 +138,52 @@ where L::Target: Logger, { _create_phantom_invoice::( - amt_msat, payment_hash, Bolt11InvoiceDescription::Hash(&description_hash), - invoice_expiry_delta_secs, phantom_route_hints, entropy_source, node_signer, logger, network, - min_final_cltv_expiry_delta, duration_since_epoch, + amt_msat, + payment_hash, + Bolt11InvoiceDescription::Hash(&description_hash), + invoice_expiry_delta_secs, + phantom_route_hints, + entropy_source, + node_signer, + logger, + network, + min_final_cltv_expiry_delta, + duration_since_epoch, ) } const MAX_CHANNEL_HINTS: usize = 3; fn _create_phantom_invoice( - amt_msat: Option, payment_hash: Option, description: Bolt11InvoiceDescription, - invoice_expiry_delta_secs: u32, phantom_route_hints: Vec, entropy_source: ES, - node_signer: NS, logger: L, network: Currency, min_final_cltv_expiry_delta: Option, duration_since_epoch: Duration, + amt_msat: Option, payment_hash: Option, + description: Bolt11InvoiceDescription, invoice_expiry_delta_secs: u32, + phantom_route_hints: Vec, entropy_source: ES, node_signer: NS, logger: L, + network: Currency, min_final_cltv_expiry_delta: Option, duration_since_epoch: Duration, ) -> Result> where ES::Target: EntropySource, NS::Target: NodeSigner, L::Target: Logger, { - if phantom_route_hints.is_empty() { - return Err(SignOrCreationError::CreationError( - CreationError::MissingRouteHints, - )); + return Err(SignOrCreationError::CreationError(CreationError::MissingRouteHints)); } - if min_final_cltv_expiry_delta.is_some() && min_final_cltv_expiry_delta.unwrap().saturating_add(3) < MIN_FINAL_CLTV_EXPIRY_DELTA { - return Err(SignOrCreationError::CreationError(CreationError::MinFinalCltvExpiryDeltaTooShort)); + if min_final_cltv_expiry_delta.is_some() + && min_final_cltv_expiry_delta.unwrap().saturating_add(3) < MIN_FINAL_CLTV_EXPIRY_DELTA + { + return Err(SignOrCreationError::CreationError( + CreationError::MinFinalCltvExpiryDeltaTooShort, + )); } let invoice = match description { Bolt11InvoiceDescription::Direct(description) => { InvoiceBuilder::new(network).description(description.0.clone()) - } - Bolt11InvoiceDescription::Hash(hash) => InvoiceBuilder::new(network).description_hash(hash.0), + }, + Bolt11InvoiceDescription::Hash(hash) => { + InvoiceBuilder::new(network).description_hash(hash.0) + }, }; // If we ever see performance here being too slow then we should probably take this ExpandedKey as a parameter instead. @@ -171,8 +194,7 @@ where amt_msat, payment_hash, invoice_expiry_delta_secs, - duration_since_epoch - .as_secs(), + duration_since_epoch.as_secs(), min_final_cltv_expiry_delta, ) .map_err(|_| SignOrCreationError::CreationError(CreationError::InvalidAmount))?; @@ -183,15 +205,18 @@ where amt_msat, invoice_expiry_delta_secs, &entropy_source, - duration_since_epoch - .as_secs(), + duration_since_epoch.as_secs(), min_final_cltv_expiry_delta, ) .map_err(|_| SignOrCreationError::CreationError(CreationError::InvalidAmount))? }; - log_trace!(logger, "Creating phantom invoice from {} participating nodes with payment hash {}", - phantom_route_hints.len(), &payment_hash); + log_trace!( + logger, + "Creating phantom invoice from {} participating nodes with payment hash {}", + phantom_route_hints.len(), + &payment_hash + ); let mut invoice = invoice .duration_since_epoch(duration_since_epoch) @@ -199,28 +224,35 @@ where .payment_secret(payment_secret) .min_final_cltv_expiry_delta( // Add a buffer of 3 to the delta if present, otherwise use LDK's minimum. - min_final_cltv_expiry_delta.map(|x| x.saturating_add(3)).unwrap_or(MIN_FINAL_CLTV_EXPIRY_DELTA).into()) + min_final_cltv_expiry_delta + .map(|x| x.saturating_add(3)) + .unwrap_or(MIN_FINAL_CLTV_EXPIRY_DELTA) + .into(), + ) .expiry_time(Duration::from_secs(invoice_expiry_delta_secs.into())); if let Some(amt) = amt_msat { invoice = invoice.amount_milli_satoshis(amt); } - - for route_hint in select_phantom_hints(amt_msat, phantom_route_hints, logger).take(MAX_CHANNEL_HINTS) { + for route_hint in + select_phantom_hints(amt_msat, phantom_route_hints, logger).take(MAX_CHANNEL_HINTS) + { invoice = invoice.private_route(route_hint); } let raw_invoice = match invoice.build_raw() { Ok(inv) => inv, - Err(e) => return Err(SignOrCreationError::CreationError(e)) + Err(e) => return Err(SignOrCreationError::CreationError(e)), }; let hrp_str = raw_invoice.hrp.to_string(); let hrp_bytes = hrp_str.as_bytes(); let data_without_signature = raw_invoice.data.to_base32(); - let signed_raw_invoice = raw_invoice.sign(|_| node_signer.sign_invoice(hrp_bytes, &data_without_signature, Recipient::PhantomNode)); + let signed_raw_invoice = raw_invoice.sign(|_| { + node_signer.sign_invoice(hrp_bytes, &data_without_signature, Recipient::PhantomNode) + }); match signed_raw_invoice { Ok(inv) => Ok(Bolt11Invoice::from_signed(inv).unwrap()), - Err(e) => Err(SignOrCreationError::SignError(e)) + Err(e) => Err(SignOrCreationError::SignError(e)), } } @@ -232,16 +264,20 @@ where /// * Select one hint from each node, up to three hints or until we run out of hints. /// /// [`PhantomKeysManager`]: lightning::sign::PhantomKeysManager -fn select_phantom_hints(amt_msat: Option, phantom_route_hints: Vec, - logger: L) -> impl Iterator +fn select_phantom_hints( + amt_msat: Option, phantom_route_hints: Vec, logger: L, +) -> impl Iterator where L::Target: Logger, { let mut phantom_hints: Vec<_> = Vec::new(); for PhantomRouteHints { channels, phantom_scid, real_node_pubkey } in phantom_route_hints { - log_trace!(logger, "Generating phantom route hints for node {}", - log_pubkey!(real_node_pubkey)); + log_trace!( + logger, + "Generating phantom route hints for node {}", + log_pubkey!(real_node_pubkey) + ); let route_hints = sort_and_filter_channels(channels, amt_msat, &logger); // If we have any public channel, the route hints from `sort_and_filter_channels` will be @@ -265,10 +301,7 @@ where hint.0.push(RouteHintHop { src_node_id: real_node_pubkey, short_channel_id: phantom_scid, - fees: RoutingFees { - base_msat: 0, - proportional_millionths: 0, - }, + fees: RoutingFees { base_msat: 0, proportional_millionths: 0 }, cltv_expiry_delta: MIN_CLTV_EXPIRY_DELTA, htlc_minimum_msat: None, htlc_maximum_msat: None, @@ -329,7 +362,16 @@ fn rotate_through_iterators>(mut vecs: Vec) -> impl /// confirmations during routing. /// /// [`MIN_FINAL_CLTV_EXPIRY_DETLA`]: lightning::ln::channelmanager::MIN_FINAL_CLTV_EXPIRY_DELTA -pub fn create_invoice_from_channelmanager( +pub fn create_invoice_from_channelmanager< + M: Deref, + T: Deref, + ES: Deref, + NS: Deref, + SP: Deref, + F: Deref, + R: Deref, + L: Deref, +>( channelmanager: &ChannelManager, node_signer: NS, logger: L, network: Currency, amt_msat: Option, description: String, invoice_expiry_delta_secs: u32, min_final_cltv_expiry_delta: Option, @@ -345,11 +387,19 @@ where L::Target: Logger, { use std::time::SystemTime; - let duration = SystemTime::now().duration_since(SystemTime::UNIX_EPOCH) + let duration = SystemTime::now() + .duration_since(SystemTime::UNIX_EPOCH) .expect("for the foreseeable future this shouldn't happen"); create_invoice_from_channelmanager_and_duration_since_epoch( - channelmanager, node_signer, logger, network, amt_msat, - description, duration, invoice_expiry_delta_secs, min_final_cltv_expiry_delta, + channelmanager, + node_signer, + logger, + network, + amt_msat, + description, + duration, + invoice_expiry_delta_secs, + min_final_cltv_expiry_delta, ) } @@ -370,7 +420,16 @@ where /// confirmations during routing. /// /// [`MIN_FINAL_CLTV_EXPIRY_DETLA`]: lightning::ln::channelmanager::MIN_FINAL_CLTV_EXPIRY_DELTA -pub fn create_invoice_from_channelmanager_with_description_hash( +pub fn create_invoice_from_channelmanager_with_description_hash< + M: Deref, + T: Deref, + ES: Deref, + NS: Deref, + SP: Deref, + F: Deref, + R: Deref, + L: Deref, +>( channelmanager: &ChannelManager, node_signer: NS, logger: L, network: Currency, amt_msat: Option, description_hash: Sha256, invoice_expiry_delta_secs: u32, min_final_cltv_expiry_delta: Option, @@ -392,80 +451,132 @@ where .expect("for the foreseeable future this shouldn't happen"); create_invoice_from_channelmanager_with_description_hash_and_duration_since_epoch( - channelmanager, node_signer, logger, network, amt_msat, - description_hash, duration, invoice_expiry_delta_secs, min_final_cltv_expiry_delta, + channelmanager, + node_signer, + logger, + network, + amt_msat, + description_hash, + duration, + invoice_expiry_delta_secs, + min_final_cltv_expiry_delta, ) } /// See [`create_invoice_from_channelmanager_with_description_hash`] /// This version can be used in a `no_std` environment, where [`std::time::SystemTime`] is not /// available and the current time is supplied by the caller. -pub fn create_invoice_from_channelmanager_with_description_hash_and_duration_since_epoch( +pub fn create_invoice_from_channelmanager_with_description_hash_and_duration_since_epoch< + M: Deref, + T: Deref, + ES: Deref, + NS: Deref, + SP: Deref, + F: Deref, + R: Deref, + L: Deref, +>( channelmanager: &ChannelManager, node_signer: NS, logger: L, network: Currency, amt_msat: Option, description_hash: Sha256, - duration_since_epoch: Duration, invoice_expiry_delta_secs: u32, min_final_cltv_expiry_delta: Option, + duration_since_epoch: Duration, invoice_expiry_delta_secs: u32, + min_final_cltv_expiry_delta: Option, ) -> Result> - where - M::Target: chain::Watch<::Signer>, - T::Target: BroadcasterInterface, - ES::Target: EntropySource, - NS::Target: NodeSigner, - SP::Target: SignerProvider, - F::Target: FeeEstimator, - R::Target: Router, - L::Target: Logger, +where + M::Target: chain::Watch<::Signer>, + T::Target: BroadcasterInterface, + ES::Target: EntropySource, + NS::Target: NodeSigner, + SP::Target: SignerProvider, + F::Target: FeeEstimator, + R::Target: Router, + L::Target: Logger, { _create_invoice_from_channelmanager_and_duration_since_epoch( - channelmanager, node_signer, logger, network, amt_msat, + channelmanager, + node_signer, + logger, + network, + amt_msat, Bolt11InvoiceDescription::Hash(&description_hash), - duration_since_epoch, invoice_expiry_delta_secs, min_final_cltv_expiry_delta, + duration_since_epoch, + invoice_expiry_delta_secs, + min_final_cltv_expiry_delta, ) } /// See [`create_invoice_from_channelmanager`] /// This version can be used in a `no_std` environment, where [`std::time::SystemTime`] is not /// available and the current time is supplied by the caller. -pub fn create_invoice_from_channelmanager_and_duration_since_epoch( +pub fn create_invoice_from_channelmanager_and_duration_since_epoch< + M: Deref, + T: Deref, + ES: Deref, + NS: Deref, + SP: Deref, + F: Deref, + R: Deref, + L: Deref, +>( channelmanager: &ChannelManager, node_signer: NS, logger: L, network: Currency, amt_msat: Option, description: String, duration_since_epoch: Duration, invoice_expiry_delta_secs: u32, min_final_cltv_expiry_delta: Option, ) -> Result> - where - M::Target: chain::Watch<::Signer>, - T::Target: BroadcasterInterface, - ES::Target: EntropySource, - NS::Target: NodeSigner, - SP::Target: SignerProvider, - F::Target: FeeEstimator, - R::Target: Router, - L::Target: Logger, +where + M::Target: chain::Watch<::Signer>, + T::Target: BroadcasterInterface, + ES::Target: EntropySource, + NS::Target: NodeSigner, + SP::Target: SignerProvider, + F::Target: FeeEstimator, + R::Target: Router, + L::Target: Logger, { _create_invoice_from_channelmanager_and_duration_since_epoch( - channelmanager, node_signer, logger, network, amt_msat, + channelmanager, + node_signer, + logger, + network, + amt_msat, Bolt11InvoiceDescription::Direct( &Description::new(description).map_err(SignOrCreationError::CreationError)?, ), - duration_since_epoch, invoice_expiry_delta_secs, min_final_cltv_expiry_delta, + duration_since_epoch, + invoice_expiry_delta_secs, + min_final_cltv_expiry_delta, ) } -fn _create_invoice_from_channelmanager_and_duration_since_epoch( +fn _create_invoice_from_channelmanager_and_duration_since_epoch< + M: Deref, + T: Deref, + ES: Deref, + NS: Deref, + SP: Deref, + F: Deref, + R: Deref, + L: Deref, +>( channelmanager: &ChannelManager, node_signer: NS, logger: L, network: Currency, amt_msat: Option, description: Bolt11InvoiceDescription, - duration_since_epoch: Duration, invoice_expiry_delta_secs: u32, min_final_cltv_expiry_delta: Option, + duration_since_epoch: Duration, invoice_expiry_delta_secs: u32, + min_final_cltv_expiry_delta: Option, ) -> Result> - where - M::Target: chain::Watch<::Signer>, - T::Target: BroadcasterInterface, - ES::Target: EntropySource, - NS::Target: NodeSigner, - SP::Target: SignerProvider, - F::Target: FeeEstimator, - R::Target: Router, - L::Target: Logger, +where + M::Target: chain::Watch<::Signer>, + T::Target: BroadcasterInterface, + ES::Target: EntropySource, + NS::Target: NodeSigner, + SP::Target: SignerProvider, + F::Target: FeeEstimator, + R::Target: Router, + L::Target: Logger, { - if min_final_cltv_expiry_delta.is_some() && min_final_cltv_expiry_delta.unwrap().saturating_add(3) < MIN_FINAL_CLTV_EXPIRY_DELTA { - return Err(SignOrCreationError::CreationError(CreationError::MinFinalCltvExpiryDeltaTooShort)); + if min_final_cltv_expiry_delta.is_some() + && min_final_cltv_expiry_delta.unwrap().saturating_add(3) < MIN_FINAL_CLTV_EXPIRY_DELTA + { + return Err(SignOrCreationError::CreationError( + CreationError::MinFinalCltvExpiryDeltaTooShort, + )); } // `create_inbound_payment` only returns an error if the amount is greater than the total bitcoin @@ -474,64 +585,108 @@ fn _create_invoice_from_channelmanager_and_duration_since_epoch( +pub fn create_invoice_from_channelmanager_and_duration_since_epoch_with_payment_hash< + M: Deref, + T: Deref, + ES: Deref, + NS: Deref, + SP: Deref, + F: Deref, + R: Deref, + L: Deref, +>( channelmanager: &ChannelManager, node_signer: NS, logger: L, network: Currency, amt_msat: Option, description: String, duration_since_epoch: Duration, - invoice_expiry_delta_secs: u32, payment_hash: PaymentHash, min_final_cltv_expiry_delta: Option, + invoice_expiry_delta_secs: u32, payment_hash: PaymentHash, + min_final_cltv_expiry_delta: Option, ) -> Result> - where - M::Target: chain::Watch<::Signer>, - T::Target: BroadcasterInterface, - ES::Target: EntropySource, - NS::Target: NodeSigner, - SP::Target: SignerProvider, - F::Target: FeeEstimator, - R::Target: Router, - L::Target: Logger, +where + M::Target: chain::Watch<::Signer>, + T::Target: BroadcasterInterface, + ES::Target: EntropySource, + NS::Target: NodeSigner, + SP::Target: SignerProvider, + F::Target: FeeEstimator, + R::Target: Router, + L::Target: Logger, { let payment_secret = channelmanager - .create_inbound_payment_for_hash(payment_hash, amt_msat, invoice_expiry_delta_secs, - min_final_cltv_expiry_delta) + .create_inbound_payment_for_hash( + payment_hash, + amt_msat, + invoice_expiry_delta_secs, + min_final_cltv_expiry_delta, + ) .map_err(|()| SignOrCreationError::CreationError(CreationError::InvalidAmount))?; _create_invoice_from_channelmanager_and_duration_since_epoch_with_payment_hash( - channelmanager, node_signer, logger, network, amt_msat, + channelmanager, + node_signer, + logger, + network, + amt_msat, Bolt11InvoiceDescription::Direct( &Description::new(description).map_err(SignOrCreationError::CreationError)?, ), - duration_since_epoch, invoice_expiry_delta_secs, payment_hash, payment_secret, + duration_since_epoch, + invoice_expiry_delta_secs, + payment_hash, + payment_secret, min_final_cltv_expiry_delta, ) } -fn _create_invoice_from_channelmanager_and_duration_since_epoch_with_payment_hash( +fn _create_invoice_from_channelmanager_and_duration_since_epoch_with_payment_hash< + M: Deref, + T: Deref, + ES: Deref, + NS: Deref, + SP: Deref, + F: Deref, + R: Deref, + L: Deref, +>( channelmanager: &ChannelManager, node_signer: NS, logger: L, network: Currency, amt_msat: Option, description: Bolt11InvoiceDescription, duration_since_epoch: Duration, invoice_expiry_delta_secs: u32, payment_hash: PaymentHash, payment_secret: PaymentSecret, min_final_cltv_expiry_delta: Option, ) -> Result> - where - M::Target: chain::Watch<::Signer>, - T::Target: BroadcasterInterface, - ES::Target: EntropySource, - NS::Target: NodeSigner, - SP::Target: SignerProvider, - F::Target: FeeEstimator, - R::Target: Router, - L::Target: Logger, +where + M::Target: chain::Watch<::Signer>, + T::Target: BroadcasterInterface, + ES::Target: EntropySource, + NS::Target: NodeSigner, + SP::Target: SignerProvider, + F::Target: FeeEstimator, + R::Target: Router, + L::Target: Logger, { let our_node_pubkey = channelmanager.get_our_node_id(); let channels = channelmanager.list_channels(); - if min_final_cltv_expiry_delta.is_some() && min_final_cltv_expiry_delta.unwrap().saturating_add(3) < MIN_FINAL_CLTV_EXPIRY_DELTA { - return Err(SignOrCreationError::CreationError(CreationError::MinFinalCltvExpiryDeltaTooShort)); + if min_final_cltv_expiry_delta.is_some() + && min_final_cltv_expiry_delta.unwrap().saturating_add(3) < MIN_FINAL_CLTV_EXPIRY_DELTA + { + return Err(SignOrCreationError::CreationError( + CreationError::MinFinalCltvExpiryDeltaTooShort, + )); } log_trace!(logger, "Creating invoice with payment hash {}", &payment_hash); @@ -539,8 +694,10 @@ fn _create_invoice_from_channelmanager_and_duration_since_epoch_with_payment_has let invoice = match description { Bolt11InvoiceDescription::Direct(description) => { InvoiceBuilder::new(network).description(description.0.clone()) - } - Bolt11InvoiceDescription::Hash(hash) => InvoiceBuilder::new(network).description_hash(hash.0), + }, + Bolt11InvoiceDescription::Hash(hash) => { + InvoiceBuilder::new(network).description_hash(hash.0) + }, }; let mut invoice = invoice @@ -551,7 +708,11 @@ fn _create_invoice_from_channelmanager_and_duration_since_epoch_with_payment_has .basic_mpp() .min_final_cltv_expiry_delta( // Add a buffer of 3 to the delta if present, otherwise use LDK's minimum. - min_final_cltv_expiry_delta.map(|x| x.saturating_add(3)).unwrap_or(MIN_FINAL_CLTV_EXPIRY_DELTA).into()) + min_final_cltv_expiry_delta + .map(|x| x.saturating_add(3)) + .unwrap_or(MIN_FINAL_CLTV_EXPIRY_DELTA) + .into(), + ) .expiry_time(Duration::from_secs(invoice_expiry_delta_secs.into())); if let Some(amt) = amt_msat { invoice = invoice.amount_milli_satoshis(amt); @@ -564,15 +725,16 @@ fn _create_invoice_from_channelmanager_and_duration_since_epoch_with_payment_has let raw_invoice = match invoice.build_raw() { Ok(inv) => inv, - Err(e) => return Err(SignOrCreationError::CreationError(e)) + Err(e) => return Err(SignOrCreationError::CreationError(e)), }; let hrp_str = raw_invoice.hrp.to_string(); let hrp_bytes = hrp_str.as_bytes(); let data_without_signature = raw_invoice.data.to_base32(); - let signed_raw_invoice = raw_invoice.sign(|_| node_signer.sign_invoice(hrp_bytes, &data_without_signature, Recipient::Node)); + let signed_raw_invoice = raw_invoice + .sign(|_| node_signer.sign_invoice(hrp_bytes, &data_without_signature, Recipient::Node)); match signed_raw_invoice { Ok(inv) => Ok(Bolt11Invoice::from_signed(inv).unwrap()), - Err(e) => Err(SignOrCreationError::SignError(e)) + Err(e) => Err(SignOrCreationError::SignError(e)), } } @@ -596,9 +758,7 @@ fn _create_invoice_from_channelmanager_and_duration_since_epoch_with_payment_has /// * Sorted by lowest inbound capacity if an online channel with the minimum amount requested exists, /// otherwise sort by highest inbound capacity to give the payment the best chance of succeeding. fn sort_and_filter_channels( - channels: Vec, - min_inbound_capacity_msat: Option, - logger: &L, + channels: Vec, min_inbound_capacity_msat: Option, logger: &L, ) -> impl ExactSizeIterator where L::Target: Logger, @@ -621,12 +781,15 @@ where }, cltv_expiry_delta: forwarding_info.cltv_expiry_delta, htlc_minimum_msat: channel.inbound_htlc_minimum_msat, - htlc_maximum_msat: channel.inbound_htlc_maximum_msat,}]) + htlc_maximum_msat: channel.inbound_htlc_maximum_msat, + }]) }; log_trace!(logger, "Considering {} channels for invoice route hints", channels.len()); for channel in channels.into_iter().filter(|chan| chan.is_channel_ready) { - if channel.get_inbound_payment_scid().is_none() || channel.counterparty.forwarding_info.is_none() { + if channel.get_inbound_payment_scid().is_none() + || channel.counterparty.forwarding_info.is_none() + { log_trace!(logger, "Ignoring channel {} for invoice route hints", &channel.channel_id); continue; } @@ -640,15 +803,21 @@ where } else { // If any public channel exists, return no hints and let the sender // look at the public channels instead. - log_trace!(logger, "Not including channels in invoice route hints on account of public channel {}", - &channel.channel_id); + log_trace!( + logger, + "Not including channels in invoice route hints on account of public channel {}", + &channel.channel_id + ); return vec![].into_iter().take(MAX_CHANNEL_HINTS).map(route_hint_from_channel); } } if channel.inbound_capacity_msat >= min_inbound_capacity { if !min_capacity_channel_exists { - log_trace!(logger, "Channel with enough inbound capacity exists for invoice route hints"); + log_trace!( + logger, + "Channel with enough inbound capacity exists for invoice route hints" + ); min_capacity_channel_exists = true; } @@ -670,12 +839,16 @@ where let new_now_public = channel.is_public && !entry.get().is_public; // Decide whether we prefer the currently selected channel with the node to the new one, // based on their inbound capacity. - let prefer_current = prefer_current_channel(min_inbound_capacity_msat, current_max_capacity, - channel.inbound_capacity_msat); + let prefer_current = prefer_current_channel( + min_inbound_capacity_msat, + current_max_capacity, + channel.inbound_capacity_msat, + ); // If the public-ness of the channel has not changed (in which case simply defer to // `new_now_public), and this channel has more desirable inbound than the incumbent, // prefer to include this channel. - let new_channel_preferable = channel.is_public == entry.get().is_public && !prefer_current; + let new_channel_preferable = + channel.is_public == entry.get().is_public && !prefer_current; if new_now_public || new_channel_preferable { log_trace!(logger, @@ -695,10 +868,10 @@ where &channel.channel_id, channel.short_channel_id, channel.inbound_capacity_msat); } - } + }, hash_map::Entry::Vacant(entry) => { entry.insert(channel); - } + }, } } @@ -727,32 +900,44 @@ where has_enough_capacity } else if online_channel_exists { channel.is_usable - } else { true }; + } else { + true + }; if include_channel { - log_trace!(logger, "Including channel {} in invoice route hints", - &channel.channel_id); + log_trace!( + logger, + "Including channel {} in invoice route hints", + &channel.channel_id + ); } else if !has_enough_capacity { - log_trace!(logger, "Ignoring channel {} without enough capacity for invoice route hints", - &channel.channel_id); + log_trace!( + logger, + "Ignoring channel {} without enough capacity for invoice route hints", + &channel.channel_id + ); } else { debug_assert!(!channel.is_usable || (has_pub_unconf_chan && !channel.is_public)); - log_trace!(logger, "Ignoring channel {} with disconnected peer", - &channel.channel_id); + log_trace!( + logger, + "Ignoring channel {} with disconnected peer", + &channel.channel_id + ); } include_channel }) .collect::>(); - eligible_channels.sort_unstable_by(|a, b| { - if online_min_capacity_channel_exists { - a.inbound_capacity_msat.cmp(&b.inbound_capacity_msat) - } else { - b.inbound_capacity_msat.cmp(&a.inbound_capacity_msat) - }}); + eligible_channels.sort_unstable_by(|a, b| { + if online_min_capacity_channel_exists { + a.inbound_capacity_msat.cmp(&b.inbound_capacity_msat) + } else { + b.inbound_capacity_msat.cmp(&a.inbound_capacity_msat) + } + }); - eligible_channels.into_iter().take(MAX_CHANNEL_HINTS).map(route_hint_from_channel) + eligible_channels.into_iter().take(MAX_CHANNEL_HINTS).map(route_hint_from_channel) } /// prefer_current_channel chooses a channel to use for route hints between a currently selected and candidate @@ -766,13 +951,13 @@ where /// our change"). /// * If no channel above our minimum amount exists, then we just prefer the channel with the most inbound to give /// payments the best chance of succeeding in multiple parts. -fn prefer_current_channel(min_inbound_capacity_msat: Option, current_channel: u64, - candidate_channel: u64) -> bool { - +fn prefer_current_channel( + min_inbound_capacity_msat: Option, current_channel: u64, candidate_channel: u64, +) -> bool { // If no min amount is given for the hints, err of the side of caution and choose the largest channel inbound to // maximize chances of any payment succeeding. if min_inbound_capacity_msat.is_none() { - return current_channel > candidate_channel + return current_channel > candidate_channel; } let scaled_min_inbound = min_inbound_capacity_msat.unwrap() * 110; @@ -780,11 +965,11 @@ fn prefer_current_channel(min_inbound_capacity_msat: Option, current_channe let candidate_sufficient = candidate_channel * 100 >= scaled_min_inbound; if current_sufficient && candidate_sufficient { - return current_channel < candidate_channel + return current_channel < candidate_channel; } else if current_sufficient { - return true + return true; } else if candidate_sufficient { - return false + return false; } current_channel > candidate_channel @@ -792,21 +977,27 @@ fn prefer_current_channel(min_inbound_capacity_msat: Option, current_channe #[cfg(test)] mod test { + use crate::utils::{ + create_invoice_from_channelmanager_and_duration_since_epoch, rotate_through_iterators, + }; + use crate::{ + Bolt11InvoiceDescription, CreationError, Currency, Description, SignOrCreationError, + }; + use bitcoin_hashes::sha256::Hash as Sha256; + use bitcoin_hashes::{sha256, Hash}; use core::cell::RefCell; use core::time::Duration; - use crate::{Currency, Description, Bolt11InvoiceDescription, SignOrCreationError, CreationError}; - use bitcoin_hashes::{Hash, sha256}; - use bitcoin_hashes::sha256::Hash as Sha256; - use lightning::sign::PhantomKeysManager; - use lightning::events::{MessageSendEvent, MessageSendEventsProvider, Event, EventsProvider}; - use lightning::ln::{PaymentPreimage, PaymentHash}; - use lightning::ln::channelmanager::{PhantomRouteHints, MIN_FINAL_CLTV_EXPIRY_DELTA, PaymentId, RecipientOnionFields, Retry}; + use lightning::events::{Event, EventsProvider, MessageSendEvent, MessageSendEventsProvider}; + use lightning::ln::channelmanager::{ + PaymentId, PhantomRouteHints, RecipientOnionFields, Retry, MIN_FINAL_CLTV_EXPIRY_DELTA, + }; use lightning::ln::functional_test_utils::*; use lightning::ln::msgs::ChannelMessageHandler; + use lightning::ln::{PaymentHash, PaymentPreimage}; use lightning::routing::router::{PaymentParameters, RouteParameters}; - use lightning::util::test_utils; + use lightning::sign::PhantomKeysManager; use lightning::util::config::UserConfig; - use crate::utils::{create_invoice_from_channelmanager_and_duration_since_epoch, rotate_through_iterators}; + use lightning::util::test_utils; use std::collections::HashSet; #[test] @@ -836,7 +1027,6 @@ mod test { assert_eq!(crate::utils::prefer_current_channel(Some(200), 100, 150), false); } - #[test] fn test_from_channelmanager() { let chanmon_cfgs = create_chanmon_cfgs(2); @@ -846,37 +1036,67 @@ mod test { create_unannounced_chan_between_nodes_with_value(&nodes, 0, 1, 100000, 10001); let non_default_invoice_expiry_secs = 4200; let invoice = create_invoice_from_channelmanager_and_duration_since_epoch( - nodes[1].node, nodes[1].keys_manager, nodes[1].logger, Currency::BitcoinTestnet, - Some(10_000), "test".to_string(), Duration::from_secs(1234567), - non_default_invoice_expiry_secs, None).unwrap(); + nodes[1].node, + nodes[1].keys_manager, + nodes[1].logger, + Currency::BitcoinTestnet, + Some(10_000), + "test".to_string(), + Duration::from_secs(1234567), + non_default_invoice_expiry_secs, + None, + ) + .unwrap(); assert_eq!(invoice.amount_pico_btc(), Some(100_000)); // If no `min_final_cltv_expiry_delta` is specified, then it should be `MIN_FINAL_CLTV_EXPIRY_DELTA`. assert_eq!(invoice.min_final_cltv_expiry_delta(), MIN_FINAL_CLTV_EXPIRY_DELTA as u64); - assert_eq!(invoice.description(), Bolt11InvoiceDescription::Direct(&Description("test".to_string()))); - assert_eq!(invoice.expiry_time(), Duration::from_secs(non_default_invoice_expiry_secs.into())); + assert_eq!( + invoice.description(), + Bolt11InvoiceDescription::Direct(&Description("test".to_string())) + ); + assert_eq!( + invoice.expiry_time(), + Duration::from_secs(non_default_invoice_expiry_secs.into()) + ); // Invoice SCIDs should always use inbound SCID aliases over the real channel ID, if one is // available. let chan = &nodes[1].node.list_usable_channels()[0]; assert_eq!(invoice.route_hints().len(), 1); assert_eq!(invoice.route_hints()[0].0.len(), 1); - assert_eq!(invoice.route_hints()[0].0[0].short_channel_id, chan.inbound_scid_alias.unwrap()); + assert_eq!( + invoice.route_hints()[0].0[0].short_channel_id, + chan.inbound_scid_alias.unwrap() + ); assert_eq!(invoice.route_hints()[0].0[0].htlc_minimum_msat, chan.inbound_htlc_minimum_msat); assert_eq!(invoice.route_hints()[0].0[0].htlc_maximum_msat, chan.inbound_htlc_maximum_msat); - let payment_params = PaymentParameters::from_node_id(invoice.recover_payee_pub_key(), - invoice.min_final_cltv_expiry_delta() as u32) - .with_bolt11_features(invoice.features().unwrap().clone()).unwrap() - .with_route_hints(invoice.route_hints()).unwrap(); + let payment_params = PaymentParameters::from_node_id( + invoice.recover_payee_pub_key(), + invoice.min_final_cltv_expiry_delta() as u32, + ) + .with_bolt11_features(invoice.features().unwrap().clone()) + .unwrap() + .with_route_hints(invoice.route_hints()) + .unwrap(); let route_params = RouteParameters::from_payment_params_and_value( - payment_params, invoice.amount_milli_satoshis().unwrap()); + payment_params, + invoice.amount_milli_satoshis().unwrap(), + ); let payment_event = { let mut payment_hash = PaymentHash([0; 32]); payment_hash.0.copy_from_slice(&invoice.payment_hash().as_ref()[0..32]); - nodes[0].node.send_payment(payment_hash, - RecipientOnionFields::secret_only(*invoice.payment_secret()), - PaymentId(payment_hash.0), route_params, Retry::Attempts(0)).unwrap(); + nodes[0] + .node + .send_payment( + payment_hash, + RecipientOnionFields::secret_only(*invoice.payment_secret()), + PaymentId(payment_hash.0), + route_params, + Retry::Attempts(0), + ) + .unwrap(); let mut added_monitors = nodes[0].chain_monitor.added_monitors.lock().unwrap(); assert_eq!(added_monitors.len(), 1); added_monitors.clear(); @@ -884,10 +1104,14 @@ mod test { let mut events = nodes[0].node.get_and_clear_pending_msg_events(); assert_eq!(events.len(), 1); SendEvent::from_event(events.remove(0)) - }; - nodes[1].node.handle_update_add_htlc(&nodes[0].node.get_our_node_id(), &payment_event.msgs[0]); - nodes[1].node.handle_commitment_signed(&nodes[0].node.get_our_node_id(), &payment_event.commitment_msg); + nodes[1] + .node + .handle_update_add_htlc(&nodes[0].node.get_our_node_id(), &payment_event.msgs[0]); + nodes[1].node.handle_commitment_signed( + &nodes[0].node.get_our_node_id(), + &payment_event.commitment_msg, + ); let mut added_monitors = nodes[1].chain_monitor.added_monitors.lock().unwrap(); assert_eq!(added_monitors.len(), 1); added_monitors.clear(); @@ -903,12 +1127,25 @@ mod test { let custom_min_final_cltv_expiry_delta = Some(50); let invoice = crate::utils::create_invoice_from_channelmanager_and_duration_since_epoch( - nodes[1].node, nodes[1].keys_manager, nodes[1].logger, Currency::BitcoinTestnet, - Some(10_000), "".into(), Duration::from_secs(1234567), 3600, + nodes[1].node, + nodes[1].keys_manager, + nodes[1].logger, + Currency::BitcoinTestnet, + Some(10_000), + "".into(), + Duration::from_secs(1234567), + 3600, if with_custom_delta { custom_min_final_cltv_expiry_delta } else { None }, - ).unwrap(); - assert_eq!(invoice.min_final_cltv_expiry_delta(), if with_custom_delta { - custom_min_final_cltv_expiry_delta.unwrap() + 3 /* Buffer */} else { MIN_FINAL_CLTV_EXPIRY_DELTA } as u64); + ) + .unwrap(); + assert_eq!( + invoice.min_final_cltv_expiry_delta(), + if with_custom_delta { + custom_min_final_cltv_expiry_delta.unwrap() + 3 /* Buffer */ + } else { + MIN_FINAL_CLTV_EXPIRY_DELTA + } as u64 + ); } #[test] @@ -926,10 +1163,17 @@ mod test { let custom_min_final_cltv_expiry_delta = Some(21); let invoice = crate::utils::create_invoice_from_channelmanager_and_duration_since_epoch( - nodes[1].node, nodes[1].keys_manager, nodes[1].logger, Currency::BitcoinTestnet, - Some(10_000), "".into(), Duration::from_secs(1234567), 3600, + nodes[1].node, + nodes[1].keys_manager, + nodes[1].logger, + Currency::BitcoinTestnet, + Some(10_000), + "".into(), + Duration::from_secs(1234567), + 3600, custom_min_final_cltv_expiry_delta, - ).unwrap(); + ) + .unwrap(); assert_eq!(invoice.min_final_cltv_expiry_delta(), MIN_FINAL_CLTV_EXPIRY_DELTA as u64); } @@ -946,7 +1190,12 @@ mod test { ).unwrap(); assert_eq!(invoice.amount_pico_btc(), Some(100_000)); assert_eq!(invoice.min_final_cltv_expiry_delta(), MIN_FINAL_CLTV_EXPIRY_DELTA as u64); - assert_eq!(invoice.description(), Bolt11InvoiceDescription::Hash(&crate::Sha256(Sha256::hash("Testing description_hash".as_bytes())))); + assert_eq!( + invoice.description(), + Bolt11InvoiceDescription::Hash(&crate::Sha256(Sha256::hash( + "Testing description_hash".as_bytes() + ))) + ); } #[test] @@ -963,7 +1212,10 @@ mod test { ).unwrap(); assert_eq!(invoice.amount_pico_btc(), Some(100_000)); assert_eq!(invoice.min_final_cltv_expiry_delta(), MIN_FINAL_CLTV_EXPIRY_DELTA as u64); - assert_eq!(invoice.description(), Bolt11InvoiceDescription::Direct(&Description("test".to_string()))); + assert_eq!( + invoice.description(), + Bolt11InvoiceDescription::Direct(&Description("test".to_string())) + ); assert_eq!(invoice.payment_hash(), &sha256::Hash::from_slice(&payment_hash.0[..]).unwrap()); } @@ -978,7 +1230,8 @@ mod test { // Create a private channel with lots of capacity and a lower value public channel (without // confirming the funding tx yet). - let unannounced_scid = create_unannounced_chan_between_nodes_with_value(&nodes, 0, 1, 10_000_000, 0); + let unannounced_scid = + create_unannounced_chan_between_nodes_with_value(&nodes, 0, 1, 10_000_000, 0); let conf_tx = create_chan_between_nodes_with_value_init(&nodes[0], &nodes[1], 10_000, 0); // Before the channel is available, we should include the unannounced_scid. @@ -990,20 +1243,37 @@ mod test { // channel we'll immediately switch to including it as a route hint, even though it isn't // yet announced. let pub_channel_scid = mine_transaction(&nodes[0], &conf_tx); - let node_a_pub_channel_ready = get_event_msg!(nodes[0], MessageSendEvent::SendChannelReady, nodes[1].node.get_our_node_id()); - nodes[1].node.handle_channel_ready(&nodes[0].node.get_our_node_id(), &node_a_pub_channel_ready); + let node_a_pub_channel_ready = get_event_msg!( + nodes[0], + MessageSendEvent::SendChannelReady, + nodes[1].node.get_our_node_id() + ); + nodes[1] + .node + .handle_channel_ready(&nodes[0].node.get_our_node_id(), &node_a_pub_channel_ready); assert_eq!(mine_transaction(&nodes[1], &conf_tx), pub_channel_scid); let events = nodes[1].node.get_and_clear_pending_msg_events(); assert_eq!(events.len(), 2); if let MessageSendEvent::SendChannelReady { msg, .. } = &events[0] { nodes[0].node.handle_channel_ready(&nodes[1].node.get_our_node_id(), msg); - } else { panic!(); } + } else { + panic!(); + } if let MessageSendEvent::SendChannelUpdate { msg, .. } = &events[1] { nodes[0].node.handle_channel_update(&nodes[1].node.get_our_node_id(), msg); - } else { panic!(); } + } else { + panic!(); + } - nodes[1].node.handle_channel_update(&nodes[0].node.get_our_node_id(), &get_event_msg!(nodes[0], MessageSendEvent::SendChannelUpdate, nodes[1].node.get_our_node_id())); + nodes[1].node.handle_channel_update( + &nodes[0].node.get_our_node_id(), + &get_event_msg!( + nodes[0], + MessageSendEvent::SendChannelUpdate, + nodes[1].node.get_our_node_id() + ), + ); expect_channel_ready_event(&nodes[0], &nodes[1].node.get_our_node_id()); expect_channel_ready_event(&nodes[1], &nodes[0].node.get_our_node_id()); @@ -1020,7 +1290,11 @@ mod test { connect_blocks(&nodes[1], 5); match_invoice_routes(Some(5000), &nodes[1], scid_aliases.clone()); connect_blocks(&nodes[1], 1); - get_event_msg!(nodes[1], MessageSendEvent::SendAnnouncementSignatures, nodes[0].node.get_our_node_id()); + get_event_msg!( + nodes[1], + MessageSendEvent::SendAnnouncementSignatures, + nodes[0].node.get_our_node_id() + ); match_invoice_routes(Some(5000), &nodes[1], HashSet::new()); } @@ -1031,8 +1305,10 @@ mod test { let node_chanmgrs = create_node_chanmgrs(3, &node_cfgs, &[None, None, None]); let nodes = create_network(3, &node_cfgs, &node_chanmgrs); - let chan_1_0 = create_unannounced_chan_between_nodes_with_value(&nodes, 1, 0, 100000, 10001); - let chan_2_0 = create_unannounced_chan_between_nodes_with_value(&nodes, 2, 0, 100000, 10001); + let chan_1_0 = + create_unannounced_chan_between_nodes_with_value(&nodes, 1, 0, 100000, 10001); + let chan_2_0 = + create_unannounced_chan_between_nodes_with_value(&nodes, 2, 0, 100000, 10001); let mut scid_aliases = HashSet::new(); scid_aliases.insert(chan_1_0.0.short_channel_id_alias.unwrap()); @@ -1048,9 +1324,12 @@ mod test { let node_chanmgrs = create_node_chanmgrs(2, &node_cfgs, &[None, None]); let nodes = create_network(2, &node_cfgs, &node_chanmgrs); - let _chan_1_0_inbound_below_amt = create_unannounced_chan_between_nodes_with_value(&nodes, 1, 0, 10_000, 0); - let _chan_1_0_large_inbound_above_amt = create_unannounced_chan_between_nodes_with_value(&nodes, 1, 0, 500_000, 0); - let chan_1_0_low_inbound_above_amt = create_unannounced_chan_between_nodes_with_value(&nodes, 1, 0, 200_000, 0); + let _chan_1_0_inbound_below_amt = + create_unannounced_chan_between_nodes_with_value(&nodes, 1, 0, 10_000, 0); + let _chan_1_0_large_inbound_above_amt = + create_unannounced_chan_between_nodes_with_value(&nodes, 1, 0, 500_000, 0); + let chan_1_0_low_inbound_above_amt = + create_unannounced_chan_between_nodes_with_value(&nodes, 1, 0, 200_000, 0); let mut scid_aliases = HashSet::new(); scid_aliases.insert(chan_1_0_low_inbound_above_amt.0.short_channel_id_alias.unwrap()); @@ -1134,31 +1413,69 @@ mod test { let node_cfgs = create_node_cfgs(3, &chanmon_cfgs); let node_chanmgrs = create_node_chanmgrs(3, &node_cfgs, &[None, None, None]); let nodes = create_network(3, &node_cfgs, &node_chanmgrs); - let chan_1_0 = create_unannounced_chan_between_nodes_with_value(&nodes, 1, 0, 100000, 10001); + let chan_1_0 = + create_unannounced_chan_between_nodes_with_value(&nodes, 1, 0, 100000, 10001); // Create an unannonced channel between `nodes[2]` and `nodes[0]`, for which the // `msgs::ChannelUpdate` is never handled for the node(s). As the `msgs::ChannelUpdate` // is never handled, the `channel.counterparty.forwarding_info` is never assigned. let mut private_chan_cfg = UserConfig::default(); private_chan_cfg.channel_handshake_config.announced_channel = false; - let temporary_channel_id = nodes[2].node.create_channel(nodes[0].node.get_our_node_id(), 1_000_000, 500_000_000, 42, Some(private_chan_cfg)).unwrap(); - let open_channel = get_event_msg!(nodes[2], MessageSendEvent::SendOpenChannel, nodes[0].node.get_our_node_id()); + let temporary_channel_id = nodes[2] + .node + .create_channel( + nodes[0].node.get_our_node_id(), + 1_000_000, + 500_000_000, + 42, + Some(private_chan_cfg), + ) + .unwrap(); + let open_channel = get_event_msg!( + nodes[2], + MessageSendEvent::SendOpenChannel, + nodes[0].node.get_our_node_id() + ); nodes[0].node.handle_open_channel(&nodes[2].node.get_our_node_id(), &open_channel); - let accept_channel = get_event_msg!(nodes[0], MessageSendEvent::SendAcceptChannel, nodes[2].node.get_our_node_id()); + let accept_channel = get_event_msg!( + nodes[0], + MessageSendEvent::SendAcceptChannel, + nodes[2].node.get_our_node_id() + ); nodes[2].node.handle_accept_channel(&nodes[0].node.get_our_node_id(), &accept_channel); let tx = sign_funding_transaction(&nodes[2], &nodes[0], 1_000_000, temporary_channel_id); - let conf_height = core::cmp::max(nodes[2].best_block_info().1 + 1, nodes[0].best_block_info().1 + 1); + let conf_height = + core::cmp::max(nodes[2].best_block_info().1 + 1, nodes[0].best_block_info().1 + 1); confirm_transaction_at(&nodes[2], &tx, conf_height); connect_blocks(&nodes[2], CHAN_CONFIRM_DEPTH - 1); confirm_transaction_at(&nodes[0], &tx, conf_height); connect_blocks(&nodes[0], CHAN_CONFIRM_DEPTH - 1); - let as_channel_ready = get_event_msg!(nodes[2], MessageSendEvent::SendChannelReady, nodes[0].node.get_our_node_id()); - nodes[2].node.handle_channel_ready(&nodes[0].node.get_our_node_id(), &get_event_msg!(nodes[0], MessageSendEvent::SendChannelReady, nodes[2].node.get_our_node_id())); - get_event_msg!(nodes[2], MessageSendEvent::SendChannelUpdate, nodes[0].node.get_our_node_id()); + let as_channel_ready = get_event_msg!( + nodes[2], + MessageSendEvent::SendChannelReady, + nodes[0].node.get_our_node_id() + ); + nodes[2].node.handle_channel_ready( + &nodes[0].node.get_our_node_id(), + &get_event_msg!( + nodes[0], + MessageSendEvent::SendChannelReady, + nodes[2].node.get_our_node_id() + ), + ); + get_event_msg!( + nodes[2], + MessageSendEvent::SendChannelUpdate, + nodes[0].node.get_our_node_id() + ); nodes[0].node.handle_channel_ready(&nodes[2].node.get_our_node_id(), &as_channel_ready); - get_event_msg!(nodes[0], MessageSendEvent::SendChannelUpdate, nodes[2].node.get_our_node_id()); + get_event_msg!( + nodes[0], + MessageSendEvent::SendChannelUpdate, + nodes[2].node.get_our_node_id() + ); expect_channel_ready_event(&nodes[0], &nodes[2].node.get_our_node_id()); expect_channel_ready_event(&nodes[2], &nodes[0].node.get_our_node_id()); @@ -1176,7 +1493,8 @@ mod test { let node_cfgs = create_node_cfgs(3, &chanmon_cfgs); let node_chanmgrs = create_node_chanmgrs(3, &node_cfgs, &[None, None, None]); let nodes = create_network(3, &node_cfgs, &node_chanmgrs); - let _chan_1_0 = create_unannounced_chan_between_nodes_with_value(&nodes, 1, 0, 100000, 10001); + let _chan_1_0 = + create_unannounced_chan_between_nodes_with_value(&nodes, 1, 0, 100000, 10001); let chan_2_0 = create_announced_chan_between_nodes_with_value(&nodes, 2, 0, 100000, 10001); nodes[2].node.handle_channel_update(&nodes[0].node.get_our_node_id(), &chan_2_0.1); @@ -1244,21 +1562,32 @@ mod test { } fn match_invoice_routes<'a, 'b: 'a, 'c: 'b>( - invoice_amt: Option, - invoice_node: &Node<'a, 'b, 'c>, - mut chan_ids_to_match: HashSet + invoice_amt: Option, invoice_node: &Node<'a, 'b, 'c>, + mut chan_ids_to_match: HashSet, ) { let invoice = create_invoice_from_channelmanager_and_duration_since_epoch( - invoice_node.node, invoice_node.keys_manager, invoice_node.logger, - Currency::BitcoinTestnet, invoice_amt, "test".to_string(), Duration::from_secs(1234567), - 3600, None).unwrap(); + invoice_node.node, + invoice_node.keys_manager, + invoice_node.logger, + Currency::BitcoinTestnet, + invoice_amt, + "test".to_string(), + Duration::from_secs(1234567), + 3600, + None, + ) + .unwrap(); let hints = invoice.private_routes(); for hint in hints { let hint_short_chan_id = (hint.0).0[0].short_channel_id; assert!(chan_ids_to_match.remove(&hint_short_chan_id)); } - assert!(chan_ids_to_match.is_empty(), "Unmatched short channel ids: {:?}", chan_ids_to_match); + assert!( + chan_ids_to_match.is_empty(), + "Unmatched short channel ids: {:?}", + chan_ids_to_match + ); } #[test] @@ -1274,8 +1603,10 @@ mod test { let seed_1 = [42u8; 32]; let seed_2 = [43u8; 32]; let cross_node_seed = [44u8; 32]; - chanmon_cfgs[1].keys_manager.backing = PhantomKeysManager::new(&seed_1, 43, 44, &cross_node_seed); - chanmon_cfgs[2].keys_manager.backing = PhantomKeysManager::new(&seed_2, 43, 44, &cross_node_seed); + chanmon_cfgs[1].keys_manager.backing = + PhantomKeysManager::new(&seed_1, 43, 44, &cross_node_seed); + chanmon_cfgs[2].keys_manager.backing = + PhantomKeysManager::new(&seed_2, 43, 44, &cross_node_seed); let node_cfgs = create_node_cfgs(3, &chanmon_cfgs); let node_chanmgrs = create_node_chanmgrs(3, &node_cfgs, &[None, None, None]); let nodes = create_network(3, &node_cfgs, &node_chanmgrs); @@ -1287,10 +1618,8 @@ mod test { nodes[2].node.handle_channel_update(&nodes[0].node.get_our_node_id(), &chan_0_2.0); let payment_amt = 10_000; - let route_hints = vec![ - nodes[1].node.get_phantom_route_hints(), - nodes[2].node.get_phantom_route_hints(), - ]; + let route_hints = + vec![nodes[1].node.get_phantom_route_hints(), nodes[2].node.get_phantom_route_hints()]; let user_payment_preimage = PaymentPreimage([1; 32]); let payment_hash = if user_generated_pmt_hash { @@ -1298,16 +1627,31 @@ mod test { } else { None }; - let genesis_timestamp = bitcoin::blockdata::constants::genesis_block(bitcoin::Network::Testnet).header.time as u64; + let genesis_timestamp = + bitcoin::blockdata::constants::genesis_block(bitcoin::Network::Testnet).header.time + as u64; let non_default_invoice_expiry_secs = 4200; - let invoice = - crate::utils::create_phantom_invoice::<&test_utils::TestKeysInterface, &test_utils::TestKeysInterface, &test_utils::TestLogger>( - Some(payment_amt), payment_hash, "test".to_string(), non_default_invoice_expiry_secs, - route_hints, nodes[1].keys_manager, nodes[1].keys_manager, nodes[1].logger, - Currency::BitcoinTestnet, None, Duration::from_secs(genesis_timestamp) - ).unwrap(); - let (payment_hash, payment_secret) = (PaymentHash(invoice.payment_hash().into_inner()), *invoice.payment_secret()); + let invoice = crate::utils::create_phantom_invoice::< + &test_utils::TestKeysInterface, + &test_utils::TestKeysInterface, + &test_utils::TestLogger, + >( + Some(payment_amt), + payment_hash, + "test".to_string(), + non_default_invoice_expiry_secs, + route_hints, + nodes[1].keys_manager, + nodes[1].keys_manager, + nodes[1].logger, + Currency::BitcoinTestnet, + None, + Duration::from_secs(genesis_timestamp), + ) + .unwrap(); + let (payment_hash, payment_secret) = + (PaymentHash(invoice.payment_hash().into_inner()), *invoice.payment_secret()); let payment_preimage = if user_generated_pmt_hash { user_payment_preimage } else { @@ -1315,23 +1659,42 @@ mod test { }; assert_eq!(invoice.min_final_cltv_expiry_delta(), MIN_FINAL_CLTV_EXPIRY_DELTA as u64); - assert_eq!(invoice.description(), Bolt11InvoiceDescription::Direct(&Description("test".to_string()))); + assert_eq!( + invoice.description(), + Bolt11InvoiceDescription::Direct(&Description("test".to_string())) + ); assert_eq!(invoice.route_hints().len(), 2); - assert_eq!(invoice.expiry_time(), Duration::from_secs(non_default_invoice_expiry_secs.into())); + assert_eq!( + invoice.expiry_time(), + Duration::from_secs(non_default_invoice_expiry_secs.into()) + ); assert!(!invoice.features().unwrap().supports_basic_mpp()); - let payment_params = PaymentParameters::from_node_id(invoice.recover_payee_pub_key(), - invoice.min_final_cltv_expiry_delta() as u32) - .with_bolt11_features(invoice.features().unwrap().clone()).unwrap() - .with_route_hints(invoice.route_hints()).unwrap(); + let payment_params = PaymentParameters::from_node_id( + invoice.recover_payee_pub_key(), + invoice.min_final_cltv_expiry_delta() as u32, + ) + .with_bolt11_features(invoice.features().unwrap().clone()) + .unwrap() + .with_route_hints(invoice.route_hints()) + .unwrap(); let params = RouteParameters::from_payment_params_and_value( - payment_params, invoice.amount_milli_satoshis().unwrap()); + payment_params, + invoice.amount_milli_satoshis().unwrap(), + ); let (payment_event, fwd_idx) = { let mut payment_hash = PaymentHash([0; 32]); payment_hash.0.copy_from_slice(&invoice.payment_hash().as_ref()[0..32]); - nodes[0].node.send_payment(payment_hash, - RecipientOnionFields::secret_only(*invoice.payment_secret()), - PaymentId(payment_hash.0), params, Retry::Attempts(0)).unwrap(); + nodes[0] + .node + .send_payment( + payment_hash, + RecipientOnionFields::secret_only(*invoice.payment_secret()), + PaymentId(payment_hash.0), + params, + Retry::Attempts(0), + ) + .unwrap(); let mut added_monitors = nodes[0].chain_monitor.added_monitors.lock().unwrap(); assert_eq!(added_monitors.len(), 1); added_monitors.clear(); @@ -1342,14 +1705,24 @@ mod test { MessageSendEvent::UpdateHTLCs { node_id, .. } => { if node_id == nodes[1].node.get_our_node_id() { 1 - } else { 2 } + } else { + 2 + } }, - _ => panic!("Unexpected event") + _ => panic!("Unexpected event"), }; (SendEvent::from_event(events.remove(0)), fwd_idx) }; - nodes[fwd_idx].node.handle_update_add_htlc(&nodes[0].node.get_our_node_id(), &payment_event.msgs[0]); - commitment_signed_dance!(nodes[fwd_idx], nodes[0], &payment_event.commitment_msg, false, true); + nodes[fwd_idx] + .node + .handle_update_add_htlc(&nodes[0].node.get_our_node_id(), &payment_event.msgs[0]); + commitment_signed_dance!( + nodes[fwd_idx], + nodes[0], + &payment_event.commitment_msg, + false, + true + ); // Note that we have to "forward pending HTLCs" twice before we see the PaymentClaimable as // this "emulates" the payment taking two hops, providing some privacy to make phantom node @@ -1365,10 +1738,23 @@ mod test { nodes[fwd_idx].node.process_pending_events(&forward_event_handler); nodes[fwd_idx].node.process_pending_events(&forward_event_handler); - let payment_preimage_opt = if user_generated_pmt_hash { None } else { Some(payment_preimage) }; + let payment_preimage_opt = + if user_generated_pmt_hash { None } else { Some(payment_preimage) }; assert_eq!(other_events.borrow().len(), 1); - check_payment_claimable(&other_events.borrow()[0], payment_hash, payment_secret, payment_amt, payment_preimage_opt, invoice.recover_payee_pub_key()); - do_claim_payment_along_route(&nodes[0], &[&vec!(&nodes[fwd_idx])[..]], false, payment_preimage); + check_payment_claimable( + &other_events.borrow()[0], + payment_hash, + payment_secret, + payment_amt, + payment_preimage_opt, + invoice.recover_payee_pub_key(), + ); + do_claim_payment_along_route( + &nodes[0], + &[&vec![&nodes[fwd_idx]][..]], + false, + payment_preimage, + ); expect_payment_sent(&nodes[0], payment_preimage, None, true, true); } @@ -1379,8 +1765,10 @@ mod test { let seed_1 = [42u8; 32]; let seed_2 = [43u8; 32]; let cross_node_seed = [44u8; 32]; - chanmon_cfgs[1].keys_manager.backing = PhantomKeysManager::new(&seed_1, 43, 44, &cross_node_seed); - chanmon_cfgs[2].keys_manager.backing = PhantomKeysManager::new(&seed_2, 43, 44, &cross_node_seed); + chanmon_cfgs[1].keys_manager.backing = + PhantomKeysManager::new(&seed_1, 43, 44, &cross_node_seed); + chanmon_cfgs[2].keys_manager.backing = + PhantomKeysManager::new(&seed_2, 43, 44, &cross_node_seed); let node_cfgs = create_node_cfgs(3, &chanmon_cfgs); let node_chanmgrs = create_node_chanmgrs(3, &node_cfgs, &[None, None, None]); let nodes = create_network(3, &node_cfgs, &node_chanmgrs); @@ -1389,24 +1777,49 @@ mod test { create_unannounced_chan_between_nodes_with_value(&nodes, 0, 2, 100000, 10001); let payment_amt = 20_000; - let (payment_hash, _payment_secret) = nodes[1].node.create_inbound_payment(Some(payment_amt), 3600, None).unwrap(); - let route_hints = vec![ - nodes[1].node.get_phantom_route_hints(), - nodes[2].node.get_phantom_route_hints(), - ]; - - let invoice = crate::utils::create_phantom_invoice::<&test_utils::TestKeysInterface, - &test_utils::TestKeysInterface, &test_utils::TestLogger>(Some(payment_amt), Some(payment_hash), - "test".to_string(), 3600, route_hints, nodes[1].keys_manager, nodes[1].keys_manager, - nodes[1].logger, Currency::BitcoinTestnet, None, Duration::from_secs(1234567)).unwrap(); + let (payment_hash, _payment_secret) = + nodes[1].node.create_inbound_payment(Some(payment_amt), 3600, None).unwrap(); + let route_hints = + vec![nodes[1].node.get_phantom_route_hints(), nodes[2].node.get_phantom_route_hints()]; + + let invoice = crate::utils::create_phantom_invoice::< + &test_utils::TestKeysInterface, + &test_utils::TestKeysInterface, + &test_utils::TestLogger, + >( + Some(payment_amt), + Some(payment_hash), + "test".to_string(), + 3600, + route_hints, + nodes[1].keys_manager, + nodes[1].keys_manager, + nodes[1].logger, + Currency::BitcoinTestnet, + None, + Duration::from_secs(1234567), + ) + .unwrap(); let chan_0_1 = &nodes[1].node.list_usable_channels()[0]; - assert_eq!(invoice.route_hints()[0].0[0].htlc_minimum_msat, chan_0_1.inbound_htlc_minimum_msat); - assert_eq!(invoice.route_hints()[0].0[0].htlc_maximum_msat, chan_0_1.inbound_htlc_maximum_msat); + assert_eq!( + invoice.route_hints()[0].0[0].htlc_minimum_msat, + chan_0_1.inbound_htlc_minimum_msat + ); + assert_eq!( + invoice.route_hints()[0].0[0].htlc_maximum_msat, + chan_0_1.inbound_htlc_maximum_msat + ); let chan_0_2 = &nodes[2].node.list_usable_channels()[0]; - assert_eq!(invoice.route_hints()[1].0[0].htlc_minimum_msat, chan_0_2.inbound_htlc_minimum_msat); - assert_eq!(invoice.route_hints()[1].0[0].htlc_maximum_msat, chan_0_2.inbound_htlc_maximum_msat); + assert_eq!( + invoice.route_hints()[1].0[0].htlc_minimum_msat, + chan_0_2.inbound_htlc_minimum_msat + ); + assert_eq!( + invoice.route_hints()[1].0[0].htlc_maximum_msat, + chan_0_2.inbound_htlc_maximum_msat + ); } #[test] @@ -1418,25 +1831,42 @@ mod test { let nodes = create_network(3, &node_cfgs, &node_chanmgrs); let payment_amt = 20_000; - let route_hints = vec![ - nodes[1].node.get_phantom_route_hints(), - nodes[2].node.get_phantom_route_hints(), - ]; + let route_hints = + vec![nodes[1].node.get_phantom_route_hints(), nodes[2].node.get_phantom_route_hints()]; - let description_hash = crate::Sha256(Hash::hash("Description hash phantom invoice".as_bytes())); + let description_hash = + crate::Sha256(Hash::hash("Description hash phantom invoice".as_bytes())); let non_default_invoice_expiry_secs = 4200; let invoice = crate::utils::create_phantom_invoice_with_description_hash::< - &test_utils::TestKeysInterface, &test_utils::TestKeysInterface, &test_utils::TestLogger, + &test_utils::TestKeysInterface, + &test_utils::TestKeysInterface, + &test_utils::TestLogger, >( - Some(payment_amt), None, non_default_invoice_expiry_secs, description_hash, - route_hints, nodes[1].keys_manager, nodes[1].keys_manager, nodes[1].logger, - Currency::BitcoinTestnet, None, Duration::from_secs(1234567), + Some(payment_amt), + None, + non_default_invoice_expiry_secs, + description_hash, + route_hints, + nodes[1].keys_manager, + nodes[1].keys_manager, + nodes[1].logger, + Currency::BitcoinTestnet, + None, + Duration::from_secs(1234567), ) .unwrap(); assert_eq!(invoice.amount_pico_btc(), Some(200_000)); assert_eq!(invoice.min_final_cltv_expiry_delta(), MIN_FINAL_CLTV_EXPIRY_DELTA as u64); - assert_eq!(invoice.expiry_time(), Duration::from_secs(non_default_invoice_expiry_secs.into())); - assert_eq!(invoice.description(), Bolt11InvoiceDescription::Hash(&crate::Sha256(Sha256::hash("Description hash phantom invoice".as_bytes())))); + assert_eq!( + invoice.expiry_time(), + Duration::from_secs(non_default_invoice_expiry_secs.into()) + ); + assert_eq!( + invoice.description(), + Bolt11InvoiceDescription::Hash(&crate::Sha256(Sha256::hash( + "Description hash phantom invoice".as_bytes() + ))) + ); } #[test] @@ -1448,22 +1878,41 @@ mod test { let nodes = create_network(3, &node_cfgs, &node_chanmgrs); let payment_amt = 20_000; - let route_hints = vec![ - nodes[1].node.get_phantom_route_hints(), - nodes[2].node.get_phantom_route_hints(), - ]; + let route_hints = + vec![nodes[1].node.get_phantom_route_hints(), nodes[2].node.get_phantom_route_hints()]; let user_payment_preimage = PaymentPreimage([1; 32]); - let payment_hash = Some(PaymentHash(Sha256::hash(&user_payment_preimage.0[..]).into_inner())); + let payment_hash = + Some(PaymentHash(Sha256::hash(&user_payment_preimage.0[..]).into_inner())); let non_default_invoice_expiry_secs = 4200; let min_final_cltv_expiry_delta = Some(100); let duration_since_epoch = Duration::from_secs(1234567); - let invoice = crate::utils::create_phantom_invoice::<&test_utils::TestKeysInterface, - &test_utils::TestKeysInterface, &test_utils::TestLogger>(Some(payment_amt), payment_hash, - "".to_string(), non_default_invoice_expiry_secs, route_hints, nodes[1].keys_manager, nodes[1].keys_manager, - nodes[1].logger, Currency::BitcoinTestnet, min_final_cltv_expiry_delta, duration_since_epoch).unwrap(); + let invoice = crate::utils::create_phantom_invoice::< + &test_utils::TestKeysInterface, + &test_utils::TestKeysInterface, + &test_utils::TestLogger, + >( + Some(payment_amt), + payment_hash, + "".to_string(), + non_default_invoice_expiry_secs, + route_hints, + nodes[1].keys_manager, + nodes[1].keys_manager, + nodes[1].logger, + Currency::BitcoinTestnet, + min_final_cltv_expiry_delta, + duration_since_epoch, + ) + .unwrap(); assert_eq!(invoice.amount_pico_btc(), Some(200_000)); - assert_eq!(invoice.min_final_cltv_expiry_delta(), (min_final_cltv_expiry_delta.unwrap() + 3) as u64); - assert_eq!(invoice.expiry_time(), Duration::from_secs(non_default_invoice_expiry_secs.into())); + assert_eq!( + invoice.min_final_cltv_expiry_delta(), + (min_final_cltv_expiry_delta.unwrap() + 3) as u64 + ); + assert_eq!( + invoice.expiry_time(), + Duration::from_secs(non_default_invoice_expiry_secs.into()) + ); } #[test] @@ -1473,14 +1922,18 @@ mod test { let seed_1 = [42u8; 32]; let seed_2 = [43u8; 32]; let cross_node_seed = [44u8; 32]; - chanmon_cfgs[1].keys_manager.backing = PhantomKeysManager::new(&seed_1, 43, 44, &cross_node_seed); - chanmon_cfgs[2].keys_manager.backing = PhantomKeysManager::new(&seed_2, 43, 44, &cross_node_seed); + chanmon_cfgs[1].keys_manager.backing = + PhantomKeysManager::new(&seed_1, 43, 44, &cross_node_seed); + chanmon_cfgs[2].keys_manager.backing = + PhantomKeysManager::new(&seed_2, 43, 44, &cross_node_seed); let node_cfgs = create_node_cfgs(3, &chanmon_cfgs); let node_chanmgrs = create_node_chanmgrs(3, &node_cfgs, &[None, None, None]); let nodes = create_network(3, &node_cfgs, &node_chanmgrs); - let chan_0_1 = create_unannounced_chan_between_nodes_with_value(&nodes, 0, 1, 100000, 10001); - let chan_0_2 = create_unannounced_chan_between_nodes_with_value(&nodes, 0, 2, 100000, 10001); + let chan_0_1 = + create_unannounced_chan_between_nodes_with_value(&nodes, 0, 1, 100000, 10001); + let chan_0_2 = + create_unannounced_chan_between_nodes_with_value(&nodes, 0, 2, 100000, 10001); let mut scid_aliases = HashSet::new(); scid_aliases.insert(chan_0_1.0.short_channel_id_alias.unwrap()); @@ -1489,28 +1942,34 @@ mod test { match_multi_node_invoice_routes( Some(10_000), &nodes[1], - vec![&nodes[1], &nodes[2],], + vec![&nodes[1], &nodes[2]], scid_aliases, - false + false, ); } #[test] #[cfg(feature = "std")] - fn test_multi_node_hints_includes_one_channel_of_each_counterparty_nodes_per_participating_node() { + fn test_multi_node_hints_includes_one_channel_of_each_counterparty_nodes_per_participating_node( + ) { let mut chanmon_cfgs = create_chanmon_cfgs(4); let seed_1 = [42u8; 32]; let seed_2 = [43u8; 32]; let cross_node_seed = [44u8; 32]; - chanmon_cfgs[2].keys_manager.backing = PhantomKeysManager::new(&seed_1, 43, 44, &cross_node_seed); - chanmon_cfgs[3].keys_manager.backing = PhantomKeysManager::new(&seed_2, 43, 44, &cross_node_seed); + chanmon_cfgs[2].keys_manager.backing = + PhantomKeysManager::new(&seed_1, 43, 44, &cross_node_seed); + chanmon_cfgs[3].keys_manager.backing = + PhantomKeysManager::new(&seed_2, 43, 44, &cross_node_seed); let node_cfgs = create_node_cfgs(4, &chanmon_cfgs); let node_chanmgrs = create_node_chanmgrs(4, &node_cfgs, &[None, None, None, None]); let nodes = create_network(4, &node_cfgs, &node_chanmgrs); - let chan_0_2 = create_unannounced_chan_between_nodes_with_value(&nodes, 0, 2, 100000, 10001); - let chan_0_3 = create_unannounced_chan_between_nodes_with_value(&nodes, 0, 3, 1000000, 10001); - let chan_1_3 = create_unannounced_chan_between_nodes_with_value(&nodes, 1, 3, 3_000_000, 10005); + let chan_0_2 = + create_unannounced_chan_between_nodes_with_value(&nodes, 0, 2, 100000, 10001); + let chan_0_3 = + create_unannounced_chan_between_nodes_with_value(&nodes, 0, 3, 1000000, 10001); + let chan_1_3 = + create_unannounced_chan_between_nodes_with_value(&nodes, 1, 3, 3_000_000, 10005); let mut scid_aliases = HashSet::new(); scid_aliases.insert(chan_0_2.0.short_channel_id_alias.unwrap()); @@ -1520,9 +1979,9 @@ mod test { match_multi_node_invoice_routes( Some(10_000), &nodes[2], - vec![&nodes[2], &nodes[3],], + vec![&nodes[2], &nodes[3]], scid_aliases, - false + false, ); } @@ -1533,38 +1992,79 @@ mod test { let seed_1 = [42u8; 32]; let seed_2 = [43u8; 32]; let cross_node_seed = [44u8; 32]; - chanmon_cfgs[2].keys_manager.backing = PhantomKeysManager::new(&seed_1, 43, 44, &cross_node_seed); - chanmon_cfgs[3].keys_manager.backing = PhantomKeysManager::new(&seed_2, 43, 44, &cross_node_seed); + chanmon_cfgs[2].keys_manager.backing = + PhantomKeysManager::new(&seed_1, 43, 44, &cross_node_seed); + chanmon_cfgs[3].keys_manager.backing = + PhantomKeysManager::new(&seed_2, 43, 44, &cross_node_seed); let node_cfgs = create_node_cfgs(4, &chanmon_cfgs); let node_chanmgrs = create_node_chanmgrs(4, &node_cfgs, &[None, None, None, None]); let nodes = create_network(4, &node_cfgs, &node_chanmgrs); - let chan_0_2 = create_unannounced_chan_between_nodes_with_value(&nodes, 0, 2, 100000, 10001); - let chan_0_3 = create_unannounced_chan_between_nodes_with_value(&nodes, 0, 3, 1000000, 10001); + let chan_0_2 = + create_unannounced_chan_between_nodes_with_value(&nodes, 0, 2, 100000, 10001); + let chan_0_3 = + create_unannounced_chan_between_nodes_with_value(&nodes, 0, 3, 1000000, 10001); // Create an unannonced channel between `nodes[1]` and `nodes[3]`, for which the // `msgs::ChannelUpdate` is never handled for the node(s). As the `msgs::ChannelUpdate` // is never handled, the `channel.counterparty.forwarding_info` is never assigned. let mut private_chan_cfg = UserConfig::default(); private_chan_cfg.channel_handshake_config.announced_channel = false; - let temporary_channel_id = nodes[1].node.create_channel(nodes[3].node.get_our_node_id(), 1_000_000, 500_000_000, 42, Some(private_chan_cfg)).unwrap(); - let open_channel = get_event_msg!(nodes[1], MessageSendEvent::SendOpenChannel, nodes[3].node.get_our_node_id()); + let temporary_channel_id = nodes[1] + .node + .create_channel( + nodes[3].node.get_our_node_id(), + 1_000_000, + 500_000_000, + 42, + Some(private_chan_cfg), + ) + .unwrap(); + let open_channel = get_event_msg!( + nodes[1], + MessageSendEvent::SendOpenChannel, + nodes[3].node.get_our_node_id() + ); nodes[3].node.handle_open_channel(&nodes[1].node.get_our_node_id(), &open_channel); - let accept_channel = get_event_msg!(nodes[3], MessageSendEvent::SendAcceptChannel, nodes[1].node.get_our_node_id()); + let accept_channel = get_event_msg!( + nodes[3], + MessageSendEvent::SendAcceptChannel, + nodes[1].node.get_our_node_id() + ); nodes[1].node.handle_accept_channel(&nodes[3].node.get_our_node_id(), &accept_channel); let tx = sign_funding_transaction(&nodes[1], &nodes[3], 1_000_000, temporary_channel_id); - let conf_height = core::cmp::max(nodes[1].best_block_info().1 + 1, nodes[3].best_block_info().1 + 1); + let conf_height = + core::cmp::max(nodes[1].best_block_info().1 + 1, nodes[3].best_block_info().1 + 1); confirm_transaction_at(&nodes[1], &tx, conf_height); connect_blocks(&nodes[1], CHAN_CONFIRM_DEPTH - 1); confirm_transaction_at(&nodes[3], &tx, conf_height); connect_blocks(&nodes[3], CHAN_CONFIRM_DEPTH - 1); - let as_channel_ready = get_event_msg!(nodes[1], MessageSendEvent::SendChannelReady, nodes[3].node.get_our_node_id()); - nodes[1].node.handle_channel_ready(&nodes[3].node.get_our_node_id(), &get_event_msg!(nodes[3], MessageSendEvent::SendChannelReady, nodes[1].node.get_our_node_id())); - get_event_msg!(nodes[1], MessageSendEvent::SendChannelUpdate, nodes[3].node.get_our_node_id()); + let as_channel_ready = get_event_msg!( + nodes[1], + MessageSendEvent::SendChannelReady, + nodes[3].node.get_our_node_id() + ); + nodes[1].node.handle_channel_ready( + &nodes[3].node.get_our_node_id(), + &get_event_msg!( + nodes[3], + MessageSendEvent::SendChannelReady, + nodes[1].node.get_our_node_id() + ), + ); + get_event_msg!( + nodes[1], + MessageSendEvent::SendChannelUpdate, + nodes[3].node.get_our_node_id() + ); nodes[3].node.handle_channel_ready(&nodes[1].node.get_our_node_id(), &as_channel_ready); - get_event_msg!(nodes[3], MessageSendEvent::SendChannelUpdate, nodes[1].node.get_our_node_id()); + get_event_msg!( + nodes[3], + MessageSendEvent::SendChannelUpdate, + nodes[1].node.get_our_node_id() + ); expect_channel_ready_event(&nodes[1], &nodes[3].node.get_our_node_id()); expect_channel_ready_event(&nodes[3], &nodes[1].node.get_our_node_id()); @@ -1578,9 +2078,9 @@ mod test { match_multi_node_invoice_routes( Some(10_000), &nodes[2], - vec![&nodes[2], &nodes[3],], + vec![&nodes[2], &nodes[3]], scid_aliases, - false + false, ); } @@ -1591,13 +2091,16 @@ mod test { let seed_1 = [42u8; 32]; let seed_2 = [43u8; 32]; let cross_node_seed = [44u8; 32]; - chanmon_cfgs[1].keys_manager.backing = PhantomKeysManager::new(&seed_1, 43, 44, &cross_node_seed); - chanmon_cfgs[2].keys_manager.backing = PhantomKeysManager::new(&seed_2, 43, 44, &cross_node_seed); + chanmon_cfgs[1].keys_manager.backing = + PhantomKeysManager::new(&seed_1, 43, 44, &cross_node_seed); + chanmon_cfgs[2].keys_manager.backing = + PhantomKeysManager::new(&seed_2, 43, 44, &cross_node_seed); let node_cfgs = create_node_cfgs(3, &chanmon_cfgs); let node_chanmgrs = create_node_chanmgrs(3, &node_cfgs, &[None, None, None]); let nodes = create_network(3, &node_cfgs, &node_chanmgrs); - let chan_0_1 = create_unannounced_chan_between_nodes_with_value(&nodes, 0, 1, 100000, 10001); + let chan_0_1 = + create_unannounced_chan_between_nodes_with_value(&nodes, 0, 1, 100000, 10001); let chan_2_0 = create_announced_chan_between_nodes_with_value(&nodes, 2, 0, 100000, 10001); nodes[2].node.handle_channel_update(&nodes[0].node.get_our_node_id(), &chan_2_0.1); @@ -1611,9 +2114,9 @@ mod test { match_multi_node_invoice_routes( Some(10_000), &nodes[1], - vec![&nodes[1], &nodes[2],], + vec![&nodes[1], &nodes[2]], scid_aliases, - true + true, ); } @@ -1624,8 +2127,10 @@ mod test { let seed_1 = [42u8; 32]; let seed_2 = [43u8; 32]; let cross_node_seed = [44u8; 32]; - chanmon_cfgs[1].keys_manager.backing = PhantomKeysManager::new(&seed_1, 43, 44, &cross_node_seed); - chanmon_cfgs[2].keys_manager.backing = PhantomKeysManager::new(&seed_2, 43, 44, &cross_node_seed); + chanmon_cfgs[1].keys_manager.backing = + PhantomKeysManager::new(&seed_1, 43, 44, &cross_node_seed); + chanmon_cfgs[2].keys_manager.backing = + PhantomKeysManager::new(&seed_2, 43, 44, &cross_node_seed); let node_cfgs = create_node_cfgs(4, &chanmon_cfgs); let node_chanmgrs = create_node_chanmgrs(4, &node_cfgs, &[None, None, None, None]); let nodes = create_network(4, &node_cfgs, &node_chanmgrs); @@ -1633,9 +2138,11 @@ mod test { let chan_0_2 = create_announced_chan_between_nodes_with_value(&nodes, 0, 2, 100000, 10001); nodes[0].node.handle_channel_update(&nodes[2].node.get_our_node_id(), &chan_0_2.1); nodes[2].node.handle_channel_update(&nodes[0].node.get_our_node_id(), &chan_0_2.0); - let _chan_1_2 = create_unannounced_chan_between_nodes_with_value(&nodes, 1, 2, 100000, 10001); + let _chan_1_2 = + create_unannounced_chan_between_nodes_with_value(&nodes, 1, 2, 100000, 10001); - let chan_0_3 = create_unannounced_chan_between_nodes_with_value(&nodes, 0, 3, 100000, 10001); + let chan_0_3 = + create_unannounced_chan_between_nodes_with_value(&nodes, 0, 3, 100000, 10001); // Hints should include `chan_0_3` from as `nodes[3]` only have private channels, and no // channels for `nodes[2]` as it contains a mix of public and private channels. @@ -1645,9 +2152,9 @@ mod test { match_multi_node_invoice_routes( Some(10_000), &nodes[2], - vec![&nodes[2], &nodes[3],], + vec![&nodes[2], &nodes[3]], scid_aliases, - true + true, ); } @@ -1658,16 +2165,22 @@ mod test { let seed_1 = [42u8; 32]; let seed_2 = [43u8; 32]; let cross_node_seed = [44u8; 32]; - chanmon_cfgs[1].keys_manager.backing = PhantomKeysManager::new(&seed_1, 43, 44, &cross_node_seed); - chanmon_cfgs[2].keys_manager.backing = PhantomKeysManager::new(&seed_2, 43, 44, &cross_node_seed); + chanmon_cfgs[1].keys_manager.backing = + PhantomKeysManager::new(&seed_1, 43, 44, &cross_node_seed); + chanmon_cfgs[2].keys_manager.backing = + PhantomKeysManager::new(&seed_2, 43, 44, &cross_node_seed); let node_cfgs = create_node_cfgs(3, &chanmon_cfgs); let node_chanmgrs = create_node_chanmgrs(3, &node_cfgs, &[None, None, None]); let nodes = create_network(3, &node_cfgs, &node_chanmgrs); - let _chan_0_1_below_amt = create_unannounced_chan_between_nodes_with_value(&nodes, 0, 1, 100_000, 0); - let _chan_0_1_above_amt_high_inbound = create_unannounced_chan_between_nodes_with_value(&nodes, 0, 1, 500_000, 0); - let chan_0_1_above_amt_low_inbound = create_unannounced_chan_between_nodes_with_value(&nodes, 0, 1, 180_000, 0); - let chan_0_2 = create_unannounced_chan_between_nodes_with_value(&nodes, 0, 2, 100000, 10001); + let _chan_0_1_below_amt = + create_unannounced_chan_between_nodes_with_value(&nodes, 0, 1, 100_000, 0); + let _chan_0_1_above_amt_high_inbound = + create_unannounced_chan_between_nodes_with_value(&nodes, 0, 1, 500_000, 0); + let chan_0_1_above_amt_low_inbound = + create_unannounced_chan_between_nodes_with_value(&nodes, 0, 1, 180_000, 0); + let chan_0_2 = + create_unannounced_chan_between_nodes_with_value(&nodes, 0, 2, 100000, 10001); let mut scid_aliases = HashSet::new(); scid_aliases.insert(chan_0_1_above_amt_low_inbound.0.short_channel_id_alias.unwrap()); @@ -1676,9 +2189,9 @@ mod test { match_multi_node_invoice_routes( Some(100_000_000), &nodes[1], - vec![&nodes[1], &nodes[2],], + vec![&nodes[1], &nodes[2]], scid_aliases, - false + false, ); } @@ -1689,8 +2202,10 @@ mod test { let seed_1 = [42u8; 32]; let seed_2 = [43u8; 32]; let cross_node_seed = [44u8; 32]; - chanmon_cfgs[1].keys_manager.backing = PhantomKeysManager::new(&seed_1, 43, 44, &cross_node_seed); - chanmon_cfgs[2].keys_manager.backing = PhantomKeysManager::new(&seed_2, 43, 44, &cross_node_seed); + chanmon_cfgs[1].keys_manager.backing = + PhantomKeysManager::new(&seed_1, 43, 44, &cross_node_seed); + chanmon_cfgs[2].keys_manager.backing = + PhantomKeysManager::new(&seed_2, 43, 44, &cross_node_seed); let node_cfgs = create_node_cfgs(4, &chanmon_cfgs); let node_chanmgrs = create_node_chanmgrs(4, &node_cfgs, &[None, None, None, None]); let nodes = create_network(4, &node_cfgs, &node_chanmgrs); @@ -1707,9 +2222,9 @@ mod test { match_multi_node_invoice_routes( Some(99_000_001), &nodes[2], - vec![&nodes[2], &nodes[3],], + vec![&nodes[2], &nodes[3]], scid_aliases_99_000_001_msat, - false + false, ); // Since the invoice is exactly at chan_0_3's inbound capacity, it should be included. @@ -1721,9 +2236,9 @@ mod test { match_multi_node_invoice_routes( Some(99_000_000), &nodes[2], - vec![&nodes[2], &nodes[3],], + vec![&nodes[2], &nodes[3]], scid_aliases_99_000_000_msat, - false + false, ); // Since the invoice is above all of `nodes[2]` channels' inbound capacity, all of @@ -1736,9 +2251,9 @@ mod test { match_multi_node_invoice_routes( Some(300_000_000), &nodes[2], - vec![&nodes[2], &nodes[3],], + vec![&nodes[2], &nodes[3]], scid_aliases_300_000_000_msat, - false + false, ); // Since the no specified amount, all channels should included. @@ -1750,9 +2265,9 @@ mod test { match_multi_node_invoice_routes( None, &nodes[2], - vec![&nodes[2], &nodes[3],], + vec![&nodes[2], &nodes[3]], scid_aliases_no_specified_amount, - false + false, ); } @@ -1764,12 +2279,17 @@ mod test { let seed_3 = [44 as u8; 32]; let seed_4 = [45 as u8; 32]; let cross_node_seed = [44 as u8; 32]; - chanmon_cfgs[2].keys_manager.backing = PhantomKeysManager::new(&seed_1, 43, 44, &cross_node_seed); - chanmon_cfgs[3].keys_manager.backing = PhantomKeysManager::new(&seed_2, 43, 44, &cross_node_seed); - chanmon_cfgs[4].keys_manager.backing = PhantomKeysManager::new(&seed_3, 43, 44, &cross_node_seed); - chanmon_cfgs[5].keys_manager.backing = PhantomKeysManager::new(&seed_4, 43, 44, &cross_node_seed); + chanmon_cfgs[2].keys_manager.backing = + PhantomKeysManager::new(&seed_1, 43, 44, &cross_node_seed); + chanmon_cfgs[3].keys_manager.backing = + PhantomKeysManager::new(&seed_2, 43, 44, &cross_node_seed); + chanmon_cfgs[4].keys_manager.backing = + PhantomKeysManager::new(&seed_3, 43, 44, &cross_node_seed); + chanmon_cfgs[5].keys_manager.backing = + PhantomKeysManager::new(&seed_4, 43, 44, &cross_node_seed); let node_cfgs = create_node_cfgs(6, &chanmon_cfgs); - let node_chanmgrs = create_node_chanmgrs(6, &node_cfgs, &[None, None, None, None, None, None]); + let node_chanmgrs = + create_node_chanmgrs(6, &node_cfgs, &[None, None, None, None, None, None]); let nodes = create_network(6, &node_cfgs, &node_chanmgrs); // Setup each phantom node with two channels from distinct peers. @@ -1821,8 +2341,10 @@ mod test { let seed_1 = [42 as u8; 32]; let seed_2 = [43 as u8; 32]; let cross_node_seed = [44 as u8; 32]; - chanmon_cfgs[1].keys_manager.backing = PhantomKeysManager::new(&seed_1, 43, 44, &cross_node_seed); - chanmon_cfgs[2].keys_manager.backing = PhantomKeysManager::new(&seed_2, 43, 44, &cross_node_seed); + chanmon_cfgs[1].keys_manager.backing = + PhantomKeysManager::new(&seed_1, 43, 44, &cross_node_seed); + chanmon_cfgs[2].keys_manager.backing = + PhantomKeysManager::new(&seed_2, 43, 44, &cross_node_seed); let node_cfgs = create_node_cfgs(5, &chanmon_cfgs); let node_chanmgrs = create_node_chanmgrs(5, &node_cfgs, &[None, None, None, None, None]); let nodes = create_network(5, &node_cfgs, &node_chanmgrs); @@ -1842,30 +2364,44 @@ mod test { match_multi_node_invoice_routes( Some(100_000_000), &nodes[3], - vec![&nodes[3], &nodes[4],], + vec![&nodes[3], &nodes[4]], scid_aliases, false, ); } fn match_multi_node_invoice_routes<'a, 'b: 'a, 'c: 'b>( - invoice_amt: Option, - invoice_node: &Node<'a, 'b, 'c>, - network_multi_nodes: Vec<&Node<'a, 'b, 'c>>, - mut chan_ids_to_match: HashSet, - nodes_contains_public_channels: bool - ){ - let phantom_route_hints = network_multi_nodes.iter() + invoice_amt: Option, invoice_node: &Node<'a, 'b, 'c>, + network_multi_nodes: Vec<&Node<'a, 'b, 'c>>, mut chan_ids_to_match: HashSet, + nodes_contains_public_channels: bool, + ) { + let phantom_route_hints = network_multi_nodes + .iter() .map(|node| node.node.get_phantom_route_hints()) .collect::>(); - let phantom_scids = phantom_route_hints.iter() + let phantom_scids = phantom_route_hints + .iter() .map(|route_hint| route_hint.phantom_scid) .collect::>(); - let invoice = crate::utils::create_phantom_invoice::<&test_utils::TestKeysInterface, - &test_utils::TestKeysInterface, &test_utils::TestLogger>(invoice_amt, None, "test".to_string(), - 3600, phantom_route_hints, invoice_node.keys_manager, invoice_node.keys_manager, - invoice_node.logger, Currency::BitcoinTestnet, None, Duration::from_secs(1234567)).unwrap(); + let invoice = crate::utils::create_phantom_invoice::< + &test_utils::TestKeysInterface, + &test_utils::TestKeysInterface, + &test_utils::TestLogger, + >( + invoice_amt, + None, + "test".to_string(), + 3600, + phantom_route_hints, + invoice_node.keys_manager, + invoice_node.keys_manager, + invoice_node.logger, + Currency::BitcoinTestnet, + None, + Duration::from_secs(1234567), + ) + .unwrap(); let invoice_hints = invoice.private_routes(); @@ -1883,10 +2419,14 @@ mod test { let phantom_scid = hints[1].short_channel_id; assert!(phantom_scids.contains(&phantom_scid)); }, - _ => panic!("Incorrect hint length generated") + _ => panic!("Incorrect hint length generated"), } } - assert!(chan_ids_to_match.is_empty(), "Unmatched short channel ids: {:?}", chan_ids_to_match); + assert!( + chan_ids_to_match.is_empty(), + "Unmatched short channel ids: {:?}", + chan_ids_to_match + ); } #[test] @@ -1896,11 +2436,20 @@ mod test { let node_chanmgrs = create_node_chanmgrs(2, &node_cfgs, &[None, None]); let nodes = create_network(2, &node_cfgs, &node_chanmgrs); let result = crate::utils::create_invoice_from_channelmanager_and_duration_since_epoch( - nodes[1].node, nodes[1].keys_manager, nodes[1].logger, Currency::BitcoinTestnet, - Some(10_000), "Some description".into(), Duration::from_secs(1234567), 3600, Some(MIN_FINAL_CLTV_EXPIRY_DELTA - 4), + nodes[1].node, + nodes[1].keys_manager, + nodes[1].logger, + Currency::BitcoinTestnet, + Some(10_000), + "Some description".into(), + Duration::from_secs(1234567), + 3600, + Some(MIN_FINAL_CLTV_EXPIRY_DELTA - 4), ); match result { - Err(SignOrCreationError::CreationError(CreationError::MinFinalCltvExpiryDeltaTooShort)) => {}, + Err(SignOrCreationError::CreationError( + CreationError::MinFinalCltvExpiryDeltaTooShort, + )) => {}, _ => panic!(), } } @@ -1929,7 +2478,11 @@ mod test { assert_eq!(expected, result); // test three nestend vectors - let a = vec![vec!["a0"].into_iter(), vec!["a1", "b1", "c1"].into_iter(), vec!["a2"].into_iter()]; + let a = vec![ + vec!["a0"].into_iter(), + vec!["a1", "b1", "c1"].into_iter(), + vec!["a2"].into_iter(), + ]; let result = rotate_through_iterators(a).collect::>(); let expected = vec!["a0", "a1", "a2", "b1", "c1"]; @@ -1943,24 +2496,25 @@ mod test { assert_eq!(expected, result); // test single empty nested vector - let a:Vec> = vec![vec![].into_iter()]; + let a: Vec> = vec![vec![].into_iter()]; let result = rotate_through_iterators(a).collect::>(); - let expected:Vec<&str> = vec![]; + let expected: Vec<&str> = vec![]; assert_eq!(expected, result); // test first nested vector is empty - let a:Vec>= vec![vec![].into_iter(), vec!["a1", "b1", "c1"].into_iter()]; + let a: Vec> = + vec![vec![].into_iter(), vec!["a1", "b1", "c1"].into_iter()]; let result = rotate_through_iterators(a).collect::>(); let expected = vec!["a1", "b1", "c1"]; assert_eq!(expected, result); // test two empty vectors - let a:Vec> = vec![vec![].into_iter(), vec![].into_iter()]; + let a: Vec> = vec![vec![].into_iter(), vec![].into_iter()]; let result = rotate_through_iterators(a).collect::>(); - let expected:Vec<&str> = vec![]; + let expected: Vec<&str> = vec![]; assert_eq!(expected, result); // test an empty vector amongst other filled vectors diff --git a/lightning-invoice/tests/ser_de.rs b/lightning-invoice/tests/ser_de.rs index e21b82eae3c..05025b81ee9 100644 --- a/lightning-invoice/tests/ser_de.rs +++ b/lightning-invoice/tests/ser_de.rs @@ -1,9 +1,9 @@ extern crate bech32; extern crate bitcoin_hashes; +extern crate hex; extern crate lightning; extern crate lightning_invoice; extern crate secp256k1; -extern crate hex; use bitcoin::util::address::WitnessVersion; use bitcoin::{PubkeyHash, ScriptHash}; @@ -13,11 +13,11 @@ use lightning::ln::PaymentSecret; use lightning::routing::gossip::RoutingFees; use lightning::routing::router::{RouteHint, RouteHintHop}; use lightning_invoice::*; -use secp256k1::PublicKey; use secp256k1::ecdsa::{RecoverableSignature, RecoveryId}; +use secp256k1::PublicKey; use std::collections::HashSet; -use std::time::Duration; use std::str::FromStr; +use std::time::Duration; fn get_test_tuples() -> Vec<(String, SignedRawBolt11Invoice, bool, bool)> { vec![ @@ -387,7 +387,8 @@ fn get_test_tuples() -> Vec<(String, SignedRawBolt11Invoice, bool, bool)> { #[test] fn invoice_deserialize() { - for (serialized, deserialized, ignore_feature_diff, ignore_unknown_fields) in get_test_tuples() { + for (serialized, deserialized, ignore_feature_diff, ignore_unknown_fields) in get_test_tuples() + { eprintln!("Testing invoice {}...", serialized); let parsed = serialized.parse::().unwrap(); @@ -398,17 +399,33 @@ fn invoice_deserialize() { assert_eq!(deserialized_invoice.hrp, parsed_invoice.hrp); assert_eq!(deserialized_invoice.data.timestamp, parsed_invoice.data.timestamp); - let mut deserialized_hunks: HashSet<_> = deserialized_invoice.data.tagged_fields.iter().collect(); + let mut deserialized_hunks: HashSet<_> = + deserialized_invoice.data.tagged_fields.iter().collect(); let mut parsed_hunks: HashSet<_> = parsed_invoice.data.tagged_fields.iter().collect(); if ignore_feature_diff { - deserialized_hunks.retain(|h| - if let RawTaggedField::KnownSemantics(TaggedField::Features(_)) = h { false } else { true }); - parsed_hunks.retain(|h| - if let RawTaggedField::KnownSemantics(TaggedField::Features(_)) = h { false } else { true }); + deserialized_hunks.retain(|h| { + if let RawTaggedField::KnownSemantics(TaggedField::Features(_)) = h { + false + } else { + true + } + }); + parsed_hunks.retain(|h| { + if let RawTaggedField::KnownSemantics(TaggedField::Features(_)) = h { + false + } else { + true + } + }); } if ignore_unknown_fields { - parsed_hunks.retain(|h| - if let RawTaggedField::UnknownSemantics(_) = h { false } else { true }); + parsed_hunks.retain(|h| { + if let RawTaggedField::UnknownSemantics(_) = h { + false + } else { + true + } + }); } assert_eq!(deserialized_hunks, parsed_hunks); diff --git a/lightning-net-tokio/src/lib.rs b/lightning-net-tokio/src/lib.rs index bac18b2b398..77e904dc33d 100644 --- a/lightning-net-tokio/src/lib.rs +++ b/lightning-net-tokio/src/lib.rs @@ -25,31 +25,30 @@ // Prefix these with `rustdoc::` when we update our MSRV to be >= 1.52 to remove warnings. #![deny(broken_intra_doc_links)] #![deny(private_intra_doc_links)] - #![deny(missing_docs)] #![cfg_attr(docsrs, feature(doc_auto_cfg))] use bitcoin::secp256k1::PublicKey; use tokio::net::TcpStream; -use tokio::time; use tokio::sync::mpsc; +use tokio::time; +use lightning::ln::msgs::SocketAddress; use lightning::ln::peer_handler; -use lightning::ln::peer_handler::SocketDescriptor as LnSocketTrait; use lightning::ln::peer_handler::APeerManager; -use lightning::ln::msgs::SocketAddress; +use lightning::ln::peer_handler::SocketDescriptor as LnSocketTrait; -use std::ops::Deref; -use std::task::{self, Poll}; use std::future::Future; +use std::hash::Hash; use std::net::SocketAddr; use std::net::TcpStream as StdTcpStream; -use std::sync::{Arc, Mutex}; +use std::ops::Deref; +use std::pin::Pin; use std::sync::atomic::{AtomicU64, Ordering}; +use std::sync::{Arc, Mutex}; +use std::task::{self, Poll}; use std::time::Duration; -use std::pin::Pin; -use std::hash::Hash; static ID_COUNTER: AtomicU64 = AtomicU64::new(0); @@ -58,27 +57,34 @@ static ID_COUNTER: AtomicU64 = AtomicU64::new(0); // define a trivial two- and three- select macro with the specific types we need and just use that. pub(crate) enum SelectorOutput { - A(Option<()>), B(Option<()>), C(tokio::io::Result<()>), + A(Option<()>), + B(Option<()>), + C(tokio::io::Result<()>), } pub(crate) struct TwoSelector< - A: Future> + Unpin, B: Future> + Unpin + A: Future> + Unpin, + B: Future> + Unpin, > { pub a: A, pub b: B, } -impl< - A: Future> + Unpin, B: Future> + Unpin -> Future for TwoSelector { +impl> + Unpin, B: Future> + Unpin> Future + for TwoSelector +{ type Output = SelectorOutput; fn poll(mut self: Pin<&mut Self>, ctx: &mut task::Context<'_>) -> Poll { match Pin::new(&mut self.a).poll(ctx) { - Poll::Ready(res) => { return Poll::Ready(SelectorOutput::A(res)); }, + Poll::Ready(res) => { + return Poll::Ready(SelectorOutput::A(res)); + }, Poll::Pending => {}, } match Pin::new(&mut self.b).poll(ctx) { - Poll::Ready(res) => { return Poll::Ready(SelectorOutput::B(res)); }, + Poll::Ready(res) => { + return Poll::Ready(SelectorOutput::B(res)); + }, Poll::Pending => {}, } Poll::Pending @@ -86,7 +92,9 @@ impl< } pub(crate) struct ThreeSelector< - A: Future> + Unpin, B: Future> + Unpin, C: Future> + Unpin + A: Future> + Unpin, + B: Future> + Unpin, + C: Future> + Unpin, > { pub a: A, pub b: B, @@ -94,20 +102,29 @@ pub(crate) struct ThreeSelector< } impl< - A: Future> + Unpin, B: Future> + Unpin, C: Future> + Unpin -> Future for ThreeSelector { + A: Future> + Unpin, + B: Future> + Unpin, + C: Future> + Unpin, + > Future for ThreeSelector +{ type Output = SelectorOutput; fn poll(mut self: Pin<&mut Self>, ctx: &mut task::Context<'_>) -> Poll { match Pin::new(&mut self.a).poll(ctx) { - Poll::Ready(res) => { return Poll::Ready(SelectorOutput::A(res)); }, + Poll::Ready(res) => { + return Poll::Ready(SelectorOutput::A(res)); + }, Poll::Pending => {}, } match Pin::new(&mut self.b).poll(ctx) { - Poll::Ready(res) => { return Poll::Ready(SelectorOutput::B(res)); }, + Poll::Ready(res) => { + return Poll::Ready(SelectorOutput::B(res)); + }, Poll::Pending => {}, } match Pin::new(&mut self.c).poll(ctx) { - Poll::Ready(res) => { return Poll::Ready(SelectorOutput::C(res)); }, + Poll::Ready(res) => { + return Poll::Ready(SelectorOutput::C(res)); + }, Poll::Pending => {}, } Poll::Pending @@ -141,9 +158,10 @@ struct Connection { } impl Connection { async fn poll_event_process( - peer_manager: PM, - mut event_receiver: mpsc::Receiver<()>, - ) where PM::Target: APeerManager { + peer_manager: PM, mut event_receiver: mpsc::Receiver<()>, + ) where + PM::Target: APeerManager, + { loop { if event_receiver.recv().await.is_none() { return; @@ -153,12 +171,11 @@ impl Connection { } async fn schedule_read( - peer_manager: PM, - us: Arc>, - reader: Arc, - mut read_wake_receiver: mpsc::Receiver<()>, - mut write_avail_receiver: mpsc::Receiver<()>, - ) where PM::Target: APeerManager { + peer_manager: PM, us: Arc>, reader: Arc, + mut read_wake_receiver: mpsc::Receiver<()>, mut write_avail_receiver: mpsc::Receiver<()>, + ) where + PM::Target: APeerManager, + { // Create a waker to wake up poll_event_process, above let (event_waker, event_receiver) = mpsc::channel(1); tokio::spawn(Self::poll_event_process(peer_manager.clone(), event_receiver)); @@ -179,7 +196,7 @@ impl Connection { // closed. // In this case, we do need to call peer_manager.socket_disconnected() to inform // Rust-Lightning that the socket is gone. - PeerDisconnected + PeerDisconnected, } let disconnect_type = loop { let read_paused = { @@ -194,28 +211,34 @@ impl Connection { TwoSelector { a: Box::pin(write_avail_receiver.recv()), b: Box::pin(read_wake_receiver.recv()), - }.await + } + .await } else { ThreeSelector { a: Box::pin(write_avail_receiver.recv()), b: Box::pin(read_wake_receiver.recv()), c: Box::pin(reader.readable()), - }.await + } + .await }; match select_result { SelectorOutput::A(v) => { assert!(v.is_some()); // We can't have dropped the sending end, its in the us Arc! - if peer_manager.as_ref().write_buffer_space_avail(&mut our_descriptor).is_err() { + if peer_manager.as_ref().write_buffer_space_avail(&mut our_descriptor).is_err() + { break Disconnect::CloseConnection; } }, SelectorOutput::B(_) => {}, SelectorOutput::C(res) => { - if res.is_err() { break Disconnect::PeerDisconnected; } + if res.is_err() { + break Disconnect::PeerDisconnected; + } match reader.try_read(&mut buf) { Ok(0) => break Disconnect::PeerDisconnected, Ok(len) => { - let read_res = peer_manager.as_ref().read_event(&mut our_descriptor, &buf[0..len]); + let read_res = + peer_manager.as_ref().read_event(&mut our_descriptor, &buf[0..len]); let mut us_lock = us.lock().unwrap(); match read_res { Ok(pause_read) => { @@ -250,7 +273,9 @@ impl Connection { } } - fn new(stream: StdTcpStream) -> (Arc, mpsc::Receiver<()>, mpsc::Receiver<()>, Arc>) { + fn new( + stream: StdTcpStream, + ) -> (Arc, mpsc::Receiver<()>, mpsc::Receiver<()>, Arc>) { // We only ever need a channel of depth 1 here: if we returned a non-full write to the // PeerManager, we will eventually get notified that there is room in the socket to write // new bytes, which will generate an event. That event will be popped off the queue before @@ -264,25 +289,30 @@ impl Connection { stream.set_nonblocking(true).unwrap(); let tokio_stream = Arc::new(TcpStream::from_std(stream).unwrap()); - (Arc::clone(&tokio_stream), write_receiver, read_receiver, - Arc::new(Mutex::new(Self { - writer: Some(tokio_stream), write_avail, read_waker, read_paused: false, - rl_requested_disconnect: false, - id: ID_COUNTER.fetch_add(1, Ordering::AcqRel) - }))) + ( + Arc::clone(&tokio_stream), + write_receiver, + read_receiver, + Arc::new(Mutex::new(Self { + writer: Some(tokio_stream), + write_avail, + read_waker, + read_paused: false, + rl_requested_disconnect: false, + id: ID_COUNTER.fetch_add(1, Ordering::AcqRel), + })), + ) } } fn get_addr_from_stream(stream: &StdTcpStream) -> Option { match stream.peer_addr() { - Ok(SocketAddr::V4(sockaddr)) => Some(SocketAddress::TcpIpV4 { - addr: sockaddr.ip().octets(), - port: sockaddr.port(), - }), - Ok(SocketAddr::V6(sockaddr)) => Some(SocketAddress::TcpIpV6 { - addr: sockaddr.ip().octets(), - port: sockaddr.port(), - }), + Ok(SocketAddr::V4(sockaddr)) => { + Some(SocketAddress::TcpIpV4 { addr: sockaddr.ip().octets(), port: sockaddr.port() }) + }, + Ok(SocketAddr::V6(sockaddr)) => { + Some(SocketAddress::TcpIpV6 { addr: sockaddr.ip().octets(), port: sockaddr.port() }) + }, Err(_) => None, } } @@ -294,17 +324,28 @@ fn get_addr_from_stream(stream: &StdTcpStream) -> Option { /// futures are freed, though, because all processing futures are spawned with tokio::spawn, you do /// not need to poll the provided future in order to make progress. pub fn setup_inbound( - peer_manager: PM, - stream: StdTcpStream, -) -> impl std::future::Future -where PM::Target: APeerManager { + peer_manager: PM, stream: StdTcpStream, +) -> impl std::future::Future +where + PM::Target: APeerManager, +{ let remote_addr = get_addr_from_stream(&stream); let (reader, write_receiver, read_receiver, us) = Connection::new(stream); #[cfg(test)] let last_us = Arc::clone(&us); - let handle_opt = if peer_manager.as_ref().new_inbound_connection(SocketDescriptor::new(us.clone()), remote_addr).is_ok() { - Some(tokio::spawn(Connection::schedule_read(peer_manager, us, reader, read_receiver, write_receiver))) + let handle_opt = if peer_manager + .as_ref() + .new_inbound_connection(SocketDescriptor::new(us.clone()), remote_addr) + .is_ok() + { + Some(tokio::spawn(Connection::schedule_read( + peer_manager, + us, + reader, + read_receiver, + write_receiver, + ))) } else { // Note that we will skip socket_disconnected here, in accordance with the PeerManager // requirements. @@ -336,16 +377,20 @@ where PM::Target: APeerManager { /// futures are freed, though, because all processing futures are spawned with tokio::spawn, you do /// not need to poll the provided future in order to make progress. pub fn setup_outbound( - peer_manager: PM, - their_node_id: PublicKey, - stream: StdTcpStream, -) -> impl std::future::Future -where PM::Target: APeerManager { + peer_manager: PM, their_node_id: PublicKey, stream: StdTcpStream, +) -> impl std::future::Future +where + PM::Target: APeerManager, +{ let remote_addr = get_addr_from_stream(&stream); let (reader, mut write_receiver, read_receiver, us) = Connection::new(stream); #[cfg(test)] let last_us = Arc::clone(&us); - let handle_opt = if let Ok(initial_send) = peer_manager.as_ref().new_outbound_connection(their_node_id, SocketDescriptor::new(us.clone()), remote_addr) { + let handle_opt = if let Ok(initial_send) = peer_manager.as_ref().new_outbound_connection( + their_node_id, + SocketDescriptor::new(us.clone()), + remote_addr, + ) { Some(tokio::spawn(async move { // We should essentially always have enough room in a TCP socket buffer to send the // initial 10s of bytes. However, tokio running in single-threaded mode will always @@ -364,13 +409,18 @@ where PM::Target: APeerManager { }, _ => { eprintln!("Failed to write first full message to socket!"); - peer_manager.as_ref().socket_disconnected(&SocketDescriptor::new(Arc::clone(&us))); + peer_manager + .as_ref() + .socket_disconnected(&SocketDescriptor::new(Arc::clone(&us))); break Err(()); - } + }, } } - }).await { - Connection::schedule_read(peer_manager, us, reader, read_receiver, write_receiver).await; + }) + .await + { + Connection::schedule_read(peer_manager, us, reader, read_receiver, write_receiver) + .await; } })) } else { @@ -408,18 +458,28 @@ where PM::Target: APeerManager { /// futures are spawned with tokio::spawn, you do not need to poll the second future in order to /// make progress. pub async fn connect_outbound( - peer_manager: PM, - their_node_id: PublicKey, - addr: SocketAddr, -) -> Option> -where PM::Target: APeerManager { - if let Ok(Ok(stream)) = time::timeout(Duration::from_secs(10), async { TcpStream::connect(&addr).await.map(|s| s.into_std().unwrap()) }).await { + peer_manager: PM, their_node_id: PublicKey, addr: SocketAddr, +) -> Option> +where + PM::Target: APeerManager, +{ + if let Ok(Ok(stream)) = time::timeout(Duration::from_secs(10), async { + TcpStream::connect(&addr).await.map(|s| s.into_std().unwrap()) + }) + .await + { Some(setup_outbound(peer_manager, their_node_id, stream)) - } else { None } + } else { + None + } } -const SOCK_WAKER_VTABLE: task::RawWakerVTable = - task::RawWakerVTable::new(clone_socket_waker, wake_socket_waker, wake_socket_waker_by_ref, drop_socket_waker); +const SOCK_WAKER_VTABLE: task::RawWakerVTable = task::RawWakerVTable::new( + clone_socket_waker, + wake_socket_waker, + wake_socket_waker_by_ref, + drop_socket_waker, +); fn clone_socket_waker(orig_ptr: *const ()) -> task::RawWaker { write_avail_to_waker(orig_ptr as *const mpsc::Sender<()>) @@ -479,7 +539,9 @@ impl peer_handler::SocketDescriptor for SocketDescriptor { us.read_paused = false; let _ = us.read_waker.try_send(()); } - if data.is_empty() { return 0; } + if data.is_empty() { + return 0; + } let waker = unsafe { task::Waker::from_raw(write_avail_to_waker(&us.write_avail)) }; let mut ctx = task::Context::from_waker(&waker); let mut written_len = 0; @@ -490,7 +552,9 @@ impl peer_handler::SocketDescriptor for SocketDescriptor { Ok(res) => { debug_assert_ne!(res, 0); written_len += res; - if written_len == data.len() { return written_len; } + if written_len == data.len() { + return written_len; + } }, Err(_) => return written_len, } @@ -519,10 +583,7 @@ impl peer_handler::SocketDescriptor for SocketDescriptor { } impl Clone for SocketDescriptor { fn clone(&self) -> Self { - Self { - conn: Arc::clone(&self.conn), - id: self.id, - } + Self { conn: Arc::clone(&self.conn), id: self.id } } } impl Eq for SocketDescriptor {} @@ -539,16 +600,16 @@ impl Hash for SocketDescriptor { #[cfg(test)] mod tests { + use bitcoin::blockdata::constants::ChainHash; + use bitcoin::secp256k1::{PublicKey, Secp256k1, SecretKey}; + use bitcoin::Network; + use lightning::events::*; + use lightning::ln::features::NodeFeatures; use lightning::ln::features::*; use lightning::ln::msgs::*; use lightning::ln::peer_handler::{MessageHandler, PeerManager}; - use lightning::ln::features::NodeFeatures; use lightning::routing::gossip::NodeId; - use lightning::events::*; use lightning::util::test_utils::TestNodeSigner; - use bitcoin::Network; - use bitcoin::blockdata::constants::ChainHash; - use bitcoin::secp256k1::{Secp256k1, SecretKey, PublicKey}; use tokio::sync::mpsc; @@ -560,11 +621,18 @@ mod tests { pub struct TestLogger(); impl lightning::util::logger::Logger for TestLogger { fn log(&self, record: &lightning::util::logger::Record) { - println!("{:<5} [{} : {}, {}] {}", record.level.to_string(), record.module_path, record.file, record.line, record.args); + println!( + "{:<5} [{} : {}, {}] {}", + record.level.to_string(), + record.module_path, + record.file, + record.line, + record.args + ); } } - struct MsgHandler{ + struct MsgHandler { expected_pubkey: PublicKey, pubkey_connected: mpsc::Sender<()>, pubkey_disconnected: mpsc::Sender<()>, @@ -572,19 +640,63 @@ mod tests { msg_events: Mutex>, } impl RoutingMessageHandler for MsgHandler { - fn handle_node_announcement(&self, _msg: &NodeAnnouncement) -> Result { Ok(false) } - fn handle_channel_announcement(&self, _msg: &ChannelAnnouncement) -> Result { Ok(false) } - fn handle_channel_update(&self, _msg: &ChannelUpdate) -> Result { Ok(false) } - fn get_next_channel_announcement(&self, _starting_point: u64) -> Option<(ChannelAnnouncement, Option, Option)> { None } - fn get_next_node_announcement(&self, _starting_point: Option<&NodeId>) -> Option { None } - fn peer_connected(&self, _their_node_id: &PublicKey, _init_msg: &Init, _inbound: bool) -> Result<(), ()> { Ok(()) } - fn handle_reply_channel_range(&self, _their_node_id: &PublicKey, _msg: ReplyChannelRange) -> Result<(), LightningError> { Ok(()) } - fn handle_reply_short_channel_ids_end(&self, _their_node_id: &PublicKey, _msg: ReplyShortChannelIdsEnd) -> Result<(), LightningError> { Ok(()) } - fn handle_query_channel_range(&self, _their_node_id: &PublicKey, _msg: QueryChannelRange) -> Result<(), LightningError> { Ok(()) } - fn handle_query_short_channel_ids(&self, _their_node_id: &PublicKey, _msg: QueryShortChannelIds) -> Result<(), LightningError> { Ok(()) } - fn provided_node_features(&self) -> NodeFeatures { NodeFeatures::empty() } - fn provided_init_features(&self, _their_node_id: &PublicKey) -> InitFeatures { InitFeatures::empty() } - fn processing_queue_high(&self) -> bool { false } + fn handle_node_announcement( + &self, _msg: &NodeAnnouncement, + ) -> Result { + Ok(false) + } + fn handle_channel_announcement( + &self, _msg: &ChannelAnnouncement, + ) -> Result { + Ok(false) + } + fn handle_channel_update(&self, _msg: &ChannelUpdate) -> Result { + Ok(false) + } + fn get_next_channel_announcement( + &self, _starting_point: u64, + ) -> Option<(ChannelAnnouncement, Option, Option)> { + None + } + fn get_next_node_announcement( + &self, _starting_point: Option<&NodeId>, + ) -> Option { + None + } + fn peer_connected( + &self, _their_node_id: &PublicKey, _init_msg: &Init, _inbound: bool, + ) -> Result<(), ()> { + Ok(()) + } + fn handle_reply_channel_range( + &self, _their_node_id: &PublicKey, _msg: ReplyChannelRange, + ) -> Result<(), LightningError> { + Ok(()) + } + fn handle_reply_short_channel_ids_end( + &self, _their_node_id: &PublicKey, _msg: ReplyShortChannelIdsEnd, + ) -> Result<(), LightningError> { + Ok(()) + } + fn handle_query_channel_range( + &self, _their_node_id: &PublicKey, _msg: QueryChannelRange, + ) -> Result<(), LightningError> { + Ok(()) + } + fn handle_query_short_channel_ids( + &self, _their_node_id: &PublicKey, _msg: QueryShortChannelIds, + ) -> Result<(), LightningError> { + Ok(()) + } + fn provided_node_features(&self) -> NodeFeatures { + NodeFeatures::empty() + } + fn provided_init_features(&self, _their_node_id: &PublicKey) -> InitFeatures { + InitFeatures::empty() + } + fn processing_queue_high(&self) -> bool { + false + } } impl ChannelMessageHandler for MsgHandler { fn handle_open_channel(&self, _their_node_id: &PublicKey, _msg: &OpenChannel) {} @@ -595,13 +707,20 @@ mod tests { fn handle_shutdown(&self, _their_node_id: &PublicKey, _msg: &Shutdown) {} fn handle_closing_signed(&self, _their_node_id: &PublicKey, _msg: &ClosingSigned) {} fn handle_update_add_htlc(&self, _their_node_id: &PublicKey, _msg: &UpdateAddHTLC) {} - fn handle_update_fulfill_htlc(&self, _their_node_id: &PublicKey, _msg: &UpdateFulfillHTLC) {} + fn handle_update_fulfill_htlc(&self, _their_node_id: &PublicKey, _msg: &UpdateFulfillHTLC) { + } fn handle_update_fail_htlc(&self, _their_node_id: &PublicKey, _msg: &UpdateFailHTLC) {} - fn handle_update_fail_malformed_htlc(&self, _their_node_id: &PublicKey, _msg: &UpdateFailMalformedHTLC) {} + fn handle_update_fail_malformed_htlc( + &self, _their_node_id: &PublicKey, _msg: &UpdateFailMalformedHTLC, + ) { + } fn handle_commitment_signed(&self, _their_node_id: &PublicKey, _msg: &CommitmentSigned) {} fn handle_revoke_and_ack(&self, _their_node_id: &PublicKey, _msg: &RevokeAndACK) {} fn handle_update_fee(&self, _their_node_id: &PublicKey, _msg: &UpdateFee) {} - fn handle_announcement_signatures(&self, _their_node_id: &PublicKey, _msg: &AnnouncementSignatures) {} + fn handle_announcement_signatures( + &self, _their_node_id: &PublicKey, _msg: &AnnouncementSignatures, + ) { + } fn handle_channel_update(&self, _their_node_id: &PublicKey, _msg: &ChannelUpdate) {} fn handle_open_channel_v2(&self, _their_node_id: &PublicKey, _msg: &OpenChannelV2) {} fn handle_accept_channel_v2(&self, _their_node_id: &PublicKey, _msg: &AcceptChannelV2) {} @@ -620,16 +739,25 @@ mod tests { self.pubkey_disconnected.clone().try_send(()).unwrap(); } } - fn peer_connected(&self, their_node_id: &PublicKey, _init_msg: &Init, _inbound: bool) -> Result<(), ()> { + fn peer_connected( + &self, their_node_id: &PublicKey, _init_msg: &Init, _inbound: bool, + ) -> Result<(), ()> { if *their_node_id == self.expected_pubkey { self.pubkey_connected.clone().try_send(()).unwrap(); } Ok(()) } - fn handle_channel_reestablish(&self, _their_node_id: &PublicKey, _msg: &ChannelReestablish) {} + fn handle_channel_reestablish( + &self, _their_node_id: &PublicKey, _msg: &ChannelReestablish, + ) { + } fn handle_error(&self, _their_node_id: &PublicKey, _msg: &ErrorMessage) {} - fn provided_node_features(&self) -> NodeFeatures { NodeFeatures::empty() } - fn provided_init_features(&self, _their_node_id: &PublicKey) -> InitFeatures { InitFeatures::empty() } + fn provided_node_features(&self) -> NodeFeatures { + NodeFeatures::empty() + } + fn provided_init_features(&self, _their_node_id: &PublicKey) -> InitFeatures { + InitFeatures::empty() + } fn get_chain_hashes(&self) -> Option> { Some(vec![ChainHash::using_genesis_block(Network::Testnet)]) } @@ -655,7 +783,9 @@ mod tests { (std::net::TcpStream::connect("127.0.0.1:9999").unwrap(), listener.accept().unwrap().0) } else if let Ok(listener) = std::net::TcpListener::bind("127.0.0.1:46926") { (std::net::TcpStream::connect("127.0.0.1:46926").unwrap(), listener.accept().unwrap().0) - } else { panic!("Failed to bind to v4 localhost on common ports"); } + } else { + panic!("Failed to bind to v4 localhost on common ports"); + } } async fn do_basic_connection_test() { @@ -674,12 +804,22 @@ mod tests { disconnected_flag: AtomicBool::new(false), msg_events: Mutex::new(Vec::new()), }); - let a_manager = Arc::new(PeerManager::new(MessageHandler { - chan_handler: Arc::clone(&a_handler), - route_handler: Arc::clone(&a_handler), - onion_message_handler: Arc::new(lightning::ln::peer_handler::IgnoringMessageHandler{}), - custom_message_handler: Arc::new(lightning::ln::peer_handler::IgnoringMessageHandler{}), - }, 0, &[1; 32], Arc::new(TestLogger()), Arc::new(TestNodeSigner::new(a_key)))); + let a_manager = Arc::new(PeerManager::new( + MessageHandler { + chan_handler: Arc::clone(&a_handler), + route_handler: Arc::clone(&a_handler), + onion_message_handler: Arc::new( + lightning::ln::peer_handler::IgnoringMessageHandler {}, + ), + custom_message_handler: Arc::new( + lightning::ln::peer_handler::IgnoringMessageHandler {}, + ), + }, + 0, + &[1; 32], + Arc::new(TestLogger()), + Arc::new(TestNodeSigner::new(a_key)), + )); let (b_connected_sender, mut b_connected) = mpsc::channel(1); let (b_disconnected_sender, mut b_disconnected) = mpsc::channel(1); @@ -690,12 +830,22 @@ mod tests { disconnected_flag: AtomicBool::new(false), msg_events: Mutex::new(Vec::new()), }); - let b_manager = Arc::new(PeerManager::new(MessageHandler { - chan_handler: Arc::clone(&b_handler), - route_handler: Arc::clone(&b_handler), - onion_message_handler: Arc::new(lightning::ln::peer_handler::IgnoringMessageHandler{}), - custom_message_handler: Arc::new(lightning::ln::peer_handler::IgnoringMessageHandler{}), - }, 0, &[2; 32], Arc::new(TestLogger()), Arc::new(TestNodeSigner::new(b_key)))); + let b_manager = Arc::new(PeerManager::new( + MessageHandler { + chan_handler: Arc::clone(&b_handler), + route_handler: Arc::clone(&b_handler), + onion_message_handler: Arc::new( + lightning::ln::peer_handler::IgnoringMessageHandler {}, + ), + custom_message_handler: Arc::new( + lightning::ln::peer_handler::IgnoringMessageHandler {}, + ), + }, + 0, + &[2; 32], + Arc::new(TestLogger()), + Arc::new(TestNodeSigner::new(b_key)), + )); // We bind on localhost, hoping the environment is properly configured with a local // address. This may not always be the case in containers and the like, so if this test is @@ -710,7 +860,8 @@ mod tests { tokio::time::timeout(Duration::from_secs(1), b_connected.recv()).await.unwrap(); a_handler.msg_events.lock().unwrap().push(MessageSendEvent::HandleError { - node_id: b_pub, action: ErrorAction::DisconnectPeer { msg: None } + node_id: b_pub, + action: ErrorAction::DisconnectPeer { msg: None }, }); assert!(!a_handler.disconnected_flag.load(Ordering::SeqCst)); assert!(!b_handler.disconnected_flag.load(Ordering::SeqCst)); @@ -744,12 +895,22 @@ mod tests { let b_key = SecretKey::from_slice(&[2; 32]).unwrap(); let b_pub = PublicKey::from_secret_key(&secp_ctx, &b_key); - let a_manager = Arc::new(PeerManager::new(MessageHandler { - chan_handler: Arc::new(lightning::ln::peer_handler::ErroringMessageHandler::new()), - onion_message_handler: Arc::new(lightning::ln::peer_handler::IgnoringMessageHandler{}), - route_handler: Arc::new(lightning::ln::peer_handler::IgnoringMessageHandler{}), - custom_message_handler: Arc::new(lightning::ln::peer_handler::IgnoringMessageHandler{}), - }, 0, &[1; 32], Arc::new(TestLogger()), Arc::new(TestNodeSigner::new(a_key)))); + let a_manager = Arc::new(PeerManager::new( + MessageHandler { + chan_handler: Arc::new(lightning::ln::peer_handler::ErroringMessageHandler::new()), + onion_message_handler: Arc::new( + lightning::ln::peer_handler::IgnoringMessageHandler {}, + ), + route_handler: Arc::new(lightning::ln::peer_handler::IgnoringMessageHandler {}), + custom_message_handler: Arc::new( + lightning::ln::peer_handler::IgnoringMessageHandler {}, + ), + }, + 0, + &[1; 32], + Arc::new(TestLogger()), + Arc::new(TestNodeSigner::new(a_key)), + )); // Make two connections, one for an inbound and one for an outbound connection let conn_a = { @@ -763,12 +924,8 @@ mod tests { // Call connection setup inside new tokio tasks. let manager_reference = Arc::clone(&a_manager); - tokio::spawn(async move { - super::setup_inbound(manager_reference, conn_a).await - }); - tokio::spawn(async move { - super::setup_outbound(a_manager, b_pub, conn_b).await - }); + tokio::spawn(async move { super::setup_inbound(manager_reference, conn_a).await }); + tokio::spawn(async move { super::setup_outbound(a_manager, b_pub, conn_b).await }); } #[tokio::test(flavor = "multi_thread")] diff --git a/lightning-persister/src/fs_store.rs b/lightning-persister/src/fs_store.rs index c665d8083cb..8216a935c94 100644 --- a/lightning-persister/src/fs_store.rs +++ b/lightning-persister/src/fs_store.rs @@ -67,7 +67,9 @@ impl FilesystemStore { } } - fn get_dest_dir_path(&self, primary_namespace: &str, secondary_namespace: &str) -> std::io::Result { + fn get_dest_dir_path( + &self, primary_namespace: &str, secondary_namespace: &str, + ) -> std::io::Result { let mut dest_dir_path = { #[cfg(target_os = "windows")] { @@ -91,7 +93,9 @@ impl FilesystemStore { } impl KVStore for FilesystemStore { - fn read(&self, primary_namespace: &str, secondary_namespace: &str, key: &str) -> std::io::Result> { + fn read( + &self, primary_namespace: &str, secondary_namespace: &str, key: &str, + ) -> std::io::Result> { check_namespace_key_validity(primary_namespace, secondary_namespace, Some(key), "read")?; let mut dest_file_path = self.get_dest_dir_path(primary_namespace, secondary_namespace)?; @@ -114,19 +118,19 @@ impl KVStore for FilesystemStore { Ok(buf) } - fn write(&self, primary_namespace: &str, secondary_namespace: &str, key: &str, buf: &[u8]) -> std::io::Result<()> { + fn write( + &self, primary_namespace: &str, secondary_namespace: &str, key: &str, buf: &[u8], + ) -> std::io::Result<()> { check_namespace_key_validity(primary_namespace, secondary_namespace, Some(key), "write")?; let mut dest_file_path = self.get_dest_dir_path(primary_namespace, secondary_namespace)?; dest_file_path.push(key); - let parent_directory = dest_file_path - .parent() - .ok_or_else(|| { - let msg = - format!("Could not retrieve parent directory of {}.", dest_file_path.display()); - std::io::Error::new(std::io::ErrorKind::InvalidInput, msg) - })?; + let parent_directory = dest_file_path.parent().ok_or_else(|| { + let msg = + format!("Could not retrieve parent directory of {}.", dest_file_path.display()); + std::io::Error::new(std::io::ErrorKind::InvalidInput, msg) + })?; fs::create_dir_all(&parent_directory)?; // Do a crazy dance with lots of fsync()s to be overly cautious here... @@ -186,11 +190,11 @@ impl KVStore for FilesystemStore { match res { Ok(()) => { // We fsync the dest file in hopes this will also flush the metadata to disk. - let dest_file = fs::OpenOptions::new().read(true).write(true) - .open(&dest_file_path)?; + let dest_file = + fs::OpenOptions::new().read(true).write(true).open(&dest_file_path)?; dest_file.sync_all()?; Ok(()) - } + }, Err(e) => Err(e), } } @@ -201,7 +205,9 @@ impl KVStore for FilesystemStore { res } - fn remove(&self, primary_namespace: &str, secondary_namespace: &str, key: &str, lazy: bool) -> std::io::Result<()> { + fn remove( + &self, primary_namespace: &str, secondary_namespace: &str, key: &str, lazy: bool, + ) -> std::io::Result<()> { check_namespace_key_validity(primary_namespace, secondary_namespace, Some(key), "remove")?; let mut dest_file_path = self.get_dest_dir_path(primary_namespace, secondary_namespace)?; @@ -229,8 +235,10 @@ impl KVStore for FilesystemStore { fs::remove_file(&dest_file_path)?; let parent_directory = dest_file_path.parent().ok_or_else(|| { - let msg = - format!("Could not retrieve parent directory of {}.", dest_file_path.display()); + let msg = format!( + "Could not retrieve parent directory of {}.", + dest_file_path.display() + ); std::io::Error::new(std::io::ErrorKind::InvalidInput, msg) })?; let dir_file = fs::OpenOptions::new().read(true).open(parent_directory)?; @@ -257,8 +265,8 @@ impl KVStore for FilesystemStore { // However, all this is partially based on assumptions and local experiments, as // Windows API is horribly underdocumented. let mut trash_file_path = dest_file_path.clone(); - let trash_file_ext = format!("{}.trash", - self.tmp_file_counter.fetch_add(1, Ordering::AcqRel)); + let trash_file_ext = + format!("{}.trash", self.tmp_file_counter.fetch_add(1, Ordering::AcqRel)); trash_file_path.set_extension(trash_file_ext); call!(unsafe { @@ -273,7 +281,9 @@ impl KVStore for FilesystemStore { { // We fsync the trash file in hopes this will also flush the original's file // metadata to disk. - let trash_file = fs::OpenOptions::new().read(true).write(true) + let trash_file = fs::OpenOptions::new() + .read(true) + .write(true) .open(&trash_file_path.clone())?; trash_file.sync_all()?; } @@ -290,7 +300,9 @@ impl KVStore for FilesystemStore { Ok(()) } - fn list(&self, primary_namespace: &str, secondary_namespace: &str) -> std::io::Result> { + fn list( + &self, primary_namespace: &str, secondary_namespace: &str, + ) -> std::io::Result> { check_namespace_key_validity(primary_namespace, secondary_namespace, None, "list")?; let prefixed_dest = self.get_dest_dir_path(primary_namespace, secondary_namespace)?; @@ -327,10 +339,17 @@ impl KVStore for FilesystemStore { // If we otherwise don't find a file at the given path something went wrong. if !metadata.is_file() { - debug_assert!(false, "Failed to list keys of {}/{}: file couldn't be accessed.", - PrintableString(primary_namespace), PrintableString(secondary_namespace)); - let msg = format!("Failed to list keys of {}/{}: file couldn't be accessed.", - PrintableString(primary_namespace), PrintableString(secondary_namespace)); + debug_assert!( + false, + "Failed to list keys of {}/{}: file couldn't be accessed.", + PrintableString(primary_namespace), + PrintableString(secondary_namespace) + ); + let msg = format!( + "Failed to list keys of {}/{}: file couldn't be accessed.", + PrintableString(primary_namespace), + PrintableString(secondary_namespace) + ); return Err(std::io::Error::new(std::io::ErrorKind::Other, msg)); } @@ -341,20 +360,36 @@ impl KVStore for FilesystemStore { keys.push(relative_path.to_string()) } } else { - debug_assert!(false, "Failed to list keys of {}/{}: file path is not valid UTF-8", - PrintableString(primary_namespace), PrintableString(secondary_namespace)); - let msg = format!("Failed to list keys of {}/{}: file path is not valid UTF-8", - PrintableString(primary_namespace), PrintableString(secondary_namespace)); + debug_assert!( + false, + "Failed to list keys of {}/{}: file path is not valid UTF-8", + PrintableString(primary_namespace), + PrintableString(secondary_namespace) + ); + let msg = format!( + "Failed to list keys of {}/{}: file path is not valid UTF-8", + PrintableString(primary_namespace), + PrintableString(secondary_namespace) + ); return Err(std::io::Error::new(std::io::ErrorKind::Other, msg)); } - } + }, Err(e) => { - debug_assert!(false, "Failed to list keys of {}/{}: {}", - PrintableString(primary_namespace), PrintableString(secondary_namespace), e); - let msg = format!("Failed to list keys of {}/{}: {}", - PrintableString(primary_namespace), PrintableString(secondary_namespace), e); + debug_assert!( + false, + "Failed to list keys of {}/{}: {}", + PrintableString(primary_namespace), + PrintableString(secondary_namespace), + e + ); + let msg = format!( + "Failed to list keys of {}/{}: {}", + PrintableString(primary_namespace), + PrintableString(secondary_namespace), + e + ); return Err(std::io::Error::new(std::io::ErrorKind::Other, msg)); - } + }, } } @@ -372,20 +407,17 @@ mod tests { use bitcoin::hashes::hex::FromHex; use bitcoin::Txid; - use lightning::chain::ChannelMonitorUpdateStatus; use lightning::chain::chainmonitor::Persist; use lightning::chain::transaction::OutPoint; + use lightning::chain::ChannelMonitorUpdateStatus; use lightning::check_closed_event; use lightning::events::{ClosureReason, MessageSendEventsProvider}; use lightning::ln::functional_test_utils::*; - use lightning::util::test_utils; use lightning::util::persist::read_channel_monitors; + use lightning::util::test_utils; use std::fs; #[cfg(target_os = "windows")] - use { - lightning::get_event_msg, - lightning::ln::msgs::ChannelMessageHandler, - }; + use {lightning::get_event_msg, lightning::ln::msgs::ChannelMessageHandler}; impl Drop for FilesystemStore { fn drop(&mut self) { @@ -393,7 +425,7 @@ mod tests { // fails. match fs::remove_dir_all(&self.data_dir) { Err(e) => println!("Failed to remove test persister directory: {}", e), - _ => {} + _ => {}, } } } @@ -417,14 +449,23 @@ mod tests { let chanmon_cfgs = create_chanmon_cfgs(1); let mut node_cfgs = create_node_cfgs(1, &chanmon_cfgs); - let chain_mon_0 = test_utils::TestChainMonitor::new(Some(&chanmon_cfgs[0].chain_source), &chanmon_cfgs[0].tx_broadcaster, &chanmon_cfgs[0].logger, &chanmon_cfgs[0].fee_estimator, &store, node_cfgs[0].keys_manager); + let chain_mon_0 = test_utils::TestChainMonitor::new( + Some(&chanmon_cfgs[0].chain_source), + &chanmon_cfgs[0].tx_broadcaster, + &chanmon_cfgs[0].logger, + &chanmon_cfgs[0].fee_estimator, + &store, + node_cfgs[0].keys_manager, + ); node_cfgs[0].chain_monitor = chain_mon_0; let node_chanmgrs = create_node_chanmgrs(1, &node_cfgs, &[None]); let nodes = create_network(1, &node_cfgs, &node_chanmgrs); // Check that read_channel_monitors() returns error if monitors/ is not a // directory. - assert!(read_channel_monitors(&store, nodes[0].keys_manager, nodes[0].keys_manager).is_err()); + assert!( + read_channel_monitors(&store, nodes[0].keys_manager, nodes[0].keys_manager).is_err() + ); } #[test] @@ -451,8 +492,17 @@ mod tests { let node_chanmgrs = create_node_chanmgrs(2, &node_cfgs, &[None, None]); let nodes = create_network(2, &node_cfgs, &node_chanmgrs); let chan = create_announced_chan_between_nodes(&nodes, 0, 1); - nodes[1].node.force_close_broadcasting_latest_txn(&chan.2, &nodes[0].node.get_our_node_id()).unwrap(); - check_closed_event!(nodes[1], 1, ClosureReason::HolderForceClosed, [nodes[0].node.get_our_node_id()], 100000); + nodes[1] + .node + .force_close_broadcasting_latest_txn(&chan.2, &nodes[0].node.get_our_node_id()) + .unwrap(); + check_closed_event!( + nodes[1], + 1, + ClosureReason::HolderForceClosed, + [nodes[0].node.get_our_node_id()], + 100000 + ); let mut added_monitors = nodes[1].chain_monitor.added_monitors.lock().unwrap(); let update_map = nodes[1].chain_monitor.latest_monitor_update_id.lock().unwrap(); let update_id = update_map.get(&added_monitors[0].0.to_channel_id()).unwrap(); @@ -466,12 +516,15 @@ mod tests { fs::set_permissions(path, perms).unwrap(); let test_txo = OutPoint { - txid: Txid::from_hex("8984484a580b825b9972d7adb15050b3ab624ccd731946b3eeddb92f4e7ef6be").unwrap(), - index: 0 + txid: Txid::from_hex( + "8984484a580b825b9972d7adb15050b3ab624ccd731946b3eeddb92f4e7ef6be", + ) + .unwrap(), + index: 0, }; match store.persist_new_channel(test_txo, &added_monitors[0].1, update_id.2) { ChannelMonitorUpdateStatus::UnrecoverableError => {}, - _ => panic!("unexpected result from persisting new channel") + _ => panic!("unexpected result from persisting new channel"), } nodes[1].node.get_and_clear_pending_msg_events(); @@ -490,8 +543,17 @@ mod tests { let node_chanmgrs = create_node_chanmgrs(2, &node_cfgs, &[None, None]); let nodes = create_network(2, &node_cfgs, &node_chanmgrs); let chan = create_announced_chan_between_nodes(&nodes, 0, 1); - nodes[1].node.force_close_broadcasting_latest_txn(&chan.2, &nodes[0].node.get_our_node_id()).unwrap(); - check_closed_event!(nodes[1], 1, ClosureReason::HolderForceClosed, [nodes[0].node.get_our_node_id()], 100000); + nodes[1] + .node + .force_close_broadcasting_latest_txn(&chan.2, &nodes[0].node.get_our_node_id()) + .unwrap(); + check_closed_event!( + nodes[1], + 1, + ClosureReason::HolderForceClosed, + [nodes[0].node.get_our_node_id()], + 100000 + ); let mut added_monitors = nodes[1].chain_monitor.added_monitors.lock().unwrap(); let update_map = nodes[1].chain_monitor.latest_monitor_update_id.lock().unwrap(); let update_id = update_map.get(&added_monitors[0].0.to_channel_id()).unwrap(); @@ -503,12 +565,15 @@ mod tests { let store = FilesystemStore::new(":<>/".into()); let test_txo = OutPoint { - txid: Txid::from_hex("8984484a580b825b9972d7adb15050b3ab624ccd731946b3eeddb92f4e7ef6be").unwrap(), - index: 0 + txid: Txid::from_hex( + "8984484a580b825b9972d7adb15050b3ab624ccd731946b3eeddb92f4e7ef6be", + ) + .unwrap(), + index: 0, }; match store.persist_new_channel(test_txo, &added_monitors[0].1, update_id.2) { ChannelMonitorUpdateStatus::UnrecoverableError => {}, - _ => panic!("unexpected result from persisting new channel") + _ => panic!("unexpected result from persisting new channel"), } nodes[1].node.get_and_clear_pending_msg_events(); @@ -526,6 +591,10 @@ pub mod bench { let store_a = super::FilesystemStore::new("bench_filesystem_store_a".into()); let store_b = super::FilesystemStore::new("bench_filesystem_store_b".into()); lightning::ln::channelmanager::bench::bench_two_sends( - bench, "bench_filesystem_persisted_sends", store_a, store_b); + bench, + "bench_filesystem_persisted_sends", + store_a, + store_b, + ); } } diff --git a/lightning-persister/src/lib.rs b/lightning-persister/src/lib.rs index ae258e137d7..ba6738f1272 100644 --- a/lightning-persister/src/lib.rs +++ b/lightning-persister/src/lib.rs @@ -3,12 +3,11 @@ // TODO: Prefix these with `rustdoc::` when we update our MSRV to be >= 1.52 to remove warnings. #![deny(broken_intra_doc_links)] #![deny(private_intra_doc_links)] - #![deny(missing_docs)] - #![cfg_attr(docsrs, feature(doc_auto_cfg))] -#[cfg(ldk_bench)] extern crate criterion; +#[cfg(ldk_bench)] +extern crate criterion; pub mod fs_store; diff --git a/lightning-persister/src/test_utils.rs b/lightning-persister/src/test_utils.rs index 360fa3492bf..e861e986c67 100644 --- a/lightning-persister/src/test_utils.rs +++ b/lightning-persister/src/test_utils.rs @@ -1,11 +1,12 @@ -use lightning::util::persist::{KVStore, KVSTORE_NAMESPACE_KEY_MAX_LEN, read_channel_monitors}; -use lightning::ln::functional_test_utils::{connect_block, create_announced_chan_between_nodes, - create_chanmon_cfgs, create_dummy_block, create_network, create_node_cfgs, create_node_chanmgrs, - send_payment}; use lightning::chain::channelmonitor::CLOSED_CHANNEL_UPDATE_ID; -use lightning::util::test_utils; -use lightning::{check_closed_broadcast, check_closed_event, check_added_monitors}; use lightning::events::ClosureReason; +use lightning::ln::functional_test_utils::{ + connect_block, create_announced_chan_between_nodes, create_chanmon_cfgs, create_dummy_block, + create_network, create_node_cfgs, create_node_chanmgrs, send_payment, +}; +use lightning::util::persist::{read_channel_monitors, KVStore, KVSTORE_NAMESPACE_KEY_MAX_LEN}; +use lightning::util::test_utils; +use lightning::{check_added_monitors, check_closed_broadcast, check_closed_event}; use std::panic::RefUnwindSafe; @@ -24,7 +25,9 @@ pub(crate) fn do_read_write_remove_list_persist(kv_s kv_store.write("", "", key, &data).unwrap(); let res = std::panic::catch_unwind(|| kv_store.write("", secondary_namespace, key, &data)); assert!(res.is_err()); - let res = std::panic::catch_unwind(|| kv_store.write(primary_namespace, secondary_namespace, "", &data)); + let res = std::panic::catch_unwind(|| { + kv_store.write(primary_namespace, secondary_namespace, "", &data) + }); assert!(res.is_err()); let listed_keys = kv_store.list(primary_namespace, secondary_namespace).unwrap(); @@ -62,8 +65,22 @@ pub(crate) fn do_read_write_remove_list_persist(kv_s pub(crate) fn do_test_store(store_0: &K, store_1: &K) { let chanmon_cfgs = create_chanmon_cfgs(2); let mut node_cfgs = create_node_cfgs(2, &chanmon_cfgs); - let chain_mon_0 = test_utils::TestChainMonitor::new(Some(&chanmon_cfgs[0].chain_source), &chanmon_cfgs[0].tx_broadcaster, &chanmon_cfgs[0].logger, &chanmon_cfgs[0].fee_estimator, store_0, node_cfgs[0].keys_manager); - let chain_mon_1 = test_utils::TestChainMonitor::new(Some(&chanmon_cfgs[1].chain_source), &chanmon_cfgs[1].tx_broadcaster, &chanmon_cfgs[1].logger, &chanmon_cfgs[1].fee_estimator, store_1, node_cfgs[1].keys_manager); + let chain_mon_0 = test_utils::TestChainMonitor::new( + Some(&chanmon_cfgs[0].chain_source), + &chanmon_cfgs[0].tx_broadcaster, + &chanmon_cfgs[0].logger, + &chanmon_cfgs[0].fee_estimator, + store_0, + node_cfgs[0].keys_manager, + ); + let chain_mon_1 = test_utils::TestChainMonitor::new( + Some(&chanmon_cfgs[1].chain_source), + &chanmon_cfgs[1].tx_broadcaster, + &chanmon_cfgs[1].logger, + &chanmon_cfgs[1].fee_estimator, + store_1, + node_cfgs[1].keys_manager, + ); node_cfgs[0].chain_monitor = chain_mon_0; node_cfgs[1].chain_monitor = chain_mon_1; let node_chanmgrs = create_node_chanmgrs(2, &node_cfgs, &[None, None]); @@ -71,25 +88,31 @@ pub(crate) fn do_test_store(store_0: &K, store_1: &K) { // Check that the persisted channel data is empty before any channels are // open. - let mut persisted_chan_data_0 = read_channel_monitors(store_0, nodes[0].keys_manager, nodes[0].keys_manager).unwrap(); + let mut persisted_chan_data_0 = + read_channel_monitors(store_0, nodes[0].keys_manager, nodes[0].keys_manager).unwrap(); assert_eq!(persisted_chan_data_0.len(), 0); - let mut persisted_chan_data_1 = read_channel_monitors(store_1, nodes[1].keys_manager, nodes[1].keys_manager).unwrap(); + let mut persisted_chan_data_1 = + read_channel_monitors(store_1, nodes[1].keys_manager, nodes[1].keys_manager).unwrap(); assert_eq!(persisted_chan_data_1.len(), 0); // Helper to make sure the channel is on the expected update ID. macro_rules! check_persisted_data { ($expected_update_id: expr) => { - persisted_chan_data_0 = read_channel_monitors(store_0, nodes[0].keys_manager, nodes[0].keys_manager).unwrap(); + persisted_chan_data_0 = + read_channel_monitors(store_0, nodes[0].keys_manager, nodes[0].keys_manager) + .unwrap(); assert_eq!(persisted_chan_data_0.len(), 1); for (_, mon) in persisted_chan_data_0.iter() { assert_eq!(mon.get_latest_update_id(), $expected_update_id); } - persisted_chan_data_1 = read_channel_monitors(store_1, nodes[1].keys_manager, nodes[1].keys_manager).unwrap(); + persisted_chan_data_1 = + read_channel_monitors(store_1, nodes[1].keys_manager, nodes[1].keys_manager) + .unwrap(); assert_eq!(persisted_chan_data_1.len(), 1); for (_, mon) in persisted_chan_data_1.iter() { assert_eq!(mon.get_latest_update_id(), $expected_update_id); } - } + }; } // Create some initial channel and check that a channel was persisted. @@ -97,24 +120,48 @@ pub(crate) fn do_test_store(store_0: &K, store_1: &K) { check_persisted_data!(0); // Send a few payments and make sure the monitors are updated to the latest. - send_payment(&nodes[0], &vec!(&nodes[1])[..], 8000000); + send_payment(&nodes[0], &vec![&nodes[1]][..], 8000000); check_persisted_data!(5); - send_payment(&nodes[1], &vec!(&nodes[0])[..], 4000000); + send_payment(&nodes[1], &vec![&nodes[0]][..], 4000000); check_persisted_data!(10); // Force close because cooperative close doesn't result in any persisted // updates. - nodes[0].node.force_close_broadcasting_latest_txn(&nodes[0].node.list_channels()[0].channel_id, &nodes[1].node.get_our_node_id()).unwrap(); - check_closed_event!(nodes[0], 1, ClosureReason::HolderForceClosed, [nodes[1].node.get_our_node_id()], 100000); + nodes[0] + .node + .force_close_broadcasting_latest_txn( + &nodes[0].node.list_channels()[0].channel_id, + &nodes[1].node.get_our_node_id(), + ) + .unwrap(); + check_closed_event!( + nodes[0], + 1, + ClosureReason::HolderForceClosed, + [nodes[1].node.get_our_node_id()], + 100000 + ); check_closed_broadcast!(nodes[0], true); check_added_monitors!(nodes[0], 1); let node_txn = nodes[0].tx_broadcaster.txn_broadcasted.lock().unwrap(); assert_eq!(node_txn.len(), 1); - connect_block(&nodes[1], &create_dummy_block(nodes[0].best_block_hash(), 42, vec![node_txn[0].clone(), node_txn[0].clone()])); + connect_block( + &nodes[1], + &create_dummy_block(nodes[0].best_block_hash(), 42, vec![ + node_txn[0].clone(), + node_txn[0].clone(), + ]), + ); check_closed_broadcast!(nodes[1], true); - check_closed_event!(nodes[1], 1, ClosureReason::CommitmentTxConfirmed, [nodes[0].node.get_our_node_id()], 100000); + check_closed_event!( + nodes[1], + 1, + ClosureReason::CommitmentTxConfirmed, + [nodes[0].node.get_our_node_id()], + 100000 + ); check_added_monitors!(nodes[1], 1); // Make sure everything is persisted as expected after close. diff --git a/lightning-persister/src/utils.rs b/lightning-persister/src/utils.rs index 59a615937c9..6a7900856cc 100644 --- a/lightning-persister/src/utils.rs +++ b/lightning-persister/src/utils.rs @@ -1,20 +1,31 @@ use lightning::util::persist::{KVSTORE_NAMESPACE_KEY_ALPHABET, KVSTORE_NAMESPACE_KEY_MAX_LEN}; use lightning::util::string::PrintableString; - pub(crate) fn is_valid_kvstore_str(key: &str) -> bool { - key.len() <= KVSTORE_NAMESPACE_KEY_MAX_LEN && key.chars().all(|c| KVSTORE_NAMESPACE_KEY_ALPHABET.contains(c)) + key.len() <= KVSTORE_NAMESPACE_KEY_MAX_LEN + && key.chars().all(|c| KVSTORE_NAMESPACE_KEY_ALPHABET.contains(c)) } pub(crate) fn check_namespace_key_validity( - primary_namespace: &str, secondary_namespace: &str, key: Option<&str>, operation: &str) --> Result<(), std::io::Error> { + primary_namespace: &str, secondary_namespace: &str, key: Option<&str>, operation: &str, +) -> Result<(), std::io::Error> { if let Some(key) = key { if key.is_empty() { - debug_assert!(false, "Failed to {} {}/{}/{}: key may not be empty.", operation, - PrintableString(primary_namespace), PrintableString(secondary_namespace), PrintableString(key)); - let msg = format!("Failed to {} {}/{}/{}: key may not be empty.", operation, - PrintableString(primary_namespace), PrintableString(secondary_namespace), PrintableString(key)); + debug_assert!( + false, + "Failed to {} {}/{}/{}: key may not be empty.", + operation, + PrintableString(primary_namespace), + PrintableString(secondary_namespace), + PrintableString(key) + ); + let msg = format!( + "Failed to {} {}/{}/{}: key may not be empty.", + operation, + PrintableString(primary_namespace), + PrintableString(secondary_namespace), + PrintableString(key) + ); return Err(std::io::Error::new(std::io::ErrorKind::Other, msg)); } @@ -29,7 +40,10 @@ pub(crate) fn check_namespace_key_validity( return Err(std::io::Error::new(std::io::ErrorKind::Other, msg)); } - if !is_valid_kvstore_str(primary_namespace) || !is_valid_kvstore_str(secondary_namespace) || !is_valid_kvstore_str(key) { + if !is_valid_kvstore_str(primary_namespace) + || !is_valid_kvstore_str(secondary_namespace) + || !is_valid_kvstore_str(key) + { debug_assert!(false, "Failed to {} {}/{}/{}: primary namespace, secondary namespace, and key must be valid.", operation, PrintableString(primary_namespace), PrintableString(secondary_namespace), PrintableString(key)); @@ -49,10 +63,19 @@ pub(crate) fn check_namespace_key_validity( return Err(std::io::Error::new(std::io::ErrorKind::Other, msg)); } if !is_valid_kvstore_str(primary_namespace) || !is_valid_kvstore_str(secondary_namespace) { - debug_assert!(false, "Failed to {} {}/{}: primary namespace and secondary namespace must be valid.", - operation, PrintableString(primary_namespace), PrintableString(secondary_namespace)); - let msg = format!("Failed to {} {}/{}: primary namespace and secondary namespace must be valid.", - operation, PrintableString(primary_namespace), PrintableString(secondary_namespace)); + debug_assert!( + false, + "Failed to {} {}/{}: primary namespace and secondary namespace must be valid.", + operation, + PrintableString(primary_namespace), + PrintableString(secondary_namespace) + ); + let msg = format!( + "Failed to {} {}/{}: primary namespace and secondary namespace must be valid.", + operation, + PrintableString(primary_namespace), + PrintableString(secondary_namespace) + ); return Err(std::io::Error::new(std::io::ErrorKind::Other, msg)); } } diff --git a/lightning-rapid-gossip-sync/src/error.rs b/lightning-rapid-gossip-sync/src/error.rs index ffd6760f8a9..e80ec83253e 100644 --- a/lightning-rapid-gossip-sync/src/error.rs +++ b/lightning-rapid-gossip-sync/src/error.rs @@ -34,7 +34,9 @@ impl Debug for GraphSyncError { fn fmt(&self, f: &mut Formatter<'_>) -> core::fmt::Result { match self { GraphSyncError::DecodeError(e) => f.write_fmt(format_args!("DecodeError: {:?}", e)), - GraphSyncError::LightningError(e) => f.write_fmt(format_args!("LightningError: {:?}", e)) + GraphSyncError::LightningError(e) => { + f.write_fmt(format_args!("LightningError: {:?}", e)) + }, } } } diff --git a/lightning-rapid-gossip-sync/src/lib.rs b/lightning-rapid-gossip-sync/src/lib.rs index 5a61be7990e..63b1978b0c8 100644 --- a/lightning-rapid-gossip-sync/src/lib.rs +++ b/lightning-rapid-gossip-sync/src/lib.rs @@ -1,7 +1,6 @@ // Prefix these with `rustdoc::` when we update our MSRV to be >= 1.52 to remove warnings. #![deny(broken_intra_doc_links)] #![deny(private_intra_doc_links)] - #![deny(missing_docs)] #![deny(unsafe_code)] #![deny(non_upper_case_globals)] @@ -64,15 +63,16 @@ #![cfg_attr(all(not(feature = "std"), not(test)), no_std)] -#[cfg(ldk_bench)] extern crate criterion; +#[cfg(ldk_bench)] +extern crate criterion; #[cfg(not(feature = "std"))] extern crate alloc; -#[cfg(feature = "std")] -use std::fs::File; use core::ops::Deref; use core::sync::atomic::{AtomicBool, Ordering}; +#[cfg(feature = "std")] +use std::fs::File; use lightning::io; use lightning::routing::gossip::NetworkGraph; @@ -91,21 +91,22 @@ mod processing; /// See [crate-level documentation] for usage. /// /// [crate-level documentation]: crate -pub struct RapidGossipSync>, L: Deref> -where L::Target: Logger { +pub struct RapidGossipSync>, L: Deref> +where + L::Target: Logger, +{ network_graph: NG, logger: L, - is_initial_sync_complete: AtomicBool + is_initial_sync_complete: AtomicBool, } -impl>, L: Deref> RapidGossipSync where L::Target: Logger { +impl>, L: Deref> RapidGossipSync +where + L::Target: Logger, +{ /// Instantiate a new [`RapidGossipSync`] instance. pub fn new(network_graph: NG, logger: L) -> Self { - Self { - network_graph, - logger, - is_initial_sync_complete: AtomicBool::new(false) - } + Self { network_graph, logger, is_initial_sync_complete: AtomicBool::new(false) } } /// Sync gossip data from a file. @@ -117,8 +118,7 @@ impl>, L: Deref> RapidGossipSync where L /// #[cfg(feature = "std")] pub fn sync_network_graph_with_file_path( - &self, - sync_path: &str, + &self, sync_path: &str, ) -> Result { let mut file = File::open(sync_path)?; self.update_network_graph_from_byte_stream(&mut file) @@ -139,7 +139,9 @@ impl>, L: Deref> RapidGossipSync where L /// /// `update_data`: `&[u8]` binary stream that comprises the update data /// `current_time_unix`: `Option` optional current timestamp to verify data age - pub fn update_network_graph_no_std(&self, update_data: &[u8], current_time_unix: Option) -> Result { + pub fn update_network_graph_no_std( + &self, update_data: &[u8], current_time_unix: Option, + ) -> Result { let mut read_cursor = io::Cursor::new(update_data); self.update_network_graph_from_byte_stream_no_std(&mut read_cursor, current_time_unix) } @@ -165,10 +167,10 @@ mod tests { use bitcoin::Network; + use crate::RapidGossipSync; use lightning::ln::msgs::DecodeError; use lightning::routing::gossip::NetworkGraph; use lightning::util::test_utils::TestLogger; - use crate::RapidGossipSync; #[test] fn test_sync_from_file() { @@ -264,9 +266,10 @@ mod tests { let rapid_sync = RapidGossipSync::new(&network_graph, &logger); let start = std::time::Instant::now(); - let sync_result = rapid_sync - .sync_network_graph_with_file_path("./res/full_graph.lngossip"); - if let Err(crate::error::GraphSyncError::DecodeError(DecodeError::Io(io_error))) = &sync_result { + let sync_result = rapid_sync.sync_network_graph_with_file_path("./res/full_graph.lngossip"); + if let Err(crate::error::GraphSyncError::DecodeError(DecodeError::Io(io_error))) = + &sync_result + { let error_string = format!("Input file lightning-rapid-gossip-sync/res/full_graph.lngossip is missing! Download it from https://bitcoin.ninja/ldk-compressed_graph-285cb27df79-2022-07-21.bin\n\n{:?}", io_error); #[cfg(not(require_route_graph_test))] { diff --git a/lightning-rapid-gossip-sync/src/processing.rs b/lightning-rapid-gossip-sync/src/processing.rs index d54f1329798..3b521843e4f 100644 --- a/lightning-rapid-gossip-sync/src/processing.rs +++ b/lightning-rapid-gossip-sync/src/processing.rs @@ -5,14 +5,12 @@ use core::sync::atomic::Ordering; use bitcoin::blockdata::constants::ChainHash; use bitcoin::secp256k1::PublicKey; -use lightning::ln::msgs::{ - DecodeError, ErrorAction, LightningError, UnsignedChannelUpdate, -}; +use lightning::io; +use lightning::ln::msgs::{DecodeError, ErrorAction, LightningError, UnsignedChannelUpdate}; use lightning::routing::gossip::NetworkGraph; use lightning::util::logger::Logger; -use lightning::{log_debug, log_warn, log_trace, log_given_level, log_gossip}; use lightning::util::ser::{BigSize, Readable}; -use lightning::io; +use lightning::{log_debug, log_given_level, log_gossip, log_trace, log_warn}; use crate::error::GraphSyncError; use crate::RapidGossipSync; @@ -21,7 +19,7 @@ use crate::RapidGossipSync; use std::time::{SystemTime, UNIX_EPOCH}; #[cfg(not(feature = "std"))] -use alloc::{vec::Vec, borrow::ToOwned}; +use alloc::{borrow::ToOwned, vec::Vec}; /// The purpose of this prefix is to identify the serialization format, should other rapid gossip /// sync formats arise in the future. @@ -37,11 +35,13 @@ const MAX_INITIAL_NODE_ID_VECTOR_CAPACITY: u32 = 50_000; /// suggestion. const STALE_RGS_UPDATE_AGE_LIMIT_SECS: u64 = 60 * 60 * 24 * 14; -impl>, L: Deref> RapidGossipSync where L::Target: Logger { +impl>, L: Deref> RapidGossipSync +where + L::Target: Logger, +{ #[cfg(feature = "std")] pub(crate) fn update_network_graph_from_byte_stream( - &self, - read_cursor: &mut R, + &self, read_cursor: &mut R, ) -> Result { #[allow(unused_mut, unused_assignments)] let mut current_time_unix = None; @@ -49,15 +49,18 @@ impl>, L: Deref> RapidGossipSync where L { // Note that many tests rely on being able to set arbitrarily old timestamps, thus we // disable this check during tests! - current_time_unix = Some(SystemTime::now().duration_since(UNIX_EPOCH).expect("Time must be > 1970").as_secs()); + current_time_unix = Some( + SystemTime::now() + .duration_since(UNIX_EPOCH) + .expect("Time must be > 1970") + .as_secs(), + ); } self.update_network_graph_from_byte_stream_no_std(read_cursor, current_time_unix) } pub(crate) fn update_network_graph_from_byte_stream_no_std( - &self, - mut read_cursor: &mut R, - current_time_unix: Option + &self, mut read_cursor: &mut R, current_time_unix: Option, ) -> Result { log_trace!(self.logger, "Processing RGS data..."); let mut prefix = [0u8; 4]; @@ -70,19 +73,24 @@ impl>, L: Deref> RapidGossipSync where L let chain_hash: ChainHash = Readable::read(read_cursor)?; let ng_chain_hash = self.network_graph.get_chain_hash(); if chain_hash != ng_chain_hash { - return Err( - LightningError { - err: "Rapid Gossip Sync data's chain hash does not match the network graph's".to_owned(), - action: ErrorAction::IgnoreError, - }.into() - ); + return Err(LightningError { + err: "Rapid Gossip Sync data's chain hash does not match the network graph's" + .to_owned(), + action: ErrorAction::IgnoreError, + } + .into()); } let latest_seen_timestamp: u32 = Readable::read(read_cursor)?; if let Some(time) = current_time_unix { - if (latest_seen_timestamp as u64) < time.saturating_sub(STALE_RGS_UPDATE_AGE_LIMIT_SECS) { - return Err(LightningError{err: "Rapid Gossip Sync data is more than two weeks old".to_owned(), action: ErrorAction::IgnoreError}.into()); + if (latest_seen_timestamp as u64) < time.saturating_sub(STALE_RGS_UPDATE_AGE_LIMIT_SECS) + { + return Err(LightningError { + err: "Rapid Gossip Sync data is more than two weeks old".to_owned(), + action: ErrorAction::IgnoreError, + } + .into()); } } @@ -108,9 +116,8 @@ impl>, L: Deref> RapidGossipSync where L // handle SCID let scid_delta: BigSize = Readable::read(read_cursor)?; - let short_channel_id = previous_scid - .checked_add(scid_delta.0) - .ok_or(DecodeError::InvalidValue)?; + let short_channel_id = + previous_scid.checked_add(scid_delta.0).ok_or(DecodeError::InvalidValue)?; previous_scid = short_channel_id; let node_id_1_index: BigSize = Readable::read(read_cursor)?; @@ -122,8 +129,12 @@ impl>, L: Deref> RapidGossipSync where L let node_id_1 = node_ids[node_id_1_index.0 as usize]; let node_id_2 = node_ids[node_id_2_index.0 as usize]; - log_gossip!(self.logger, "Adding channel {} from RGS announcement at {}", - short_channel_id, latest_seen_timestamp); + log_gossip!( + self.logger, + "Adding channel {} from RGS announcement at {}", + short_channel_id, + latest_seen_timestamp + ); let announcement_result = network_graph.add_channel_from_partial_announcement( short_channel_id, @@ -136,7 +147,11 @@ impl>, L: Deref> RapidGossipSync where L if let ErrorAction::IgnoreDuplicateGossip = lightning_error.action { // everything is fine, just a duplicate channel announcement } else { - log_warn!(self.logger, "Failed to process channel announcement: {:?}", lightning_error); + log_warn!( + self.logger, + "Failed to process channel announcement: {:?}", + lightning_error + ); return Err(lightning_error.into()); } } @@ -160,9 +175,8 @@ impl>, L: Deref> RapidGossipSync where L for _ in 0..update_count { let scid_delta: BigSize = Readable::read(read_cursor)?; - let short_channel_id = previous_scid - .checked_add(scid_delta.0) - .ok_or(DecodeError::InvalidValue)?; + let short_channel_id = + previous_scid.checked_add(scid_delta.0).ok_or(DecodeError::InvalidValue)?; previous_scid = short_channel_id; let channel_flags: u8 = Readable::read(read_cursor)?; @@ -188,15 +202,17 @@ impl>, L: Deref> RapidGossipSync where L if (channel_flags & 0b_1000_0000) != 0 { // incremental update, field flags will indicate mutated values let read_only_network_graph = network_graph.read_only(); - if let Some(directional_info) = - read_only_network_graph.channels().get(&short_channel_id) + if let Some(directional_info) = read_only_network_graph + .channels() + .get(&short_channel_id) .and_then(|channel| channel.get_directional_info(channel_flags)) { synthetic_update.cltv_expiry_delta = directional_info.cltv_expiry_delta; synthetic_update.htlc_minimum_msat = directional_info.htlc_minimum_msat; synthetic_update.htlc_maximum_msat = directional_info.htlc_maximum_msat; synthetic_update.fee_base_msat = directional_info.fees.base_msat; - synthetic_update.fee_proportional_millionths = directional_info.fees.proportional_millionths; + synthetic_update.fee_proportional_millionths = + directional_info.fees.proportional_millionths; } else { log_trace!(self.logger, "Skipping application of channel update for chan {} with flags {} as original data is missing.", @@ -234,13 +250,23 @@ impl>, L: Deref> RapidGossipSync where L continue; } - log_gossip!(self.logger, "Updating channel {} with flags {} from RGS announcement at {}", - short_channel_id, channel_flags, latest_seen_timestamp); + log_gossip!( + self.logger, + "Updating channel {} with flags {} from RGS announcement at {}", + short_channel_id, + channel_flags, + latest_seen_timestamp + ); match network_graph.update_channel_unsigned(&synthetic_update) { Ok(_) => {}, Err(LightningError { action: ErrorAction::IgnoreDuplicateGossip, .. }) => {}, Err(LightningError { action: ErrorAction::IgnoreAndLog(level), err }) => { - log_given_level!(self.logger, level, "Failed to apply channel update: {:?}", err); + log_given_level!( + self.logger, + level, + "Failed to apply channel update: {:?}", + err + ); }, Err(LightningError { action: ErrorAction::IgnoreError, .. }) => {}, Err(e) => return Err(e.into()), @@ -274,21 +300,20 @@ mod tests { use crate::RapidGossipSync; const VALID_RGS_BINARY: [u8; 300] = [ - 76, 68, 75, 1, 111, 226, 140, 10, 182, 241, 179, 114, 193, 166, 162, 70, 174, 99, 247, - 79, 147, 30, 131, 101, 225, 90, 8, 156, 104, 214, 25, 0, 0, 0, 0, 0, 97, 227, 98, 218, - 0, 0, 0, 4, 2, 22, 7, 207, 206, 25, 164, 197, 231, 230, 231, 56, 102, 61, 250, 251, - 187, 172, 38, 46, 79, 247, 108, 44, 155, 48, 219, 238, 252, 53, 192, 6, 67, 2, 36, 125, - 157, 176, 223, 175, 234, 116, 94, 248, 201, 225, 97, 235, 50, 47, 115, 172, 63, 136, - 88, 216, 115, 11, 111, 217, 114, 84, 116, 124, 231, 107, 2, 158, 1, 242, 121, 152, 106, - 204, 131, 186, 35, 93, 70, 216, 10, 237, 224, 183, 89, 95, 65, 3, 83, 185, 58, 138, - 181, 64, 187, 103, 127, 68, 50, 2, 201, 19, 17, 138, 136, 149, 185, 226, 156, 137, 175, - 110, 32, 237, 0, 217, 90, 31, 100, 228, 149, 46, 219, 175, 168, 77, 4, 143, 38, 128, - 76, 97, 0, 0, 0, 2, 0, 0, 255, 8, 153, 192, 0, 2, 27, 0, 0, 0, 1, 0, 0, 255, 2, 68, - 226, 0, 6, 11, 0, 1, 2, 3, 0, 0, 0, 4, 0, 40, 0, 0, 0, 0, 0, 0, 3, 232, 0, 0, 3, 232, - 0, 0, 0, 1, 0, 0, 0, 0, 29, 129, 25, 192, 255, 8, 153, 192, 0, 2, 27, 0, 0, 60, 0, 0, - 0, 0, 0, 0, 0, 1, 0, 0, 0, 100, 0, 0, 2, 224, 0, 0, 0, 0, 58, 85, 116, 216, 0, 29, 0, - 0, 0, 1, 0, 0, 0, 125, 0, 0, 0, 0, 58, 85, 116, 216, 255, 2, 68, 226, 0, 6, 11, 0, 1, - 0, 0, 1, + 76, 68, 75, 1, 111, 226, 140, 10, 182, 241, 179, 114, 193, 166, 162, 70, 174, 99, 247, 79, + 147, 30, 131, 101, 225, 90, 8, 156, 104, 214, 25, 0, 0, 0, 0, 0, 97, 227, 98, 218, 0, 0, 0, + 4, 2, 22, 7, 207, 206, 25, 164, 197, 231, 230, 231, 56, 102, 61, 250, 251, 187, 172, 38, + 46, 79, 247, 108, 44, 155, 48, 219, 238, 252, 53, 192, 6, 67, 2, 36, 125, 157, 176, 223, + 175, 234, 116, 94, 248, 201, 225, 97, 235, 50, 47, 115, 172, 63, 136, 88, 216, 115, 11, + 111, 217, 114, 84, 116, 124, 231, 107, 2, 158, 1, 242, 121, 152, 106, 204, 131, 186, 35, + 93, 70, 216, 10, 237, 224, 183, 89, 95, 65, 3, 83, 185, 58, 138, 181, 64, 187, 103, 127, + 68, 50, 2, 201, 19, 17, 138, 136, 149, 185, 226, 156, 137, 175, 110, 32, 237, 0, 217, 90, + 31, 100, 228, 149, 46, 219, 175, 168, 77, 4, 143, 38, 128, 76, 97, 0, 0, 0, 2, 0, 0, 255, + 8, 153, 192, 0, 2, 27, 0, 0, 0, 1, 0, 0, 255, 2, 68, 226, 0, 6, 11, 0, 1, 2, 3, 0, 0, 0, 4, + 0, 40, 0, 0, 0, 0, 0, 0, 3, 232, 0, 0, 3, 232, 0, 0, 0, 1, 0, 0, 0, 0, 29, 129, 25, 192, + 255, 8, 153, 192, 0, 2, 27, 0, 0, 60, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 100, 0, 0, 2, 224, + 0, 0, 0, 0, 58, 85, 116, 216, 0, 29, 0, 0, 0, 1, 0, 0, 0, 125, 0, 0, 0, 0, 58, 85, 116, + 216, 255, 2, 68, 226, 0, 6, 11, 0, 1, 0, 0, 1, ]; const VALID_BINARY_TIMESTAMP: u64 = 1642291930; @@ -400,10 +425,7 @@ mod tests { let rapid_sync = RapidGossipSync::new(&network_graph, &logger); let initialization_result = rapid_sync.update_network_graph(&initialization_input[..]); if initialization_result.is_err() { - panic!( - "Unexpected initialization result: {:?}", - initialization_result - ) + panic!("Unexpected initialization result: {:?}", initialization_result) } assert_eq!(network_graph.read_only().channels().len(), 2); @@ -466,7 +488,8 @@ mod tests { 0, 0, 0, 0, 0, 0, 0, 0, 0, 255, 8, 153, 192, 0, 2, 27, 0, 0, 136, 0, 0, 0, 221, 255, 2, 68, 226, 0, 6, 11, 0, 1, 128, ]; - let update_result = rapid_sync.update_network_graph(&single_direction_incremental_update_input[..]); + let update_result = + rapid_sync.update_network_graph(&single_direction_incremental_update_input[..]); if update_result.is_err() { panic!("Unexpected update result: {:?}", update_result) } @@ -526,9 +549,11 @@ mod tests { 0, 0, 0, 0, 0, 0, 0, 0, 0, 255, 8, 153, 192, 0, 2, 27, 0, 0, 136, 0, 0, 0, 221, 255, 2, 68, 226, 0, 6, 11, 0, 1, 128, ]; - let update_result_1 = rapid_sync.update_network_graph(&single_direction_incremental_update_input[..]); + let update_result_1 = + rapid_sync.update_network_graph(&single_direction_incremental_update_input[..]); // Apply duplicate update - let update_result_2 = rapid_sync.update_network_graph(&single_direction_incremental_update_input[..]); + let update_result_2 = + rapid_sync.update_network_graph(&single_direction_incremental_update_input[..]); assert!(update_result_1.is_ok()); assert!(update_result_2.is_ok()); } @@ -591,7 +616,8 @@ mod tests { assert_eq!(network_graph.read_only().channels().len(), 0); let rapid_sync = RapidGossipSync::new(&network_graph, &logger); - let update_result = rapid_sync.update_network_graph_no_std(&VALID_RGS_BINARY, Some(latest_nonpruning_time)); + let update_result = rapid_sync + .update_network_graph_no_std(&VALID_RGS_BINARY, Some(latest_nonpruning_time)); assert!(update_result.is_ok()); assert_eq!(network_graph.read_only().channels().len(), 2); } @@ -601,7 +627,8 @@ mod tests { assert_eq!(network_graph.read_only().channels().len(), 0); let rapid_sync = RapidGossipSync::new(&network_graph, &logger); - let update_result = rapid_sync.update_network_graph_no_std(&VALID_RGS_BINARY, Some(latest_nonpruning_time + 1)); + let update_result = rapid_sync + .update_network_graph_no_std(&VALID_RGS_BINARY, Some(latest_nonpruning_time + 1)); assert!(update_result.is_ok()); assert_eq!(network_graph.read_only().channels().len(), 0); } @@ -620,7 +647,8 @@ mod tests { assert_eq!(network_graph.read_only().channels().len(), 0); let rapid_sync = RapidGossipSync::new(&network_graph, &logger); - let update_result = rapid_sync.update_network_graph_no_std(&VALID_RGS_BINARY, Some(latest_succeeding_time)); + let update_result = rapid_sync + .update_network_graph_no_std(&VALID_RGS_BINARY, Some(latest_succeeding_time)); assert!(update_result.is_ok()); assert_eq!(network_graph.read_only().channels().len(), 0); } @@ -630,7 +658,8 @@ mod tests { assert_eq!(network_graph.read_only().channels().len(), 0); let rapid_sync = RapidGossipSync::new(&network_graph, &logger); - let update_result = rapid_sync.update_network_graph_no_std(&VALID_RGS_BINARY, Some(earliest_failing_time)); + let update_result = rapid_sync + .update_network_graph_no_std(&VALID_RGS_BINARY, Some(earliest_failing_time)); assert!(update_result.is_err()); if let Err(GraphSyncError::LightningError(lightning_error)) = update_result { assert_eq!( @@ -690,7 +719,10 @@ mod tests { let update_result = rapid_sync.update_network_graph_no_std(&VALID_RGS_BINARY, Some(0)); assert!(update_result.is_err()); if let Err(GraphSyncError::LightningError(err)) = update_result { - assert_eq!(err.err, "Rapid Gossip Sync data's chain hash does not match the network graph's"); + assert_eq!( + err.err, + "Rapid Gossip Sync data's chain hash does not match the network graph's" + ); } else { panic!("Unexpected update result: {:?}", update_result) } diff --git a/lightning/src/blinded_path/message.rs b/lightning/src/blinded_path/message.rs index d2e81444ef6..cb9b49f435a 100644 --- a/lightning/src/blinded_path/message.rs +++ b/lightning/src/blinded_path/message.rs @@ -1,7 +1,7 @@ use bitcoin::secp256k1::{self, PublicKey, Secp256k1, SecretKey}; -use crate::blinded_path::{BlindedHop, BlindedPath}; use crate::blinded_path::utils; +use crate::blinded_path::{BlindedHop, BlindedPath}; use crate::io; use crate::io::Cursor; use crate::ln::onion_utils; @@ -55,9 +55,10 @@ impl Writeable for ReceiveTlvs { /// Construct blinded onion message hops for the given `unblinded_path`. pub(super) fn blinded_hops( - secp_ctx: &Secp256k1, unblinded_path: &[PublicKey], session_priv: &SecretKey + secp_ctx: &Secp256k1, unblinded_path: &[PublicKey], session_priv: &SecretKey, ) -> Result, secp256k1::Error> { - let blinded_tlvs = unblinded_path.iter() + let blinded_tlvs = unblinded_path + .iter() .skip(1) // The first node's TLVs contains the next node's pubkey .map(|pk| { ControlTlvs::Forward(ForwardTlvs { next_node_id: *pk, next_blinding_override: None }) @@ -70,28 +71,33 @@ pub(super) fn blinded_hops( // Advance the blinded onion message path by one hop, so make the second hop into the new // introduction node. pub(crate) fn advance_path_by_one( - path: &mut BlindedPath, node_signer: &NS, secp_ctx: &Secp256k1 -) -> Result<(), ()> where NS::Target: NodeSigner { + path: &mut BlindedPath, node_signer: &NS, secp_ctx: &Secp256k1, +) -> Result<(), ()> +where + NS::Target: NodeSigner, +{ let control_tlvs_ss = node_signer.ecdh(Recipient::Node, &path.blinding_point, None)?; let rho = onion_utils::gen_rho_from_shared_secret(&control_tlvs_ss.secret_bytes()); let encrypted_control_tlvs = path.blinded_hops.remove(0).encrypted_payload; let mut s = Cursor::new(&encrypted_control_tlvs); let mut reader = FixedLengthReader::new(&mut s, encrypted_control_tlvs.len() as u64); match ChaChaPolyReadAdapter::read(&mut reader, rho) { - Ok(ChaChaPolyReadAdapter { readable: ControlTlvs::Forward(ForwardTlvs { - mut next_node_id, next_blinding_override, - })}) => { + Ok(ChaChaPolyReadAdapter { + readable: ControlTlvs::Forward(ForwardTlvs { mut next_node_id, next_blinding_override }), + }) => { let mut new_blinding_point = match next_blinding_override { Some(blinding_point) => blinding_point, - None => { - onion_utils::next_hop_pubkey(secp_ctx, path.blinding_point, - control_tlvs_ss.as_ref()).map_err(|_| ())? - } + None => onion_utils::next_hop_pubkey( + secp_ctx, + path.blinding_point, + control_tlvs_ss.as_ref(), + ) + .map_err(|_| ())?, }; mem::swap(&mut path.blinding_point, &mut new_blinding_point); mem::swap(&mut path.introduction_node_id, &mut next_node_id); Ok(()) }, - _ => Err(()) + _ => Err(()), } } diff --git a/lightning/src/blinded_path/mod.rs b/lightning/src/blinded_path/mod.rs index d75b4f25b36..2b82ab92559 100644 --- a/lightning/src/blinded_path/mod.rs +++ b/lightning/src/blinded_path/mod.rs @@ -9,8 +9,8 @@ //! Creating blinded paths and related utilities live here. -pub mod payment; pub(crate) mod message; +pub mod payment; pub(crate) mod utils; use bitcoin::secp256k1::{self, PublicKey, Secp256k1, SecretKey}; @@ -57,8 +57,11 @@ pub struct BlindedHop { impl BlindedPath { /// Create a one-hop blinded path for a message. - pub fn one_hop_for_message( - recipient_node_id: PublicKey, entropy_source: &ES, secp_ctx: &Secp256k1 + pub fn one_hop_for_message< + ES: EntropySource + ?Sized, + T: secp256k1::Signing + secp256k1::Verification, + >( + recipient_node_id: PublicKey, entropy_source: &ES, secp_ctx: &Secp256k1, ) -> Result { Self::new_for_message(&[recipient_node_id], entropy_source, secp_ctx) } @@ -68,31 +71,46 @@ impl BlindedPath { /// /// Errors if no hops are provided or if `node_pk`(s) are invalid. // TODO: make all payloads the same size with padding + add dummy hops - pub fn new_for_message( - node_pks: &[PublicKey], entropy_source: &ES, secp_ctx: &Secp256k1 + pub fn new_for_message< + ES: EntropySource + ?Sized, + T: secp256k1::Signing + secp256k1::Verification, + >( + node_pks: &[PublicKey], entropy_source: &ES, secp_ctx: &Secp256k1, ) -> Result { - if node_pks.is_empty() { return Err(()) } + if node_pks.is_empty() { + return Err(()); + } let blinding_secret_bytes = entropy_source.get_secure_random_bytes(); - let blinding_secret = SecretKey::from_slice(&blinding_secret_bytes[..]).expect("RNG is busted"); + let blinding_secret = + SecretKey::from_slice(&blinding_secret_bytes[..]).expect("RNG is busted"); let introduction_node_id = node_pks[0]; Ok(BlindedPath { introduction_node_id, blinding_point: PublicKey::from_secret_key(secp_ctx, &blinding_secret), - blinded_hops: message::blinded_hops(secp_ctx, node_pks, &blinding_secret).map_err(|_| ())?, + blinded_hops: message::blinded_hops(secp_ctx, node_pks, &blinding_secret) + .map_err(|_| ())?, }) } /// Create a one-hop blinded path for a payment. - pub fn one_hop_for_payment( + pub fn one_hop_for_payment< + ES: EntropySource + ?Sized, + T: secp256k1::Signing + secp256k1::Verification, + >( payee_node_id: PublicKey, payee_tlvs: payment::ReceiveTlvs, entropy_source: &ES, - secp_ctx: &Secp256k1 + secp_ctx: &Secp256k1, ) -> Result<(BlindedPayInfo, Self), ()> { // This value is not considered in pathfinding for 1-hop blinded paths, because it's intended to // be in relation to a specific channel. let htlc_maximum_msat = u64::max_value(); Self::new_for_payment( - &[], payee_node_id, payee_tlvs, htlc_maximum_msat, entropy_source, secp_ctx + &[], + payee_node_id, + payee_tlvs, + htlc_maximum_msat, + entropy_source, + secp_ctx, ) } @@ -105,21 +123,31 @@ impl BlindedPath { /// /// [`ForwardTlvs`]: crate::blinded_path::payment::ForwardTlvs // TODO: make all payloads the same size with padding + add dummy hops - pub(crate) fn new_for_payment( + pub(crate) fn new_for_payment< + ES: EntropySource + ?Sized, + T: secp256k1::Signing + secp256k1::Verification, + >( intermediate_nodes: &[payment::ForwardNode], payee_node_id: PublicKey, payee_tlvs: payment::ReceiveTlvs, htlc_maximum_msat: u64, entropy_source: &ES, - secp_ctx: &Secp256k1 + secp_ctx: &Secp256k1, ) -> Result<(BlindedPayInfo, Self), ()> { let blinding_secret_bytes = entropy_source.get_secure_random_bytes(); - let blinding_secret = SecretKey::from_slice(&blinding_secret_bytes[..]).expect("RNG is busted"); + let blinding_secret = + SecretKey::from_slice(&blinding_secret_bytes[..]).expect("RNG is busted"); - let blinded_payinfo = payment::compute_payinfo(intermediate_nodes, &payee_tlvs, htlc_maximum_msat)?; + let blinded_payinfo = + payment::compute_payinfo(intermediate_nodes, &payee_tlvs, htlc_maximum_msat)?; Ok((blinded_payinfo, BlindedPath { introduction_node_id: intermediate_nodes.first().map_or(payee_node_id, |n| n.node_id), blinding_point: PublicKey::from_secret_key(secp_ctx, &blinding_secret), blinded_hops: payment::blinded_hops( - secp_ctx, intermediate_nodes, payee_node_id, payee_tlvs, &blinding_secret - ).map_err(|_| ())?, + secp_ctx, + intermediate_nodes, + payee_node_id, + payee_tlvs, + &blinding_secret, + ) + .map_err(|_| ())?, })) } } @@ -141,16 +169,14 @@ impl Readable for BlindedPath { let introduction_node_id = Readable::read(r)?; let blinding_point = Readable::read(r)?; let num_hops: u8 = Readable::read(r)?; - if num_hops == 0 { return Err(DecodeError::InvalidValue) } + if num_hops == 0 { + return Err(DecodeError::InvalidValue); + } let mut blinded_hops: Vec = Vec::with_capacity(num_hops.into()); for _ in 0..num_hops { blinded_hops.push(Readable::read(r)?); } - Ok(BlindedPath { - introduction_node_id, - blinding_point, - blinded_hops, - }) + Ok(BlindedPath { introduction_node_id, blinding_point, blinded_hops }) } } @@ -158,4 +184,3 @@ impl_writeable!(BlindedHop, { blinded_node_id, encrypted_payload }); - diff --git a/lightning/src/blinded_path/payment.rs b/lightning/src/blinded_path/payment.rs index 4edfb7d8de0..11f0d319fd0 100644 --- a/lightning/src/blinded_path/payment.rs +++ b/lightning/src/blinded_path/payment.rs @@ -4,12 +4,12 @@ use bitcoin::secp256k1::{self, PublicKey, Secp256k1, SecretKey}; -use crate::blinded_path::BlindedHop; use crate::blinded_path::utils; +use crate::blinded_path::BlindedHop; use crate::io; -use crate::ln::PaymentSecret; use crate::ln::features::BlindedHopFeatures; use crate::ln::msgs::DecodeError; +use crate::ln::PaymentSecret; use crate::offers::invoice::BlindedPayInfo; use crate::prelude::*; use crate::util::ser::{Readable, Writeable, Writer}; @@ -128,7 +128,7 @@ impl Readable for ReceiveTlvs { }); Ok(Self { payment_secret: payment_secret.0.unwrap(), - payment_constraints: payment_constraints.0.unwrap() + payment_constraints: payment_constraints.0.unwrap(), }) } } @@ -157,7 +157,9 @@ impl Readable for BlindedPaymentTlvs { let _padding: Option = _padding; if let Some(short_channel_id) = scid { - if payment_secret.is_some() { return Err(DecodeError::InvalidValue) } + if payment_secret.is_some() { + return Err(DecodeError::InvalidValue); + } Ok(BlindedPaymentTlvs::Forward(ForwardTlvs { short_channel_id, payment_relay: payment_relay.ok_or(DecodeError::InvalidValue)?, @@ -165,7 +167,9 @@ impl Readable for BlindedPaymentTlvs { features: features.ok_or(DecodeError::InvalidValue)?, })) } else { - if payment_relay.is_some() || features.is_some() { return Err(DecodeError::InvalidValue) } + if payment_relay.is_some() || features.is_some() { + return Err(DecodeError::InvalidValue); + } Ok(BlindedPaymentTlvs::Receive(ReceiveTlvs { payment_secret: payment_secret.ok_or(DecodeError::InvalidValue)?, payment_constraints: payment_constraints.0.unwrap(), @@ -176,12 +180,14 @@ impl Readable for BlindedPaymentTlvs { /// Construct blinded payment hops for the given `intermediate_nodes` and payee info. pub(super) fn blinded_hops( - secp_ctx: &Secp256k1, intermediate_nodes: &[ForwardNode], - payee_node_id: PublicKey, payee_tlvs: ReceiveTlvs, session_priv: &SecretKey + secp_ctx: &Secp256k1, intermediate_nodes: &[ForwardNode], payee_node_id: PublicKey, + payee_tlvs: ReceiveTlvs, session_priv: &SecretKey, ) -> Result, secp256k1::Error> { - let pks = intermediate_nodes.iter().map(|node| &node.node_id) - .chain(core::iter::once(&payee_node_id)); - let tlvs = intermediate_nodes.iter().map(|node| BlindedPaymentTlvsRef::Forward(&node.tlvs)) + let pks = + intermediate_nodes.iter().map(|node| &node.node_id).chain(core::iter::once(&payee_node_id)); + let tlvs = intermediate_nodes + .iter() + .map(|node| BlindedPaymentTlvsRef::Forward(&node.tlvs)) .chain(core::iter::once(BlindedPaymentTlvsRef::Receive(&payee_tlvs))); utils::construct_blinded_hops(secp_ctx, pks, tlvs, session_priv) } @@ -208,7 +214,7 @@ fn amt_to_forward_msat(inbound_amt_msat: u64, payment_relay: &PaymentRelay) -> O } pub(super) fn compute_payinfo( - intermediate_nodes: &[ForwardNode], payee_tlvs: &ReceiveTlvs, payee_htlc_maximum_msat: u64 + intermediate_nodes: &[ForwardNode], payee_tlvs: &ReceiveTlvs, payee_htlc_maximum_msat: u64, ) -> Result { let mut curr_base_fee: u64 = 0; let mut curr_prop_mil: u64 = 0; @@ -216,26 +222,31 @@ pub(super) fn compute_payinfo( for tlvs in intermediate_nodes.iter().rev().map(|n| &n.tlvs) { // In the future, we'll want to take the intersection of all supported features for the // `BlindedPayInfo`, but there are no features in that context right now. - if tlvs.features.requires_unknown_bits_from(&BlindedHopFeatures::empty()) { return Err(()) } + if tlvs.features.requires_unknown_bits_from(&BlindedHopFeatures::empty()) { + return Err(()); + } let next_base_fee = tlvs.payment_relay.fee_base_msat as u64; let next_prop_mil = tlvs.payment_relay.fee_proportional_millionths as u64; // Use integer arithmetic to compute `ceil(a/b)` as `(a+b-1)/b` // ((curr_base_fee * (1_000_000 + next_prop_mil)) / 1_000_000) + next_base_fee - curr_base_fee = curr_base_fee.checked_mul(1_000_000 + next_prop_mil) + curr_base_fee = curr_base_fee + .checked_mul(1_000_000 + next_prop_mil) .and_then(|f| f.checked_add(1_000_000 - 1)) .map(|f| f / 1_000_000) .and_then(|f| f.checked_add(next_base_fee)) .ok_or(())?; // ceil(((curr_prop_mil + 1_000_000) * (next_prop_mil + 1_000_000)) / 1_000_000) - 1_000_000 - curr_prop_mil = curr_prop_mil.checked_add(1_000_000) + curr_prop_mil = curr_prop_mil + .checked_add(1_000_000) .and_then(|f1| next_prop_mil.checked_add(1_000_000).and_then(|f2| f2.checked_mul(f1))) .and_then(|f| f.checked_add(1_000_000 - 1)) .map(|f| f / 1_000_000) .and_then(|f| f.checked_sub(1_000_000)) .ok_or(())?; - cltv_expiry_delta = cltv_expiry_delta.checked_add(tlvs.payment_relay.cltv_expiry_delta).ok_or(())?; + cltv_expiry_delta = + cltv_expiry_delta.checked_add(tlvs.payment_relay.cltv_expiry_delta).ok_or(())?; } let mut htlc_minimum_msat: u64 = 1; @@ -246,18 +257,22 @@ pub(super) fn compute_payinfo( // in the amount that this node receives and contribute towards reaching its min. htlc_minimum_msat = amt_to_forward_msat( core::cmp::max(node.tlvs.payment_constraints.htlc_minimum_msat, htlc_minimum_msat), - &node.tlvs.payment_relay - ).unwrap_or(1); // If underflow occurs, we definitely reached this node's min + &node.tlvs.payment_relay, + ) + .unwrap_or(1); // If underflow occurs, we definitely reached this node's min htlc_maximum_msat = amt_to_forward_msat( - core::cmp::min(node.htlc_maximum_msat, htlc_maximum_msat), &node.tlvs.payment_relay - ).ok_or(())?; // If underflow occurs, we cannot send to this hop without exceeding their max + core::cmp::min(node.htlc_maximum_msat, htlc_maximum_msat), + &node.tlvs.payment_relay, + ) + .ok_or(())?; // If underflow occurs, we cannot send to this hop without exceeding their max } - htlc_minimum_msat = core::cmp::max( - payee_tlvs.payment_constraints.htlc_minimum_msat, htlc_minimum_msat - ); + htlc_minimum_msat = + core::cmp::max(payee_tlvs.payment_constraints.htlc_minimum_msat, htlc_minimum_msat); htlc_maximum_msat = core::cmp::min(payee_htlc_maximum_msat, htlc_maximum_msat); - if htlc_maximum_msat < htlc_minimum_msat { return Err(()) } + if htlc_maximum_msat < htlc_minimum_msat { + return Err(()); + } Ok(BlindedPayInfo { fee_base_msat: u32::try_from(curr_base_fee).map_err(|_| ())?, fee_proportional_millionths: u32::try_from(curr_prop_mil).map_err(|_| ())?, @@ -281,58 +296,61 @@ impl_writeable_msg!(PaymentConstraints, { #[cfg(test)] mod tests { - use bitcoin::secp256k1::PublicKey; - use crate::blinded_path::payment::{ForwardNode, ForwardTlvs, ReceiveTlvs, PaymentConstraints, PaymentRelay}; - use crate::ln::PaymentSecret; + use crate::blinded_path::payment::{ + ForwardNode, ForwardTlvs, PaymentConstraints, PaymentRelay, ReceiveTlvs, + }; use crate::ln::features::BlindedHopFeatures; + use crate::ln::PaymentSecret; + use bitcoin::secp256k1::PublicKey; #[test] fn compute_payinfo() { // Taken from the spec example for aggregating blinded payment info. See // https://github.com/lightning/bolts/blob/master/proposals/route-blinding.md#blinded-payments let dummy_pk = PublicKey::from_slice(&[2; 33]).unwrap(); - let intermediate_nodes = vec![ForwardNode { - node_id: dummy_pk, - tlvs: ForwardTlvs { - short_channel_id: 0, - payment_relay: PaymentRelay { - cltv_expiry_delta: 144, - fee_proportional_millionths: 500, - fee_base_msat: 100, - }, - payment_constraints: PaymentConstraints { - max_cltv_expiry: 0, - htlc_minimum_msat: 100, + let intermediate_nodes = vec![ + ForwardNode { + node_id: dummy_pk, + tlvs: ForwardTlvs { + short_channel_id: 0, + payment_relay: PaymentRelay { + cltv_expiry_delta: 144, + fee_proportional_millionths: 500, + fee_base_msat: 100, + }, + payment_constraints: PaymentConstraints { + max_cltv_expiry: 0, + htlc_minimum_msat: 100, + }, + features: BlindedHopFeatures::empty(), }, - features: BlindedHopFeatures::empty(), + htlc_maximum_msat: u64::max_value(), }, - htlc_maximum_msat: u64::max_value(), - }, ForwardNode { - node_id: dummy_pk, - tlvs: ForwardTlvs { - short_channel_id: 0, - payment_relay: PaymentRelay { - cltv_expiry_delta: 144, - fee_proportional_millionths: 500, - fee_base_msat: 100, - }, - payment_constraints: PaymentConstraints { - max_cltv_expiry: 0, - htlc_minimum_msat: 1_000, + ForwardNode { + node_id: dummy_pk, + tlvs: ForwardTlvs { + short_channel_id: 0, + payment_relay: PaymentRelay { + cltv_expiry_delta: 144, + fee_proportional_millionths: 500, + fee_base_msat: 100, + }, + payment_constraints: PaymentConstraints { + max_cltv_expiry: 0, + htlc_minimum_msat: 1_000, + }, + features: BlindedHopFeatures::empty(), }, - features: BlindedHopFeatures::empty(), + htlc_maximum_msat: u64::max_value(), }, - htlc_maximum_msat: u64::max_value(), - }]; + ]; let recv_tlvs = ReceiveTlvs { payment_secret: PaymentSecret([0; 32]), - payment_constraints: PaymentConstraints { - max_cltv_expiry: 0, - htlc_minimum_msat: 1, - }, + payment_constraints: PaymentConstraints { max_cltv_expiry: 0, htlc_minimum_msat: 1 }, }; let htlc_maximum_msat = 100_000; - let blinded_payinfo = super::compute_payinfo(&intermediate_nodes[..], &recv_tlvs, htlc_maximum_msat).unwrap(); + let blinded_payinfo = + super::compute_payinfo(&intermediate_nodes[..], &recv_tlvs, htlc_maximum_msat).unwrap(); assert_eq!(blinded_payinfo.fee_base_msat, 201); assert_eq!(blinded_payinfo.fee_proportional_millionths, 1001); assert_eq!(blinded_payinfo.cltv_expiry_delta, 288); @@ -344,10 +362,7 @@ mod tests { fn compute_payinfo_1_hop() { let recv_tlvs = ReceiveTlvs { payment_secret: PaymentSecret([0; 32]), - payment_constraints: PaymentConstraints { - max_cltv_expiry: 0, - htlc_minimum_msat: 1, - }, + payment_constraints: PaymentConstraints { max_cltv_expiry: 0, htlc_minimum_msat: 1 }, }; let blinded_payinfo = super::compute_payinfo(&[], &recv_tlvs, 4242).unwrap(); assert_eq!(blinded_payinfo.fee_base_msat, 0); @@ -362,48 +377,49 @@ mod tests { // If no hops charge fees, the htlc_minimum_msat should just be the maximum htlc_minimum_msat // along the path. let dummy_pk = PublicKey::from_slice(&[2; 33]).unwrap(); - let intermediate_nodes = vec![ForwardNode { - node_id: dummy_pk, - tlvs: ForwardTlvs { - short_channel_id: 0, - payment_relay: PaymentRelay { - cltv_expiry_delta: 0, - fee_proportional_millionths: 0, - fee_base_msat: 0, - }, - payment_constraints: PaymentConstraints { - max_cltv_expiry: 0, - htlc_minimum_msat: 1, + let intermediate_nodes = vec![ + ForwardNode { + node_id: dummy_pk, + tlvs: ForwardTlvs { + short_channel_id: 0, + payment_relay: PaymentRelay { + cltv_expiry_delta: 0, + fee_proportional_millionths: 0, + fee_base_msat: 0, + }, + payment_constraints: PaymentConstraints { + max_cltv_expiry: 0, + htlc_minimum_msat: 1, + }, + features: BlindedHopFeatures::empty(), }, - features: BlindedHopFeatures::empty(), + htlc_maximum_msat: u64::max_value(), }, - htlc_maximum_msat: u64::max_value() - }, ForwardNode { - node_id: dummy_pk, - tlvs: ForwardTlvs { - short_channel_id: 0, - payment_relay: PaymentRelay { - cltv_expiry_delta: 0, - fee_proportional_millionths: 0, - fee_base_msat: 0, + ForwardNode { + node_id: dummy_pk, + tlvs: ForwardTlvs { + short_channel_id: 0, + payment_relay: PaymentRelay { + cltv_expiry_delta: 0, + fee_proportional_millionths: 0, + fee_base_msat: 0, + }, + payment_constraints: PaymentConstraints { + max_cltv_expiry: 0, + htlc_minimum_msat: 2_000, + }, + features: BlindedHopFeatures::empty(), }, - payment_constraints: PaymentConstraints { - max_cltv_expiry: 0, - htlc_minimum_msat: 2_000, - }, - features: BlindedHopFeatures::empty(), + htlc_maximum_msat: u64::max_value(), }, - htlc_maximum_msat: u64::max_value() - }]; + ]; let recv_tlvs = ReceiveTlvs { payment_secret: PaymentSecret([0; 32]), - payment_constraints: PaymentConstraints { - max_cltv_expiry: 0, - htlc_minimum_msat: 3, - }, + payment_constraints: PaymentConstraints { max_cltv_expiry: 0, htlc_minimum_msat: 3 }, }; let htlc_maximum_msat = 100_000; - let blinded_payinfo = super::compute_payinfo(&intermediate_nodes[..], &recv_tlvs, htlc_maximum_msat).unwrap(); + let blinded_payinfo = + super::compute_payinfo(&intermediate_nodes[..], &recv_tlvs, htlc_maximum_msat).unwrap(); assert_eq!(blinded_payinfo.htlc_minimum_msat, 2_000); } @@ -412,51 +428,53 @@ mod tests { // Create a path with varying fees and htlc_mins, and make sure htlc_minimum_msat ends up as the // max (htlc_min - following_fees) along the path. let dummy_pk = PublicKey::from_slice(&[2; 33]).unwrap(); - let intermediate_nodes = vec![ForwardNode { - node_id: dummy_pk, - tlvs: ForwardTlvs { - short_channel_id: 0, - payment_relay: PaymentRelay { - cltv_expiry_delta: 0, - fee_proportional_millionths: 500, - fee_base_msat: 1_000, - }, - payment_constraints: PaymentConstraints { - max_cltv_expiry: 0, - htlc_minimum_msat: 5_000, + let intermediate_nodes = vec![ + ForwardNode { + node_id: dummy_pk, + tlvs: ForwardTlvs { + short_channel_id: 0, + payment_relay: PaymentRelay { + cltv_expiry_delta: 0, + fee_proportional_millionths: 500, + fee_base_msat: 1_000, + }, + payment_constraints: PaymentConstraints { + max_cltv_expiry: 0, + htlc_minimum_msat: 5_000, + }, + features: BlindedHopFeatures::empty(), }, - features: BlindedHopFeatures::empty(), + htlc_maximum_msat: u64::max_value(), }, - htlc_maximum_msat: u64::max_value() - }, ForwardNode { - node_id: dummy_pk, - tlvs: ForwardTlvs { - short_channel_id: 0, - payment_relay: PaymentRelay { - cltv_expiry_delta: 0, - fee_proportional_millionths: 500, - fee_base_msat: 200, + ForwardNode { + node_id: dummy_pk, + tlvs: ForwardTlvs { + short_channel_id: 0, + payment_relay: PaymentRelay { + cltv_expiry_delta: 0, + fee_proportional_millionths: 500, + fee_base_msat: 200, + }, + payment_constraints: PaymentConstraints { + max_cltv_expiry: 0, + htlc_minimum_msat: 2_000, + }, + features: BlindedHopFeatures::empty(), }, - payment_constraints: PaymentConstraints { - max_cltv_expiry: 0, - htlc_minimum_msat: 2_000, - }, - features: BlindedHopFeatures::empty(), + htlc_maximum_msat: u64::max_value(), }, - htlc_maximum_msat: u64::max_value() - }]; + ]; let recv_tlvs = ReceiveTlvs { payment_secret: PaymentSecret([0; 32]), - payment_constraints: PaymentConstraints { - max_cltv_expiry: 0, - htlc_minimum_msat: 1, - }, + payment_constraints: PaymentConstraints { max_cltv_expiry: 0, htlc_minimum_msat: 1 }, }; let htlc_minimum_msat = 3798; - assert!(super::compute_payinfo(&intermediate_nodes[..], &recv_tlvs, htlc_minimum_msat - 1).is_err()); + assert!(super::compute_payinfo(&intermediate_nodes[..], &recv_tlvs, htlc_minimum_msat - 1) + .is_err()); let htlc_maximum_msat = htlc_minimum_msat + 1; - let blinded_payinfo = super::compute_payinfo(&intermediate_nodes[..], &recv_tlvs, htlc_maximum_msat).unwrap(); + let blinded_payinfo = + super::compute_payinfo(&intermediate_nodes[..], &recv_tlvs, htlc_maximum_msat).unwrap(); assert_eq!(blinded_payinfo.htlc_minimum_msat, htlc_minimum_msat); assert_eq!(blinded_payinfo.htlc_maximum_msat, htlc_maximum_msat); } @@ -466,48 +484,49 @@ mod tests { // Create a path with varying fees and `htlc_maximum_msat`s, and make sure the aggregated max // htlc ends up as the min (htlc_max - following_fees) along the path. let dummy_pk = PublicKey::from_slice(&[2; 33]).unwrap(); - let intermediate_nodes = vec![ForwardNode { - node_id: dummy_pk, - tlvs: ForwardTlvs { - short_channel_id: 0, - payment_relay: PaymentRelay { - cltv_expiry_delta: 0, - fee_proportional_millionths: 500, - fee_base_msat: 1_000, - }, - payment_constraints: PaymentConstraints { - max_cltv_expiry: 0, - htlc_minimum_msat: 1, + let intermediate_nodes = vec![ + ForwardNode { + node_id: dummy_pk, + tlvs: ForwardTlvs { + short_channel_id: 0, + payment_relay: PaymentRelay { + cltv_expiry_delta: 0, + fee_proportional_millionths: 500, + fee_base_msat: 1_000, + }, + payment_constraints: PaymentConstraints { + max_cltv_expiry: 0, + htlc_minimum_msat: 1, + }, + features: BlindedHopFeatures::empty(), }, - features: BlindedHopFeatures::empty(), + htlc_maximum_msat: 5_000, }, - htlc_maximum_msat: 5_000, - }, ForwardNode { - node_id: dummy_pk, - tlvs: ForwardTlvs { - short_channel_id: 0, - payment_relay: PaymentRelay { - cltv_expiry_delta: 0, - fee_proportional_millionths: 500, - fee_base_msat: 1, + ForwardNode { + node_id: dummy_pk, + tlvs: ForwardTlvs { + short_channel_id: 0, + payment_relay: PaymentRelay { + cltv_expiry_delta: 0, + fee_proportional_millionths: 500, + fee_base_msat: 1, + }, + payment_constraints: PaymentConstraints { + max_cltv_expiry: 0, + htlc_minimum_msat: 1, + }, + features: BlindedHopFeatures::empty(), }, - payment_constraints: PaymentConstraints { - max_cltv_expiry: 0, - htlc_minimum_msat: 1, - }, - features: BlindedHopFeatures::empty(), + htlc_maximum_msat: 10_000, }, - htlc_maximum_msat: 10_000 - }]; + ]; let recv_tlvs = ReceiveTlvs { payment_secret: PaymentSecret([0; 32]), - payment_constraints: PaymentConstraints { - max_cltv_expiry: 0, - htlc_minimum_msat: 1, - }, + payment_constraints: PaymentConstraints { max_cltv_expiry: 0, htlc_minimum_msat: 1 }, }; - let blinded_payinfo = super::compute_payinfo(&intermediate_nodes[..], &recv_tlvs, 10_000).unwrap(); + let blinded_payinfo = + super::compute_payinfo(&intermediate_nodes[..], &recv_tlvs, 10_000).unwrap(); assert_eq!(blinded_payinfo.htlc_maximum_msat, 3997); } } diff --git a/lightning/src/blinded_path/utils.rs b/lightning/src/blinded_path/utils.rs index c62b4e6c261..65ff6a79c62 100644 --- a/lightning/src/blinded_path/utils.rs +++ b/lightning/src/blinded_path/utils.rs @@ -9,11 +9,11 @@ //! Onion message utility methods live here. -use bitcoin::hashes::{Hash, HashEngine}; use bitcoin::hashes::hmac::{Hmac, HmacEngine}; use bitcoin::hashes::sha256::Hash as Sha256; -use bitcoin::secp256k1::{self, PublicKey, Secp256k1, SecretKey, Scalar}; +use bitcoin::hashes::{Hash, HashEngine}; use bitcoin::secp256k1::ecdh::SharedSecret; +use bitcoin::secp256k1::{self, PublicKey, Scalar, Secp256k1, SecretKey}; use super::{BlindedHop, BlindedPath}; use crate::ln::msgs::DecodeError; @@ -29,11 +29,11 @@ use crate::prelude::*; #[inline] pub(crate) fn construct_keys_callback<'a, T, I, F>( secp_ctx: &Secp256k1, unblinded_path: I, destination: Option, - session_priv: &SecretKey, mut callback: F + session_priv: &SecretKey, mut callback: F, ) -> Result<(), secp256k1::Error> where T: secp256k1::Signing + secp256k1::Verification, - I: Iterator, + I: Iterator, F: FnMut(PublicKey, SharedSecret, PublicKey, [u8; 32], Option, Option>), { let mut msg_blinding_point_priv = session_priv.clone(); @@ -45,7 +45,9 @@ where ($pk: expr, $blinded: expr, $encrypted_payload: expr) => {{ let encrypted_data_ss = SharedSecret::new(&$pk, &msg_blinding_point_priv); - let blinded_hop_pk = if $blinded { $pk } else { + let blinded_hop_pk = if $blinded { + $pk + } else { let hop_pk_blinding_factor = { let mut hmac = HmacEngine::::new(b"blinded_node_id"); hmac.input(encrypted_data_ss.as_ref()); @@ -57,14 +59,22 @@ where let rho = onion_utils::gen_rho_from_shared_secret(encrypted_data_ss.as_ref()); let unblinded_pk_opt = if $blinded { None } else { Some($pk) }; - callback(blinded_hop_pk, onion_packet_ss, onion_packet_pubkey, rho, unblinded_pk_opt, $encrypted_payload); + callback( + blinded_hop_pk, + onion_packet_ss, + onion_packet_pubkey, + rho, + unblinded_pk_opt, + $encrypted_payload, + ); (encrypted_data_ss, onion_packet_ss) - }} + }}; } macro_rules! build_keys_in_loop { ($pk: expr, $blinded: expr, $encrypted_payload: expr) => { - let (encrypted_data_ss, onion_packet_ss) = build_keys!($pk, $blinded, $encrypted_payload); + let (encrypted_data_ss, onion_packet_ss) = + build_keys!($pk, $blinded, $encrypted_payload); let msg_blinding_point_blinding_factor = { let mut sha = Sha256::engine(); @@ -73,7 +83,8 @@ where Sha256::from_engine(sha).into_inner() }; - msg_blinding_point_priv = msg_blinding_point_priv.mul_tweak(&Scalar::from_be_bytes(msg_blinding_point_blinding_factor).unwrap())?; + msg_blinding_point_priv = msg_blinding_point_priv + .mul_tweak(&Scalar::from_be_bytes(msg_blinding_point_blinding_factor).unwrap())?; msg_blinding_point = PublicKey::from_secret_key(secp_ctx, &msg_blinding_point_priv); let onion_packet_pubkey_blinding_factor = { @@ -82,7 +93,8 @@ where sha.input(onion_packet_ss.as_ref()); Sha256::from_engine(sha).into_inner() }; - onion_packet_pubkey_priv = onion_packet_pubkey_priv.mul_tweak(&Scalar::from_be_bytes(onion_packet_pubkey_blinding_factor).unwrap())?; + onion_packet_pubkey_priv = onion_packet_pubkey_priv + .mul_tweak(&Scalar::from_be_bytes(onion_packet_pubkey_blinding_factor).unwrap())?; onion_packet_pubkey = PublicKey::from_secret_key(secp_ctx, &onion_packet_pubkey_priv); }; } @@ -107,23 +119,30 @@ where // Panics if `unblinded_tlvs` length is less than `unblinded_pks` length pub(super) fn construct_blinded_hops<'a, T, I1, I2>( - secp_ctx: &Secp256k1, unblinded_pks: I1, mut unblinded_tlvs: I2, session_priv: &SecretKey + secp_ctx: &Secp256k1, unblinded_pks: I1, mut unblinded_tlvs: I2, session_priv: &SecretKey, ) -> Result, secp256k1::Error> where T: secp256k1::Signing + secp256k1::Verification, - I1: Iterator, + I1: Iterator, I2: Iterator, - I2::Item: Writeable + I2::Item: Writeable, { let mut blinded_hops = Vec::with_capacity(unblinded_pks.size_hint().0); construct_keys_callback( - secp_ctx, unblinded_pks, None, session_priv, + secp_ctx, + unblinded_pks, + None, + session_priv, |blinded_node_id, _, _, encrypted_payload_rho, _, _| { blinded_hops.push(BlindedHop { blinded_node_id, - encrypted_payload: encrypt_payload(unblinded_tlvs.next().unwrap(), encrypted_payload_rho), + encrypted_payload: encrypt_payload( + unblinded_tlvs.next().unwrap(), + encrypted_payload_rho, + ), }); - })?; + }, + )?; Ok(blinded_hops) } @@ -144,7 +163,9 @@ impl Readable for Padding { fn read(reader: &mut R) -> Result { loop { let mut buf = [0; 8192]; - if reader.read(&mut buf[..])? == 0 { break; } + if reader.read(&mut buf[..])? == 0 { + break; + } } Ok(Self {}) } diff --git a/lightning/src/chain/chaininterface.rs b/lightning/src/chain/chaininterface.rs index 73707f05236..ad14862dbb3 100644 --- a/lightning/src/chain/chaininterface.rs +++ b/lightning/src/chain/chaininterface.rs @@ -13,8 +13,8 @@ //! Includes traits for monitoring and receiving notifications of new blocks and block //! disconnections, transaction broadcasting, and feerate information requests. -use core::{cmp, ops::Deref}; use core::convert::TryInto; +use core::{cmp, ops::Deref}; use bitcoin::blockdata::transaction::Transaction; @@ -176,25 +176,29 @@ pub const FEERATE_FLOOR_SATS_PER_KW: u32 = 253; /// /// Note that this does *not* implement [`FeeEstimator`] to make it harder to accidentally mix the /// two. -pub(crate) struct LowerBoundedFeeEstimator(pub F) where F::Target: FeeEstimator; - -impl LowerBoundedFeeEstimator where F::Target: FeeEstimator { +pub(crate) struct LowerBoundedFeeEstimator(pub F) +where + F::Target: FeeEstimator; + +impl LowerBoundedFeeEstimator +where + F::Target: FeeEstimator, +{ /// Creates a new `LowerBoundedFeeEstimator` which wraps the provided fee_estimator pub fn new(fee_estimator: F) -> Self { LowerBoundedFeeEstimator(fee_estimator) } pub fn bounded_sat_per_1000_weight(&self, confirmation_target: ConfirmationTarget) -> u32 { - cmp::max( - self.0.get_est_sat_per_1000_weight(confirmation_target), - FEERATE_FLOOR_SATS_PER_KW, - ) + cmp::max(self.0.get_est_sat_per_1000_weight(confirmation_target), FEERATE_FLOOR_SATS_PER_KW) } } #[cfg(test)] mod tests { - use super::{FEERATE_FLOOR_SATS_PER_KW, LowerBoundedFeeEstimator, ConfirmationTarget, FeeEstimator}; + use super::{ + ConfirmationTarget, FeeEstimator, LowerBoundedFeeEstimator, FEERATE_FLOOR_SATS_PER_KW, + }; struct TestFeeEstimator { sat_per_kw: u32, @@ -212,7 +216,10 @@ mod tests { let test_fee_estimator = &TestFeeEstimator { sat_per_kw }; let fee_estimator = LowerBoundedFeeEstimator::new(test_fee_estimator); - assert_eq!(fee_estimator.bounded_sat_per_1000_weight(ConfirmationTarget::AnchorChannelFee), FEERATE_FLOOR_SATS_PER_KW); + assert_eq!( + fee_estimator.bounded_sat_per_1000_weight(ConfirmationTarget::AnchorChannelFee), + FEERATE_FLOOR_SATS_PER_KW + ); } #[test] @@ -221,6 +228,9 @@ mod tests { let test_fee_estimator = &TestFeeEstimator { sat_per_kw }; let fee_estimator = LowerBoundedFeeEstimator::new(test_fee_estimator); - assert_eq!(fee_estimator.bounded_sat_per_1000_weight(ConfirmationTarget::AnchorChannelFee), sat_per_kw); + assert_eq!( + fee_estimator.bounded_sat_per_1000_weight(ConfirmationTarget::AnchorChannelFee), + sat_per_kw + ); } } diff --git a/lightning/src/chain/chainmonitor.rs b/lightning/src/chain/chainmonitor.rs index e87d082d9a7..39797508796 100644 --- a/lightning/src/chain/chainmonitor.rs +++ b/lightning/src/chain/chainmonitor.rs @@ -24,28 +24,31 @@ //! servicing [`ChannelMonitor`] updates from the client. use bitcoin::blockdata::block::BlockHeader; -use bitcoin::hash_types::{Txid, BlockHash}; +use bitcoin::hash_types::{BlockHash, Txid}; use crate::chain; -use crate::chain::{ChannelMonitorUpdateStatus, Filter, WatchedOutput}; use crate::chain::chaininterface::{BroadcasterInterface, FeeEstimator}; -use crate::chain::channelmonitor::{ChannelMonitor, ChannelMonitorUpdate, Balance, MonitorEvent, TransactionOutputs, LATENCY_GRACE_PERIOD_BLOCKS}; +use crate::chain::channelmonitor::{ + Balance, ChannelMonitor, ChannelMonitorUpdate, MonitorEvent, TransactionOutputs, + LATENCY_GRACE_PERIOD_BLOCKS, +}; use crate::chain::transaction::{OutPoint, TransactionData}; -use crate::sign::WriteableEcdsaChannelSigner; +use crate::chain::{ChannelMonitorUpdateStatus, Filter, WatchedOutput}; use crate::events; use crate::events::{Event, EventHandler}; +use crate::ln::channelmanager::ChannelDetails; +use crate::sign::WriteableEcdsaChannelSigner; use crate::util::atomic_counter::AtomicCounter; -use crate::util::logger::Logger; use crate::util::errors::APIError; +use crate::util::logger::Logger; use crate::util::wakers::{Future, Notifier}; -use crate::ln::channelmanager::ChannelDetails; use crate::prelude::*; -use crate::sync::{RwLock, RwLockReadGuard, Mutex, MutexGuard}; +use crate::sync::{Mutex, MutexGuard, RwLock, RwLockReadGuard}; +use bitcoin::secp256k1::PublicKey; use core::iter::FromIterator; use core::ops::Deref; use core::sync::atomic::{AtomicUsize, Ordering}; -use bitcoin::secp256k1::PublicKey; mod update_origin { #[derive(Debug, Clone, Copy, Hash, PartialEq, Eq)] @@ -82,7 +85,9 @@ impl MonitorUpdateId { pub(crate) fn from_monitor_update(update: &ChannelMonitorUpdate) -> Self { Self { contents: UpdateOrigin::OffChain(update.update_id) } } - pub(crate) fn from_new_monitor(monitor: &ChannelMonitor) -> Self { + pub(crate) fn from_new_monitor( + monitor: &ChannelMonitor, + ) -> Self { Self { contents: UpdateOrigin::OffChain(monitor.get_latest_update_id()) } } } @@ -158,7 +163,10 @@ pub trait Persist { /// /// [`ChannelManager`]: crate::ln::channelmanager::ChannelManager /// [`Writeable::write`]: crate::util::ser::Writeable::write - fn persist_new_channel(&self, channel_id: OutPoint, data: &ChannelMonitor, update_id: MonitorUpdateId) -> ChannelMonitorUpdateStatus; + fn persist_new_channel( + &self, channel_id: OutPoint, data: &ChannelMonitor, + update_id: MonitorUpdateId, + ) -> ChannelMonitorUpdateStatus; /// Update one channel's data. The provided [`ChannelMonitor`] has already applied the given /// update. @@ -193,7 +201,10 @@ pub trait Persist { /// [`ChannelMonitorUpdateStatus`] for requirements when returning errors. /// /// [`Writeable::write`]: crate::util::ser::Writeable::write - fn update_persisted_channel(&self, channel_id: OutPoint, update: Option<&ChannelMonitorUpdate>, data: &ChannelMonitor, update_id: MonitorUpdateId) -> ChannelMonitorUpdateStatus; + fn update_persisted_channel( + &self, channel_id: OutPoint, update: Option<&ChannelMonitorUpdate>, + data: &ChannelMonitor, update_id: MonitorUpdateId, + ) -> ChannelMonitorUpdateStatus; } struct MonitorHolder { @@ -226,13 +237,27 @@ struct MonitorHolder { } impl MonitorHolder { - fn has_pending_offchain_updates(&self, pending_monitor_updates_lock: &MutexGuard>) -> bool { - pending_monitor_updates_lock.iter().any(|update_id| - if let UpdateOrigin::OffChain(_) = update_id.contents { true } else { false }) + fn has_pending_offchain_updates( + &self, pending_monitor_updates_lock: &MutexGuard>, + ) -> bool { + pending_monitor_updates_lock.iter().any(|update_id| { + if let UpdateOrigin::OffChain(_) = update_id.contents { + true + } else { + false + } + }) } - fn has_pending_chainsync_updates(&self, pending_monitor_updates_lock: &MutexGuard>) -> bool { - pending_monitor_updates_lock.iter().any(|update_id| - if let UpdateOrigin::ChainSync(_) = update_id.contents { true } else { false }) + fn has_pending_chainsync_updates( + &self, pending_monitor_updates_lock: &MutexGuard>, + ) -> bool { + pending_monitor_updates_lock.iter().any(|update_id| { + if let UpdateOrigin::ChainSync(_) = update_id.contents { + true + } else { + false + } + }) } } @@ -268,12 +293,19 @@ impl Deref for LockedChannelMonitor< /// [`ChannelManager`]: crate::ln::channelmanager::ChannelManager /// [module-level documentation]: crate::chain::chainmonitor /// [`rebroadcast_pending_claims`]: Self::rebroadcast_pending_claims -pub struct ChainMonitor - where C::Target: chain::Filter, - T::Target: BroadcasterInterface, - F::Target: FeeEstimator, - L::Target: Logger, - P::Target: Persist, +pub struct ChainMonitor< + ChannelSigner: WriteableEcdsaChannelSigner, + C: Deref, + T: Deref, + F: Deref, + L: Deref, + P: Deref, +> where + C::Target: chain::Filter, + T::Target: BroadcasterInterface, + F::Target: FeeEstimator, + L::Target: Logger, + P::Target: Persist, { monitors: RwLock>>, /// When we generate a [`MonitorUpdateId`] for a chain-event monitor persistence, we need a @@ -294,12 +326,20 @@ pub struct ChainMonitor ChainMonitor -where C::Target: chain::Filter, - T::Target: BroadcasterInterface, - F::Target: FeeEstimator, - L::Target: Logger, - P::Target: Persist, +impl< + ChannelSigner: WriteableEcdsaChannelSigner, + C: Deref, + T: Deref, + F: Deref, + L: Deref, + P: Deref, + > ChainMonitor +where + C::Target: chain::Filter, + T::Target: BroadcasterInterface, + F::Target: FeeEstimator, + L::Target: Logger, + P::Target: Persist, { /// Dispatches to per-channel monitors, which are responsible for updating their on-chain view /// of a channel and reacting accordingly based on transactions in the given chain data. See @@ -312,16 +352,29 @@ where C::Target: chain::Filter, /// updated `txdata`. /// /// Calls which represent a new blockchain tip height should set `best_height`. - fn process_chain_data(&self, header: &BlockHeader, best_height: Option, txdata: &TransactionData, process: FN) - where - FN: Fn(&ChannelMonitor, &TransactionData) -> Vec + fn process_chain_data( + &self, header: &BlockHeader, best_height: Option, txdata: &TransactionData, + process: FN, + ) where + FN: Fn(&ChannelMonitor, &TransactionData) -> Vec, { let err_str = "ChannelMonitor[Update] persistence failed unrecoverably. This indicates we cannot continue normal operation and must shut down."; - let funding_outpoints: HashSet = HashSet::from_iter(self.monitors.read().unwrap().keys().cloned()); + let funding_outpoints: HashSet = + HashSet::from_iter(self.monitors.read().unwrap().keys().cloned()); for funding_outpoint in funding_outpoints.iter() { let monitor_lock = self.monitors.read().unwrap(); if let Some(monitor_state) = monitor_lock.get(funding_outpoint) { - if self.update_monitor_with_chain_data(header, best_height, txdata, &process, funding_outpoint, &monitor_state).is_err() { + if self + .update_monitor_with_chain_data( + header, + best_height, + txdata, + &process, + funding_outpoint, + &monitor_state, + ) + .is_err() + { // Take the monitors lock for writing so that we poison it and any future // operations going forward fail immediately. core::mem::drop(monitor_lock); @@ -336,7 +389,17 @@ where C::Target: chain::Filter, let monitor_states = self.monitors.write().unwrap(); for (funding_outpoint, monitor_state) in monitor_states.iter() { if !funding_outpoints.contains(funding_outpoint) { - if self.update_monitor_with_chain_data(header, best_height, txdata, &process, funding_outpoint, &monitor_state).is_err() { + if self + .update_monitor_with_chain_data( + header, + best_height, + txdata, + &process, + funding_outpoint, + &monitor_state, + ) + .is_err() + { log_error!(self.logger, "{}", err_str); panic!("{}", err_str); } @@ -356,8 +419,11 @@ where C::Target: chain::Filter, fn update_monitor_with_chain_data( &self, header: &BlockHeader, best_height: Option, txdata: &TransactionData, - process: FN, funding_outpoint: &OutPoint, monitor_state: &MonitorHolder - ) -> Result<(), ()> where FN: Fn(&ChannelMonitor, &TransactionData) -> Vec { + process: FN, funding_outpoint: &OutPoint, monitor_state: &MonitorHolder, + ) -> Result<(), ()> + where + FN: Fn(&ChannelMonitor, &TransactionData) -> Vec, + { let monitor = &monitor_state.monitor; let mut txn_outputs; { @@ -371,14 +437,28 @@ where C::Target: chain::Filter, // If there are not ChainSync persists awaiting completion, go ahead and // set last_chain_persist_height here - we wouldn't want the first // InProgress to always immediately be considered "overly delayed". - monitor_state.last_chain_persist_height.store(height as usize, Ordering::Release); + monitor_state + .last_chain_persist_height + .store(height as usize, Ordering::Release); } } - log_trace!(self.logger, "Syncing Channel Monitor for channel {}", log_funding_info!(monitor)); - match self.persister.update_persisted_channel(*funding_outpoint, None, monitor, update_id) { - ChannelMonitorUpdateStatus::Completed => - log_trace!(self.logger, "Finished syncing Channel Monitor for channel {}", log_funding_info!(monitor)), + log_trace!( + self.logger, + "Syncing Channel Monitor for channel {}", + log_funding_info!(monitor) + ); + match self.persister.update_persisted_channel( + *funding_outpoint, + None, + monitor, + update_id, + ) { + ChannelMonitorUpdateStatus::Completed => log_trace!( + self.logger, + "Finished syncing Channel Monitor for channel {}", + log_funding_info!(monitor) + ), ChannelMonitorUpdateStatus::InProgress => { log_debug!(self.logger, "Channel Monitor sync for channel {} in progress, holding events until completion!", log_funding_info!(monitor)); pending_monitor_updates.push(update_id); @@ -415,7 +495,9 @@ where C::Target: chain::Filter, /// pre-filter blocks or only fetch blocks matching a compact filter. Otherwise, clients may /// always need to fetch full blocks absent another means for determining which blocks contain /// transactions relevant to the watched channels. - pub fn new(chain_source: Option, broadcaster: T, logger: L, feeest: F, persister: P) -> Self { + pub fn new( + chain_source: Option, broadcaster: T, logger: L, feeest: F, persister: P, + ) -> Self { Self { monitors: RwLock::new(HashMap::new()), sync_persistence_id: AtomicCounter::new(), @@ -460,7 +542,9 @@ where C::Target: chain::Filter, /// /// Note that the result holds a mutex over our monitor set, and should not be held /// indefinitely. - pub fn get_monitor(&self, funding_txo: OutPoint) -> Result, ()> { + pub fn get_monitor( + &self, funding_txo: OutPoint, + ) -> Result, ()> { let lock = self.monitors.read().unwrap(); if lock.get(&funding_txo).is_some() { Ok(LockedChannelMonitor { lock, funding_txo }) @@ -480,20 +564,29 @@ where C::Target: chain::Filter, #[cfg(not(c_bindings))] /// Lists the pending updates for each [`ChannelMonitor`] (by `OutPoint` being monitored). pub fn list_pending_monitor_updates(&self) -> HashMap> { - self.monitors.read().unwrap().iter().map(|(outpoint, holder)| { - (*outpoint, holder.pending_monitor_updates.lock().unwrap().clone()) - }).collect() + self.monitors + .read() + .unwrap() + .iter() + .map(|(outpoint, holder)| { + (*outpoint, holder.pending_monitor_updates.lock().unwrap().clone()) + }) + .collect() } #[cfg(c_bindings)] /// Lists the pending updates for each [`ChannelMonitor`] (by `OutPoint` being monitored). pub fn list_pending_monitor_updates(&self) -> Vec<(OutPoint, Vec)> { - self.monitors.read().unwrap().iter().map(|(outpoint, holder)| { - (*outpoint, holder.pending_monitor_updates.lock().unwrap().clone()) - }).collect() + self.monitors + .read() + .unwrap() + .iter() + .map(|(outpoint, holder)| { + (*outpoint, holder.pending_monitor_updates.lock().unwrap().clone()) + }) + .collect() } - #[cfg(test)] pub fn remove_monitor(&self, funding_txo: &OutPoint) -> ChannelMonitor { self.monitors.write().unwrap().remove(funding_txo).unwrap().monitor @@ -515,10 +608,16 @@ where C::Target: chain::Filter, /// /// Returns an [`APIError::APIMisuseError`] if `funding_txo` does not match any currently /// registered [`ChannelMonitor`]s. - pub fn channel_monitor_updated(&self, funding_txo: OutPoint, completed_update_id: MonitorUpdateId) -> Result<(), APIError> { + pub fn channel_monitor_updated( + &self, funding_txo: OutPoint, completed_update_id: MonitorUpdateId, + ) -> Result<(), APIError> { let monitors = self.monitors.read().unwrap(); - let monitor_data = if let Some(mon) = monitors.get(&funding_txo) { mon } else { - return Err(APIError::APIMisuseError { err: format!("No ChannelMonitor matching funding outpoint {:?} found", funding_txo) }); + let monitor_data = if let Some(mon) = monitors.get(&funding_txo) { + mon + } else { + return Err(APIError::APIMisuseError { + err: format!("No ChannelMonitor matching funding outpoint {:?} found", funding_txo), + }); }; let mut pending_monitor_updates = monitor_data.pending_monitor_updates.lock().unwrap(); pending_monitor_updates.retain(|update_id| *update_id != completed_update_id); @@ -534,20 +633,28 @@ where C::Target: chain::Filter, // - we can still update our channel state, just as long as we don't return // `MonitorEvent`s from the monitor back to the `ChannelManager` until they // complete. - let monitor_is_pending_updates = monitor_data.has_pending_offchain_updates(&pending_monitor_updates); + let monitor_is_pending_updates = + monitor_data.has_pending_offchain_updates(&pending_monitor_updates); if monitor_is_pending_updates { // If there are still monitor updates pending, we cannot yet construct a // Completed event. return Ok(()); } - self.pending_monitor_events.lock().unwrap().push((funding_txo, vec![MonitorEvent::Completed { + self.pending_monitor_events.lock().unwrap().push(( funding_txo, - monitor_update_id: monitor_data.monitor.get_latest_update_id(), - }], monitor_data.monitor.get_counterparty_node_id())); + vec![MonitorEvent::Completed { + funding_txo, + monitor_update_id: monitor_data.monitor.get_latest_update_id(), + }], + monitor_data.monitor.get_counterparty_node_id(), + )); }, MonitorUpdateId { contents: UpdateOrigin::ChainSync(_) } => { if !monitor_data.has_pending_chainsync_updates(&pending_monitor_updates) { - monitor_data.last_chain_persist_height.store(self.highest_chain_height.load(Ordering::Acquire), Ordering::Release); + monitor_data.last_chain_persist_height.store( + self.highest_chain_height.load(Ordering::Acquire), + Ordering::Release, + ); // The next time release_pending_monitor_events is called, any events for this // ChannelMonitor will be returned. } @@ -563,11 +670,13 @@ where C::Target: chain::Filter, #[cfg(any(test, fuzzing))] pub fn force_channel_monitor_updated(&self, funding_txo: OutPoint, monitor_update_id: u64) { let monitors = self.monitors.read().unwrap(); - let counterparty_node_id = monitors.get(&funding_txo).and_then(|m| m.monitor.get_counterparty_node_id()); - self.pending_monitor_events.lock().unwrap().push((funding_txo, vec![MonitorEvent::Completed { + let counterparty_node_id = + monitors.get(&funding_txo).and_then(|m| m.monitor.get_counterparty_node_id()); + self.pending_monitor_events.lock().unwrap().push(( funding_txo, - monitor_update_id, - }], counterparty_node_id)); + vec![MonitorEvent::Completed { funding_txo, monitor_update_id }], + counterparty_node_id, + )); self.event_notifier.notify(); } @@ -586,8 +695,11 @@ where C::Target: chain::Filter, /// See the trait-level documentation of [`EventsProvider`] for requirements. /// /// [`EventsProvider`]: crate::events::EventsProvider - pub async fn process_pending_events_async Future>( - &self, handler: H + pub async fn process_pending_events_async< + Future: core::future::Future, + H: Fn(Event) -> Future, + >( + &self, handler: H, ) { // Sadly we can't hold the monitors read lock through an async call. Thus we have to do a // crazy dance to process a monitor's events then only remove them once we've done so. @@ -595,7 +707,10 @@ where C::Target: chain::Filter, for funding_txo in mons_to_process { let mut ev; super::channelmonitor::process_events_body!( - self.monitors.read().unwrap().get(&funding_txo).map(|m| &m.monitor), ev, handler(ev).await); + self.monitors.read().unwrap().get(&funding_txo).map(|m| &m.monitor), + ev, + handler(ev).await + ); } } @@ -620,14 +735,22 @@ where C::Target: chain::Filter, let monitors = self.monitors.read().unwrap(); for (_, monitor_holder) in &*monitors { monitor_holder.monitor.rebroadcast_pending_claims( - &*self.broadcaster, &*self.fee_estimator, &*self.logger + &*self.broadcaster, + &*self.fee_estimator, + &*self.logger, ) } } } -impl -chain::Listen for ChainMonitor +impl< + ChannelSigner: WriteableEcdsaChannelSigner, + C: Deref, + T: Deref, + F: Deref, + L: Deref, + P: Deref, + > chain::Listen for ChainMonitor where C::Target: chain::Filter, T::Target: BroadcasterInterface, @@ -635,26 +758,55 @@ where L::Target: Logger, P::Target: Persist, { - fn filtered_block_connected(&self, header: &BlockHeader, txdata: &TransactionData, height: u32) { - log_debug!(self.logger, "New best block {} at height {} provided via block_connected", header.block_hash(), height); + fn filtered_block_connected( + &self, header: &BlockHeader, txdata: &TransactionData, height: u32, + ) { + log_debug!( + self.logger, + "New best block {} at height {} provided via block_connected", + header.block_hash(), + height + ); self.process_chain_data(header, Some(height), &txdata, |monitor, txdata| { monitor.block_connected( - header, txdata, height, &*self.broadcaster, &*self.fee_estimator, &*self.logger) + header, + txdata, + height, + &*self.broadcaster, + &*self.fee_estimator, + &*self.logger, + ) }); } fn block_disconnected(&self, header: &BlockHeader, height: u32) { let monitor_states = self.monitors.read().unwrap(); - log_debug!(self.logger, "Latest block {} at height {} removed via block_disconnected", header.block_hash(), height); + log_debug!( + self.logger, + "Latest block {} at height {} removed via block_disconnected", + header.block_hash(), + height + ); for monitor_state in monitor_states.values() { monitor_state.monitor.block_disconnected( - header, height, &*self.broadcaster, &*self.fee_estimator, &*self.logger); + header, + height, + &*self.broadcaster, + &*self.fee_estimator, + &*self.logger, + ); } } } -impl -chain::Confirm for ChainMonitor +impl< + ChannelSigner: WriteableEcdsaChannelSigner, + C: Deref, + T: Deref, + F: Deref, + L: Deref, + P: Deref, + > chain::Confirm for ChainMonitor where C::Target: chain::Filter, T::Target: BroadcasterInterface, @@ -663,10 +815,22 @@ where P::Target: Persist, { fn transactions_confirmed(&self, header: &BlockHeader, txdata: &TransactionData, height: u32) { - log_debug!(self.logger, "{} provided transactions confirmed at height {} in block {}", txdata.len(), height, header.block_hash()); + log_debug!( + self.logger, + "{} provided transactions confirmed at height {} in block {}", + txdata.len(), + height, + header.block_hash() + ); self.process_chain_data(header, None, txdata, |monitor, txdata| { monitor.transactions_confirmed( - header, txdata, height, &*self.broadcaster, &*self.fee_estimator, &*self.logger) + header, + txdata, + height, + &*self.broadcaster, + &*self.fee_estimator, + &*self.logger, + ) }); } @@ -674,18 +838,33 @@ where log_debug!(self.logger, "Transaction {} reorganized out of chain", txid); let monitor_states = self.monitors.read().unwrap(); for monitor_state in monitor_states.values() { - monitor_state.monitor.transaction_unconfirmed(txid, &*self.broadcaster, &*self.fee_estimator, &*self.logger); + monitor_state.monitor.transaction_unconfirmed( + txid, + &*self.broadcaster, + &*self.fee_estimator, + &*self.logger, + ); } } fn best_block_updated(&self, header: &BlockHeader, height: u32) { - log_debug!(self.logger, "New best block {} at height {} provided via best_block_updated", header.block_hash(), height); + log_debug!( + self.logger, + "New best block {} at height {} provided via best_block_updated", + header.block_hash(), + height + ); self.process_chain_data(header, Some(height), &[], |monitor, txdata| { // While in practice there shouldn't be any recursive calls when given empty txdata, // it's still possible if a chain::Filter implementation returns a transaction. debug_assert!(txdata.is_empty()); monitor.best_block_updated( - header, height, &*self.broadcaster, &*self.fee_estimator, &*self.logger) + header, + height, + &*self.broadcaster, + &*self.fee_estimator, + &*self.logger, + ) }); } @@ -702,15 +881,24 @@ where } } -impl -chain::Watch for ChainMonitor -where C::Target: chain::Filter, - T::Target: BroadcasterInterface, - F::Target: FeeEstimator, - L::Target: Logger, - P::Target: Persist, +impl< + ChannelSigner: WriteableEcdsaChannelSigner, + C: Deref, + T: Deref, + F: Deref, + L: Deref, + P: Deref, + > chain::Watch for ChainMonitor +where + C::Target: chain::Filter, + T::Target: BroadcasterInterface, + F::Target: FeeEstimator, + L::Target: Logger, + P::Target: Persist, { - fn watch_channel(&self, funding_outpoint: OutPoint, monitor: ChannelMonitor) -> Result { + fn watch_channel( + &self, funding_outpoint: OutPoint, monitor: ChannelMonitor, + ) -> Result { let mut monitors = self.monitors.write().unwrap(); let entry = match monitors.entry(funding_outpoint) { hash_map::Entry::Occupied(_) => { @@ -719,17 +907,29 @@ where C::Target: chain::Filter, }, hash_map::Entry::Vacant(e) => e, }; - log_trace!(self.logger, "Got new ChannelMonitor for channel {}", log_funding_info!(monitor)); + log_trace!( + self.logger, + "Got new ChannelMonitor for channel {}", + log_funding_info!(monitor) + ); let update_id = MonitorUpdateId::from_new_monitor(&monitor); let mut pending_monitor_updates = Vec::new(); let persist_res = self.persister.persist_new_channel(funding_outpoint, &monitor, update_id); match persist_res { ChannelMonitorUpdateStatus::InProgress => { - log_info!(self.logger, "Persistence of new ChannelMonitor for channel {} in progress", log_funding_info!(monitor)); + log_info!( + self.logger, + "Persistence of new ChannelMonitor for channel {} in progress", + log_funding_info!(monitor) + ); pending_monitor_updates.push(update_id); }, ChannelMonitorUpdateStatus::Completed => { - log_info!(self.logger, "Persistence of new ChannelMonitor for channel {} completed", log_funding_info!(monitor)); + log_info!( + self.logger, + "Persistence of new ChannelMonitor for channel {} completed", + log_funding_info!(monitor) + ); }, ChannelMonitorUpdateStatus::UnrecoverableError => { let err_str = "ChannelMonitor[Update] persistence failed unrecoverably. This indicates we cannot continue normal operation and must shut down."; @@ -743,17 +943,24 @@ where C::Target: chain::Filter, entry.insert(MonitorHolder { monitor, pending_monitor_updates: Mutex::new(pending_monitor_updates), - last_chain_persist_height: AtomicUsize::new(self.highest_chain_height.load(Ordering::Acquire)), + last_chain_persist_height: AtomicUsize::new( + self.highest_chain_height.load(Ordering::Acquire), + ), }); Ok(persist_res) } - fn update_channel(&self, funding_txo: OutPoint, update: &ChannelMonitorUpdate) -> ChannelMonitorUpdateStatus { + fn update_channel( + &self, funding_txo: OutPoint, update: &ChannelMonitorUpdate, + ) -> ChannelMonitorUpdateStatus { // Update the monitor that watches the channel referred to by the given outpoint. let monitors = self.monitors.read().unwrap(); let ret = match monitors.get(&funding_txo) { None => { - log_error!(self.logger, "Failed to update channel monitor: no such monitor registered"); + log_error!( + self.logger, + "Failed to update channel monitor: no such monitor registered" + ); // We should never ever trigger this from within ChannelManager. Technically a // user could use this object with some proxying in between which makes this @@ -765,11 +972,21 @@ where C::Target: chain::Filter, }, Some(monitor_state) => { let monitor = &monitor_state.monitor; - log_trace!(self.logger, "Updating ChannelMonitor for channel {}", log_funding_info!(monitor)); - let update_res = monitor.update_monitor(update, &self.broadcaster, &self.fee_estimator, &self.logger); + log_trace!( + self.logger, + "Updating ChannelMonitor for channel {}", + log_funding_info!(monitor) + ); + let update_res = monitor.update_monitor( + update, + &self.broadcaster, + &self.fee_estimator, + &self.logger, + ); let update_id = MonitorUpdateId::from_monitor_update(update); - let mut pending_monitor_updates = monitor_state.pending_monitor_updates.lock().unwrap(); + let mut pending_monitor_updates = + monitor_state.pending_monitor_updates.lock().unwrap(); let persist_res = if update_res.is_err() { // Even if updating the monitor returns an error, the monitor's state will // still be changed. Therefore, we should persist the updated monitor despite the error. @@ -779,24 +996,38 @@ where C::Target: chain::Filter, log_warn!(self.logger, "Failed to update ChannelMonitor for channel {}. Going ahead and persisting the entire ChannelMonitor", log_funding_info!(monitor)); self.persister.update_persisted_channel(funding_txo, None, monitor, update_id) } else { - self.persister.update_persisted_channel(funding_txo, Some(update), monitor, update_id) + self.persister.update_persisted_channel( + funding_txo, + Some(update), + monitor, + update_id, + ) }; match persist_res { ChannelMonitorUpdateStatus::InProgress => { pending_monitor_updates.push(update_id); - log_debug!(self.logger, "Persistence of ChannelMonitorUpdate for channel {} in progress", log_funding_info!(monitor)); + log_debug!( + self.logger, + "Persistence of ChannelMonitorUpdate for channel {} in progress", + log_funding_info!(monitor) + ); }, ChannelMonitorUpdateStatus::Completed => { - log_debug!(self.logger, "Persistence of ChannelMonitorUpdate for channel {} completed", log_funding_info!(monitor)); + log_debug!( + self.logger, + "Persistence of ChannelMonitorUpdate for channel {} completed", + log_funding_info!(monitor) + ); + }, + ChannelMonitorUpdateStatus::UnrecoverableError => { /* we'll panic in a moment */ }, - ChannelMonitorUpdateStatus::UnrecoverableError => { /* we'll panic in a moment */ }, } if update_res.is_err() { ChannelMonitorUpdateStatus::InProgress } else { persist_res } - } + }, }; if let ChannelMonitorUpdateStatus::UnrecoverableError = ret { // Take the monitors lock for writing so that we poison it and any future @@ -810,26 +1041,45 @@ where C::Target: chain::Filter, ret } - fn release_pending_monitor_events(&self) -> Vec<(OutPoint, Vec, Option)> { + fn release_pending_monitor_events( + &self, + ) -> Vec<(OutPoint, Vec, Option)> { let mut pending_monitor_events = self.pending_monitor_events.lock().unwrap().split_off(0); for monitor_state in self.monitors.read().unwrap().values() { - let is_pending_monitor_update = monitor_state.has_pending_chainsync_updates(&monitor_state.pending_monitor_updates.lock().unwrap()); - if is_pending_monitor_update && - monitor_state.last_chain_persist_height.load(Ordering::Acquire) + LATENCY_GRACE_PERIOD_BLOCKS as usize - > self.highest_chain_height.load(Ordering::Acquire) + let is_pending_monitor_update = monitor_state.has_pending_chainsync_updates( + &monitor_state.pending_monitor_updates.lock().unwrap(), + ); + if is_pending_monitor_update + && monitor_state.last_chain_persist_height.load(Ordering::Acquire) + + LATENCY_GRACE_PERIOD_BLOCKS as usize + > self.highest_chain_height.load(Ordering::Acquire) { log_debug!(self.logger, "A Channel Monitor sync is still in progress, refusing to provide monitor events!"); } else { if is_pending_monitor_update { - log_error!(self.logger, "A ChannelMonitor sync took longer than {} blocks to complete.", LATENCY_GRACE_PERIOD_BLOCKS); - log_error!(self.logger, " To avoid funds-loss, we are allowing monitor updates to be released."); - log_error!(self.logger, " This may cause duplicate payment events to be generated."); + log_error!( + self.logger, + "A ChannelMonitor sync took longer than {} blocks to complete.", + LATENCY_GRACE_PERIOD_BLOCKS + ); + log_error!( + self.logger, + " To avoid funds-loss, we are allowing monitor updates to be released." + ); + log_error!( + self.logger, + " This may cause duplicate payment events to be generated." + ); } let monitor_events = monitor_state.monitor.get_and_clear_pending_monitor_events(); if monitor_events.len() > 0 { let monitor_outpoint = monitor_state.monitor.get_funding_txo().0; let counterparty_node_id = monitor_state.monitor.get_counterparty_node_id(); - pending_monitor_events.push((monitor_outpoint, monitor_events, counterparty_node_id)); + pending_monitor_events.push(( + monitor_outpoint, + monitor_events, + counterparty_node_id, + )); } } } @@ -837,12 +1087,20 @@ where C::Target: chain::Filter, } } -impl events::EventsProvider for ChainMonitor - where C::Target: chain::Filter, - T::Target: BroadcasterInterface, - F::Target: FeeEstimator, - L::Target: Logger, - P::Target: Persist, +impl< + ChannelSigner: WriteableEcdsaChannelSigner, + C: Deref, + T: Deref, + F: Deref, + L: Deref, + P: Deref, + > events::EventsProvider for ChainMonitor +where + C::Target: chain::Filter, + T::Target: BroadcasterInterface, + F::Target: FeeEstimator, + L::Target: Logger, + P::Target: Persist, { /// Processes [`SpendableOutputs`] events produced from each [`ChannelMonitor`] upon maturity. /// @@ -857,7 +1115,10 @@ impl(&self, handler: H) where H::Target: EventHandler { + fn process_pending_events(&self, handler: H) + where + H::Target: EventHandler, + { for monitor_state in self.monitors.read().unwrap().values() { monitor_state.monitor.process_pending_events(&handler); } @@ -866,16 +1127,19 @@ impl())); + let mut updates = Vec::with_capacity(cmp::min( + len as usize, + MAX_ALLOC_SIZE / ::core::mem::size_of::(), + )); for _ in 0..len { if let Some(upd) = MaybeReadable::read(r)? { updates.push(upd); @@ -261,14 +277,16 @@ impl_writeable_tlv_based!(HolderSignedTx, { impl HolderSignedTx { fn non_dust_htlcs(&self) -> Vec { - self.htlc_outputs.iter().filter_map(|(htlc, _, _)| { - if let Some(_) = htlc.transaction_output_index { - Some(htlc.clone()) - } else { - None - } - }) - .collect() + self.htlc_outputs + .iter() + .filter_map(|(htlc, _, _)| { + if let Some(_) = htlc.transaction_output_index { + Some(htlc.clone()) + } else { + None + } + }) + .collect() } } @@ -298,7 +316,7 @@ impl Readable for CounterpartyCommitmentParameters { // Versions prior to 0.0.100 had some per-HTLC state stored here, which is no longer // used. Read it for compatibility. let per_htlc_len: u64 = Readable::read(r)?; - for _ in 0..per_htlc_len { + for _ in 0..per_htlc_len { let _txid: Txid = Readable::read(r)?; let htlcs_count: u64 = Readable::read(r)?; for _ in 0..htlcs_count { @@ -315,7 +333,9 @@ impl Readable for CounterpartyCommitmentParameters { (4, on_counterparty_tx_csv, required), }); CounterpartyCommitmentParameters { - counterparty_delayed_payment_base_key: counterparty_delayed_payment_base_key.0.unwrap(), + counterparty_delayed_payment_base_key: counterparty_delayed_payment_base_key + .0 + .unwrap(), counterparty_htlc_base_key: counterparty_htlc_base_key.0.unwrap(), on_counterparty_tx_csv, } @@ -342,14 +362,15 @@ impl OnchainEventEntry { let mut conf_threshold = self.height + ANTI_REORG_DELAY - 1; match self.event { OnchainEvent::MaturingOutput { - descriptor: SpendableOutputDescriptor::DelayedPaymentOutput(ref descriptor) + descriptor: SpendableOutputDescriptor::DelayedPaymentOutput(ref descriptor), } => { // A CSV'd transaction is confirmable in block (input height) + CSV delay, which means // it's broadcastable when we see the previous block. - conf_threshold = cmp::max(conf_threshold, self.height + descriptor.to_self_delay as u32 - 1); + conf_threshold = + cmp::max(conf_threshold, self.height + descriptor.to_self_delay as u32 - 1); }, - OnchainEvent::FundingSpendConfirmation { on_local_output_csv: Some(csv), .. } | - OnchainEvent::HTLCSpendConfirmation { on_to_local_output_csv: Some(csv), .. } => { + OnchainEvent::FundingSpendConfirmation { on_local_output_csv: Some(csv), .. } + | OnchainEvent::HTLCSpendConfirmation { on_to_local_output_csv: Some(csv), .. } => { // A CSV'd transaction is confirmable in block (input height) + CSV delay, which means // it's broadcastable when we see the previous block. conf_threshold = cmp::max(conf_threshold, self.height + csv as u32 - 1); @@ -389,9 +410,7 @@ enum OnchainEvent { }, /// An output waiting on [`ANTI_REORG_DELAY`] confirmations before we hand the user the /// [`SpendableOutputDescriptor`]. - MaturingOutput { - descriptor: SpendableOutputDescriptor, - }, + MaturingOutput { descriptor: SpendableOutputDescriptor }, /// A spend of the funding output, either a commitment transaction or a cooperative closing /// transaction. FundingSpendConfirmation { @@ -454,7 +473,13 @@ impl MaybeReadable for OnchainEventEntry { (3, block_hash, option), (4, event, upgradable_required), }); - Ok(Some(Self { txid, transaction, height, block_hash, event: _init_tlv_based_struct_field!(event, upgradable_required) })) + Ok(Some(Self { + txid, + transaction, + height, + block_hash, + event: _init_tlv_based_struct_field!(event, upgradable_required), + })) } } @@ -522,8 +547,12 @@ pub(crate) enum ChannelMonitorUpdateStep { impl ChannelMonitorUpdateStep { fn variant_name(&self) -> &'static str { match self { - ChannelMonitorUpdateStep::LatestHolderCommitmentTXInfo { .. } => "LatestHolderCommitmentTXInfo", - ChannelMonitorUpdateStep::LatestCounterpartyCommitmentTXInfo { .. } => "LatestCounterpartyCommitmentTXInfo", + ChannelMonitorUpdateStep::LatestHolderCommitmentTXInfo { .. } => { + "LatestHolderCommitmentTXInfo" + }, + ChannelMonitorUpdateStep::LatestCounterpartyCommitmentTXInfo { .. } => { + "LatestCounterpartyCommitmentTXInfo" + }, ChannelMonitorUpdateStep::PaymentPreimage { .. } => "PaymentPreimage", ChannelMonitorUpdateStep::CommitmentSecret { .. } => "CommitmentSecret", ChannelMonitorUpdateStep::ChannelForceClosed { .. } => "ChannelForceClosed", @@ -656,14 +685,12 @@ impl Balance { /// On-chain fees required to claim the balance are not included in this amount. pub fn claimable_amount_satoshis(&self) -> u64 { match self { - Balance::ClaimableOnChannelClose { amount_satoshis, .. }| - Balance::ClaimableAwaitingConfirmations { amount_satoshis, .. }| - Balance::ContentiousClaimable { amount_satoshis, .. }| - Balance::CounterpartyRevokedOutputClaimable { amount_satoshis, .. } - => *amount_satoshis, - Balance::MaybeTimeoutClaimableHTLC { .. }| - Balance::MaybePreimageClaimableHTLC { .. } - => 0, + Balance::ClaimableOnChannelClose { amount_satoshis, .. } + | Balance::ClaimableAwaitingConfirmations { amount_satoshis, .. } + | Balance::ContentiousClaimable { amount_satoshis, .. } + | Balance::CounterpartyRevokedOutputClaimable { amount_satoshis, .. } => *amount_satoshis, + Balance::MaybeTimeoutClaimableHTLC { .. } + | Balance::MaybePreimageClaimableHTLC { .. } => 0, } } } @@ -687,7 +714,8 @@ struct IrrevocablyResolvedHTLC { // using `u32::max_value()` as a sentinal to indicate the HTLC was dust. impl Writeable for IrrevocablyResolvedHTLC { fn write(&self, writer: &mut W) -> Result<(), io::Error> { - let mapped_commitment_tx_output_idx = self.commitment_tx_output_idx.unwrap_or(u32::max_value()); + let mapped_commitment_tx_output_idx = + self.commitment_tx_output_idx.unwrap_or(u32::max_value()); write_tlv_fields!(writer, { (0, mapped_commitment_tx_output_idx, required), (1, self.resolving_txid, option), @@ -711,7 +739,11 @@ impl Readable for IrrevocablyResolvedHTLC { (3, resolving_tx, option), }); Ok(Self { - commitment_tx_output_idx: if mapped_commitment_tx_output_idx == u32::max_value() { None } else { Some(mapped_commitment_tx_output_idx) }, + commitment_tx_output_idx: if mapped_commitment_tx_output_idx == u32::max_value() { + None + } else { + Some(mapped_commitment_tx_output_idx) + }, resolving_txid, payment_preimage, resolving_tx, @@ -737,7 +769,10 @@ pub struct ChannelMonitor { pub(super) inner: Mutex>, } -impl Clone for ChannelMonitor where Signer: Clone { +impl Clone for ChannelMonitor +where + Signer: Clone, +{ fn clone(&self) -> Self { let inner = self.inner.lock().unwrap().clone(); ChannelMonitor::from_impl(inner) @@ -772,7 +807,8 @@ pub(crate) struct ChannelMonitorImpl { /// The set of outpoints in each counterparty commitment transaction. We always need at least /// the payment hash from `HTLCOutputInCommitment` to claim even a revoked commitment /// transaction broadcast as we need to be able to construct the witness script in all cases. - counterparty_claimable_outpoints: HashMap>)>>, + counterparty_claimable_outpoints: + HashMap>)>>, /// We cannot identify HTLC-Success or HTLC-Timeout transactions by themselves on the chain. /// Nor can we figure out their commitment numbers without the commitment transaction they are /// spending. Thus, in order to claim them via revocation key, we track all the counterparty @@ -895,14 +931,25 @@ pub(crate) struct ChannelMonitorImpl { /// Transaction outputs to watch for on-chain spends. pub type TransactionOutputs = (Txid, Vec<(u32, TxOut)>); -impl PartialEq for ChannelMonitor where Signer: PartialEq { +impl PartialEq for ChannelMonitor +where + Signer: PartialEq, +{ fn eq(&self, other: &Self) -> bool { // We need some kind of total lockorder. Absent a better idea, we sort by position in // memory and take locks in that order (assuming that we can't move within memory while a // lock is held). let ord = ((self as *const _) as usize) < ((other as *const _) as usize); - let a = if ord { self.inner.unsafe_well_ordered_double_lock_self() } else { other.inner.unsafe_well_ordered_double_lock_self() }; - let b = if ord { other.inner.unsafe_well_ordered_double_lock_self() } else { self.inner.unsafe_well_ordered_double_lock_self() }; + let a = if ord { + self.inner.unsafe_well_ordered_double_lock_self() + } else { + other.inner.unsafe_well_ordered_double_lock_self() + }; + let b = if ord { + other.inner.unsafe_well_ordered_double_lock_self() + } else { + self.inner.unsafe_well_ordered_double_lock_self() + }; a.eq(&b) } } @@ -927,7 +974,9 @@ impl Writeable for ChannelMonitorImpl Writeable for ChannelMonitorImpl Writeable for ChannelMonitorImpl Writeable for ChannelMonitorImpl Writeable for ChannelMonitorImpl true, - MonitorEvent::HolderForceClosed(_) => true, - _ => false, - }).count() as u64).to_be_bytes())?; + writer.write_all( + &(self + .pending_monitor_events + .iter() + .filter(|ev| match ev { + MonitorEvent::HTLCEvent(_) => true, + MonitorEvent::HolderForceClosed(_) => true, + _ => false, + }) + .count() as u64) + .to_be_bytes(), + )?; for event in self.pending_monitor_events.iter() { match event { MonitorEvent::HTLCEvent(upd) => { @@ -1052,7 +1114,8 @@ impl Writeable for ChannelMonitorImpl ChannelMonitor { ChannelMonitor { inner: Mutex::new(imp) } } - pub(crate) fn new(secp_ctx: Secp256k1, keys: Signer, shutdown_script: Option