From 59e8a00f811302223e63e313e8c8124a6bcaefde Mon Sep 17 00:00:00 2001 From: DanGould Date: Mon, 30 Dec 2024 12:58:39 -0500 Subject: [PATCH 1/3] Use BoxError abstraction Reduce a bunch of duplicated boilerplate code --- src/error.rs | 2 ++ src/gateway_uri.rs | 3 ++- src/lib.rs | 17 ++++------------- 3 files changed, 8 insertions(+), 14 deletions(-) diff --git a/src/error.rs b/src/error.rs index 5d13ba9..fb4e513 100644 --- a/src/error.rs +++ b/src/error.rs @@ -5,6 +5,8 @@ use hyper::{Response, StatusCode}; use crate::{empty, full}; +pub(crate) type BoxError = Box; + #[derive(Debug)] #[allow(clippy::enum_variant_names)] pub(crate) enum Error { diff --git a/src/gateway_uri.rs b/src/gateway_uri.rs index 7767daa..dc2f8b3 100644 --- a/src/gateway_uri.rs +++ b/src/gateway_uri.rs @@ -1,11 +1,12 @@ use http::Uri; +use crate::error::BoxError; /// A normalized gateway origin URI with a default port if none is specified. #[derive(Debug, Clone, PartialEq, Eq)] pub struct GatewayUri(Uri); impl GatewayUri { - pub fn new(mut gateway_origin: Uri) -> Result> { + pub fn new(mut gateway_origin: Uri) -> Result { let (scheme, default_port) = match gateway_origin.scheme_str() { Some("http") => ("http", 80), Some("https") | None => ("https", 443), diff --git a/src/lib.rs b/src/lib.rs index 1415486..59a65c0 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -25,7 +25,7 @@ use tracing::{debug, error, info, instrument}; pub mod error; mod gateway_uri; -use crate::error::Error; +use crate::error::{BoxError, Error}; #[cfg(any(feature = "connect-bootstrap", feature = "ws-bootstrap"))] pub mod bootstrap; @@ -37,10 +37,7 @@ pub static EXPECTED_MEDIA_TYPE: Lazy = Lazy::new(|| HeaderValue::from_str("message/ohttp-req").expect("Invalid HeaderValue")); #[instrument] -pub async fn listen_tcp( - port: u16, - gateway_origin: Uri, -) -> Result<(), Box> { +pub async fn listen_tcp(port: u16, gateway_origin: Uri) -> Result<(), BoxError> { let addr = SocketAddr::from(([0, 0, 0, 0], port)); let listener = TcpListener::bind(addr).await?; println!("OHTTP relay listening on tcp://{}", addr); @@ -48,20 +45,14 @@ pub async fn listen_tcp( } #[instrument] -pub async fn listen_socket( - socket_path: &str, - gateway_origin: Uri, -) -> Result<(), Box> { +pub async fn listen_socket(socket_path: &str, gateway_origin: Uri) -> Result<(), BoxError> { let listener = UnixListener::bind(socket_path)?; info!("OHTTP relay listening on socket: {}", socket_path); ohttp_relay(listener, gateway_origin).await } #[instrument(skip(listener))] -async fn ohttp_relay( - mut listener: L, - gateway_origin: Uri, -) -> Result<(), Box> +async fn ohttp_relay(mut listener: L, gateway_origin: Uri) -> Result<(), BoxError> where L: Listener + Unpin, L::Io: AsyncRead + AsyncWrite + Unpin + Send + 'static, From 9cbec6831370efb489260848cf49ae9e3085f421 Mon Sep 17 00:00:00 2001 From: DanGould Date: Mon, 30 Dec 2024 12:58:39 -0500 Subject: [PATCH 2/3] Add listen_tcp_on_free_port to return a test port Previously in tests downstream ohttp_relay was initiated with a port that may no longer be free by the time it got bound. By having this code bind on and return the port the indirection is removed. --- Cargo.toml | 1 + src/lib.rs | 65 ++++++++++++++++++++++++++++++-------------- src/main.rs | 3 +- tests/integration.rs | 30 ++++++++++++++++---- 4 files changed, 71 insertions(+), 28 deletions(-) diff --git a/Cargo.toml b/Cargo.toml index fd3a03e..4024878 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -18,6 +18,7 @@ default = ["bootstrap"] bootstrap = ["connect-bootstrap", "ws-bootstrap"] connect-bootstrap = [] ws-bootstrap = ["futures", "hyper-tungstenite", "rustls", "tokio-tungstenite"] +_test-util = [] [dependencies] futures = { version = "0.3", optional = true } diff --git a/src/lib.rs b/src/lib.rs index 59a65c0..359a76f 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -37,7 +37,10 @@ pub static EXPECTED_MEDIA_TYPE: Lazy = Lazy::new(|| HeaderValue::from_str("message/ohttp-req").expect("Invalid HeaderValue")); #[instrument] -pub async fn listen_tcp(port: u16, gateway_origin: Uri) -> Result<(), BoxError> { +pub async fn listen_tcp( + port: u16, + gateway_origin: Uri, +) -> Result>, BoxError> { let addr = SocketAddr::from(([0, 0, 0, 0], port)); let listener = TcpListener::bind(addr).await?; println!("OHTTP relay listening on tcp://{}", addr); @@ -45,39 +48,59 @@ pub async fn listen_tcp(port: u16, gateway_origin: Uri) -> Result<(), BoxError> } #[instrument] -pub async fn listen_socket(socket_path: &str, gateway_origin: Uri) -> Result<(), BoxError> { +pub async fn listen_socket( + socket_path: &str, + gateway_origin: Uri, +) -> Result>, BoxError> { let listener = UnixListener::bind(socket_path)?; info!("OHTTP relay listening on socket: {}", socket_path); ohttp_relay(listener, gateway_origin).await } +#[cfg(feature = "_test-util")] +pub async fn listen_tcp_on_free_port( + gateway_origin: Uri, +) -> Result<(u16, tokio::task::JoinHandle>), BoxError> { + let listener = tokio::net::TcpListener::bind("[::]:0").await?; + let port = listener.local_addr()?.port(); + println!("Directory server binding to port {}", listener.local_addr()?); + let handle = ohttp_relay(listener, gateway_origin).await?; + Ok((port, handle)) +} + #[instrument(skip(listener))] -async fn ohttp_relay(mut listener: L, gateway_origin: Uri) -> Result<(), BoxError> +async fn ohttp_relay( + mut listener: L, + gateway_origin: Uri, +) -> Result>, BoxError> where - L: Listener + Unpin, + L: Listener + Unpin + Send + 'static, L::Io: AsyncRead + AsyncWrite + Unpin + Send + 'static, { let gateway_origin = GatewayUri::new(gateway_origin)?; let gateway_origin: Arc = Arc::new(gateway_origin); - while let Ok((stream, _)) = listener.accept().await { - let gateway_origin = gateway_origin.clone(); - let io = TokioIo::new(stream); - tokio::spawn(async move { - if let Err(err) = http1::Builder::new() - .serve_connection( - io, - service_fn(move |req| serve_ohttp_relay(req, gateway_origin.clone())), - ) - .with_upgrades() - .await - { - error!("Error serving connection: {:?}", err); - } - }); - } + let handle = tokio::spawn(async move { + while let Ok((stream, _)) = listener.accept().await { + let gateway_origin = gateway_origin.clone(); + let io = TokioIo::new(stream); + tokio::spawn(async move { + if let Err(err) = http1::Builder::new() + .serve_connection( + io, + service_fn(move |req| serve_ohttp_relay(req, gateway_origin.clone())), + ) + .with_upgrades() + .await + { + error!("Error serving connection: {:?}", err); + } + }); + } + Ok(()) + }); - Ok(()) + Ok(handle) } #[instrument] diff --git a/src/main.rs b/src/main.rs index b4daec7..9fd60e4 100644 --- a/src/main.rs +++ b/src/main.rs @@ -26,8 +26,7 @@ async fn main() -> Result<(), Box> { } (Err(_), Err(_)) => ohttp_relay::listen_tcp(DEFAULT_PORT, gateway_origin).await?, } - - Ok(()) + .await? } fn init_tracing() { diff --git a/tests/integration.rs b/tests/integration.rs index 1121802..351a1ed 100644 --- a/tests/integration.rs +++ b/tests/integration.rs @@ -1,4 +1,5 @@ #[cfg(test)] +#[cfg(feature = "_test-util")] mod integration { use std::fs::File; use std::io::Read; @@ -34,7 +35,13 @@ mod integration { async fn test_request_response_tcp() { let gateway_port = find_free_port(); let gateway = Uri::from_str(&format!("http://0.0.0.0:{}", gateway_port)).unwrap(); - let relay_port = find_free_port(); + let (relay_port, relay_handle) = + listen_tcp_on_free_port(gateway).await.expect("Failed to listen on free port"); + let relay_task = tokio::spawn(async move { + if let Err(e) = relay_handle.await { + eprintln!("Relay failed: {}", e); + } + }); let n_http_port = find_free_port(); let n_https_port = find_free_port(); let nginx_cert = gen_localhost_cert(); @@ -46,7 +53,7 @@ mod integration { _ = example_gateway_http(gateway_port) => { assert!(false, "Gateway is long running"); } - _ = listen_tcp(relay_port, gateway) => { + _ = relay_task => { assert!(false, "Relay is long running"); } _ = ohttp_req(n_https_port, nginx_cert_der) => {} @@ -67,6 +74,13 @@ mod integration { let nginx_cert = gen_localhost_cert(); let nginx_cert_der = cert_to_cert_der(&nginx_cert); let socket_path_str = socket_path.to_str().unwrap(); + let relay_handle = + listen_socket(socket_path_str, gateway).await.expect("Failed to listen on socket"); + let relay_task = tokio::spawn(async move { + if let Err(e) = relay_handle.await { + eprintln!("Relay failed: {}", e); + } + }); let n_http_port = find_free_port(); let n_https_port = find_free_port(); let _nginx = @@ -76,7 +90,7 @@ mod integration { _ = example_gateway_http(gateway_port) => { assert!(false, "Gateway is long running"); } - _ = listen_socket(socket_path_str, gateway) => { + _ = relay_task => { assert!(false, "Relay is long running"); } _ = ohttp_req(n_https_port, nginx_cert_der) => {} @@ -286,10 +300,16 @@ mod integration { { let gateway_port = find_free_port(); let gateway = Uri::from_str(&format!("http://0.0.0.0:{}", gateway_port)).unwrap(); - let relay_port = find_free_port(); let nginx_cert = gen_localhost_cert(); let gateway_cert = gen_localhost_cert(); let gateway_cert_der = cert_to_cert_der(&gateway_cert); + let (relay_port, relay_handle) = + listen_tcp_on_free_port(gateway).await.expect("Failed to listen on free port"); + let relay_task = tokio::spawn(async move { + if let Err(e) = relay_handle.await { + eprintln!("Relay failed: {}", e); + } + }); let n_http_port = find_free_port(); let n_https_port = find_free_port(); let _nginx = start_nginx( @@ -303,7 +323,7 @@ mod integration { _ = example_gateway_https(gateway_port, gateway_cert) => { assert!(false, "Gateway is long running"); } - _ = listen_tcp(relay_port, gateway) => { + _ = relay_task => { assert!(false, "Relay is long running"); } _ = client_fn(n_http_port, gateway_port, gateway_cert_der) => {} From d8f75786df8a4a554a02db3dedf755125fd125e4 Mon Sep 17 00:00:00 2001 From: DanGould Date: Mon, 30 Dec 2024 16:41:42 -0500 Subject: [PATCH 3/3] Fix nginx race NGINX was also returning before binding to the port, sometimes leading to port contention with other test services Fix by waiting for it to connect before attempting to start other services --- tests/integration.rs | 9 +++++++++ 1 file changed, 9 insertions(+) diff --git a/tests/integration.rs b/tests/integration.rs index 351a1ed..5d44519 100644 --- a/tests/integration.rs +++ b/tests/integration.rs @@ -454,6 +454,15 @@ mod integration { .spawn() .expect("Failed to start nginx"); + let start = std::time::Instant::now(); + let timeout = std::time::Duration::from_secs(5); + while start.elapsed() < timeout { + if let Ok(_) = std::net::TcpStream::connect(format!("127.0.0.1:{}", n_https_port)) { + break; + } + tokio::time::sleep(std::time::Duration::from_millis(100)).await; + } + // Keep the config file open as long as NGINX is using it std::mem::forget(config_file);