diff --git a/.github/workflows/build.yml b/.github/workflows/build.yml index 05e73b0..1ac7d25 100644 --- a/.github/workflows/build.yml +++ b/.github/workflows/build.yml @@ -21,6 +21,7 @@ jobs: - stable - beta - nightly + - "1.75.0" steps: - uses: actions/checkout@v4 diff --git a/Cargo.toml b/Cargo.toml index 2908623..d86456d 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -2,6 +2,7 @@ name = "hudsucker" version = "0.21.0" edition = "2021" +rust-version = "1.75.0" description = "MITM HTTP/S proxy" documentation = "https://docs.rs/hudsucker" readme = "README.md" @@ -18,44 +19,45 @@ rustdoc-args = ["--cfg", "docsrs"] [dependencies] async-compression = { version = "0.4.0", features = ["tokio", "brotli", "gzip", "zlib", "zstd"], optional = true } -async-trait = "0.1.67" bstr = "1.0.0" -bytes = "1.0.0" futures = "0.3.11" -http = "0.2.0" -hyper = { version = "0.14.15", features = ["client", "http1", "server", "tcp"] } -hyper-rustls = { version = "0.24.0", default-features = false, features = ["http1", "logging", "tls12", "webpki-tokio"], optional = true } -hyper-tls = { version = "0.5.0", optional = true } -hyper-tungstenite = "0.11.1" +http = "1.1.0" +http-body-util = "0.1.0" +hyper = "1.1.0" +hyper-rustls = { version = "0.26.0", default-features = false, features = ["http1", "logging", "ring", "tls12", "webpki-tokio"], optional = true } +hyper-tls = { version = "0.6.0", optional = true } +hyper-tungstenite = "0.13.0" +hyper-util = { version="0.1.3", features = ["client-legacy", "server", "http1"] } moka = { version = "0.12.0", features = ["future"], optional = true } -openssl = { version = "0.10.39", optional = true } +openssl = { version = "0.10.46", optional = true } rand = { version = "0.8.0", optional = true } rcgen = { version = "0.12.0", features = ["x509-parser"], optional = true } thiserror = "1.0.30" time = { version = "0.3.20", optional = true } -tokio = { version = "1.24.2", features = ["rt"] } -tokio-rustls = "0.24.0" -tokio-tungstenite = "0.20.0" -tokio-util = { version = "0.7.0", features = ["io"], optional = true } -tracing = { version = "0.1.23", features = ["log"] } +tokio = { version = "1.24.2", features = ["macros", "rt"] } +tokio-graceful = "0.1.6" +tokio-rustls = "0.25.0" +tokio-tungstenite = "0.21.0" +tokio-util = { version = "0.7.1", features = ["io"], optional = true } +tracing = { version = "0.1.35", features = ["log"] } [dev-dependencies] async-http-proxy = { version = "1.2.5", features = ["runtime-tokio"] } criterion = { version = "0.5.0", features = ["async_tokio"] } reqwest = "0.11.10" -rustls-native-certs = "0.6.2" +rustls-native-certs = "0.7.0" rustls-pemfile = "2.0.0" -tls-listener = { version = "0.8.0", features = ["rustls", "hyper-h1", "hyper-h2"] } +tls-listener = { version = "0.9.1", features = ["rustls"] } tokio = { version = "1.24.2", features = ["full"] } tokio-native-tls = "0.3.1" -tracing-subscriber = "0.3.0" +tracing-subscriber = "0.3.8" x509-parser = "0.16.0" [features] -decoder = ["dep:async-compression", "dep:tokio-util", "hyper/stream", "tokio/io-util"] +decoder = ["dep:async-compression", "dep:tokio-util", "tokio/io-util"] default = ["decoder", "rcgen-ca", "rustls-client"] full = ["decoder", "http2", "native-tls-client", "openssl-ca", "rcgen-ca", "rustls-client"] -http2 = ["hyper/http2", "hyper-rustls?/http2"] +http2 = ["hyper-util/http2", "hyper-rustls?/http2"] native-tls-client = ["dep:hyper-tls", "tokio-tungstenite/native-tls"] openssl-ca = ["dep:openssl", "dep:moka"] rcgen-ca = ["dep:rcgen", "dep:moka", "dep:time", "dep:rand"] diff --git a/benches/certificate_authorities.rs b/benches/certificate_authorities.rs index b9692d8..475a464 100644 --- a/benches/certificate_authorities.rs +++ b/benches/certificate_authorities.rs @@ -3,7 +3,6 @@ use http::uri::Authority; use hudsucker::{ certificate_authority::{CertificateAuthority, OpensslAuthority, RcgenAuthority}, openssl::{hash::MessageDigest, pkey::PKey, x509::X509}, - rustls, }; use rustls_pemfile as pemfile; @@ -16,21 +15,13 @@ fn runtime() -> tokio::runtime::Runtime { fn build_rcgen_ca(cache_size: u64) -> RcgenAuthority { let mut private_key_bytes: &[u8] = include_bytes!("../examples/ca/hudsucker.key"); let mut ca_cert_bytes: &[u8] = include_bytes!("../examples/ca/hudsucker.cer"); - let private_key = rustls::PrivateKey( - pemfile::pkcs8_private_keys(&mut private_key_bytes) - .next() - .unwrap() - .expect("Failed to parse private key") - .secret_pkcs8_der() - .to_vec(), - ); - let ca_cert = rustls::Certificate( - pemfile::certs(&mut ca_cert_bytes) - .next() - .unwrap() - .expect("Failed to parse CA certificate") - .to_vec(), - ); + let private_key = pemfile::private_key(&mut private_key_bytes) + .unwrap() + .expect("Failed to parse private key"); + let ca_cert = pemfile::certs(&mut ca_cert_bytes) + .next() + .unwrap() + .expect("Failed to parse CA certificate"); RcgenAuthority::new(private_key, ca_cert, cache_size) .expect("Failed to create Certificate Authority") diff --git a/benches/decoder.rs b/benches/decoder.rs index 56ce879..6c9644c 100644 --- a/benches/decoder.rs +++ b/benches/decoder.rs @@ -4,8 +4,9 @@ use hudsucker::{ decode_request, decode_response, hyper::{ header::{CONTENT_ENCODING, CONTENT_LENGTH}, - Body, Request, Response, + Request, Response, }, + Body, }; use tokio::io::BufReader; use tokio_util::io::ReaderStream; diff --git a/benches/proxy.rs b/benches/proxy.rs index 468a6bf..39fd514 100644 --- a/benches/proxy.rs +++ b/benches/proxy.rs @@ -1,21 +1,21 @@ use criterion::{criterion_group, criterion_main, Criterion, Throughput}; +use http_body_util::Empty; use hudsucker::{ certificate_authority::{CertificateAuthority, RcgenAuthority}, - hyper::{ - client::connect::HttpConnector, - service::{make_service_fn, service_fn}, - Body, Method, Request, Response, Server, + hyper::{body::Incoming, service::service_fn, Method, Request, Response}, + hyper_util::client::legacy::{connect::HttpConnector, Client}, + hyper_util::{ + rt::{TokioExecutor, TokioIo}, + server::conn::auto, }, - rustls, Proxy, + Body, Proxy, }; use reqwest::Certificate; use rustls_pemfile as pemfile; -use std::{ - convert::Infallible, - net::{SocketAddr, TcpListener}, -}; +use std::{convert::Infallible, net::SocketAddr}; use tls_listener::TlsListener; -use tokio::sync::oneshot::Sender; +use tokio::{net::TcpListener, sync::oneshot::Sender}; +use tokio_graceful::Shutdown; use tokio_native_tls::native_tls; fn runtime() -> tokio::runtime::Runtime { @@ -29,74 +29,109 @@ fn runtime() -> tokio::runtime::Runtime { fn build_ca() -> RcgenAuthority { let mut private_key_bytes: &[u8] = include_bytes!("../examples/ca/hudsucker.key"); let mut ca_cert_bytes: &[u8] = include_bytes!("../examples/ca/hudsucker.cer"); - let private_key = rustls::PrivateKey( - pemfile::pkcs8_private_keys(&mut private_key_bytes) - .next() - .unwrap() - .expect("Failed to parse private key") - .secret_pkcs8_der() - .to_vec(), - ); - let ca_cert = rustls::Certificate( - pemfile::certs(&mut ca_cert_bytes) - .next() - .unwrap() - .expect("Failed to parse CA certificate") - .to_vec(), - ); + let private_key = pemfile::private_key(&mut private_key_bytes) + .unwrap() + .expect("Failed to parse private key"); + let ca_cert = pemfile::certs(&mut ca_cert_bytes) + .next() + .unwrap() + .expect("Failed to parse CA certificate"); RcgenAuthority::new(private_key, ca_cert, 1_000) .expect("Failed to create Certificate Authority") } -async fn test_server(req: Request) -> Result, Infallible> { +async fn test_server(req: Request) -> Result, Infallible> { match (req.method(), req.uri().path()) { (&Method::GET, "/hello") => Ok(Response::new(Body::from("hello, world"))), - _ => Ok(Response::new(Body::empty())), + _ => Ok(Response::new(Body::from(Empty::new()))), } } -fn start_http_server() -> Result<(SocketAddr, Sender<()>), Box> { - let make_svc = make_service_fn(|_| async { Ok::<_, Infallible>(service_fn(test_server)) }); - - let listener = TcpListener::bind(SocketAddr::from(([127, 0, 0, 1], 0)))?; +pub async fn start_http_server() -> Result<(SocketAddr, Sender<()>), Box> { + let listener = TcpListener::bind(SocketAddr::from(([127, 0, 0, 1], 0))).await?; let addr = listener.local_addr()?; - let (tx, rx) = tokio::sync::oneshot::channel(); - tokio::spawn( - Server::from_tcp(listener)? - .serve(make_svc) - .with_graceful_shutdown(async { rx.await.unwrap_or_default() }), - ); + tokio::spawn(async move { + let server = auto::Builder::new(TokioExecutor::new()); + let shutdown = Shutdown::new(async { rx.await.unwrap_or_default() }); + let guard = shutdown.guard_weak(); + + loop { + tokio::select! { + res = listener.accept() => { + let Ok((tcp, _)) = res else { + continue; + }; + + let server = server.clone(); + + shutdown.spawn_task(async move { + server + .serve_connection_with_upgrades(TokioIo::new(tcp), service_fn(test_server)) + .await + .unwrap(); + }); + } + _ = guard.cancelled() => { + break; + } + } + } + + shutdown.shutdown().await; + }); Ok((addr, tx)) } -async fn start_https_server() -> Result<(SocketAddr, Sender<()>), Box> { - let make_svc = make_service_fn(|_| async { Ok::<_, Infallible>(service_fn(test_server)) }); - - let listener = TcpListener::bind(SocketAddr::from(([127, 0, 0, 1], 0)))?; - listener.set_nonblocking(true)?; +pub async fn start_https_server( + ca: impl CertificateAuthority, +) -> Result<(SocketAddr, Sender<()>), Box> { + let listener = TcpListener::bind(SocketAddr::from(([127, 0, 0, 1], 0))).await?; let addr = listener.local_addr()?; - let acceptor: tokio_rustls::TlsAcceptor = build_ca() - .gen_server_config(&format!("localhost:{}", addr.port()).parse().unwrap()) + let acceptor: tokio_rustls::TlsAcceptor = ca + .gen_server_config(&"localhost".parse().unwrap()) .await .into(); - let listener = TlsListener::new(acceptor, tokio::net::TcpListener::from_std(listener)?); - + let mut listener = TlsListener::new(acceptor, listener); let (tx, rx) = tokio::sync::oneshot::channel(); - tokio::spawn( - Server::builder(listener) - .serve(make_svc) - .with_graceful_shutdown(async { rx.await.unwrap_or_default() }), - ); + tokio::spawn(async move { + let server = auto::Builder::new(TokioExecutor::new()); + let shutdown = Shutdown::new(async { rx.await.unwrap_or_default() }); + let guard = shutdown.guard_weak(); + + loop { + tokio::select! { + res = listener.accept() => { + let Ok((tcp, _)) = res else { + continue; + }; + + let server = server.clone(); + + shutdown.spawn_task(async move { + server + .serve_connection_with_upgrades(TokioIo::new(tcp), service_fn(test_server)) + .await + .unwrap(); + }); + } + _ = guard.cancelled() => { + break; + } + } + } + + shutdown.shutdown().await; + }); Ok((addr, tx)) } -fn native_tls_client() -> hyper::client::Client> { +fn native_tls_client() -> Client, Body> { let mut http = HttpConnector::new(); http.enforce_http(false); let ca_cert = @@ -110,25 +145,25 @@ fn native_tls_client() -> hyper::client::Client = (http, tls).into(); - hyper::Client::builder().build(https) + Client::builder(TokioExecutor::new()).build(https) } -fn start_proxy( +async fn start_proxy( ca: impl CertificateAuthority, ) -> Result<(SocketAddr, Sender<()>), Box> { - let listener = TcpListener::bind(SocketAddr::from(([127, 0, 0, 1], 0)))?; + let listener = TcpListener::bind(SocketAddr::from(([127, 0, 0, 1], 0))).await?; let addr = listener.local_addr()?; let (tx, rx) = tokio::sync::oneshot::channel(); - let proxy = Proxy::builder() .with_listener(listener) .with_client(native_tls_client()) .with_ca(ca) + .with_graceful_shutdown(async { + rx.await.unwrap_or_default(); + }) .build(); - tokio::spawn(proxy.start(async { - rx.await.unwrap_or_default(); - })); + tokio::spawn(proxy.start()); Ok((addr, tx)) } @@ -157,9 +192,9 @@ fn bench_local(c: &mut Criterion) { let runtime = runtime(); let _guard = runtime.enter(); - let (proxy_addr, stop_proxy) = start_proxy(build_ca()).unwrap(); - let (http_addr, stop_http) = start_http_server().unwrap(); - let (https_addr, stop_https) = runtime.block_on(start_https_server()).unwrap(); + let (proxy_addr, stop_proxy) = runtime.block_on(start_proxy(build_ca())).unwrap(); + let (http_addr, stop_http) = runtime.block_on(start_http_server()).unwrap(); + let (https_addr, stop_https) = runtime.block_on(start_https_server(build_ca())).unwrap(); let client = build_client(); let proxied_client = build_proxied_client(&proxy_addr.to_string()); @@ -212,7 +247,7 @@ fn bench_remote(c: &mut Criterion) { let runtime = runtime(); let _guard = runtime.enter(); - let (proxy_addr, stop_proxy) = start_proxy(build_ca()).unwrap(); + let (proxy_addr, stop_proxy) = runtime.block_on(start_proxy(build_ca())).unwrap(); let client = build_client(); let proxied_client = build_proxied_client(&proxy_addr.to_string()); diff --git a/examples/log.rs b/examples/log.rs index f4c81ca..520a811 100644 --- a/examples/log.rs +++ b/examples/log.rs @@ -1,7 +1,6 @@ use hudsucker::{ - async_trait::async_trait, certificate_authority::RcgenAuthority, - hyper::{Body, Request, Response}, + hyper::{Request, Response}, tokio_tungstenite::tungstenite::Message, *, }; @@ -18,7 +17,6 @@ async fn shutdown_signal() { #[derive(Clone)] struct LogHandler; -#[async_trait] impl HttpHandler for LogHandler { async fn handle_request( &mut self, @@ -35,7 +33,6 @@ impl HttpHandler for LogHandler { } } -#[async_trait] impl WebSocketHandler for LogHandler { async fn handle_message(&mut self, _ctx: &WebSocketContext, msg: Message) -> Option { println!("{:?}", msg); @@ -49,21 +46,13 @@ async fn main() { let mut private_key_bytes: &[u8] = include_bytes!("ca/hudsucker.key"); let mut ca_cert_bytes: &[u8] = include_bytes!("ca/hudsucker.cer"); - let private_key = rustls::PrivateKey( - pemfile::pkcs8_private_keys(&mut private_key_bytes) - .next() - .unwrap() - .expect("Failed to parse private key") - .secret_pkcs8_der() - .to_vec(), - ); - let ca_cert = rustls::Certificate( - pemfile::certs(&mut ca_cert_bytes) - .next() - .unwrap() - .expect("Failed to parse CA certificate") - .to_vec(), - ); + let private_key = pemfile::private_key(&mut private_key_bytes) + .unwrap() + .expect("Failed to parse private key"); + let ca_cert = pemfile::certs(&mut ca_cert_bytes) + .next() + .unwrap() + .expect("Failed to parse CA certificate"); let ca = RcgenAuthority::new(private_key, ca_cert, 1_000) .expect("Failed to create Certificate Authority"); @@ -74,9 +63,10 @@ async fn main() { .with_ca(ca) .with_http_handler(LogHandler) .with_websocket_handler(LogHandler) + .with_graceful_shutdown(shutdown_signal()) .build(); - if let Err(e) = proxy.start(shutdown_signal()).await { + if let Err(e) = proxy.start().await { error!("{}", e); } } diff --git a/examples/noop.rs b/examples/noop.rs index 7041487..0471f3f 100644 --- a/examples/noop.rs +++ b/examples/noop.rs @@ -15,21 +15,13 @@ async fn main() { let mut private_key_bytes: &[u8] = include_bytes!("ca/hudsucker.key"); let mut ca_cert_bytes: &[u8] = include_bytes!("ca/hudsucker.cer"); - let private_key = rustls::PrivateKey( - pemfile::pkcs8_private_keys(&mut private_key_bytes) - .next() - .unwrap() - .expect("Failed to parse private key") - .secret_pkcs8_der() - .to_vec(), - ); - let ca_cert = rustls::Certificate( - pemfile::certs(&mut ca_cert_bytes) - .next() - .unwrap() - .expect("Failed to parse CA certificate") - .to_vec(), - ); + let private_key = pemfile::private_key(&mut private_key_bytes) + .unwrap() + .expect("Failed to parse private key"); + let ca_cert = pemfile::certs(&mut ca_cert_bytes) + .next() + .unwrap() + .expect("Failed to parse CA certificate"); let ca = RcgenAuthority::new(private_key, ca_cert, 1_000) .expect("Failed to create Certificate Authority"); @@ -38,9 +30,10 @@ async fn main() { .with_addr(SocketAddr::from(([127, 0, 0, 1], 3000))) .with_rustls_client() .with_ca(ca) + .with_graceful_shutdown(shutdown_signal()) .build(); - if let Err(e) = proxy.start(shutdown_signal()).await { + if let Err(e) = proxy.start().await { error!("{}", e); } } diff --git a/examples/openssl.rs b/examples/openssl.rs index c7e9af8..86d34a4 100644 --- a/examples/openssl.rs +++ b/examples/openssl.rs @@ -1,7 +1,6 @@ use hudsucker::{ - async_trait::async_trait, certificate_authority::OpensslAuthority, - hyper::{Body, Request, Response}, + hyper::{Request, Response}, openssl::{hash::MessageDigest, pkey::PKey, x509::X509}, tokio_tungstenite::tungstenite::Message, *, @@ -18,7 +17,6 @@ async fn shutdown_signal() { #[derive(Clone)] struct LogHandler; -#[async_trait] impl HttpHandler for LogHandler { async fn handle_request( &mut self, @@ -35,7 +33,6 @@ impl HttpHandler for LogHandler { } } -#[async_trait] impl WebSocketHandler for LogHandler { async fn handle_message(&mut self, _ctx: &WebSocketContext, msg: Message) -> Option { println!("{:?}", msg); @@ -60,9 +57,10 @@ async fn main() { .with_rustls_client() .with_ca(ca) .with_http_handler(LogHandler) + .with_graceful_shutdown(shutdown_signal()) .build(); - if let Err(e) = proxy.start(shutdown_signal()).await { + if let Err(e) = proxy.start().await { error!("{}", e); } } diff --git a/src/body.rs b/src/body.rs new file mode 100644 index 0000000..d11850e --- /dev/null +++ b/src/body.rs @@ -0,0 +1,151 @@ +use crate::Error; +use futures::{Stream, StreamExt}; +use http_body_util::{combinators::BoxBody, Collected, Empty, Full, StreamBody}; +use hyper::body::{Body as HyperBody, Bytes, Frame, Incoming, SizeHint}; +use std::pin::Pin; + +#[derive(Debug)] +enum Internal { + BoxBody(BoxBody), + Collected(Collected), + Empty(Empty), + Full(Full), + Incoming(Incoming), + String(String), +} + +#[derive(Debug)] +pub struct Body { + inner: Internal, +} + +impl Body { + pub fn wrap_stream(stream: S) -> Self + where + S: Stream> + Send + Sync + 'static, + O: Into, + E: Into, + { + Self { + inner: Internal::BoxBody(BoxBody::new(StreamBody::new( + stream.map(|res| res.map(Into::into).map(Frame::data).map_err(Into::into)), + ))), + } + } +} + +impl HyperBody for Body { + type Data = Bytes; + type Error = crate::Error; + + fn poll_frame( + mut self: std::pin::Pin<&mut Self>, + cx: &mut std::task::Context<'_>, + ) -> std::task::Poll, Self::Error>>> { + match &mut self.inner { + Internal::BoxBody(body) => Pin::new(body).poll_frame(cx), + Internal::Collected(body) => Pin::new(body).poll_frame(cx).map_err(|e| match e {}), + Internal::Empty(body) => Pin::new(body).poll_frame(cx).map_err(|e| match e {}), + Internal::Full(body) => Pin::new(body).poll_frame(cx).map_err(|e| match e {}), + Internal::Incoming(body) => Pin::new(body).poll_frame(cx).map_err(Error::from), + Internal::String(body) => Pin::new(body).poll_frame(cx).map_err(|e| match e {}), + } + } + + fn is_end_stream(&self) -> bool { + match &self.inner { + Internal::BoxBody(body) => body.is_end_stream(), + Internal::Collected(body) => body.is_end_stream(), + Internal::Empty(body) => body.is_end_stream(), + Internal::Full(body) => body.is_end_stream(), + Internal::Incoming(body) => body.is_end_stream(), + Internal::String(body) => body.is_end_stream(), + } + } + + fn size_hint(&self) -> SizeHint { + match &self.inner { + Internal::BoxBody(body) => body.size_hint(), + Internal::Collected(body) => body.size_hint(), + Internal::Empty(body) => body.size_hint(), + Internal::Full(body) => body.size_hint(), + Internal::Incoming(body) => body.size_hint(), + Internal::String(body) => body.size_hint(), + } + } +} + +impl From> for Body { + fn from(value: BoxBody) -> Self { + Self { + inner: Internal::BoxBody(value), + } + } +} + +impl From> for Body { + fn from(value: Collected) -> Self { + Self { + inner: Internal::Collected(value), + } + } +} + +impl From> for Body { + fn from(value: Empty) -> Self { + Self { + inner: Internal::Empty(value), + } + } +} + +impl From> for Body { + fn from(value: Full) -> Self { + Self { + inner: Internal::Full(value), + } + } +} + +impl From for Body { + fn from(value: Incoming) -> Self { + Self { + inner: Internal::Incoming(value), + } + } +} + +impl From> for Body +where + S: Stream, Error>> + Send + Sync + 'static, +{ + fn from(value: StreamBody) -> Self { + Self { + inner: Internal::BoxBody(BoxBody::new(value)), + } + } +} + +impl From for Body { + fn from(value: String) -> Self { + Self { + inner: Internal::String(value), + } + } +} + +impl From<&'static str> for Body { + fn from(value: &'static str) -> Self { + Self { + inner: Internal::Full(Full::new(Bytes::from_static(value.as_bytes()))), + } + } +} + +impl From<&'static [u8]> for Body { + fn from(value: &'static [u8]) -> Self { + Self { + inner: Internal::Full(Full::new(Bytes::from_static(value))), + } + } +} diff --git a/src/certificate_authority/mod.rs b/src/certificate_authority/mod.rs index b61818f..e0bde1f 100644 --- a/src/certificate_authority/mod.rs +++ b/src/certificate_authority/mod.rs @@ -3,8 +3,8 @@ mod openssl_authority; #[cfg(feature = "rcgen-ca")] mod rcgen_authority; -use async_trait::async_trait; use http::uri::Authority; +use std::future::Future; use std::sync::Arc; use tokio_rustls::rustls::ServerConfig; @@ -21,8 +21,10 @@ const NOT_BEFORE_OFFSET: i64 = 60; /// /// Clients should be configured to either trust the provided root certificate, or to ignore /// certificate errors. -#[async_trait] pub trait CertificateAuthority: Send + Sync + 'static { /// Generate ServerConfig for use with rustls. - async fn gen_server_config(&self, authority: &Authority) -> Arc; + fn gen_server_config( + &self, + authority: &Authority, + ) -> impl Future> + Send; } diff --git a/src/certificate_authority/openssl_authority.rs b/src/certificate_authority/openssl_authority.rs index 2df046b..1df6be8 100644 --- a/src/certificate_authority/openssl_authority.rs +++ b/src/certificate_authority/openssl_authority.rs @@ -1,5 +1,4 @@ use crate::certificate_authority::{CertificateAuthority, CACHE_TTL, NOT_BEFORE_OFFSET, TTL_SECS}; -use async_trait::async_trait; use http::uri::Authority; use moka::future::Cache; use openssl::{ @@ -15,7 +14,10 @@ use std::{ sync::Arc, time::{Duration, SystemTime}, }; -use tokio_rustls::rustls::{self, ServerConfig}; +use tokio_rustls::rustls::{ + pki_types::{CertificateDer, PrivateKeyDer, PrivatePkcs8KeyDer}, + ServerConfig, +}; use tracing::debug; /// Issues certificates for use when communicating with clients. @@ -29,7 +31,7 @@ use tracing::debug; /// ```rust /// use hudsucker::{ /// certificate_authority::OpensslAuthority, -/// openssl::{hash::MessageDigest, pkey::PKey, x509::X509} +/// openssl::{hash::MessageDigest, pkey::PKey, x509::X509}, /// }; /// /// let private_key_bytes: &[u8] = include_bytes!("../../examples/ca/hudsucker.key"); @@ -40,10 +42,9 @@ use tracing::debug; /// let ca = OpensslAuthority::new(private_key, ca_cert, MessageDigest::sha256(), 1_000); /// ``` #[cfg_attr(docsrs, doc(cfg(feature = "openssl-ca")))] -#[derive(Clone)] pub struct OpensslAuthority { pkey: PKey, - private_key: rustls::PrivateKey, + private_key: PrivateKeyDer<'static>, ca_cert: X509, hash: MessageDigest, cache: Cache>, @@ -52,10 +53,10 @@ pub struct OpensslAuthority { impl OpensslAuthority { /// Creates a new openssl authority. pub fn new(pkey: PKey, ca_cert: X509, hash: MessageDigest, cache_size: u64) -> Self { - let private_key = rustls::PrivateKey( - pkey.private_key_to_der() + let private_key = PrivateKeyDer::from(PrivatePkcs8KeyDer::from( + pkey.private_key_to_pkcs8() .expect("Failed to encode private key"), - ); + )); Self { pkey, @@ -69,7 +70,7 @@ impl OpensslAuthority { } } - fn gen_cert(&self, authority: &Authority) -> Result { + fn gen_cert(&self, authority: &Authority) -> Result, ErrorStack> { let mut name_builder = X509NameBuilder::new()?; name_builder.append_entry_by_text("CN", authority.host())?; let name = name_builder.build(); @@ -103,11 +104,10 @@ impl OpensslAuthority { x509_builder.sign(&self.pkey, self.hash)?; let x509 = x509_builder.build(); - Ok(rustls::Certificate(x509.to_der()?)) + Ok(CertificateDer::from(x509.to_der()?)) } } -#[async_trait] impl CertificateAuthority for OpensslAuthority { async fn gen_server_config(&self, authority: &Authority) -> Arc { if let Some(server_cfg) = self.cache.get(authority).await { @@ -121,9 +121,8 @@ impl CertificateAuthority for OpensslAuthority { .unwrap_or_else(|_| panic!("Failed to generate certificate for {}", authority))]; let mut server_cfg = ServerConfig::builder() - .with_safe_defaults() .with_no_client_auth() - .with_single_cert(certs, self.private_key.clone()) + .with_single_cert(certs, self.private_key.clone_key()) .expect("Failed to build ServerConfig"); server_cfg.alpn_protocols = vec![ @@ -168,13 +167,13 @@ mod tests { let c3 = ca.gen_cert(&authority1).unwrap(); let c4 = ca.gen_cert(&authority2).unwrap(); - let (_, cert1) = x509_parser::parse_x509_certificate(&c1.0).unwrap(); - let (_, cert2) = x509_parser::parse_x509_certificate(&c2.0).unwrap(); + let (_, cert1) = x509_parser::parse_x509_certificate(&c1).unwrap(); + let (_, cert2) = x509_parser::parse_x509_certificate(&c2).unwrap(); assert_ne!(cert1.raw_serial(), cert2.raw_serial()); - let (_, cert3) = x509_parser::parse_x509_certificate(&c3.0).unwrap(); - let (_, cert4) = x509_parser::parse_x509_certificate(&c4.0).unwrap(); + let (_, cert3) = x509_parser::parse_x509_certificate(&c3).unwrap(); + let (_, cert4) = x509_parser::parse_x509_certificate(&c4).unwrap(); assert_ne!(cert3.raw_serial(), cert4.raw_serial()); diff --git a/src/certificate_authority/rcgen_authority.rs b/src/certificate_authority/rcgen_authority.rs index cc3040c..a4daa7a 100644 --- a/src/certificate_authority/rcgen_authority.rs +++ b/src/certificate_authority/rcgen_authority.rs @@ -2,14 +2,16 @@ use crate::{ certificate_authority::{CertificateAuthority, CACHE_TTL, NOT_BEFORE_OFFSET, TTL_SECS}, Error, }; -use async_trait::async_trait; use http::uri::Authority; use moka::future::Cache; use rand::{thread_rng, Rng}; use rcgen::{DistinguishedName, DnType, KeyPair, SanType}; use std::sync::Arc; use time::{Duration, OffsetDateTime}; -use tokio_rustls::rustls::{self, ServerConfig}; +use tokio_rustls::rustls::{ + pki_types::{CertificateDer, PrivateKeyDer}, + ServerConfig, +}; use tracing::debug; /// Issues certificates for use when communicating with clients. @@ -26,29 +28,20 @@ use tracing::debug; /// /// let mut private_key_bytes: &[u8] = include_bytes!("../../examples/ca/hudsucker.key"); /// let mut ca_cert_bytes: &[u8] = include_bytes!("../../examples/ca/hudsucker.cer"); -/// let private_key = rustls::PrivateKey( -/// pemfile::pkcs8_private_keys(&mut private_key_bytes) -/// .next() -/// .unwrap() -/// .expect("Failed to parse private key") -/// .secret_pkcs8_der() -/// .to_vec(), -/// ); -/// let ca_cert = rustls::Certificate( -/// pemfile::certs(&mut ca_cert_bytes) -/// .next() -/// .unwrap() -/// .expect("Failed to parse CA certificate") -/// .to_vec(), -/// ); +/// let private_key = pemfile::private_key(&mut private_key_bytes) +/// .unwrap() +/// .expect("Failed to parse private key"); +/// let ca_cert = pemfile::certs(&mut ca_cert_bytes) +/// .next() +/// .unwrap() +/// .expect("Failed to parse CA certificate"); /// /// let ca = RcgenAuthority::new(private_key, ca_cert, 1_000).unwrap(); /// ``` #[cfg_attr(docsrs, doc(cfg(feature = "rcgen-ca")))] -#[derive(Clone)] pub struct RcgenAuthority { - private_key: rustls::PrivateKey, - ca_cert: rustls::Certificate, + private_key: PrivateKeyDer<'static>, + ca_cert: CertificateDer<'static>, cache: Cache>, } @@ -60,8 +53,8 @@ impl RcgenAuthority { /// This will return an error if the provided key or certificate is invalid, or if the key does /// not match the certificate. pub fn new( - private_key: rustls::PrivateKey, - ca_cert: rustls::Certificate, + private_key: PrivateKeyDer<'static>, + ca_cert: CertificateDer<'static>, cache_size: u64, ) -> Result { let ca = Self { @@ -77,7 +70,7 @@ impl RcgenAuthority { Ok(ca) } - fn gen_cert(&self, authority: &Authority) -> rustls::Certificate { + fn gen_cert(&self, authority: &Authority) -> CertificateDer<'static> { let mut params = rcgen::CertificateParams::default(); params.serial_number = Some(thread_rng().gen::().into()); @@ -93,35 +86,36 @@ impl RcgenAuthority { .subject_alt_names .push(SanType::DnsName(authority.host().to_owned())); - let key_pair = KeyPair::from_der(&self.private_key.0).expect("Failed to parse private key"); + let key_pair = + KeyPair::from_der(self.private_key.secret_der()).expect("Failed to parse private key"); params.alg = key_pair .compatible_algs() .next() .expect("Failed to find compatible algorithm"); params.key_pair = Some(key_pair); - let key_pair = KeyPair::from_der(&self.private_key.0).expect("Failed to parse private key"); + let key_pair = + KeyPair::from_der(self.private_key.secret_der()).expect("Failed to parse private key"); - let ca_cert_params = rcgen::CertificateParams::from_ca_cert_der(&self.ca_cert.0, key_pair) + let ca_cert_params = rcgen::CertificateParams::from_ca_cert_der(&self.ca_cert, key_pair) .expect("Failed to parse CA certificate"); let ca_cert = rcgen::Certificate::from_params(ca_cert_params) .expect("Failed to generate CA certificate"); let cert = rcgen::Certificate::from_params(params).expect("Failed to generate certificate"); - rustls::Certificate( + CertificateDer::from( cert.serialize_der_with_signer(&ca_cert) .expect("Failed to serialize certificate"), ) } fn validate(&self) -> Result<(), rcgen::Error> { - let key_pair = rcgen::KeyPair::from_der(&self.private_key.0)?; - rcgen::CertificateParams::from_ca_cert_der(&self.ca_cert.0, key_pair)?; + let key_pair = rcgen::KeyPair::from_der(self.private_key.secret_der())?; + rcgen::CertificateParams::from_ca_cert_der(&self.ca_cert, key_pair)?; Ok(()) } } -#[async_trait] impl CertificateAuthority for RcgenAuthority { async fn gen_server_config(&self, authority: &Authority) -> Arc { if let Some(server_cfg) = self.cache.get(authority).await { @@ -133,9 +127,8 @@ impl CertificateAuthority for RcgenAuthority { let certs = vec![self.gen_cert(authority)]; let mut server_cfg = ServerConfig::builder() - .with_safe_defaults() .with_no_client_auth() - .with_single_cert(certs, self.private_key.clone()) + .with_single_cert(certs, self.private_key.clone_key()) .expect("Failed to build ServerConfig"); server_cfg.alpn_protocols = vec![ @@ -158,25 +151,18 @@ impl CertificateAuthority for RcgenAuthority { mod tests { use super::*; use rustls_pemfile as pemfile; + use tokio_rustls::rustls::pki_types::PrivatePkcs1KeyDer; fn init_ca(cache_size: u64) -> RcgenAuthority { let mut private_key_bytes: &[u8] = include_bytes!("../../examples/ca/hudsucker.key"); let mut ca_cert_bytes: &[u8] = include_bytes!("../../examples/ca/hudsucker.cer"); - let private_key = rustls::PrivateKey( - pemfile::pkcs8_private_keys(&mut private_key_bytes) - .next() - .unwrap() - .expect("Failed to parse private key") - .secret_pkcs8_der() - .to_vec(), - ); - let ca_cert = rustls::Certificate( - pemfile::certs(&mut ca_cert_bytes) - .next() - .unwrap() - .expect("Failed to parse CA certificate") - .to_vec(), - ); + let private_key = pemfile::private_key(&mut private_key_bytes) + .unwrap() + .expect("Failed to parse private key"); + let ca_cert = pemfile::certs(&mut ca_cert_bytes) + .next() + .unwrap() + .expect("Failed to parse CA certificate"); RcgenAuthority::new(private_key, ca_cert, cache_size).unwrap() } @@ -184,7 +170,8 @@ mod tests { #[test] fn error_for_invalid_key() { let ca = init_ca(0); - let private_key = rustls::PrivateKey(vec![0, 1, 2, 3, 4, 5, 6, 7, 8, 9]); + let private_key = + PrivateKeyDer::from(PrivatePkcs1KeyDer::from(vec![0, 1, 2, 3, 4, 5, 6, 7, 8, 9])); let result = RcgenAuthority::new(private_key, ca.ca_cert, 0); assert!(result.is_err()); @@ -193,7 +180,7 @@ mod tests { #[test] fn error_for_invalid_ca_cert() { let ca = init_ca(0); - let ca_cert = rustls::Certificate(vec![0, 1, 2, 3, 4, 5, 6, 7, 8, 9]); + let ca_cert = CertificateDer::from(vec![0, 1, 2, 3, 4, 5, 6, 7, 8, 9]); let result = RcgenAuthority::new(ca.private_key, ca_cert, 0); assert!(result.is_err()); @@ -211,13 +198,13 @@ mod tests { let c3 = ca.gen_cert(&authority1); let c4 = ca.gen_cert(&authority2); - let (_, cert1) = x509_parser::parse_x509_certificate(&c1.0).unwrap(); - let (_, cert2) = x509_parser::parse_x509_certificate(&c2.0).unwrap(); + let (_, cert1) = x509_parser::parse_x509_certificate(&c1).unwrap(); + let (_, cert2) = x509_parser::parse_x509_certificate(&c2).unwrap(); assert_ne!(cert1.raw_serial(), cert2.raw_serial()); - let (_, cert3) = x509_parser::parse_x509_certificate(&c3.0).unwrap(); - let (_, cert4) = x509_parser::parse_x509_certificate(&c4.0).unwrap(); + let (_, cert3) = x509_parser::parse_x509_certificate(&c3).unwrap(); + let (_, cert4) = x509_parser::parse_x509_certificate(&c4).unwrap(); assert_ne!(cert3.raw_serial(), cert4.raw_serial()); diff --git a/src/decoder.rs b/src/decoder.rs index cb35820..c303c3f 100644 --- a/src/decoder.rs +++ b/src/decoder.rs @@ -1,52 +1,56 @@ -use crate::Error; +use crate::{Body, Error}; use async_compression::tokio::bufread::{BrotliDecoder, GzipDecoder, ZlibDecoder, ZstdDecoder}; use bstr::ByteSlice; -use bytes::Bytes; use futures::Stream; +use http_body_util::BodyStream; use hyper::{ + body::{Bytes, Frame}, header::{HeaderMap, HeaderValue, CONTENT_ENCODING, CONTENT_LENGTH}, - Body, Error as HyperError, Request, Response, + Request, Response, }; use std::{ io, - io::Error as IoError, pin::Pin, task::{Context, Poll}, }; use tokio::io::{AsyncBufRead, AsyncRead, BufReader}; use tokio_util::io::{ReaderStream, StreamReader}; -struct IoStream> + Unpin>(T); +struct IoStream, Error>> + Unpin>(T); -impl> + Unpin> Stream for IoStream { - type Item = Result; +impl, Error>> + Unpin> Stream for IoStream { + type Item = Result; fn poll_next(mut self: Pin<&mut Self>, cx: &mut Context) -> Poll> { match futures::ready!(Pin::new(&mut self.0).poll_next(cx)) { - Some(Ok(chunk)) => Poll::Ready(Some(Ok(chunk))), - Some(Err(err)) => Poll::Ready(Some(Err(IoError::new(io::ErrorKind::Other, err)))), + Some(Ok(chunk)) => match chunk.into_data() { + Ok(chunk) => Poll::Ready(Some(Ok(chunk))), + Err(_) => Poll::Ready(None), + }, + Some(Err(Error::Io(err))) => Poll::Ready(Some(Err(err))), + Some(Err(err)) => Poll::Ready(Some(Err(io::Error::other(err)))), None => Poll::Ready(None), } } } -enum Decoder { - Body(Body), - Decoder(Box), +enum Decoder { + Body(T), + Decoder(Box), } -impl Decoder { +impl Decoder { pub fn decode(self, encoding: &[u8]) -> Result { if encoding == b"identity" { return Ok(self); } - let reader: Box = match self { - Self::Body(body) => Box::new(StreamReader::new(IoStream(body))), + let reader: Box = match self { + Self::Body(body) => Box::new(StreamReader::new(IoStream(BodyStream::new(body)))), Self::Decoder(decoder) => Box::new(BufReader::new(decoder)), }; - let decoder: Box = match encoding { + let decoder: Box = match encoding { b"gzip" | b"x-gzip" => Box::new(GzipDecoder::new(reader)), b"deflate" => Box::new(ZlibDecoder::new(reader)), b"br" => Box::new(BrotliDecoder::new(reader)), @@ -58,8 +62,8 @@ impl Decoder { } } -impl From for Body { - fn from(decoder: Decoder) -> Body { +impl From> for Body { + fn from(decoder: Decoder) -> Body { match decoder { Decoder::Body(body) => body, Decoder::Decoder(decoder) => Body::wrap_stream(ReaderStream::new(decoder)), @@ -100,16 +104,12 @@ fn decode_body<'a>( /// /// ```rust /// use hudsucker::{ -/// async_trait::async_trait, -/// decode_request, -/// hyper::{Body, Request, Response}, -/// Error, HttpContext, HttpHandler, RequestOrResponse, +/// decode_request, hyper::Request, Body, HttpContext, HttpHandler, RequestOrResponse, /// }; /// /// #[derive(Clone)] /// pub struct MyHandler; /// -/// #[async_trait] /// impl HttpHandler for MyHandler { /// async fn handle_request( /// &mut self, @@ -159,19 +159,17 @@ pub fn decode_request(mut req: Request) -> Result, Error> { /// # Examples /// /// ```rust -/// use hudsucker::{ -/// async_trait::async_trait, -/// decode_response, -/// hyper::{Body, Request, Response}, -/// Error, HttpContext, HttpHandler, RequestOrResponse, -/// }; +/// use hudsucker::{decode_response, hyper::Response, Body, HttpContext, HttpHandler}; /// /// #[derive(Clone)] /// pub struct MyHandler; /// -/// #[async_trait] /// impl HttpHandler for MyHandler { -/// async fn handle_response(&mut self, _ctx: &HttpContext, res: Response) -> Response { +/// async fn handle_response( +/// &mut self, +/// _ctx: &HttpContext, +/// res: Response, +/// ) -> Response { /// let res = decode_response(res).unwrap(); /// /// // Do something with the response @@ -207,6 +205,7 @@ pub fn decode_response(mut res: Response) -> Result, Error> #[cfg(test)] mod tests { use super::*; + use hyper::body::Body as HyperBody; mod extract_encodings { use super::*; @@ -265,10 +264,18 @@ mod tests { } } + async fn to_bytes(body: H) -> Bytes + where + ::Error: std::fmt::Debug, + { + use http_body_util::BodyExt; + body.collect().await.unwrap().to_bytes() + } + mod decode_body { use super::*; use async_compression::tokio::bufread::{BrotliEncoder, GzipEncoder}; - use hyper::body::to_bytes; + use http_body_util::Empty; #[tokio::test] async fn no_encodings() { @@ -276,7 +283,7 @@ mod tests { let body = Body::from(content); assert_eq!( - &to_bytes(decode_body(vec![], body).unwrap()).await.unwrap()[..], + &to_bytes(decode_body(vec![], body).unwrap()).await[..], content.as_bytes() ); } @@ -287,9 +294,7 @@ mod tests { let body = Body::from(content); assert_eq!( - &to_bytes(decode_body(vec![&b"identity"[..]], body).unwrap()) - .await - .unwrap()[..], + &to_bytes(decode_body(vec![&b"identity"[..]], body).unwrap()).await[..], content.as_bytes() ); } @@ -301,9 +306,7 @@ mod tests { let body = Body::wrap_stream(ReaderStream::new(encoder)); assert_eq!( - &to_bytes(decode_body(vec![&b"gzip"[..]], body).unwrap()) - .await - .unwrap()[..], + &to_bytes(decode_body(vec![&b"gzip"[..]], body).unwrap()).await[..], content ); } @@ -316,16 +319,14 @@ mod tests { let body = Body::wrap_stream(ReaderStream::new(encoder)); assert_eq!( - &to_bytes(decode_body(vec![&b"br"[..], &b"gzip"[..]], body).unwrap()) - .await - .unwrap()[..], + &to_bytes(decode_body(vec![&b"br"[..], &b"gzip"[..]], body).unwrap()).await[..], content ); } #[test] fn invalid_encoding() { - let body = Body::empty(); + let body = Body::from(Empty::::new()); assert!(decode_body(vec![&b"invalid"[..]], body).is_err()); } @@ -334,7 +335,6 @@ mod tests { mod decode_request { use super::*; use async_compression::tokio::bufread::GzipEncoder; - use hyper::body::to_bytes; #[tokio::test] async fn decodes_request() { @@ -350,14 +350,13 @@ mod tests { assert!(!req.headers().contains_key(CONTENT_LENGTH)); assert!(!req.headers().contains_key(CONTENT_ENCODING)); - assert_eq!(&to_bytes(req.into_body()).await.unwrap()[..], content); + assert_eq!(&to_bytes(req.into_body()).await[..], content); } } mod decode_response { use super::*; use async_compression::tokio::bufread::GzipEncoder; - use hyper::body::to_bytes; #[tokio::test] async fn decodes_response() { @@ -373,7 +372,7 @@ mod tests { assert!(!res.headers().contains_key(CONTENT_LENGTH)); assert!(!res.headers().contains_key(CONTENT_ENCODING)); - assert_eq!(&to_bytes(res.into_body()).await.unwrap()[..], content); + assert_eq!(&to_bytes(res.into_body()).await[..], content); } } } diff --git a/src/error.rs b/src/error.rs index 2251e66..2e87ec5 100644 --- a/src/error.rs +++ b/src/error.rs @@ -9,6 +9,8 @@ pub enum Error { Tls(#[from] rcgen::Error), #[error("network error")] Network(#[from] hyper::Error), + #[error("io error")] + Io(#[from] std::io::Error), #[error("unable to decode body")] Decode, #[error("unknown error")] diff --git a/src/lib.rs b/src/lib.rs index db842e1..b6e8337 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -16,6 +16,7 @@ //! - `rcgen-ca`: Enables [`certificate_authority::RcgenAuthority`] (enabled by default). //! - `rustls-client`: Enables [`ProxyBuilder::with_rustls_client`] (enabled by default). +mod body; #[cfg(feature = "decoder")] mod decoder; mod error; @@ -26,21 +27,23 @@ mod rewind; pub mod certificate_authority; use futures::{Sink, SinkExt, Stream, StreamExt}; -use hyper::{Body, Request, Response, StatusCode, Uri}; -use std::net::SocketAddr; +use http_body_util::Empty; +use hyper::{Request, Response, StatusCode, Uri}; +use std::{future::Future, net::SocketAddr}; use tokio_tungstenite::tungstenite::{self, Message}; use tracing::error; pub(crate) use rewind::Rewind; -pub use async_trait; pub use futures; pub use hyper; +pub use hyper_util; #[cfg(feature = "openssl-ca")] pub use openssl; pub use tokio_rustls::rustls; pub use tokio_tungstenite; +pub use body::Body; #[cfg(feature = "decoder")] pub use decoder::{decode_request, decode_response}; pub use error::Error; @@ -98,76 +101,90 @@ pub enum WebSocketContext { /// Handler for HTTP requests and responses. /// /// Each request/response pair is passed to the same instance of the handler. -#[async_trait::async_trait] pub trait HttpHandler: Clone + Send + Sync + 'static { /// This handler will be called for each HTTP request. It can either return a modified request, /// or a response. If a request is returned, it will be sent to the upstream server. If a /// response is returned, it will be sent to the client. - async fn handle_request( + fn handle_request( &mut self, _ctx: &HttpContext, req: Request, - ) -> RequestOrResponse { - req.into() + ) -> impl Future + Send { + async { req.into() } } /// This handler will be called for each HTTP response. It can modify a response before it is /// forwarded to the client. - async fn handle_response(&mut self, _ctx: &HttpContext, res: Response) -> Response { - res + fn handle_response( + &mut self, + _ctx: &HttpContext, + res: Response, + ) -> impl Future> + Send { + async { res } } /// This handler will be called if a proxy request fails. Default response is a 502 Bad Gateway. - async fn handle_error(&mut self, _ctx: &HttpContext, err: hyper::Error) -> Response { - error!("Failed to forward request: {}", err); - Response::builder() - .status(StatusCode::BAD_GATEWAY) - .body(Body::empty()) - .expect("Failed to build response") + fn handle_error( + &mut self, + _ctx: &HttpContext, + err: hyper_util::client::legacy::Error, + ) -> impl Future> + Send { + async move { + error!("Failed to forward request: {}", err); + Response::builder() + .status(StatusCode::BAD_GATEWAY) + .body(Empty::new().into()) + .expect("Failed to build response") + } } /// Whether a CONNECT request should be intercepted. Defaults to `true` for all requests. - async fn should_intercept(&mut self, _ctx: &HttpContext, _req: &Request) -> bool { - true + fn should_intercept( + &mut self, + _ctx: &HttpContext, + _req: &Request, + ) -> impl Future + Send { + async { true } } } /// Handler for WebSocket messages. /// /// Messages sent over the same WebSocket Stream are passed to the same instance of the handler. -#[async_trait::async_trait] pub trait WebSocketHandler: Clone + Send + Sync + 'static { /// This handler is responsible for forwarding WebSocket messages from a Stream to a Sink and /// recovering from any potential errors. - async fn handle_websocket( + fn handle_websocket( mut self, ctx: WebSocketContext, mut stream: impl Stream> + Unpin + Send + 'static, mut sink: impl Sink + Unpin + Send + 'static, - ) { - while let Some(message) = stream.next().await { - match message { - Ok(message) => { - let Some(message) = self.handle_message(&ctx, message).await else { - continue; - }; - - match sink.send(message).await { - Err(tungstenite::Error::ConnectionClosed) => (), - Err(e) => error!("WebSocket send error: {}", e), - _ => (), + ) -> impl Future + Send { + async move { + while let Some(message) = stream.next().await { + match message { + Ok(message) => { + let Some(message) = self.handle_message(&ctx, message).await else { + continue; + }; + + match sink.send(message).await { + Err(tungstenite::Error::ConnectionClosed) => (), + Err(e) => error!("WebSocket send error: {}", e), + _ => (), + } } - } - Err(e) => { - error!("WebSocket message error: {}", e); + Err(e) => { + error!("WebSocket message error: {}", e); - match sink.send(Message::Close(None)).await { - Err(tungstenite::Error::ConnectionClosed) => (), - Err(e) => error!("WebSocket close error: {}", e), - _ => (), - }; + match sink.send(Message::Close(None)).await { + Err(tungstenite::Error::ConnectionClosed) => (), + Err(e) => error!("WebSocket close error: {}", e), + _ => (), + }; - break; + break; + } } } } @@ -175,11 +192,11 @@ pub trait WebSocketHandler: Clone + Send + Sync + 'static { /// This handler will be called for each WebSocket message. It can return an optional modified /// message. If None is returned the message will not be forwarded. - async fn handle_message( + fn handle_message( &mut self, _ctx: &WebSocketContext, message: Message, - ) -> Option { - Some(message) + ) -> impl Future> + Send { + async { Some(message) } } } diff --git a/src/proxy/builder.rs b/src/proxy/builder.rs index 6e74085..bc20076 100644 --- a/src/proxy/builder.rs +++ b/src/proxy/builder.rs @@ -1,18 +1,25 @@ use crate::{ - certificate_authority::CertificateAuthority, HttpHandler, NoopHandler, Proxy, WebSocketHandler, -}; -use hyper::{ - client::{connect::Connect, Client, HttpConnector}, - server::conn::AddrIncoming, + certificate_authority::CertificateAuthority, Body, HttpHandler, NoopHandler, Proxy, + WebSocketHandler, }; #[cfg(feature = "rustls-client")] use hyper_rustls::{HttpsConnector as RustlsConnector, HttpsConnectorBuilder}; #[cfg(feature = "native-tls-client")] use hyper_tls::HttpsConnector as NativeTlsConnector; +use hyper_util::{ + client::legacy::{ + connect::{Connect, HttpConnector}, + Client, + }, + rt::TokioExecutor, + server::conn::auto::Builder, +}; use std::{ - net::{SocketAddr, TcpListener}, + future::{pending, Future, Pending}, + net::SocketAddr, sync::Arc, }; +use tokio::net::TcpListener; use tokio_tungstenite::Connector; /// A builder for creating a [`Proxy`]. @@ -29,21 +36,13 @@ use tokio_tungstenite::Connector; /// # /// # let mut private_key_bytes: &[u8] = include_bytes!("../../examples/ca/hudsucker.key"); /// # let mut ca_cert_bytes: &[u8] = include_bytes!("../../examples/ca/hudsucker.cer"); -/// # let private_key = rustls::PrivateKey( -/// # pemfile::pkcs8_private_keys(&mut private_key_bytes) -/// # .next() +/// # let private_key = pemfile::private_key(&mut private_key_bytes) /// # .unwrap() -/// # .expect("Failed to parse private key") -/// # .secret_pkcs8_der() -/// # .to_vec(), -/// # ); -/// # let ca_cert = rustls::Certificate( -/// # pemfile::certs(&mut ca_cert_bytes) +/// # .expect("Failed to parse private key"); +/// # let ca_cert = pemfile::certs(&mut ca_cert_bytes) /// # .next() /// # .unwrap() -/// # .expect("Failed to parse CA certificate") -/// # .to_vec(), -/// # ); +/// # .expect("Failed to parse CA certificate"); /// # /// # let ca = RcgenAuthority::new(private_key, ca_cert, 1_000) /// # .expect("Failed to create Certificate Authority"); @@ -65,10 +64,9 @@ pub struct ProxyBuilder(T); pub struct WantsAddr(()); #[derive(Debug)] -pub(crate) enum AddrListenerServer { +pub(crate) enum AddrOrListener { Addr(SocketAddr), Listener(TcpListener), - Server(Box>), } impl ProxyBuilder { @@ -80,24 +78,14 @@ impl ProxyBuilder { /// Set the address to listen on. pub fn with_addr(self, addr: SocketAddr) -> ProxyBuilder { ProxyBuilder(WantsClient { - als: AddrListenerServer::Addr(addr), + al: AddrOrListener::Addr(addr), }) } /// Set a listener to use for the proxy server. pub fn with_listener(self, listener: TcpListener) -> ProxyBuilder { ProxyBuilder(WantsClient { - als: AddrListenerServer::Listener(listener), - }) - } - - /// Set a custom server builder to use for the proxy server. - pub fn with_server( - self, - server: hyper::server::Builder, - ) -> ProxyBuilder { - ProxyBuilder(WantsClient { - als: AddrListenerServer::Server(Box::new(server)), + al: AddrOrListener::Listener(listener), }) } } @@ -111,7 +99,7 @@ impl Default for ProxyBuilder { /// Builder state that needs a client. #[derive(Debug)] pub struct WantsClient { - als: AddrListenerServer, + al: AddrOrListener, } impl ProxyBuilder { @@ -130,8 +118,8 @@ impl ProxyBuilder { let https = https.build(); ProxyBuilder(WantsCa { - als: self.0.als, - client: Client::builder() + al: self.0.al, + client: Client::builder(TokioExecutor::new()) .http1_title_case_headers(true) .http1_preserve_header_case(true) .build(https), @@ -147,8 +135,8 @@ impl ProxyBuilder { let https = NativeTlsConnector::new(); ProxyBuilder(WantsCa { - als: self.0.als, - client: Client::builder() + al: self.0.al, + client: Client::builder(TokioExecutor::new()) .http1_title_case_headers(true) .http1_preserve_header_case(true) .build(https), @@ -156,12 +144,12 @@ impl ProxyBuilder { } /// Use a custom client. - pub fn with_client(self, client: Client) -> ProxyBuilder> + pub fn with_client(self, client: Client) -> ProxyBuilder> where C: Connect + Clone + Send + Sync + 'static, { ProxyBuilder(WantsCa { - als: self.0.als, + al: self.0.al, client, }) } @@ -170,8 +158,8 @@ impl ProxyBuilder { /// Builder state that needs a certificate authority. #[derive(Debug)] pub struct WantsCa { - als: AddrListenerServer, - client: Client, + al: AddrOrListener, + client: Client, } impl ProxyBuilder> { @@ -179,41 +167,47 @@ impl ProxyBuilder> { pub fn with_ca( self, ca: CA, - ) -> ProxyBuilder> { + ) -> ProxyBuilder>> { ProxyBuilder(WantsHandlers { - als: self.0.als, + al: self.0.al, client: self.0.client, ca, http_handler: NoopHandler::new(), websocket_handler: NoopHandler::new(), websocket_connector: None, + server: None, + graceful_shutdown: pending(), }) } } /// Builder state that can take additional handlers. -pub struct WantsHandlers { - als: AddrListenerServer, - client: Client, +pub struct WantsHandlers { + al: AddrOrListener, + client: Client, ca: CA, http_handler: H, websocket_handler: W, websocket_connector: Option, + server: Option>, + graceful_shutdown: F, } -impl ProxyBuilder> { +impl ProxyBuilder> { /// Set the HTTP handler. pub fn with_http_handler( self, http_handler: H2, - ) -> ProxyBuilder> { + ) -> ProxyBuilder> { ProxyBuilder(WantsHandlers { - als: self.0.als, + al: self.0.al, client: self.0.client, ca: self.0.ca, http_handler, websocket_handler: self.0.websocket_handler, websocket_connector: self.0.websocket_connector, + server: self.0.server, + graceful_shutdown: self.0.graceful_shutdown, }) } @@ -221,14 +215,16 @@ impl ProxyBuilder> { pub fn with_websocket_handler( self, websocket_handler: W2, - ) -> ProxyBuilder> { + ) -> ProxyBuilder> { ProxyBuilder(WantsHandlers { - als: self.0.als, + al: self.0.al, client: self.0.client, ca: self.0.ca, http_handler: self.0.http_handler, websocket_handler, websocket_connector: self.0.websocket_connector, + server: self.0.server, + graceful_shutdown: self.0.graceful_shutdown, }) } @@ -240,15 +236,42 @@ impl ProxyBuilder> { }) } + /// Set a custom server builder to use for the proxy server. + pub fn with_server(self, server: Builder) -> Self { + ProxyBuilder(WantsHandlers { + server: Some(server), + ..self.0 + }) + } + + /// Set a future that when ready will gracefully shutdown the proxy server. + pub fn with_graceful_shutdown + Send + 'static>( + self, + graceful_shutdown: F2, + ) -> ProxyBuilder> { + ProxyBuilder(WantsHandlers { + al: self.0.al, + client: self.0.client, + ca: self.0.ca, + http_handler: self.0.http_handler, + websocket_handler: self.0.websocket_handler, + websocket_connector: self.0.websocket_connector, + server: self.0.server, + graceful_shutdown, + }) + } + /// Build the proxy. - pub fn build(self) -> Proxy { + pub fn build(self) -> Proxy { Proxy { - als: self.0.als, + al: self.0.al, client: self.0.client, ca: Arc::new(self.0.ca), http_handler: self.0.http_handler, websocket_handler: self.0.websocket_handler, websocket_connector: self.0.websocket_connector, + server: self.0.server, + graceful_shutdown: self.0.graceful_shutdown, } } } diff --git a/src/proxy/internal.rs b/src/proxy/internal.rs index ae283f3..7259c34 100644 --- a/src/proxy/internal.rs +++ b/src/proxy/internal.rs @@ -1,19 +1,23 @@ use crate::{ - certificate_authority::CertificateAuthority, HttpContext, HttpHandler, RequestOrResponse, - Rewind, WebSocketContext, WebSocketHandler, + body::Body, certificate_authority::CertificateAuthority, HttpContext, HttpHandler, + RequestOrResponse, Rewind, WebSocketContext, WebSocketHandler, }; use futures::{Sink, Stream, StreamExt}; use http::uri::{Authority, Scheme}; +use http_body_util::Empty; use hyper::{ - client::connect::Connect, header::Entry, server::conn::Http, service::service_fn, - upgrade::Upgraded, Body, Client, Method, Request, Response, StatusCode, Uri, + body::{Bytes, Incoming}, + header::Entry, + service::service_fn, + upgrade::Upgraded, + Method, Request, Response, StatusCode, Uri, }; -use std::{convert::Infallible, future::Future, net::SocketAddr, sync::Arc}; -use tokio::{ - io::{AsyncRead, AsyncReadExt, AsyncWrite}, - net::TcpStream, - task::JoinHandle, +use hyper_util::{ + client::legacy::{connect::Connect, Client}, + rt::{TokioExecutor, TokioIo}, }; +use std::{convert::Infallible, future::Future, net::SocketAddr, sync::Arc}; +use tokio::{io::AsyncReadExt, net::TcpStream, task::JoinHandle}; use tokio_rustls::TlsAcceptor; use tokio_tungstenite::{ tungstenite::{self, Message}, @@ -24,7 +28,7 @@ use tracing::{error, info_span, instrument, warn, Instrument, Span}; fn bad_request() -> Response { Response::builder() .status(StatusCode::BAD_REQUEST) - .body(Body::empty()) + .body(Empty::new().into()) .expect("Failed to build response") } @@ -37,7 +41,7 @@ fn spawn_with_trace( pub(crate) struct InternalProxy { pub ca: Arc, - pub client: Client, + pub client: Client, pub http_handler: H, pub websocket_handler: W, pub websocket_connector: Option, @@ -84,12 +88,15 @@ where client_addr = %self.client_addr, ) )] - pub(crate) async fn proxy(mut self, req: Request) -> Result, Infallible> { + pub(crate) async fn proxy( + mut self, + req: Request, + ) -> Result, Infallible> { let ctx = self.context(); let req = match self .http_handler - .handle_request(&ctx, req) + .handle_request(&ctx, req.map(Body::from)) .instrument(info_span!("handle_request")) .await { @@ -111,7 +118,7 @@ where match res { Ok(res) => Ok(self .http_handler - .handle_response(&ctx, res) + .handle_response(&ctx, res.map(Body::from)) .instrument(info_span!("handle_response")) .await), Err(err) => Ok(self @@ -129,7 +136,8 @@ where let span = info_span!("process_connect"); let fut = async move { match hyper::upgrade::on(&mut req).await { - Ok(mut upgraded) => { + Ok(upgraded) => { + let mut upgraded = TokioIo::new(upgraded); let mut buffer = [0; 4]; let bytes_read = match upgraded.read(&mut buffer).await { Ok(bytes_read) => bytes_read, @@ -141,7 +149,7 @@ where let mut upgraded = Rewind::new_buffered( upgraded, - bytes::Bytes::copy_from_slice(buffer[..bytes_read].as_ref()), + Bytes::copy_from_slice(buffer[..bytes_read].as_ref()), ); if self @@ -150,8 +158,13 @@ where .await { if buffer == *b"GET " { - if let Err(e) = - self.serve_stream(upgraded, Scheme::HTTP, authority).await + if let Err(e) = self + .serve_stream( + TokioIo::new(upgraded), + Scheme::HTTP, + authority, + ) + .await { error!("WebSocket connect error: {}", e); } @@ -168,7 +181,7 @@ where .accept(upgraded) .await { - Ok(stream) => stream, + Ok(stream) => TokioIo::new(stream), Err(e) => { error!("Failed to establish TLS connection: {}", e); return; @@ -214,7 +227,7 @@ where }; spawn_with_trace(fut, span); - Response::new(Body::empty()) + Response::new(Empty::new().into()) } None => bad_request(), } @@ -262,7 +275,7 @@ where }; spawn_with_trace(fut, span); - res + res.map(Into::into) } Err(_) => bad_request(), } @@ -271,7 +284,7 @@ where #[instrument(skip_all)] async fn handle_websocket( self, - server_socket: WebSocketStream, + server_socket: WebSocketStream>, req: Request<()>, ) -> Result<(), tungstenite::Error> { let uri = req.uri().clone(); @@ -324,9 +337,9 @@ where stream: I, scheme: Scheme, authority: Authority, - ) -> Result<(), hyper::Error> + ) -> Result<(), Box> where - I: AsyncRead + AsyncWrite + Unpin + Send + 'static, + I: hyper::rt::Read + hyper::rt::Write + Unpin + Send + 'static, { let service = service_fn(|mut req| { if req.version() == hyper::Version::HTTP_10 || req.version() == hyper::Version::HTTP_11 @@ -346,9 +359,8 @@ where self.clone().proxy(req) }); - Http::new() - .serve_connection(stream, service) - .with_upgrades() + hyper_util::server::conn::auto::Builder::new(TokioExecutor::new()) + .serve_connection_with_upgrades(stream, service) .await } } @@ -382,23 +394,21 @@ fn normalize_request(mut req: Request) -> Request { #[cfg(test)] mod tests { use super::*; + use hyper_util::client::legacy::connect::HttpConnector; use tokio_rustls::rustls::ServerConfig; struct CA; - #[async_trait::async_trait] impl CertificateAuthority for CA { async fn gen_server_config(&self, _authority: &Authority) -> Arc { unimplemented!(); } } - fn build_proxy( - ) -> InternalProxy - { + fn build_proxy() -> InternalProxy { InternalProxy { ca: Arc::new(CA), - client: hyper::Client::new(), + client: Client::builder(TokioExecutor::new()).build(HttpConnector::new()), http_handler: crate::NoopHandler::new(), websocket_handler: crate::NoopHandler::new(), websocket_connector: None, @@ -464,7 +474,7 @@ mod tests { let req = Request::builder() .uri("/foo/bar?baz") - .body(Body::empty()) + .body(Empty::new().into()) .unwrap(); let res = proxy.process_connect(req); @@ -482,7 +492,7 @@ mod tests { let req = Request::builder() .uri("/foo/bar?baz") - .body(Body::empty()) + .body(Empty::new().into()) .unwrap(); let res = proxy.upgrade_websocket(req); @@ -496,7 +506,7 @@ mod tests { let req = Request::builder() .uri("http://example.com/foo/bar?baz") - .body(Body::empty()) + .body(Empty::new().into()) .unwrap(); let res = proxy.upgrade_websocket(req); diff --git a/src/proxy/mod.rs b/src/proxy/mod.rs index 1777096..00889d6 100644 --- a/src/proxy/mod.rs +++ b/src/proxy/mod.rs @@ -2,17 +2,22 @@ mod internal; pub mod builder; -use crate::{certificate_authority::CertificateAuthority, Error, HttpHandler, WebSocketHandler}; -use builder::{AddrListenerServer, WantsAddr}; -use hyper::{ - client::connect::Connect, - server::conn::AddrStream, - service::{make_service_fn, service_fn}, - Client, Server, +use crate::{ + certificate_authority::CertificateAuthority, Body, Error, HttpHandler, WebSocketHandler, +}; +use builder::{AddrOrListener, WantsAddr}; +use hyper::service::service_fn; +use hyper_util::{ + client::legacy::{connect::Connect, Client}, + rt::{TokioExecutor, TokioIo}, + server::conn::auto::{self, Builder}, }; use internal::InternalProxy; -use std::{convert::Infallible, future::Future, sync::Arc}; +use std::{future::Future, sync::Arc}; +use tokio::net::TcpListener; +use tokio_graceful::Shutdown; use tokio_tungstenite::Connector; +use tracing::error; pub use builder::ProxyBuilder; @@ -32,38 +37,31 @@ pub use builder::ProxyBuilder; /// # /// # let mut private_key_bytes: &[u8] = include_bytes!("../../examples/ca/hudsucker.key"); /// # let mut ca_cert_bytes: &[u8] = include_bytes!("../../examples/ca/hudsucker.cer"); -/// # let private_key = rustls::PrivateKey( -/// # pemfile::pkcs8_private_keys(&mut private_key_bytes) -/// # .next() +/// # let private_key = pemfile::private_key(&mut private_key_bytes) /// # .unwrap() -/// # .expect("Failed to parse private key") -/// # .secret_pkcs8_der() -/// # .to_vec(), -/// # ); -/// # let ca_cert = rustls::Certificate( -/// # pemfile::certs(&mut ca_cert_bytes) +/// # .expect("Failed to parse private key"); +/// # let ca_cert = pemfile::certs(&mut ca_cert_bytes) /// # .next() /// # .unwrap() -/// # .expect("Failed to parse CA certificate") -/// # .to_vec(), -/// # ); +/// # .expect("Failed to parse CA certificate"); /// # /// # let ca = RcgenAuthority::new(private_key, ca_cert, 1_000) /// # .expect("Failed to create Certificate Authority"); /// /// // let ca = ...; /// +/// let (stop, done) = tokio::sync::oneshot::channel(); +/// /// let proxy = Proxy::builder() /// .with_addr(std::net::SocketAddr::from(([127, 0, 0, 1], 0))) /// .with_rustls_client() /// .with_ca(ca) +/// .with_graceful_shutdown(async { +/// done.await.unwrap_or_default(); +/// }) /// .build(); /// -/// let (stop, done) = tokio::sync::oneshot::channel(); -/// -/// tokio::spawn(proxy.start(async { -/// done.await.unwrap_or_default(); -/// })); +/// tokio::spawn(proxy.start()); /// /// // Do something else... /// @@ -73,71 +71,111 @@ pub use builder::ProxyBuilder; /// # #[cfg(not(all(feature = "rcgen-ca", feature = "rustls-client")))] /// # fn main() {} /// ``` -pub struct Proxy { - als: AddrListenerServer, +pub struct Proxy { + al: AddrOrListener, ca: Arc, - client: Client, + client: Client, http_handler: H, websocket_handler: W, websocket_connector: Option, + server: Option>, + graceful_shutdown: F, } -impl Proxy<(), (), (), ()> { +impl Proxy<(), (), (), (), ()> { /// Create a new [`ProxyBuilder`]. pub fn builder() -> ProxyBuilder { ProxyBuilder::new() } } -impl Proxy +impl Proxy where C: Connect + Clone + Send + Sync + 'static, CA: CertificateAuthority, H: HttpHandler, W: WebSocketHandler, + F: Future + Send + 'static, { /// Attempts to start the proxy server. /// /// # Errors /// /// This will return an error if the proxy server is unable to be started. - pub async fn start>(self, shutdown_signal: F) -> Result<(), Error> { - let make_service = make_service_fn(move |conn: &AddrStream| { - let client = self.client.clone(); - let ca = Arc::clone(&self.ca); - let http_handler = self.http_handler.clone(); - let websocket_handler = self.websocket_handler.clone(); - let websocket_connector = self.websocket_connector.clone(); - let client_addr = conn.remote_addr(); - async move { - Ok::<_, Infallible>(service_fn(move |req| { - InternalProxy { - ca: Arc::clone(&ca), - client: client.clone(), - http_handler: http_handler.clone(), - websocket_handler: websocket_handler.clone(), - websocket_connector: websocket_connector.clone(), - client_addr, - } - .proxy(req) - })) - } + pub async fn start(self) -> Result<(), Error> { + let server = self.server.unwrap_or_else(|| { + let mut builder = auto::Builder::new(TokioExecutor::new()); + builder + .http1() + .title_case_headers(true) + .preserve_header_case(true); + builder }); - let server_builder = match self.als { - AddrListenerServer::Addr(addr) => Server::try_bind(&addr)? - .http1_preserve_header_case(true) - .http1_title_case_headers(true), - AddrListenerServer::Listener(listener) => Server::from_tcp(listener)? - .http1_preserve_header_case(true) - .http1_title_case_headers(true), - AddrListenerServer::Server(server) => *server, + let listener = match self.al { + AddrOrListener::Addr(addr) => TcpListener::bind(addr).await?, + AddrOrListener::Listener(listener) => listener, }; - server_builder - .serve(make_service) - .with_graceful_shutdown(shutdown_signal) - .await - .map_err(Into::into) + let shutdown = Shutdown::new(self.graceful_shutdown); + let guard = shutdown.guard_weak(); + + loop { + tokio::select! { + res = listener.accept() => { + let (tcp, client_addr) = match res { + Ok((tcp, client_addr)) => (tcp, client_addr), + Err(e) => { + error!("Failed to accept incoming connection: {}", e); + continue; + } + }; + + let server = server.clone(); + let client = self.client.clone(); + let ca = Arc::clone(&self.ca); + let http_handler = self.http_handler.clone(); + let websocket_handler = self.websocket_handler.clone(); + let websocket_connector = self.websocket_connector.clone(); + + shutdown.spawn_task_fn(move |guard| async move { + let conn = server + .serve_connection_with_upgrades( + TokioIo::new(tcp), + service_fn(move |req| { + InternalProxy { + ca: Arc::clone(&ca), + client: client.clone(), + http_handler: http_handler.clone(), + websocket_handler: websocket_handler.clone(), + websocket_connector: websocket_connector.clone(), + client_addr, + } + .proxy(req) + }), + ); + + let mut conn = std::pin::pin!(conn); + + if let Err(err) = tokio::select! { + conn = conn.as_mut() => conn, + _ = guard.cancelled() => { + conn.as_mut().graceful_shutdown(); + conn.await + } + } { + error!("Error serving connection: {}", err); + } + }); + } + _ = guard.cancelled() => { + break; + } + } + } + + shutdown.shutdown().await; + + Ok(()) } } diff --git a/src/rewind.rs b/src/rewind.rs index 2f43122..ad17b2b 100644 --- a/src/rewind.rs +++ b/src/rewind.rs @@ -1,9 +1,8 @@ // adapted from https://github.com/hyperium/hyper/blob/master/src/common/io/rewind.rs -use bytes::{Buf, Bytes}; +use hyper::body::{Buf, Bytes}; use std::{ cmp, io, - marker::Unpin, pin::Pin, task::{self, Poll}, }; @@ -49,7 +48,7 @@ where // TODO: There should be a way to do following two lines cleaner... buf.put_slice(&prefix[..copy_len]); prefix.advance(copy_len); - // Put back whats left + // Put back what's left if !prefix.is_empty() { self.pre = Some(prefix); } diff --git a/tests/common/mod.rs b/tests/common/mod.rs index 38c0cfc..1cf4aa3 100644 --- a/tests/common/mod.rs +++ b/tests/common/mod.rs @@ -1,42 +1,45 @@ use async_compression::tokio::bufread::GzipEncoder; use futures::{SinkExt, StreamExt}; +use http_body_util::Empty; use hudsucker::{ - async_trait::async_trait, certificate_authority::CertificateAuthority, decode_request, decode_response, hyper::{ - client::{ + body::Incoming, header::CONTENT_ENCODING, service::service_fn, Method, Request, Response, + StatusCode, + }, + hyper_util::{ + client::legacy::{ connect::{Connect, HttpConnector}, Client, }, - header::CONTENT_ENCODING, - server::conn::AddrStream, - service::{make_service_fn, service_fn}, - Body, Method, Request, Response, Server, StatusCode, + rt::{TokioExecutor, TokioIo}, + server::conn::auto, }, rustls, tokio_tungstenite::tungstenite::Message, - HttpContext, HttpHandler, Proxy, RequestOrResponse, WebSocketContext, WebSocketHandler, + Body, HttpContext, HttpHandler, Proxy, RequestOrResponse, WebSocketContext, WebSocketHandler, }; use reqwest::tls::Certificate; use rustls_pemfile as pemfile; use std::{ convert::Infallible, - net::{SocketAddr, TcpListener}, + net::SocketAddr, sync::{ atomic::{AtomicUsize, Ordering}, Arc, }, }; use tls_listener::TlsListener; -use tokio::sync::oneshot::Sender; -use tokio_native_tls::{self, native_tls}; +use tokio::{net::TcpListener, sync::oneshot::Sender}; +use tokio_graceful::Shutdown; +use tokio_native_tls::native_tls; use tokio_util::io::ReaderStream; pub const HELLO_WORLD: &str = "Hello, World!"; pub const WORLD: &str = "world"; -async fn test_server(req: Request) -> Result, Infallible> { +async fn test_server(req: Request) -> Result, Infallible> { if hyper_tungstenite::is_upgrade_request(&req) { let (res, ws) = hyper_tungstenite::upgrade(req, None).unwrap(); @@ -52,7 +55,7 @@ async fn test_server(req: Request) -> Result, Infallible> { } }); - return Ok(res); + return Ok(res.map(Body::from)); } match (req.method(), req.uri().path()) { @@ -64,26 +67,45 @@ async fn test_server(req: Request) -> Result, Infallible> { HELLO_WORLD.as_bytes(), )))) .unwrap()), - (&Method::POST, "/echo") => Ok(Response::new(req.into_body())), - _ => Ok(Response::new(Body::empty())), + (&Method::POST, "/echo") => Ok(Response::new(req.into_body().into())), + _ => Ok(Response::new(Body::from(Empty::new()))), } } -pub fn start_http_server() -> Result<(SocketAddr, Sender<()>), Box> { - let make_svc = make_service_fn(|_conn: &AddrStream| async { - Ok::<_, Infallible>(service_fn(test_server)) - }); - - let listener = TcpListener::bind(SocketAddr::from(([127, 0, 0, 1], 0)))?; +pub async fn start_http_server() -> Result<(SocketAddr, Sender<()>), Box> { + let listener = TcpListener::bind(SocketAddr::from(([127, 0, 0, 1], 0))).await?; let addr = listener.local_addr()?; - let (tx, rx) = tokio::sync::oneshot::channel(); - tokio::spawn( - Server::from_tcp(listener)? - .serve(make_svc) - .with_graceful_shutdown(async { rx.await.unwrap_or_default() }), - ); + tokio::spawn(async move { + let server = auto::Builder::new(TokioExecutor::new()); + let shutdown = Shutdown::new(async { rx.await.unwrap_or_default() }); + let guard = shutdown.guard_weak(); + + loop { + tokio::select! { + res = listener.accept() => { + let Ok((tcp, _)) = res else { + continue; + }; + + let server = server.clone(); + + shutdown.spawn_task(async move { + server + .serve_connection_with_upgrades(TokioIo::new(tcp), service_fn(test_server)) + .await + .unwrap(); + }); + } + _ = guard.cancelled() => { + break; + } + } + } + + shutdown.shutdown().await; + }); Ok((addr, tx)) } @@ -91,30 +113,50 @@ pub fn start_http_server() -> Result<(SocketAddr, Sender<()>), Box Result<(SocketAddr, Sender<()>), Box> { - let make_svc = make_service_fn(|_| async { Ok::<_, Infallible>(service_fn(test_server)) }); - - let listener = TcpListener::bind(SocketAddr::from(([127, 0, 0, 1], 0)))?; - listener.set_nonblocking(true)?; + let listener = TcpListener::bind(SocketAddr::from(([127, 0, 0, 1], 0))).await?; let addr = listener.local_addr()?; let acceptor: tokio_rustls::TlsAcceptor = ca .gen_server_config(&"localhost".parse().unwrap()) .await .into(); - let listener = TlsListener::new(acceptor, tokio::net::TcpListener::from_std(listener)?); - + let mut listener = TlsListener::new(acceptor, listener); let (tx, rx) = tokio::sync::oneshot::channel(); - tokio::spawn( - Server::builder(listener) - .serve(make_svc) - .with_graceful_shutdown(async { rx.await.unwrap_or_default() }), - ); + tokio::spawn(async move { + let server = auto::Builder::new(TokioExecutor::new()); + let shutdown = Shutdown::new(async { rx.await.unwrap_or_default() }); + let guard = shutdown.guard_weak(); + + loop { + tokio::select! { + res = listener.accept() => { + let Ok((tcp, _)) = res else { + continue; + }; + + let server = server.clone(); + + shutdown.spawn_task(async move { + server + .serve_connection_with_upgrades(TokioIo::new(tcp), service_fn(test_server)) + .await + .unwrap(); + }); + } + _ = guard.cancelled() => { + break; + } + } + } + + shutdown.shutdown().await; + }); Ok((addr, tx)) } -pub fn http_client() -> Client { - Client::new() +pub fn http_client() -> Client { + Client::builder(TokioExecutor::new()).build_http() } pub fn plain_websocket_connector() -> tokio_tungstenite::Connector { @@ -125,23 +167,18 @@ fn rustls_client_config() -> rustls::ClientConfig { let mut roots = rustls::RootCertStore::empty(); for cert in rustls_native_certs::load_native_certs().unwrap() { - let cert = rustls::Certificate(cert.0); - roots.add(&cert).unwrap(); + roots.add(cert.clone()).unwrap(); } let mut ca_cert_bytes: &[u8] = include_bytes!("../../examples/ca/hudsucker.cer"); - let ca_cert = rustls::Certificate( - pemfile::certs(&mut ca_cert_bytes) - .next() - .unwrap() - .expect("Failed to parse CA certificate") - .to_vec(), - ); + let ca_cert = pemfile::certs(&mut ca_cert_bytes) + .next() + .unwrap() + .expect("Failed to parse CA certificate"); - roots.add(&ca_cert).unwrap(); + roots.add(ca_cert).unwrap(); rustls::ClientConfig::builder() - .with_safe_defaults() .with_root_certificates(roots) .with_no_client_auth() } @@ -150,14 +187,14 @@ pub fn rustls_websocket_connector() -> tokio_tungstenite::Connector { tokio_tungstenite::Connector::Rustls(Arc::new(rustls_client_config())) } -pub fn rustls_client() -> Client> { +pub fn rustls_client() -> Client, Body> { let https = hyper_rustls::HttpsConnectorBuilder::new() .with_tls_config(rustls_client_config()) .https_or_http() .enable_http1() .build(); - Client::builder() + Client::builder(TokioExecutor::new()) .http1_title_case_headers(true) .http1_preserve_header_case(true) .build(https) @@ -178,48 +215,48 @@ pub fn native_tls_websocket_connector() -> tokio_tungstenite::Connector { tokio_tungstenite::Connector::NativeTls(native_tls_connector()) } -pub fn native_tls_client() -> Client> { +pub fn native_tls_client() -> Client, Body> { let mut http = HttpConnector::new(); http.enforce_http(false); let tls = native_tls_connector().into(); let https: hyper_tls::HttpsConnector = (http, tls).into(); - Client::builder().build(https) + Client::builder(TokioExecutor::new()).build(https) } -pub fn start_proxy( +pub async fn start_proxy( ca: impl CertificateAuthority, - client: Client, + client: Client, websocket_connector: tokio_tungstenite::Connector, ) -> Result<(SocketAddr, TestHandler, Sender<()>), Box> where C: Connect + Clone + Send + Sync + 'static, { - _start_proxy(ca, client, websocket_connector, true) + _start_proxy(ca, client, websocket_connector, true).await } -pub fn start_proxy_without_intercept( +pub async fn start_proxy_without_intercept( ca: impl CertificateAuthority, - client: Client, + client: Client, websocket_connector: tokio_tungstenite::Connector, ) -> Result<(SocketAddr, TestHandler, Sender<()>), Box> where C: Connect + Clone + Send + Sync + 'static, { - _start_proxy(ca, client, websocket_connector, false) + _start_proxy(ca, client, websocket_connector, false).await } -fn _start_proxy( +async fn _start_proxy( ca: impl CertificateAuthority, - client: Client, + client: Client, websocket_connector: tokio_tungstenite::Connector, should_intercept: bool, ) -> Result<(SocketAddr, TestHandler, Sender<()>), Box> where C: Connect + Clone + Send + Sync + 'static, { - let listener = TcpListener::bind(SocketAddr::from(([127, 0, 0, 1], 0)))?; + let listener = TcpListener::bind(SocketAddr::from(([127, 0, 0, 1], 0))).await?; let addr = listener.local_addr()?; let (tx, rx) = tokio::sync::oneshot::channel(); @@ -232,19 +269,19 @@ where .with_http_handler(handler.clone()) .with_websocket_handler(handler.clone()) .with_websocket_connector(websocket_connector) + .with_graceful_shutdown(async { + rx.await.unwrap_or_default(); + }) .build(); - tokio::spawn(proxy.start(async { - rx.await.unwrap_or_default(); - })); - + tokio::spawn(proxy.start()); Ok((addr, handler, tx)) } -pub fn start_noop_proxy( +pub async fn start_noop_proxy( ca: impl CertificateAuthority, ) -> Result<(SocketAddr, Sender<()>), Box> { - let listener = TcpListener::bind(SocketAddr::from(([127, 0, 0, 1], 0)))?; + let listener = TcpListener::bind(SocketAddr::from(([127, 0, 0, 1], 0))).await?; let addr = listener.local_addr()?; let (tx, rx) = tokio::sync::oneshot::channel(); @@ -252,12 +289,12 @@ pub fn start_noop_proxy( .with_listener(listener) .with_client(native_tls_client()) .with_ca(ca) + .with_graceful_shutdown(async { + rx.await.unwrap_or_default(); + }) .build(); - tokio::spawn(proxy.start(async { - rx.await.unwrap_or_default(); - })); - + tokio::spawn(proxy.start()); Ok((addr, tx)) } @@ -294,7 +331,6 @@ impl TestHandler { } } -#[async_trait] impl HttpHandler for TestHandler { async fn handle_request( &mut self, @@ -316,7 +352,6 @@ impl HttpHandler for TestHandler { } } -#[async_trait] impl WebSocketHandler for TestHandler { async fn handle_message(&mut self, _ctx: &WebSocketContext, msg: Message) -> Option { self.message_counter.fetch_add(1, Ordering::Relaxed); diff --git a/tests/openssl_ca.rs b/tests/openssl_ca.rs index 5532185..a0348ef 100644 --- a/tests/openssl_ca.rs +++ b/tests/openssl_ca.rs @@ -23,6 +23,7 @@ async fn https_rustls() { common::rustls_client(), common::rustls_websocket_connector(), ) + .await .unwrap(); let (server_addr, stop_server) = common::start_https_server(build_ca()).await.unwrap(); @@ -49,6 +50,7 @@ async fn https_native_tls() { common::native_tls_client(), common::native_tls_websocket_connector(), ) + .await .unwrap(); let (server_addr, stop_server) = common::start_https_server(build_ca()).await.unwrap(); @@ -75,6 +77,7 @@ async fn without_intercept() { common::http_client(), common::plain_websocket_connector(), ) + .await .unwrap(); let (server_addr, stop_server) = common::start_https_server(build_ca()).await.unwrap(); @@ -101,9 +104,10 @@ async fn decodes_response() { common::native_tls_client(), common::native_tls_websocket_connector(), ) + .await .unwrap(); - let (server_addr, stop_server) = common::start_http_server().unwrap(); + let (server_addr, stop_server) = common::start_http_server().await.unwrap(); let client = common::build_client(&proxy_addr.to_string()); let res = client @@ -121,8 +125,8 @@ async fn decodes_response() { #[tokio::test] async fn noop() { - let (proxy_addr, stop_proxy) = common::start_noop_proxy(build_ca()).unwrap(); - let (server_addr, stop_server) = common::start_http_server().unwrap(); + let (proxy_addr, stop_proxy) = common::start_noop_proxy(build_ca()).await.unwrap(); + let (server_addr, stop_server) = common::start_http_server().await.unwrap(); let client = common::build_client(&proxy_addr.to_string()); let res = client diff --git a/tests/rcgen_ca.rs b/tests/rcgen_ca.rs index a417b7b..6764173 100644 --- a/tests/rcgen_ca.rs +++ b/tests/rcgen_ca.rs @@ -1,4 +1,4 @@ -use hudsucker::{certificate_authority::RcgenAuthority, rustls}; +use hudsucker::certificate_authority::RcgenAuthority; use rustls_pemfile as pemfile; use std::sync::atomic::Ordering; @@ -7,21 +7,13 @@ mod common; fn build_ca() -> RcgenAuthority { let mut private_key_bytes: &[u8] = include_bytes!("../examples/ca/hudsucker.key"); let mut ca_cert_bytes: &[u8] = include_bytes!("../examples/ca/hudsucker.cer"); - let private_key = rustls::PrivateKey( - pemfile::pkcs8_private_keys(&mut private_key_bytes) - .next() - .unwrap() - .expect("Failed to parse private key") - .secret_pkcs8_der() - .to_vec(), - ); - let ca_cert = rustls::Certificate( - pemfile::certs(&mut ca_cert_bytes) - .next() - .unwrap() - .expect("Failed to parse CA certificate") - .to_vec(), - ); + let private_key = pemfile::private_key(&mut private_key_bytes) + .unwrap() + .expect("Failed to parse private key"); + let ca_cert = pemfile::certs(&mut ca_cert_bytes) + .next() + .unwrap() + .expect("Failed to parse CA certificate"); RcgenAuthority::new(private_key, ca_cert, 1_000) .expect("Failed to create Certificate Authority") @@ -34,6 +26,7 @@ async fn https_rustls() { common::rustls_client(), common::rustls_websocket_connector(), ) + .await .unwrap(); let (server_addr, stop_server) = common::start_https_server(build_ca()).await.unwrap(); @@ -60,6 +53,7 @@ async fn https_native_tls() { common::native_tls_client(), common::native_tls_websocket_connector(), ) + .await .unwrap(); let (server_addr, stop_server) = common::start_https_server(build_ca()).await.unwrap(); @@ -86,6 +80,7 @@ async fn without_intercept() { common::http_client(), common::plain_websocket_connector(), ) + .await .unwrap(); let (server_addr, stop_server) = common::start_https_server(build_ca()).await.unwrap(); @@ -112,9 +107,10 @@ async fn decodes_response() { common::native_tls_client(), common::native_tls_websocket_connector(), ) + .await .unwrap(); - let (server_addr, stop_server) = common::start_http_server().unwrap(); + let (server_addr, stop_server) = common::start_http_server().await.unwrap(); let client = common::build_client(&proxy_addr.to_string()); let res = client @@ -132,8 +128,8 @@ async fn decodes_response() { #[tokio::test] async fn noop() { - let (proxy_addr, stop_proxy) = common::start_noop_proxy(build_ca()).unwrap(); - let (server_addr, stop_server) = common::start_http_server().unwrap(); + let (proxy_addr, stop_proxy) = common::start_noop_proxy(build_ca()).await.unwrap(); + let (server_addr, stop_server) = common::start_http_server().await.unwrap(); let client = common::build_client(&proxy_addr.to_string()); let res = client diff --git a/tests/websocket.rs b/tests/websocket.rs index d821beb..a4abe5b 100644 --- a/tests/websocket.rs +++ b/tests/websocket.rs @@ -1,8 +1,6 @@ use async_http_proxy::http_connect_tokio; use futures::{SinkExt, StreamExt}; -use hudsucker::{ - certificate_authority::RcgenAuthority, rustls, tokio_tungstenite::tungstenite::Message, -}; +use hudsucker::{certificate_authority::RcgenAuthority, tokio_tungstenite::tungstenite::Message}; use rustls_pemfile as pemfile; use std::sync::atomic::Ordering; use tokio::net::TcpStream; @@ -13,21 +11,13 @@ mod common; fn build_ca() -> RcgenAuthority { let mut private_key_bytes: &[u8] = include_bytes!("../examples/ca/hudsucker.key"); let mut ca_cert_bytes: &[u8] = include_bytes!("../examples/ca/hudsucker.cer"); - let private_key = rustls::PrivateKey( - pemfile::pkcs8_private_keys(&mut private_key_bytes) - .next() - .unwrap() - .expect("Failed to parse private key") - .secret_pkcs8_der() - .to_vec(), - ); - let ca_cert = rustls::Certificate( - pemfile::certs(&mut ca_cert_bytes) - .next() - .unwrap() - .expect("Failed to parse CA certificate") - .to_vec(), - ); + let private_key = pemfile::private_key(&mut private_key_bytes) + .unwrap() + .expect("Failed to parse private key"); + let ca_cert = pemfile::certs(&mut ca_cert_bytes) + .next() + .unwrap() + .expect("Failed to parse CA certificate"); RcgenAuthority::new(private_key, ca_cert, 1_000) .expect("Failed to create Certificate Authority") @@ -40,9 +30,10 @@ async fn http() { common::native_tls_client(), common::native_tls_websocket_connector(), ) + .await .unwrap(); - let (server_addr, stop_server) = common::start_http_server().unwrap(); + let (server_addr, stop_server) = common::start_http_server().await.unwrap(); let mut stream = TcpStream::connect(proxy_addr).await.unwrap(); http_connect_tokio( @@ -75,6 +66,7 @@ async fn https_rustls() { common::rustls_client(), common::rustls_websocket_connector(), ) + .await .unwrap(); let (server_addr, stop_server) = common::start_https_server(build_ca()).await.unwrap(); @@ -111,6 +103,7 @@ async fn https_native_tls() { common::native_tls_client(), common::native_tls_websocket_connector(), ) + .await .unwrap(); let (server_addr, stop_server) = common::start_https_server(build_ca()).await.unwrap(); @@ -147,9 +140,10 @@ async fn without_intercept() { common::http_client(), common::plain_websocket_connector(), ) + .await .unwrap(); - let (server_addr, stop_server) = common::start_http_server().unwrap(); + let (server_addr, stop_server) = common::start_http_server().await.unwrap(); let mut stream = TcpStream::connect(proxy_addr).await.unwrap(); http_connect_tokio( @@ -177,8 +171,8 @@ async fn without_intercept() { #[tokio::test] async fn noop() { - let (proxy_addr, stop_proxy) = common::start_noop_proxy(build_ca()).unwrap(); - let (server_addr, stop_server) = common::start_http_server().unwrap(); + let (proxy_addr, stop_proxy) = common::start_noop_proxy(build_ca()).await.unwrap(); + let (server_addr, stop_server) = common::start_http_server().await.unwrap(); let mut stream = TcpStream::connect(proxy_addr).await.unwrap(); http_connect_tokio(