From dad8312ff7478374b202d82e8e31beafcb2de9e2 Mon Sep 17 00:00:00 2001 From: Sergio Benitez Date: Tue, 16 Apr 2024 15:17:10 -0700 Subject: [PATCH] Introduce dynamic TLS resolvers. This commit introduces the ability to dynamically select a TLS configuration based on the client's TLS hello. Added `Authority::set_port()`. Various `Config` structures for listeners removed. `UdsListener` is now `UnixListener`. `Bindable` removed in favor of new `Bind`. `Connection` requires `AsyncRead + AsyncWrite` again The `Debug` impl for `Endpoint` displays the underlying address in plaintext. `Listener` must be `Sized`. `tls` listener moved to `tls::TlsListener` The preview `quic` listener no longer implements `Listener`. All built-in listeners now implement `Bind<&Rocket>`. Clarified docs for `mtls::Certificate` guard. No reexporitng rustls from `tls`. Added `TlsConfig::server_config()`. Added some future helpers: `race()` and `race_io()`. Fix an issue where the logger wouldn't respect a configuration during error printing. Added Rocket::launch_with(), launch_on(), bind_launch(). Added a default client.pem to the TLS example. Revamped the testbench. Added tests for TLS resolvers, MTLS, listener failure output. TODO: clippy. TODO: UDS testing. Resolves #2730. Resolves #2363. Closes #2748. Closes #2683. Closes #2577. --- core/lib/src/error.rs | 8 +-- core/lib/src/listener/default.rs | 107 ++++++++++++++++++++++--------- core/lib/src/tls/resolver.rs | 33 ++++++++++ testbench/src/main.rs | 75 ++++++++++++++-------- testbench/src/server.rs | 38 ++++------- 5 files changed, 175 insertions(+), 86 deletions(-) diff --git a/core/lib/src/error.rs b/core/lib/src/error.rs index 85867017dd..5802e817fb 100644 --- a/core/lib/src/error.rs +++ b/core/lib/src/error.rs @@ -179,16 +179,16 @@ impl Error { match self.kind() { ErrorKind::Bind(ref a, ref e) => { if let Some(e) = e.downcast_ref::() { - e.pretty_print(); + e.pretty_print() } else { match a { Some(a) => error!("Binding to {} failed.", a.primary().underline()), None => error!("Binding to network interface failed."), } - } - info_!("{}", e); - "aborting due to bind error" + info_!("{}", e); + "aborting due to bind error" + } } ErrorKind::Io(ref e) => { error!("Rocket failed to launch due to an I/O error."); diff --git a/core/lib/src/listener/default.rs b/core/lib/src/listener/default.rs index a82201e3dd..ae2c102b11 100644 --- a/core/lib/src/listener/default.rs +++ b/core/lib/src/listener/default.rs @@ -1,8 +1,9 @@ +use core::fmt; + use serde::Deserialize; -use tokio_util::either::{Either, Either::{Left, Right}}; -use futures::TryFutureExt; +use tokio_util::either::Either::{Left, Right}; +use either::Either; -use crate::error::ErrorKind; use crate::{Ignite, Rocket}; use crate::listener::{Bind, Endpoint, tcp::TcpListener}; @@ -10,7 +11,7 @@ use crate::listener::{Bind, Endpoint, tcp::TcpListener}; #[cfg(feature = "tls")] use crate::tls::{TlsListener, TlsConfig}; mod private { - use super::{Either, TcpListener}; + use tokio_util::either::Either; #[cfg(feature = "tls")] pub type TlsListener = super::TlsListener; #[cfg(not(feature = "tls"))] pub type TlsListener = T; @@ -18,8 +19,8 @@ mod private { #[cfg(not(unix))] pub type UnixListener = super::TcpListener; pub type Listener = Either< - Either, TlsListener>, - Either, + Either, TlsListener>, + Either, >; } @@ -33,48 +34,90 @@ struct Config { pub type DefaultListener = private::Listener; +#[derive(Debug)] +pub enum Error { + Config(figment::Error), + Io(std::io::Error), + Unsupported(Endpoint), + #[cfg(feature = "tls")] + Tls(crate::tls::Error), +} + +impl From for Error { + fn from(value: figment::Error) -> Self { + Error::Config(value) + } +} + +impl From for Error { + fn from(value: std::io::Error) -> Self { + Error::Io(value) + } +} + +#[cfg(feature = "tls")] +impl From for Error { + fn from(value: crate::tls::Error) -> Self { + Error::Tls(value) + } +} + +impl From> for Error { + fn from(value: Either) -> Self { + value.either(Error::Config, Error::Io) + } +} + +impl fmt::Display for Error { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + match self { + Error::Config(e) => e.fmt(f), + Error::Io(e) => e.fmt(f), + Error::Unsupported(e) => write!(f, "unsupported endpoint: {e:?}"), + #[cfg(feature = "tls")] + Error::Tls(error) => error.fmt(f), + } + } +} + +impl std::error::Error for Error { + fn source(&self) -> Option<&(dyn std::error::Error + 'static)> { + match self { + Error::Config(e) => Some(e), + Error::Io(e) => Some(e), + Error::Unsupported(_) => None, + #[cfg(feature = "tls")] + Error::Tls(e) => Some(e), + } + } +} + impl<'r> Bind<&'r Rocket> for DefaultListener { - type Error = crate::Error; + type Error = Error; async fn bind(rocket: &'r Rocket) -> Result { let config: Config = rocket.figment().extract()?; match config.address { #[cfg(feature = "tls")] - endpoint@Endpoint::Tcp(_) if config.tls.is_some() => { - let listener = as Bind<_>>::bind(rocket) - .map_err(|e| ErrorKind::Bind(Some(endpoint), Box::new(e))) - .await?; - + Endpoint::Tcp(_) if config.tls.is_some() => { + let listener = as Bind<_>>::bind(rocket).await?; Ok(Left(Left(listener))) } - endpoint@Endpoint::Tcp(_) => { - let listener = >::bind(rocket) - .map_err(|e| ErrorKind::Bind(Some(endpoint), Box::new(e))) - .await?; - + Endpoint::Tcp(_) => { + let listener = >::bind(rocket).await?; Ok(Right(Left(listener))) } #[cfg(all(unix, feature = "tls"))] - endpoint@Endpoint::Unix(_) if config.tls.is_some() => { - let listener = as Bind<_>>::bind(rocket) - .map_err(|e| ErrorKind::Bind(Some(endpoint), Box::new(e))) - .await?; - + Endpoint::Unix(_) if config.tls.is_some() => { + let listener = as Bind<_>>::bind(rocket).await?; Ok(Left(Right(listener))) } #[cfg(unix)] - endpoint@Endpoint::Unix(_) => { - let listener = >::bind(rocket) - .map_err(|e| ErrorKind::Bind(Some(endpoint), Box::new(e))) - .await?; - + Endpoint::Unix(_) => { + let listener = >::bind(rocket).await?; Ok(Right(Right(listener))) } - endpoint => { - let msg = format!("unsupported bind endpoint: {endpoint}"); - let error = Box::::from(msg); - Err(ErrorKind::Bind(Some(endpoint), error).into()) - } + endpoint => Err(Error::Unsupported(endpoint)), } } diff --git a/core/lib/src/tls/resolver.rs b/core/lib/src/tls/resolver.rs index 475ec2b8e9..d7fb39677e 100644 --- a/core/lib/src/tls/resolver.rs +++ b/core/lib/src/tls/resolver.rs @@ -15,6 +15,39 @@ pub(crate) struct DynResolver(Arc); pub struct Fairing(PhantomData); /// A dynamic TLS configuration resolver. +/// +/// # Example +/// +/// This is an async trait. Implement it as follows: +/// +/// ```rust +/// # #[macro_use] extern crate rocket; +/// use std::sync::Arc; +/// use rocket::tls::{self, Resolver, TlsConfig, ClientHello, ServerConfig}; +/// use rocket::{Rocket, Build}; +/// +/// struct MyResolver(Arc); +/// +/// #[rocket::async_trait] +/// impl Resolver for MyResolver { +/// async fn init(rocket: &Rocket) -> tls::Result { +/// // This is equivalent to what the default resolver would do. +/// let config: TlsConfig = rocket.figment().extract_inner("tls")?; +/// let server_config = config.server_config().await?; +/// Ok(MyResolver(Arc::new(server_config))) +/// } +/// +/// async fn resolve(&self, hello: ClientHello<'_>) -> Option> { +/// // return a `ServerConfig` based on `hello`; here we ignore it +/// Some(self.0.clone()) +/// } +/// } +/// +/// #[launch] +/// fn rocket() -> _ { +/// rocket::build().attach(MyResolver::fairing()) +/// } +/// ``` #[crate::async_trait] pub trait Resolver: Send + Sync + 'static { async fn init(rocket: &Rocket) -> crate::tls::Result where Self: Sized { diff --git a/testbench/src/main.rs b/testbench/src/main.rs index 44b1a24fac..165d27b07d 100644 --- a/testbench/src/main.rs +++ b/testbench/src/main.rs @@ -1,4 +1,5 @@ use std::process::ExitCode; +use std::time::Duration; use rocket::listener::unix::UnixListener; use rocket::tokio::net::TcpListener; @@ -163,9 +164,7 @@ fn tls_resolver() -> Result<()> { let server = spawn! { #[get("/count")] fn count(counter: &State>) -> String { - let count = counter.load(Ordering::Acquire); - println!("{count}"); - count.to_string() + counter.load(Ordering::Acquire).to_string() } let counter = Arc::new(AtomicUsize::new(0)); @@ -329,8 +328,8 @@ fn tcp_unix_listener_fail() -> Result<()> { }; if let Err(Error::Liftoff(stdout, _)) = server { - assert!(stdout.contains("expected valid TCP")); - assert!(stdout.contains("for key default.address")); + assert!(stdout.contains("expected valid TCP (ip) or unix (path)")); + assert!(stdout.contains("default.address")); } else { panic!("unexpected result: {server:#?}"); } @@ -361,14 +360,17 @@ fn tcp_unix_listener_fail() -> Result<()> { macro_rules! tests { ($($f:ident),* $(,)?) => {[ - $(Test { name: stringify!($f), func: $f, }),* + $(Test { + name: stringify!($f), + run: |_: ()| $f().map_err(|e| e.to_string()), + }),* ]}; } #[derive(Copy, Clone)] struct Test { name: &'static str, - func: fn() -> Result<()>, + run: fn(()) -> Result<(), String>, } static TESTS: &[Test] = &tests![ @@ -377,37 +379,58 @@ static TESTS: &[Test] = &tests![ ]; fn main() -> ExitCode { + procspawn::init(); + let filter = std::env::args().nth(1).unwrap_or_default(); let filtered = TESTS.into_iter().filter(|test| test.name.contains(&filter)); println!("running {}/{} tests", filtered.clone().count(), TESTS.len()); - let handles: Vec<_> = filtered - .map(|test| (test, std::thread::spawn(move || { - if let Err(e) = (test.func)() { - println!("test {} ... {}\n {e}", test.name.bold(), "fail".red()); - return Err(e); + let handles = filtered.map(|test| (test, std::thread::spawn(|| { + let name = test.name; + let start = std::time::SystemTime::now(); + let mut proc = procspawn::spawn((), test.run); + let result = loop { + match proc.join_timeout(Duration::from_secs(10)) { + Err(e) if e.is_timeout() => { + let elapsed = start.elapsed().unwrap().as_secs(); + println!("{name} has been running for {elapsed} seconds..."); + + if elapsed >= 30 { + println!("{name} timeout"); + break Err(e); + } + }, + result => break result, } + }; - println!("test {} ... {}", test.name.bold(), "ok".green()); - Ok(()) - }))) - .collect(); - - let mut failure = false; - for (test, handle) in handles { - let result = handle.join(); - failure |= matches!(result, Err(_) | Ok(Err(_))); - if result.is_err() { - println!("test {} ... {}", test.name.bold(), "panic".red().underline()); + match result.as_ref().map_err(|e| e.panic_info()) { + Ok(Ok(_)) => println!("test {name} ... {}", "ok".green()), + Ok(Err(e)) => println!("test {name} ... {}\n {e}", "fail".red()), + Err(Some(_)) => println!("test {name} ... {}", "panic".red().underline()), + Err(None) => println!("test {name} ... {}", "error".magenta()), } + + matches!(result, Ok(Ok(()))) + }))); + + let mut success = true; + for (_, handle) in handles { + success &= handle.join().unwrap_or(false); } - match failure { - true => ExitCode::FAILURE, - false => ExitCode::SUCCESS + match success { + true => ExitCode::SUCCESS, + false => { + println!("note: use `NOCAPTURE=1` to see test output"); + ExitCode::FAILURE + } } } +// TODO: Implement an `UpdatingResolver`. Expose `SniResolver` and +// `UpdatingResolver` in a `contrib` library or as part of `rocket`. +// // struct UpdatingResolver { // timestamp: AtomicU64, // config: ArcSwap diff --git a/testbench/src/server.rs b/testbench/src/server.rs index 8a07cc14be..13b40c3e8d 100644 --- a/testbench/src/server.rs +++ b/testbench/src/server.rs @@ -42,6 +42,16 @@ fn stdio() -> Stdio { .unwrap_or_else(Stdio::piped) } +fn read(io: Option) -> Result { + if let Some(mut io) = io { + let mut string = String::new(); + io.read_to_string(&mut string)?; + return Ok(string); + } + + Ok(String::new()) +} + impl Server { pub fn spawn(ctxt: T, f: fn((Token, T)) -> Launched) -> Result where T: Serialize + DeserializeOwned @@ -62,14 +72,7 @@ impl Server { Ok(Server { proc, tls, port, _rx: rx }) }, Message::Failure => { - let stdout = proc.stdout().unwrap(); - let mut out = String::new(); - stdout.read_to_string(&mut out)?; - - let stderr = proc.stderr().unwrap(); - let mut err = String::new(); - stderr.read_to_string(&mut err)?; - Err(Error::Liftoff(out, err)) + Err(Error::Liftoff(read(proc.stdout())?, read(proc.stderr())?)) } } } @@ -80,23 +83,11 @@ impl Server { } pub fn read_stdout(&mut self) -> Result { - let Some(stdout) = self.proc.stdout() else { - return Ok(String::new()); - }; - - let mut string = String::new(); - stdout.read_to_string(&mut string)?; - Ok(string) + read(self.proc.stdout()) } pub fn read_stderr(&mut self) -> Result { - let Some(stderr) = self.proc.stderr() else { - return Ok(String::new()); - }; - - let mut string = String::new(); - stderr.read_to_string(&mut string)?; - Ok(string) + read(self.proc.stderr()) } pub fn kill(&mut self) -> Result<()> { @@ -133,8 +124,7 @@ impl Token { }))); let server = self.0.clone(); - let fut = rocket.launch_with::(); - if let Err(e) = rocket::execute(fut) { + if let Err(e) = rocket::execute(rocket.launch_with::()) { let sender = IpcSender::::connect(server).unwrap(); let _ = sender.send(Message::Failure); let _ = sender.send(Message::Failure);