@@ -6,16 +6,15 @@ use iroh::{
66 endpoint:: { get_remote_node_id, Connecting } ,
77 Endpoint , NodeAddr , SecretKey ,
88} ;
9+ use quinn:: Connection ;
910use std:: {
10- io,
11- net:: { SocketAddr , SocketAddrV4 , SocketAddrV6 , ToSocketAddrs } ,
12- str:: FromStr ,
11+ collections:: HashMap , io, net:: { SocketAddr , SocketAddrV4 , SocketAddrV6 , ToSocketAddrs } , str:: FromStr , sync:: Arc
1312} ;
1413use tokio:: {
15- io:: { AsyncRead , AsyncWrite , AsyncWriteExt } ,
16- select,
14+ io:: { AsyncRead , AsyncWrite , AsyncWriteExt } , net:: UdpSocket , select
1715} ;
1816use tokio_util:: sync:: CancellationToken ;
17+ mod udpconn;
1918
2019/// Create a dumb pipe between two machines, using an iroh magicsocket.
2120///
@@ -54,6 +53,15 @@ pub enum Commands {
5453 /// connecting to a TCP socket for which you have to specify the host and port.
5554 ListenTcp ( ListenTcpArgs ) ,
5655
56+ /// Listen on a magicsocket and forward incoming connections to the specified
57+ /// host and port. Every incoming bidi stream is forwarded to a new connection.
58+ ///
59+ /// Will print a node ticket on stderr that can be used to connect.
60+ ///
61+ /// As far as the magic socket is concerned, this is listening. But it is
62+ /// connecting to a UDP socket for which you have to specify the host and port.
63+ ListenUdp ( ListenUdpArgs ) ,
64+
5765 /// Connect to a magicsocket, open a bidi stream, and forward stdin/stdout.
5866 ///
5967 /// A node ticket is required to connect.
@@ -67,6 +75,15 @@ pub enum Commands {
6775 /// As far as the magic socket is concerned, this is connecting. But it is
6876 /// listening on a TCP socket for which you have to specify the interface and port.
6977 ConnectTcp ( ConnectTcpArgs ) ,
78+
79+ /// Connect to a magicsocket, open a bidi stream, and forward stdin/stdout
80+ /// to it.
81+ ///
82+ /// A node ticket is required to connect.
83+ ///
84+ /// As far as the magic socket is concerned, this is connecting. But it is
85+ /// listening on a UDP socket for which you have to specify the interface and port.
86+ ConnectUdp ( ConnectUdpArgs ) ,
7087}
7188
7289#[ derive( Parser , Debug ) ]
@@ -140,6 +157,15 @@ pub struct ListenTcpArgs {
140157 pub common : CommonArgs ,
141158}
142159
160+ #[ derive( Parser , Debug ) ]
161+ pub struct ListenUdpArgs {
162+ #[ clap( long) ]
163+ pub host : String ,
164+
165+ #[ clap( flatten) ]
166+ pub common : CommonArgs ,
167+ }
168+
143169#[ derive( Parser , Debug ) ]
144170pub struct ConnectTcpArgs {
145171 /// The addresses to listen on for incoming tcp connections.
@@ -155,6 +181,21 @@ pub struct ConnectTcpArgs {
155181 pub common : CommonArgs ,
156182}
157183
184+ #[ derive( Parser , Debug ) ]
185+ pub struct ConnectUdpArgs {
186+ /// The addresses to listen on for incoming udp connections.
187+ ///
188+ /// To listen on all network interfaces, use 0.0.0.0:12345
189+ #[ clap( long) ]
190+ pub addr : String ,
191+
192+ /// The node to connect to
193+ pub ticket : NodeTicket ,
194+
195+ #[ clap( flatten) ]
196+ pub common : CommonArgs ,
197+ }
198+
158199#[ derive( Parser , Debug ) ]
159200pub struct ConnectArgs {
160201 /// The node to connect to
@@ -440,6 +481,126 @@ async fn connect_tcp(args: ConnectTcpArgs) -> anyhow::Result<()> {
440481 Ok ( ( ) )
441482}
442483
484+ pub struct SplitUdpConn {
485+ // TODO: Do we need to store this connection?
486+ // Holding on to this for the future where we need to cleanup the resources.
487+ connection : quinn:: Connection ,
488+ send : quinn:: SendStream ,
489+ }
490+
491+ impl SplitUdpConn {
492+ pub fn new ( connection : quinn:: Connection , send : quinn:: SendStream ) -> Self {
493+ Self {
494+ connection,
495+ send
496+ }
497+ }
498+ }
499+
500+ // 1- Receives request message from socket
501+ // 2- Forwards it to the quinn stream
502+ // 3- Receives response message back from quinn stream
503+ // 4- Forwards it back to the socket
504+ async fn connect_udp ( args : ConnectUdpArgs ) -> anyhow:: Result < ( ) > {
505+ let addrs = args
506+ . addr
507+ . to_socket_addrs ( )
508+ . context ( format ! ( "invalid host string {}" , args. addr) ) ?;
509+ let secret_key = get_or_create_secret ( ) ?;
510+ let mut builder = Endpoint :: builder ( ) . secret_key ( secret_key) . alpns ( vec ! [ ] ) ;
511+ if let Some ( addr) = args. common . magic_ipv4_addr {
512+ builder = builder. bind_addr_v4 ( addr) ;
513+ }
514+ if let Some ( addr) = args. common . magic_ipv6_addr {
515+ builder = builder. bind_addr_v6 ( addr) ;
516+ }
517+ let endpoint = builder. bind ( ) . await . context ( "unable to bind magicsock" ) ?;
518+ tracing:: info!( "udp listening on {:?}" , addrs) ;
519+ let socket = Arc :: new ( UdpSocket :: bind ( addrs. as_slice ( ) ) . await ?) ;
520+
521+ let node_addr = args. ticket . node_addr ( ) ;
522+ let mut buf: Vec < u8 > = vec ! [ 0u8 ; 65535 ] ;
523+ let mut conns = HashMap :: < SocketAddr , SplitUdpConn > :: new ( ) ;
524+ loop {
525+ match socket. recv_from ( & mut buf) . await {
526+ Ok ( ( size, sock_addr) ) => {
527+ // Check if we already have a connection for this socket address
528+ let connection = match conns. get_mut ( & sock_addr) {
529+ Some ( conn) => conn,
530+ None => {
531+ // We need to finish the connection to be done or we should use something like promise because
532+ // when the connection was getting established, it might receive another message.
533+ let endpoint = endpoint. clone ( ) ;
534+ let addr = node_addr. clone ( ) ;
535+ let handshake = !args. common . is_custom_alpn ( ) ;
536+ let alpn = args. common . alpn ( ) ?;
537+
538+ let remote_node_id = addr. node_id ;
539+ tracing:: info!( "forwarding UDP to {}" , remote_node_id) ;
540+
541+ // connect to the node, try only once
542+ let connection = endpoint
543+ . connect ( addr. clone ( ) , & alpn)
544+ . await
545+ . context ( format ! ( "error connecting to {}" , remote_node_id) ) ?;
546+ tracing:: info!( "connected to {}" , remote_node_id) ;
547+
548+ // open a bidi stream, try only once
549+ let ( mut send, recv) = connection
550+ . open_bi ( )
551+ . await
552+ . context ( format ! ( "error opening bidi stream to {}" , remote_node_id) ) ?;
553+ tracing:: info!( "opened bidi stream to {}" , remote_node_id) ;
554+
555+ // send the handshake unless we are using a custom alpn
556+ if handshake {
557+ send. write_all ( & dumbpipe:: HANDSHAKE ) . await ?;
558+ }
559+
560+ let sock_send = socket. clone ( ) ;
561+ // Spawn a task for listening the quinn connection, and forwarding the data to the UDP socket
562+ tokio:: spawn ( async move {
563+ // 3- Receives response message back from quinn stream
564+ // 4- Forwards it back to the socket
565+ if let Err ( cause) = udpconn:: handle_udp_accept ( sock_addr, sock_send, recv )
566+ . await {
567+ // log error at warn level
568+ //
569+ // we should know about it, but it's not fatal
570+ tracing:: warn!( "error handling connection: {}" , cause) ;
571+
572+ // TODO: cleanup resources
573+ }
574+ } ) ;
575+
576+ // Create and store the split connection
577+ let split_conn = SplitUdpConn :: new ( connection. clone ( ) , send) ;
578+ conns. insert ( sock_addr, split_conn) ;
579+ conns. get_mut ( & sock_addr) . expect ( "connection was just inserted" )
580+ }
581+ } ;
582+
583+ tracing:: info!( "forward_udp_to_quinn: Received {} bytes from {}" , size, sock_addr) ;
584+
585+ // 1- Receives request message from socket
586+ // 2- Forwards it to the quinn stream
587+ if let Err ( e) = connection. send . write_all ( & buf[ ..size] ) . await {
588+ tracing:: error!( "Error writing to Quinn stream: {}" , e) ;
589+ // TODO: Cleanup the resources on error.
590+ // Remove the failed connection
591+ // conns.remove(&sock_addr);
592+ return Err ( e. into ( ) ) ;
593+ }
594+ }
595+ Err ( e) => {
596+ tracing:: warn!( "error receiving from UDP socket: {}" , e) ;
597+ break ;
598+ }
599+ }
600+ }
601+ Ok ( ( ) )
602+ }
603+
443604/// Listen on a magicsocket and forward incoming connections to a tcp socket.
444605async fn listen_tcp ( args : ListenTcpArgs ) -> anyhow:: Result < ( ) > {
445606 let addrs = match args. host . to_socket_addrs ( ) {
@@ -533,15 +694,111 @@ async fn listen_tcp(args: ListenTcpArgs) -> anyhow::Result<()> {
533694 Ok ( ( ) )
534695}
535696
697+ /// Listen on a magicsocket and forward incoming connections to a udp socket.
698+ async fn listen_udp ( args : ListenUdpArgs ) -> anyhow:: Result < ( ) > {
699+ let addrs = match args. host . to_socket_addrs ( ) {
700+ Ok ( addrs) => addrs. collect :: < Vec < _ > > ( ) ,
701+ Err ( e) => anyhow:: bail!( "invalid host string {}: {}" , args. host, e) ,
702+ } ;
703+ let secret_key = get_or_create_secret ( ) ?;
704+ let mut builder = Endpoint :: builder ( )
705+ . alpns ( vec ! [ args. common. alpn( ) ?] )
706+ . secret_key ( secret_key) ;
707+ if let Some ( addr) = args. common . magic_ipv4_addr {
708+ builder = builder. bind_addr_v4 ( addr) ;
709+ }
710+ if let Some ( addr) = args. common . magic_ipv6_addr {
711+ builder = builder. bind_addr_v6 ( addr) ;
712+ }
713+ let endpoint = builder. bind ( ) . await ?;
714+ // wait for the endpoint to figure out its address before making a ticket
715+ endpoint. home_relay ( ) . initialized ( ) . await ?;
716+ let node_addr = endpoint. node_addr ( ) . await ?;
717+ let mut short = node_addr. clone ( ) ;
718+ let ticket = NodeTicket :: new ( node_addr) ;
719+ short. direct_addresses . clear ( ) ;
720+ let short = NodeTicket :: new ( short) ;
721+
722+ // print the ticket on stderr so it doesn't interfere with the data itself
723+ //
724+ // note that the tests rely on the ticket being the last thing printed
725+ eprintln ! ( "Forwarding incoming requests to '{}'." , args. host) ;
726+ eprintln ! ( "To connect, use e.g.:" ) ;
727+ eprintln ! ( "dumbpipe connect-udp {ticket}" ) ;
728+ if args. common . verbose > 0 {
729+ eprintln ! ( "or:\n dumbpipe connect-udp {}" , short) ;
730+ }
731+ tracing:: info!( "node id is {}" , ticket. node_addr( ) . node_id) ;
732+ tracing:: info!( "derp url is {:?}" , ticket. node_addr( ) . relay_url) ;
733+
734+ // handle a new incoming connection on the magic endpoint
735+ async fn handle_magic_accept (
736+ connecting : Connecting ,
737+ addrs : Vec < std:: net:: SocketAddr > ,
738+ handshake : bool ,
739+ ) -> anyhow:: Result < ( ) > {
740+ let connection = connecting. await . context ( "error accepting connection" ) ?;
741+ let remote_node_id = get_remote_node_id ( & connection) ?;
742+ tracing:: info!( "got connection from {}" , remote_node_id) ;
743+ let ( s, mut r) = connection
744+ . accept_bi ( )
745+ . await
746+ . context ( "error accepting stream" ) ?;
747+ tracing:: info!( "accepted bidi stream from {}" , remote_node_id) ;
748+ if handshake {
749+ // read the handshake and verify it
750+ let mut buf = [ 0u8 ; dumbpipe:: HANDSHAKE . len ( ) ] ;
751+ r. read_exact ( & mut buf) . await ?;
752+ anyhow:: ensure!( buf == dumbpipe:: HANDSHAKE , "invalid handshake" ) ;
753+ }
754+
755+ // 1- Receives request message from quinn stream
756+ // 2- Forwards it to the (addrs) via UDP socket
757+ // 3- Receives response message back from UDP socket
758+ // 4- Forwards it back to the quinn stream
759+ udpconn:: handle_udp_listen ( addrs. as_slice ( ) , r, s) . await ?;
760+ Ok ( ( ) )
761+ }
762+
763+ loop {
764+ let incoming = select ! {
765+ incoming = endpoint. accept( ) => incoming,
766+ _ = tokio:: signal:: ctrl_c( ) => {
767+ eprintln!( "got ctrl-c, exiting" ) ;
768+ break ;
769+ }
770+ } ;
771+ let Some ( incoming) = incoming else {
772+ break ;
773+ } ;
774+ let Ok ( connecting) = incoming. accept ( ) else {
775+ break ;
776+ } ;
777+ let addrs = addrs. clone ( ) ;
778+ let handshake = !args. common . is_custom_alpn ( ) ;
779+ tokio:: spawn ( async move {
780+ if let Err ( cause) = handle_magic_accept ( connecting, addrs, handshake) . await {
781+ // log error at warn level
782+ //
783+ // we should know about it, but it's not fatal
784+ tracing:: warn!( "error handling connection: {}" , cause) ;
785+ }
786+ } ) ;
787+ }
788+ Ok ( ( ) )
789+ }
790+
536791#[ tokio:: main]
537792async fn main ( ) -> anyhow:: Result < ( ) > {
538793 tracing_subscriber:: fmt:: init ( ) ;
539794 let args = Args :: parse ( ) ;
540795 let res = match args. command {
541796 Commands :: Listen ( args) => listen_stdio ( args) . await ,
542797 Commands :: ListenTcp ( args) => listen_tcp ( args) . await ,
798+ Commands :: ListenUdp ( args) => listen_udp ( args) . await ,
543799 Commands :: Connect ( args) => connect_stdio ( args) . await ,
544800 Commands :: ConnectTcp ( args) => connect_tcp ( args) . await ,
801+ Commands :: ConnectUdp ( args) => connect_udp ( args) . await ,
545802 } ;
546803 match res {
547804 Ok ( ( ) ) => std:: process:: exit ( 0 ) ,
0 commit comments