diff --git a/src/downloader.rs b/src/downloader.rs index 5cc46e653..7df0bd504 100644 --- a/src/downloader.rs +++ b/src/downloader.rs @@ -1,7 +1,6 @@ //! Handle downloading blobs and collections concurrently and from nodes. //! //! The [`Downloader`] interacts with four main components to this end. -//! - [`Dialer`]: Used to queue opening connections to nodes we need to perform downloads. //! - `ProviderMap`: Where the downloader obtains information about nodes that could be //! used to perform a download. //! - [`Store`]: Where data is stored. @@ -10,7 +9,7 @@ //! 1. The `ProviderMap` is queried for nodes. From these nodes some are selected //! prioritizing connected nodes with lower number of active requests. If no useful node is //! connected, or useful connected nodes have no capacity to perform the request, a connection -//! attempt is started using the [`Dialer`]. +//! attempt is started using the `DialerT`. //! 2. The download is queued for processing at a later time. Downloads are not performed right //! away. Instead, they are initially delayed to allow the node to obtain the data itself, and //! to wait for the new connection to be established if necessary. @@ -34,13 +33,16 @@ use std::{ fmt, future::Future, num::NonZeroUsize, + pin::Pin, sync::{ atomic::{AtomicU64, Ordering}, Arc, }, + task::Poll, time::Duration, }; +use anyhow::anyhow; use futures_lite::{future::BoxedLocal, Stream, StreamExt}; use hashlink::LinkedHashSet; use iroh::{endpoint, Endpoint, NodeAddr, NodeId}; @@ -51,7 +53,7 @@ use tokio::{ task::JoinSet, }; use tokio_util::{either::Either, sync::CancellationToken, time::delay_queue}; -use tracing::{debug, error_span, trace, warn, Instrument}; +use tracing::{debug, error, error_span, trace, warn, Instrument}; use crate::{ get::{db::DownloadProgress, Stats}, @@ -77,7 +79,7 @@ const SERVICE_CHANNEL_CAPACITY: usize = 128; pub struct IntentId(pub u64); /// Trait modeling a dialer. This allows for IO-less testing. -pub trait Dialer: Stream)> + Unpin { +trait DialerT: Stream)> + Unpin { /// Type of connections returned by the Dialer. type Connection: Clone + 'static; /// Dial a node. @@ -354,7 +356,7 @@ impl Downloader { { let me = endpoint.node_id().fmt_short(); let (msg_tx, msg_rx) = mpsc::channel(SERVICE_CHANNEL_CAPACITY); - let dialer = iroh::dialer::Dialer::new(endpoint); + let dialer = Dialer::new(endpoint); let create_future = move || { let getter = get::IoGetter { @@ -532,7 +534,7 @@ enum NodeState<'a, Conn> { } #[derive(Debug)] -struct Service { +struct Service { /// The getter performs individual requests. getter: G, /// Map to query for nodes that we believe have the data we are looking for. @@ -564,7 +566,7 @@ struct Service { /// Progress tracker progress_tracker: ProgressTracker, } -impl, D: Dialer> Service { +impl, D: DialerT> Service { fn new( getter: G, dialer: D, @@ -1492,7 +1494,7 @@ impl Queue { } } -impl Dialer for iroh::dialer::Dialer { +impl DialerT for Dialer { type Connection = endpoint::Connection; fn queue_dial(&mut self, node_id: NodeId) { @@ -1511,3 +1513,81 @@ impl Dialer for iroh::dialer::Dialer { self.endpoint().node_id() } } + +/// Dials nodes and maintains a queue of pending dials. +/// +/// The [`Dialer`] wraps an [`Endpoint`], connects to nodes through the endpoint, stores the +/// pending connect futures and emits finished connect results. +/// +/// The [`Dialer`] also implements [`Stream`] to retrieve the dialled connections. +#[derive(Debug)] +struct Dialer { + endpoint: Endpoint, + pending: JoinSet<(NodeId, anyhow::Result)>, + pending_dials: HashMap, +} + +impl Dialer { + /// Create a new dialer for a [`Endpoint`] + fn new(endpoint: Endpoint) -> Self { + Self { + endpoint, + pending: Default::default(), + pending_dials: Default::default(), + } + } + + /// Starts to dial a node by [`NodeId`]. + fn queue_dial(&mut self, node_id: NodeId, alpn: &'static [u8]) { + if self.is_pending(node_id) { + return; + } + let cancel = CancellationToken::new(); + self.pending_dials.insert(node_id, cancel.clone()); + let endpoint = self.endpoint.clone(); + self.pending.spawn(async move { + let res = tokio::select! { + biased; + _ = cancel.cancelled() => Err(anyhow!("Cancelled")), + res = endpoint.connect(node_id, alpn) => res + }; + (node_id, res) + }); + } + + /// Checks if a node is currently being dialed. + fn is_pending(&self, node: NodeId) -> bool { + self.pending_dials.contains_key(&node) + } + + /// Number of pending connections to be opened. + fn pending_count(&self) -> usize { + self.pending_dials.len() + } + + /// Returns a reference to the endpoint used in this dialer. + fn endpoint(&self) -> &Endpoint { + &self.endpoint + } +} + +impl Stream for Dialer { + type Item = (NodeId, anyhow::Result); + + fn poll_next( + mut self: Pin<&mut Self>, + cx: &mut std::task::Context<'_>, + ) -> Poll> { + match self.pending.poll_join_next(cx) { + Poll::Ready(Some(Ok((node_id, result)))) => { + self.pending_dials.remove(&node_id); + Poll::Ready(Some((node_id, result))) + } + Poll::Ready(Some(Err(e))) => { + error!("dialer error: {:?}", e); + Poll::Pending + } + _ => Poll::Pending, + } + } +} diff --git a/src/downloader/invariants.rs b/src/downloader/invariants.rs index 0409e3d92..6d49b3497 100644 --- a/src/downloader/invariants.rs +++ b/src/downloader/invariants.rs @@ -5,7 +5,7 @@ use super::*; /// invariants for the service. -impl, D: Dialer> Service { +impl, D: DialerT> Service { /// Checks the various invariants the service must maintain #[track_caller] pub(in crate::downloader) fn check_invariants(&self) { diff --git a/src/downloader/test/dialer.rs b/src/downloader/test/dialer.rs index fc5a93995..5124e2d3d 100644 --- a/src/downloader/test/dialer.rs +++ b/src/downloader/test/dialer.rs @@ -38,7 +38,7 @@ impl Default for TestingDialerInner { } } -impl Dialer for TestingDialer { +impl DialerT for TestingDialer { type Connection = NodeId; fn queue_dial(&mut self, node_id: NodeId) {