diff --git a/Cargo.toml b/Cargo.toml index fd3a03e..4024878 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -18,6 +18,7 @@ default = ["bootstrap"] bootstrap = ["connect-bootstrap", "ws-bootstrap"] connect-bootstrap = [] ws-bootstrap = ["futures", "hyper-tungstenite", "rustls", "tokio-tungstenite"] +_test-util = [] [dependencies] futures = { version = "0.3", optional = true } diff --git a/src/error.rs b/src/error.rs index 5d13ba9..fb4e513 100644 --- a/src/error.rs +++ b/src/error.rs @@ -5,6 +5,8 @@ use hyper::{Response, StatusCode}; use crate::{empty, full}; +pub(crate) type BoxError = Box; + #[derive(Debug)] #[allow(clippy::enum_variant_names)] pub(crate) enum Error { diff --git a/src/gateway_uri.rs b/src/gateway_uri.rs index 7767daa..dc2f8b3 100644 --- a/src/gateway_uri.rs +++ b/src/gateway_uri.rs @@ -1,11 +1,12 @@ use http::Uri; +use crate::error::BoxError; /// A normalized gateway origin URI with a default port if none is specified. #[derive(Debug, Clone, PartialEq, Eq)] pub struct GatewayUri(Uri); impl GatewayUri { - pub fn new(mut gateway_origin: Uri) -> Result> { + pub fn new(mut gateway_origin: Uri) -> Result { let (scheme, default_port) = match gateway_origin.scheme_str() { Some("http") => ("http", 80), Some("https") | None => ("https", 443), diff --git a/src/lib.rs b/src/lib.rs index 1415486..359a76f 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -25,7 +25,7 @@ use tracing::{debug, error, info, instrument}; pub mod error; mod gateway_uri; -use crate::error::Error; +use crate::error::{BoxError, Error}; #[cfg(any(feature = "connect-bootstrap", feature = "ws-bootstrap"))] pub mod bootstrap; @@ -40,7 +40,7 @@ pub static EXPECTED_MEDIA_TYPE: Lazy = pub async fn listen_tcp( port: u16, gateway_origin: Uri, -) -> Result<(), Box> { +) -> Result>, BoxError> { let addr = SocketAddr::from(([0, 0, 0, 0], port)); let listener = TcpListener::bind(addr).await?; println!("OHTTP relay listening on tcp://{}", addr); @@ -51,42 +51,56 @@ pub async fn listen_tcp( pub async fn listen_socket( socket_path: &str, gateway_origin: Uri, -) -> Result<(), Box> { +) -> Result>, BoxError> { let listener = UnixListener::bind(socket_path)?; info!("OHTTP relay listening on socket: {}", socket_path); ohttp_relay(listener, gateway_origin).await } +#[cfg(feature = "_test-util")] +pub async fn listen_tcp_on_free_port( + gateway_origin: Uri, +) -> Result<(u16, tokio::task::JoinHandle>), BoxError> { + let listener = tokio::net::TcpListener::bind("[::]:0").await?; + let port = listener.local_addr()?.port(); + println!("Directory server binding to port {}", listener.local_addr()?); + let handle = ohttp_relay(listener, gateway_origin).await?; + Ok((port, handle)) +} + #[instrument(skip(listener))] async fn ohttp_relay( mut listener: L, gateway_origin: Uri, -) -> Result<(), Box> +) -> Result>, BoxError> where - L: Listener + Unpin, + L: Listener + Unpin + Send + 'static, L::Io: AsyncRead + AsyncWrite + Unpin + Send + 'static, { let gateway_origin = GatewayUri::new(gateway_origin)?; let gateway_origin: Arc = Arc::new(gateway_origin); - while let Ok((stream, _)) = listener.accept().await { - let gateway_origin = gateway_origin.clone(); - let io = TokioIo::new(stream); - tokio::spawn(async move { - if let Err(err) = http1::Builder::new() - .serve_connection( - io, - service_fn(move |req| serve_ohttp_relay(req, gateway_origin.clone())), - ) - .with_upgrades() - .await - { - error!("Error serving connection: {:?}", err); - } - }); - } + let handle = tokio::spawn(async move { + while let Ok((stream, _)) = listener.accept().await { + let gateway_origin = gateway_origin.clone(); + let io = TokioIo::new(stream); + tokio::spawn(async move { + if let Err(err) = http1::Builder::new() + .serve_connection( + io, + service_fn(move |req| serve_ohttp_relay(req, gateway_origin.clone())), + ) + .with_upgrades() + .await + { + error!("Error serving connection: {:?}", err); + } + }); + } + Ok(()) + }); - Ok(()) + Ok(handle) } #[instrument] diff --git a/src/main.rs b/src/main.rs index b4daec7..9fd60e4 100644 --- a/src/main.rs +++ b/src/main.rs @@ -26,8 +26,7 @@ async fn main() -> Result<(), Box> { } (Err(_), Err(_)) => ohttp_relay::listen_tcp(DEFAULT_PORT, gateway_origin).await?, } - - Ok(()) + .await? } fn init_tracing() { diff --git a/tests/integration.rs b/tests/integration.rs index 1121802..5d44519 100644 --- a/tests/integration.rs +++ b/tests/integration.rs @@ -1,4 +1,5 @@ #[cfg(test)] +#[cfg(feature = "_test-util")] mod integration { use std::fs::File; use std::io::Read; @@ -34,7 +35,13 @@ mod integration { async fn test_request_response_tcp() { let gateway_port = find_free_port(); let gateway = Uri::from_str(&format!("http://0.0.0.0:{}", gateway_port)).unwrap(); - let relay_port = find_free_port(); + let (relay_port, relay_handle) = + listen_tcp_on_free_port(gateway).await.expect("Failed to listen on free port"); + let relay_task = tokio::spawn(async move { + if let Err(e) = relay_handle.await { + eprintln!("Relay failed: {}", e); + } + }); let n_http_port = find_free_port(); let n_https_port = find_free_port(); let nginx_cert = gen_localhost_cert(); @@ -46,7 +53,7 @@ mod integration { _ = example_gateway_http(gateway_port) => { assert!(false, "Gateway is long running"); } - _ = listen_tcp(relay_port, gateway) => { + _ = relay_task => { assert!(false, "Relay is long running"); } _ = ohttp_req(n_https_port, nginx_cert_der) => {} @@ -67,6 +74,13 @@ mod integration { let nginx_cert = gen_localhost_cert(); let nginx_cert_der = cert_to_cert_der(&nginx_cert); let socket_path_str = socket_path.to_str().unwrap(); + let relay_handle = + listen_socket(socket_path_str, gateway).await.expect("Failed to listen on socket"); + let relay_task = tokio::spawn(async move { + if let Err(e) = relay_handle.await { + eprintln!("Relay failed: {}", e); + } + }); let n_http_port = find_free_port(); let n_https_port = find_free_port(); let _nginx = @@ -76,7 +90,7 @@ mod integration { _ = example_gateway_http(gateway_port) => { assert!(false, "Gateway is long running"); } - _ = listen_socket(socket_path_str, gateway) => { + _ = relay_task => { assert!(false, "Relay is long running"); } _ = ohttp_req(n_https_port, nginx_cert_der) => {} @@ -286,10 +300,16 @@ mod integration { { let gateway_port = find_free_port(); let gateway = Uri::from_str(&format!("http://0.0.0.0:{}", gateway_port)).unwrap(); - let relay_port = find_free_port(); let nginx_cert = gen_localhost_cert(); let gateway_cert = gen_localhost_cert(); let gateway_cert_der = cert_to_cert_der(&gateway_cert); + let (relay_port, relay_handle) = + listen_tcp_on_free_port(gateway).await.expect("Failed to listen on free port"); + let relay_task = tokio::spawn(async move { + if let Err(e) = relay_handle.await { + eprintln!("Relay failed: {}", e); + } + }); let n_http_port = find_free_port(); let n_https_port = find_free_port(); let _nginx = start_nginx( @@ -303,7 +323,7 @@ mod integration { _ = example_gateway_https(gateway_port, gateway_cert) => { assert!(false, "Gateway is long running"); } - _ = listen_tcp(relay_port, gateway) => { + _ = relay_task => { assert!(false, "Relay is long running"); } _ = client_fn(n_http_port, gateway_port, gateway_cert_der) => {} @@ -434,6 +454,15 @@ mod integration { .spawn() .expect("Failed to start nginx"); + let start = std::time::Instant::now(); + let timeout = std::time::Duration::from_secs(5); + while start.elapsed() < timeout { + if let Ok(_) = std::net::TcpStream::connect(format!("127.0.0.1:{}", n_https_port)) { + break; + } + tokio::time::sleep(std::time::Duration::from_millis(100)).await; + } + // Keep the config file open as long as NGINX is using it std::mem::forget(config_file);