From 25261c1778e6cf920e76e7517ee122ac258cdd88 Mon Sep 17 00:00:00 2001 From: DanGould Date: Mon, 30 Dec 2024 12:58:39 -0500 Subject: [PATCH] Add listen_tcp_on_free_port to return a test port Previously in tests downstream ohttp_relay was initiated with a port that may no longer be free by the time it got bound. By having this code bind on and return the port the indirection is removed. --- Cargo.toml | 1 + src/lib.rs | 61 +++++++++++++++++++++++++++++++++++------------------- 2 files changed, 41 insertions(+), 21 deletions(-) diff --git a/Cargo.toml b/Cargo.toml index 1c9c660..08a8fcc 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -18,6 +18,7 @@ default = ["bootstrap"] bootstrap = ["connect-bootstrap", "ws-bootstrap"] connect-bootstrap = [] ws-bootstrap = ["futures", "hyper-tungstenite", "rustls", "tokio-tungstenite"] +_test-util = [] [dependencies] futures = { version = "0.3", optional = true } diff --git a/src/lib.rs b/src/lib.rs index 59a65c0..700cbe1 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -41,43 +41,62 @@ pub async fn listen_tcp(port: u16, gateway_origin: Uri) -> Result<(), BoxError> let addr = SocketAddr::from(([0, 0, 0, 0], port)); let listener = TcpListener::bind(addr).await?; println!("OHTTP relay listening on tcp://{}", addr); - ohttp_relay(listener, gateway_origin).await + ohttp_relay(listener, gateway_origin).await?; + Ok(()) } #[instrument] pub async fn listen_socket(socket_path: &str, gateway_origin: Uri) -> Result<(), BoxError> { let listener = UnixListener::bind(socket_path)?; info!("OHTTP relay listening on socket: {}", socket_path); - ohttp_relay(listener, gateway_origin).await + ohttp_relay(listener, gateway_origin).await?; + Ok(()) +} + +#[cfg(feature = "_test-util")] +pub async fn listen_tcp_on_free_port( + gateway_origin: Uri, +) -> Result<(u16, tokio::task::JoinHandle>), BoxError> { + let listener = tokio::net::TcpListener::bind("[::]:0").await?; + let port = listener.local_addr()?.port(); + println!("Directory server binding to port {}", listener.local_addr()?); + let handle = ohttp_relay(listener, gateway_origin).await?; + Ok((port, handle)) } #[instrument(skip(listener))] -async fn ohttp_relay(mut listener: L, gateway_origin: Uri) -> Result<(), BoxError> +async fn ohttp_relay( + mut listener: L, + gateway_origin: Uri, +) -> Result>, BoxError> where - L: Listener + Unpin, + L: Listener + Unpin + Send + 'static, L::Io: AsyncRead + AsyncWrite + Unpin + Send + 'static, { let gateway_origin = GatewayUri::new(gateway_origin)?; let gateway_origin: Arc = Arc::new(gateway_origin); - while let Ok((stream, _)) = listener.accept().await { - let gateway_origin = gateway_origin.clone(); - let io = TokioIo::new(stream); - tokio::spawn(async move { - if let Err(err) = http1::Builder::new() - .serve_connection( - io, - service_fn(move |req| serve_ohttp_relay(req, gateway_origin.clone())), - ) - .with_upgrades() - .await - { - error!("Error serving connection: {:?}", err); - } - }); - } + let handle = tokio::spawn(async move { + while let Ok((stream, _)) = listener.accept().await { + let gateway_origin = gateway_origin.clone(); + let io = TokioIo::new(stream); + tokio::spawn(async move { + if let Err(err) = http1::Builder::new() + .serve_connection( + io, + service_fn(move |req| serve_ohttp_relay(req, gateway_origin.clone())), + ) + .with_upgrades() + .await + { + error!("Error serving connection: {:?}", err); + } + }); + } + Ok(()) + }); - Ok(()) + Ok(handle) } #[instrument]