Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
22 commits
Select commit Hold shift + click to select a range
b3ac7a4
Implement callbacks to configure the sockets used as listeners
DrSloth Mar 13, 2025
5d488ad
Run just fmt
DrSloth Mar 14, 2025
710788a
Execute workspace hack and add feature flags to the configure socket …
DrSloth Mar 14, 2025
9dd70fb
Make ConfigureSocketCallback Send
DrSloth Mar 18, 2025
d929ed1
Make ConfigureSocketCallback also Sync
DrSloth Mar 18, 2025
a7d66e7
fix(http): feature gate
lennartkloock Mar 19, 2025
890d7ef
fix(http): udp socket
lennartkloock Mar 19, 2025
202a9d1
fix(http): `IPV6_V6ONLY` flag
lennartkloock Mar 19, 2025
f6cd705
docs: add changelog file
lennartkloock Mar 19, 2025
64e2f8f
Change shebangs in Justfile from '/bin/bash/' to '/usr/bin/env bash'
DrSloth Mar 20, 2025
e6efd14
Auto merge of https://github.com/ScuffleCloud/scuffle/pull/410 - chan…
scuffle-brawl[bot] Mar 20, 2025
0e3b2a7
Implement callbacks to configure the sockets used as listeners
DrSloth Mar 13, 2025
9609aca
Run just fmt
DrSloth Mar 14, 2025
79ce58f
Execute workspace hack and add feature flags to the configure socket …
DrSloth Mar 14, 2025
1b9ea7e
Make ConfigureSocketCallback Send
DrSloth Mar 18, 2025
97cc48e
Make ConfigureSocketCallback also Sync
DrSloth Mar 18, 2025
8b8dc97
fix(http): feature gate
lennartkloock Mar 19, 2025
c48e43f
fix(http): udp socket
lennartkloock Mar 19, 2025
0f484c8
fix(http): `IPV6_V6ONLY` flag
lennartkloock Mar 19, 2025
358c557
docs: add changelog file
lennartkloock Mar 19, 2025
0739b17
Change the ConfigureSocketCallback to a CreateSocketCallback and star…
DrSloth Mar 24, 2025
08559d9
Adapt unit test
DrSloth Mar 24, 2025
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions Cargo.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

6 changes: 3 additions & 3 deletions Justfile
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@ lint *args:

alias coverage := test
test *args:
#!/bin/bash
#!/usr/bin/env bash
set -euo pipefail

INSTA_FORCE_PASS=1 cargo +{{RUST_TOOLCHAIN}} llvm-cov clean --workspace
Expand All @@ -34,7 +34,7 @@ coverage-serve:
miniserve target/llvm-cov/html --index index.html --port 3000

grind *args:
#!/bin/bash
#!/usr/bin/env bash
set -euo pipefail

# Runs valgrind on the tests.
Expand All @@ -46,7 +46,7 @@ grind *args:

alias docs := doc
doc *args:
#!/bin/bash
#!/usr/bin/env bash
set -euo pipefail

# `--cfg docsrs` enables us to write feature hints in the form of `#[cfg_attr(docsrs, doc(cfg(feature = "some-feature")))]`
Expand Down
4 changes: 4 additions & 0 deletions changes.d/pr-407.toml
Original file line number Diff line number Diff line change
@@ -0,0 +1,4 @@
[[scuffle-http]]
category = "feat"
description = "add ability to configure sockets using callbacks"
authors = ["@DrSloth", "@lennartkloock"]
1 change: 1 addition & 0 deletions crates/http/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@ futures = { version = "0.3.31", default-features = false, features = ["alloc"]}
bon = "3.3.2"
pin-project-lite = "0.2.16"
scuffle-context.workspace = true
socket2 = "0.5.8"

# HTTP parsing
http = "1.2.0"
Expand Down
25 changes: 24 additions & 1 deletion crates/http/src/backend/h3/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@ use tracing::Instrument;
use utils::copy_response_body;

use crate::error::Error;
use crate::server::CreateSocketCallback;
use crate::service::{HttpService, HttpServiceFactory};

pub mod body;
Expand All @@ -36,6 +37,8 @@ pub struct Http3Backend<F> {
/// Use `[::]` for a dual-stack listener.
/// For example, use `[::]:80` to bind to port 80 on both IPv4 and IPv6.
bind: SocketAddr,
/// Callback to configure socket
create_custom_sock: Option<CreateSocketCallback>,
/// rustls config.
///
/// Use this field to set the server into TLS mode.
Expand Down Expand Up @@ -67,7 +70,27 @@ where
let server_config = h3_quinn::quinn::ServerConfig::with_crypto(Arc::new(crypto));

// Bind the UDP socket
let socket = std::net::UdpSocket::bind(self.bind)?;
let socket = {
let sock = if let Some(cfg_fn) = self.create_custom_sock.as_ref() {
cfg_fn.call(self.bind)?
} else {
let sock = socket2::Socket::new(
match self.bind {
SocketAddr::V4(_) => socket2::Domain::IPV4,
SocketAddr::V6(_) => socket2::Domain::IPV6,
},
socket2::Type::DGRAM,
Some(socket2::Protocol::UDP),
)?;

sock.set_nonblocking(true)?;
sock.bind(&socket2::SockAddr::from(self.bind))?;

sock
};

std::net::UdpSocket::from(sock)
};

// Runtime for the quinn endpoint
let runtime = h3_quinn::quinn::default_runtime().ok_or_else(|| io::Error::other("no async runtime found"))?;
Expand Down
26 changes: 25 additions & 1 deletion crates/http/src/backend/hyper/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@ use scuffle_context::ContextFutExt;
use tracing::Instrument;

use crate::error::Error;
use crate::server::CreateSocketCallback;
use crate::service::{HttpService, HttpServiceFactory};

mod handler;
Expand All @@ -33,6 +34,8 @@ pub struct HyperBackend<F> {
/// Use `[::]` for a dual-stack listener.
/// For example, use `[::]:80` to bind to port 80 on both IPv4 and IPv6.
bind: SocketAddr,
/// Callback to create a custom socket
create_custom_sock: Option<CreateSocketCallback>,
/// rustls config.
///
/// Use this field to set the server into TLS mode.
Expand Down Expand Up @@ -79,7 +82,28 @@ where
}

// We have to create an std listener first because the tokio listener isn't clonable
let listener = tokio::net::TcpListener::bind(self.bind).await?.into_std()?;
let listener = {
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I suspect this might not work on windows. Previously when adding windows support we found that if the listener was constructed outside of tokio it would block the eventloop even if non-blocking was set to true. We didnt investigate this further than that, but perhaps this might be a good time to understand why this behaviour was the case when using std::net::TcpListener::bind

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I have run just fmt now. I haven't had this problem on windows yet at least when using the socket2 crate but my time spent in windows is not a lot.

let sock = if let Some(cfg_fn) = self.create_custom_sock.as_ref() {
cfg_fn.call(self.bind)?
} else {
let mut sock = socket2::Socket::new(
match self.bind {
SocketAddr::V4(_) => socket2::Domain::IPV4,
SocketAddr::V6(_) => socket2::Domain::IPV6,
},
socket2::Type::STREAM,
Some(socket2::Protocol::TCP),
)?;

sock.set_nonblocking(true)?;
sock.bind(&socket2::SockAddr::from(self.bind))?;
sock.listen(128)?;

sock
};

std::net::TcpListener::from(sock)
};

#[cfg(feature = "tls-rustls")]
let tls_acceptor = self
Expand Down
73 changes: 68 additions & 5 deletions crates/http/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -82,7 +82,7 @@ pub mod service;

pub use http;
pub use http::Response;
pub use server::{HttpServer, HttpServerBuilder};
pub use server::{CreateSocketCallback, HttpServer, HttpServerBuilder};

/// An incoming request.
pub type IncomingRequest = http::Request<body::IncomingBody>;
Expand Down Expand Up @@ -197,6 +197,15 @@ mod tests {
// Wait for the server to start
tokio::time::sleep(std::time::Duration::from_millis(100)).await;

test_tls_server_inner(addr, versions).await;

handler.shutdown().await;
handle.await.expect("task failed");
}

#[cfg(feature = "tls-rustls")]
#[allow(dead_code)]
async fn test_tls_server_inner(addr: std::net::SocketAddr, versions: &[reqwest::Version]) {
let url = format!("https://{}/", addr);

for version in versions {
Expand All @@ -222,16 +231,13 @@ mod tests {
let resp = client
.execute(request)
.await
.unwrap_or_else(|_| panic!("failed to get response version {:?}", version))
.unwrap_or_else(|e| panic!("failed to get response version {:?}: {}", version, e))
.text()
.await
.expect("failed to get text");

assert_eq!(resp, RESPONSE_TEXT);
}

handler.shutdown().await;
handle.await.expect("task failed");
}

#[tokio::test]
Expand Down Expand Up @@ -605,4 +611,61 @@ mod tests {

test_tls_server(builder, &[reqwest::Version::HTTP_2, reqwest::Version::HTTP_3]).await;
}

#[tokio::test]
#[should_panic(expected="Address already in use")]
#[cfg(all(feature = "http2", feature = "http3", feature = "tls-rustls"))]
async fn multi_bind_no_reuseport_fails() {
struct TestBody;

impl http_body::Body for TestBody {
type Data = bytes::Bytes;
type Error = Infallible;

fn poll_frame(
self: std::pin::Pin<&mut Self>,
_cx: &mut std::task::Context<'_>,
) -> std::task::Poll<Option<Result<http_body::Frame<Self::Data>, Self::Error>>> {
std::task::Poll::Ready(Some(Ok(http_body::Frame::data(bytes::Bytes::from(RESPONSE_TEXT)))))
}
}

let addr = get_available_addr().expect("failed to get available address");

let addr0 = addr.clone();

let t0 = tokio::spawn(async move {
let builder = HttpServer::builder()
.service_factory(service_clone_factory(fn_http_service(|_req| async {
Ok::<_, Infallible>(http::Response::new(TestBody))
})))
.rustls_config(rustls_config())
.enable_http3(true)
.enable_http2(true)
.bind(addr0);
builder.build().run().await.expect("")
});
let addr1 = addr.clone();
let t1 = tokio::spawn(async move {
// Wait for a short time to definitely create this server AFTER the other
let builder = HttpServer::builder()
.service_factory(service_clone_factory(fn_http_service(|_req| async {
Ok::<_, Infallible>(http::Response::new(TestBody))
})))
.rustls_config(rustls_config())
.enable_http3(true)
.enable_http2(true)
.bind(addr1);
builder.build().run().await.expect("")
});

tokio::select! {
res0 = t0 => {
res0.unwrap()
}
res1 = t1 => {
res1.unwrap()
}
}
}
}
40 changes: 40 additions & 0 deletions crates/http/src/server.rs
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
use std::fmt::Debug;
use std::net::SocketAddr;
use std::sync::Arc;

use crate::error::Error;
use crate::service::{HttpService, HttpServiceFactory};
Expand Down Expand Up @@ -40,6 +41,16 @@ pub struct HttpServer<F> {
#[cfg(feature = "http3")]
#[cfg_attr(docsrs, doc(cfg(feature = "http3")))]
enable_http3: bool,
/// Callback to create a custom socket used for http1 and http2.
/// The socket should be a tcp socket which is already bound and listening.
#[cfg(any(feature = "http1", feature = "http2"))]
#[cfg_attr(docsrs, doc(cfg(any(feature = "http1", feature = "http2"))))]
create_custom_h12_sock: Option<CreateSocketCallback>,
/// Callback to configure socket used for http3
/// The socket should be a udp socket which is already bound (don't call listen for udp).
#[cfg(feature = "http3")]
#[cfg_attr(docsrs, doc(cfg(feature = "http3")))]
create_custom_h3_sock: Option<CreateSocketCallback>,
/// rustls config.
///
/// Use this field to set the server into TLS mode.
Expand Down Expand Up @@ -213,6 +224,7 @@ where
.service_factory(self.service_factory)
.bind(self.bind)
.rustls_config(_rustls_config)
.maybe_create_custom_sock(self.create_custom_h3_sock.clone())
.build();

return backend.run().await;
Expand All @@ -224,6 +236,7 @@ where
.worker_tasks(self.worker_tasks)
.service_factory(self.service_factory)
.bind(self.bind)
.maybe_create_custom_sock(self.create_custom_h12_sock.clone())
.rustls_config(_rustls_config);

#[cfg(feature = "http1")]
Expand All @@ -241,6 +254,7 @@ where
.worker_tasks(self.worker_tasks)
.service_factory(self.service_factory.clone())
.bind(self.bind)
.maybe_create_custom_sock(self.create_custom_h12_sock.clone())
.rustls_config(_rustls_config.clone());

#[cfg(feature = "http1")]
Expand All @@ -256,6 +270,7 @@ where
.worker_tasks(self.worker_tasks)
.service_factory(self.service_factory)
.bind(self.bind)
.maybe_create_custom_sock(self.create_custom_h3_sock.clone())
.rustls_config(_rustls_config)
.build()
.run();
Expand Down Expand Up @@ -283,6 +298,7 @@ where
.ctx(self.ctx)
.worker_tasks(self.worker_tasks)
.service_factory(self.service_factory)
.maybe_create_custom_sock(self.create_custom_h12_sock.clone())
.bind(self.bind);

#[cfg(feature = "http1")]
Expand All @@ -297,3 +313,27 @@ where
Ok(())
}
}

/// A callback used to configure a socket2 instance.
///
/// This can be used to tweak options on the TCP/UDP layer
#[derive(Clone)]
pub struct CreateSocketCallback(Arc<dyn Fn(SocketAddr) -> std::io::Result<socket2::Socket> + Send + Sync>);

impl CreateSocketCallback {
/// Create a new `ConfigureSocketCallback` from the given callback function.
pub fn new<F: Fn(SocketAddr) -> std::io::Result<socket2::Socket> + Send + Sync + 'static>(f: F) -> Self {
Self(Arc::new(f))
}

/// Create a new `ConfigureSocketCallback` from the given callback function.
pub fn call(&self, sock: SocketAddr) -> std::io::Result<socket2::Socket> {
(self.0)(sock)
}
}

impl std::fmt::Debug for CreateSocketCallback {
fn fmt(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result {
write!(f, "CreateSocketCallback ")
}
}
Loading