diff --git a/Cargo.toml b/Cargo.toml index bcc341b..cb3989b 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -24,20 +24,21 @@ futures = "0.3.11" http = "1.1.0" http-body-util = "0.1.0" hyper = "1.1.0" -hyper-rustls = { version = "0.26.0", default-features = false, features = ["http1", "logging", "ring", "tls12", "webpki-tokio"], optional = true } +hyper-rustls = { version = "0.27.0", default-features = false, features = ["http1", "logging", "tls12", "webpki-tokio"], optional = true } hyper-tls = { version = "0.6.0", optional = true } -hyper-tungstenite = "0.13.0" +hyper-tungstenite = "0.15.0" hyper-util = { version="0.1.3", features = ["client-legacy", "server", "http1"] } moka = { version = "0.12.0", features = ["future"], optional = true } openssl = { version = "0.10.46", optional = true } rand = { version = "0.8.0", optional = true } rcgen = { version = "0.13.0", features = ["x509-parser"], optional = true } thiserror = "1.0.30" -time = { version = "0.3.20", optional = true } +time = { version = "0.3.35", optional = true } tokio = { version = "1.24.2", features = ["macros", "rt"] } tokio-graceful = "0.1.6" -tokio-rustls = "0.25.0" -tokio-tungstenite = "0.21.0" +tokio-native-tls = { version = "0.3.1", optional = true } +tokio-rustls = { version = "0.26.0", features = ["logging", "tls12"] } +tokio-tungstenite = "0.24.0" tokio-util = { version = "0.7.1", features = ["io"], optional = true } tracing = { version = "0.1.35", features = ["log"] } @@ -45,7 +46,7 @@ tracing = { version = "0.1.35", features = ["log"] } async-http-proxy = { version = "1.2.5", features = ["runtime-tokio"] } criterion = { version = "0.5.0", features = ["async_tokio"] } reqwest = "0.12.0" -rustls-native-certs = "0.7.0" +rustls-native-certs = "0.8.0" rustls-pemfile = "2.0.0" tokio = { version = "1.24.2", features = ["full"] } tokio-native-tls = "0.3.1" @@ -57,7 +58,7 @@ decoder = ["dep:async-compression", "dep:tokio-util", "tokio/io-util"] default = ["decoder", "rcgen-ca", "rustls-client"] full = ["decoder", "http2", "native-tls-client", "openssl-ca", "rcgen-ca", "rustls-client"] http2 = ["hyper-util/http2", "hyper-rustls?/http2"] -native-tls-client = ["dep:hyper-tls", "tokio-tungstenite/native-tls"] +native-tls-client = ["dep:hyper-tls", "dep:tokio-native-tls", "tokio-tungstenite/native-tls"] openssl-ca = ["dep:openssl", "dep:moka"] rcgen-ca = ["dep:rcgen", "dep:moka", "dep:time", "dep:rand"] rustls-client = ["dep:hyper-rustls", "tokio-tungstenite/rustls-tls-webpki-roots"] diff --git a/benches/certificate_authorities.rs b/benches/certificate_authorities.rs index 0323ad1..4f598bf 100644 --- a/benches/certificate_authorities.rs +++ b/benches/certificate_authorities.rs @@ -4,6 +4,7 @@ use hudsucker::{ certificate_authority::{CertificateAuthority, OpensslAuthority, RcgenAuthority}, openssl::{hash::MessageDigest, pkey::PKey, x509::X509}, rcgen::{CertificateParams, KeyPair}, + rustls::crypto::aws_lc_rs, }; fn runtime() -> tokio::runtime::Runtime { @@ -21,7 +22,7 @@ fn build_rcgen_ca(cache_size: u64) -> RcgenAuthority { .self_signed(&key_pair) .expect("Failed to sign CA certificate"); - RcgenAuthority::new(key_pair, ca_cert, cache_size) + RcgenAuthority::new(key_pair, ca_cert, cache_size, aws_lc_rs::default_provider()) } fn build_openssl_ca(cache_size: u64) -> OpensslAuthority { @@ -30,7 +31,13 @@ fn build_openssl_ca(cache_size: u64) -> OpensslAuthority { let private_key = PKey::private_key_from_pem(private_key).expect("Failed to parse private key"); let ca_cert = X509::from_pem(ca_cert).expect("Failed to parse CA certificate"); - OpensslAuthority::new(private_key, ca_cert, MessageDigest::sha256(), cache_size) + OpensslAuthority::new( + private_key, + ca_cert, + MessageDigest::sha256(), + cache_size, + aws_lc_rs::default_provider(), + ) } fn compare_cas(c: &mut Criterion) { diff --git a/benches/proxy.rs b/benches/proxy.rs index fe0f0bf..9b1fd1b 100644 --- a/benches/proxy.rs +++ b/benches/proxy.rs @@ -8,6 +8,7 @@ use hudsucker::{ server::conn::auto, }, rcgen::{CertificateParams, KeyPair}, + rustls::crypto::aws_lc_rs, Body, Proxy, }; use reqwest::Certificate; @@ -33,7 +34,7 @@ fn build_ca() -> RcgenAuthority { .self_signed(&key_pair) .expect("Failed to sign CA certificate"); - RcgenAuthority::new(key_pair, ca_cert, 1000) + RcgenAuthority::new(key_pair, ca_cert, 1000, aws_lc_rs::default_provider()) } async fn test_server(req: Request) -> Result, Infallible> { @@ -145,12 +146,13 @@ async fn start_proxy( let (tx, rx) = tokio::sync::oneshot::channel(); let proxy = Proxy::builder() .with_listener(listener) - .with_client(native_tls_client()) .with_ca(ca) + .with_client(native_tls_client()) .with_graceful_shutdown(async { rx.await.unwrap_or_default(); }) - .build(); + .build() + .expect("Failed to create proxy"); tokio::spawn(proxy.start()); diff --git a/examples/log.rs b/examples/log.rs index be126e6..e57f483 100644 --- a/examples/log.rs +++ b/examples/log.rs @@ -2,6 +2,7 @@ use hudsucker::{ certificate_authority::RcgenAuthority, hyper::{Request, Response}, rcgen::{CertificateParams, KeyPair}, + rustls::crypto::aws_lc_rs, tokio_tungstenite::tungstenite::Message, *, }; @@ -52,16 +53,17 @@ async fn main() { .self_signed(&key_pair) .expect("Failed to sign CA certificate"); - let ca = RcgenAuthority::new(key_pair, ca_cert, 1_000); + let ca = RcgenAuthority::new(key_pair, ca_cert, 1_000, aws_lc_rs::default_provider()); let proxy = Proxy::builder() .with_addr(SocketAddr::from(([127, 0, 0, 1], 3000))) - .with_rustls_client() .with_ca(ca) + .with_rustls_client(aws_lc_rs::default_provider()) .with_http_handler(LogHandler) .with_websocket_handler(LogHandler) .with_graceful_shutdown(shutdown_signal()) - .build(); + .build() + .expect("Failed to create proxy"); if let Err(e) = proxy.start().await { error!("{}", e); diff --git a/examples/noop.rs b/examples/noop.rs index 1367877..ab13f28 100644 --- a/examples/noop.rs +++ b/examples/noop.rs @@ -1,6 +1,7 @@ use hudsucker::{ certificate_authority::RcgenAuthority, rcgen::{CertificateParams, KeyPair}, + rustls::crypto::aws_lc_rs, *, }; use std::net::SocketAddr; @@ -24,14 +25,15 @@ async fn main() { .self_signed(&key_pair) .expect("Failed to sign CA certificate"); - let ca = RcgenAuthority::new(key_pair, ca_cert, 1_000); + let ca = RcgenAuthority::new(key_pair, ca_cert, 1_000, aws_lc_rs::default_provider()); let proxy = Proxy::builder() .with_addr(SocketAddr::from(([127, 0, 0, 1], 3000))) - .with_rustls_client() .with_ca(ca) + .with_rustls_client(aws_lc_rs::default_provider()) .with_graceful_shutdown(shutdown_signal()) - .build(); + .build() + .expect("Failed to create proxy"); if let Err(e) = proxy.start().await { error!("{}", e); diff --git a/examples/openssl.rs b/examples/openssl.rs index 86d34a4..e364c69 100644 --- a/examples/openssl.rs +++ b/examples/openssl.rs @@ -2,6 +2,7 @@ use hudsucker::{ certificate_authority::OpensslAuthority, hyper::{Request, Response}, openssl::{hash::MessageDigest, pkey::PKey, x509::X509}, + rustls::crypto::aws_lc_rs, tokio_tungstenite::tungstenite::Message, *, }; @@ -50,15 +51,22 @@ async fn main() { PKey::private_key_from_pem(private_key_bytes).expect("Failed to parse private key"); let ca_cert = X509::from_pem(ca_cert_bytes).expect("Failed to parse CA certificate"); - let ca = OpensslAuthority::new(private_key, ca_cert, MessageDigest::sha256(), 1_000); + let ca = OpensslAuthority::new( + private_key, + ca_cert, + MessageDigest::sha256(), + 1_000, + aws_lc_rs::default_provider(), + ); let proxy = Proxy::builder() .with_addr(SocketAddr::from(([127, 0, 0, 1], 3000))) - .with_rustls_client() .with_ca(ca) + .with_rustls_client(aws_lc_rs::default_provider()) .with_http_handler(LogHandler) .with_graceful_shutdown(shutdown_signal()) - .build(); + .build() + .expect("Failed to create proxy"); if let Err(e) = proxy.start().await { error!("{}", e); diff --git a/rustfmt.toml b/rustfmt.toml index efb2974..51f4713 100644 --- a/rustfmt.toml +++ b/rustfmt.toml @@ -1,2 +1,3 @@ +format_code_in_doc_comments = true imports_granularity = "Crate" newline_style = "Unix" diff --git a/src/certificate_authority/openssl_authority.rs b/src/certificate_authority/openssl_authority.rs index d04855e..7c87997 100644 --- a/src/certificate_authority/openssl_authority.rs +++ b/src/certificate_authority/openssl_authority.rs @@ -15,6 +15,7 @@ use std::{ time::{Duration, SystemTime}, }; use tokio_rustls::rustls::{ + crypto::CryptoProvider, pki_types::{CertificateDer, PrivateKeyDer, PrivatePkcs8KeyDer}, ServerConfig, }; @@ -32,6 +33,7 @@ use tracing::debug; /// use hudsucker::{ /// certificate_authority::OpensslAuthority, /// openssl::{hash::MessageDigest, pkey::PKey, x509::X509}, +/// rustls::crypto::aws_lc_rs, /// }; /// /// let private_key_bytes: &[u8] = include_bytes!("../../examples/ca/hudsucker.key"); @@ -39,7 +41,13 @@ use tracing::debug; /// let private_key = PKey::private_key_from_pem(private_key_bytes).unwrap(); /// let ca_cert = X509::from_pem(ca_cert_bytes).unwrap(); /// -/// let ca = OpensslAuthority::new(private_key, ca_cert, MessageDigest::sha256(), 1_000); +/// let ca = OpensslAuthority::new( +/// private_key, +/// ca_cert, +/// MessageDigest::sha256(), +/// 1_000, +/// aws_lc_rs::default_provider(), +/// ); /// ``` #[cfg_attr(docsrs, doc(cfg(feature = "openssl-ca")))] pub struct OpensslAuthority { @@ -48,11 +56,18 @@ pub struct OpensslAuthority { ca_cert: X509, hash: MessageDigest, cache: Cache>, + provider: Arc, } impl OpensslAuthority { /// Creates a new openssl authority. - pub fn new(pkey: PKey, ca_cert: X509, hash: MessageDigest, cache_size: u64) -> Self { + pub fn new( + pkey: PKey, + ca_cert: X509, + hash: MessageDigest, + cache_size: u64, + provider: CryptoProvider, + ) -> Self { let private_key = PrivateKeyDer::from(PrivatePkcs8KeyDer::from( pkey.private_key_to_pkcs8() .expect("Failed to encode private key"), @@ -67,6 +82,7 @@ impl OpensslAuthority { .max_capacity(cache_size) .time_to_live(Duration::from_secs(CACHE_TTL)) .build(), + provider: Arc::new(provider), } } @@ -120,7 +136,9 @@ impl CertificateAuthority for OpensslAuthority { .gen_cert(authority) .unwrap_or_else(|_| panic!("Failed to generate certificate for {}", authority))]; - let mut server_cfg = ServerConfig::builder() + let mut server_cfg = ServerConfig::builder_with_provider(Arc::clone(&self.provider)) + .with_safe_default_protocol_versions() + .expect("Failed to specify protocol versions") .with_no_client_auth() .with_single_cert(certs, self.private_key.clone_key()) .expect("Failed to build ServerConfig"); @@ -144,6 +162,7 @@ impl CertificateAuthority for OpensslAuthority { #[cfg(test)] mod tests { use super::*; + use tokio_rustls::rustls::crypto::aws_lc_rs; fn build_ca(cache_size: u64) -> OpensslAuthority { let private_key_bytes: &[u8] = include_bytes!("../../examples/ca/hudsucker.key"); @@ -152,7 +171,13 @@ mod tests { PKey::private_key_from_pem(private_key_bytes).expect("Failed to parse private key"); let ca_cert = X509::from_pem(ca_cert_bytes).expect("Failed to parse CA certificate"); - OpensslAuthority::new(private_key, ca_cert, MessageDigest::sha256(), cache_size) + OpensslAuthority::new( + private_key, + ca_cert, + MessageDigest::sha256(), + cache_size, + aws_lc_rs::default_provider(), + ) } #[test] diff --git a/src/certificate_authority/rcgen_authority.rs b/src/certificate_authority/rcgen_authority.rs index 21ad3be..84220e7 100644 --- a/src/certificate_authority/rcgen_authority.rs +++ b/src/certificate_authority/rcgen_authority.rs @@ -8,6 +8,7 @@ use rcgen::{ use std::sync::Arc; use time::{Duration, OffsetDateTime}; use tokio_rustls::rustls::{ + crypto::CryptoProvider, pki_types::{CertificateDer, PrivateKeyDer, PrivatePkcs8KeyDer}, ServerConfig, }; @@ -22,7 +23,7 @@ use tracing::debug; /// # Examples /// /// ```rust -/// use hudsucker::{certificate_authority::RcgenAuthority, rustls}; +/// use hudsucker::{certificate_authority::RcgenAuthority, rustls::crypto::aws_lc_rs}; /// use rcgen::{CertificateParams, KeyPair}; /// /// let key_pair = include_str!("../../examples/ca/hudsucker.key"); @@ -33,7 +34,7 @@ use tracing::debug; /// .self_signed(&key_pair) /// .expect("Failed to sign CA certificate"); /// -/// let ca = RcgenAuthority::new(key_pair, ca_cert, 1_000); +/// let ca = RcgenAuthority::new(key_pair, ca_cert, 1_000, aws_lc_rs::default_provider()); /// ``` #[cfg_attr(docsrs, doc(cfg(feature = "rcgen-ca")))] pub struct RcgenAuthority { @@ -41,11 +42,17 @@ pub struct RcgenAuthority { ca_cert: Certificate, private_key: PrivateKeyDer<'static>, cache: Cache>, + provider: Arc, } impl RcgenAuthority { /// Creates a new rcgen authority. - pub fn new(key_pair: KeyPair, ca_cert: Certificate, cache_size: u64) -> Self { + pub fn new( + key_pair: KeyPair, + ca_cert: Certificate, + cache_size: u64, + provider: CryptoProvider, + ) -> Self { let private_key = PrivateKeyDer::from(PrivatePkcs8KeyDer::from(key_pair.serialize_der())); Self { @@ -56,6 +63,7 @@ impl RcgenAuthority { .max_capacity(cache_size) .time_to_live(std::time::Duration::from_secs(CACHE_TTL)) .build(), + provider: Arc::new(provider), } } @@ -92,7 +100,9 @@ impl CertificateAuthority for RcgenAuthority { let certs = vec![self.gen_cert(authority)]; - let mut server_cfg = ServerConfig::builder() + let mut server_cfg = ServerConfig::builder_with_provider(Arc::clone(&self.provider)) + .with_safe_default_protocol_versions() + .expect("Failed to specify protocol versions") .with_no_client_auth() .with_single_cert(certs, self.private_key.clone_key()) .expect("Failed to build ServerConfig"); @@ -116,6 +126,7 @@ impl CertificateAuthority for RcgenAuthority { #[cfg(test)] mod tests { use super::*; + use tokio_rustls::rustls::crypto::aws_lc_rs; fn build_ca(cache_size: u64) -> RcgenAuthority { let key_pair = include_str!("../../examples/ca/hudsucker.key"); @@ -126,7 +137,7 @@ mod tests { .self_signed(&key_pair) .expect("Failed to sign CA certificate"); - RcgenAuthority::new(key_pair, ca_cert, cache_size) + RcgenAuthority::new(key_pair, ca_cert, cache_size, aws_lc_rs::default_provider()) } #[test] diff --git a/src/error.rs b/src/error.rs index a14098e..f511327 100644 --- a/src/error.rs +++ b/src/error.rs @@ -1,5 +1,14 @@ use thiserror::Error; +#[derive(Debug, Error)] +#[non_exhaustive] +pub enum BuilderError { + #[error("{0}")] + NativeTls(#[from] hyper_tls::native_tls::Error), + #[error("{0}")] + Rustls(#[from] tokio_rustls::rustls::Error), +} + #[derive(Debug, Error)] #[non_exhaustive] pub enum Error { @@ -9,6 +18,8 @@ pub enum Error { Io(#[from] std::io::Error), #[error("unable to decode body")] Decode, + #[error("builder error")] + Builder(#[from] BuilderError), #[error("unknown error")] Unknown, } diff --git a/src/proxy/builder.rs b/src/proxy/builder.rs index f773122..2809e42 100644 --- a/src/proxy/builder.rs +++ b/src/proxy/builder.rs @@ -1,11 +1,7 @@ use crate::{ - certificate_authority::CertificateAuthority, Body, HttpHandler, NoopHandler, Proxy, - WebSocketHandler, + certificate_authority::CertificateAuthority, error::BuilderError, Body, Error, HttpHandler, + NoopHandler, Proxy, WebSocketHandler, }; -#[cfg(feature = "rustls-client")] -use hyper_rustls::{HttpsConnector as RustlsConnector, HttpsConnectorBuilder}; -#[cfg(feature = "native-tls-client")] -use hyper_tls::HttpsConnector as NativeTlsConnector; use hyper_util::{ client::legacy::{ connect::{Connect, HttpConnector}, @@ -20,6 +16,7 @@ use std::{ sync::Arc, }; use tokio::net::TcpListener; +use tokio_rustls::rustls::{crypto::CryptoProvider, ClientConfig}; use tokio_tungstenite::Connector; /// A builder for creating a [`Proxy`]. @@ -33,6 +30,7 @@ use tokio_tungstenite::Connector; /// # use hudsucker::{ /// # certificate_authority::RcgenAuthority, /// # rcgen::{CertificateParams, KeyPair}, +/// # rustls::crypto::aws_lc_rs, /// # }; /// # /// # let key_pair = include_str!("../../examples/ca/hudsucker.key"); @@ -43,14 +41,14 @@ use tokio_tungstenite::Connector; /// # .self_signed(&key_pair) /// # .expect("Failed to sign CA certificate"); /// # -/// # let ca = RcgenAuthority::new(key_pair, ca_cert, 1_000); +/// # let ca = RcgenAuthority::new(key_pair, ca_cert, 1_000, aws_lc_rs::default_provider()); /// /// // let ca = ...; /// /// let proxy = Proxy::builder() /// .with_addr(std::net::SocketAddr::from(([127, 0, 0, 1], 0))) -/// .with_rustls_client() /// .with_ca(ca) +/// .with_rustls_client(aws_lc_rs::default_provider()) /// .build(); /// # } /// ``` @@ -74,15 +72,15 @@ impl ProxyBuilder { } /// Set the address to listen on. - pub fn with_addr(self, addr: SocketAddr) -> ProxyBuilder { - ProxyBuilder(WantsClient { + pub fn with_addr(self, addr: SocketAddr) -> ProxyBuilder { + ProxyBuilder(WantsCa { al: AddrOrListener::Addr(addr), }) } /// Set a listener to use for the proxy server. - pub fn with_listener(self, listener: TcpListener) -> ProxyBuilder { - ProxyBuilder(WantsClient { + pub fn with_listener(self, listener: TcpListener) -> ProxyBuilder { + ProxyBuilder(WantsCa { al: AddrOrListener::Listener(listener), }) } @@ -94,19 +92,57 @@ impl Default for ProxyBuilder { } } +/// Builder state that needs a certificate authority. +#[derive(Debug)] +pub struct WantsCa { + al: AddrOrListener, +} + +impl ProxyBuilder { + /// Set the certificate authority to use. + pub fn with_ca(self, ca: CA) -> ProxyBuilder> { + ProxyBuilder(WantsClient { al: self.0.al, ca }) + } +} + /// Builder state that needs a client. #[derive(Debug)] -pub struct WantsClient { +pub struct WantsClient { al: AddrOrListener, + ca: CA, } -impl ProxyBuilder { +impl ProxyBuilder> { /// Use a hyper-rustls connector. #[cfg(feature = "rustls-client")] #[cfg_attr(docsrs, doc(cfg(feature = "rustls-client")))] - pub fn with_rustls_client(self) -> ProxyBuilder>> { - let https = HttpsConnectorBuilder::new() - .with_webpki_roots() + pub fn with_rustls_client( + self, + provider: CryptoProvider, + ) -> ProxyBuilder>> + { + use hyper_rustls::ConfigBuilderExt; + + let rustls_config = match ClientConfig::builder_with_provider(Arc::new(provider)) + .with_safe_default_protocol_versions() + { + Ok(config) => config.with_webpki_roots().with_no_client_auth(), + Err(e) => { + return ProxyBuilder(WantsHandlers { + al: self.0.al, + ca: self.0.ca, + client: Err(BuilderError::from(e)), + http_handler: NoopHandler::new(), + websocket_handler: NoopHandler::new(), + websocket_connector: None, + server: None, + graceful_shutdown: pending(), + }); + } + }; + + let https = hyper_rustls::HttpsConnectorBuilder::new() + .with_tls_config(rustls_config.clone()) .https_or_http() .enable_http1(); @@ -115,12 +151,18 @@ impl ProxyBuilder { let https = https.build(); - ProxyBuilder(WantsCa { + ProxyBuilder(WantsHandlers { al: self.0.al, - client: Client::builder(TokioExecutor::new()) + ca: self.0.ca, + client: Ok(Client::builder(TokioExecutor::new()) .http1_title_case_headers(true) .http1_preserve_header_case(true) - .build(https), + .build(https)), + http_handler: NoopHandler::new(), + websocket_handler: NoopHandler::new(), + websocket_connector: Some(Connector::Rustls(Arc::new(rustls_config))), + server: None, + graceful_shutdown: pending(), }) } @@ -129,47 +171,54 @@ impl ProxyBuilder { #[cfg_attr(docsrs, doc(cfg(feature = "native-tls-client")))] pub fn with_native_tls_client( self, - ) -> ProxyBuilder>> { - let https = NativeTlsConnector::new(); + ) -> ProxyBuilder>> + { + let tls_connector = match hyper_tls::native_tls::TlsConnector::new() { + Ok(tls_connector) => tls_connector, + Err(e) => { + return ProxyBuilder(WantsHandlers { + al: self.0.al, + ca: self.0.ca, + client: Err(BuilderError::from(e)), + http_handler: NoopHandler::new(), + websocket_handler: NoopHandler::new(), + websocket_connector: None, + server: None, + graceful_shutdown: pending(), + }) + } + }; - ProxyBuilder(WantsCa { + let tokio_tls_connector = tokio_native_tls::TlsConnector::from(tls_connector.clone()); + let https = hyper_tls::HttpsConnector::from((HttpConnector::new(), tokio_tls_connector)); + + ProxyBuilder(WantsHandlers { al: self.0.al, - client: Client::builder(TokioExecutor::new()) + ca: self.0.ca, + client: Ok(Client::builder(TokioExecutor::new()) .http1_title_case_headers(true) .http1_preserve_header_case(true) - .build(https), + .build(https)), + http_handler: NoopHandler::new(), + websocket_handler: NoopHandler::new(), + websocket_connector: Some(Connector::NativeTls(tls_connector)), + server: None, + graceful_shutdown: pending(), }) } /// Use a custom client. - pub fn with_client(self, client: Client) -> ProxyBuilder> + pub fn with_client( + self, + client: Client, + ) -> ProxyBuilder>> where C: Connect + Clone + Send + Sync + 'static, { - ProxyBuilder(WantsCa { - al: self.0.al, - client, - }) - } -} - -/// Builder state that needs a certificate authority. -#[derive(Debug)] -pub struct WantsCa { - al: AddrOrListener, - client: Client, -} - -impl ProxyBuilder> { - /// Set the certificate authority to use. - pub fn with_ca( - self, - ca: CA, - ) -> ProxyBuilder>> { ProxyBuilder(WantsHandlers { al: self.0.al, - client: self.0.client, - ca, + ca: self.0.ca, + client: Ok(client), http_handler: NoopHandler::new(), websocket_handler: NoopHandler::new(), websocket_connector: None, @@ -180,10 +229,10 @@ impl ProxyBuilder> { } /// Builder state that can take additional handlers. -pub struct WantsHandlers { +pub struct WantsHandlers { al: AddrOrListener, - client: Client, ca: CA, + client: Result, BuilderError>, http_handler: H, websocket_handler: W, websocket_connector: Option, @@ -191,16 +240,16 @@ pub struct WantsHandlers { graceful_shutdown: F, } -impl ProxyBuilder> { +impl ProxyBuilder> { /// Set the HTTP handler. pub fn with_http_handler( self, http_handler: H2, - ) -> ProxyBuilder> { + ) -> ProxyBuilder> { ProxyBuilder(WantsHandlers { al: self.0.al, - client: self.0.client, ca: self.0.ca, + client: self.0.client, http_handler, websocket_handler: self.0.websocket_handler, websocket_connector: self.0.websocket_connector, @@ -213,11 +262,11 @@ impl ProxyBuilder> { pub fn with_websocket_handler( self, websocket_handler: W2, - ) -> ProxyBuilder> { + ) -> ProxyBuilder> { ProxyBuilder(WantsHandlers { al: self.0.al, - client: self.0.client, ca: self.0.ca, + client: self.0.client, http_handler: self.0.http_handler, websocket_handler, websocket_connector: self.0.websocket_connector, @@ -246,11 +295,11 @@ impl ProxyBuilder> { pub fn with_graceful_shutdown + Send + 'static>( self, graceful_shutdown: F2, - ) -> ProxyBuilder> { + ) -> ProxyBuilder> { ProxyBuilder(WantsHandlers { al: self.0.al, - client: self.0.client, ca: self.0.ca, + client: self.0.client, http_handler: self.0.http_handler, websocket_handler: self.0.websocket_handler, websocket_connector: self.0.websocket_connector, @@ -260,16 +309,16 @@ impl ProxyBuilder> { } /// Build the proxy. - pub fn build(self) -> Proxy { - Proxy { + pub fn build(self) -> Result, Error> { + Ok(Proxy { al: self.0.al, - client: self.0.client, ca: Arc::new(self.0.ca), + client: self.0.client?, http_handler: self.0.http_handler, websocket_handler: self.0.websocket_handler, websocket_connector: self.0.websocket_connector, server: self.0.server, graceful_shutdown: self.0.graceful_shutdown, - } + }) } } diff --git a/src/proxy/mod.rs b/src/proxy/mod.rs index db1c10a..03ae5e6 100644 --- a/src/proxy/mod.rs +++ b/src/proxy/mod.rs @@ -29,6 +29,7 @@ use tracing::error; /// # use hudsucker::{ /// # certificate_authority::RcgenAuthority, /// # rcgen::{CertificateParams, KeyPair}, +/// # rustls::crypto::aws_lc_rs, /// # }; /// # /// # #[cfg(all(feature = "rcgen-ca", feature = "rustls-client"))] @@ -42,7 +43,7 @@ use tracing::error; /// # .self_signed(&key_pair) /// # .expect("Failed to sign CA certificate"); /// # -/// # let ca = RcgenAuthority::new(key_pair, ca_cert, 1_000); +/// # let ca = RcgenAuthority::new(key_pair, ca_cert, 1_000, aws_lc_rs::default_provider()); /// /// // let ca = ...; /// @@ -50,12 +51,13 @@ use tracing::error; /// /// let proxy = Proxy::builder() /// .with_addr(std::net::SocketAddr::from(([127, 0, 0, 1], 0))) -/// .with_rustls_client() /// .with_ca(ca) +/// .with_rustls_client(aws_lc_rs::default_provider()) /// .with_graceful_shutdown(async { /// done.await.unwrap_or_default(); /// }) -/// .build(); +/// .build() +/// .expect("Failed to create proxy"); /// /// tokio::spawn(proxy.start()); /// diff --git a/tests/common/mod.rs b/tests/common/mod.rs index d86d3f8..9a3fb5e 100644 --- a/tests/common/mod.rs +++ b/tests/common/mod.rs @@ -256,15 +256,16 @@ where let proxy = Proxy::builder() .with_listener(listener) - .with_client(client) .with_ca(ca) + .with_client(client) .with_http_handler(handler.clone()) .with_websocket_handler(handler.clone()) .with_websocket_connector(websocket_connector) .with_graceful_shutdown(async { rx.await.unwrap_or_default(); }) - .build(); + .build() + .expect("Failed to create proxy"); tokio::spawn(proxy.start()); Ok((addr, handler, tx)) @@ -279,12 +280,13 @@ pub async fn start_noop_proxy( let proxy = Proxy::builder() .with_listener(listener) - .with_client(native_tls_client()) .with_ca(ca) + .with_client(native_tls_client()) .with_graceful_shutdown(async { rx.await.unwrap_or_default(); }) - .build(); + .build() + .expect("Failed to create proxy"); tokio::spawn(proxy.start()); Ok((addr, tx)) diff --git a/tests/openssl_ca.rs b/tests/openssl_ca.rs index a0348ef..f8ed5a0 100644 --- a/tests/openssl_ca.rs +++ b/tests/openssl_ca.rs @@ -1,6 +1,7 @@ use hudsucker::{ certificate_authority::OpensslAuthority, openssl::{hash::MessageDigest, pkey::PKey, x509::X509}, + rustls::crypto::aws_lc_rs, }; use std::sync::atomic::Ordering; @@ -13,7 +14,13 @@ fn build_ca() -> OpensslAuthority { PKey::private_key_from_pem(private_key_bytes).expect("Failed to parse private key"); let ca_cert = X509::from_pem(ca_cert_bytes).expect("Failed to parse CA certificate"); - OpensslAuthority::new(private_key, ca_cert, MessageDigest::sha256(), 1_000) + OpensslAuthority::new( + private_key, + ca_cert, + MessageDigest::sha256(), + 1_000, + aws_lc_rs::default_provider(), + ) } #[tokio::test] diff --git a/tests/rcgen_ca.rs b/tests/rcgen_ca.rs index be2832b..32c7b29 100644 --- a/tests/rcgen_ca.rs +++ b/tests/rcgen_ca.rs @@ -1,6 +1,7 @@ use hudsucker::{ certificate_authority::RcgenAuthority, rcgen::{CertificateParams, KeyPair}, + rustls::crypto::aws_lc_rs, }; use std::sync::atomic::Ordering; @@ -15,7 +16,7 @@ fn build_ca() -> RcgenAuthority { .self_signed(&key_pair) .expect("Failed to sign CA certificate"); - RcgenAuthority::new(key_pair, ca_cert, 1000) + RcgenAuthority::new(key_pair, ca_cert, 1000, aws_lc_rs::default_provider()) } #[tokio::test] diff --git a/tests/websocket.rs b/tests/websocket.rs index 63ad16d..c2fe1fd 100644 --- a/tests/websocket.rs +++ b/tests/websocket.rs @@ -3,6 +3,7 @@ use futures::{SinkExt, StreamExt}; use hudsucker::{ certificate_authority::RcgenAuthority, rcgen::{CertificateParams, KeyPair}, + rustls::crypto::aws_lc_rs, tokio_tungstenite::tungstenite::Message, }; use std::sync::atomic::Ordering; @@ -20,7 +21,7 @@ fn build_ca() -> RcgenAuthority { .self_signed(&key_pair) .expect("Failed to sign CA certificate"); - RcgenAuthority::new(key_pair, ca_cert, 1000) + RcgenAuthority::new(key_pair, ca_cert, 1000, aws_lc_rs::default_provider()) } #[tokio::test]