Skip to content

Commit

Permalink
Support custom server name resolution (#269)
Browse files Browse the repository at this point in the history
Co-authored-by: Daniel McCarney <[email protected]>
  • Loading branch information
sfackler and cpu authored Apr 10, 2024
1 parent 16d7e59 commit 7d31dc5
Show file tree
Hide file tree
Showing 3 changed files with 117 additions and 29 deletions.
92 changes: 72 additions & 20 deletions src/connector.rs
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@ pub struct HttpsConnector<T> {
force_https: bool,
http: T,
tls_config: Arc<rustls::ClientConfig>,
override_server_name: Option<String>,
server_name_resolver: Arc<dyn ResolveServerName + Sync + Send>,
}

impl<T> HttpsConnector<T> {
Expand Down Expand Up @@ -90,24 +90,10 @@ where
};

let cfg = self.tls_config.clone();
let mut hostname = match self.override_server_name.as_deref() {
Some(h) => h,
None => dst.host().unwrap_or_default(),
};

// Remove square brackets around IPv6 address.
if let Some(trimmed) = hostname
.strip_prefix('[')
.and_then(|h| h.strip_suffix(']'))
{
hostname = trimmed;
}

let hostname = match ServerName::try_from(hostname) {
Ok(dns_name) => dns_name.to_owned(),
Err(_) => {
let err = io::Error::new(io::ErrorKind::Other, "invalid dnsname");
return Box::pin(async move { Err(Box::new(err).into()) });
let hostname = match self.server_name_resolver.resolve(&dst) {
Ok(hostname) => hostname,
Err(e) => {
return Box::pin(async move { Err(e) });
}
};

Expand Down Expand Up @@ -135,7 +121,7 @@ where
force_https: false,
http,
tls_config: cfg.into(),
override_server_name: None,
server_name_resolver: Arc::new(DefaultServerNameResolver::default()),
}
}
}
Expand All @@ -147,3 +133,69 @@ impl<T> fmt::Debug for HttpsConnector<T> {
.finish()
}
}

/// The default server name resolver, which uses the hostname in the URI.
#[derive(Default)]
pub struct DefaultServerNameResolver(());

impl ResolveServerName for DefaultServerNameResolver {
fn resolve(
&self,
uri: &Uri,
) -> Result<ServerName<'static>, Box<dyn std::error::Error + Sync + Send>> {
let mut hostname = uri.host().unwrap_or_default();

// Remove square brackets around IPv6 address.
if let Some(trimmed) = hostname
.strip_prefix('[')
.and_then(|h| h.strip_suffix(']'))
{
hostname = trimmed;
}

ServerName::try_from(hostname.to_string()).map_err(|e| Box::new(e) as _)
}
}

/// A server name resolver which always returns the same fixed name.
pub struct FixedServerNameResolver {
name: ServerName<'static>,
}

impl FixedServerNameResolver {
/// Creates a new resolver returning the specified name.
pub fn new(name: ServerName<'static>) -> Self {
Self { name }
}
}

impl ResolveServerName for FixedServerNameResolver {
fn resolve(
&self,
_: &Uri,
) -> Result<ServerName<'static>, Box<dyn std::error::Error + Sync + Send>> {
Ok(self.name.clone())
}
}

impl<F, E> ResolveServerName for F
where
F: Fn(&Uri) -> Result<ServerName<'static>, E>,
E: Into<Box<dyn std::error::Error + Sync + Send>>,
{
fn resolve(
&self,
uri: &Uri,
) -> Result<ServerName<'static>, Box<dyn std::error::Error + Sync + Send>> {
self(uri).map_err(Into::into)
}
}

/// A trait implemented by types that can resolve a [`ServerName`] for a request.
pub trait ResolveServerName {
/// Maps a [`Uri`] into a [`ServerName`].
fn resolve(
&self,
uri: &Uri,
) -> Result<ServerName<'static>, Box<dyn std::error::Error + Sync + Send>>;
}
50 changes: 42 additions & 8 deletions src/connector/builder.rs
Original file line number Diff line number Diff line change
@@ -1,11 +1,14 @@
use std::sync::Arc;

use hyper_util::client::legacy::connect::HttpConnector;
#[cfg(any(feature = "rustls-native-certs", feature = "webpki-roots"))]
use rustls::crypto::CryptoProvider;
use rustls::ClientConfig;

use super::HttpsConnector;
use super::{DefaultServerNameResolver, HttpsConnector, ResolveServerName};
#[cfg(any(feature = "rustls-native-certs", feature = "webpki-roots"))]
use crate::config::ConfigBuilderExt;
use pki_types::ServerName;

/// A builder for an [`HttpsConnector`]
///
Expand Down Expand Up @@ -153,7 +156,7 @@ impl ConnectorBuilder<WantsSchemes> {
ConnectorBuilder(WantsProtocols1 {
tls_config: self.0.tls_config,
https_only: true,
override_server_name: None,
server_name_resolver: None,
})
}

Expand All @@ -165,7 +168,7 @@ impl ConnectorBuilder<WantsSchemes> {
ConnectorBuilder(WantsProtocols1 {
tls_config: self.0.tls_config,
https_only: false,
override_server_name: None,
server_name_resolver: None,
})
}
}
Expand All @@ -177,7 +180,7 @@ impl ConnectorBuilder<WantsSchemes> {
pub struct WantsProtocols1 {
tls_config: ClientConfig,
https_only: bool,
override_server_name: Option<String>,
server_name_resolver: Option<Arc<dyn ResolveServerName + Sync + Send>>,
}

impl WantsProtocols1 {
Expand All @@ -186,7 +189,9 @@ impl WantsProtocols1 {
force_https: self.https_only,
http: conn,
tls_config: std::sync::Arc::new(self.tls_config),
override_server_name: self.override_server_name,
server_name_resolver: self
.server_name_resolver
.unwrap_or_else(|| Arc::new(DefaultServerNameResolver::default())),
}
}

Expand Down Expand Up @@ -237,6 +242,22 @@ impl ConnectorBuilder<WantsProtocols1> {
})
}

/// Override server name for the TLS stack
///
/// By default, for each connection hyper-rustls will extract host portion
/// of the destination URL and verify that server certificate contains
/// this value.
///
/// If this method is called, hyper-rustls will instead use this resolver
/// to compute the value used to verify the server certificate.
pub fn with_server_name_resolver(
mut self,
resolver: impl ResolveServerName + 'static + Sync + Send,
) -> Self {
self.0.server_name_resolver = Some(Arc::new(resolver));
self
}

/// Override server name for the TLS stack
///
/// By default, for each connection hyper-rustls will extract host portion
Expand All @@ -246,9 +267,22 @@ impl ConnectorBuilder<WantsProtocols1> {
/// If this method is called, hyper-rustls will instead verify that server
/// certificate contains `override_server_name`. Domain name included in
/// the URL will not affect certificate validation.
pub fn with_server_name(mut self, override_server_name: String) -> Self {
self.0.override_server_name = Some(override_server_name);
self
#[deprecated(
since = "0.27.1",
note = "use Self::with_server_name_resolver with FixedServerNameResolver instead"
)]
pub fn with_server_name(self, mut override_server_name: String) -> Self {
// remove square brackets around IPv6 address.
if let Some(trimmed) = override_server_name
.strip_prefix('[')
.and_then(|s| s.strip_suffix(']'))
{
override_server_name = trimmed.to_string();
}

self.with_server_name_resolver(move |_: &_| {
ServerName::try_from(override_server_name.clone())
})
}
}

Expand Down
4 changes: 3 additions & 1 deletion src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -55,7 +55,9 @@ mod log {

pub use crate::config::ConfigBuilderExt;
pub use crate::connector::builder::ConnectorBuilder as HttpsConnectorBuilder;
pub use crate::connector::HttpsConnector;
pub use crate::connector::{
DefaultServerNameResolver, FixedServerNameResolver, HttpsConnector, ResolveServerName,
};
pub use crate::stream::MaybeHttpsStream;

/// The various states of the [`HttpsConnectorBuilder`]
Expand Down

0 comments on commit 7d31dc5

Please sign in to comment.