@@ -4,26 +4,26 @@ use std::{
44 fmt:: Debug ,
55 future:: { Future , IntoFuture } ,
66 io,
7- ops:: Deref ,
87 sync:: Arc ,
9- time:: { Duration , SystemTime } ,
108} ;
119
1210use anyhow:: bail;
1311use genawaiter:: sync:: Gen ;
14- use iroh:: { endpoint :: Connection , Endpoint , NodeId } ;
12+ use iroh:: { Endpoint , NodeId } ;
1513use irpc:: { channel:: mpsc, rpc_requests} ;
1614use n0_future:: { future, stream, BufferedStreamExt , Stream , StreamExt } ;
1715use rand:: seq:: SliceRandom ;
1816use serde:: { de:: Error , Deserialize , Serialize } ;
19- use tokio:: { sync:: Mutex , task:: JoinSet } ;
20- use tokio_util:: time:: FutureExt ;
21- use tracing:: { info, instrument:: Instrument , warn} ;
17+ use tokio:: task:: JoinSet ;
18+ use tracing:: instrument:: Instrument ;
2219
23- use super :: { remote :: GetConnection , Store } ;
20+ use super :: Store ;
2421use crate :: {
2522 protocol:: { GetManyRequest , GetRequest } ,
26- util:: sink:: { Drain , IrpcSenderRefSink , Sink , TokioMpscSenderSink } ,
23+ util:: {
24+ connection_pool:: ConnectionPool ,
25+ sink:: { Drain , IrpcSenderRefSink , Sink , TokioMpscSenderSink } ,
26+ } ,
2727 BlobFormat , Hash , HashAndFormat ,
2828} ;
2929
@@ -69,7 +69,7 @@ impl DownloaderActor {
6969 fn new ( store : Store , endpoint : Endpoint ) -> Self {
7070 Self {
7171 store,
72- pool : ConnectionPool :: new ( endpoint, crate :: ALPN . to_vec ( ) ) ,
72+ pool : ConnectionPool :: new ( endpoint, crate :: ALPN , Default :: default ( ) ) ,
7373 tasks : JoinSet :: new ( ) ,
7474 running : HashSet :: new ( ) ,
7575 }
@@ -414,90 +414,6 @@ async fn split_request<'a>(
414414 } )
415415}
416416
417- #[ derive( Debug ) ]
418- struct ConnectionPoolInner {
419- alpn : Vec < u8 > ,
420- endpoint : Endpoint ,
421- connections : Mutex < HashMap < NodeId , Arc < Mutex < SlotState > > > > ,
422- retry_delay : Duration ,
423- connect_timeout : Duration ,
424- }
425-
426- #[ derive( Debug , Clone ) ]
427- struct ConnectionPool ( Arc < ConnectionPoolInner > ) ;
428-
429- #[ derive( Debug , Default ) ]
430- enum SlotState {
431- #[ default]
432- Initial ,
433- Connected ( Connection ) ,
434- AttemptFailed ( SystemTime ) ,
435- #[ allow( dead_code) ]
436- Evil ( String ) ,
437- }
438-
439- impl ConnectionPool {
440- fn new ( endpoint : Endpoint , alpn : Vec < u8 > ) -> Self {
441- Self (
442- ConnectionPoolInner {
443- endpoint,
444- alpn,
445- connections : Default :: default ( ) ,
446- retry_delay : Duration :: from_secs ( 5 ) ,
447- connect_timeout : Duration :: from_secs ( 2 ) ,
448- }
449- . into ( ) ,
450- )
451- }
452-
453- pub fn alpn ( & self ) -> & [ u8 ] {
454- & self . 0 . alpn
455- }
456-
457- pub fn endpoint ( & self ) -> & Endpoint {
458- & self . 0 . endpoint
459- }
460-
461- pub fn retry_delay ( & self ) -> Duration {
462- self . 0 . retry_delay
463- }
464-
465- fn dial ( & self , id : NodeId ) -> DialNode {
466- DialNode {
467- pool : self . clone ( ) ,
468- id,
469- }
470- }
471-
472- #[ allow( dead_code) ]
473- async fn mark_evil ( & self , id : NodeId , reason : String ) {
474- let slot = self
475- . 0
476- . connections
477- . lock ( )
478- . await
479- . entry ( id)
480- . or_default ( )
481- . clone ( ) ;
482- let mut t = slot. lock ( ) . await ;
483- * t = SlotState :: Evil ( reason)
484- }
485-
486- #[ allow( dead_code) ]
487- async fn mark_closed ( & self , id : NodeId ) {
488- let slot = self
489- . 0
490- . connections
491- . lock ( )
492- . await
493- . entry ( id)
494- . or_default ( )
495- . clone ( ) ;
496- let mut t = slot. lock ( ) . await ;
497- * t = SlotState :: Initial
498- }
499- }
500-
501417/// Execute a get request sequentially for multiple providers.
502418///
503419/// It will try each provider in order
@@ -526,13 +442,13 @@ async fn execute_get(
526442 request : request. clone ( ) ,
527443 } )
528444 . await ?;
529- let mut conn = pool. dial ( provider) ;
445+ let conn = pool. get_or_connect ( provider) ;
530446 let local = remote. local_for_request ( request. clone ( ) ) . await ?;
531447 if local. is_complete ( ) {
532448 return Ok ( ( ) ) ;
533449 }
534450 let local_bytes = local. local_bytes ( ) ;
535- let Ok ( conn) = conn. connection ( ) . await else {
451+ let Ok ( conn) = conn. await else {
536452 progress
537453 . send ( DownloadProgessItem :: ProviderFailed {
538454 id : provider,
@@ -543,7 +459,7 @@ async fn execute_get(
543459 } ;
544460 match remote
545461 . execute_get_sink (
546- conn,
462+ & conn,
547463 local. missing ( ) ,
548464 ( & mut progress) . with_map ( move |x| DownloadProgessItem :: Progress ( x + local_bytes) ) ,
549465 )
@@ -571,77 +487,6 @@ async fn execute_get(
571487 bail ! ( "Unable to download {}" , request. hash) ;
572488}
573489
574- #[ derive( Debug , Clone ) ]
575- struct DialNode {
576- pool : ConnectionPool ,
577- id : NodeId ,
578- }
579-
580- impl DialNode {
581- async fn connection_impl ( & self ) -> anyhow:: Result < Connection > {
582- info ! ( "Getting connection for node {}" , self . id) ;
583- let slot = self
584- . pool
585- . 0
586- . connections
587- . lock ( )
588- . await
589- . entry ( self . id )
590- . or_default ( )
591- . clone ( ) ;
592- info ! ( "Dialing node {}" , self . id) ;
593- let mut guard = slot. lock ( ) . await ;
594- match guard. deref ( ) {
595- SlotState :: Connected ( conn) => {
596- return Ok ( conn. clone ( ) ) ;
597- }
598- SlotState :: AttemptFailed ( time) => {
599- let elapsed = time. elapsed ( ) . unwrap_or_default ( ) ;
600- if elapsed <= self . pool . retry_delay ( ) {
601- bail ! (
602- "Connection attempt failed {} seconds ago" ,
603- elapsed. as_secs_f64( )
604- ) ;
605- }
606- }
607- SlotState :: Evil ( reason) => {
608- bail ! ( "Node is banned due to evil behavior: {reason}" ) ;
609- }
610- SlotState :: Initial => { }
611- }
612- let res = self
613- . pool
614- . endpoint ( )
615- . connect ( self . id , self . pool . alpn ( ) )
616- . timeout ( self . pool . 0 . connect_timeout )
617- . await ;
618- match res {
619- Ok ( Ok ( conn) ) => {
620- info ! ( "Connected to node {}" , self . id) ;
621- * guard = SlotState :: Connected ( conn. clone ( ) ) ;
622- Ok ( conn)
623- }
624- Ok ( Err ( e) ) => {
625- warn ! ( "Failed to connect to node {}: {}" , self . id, e) ;
626- * guard = SlotState :: AttemptFailed ( SystemTime :: now ( ) ) ;
627- Err ( e. into ( ) )
628- }
629- Err ( e) => {
630- warn ! ( "Failed to connect to node {}: {}" , self . id, e) ;
631- * guard = SlotState :: AttemptFailed ( SystemTime :: now ( ) ) ;
632- bail ! ( "Failed to connect to node: {}" , e) ;
633- }
634- }
635- }
636- }
637-
638- impl GetConnection for DialNode {
639- fn connection ( & mut self ) -> impl Future < Output = Result < Connection , anyhow:: Error > > + ' _ {
640- let this = self . clone ( ) ;
641- async move { this. connection_impl ( ) . await }
642- }
643- }
644-
645490/// Trait for pluggable content discovery strategies.
646491pub trait ContentDiscovery : Debug + Send + Sync + ' static {
647492 fn find_providers ( & self , hash : HashAndFormat ) -> n0_future:: stream:: Boxed < NodeId > ;
0 commit comments