Skip to content

Commit

Permalink
simpler impl of dynamically swapping tls creds
Browse files Browse the repository at this point in the history
closes rwf2#2363
  • Loading branch information
hcldan committed Mar 12, 2024
1 parent 97992b6 commit d0a0a7f
Show file tree
Hide file tree
Showing 6 changed files with 116 additions and 15 deletions.
26 changes: 26 additions & 0 deletions core/http/src/tls/certificate_resolver.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,26 @@
use std::{io, sync::Arc};

use rustls::{server::ClientHello, sign::{any_supported_type, CertifiedKey}};

use crate::tls::Config;
use crate::tls::util::{load_certs, load_private_key};

pub(crate) struct CertResolver(Arc<CertifiedKey>);
impl CertResolver {
pub fn new<R>(config: &mut Config<R>) -> Result<Self, std::io::Error>
where R: io::BufRead,
{
let certs = load_certs(&mut config.cert_chain)?;
let private_key = load_private_key(&mut config.private_key)?;
let key = any_supported_type(&private_key)
.map_err(|e| io::Error::new(io::ErrorKind::Other, format!("bad TLS config: {}", e)))?;

Ok(Self(Arc::new(CertifiedKey::new(certs, key))))
}
}

impl rustls::server::ResolvesServerCert for CertResolver {
fn resolve(&self, _client_hello: ClientHello<'_>) -> Option<Arc<CertifiedKey>> {
Some(self.0.clone())
}
}
22 changes: 11 additions & 11 deletions core/http/src/tls/listener.rs
Original file line number Diff line number Diff line change
Expand Up @@ -5,12 +5,14 @@ use std::task::{Context, Poll};
use std::future::Future;
use std::net::SocketAddr;

use rustls::server::ResolvesServerCert;
use tokio::net::{TcpListener, TcpStream};
use tokio::io::{AsyncRead, AsyncWrite};
use tokio_rustls::{Accept, TlsAcceptor, server::TlsStream as BareTlsStream};

use crate::tls::util::{load_certs, load_private_key, load_ca_certs};
use crate::tls::util::load_ca_certs;
use crate::listener::{Connection, Listener, Certificates};
use crate::tls::CertResolver;

/// A TLS listener over TCP.
pub struct TlsListener {
Expand Down Expand Up @@ -72,18 +74,12 @@ pub struct Config<R> {
}

impl TlsListener {
pub async fn bind<R>(addr: SocketAddr, mut c: Config<R>) -> io::Result<TlsListener>
where R: io::BufRead
pub async fn bind<R>(addr: SocketAddr, mut c: Config<R>, cert_resolver: Option<&Arc<dyn ResolvesServerCert>>) -> io::Result<TlsListener>
where R: io::BufRead,
{
use rustls::server::{AllowAnyAuthenticatedClient, AllowAnyAnonymousOrAuthenticatedClient};
use rustls::server::{NoClientAuth, ServerSessionMemoryCache, ServerConfig};

let cert_chain = load_certs(&mut c.cert_chain)
.map_err(|e| io::Error::new(e.kind(), format!("bad TLS cert chain: {}", e)))?;

let key = load_private_key(&mut c.private_key)
.map_err(|e| io::Error::new(e.kind(), format!("bad TLS private key: {}", e)))?;

let client_auth = match c.ca_certs {
Some(ref mut ca_certs) => match load_ca_certs(ca_certs) {
Ok(ca) if c.mandatory_mtls => AllowAnyAuthenticatedClient::new(ca).boxed(),
Expand All @@ -93,14 +89,18 @@ impl TlsListener {
None => NoClientAuth::boxed(),
};

let cert_resolver = match cert_resolver {
Some(c) => c.clone(),
None => Arc::new(CertResolver::new(&mut c)?),
};

let mut tls_config = ServerConfig::builder()
.with_cipher_suites(&c.ciphersuites)
.with_safe_default_kx_groups()
.with_safe_default_protocol_versions()
.map_err(|e| io::Error::new(io::ErrorKind::Other, format!("bad TLS config: {}", e)))?
.with_client_cert_verifier(client_auth)
.with_single_cert(cert_chain, key)
.map_err(|e| io::Error::new(io::ErrorKind::Other, format!("bad TLS config: {}", e)))?;
.with_cert_resolver(cert_resolver);

tls_config.ignore_client_order = c.prefer_server_order;

Expand Down
3 changes: 3 additions & 0 deletions core/http/src/tls/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,9 @@ mod listener;
#[cfg(feature = "mtls")]
pub mod mtls;

pub(crate) mod certificate_resolver;

pub use rustls;
pub use listener::{TlsListener, Config};
pub(crate) use certificate_resolver::*;
pub mod util;
1 change: 1 addition & 0 deletions core/lib/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -86,4 +86,5 @@ version_check = "0.9.1"

[dev-dependencies]
figment = { version = "0.10", features = ["test"] }
reqwest = { version = "0.11", features = ["blocking"] }
pretty_assertions = "1"
7 changes: 5 additions & 2 deletions core/lib/src/server.rs
Original file line number Diff line number Diff line change
Expand Up @@ -3,8 +3,8 @@ use std::sync::Arc;
use std::time::Duration;
use std::pin::Pin;

use yansi::Paint;
use tokio::sync::oneshot;
use yansi::Paint;
use tokio::time::sleep;
use futures::stream::StreamExt;
use futures::future::{FutureExt, Future, BoxFuture};
Expand Down Expand Up @@ -421,9 +421,12 @@ impl Rocket<Orbit> {
if self.config.tls_enabled() {
if let Some(ref config) = self.config.tls {
use crate::http::tls::TlsListener;
use crate::http::tls::rustls::server::ResolvesServerCert;

let conf = config.to_native_config().map_err(ErrorKind::Io)?;
let l = TlsListener::bind(addr, conf).await.map_err(ErrorKind::Bind)?;
let resolver = self.state::<Arc<dyn ResolvesServerCert>>();

let l = TlsListener::bind(addr, conf, resolver).await.map_err(ErrorKind::Bind)?;
addr = l.local_addr().unwrap_or(addr);
self.config.address = addr.ip();
self.config.port = addr.port();
Expand Down
72 changes: 70 additions & 2 deletions core/lib/tests/tls-config-from-source-1503.rs
Original file line number Diff line number Diff line change
Expand Up @@ -11,8 +11,8 @@ fn tls_config_from_source() {
use rocket::config::{Config, TlsConfig};
use rocket::figment::Figment;

let cert_path = relative!("examples/tls/private/cert.pem");
let key_path = relative!("examples/tls/private/key.pem");
let cert_path = relative!("../../examples/tls/private/cert.pem");
let key_path = relative!("../../examples/tls/private/key.pem");

let rocket_config = Config {
tls: Some(TlsConfig::from_paths(cert_path, key_path)),
Expand All @@ -24,3 +24,71 @@ fn tls_config_from_source() {
assert_eq!(tls.certs().unwrap_left(), cert_path);
assert_eq!(tls.key().unwrap_left(), key_path);
}

#[test]
fn tls_server_operation() {
use std::io::Read;

use rocket::{get, routes};
use rocket::config::{Config, TlsConfig};
use rocket::figment::Figment;

let cert_path = relative!("../../examples/tls/private/rsa_sha256_cert.pem");
let key_path = relative!("../../examples/tls/private/rsa_sha256_key.pem");
let ca_cert_path = relative!("../../examples/tls/private/ca_cert.pem");

println!("{cert_path:?}");

let port = {
let listener = std::net::TcpListener::bind(("127.0.0.1", 0)).expect("creating listener");
listener.local_addr().expect("getting listener's port").port()
};

let rocket_config = Config {
port,
tls: Some(TlsConfig::from_paths(cert_path, key_path)),
..Default::default()
};
let config: Config = Figment::from(rocket_config).extract().expect("creating config");
let (shutdown_signal_sender, mut shutdown_signal_receiver) = tokio::sync::mpsc::channel::<()>(1);

// Create a runtime in a separate thread for the server being tested
let join_handle = std::thread::spawn(move || {
let rt = tokio::runtime::Runtime::new().unwrap();

#[get("/hello")]
fn tls_test_get() -> &'static str {
"world"
}

rt.block_on(async {
let task_handle = tokio::spawn( async {
rocket::custom(config)
.mount("/", routes![tls_test_get])
.launch().await.unwrap();
});
shutdown_signal_receiver.recv().await;
task_handle.abort();
});
});

let request_url = format!("https://localhost:{}/hello", port);

// CA certificate is not loaded, so request should fail
assert!(reqwest::blocking::get(&request_url).is_err());

// Load the CA certicate for use with test client
let cert = {
let mut buf = Vec::new();
std::fs::File::open(ca_cert_path).expect("open ca_certs")
.read_to_end(&mut buf).expect("read ca_certs");
reqwest::Certificate::from_pem(&buf).expect("create certificate")
};
let client = reqwest::blocking::Client::builder().add_root_certificate(cert).build().expect("build client");

let response = client.get(&request_url).send().expect("https request");
assert_eq!(&response.text().unwrap(), "world");

shutdown_signal_sender.blocking_send(()).expect("signal shutdown");
join_handle.join().expect("join thread");
}

0 comments on commit d0a0a7f

Please sign in to comment.