diff --git a/Cargo.lock b/Cargo.lock index 264c816f6..d5402576f 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -3040,6 +3040,7 @@ dependencies = [ "scuffle-context", "scuffle-future-ext", "scuffle-workspace-hack", + "socket2", "thiserror 2.0.12", "tokio", "tokio-rustls", @@ -3281,6 +3282,7 @@ dependencies = [ "serde", "serde_json", "smallvec", + "socket2", "stable_deref_trait", "syn", "sync_wrapper", diff --git a/Justfile b/Justfile index 20d68af98..314d041a4 100644 --- a/Justfile +++ b/Justfile @@ -16,7 +16,7 @@ lint *args: alias coverage := test test *args: - #!/bin/bash + #!/usr/bin/env bash set -euo pipefail INSTA_FORCE_PASS=1 cargo +{{RUST_TOOLCHAIN}} llvm-cov clean --workspace @@ -34,7 +34,7 @@ coverage-serve: miniserve target/llvm-cov/html --index index.html --port 3000 grind *args: - #!/bin/bash + #!/usr/bin/env bash set -euo pipefail # Runs valgrind on the tests. @@ -46,7 +46,7 @@ grind *args: alias docs := doc doc *args: - #!/bin/bash + #!/usr/bin/env bash set -euo pipefail # `--cfg docsrs` enables us to write feature hints in the form of `#[cfg_attr(docsrs, doc(cfg(feature = "some-feature")))]` diff --git a/changes.d/pr-407.toml b/changes.d/pr-407.toml new file mode 100644 index 000000000..1502a3ff6 --- /dev/null +++ b/changes.d/pr-407.toml @@ -0,0 +1,4 @@ +[[scuffle-http]] +category = "feat" +description = "add ability to configure sockets using callbacks" +authors = ["@DrSloth", "@lennartkloock"] diff --git a/crates/http/Cargo.toml b/crates/http/Cargo.toml index 70da8f6a5..a82b2903a 100644 --- a/crates/http/Cargo.toml +++ b/crates/http/Cargo.toml @@ -20,6 +20,7 @@ futures = { version = "0.3.31", default-features = false, features = ["alloc"]} bon = "3.3.2" pin-project-lite = "0.2.16" scuffle-context.workspace = true +socket2 = "0.5.8" # HTTP parsing http = "1.2.0" diff --git a/crates/http/src/backend/h3/mod.rs b/crates/http/src/backend/h3/mod.rs index 34c599510..55c9e2e34 100644 --- a/crates/http/src/backend/h3/mod.rs +++ b/crates/http/src/backend/h3/mod.rs @@ -11,6 +11,7 @@ use tracing::Instrument; use utils::copy_response_body; use crate::error::Error; +use crate::server::CreateSocketCallback; use crate::service::{HttpService, HttpServiceFactory}; pub mod body; @@ -36,6 +37,8 @@ pub struct Http3Backend { /// Use `[::]` for a dual-stack listener. /// For example, use `[::]:80` to bind to port 80 on both IPv4 and IPv6. bind: SocketAddr, + /// Callback to configure socket + create_custom_sock: Option, /// rustls config. /// /// Use this field to set the server into TLS mode. @@ -67,7 +70,27 @@ where let server_config = h3_quinn::quinn::ServerConfig::with_crypto(Arc::new(crypto)); // Bind the UDP socket - let socket = std::net::UdpSocket::bind(self.bind)?; + let socket = { + let sock = if let Some(cfg_fn) = self.create_custom_sock.as_ref() { + cfg_fn.call(self.bind)? + } else { + let sock = socket2::Socket::new( + match self.bind { + SocketAddr::V4(_) => socket2::Domain::IPV4, + SocketAddr::V6(_) => socket2::Domain::IPV6, + }, + socket2::Type::DGRAM, + Some(socket2::Protocol::UDP), + )?; + + sock.set_nonblocking(true)?; + sock.bind(&socket2::SockAddr::from(self.bind))?; + + sock + }; + + std::net::UdpSocket::from(sock) + }; // Runtime for the quinn endpoint let runtime = h3_quinn::quinn::default_runtime().ok_or_else(|| io::Error::other("no async runtime found"))?; diff --git a/crates/http/src/backend/hyper/mod.rs b/crates/http/src/backend/hyper/mod.rs index bf6074925..1a3128dd1 100644 --- a/crates/http/src/backend/hyper/mod.rs +++ b/crates/http/src/backend/hyper/mod.rs @@ -7,6 +7,7 @@ use scuffle_context::ContextFutExt; use tracing::Instrument; use crate::error::Error; +use crate::server::CreateSocketCallback; use crate::service::{HttpService, HttpServiceFactory}; mod handler; @@ -33,6 +34,8 @@ pub struct HyperBackend { /// Use `[::]` for a dual-stack listener. /// For example, use `[::]:80` to bind to port 80 on both IPv4 and IPv6. bind: SocketAddr, + /// Callback to create a custom socket + create_custom_sock: Option, /// rustls config. /// /// Use this field to set the server into TLS mode. @@ -79,7 +82,28 @@ where } // We have to create an std listener first because the tokio listener isn't clonable - let listener = tokio::net::TcpListener::bind(self.bind).await?.into_std()?; + let listener = { + let sock = if let Some(cfg_fn) = self.create_custom_sock.as_ref() { + cfg_fn.call(self.bind)? + } else { + let mut sock = socket2::Socket::new( + match self.bind { + SocketAddr::V4(_) => socket2::Domain::IPV4, + SocketAddr::V6(_) => socket2::Domain::IPV6, + }, + socket2::Type::STREAM, + Some(socket2::Protocol::TCP), + )?; + + sock.set_nonblocking(true)?; + sock.bind(&socket2::SockAddr::from(self.bind))?; + sock.listen(128)?; + + sock + }; + + std::net::TcpListener::from(sock) + }; #[cfg(feature = "tls-rustls")] let tls_acceptor = self diff --git a/crates/http/src/lib.rs b/crates/http/src/lib.rs index 00162a7fd..968698472 100644 --- a/crates/http/src/lib.rs +++ b/crates/http/src/lib.rs @@ -82,7 +82,7 @@ pub mod service; pub use http; pub use http::Response; -pub use server::{HttpServer, HttpServerBuilder}; +pub use server::{CreateSocketCallback, HttpServer, HttpServerBuilder}; /// An incoming request. pub type IncomingRequest = http::Request; @@ -197,6 +197,15 @@ mod tests { // Wait for the server to start tokio::time::sleep(std::time::Duration::from_millis(100)).await; + test_tls_server_inner(addr, versions).await; + + handler.shutdown().await; + handle.await.expect("task failed"); + } + + #[cfg(feature = "tls-rustls")] + #[allow(dead_code)] + async fn test_tls_server_inner(addr: std::net::SocketAddr, versions: &[reqwest::Version]) { let url = format!("https://{}/", addr); for version in versions { @@ -222,16 +231,13 @@ mod tests { let resp = client .execute(request) .await - .unwrap_or_else(|_| panic!("failed to get response version {:?}", version)) + .unwrap_or_else(|e| panic!("failed to get response version {:?}: {}", version, e)) .text() .await .expect("failed to get text"); assert_eq!(resp, RESPONSE_TEXT); } - - handler.shutdown().await; - handle.await.expect("task failed"); } #[tokio::test] @@ -605,4 +611,61 @@ mod tests { test_tls_server(builder, &[reqwest::Version::HTTP_2, reqwest::Version::HTTP_3]).await; } + + #[tokio::test] + #[should_panic(expected="Address already in use")] + #[cfg(all(feature = "http2", feature = "http3", feature = "tls-rustls"))] + async fn multi_bind_no_reuseport_fails() { + struct TestBody; + + impl http_body::Body for TestBody { + type Data = bytes::Bytes; + type Error = Infallible; + + fn poll_frame( + self: std::pin::Pin<&mut Self>, + _cx: &mut std::task::Context<'_>, + ) -> std::task::Poll, Self::Error>>> { + std::task::Poll::Ready(Some(Ok(http_body::Frame::data(bytes::Bytes::from(RESPONSE_TEXT))))) + } + } + + let addr = get_available_addr().expect("failed to get available address"); + + let addr0 = addr.clone(); + + let t0 = tokio::spawn(async move { + let builder = HttpServer::builder() + .service_factory(service_clone_factory(fn_http_service(|_req| async { + Ok::<_, Infallible>(http::Response::new(TestBody)) + }))) + .rustls_config(rustls_config()) + .enable_http3(true) + .enable_http2(true) + .bind(addr0); + builder.build().run().await.expect("") + }); + let addr1 = addr.clone(); + let t1 = tokio::spawn(async move { + // Wait for a short time to definitely create this server AFTER the other + let builder = HttpServer::builder() + .service_factory(service_clone_factory(fn_http_service(|_req| async { + Ok::<_, Infallible>(http::Response::new(TestBody)) + }))) + .rustls_config(rustls_config()) + .enable_http3(true) + .enable_http2(true) + .bind(addr1); + builder.build().run().await.expect("") + }); + + tokio::select! { + res0 = t0 => { + res0.unwrap() + } + res1 = t1 => { + res1.unwrap() + } + } + } } diff --git a/crates/http/src/server.rs b/crates/http/src/server.rs index 02c210446..a4e54b9ca 100644 --- a/crates/http/src/server.rs +++ b/crates/http/src/server.rs @@ -1,5 +1,6 @@ use std::fmt::Debug; use std::net::SocketAddr; +use std::sync::Arc; use crate::error::Error; use crate::service::{HttpService, HttpServiceFactory}; @@ -40,6 +41,16 @@ pub struct HttpServer { #[cfg(feature = "http3")] #[cfg_attr(docsrs, doc(cfg(feature = "http3")))] enable_http3: bool, + /// Callback to create a custom socket used for http1 and http2. + /// The socket should be a tcp socket which is already bound and listening. + #[cfg(any(feature = "http1", feature = "http2"))] + #[cfg_attr(docsrs, doc(cfg(any(feature = "http1", feature = "http2"))))] + create_custom_h12_sock: Option, + /// Callback to configure socket used for http3 + /// The socket should be a udp socket which is already bound (don't call listen for udp). + #[cfg(feature = "http3")] + #[cfg_attr(docsrs, doc(cfg(feature = "http3")))] + create_custom_h3_sock: Option, /// rustls config. /// /// Use this field to set the server into TLS mode. @@ -213,6 +224,7 @@ where .service_factory(self.service_factory) .bind(self.bind) .rustls_config(_rustls_config) + .maybe_create_custom_sock(self.create_custom_h3_sock.clone()) .build(); return backend.run().await; @@ -224,6 +236,7 @@ where .worker_tasks(self.worker_tasks) .service_factory(self.service_factory) .bind(self.bind) + .maybe_create_custom_sock(self.create_custom_h12_sock.clone()) .rustls_config(_rustls_config); #[cfg(feature = "http1")] @@ -241,6 +254,7 @@ where .worker_tasks(self.worker_tasks) .service_factory(self.service_factory.clone()) .bind(self.bind) + .maybe_create_custom_sock(self.create_custom_h12_sock.clone()) .rustls_config(_rustls_config.clone()); #[cfg(feature = "http1")] @@ -256,6 +270,7 @@ where .worker_tasks(self.worker_tasks) .service_factory(self.service_factory) .bind(self.bind) + .maybe_create_custom_sock(self.create_custom_h3_sock.clone()) .rustls_config(_rustls_config) .build() .run(); @@ -283,6 +298,7 @@ where .ctx(self.ctx) .worker_tasks(self.worker_tasks) .service_factory(self.service_factory) + .maybe_create_custom_sock(self.create_custom_h12_sock.clone()) .bind(self.bind); #[cfg(feature = "http1")] @@ -297,3 +313,27 @@ where Ok(()) } } + +/// A callback used to configure a socket2 instance. +/// +/// This can be used to tweak options on the TCP/UDP layer +#[derive(Clone)] +pub struct CreateSocketCallback(Arc std::io::Result + Send + Sync>); + +impl CreateSocketCallback { + /// Create a new `ConfigureSocketCallback` from the given callback function. + pub fn new std::io::Result + Send + Sync + 'static>(f: F) -> Self { + Self(Arc::new(f)) + } + + /// Create a new `ConfigureSocketCallback` from the given callback function. + pub fn call(&self, sock: SocketAddr) -> std::io::Result { + (self.0)(sock) + } +} + +impl std::fmt::Debug for CreateSocketCallback { + fn fmt(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result { + write!(f, "CreateSocketCallback ") + } +} diff --git a/crates/workspace-hack/Cargo.toml b/crates/workspace-hack/Cargo.toml index 22e8d1d51..7d706cc78 100644 --- a/crates/workspace-hack/Cargo.toml +++ b/crates/workspace-hack/Cargo.toml @@ -139,6 +139,7 @@ quinn-proto = { version = "0.11", default-features = false, features = ["log", " quinn-udp = { version = "0.5", default-features = false, features = ["log"] } rustls = { version = "0.23", default-features = false, features = ["ring"] } rustls-webpki = { version = "0.102", default-features = false, features = ["aws_lc_rs", "ring", "std"] } +socket2 = { version = "0.5", default-features = false, features = ["all"] } stable_deref_trait = { version = "1" } tokio-rustls = { version = "0.26", default-features = false, features = ["aws_lc_rs", "logging", "ring", "tls12"] } tower = { version = "0.5", default-features = false, features = ["timeout"] } @@ -159,6 +160,7 @@ quinn-proto = { version = "0.11", default-features = false, features = ["log", " quinn-udp = { version = "0.5", default-features = false, features = ["log"] } rustls = { version = "0.23", default-features = false, features = ["ring"] } rustls-webpki = { version = "0.102", default-features = false, features = ["aws_lc_rs", "ring", "std"] } +socket2 = { version = "0.5", default-features = false, features = ["all"] } stable_deref_trait = { version = "1" } tokio-rustls = { version = "0.26", default-features = false, features = ["aws_lc_rs", "logging", "ring", "tls12"] } tower = { version = "0.5", default-features = false, features = ["timeout"] } @@ -178,6 +180,7 @@ quinn-proto = { version = "0.11", default-features = false, features = ["log", " quinn-udp = { version = "0.5", default-features = false, features = ["log"] } rustls = { version = "0.23", default-features = false, features = ["ring"] } rustls-webpki = { version = "0.102", default-features = false, features = ["aws_lc_rs", "ring", "std"] } +socket2 = { version = "0.5", default-features = false, features = ["all"] } stable_deref_trait = { version = "1" } tokio-rustls = { version = "0.26", default-features = false, features = ["aws_lc_rs", "logging", "ring", "tls12"] } tower = { version = "0.5", default-features = false, features = ["timeout"] } @@ -198,6 +201,7 @@ quinn-proto = { version = "0.11", default-features = false, features = ["log", " quinn-udp = { version = "0.5", default-features = false, features = ["log"] } rustls = { version = "0.23", default-features = false, features = ["ring"] } rustls-webpki = { version = "0.102", default-features = false, features = ["aws_lc_rs", "ring", "std"] } +socket2 = { version = "0.5", default-features = false, features = ["all"] } stable_deref_trait = { version = "1" } tokio-rustls = { version = "0.26", default-features = false, features = ["aws_lc_rs", "logging", "ring", "tls12"] } tower = { version = "0.5", default-features = false, features = ["timeout"] } @@ -218,6 +222,7 @@ quinn-proto = { version = "0.11", default-features = false, features = ["log", " quinn-udp = { version = "0.5", default-features = false, features = ["log"] } rustls = { version = "0.23", default-features = false, features = ["ring"] } rustls-webpki = { version = "0.102", default-features = false, features = ["aws_lc_rs", "ring", "std"] } +socket2 = { version = "0.5", default-features = false, features = ["all"] } stable_deref_trait = { version = "1" } tokio-rustls = { version = "0.26", default-features = false, features = ["aws_lc_rs", "logging", "ring", "tls12"] } tower = { version = "0.5", default-features = false, features = ["timeout"] } @@ -239,6 +244,7 @@ quinn-proto = { version = "0.11", default-features = false, features = ["log", " quinn-udp = { version = "0.5", default-features = false, features = ["log"] } rustls = { version = "0.23", default-features = false, features = ["ring"] } rustls-webpki = { version = "0.102", default-features = false, features = ["aws_lc_rs", "ring", "std"] } +socket2 = { version = "0.5", default-features = false, features = ["all"] } stable_deref_trait = { version = "1" } tokio-rustls = { version = "0.26", default-features = false, features = ["aws_lc_rs", "logging", "ring", "tls12"] } tower = { version = "0.5", default-features = false, features = ["timeout"] } @@ -259,6 +265,7 @@ quinn-proto = { version = "0.11", default-features = false, features = ["log", " quinn-udp = { version = "0.5", default-features = false, features = ["log"] } rustls = { version = "0.23", default-features = false, features = ["ring"] } rustls-webpki = { version = "0.102", default-features = false, features = ["aws_lc_rs", "ring", "std"] } +socket2 = { version = "0.5", default-features = false, features = ["all"] } stable_deref_trait = { version = "1" } tokio-rustls = { version = "0.26", default-features = false, features = ["aws_lc_rs", "logging", "ring", "tls12"] } tower = { version = "0.5", default-features = false, features = ["timeout"] } @@ -280,6 +287,7 @@ quinn-proto = { version = "0.11", default-features = false, features = ["log", " quinn-udp = { version = "0.5", default-features = false, features = ["log"] } rustls = { version = "0.23", default-features = false, features = ["ring"] } rustls-webpki = { version = "0.102", default-features = false, features = ["aws_lc_rs", "ring", "std"] } +socket2 = { version = "0.5", default-features = false, features = ["all"] } stable_deref_trait = { version = "1" } tokio-rustls = { version = "0.26", default-features = false, features = ["aws_lc_rs", "logging", "ring", "tls12"] } tower = { version = "0.5", default-features = false, features = ["timeout"] } @@ -295,6 +303,7 @@ quinn-proto = { version = "0.11", default-features = false, features = ["log", " quinn-udp = { version = "0.5", default-features = false, features = ["log"] } rustls = { version = "0.23", default-features = false, features = ["ring"] } rustls-webpki = { version = "0.102", default-features = false, features = ["aws_lc_rs", "ring", "std"] } +socket2 = { version = "0.5", default-features = false, features = ["all"] } tokio-rustls = { version = "0.26", default-features = false, features = ["aws_lc_rs", "logging", "ring", "tls12"] } tokio-stream = { version = "0.1", features = ["sync"] } tower = { version = "0.5", default-features = false, features = ["timeout"] } @@ -312,6 +321,7 @@ quinn-proto = { version = "0.11", default-features = false, features = ["log", " quinn-udp = { version = "0.5", default-features = false, features = ["log"] } rustls = { version = "0.23", default-features = false, features = ["ring"] } rustls-webpki = { version = "0.102", default-features = false, features = ["aws_lc_rs", "ring", "std"] } +socket2 = { version = "0.5", default-features = false, features = ["all"] } tokio-rustls = { version = "0.26", default-features = false, features = ["aws_lc_rs", "logging", "ring", "tls12"] } tokio-stream = { version = "0.1", features = ["sync"] } tower = { version = "0.5", default-features = false, features = ["timeout"] } @@ -328,6 +338,7 @@ quinn-proto = { version = "0.11", default-features = false, features = ["log", " quinn-udp = { version = "0.5", default-features = false, features = ["log"] } rustls = { version = "0.23", default-features = false, features = ["ring"] } rustls-webpki = { version = "0.102", default-features = false, features = ["aws_lc_rs", "ring", "std"] } +socket2 = { version = "0.5", default-features = false, features = ["all"] } tokio-rustls = { version = "0.26", default-features = false, features = ["aws_lc_rs", "logging", "ring", "tls12"] } tokio-stream = { version = "0.1", features = ["sync"] } tower = { version = "0.5", default-features = false, features = ["timeout"] } @@ -345,6 +356,7 @@ quinn-proto = { version = "0.11", default-features = false, features = ["log", " quinn-udp = { version = "0.5", default-features = false, features = ["log"] } rustls = { version = "0.23", default-features = false, features = ["ring"] } rustls-webpki = { version = "0.102", default-features = false, features = ["aws_lc_rs", "ring", "std"] } +socket2 = { version = "0.5", default-features = false, features = ["all"] } tokio-rustls = { version = "0.26", default-features = false, features = ["aws_lc_rs", "logging", "ring", "tls12"] } tokio-stream = { version = "0.1", features = ["sync"] } tower = { version = "0.5", default-features = false, features = ["timeout"] }