Skip to content

Commit

Permalink
Introduce dynamic TLS resolvers.
Browse files Browse the repository at this point in the history
This commit introduces the ability to dynamically select a TLS
configuration based on the client's TLS hello.

Added `Authority::set_port()`.
Various `Config` structures for listeners removed.
`UdsListener` is now `UnixListener`.
`Bindable` removed in favor of new `Bind`.
`Connection` requires `AsyncRead + AsyncWrite` again
The `Debug` impl for `Endpoint` displays the underlying address in plaintext.
`Listener` must be `Sized`.
`tls` listener moved to `tls::TlsListener`
The preview `quic` listener no longer implements `Listener`.

All built-in listeners now implement `Bind<&Rocket>`.

Clarified docs for `mtls::Certificate` guard.

No reexporitng rustls from `tls`.
Added `TlsConfig::server_config()`.

Added some future helpers: `race()` and `race_io()`.

Fix an issue where the logger wouldn't respect a configuration during error
printing.

Added Rocket::launch_with(), launch_on(), bind_launch().

Added a default client.pem to the TLS example.

Revamped the testbench.

Added tests for TLS resolvers, MTLS, listener failure output.

TODO: clippy.
TODO: UDS testing.

Resolves #2730.
Resolves #2363.
Closes #2748.
Closes #2683.
Closes #2577.
  • Loading branch information
SergioBenitez committed Apr 17, 2024
1 parent 280fda4 commit dad8312
Show file tree
Hide file tree
Showing 5 changed files with 175 additions and 86 deletions.
8 changes: 4 additions & 4 deletions core/lib/src/error.rs
Original file line number Diff line number Diff line change
Expand Up @@ -179,16 +179,16 @@ impl Error {
match self.kind() {
ErrorKind::Bind(ref a, ref e) => {
if let Some(e) = e.downcast_ref::<Self>() {
e.pretty_print();
e.pretty_print()
} else {
match a {
Some(a) => error!("Binding to {} failed.", a.primary().underline()),
None => error!("Binding to network interface failed."),
}
}

info_!("{}", e);
"aborting due to bind error"
info_!("{}", e);
"aborting due to bind error"
}
}
ErrorKind::Io(ref e) => {
error!("Rocket failed to launch due to an I/O error.");
Expand Down
107 changes: 75 additions & 32 deletions core/lib/src/listener/default.rs
Original file line number Diff line number Diff line change
@@ -1,25 +1,26 @@
use core::fmt;

use serde::Deserialize;
use tokio_util::either::{Either, Either::{Left, Right}};
use futures::TryFutureExt;
use tokio_util::either::Either::{Left, Right};
use either::Either;

use crate::error::ErrorKind;
use crate::{Ignite, Rocket};
use crate::listener::{Bind, Endpoint, tcp::TcpListener};

#[cfg(unix)] use crate::listener::unix::UnixListener;
#[cfg(feature = "tls")] use crate::tls::{TlsListener, TlsConfig};

mod private {
use super::{Either, TcpListener};
use tokio_util::either::Either;

#[cfg(feature = "tls")] pub type TlsListener<T> = super::TlsListener<T>;
#[cfg(not(feature = "tls"))] pub type TlsListener<T> = T;
#[cfg(unix)] pub type UnixListener = super::UnixListener;
#[cfg(not(unix))] pub type UnixListener = super::TcpListener;

pub type Listener = Either<
Either<TlsListener<TcpListener>, TlsListener<UnixListener>>,
Either<TcpListener, UnixListener>,
Either<TlsListener<super::TcpListener>, TlsListener<UnixListener>>,
Either<super::TcpListener, UnixListener>,
>;
}

Expand All @@ -33,48 +34,90 @@ struct Config {

pub type DefaultListener = private::Listener;

#[derive(Debug)]
pub enum Error {
Config(figment::Error),
Io(std::io::Error),
Unsupported(Endpoint),
#[cfg(feature = "tls")]
Tls(crate::tls::Error),
}

impl From<figment::Error> for Error {
fn from(value: figment::Error) -> Self {
Error::Config(value)
}
}

impl From<std::io::Error> for Error {
fn from(value: std::io::Error) -> Self {
Error::Io(value)
}
}

#[cfg(feature = "tls")]
impl From<crate::tls::Error> for Error {
fn from(value: crate::tls::Error) -> Self {
Error::Tls(value)
}
}

impl From<Either<figment::Error, std::io::Error>> for Error {
fn from(value: Either<figment::Error, std::io::Error>) -> Self {
value.either(Error::Config, Error::Io)
}
}

impl fmt::Display for Error {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
match self {
Error::Config(e) => e.fmt(f),
Error::Io(e) => e.fmt(f),
Error::Unsupported(e) => write!(f, "unsupported endpoint: {e:?}"),
#[cfg(feature = "tls")]
Error::Tls(error) => error.fmt(f),
}
}
}

impl std::error::Error for Error {
fn source(&self) -> Option<&(dyn std::error::Error + 'static)> {
match self {
Error::Config(e) => Some(e),
Error::Io(e) => Some(e),
Error::Unsupported(_) => None,
#[cfg(feature = "tls")]
Error::Tls(e) => Some(e),
}
}
}

impl<'r> Bind<&'r Rocket<Ignite>> for DefaultListener {
type Error = crate::Error;
type Error = Error;

async fn bind(rocket: &'r Rocket<Ignite>) -> Result<Self, Self::Error> {
let config: Config = rocket.figment().extract()?;
match config.address {
#[cfg(feature = "tls")]
endpoint@Endpoint::Tcp(_) if config.tls.is_some() => {
let listener = <TlsListener<TcpListener> as Bind<_>>::bind(rocket)
.map_err(|e| ErrorKind::Bind(Some(endpoint), Box::new(e)))
.await?;

Endpoint::Tcp(_) if config.tls.is_some() => {
let listener = <TlsListener<TcpListener> as Bind<_>>::bind(rocket).await?;
Ok(Left(Left(listener)))
}
endpoint@Endpoint::Tcp(_) => {
let listener = <TcpListener as Bind<_>>::bind(rocket)
.map_err(|e| ErrorKind::Bind(Some(endpoint), Box::new(e)))
.await?;

Endpoint::Tcp(_) => {
let listener = <TcpListener as Bind<_>>::bind(rocket).await?;
Ok(Right(Left(listener)))
}
#[cfg(all(unix, feature = "tls"))]
endpoint@Endpoint::Unix(_) if config.tls.is_some() => {
let listener = <TlsListener<UnixListener> as Bind<_>>::bind(rocket)
.map_err(|e| ErrorKind::Bind(Some(endpoint), Box::new(e)))
.await?;

Endpoint::Unix(_) if config.tls.is_some() => {
let listener = <TlsListener<UnixListener> as Bind<_>>::bind(rocket).await?;
Ok(Left(Right(listener)))
}
#[cfg(unix)]
endpoint@Endpoint::Unix(_) => {
let listener = <UnixListener as Bind<_>>::bind(rocket)
.map_err(|e| ErrorKind::Bind(Some(endpoint), Box::new(e)))
.await?;

Endpoint::Unix(_) => {
let listener = <UnixListener as Bind<_>>::bind(rocket).await?;
Ok(Right(Right(listener)))
}
endpoint => {
let msg = format!("unsupported bind endpoint: {endpoint}");
let error = Box::<dyn std::error::Error + Send + Sync>::from(msg);
Err(ErrorKind::Bind(Some(endpoint), error).into())
}
endpoint => Err(Error::Unsupported(endpoint)),
}
}

Expand Down
33 changes: 33 additions & 0 deletions core/lib/src/tls/resolver.rs
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,39 @@ pub(crate) struct DynResolver(Arc<dyn Resolver>);
pub struct Fairing<T: ?Sized>(PhantomData<T>);

/// A dynamic TLS configuration resolver.
///
/// # Example
///
/// This is an async trait. Implement it as follows:
///
/// ```rust
/// # #[macro_use] extern crate rocket;
/// use std::sync::Arc;
/// use rocket::tls::{self, Resolver, TlsConfig, ClientHello, ServerConfig};
/// use rocket::{Rocket, Build};
///
/// struct MyResolver(Arc<ServerConfig>);
///
/// #[rocket::async_trait]
/// impl Resolver for MyResolver {
/// async fn init(rocket: &Rocket<Build>) -> tls::Result<Self> {
/// // This is equivalent to what the default resolver would do.
/// let config: TlsConfig = rocket.figment().extract_inner("tls")?;
/// let server_config = config.server_config().await?;
/// Ok(MyResolver(Arc::new(server_config)))
/// }
///
/// async fn resolve(&self, hello: ClientHello<'_>) -> Option<Arc<ServerConfig>> {
/// // return a `ServerConfig` based on `hello`; here we ignore it
/// Some(self.0.clone())
/// }
/// }
///
/// #[launch]
/// fn rocket() -> _ {
/// rocket::build().attach(MyResolver::fairing())
/// }
/// ```
#[crate::async_trait]
pub trait Resolver: Send + Sync + 'static {
async fn init(rocket: &Rocket<Build>) -> crate::tls::Result<Self> where Self: Sized {
Expand Down
75 changes: 49 additions & 26 deletions testbench/src/main.rs
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
use std::process::ExitCode;
use std::time::Duration;

use rocket::listener::unix::UnixListener;
use rocket::tokio::net::TcpListener;
Expand Down Expand Up @@ -163,9 +164,7 @@ fn tls_resolver() -> Result<()> {
let server = spawn! {
#[get("/count")]
fn count(counter: &State<Arc<AtomicUsize>>) -> String {
let count = counter.load(Ordering::Acquire);
println!("{count}");
count.to_string()
counter.load(Ordering::Acquire).to_string()
}

let counter = Arc::new(AtomicUsize::new(0));
Expand Down Expand Up @@ -329,8 +328,8 @@ fn tcp_unix_listener_fail() -> Result<()> {
};

if let Err(Error::Liftoff(stdout, _)) = server {
assert!(stdout.contains("expected valid TCP"));
assert!(stdout.contains("for key default.address"));
assert!(stdout.contains("expected valid TCP (ip) or unix (path)"));
assert!(stdout.contains("default.address"));
} else {
panic!("unexpected result: {server:#?}");
}
Expand Down Expand Up @@ -361,14 +360,17 @@ fn tcp_unix_listener_fail() -> Result<()> {

macro_rules! tests {
($($f:ident),* $(,)?) => {[
$(Test { name: stringify!($f), func: $f, }),*
$(Test {
name: stringify!($f),
run: |_: ()| $f().map_err(|e| e.to_string()),
}),*
]};
}

#[derive(Copy, Clone)]
struct Test {
name: &'static str,
func: fn() -> Result<()>,
run: fn(()) -> Result<(), String>,
}

static TESTS: &[Test] = &tests![
Expand All @@ -377,37 +379,58 @@ static TESTS: &[Test] = &tests![
];

fn main() -> ExitCode {
procspawn::init();

let filter = std::env::args().nth(1).unwrap_or_default();
let filtered = TESTS.into_iter().filter(|test| test.name.contains(&filter));

println!("running {}/{} tests", filtered.clone().count(), TESTS.len());
let handles: Vec<_> = filtered
.map(|test| (test, std::thread::spawn(move || {
if let Err(e) = (test.func)() {
println!("test {} ... {}\n {e}", test.name.bold(), "fail".red());
return Err(e);
let handles = filtered.map(|test| (test, std::thread::spawn(|| {
let name = test.name;
let start = std::time::SystemTime::now();
let mut proc = procspawn::spawn((), test.run);
let result = loop {
match proc.join_timeout(Duration::from_secs(10)) {
Err(e) if e.is_timeout() => {
let elapsed = start.elapsed().unwrap().as_secs();
println!("{name} has been running for {elapsed} seconds...");

if elapsed >= 30 {
println!("{name} timeout");
break Err(e);
}
},
result => break result,
}
};

println!("test {} ... {}", test.name.bold(), "ok".green());
Ok(())
})))
.collect();

let mut failure = false;
for (test, handle) in handles {
let result = handle.join();
failure |= matches!(result, Err(_) | Ok(Err(_)));
if result.is_err() {
println!("test {} ... {}", test.name.bold(), "panic".red().underline());
match result.as_ref().map_err(|e| e.panic_info()) {
Ok(Ok(_)) => println!("test {name} ... {}", "ok".green()),
Ok(Err(e)) => println!("test {name} ... {}\n {e}", "fail".red()),
Err(Some(_)) => println!("test {name} ... {}", "panic".red().underline()),
Err(None) => println!("test {name} ... {}", "error".magenta()),
}

matches!(result, Ok(Ok(())))
})));

let mut success = true;
for (_, handle) in handles {
success &= handle.join().unwrap_or(false);
}

match failure {
true => ExitCode::FAILURE,
false => ExitCode::SUCCESS
match success {
true => ExitCode::SUCCESS,
false => {
println!("note: use `NOCAPTURE=1` to see test output");
ExitCode::FAILURE
}
}
}

// TODO: Implement an `UpdatingResolver`. Expose `SniResolver` and
// `UpdatingResolver` in a `contrib` library or as part of `rocket`.
//
// struct UpdatingResolver {
// timestamp: AtomicU64,
// config: ArcSwap<ServerConfig>
Expand Down
Loading

0 comments on commit dad8312

Please sign in to comment.